用TensorFlow实现线性回归

devtools/2024/9/25 15:23:01/

说明

本文采用TensorFlow框架进行讲解,虽然之前的文章都采用mxnet,但是我发现tensorflow提供了免费的gpu可供使用,所以果断开始改为tensorflow,若要实现文章代码,可以使用colaboratory进行运行,当然,如果您已经安装了tensorflow,可以采用python直接运行。

贡献

学习时采取动手学深度学习第二版作为教材,但由于本书通过引入d2l(著者自写库)进行深度学习,我希望将d2l的影响去掉,即不使用d2l,使用tensorflow,这一点通过查询GitHub中d2l库提供的相关函数尝试进行实现。

如果本系列文章具有良好表现,将译为英文版上传至Github。

预备知识

学习本篇文章之前,您最好具有以下基础知识:

  1. 回归>线性回归的基础知识
  2. python的基础知识

基本原理 

使用一个仿射变换,通过y=wx+b的模型来对数据进行预测(w和x均为矩阵,大小取决于输入规模),反向传播采用随机梯度下降对参数进行更新,参数包括w和b,即权重和偏差。

实现过程

生成数据集

只需要引入tensorflow即可,synthetic_data()函数将初始化X和Y,即通过真实的权重和偏差值生成数据集。

import tensorflow as tfdef synthetic_data(w, b, num_examples):X = tf.zeros((num_examples, w.shape[0]))X += tf.random.normal(shape=X.shape)y = tf.matmul(X, tf.reshape(w, (-1, 1))) + by += tf.random.normal(shape=y.shape, stddev=0.01)y = tf.reshape(y, (-1, 1))return X, ytrue_w = tf.constant([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

读取数据集

加载刚刚生成的数据集,is_train表示是否进行打乱,默认对数据进行打乱处理,使用load_array函数加载数据集。

def load_array(data_arrays, batch_size, is_train=True):dataset = tf.data.Dataset.from_tensor_slices(data_arrays)if is_train:dataset = dataset.shuffle(buffer_size=1000)dataset = dataset.batch(batch_size)return datasetbatch_size = 10
data_iter = load_array((features, labels), batch_size)

定义模型

模型使用keras API实现,keras是tensorflow中机器学习相关的库。先使用Sequential类定义承载容器,之后添加一个单神经元的全连接层。在TensorFlow中,Sequential表示容器相关的类,layer表示层相关的类。回归>线性回归只需要通过keras中的单神经元的全连接层即可实现,神经元的值即为输出结果。

net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1))

示例的回归>线性回归仅有一个输入X,实际在其他回归>线性回归过程中,很有可能有多个x及其对应的w,但keras的代码均不会发生改变,因为keras的Dense类可以自动判断输入的个数。 

初始化模型参数 

stddev表示标准差,initializer生成一个标准差为1,均值为0的正态分布。在构建全连接层时,使用该正态分布进行初始化。

initializer = tf.initializers.RandomNormal(stddev=0.01)
net = tf.keras.Sequential()
net.add(tf.keras.layers.Dense(1, kernel_initializer=initializer))

定义损失函数和优化算法 

损失函数使用平方损失函数进行计算,训练时使用小批量随机梯度下降SGD方法进行训练,学习率为0.03。

loss = tf.keras.losses.MeanSquaredError()
trainer = tf.keras.optimizers.SGD(learning_rate=0.03)

训练

运行以下代码可以观察训练结果。运行轮次为3轮,每一轮对所有训练集数据进行学习。计算w和b的梯度值,使用梯度下降更新权重w和偏差b。每一轮输出损失函数的值,最终显示权重和偏差的估计误差。

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:with tf.GradientTape() as tape:l = loss(net(X, training=True), y)grads = tape.gradient(l, net.trainable_variables)trainer.apply_gradients(zip(grads, net.trainable_variables))l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')
w = net.get_weights()[0]
print('w的估计误差:', true_w - tf.reshape(w, true_w.shape))
b = net.get_weights()[1]
print('b的估计误差:', true_b - b)

运行结果

epoch 1, loss 0.000194

epoch 2, loss 0.000091

epoch 3, loss 0.000091

w的估计误差: tf.Tensor([-0.00026917 0.00094557], shape=(2,), dtype=float32)

b的估计误差: [4.7683716e-06]

 改进尝试

  1. 更改SGD优化算法为Adam
  2. 更改MeanSquaredError为其他损失函数

对于上述改进,损失均有显著增加,表明原有方法已为最好方法。


http://www.ppmy.cn/devtools/98679.html

相关文章

在AI时代,程序员如何保持核心竞争力?

随着AIGC(如ChatGPT、MidJourney、Claude等)大语言模型的不断涌现,AI辅助编程工具正在迅速普及,程序员的工作方式也正在发生深刻变革。这一趋势引发了广泛的讨论:AI是否会取代部分编程工作?程序员应该如何应…

基于Spark计算网络图中节点之间的Jaccard相似性

基于Spark计算网络图中节点之间的Jaccard相似性 Jaccard 相似度是一种较为常用的衡量两个集合相似性的指标,用于计算两个集合的交集与并集的比率。具体来说,它的计算公式为: 在网络图中同样经常使用Jaccard来计算节点之间的相似性&#xff…

梧桐数据库(WuTongDB):数据库技术中LR算法详解

LR(Left-to-Right, Rightmost Derivation)算法是一种自底向上的语法分析方法,用于解析上下文无关文法。与 LL 分析器的自顶向下分析方式不同,LR 分析器从输入的最左侧开始读取符号,但通过“最右推导”来构建语法树。这…

vue.js的设计与实现(权衡的的艺术-命令式和声明式)

权衡的的艺术 什么是命令式和声明式呢?性能与可维护性的权衡那么,问题又来了,为什么vue.js不选择性能更好的命令式,而选择声明式呢? 虚拟DOM的性能到底如何总结 什么是命令式和声明式呢? 我们来看一下jQue…

MyBatis核心机制

实现MyBatis核心机制环境搭建 1.核心框架示意图 2.模块搭建 1.创建maven项目 2.引入依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSc…

C++ 内嵌 python 解释器

AI 提供 #include <Python.h> #include <map> #include <string>int main() {// 初始化 Python 解释器Py_Initialize();// 创建一个 C std::mapstd::map<std::string, int> myMap {{"apple", 3},{"banana", 5},{"orange&quo…

Linux系统性能调优指南-定期维护

目录 定期维护 日志管理 示例 磁盘维护 示例 示例代码 日志管理示例 磁盘维护示例 定期维护 定期维护对于保持Linux系统的稳定性和性能至关重要。这包括日志管理以及磁盘维护等方面。下面详细介绍这些方面的配置和优化方法。 日志管理 日志文件随着时间的积累可能会占用大量的磁…

Lumos学习王佩丰Excel第十二讲:Match与Index

一、函数语法 1、vlookup的局限 举个栗子&#xff0c;VLOOKUP不能做到从右推左&#xff1a; 由此看来&#xff0c;使用vlookup函数&#xff0c;表格范围要遵循从左到右的顺序&#xff0c;左为自变量&#xff0c;右为因变量&#xff1b;而要解决这种场景的弊端&#xff0c;可以…