pytorch中的hook机制register_forward_hook

ops/2024/10/18 16:47:59/

上篇文章主要介绍了hook钩子函数的大致使用流程,本篇文章主要介绍pytorch中的hook机制register_forward_hook,手动在forward之前注册hook,hook在forward执行以后被自动执行。

1、hook背景
Hook被成为钩子机制,pytorch中包含forward和backward两个钩子注册函数,用于获取forward和backward中输入和输出,按照自己不全面的理解,应该目的是“不改变网络的定义代码,也不需要在forward函数中return某个感兴趣层的输出,这样代码太冗杂了”。

2、源码阅读
register_forward_hook()函数必须在forward()函数调用之前被使用,因为该函数源码注释显示这个函数“ it will not have effect on forward since this is called after :func:forward is called”,也就是这个函数在forward()之后就没有作用了!):
作用:获取forward过程中每层的输入和输出,用于对比hook是不是正确记录。

def register_forward_hook(self, hook):r"""Registers a forward hook on the module.The hook will be called every time after :func:`forward` has computed an output.It should have the following signature::hook(module, input, output) -> None or modified outputThe hook can modify the output. It can modify the input inplace butit will not have effect on forward since this is called after:func:`forward` is called.Returns::class:`torch.utils.hooks.RemovableHandle`:a handle that can be used to remove the added hook by calling``handle.remove()``"""handle = hooks.RemovableHandle(self._forward_hooks)self._forward_hooks[handle.id] = hookreturn handle

3、定义一个用于测试hook的类
如果随机的初始化每个层,那么就无法测试出自己获取的输入输出是不是forward中的输入输出了,所以需要将每一层的权重和偏置设置为可识别的值(比如全部初始化为1)。网络包含两层(Linear有需要求导的参数被称为一个层,而ReLU没有需要求导的参数不被称作一层),init()中调用initialize函数对所有层进行初始化。

**注意:**在forward()函数返回各个层的输出,但是ReLU6没有返回,因为后续测试的时候不对这一层进行注册hook。

class TestForHook(nn.Module):def __init__(self):super().__init__()self.linear_1 = nn.Linear(in_features=2, out_features=2)self.linear_2 = nn.Linear(in_features=2, out_features=1)self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.initialize()def forward(self, x):linear_1 = self.linear_1(x)linear_2 = self.linear_2(linear_1)relu = self.relu(linear_2)relu_6 = self.relu6(relu)layers_in = (x, linear_1, linear_2)layers_out = (linear_1, linear_2, relu)return relu_6, layers_in, layers_outdef initialize(self):""" 定义特殊的初始化,用于验证是不是获取了权重"""self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))return True

4、定义hook函数
hook()函数是register_forward_hook()函数必须提供的参数,首先定义几个容器用于记录:
定义用于获取网络各层输入输出tensor的容器:

# 同时定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []
hook函数需要三个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数:

hook函数负责将获取的输入输出添加到feature列表中;并提供相应的module名字。

def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None

5、对需要的层注册hook
注册钩子必须在forward()函数被执行之前,也就是定义网络进行计算之前就要注册,下面的代码对网络除去ReLU6以外的层都进行了注册(也可以选定某些层进行注册):
注册钩子可以对某些层单独进行:

net = TestForHook()
net_chilren = net.children()
for child in net_chilren:if not isinstance(child, nn.ReLU6):child.register_forward_hook(hook=hook)

6、测试forward()返回的特征和hook记录的是否一致
6.1 测试forward()提供的输入输出特征

由于前面的forward()函数返回了需要记录的特征,这里可以直接测试:

out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)

输出如下:

*****forward return features*****
(tensor([[0.1000, 0.1000],[0.1000, 0.1000]]), tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>))
(tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<ThresholdBackward0>))
*****forward return features*****

6.2 hook记录的输入特征和输出特征
hook通过list结构进行记录,所以可以直接print。

测试features_in是否存储了输入:

print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)

得到和forward一样的结果:

