本篇文章是博主在AI、强化学习等领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对人工智能等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在AI学习:
AI学习笔记(5)---《LetNet、AlexNet、ResNet网络模型实现手写数字识别》
LetNet、AlexNet、ResNet网络模型实现手写数字识别
目录
前言:
1 LetNet5层网路模型
2 AlexNet网路模型
3 ResNet网路模型
前言:
本篇文章主要分享使用 LetNet、AlexNet、ResNet网络模型实现手写数字识别,项目的代码在下面的链接中:
如果CSDN下载不了的话,可以关注公众号免费获取:小趴菜只卖红薯
1 LetNet5层网路模型
1.1 代码如下:
class LetNet(torch.nn.Module):def __init__(self):super(LetNet, self).__init__()self.conv1 = torch.nn.Sequential( # 1*28*28torch.nn.Conv2d(1, 10, kernel_size=3), # 10*26*26torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2), # 10*13*13)self.conv2 = torch.nn.Sequential(torch.nn.Conv2d(10, 20, kernel_size=2), # 10*12*12torch.nn.ReLU(),)self.conv3 = torch.nn.Sequential(torch.nn.Conv2d(20, 20, kernel_size=5), # 20*8*8torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2), # 20*4*4)self.fc = torch.nn.Sequential(torch.nn.Linear(320, 50),torch.nn.Linear(50, 10),)def forward(self, x):batch_size = x.size(0)x = self.conv1(x) # 一层卷积层,一层池化层,一层激活层(图是先卷积后激活再池化,差别不大)x = self.conv2(x) # 再来一次x = self.conv3(x) # 再来一次x = x.view(batch_size, -1) # flatten 变成全连接网络需要的输入 (batch, 20,4,4) ==> (batch,320), -1 此处自动算出的是320x = self.fc(x)return x # 最后输出的是维度为10的,也就是(对应数学符号的0~9)
1.2 测试结果:
在超参数为batch_size = 64,learning_rate = 0.01 ,momentum = 0.5 ,EPOCH = 10的情况下,随着测试样本变化,准确率变化如下图所示,准确率为98.2%,有较高的准确率。
2 AlexNet网路模型
2.1 代码如下:
class AlexNet(nn.Module):def __init__(self, width_mult=1):super(AlexNet, self).__init__()self.layer1 = nn.Sequential( # 输入1*28*28nn.Conv2d(1, 32, kernel_size=3, padding=1), # 32*28*28nn.MaxPool2d(kernel_size=2, stride=2), # 32*14*14nn.ReLU(inplace=True),)self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1), # 64*14*14nn.MaxPool2d(kernel_size=2, stride=2), # 64*7*7nn.ReLU(inplace=True),)self.layer3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1), # 128*7*7)self.layer4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, padding=1), # 256*7*7)self.layer5 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, padding=1), # 256*7*7nn.MaxPool2d(kernel_size=3, stride=2), # 256*3*3nn.ReLU(inplace=True),)self.fc1 = nn.Linear(256 * 3 * 3, 1024)self.fc2 = nn.Linear(1024, 512)self.fc3 = nn.Linear(512, 10)def forward(self, x):x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.layer5(x)x = x.view(-1, 256 * 3 * 3)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x
2.2 测试结果:
在超参数为batch_size = 64,learning_rate = 0.01 ,momentum = 0.5 ,EPOCH = 10的情况下,随着测试样本变化,准确率变化如下图所示,准确率始终保持在99%附近,有较高的准确率。
3 ResNet网路模型
3.1 代码如下:
class Residual(nn.Module):def __init__(self, input_channels, num_channels,use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))def resnet_block(input_channels, num_channels, num_residuals,first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blkb2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))ResNet = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 10))
3.2 测试结果:
在超参数为batch_size = 64,learning_rate = 0.01 ,momentum = 0.5 ,EPOCH = 10的情况下,随着测试样本变化,准确率变化如下图所示,准确率为99.2%,相比于LetNet、AlexNet略高,平均准确率为99.05%;但是ResNet网络层数较多,运算相对较慢。
文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。