【编译系列】Torch.compile()训练编译——算子融合逻辑 工程化

server/2025/2/4 9:03:00/

1. 背景:

torch.compile()中,Dynamo作为前端负责计算图的捕获,后端有inductor、tvm等进行编译优化。

  • Dynamo:在Python字节码层面注入pass,实现bytecode-to-bytecode的优化,通过对bytecode逐行进行解析构建FX Graph
  • Inductor:负责对FX Graph进行AOTAutograd生成joint-graph、decompose成PrimTorch基础op、基于硬件生成对应的kernel代码

本次主要分享以下几个方面:

  • TorchInductor中的算子融合逻辑
  • 如何在torch.compile中自定义融合算子

2. TorchInductor算子融合逻辑:

在这里插入图片描述
在TorchInductor中负责算子融合的主要有三类:

  1. FX Graph上进行的算子融合:FX IR的target可能还是torch.ops级别,属于比较粗粒度的融合,通常在推理场景下生效,如推理场景下会对Conv+BN进行算子融合,而训练场景下因为权重更新问题不会生效。
  2. GraphLowering过程中的inline:GraphLowering主要负责将FX Graph转为Inductor IR,在转换成Inductor IR的过程中对那些纯计算的中间结果进行inline实现融合效果。
  3. Inductor IR上的算子融合:Scheduler对GraphLowering后所有内存分配的Inductor IR(Inductor里面称为buffer)中有共享内存访问的算子进行融合。

2.1. FX Graph上的算子融合:

此阶段存在尚未decompose的op,且还需要进行AOTAutograd构建反传计算图,因此训练场景下进行算子融合的话会比较受限(通常需要提供算子的反传函数),但相对而言能在更顶层上进行融合收益也会更明显。以Conv+BN代码为例,

# Conv+BN
# code
import torch
from typing import List
import torch._dynamo as dynamo
import torch
import torch.nn as nnclass ConvNet(nn.Module):def __init__(self, num_classes=10):super(ConvNet, self).__init__()self.conv_1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=2) # assuming input is 3-channel imageself.bn_1 = nn.BatchNorm2d(16)self.relu_1 = nn.ReLU()def forward(self, x):out = self.conv_1(x)out = self.bn_1(out)out = self.relu_1(out)return outdef test():data = torch.randn((2,3,128,128),requires_grad=True,device="cuda")model = ConvNet().to("cuda")model.eval()model = dynamo.optimize("inductor")(model)output = model(data)test()

在训练场景下,可以发现并没有在FX IR上进行任何算子融合的操作,就是先计算conv—》计算均值方差—》计算BN。

