论文:Detecting and Correcting for Label Shift with Black Box Predictors(BBSE)

news/2025/1/16 17:59:25/

前言

如果你对这篇文章感兴趣,可以点击「【访客必读 - 指引页】一文囊括主页内所有高质量博客」,查看完整博客分类与对应链接。


概述

首先从一个流感的例子讲起,医院在八月根据当月数据训练了模型 f f f,假设其特征 x \bm{x} x 为「有无咳嗽」,预测标签 y y y 为「有无得流感」。

后续几个月模型 f f f 运转良好,但到第二年二月时,医院发现 f f f 预测为「得流感」的人数大幅增加,此时我们知道这与「冬季是流感高发期」有关。但一个问题随即出现了,用八月数据训出的 f f f 是否在二月也能有效预测,其在八月数据上学得的先验是否会影响二月时的判断。

将问题形式化,我们可以发现八月和二月的 p ( y ∣ x ) = p ( p(y\mid \bm{x})=p( p(yx)=p(流感 ∣ \mid 咳嗽 ) ) ) p ( y ) = p ( p(y)=p( p(y)=p(流感 ) ) ) 明显发生了变化,因此过往在「covariate shift」上的研究不再适用。

继续深入,我们可以发现 p ( x ∣ y ) = p ( p(\bm{x}\mid y)=p( p(xy)=p(咳嗽 ∣ \mid 流感 ) ) ) 似乎并没有发生太大的变化,由此引入本篇文章所关注的「label shift」问题,其代表下述这种情况:

  • 标签边际分布 p ( y ) p(y) p(y) 发生变化,但条件分布 p ( x ∣ y ) p(\bm{x}\mid y) p(xy) 不变

随后文中提出「Black Box Shift Estimation (BBSE)」方法,利用「黑盒预测器」来估计变化的 p ( y ) p(y) p(y),且仅要求其对应「混淆矩阵 (confusion matrices)」是可逆的,即使预测器是 biased,inaccurate 或 uncalibrated。


问题设定

源域 X × Y \mathcal{X}\times \mathcal{Y} X×Y 上的分布 P P P D = { ( x i , y i ) } i = 1 n D=\{(\bm{x}_i, y_i)\}_{i=1}^n D={(xi,yi)}i=1n,基于 D D D 训练得到的黑盒模型 f : X → Y f:\mathcal{X}\rightarrow \mathcal{Y} f:XY

目标域 X × Y \mathcal{X}\times \mathcal{Y} X×Y 上的分布 Q Q Q X ′ = [ x 1 ′ ; . . . ; x m ′ ] X'=[\bm{x}_1';...;\bm{x}_m'] X=[x1;...;xm]

目标:检测 P → Q P\rightarrow Q PQ 是否发生了「label shift」,若发生了则重新训练模型,使其适应分布 Q Q Q

三大假设

  • 「label shift / target shift」假设:
    p ( x ∣ y ) = q ( x ∣ y ) ∀ x ∈ X , y ∈ Y p(\boldsymbol{x} \mid y)=q(\boldsymbol{x} \mid y) \quad \forall x \in \mathcal{X}, y \in \mathcal{Y} p(xy)=q(xy)xX,yY
  • ∀ y ∈ Y \forall y\in \mathcal{Y} yY,若 q ( y ) > 0 q(y)>0 q(y)>0 p ( y ) > 0 p(y)>0 p(y)>0
  • f f f 对应的混淆矩阵 (confusion matrix) C p ( f ) \mathrm{C}_p(f) Cp(f) 可逆,矩阵定义如下:
    C P ( f ) : = p ( f ( x ) , y ) ∈ R ∣ Y ∣ × ∣ Y ∣ \mathbf{C}_P(f):=p(f(x), y) \in \mathbb{R}^{|\mathcal{Y}| \times|\mathcal{Y}|} CP(f):=p(f(x),y)RY×Y

BBSE

「Black Box Shift Estimation (BBSE)」方法主要用于估计 w ( y ) = q ( y ) / p ( y ) w(y)=q(y)/p(y) w(y)=q(y)/p(y),其核心思路如下:
q ( y ^ ) = ∑ y ∈ Y q ( y ^ ∣ y ) q ( y ) = ∑ y ∈ Y p ( y ^ ∣ y ) q ( y ) = ∑ y ∈ Y p ( y ^ , y ) q ( y ) p ( y ) \begin{aligned} q(\hat{y}) &=\sum_{y \in \mathcal{Y}} q(\hat{y} \mid y) q(y) \\ &=\sum_{y \in \mathcal{Y}} p(\hat{y} \mid y) q(y)=\sum_{y \in \mathcal{Y}} p(\hat{y}, y) \frac{q(y)}{p(y)} \end{aligned} q(y^)=yYq(y^y)q(y)=yYp(y^y)q(y)=yYp(y^,y)p(y)q(y)

