利用register_forward_hook()精确定位到模型某一层的输入和输出

news/2024/12/1 0:26:07/

在论文中偶然读到一些方法会用到模型中间的隐藏层作为分类器,与模型最后一层作为分类器的性能进行对比,故而思考如何能够简便快捷地实现将模型某一层的输出输出拉取出来的方法,发现有现成hook函数可以做到这一点。

hook

hook就是一个钩子,用来把网络中的某一层的输入输出或者其他信息钩出来,如果想知道网络中某一层的详细信息,不用在定义网络时单独写一个print,直接写一个hook函数即可。

register_forward_hook

源代码里说明,hook只能用在forward()函数运行之前,写在forward函数运行之后是没用的,意思是想要运行hook,先把hook的函数写好,然后再实例化网络

def register_forward_hook(self, hook):r'''Registers a forward hook on the module.The hook will be called every time after :func:`forward` has computed an output.It should have the following signature::hook(module, input, output) -> None or modified outputThe hook can modify the output. It can modify the input inplace butit will not have effect on forward since this is called after:func:`forward` is called.Returns::class:`torch.utils.hooks.RemovableHandle`:a handle that can be used to remove the added hook by calling``handle.remove()``'''handle = hooks.RemovableHandle(self._forward_hooks)self._forward_hooks[handle.id] = hookreturn handle

问题

模型中有时会出现多个Linear层,但net.children()提取出来的所有类型一致的模块其名称也一致,故根据当前Linear层的输入和输出维度进行判断,精确锁定到该层,其他模块也依然适用

代码部分

import torch
import torch.nn as nn
class TestForHook(nn.Module):def __init__(self):super().__init__()self.linear_1=nn.Linear(2,2)self.linear_2 = nn.Linear(in_features=2, out_features=1)self.add_module('linear1', self.linear_1)self.add_module('linear2', self.linear_2)self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.initialize()def forward(self,x):linear_1=self.linear_1(x)linear_2=self.linear_2(linear_1)relu=self.relu(linear_2)relu_6 = self.relu6(relu)layers_in=(x,linear_1,linear_2)layers_out=(linear_1,linear_2,relu)return relu_6,layers_in,layers_outdef initialize(self):'''定义特殊的初始化,用于验证hook作用时是否获取了权重'''self.linear_1.weight=torch.nn.Parameter(torch.FloatTensor([[1,1],[1,1]]))self.linear_1.bias=torch.nn.Parameter(torch.FloatTensor([1,1]))self.linear_2.weight=torch.nn.Parameter(torch.FloatTensor([[1,1]]))self.linear_2.bias=torch.nn.Parameter(torch.FloatTensor([1]))return True
#定义hook函数用来决定勾出来的网络信息用来做什么
#定义用于获取网络各层输入输出的tensor容器
#定义nodule_name用于记录相应的module名字
module_name=[]
features_in_hook=[]
features_out_hook=[]
#hook函数需要3个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数
#hook函数负责将获取的输入输出添加到feature列表中 并提供相应的module名字
def hook(module,input,output):print("hooker working")module_name.append(module.__class__)features_in_hook.append(input)features_out_hook.append(output)return None
#对需要的层register hook
#register hook必须在forward()函数被执行之前,也就是实例化网络之前,下面的代码对网络除了ReLU以外的层都register了
#也可以选定其中的某些层进行register
net=TestForHook()
net_children=net.children()
#不同Linear层的参数in_features和out_features通常不同,可以用这些信息来判断
for child in net_children:if isinstance(child, nn.Linear) and child.in_features == 2 and child.out_features == 2:# if isinstance(child, nn.Linear):child.register_forward_hook(hook=hook)
#测试forward()提供的输入输出特征
x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])
out, features_in_forward, features_out_forward = net(x)
# print("*"*5+"forward return features"+"*"*5)
# print(features_in_forward)
# print(features_out_forward)
# print("*"*5+"forward return features"+"*"*5)
#hook通过list结构进行记录,所以可以直接print
print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)

http://www.ppmy.cn/news/1054651.html

相关文章

【学习FreeRTOS】第11章——FreeRTOS中任务相关的其他API函数

1.函数总览 序号函数描述1uxTaskPriorityGet()获取任务优先级2vTaskPrioritySet()设置任务优先级3uxTaskGetNumberOfTasks()获取系统中任务的数量4uxTaskGetSystemState()获取所有任务的状态信息5vTaskGetInfo()获取单个任务的状态信息6xTaskGetCurrentTaskHandle()获取当前任…

基于LSTM深度学习网络的时间序列分析matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 % 随机打乱数据集并划分训练集和测试集 index_list randperm(size(wdata, 1)); ind …

LeetCode 786. 第 K 个最小的素数分数

&#x1f517; 原题链接&#xff1a;786. 第 K 个最小的素数分数 本题可以暴力求解&#xff1a; class Solution { public:vector<int> kthSmallestPrimeFraction(vector<int>& arr, int k) {int n arr.size();vector<pair<int, int>> frac;for …

flutter TARGET_SDK_VERSION和android 13

config.gradle ext{SDK_VERSION 33MIN_SDK_VERSION 23TARGET_SDK_VERSION 33COMPILE_SDK_VERSION SDK_VERSIONBUILD_TOOL_VERSION "33.0.0"//兼容库版本SUPPORT_LIB_VERSION "33.0.0"}app/build.gradle里面的 defaultConfig {// TODO: Specify your…

云服务器和虚拟主机区别

虚拟主机和云服务器是常见的网站托管方式&#xff0c;都可以让网站在互联网上运行&#xff0c;但是它们有很大的区别。本文将从使用场景、性能、安全性、灵活性、价格等方面详细介绍虚拟主机和云服务器的区别。 一、使用场景 虚拟主机是一个物理服务器通过虚拟化技术划分成多…

配置NTP时间服务器

1.配置ntp时间服务器&#xff0c;确保客户端主机能和服务主机同步时间 ​ 客户端主机 同步成功 2.配置ssh免密登陆&#xff0c;能够通过客户端主机通过redhat用户和服务端主机基于公钥验证方式进行远程连接

攻防世界-Web_php_include

原题 解题思路 php://被替换了&#xff0c;但是只做了一次比对&#xff0c;改大小写就可以绕过。 用burp抓包&#xff0c;看看有哪些文件 flag明显在第一个PHP文件里&#xff0c;直接看

高等数学:线性代数-第二章

文章目录 第2章 矩阵及其运算2.1 线性方程组和矩阵2.2 矩阵的运算2.3 逆矩阵2.4 Cramer法则 第2章 矩阵及其运算 2.1 线性方程组和矩阵 n \bm{n} n 元线性方程组 设有 n 个未知数 m 个方程的线性方程组 { a 11 x 1 a 12 x 2 ⋯ a 1 n x n b 1 a 21 x 1 a 22 x 2 ⋯ a …