在使用fvcore模块计算模型的flops时,遇到了标题中的问题,记录一下解决方案。
首先是在jit_analysis.py的589行出错。经过调试发现,op_counts.values()的类型是int32,但是计算要求的类型只能是int、float、np.float64和np.int64,因此需要做如下修改:
inputs, outputs = list(node.inputs()), list(node.outputs())op_counts = self._op_handles[kind](inputs, outputs)if isinstance(op_counts, Number):op_counts = float(op_counts) # 手动进行强制转换op_counts = Counter({self._simplify_op_name(kind): op_counts})for v in op_counts.values():if not isinstance(v, (int, float, np.float64, np.int64)):raise ValueError(f"Invalid type {type(v)} for the flop count! ""Please use a wider type to avoid overflow.")
添加了注释的部分为我手动修改。
然后,在jit_handles.py处有一个警告,意思是:变量的范围不够大,值出现了溢出(这可能会导致最终计算结果为0)。解决方案在https://github.com/facebookresearch/fvcore/issues/104
如果进不去github,改进代码如下:
try:from math import prod
except ImportError:from numpy import prod as prodnp # 修改def prod(x): # 新增return int(prodnp(x))