【博士每天一篇文献-算法】持续学习之GEM算法:Gradient Episodic Memory for Continual Learning

embedded/2024/10/21 3:07:41/

1 介绍

年份:2017

期刊: Advances in neural information processing systems

引用量:2829

Lopez-Paz D, Ranzato M A. Gradient episodic memory for continual learning[J]. Advances in neural information processing systems, 2017, 30.

本文提出的算法是Gradient Episodic Memory (GEM),它通过维护一个存储先前任务样本的情节记忆,并利用梯度信息来避免在持续学习过程中对先前任务性能的损害,同时允许对先前任务有益的知识传递。

2 创新点

  1. 新的评价指标:提出了一套新的度量标准来评估模型在连续数据流上的性能,不仅包括测试准确率,还包括模型在不同任务间的知识迁移能力。
  2. Gradient Episodic Memory (GEM)模型:提出了一种新的持续学习模型GEM,它通过情节记忆来减轻遗忘问题,同时允许对之前任务的知识进行有益的传递。
  3. 非独立同分布数据的处理:GEM模型能够处理非独立同分布的数据,即数据是按任务顺序观察到的,而不是从固定的概率分布中独立抽取的。
  4. 遗忘和迁移学习的优化:GEM通过优化策略在减少对先前任务性能的负面影响(即减轻遗忘)的同时,还试图实现对新任务的正向迁移。
  5. 梯度投影方法:GEM使用梯度投影方法来确保模型在更新参数时不会增加先前任务的损失,从而避免灾难性遗忘。
  6. 高效的计算:GEM的计算效率来自于它优化的是任务数量级别的变量,而不是模型参数数量级别的变量,这大大减少了计算量。

3 算法

3.1 算法原理

  1. 情节记忆(Episodic Memory):
    • GEM算法维护一个情节记忆,该记忆存储了之前任务的样本。这些样本有助于在后续任务中避免遗忘先前学到的知识。
  2. 任务描述符(Task Descriptors):
    • 算法使用任务描述符来识别每个样本所属的任务。这有助于模型区分和处理不同的任务。
  3. 非独立同分布数据(Non-iid Data):
    • 算法设计来处理非独立同分布的数据,即数据是按任务顺序观察的,而不是随机抽取的。
  4. 梯度约束(Gradient Constraints):
    • 在训练过程中,GEM使用梯度信息来确保新任务的学习不会增加先前任务的损失,从而避免灾难性遗忘。
  5. 梯度投影(Gradient Projection):
    • GEM通过投影梯度到满足所有梯度约束的方向上来更新模型参数。这确保了对先前任务性能的保护。
  6. 正向和反向迁移(Forward and Backward Transfer):
    • GEM旨在最小化对先前任务的负面影响(反向迁移),同时允许对新任务的正向迁移。
  7. 优化问题(Optimization Problem):
    • GEM算法通过解决一个优化问题来更新模型参数,该问题包括最小化当前任务的损失,同时满足先前任务损失不增加的约束。
  8. 二次规划(Quadratic Programming):
    • GEM算法使用二次规划来有效地解决梯度投影问题,这涉及到将原始问题转化为对任务数量(而不是参数数量)进行优化的对偶问题。
  9. 评估指标(Evaluation Metrics):
    • 算法定义了一套评估指标,包括平均准确率(ACC)、反向迁移(BWT)和正向迁移(FWT),来衡量模型在连续学习任务中的表现。

3.4 算法步骤

  1. 初始化
    • 对于每个任务t,初始化一个空的情节记忆$ M_t $,用于存储该任务的部分样本。
  2. 观察数据
    • 模型按顺序观察数据,每个数据点由输入特征向量$ x_i 、任务描述符 、任务描述符 、任务描述符 t_i 和目标向量 和目标向量 和目标向量 y_i $组成。
  3. 更新情节记忆
    • 对于每个新观察到的样本,根据设定的存储策略(如存储最新的m个样本)更新对应任务的情节记忆。
  4. 计算当前任务的梯度
    • 对当前任务的样本计算梯度g,该梯度指向模型参数更新的方向。
  5. 计算先前任务的梯度
    • 对于所有先前任务k < t,计算在先前任务的样本在当前参数更新下的梯度$ g_k $ 。
  6. 梯度投影
    • 将当前梯度g投影到满足所有先前任务梯度约束的方向上,即找到一个更新方向$ \tilde{g} $,使得对所有先前任务的损失都不会增加。
    • 梯度投影通过计算当前梯度与先前任务梯度之间的夹角,来确定参数更新方向,使得新任务的学习不会损害旧任务的性能。
    • 梯度投影问题转化为一个二次规划(Quadratic Programming, QP)问题。目标是最小化 g~ 与原始梯度 g 之间的欧几里得距离,同时满足它与所有先前任务的梯度 gkgk 之间的夹角满足一定条件(通常是非负条件),以保证先前任务的损失不会增加。
      • $ min⁡ \hat{g} \frac{1}{2}∥\hat{g}−g∥^2_2 \
        subject to⟨ \hat{g},g_k⟩≥ 0 $
    • 二次规划问题可以通过多种数值方法求解,如内点法、梯度投影法或使用现有的QP求解器。在GEM算法中,作者提出了一种基于对偶问题的求解方法,通过求解对偶问题来找到原始问题的解。
  7. 参数更新
    • 使用投影后的梯度$ \tilde{g} $更新模型参数。
  8. 评估模型
    • 在每个任务学习完成后,使用测试集评估模型在所有任务上的性能,并记录评估结果。
  9. 计算评估指标
    • 根据评估结果计算平均准确率(ACC)、反向迁移(BWT)和正向迁移(FWT)。
  10. 迭代学习
    • 重复步骤2至9,直到所有任务的数据都被观察完毕。
  11. 返回模型和评估结果
    • 返回训练好的模型和评估矩阵( R ),其中包含了模型在连续学习过程中的性能指标。

