现象: loss不下降
过程如下:
1.减少层数,准备最小复现环境
2.dropout设置为0,重复运行二次,对比loss是否一致
3.第二次迭代开始loss不一致
4.对比backward之后的梯度,发现某一个梯度不一致
5.dump得到所有算子的规模,单算子测试功能正常
6.怀疑是内存越界导致
7.排除通信库的问题,逐算子bypass
8.dump reduce_scatter的输入,发现每次都不样
9.在异常的时候pause进程,在python调用reduce_scatter的位置打印调用栈
10.定位到有问题的模块,是一个融合算子
11.用普通算子替换,结果一致
12.复测这个规模的融合算子功能正常
13.怀疑算子内部有内存踩踏行为
14.将输入类型从fp16改为fp32,结果正常
15.review该算子内部实现,确实有几行代码将输入当fp32处理