其中 y ^ \hat{y} y^ f f f 给出的伪标记,而 q ( y ^ ∣ y ) = p ( y ^ ∣ y ) q(\hat{y}\mid y)=p(\hat{y}\mid y) q(y^y)=p(y^y) 则来自于下述推导:
q ( y ^ ∣ y ) = ∑ y ∈ Y q ( y ^ ∣ x , y ) q ( x ∣ y ) = ∑ y ∈ Y q ( y ^ ∣ x , y ) p ( x ∣ y ) = ∑ y ∈ Y p f ( y ^ ∣ x ) p ( x ∣ y ) = ∑ y ∈ Y p ( y ^ ∣ x , y ) p ( x ∣ y ) = p ( y ^ ∣ y ) \begin{aligned} &q(\hat{y} \mid y)=\sum_{y \in \mathcal{Y}} q(\hat{y} \mid \boldsymbol{x}, y) q(\boldsymbol{x} \mid y)=\sum_{y \in \mathcal{Y}} q(\hat{y} \mid \boldsymbol{x}, y) p(\boldsymbol{x} \mid y) \\ &=\sum_{y \in \mathcal{Y}} p_f(\hat{y} \mid \boldsymbol{x}) p(\boldsymbol{x} \mid y)=\sum_{y \in \mathcal{Y}} p(\hat{y} \mid \boldsymbol{x}, y) p(\boldsymbol{x} \mid y)=p(\hat{y} \mid y) \end{aligned} q(y^y)=yYq(y^x,y)q(xy)=yYq(y^x,y)p(xy)=yYpf(y^x)p(xy)=yYp(y^x,y)p(xy)=p(y^y)

其关键部分在于 q ( x ∣ y ) = p ( x ∣ y ) q(\bm{x}\mid y)=p(\bm{x}\mid y) q(xy)=p(xy) 的假设。随后便可以得到:
μ y ^ = C y ^ ∣ y μ y = C y ^ , y w w ^ = C ^ y ^ , y − 1 μ ^ y ^ μ ^ y = diag ⁡ ( ν ^ y ) w ^ \begin{gathered} \mu_{\hat{y}}=\mathrm{C}_{\hat{y} \mid y} \mu_y=\mathrm{C}_{\hat{y}, y} w \\ \hat{\boldsymbol{w}}=\hat{\mathbf{C}}_{\hat{y}, y}^{-1} \hat{\boldsymbol{\mu}}_{\hat{y}} \\ \hat{\boldsymbol{\mu}}_y=\operatorname{diag}\left(\hat{\boldsymbol{\nu}}_y\right) \hat{\boldsymbol{w}} \end{gathered} μy^=Cy^yμy=Cy^,yww^=C^y^,y1μ^y^μ^y=diag(ν^y)w^

其中各符号定义如下,其核心思想就是本节最开头的公式,只不过为了严谨而引入了大量符号,但实质不变。
在这里插入图片描述

理论保障

首先是「Consistency」的保证:
在这里插入图片描述
其次是「Error bounds」方面的保证:
在这里插入图片描述
根据上述「Error bounds」的结果,可以发现在选择黑盒模型时,「 C y ^ , y \mathrm{C}_{\hat{y}, y} Cy^,y 最小奇异值」越大的模型越合适。


Label-Shift 检测

在最开头的三大假设下, q ( y ) = p ( y ) ⇔ p ( y ^ ) = q ( y ^ ) q(y)=p(y)\Leftrightarrow p(\hat{y})=q(\hat{y}) q(y)=p(y)p(y^)=q(y^),因此使用「two-sample tests」对 p ( y ^ ) = q ( y ^ ) p(\hat{y})=q(\hat{y}) p(y^)=q(y^) 进行检测即可。


让模型适应新分布

计算出 w ^ \hat{\bm{w}} w^ 后,采用「importance weighted ERM」在源域数据集 D \mathcal{D} D 上重新训练模型即可,具体训练目标如下:
L = ∑ i = 1 n w ^ i ⋅ ℓ ( y i , x i ) \mathcal{L}=\sum_{i=1}^n \hat{w}_i\cdot \ell\left(y_i, \bm{x}_i\right) L=i=1nw^i(yi,xi)