# 训练场景下的Conv+BN的FXGraph
def forward(self, primals, tangents):primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)convolution = torch.ops.aten.convolution.default(primals_8, primals_1, primals_2, [1, 1], [2, 2], [1, 1], False, [0, 0], 1);  primals_2 = Noneconvert_element_type = torch.ops.prims.convert_element_type.default(primals_5, torch.float32)convert_element_type_1 = torch.ops.prims.convert_element_type.default(primals_6, torch.float32)add = torch.ops.aten.add.Tensor(convert_element_type_1, 1e-05);  convert_element_type_1 = Nonesqrt = torch.ops.aten.sqrt.default(add);  add = Nonereciprocal = torch.ops.aten.reciprocal.default(sqrt);  sqrt = Nonemul = torch.ops.aten.mul.Tensor(reciprocal, 1);  reciprocal = Noneunsqueeze = torch.ops.aten.unsqueeze.default(convert_element_type, -1);  convert_element_type = Noneunsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, -1);  unsqueeze = Noneunsqueeze_2 = torch.ops.aten.unsqueeze.default(mul, -1);  mul = Noneunsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, -1);  unsqueeze_2 = Nonesub = torch.ops.aten.sub.Tensor(convolution, unsqueeze_1);  unsqueeze_1 = Nonemul_1 = torch.ops.aten.mul.Tensor(sub, unsqueeze_3);  sub = unsqueeze_3 = Noneunsqueeze_4 = torch.ops.aten.unsqueeze.default(primals_3, -1)unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, -1);  unsqueeze_4 = Nonemul_2 = torch.ops.aten.mul.Tensor(mul_1, unsqueeze_5);  mul_1 = unsqueeze_5 = Noneunsqueeze_6 = torch.ops.aten.unsqueeze.default(primals_4, -1);  primals_4 = Noneunsqueeze_7 = torch.ops.aten.unsqueeze.default(unsqueeze_6, -1);  unsqueeze_6 = Noneadd_1 = torch.ops.aten.add.Tensor(mul_2, unsqueeze_7);  mul_2 = unsqueeze_7 = Nonerelu = torch.ops.aten.relu.default(add_1);  add_1 = Nonealias = torch.ops.aten.alias.default(relu)alias_1 = torch.ops.aten.alias.default(alias);  alias = Nonealias_2 = torch.ops.aten.alias.default(alias_1);  alias_1 = Nonealias_3 = torch.ops.aten.alias.default(alias_2);  alias_2 = Nonele = torch.ops.aten.le.Scalar(alias_3, 0);  alias_3 = Nonescalar_tensor = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0))where = torch.ops.aten.where.self(le, scalar_tensor, tangents_1);  le = scalar_tensor = tangents_1 = Noneadd_2 = torch.ops.aten.add.Tensor(primals_6, 1e-05);  primals_6 = Nonersqrt = torch.ops.aten.rsqrt.default(add_2);  add_2 = Noneunsqueeze_8 = torch.ops.aten.unsqueeze.default(primals_5, 0);  primals_5 = Noneunsqueeze_9 = torch.ops.aten.unsqueeze.default(unsqueeze_8, 2);  unsqueeze_8 = Noneunsqueeze_10 = torch.ops.aten.unsqueeze.default(unsqueeze_9, 3);  unsqueeze_9 = Nonesum_1 = torch.ops.aten.sum.dim_IntList(where, [0, 2, 3])sub_1 = torch.ops.aten.sub.Tensor(convolution, unsqueeze_10);  convolution = unsqueeze_10 = Nonemul_3 = torch.ops.aten.mul.Tensor(where, sub_1);  sub_1 = Nonesum_2 = torch.ops.aten.sum.dim_IntList(mul_3, [0, 2, 3]);  mul_3 = Nonemul_8 = torch.ops.aten.mul.Tensor(rsqrt, primals_3);  primals_3 = Noneunsqueeze_17 = torch.ops.aten.unsqueeze.default(mul_8, 0);  mul_8 = Noneunsqueeze_18 = torch.ops.aten.unsqueeze.default(unsqueeze_17, 2);  unsqueeze_17 = Noneunsqueeze_19 = torch.ops.aten.unsqueeze.default(unsqueeze_18, 3);  unsqueeze_18 = Nonemul_9 = torch.ops.aten.mul.Tensor(where, unsqueeze_19);  where = unsqueeze_19 = Nonemul_10 = torch.ops.aten.mul.Tensor(sum_2, rsqrt);  sum_2 = rsqrt = Nonesum_3 = torch.ops.aten.sum.dim_IntList(mul_9, [0, 2, 3])convolution_backward = torch.ops.aten.convolution_backward.default(mul_9, primals_8, primals_1, [16], [1, 1], [2, 2], [1, 1], False, [0, 0], 1, [True, True, False]);  mul_9 = primals_8 = primals_1 = Nonegetitem = convolution_backward[0]getitem_1 = convolution_backward[1];  convolution_backward = Nonereturn pytree.tree_unflatten([relu, getitem_1, sum_3, mul_10, sum_1, None, None, None, getitem], self._out_spec)

在推理场景下,可以发现此时的FX Graph发生了变化,会先计算均值方差----》和Conv的weight和bias进行加法----》计算conv,即对应将BN的权重融合到Conv中进行算子融合的操作。