5 实验分析

  1. 数据集
    • 使用了MNIST Permutations、MNIST Rotations和Incremental CIFAR100等变体数据集进行实验。
  2. 模型架构
    • 对于MNIST任务,使用了两层100个ReLU单元的全连接神经网络。
    • 对于CIFAR100任务,使用了较小版本的ResNet18,并为每个任务添加了一个最终的线性分类器。
  3. 与现有技术的比较
    • 比较了GEM与五种其他方法:单一预测器、每个任务一个独立预测器、多峰值预测器、EWC(弹性权重合并)和iCARL。
  4. 性能指标
    • 考虑了平均准确率(ACC)、反向迁移(BWT)和正向迁移(FWT)作为性能指标。
  5. 实验结果
    • GEM在所有考虑的数据集上都显示出与多峰值模型相似或更好的性能,并且在反向迁移方面表现得更好,同时显示出轻微的或积极的正向迁移。
  6. 遗忘和迁移
    • GEM在CIFAR100数据集上展示了最小的遗忘,并在多个任务中对第一个任务的测试准确度表现出积极的反向迁移。
  7. 计算效率
    • GEM在计算上比其他持续学习方法(如EWC)更有效,并且在MNIST实验中的CPU训练时间更少。
  8. 记忆大小的影响
    • 在CIFAR100实验中,GEM的最终ACC随着情节记忆大小的增加而增加,表明GEM对记忆大小的调整更为鲁棒。
  9. 训练次数的影响
    • 在MNIST Rotations实验中,与没有记忆的方法相比,基于记忆的方法(如EWC和GEM)在数据上进行多次训练时表现出更高的ACC和更低的负面BWT。
  10. 与理想性能的比较
    • GEM在MNIST Rotations实验中的表现与通过所有任务的iid数据训练的单一预测器相当,达到了“理想性能上限”。

6 思考

(1)是先求当前任务样本的梯度?还是先求先前任务样本的梯度

先求计算当前任务样本的梯度,再计算先前任务样本的梯度

(2)本文中计算先前任务的梯度,是根据当前先前任务样本在当前模型上的计算得到。而不是训练之前任务存储的。


http://www.ppmy.cn/embedded/129164.html

相关文章

python中堆的用法

Python 堆&#xff08;Headp&#xff09; Python中堆是一种基于二叉树存储的数据结构。 主要应用场景&#xff1a; 对一个序列数据的操作基于排序的操作场景&#xff0c;例如序列数据基于最大值最小值进行的操作。 堆的数据结构&#xff1a; Python 中堆是一颗平衡二叉树&am…

K8s高级调度--CronJob与污点容忍及亲和力

文章目录 CronJobCronJob 的核心概念Job调度时间表并发策略启动历史保留 CronJob YAML 配置示例Cron 表达式 CronJob 实际应用场景定期数据备份日志清理任务 污点和容忍污点的概念污点的三种效应污点和容忍的工作流程设置污点和容忍1. 给节点添加污点2. 给 Pod 添加容忍 实际应…

LeetCode搜索插入位置

题目描述 给定一个排序数组和一个目标值&#xff0c;在数组中找到目标值&#xff0c;并返回其索引。如果目标值不存在于数组中&#xff0c;返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 示例 1: 输入: nums [1,3,5,6], target 5 输出: 2 …

六、存储过程和触发器及视图和临时表

一. 存储过程和触发器是数据库中用于实现复杂业务逻辑和自动化操作的重要工具。 下面是对存储过程和触发器的详细讲解和示例说明&#xff1a;存储过程&#xff1a; 存储过程是一组预定义的SQL语句&#xff0c;封装在数据库中并可通过名称调用。存储过程可以接受输入参数和输出…

Win10+Python3.8+GPU版tensorflow2.x环境搭建最简流程(转载学习用)

在开始之前&#xff0c;请确保你的计算机已经安装了Windows 10操作系统&#xff0c;并且具备一个支持CUDA的NVIDIA显卡。 步骤1&#xff1a;安装Python 3.8 你可以选择从Python官网下载Python 3.8的安装包。在下载过程中&#xff0c;请确保勾选“Add Python to PATH”选项&…

gc cr/current block 2-way

官方文档描述 14.9.4 Analyzing Cache Fusion Transfer Impact Using GCS Statistics Describes how to monitor GCS performance by identifying objects read and modified frequently and the service times imposed by the remote access. Waiting for blocks to arrive ma…

java通过模板实现导出

/*** 导出作业票角度统计*/Log(title "导出作业票角度统计", businessType BusinessType.EXPORT)PostMapping("/export")public void export(HttpServletResponse response, PlanWiDto dto) throws IOException {try {ExcelUtil.createExcel(response, &…

线性可分支持向量机的原理推导 9-19基于拉格朗日函数L(w,b,α) 对b求偏导 公式解析

本文是将文章《线性可分支持向量机的原理推导》中的公式单独拿出来做一个详细的解析&#xff0c;便于初学者更好的理解。 公式 9-19 是对拉格朗日函数 L ( w , b , α ) L(\mathbf{w}, b, \alpha) L(w,b,α) 中的偏导数进行求解&#xff0c;目的是找到拉格朗日函数对 b b b 的…