【深度学习】基于MXNet的多层感知机的实现

devtools/2025/2/12 6:43:54/

多层感知机

结构组成

大致由三层组成:输入层-隐藏层-输出层,其中隐藏层大于等于一层

其中,隐藏层和输出层都是全连接

隐藏层的层数和神经元个数也是超参数

多层隐藏层,在本质上仍等价于单层神经网络(可从输出方程简单推得),
但是增加网络的深度可以更加有效地提高网络对深层抽象概念的理解,降低训练难度

激活函数

目前Sigmoid函数正在被逐渐淘汰,目前仅在二分类问题上仍有用武之地

目前最主流的激活函数是ReLU函数及其变种,它使模型更加简单高效,没有梯度消失问题,对输入的敏感程度更高,迭代速度更快

具体实现

  • 完整版本
python">import d2lzh as d2l
from mxnet import nd
from mxnet.gluon import loss as gloss'''
基础准备工作
'''
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)num_inputs, num_outputs, num_hiddens = 784, 10, 256W1 = nd.random.normal(scale=0.01, shape=(num_inputs, num_hiddens))      # 形状等于 输入*输出
b1 = nd.zeros(num_hiddens)
W2 = nd.random.normal(scale=0.01, shape=(num_hiddens, num_outputs))      # 形状等于 输入*输出
b2 = nd.zeros(num_outputs)
params = [W1, W2, b1, b2]
for param in params:param.attach_grad()     # 统一申请梯度空间# 激活函数
def relu(X):return nd.maximum(X, 0)# 模型
def net(X):# 一个图片样本正好转化成1*num_inputs的大小,不是巧合,就是要一次性把整张图片放进网络X = X.reshape((-1, num_inputs)) H = relu(nd.dot(X, W1)+b1)      # 隐藏层需要应用激活函数return nd.dot(H, W2) + b2       # 输出层不需要用激活函数# 损失
loss = gloss.SoftmaxCrossEntropyLoss()'''
开始训练
'''
num_epochs, lr = 20, 0.2
d2l.train_ch3(net, test_iter, test_iter, loss, num_epochs, batch_size, params, lr)
  • 简化版本
python">import d2lzh as d2l
from mxnet import gluon, init
from mxnet.gluon import loss as gloss, nnnet = nn.Sequential()
# 添加一层256个节点的全连接层,并使用ReLU激活函数
# 再添加一层10个节点的全连接层,不使用激活函数(输出层)
net.add(nn.Dense(256, activation='relu'), nn.Dense(10))
net.initialize(init.Normal(sigma=0.01))batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)loss = gloss.SoftmaxCrossEntropyLoss()trainer = gluon.Trainer(net.collect_params(), 'sgd', {"learning_rate": 0.2})
num_epochs = 20
d2l.train_ch3(net, test_iter, test_iter, loss, num_epochs, batch_size, None, None, trainer)

实际上只简化了训练器的构建,由此也可以发现,实现一个网络的训练是一件非常简单的事情,复杂的主要是训练前后的各种处理,训练只是手段,不是目的

注意事项

尝试将隐藏层的数量改成1024,再增加训练次数,此时可以发现,模型对训练集的误差一直在缩小,但是对测试集的误差不降反增,此时发生了过拟合


http://www.ppmy.cn/devtools/157798.html

相关文章

java-异常家族梳理(流程图)

前言: 使用流程图梳理异常,便于理解 梳理: Throwable ├── Error(严重错误,无需捕获) │ ├── OutOfMemoryError │ ├── StackOverflowError │ └── ... ├── Exception(可捕获处理) │ ├── RuntimeException(非检查异常/Unchecked) │ …

SpringMVC常用的注解

Spring MVC 提供了丰富的注解,这些注解可以简化开发过程,提高开发效率,使代码结构更加清晰。以下是 Spring MVC 中一些常用注解的详细介绍: 1. 控制器相关注解 1.1 Controller 作用:用于标记一个类为控制器类&#…

26~31.ppt

目录 26.北京主要的景点 题目 解析 27.创新产品展示及说明会 题目​ 解析 28.《小企业会计准则》 题目​ 解析 29.学习型社会的学习理念 题目​ 解析 30.小王-产品展示信息 题目​ 解析 31.小王-办公理念-信息工作者的每一天 题目​ 解析 26.北京主要的景点…

C++ labmbd表达式

文章目录 C++ Lambda 表达式详解1. Lambda 表达式的组成部分:2. Lambda 语法示例(1) 最简单的 Lambda(2) 带参数的 Lambda(3) 指定返回类型的 Lambda3. 捕获外部变量(1) 值捕获(复制)(2) 引用捕获(3) 捕获所有变量4. Lambda 在 STL 中的应用5. Lambda 作为 `std::function`6…

DeepSeek R1技术报告关键解析(8/10):DeepSeek-R1 的“aha 时刻”,AI 自主学习的新突破

1. 什么是 AI 的“aha 时刻”? 在强化学习过程中,AI 的推理能力并不是线性增长的,而是会经历一些关键的“顿悟”时刻,研究人员将其称为“aha 时刻”。 这是 AI 在训练过程中突然学会了一种新的推理方式,或者能够主动…

20240206 adb 连不上手机解决办法

Step 1: lsusb 确认电脑 usb 端口能识别设备 lsusb不知道设备有没有连上,就插拔一下,对比观察多了/少了哪个设备。 Step 2: 重启 adb server sudo adb kill-serversudo adb start-serveradb devices基本上就可以了~ Reference https://b…

SQL自学,mysql从入门到精通 --- 第 14天,主键、外键的使用

1.主键 PRIMARY KEY 主键的使用 字段值不允许重复,且不允许赋NULL值 创建主键 root@mysqldb 10:11: [d1]> CREATE TABLE t3(-> name varchar(10) PRIMARY KEY,-> age int,-> class varchar(8)-> ); Query OK, 0 rows affected (0.01 sec)root@mysqldb 10:…

Google安装vue插件多种解决方案

方式1:google-极简插件 (1)通过谷歌应用商店安装 (国外网站) (2)极简插件: 下载 → 开发者模式 → 拖拽安装 → 插件详情允许访问文件 极简插件 地址: https://chrome.zzzmh.cn/in…