FisherTrainer
的自定义 Trainer
类:累积梯度的平方并求平均来近似计算 Fisher 信息矩阵
用于计算模型参数的 Fisher 信息矩阵的近似值
整体目标
Fisher 信息矩阵用于衡量模型参数的不确定性,其在优化问题中可以帮助我们更准确地更新模型参数,避免陷入局部最优。在代码中,我们通过累积梯度的平方并求平均来近似计算 Fisher 信息矩阵。
代码各部分数学原理分析
1. 初始化部分
self.gradient_squared_sum = {name: torch.zeros_like(param)