libtorch---day03[自定义导数]

news/2024/9/18 23:17:30/ 标签: c++, pytorch

参考pytorch

背景

希望使用勒让德多项式拟合一个周期内的正弦函数。
真值: y = s i n ( x ) , x ∈ [ − π , π ] y=sin(x),x\in\left[-\pi,\pi\right] y=sin(x),x[π,π]

torch::Tensor x = torch::linspace(-M_PI, M_PI, 2000, torch::kFloat);
torch::Tensor y = torch::sin(x);

预测值是 n = 3 n=3 n=3的勒让德多多项式: y ^ = a + b × P 3 ( x ) × ( c + d x ) \hat{y} = a+b\times P_3(x)\times(c+dx) y^=a+b×P3(x)×(c+dx),其中 P 3 ( x ) = 1 2 ( 5 x 3 − 3 x ) P_3(x) = \frac{1}{2}(5x^3-3x) P3(x)=21(5x33x)

构造自动求导类

torch提供了一种可以让开发者自主定义前向传播和后向求导的机制:

1、写一个类,继承torch::autograd::Function
2、在类中定义静态的forwardbackward函数,必须是静态的,这样在调用torch::autograd::Function::applytorch::autograd::Function::backward的时候,会自动调用上述两个静态函数;

struct LegenderPolynominal3 : public torch::autograd::Function<LegenderPolynominal3>
{static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input){ctx->save_for_backward({ input });return 0.5 * (5 * torch::pow(input, 3) - 3 * input);}static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext* ctx, std::vector<torch::Tensor> grad_output){auto saved = ctx->get_saved_variables();torch::Tensor input = saved[0];torch::Tensor grad_input = grad_output[0] * 1.5 * (5 * torch::pow(input, 2) - 1);return { grad_input };}
};

关键点

  • 必须显式调用**ctx->save_for_backward({ input });保存节点信息、调用auto saved = ctx->get_saved_variables();**获取保存的节点信息;
  • forward函数计算的是预测值,这个和认知里的forward的功能相同;
  • backward函数的输入是grad_output,是损失项关于输出的梯度 ∂ L ∂ y \frac{\partial L}{\partial y} yL,而backward计算的是损失函数关于输入的梯度 ∂ L ∂ x \frac{\partial L}{\partial x} xL,因此需要计算 ∂ L ∂ x = ∂ L ∂ y × ∂ y ∂ x \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y}\times \frac{\partial y}{\partial x} xL=yL×xy
  • 必须要注意backwardforward的参数列表必须固定;

全部代码

