MLP 多层感知机

devtools/2024/10/18 12:28:53/

为了拟合更特殊的函数,在网络中加入多个隐藏层,克服线性的限制。最后一层可以看作线性predictor。

一、

1.最简单流程

输入x矩阵,含有n个样本,每个样本有d个特征。经过隐藏层H将维度转化为h,在经过最后的输出层O将维度转化为q。

2.当我们添加了多个隐藏层时,如果只是对上一层的输出做一个简单的映射,可以发现:

合并隐藏层后其实等价于单层的模型。

所以,需要在每个隐藏单元输出应用激活函数σ(常用包括 relu 0~1,sigmoid 0~1,tanh -1~1),这样就避免了上述的退化情况。

3.如果是全连接的网络,每个神经元都依赖于所有输入的值。所以理论上只有一个隐藏层也可以通过足够的神经元和权重,拟合任意函数。

不过,使用更深(而不是更广)的网络,可以更容易的拟合函数。

4.代码实现

· 初始化w、b

· def relu(X):

    a = torch.zeros_like(X) # 创建一个与X形状相同且元素全为0的张量

    return torch.max(X, a)

· def net(X):

    X = X.reshape((-1, num_inputs)) # 将每张图片都拉平成一个一维的向量

    H = relu(X@W1 + b1)  # 这里“@”代表矩阵乘法

    return (H@W2 + b2)

· loss = nn.CrossEntropyLoss(reduction='none')

· updater = torch.optim.SGD(params, lr=lr)

· d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

5.总结

· 对于相同的分类问题,多层感知机的实现与softmax回归的实现相同,只是多层感知机的实现里增加了带有激活函数的隐藏层。

· 不同的层数、激活函数、权重,都会影响模型acc。

二、过拟合 欠拟合

1.在监督学习中(有监督学习指的是 我们知道每个样本的结果 如回归,无监督学习指的是 不知道/没有样本的结果 如聚类降维),我们假设train set和test set是独立同分布的。

2.介绍几个倾向于影响模型泛化的因素

· 可调整参数的数量。当可调整参数的数量(有时称为自由度)很大时,模型往往更容易过拟合。因为容易受到噪声的影响而拟合歪了。

· 参数采用的值。当权重的取值范围较大时,模型可能更容易过拟合。

· 训练样本的数量。即使模型很简单,也很容易过拟合只包含一两个样本的数据集。而过拟合一个有数百万个样本的数据集则需要一个极其灵活的模型。

3.验证集

实际应用中,测试集只会使用一次。所以我们会通过验证集确定一个最好的超参数,最后再测试。

我记得验证集是从训练集里分出来的,测试集是单独的。

4.K折交叉验证

当训练数据稀缺时,我们甚至可能无法提供足够的数据来构成一个合适的验证集。这个问题的一个流行的解决方案是采用K折交叉验证。

这里,原始训练数据被分成K个不重叠的子集。然后执行K次模型训练和验证,每次在K-1个子集上进行训练,并在剩余的一个子集(在该轮中没有用于训练的子集)上进行验证。最后,通过对K次实验的结果取平均来估计训练和验证误差。

5.生成数据集的代码

features = np.random.normal(size=(n_train + n_test, 1)) # 生成特征

np.random.shuffle(features)

poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))  # 多项式特征

for i in range(max_degree): # gamma函数重新缩放

    poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!

# labels的维度:(n_train+n_test,)

labels = np.dot(poly_features, true_w) #点乘

labels += np.random.normal(scale=0.1, size=labels.shape) # 添加噪声

当使用reshape(1, -1)时,NumPy会根据原始数组的形状和1这个参数,自动计算出合适的列数,使得改变形状后的数组元素个数不变。

三、范数与权重衰减

注解:为了防止过拟合 提高泛化性,使用权重衰减的方法。

它是通过给损失函数增加模型权重L2范数的惩罚(penalty)来让模型权重不要太大,以此来减小模型的复杂度,从而抑制模型的过拟合。  因为上文提过,权重的取值过大也会导致过拟合。

