python-pytorch 如何使用python库Netron查看模型结构(以pytorch官网模型为例)0.9.2

news/2024/10/21 10:20:22/

Netron查看模型结构

    • 参照模型
    • 安装Netron
    • 写netron代码
    • 运行查看结果
    • 需要关注的地方

  • 2024年4月27日14:32:30----0.9.2

参照模型

pytorch官网的tutorial为观察对象,链接是https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html

模型代码如下

python">import torch.nn as nn
import torch.nn.functional as Fclass RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.i2h = nn.Linear(input_size, hidden_size)self.h2h = nn.Linear(hidden_size, hidden_size)self.h2o = nn.Linear(hidden_size, output_size)self.softmax = nn.LogSoftmax(dim=1)def forward(self, input, hidden):hidden = F.tanh(self.i2h(input) + self.h2h(hidden))output = self.h2o(hidden)output = self.softmax(output)return output, hiddendef initHidden(self):return torch.zeros(1, self.hidden_size)n_hidden = 128
rnn = RNN(n_letters, n_hidden, n_categories)

安装Netron

pip install netron即可

其他安装方式参考链接
https://blog.csdn.net/m0_49963403/article/details/136242313

写netron代码

随便找一个地方打个点,如sample方法中

python">import netron
max_length = 20# Sample from a category and starting letter
def sample(category, start_letter='A'):with torch.no_grad():  # no need to track history in samplingcategory_tensor = categoryTensor(category)input = inputTensor(start_letter)hidden = rnn.initHidden()output_name = start_letterfor i in range(max_length):
#             print("category_tensor",category_tensor.size())
#             print("input[0]",input[0].size())
#             print("hidden",hidden.size())output, hidden = rnn(category_tensor, input[0], hidden)torch.onnx.export(rnn,(category_tensor, input[0], hidden) , f='AlexNet1.onnx')   #导出 .onnx 文件netron.start('AlexNet1.onnx') #展示结构图break
#             print("output",output.size())
#             print("hidden",hidden.size())
#             print("====================")topv, topi = output.topk(1)topi = topi[0][0]if topi == n_letters - 1:breakelse:letter = all_letters[topi]output_name += letterinput = inputTensor(letter)return output_name# Get multiple samples from one category and multiple starting letters
def samples(category, start_letters='ABC'):for start_letter in start_letters:print(sample(category, start_letter))breaksamples('Russian', 'RUS')

运行查看结果

结果是在浏览器中,运行成功后会显示:
Serving ‘AlexNet.onnx’ at http://localhost:8080

打开这个网页就可以看见模型结构,如下图

在这里插入图片描述

需要关注的地方

  1. 关于参数
    如果模型是一个参数的情况下,如下使用就可以了
python">import torch
from torchvision.models import AlexNet
import netron
model = AlexNet()
input = torch.ones((1,3,224,224))
torch.onnx.export(model, input, f='AlexNet.onnx')
netron.start('AlexNet.onnx')

如果模型有多个参数的情况下,则需要如下用括号括起来,如本文中的例子

python">torch.onnx.export(rnn,(category_tensor, input[0], hidden) , f='AlexNet1.onnx')   #导出 .onnx 文件
netron.start('AlexNet1.onnx') #展示结构图
  1. 如果运行过程中发现报错找不到模型
    有可能是你手动删除了生成的模型,最好的方法是重新生成这个模型,再运行

http://www.ppmy.cn/news/1442627.html

相关文章

应用实战|只需几步,即可享有外卖订餐小程序

本示例是一个简单的外卖查看店铺点菜的外卖微信小程序,小程序后端服务使用了MemFire Cloud,其中使用到的MemFire Cloud功能包括: 其中使用到的MemFire Cloud功能包括: 云数据库:存储外卖微信小程序所有数据表的信息。…

【需求案例】博客需求

第一章 需求调研 1.Blog系统趋势 1.1 个人博客 市面上存在很多的个人博客,且个人博客一般用于记录自己所遇到的问题及解决方案和技术分享。个人博客的优势是可以快速的记录问题与回归问题,可将自己掌握的技术知识展现给互联网中的用户进行技术吸纳。个人博客的劣势是,站点…

华为ensp中链路聚合两种(lacp-static)模式配置方法

作者主页:点击! ENSP专栏:点击! 创作时间:2024年4月26日11点54分 链路聚合(Link Aggregation),又称为端口聚合(Port Trunking),是一种将多条物理…

【Flutter 面试题】 Dart 当中的 .. 表示什么?

【Flutter 面试题】 Dart 当中的 … 表示什么? 文章目录 写在前面口述回答补充说明写在前面 🙋 关于我 ,小雨青年 👉 CSDN博客专家,GitChat专栏作者,阿里云社区专家博主,51CTO专家博主。2023博客之星TOP153。 👏🏻 正在学 Flutter 的同学,你好! 😊 Flutter…

服务器端口映射到另一台服务器

在服务器管理的日常工作中,端口映射是一项常见且关键的任务。通过端口映射,我们可以将一台服务器上的特定端口流量重定向到另一台服务器,实现服务的灵活部署和管理。 一、端口映射的基本概念 端口映射,也称为端口转发或端口重定向…

算法模板——数据结构篇

声明:参考自acwing 目录 1.单链表 2.双链表 3.数组栈与队列 4.单调栈 1.单链表 int head,e[N],ne[N],idx;void init(){head-1;idx0; } void add_head(int x){ //head有实值e[idx]x,ne[idx]head,headidx; } void add(int k,int x){ e[idx]x,…

浓眉大眼的Apple开源OpenELM模型;IDM-VTON试衣抱抱脸免费使用;先进的语音技术,能够轻松克隆任何人的声音

✨ 1: openelm OpenELM是苹果机器学习研究团队发布的高效开源语言模型家族 OpenELM是苹果机器学习研究团队开发的一种高效的语言模型,旨在推动开放研究、确保结果的可信赖性、允许对数据和模型偏见以及潜在风险进行调查。其特色在于采用了一种分层缩放策略&#x…

LeetCode 2385.感染二叉树需要的总时间:两次搜索(深搜 + 广搜)

【LetMeFly】2385.感染二叉树需要的总时间:两次搜索(深搜 广搜) 力扣题目链接:https://leetcode.cn/problems/amount-of-time-for-binary-tree-to-be-infected/ 给你一棵二叉树的根节点 root ,二叉树中节点的值 互不…