第四章 ResNet网络详解

news/2025/2/1 9:58:16/

系列文章目录

第一章 AlexNet网络详解

第二章 VGG网络详解

第三章 GoogLeNet网络详解 

第四章 ResNet网络详解 

第五章 ResNeXt网络详解 

第六章 MobileNetv1网络详解 

第七章 MobileNetv2网络详解 

第八章 MobileNetv3网络详解 

第九章 ShuffleNetv1网络详解 

第十章 ShuffleNetv2网络详解 

第十一章 EfficientNetv1网络详解 

第十二章 EfficientNetv2网络详解 

第十三章 Transformer注意力机制

第十四章 Vision Transformer网络详解 

第十五章 Swin-Transformer网络详解 

第十六章 ConvNeXt网络详解 

第十七章 RepVGG网络详解 

第十八章 MobileViT网络详解 

 


文章目录

  • ResNet网络详解
  • 0. 前言
  • 1. 摘要
  • 2. ResNet网络详解网络架构
    • 1. ResNet_Model.py(pytorch实现)
    • 2.
  • 总结


0、前言


1、摘要

       更深的神经网络更难训练。我们提出了一个残差学习框架,以便训练比以前使用的网络更深。我们将层明确重组为以层输入为参考的学习残差函数,而不是学习无关的函数。我们提供了全面的实证证据,表明这些残差网络更容易优化,并且可以通过大幅增加深度获得更高的准确性。在ImageNet数据集上,我们评估了深度达152层的残差网络,比VGG网络[40]深8倍,但复杂度仍然较低。这些残差网络的组合在ImageNet测试集上实现3.57%的错误率。这个结果赢得了ILSVRC 2015分类任务的第一名。我们还对具有100和1000层的CIFAR-10进行了分析。表示的深度对于许多视觉识别任务具有重要意义。仅仅由于我们极其深的表示,我们在COCO物体检测数据集上获得了28%的相对改进。深层残差网络是我们提交给ILSVRC和COCO 2015比赛的基础,其中我们还赢得了ImageNet检测、ImageNet定位、COCO检测和COCO分割任务的第一名。

2、ResNet网络结构

1.本文介绍了一种深度神经网络的学习框架,使得更深层次的网络更容易训练和优化。

2.本文的研究背景是深度神经网络在训练和优化方面的困难。

3.本文的主要论点是,通过引入残差学习框架,可以更轻松地训练和优化深度神经网络,并在多个视觉识别任务中获得更高的准确率。

4.过去的研究主要采用传统的前向传播方法来训练和优化神经网络,但在深度增加时很难解决梯度消失和梯度爆炸的问题。这些方法也容易导致网络过拟合和训练缓慢。

5.本文提出的方法是引入残差学习框架,重定义网络层的学习方式为学习残差函数,从而实现更深层次网络的训练和优化。

6.研究发现,在多个视觉识别任务中,通过残差学习框架训练的深度神经网络可以获得更高的准确率。该方法在2015年的多个比赛中获得了第一名的好成绩。但是,由于本文研究主要关注残差学习框架在视觉识别任务中的应用,因此在其他领域的适用性还需要进一步探讨。

ResNet网络解决了深度神经网络训练过程中的梯度消失和梯度爆炸问题。梯度消失问题是由于当神经网络过深时,反向传播算法中的梯度值会变得非常小,可能会接近于0。这使得神经网络无法更新权重,从而无法继续学习。而梯度爆炸则是相反的问题,即梯度值过大,导致权重更新过快,从而使得网络失去稳定性。 ResNet的创新点是通过引入残差连接(residual connections)解决了梯度消失和梯度爆炸问题。残差连接直接连接了网络中前一层和后一层的输出,使得后一层不仅学习新的特征,还能保留前一层学习到的特征。这种方法使得神经网络的训练变得更加稳定,深度网络也变得更易于训练。此外,ResNet还创新性地使用了1x1卷积来降低特征图的维度,从而减少了模型参数,提高了效率。

1.ResNet_Model.py(pytorch实现)

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self,num_classes=1000,init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(48, 128, kernel_size=5, stride=1, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(128, 192, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3,stride=2))self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128*6*6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes))if init_weights:self._initialize_weights()def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return xdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

2.train.py

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = 'D:/100_DataSets/'image_path = os.path.join(data_root, "03_flower_data")assert os.path.exists(image_path), "{} path does not exits.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform = data_transform['train'])train_num = len(train_dataset)flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())json_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 6nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])print('Using {} dataloder workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform['val'])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=4,shuffle=False,num_workers=nw)print("using {} image for train, {} images for validation.".format(train_num, val_num))net = AlexNet(num_classes=5, init_weights=True)net.to(device)loss_fuction = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0002)epochs = 10save_path = './AlexNet.pth'best_acc = 0.0train_steps = len(train_loader)for epoch in range(epochs):net.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_fuction(outputs, labels.to(device))loss.backward()optimizer.step()running_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:,.3f}".format(epoch+1, epochs, loss)net.eval()acc = 0.0with torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch % d] train_loss: %.3f val_accuracy: %.3f' %(epoch+1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(),save_path)print("Finished Training")if __name__ == '__main__':main()

3.predict.py

