利用前向勾子获取神经网络中间层的输出并将其进行保存(示例详解)

news/2024/10/30 17:14:34/

代码示例:

python"># 激活字典,用于保存每次的中间特征
activation = {}# 将 forward_hook 函数定义在 upsample_v2 外部
def forward_hook(name):def hook(module, input, output):activation[name] = output.detach()return hookdef upsample_v2(in_channels, out_channels, upscale, kernel_size=3):layers = []# Define mid channel stages (three times reduction)mid_channels = [256, 128, 64]  # 512 32 32 -> 256 64 64 -> 128 128 128 -> 64 256 256 -> 2 256 256scale_factor_per_step = upscale ** (1/3)  # Calculate the scaling for each stepcurrent_in_channels = in_channels# Upsample and reduce channels in 3 stepsfor step, mid_channel in enumerate(mid_channels):# Conv layer to reduce number of channelsconv = nn.Conv2d(current_in_channels, mid_channel, kernel_size=kernel_size, padding=1, bias=False)nn.init.kaiming_normal_(conv.weight.data, nonlinearity='relu')layers.append(conv)# ReLU activationrelu = nn.ReLU()layers.append(relu)# Upsampling layerup = nn.Upsample(scale_factor=scale_factor_per_step, mode='bilinear', align_corners=True)layers.append(up)layers[-1].register_forward_hook(forward_hook(f'step_{step}'))# Update current in_channels for the next layercurrent_in_channels = mid_channelconv = nn.Conv2d(current_in_channels, out_channels, kernel_size=kernel_size, padding=1, bias=False)nn.init.kaiming_normal_(conv.weight.data, nonlinearity='relu')layers.append(conv)return nn.Sequential(*layers)
python">def forward_hook(name):def hook(module, input, output):activation[name] = output.detach()return hook

forward_hook布置了抓取函数。其中,module代表你下面勾的那一层,input代表那一层的输入,output定义那一层的输出,我们常常只使用output。

python">layers[-1].register_forward_hook(forward_hook(f'step_{step}'))

这里定义了我需要捕获的那一层,layers[-1]代表我要捕获当前layers的最后一层,即上采用层,由于循环了三次,所以最后勾取的应当是三份中间层输出。


http://www.ppmy.cn/news/1543146.html

相关文章

抓取和分析JSON数据:使用Python构建数据处理管道

引言 在大数据时代,电商网站如亚马逊、京东等已成为数据采集的重要来源。获取并分析这些平台的产品信息可为市场分析、价格比较等提供数据支持。然而,由于网站数据通常以JSON格式动态加载,且限制较多(如IP限制、反爬机制&#xf…

echarts实现 水库高程模拟图表

需求背景解决思路解决效果index.vue 需求背景 需要做一个水库高程模拟的图表&#xff0c;x轴是水平距离&#xff0c;y轴是高程&#xff0c;需要模拟改水库的形状 echarts 图表集链接 解决思路 配合ui切图&#xff0c;模拟水库形状 解决效果 index.vue <!--/*** author:…

【Python爬虫实战】网络爬虫完整指南:网络协议OSI模型

网络爬虫完整指南&#xff1a;从协议基础到实践应用 什么是网络协议&#xff1f; **网络协议&#xff08;Network Protocol&#xff09;**是指计算机网络中设备和设备之间进行通信的规则和约定。它定义了数据传输的格式、顺序、传输方法和错误处理机制&#xff0c;使不同设备和…

Ubuntu 22.04系统启动时自动运行ROS2节点

在 Ubuntu 启动时自动运行 ROS2 节点的方法 环境&#xff1a;Ubuntu 系统&#xff0c;ROS2 Humble&#xff0c;使用系统自带的 启动应用程序 目标&#xff1a;在系统启动时自动运行指定的 ROS2 节点 效果展示 系统启动后&#xff0c;自动运行小乌龟节点和键盘控制节点。 实践…

uniapp写抖音小程序阻止右滑返回上一个页面

最近用uniapp写小程序遇到一个问题因为内部用到右滑的业务&#xff0c;但是只要右滑就会回到上一页面&#xff0c;用了event.preventDeafult()没有用&#xff0c;看了文档找到了解决办法 1.在最外层view加上touchstart事件 <view class"container" touchstart&q…

ChartCheck: Explainable Fact-Checking over Real-World Chart Images

论文地址: https://aclanthology.org/2024.findings-acl.828.pdfhttps://aclanthology.org/2024.findings-acl.828.pdf 1.概述 事实验证技术在自然语言处理领域获得了广泛关注,尤其是在针对误导性陈述的检查方面。然而,利用图表等数据可视化来传播信息误导的情况却很少受到…

一篇快速入门Jmeter

&#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 为什么要撰写这样一个教程呢&#xff1f; 深入学习Jmeter 温故而知新。尽管我已经使用JMeter很长时间&#xff0c;但还有许多元件我并不十分了解&#xff0c;…

C++:STL

STL的定义&#xff1a;包括了三类&#xff0c;算法容器和迭代器。 算法&#xff1a;包括排序、复制等常用算法&#xff0c;以及不同容器特定的算法。 容器&#xff1a;数据存放的形式&#xff0c;包括序列式容器和关联式容器。序列式容器就是list、vector等。关联式容器就是s…