深度学习学习经验——全连接神经网络(FCNN)

news/2024/9/18 14:55:43/ 标签: 深度学习, 学习, 神经网络

什么是全连接神经网络

全连接神经网络(FCNN)是最基础的神经网络结构,它由多个神经元组成,这些神经元按照层级顺序连接在一起。每一层的每个神经元都与前一层的每个神经元连接。

想象你在参加一个盛大的晚会,晚会上有三个区域:接待区交流区结果区

  • 接待区(输入层):负责接收来宾(数据),每个来宾代表一个特征。
  • 交流区(隐藏层):每个来宾在交流区与其他来宾交流,交换信息。这里的每个交流区的来宾(神经元)都会与其他来宾进行对话,以获得更深层次的理解。
  • 结果区(输出层):最后,在结果区,来宾们会得出他们的总结(预测结果),然后将其提供给晚会的组织者(输出)。

神经网络的结构

输入层(Input Layer)
  • 功能:接受原始数据。每个神经元代表一个特征,例如图片的像素值、语音信号的特征等。
  • 示例:如果你有一张28x28像素的灰度图片,那么输入层会有784个神经元(28x28=784),每个神经元代表一个像素值。
隐藏层(Hidden Layer)
  • 功能:处理和提取数据特征。每个神经元通过加权和激活函数来处理输入数据,然后将结果传递到下一层。
  • 示例:隐藏层的神经元数目可以是任意的,例如100个神经元。隐藏层能够提取数据的复杂特征,如图像中的边缘或形状。
输出层(Output Layer)
  • 功能:给出最终的预测结果。输出层的神经元数目等于任务的类别数。例如,在数字分类任务中,输出层有10个神经元(分别代表0到9这10个数字)。
  • 示例:如果你要识别图片中的数字(0-9),输出层的每个神经元会输出一个数字的概率。
激活函数(Activation Function)
  • 功能:引入非线性,使得神经网络能够处理复杂的模式。激活函数决定了神经元是否激活。
  • 常见激活函数
    • ReLU(Rectified Linear Unit):f(x) = max(0, x)。用于隐藏层,能够引入非线性,提升模型表现。
    • Sigmoidf(x) = 1 / (1 + exp(-x))。用于输出层,特别是在二分类任务中。

下面为代码案例

我们使用MNIST数据集来训练一个简单的全连接神经网络模型。MNIST数据集包含了手写数字的图像,每个图像的大小为28x28像素,共10个类别(0-9)。

如果在下载 MNIST 数据集时遇到了问题。尝试手动下载数据集(经尝试官方的似乎会被墙,说没有权限下载

手动下载数据集

可以手动下载 MNIST 数据集,并将其放到合适的目录下。

  1. 下载数据集

    • MNIST 数据集 主页提供了所有必要的文件。可以直接从这里下载:
      • train-images-idx3-ubyte.gz
      • train-labels-idx1-ubyte.gz
      • t10k-images-idx3-ubyte.gz
      • t10k-labels-idx1-ubyte.gz
    • 如果官方的渠道无法下载 可以从我分享的 MNIST数据集下载
  2. 解压文件

    • 下载后,可以使用工具解压这些 .gz 文件。例如,在终端中运行:
      gunzip train-images-idx3-ubyte.gz
      gunzip train-labels-idx1-ubyte.gz
      gunzip t10k-images-idx3-ubyte.gz
      gunzip t10k-labels-idx1-ubyte.gz
      
  3. 移动文件

    • 将解压后的文件放到项目目录下的 ./data/MNIST/raw/ 目录中。

1. 数据预处理

首先,我们需要对 MNIST 数据集进行预处理,将图像转换为 Tensor,并进行归一化处理。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 1. 数据预处理
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为Tensor,并归一化到[0, 1]transforms.Normalize((0.5,), (0.5,))  # 标准化
])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

讲解

  • 使用 torchvision.transforms 对 MNIST 数据集进行转换,将图像转换为 Tensor,并归一化到 [0, 1] 范围。
  • 将数据集加载到 DataLoader 中,设置批量大小和是否打乱数据。