import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])img_path = "D:/20_Models/01_AlexNet_pytorch/image_predict/tulip.jpg"assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)img = data_transform(img)img = torch.unsqueeze(img, dim=0)json_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)with open(json_path,"r") as f:class_indict = json.load(f)model = AlexNet(num_classes=5).to(device)weights_path = "./AlexNet.pth"assert os.path.exists(weights_path), "file: '{}' does not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path))model.eval()with torch.no_grad():output = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {} prob: {:.3f}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10} prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

4.predict.py

import os
from shutil import copy, rmtree
import randomdef mk_file(file_path: str):if os.path.exists(file_path):rmtree(file_path)os.makedirs(file_path)def main():random.seed(0)split_rate = 0.1#cwd = os.getcwd()#data_root = os.path.join(cwd, "flower_data")data_root = 'D:/100_DataSets/03_flower_data'origin_flower_path = os.path.join(data_root, "flower_photos")assert os.path.exists(origin_flower_path), "path '{}' does not exist".format(origin_flower_path)flower_class = [cla for cla in os.listdir(origin_flower_path) if os.path.isdir(os.path.join(origin_flower_path, cla))]train_root = os.path.join(data_root,"train")mk_file(train_root)for cla in flower_class:mk_file(os.path.join(train_root, cla))val_root = os.path.join(data_root, "val")mk_file(val_root)for cla in flower_class:mk_file(os.path.join(val_root,cla))for cla in flower_class:cla_path = os.path.join(origin_flower_path,cla)images = os.listdir(cla_path)num = len(images)eval_index = random.sample(images, k=int(num*split_rate))for index, image in enumerate(images):if image in eval_index:image_path = os.path.join(cla_path, image)new_path = os.path.join(val_root, cla)copy(image_path, new_path)else:image_path = os.path.join(cla_path, image)new_path = os.path.join(train_root, cla)copy(image_path, new_path)print("\r[{}] processing [{} / {}]".format(cla, index+1, num), end="")print()print("processing done!")if __name__ == "__main__":main()

总结

提示:这里对文章进行总结:

每天一个网络,网络的学习往往具有连贯性,新的网络往往是基于旧的网络进行不断改进。


http://www.ppmy.cn/news/510562.html

相关文章

VIAVI MTS-5800手持式网络测试仪支持10G的万兆网络测试

MTS-5800 手持式网络测试仪是网络技术人员和工程师在安装、开通和维护网络时所必需的一种工具。它支持各种传统技术和新兴技术,可自如应对各种各样的网络应用,包括城域/核心、移动回传和企业服务网络。业界最小的手持式仪表,可在整个业务周期…

什么软件可以用蓝牙测试信号,litepoint IQview蓝牙测试仪/无线wifi网络信号测试仪...

无线测试系统 LitePoint IQview提供强大工具,能取得并测量WiFi及蓝牙讯号,并以分析功能提供装置运作说明,简化了WiFi与Bluetooth装置的开发过程。IQview与IQflex製造解决方桉完全相容,使得以IQview开发的测试软体也能在已生产线使…

网络综合测试仪

1.都有什么功能,他都可以测什么数据。 这里我们以TFN TT70 FT100智能网络测试平为例子,讲一下台提供全方位的通信技术连接和服务测试功能,可选,支持OTN, SDH/SONET, MSTP, PDH/DSn, PTN/IP RAN, SyncE, IEEE1588v2 PTP, OTDR, 以…

无线网络测试

进行无线网络测试,测试工具netstumbler , 有时候由于无线网络广播太大,测试软件无法下常运行,都遇到过好几次。具体现场是:软件打开后,无法显示或都无线测试时软件图形界面有无数的断线状。以前认为是软件问…

WiFi以及天线测试项目详解

1.相关术语: 天线增益 天线增益就是某天线在最大辐射方向上的辐射能量跟点源天线(dBi)或偶极子天线(dBd)在同方向上的辐射能量的比值. 天线规格书的几个参数 Gain(dBi):在相同的输入功率下,天线在空间某点的辐射功率…

iq2010wifi测试软件,【IQ2010 WiFi综合测试仪 无线网络分析仪】价格_厂家 - 中国供应商...

itepoint iq2010无线设备测试仪 产品名称: IQ2010测试系统 产品规格: WiFi/BT/GPS/NFC 型号: IQ2010 产品类别: 无线通信测试仪->LitePoint 仪器介绍 Litepoint的IQ2010特别为降低多重通讯产品的测试成本而设计,是一个***兼容多…

无线网络工程建设、升级及维护要注意什么——TFN-100A系列天馈线频谱测试仪

一、无线网络 随着科学技术的不断提高,无线通信网络技术快速发展。无线网络的建设、升级及维护是无线网络工程中的重点,无线传输网络的质量保障是重中之重。 天馈线测试仪和频谱分析仪是无线网络维护中常用的仪表。射频网络的发射和接收信号都是通过天馈…

无线网络安全测试软件

看来还是有督促作用的,以前多次想搞清楚的问题最近思路也清晰了很多,先总结一下软件相关的。也不知道Csdn_Blog以后能不能支持本地备份,不然以后无法登陆不是惨了。 扫描类NetStumblerNetStumbler是WINDOWS下面基于802.11的无线网络扫描工具…