自定义数据集使用框架的线性回归方法对其进行拟合

ops/2025/2/2 16:13:51/

代码

python">import torch
import numpy as np
import torch.nn as nncriterion = nn.MSELoss()data = np.array([[-0.5, 7.7],[1.8, 98.5],[0.9, 57.8],[0.4, 39.2],[-1.4, -15.7],[-1.4, -37.3],[-1.8, -49.1],[1.5, 75.6],[0.4, 34.0],[0.8, 62.3]])x_data = data[:, 0]
y_data = data[:, 1]x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)model = nn.Sequential(nn.Linear(1, 1))optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epoch = 500
for n in range(1,epoch+1):y_prd = model(x_train.unsqueeze(1))loss = criterion(y_prd.squeeze(1), y_train)optimizer.zero_grad()loss.backward()optimizer.step()if n % 10 == 0 or n == 1:print(f'epoch: {n}, loss: {loss}')

运行结果:

python">epoch: 1, loss: 2957.510986328125
epoch: 10, loss: 1785.2095947265625
epoch: 20, loss: 1031.366455078125
epoch: 30, loss: 606.9622192382812
epoch: 40, loss: 366.8883361816406
epoch: 50, loss: 230.34017944335938
epoch: 60, loss: 152.19020080566406
epoch: 70, loss: 107.1496810913086
epoch: 80, loss: 80.98954772949219
epoch: 90, loss: 65.66667175292969
epoch: 100, loss: 56.6099967956543
epoch: 110, loss: 51.205726623535156
epoch: 120, loss: 47.949058532714844
epoch: 130, loss: 45.966827392578125
epoch: 140, loss: 44.74833297729492
epoch: 150, loss: 43.99201583862305
epoch: 160, loss: 43.51824188232422
epoch: 170, loss: 43.218894958496094
epoch: 180, loss: 43.02821731567383
epoch: 190, loss: 42.90589141845703
epoch: 200, loss: 42.82688903808594
epoch: 210, loss: 42.77559280395508
epoch: 220, loss: 42.74211120605469
epoch: 230, loss: 42.72018051147461
epoch: 240, loss: 42.705726623535156
epoch: 250, loss: 42.696205139160156
epoch: 260, loss: 42.68989181518555
epoch: 270, loss: 42.68572235107422
epoch: 280, loss: 42.68293380737305
epoch: 290, loss: 42.68109130859375
epoch: 300, loss: 42.67984390258789
epoch: 310, loss: 42.679046630859375
epoch: 320, loss: 42.678489685058594
epoch: 330, loss: 42.67811584472656
epoch: 340, loss: 42.67786407470703
epoch: 350, loss: 42.67770004272461
epoch: 360, loss: 42.677608489990234
epoch: 370, loss: 42.677520751953125
epoch: 380, loss: 42.67747116088867
epoch: 390, loss: 42.677452087402344
epoch: 400, loss: 42.67742156982422
epoch: 410, loss: 42.67741775512695
epoch: 420, loss: 42.67739486694336
epoch: 430, loss: 42.67738342285156
epoch: 440, loss: 42.67737579345703
epoch: 450, loss: 42.677391052246094
epoch: 460, loss: 42.67737579345703
epoch: 470, loss: 42.67738723754883
epoch: 480, loss: 42.67738342285156
epoch: 490, loss: 42.67738342285156
epoch: 500, loss: 42.67737579345703


http://www.ppmy.cn/ops/155074.html

相关文章

边缘计算与ROS结合:如何实现分布式机器人智能决策?

前言 在现代机器人系统中,分布式决策能力正成为实现群体协作任务的关键需求。传统集中式架构存在决策延迟、通信瓶颈以及容错性低等问题,而边缘计算结合 ROS(Robot Operating System)为分布式机器人智能决策提供了全新的解决方案…

【ArcGIS_Python】使用arcpy脚本将shape数据转换为三维白膜数据

说明: 该专栏之前的文章中python脚本使用的是ArcMap10.6自带的arcpy(好几年前的文章),从本篇开始使用的是ArcGIS Pro 3.3版本自带的arcpy,需要注意不同版本对应的arcpy函数是存在差异的 数据准备:准备一个…

实现网站内容快速被搜索引擎收录的方法

本文转自:百万收录网 原文链接:https://www.baiwanshoulu.com/6.html 实现网站内容快速被搜索引擎收录,是网站运营和推广的重要目标之一。以下是一些有效的方法,可以帮助网站内容更快地被搜索引擎发现和收录: 一、确…

跨境支付领域中常用的英文单词(持续更新)

### **1. 支付方式 (Payment Methods)** 1. Credit Card 2. Debit Card 3. Bank Transfer 4. Wire Transfer 5. PayPal 6. Alipay 7. WeChat Pay 8. Apple Pay 9. Google Pay 10. Cryptocurrency 11. Digital Wallet 12. Mobile Payment 13. Cash on D…

12.udp

12.udp **1. UDP特性****2. UDP编程框架(C/S模式)****3. UDP发送接收函数****4. UDP编程练习** 1. UDP特性 连接特性:无链接,通信前无需像TCP那样建立连接。可靠性:不可靠,不保证数据按序到达、不保证数据…

Oracle Primavera P6 最新版 v24.12 更新 2/2

目录 一. 引言 二. P6 EPPM 更新内容 1. 用户管理改进 2. 更轻松地标准化用户设置 3. 摘要栏标签汇总数据字段 4. 将里程碑和剩余最早开始日期拖到甘特图上 5. 轻松访问审计数据 6. 粘贴数据时排除安全代码 7. 改进了状态更新卡片视图中的筛选功能 8. 直接从活动电子…

独立游戏RPG回顾:高成本

刚看了某纪录片, 内容是rpg项目的回顾。也是这个以钱为核心话题的系列的最后一集。 对这期特别有代入感,因为主角是曾经的同事,曾经在某天晚上听过其项目组的争论。 对其这些年的起伏特别的能体会。 主角是制作人,在访谈中透露这…

android studio搭建NDK环境,使用JNI (Java 调 C) 二 通过CMake

在Java文件中写native方法 package com.example.myapplication;public class JNI {//加载so库static {System.loadLibrary("FirstJni");}//native方法public static native String sayHello();}在Termial控制台中cd到项目目录中的源码目录下,并执行 java…