2. 定义全连接神经网络模型

接下来,定义一个简单的全连接神经网络模型。这个模型包括三个全连接层和两个激活函数(ReLU 和 Softmax)。

# 2. 定义全连接神经网络模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28*28, 128)  # 第一层:输入层到隐藏层self.fc2 = nn.Linear(128, 64)     # 第二层:隐藏层到隐藏层self.fc3 = nn.Linear(64, 10)      # 第三层:隐藏层到输出层self.relu = nn.ReLU()             # ReLU激活函数self.softmax = nn.Softmax(dim=1)  # Softmax激活函数,用于输出层def forward(self, x):x = x.view(-1, 28*28)  # 将每张图片展平为一维向量x = self.relu(self.fc1(x))  # 第一层到ReLUx = self.relu(self.fc2(x))  # 第二层到ReLUx = self.fc3(x)            # 第三层,输出层return self.softmax(x)     # Softmax输出model = SimpleNN()

讲解

  • 定义了一个名为 SimpleNN神经网络类,继承自 nn.Module
  • __init__ 方法中定义了三个全连接层(fc1, fc2, fc3)和两个激活函数(ReLUSoftmax)。
  • forward 方法中定义了数据如何流经网络层,包括展平、激活函数和最终的 Softmax 输出。

3. 定义损失函数和优化器

然后,我们定义损失函数和优化器。这里使用交叉熵损失函数和 Adam 优化器。

# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器

讲解

  • 使用 nn.CrossEntropyLoss 作为损失函数,这适用于分类任务。
  • 使用 optim.Adam 作为优化器来更新网络参数,设置学习率为 0.001。

4. 训练模型

接下来,我们定义一个函数来训练模型。在每个 epoch 中,模型会遍历训练数据,进行前向传播、计算损失、进行反向传播并更新参数。

