libtorch---day03[自定义导数]

ops/2024/9/23 8:29:24/

参考pytorch

背景

希望使用勒让德多项式拟合一个周期内的正弦函数。
真值: y = s i n ( x ) , x ∈ [ − π , π ] y=sin(x),x\in\left[-\pi,\pi\right] y=sin(x),x[π,π]

torch::Tensor x = torch::linspace(-M_PI, M_PI, 2000, torch::kFloat);
torch::Tensor y = torch::sin(x);

预测值是 n = 3 n=3 n=3的勒让德多多项式: y ^ = a + b × P 3 ( x ) × ( c + d x ) \hat{y} = a+b\times P_3(x)\times(c+dx) y^=a+b×P3(x)×(c+dx),其中 P 3 ( x ) = 1 2 ( 5 x 3 − 3 x ) P_3(x) = \frac{1}{2}(5x^3-3x) P3(x)=21(5x33x)

构造自动求导类

torch提供了一种可以让开发者自主定义前向传播和后向求导的机制:

1、写一个类,继承torch::autograd::Function
2、在类中定义静态的forwardbackward函数,必须是静态的,这样在调用torch::autograd::Function::applytorch::autograd::Function::backward的时候,会自动调用上述两个静态函数;

struct LegenderPolynominal3 : public torch::autograd::Function<LegenderPolynominal3>
{static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input){ctx->save_for_backward({ input });return 0.5 * (5 * torch::pow(input, 3) - 3 * input);}static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext* ctx, std::vector<torch::Tensor> grad_output){auto saved = ctx->get_saved_variables();torch::Tensor input = saved[0];torch::Tensor grad_input = grad_output[0] * 1.5 * (5 * torch::pow(input, 2) - 1);return { grad_input };}
};

关键点

  • 必须显式调用**ctx->save_for_backward({ input });保存节点信息、调用auto saved = ctx->get_saved_variables();**获取保存的节点信息;
  • forward函数计算的是预测值,这个和认知里的forward的功能相同;
  • backward函数的输入是grad_output,是损失项关于输出的梯度 ∂ L ∂ y \frac{\partial L}{\partial y} yL,而backward计算的是损失函数关于输入的梯度 ∂ L ∂ x \frac{\partial L}{\partial x} xL,因此需要计算 ∂ L ∂ x = ∂ L ∂ y × ∂ y ∂ x \frac{\partial L}{\partial x} = \frac{\partial L}{\partial y}\times \frac{\partial y}{\partial x} xL=yL×xy
  • 必须要注意backwardforward的参数列表必须固定;

全部代码

