NV GPU FMA指令测试
- 一.小结
- 二.复现步骤
- 1.获取FMA指令的峰值性能、启动开销
- 2.假设固定开销为120个cycle,希望fma pipe利用率超过95%,需要多少条指令呢,求解以下不等式:
- 3.采用1140条fma指令测试
- 4.生成fatbin
- 5.修改SASS指令,删除掉STG.E.STRONG.SYS指令,重新生成fatbin
- 6.准备测试程序,加载fatbin并运行里面的Kernel
- 7.ncu profing
- 8.将Kernel里的FMA指令增加4倍,一个smsp一个warp能打满利用率吗【不行】
本文测试了NV GPU FMA指令的行为
一.小结
-
哪怕一个空的Kernel,也有ULDC指令,从Constant Memory加载Context(>700cycle)和等待指令加载的stall(>100cycle)
根据fma的峰值性能,smsp的一个active cycle跟fma pipe cycle的比为1:2
如果一个smsp的fma pipe要达到峰值性能的95%,根据以下不等式:
(2*fma_inst) / ((fma_inst[eligible]+fma_inst[issued]) + 上面的开销[>800cycle]) > 0.95
得fma_inst>7600条指令 -
相同的指令条数,拆到4个warp里执行比放在同一个warp里执行,fma pipe利用率高2倍(本次实验的规模)
怀疑每一个warp slot里可以提前准备指令
如果只有一个warp slot在工作,指令准备与执行是串行的,导致 fma pipe工作不饱和
因此,一个warp里哪怕持续发射7600条fma指令,也打不满fma pipe -
测试以下二个规模(1, 1, 1)x(512, 1, 1) 和 (112, 1, 1)x(128, 1, 1),smsp.max的metrics一样
512=32(warpsize)4(smsp)(4个warp slot都放上warp)
128=32(warpsize)*4(smsp)
112=28(sm个数)*2(每个sm放2个block)
也就说,对某一个smsp而言,二种方案都分到了4个warp,warp slot是不区分warp来自哪一个block
只要能放在warp slot中,性能都一样 -
对算子开发的启示:
启动开销远大于执行一条fma指令需要的cycle数,使得执行一个小kernel无法充分发挥SM的性能
每个SM最放置(>smsp(个数)4warpsize)的线程,才能充分隐藏smsp指令调度的latency
二.复现步骤
1.获取FMA指令的峰值性能、启动开销
tee ncu_get_gpu_peak_sustained.cu<<-'EOF'
#include <iostream>
#include <cuda_runtime.h>__global__ void fma_kernel_v0(float *input, float *d_out) {float a=clock();float b=clock();float c=clock();float d0;int tid = threadIdx.x + blockIdx.x * blockDim.x;__asm__ __volatile__("fma.rn.f32 %0,%1,%2,%3;" : "=f"(d0) : "f"(a),"f"(b),"f"(d0));input[tid]=d0;
}__global__ void fma_kernel_v1(float *input, float *d_out) {float d0;float d1;float d2;float d3;float a=clock();float b=clock();float c=clock();#pragma unrollfor(int i=0;i<1;i++){__asm__ __volatile__("fma.rn.f32 %0,%1,%2,%3;" : "=f"(d0) : "f"(a),"f"(b),"f"(d0));}__asm__ __volatile__("st.global.v4.f32 [%0],{%1,%2,%3,%4};" :: "l"(input),"f"(d0),"f"(d1),"f"(d2),"f"(d3): "memory");
}int main() {float *d_in;float *d_out;int sm_count=1;int smsp_count=1;int warpsize=32;int total_count=sm_count*smsp_count*warpsize; cudaMalloc((void**)&d_in, total_count * sizeof(float));cudaMalloc((void**)&d_out, total_count * sizeof(float));fma_kernel_v0<<<sm_count, warpsize*smsp_count>>>(d_in, d_out);cudaDeviceSynchronize();fma_kernel_v1<<<sm_count, warpsize*smsp_count>>>(d_in, d_out);cudaDeviceSynchronize();cudaFree(d_in);cudaFree(d_out);return 0;
}
EOF
/usr/local/cuda/bin/nvcc -std=c++17 -lineinfo ncu_get_gpu_peak_sustained.cu -o ncu_get_gpu_peak_sustained
/usr/local/cuda/bin/nvcc -std=c++17 -dc -lineinfo -arch=sm_86 -ptx ncu_get_gpu_peak_sustained.cu -o ncu_get_gpu_peak_sustained.ptx
# 生成cubin
/usr/local/cuda/bin/nvcc -arch=sm_86 ncu_get_gpu_peak_sustained.ptx -cubin -o ncu_get_gpu_peak_sustained.cubin
# 生成fatbin
/usr/local/cuda/bin/nvcc -arch=sm_86 ncu_get_gpu_peak_sustained.cubin -fatbin -o ncu_get_gpu_peak_sustained.fatbin
/usr/local/cuda/bin/cuobjdump --dump-sass ncu_get_gpu_peak_sustained.fatbin/usr/local/NVIDIA-Nsight-Compute/ncu --clock-control=none --metrics \
smsp__inst_issued.max,\
smsp__inst_executed.max,\
smsp__warps_eligible.max,\
smsp__cycles_elapsed.avg.per_second,\
smsp__cycles_elapsed.max,\
smsp__warps_active.max,\
smsp__issue_active.max,\
smsp__cycles_active.max,\
sm__cycles_active.max,\
sm__inst_executed_pipe_fma.max,\
smsp__inst_executed_pipe_fma.max,\
sm__sass_thread_inst_executed_op_ffma_pred_on.max,\
sm__pipe_fma_cycles_active.max,\
smsp__pipe_fma_cycles_active.max,\
sm__thread_inst_executed_pipe_fma_pred_on.max,\
smsp__pipe_fma_cycles_active.sum.peak_sustained,\
smsp__pipe_fma_cycles_active.avg.peak_sustained,\
smsp__pipe_fma_cycles_active.max.peak_sustained,\
sm__sass_thread_inst_executed_op_ffma_pred_on.sum.peak_sustained,\
sm__sass_thread_inst_executed_op_ffma_pred_on.avg.peak_sustained,\
smsp__inst_executed_pipe_fma.sum.peak_sustained,\
smsp__warps_issue_stalled_barrier.max,\
smsp__warps_issue_stalled_branch_resolving.max,\
smsp__warps_issue_stalled_dispatch_stall.max,\
smsp__warps_issue_stalled_drain.max,\
smsp__warps_issue_stalled_imc_miss.max,\
smsp__warps_issue_stalled_lg_throttle.max,\
smsp__warps_issue_stalled_long_scoreboard.max,\
smsp__warps_issue_stalled_long_scoreboard_pipe_l1tex.max,\
smsp__warps_issue_stalled_math_pipe_throttle.max,\
smsp__warps_issue_stalled_membar.max,\
smsp__warps_issue_stalled_mio_throttle.max,\
smsp__warps_issue_stalled_mio_throttle_pipe_mio.max,\
smsp__warps_issue_stalled_misc.max,\
smsp__warps_issue_stalled_no_instruction.max,\
smsp__warps_issue_stalled_not_selected.max,\
smsp__warps_issue_stalled_short_scoreboard.max,\
smsp__warps_issue_stalled_sleeping.max,\
smsp__warps_issue_stalled_tex_throttle.max,\
smsp__warps_issue_stalled_wait.max,\
smsp__warps_issue_stalled_selected.max,\
smsp__inst_executed_pipe_fma.avg.peak_sustained ./ncu_get_gpu_peak_sustained | grep -v "n/a"
输出
fma_kernel_v0(float *, float *) (1, 1, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 8.6
---------------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
---------------------------------------------------------------- ----------- ------------
sm__cycles_active.max cycle 998
sm__sass_thread_inst_executed_op_ffma_pred_on.avg.peak_sustained inst/cycle 128
sm__sass_thread_inst_executed_op_ffma_pred_on.max inst 32 # 1.实际只有一个warp,且只有一条fma sass指令
sm__sass_thread_inst_executed_op_ffma_pred_on.sum.peak_sustained inst/cycle 3,584 # fma峰值性能
sm__thread_inst_executed_pipe_fma_pred_on.max inst 96
smsp__cycles_active.max cycle 972
smsp__cycles_elapsed.avg.per_second Ghz 1.88
smsp__cycles_elapsed.max cycle 2,704
smsp__inst_executed.max inst 14
smsp__inst_executed_pipe_fma.avg.peak_sustained inst/cycle 1
smsp__inst_executed_pipe_fma.max inst 3 # 2.实际执行了3条fma warp指令
smsp__inst_executed_pipe_fma.sum.peak_sustained inst/cycle 112
smsp__inst_issued.max inst 18
smsp__issue_active.max cycle 18 # 发射条数比实际执行的多
smsp__pipe_fma_cycles_active.avg.peak_sustained 2
smsp__pipe_fma_cycles_active.max cycle 8 # 实际上3条fma用8个cycle
smsp__pipe_fma_cycles_active.max.peak_sustained 2 # 理论上一条fma指令需要2个cycle
smsp__pipe_fma_cycles_active.sum.peak_sustained 224 # 2*28(sm)*4(smsp)
smsp__warps_active.max warp 972
smsp__warps_eligible.max warp 18
smsp__warps_issue_stalled_branch_resolving.max warp 8
smsp__warps_issue_stalled_dispatch_stall.max warp 2
smsp__warps_issue_stalled_drain.max warp 21
smsp__warps_issue_stalled_imc_miss.max warp 750 #等待加载context数据
smsp__warps_issue_stalled_misc.max warp 2
smsp__warps_issue_stalled_no_instruction.max warp 114 #等待加载指令
smsp__warps_issue_stalled_selected.max warp 18
smsp__warps_issue_stalled_short_scoreboard.max warp 46 #等待从share memory加载数据
smsp__warps_issue_stalled_wait.max warp 39
---------------------------------------------------------------- ----------- ------------fma_kernel_v1(float *, float *) (1, 1, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 8.6
Warning: Data collection happened without fixed GPU frequencies. Profiling results may be inconsistent.
Section: Command line profiler metrics
---------------------------------------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
---------------------------------------------------------------- ----------- ------------
sm__cycles_active.max cycle 1,031
sm__inst_executed_pipe_fma.max inst 1 # 正常了,一条fma指令2个cycle
sm__pipe_fma_cycles_active.max cycle 2
sm__sass_thread_inst_executed_op_ffma_pred_on.avg.peak_sustained inst/cycle 128
sm__sass_thread_inst_executed_op_ffma_pred_on.max inst 32
sm__sass_thread_inst_executed_op_ffma_pred_on.sum.peak_sustained inst/cycle 3,584
sm__thread_inst_executed_pipe_fma_pred_on.max inst 32
smsp__cycles_active.max cycle 1,000
smsp__cycles_elapsed.avg.per_second Ghz 1.88
smsp__cycles_elapsed.max cycle 2,711
smsp__inst_executed.max inst 11
smsp__inst_executed_pipe_fma.avg.peak_sustained inst/cycle 1
smsp__inst_executed_pipe_fma.max inst 1
smsp__inst_executed_pipe_fma.sum.peak_sustained inst/cycle 112
smsp__inst_issued.max inst 16
smsp__issue_active.max cycle 16
smsp__pipe_fma_cycles_active.avg.peak_sustained 2
smsp__pipe_fma_cycles_active.max cycle 2
smsp__pipe_fma_cycles_active.max.peak_sustained 2
smsp__pipe_fma_cycles_active.sum.peak_sustained 224
smsp__warps_active.max warp 1,000
smsp__warps_eligible.max warp 16
smsp__warps_issue_stalled_branch_resolving.max warp 8
smsp__warps_issue_stalled_drain.max warp 33
smsp__warps_issue_stalled_imc_miss.max warp 743
smsp__warps_issue_stalled_misc.max warp 1
smsp__warps_issue_stalled_no_instruction.max warp 157
smsp__warps_issue_stalled_selected.max warp 16
smsp__warps_issue_stalled_short_scoreboard.max warp 6
smsp__warps_issue_stalled_wait.max warp 33
---------------------------------------------------------------- ----------- ------------
小结
1.不同的使用方式,可能会导致执行重复发射
2.加载context和指令的开销不可避免,远大于执行一条fma指令需要的cycle数
3.FMA PIPE利用率100%时,每个smsp cycle,fma_cycles为2个cycle,即fma pipe需要二个cycle
4.smsp__pipe_fma_cycles_active(2)=smsp__cycles_active(1000)*2 时才能达到峰值性能
5.smsp__cycles_active(1000)=smsp__warps_active(1000)+其它开销(0)
6.一个warp可能同时处于多个smsp__warps_issue_stalled状态,因此不能准确知道一共stall了多长时间
7.smsp__warps_active(1000)=smsp__issue_active(16)+smsp__warps_eligible(16)+smsp__warps_issue_stalled*(>743)
8.假设去掉加载context的时间(实际不能去掉).这个简单的kernel,加载指令也需要100多个cycle,视它为固定开销
2.假设固定开销为120个cycle,希望fma pipe利用率超过95%,需要多少条指令呢,求解以下不等式:
tee solve.py<<-'EOF'
import sympy as sp
from sympy import Symbol, And
n = sp.symbols('n', positive=True)
inequality = (2*n) / ((n+n) + 120) > 0.95
sol = sp.solve([inequality])
print(sol)
EOF
python solve.py
输出
1140 < n #最少需要1140条fma指令
3.采用1140条fma指令测试
tee fma_kernel.cu<<-'EOF'
#include <iostream>
#include <cuda_runtime.h>
__global__ void fma_kernel(float *input,float *output) {#define COUNT 4float d0[COUNT];float d1[COUNT];float d2[COUNT];float d3[COUNT];int tid = threadIdx.x + blockIdx.x * blockDim.x;float a=clock();float b=clock();float c=clock();//4*4*72=1152条fma指令#pragma unrollfor(int j=0;j<72;j++){#pragma unrollfor(int i=0;i<COUNT;i++){d0[i]=input[i*32+tid];__asm__ __volatile__("fma.rn.f32 %0,%1,%2,%3;" : "=f"(d0[i]) : "f"(a),"f"(b),"f"(d0[i]));__asm__ __volatile__("fma.rn.f32 %0,%1,%2,%3;" : "=f"(d1[i]) : "f"(a),"f"(b),"f"(d1[i]));__asm__ __volatile__("fma.rn.f32 %0,%1,%2,%3;" : "=f"(d2[i]) : "f"(a),"f"(b),"f"(d2[i]));__asm__ __volatile__("fma.rn.f32 %0,%1,%2,%3;" : "=f"(d3[i]) : "f"(a),"f"(b),"f"(d3[i])); }}#pragma unrollfor(int i=0;i<COUNT;i++){__asm__ __volatile__("st.global.v4.f32 [%0],{%1,%2,%3,%4};" :: "l"(&output[i*32+tid]),"f"(d0[i]),"f"(d1[i]),"f"(d2[i]),"f"(d3[i]): "memory");}
}
EOF
4.生成fatbin
# 生成ptx
/usr/local/cuda/bin/nvcc -std=c++17 -dc -lineinfo -arch=sm_86 -ptx fma_kernel.cu -o fma_kernel.ptx
# 生成cubin
/usr/local/cuda/bin/nvcc -arch=sm_86 fma_kernel.ptx -cubin -o fma_kernel.cubin
# 生成fatbin
/usr/local/cuda/bin/nvcc -arch=sm_86 fma_kernel.cubin -fatbin -o fma_kernel.fatbin
# 查看ptx
cat fma_kernel.ptx
# 查看sass指令
/usr/local/cuda/bin/cuobjdump --dump-sass fma_kernel.fatbin
5.修改SASS指令,删除掉STG.E.STRONG.SYS指令,重新生成fatbin
cuasm.py fma_kernel.cubin fma_kernel.cuasm# 仅保留FMA指令
sed '/MOV/d' -i fma_kernel.cuasm
sed '/ULDC/d' -i fma_kernel.cuasm
sed '/STG/d' -i fma_kernel.cuasm
sed '/I2F/d' -i fma_kernel.cuasm
sed '/CS2R/d' -i fma_kernel.cuasm
sed '/BRA/d' -i fma_kernel.cuasm
sed '/LDG/d' -i fma_kernel.cuasm
sed '/IMAD/d' -i fma_kernel.cuasm
sed '/S2R/d' -i fma_kernel.cuasm# 生新行成cubin
cuasm.py fma_kernel.cuasm
# 生成fatbin
/usr/local/cuda/bin/nvcc -arch=sm_86 fma_kernel.cubin -fatbin -o fma_kernel.fatbin
6.准备测试程序,加载fatbin并运行里面的Kernel
tee fma_kernel_main.cpp<<-'EOF'
#include <stdio.h>
#include <string.h>
#include <cuda_runtime.h>
#include <cuda.h>int main(int argc,char *argv[])
{CUresult error;CUdevice cuDevice;cuInit(0);int deviceCount = 0;error = cuDeviceGetCount(&deviceCount);error = cuDeviceGet(&cuDevice, 0);if(error!=CUDA_SUCCESS){printf("Error happened in get device!\n");}CUcontext cuContext;error = cuCtxCreate(&cuContext, 0, cuDevice);if(error!=CUDA_SUCCESS){printf("Error happened in create context!\n");}CUmodule module;CUfunction function;const char* module_file = "fma_kernel.fatbin";const char* kernel_name = "_Z10fma_kernelPfS_";error = cuModuleLoad(&module, module_file);if(error!=CUDA_SUCCESS){printf("Error happened in load moudle %d!\n",error);}error = cuModuleGetFunction(&function, module, kernel_name);if(error!=CUDA_SUCCESS){printf("get function error!\n");}int data_size=sizeof(float)*8192;float *output_ptr=nullptr;float *input_ptr=nullptr;int cudaStatus=0;cudaStatus = cudaMalloc((void**)&input_ptr, data_size);cudaStatus = cudaMalloc((void**)&output_ptr, data_size);void *kernelParams[]= {(void*)&output_ptr, (void*)&input_ptr};cuLaunchKernel(function,1, 1, 1,32, 1, 1,0,0,kernelParams, 0);cuLaunchKernel(function,1, 1, 1,32*4, 1, 1,0,0,kernelParams, 0);cuLaunchKernel(function,1, 1, 1,32*4*2, 1, 1,0,0,kernelParams, 0);cuLaunchKernel(function,28*2, 1, 1,32*4, 1, 1,0,0,kernelParams, 0); cuLaunchKernel(function,1, 1, 1,32*4*4, 1, 1,0,0,kernelParams, 0);cuLaunchKernel(function,28*4, 1, 1,32*4, 1, 1,0,0,kernelParams, 0); cudaFree(output_ptr);cudaFree(input_ptr);cuModuleUnload(module);cuCtxDestroy(cuContext);return 0;
}
EOF
g++ fma_kernel_main.cpp -o fma_kernel_main -I /usr/local/cuda/include -L /usr/local/cuda/lib64 -lcudart -lcuda
7.ncu profing
/usr/local/NVIDIA-Nsight-Compute/ncu --clock-control=none --metrics \
smsp__pipe_fma_cycles_active.max,\
sm__pipe_fma_cycles_active.max,\
sm__cycles_active.max,\
smsp__warps_active.max,\
smsp__cycles_active.max ./fma_kernel_main
输出
# 只有一个smsp 且只有一个warp
fma_kernel(float *, float *) (1, 1, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 8.6
-------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------- ----------- ------------
sm__cycles_active.max cycle 1,999
sm__pipe_fma_cycles_active.max cycle 1,736
smsp__cycles_active.max cycle 1,981
smsp__pipe_fma_cycles_active.max cycle 1,736
smsp__warps_active.max warp 1,991
-------------------------------- ----------- ------------fma_kernel(float *, float *) (1, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.6
-------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------- ----------- ------------
sm__cycles_active.max cycle 1,988
sm__pipe_fma_cycles_active.max cycle 6,944 # 4个smsp分别执行一个warp,同样的sm__cycles_active下fma性能提升了4倍
smsp__cycles_active.max cycle 1,964
smsp__pipe_fma_cycles_active.max cycle 1,736
smsp__warps_active.max warp 2,007
-------------------------------- ----------- ------------# 1个sm 4个smsp,每个上分配2个warp
fma_kernel(float *, float *) (1, 1, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.6
-------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------- ----------- ------------
sm__cycles_active.max cycle 2,034
sm__pipe_fma_cycles_active.max cycle 13,888
smsp__cycles_active.max cycle 2,012 # 2012-1964=只增加了48个cycle,但FMA的性能翻倍(3472/1736) 但只有理论值的86% (3472/2012/2)
smsp__pipe_fma_cycles_active.max cycle 3,472
smsp__warps_active.max warp 3,951
-------------------------------- ----------- ------------# 多个block跟
fma_kernel(float *, float *) (56, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.6
-------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------- ----------- ------------
sm__cycles_active.max cycle 2,000
sm__pipe_fma_cycles_active.max cycle 13,888
smsp__cycles_active.max cycle 1,982
smsp__pipe_fma_cycles_active.max cycle 3,472
smsp__warps_active.max warp 3,972
-------------------------------- ----------- ------------# 1个sm 4个smsp,每个上分配4个warp
fma_kernel(float *, float *) (1, 1, 1)x(512, 1, 1), Context 1, Stream 7, Device 0, CC 8.6
-------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------- ----------- ------------
sm__cycles_active.max cycle 3,657
sm__pipe_fma_cycles_active.max cycle 27,776
smsp__cycles_active.max cycle 3,634 #3634-2012=增加了1622个cycle,FMA性能翻倍(6944/3472) 达到理论性能的95%(6944/3634/2)
smsp__pipe_fma_cycles_active.max cycle 6,944
smsp__warps_active.max warp 11,444
-------------------------------- ----------- ------------# 每个smsp 4个warp
fma_kernel(float *, float *) (112, 1, 1)x(128, 1, 1), Context 1, Stream 7, Device 0, CC 8.6
-------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------- ----------- ------------
sm__cycles_active.max cycle 3,669
sm__pipe_fma_cycles_active.max cycle 27,776
smsp__cycles_active.max cycle 3,649
smsp__pipe_fma_cycles_active.max cycle 6,944
smsp__warps_active.max warp 10,981
-------------------------------- ----------- ------------
8.将Kernel里的FMA指令增加4倍,一个smsp一个warp能打满利用率吗【不行】
tee fma_kernel.cu<<-'EOF'
#include <iostream>
#include <cuda_runtime.h>__global__ void fma_kernel(float *input,float *output) {#define COUNT 4float d0[COUNT];float d1[COUNT];float d2[COUNT];float d3[COUNT];int tid = threadIdx.x + blockIdx.x * blockDim.x;float a=clock();float b=clock();float c=clock();#pragma unrollfor(int j=0;j<72*4;j++){#pragma unrollfor(int i=0;i<COUNT;i++){d0[i]=input[i*32+tid];__asm__ __volatile__("fma.rn.f32 %0,%1,%2,%3;" : "=f"(d0[i]) : "f"(a),"f"(b),"f"(d0[i]));__asm__ __volatile__("fma.rn.f32 %0,%1,%2,%3;" : "=f"(d1[i]) : "f"(a),"f"(b),"f"(d1[i]));__asm__ __volatile__("fma.rn.f32 %0,%1,%2,%3;" : "=f"(d2[i]) : "f"(a),"f"(b),"f"(d2[i]));__asm__ __volatile__("fma.rn.f32 %0,%1,%2,%3;" : "=f"(d3[i]) : "f"(a),"f"(b),"f"(d3[i])); }}#pragma unrollfor(int i=0;i<COUNT;i++){__asm__ __volatile__("st.global.v4.f32 [%0],{%1,%2,%3,%4};" :: "l"(&output[i*32+tid]),"f"(d0[i]),"f"(d1[i]),"f"(d2[i]),"f"(d3[i]): "memory");}
}
EOF# 生成ptx
/usr/local/cuda/bin/nvcc -std=c++17 -dc -lineinfo -arch=sm_86 -ptx fma_kernel.cu -o fma_kernel.ptx
# 生成cubin
/usr/local/cuda/bin/nvcc -arch=sm_86 fma_kernel.ptx -cubin -o fma_kernel.cubin
# 生成fatbin
/usr/local/cuda/bin/nvcc -arch=sm_86 fma_kernel.cubin -fatbin -o fma_kernel.fatbincuasm.py fma_kernel.cubin fma_kernel.cuasm# 仅保留FMA指令
sed '/MOV/d' -i fma_kernel.cuasm
sed '/ULDC/d' -i fma_kernel.cuasm
sed '/STG/d' -i fma_kernel.cuasm
sed '/I2F/d' -i fma_kernel.cuasm
sed '/CS2R/d' -i fma_kernel.cuasm
sed '/BRA/d' -i fma_kernel.cuasm
sed '/LDG/d' -i fma_kernel.cuasm
sed '/IMAD/d' -i fma_kernel.cuasm
sed '/S2R/d' -i fma_kernel.cuasm# 生新行成cubin
cuasm.py fma_kernel.cuasm
# 生成fatbin
/usr/local/cuda/bin/nvcc -arch=sm_86 fma_kernel.cubin -fatbin -o fma_kernel.fatbin
/usr/local/NVIDIA-Nsight-Compute/ncu --clock-control=none --metrics \
smsp__pipe_fma_cycles_active.max,\
sm__pipe_fma_cycles_active.max,\
sm__cycles_active.max,\
smsp__warps_active.max,\
smsp__cycles_active.max ./fma_kernel_main
输出
fma_kernel(float *, float *) (1, 1, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 8.6
-------------------------------- ----------- ------------
Metric Name Metric Unit Metric Value
-------------------------------- ----------- ------------
sm__cycles_active.max cycle 7,677
sm__pipe_fma_cycles_active.max cycle 6,920
smsp__cycles_active.max cycle 7,659
smsp__pipe_fma_cycles_active.max cycle 6,920 # 不行
smsp__warps_active.max warp 7,399
-------------------------------- ----------- ------------