【单层神经网络】基于MXNet的线性回归实现(底层实现)

ops/2025/2/5 23:15:57/

写在前面

  1. 刚开始先从普通的寻优算法开始,熟悉一下学习训练过程
  2. 下面将使用梯度下降法寻优,但这大概只能是局部最优,它并不是一个十分优秀的寻优算法

整体流程

  1. 生成训练数据集(实际工程中,需要从实际对象身上采集数据)
  2. 确定模型及其参数(输入输出个数、阶次,偏置等)
  3. 确定学习方式(损失函数、优化算法,学习率,训练次数,终止条件等)
  4. 读取数据集(不同的读取方式会影响最终的训练效果)
  5. 训练模型

完整程序及注释

python">from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random'''
获取(生成)训练集
'''
input_num = 2				# 输入个数
examples_num = 1000			# 生成样本个数
# 确定真实模型参数
real_W = [10.9, -8.7]		
real_bias = 6.5	features = nd.random.normal(scale=1, shape=(examples_num, input_num))       # 标准差=1,均值缺省=0
labels = real_W[0]*features[:,0] + real_W[1]*features[:,1] + real_bias		# 根据特征和参数生成对应标签
labels_noise = labels + nd.random.normal(scale=0.1, shape=labels.shape)		# 为标签附加噪声,模拟真实情况# 绘制标签和特征的散点图(矢量图)
# def use_svg_display():
#     display.set_matplotlib_formats('svg')# def set_figure_size(figsize=(3.5,2.5)):
#     use_svg_display()
#     plt.rcParams['figure.figsize'] = figsize# set_figure_size()
# plt.scatter(features[:,0].asnumpy(), labels_noise.asnumpy(), 1)
# plt.scatter(features[:,1].asnumpy(), labels_noise.asnumpy(), 1)
# plt.show()# 创建一个迭代器(确定从数据集获取数据的方式)
def data_iter(batch_size, features, labels):num = len(features)indices = list(range(num))                                  # 生成索引数组random.shuffle(indices)                                     # 打乱indices# 该遍历方式同时确保了随机采样和无遗漏for i in range(0, num, batch_size):j = nd.array(indices[i: min(i+batch_size, num)])        # 对indices从i开始取,取batch_size个样本,并转换为列表yield features.take(j), labels.take(j)                  # take方法使用索引数组,从features和labels提取所需数据"""
训练的基础准备
"""
# 声明训练变量,并赋高斯随机初始值
w = nd.random.normal(scale=0.01, shape=(input_num))
b = nd.zeros(shape=(1,))
# b = nd.zeros(1)       # 不同写法,等价于上面的
w.attach_grad()         # 为需要迭代的参数申请求梯度空间
b.attach_grad()# 定义模型
def linreg(X, w, b):return nd.dot(X,w)+b# 定义损失函数
def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) **2 /2# 定义寻优算法
def sgd(params, learning_rate, batch_size):for param in params:# 新参数 = 原参数 - 学习率*当前批量的参数梯度/当前批量的大小param[:] = param - learning_rate * param.grad / batch_size# 确定超参数和学习方式
lr = 0.03
num_iterations = 5
net = linreg				# 目标模型
loss = squared_loss			# 代价函数(损失函数)
batch_size = 10				# 每次随机小批量的大小'''
开始训练
'''
for iteration in range(num_iterations):		# 确定迭代次数for x, y in data_iter(batch_size, features, labels):with autograd.record():l = loss(net(x,w,b), y)			# 求当前小批量的总损失l.backward()						# 求梯度sgd([w,b], lr, batch_size)			# 梯度更新参数train_l = loss(net(features,w,b), labels)print("iteration %d, loss %f" % (iteration+1, train_l.mean().asnumpy()))
# 打印比较真实参数和训练得到的参数
print("real_w " + str(real_W) + "\n train_w " + str(w))
print("real_w " + str(real_bias) + "\n train_b " + str(b))

具体程序解释

param[:] = param - learning_rate * param.grad / batch_size
将batch_size与参数调整相关联的原因,是为了使得每次更新的步长不受批次大小的影响
具体来说,当计算一批数据的损失函数的梯度时,实际上是将这批数据中每个样本对损失函数的贡献累加起来。这意味着如果批次较大,梯度的模也会相应增大
故更新权值时,使用的是数据集的平均梯度,而不是总和


http://www.ppmy.cn/ops/155996.html

相关文章

Node.js与嵌入式开发:打破界限的创新结合

文章目录 一、Node.js的本质与核心优势1.1 什么是Node.js?1.2 嵌入式开发的范式转变二、Node.js与嵌入式结合的四大技术路径2.1 硬件交互层2.2 物联网协议栈2.3 边缘计算架构2.4 轻量化运行时方案三、实战案例:智能农业监测系统3.1 硬件配置3.2 软件架构3.3 核心代码片段四、…

深入探究 Spring 中 FactoryBean 注册服务的实现与原理

深入探究 Spring 中 FactoryBean 注册服务的实现与原理 引言 在 Spring 框架里,FactoryBean 是一个强大且重要的特性。它允许开发者以一种更加灵活的方式来创建和管理 Bean 对象。FactoryBeanRegistrySupport 和 AbstractBeanFactory 这两个类在处理 FactoryBean …

vue3中el-input无法获得焦点的问题

文章目录 现象两次nextTick()加setTimeout()解决结论 现象 el-input被外层div包裹了&#xff0c;设置autofocus不起作用&#xff1a; <el-dialog v-model"visible" :title"title" :append-to-bodytrue width"50%"><el-form v-model&q…

C++滑动窗口技术深度解析:核心原理、高效实现与高阶应用实践

目录 一、滑动窗口的核心原理 二、滑动窗口的两种类型 1. 固定大小的窗口 2. 可变大小的窗口 三、实现细节与关键点 1. 窗口的初始化 2. 窗口的移动策略 3. 结果的更新时机 四、经典问题与代码示例 示例 1&#xff1a;和 ≥ target 的最短子数组&#xff08;可变窗口…

python学opencv|读取图像(五十六)使用cv2.GaussianBlur()函数实现图像像素高斯滤波处理

【1】引言 前序学习了均值滤波和中值滤波&#xff0c;对图像的滤波处理有了基础认知&#xff0c;相关文章链接为&#xff1a; python学opencv|读取图像&#xff08;五十四&#xff09;使用cv2.blur()函数实现图像像素均值处理-CSDN博客 python学opencv|读取图像&#xff08;…

DeepSeek 的含金量还在上升

大家好啊&#xff0c;我是董董灿。 最近 DeepSeek 越来越火了。 网上有很多针对 DeepSeek 的推理测评&#xff0c;除此之外&#xff0c;也有很多人从技术的角度来探讨 DeepSeek 带给行业的影响。 比如今天就看到了一篇文章&#xff0c;探讨 DeepSeek 在使用 GPU 进行模型训练…

CNN的各种知识点(四): 非极大值抑制(Non-Maximum Suppression, NMS)

非极大值抑制&#xff08;Non-Maximum Suppression, NMS&#xff09; 1. 非极大值抑制&#xff08;Non-Maximum Suppression, NMS&#xff09;概念&#xff1a;算法步骤&#xff1a;具体例子&#xff1a;PyTorch实现&#xff1a; 总结&#xff1a; 1. 非极大值抑制&#xff08;…

Linux环境下的Java项目部署技巧:Nginx 详解

Nginx 的启动 Nginx 启动会生成 2 个进程&#xff1a;主进程与守护进程 主进程&#xff1a;常用于提供反向代理服务。特点&#xff1a;占内存大守护进程&#xff1a;防止主进程以外关闭。特点&#xff1a;占内存小 Nginx 启动需要占用 80 端口: 当 Ngnix 启动失败时&#xff0…