#include <torch/torch.h>
#include <iostream>
#include "matplotlibcpp.h"struct LegenderPolynominal3 : public torch::autograd::Function<LegenderPolynominal3>
{static torch::Tensor forward(torch::autograd::AutogradContext* ctx, torch::Tensor input){ctx->save_for_backward({ input });return 0.5 * (5 * torch::pow(input, 3) - 3 * input);}static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext* ctx, std::vector<torch::Tensor> grad_output){auto saved = ctx->get_saved_variables();torch::Tensor input = saved[0];torch::Tensor grad_input = grad_output[0] * 1.5 * (5 * torch::pow(input, 2) - 1);return { grad_input };}
};
void plot_tensor_xy_compare(const torch::Tensor x, const torch::Tensor y, const torch::Tensor predict)
{auto data_ptr = x.data_ptr<float>();std::vector<float> x_vector(data_ptr, data_ptr + x.numel());data_ptr = y.data_ptr<float>();std::vector<float> y_vector(data_ptr, data_ptr + y.numel());data_ptr = predict.data_ptr<float>();std::vector<float> predict_vector(data_ptr, data_ptr + predict.numel());std::map<std::string, std::string> key_words({ {"label", "ground_true"}, {"color", "blue"}, {"linestyle", "-"}});matplotlibcpp::plot(x_vector, y_vector, key_words);key_words["color"] = "red";key_words["linestyle"] = "--";key_words["label"] = "prediction";matplotlibcpp::plot(x_vector, predict_vector, key_words);matplotlibcpp::grid(true);matplotlibcpp::legend();matplotlibcpp::show();
}
int main()
{torch::Tensor x = torch::linspace(-M_PI, M_PI, 1000, torch::kFloat);torch::Tensor y = torch::sin(x);torch::Tensor a = torch::full({}, 0., torch::kFloat).set_requires_grad(true);torch::Tensor b = torch::full({}, -1., torch::kFloat).set_requires_grad(true);torch::Tensor c = torch::full({}, 0., torch::kFloat).set_requires_grad(true);torch::Tensor d = torch::full({}, 0.3, torch::kFloat).set_requires_grad(true);double learning_rate = 5e-6;torch::nn::MSELoss criterion;torch::optim::SGD optimizer({a, b, c, d}, torch::optim::SGDOptions(learning_rate));for (int i = 0; i < 2000; i++){auto P3 = LegenderPolynominal3::apply(c + d * x);torch::Tensor predict = a + b * P3;torch::Tensor loss = (predict - y).pow(2).sum();// auto loss = criterion(predict, y);loss.backward();optimizer.step();optimizer.zero_grad();std::cout << "iteration: " << i + 1 << "/2000" << ", loss: " << loss.item<double>() << std::endl;}auto P3 = LegenderPolynominal3::apply(c + d * x);torch::Tensor predict = a + b * P3;plot_tensor_xy_compare(x, y, predict);return 0;
}

结果

在这里插入图片描述


http://www.ppmy.cn/ops/103961.html

相关文章

Leetcode面试经典150题-82.删除排序链表中的重复元素II前序-83.删除排序链表中的重复元素

解法都在代码里&#xff0c;不懂就留言或者私信&#xff0c;比第一题稍微难点 题目比较简单&#xff0c;真实面试中82和83都出现过&#xff0c;83偏多&#xff0c;先有个基础&#xff0c;马上分析82 /*** Definition for singly-linked list.* public class ListNode {* …

【王树森】Few-Shot Learning小样本学习 (1/3): 基本概念(个人向笔记)

前言 下面是犰狳和穿山甲的一些图片。现在要你判断右边给定的图片是犰狳和穿山甲。我相信应该不知道犰狳和穿山甲长啥样&#xff0c;但是在看了左边的 Support Set 之后&#xff0c;你就有能力从两者之间辨别出来。既然人可以通过这四张图片分辨出犰狳和穿山甲。那么计算机能不…

鸿蒙OS试题(6)

下面持续交付&持续部署描述哪个是正确的: A.持续交付(CD,Continuous Delivery):指的是&#xff0c;频繁的将软件的新版本&#xff0c;交付给质量团队或者用户&#xff0c;以供评审。如果评审通过&#xff0c;代码就进入生产阶段。它强调的是&#xff0c;不管怎么更新&…

滑动窗口系列(定长滑动窗口长度)8/31

1.长度为K子数组中的最大和 给你一个整数数组 nums 和一个整数 k 。请你从 nums 中满足下述条件的全部子数组中找出最大子数组和&#xff1a; 子数组的长度是 k&#xff0c;且子数组中的所有元素 各不相同 题意&#xff1a; 在之前题目的基础上添加了一个条件&#xff1a;…

【Python系列】text二进制方式写入文件

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

斑马线识别检测系统源码分享

斑马线识别检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Visio…

Spring MVC RESTful API - 修改状态接口示例

前言 在许多应用程序中&#xff0c;更新资源的状态是一项常见的需求。例如&#xff0c;在任务管理系统中&#xff0c;用户可能需要更新任务的状态&#xff0c;如从“待办”变为“完成”。为了实现这一功能&#xff0c;我们可以使用Spring MVC框架结合MyBatis Plus来创建一个简…

vim 简易配置

set nocompatible set backspace2 "--------------display----------------- set nu "行号 syntax on "语法高亮 set ruler "显示当前行和列 set showcmd "显示部分命令 set showmode "最后一行显示当前模式 "set match "显示括号匹配…