【单层神经网络】基于MXNet库简化实现线性回归

news/2025/2/9 12:09:48/

写在前面

同最开始的两篇文章

完整程序及注释

'''
导入使用的库
'''
# 基本
from mxnet import autograd, nd, gluon
# 模型、网络
from mxnet.gluon import nn                     
from mxnet import init
# 学习
from mxnet.gluon import loss as gloss
# 数据集
from mxnet.gluon import data as gdata
'''
生成测试数据集
'''
# 被拟合参数
true_w = [2, -3.4]      # 特征的权重系数
true_b = 4.2            # 整体模型的偏置
# 创建训练数据集
num_inputs = 2          
num_examples = 1000
features = nd.random.normal(loc=0, scale=1, shape=(num_examples, num_inputs))  # 均值为0,标准差为1
labels = true_w[0]*features[:,0] + true_w[1]*features[:,1] + true_b
labels_noise = labels + nd.random.normal()
'''
确定模型
'''
net = nn.Sequential()                       # 声明一个Sequential容器,存放Neural Network
net.add(nn.Dense(1))                        # 向容器中添加一个全连接层,且不使用激活函数,“1”表示该全连接层的输出神经元有1个
net.initialize(init.Normal(sigma=0.01))     # 权重参数随机取自均值=0,标准差=0.01的高斯分布,bias默认=0
'''
确定学习方式
'''
loss = gloss.L2Loss()       # L2范数损失 等价于 平方损失
# .collect_params()方法获取net实例的全部参数,并提供给trainer
# 选择小批量随机梯度下降法(sgd)寻优,学习率为0.03
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03})
'''
数据集采样
'''
batch_size = 10
dataset = gdata.ArrayDataset(features, labels_noise)        # 将标签和特征组合成完整数据集
# DataLoader返回一个迭代器,每次从数据集中提取一个长度为batch_size的子集出来
data_iter = gdata.DataLoader(dataset, batch_size, shuffle=True) # shuffle=True 打乱数据集(随机采样)
'''
开始训练
'''
num_epoch = 3       # 训练轮次
for epoch in range(0, num_epoch):for x, y in data_iter:          # 随机取出一组小批量,同时做到遍历with autograd.record():     # 自动保存梯度数据l = loss(net(x), y)     # 将得到的一组特征放入网络,求得到的输出与对应的标签(含噪声)的损失l.backward()                # 计算该次损失的梯度trainer.step(batch_size)    # 反向传播,基于l.backward()得到的梯度来更新模型的参数l = loss(net(features), labels_noise)     # 该轮训练结束后,求网络对数据集特征的输出,再求输出和含噪声标签的损失print('epoch %d, mean loss: %f' % (epoch+1, l.mean().asnumpy()))  # 展示训练轮次和数据集损失的平均

具体函数解释

trainer.step(batch_size):batch_size指定了当前批的大小,用于计算这次梯度下降的步长

with autograd.record():这行代码的作用是在其作用域内的计算将会被记录下来,以便自动求导


http://www.ppmy.cn/news/1570571.html

相关文章

服务器重启后报Predis_ServerException: Client sent AUTH, but no password is set

Redis问题产生后,处理办法 2025/02/08 11:21:43 [error] [exception.Predis_ServerException] Predis_ServerException: Client sent AUTH, but no password is set in /www/wwwroot/er/protected/extensions/redis/Predis.php:573 Stack trace: #0 /www/wwwroot/er/protected…

Vue3 ref属性

ref() 接受一个内部值&#xff0c;返回一个响应式的、可更改的 ref 对象&#xff0c;此对象只有一个指向其内部值的属性 .value。 function ref<T>(value: T): Ref<UnwrapRef<T>>interface Ref<T> {value: T } 详细信息 ref 对象是可更改的&#xff…

修剪二叉搜索树(力扣669)

这道题还是比较复杂&#xff0c;在递归上与之前写过的二叉树的题目都有所不同。如果当前递归到的子树的父节点不在范围中&#xff0c;我们根据节点数值的大小选择进行左递归还是右递归。为什么找到了不满足要求的节点之后&#xff0c;还要进行递归呢&#xff1f;因为该不满足要…

DeepSeek使用技巧大全(含本地部署教程)

在人工智能技术日新月异的今天&#xff0c;DeepSeek 作为一款极具创新性和实用性的 AI&#xff0c;在众多同类产品中崭露头角&#xff0c;凭借其卓越的性能和丰富的功能&#xff0c;吸引了大量用户的关注。 DeepSeek 是一款由国内顶尖团队研发的人工智能&#xff0c;它基于先进…

java将list转成树结构

首先是实体类 public class DwdCusPtlSelectDto {//idprivate String key;//值private String value;//中文名private String title;private List<DwdCusPtlSelectDto> children;private String parentId;public void addChild(DwdCusPtlSelectDto child) {if(this.chil…

2025年最新Stable Diffusion 新手入门教程,安装使用及模型下载

一、安装要求&#xff1a; ① 操作系统&#xff1a;Windows10以后的系统 ② CPU&#xff1a;不做强制性要求 ③ 内存&#xff1a;推荐8G以上 ④ 显卡&#xff1a;必须是Nvidia的独立显卡&#xff0c;显存最低4G&#xff0c;推荐20系以后&#xff1b;A卡、核显只能用CPU跑 …

余数相同问题(信息学奥赛一本通-1080)

【题目描述】 已知三个正整数a&#xff0c;b&#xff0c;c。现有一个大于1的整数x&#xff0c;将其作为除数分别除a&#xff0c;b&#xff0c;c&#xff0c;得到的余数相同。请问满足上述条件的x的最小值是多少&#xff1f;数据保证x有解。 【输入】 一行&#xff0c;三个不大于…

【AIGC魔童】DeepSeek v3推理部署:华为昇腾NPU/TRT-LLM

【AIGC魔童】DeepSeek v3推理部署&#xff1a;华为昇腾NPU/TRT-LLM &#xff08;1&#xff09;使用华为昇腾NPU推理部署DeepSeek&#xff08;2&#xff09;使用TRT-LLM推理部署DeepSeek &#xff08;1&#xff09;使用华为昇腾NPU推理部署DeepSeek 参考博客&#xff1a;华为昇…