在跑自己的课题时,突然发现自己的模型的最终输出全是[Nan,Nan,Nan…],这时候就开始了逐步的排查:
参考链接:https://blog.csdn.net/qq_41682740/article/details/126304613。
一般模型输出为Nan有这么几个原因:
- 1.数据问题:输入数据中可能存在NaN或Inf(无穷大)的值,这可能会导致计算结果为NaN。建议检查输入数据,确保其数值范围正确。
- 2.模型结构问题:模型中的某些层或操作可能会导致NaN的输出。例如,当计算log或sqrt等函数的负数时,会出现NaN。建议检查模型结构,确保模型中的所有操作都是数值稳定的。
- 3.梯度问题:当模型的梯度变得非常大时,会导致计算结果为NaN。这可能是由于学习率设置过高或梯度爆炸导致的。建议检查学习率设置,或尝试使用梯度剪裁来控制梯度范围。
- 4.数据类型问题:当数据类型不匹配时,会出现NaN。例如,在计算浮点数时使用整数类型的数据,或在计算整数时使用浮点数类型的数据。建议检查数据类型,确保输入和模型参数的类型匹配。
顺着思路一个一个找就可以了,我这里直接检查模型的梯度:
for name, param in model.named_parameters():if param.requires_grad:print(name, param.data)if param.data!=param.data: #Nan数据的判断方法之一print('Nan')
发现梯度中有很多的Nan的产生。使用torch.autograd.detect_anomaly()
可以方便定位报错的位置,使用方式如下:
with torch.autograd.detect_anomaly():#反向传播的部分代码optimizer.zero_grad()loss_all.backward()
该方式会大幅度减慢模型的运行速度,但pytorch会自动定位至初步产生Nan的地方,方便进行进一步的检查。
最终我遇到的问题是模型的数据类型不对,我用到的clip模型是一个FP16的,这种训练很容易导致溢出产生Nan的数据,故转成FP32即可:
model.to(torch.float32)