#include <torch/torch.h>
#include <iostream>
#include "matplotlibcpp.h"struct LegenderPolynominal3 : public torch::autograd::Function<LegenderPolynominal3>
{static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input){ctx->save_for_backward({ input });return 0.5 * (5 * torch::pow(input, 3) - 3 * input);}static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext* ctx, std::vector<torch::Tensor> grad_output){auto saved = ctx->get_saved_variables();torch::Tensor input = saved[0];torch::Tensor grad_input = grad_output[0] * 1.5 * (5 * torch::pow(input, 2) - 1);return { grad_input };}
};
void plot_tensor_xy_compare(const torch::Tensor x, const torch::Tensor y, const torch::Tensor predict)
{auto data_ptr = x.data_ptr<float>();std::vector<float> x_vector(data_ptr, data_ptr + x.numel());data_ptr = y.data_ptr<float>();std::vector<float> y_vector(data_ptr, data_ptr + y.numel());data_ptr = predict.data_ptr<float>();std::vector<float> predict_vector(data_ptr, data_ptr + predict.numel());std::map<std::string, std::string> key_words({ {"label", "ground_true"}, {"color", "blue"}, {"linestyle", "-"}});matplotlibcpp::plot(x_vector, y_vector, key_words);key_words["color"] = "red";key_words["linestyle"] = "--";key_words["label"] = "prediction";matplotlibcpp::plot(x_vector, predict_vector, key_words);matplotlibcpp::grid(true);matplotlibcpp::legend();matplotlibcpp::show();
}
int main()
{torch::Tensor x = torch::linspace(-M_PI, M_PI, 1000, torch::kFloat);torch::Tensor y = torch::sin(x);torch::Tensor a = torch::full({}, 0., torch::kFloat).set_requires_grad(true);torch::Tensor b = torch::full({}, -1., torch::kFloat).set_requires_grad(true);torch::Tensor c = torch::full({}, 0., torch::kFloat).set_requires_grad(true);torch::Tensor d = torch::full({}, 0.3, torch::kFloat).set_requires_grad(true);double learning_rate = 5e-6;torch::nn::MSELoss criterion;torch::optim::SGD optimizer({a, b, c, d}, torch::optim::SGDOptions(learning_rate));for (int i = 0; i < 2000; i++){auto P3 = LegenderPolynominal3::apply(c + d * x);torch::Tensor predict = a + b * P3;torch::Tensor loss = (predict - y).pow(2).sum();// auto loss = criterion(predict, y);loss.backward();optimizer.step();optimizer.zero_grad();std::cout << "iteration: " << i + 1 << "/2000" << ", loss: " << loss.item<double>() << std::endl;}auto P3 = LegenderPolynominal3::apply(c + d * x);torch::Tensor predict = a + b * P3;plot_tensor_xy_compare(x, y, predict);return 0;
}

结果

在这里插入图片描述


http://www.ppmy.cn/news/1518866.html

相关文章

前端配置环境

工具类配置 一、下载Git Bash 下载地址 二、下载google浏览器 下载地址 三、下载微信开发者工具 下载地址 四、下载vscode 下载地址 1、安装中文包 安装中文包 教程 2、安装插件 3、vscode中使用git 教程 4、setting.json 我自己常用的&#xff1a; {"editor.fontSiz…

分布式中间件

1.Nacos 服务注册和服务发现原理图&#xff1a; 1.服务提供方将集群信息注册到Nacos&#xff0c;并定期心跳包提供健康信息&#xff0c;宕机即剔除 2.服务消费方定期拉取订阅信息&#xff0c;获取服务实例列表 3.服务集群的负载均衡是在消费者一方进行选择 负载均衡&#xf…

代理 IP 在工业物联网中的大作用

随着科技的飞速发展&#xff0c;工业物联网&#xff08;IIoT&#xff09;已经成为现代工业的重要组成部分&#xff0c;它通过将各种物理设备、传感器、控制系统等通过互联网连接起来&#xff0c;实现了工业生产的智能化、自动化和远程监控。而在这个庞大的网络体系中&#xff0…

【RabbitMQ】快速上手

目 录 一. RabbitMQ 安装二. RabbitMQ 核心概念2.1 Producer 和 Consumer2.2 Connection 和 Channel2.3 Virtual host2.4 Queue2.5 Exchange2.6 RabbitMQ 工作流程 三. AMQP四. web界面操作4.1 用户相关操作4.2 虚拟主机相关操作 五. RabbitMQ 快速入门5.1 引入依赖5.2 编写生产…

C# Default.aspx 中文乱码解决方案

Language: C#(CSharp) IDE: Notepad 程序文件内容一摸一样&#xff0c;后缀改为 Default.aspx 就会乱码&#xff0c;改成 Default.php 或 Default.html 一切正常。 尝试使用服务器端指定编码 Response.ContentType "text/html; charsetutf-8"; 客户端也指定编码 …

vim 修改文件

在 Vim 中修改文件是一个常见的任务。以下是一些基本步骤和命令&#xff0c;帮助你在 Vim 中编辑和保存文件。 打开文件 使用以下命令在终端中打开一个文件&#xff1a; vim filename基本模式 Vim 有三种基本模式&#xff1a; 正常模式&#xff08;Normal mode&#xff09…

Linux 下查找运行中的 Java 进程及 .jar 文件位置

在 Linux 环境中&#xff0c;有时我们需要查找正在运行的 Java 进程以及它们对应的 .jar 文件位置。本文将介绍如何使用命令行工具来实现这一目标。 前言 在 Linux 系统中&#xff0c;我们经常需要监控正在运行的应用程序&#xff0c;特别是在出现问题时&#xff0c;了解应用程…

乐凡三防:工业界的硬核产品——重新定义三防平板的极限

在工业4.0的浪潮中&#xff0c;科技与制造业的深度融合催生了一系列高性能、高耐用的智能产品。乐凡三防平板&#xff0c;作为工业界的新宠&#xff0c;正以其卓越的防护性能和强大的功能&#xff0c;重新定义了三防平板的极限&#xff0c;成为硬核科技的代表。 硬核防护&#…

时空图卷积网络:用于交通流量预测的深度学习框架-1

摘要 准确的交通预测对于城市交通控制和引导至关重要。由于交通流的高度非线性和复杂性&#xff0c;传统方法无法满足中长期预测任务的需求&#xff0c;且往往忽略了空间和时间的依赖关系。本文提出一种新的深度学习框架——时空图卷积网络(STGCN)来解决交通领域的时间序列预测…

在Ubuntu 18.04上安装MySQL的方法

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。 介绍 MySQL 是一个开源的数据库管理系统&#xff0c;通常作为流行的 LAMP&#xff08;Linux、Apache、MySQL、PHP/Python/Perl&#xf…

discuz Upload Failed.

baidu搜索关键词 NGINX上传文件大小限制 在Nginx中设置文件上传大小限制&#xff0c;你需要修改client_max_body_size指令。默认情况下&#xff0c;该指令的值为1M&#xff0c;意味着上传文件大小不能超过1MB。 打开Nginx配置文件&#xff08;通常是nginx.conf或者位于/etc/…

pyautogui的一些自动化示例,附代码

以下为您提供一些 pyautogui 的自动化示例及代码&#xff1a; 模拟鼠标点击和移动&#xff1a;import pyautogui # 获取屏幕的宽度和高度 screen_width, screen_height pyautogui.size() # 将鼠标移动到屏幕中心 pyautogui.moveTo(screen_width / 2, screen_height / 2) # 在…

Linux的常见指令

前言 Hello,今天我们继续学习Liunx&#xff0c;上期我们简单了解了Linux的基本用处&#xff0c;并了解了Linux的重要性&#xff0c;今天我们就继续更加深入的学习Linux&#xff0c;进行指令方面的学习&#xff0c;我们可以通过先学习简单的基础命令来学习Linux&#xff0c;并在…

css设置让整个盒子的内容渐变透明(非颜色渐变透明)

css设置让整个盒子的内容渐变透明&#xff08;非颜色渐变透明&#xff09; 效果核心css代码 效果 核心css代码 /* 设置蒙版上下左右渐变显示 */ mask-image: linear-gradient(to right, rgba(0, 0, 0, 0) 0%, rgba(0, 0, 0, 1) 10%, rgba(0, 0, 0, 1) 90%, rgba(0, 0, 0, 0) 1…

LuaJit分析(一)LuaJit交叉编译

​​​​​​Android 使用ndk版本 r16b 在luajit2.1.0-beta3目录下创建一个脚本文件&#xff0c;armv7编译代码如下&#xff1a; make clean NDKE:/android-ndk-r16b #ndk路径 NDKABI21 NDKTRIPLEarm-linux-androideabi NDKVER$NDK/toolchains/$NDKTRIPLE-4.9 NDKP$NDKVER/…

QT基础之【模块】

QT基础之【模块】 写在前面版本信息内容全部模块QT基本模块QT附加模块增值模块技术预览模块QT工具 补充模块路径网络资料简要描述 摘要&#xff1a; 1.本文介绍了QT5.12.9的模块&#xff0c;主要核心内容来源于帮助文档&#xff0c;少量整理网络中的资料 2.分析查看安装中径中的…

代码随想录——回文子串(Leetcode 647)

题目链接 我的题解&#xff08;双指针&#xff09; 思路&#xff1a; 当然&#xff0c;以下是对您提供的代码的解释&#xff1a; class Solution {public int countSubstrings(String s) {// 初始化回文子字符串的数量int count 0;// 遍历字符串的每个字符&#xff0c;使用…

嵌入式Linux C应用编程指南-进程、线程(速记版)

第九章 进程 9.1 进程与程序 9.1.1 main()函数由谁调用&#xff1f; C 语言程序总是从 main 函数开始执行&#xff0c;main()函数的原型是&#xff1a; int main(void) 或 int main(int argc, char *argv[])。 操作系统下的应用程序在运行 main()函数之前需要先执行一段引导代…

深入解析HarmonyOS Image组件的使用与优化

在现代移动应用开发中&#xff0c;图像处理是一个至关重要的部分。HarmonyOS 提供了功能强大的图像组件&#xff0c;允许开发者从多种来源显示图像&#xff0c;如本地资源、网络资源、资源文件、媒体库和 Base64图像编码。本篇博客将深入探讨如何接地使用图像组件&#xff0c;并…

Golang | Leetcode Golang题解之第385题迷你语法分析器

题目&#xff1a; 题解&#xff1a; func deserialize(s string) *NestedInteger {index : 0var dfs func() *NestedIntegerdfs func() *NestedInteger {ni : &NestedInteger{}if s[index] [ {indexfor s[index] ! ] {ni.Add(*dfs())if s[index] , {index}}indexreturn…