整体算法如下:
在这里插入图片描述


检测 Label-Shift 假设成立

采用「kernel two-sample tests」检测下述式子是否成立:
E p [ w ( y ) k ( ϕ ( x ) , ⋅ ) ] = E q [ k ( ϕ ( x ) , ⋅ ) ] \mathbb{E}_p[\boldsymbol{w}(y) k(\phi(\boldsymbol{x}), \cdot)]=\mathbb{E}_q[k(\phi(\boldsymbol{x}), \cdot)] Ep[w(y)k(ϕ(x),)]=Eq[k(ϕ(x),)]

即转化为下述 MMD 距离的计算:
∥ 1 n ∑ i = 1 n [ w ^ ( y i ) k ( ϕ ( x i ) , ⋅ ) ] − 1 m ∑ j = 1 m k ( ϕ ( x j ′ ) , ⋅ ) ∥ H 2 \left\|\frac{1}{n} \sum_{i=1}^n\left[\hat{\boldsymbol{w}}\left(y_i\right) k\left(\phi\left(\boldsymbol{x}_i\right), \cdot\right)\right]-\frac{1}{m} \sum_{j=1}^m k\left(\phi\left(\boldsymbol{x}_j^{\prime}\right), \cdot\right)\right\|_{\mathcal{H}}^2 n1i=1n[w^(yi)k(ϕ(xi),)]m1j=1mk(ϕ(xj),) H2


参考资料

  • [ICML18 - Zachary C. Lipton] Detecting and Correcting for Label Shift with Black Box Predictors

http://www.ppmy.cn/news/628637.html

相关文章

论文解读《Co-Correcting:Noise-tolerant Medical Image Classification via mutual Label Correction》

论文解读《Co-Correcting:Noise-tolerant Medical Image Classification via mutual Label Correction》 论文解读:协同校正:通过相互标签校正的抗噪声医学图像分类 期刊名: IEEE TRANSACTIONS ON MEDICAL IMAGING (医学影像学报&#xff0…

纠错输出编码(Error-Correcting Output Codes: ECOC)

最近在利用Error-Correcting Output Codes做论文,发现网上没有一种讲的比较清楚的,那我今天就花点时间大致上讲一下这种方法。最初提出ECOC方法的是如下的文章 Solving Multiclass Learning Problems via Error-Correcting Output Codes --Thomas G. Die…

纠错输出码(Error Correcting Output Code, ECOC)

纠错输出码流程 1、编码 对N个类别做M次划分,每次划分将一部分类别划为正类,一部分划为反类(M个训练集)。 如例(a) 则N4,M5,每次划分为1或者-1(二分类) 2、解码 测试示例交给M个…

ChatGPT伪原创文章的应用与发展

ChatGPT是一种基于人工智能技术的自然语言处理模型,它能够生成逼真的、具有上下文连贯性的文本。近年来,ChatGPT在各个领域的应用越来越广泛,其发展潜力也逐渐被人们所认识。本文将从多个方面对ChatGPT的应用与发展进行详细阐述。 ChatGPT在…

数据质量管理—3、数据修正(Data Correcting)

前面的两篇文章——分析的前提—数据质量1和分析的前提—数据质量2分别介绍了通过Data Profiling的方法获取数据的统计信息,并使用Data Auditing来评估数据是否存在质量问题,数据的质量问题可以通过完整性、准确性和一致性三个方面进行审核。这篇文章介绍…

常用校验方式以及优缺点(奇偶校验,CRC校验,校验和)

一、差错产生的原因 在原始的物理传输线路上传输数据信号是有差错的,存在一定的误码率,数据链路层存在的目的就是给原始二进制位流增加一些控制信息 ,实现如何在有差错的线路上进行无差错传输 信道的电气特性引起信号幅度,频率&a…

论文解读:Correcting Chinese Spelling Errors with Phonetic Pre-training

论文解读:Correcting Chinese Spelling Errors with Phonetic Pre-training(ACL2021) 中文拼写纠错CSC任务具有挑战性,目前的SOTA方法是仅使用语言模型,或将语音信息作为外部知识;本文将提出一种新的端到端…

jdk代理和cglib代理(实例推导)

目录 jdk代理和cglib代理(实例推导)jdk动态代理Cglib动态代理总结 jdk代理和cglib代理(实例推导) 更深层的探究jdk和cglib动态代理的原理 jdk动态代理 jdk动态代理(简单实现) 定义一个House的房源类型接口…