# 4. 训练模型
def train(model, criterion, optimizer, train_loader, epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")

讲解

  • train 函数通过指定的 epochs 训练模型。每个 epoch 中,模型会遍历训练数据,进行前向传播、计算损失、进行反向传播并更新参数。
  • 每个 epoch 结束后,输出当前的平均损失。

5. 评估模型

最后,我们定义一个函数来评估模型的准确性。在测试数据上评估模型的表现。

# 5. 评估模型
def evaluate(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint(f"Test Accuracy: {accuracy:.4f}")

讲解

  • evaluate 函数在测试数据上评估模型的准确率。模型进入评估模式,不计算梯度,直接进行预测并计算准确率。
  • 输出模型在测试数据上的准确性。

完整代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 1. 数据预处理
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为Tensor,并归一化到[0, 1]transforms.Normalize((0.5,), (0.5,))  # 标准化
])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)# 2. 定义全连接神经网络模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28*28, 128)  # 第一层:输入层到隐藏层self.fc2 = nn.Linear(128, 64)     # 第二层:隐藏层到隐藏层self.fc3 = nn.Linear(64, 10)      # 第三层:隐藏层到输出层self.relu = nn.ReLU()             # ReLU激活函数self.softmax = nn.Softmax(dim=1)  # Softmax激活函数,用于输出层def forward(self, x):x = x.view(-1, 28*28)  # 将每张图片展平为一维向量x = self.relu(self.fc1(x))  # 第一层到ReLUx = self.relu(self.fc2(x))  # 第二层到ReLUx = self.fc3(x)            # 第三层,输出层return self.softmax(x)     # Softmax输出model = SimpleNN()# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器# 4. 训练模型
def train(model, criterion, optimizer, train_loader, epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:optimizer.zero_grad()  # 清空梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数running_loss += loss.item()print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}")# 5. 评估模型
def evaluate(model, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint(f"Test Accuracy: {accuracy:.4f}")# 执行训练和评估
train(model, criterion, optimizer, train_loader)
evaluate(model, test_loader)

这个完整的代码示例展示了如何使用 PyTorch 构建、训练和评估一个简单的全连接神经网络模型。

总结

全连接神经网络是最基础的神经网络结构,通过输入层接收数据,通过隐藏层进行特征提取和学习,最后通过输出层给出预测结果。激活函数为网络引入非线性,使其能够学习和处理复杂的模式。通过逐步构建和训练模型,我们可以解决各种数据分类和回归问题。


http://www.ppmy.cn/news/1516713.html

相关文章

Vue中的this.$emit()方法详解【父子组件传值常用】

​在Vue中,this.$emit()方法用于触发自定义事件。它是Vue实例的一个方法,可以在组件内部使用。 使用this.$emit()方法,你可以向父组件发送自定义事件,并传递数据给父组件。父组件可以通过监听这个自定义事件来执行相应的逻辑。 …

问界M7 Pro这招太狠了,直击理想L6/L7要害

文 | AUTO芯球 作者 | 雷慢 李想的理想估计要失眠了,为什么啊? 前有L6悬架薄如铁片被曝光,被车主们骂了个狗血淋头, 现在又来个问界M7 Pro版, 24.98万的后驱智驾版就上华为ADS主视觉智驾了, 两个后驱&…

TMDOG的微服务之路_07——初入微服务,NestJS微服务快速入门

TMDOG的微服务之路_07——初入微服务,NestJS微服务快速入门 博客地址:TMDOG的博客 在前几篇博客中,我们探讨了如何在 NestJS 中的一些基础功能,并可以使用NestJS实现一个简单的单体架构后端应用。本篇博客,我们将进入…

基于改进YOLOv8的景区行人检测算法

贵向泉, 刘世清, 李立, 秦庆松, 李唐艳. 基于改进YOLOv8的景区行人检测算法[J]. 计算机工程, 2024, 50(7): 342-351. DOI: 10.19678/j.issn.10 原文链接如下:基于改进YOLOv8的景区行人检测算法https://www.ecice06.com/CN/rich_html/10.19678/j.issn.1000-3428.006…

解决Element-plus中Carousel(走马灯)图片无法正常加载的bug

前言&#xff1a; 最近帮助朋友解决了一个使用Element-plus中Carousel&#xff08;走马灯&#xff09;图片无法正常加载的bug&#xff0c;经过笔者的不断努力终于实现了&#xff0c;现在跟大家分享一下&#xff1a; 朋友原来的代码是这样的&#xff1a; <template><…

【计算机网络】电路交换、报文交换、分组交换

电路交换&#xff08;Circuit Switching&#xff09;&#xff1a;通过物理线路的连接&#xff0c;动态地分配传输线路资源 ​​​​

依靠 VPN 生存——探索 VPN 后利用技术

执行摘要 在这篇博文中,Akamai 研究人员强调了被忽视的 VPN 后利用威胁;也就是说,我们讨论了威胁行为者在入侵 VPN 服务器后可以用来进一步升级入侵的技术。 我们的发现包括影响 Ivanti Connect Secure 和 FortiGate VPN 的几个漏洞。 除了漏洞之外,我们还详细介绍了一组…

SpringBoot集成kafka-获取生产者发送的消息(阻塞式和非阻塞式获取)

说明 CompletableFuture对象需要的SpringBoot版本为3.X.X以上&#xff0c;需要的kafka依赖版本为3.X.X以上&#xff0c;需要的jdk版本17以上。 1、阻塞式&#xff08;等待式&#xff09;获取生产者发送的消息 生产者&#xff1a; package com.power.producer;import org.ap…

Linux的进程详解(进程创建函数fork和vfork的区别,资源回收函数wait,进程的状态(孤儿进程,僵尸进程),加载进程函数popen)

目录 什么是进程 Linux下操作进程的相关命令 进程的状态&#xff08;生老病死&#xff09; 创建进程系统api介绍&#xff1a; fork() 父进程和子进程的区别 vfork() 进程的状态补充&#xff1a; 孤儿进程 僵尸进程 回收进程资源api介绍&#xff1a; wait() waitpid…

VastBase——全局性能调优

目录 一、系统资源调优 1.内存和CPU 2.网络 3.I/O 二、查询最耗性能的SQL 三、分析作业是否被阻塞 背景&#xff1a;影响性能的因素 系统资源 数据库性能在很大程度上依赖于磁盘的I/O和内存使用情况。为了准确设置性能指标&#xff0c;用户需要了解Vastbase部署硬件的基本…

深信服研发面试经验分享

吉祥知识星球http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247485367&idx1&sn837891059c360ad60db7e9ac980a3321&chksmc0e47eebf793f7fdb8fcd7eed8ce29160cf79ba303b59858ba3a6660c6dac536774afb2a6330#rd 《网安面试指南》http://mp.weixin.qq.com/s?…

在Spring Boot项目中集成Geth(Go Ethereum)

在Spring Boot项目中集成Geth&#xff08;Go Ethereum&#xff09;客户端&#xff0c;通常是为了与以太坊区块链进行交互。以下是一些基本的步骤和考虑因素&#xff0c;帮助你在Spring Boot应用程序中集成Geth。 安装Geth 首先&#xff0c;你需要在你的机器上安装Geth。你可以从…

k8s备份etcd3.5

一、思路 1、创建nfs存储类,用作存储备份数据<略> 2、制作用于备份的镜像文件 3、指定cronjob 二、制作镜像 ## dockerfile文件# cat Dockerfile FROM dhub.kubesre.xyz/centos:7 ADD etcdv359.tar / RUN mkdir /snapshot# docker build -t registry.k8s.io/etcd:3.…

ST表模板

P3865 【模板】ST 表 && RMQ 问题 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 思路:区间最大值&#xff0c;模板题。 int n,m; int arr[100005]; int f[100005][25]; (1<<20)1e6 void init(){ o(nlogn)for(int i1;i<…

爆改YOLOv8|利用SENetV2改进yolov8,暴力涨点

1&#xff0c;本文介绍 本文探讨了将 SENetV2 的稠密聚合层与 SE 模块结合&#xff0c;应用于 YOLOv8&#xff0c;以提升特征表达能力和目标检测性能。SENetV2 通过 Squeeze-and-Excitation&#xff08;SE&#xff09;模块优化通道和全局特征&#xff0c;从而提高分类准确率。…

UE5.4内容示例(5)UI_CommonUI - 学习笔记

https://www.unrealengine.com/marketplace/zh-CN/product/content-examples 《内容示例》是学习UE5的基础示例&#xff0c;可以用此熟悉一遍UE5的功能 UI_CommonUI可以看这个视频学习&#xff0c;此插件处于Beta状态&#xff0c;应用UI游戏方面&#xff0c;支持手柄等多输入端…

sap 开发工具 jdbc odbc 驱动 下载地址

SAP Development Tools (ondemand.com) sap 开发工具 jdbc odbc 驱动 下载地址

【系统架构设计师-2018年】综合知识-答案及详解

文章目录 【第1题】【第2~3题】【第4题】【第5~6题】【第7题】【第8题】【第9题】【第10题】【第11题】【第12题】【第13题】【第14题】【第15题】【第16~17题】【第18~21题】【第22题】【第23题】【第24题】【第25题】【第26题】【第27~28题】【第29~30题】【第31题】【第32~3…

学习前端面试知识(16)

computed和watch 参考文章vue computed 计算属性&#xff0c;有缓存功能&#xff0c;底层通过dirty来判断是否重新计算&#xff0c;只有在依赖数据发生变化时才会重新计算&#xff0c;性能更好。不能进行异步操作。缓存属性受多个属性影响&#xff0c;比如购物车商品结算函数…

OSPF-基础多区域实验

1.ENSP下载 阿里云盘分享 ⭐/*无需密钥 免费下载 安装不成功&#xff0c;可关注并私信博主*/ 2.OSPF的基础需求和规则 实验规则&#xff1a; 1.接口地址→XY.XY.XY.R /24 X:两者之间最小的 Y:两者之间最大的 R:谁的接口就是谁的编号 以R1和R2之间的连接为例&#xff0…