【示例】如何使用Pytorch堆叠一个神经网络

news/2025/2/22 3:43:57/

本文主要从大致步骤上讲述如何从零开始构建一个网络,仅提供一个思路,具体实现以实际情况为准。

一、构建网络

class 网络模型(nn.Module):#----初始化函数----##主要用来构建网络单元,类似于类定义def __init__(self,需要传入的参数列表): super(网络模型,self).__init__()#----定义一些相关的神经单元----#self.backbone=BackBone(...)#----将神经单元堆叠成网络模块----#self.block=nn.Sequential(#将标准模块堆叠)self.cls_conv=nn.Conv2d(256,numclass,1,stride=1)#----前向传递函数----##----主要用来实现网络计算,需要按严格的顺序堆叠def forward(self,x):    # x为输入特征#----堆叠网络                                                                                       return x

        以类的形式来构造神经网络的主结构,同时可以通过多个类的叠加来区分基干网络等各个部分,但是每个网络类中主要分为两个部分:

        ①类定义:主要负责声明神经模组,如卷积,池化,BN层等;以及参数的注入。同时将一些标准模块堆叠成神经模块(可以构造多个);类似于声明,顺序没有特殊要求(小模块内有顺序)

        ②前向传递函数:将类定义中构造的各模块进行顺序构造。x为传入的特征向量,通过让x按顺序经过各模块来实现网络的构造。类似于实现,需要遵守严格的顺序

二、网络训练

#----各种参数的定义,预处理----#
# 包括一些超参数,路径,cuda等参数的设定,以及卷积核尺寸等参数的计算#----网络模型实体化----#
model = 类定义(参数列表)#----权重的加载/初始化----#
model_dict = model.state_dict()    #获取网络结构
pretrained_dict = torch.load(modePath,map_location=device)    #从文件中加载权重
for k,v in pretrained_dict.items():#逐层判断网络结构是否一致
model_dict.update(temp_dict)         #上传权重参数
model.load_state_dict(model_dict)    #加载权重#----如有必要,进行cuda型的转换----##----数据集的加载----#
dataset = 
dataloader = #----进行训练----#

        由于训练的步数较多,建议将其分为预处理和具体训练分开编写。预处理包括网络的实体化,参数和超参数的处理及填充,数据集的加载和转换

        而训练部分则包括:优化器的实例化和参数的填充训练和验证

        训练是指分步将训练数据从dataloader中取出并执行以下步骤:         

                        ①前向传递

                        ②计算损失函数

                        ③梯度清零

                        ④前向传递

                        ⑤优化参数

        随后将测试集从dataloader中取出按训练同样的步骤进行预测,并计算损失函数 

        训练步骤

loss_fn=   #设置交叉函数
learing_rate=   #学习率
optimizer = torch.optim   #设置优化器#----开始训练----#
for i in range(epoch):#----将训练集从dataloader中解包----#for data in train_loader:imgs,targes=data    outputs=mynet(imgs) #网络前向传递loss=loss_fn(outputs,targes)    #计算损失函数optimizer.zero_grad()   #梯度清零loss.backward()  #前向传递optimizer.step()    #逐步优化total_train_step+=1 #训练计数#----开始测试----#
with torch.no_grad():    #不设置梯度(保证不进行调优)#----将测试集从dataloader中拆包----#for data in test_loader:imgs,targets = data outputs = mynet(imgs)    #进行预测loss = loss_fn(outputs,targets)    #计算损失函数#----保存每轮的模型----#
#torch.save(mynet,"MyNerNet_Ver{}.pth".format(total_train_step))

三、进行预测

        进行预测总体和训练类似,但是不需要将数据送入dataset和dataloader中,一般也不需要计算损失函数,仅需要调用网络对数据进行预测即可。是神经网络的应用环节。

    def ImgDetect(self,img,count=False,nameClasses=None,outType=0):#----图片的预处理----## 主要包括图像的参数计算,resize,和添加batch_size维度#----使用网络预测----#with torch.no_grad():# 类型转换imgs = torch.from_numpy(img_data)# 传入网络并得到结果Img = self.net(imgs)# 进行后续处理#----返回结果----#return Img

http://www.ppmy.cn/news/827.html

相关文章

使用PyQt5界面设计

一、环境搭建 直接pip安装即可: pip install PyQt5 pip install pyqt5-tools 二、Qt Designer设计GUI Qt Designer 是通过拖拽的方式放置控件,并实时查看控件效果进行快速UI设计。最终生成.ui文件,可以通过pyuic5工具转换成.py文件。 打开d…

RabbitMQ中的集群架构介绍

文章目录前言一、普通集群(副本集群)1.架构图二、镜像集群1.架构图前言 在之前我们是以单节点的形式来运行mq。在真正的生产实践中,mq主要用来完成两个应用系统间的通信,如果在某一时刻mq宕机了,会导致系统瘫痪,就是无法进行通信…

操作系统02_进程管理_同步互斥信号量_PV操作_死锁---软考高级系统架构师007

存储管理可以分为固定存储管理和分页存储管理。 现在固定存储管理已经不用也不考,但要知道因为固定存储管理指的是整存整取 也就是把一整个程序,比如说10G的游戏全部都存到内存里 这样的话是非常占用内存的,这个固定存储管理现在已经不用了。 然后这里我们主要看分页存储管: …

全国职业院校技能大赛 - ruijie网络模块 - 样卷三解析

目录 一.说明 二.项目背景 三.项目规划和设计 四.网络项目实施

基于51单片机农业大棚温控系统

资料编号:197 大棚种植温控系统概述: 本文介绍的是一个由单片机构成的温度控制系统,主要用来提供测温的解决方案,同时还能实时监控温度变化趋势,以及报警功能。它利用STC89C52RC单片机,DS18B20&#xff0c…

掌握docker这几招,你也能搞云计算了

Docker的好处 容器技术出现十多年了,已经在测试和生产环境得到普遍应用。几个好处: 便携性、隔离性封装性,可复用方便做集群部署和资源调度 … 所谓云计算,就是所有计算、服务、产品都云化,部署在云上,你…

01-go基础-06-切片(声明切片、初始化切片、切片赋值、切片长度、切片容量、空切片、append、copy)

文章目录1. 声明切片2. 初始化切片2.1 切片长度2.1.1 初始化指定长度的切片2.1.2 查看切片长度 len()2.2 切片容量2.2.1 初始化指定容量的切片2.2.2 查看切片长度 len()3 切片赋值3.1 直接赋值3.2 引用数组给切片赋值3.3 引用数组某区间给切片赋值3.3.1 从数组位置N个取到第M个…

【目标检测】Faster R-CNN论文代码复现过程解读(含源代码)

目录:Faster R-CNN论文代码复现过程解读Faster R-CNN代码使用说明书(分享在github上)一、代码的地址二、我的配置环境三、参数值文件下载四、VOC数据集下载五、模型训练步骤(1)训练VOC0712数据集1.数据集的准备2.数据集…