构建模型三要素与权重初始化

news/2025/3/5 6:02:22/

1、模型三要素

三要素其实很简单:

  1. 必须要继承nn.Module这个类,要让PyTorch知道这个类是一个Module
  2. __init__(self)中设置好需要的组件,比如conv,pooling,Linear,BatchNorm等等。
  3. 最后在forward(self,x)中用定义好的组件进行组装,就像搭积木,把网络结构搭建出来,这样一个模型就定义好了。

我们先看一个例子:
先看__init__(self)函数

# class Net(nn.Module):
def __init__(self):super(Net,self).__init__()self.conv1 = nn.Conv2d(3,6,5)self.pool1 = nn.MaxPool2d(2,2)self.conv2 = nn.Conv2d(6,16,5)self.pool2 = nn.MaxPool2d(2,2)self.fc1 = nn.Linear(16*5*5,120)self.fc2 = nn.Linear(120,84)self.fc3 = nn.Linear(84,10)

第一行是初始化,往后定义了一系列组件。nn.Conv2d 就是一般图片处理的卷积模块,然后池化层,全连接层等等。

定义完这些,再定义forward函数

def forward(self,x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1,16*5*5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

x为模型的输入,第一行表示x经过conv1,然后经过激活函数relu,然后经过pool1操作
第三行表示对x进行reshape,为后面的全连接层做准备

至此,对一个模型的定义完毕,如何使用呢?
例如:

net = Net()
outputs = net(inputs)

其实net(inputs),就是类似于使用了net.forward(inputs)这个函数。

2、参数初始化

简单地说就是设定什么层用什么初始方法,初始化的方法会在torch.nn.init中

# 定义权值初始化
def initialize_weights(self):for m in self.modules():if isinstance(m,nn.Conv2d):torch.nn.init.xavier_normal_(m.weight.data)if m.bias is not None:m.bias.data.zero_()elif isinstance(m,nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()elif isinstance(m,nn.Linear):torch.nn.init.normal_(m.weight.data,0,0.01)# m.weight.data.normal_(0,0.01)m.bias.data.zero_()

这段代码的基本流程就是,先从self.modules()中遍历每一层,然后判断更曾属于什么类型,是否是Conv2d,是否是BatchNorm2d,是否是Linear的,然后根据不同类型的层,设定不同的权值初始化方法,例如Xavierkaimingnormal_等等。kaiming也是MSRA初始化,是何恺明大佬在微软亚洲研究院的时候,因此得名。

上面代码中用到了self.modules(),这个是什么东西呢?

# self.modules的源码
def modules(self):for name,module in self.named_modules():yield module

功能就是:能依次返回模型中的各层,yield是让一个函数可以像迭代器一样可以用for循环不断从里面遍历(可能说的不太明确)。

3、完整运行代码

我们用下面的例子来更深入的理解self.modules(),同时也把上面的内容都串起来(下面的代码块可以运行):

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoaderclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):torch.nn.init.xavier_normal_(m.weight.data)if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()elif isinstance(m, nn.Linear):torch.nn.init.normal_(m.weight.data, 0, 0.01)# m.weight.data.normal_(0,0.01)m.bias.data.zero_()net = Net()
net.initialize_weights()
print(net.modules())
for m in net.modules():print(m)

运行结果:

# 这个是print(net.modules())的输出
<generator object Module.modules at 0x0000023BDCA23258>
# 这个是第一次从net.modules()取出来的东西,是整个网络的结构
Net((conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)
# 从net.modules()第二次开始取得东西就是每一层了
Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
Linear(in_features=400, out_features=120, bias=True)
Linear(in_features=120, out_features=84, bias=True)
Linear(in_features=84, out_features=10, bias=True)

其中呢,并不是每一层都有偏执bias的,有的卷积层可以设置成不要bias的,所以对于卷积网络参数的初始化,需要判断一下是否有bias,(不过我好像记得bias默认初始化为0?不确定,有知道的朋友可以交流)

torch.nn.init.xavier_normal(m.weight.data)
if m.bias is not None:m.bias.data.zero_()

上面代码表示用xavier_normal方法对该层的weight初始化,并判断是否存在偏执bias,若存在,将bias初始化为0。


http://www.ppmy.cn/news/1112824.html

相关文章

WebRTC 如何指定 H265解码器

WebRTC 本身支持多种视频编解码器&#xff0c;但 H.265/HEVC 编解码器的支持主要取决于浏览器或应用的实现。不过&#xff0c;如果你确定你的 WebRTC 实现和对端支持 H.265&#xff0c;可以通过修改 SDP 来优先选择 H.265 编解码器。 以下是如何指定 H.265 作为优先解码器的基…

elasticsearch6-RestClient操作文档

个人名片&#xff1a; 博主&#xff1a;酒徒ᝰ. 个人简介&#xff1a;沉醉在酒中&#xff0c;借着一股酒劲&#xff0c;去拼搏一个未来。 本篇励志&#xff1a;三人行&#xff0c;必有我师焉。 本项目基于B站黑马程序员Java《SpringCloud微服务技术栈》&#xff0c;SpringCloud…

数组相关面试题

1、原地移除数组中所有的元素val&#xff0c;要求时间复杂度为O(N),空间复杂度为O(1)。 OJ链接&#xff1a;27. 移除元素 - 力扣&#xff08;LeetCode&#xff09; 分析&#xff1a; 法1&#xff1a;挪到数据&#xff0c;思路如顺序表的头删&#xff0c;将后面的数据向前挪动将…

Windows【工具 04】WinSW官网使用说明及实例分享(将exe和jar注册成服务)实现服务器重启后的服务自动重启

官方Github&#xff1b;官方下载地址。没有Git加速的话很难下载&#xff0c;分享一下发布日期为2023.01.29的当前最新稳定版v2.12.0网盘连接。 包含文件&#xff1a; WinSW-x64.exesample-minimal.xmlsample-allOptions.xml 链接&#xff1a;https://pan.baidu.com/s/1sN3hL5H…

机器学习第六课--朴素贝叶斯

朴素贝叶斯广泛地应用在文本分类任务中&#xff0c;其中最为经典的场景为垃圾文本分类(如垃圾邮件分类:给定一个邮件&#xff0c;把它自动分类为垃圾或者正常邮件)。这个任务本身是属于文本分析任务&#xff0c;因为对应的数据均为文本类型&#xff0c;所以对于此类任务我们首先…

十天学完基础数据结构-第一天(绪论)

1. 数据结构的研究内容 数据结构的研究主要包括以下核心内容和目标&#xff1a; 存储和组织数据&#xff1a;数据结构研究如何高效地存储和组织数据&#xff0c;以便于访问和操作。这包括了在内存或磁盘上的数据存储方式&#xff0c;如何将数据元素组织成有序或无序的集合&…

[NLP] LLM---<训练中文LLama2(三)>对LLama2进行中文预料预训练

预训练 预训练部分可以为两个阶段&#xff1a; 第一阶段&#xff1a;冻结transformer参数&#xff0c;仅训练embedding&#xff0c;在尽量不干扰原模型的情况下适配新增的中文词向量。第二阶段&#xff1a;使用 LoRA 技术&#xff0c;为模型添加LoRA权重&#xff08;adapter&…

Android 格式化存储之Formatter

格式化存储相关的数值时&#xff0c;可以用 android.text.format.Formatter 。 Formatter.formatFileSize(Context context, long sizeBytes) 源码说明&#xff0c;在 Android O 后&#xff0c;存储单位的进制是 1000 &#xff0c;Android N 之前单位进制是 1024 。 /*** Fo…