torch.nn.Parameter()

news/2025/3/22 13:05:58/

一文通俗理解torch.nn.Parameter()

一、起源

首先,我写这篇文章的起源是因为,我突然看到了一段有关torch.nn.Parameter()的代码。

因此就去了解了一下这个函数,把自己的一些理解记录下来,希望可以帮到你。

二、官方文档

网址如下:https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html#torch.nn.parameter.Parameter

具体参数解释如下:

torch.nn.parameter.Parameter(data=None, requires_grad=True)

  • data:代表一个tensor类似的数据
  • requires_grad:是否需要进行梯度计算,默认为True

三、个人理解

  • 这个函数的主要作用就是把一个不可训练的Tensor数据转换成可以训练的Tensor数据。

那么这个这个函数怎么实现的呢。

这个函数可以将你输入的数据(你想训练的数据)加入到你模型的参数里面(因为requires_grad=True,如果为False就是不加入),跟着你模型的参数一起训练,一起学习,逐渐达到最优解。

代码实现

self.w = nn.Parameter(torch.tensor(0.5, dtype=torch.float), requires_grad=True)
"""
初始数据为0.5,且为float类型,进行训练。
"""

四、示例程序

# -*- coding: UTF-8 -*-
# Project :python 
# File    :test_1.py
# IDE     :PyCharm 
# Author  :小李同学
# Date    :2023/10/21 13:44import torch
import torch.optim as optim
import matplotlib.pyplot as plt# 创建一个可学习的权重参数,初始值为0.5
weight = torch.nn.Parameter(torch.tensor(0.5, requires_grad=True))
# 定义一个优化器,用于更新权重
optimizer = optim.SGD([weight], lr=0.01)
# 目标值
target = torch.tensor(20.0)
# 存储损失和权重的列表,用于绘制学习曲线
losses = []
weights = []
# 训练循环
for epoch in range(10):# 模型的预测值prediction = weight * 5.0  # 假设模型的预测是输入值乘以权重# 计算损失,这里使用均方误差损失loss = (prediction - target) ** 2losses.append(loss.item())weights.append(weight.item())# 梯度清零optimizer.zero_grad()# 反向传播和权重更新loss.backward()optimizer.step()print(f'Epoch {epoch + 1}: Loss={loss.item():.2f}, Weight={weight.item():.2f}')# 绘制学习曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curve')plt.subplot(1, 2, 2)
plt.plot(weights)
plt.xlabel('Epoch')
plt.ylabel('Weight')
plt.title('Weight Curve')plt.show()

输出结果如下:

Epoch 1: Loss=306.25, Weight=2.25
Epoch 2: Loss=76.56, Weight=3.12
Epoch 3: Loss=19.14, Weight=3.56
Epoch 4: Loss=4.79, Weight=3.78
Epoch 5: Loss=1.20, Weight=3.89
Epoch 6: Loss=0.30, Weight=3.95
Epoch 7: Loss=0.07, Weight=3.97
Epoch 8: Loss=0.02, Weight=3.99
Epoch 9: Loss=0.00, Weight=3.99
Epoch 10: Loss=0.00, Weight=4.00

学习曲线如下:

在这里插入图片描述

如果想获得本文的的pdf,请在公众号“冬天的李同学”上回复“2023.10.22”即可获得。

参考文章:

1.https://mp.weixin.qq.com/s/ryfSof2OrGQdJauqmTpK0A

2.https://blog.csdn.net/weixin_44878336/article/details/124733598?

ps://mp.weixin.qq.com/s/ryfSof2OrGQdJauqmTpK0A

2.https://blog.csdn.net/weixin_44878336/article/details/124733598?

3.https://blog.csdn.net/weixin_43145941/article/details/114757673?


http://www.ppmy.cn/news/1171893.html

相关文章

windows本地搭建mmlspark分布式机器平台流程

文章目录 windows本地搭建mmlspark分布式机器平台流程安装环境pyspark环境spark环境java环境hadoop环境1.修改hadoop配置文件下的jdk地址为自己的实际地址2.修改bin文件离线环境jar包环境1mmlsprk第三方包jar包环境2参考代码我有话说其他问题记录概要参考文献windows本地搭建mm…

低代码平台如何实现快速开发应用?

目录 一、低代码“快”在哪里? 下面分享低代码低代码平台实现快速开发的一些主要方式: 1.图形化编程: 2.预构建组件: 3.模板和插件: 4.自动化流程: 5.集成和扩展: 6.多端适配: 7.快速…

第九章:最新版零基础学习 PYTHON 教程—Python 元组(第四节 -Python连接元组的方法)

很多时候,在处理记录时,我们可能会遇到需要添加两条记录并将它们存储在一起的问题。这需要串联。由于元组是不可变的,因此这个任务变得不太复杂。让我们讨论执行此任务的某些方法。 目录

C语言指针详解——必备7大知识点

Part1指针是什么? 1.1 浅谈指针 理解指针的 两个要点: 指针是内存中一个最小单元的编号,也就是地址; 平时口语中说的指针,通常指的是指针变量,是用来存放内存地址的变量。 总结:指针就是地址&#xff…

Android-Framework 应用间跳转时,提供 Android Broadcast 通知

一、环境 高通865 Android 10 二、情景 应用跳转时,通过广播发送源app的包名和目标app的包名 三、代码实现 frameworks/base/services/core/java/com/android/server/wm/ActivityStarter.java -132,6 132,14 import java.io.PrintWriter;import java.text.DateFormat;imp…

践行国策,男性生育力保护与修复新启航

金秋送爽,丹桂飘香!值2023年男性健康日即将到来之时,10月22日,由中国优生优育协会生育力保护与修复专业委员会、南京大学医学院附属鼓楼医院联合举办的“首届男性生育力保护与修复诊疗技术培训班”暨“中国优生优育男性生育力保护…

ATA-2161高压放大器在压电薄膜传感器心脏监测研究中的应用

近年来,随着医疗技术的不断进步和人们对健康关注的增加,心脏疾病的早期监测与预防成为了研究的热点。压电薄膜传感器作为一种重要的生物传感器,在心脏监测领域发挥着重要的作用。而高压放大器作为压电薄膜传感器的关键驱动设备,对…

WinCC趋势跨度设置(时间范围)

控件:输入输出域、组合框、按钮、实时趋势控件 输入输出域 对象名称:IOI 域类型:输入 组合框 对象名称:cb 索引与文本一一对应 按钮VB Sub OnClick(Byval Item) …