pytorch-手写数字识别之全连接层实现

news/2024/9/23 8:54:00/

目录

  • 1. 背景
  • 2. nn.Linear线性层
  • 2. 实现MLP网络
  • 3. train
  • 4. 完整代码

1. 背景

上一篇https://blog.csdn.net/wyw0000/article/details/137622977?spm=1001.2014.3001.5502中实现手撸代码的方式实现了手写数字识别,本文将使用pytorch的API实现。

2. nn.Linear线性层

相当于x = x@w1.t() + b1
因此可以使用nn.Linear代替,上一篇文中中的

python">w1, b1 = torch.randn(200, 784, requires_grad=True),\torch.zeros(200, requires_grad=True)
x = x@w1.t() + b1

使用nn.Linear定义的三层网络,如下图所示:
在这里插入图片描述
增加激活函数relu
在这里插入图片描述

2. 实现MLP网络

  • 实现__init__函数
  • 将网络各层放到序列化容器Sequential中
    见下图使用nn.Sequential创建一个小的model,运行时输入首先传给nn.Linear(784, 200),nn.Linear(784, 200)的输出再传给nn.ReLU(inplace=True),这样依次传递下去,直至结束。
  • 实现forward
    调用将输入x作为model的参数调用model,并将结果返回。
    在这里插入图片描述

3. train

在这里插入图片描述

4. 完整代码

python">import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.ReLU(inplace=True),nn.Linear(200, 200),nn.ReLU(inplace=True),nn.Linear(200, 10),nn.ReLU(inplace=True),)def forward(self, x):x = self.model(x)return xnet = MLP()
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss()for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)logits = net(data)test_loss += criteon(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

http://www.ppmy.cn/news/1426050.html

相关文章

T31开发笔记: 移动侦测

若该文为原创文章,转载请注明原文出处。 最近在测试创安源IPC时发现摄像头的视频流有移动侦测功能 ,拆解后发现使用的是T31,刚好手头上有淘宝买50多点的T31摄像头,就自己现在了个简易DEMO测试一下。 一、硬件和开发环境 1、硬件:…

go并发编程以及socket通信的理解

go并发编程以及socket通信的理解 文章目录 go并发编程以及socket通信的理解一、管道的简单使用二、go中的socket实现通信 一、管道的简单使用 " golang不是通过共享内存来通信,而是通过通信来共享内存 " 1、go简单初始化 // golang不是通过共享内存来通…

机器学习方法在测井解释上的应用-以岩性分类为例

机器学习在测井解释上的应用越来越广泛,主要用于提高油气勘探和开发的效率和精度。通过使用机器学习算法,可以从测井数据中自动识别地质特征,预测岩石物理性质,以及优化油气储层的评估和管理。 以下是机器学习在测井解释中的一些…

最新AI创作系统ChatGPT网站源码AI绘画,GPTs,AI换脸支持,GPT联网提问、DALL-E3文生图

一、前言 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,那么如何搭建部署AI创作ChatGPT?小编这里写一个详细图文教程吧。已支持GPT…

python爬虫之爬取携程景点评价(5)

一、景点部分评价爬取 【携程攻略】携程旅游攻略,自助游,自驾游,出游,自由行攻略指南 (ctrip.com) import requests from bs4 import BeautifulSoupif __name__ __main__:url https://m.ctrip.com/webapp/you/commentWeb/commentList?seo0&businessId22176&busines…

阿药陪你学Java(第三讲)

第三讲:输入和输出 Java中的输入和输出是用户可以直观的和程序进行交互的一种方式。 所谓输出就是程序打印信息,比如打印计算结果到控制台显示;输入就是用户手动提供给程序某些数据,比如根据程序提示从控制台输入数字或字符串等…

js-pytorch:开启前端+AI新世界

嗨, 大家好, 我是 徐小夕。最近在 github 上发现一款非常有意思的框架—— js-pytorch。它可以让前端轻松使用 javascript 来运行深度学习框架。作为一名资深前端技术玩家, 今天就和大家分享一下这款框架。 往期精彩 Nocode/Doc,可…

Java作业6-Java类的基本概念三

编程1 import java.util.*;abstract class Rodent//抽象类 {public abstract String findFood();//抽象方法public abstract String chewFood(); } class Mouse extends Rodent {public String findFood(){ return "大米"; }public String chewFood(){ return "…