发表博客之:weight only int8 详细讲解,小白都可以看得懂,不懂请来打我!
- 考虑一个模型中有一个Gemm Op,有两个输入,假设都是fp16数据类型吧!
- input0是 [ M , K ] [M,K] [M,K],input1是 [ K , N ] [K,N] [K,N]且input1是个不变的权重,
- 比如下面这样的case, M = 85 ; K = 5120 , N = 15360 ; M = 85; K = 5120, N = 15360; M=85;K=5120,N=15360;,这样的case中,input1这个权重正常都是很大的,因此很占显存,要是能有啥办法减少这部分的显存就好了。
- weight only int8就可以将input1的权重变成int8,这样就可以减少一半显存了。
- 实现方式很简单啊,那就是对 [ K , N ] f p 16 [K,N]_{fp16} [K,N]fp16这个矩阵,然后每列每列的方式进行量化到int8即可,这样每列都被搞成了int8数据类型了
- 同时需要记录下每列的scale,方便运算的时候反量化到fp16参与运算!
- 至于scale数据类型是fp16还是fp32呢,这随你啦!
- 比如暂定为fp32吧!
- 运算的时候,除了有 [ K , N ] i n t 8 [K,N]_{int8} [K,N]int8这个矩阵外,还有一个 [ N ] f p 32 [N]_{fp32} [N]fp32的scale!