Abstract
本文将多任务学习中的梯度组合步骤视为一种讨价还价式博弈(bargaining game),通过游戏,各个任务协商出共识梯度更新方向。
在一定条件下,这种问题具有唯一解(Nash Bargaining Solution),可以作为多任务学习中的一种原则方法。
本文提出Nash-MTL,推导了其收敛性的理论保证。
1 Introduction
大部分MTL优化算法遵循一个通用方案。
- 计算所有任务的梯度 g 1 , g 2 , ⋯ , g K g_1,g_2,\cdots,g_K g1,g2,⋯,gK
- 使用某种聚合算法 A \mathcal{A} A,聚合梯度,得到联合梯度 Δ = A ( g 1 , ⋯ , g K ) \Delta=\mathcal{A}(g_1,\cdots,g_K) Δ=A(g1,⋯,gK)。最后采用单梯度优化算法更新模型参数。
目前还没有原则性的、公理化的聚合方法。
本文将梯度组合视为一个合作讨价还价式博弈(cooperative bargaining game)来解决,每个玩家代表多任务中的一个任务,每个玩家的收益(utility)是梯度,所有玩家通过协商找到彼此达成一致的方向。
这个情景让讨价还价式博弈可以使用,从公理化的角度分析该问题。
在一定的公理下,讨价还价式博弈有唯一的解,称为纳什讨价还价解(Nash Bargaining Solution),这个解式最公平的,是最优的。
贡献:
- 本文刻画了MTL的纳什讨价还价解,推导了一个有效的算法逼近这个值。
- 从理论分析了本文方法,在凸和非凸的情况下建立了收敛性保证。
- 实验表明Nash-MTL取得了最先进的效果。
2 Background
2.1 Pareto Optimality
MTL优化问题是多目标优化问题(multiple-objective optimization, MOO)的一个特例。
给定目标函数 ℓ 1 , ⋯ , ℓ K \ell_1,\cdots,\ell_K ℓ1,⋯,ℓK,一个解 x x x的效果可以通过目标值 ( ℓ 1 ( x ) , ⋯ , ℓ K ( x ) ) (\ell_1(x),\cdots,\ell_K(x)) (ℓ1(x),⋯,ℓK(x))向量来表示。
MOO的主要性质是:由于向量上不存在自然的线性排序,因此并不总是可以比较解,因此没有明确的最优值。
我们说一个解 x x x优于 x ′ x' x′,当且仅当 x x x在一个或多个任务上更好,而在其他任务上不差。
没有其他解更优的解,称为Pareto optimal,所有这样的解的集合成为Pareto front。在没有额外假设或用户偏好先验的情况下,无法从Pareto optimal中挑出最优解。
对于非凸问题,如果某点在包含它的某个开集内是Pareto最优的,则定义该点为局部Pareto最优。
如果某点存在梯度的凸组合,且梯度为0,则该点为Pareto stationary。Pareto stationary是Pareto optimal的必要条件。
2.2 Nash Bargaining Solution
在一个讨价还价博弈问题中,有 K K K个玩家,每个玩家的收益函数 u i : A ∪ { D } → R u_i:A\cup\{D\}\rightarrow\mathbb{R} ui:A∪{D}→R,这个是每个玩家都希望最大化的。 其中 A A A是可能达成的协议的集合, D D D是不能达成协议的谈判破裂点,如果玩家没能达成协议,则玩家会默认 D D D。
定义可能的收益集 U = { ( u 1 ( x ) , ⋯ , u K ( x ) ) : x ∈ A ⊂ R K U=\{(u_1(x),\cdots,u_K(x)):x\in A\subset\mathbb{R}^K U={(u1(x),⋯,uK(x)):x∈A⊂RK, d = ( u 1 ( D ) , ⋯ , u K ( D ) ) d=(u_1(D),\cdots,u_K(D)) d=(u1(D),⋯,uK(D))。
假设 U U U是凸紧的, U U U中存在一个点严格优于 d d d,称为存在 u ∈ U u\in U u∈U,使得 ∀ i : u i > d i \forall i: u_i>d_i ∀i:ui>di。
对于这样的收益集 U U U,两人讨价还价问题存在唯一解,该解满足以下性质或公理:Pareto optimality,对称性,无关方案独立性,仿射变换不变性。
Axiom 2.1 Pareto optimality
被认同的方案不能劣于其他方案。
Axiom 2.2 Symmetry
交换玩家的顺序后,最优解应当不变。
Axiom 2.3 Independence of irrelevant alternatives(IIA)
将收益集 U U U扩大到 U ~ ⊋ U \tilde{U}\supsetneq U U~⊋U,解决方案在原始集合 U U U中, u ∗ ∈ U u^*\in U u∗∈U,那么最优解将仍是 u ∗ u^* u∗。
Axiom 2.4 Invariance to affine transformation
将收益函数 u i ( x ) u_i(x) ui(x)变换成 u ~ i ( x ) = c i ⋅ u i ( x ) + b i \tilde{u}_i(x)=c_i\cdot u_i(x)+b_i u~i(x)=ci⋅ui(x)+bi, c i > 0 c_i>0 ci>0,如果原始最优解的收益为 ( y 1 , ⋯ , y k ) (y_1,\cdots,y_k) (y1,⋯,yk),那么变换后的最优解是 ( c 1 y 1 + b 1 , ⋯ , c k y k + b k ) (c_1y_1+b_1,\cdots,c_ky_k+b_k) (c1y1+b1,⋯,ckyk+bk)。
满足以上公理的唯一点被称为Nash bargaining solution,为
u ∗ = arg max u ∈ U ∑ i log ( u i − d i ) s . t . ∀ i : u i > d i (1) u*=\arg\max_{u\in U}\sum_i\log(u_i-d_i) \ s.t. \forall i:u_i>d_i\tag{1} u∗=argu∈Umaxi∑log(ui−di) s.t.∀i:ui>di(1)
3 Method
3.1 Nash Bargaining Multi-Task Learning
给定一个MTL优化问题和模型参数 θ \theta θ,目标是在以零点为中心,半径为 ϵ \epsilon ϵ的球 B ϵ B_\epsilon Bϵ内找到一个更新向量 Δ θ \Delta\theta Δθ。
在讨价还价博弈情景下,可达成的协议为 B ϵ B_\epsilon Bϵ集合,谈判破裂点在零点(原参数 θ \theta θ不更新)。
定义每个玩家的收益函数为 u i ( Δ θ ) = g i T Δ θ u_i(\Delta\theta)=g_i^T\Delta\theta ui(Δθ)=giTΔθ,其中 g i g_i gi是模型参数为 θ \theta θ时任务 i i i的损失梯度。由于收益集是凸紧的,且收益是线性的,可以得出:可能的收益集合也是凸紧的。
基于主要假设,如果 θ \theta θ不是Pareto stationary,那么梯度是线性无关的。
在此猜想下,谈判崩裂点 Δ θ = 0 \Delta\theta=0 Δθ=0是列于 B ϵ B_\epsilon Bϵ中其他的解的。
如果 θ \theta θ不在Pareto front中,那么Nash bargaining solution具有如下形式:
Claim 3.1
令 G G G为一个 d × K d\times K d×K的矩阵,该矩阵第 i i i列为梯度 g i g_i gi。
arg max Δ θ ∈ B ϵ ∑ i log ( Δ θ T g i ) \arg\max_{\Delta\theta\in B\epsilon}\sum_i\log(\Delta\theta^T g_i) argmaxΔθ∈Bϵ∑ilog(ΔθTgi)的解是 ∑ i α i g i \sum_i \alpha_ig_i ∑iαigi,其中 α ∈ R + K \alpha\in\mathbb{R}_+^K α∈R+K是 G T G α = 1 / α G^TG\alpha=1/\alpha GTGα=1/α的解, 1 / α 1/\alpha 1/α是逐元素倒数操作。
proof
该目标函数的导数是 ∑ i = 1 K 1 Δ θ T g i g i \sum_{i=1}^K \frac{1}{\Delta\theta^T g_i}g_i ∑i=1KΔθTgi1gi。对于所有 Δ θ \Delta\theta Δθ向量, ∀ i : Δ θ T g i > 0 \forall i:\Delta\theta^T g_i>0 ∀i:ΔθTgi>0,每个任务的收益函数以 Δ θ \Delta\theta Δθ的范数单调递增,显然 B ϵ B_\epsilon Bϵ球面上的解肯定是最优的。因此,最优点上的梯度 ∑ i = 1 K 1 Δ θ T g i g i \sum_{i=1}^K \frac{1}{\Delta\theta^T g_i}g_i ∑i=1KΔθTgi1gi一定是径向的,如 ∑ i = 1 K 1 Δ θ T g i g i = λ Δ θ \sum_{i=1}^K\frac{1}{\Delta\theta^T g_i}g_i=\lambda \Delta\theta ∑i=1KΔθTgi1gi=λΔθ。
由于梯度之间互相独立,有 Δ θ = ∑ i α i g i \Delta\theta=\sum_i\alpha_i g_i Δθ=∑iαigi, ∀ i : 1 Δ θ T g i = λ α i \forall i:\frac{1}{\Delta\theta^T g_i}=\lambda \alpha_i ∀i:ΔθTgi1=λαi。(向量之间线性无关)
下降方向内积为正,因此可以得到 λ > 0 \lambda>0 λ>0。设定 λ = 1 \lambda=1 λ=1来确定 Δ θ \Delta\theta Δθ(范数可能更大)的方向。
现在找到bargaining solution的问题已经简化为找到一个 α ∈ R K \alpha\in\mathbb{R}^K α∈RK, α i > 0 \alpha_i>0 αi>0,使得 ∀ i : Δ θ T g i = ∑ j α j g j T g i = 1 α i \forall i:\Delta\theta^T g_i=\sum_j\alpha_j g_j^T g_i=\frac{1}{\alpha_i} ∀i:ΔθTgi=∑jαjgjTgi=αi1,这等价于 G T G α = 1 / α G^TG\alpha=1/\alpha GTGα=1/α,其中 1 / α 1/\alpha 1/α是逐元素取倒数。
现在为该解提供一些直观的说明。
首先,如果所有的 g i g_i gi是互相正交的,则有 α i = 1 / ∣ ∣ g i ∣ ∣ \alpha_i=1/||g_i|| αi=1/∣∣gi∣∣, Δ θ = ∑ g i ∣ ∣ g i ∣ ∣ \Delta\theta=\sum \frac{g_i}{||g_i||} Δθ=∑∣∣gi∣∣gi。这是明显的尺度不变解。
如果非相互正交,可得:
α i ∣ ∣ g i ∣ ∣ 2 + ∑ j ≠ i α j g j T g i = 1 α i (2) \alpha_i||g_i||^2+\sum_{j\neq i}\alpha_j g_j^T g_i=\frac{1}{\alpha_i}\tag{2} αi∣∣gi∣∣2+j=i∑αjgjTgi=αi1(2)
∑ j ≠ i α j g j T g i = ( ∑ j ≠ i α j g j ) T g i \sum_{j\neq i}\alpha_j g_j^T g_i=(\sum_{j\neq i}\alpha_j g_j)^T g_i ∑j=iαjgjTgi=(∑j=iαjgj)Tgi可以被认为是任务 i i i对其他任务的影响。
- 如果这个值是正值,说明存在正向影响,其他任务的梯度有助于第 i i i项任务。
- 如果这个值是负值,说明存在负面影响,其他任务的梯度有碍于第 i i i项任务。
当该值为负值时,Eq.2等式左边变小,需要通过 α i \alpha_i αi变大来补偿。
当该值为正值时, α i \alpha_i αi变小。
3.2 Solving G T G α = 1 / α G^T G\alpha=1/\alpha GTGα=1/α
本节描述如何通过一系列凸优化问题有效逼近 G T G α = 1 / α G^TG\alpha=1/\alpha GTGα=1/α的最优解。
定义 β i ( α ) = g i T G α \beta_i(\alpha)=g_i^TG\alpha βi(α)=giTGα,希望找到一个 α \alpha α使得 ∀ i , α i = 1 / β i \forall i, \alpha_i=1/\beta_i ∀i,αi=1/βi,或等价于 log ( α i ) + log ( β i ( α i ) ) = 0 \log(\alpha_i)+\log(\beta_i(\alpha_i))=0 log(αi)+log(βi(αi))=0。
令 φ i ( α ) = log ( α i ) + log ( β i ( α ) ) \varphi_i(\alpha)=\log(\alpha_i)+\log(\beta_i(\alpha)) φi(α)=log(αi)+log(βi(α)), φ ( α ) = ∑ i φ i ( α ) \varphi(\alpha)=\sum_i\varphi_i(\alpha) φ(α)=∑iφi(α),目标是找到非负 α \alpha α使得 ∀ i , φ i ( α ) = 0 \forall i,\varphi_i(\alpha)=0 ∀i,φi(α)=0。于是优化问题变成:
min α ∑ i φ i ( α ) , s . t . ∀ i , − φ i ( α ) ≤ 0 , α i > 0 (3) \min_\alpha\sum_i\varphi_i(\alpha),\ s.t.\forall i, -\varphi_i(\alpha)\leq 0, \ \alpha_i>0\tag{3} αmini∑φi(α), s.t.∀i,−φi(α)≤0, αi>0(3)
约束是凸的且线性的,但是目标函数是凹的。首先尝试解决下面的凸目标函数:
min α ∑ i β i ( α ) , s . t . ∀ i , − φ i ( α ) ≤ 0 , α i > 0 (4) \min_\alpha\sum_i\beta_i(\alpha),\ s.t.\forall i, -\varphi_i(\alpha)\leq 0, \ \alpha_i>0\tag{4} αmini∑βi(α), s.t.∀i,−φi(α)≤0, αi>0(4)
这里最小化 β i = g i T G α ≥ 1 / α i \beta_i=g_i^TG\alpha\geq 1/\alpha_i βi=giTGα≥1/αi约束下的 ∑ i β i \sum_i\beta_i ∑iβi。虽然这个目标函数并不等价于原始问题,但却非常有效。很多情况下,得到的 φ ( α ) = 0 \varphi(\alpha)=0 φ(α)=0,符合需求。
为了进一步近似,考虑下面的问题:
min α ∑ i β i ( α ) + φ ( α ) , s . t . ∀ i , − φ i ( α ) ≤ 0 , α i > 0 (5) \min_\alpha\sum_i\beta_i(\alpha)+\varphi(\alpha),\ s.t.\forall i, -\varphi_i(\alpha)\leq 0, \ \alpha_i>0\tag{5} αmini∑βi(α)+φ(α), s.t.∀i,−φi(α)≤0, αi>0(5)
在目标函数中加入 φ ( α ) \varphi(\alpha) φ(α)可以进一步减小 φ ( α ) \varphi(\alpha) φ(α),虽然这可能导致问题是非凸的。但此时解可以被迭代地改进,通过将凹项 φ ( α ) \varphi(\alpha) φ(α)替换为其一阶近似 φ ~ τ ( α ) = φ ( α ( τ ) ) + ∇ φ ( α ( τ ) ) T ( α − α ( τ ) ) \tilde{\varphi}_\tau(\alpha)=\varphi(\alpha^{(\tau)})+\nabla\varphi(\alpha^{(\tau)})^T(\alpha-\alpha^{(\tau)}) φ~τ(α)=φ(α(τ))+∇φ(α(τ))T(α−α(τ))(泰勒展开)。其中, α ( τ ) \alpha^{(\tau)} α(τ)是第 τ \tau τ轮迭代的解。这里只替代目标函数中的 φ \varphi φ,不替代约束中的。由于没有改变约束,对任意的 τ \tau τ, α ( τ ) \alpha^{(\tau)} α(τ)总是满足原问题的约束。
最后,下面的命题表明,原始目标随 τ \tau τ单调递减:
Proposition 3.2
在Eq.5的优化问题中,将目标函数表示为 φ ( α ) = ∑ i β i ( α ) + φ ( α ) \varphi(\alpha)=\sum_i\beta_i(\alpha)+\varphi(\alpha) φ(α)=∑iβi(α)+φ(α)。于是对于所有 τ > 1 \tau>1 τ>1, φ ( α ( τ + 1 ) ) ≤ φ ( α τ ) \varphi(\alpha^{(\tau+1)})\leq \varphi(\alpha^{\tau}) φ(α(τ+1))≤φ(ατ)。
3.3 Practical Speedup
许多主流MTL方法的缺点是需要所有任务梯度来获取联合更新的方向。当任务数量 K K K很大时,非常耗费计算资源。
实际操作中发现,使用特征级梯度作为共享参数的替代会显著降低本文方法的性能。
本文提议:每隔几次迭代,更新一次梯度权重 α ( t ) \alpha^{(t)} α(t),而不是每次迭代。这种方法在维持原有效果的同时显著降低运行时间。
Algorithm 1 Nash-MTL
输入:初始参数向量 θ ( 0 ) \theta^{(0)} θ(0),可微损失函数 { ℓ i } i = 1 K \{\ell_i\}_{i=1}^K {ℓi}i=1K,学习率 η \eta η。
对于每一轮迭代 t = 1 , ⋯ , T t=1,\cdots,T t=1,⋯,T:
计算任务梯度 g i ( t ) = ∇ θ ( t − 1 ) ℓ i g_i^{(t)}=\nabla_{\theta^{(t-1)}}\ell_i gi(t)=∇θ(t−1)ℓi
将矩阵 G ( t ) G^{(t)} G(t)的每一列设置为 g i ( t ) g_i^{(t)} gi(t)
通过 ( G ( t ) ) T G ( t ) α = 1 / α (G^{(t)})^TG(t)\alpha=1/\alpha (G(t))TG(t)α=1/α获得 α ( t ) \alpha^{(t)} α(t)
更新参数 θ ( t ) = θ ( t ) − η G ( t ) α ( t ) \theta^{(t)}=\theta^{(t)}-\eta G^{(t)}\alpha^{(t)} θ(t)=θ(t)−ηG(t)α(t)
返回: θ ( T ) \theta^{(T)} θ(T)
5 Analysis
现在分析本文方法在凸和非凸情况下的收敛性。
即使是单任务,非凸优化也可能只收敛到一个稳定点,因此需要证明本文方法可以收敛到Pareto stationary点,即梯度的某个凸组合为0的点。如前所述,仍然假设在非Pareto stationary点时,梯度之间互相独立。这个假设排除了如两个相同任务的边缘情况。
通过将Assumption 5.1中的Pareto stationary替换成局部Pareto optimality,可以证明算法收敛到局部Pareto optimal point。
这一假设具有重要意义,意味着可以避免任意特定任务中的局部最大值和鞍点。
Assumption 5.1
对于由本文算法得到的序列 { θ ( t ) } t = 1 ∞ \{\theta^{(t)}\}_{t=1}^\infty {θ(t)}t=1∞,集合中任意一点和任意极限处的梯度向量 g 1 ( t ) , ⋯ , g K ( t ) g_1^{(t)},\cdots,g_K^{(t)} g1(t),⋯,gK(t)都是线性无关的,除非该点是Pareto stationary。
Assumption 5.2
假设所有损失函数都是可微的,有下界,并且所有的次级集合都是有界的。输入域是开放且凸的。
Assumption 5.3
假设所有损失函数都是光滑的:
∣ ∣ ∇ ℓ i ( x ) − ∇ ℓ i ( y ) ∣ ∣ ≤ L ∣ ∣ x − y ∣ ∣ (6) ||\nabla\ell_i(x)-\nabla\ell_i(y)||\leq L||x-y||\tag{6} ∣∣∇ℓi(x)−∇ℓi(y)∣∣≤L∣∣x−y∣∣(6)
Theorem 5.4
令 { θ ( t ) } t = 1 ∞ \{\theta^{(t)}\}_{t=1}^\infty {θ(t)}t=1∞为由 θ ( t + 1 ) = θ ( t ) − μ ( t ) Δ θ ( t ) \theta^{(t+1)}=\theta^{(t)}-\mu^{(t)}\Delta\theta^{(t)} θ(t+1)=θ(t)−μ(t)Δθ(t)生成的参数序列, Δ θ ( t ) = ∑ i = 1 K α i ( t ) g i ( t ) \Delta\theta^{(t)}=\sum_{i=1}^K\alpha_i^{(t)}g_i^{(t)} Δθ(t)=∑i=1Kαi(t)gi(t)是Nash bargaining solution ( G ( t ) ) T G ( t ) α ( t ) = 1 / α ( t ) (G^{(t)})^T G^{(t)}\alpha^{(t)}=1/\alpha^{(t)} (G(t))TG(t)α(t)=1/α(t)的解。
设 μ ( t ) = min i ∈ [ K ] 1 L K α i ( t ) \mu^{(t)}=\min_{i\in[K]}\frac{1}{LK\alpha_i^{(t)}} μ(t)=mini∈[K]LKαi(t)1。于是,序列 { θ ( t ) } t = 1 ∞ \{\theta^{(t)}\}_{t=1}^\infty {θ(t)}t=1∞存在一个子序列收敛于Pareto stationary point θ ∗ \theta^* θ∗。进一步地,所有的损失函数 ( ℓ 1 ( θ ( t ) ) , ⋯ , ℓ K ( θ ( t ) ) ) (\ell_1(\theta^{(t)}),\cdots,\ell_K(\theta^{(t)})) (ℓ1(θ(t)),⋯,ℓK(θ(t)))也收敛到 ( ℓ 1 ( θ ∗ ( t ) ∗ ) , ⋯ , ℓ K ( θ ∗ ( t ) ∗ ) ) (\ell_1(\theta^*{(t)}*),\cdots,\ell_K(\theta^*{(t)}*)) (ℓ1(θ∗(t)∗),⋯,ℓK(θ∗(t)∗))。