自定义数据集使用框架的线性回归方法对其进行拟合

embedded/2025/2/2 8:30:48/

代码

python">import torch
import numpy as np
import torch.nn as nncriterion = nn.MSELoss()data = np.array([[-0.5, 7.7],[1.8, 98.5],[0.9, 57.8],[0.4, 39.2],[-1.4, -15.7],[-1.4, -37.3],[-1.8, -49.1],[1.5, 75.6],[0.4, 34.0],[0.8, 62.3]])x_data = data[:, 0]
y_data = data[:, 1]x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)model = nn.Sequential(nn.Linear(1, 1))optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epoch = 500
for n in range(1,epoch+1):y_prd = model(x_train.unsqueeze(1))loss = criterion(y_prd.squeeze(1), y_train)optimizer.zero_grad()loss.backward()optimizer.step()if n % 10 == 0 or n == 1:print(f'epoch: {n}, loss: {loss}')

运行结果:

python">epoch: 1, loss: 2957.510986328125
epoch: 10, loss: 1785.2095947265625
epoch: 20, loss: 1031.366455078125
epoch: 30, loss: 606.9622192382812
epoch: 40, loss: 366.8883361816406
epoch: 50, loss: 230.34017944335938
epoch: 60, loss: 152.19020080566406
epoch: 70, loss: 107.1496810913086
epoch: 80, loss: 80.98954772949219
epoch: 90, loss: 65.66667175292969
epoch: 100, loss: 56.6099967956543
epoch: 110, loss: 51.205726623535156
epoch: 120, loss: 47.949058532714844
epoch: 130, loss: 45.966827392578125
epoch: 140, loss: 44.74833297729492
epoch: 150, loss: 43.99201583862305
epoch: 160, loss: 43.51824188232422
epoch: 170, loss: 43.218894958496094
epoch: 180, loss: 43.02821731567383
epoch: 190, loss: 42.90589141845703
epoch: 200, loss: 42.82688903808594
epoch: 210, loss: 42.77559280395508
epoch: 220, loss: 42.74211120605469
epoch: 230, loss: 42.72018051147461
epoch: 240, loss: 42.705726623535156
epoch: 250, loss: 42.696205139160156
epoch: 260, loss: 42.68989181518555
epoch: 270, loss: 42.68572235107422
epoch: 280, loss: 42.68293380737305
epoch: 290, loss: 42.68109130859375
epoch: 300, loss: 42.67984390258789
epoch: 310, loss: 42.679046630859375
epoch: 320, loss: 42.678489685058594
epoch: 330, loss: 42.67811584472656
epoch: 340, loss: 42.67786407470703
epoch: 350, loss: 42.67770004272461
epoch: 360, loss: 42.677608489990234
epoch: 370, loss: 42.677520751953125
epoch: 380, loss: 42.67747116088867
epoch: 390, loss: 42.677452087402344
epoch: 400, loss: 42.67742156982422
epoch: 410, loss: 42.67741775512695
epoch: 420, loss: 42.67739486694336
epoch: 430, loss: 42.67738342285156
epoch: 440, loss: 42.67737579345703
epoch: 450, loss: 42.677391052246094
epoch: 460, loss: 42.67737579345703
epoch: 470, loss: 42.67738723754883
epoch: 480, loss: 42.67738342285156
epoch: 490, loss: 42.67738342285156
epoch: 500, loss: 42.67737579345703


http://www.ppmy.cn/embedded/158857.html

相关文章

10 Flink CDC

10 Flink CDC 1. CDC是什么2. CDC 的种类3. 传统CDC与Flink CDC对比4. Flink-CDC 案例5. Flink SQL 方式的案例 1. CDC是什么 CDC 是 Change Data Capture(变更数据获取)的简称。核心思想是,监测并捕获数据库的变动(包括数据或数…

初识c语言(关键字)

前言: 注意: 变量的名称不能是关键字 变量的命名: 1、有意义 int age; flat salary; 2、名字必须是字母、数字、下划线组成,不能有特殊字符, 同时不能以数字开头 3、变量的命名不能是关键字 内容: …

75-《倒提壶》

倒提壶 倒提壶(学名:Cynoglossum amabile Stapf et Drumm.):紫草科,琉璃草属多年生草本植物,高可达60厘米。茎密生贴伏短柔毛。基生叶,长圆状披针形或披针形,茎生叶长圆形或披针形&a…

分布式系统相关面试题收集

目录 什么是分布式系统,以及它有哪些主要特性? 分布式系统中如何保证数据的一致性? 解释一下CAP理论,并说明在分布式系统中如何权衡CAP三者? 什么是分布式事务,以及它的实现方式有哪些? 什么是…

Spring Boot + Facade Pattern : 通过统一接口简化多模块业务

文章目录 Pre概述在编程中,外观模式是如何工作的?外观设计模式 UML 类图外观类和子系统的关系优点案例外观模式在复杂业务中的应用实战运用1. 项目搭建与基础配置2. 构建子系统组件航班服务酒店服务旅游套餐服务 3. 创建外观类4. 在 Controller 中使用外…

oracle 19C RAC打补丁到19.26

oracle 19CRAC打补丁到19.26 本文只保留简介命令和每个命令大概的用时,方便大家直接copy使用,并了解每个命令的预期时间,减少命令执行期的等待焦虑。 1.本次所需的补丁如下 p6880880_190000_Linux-x86-64.zip (.43的opatch&…

linux 扩容

tmpfs tmpfs 82M 0 82M 0% /run/user/1002 tmpfs tmpfs 82M 0 82M 0% /run/user/0 [输入命令]# fdisk -lu Disk /dev/vda: 40 GiB, 42949672960 bytes, 83886080 sectors Units: sectors of 1 * 512 512 bytes Sector size (logi…

【MySQL】数据类型与表约束

目录 数据类型分类 数值类型 tinyint类型 bit类型 小数类型 字符串类型 日期和时间类型 enum和set 表的约束 空属性 默认值 列描述 zerofill 主键 自增长 唯一键 外键 数据类型分类 数值类型 tinyint类型 MySQL中,整形可以是有符号和无符号的&…