简洁实现:

DL将权重衰减集成到优化器中

def train_concise(wd):

    net = nn.Sequential(nn.Linear(num_inputs, 1)) # 定义模型

    for param in net.parameters():

        param.data.normal_() # 初始化模型参数

    loss = nn.MSELoss(reduction='none')

    num_epochs, lr = 100, 0.003

    # 偏置参数没有衰减

    trainer = torch.optim.SGD([

        {"params":net[0].weight,'weight_decay': wd},  # 指定权重衰减

        {"params":net[0].bias}], lr=lr)   # 偏置参数没有衰减

    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',

                            xlim=[5, num_epochs], legend=['train', 'test'])   # 绘图

    for epoch in range(num_epochs):

        for X, y in train_iter:

            trainer.zero_grad()

            l = loss(net(X), y)

            l.mean().backward() #在l上进行反向传播

            trainer.step()  # 更新参数

        if (epoch + 1) % 5 == 0:

            animator.add(epoch + 1,

                         (d2l.evaluate_loss(net, train_iter, loss),

                          d2l.evaluate_loss(net, test_iter, loss)))

    print('w的L2范数:', net[0].weight.norm().item())


http://www.ppmy.cn/devtools/110000.html

相关文章

螺旋模型例题

答案:D 知识点:螺旋模型是瀑布模型和演化模型的结合。适用于庞大,复杂并且具有高风险的系统 螺旋模型阶段: 制定计划:决定目标,方案和限制 风险分析:评价方案,识别风险&#xff…

支持向量机 (Support Vector Machines, SVM)

支持向量机 (Support Vector Machines, SVM) 通俗易懂算法 支持向量机(SVM)是一种用于分类和回归任务的机器学习算法。在最简单的情况下,SVM是一种线性分类器,适用于二分类问题。它的基本思想是找到一个超平面(在二维…

Mysql | 知识 | 事务隔离级别

转账案例缘起 我的钱包,共有 100 元。 今天我心情好,我决定给你的转账99元,最后的结果肯定是我的余额变为 1元,你的余额多了99元。 转账这一动作在程序里会涉及到一系列的操作,假设我向你转账 99元 的过程是有下面这…

初始JESD204B高速接口协议(JESD204B一)

本文参考 B B B站尤老师 J E S D 204 B JESD204B JESD204B视频,图片来自 J E S D JESD JESD手册或者 A D I / T I ADI/TI ADI/TI官方文档。 1、对比 L V D S LVDS LVDS与 J E S D 204 JESD204 JESD204 J E S D 204 B JESD204B JESD204B是逻辑器件和高速 A D C / D …

Mysql基础练习题 1757.可回收且低脂的产品(力扣)

编写解决方案找出既是低脂又是可回收的产品编号。 题目链接: https://leetcode.cn/problems/recyclable-and-low-fat-products/description/ 建表插入数据: Create table If Not Exists Products (product_id int, low_fats ENUM(Y, N), recyclable …

信号与槽,QMainWindow中常用类的使用

QMainWindow菜单栏和工具栏 菜单栏,工具栏,状态栏,中心部件,铆接部件(浮动窗口) 菜单栏 //创建菜单栏QMenuBar *bar menuBar();//指定父组件this->setMenuBar(bar);this->resize(600,400);this-&g…

linux top命令介绍以及使用

文章目录 介绍 top 命令1. top 的基本功能2. 如何启动 top3. top 的输出解释系统概况任务和 CPU 使用情况内存和交换空间进程信息 4. 常用操作 总结查看逻辑CPU的个数查看系统运行时间 介绍 top 命令 top 是一个在类 Unix 系统中广泛使用的命令行工具,用于实时显示…

【JavaWeb】JDBCDruidTomcat入门使用

本章使用技术版本: Tomcatv10.1.25 关于javaweb相关的其他技术,比如tomcat和maven,在我的主页记录了笔记,ajax我用的是本地笔记以后再考虑上传,前端三板斧我用的菜鸟教程文档 JDBC 初识 JDBC概念 JDBC 就是使用Jav…