PyTorch小技巧:使用Hook可视化网络层激活(各层输出)

server/2024/10/18 0:33:07/

这篇文章将演示如何可视化PyTorch激活层。可视化激活,即模型内各层的输出,对于理解深度神经网络如何处理视觉信息至关重要,这有助于诊断模型行为并激发改进。

我们先安装必要的库:

 pip install torch torchvision matplotlib

加载CIFAR-10数据集并可视化一些图像。这有助于理解模型处理的输入。

 importtorchvisionimporttorchvision.transformsastransformsimportmatplotlib.pyplotasplt# Transformations for the imagestransform=transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# Load CIFAR-10 datasettrainset=torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)trainloader=torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)# Function to show imagesdefimshow(img):img=img.numpy().transpose((1, 2, 0))mean=np.array([0.485, 0.456, 0.406])std=np.array([0.229, 0.224, 0.225])img=std*img+mean  # unnormalizeplt.imshow(img)plt.show()# Get some imagesdataiter=iter(trainloader)images, labels=next(dataiter)# Display imagesimshow(torchvision.utils.make_grid(images))

看着很模糊的原因是我们使用的CIFAR-10图像32x32的,很小 。因为对于小图像,处理速度很快,所以CIFAR-10称为研究的首选。

然后我们加载一个预训练的ResNet模型,并在特定的层上设置钩子函数,以在向前传递期间捕获激活。

 import torchfrom torchvision.models import resnet18# Load pretrained ResNet18model = resnet18(pretrained=True)model.eval()  # Set the model to evaluation mode# Hook setupactivations = {}def get_activation(name):def hook(model, input, output):activations[name] = output.detach()return hook# Register hooksmodel.layer1[0].conv1.register_forward_hook(get_activation('layer1_0_conv1'))model.layer4[0].conv1.register_forward_hook(get_activation('layer4_0_conv1'))

这样,在通过模型处理图像时就能捕获到激活。

 # Run the modelwith torch.no_grad():output = model(images)

通过上面钩子函数我们获得了激活下面就可以进行可视化

 # Visualization function for activationsdef plot_activations(layer, num_cols=4, num_activations=16):num_kernels = layer.shape[1]fig, axes = plt.subplots(nrows=(num_activations + num_cols - 1) // num_cols, ncols=num_cols, figsize=(12, 12))for i, ax in enumerate(axes.flat):if i < num_kernels:ax.imshow(layer[0, i].cpu().numpy(), cmap='twilight')ax.axis('off')plt.tight_layout()plt.show()# Display a subset of activationsplot_activations(activations['layer1_0_conv1'], num_cols=4, num_activations=16)

结果如下:

 plot_activations(activations['layer4_0_conv1'], num_cols=4, num_activations=16)

PyTorch的钩子函数(hooks)是一种非常有用的特性,它们允许你在训练的前向传播和反向传播过程中插入自定义操作。这对于调试、修改梯度或者理解网络的内部运作非常有帮助。

利用 PyTorch 钩子函数来可视化网络中的激活是一种很好的方式,尤其是想要理解不同层如何响应不同输入的情况下。在这个过程中,我们可以捕捉到网络各层的输出,并将其可视化以获得直观的理解。

可视化激活有助于理解卷积神经网络中的各个层如何响应输入图像中的不同特征。通过可视化不同的层,可以评估早期层是否捕获边缘和纹理等基本特征,而较深的层是否捕获更复杂的特征。这些知识对于诊断问题、调整层架构和改进整体模型性能是非常宝贵的。

https://avoid.overfit.cn/post/c63b9b1130fe425ea5b7d0bedf209b2e


http://www.ppmy.cn/server/7529.html

相关文章

浏览器CSS兼容性问题解决方案整理

1、CSS Hack 使用 hacker 可以把浏览器分为3类&#xff1a;IE6&#xff1b;IE7和遨游&#xff1b;其他&#xff08;IE8 Chrome ff Safari opera等&#xff09; &#xff08;1&#xff09;IE6认识的 hacker 是 下划线 _ 和星号 * &#xff08;2&#xff09;IE7和遨游认识的 hac…

npm常用命令详解

前言 npm&#xff08;Node Package Manager&#xff09;是Node.js的包管理器&#xff0c;它允许开发者安装、分享、更新和管理JavaScript库和工具。以下是一些常用的npm命令及其详细解释&#xff1a; 基础命令 1. 初始化一个新项目 npm init这个命令会引导你创建一个新的pack…

安全狗云眼的主要功能有哪些?

"安全狗云眼"是一款综合性的网络安全产品&#xff0c;主要用于实时监控和保护企业的网络安全。其核心功能包括威胁检测、漏洞扫描、日志管理和合规性检查等。 以下是安全狗云眼的主要功能详细介绍&#xff1a; 1、资产管理 定期获取并记录主机上的Web站点、Web容器、…

vite+vue3+antDesignVue 记录-持续记录

记录学习过程 持续补充 每天的学习点滴 开始时间2024-04-12 1&#xff0c;报错记录 &#xff08;1&#xff09;env.d.ts文件 解决方法&#xff1a; 在env.d.ts文件中添加以下代码&#xff08;可以看一下B站尚硅谷的讲解视频&#xff09; declare module *.vue {import { Defi…

Android startForegroundService与startForeground

启动service service启动有四种形式。 1.显示启动(如直接按service的全路径启动) 2.隐示启动(如通过intent-filter的action标签启动) 3.通过bindservice显示启动。 4.通过bindservice隐示启动。 Demo 创建一个service的子类&#xff0c;如 import android.app.Notifica…

力扣HOT100 - 25. K 个一组翻转链表

解题思路&#xff1a; class Solution {public ListNode reverseKGroup(ListNode head, int k) {ListNode dum new ListNode(0, head);ListNode pre dum;ListNode end dum;while (end.next ! null) {for (int i 0; i < k && end ! null; i) {end end.next;}if …

Vue模版语法(初学Vue之v-指令语法)

目录 一、介绍 1.概念 2.常见指令语法及用法 1.v-bind: 2.v-model: 3.v-if / v-else-if / v-else: 4.v-for: 5.v-on: 6.v-show: 7.v-pre: 8.v-cloak: 二、使用 1.Mustache插值语法 2.v-once指令使用 3.v-text指令使用 4.v-html指令使用 5.v-pre指令使用 6.v-…

springboot 从mysql 迁移人大金仓 -kingbase

一、配置方法修改 1、添加maven依赖 <!-- 人大金仓 --><dependency><groupId>cn.com.kingbase</groupId><artifactId>kingbase8</artifactId><version>8.6.0</version></dependency> 2、连接配置&#xff0c;修改 .y…