# 推理场景下的Conv+BN的FXGraph
def forward(self, primals, tangents):primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)add = torch.ops.aten.add.Tensor(primals_6, 1e-05);  primals_6 = Nonersqrt = torch.ops.aten.rsqrt.default(add);  add = Noneview = torch.ops.aten.view.default(rsqrt, [-1, 1, 1, 1]);  rsqrt = Noneview_1 = torch.ops.aten.view.default(primals_3, [16, 1, 1, 1]);  primals_3 = Nonemul = torch.ops.aten.mul.Tensor(view_1, view);  view_1 = Nonemul_1 = torch.ops.aten.mul.Tensor(primals_1, mul)view_2 = torch.ops.aten.view.default(mul, [16])sub = torch.ops.aten.sub.Tensor(primals_2, primals_5);  primals_2 = primals_5 = Nonemul_2 = torch.ops.aten.mul.Tensor(view_2, sub)add_1 = torch.ops.aten.add.Tensor(primals_4, mul_2);  primals_4 = mul_2 = Noneconvolution = torch.ops.aten.convolution.default(primals_8, mul_1, add_1, [1, 1], [2, 2], [1, 1], False, [0, 0], 1);  add_1 = Nonerelu = torch.ops.aten.relu.default(convolution);  convolution = Nonealias = torch.ops.aten.alias.default(relu)alias_1 = torch.ops.aten.alias.default(alias);  alias 

http://www.ppmy.cn/server/164842.html

相关文章

论文阅读(七):贝叶斯因果表型网络解释遗传变异和生物学知识

1.论文链接:Bayesian Causal Phenotype Network Incorporating Genetic Variation and Biological Knowledge 摘要: 在分离群体中,数量性状基因座(QTL)定位可以确定对表型有因果效应的QTL。这些方法的一个共同特点是Q…

排序算法3

4、希尔排序 希尔排序,也称递减增量排序算法,是插入排序的一种更高效的改进版本。但希尔排序是非稳定排序算法。希尔排序是基于插入排序的以下两点性质而提出改进方法的:插入排序在对几乎已经排好序的数据操作时,效率高&#xff…

STM32 LED呼吸灯

接线图: 这里将正极接到PA0引脚上,负极接到GND,这样就高电平点亮LED,低电平熄灭。 占空比越大,LED越亮,占空比越小,LED越暗 PWM初始化配置 输出比较函数介绍: 用这四个函数配置输…

LabVIEW无人机航线控制系统

介绍了一种无人机航线控制系统,该系统利用LabVIEW软件与MPU6050九轴传感器相结合,实现无人机飞行高度、速度、俯仰角和滚动角的实时监控。系统通过虚拟仪器技术,有效实现了数据的采集、处理及回放,极大提高了无人机航线的控制精度…

pytorch实现循环神经网络

人工智能例子汇总:AI常见的算法和例子-CSDN博客 PyTorch 提供三种主要的 RNN 变体: nn.RNN:最基本的循环神经网络,适用于短时依赖任务。nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决…

【Redis_2】短信登录

一、基于Session实现登录 RegexUtils:是定义的关于一些格式的正则表达式的工具箱 package com.hmdp.utils;import cn.hutool.core.util.StrUtil;public class RegexUtils {/*** 是否是无效手机格式* param phone 要校验的手机号* return true:符合,false&#xff…

对顾客行为的数据分析:融入2+1链动模式、AI智能名片与S2B2C商城小程序的新视角

摘要:随着互联网技术的飞速发展,企业与顾客之间的交互方式变得日益多样化,移动设备、社交媒体、门店、电子商务网站等交互点应运而生。这些交互点不仅为顾客提供了便捷的服务体验,同时也为企业积累了大量的顾客行为数据。本文旨在…

【含文档+PPT+源码】基于大数据的交通流量预测系统

项目介绍 本课程演示的是一款基于Python的图书管理系统的设计与实现,主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的 Java 学习者。 包含:项目源码、项目文档、数据库脚本、软件工具等所有资料 带你从零开始部署运行本套系统 该项目附…