一起深度学习24/04/30——ResNet

server/2024/9/24 10:17:06/

ResNet神经网络

  • 定义ResNet Block
  • 定义ResNet18
  • 加载数据集并训练、测试

定义ResNet Block

ResNet Block 的作用:
是一个残差块,用于构建ResNet
主要是为了解决神经网络中的梯度爆炸和梯度消失问题,以及缓解训练过程中的退化问题。
在传统的神经网络中,每层的输出会直接作为下一层的输入,可能会导致梯度在反向传播过程中逐渐减小,当层数比较深时,就可能导致梯度消失。故引入了跳跃连接,将每一层的输出与最初的x进行相加,当你对其进行求导,能发现比传统的多了一项对x的求导,也就是因为该项,避免了梯度消失的问题。

class ResBlk(nn.Module):"""resnet Block"""def __init__(self,ch_in,ch_out,stride):super(ResBlk,self).__init__()self.conv1 = nn.Conv2d(in_channels=ch_in,out_channels=ch_out,kernel_size=3,stride=stride,padding=1)print(self.conv1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(in_channels=ch_out, out_channels=ch_out, kernel_size=3, stride=1, padding=1)print(self.conv2)self.bn2 = nn.BatchNorm2d(ch_out)self.extra =nn.Sequential()#当输入通道数并不等于输出通道数的时候,进行转换。if ch_out != ch_in:self.extra = nn.Sequential(# [b,ch_in,h,w] =>[b,ch_out,h,w]nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),nn.BatchNorm2d(ch_out))def forward(self,x):""":param x: [b,ch,h,w]:return:"""out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))#shor cut# x :[b,ch_in,h,w]  而out [b,ch_out,h,w]out = self.extra(x) +out #resNet的精髓所在,能够避免过拟合,梯度爆炸,梯度消失,return out

运行测试一下:

def main():blk = ResBlk(64,128,stride=4)tmp = torch.randn(2,64,32,32)out = blk(tmp)print(out.shape)
if __name__ == '__main__':main()

在这里说明一下其中的疑惑,在做该模块的时候
blk = ResBlk(64,128,stride=4) #64是输入通道数,128表示输出通道数。
tmp = torch.randn(2,64,32,32) # 2是样本数量,64是输入通道数,32是形状。
out = blk(tmp) #将其传入到ResBlok中,进行运算。
输出为torch.Size([2, 128, 8, 8])。

定义ResNet18

class ResNet18(nn.Module):def __init__(self):super(ResNet18,self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3,64,kernel_size=3,stride=3,padding=0),nn.BatchNorm2d(64))# followed 4 blocks# [b,64,h,w] => [b,128,h,w]self.blk1 =  ResBlk(64,128,stride=2)# [b,128,h,w] => [b,256,h,w]self.blk2 = ResBlk(128,256,stride=2)# [b,256,h,w] => [b,512,h,w]self.blk3 = ResBlk(256, 512,stride=2)# [b,512,h,w] => [b,1024,h,w]self.blk4 = ResBlk(512, 512,stride=2)self.outlayer = nn.Linear(512,10)def forward(self,x):x = F.relu(self.conv1(x))x = self.blk1(x)x = self.blk2(x)x = self.blk3(x)x = self.blk4(x)x = F.adaptive_avg_pool2d(x,[1,1])x = x.view(x.size(0), -1)x = self.outlayer(x)return x

加载数据集并训练、测试

import torch
import torchvision.transforms
from torch import nn, optim
from torchvision import datasets
from torch.utils.data import DataLoader
# from lenet5 import Lenet5
from learing_resnet import ResNet18
def main():batchsz = 32cifar_train= datasets.CIFAR10('data',train=True,transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()]),download=True)cifar_train = DataLoader(cifar_train,batch_size=batchsz,shuffle=True)cifar_test= datasets.CIFAR10('data',train=False,transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()]),download=True)cifar_test = DataLoader(cifar_test,batch_size=batchsz,shuffle=True)# x, label = iter(cifar_train)# print("x:",x.shape,"label:",label.shape)device  = torch.device('cuda')# model = Lenet5().to(device)model = ResNet18().to(device)criten = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(),lr=1e-3)for epoch in range(1000):for batchidx,(x,lable) in enumerate(cifar_train):x,lable = x.to(device),lable.to(device)logits = model(x)loss = criten(logits,lable)optimizer.zero_grad()loss.backward()optimizer.step()print(epoch,loss.item())total_correct = 0total_num = 0model.eval()with torch.no_grad():for x,label in cifar_test:x,label = x.to(device),label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred,label).float().sum().item()total_num += x.size(0)acc = total_correct /total_numprint(epoch,acc)if __name__ == '__main__':main()

http://www.ppmy.cn/server/37345.html

相关文章

书生·浦语大模型全链路开源开放体系

书生浦语大模型全链路开源体系InternLM2技术报告 InternLM 实战营第二期(第一节) 为了帮助社区用户高效掌握和广泛应用大模型技术,我们重磅推出书生浦语大模型实战营系列活动,旨在为开发者们提供全面而系统的大模型技术学习课程。…

智能电视遇险?海信、TCL、小米“上下求索”

众所周知,智能化已经成为各行各业的发展新方向,家电行业更是早早掀起了智能化浪潮,越来越多的智能家电产品涌现出来。在客厅场景,电视毫无疑问是众多家电中的C位,而在智能化浪潮的助推下,电视这一产品同样开…

crossover怎么打开软件 mac怎么下载steam crossover下载的软件怎么运行

CrossOver是一款Mac和Linux平台上的类虚拟机软件,通过CrossOver可以运行Windows的可执行文件。如果你是Mac用户且需要使用CrossOver,但是不知道CrossOver怎么打开软件,如果你想在Mac电脑上玩Windows游戏,但不知道怎么下载Steam&am…

vue3使用tsx/jsx时报错:JSX 元素隐式具有类型 “any“,因为不存在接口 “JSX.IntrinsicElements“。

vue3使用tsx/jsx时报错:JSX 元素隐式具有类型 "any",因为不存在接口 "JSX.IntrinsicElements"。 在项目中安装:npm install types/react npm install types/react

使用Axios从前端上传文件并且下载后端返回的文件

前端代码: function uploadAndDownload(){showLoading();const fileInput document.querySelector(#uploadFile);const file fileInput.files[0];const formData new FormData()formData.append(file, file)return new Promise((resolve, reject) > {axios({…

shell常用文件处理命令

1. 解压 1.1 tar 和 gz 文件 如果你有一个 .tar 文件,你可以使用以下命令来解压: tar -xvf your_file.tar在这个命令中,-x 表示解压缩,-v 表示详细输出(可选),-f 后面跟着要解压的文件名。 如果你的 .tar 文件同时被 gzip 压缩了(即 .tar.gz 文件),你可以使用以下…

我们该如何看待AIGC(人工智能)

引言 人工智能(AI)是当今世界科技发展的前沿领域之一,它正在以前所未有的速度和规模影响着我们的生活、工作和思考方式。AIGC,即人工智能生成内容(Artificial Intelligence Generated Content)&#xff0c…

【一刷《剑指Offer》】面试题 16:反转链表

力扣对应题目链接:206. 反转链表 - 力扣(LeetCode) 牛客对应题目链接:反转链表_牛客题霸_牛客网 (nowcoder.com) 核心考点 :链表操作,思维缜密程度。 一、《剑指 Offer》内容 二、分析题目 解题思路&#…