*****hook record features*****
[(tensor([[0.1000, 0.1000],[0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>),)]
[tensor([[1.2000, 1.2000],[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],[3.4000]], grad_fn=<ThresholdBackward0>)]
[<class 'torch.nn.modules.linear.Linear'>, 
<class 'torch.nn.modules.linear.Linear'>,<class 'torch.nn.modules.activation.ReLU'>]
*****hook record features*****

6.3 把hook记录的和forward做减法
如果害怕会有小数点后面的数值不一致,或者数据类型的不匹配,可以对hook记录的特征和forward记录的特征做减法:
测试forward返回的feautes_in是不是和hook记录的一致:

print("sub result'")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):print(forward_return-hook_record[0])

得到的全部都是0,说明hook没问题:

sub result
tensor([[0., 0.],[0., 0.]])
tensor([[0., 0.],[0., 0.]], grad_fn=<SubBackward0>)
tensor([[0.],[0.]], grad_fn=<SubBackward0>)

7、完整代码

import torch
import torch.nn as nnclass TestForHook(nn.Module):def __init__(self):super().__init__()self.linear_1 = nn.Linear(in_features=2, out_features=2)self.linear_2 = nn.Linear(in_features=2, out_features=1)self.relu = nn.ReLU()self.relu6 = nn.ReLU6()self.initialize()def forward(self, x):linear_1 = self.linear_1(x)linear_2 = self.linear_2(linear_1)relu = self.relu(linear_2)relu_6 = self.relu6(relu)layers_in = (x, linear_1, linear_2)layers_out = (linear_1, linear_2, relu)return relu_6, layers_in, layers_outdef initialize(self):""" 定义特殊的初始化,用于验证是不是获取了权重"""self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))return True# 定义用于获取网络各层输入输出tensor的容器,并定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []# hook函数负责将获取的输入输出添加到feature列表中,并提供相应的module名字
def hook(module, fea_in, fea_out):print("hooker working")module_name.append(module.__class__)features_in_hook.append(fea_in)features_out_hook.append(fea_out)return None# 定义全部是1的输入:
x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])# 注册钩子可以对某些层单独进行:
net = TestForHook()
net_chilren = net.children()
for child in net_chilren:if not isinstance(child, nn.ReLU6):child.register_forward_hook(hook=hook)# 测试网络输出:
out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)# 测试features_in是不是存储了输入:
print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)# 测试forward返回的feautes_in是不是和hook记录的一致:
print("sub result")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):print(forward_return-hook_record[0])

http://www.ppmy.cn/ops/90359.html

相关文章

Oceanbase 执行计划

test100 CREATE TABLE `test100` ( `GRNT_CTR_NO` varchar(32) COLLATE utf8mb4_bin NOT NULL COMMENT 担保合同编号, `GRNT_CTR_TYP` varchar(3) COLLATE utf8mb4_bin NOT NULL COMMENT 担保合同类型, `COLC_GRNT_IND` varchar(1) COLLATE utf8mb4_bin DEFAULT NULL …

ELK 之logstash filter grok常见内置模式

Logstash 的 grok 过滤器是一种用于解析非结构化日志数据的强大工具。grok 使用了一系列模式来解析和提取日志消息中的数据。这些模式可以用来识别并提取特定的数据字段&#xff0c;如 IP 地址、日期、URL、数字等。grok 提供了一些常见的内置模式&#xff0c;以下是一些最常见…

oracle数据库监控数据库中某个表是否正常生成

事情经过&#xff1a; 公司某业务系统每月25日0点会自动生成下个月的表&#xff0c;表名字是tabname_202407的格式。由于7月25日0点做系统保养的时候重启了应用系统服务&#xff0c;导致8月份的表没有生成。最终操作业务影响&#xff0c;为此决定对这个表进行监控&#xff0c;…

【开源移植】MultiButton_小型按键驱动模块移植

MultiButton 简介 MultiButton 是一个小巧简单易用的事件驱动型按键驱动模块&#xff0c;可无限量扩展按键&#xff0c;按键事件的回调异步处理方式可以简化你的程序结构&#xff0c;去除冗余的按键处理硬编码&#xff0c;让你的按键业务逻辑更清晰。 使用方法 1.先申请一个…

科普文:微服务之Spring Cloud OpenFeign服务调用调用过程分析

概叙 Feign和OpenFeign的关系 其实到现在&#xff0c;至少是2018年之后&#xff0c;我们说Feign&#xff0c;或者说OpenFeign其实是一个东西&#xff0c;就是OpenFeign&#xff0c;是2018.12 Netflix停止维护后&#xff0c;Spring cloud整合Netflix生态系统的延续。后面其实都…

PostgreSQL-01-入门篇-简介

文章目录 1. PostgreSQL是什么?2. PostgreSQL 历史2.1. 伯克利 POSTGRES 项目2.2. Postgres952.3. PostgreSQL来了 3. PostgreSQL vs MySQL4. 安装4.1 Windows 安装4.2 linux 安装4.3 docker安装 1. PostgreSQL是什么? PostgreSQL 是一个基于加州大学伯克利分校计算机系开发…

2、HarmonyOS鸿蒙开发--ArkUI界面开发基础

ArkUI基础 组件&#xff1a;ArkUI构建页面的最小单元 组件分类&#xff1a;容器组件(column&#xff0c;row)&#xff0c;基础组件 布局思路&#xff1a;先排版&#xff0c;放内容,后美化 注意&#xff1a;build是容器组件&#xff0c;且只能有一个根元素(不可并列两个row/colu…

Nextjs——国际化那些事儿

背景&#xff1a; 某一天&#xff0c;产品经理跟我说&#xff0c;我们的产品需要搞国际化 国际化的需求说白了就是把项目中的文案翻译成不同的语言&#xff0c;用户想用啥语言来浏览网页就用啥语言&#xff0c;虽然说英语是通用语言&#xff0c;但国际化了嘛&#xff0c;产品才…