深度学习02 神经网络实现手写数字案例

ops/2025/2/21 4:24:27/

目录

下载手写数字图像(图像+标签)

展示手写数字图片

数据打包

判断当前设备是否支持GPU

建立神经网络模型

设置训练集与测试集

创建损失函数、优化器

开始训练


下载手写数字图像(图像+标签)
python">training_data=datasets.MNIST(root='data',train=True,download=True,transform=ToTensor(),
)
test_data=datasets.MNIST(root='data',train=False,download=True,transform=ToTensor(),
)
展示手写数字图片
python">from matplotlib import pyplot as plt
figure=plt.figure()
for i in range(16):img,label=training_data[i+59000]figure.add_subplot(4,4,i+1)plt.title(label)plt.axis('off')plt.imshow(img.squeeze(),cmap='gray')a=img.squeeze()
plt.show()

数据打包
python">train_dataloader=DataLoader(training_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)

判断当前设备是否支持GPU
python">device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(device)
建立神经网络模型
python">class  Neturalwork(nn.Module):def __init__(self):super().__init__()self.flatten=nn.Flatten()self.hidden1=nn.Linear(28*28,128)self.hidden2=nn.Linear(128,256)self.out=nn.Linear(256,10)def forward(self,x):x=self.flatten(x)x=self.hidden1(x)x=torch.sigmoid(x)x=self.hidden2(x)x=torch.sigmoid(x)x=self.out(x)return x
​
model=Neturalwork().to(device)
print(model)

 

设置训练集与测试集
python">def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num=1for x,y in dataloader:x,y=x.to(device),y.to(device)pred=model.forward(x)loss=loss_fn(pred,y)
​optimizer.zero_grad()loss.backward()optimizer.step()loss_value=loss.item()if batch_size_num%100==0:print(f'loss:{loss_value:>7f} [number:{batch_size_num}]')batch_size_num+=1
​def test(dataloader,model,loss_fn):size=len(dataloader.dataset)num_batches=len(dataloader)model.eval()test_loss,correct=0,0with torch.no_grad():for x,y in dataloader:x,y=x.to(device),y.to(device)pred=model.forward(x)test_loss+=loss_fn(pred,y).item()correct+=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)b=(pred.argmax(1)==y).type(torch.float)test_loss/=num_batchescorrect/=sizeprint(f'Test result:\n Accuracy:{(100*correct)}%,Avg loss:{test_loss}')
创建损失函数、优化器
python">loss_fn=nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 学习率可以根据需要调整
开始训练
python">epochs=15
for t in range(epochs):print(f'EPOCH {t+1}\n-----------')train(train_dataloader,model,loss_fn,optimizer)
print('结束')
test(test_dataloader,model,loss_fn)


http://www.ppmy.cn/ops/160142.html

相关文章

瑞萨RA-T系列芯片ADCGPT功能模块的配合使用

在马达或电源工程中,往往需要采集多路AD信号,且这些信号的优先级和采样时机不相同。本篇介绍在使用RA-T系列芯片建立马达或电源工程时,如何根据需求来设置主要功能模块ADC&GPT,包括采样通道打包和分组,GPT触发启动…

审计级别未启用扩展模式导致查询 DBA_AUDIT_TRAIL 时 SQL_TEXT 列为空

如果查询 DBA_AUDIT_TRAIL 时发现 SQL_TEXT 列为空,但其他字段(如 OS_USERNAME、USERNAME、TIMESTAMP 等)有数据,可能是由于以下原因之一。以下是可能的原因及解决方法: 1. 审计级别未启用扩展模式 默认情况下&#x…

青少年编程与数学 02-009 Django 5 Web 编程 17课题、中间件

青少年编程与数学 02-009 Django 5 Web 编程 17课题、中间件 一、中间件中间件的特点中间件的作用 二、应用场景1. **消息传递**2. **数据库管理**3. **负载均衡**4. **缓存**5. **身份验证和授权**6. **企业应用集成**7. **云计算平台** 三、Django中的中间件中间件的工作原理…

二级指针略解【C语言】

以int** a为例 1.二级指针的声明 a 是一个指向 int*(指向整型的指针)的指针,即二级指针。 通俗的讲,a是一个指向指针的指针,对a解引用会是一个指针。 它可以用于操作动态分配的二维数组、指针数组或需要间接修改指针…

macOS部署DeepSeek-r1

好奇,跟着网友们的操作试了一下 网上方案很多,主要参考的是这篇 DeepSeek 接入 PyCharm,轻松助力编程_pycharm deepseek-CSDN博客 方案是:PyCharm CodeGPT插件 DeepSeek-r1:1.5b 假设已经安装好了PyCharm PyCharm: the Pyth…

力扣-二叉树-501 二叉搜索树的众数

思路 二叉搜索树的特性就是中序遍历有序&#xff0c;所以思考时可以先按照有序数组思考 代码 class Solution { public:vector<int> result;TreeNode* pre nullptr;int count 1;int maxCount 0;void travesl(TreeNode* node){if(node nullptr) return;travesl(nod…

Linux系统:ubuntu系统几款系统界面化工具的区别(xubuntu-desktop和ubuntu-desktop和kubuntu-desktop)

‌Xubuntu-desktop、Ubuntu-desktop和Kubuntu-desktop的主要区别在于它们使用的桌面环境、系统资源占用以及用户界面设计。 主要区别 桌面环境系统资源占用用户界面设计 桌面环境 ‌Xubuntu-desktop‌&#xff1a;使用Xfce桌面环境&#xff0c;这是一个轻量级的桌面环境&#…

从 0 到 1:Spring Boot 构建高效应用指南

感兴趣的可以先收藏起来&#xff0c;还有大家在毕设选题&#xff0c;项目以及论文编写等相关问题都可以给我留言咨询&#xff0c;我会一一回复&#xff0c;希望帮助更多的人。 在当下竞争激烈且技术飞速发展的 Java 开发领域&#xff0c;各类框架层出不穷。而 Spring Boot 凭借…