论文笔记CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX

news/2024/11/17 22:48:15/

目录

  • Gumbel-Softmax分布
  • Gumbel-Softmax Estimator
  • Straight-Through (ST) Gumbel-Softmax Estimator
    • Straight-Through Estimator (STE)
    • Straight-Through (ST) Gumbel-Softmax Estimator
  • 参考

Gumbel-Softmax分布

Gumbel-Softmax分布是一个定义在单纯形(simplex)上的连续分布。
Gumbel-Softmax分布可以近似categorical分布。
zzz表示为服从π=(π1,…,πk)\pi = (\pi_1,\ldots,\pi_k)π=(π1,,πk)的categorical随机变量。categorical分布的样本表示为kkk维的one-hot向量,在k−1k-1k1维的单纯形空间△k−1\bigtriangleup^{k-1}k1中。

Gumbel-Max trick是reparametrization tricks的一个特例,其提供了一个简单有效的从categorical分布采样的方法:
z=one-hot(argmax⁡i[gi+log⁡πi])(1)z = \text{one-hot}\left(\operatorname{argmax}_i[g_i + \log \pi_i]\right) \tag{1} z=one-hot(argmaxi[gi+logπi])(1)其中gi∼Gumbel(0,1)g_i \sim Gumbel(0,1)giGumbel(0,1)。Gumbel分布用于对各种分布的多个样本的最大值(或最小值)的分布进行建模。Gumbel分布的概率密度是:
Gumbel(μ,β)=1βexp⁡(−x−μβ+exp⁡(−x−μβ))Gumbel(\mu, \beta) = \frac{1}{\beta}\exp(-\frac{x - \mu}{\beta} + \exp(-\frac{x - \mu}{\beta})) Gumbel(μ,β)=β1exp(βxμ+exp(βxμ))
softmax函数是可导的,使用softmax函数去近似公式(1)中的argmax,可以得到样本y∈△k−1y\in\bigtriangleup^{k-1}yk1
yi=softmax⁡[gi+log⁡πi]=exp⁡(log⁡πi+giτ)∑j=1kexp⁡(log⁡πj+gjτ)y_i = \operatorname{softmax}[g_i + \log \pi_i] = \frac{\exp(\frac{\log \pi_i + g_i}{\tau})}{\sum_{j = 1}^k \exp(\frac{\log \pi_j + g_j}{\tau})} yi=softmax[gi+logπi]=j=1kexp(τlogπj+gj)exp(τlogπi+gi)Gumbel-Softmax分布的概率密度函数是:
在这里插入图片描述
随着τ\tauτ趋近于0,Gumbel-Softmax分布的样本逐渐变成one-hot的,Gumbel-Softmax分布也逐渐变成了categorical分布。如下图所示:

在这里插入图片描述

Gumbel-Softmax Estimator

Gumbel-Softmax分布的∂y∂π\frac{\partial y}{\partial \pi}πy是有定义的。
通过用Gumbel-Softmax样本替换categorical样本,我们可以使用反向传播来计算梯度。
把在训练阶段,用可导的Gumbel-Softmax样本替代不可导的categorical样本的过程称为Gumbel-Softmax Estimator。

在温度τ\tauτ小时,样本接近单热但梯度方差大,在温度τ\tauτ大时,样本平滑但梯度方差小。
实际中,我们从高温τ\tauτ开始,然后退火到一个很小但非零的温度。

Straight-Through (ST) Gumbel-Softmax Estimator

Straight-Through Estimator (STE)

首先介绍下Straight-Through Estimator (STE)。
STE是量化(quantization)中常见的求导方式。
比如有sign函数:
wb=sign⁡(w)={+1,if w≥0−1,otherwise w_{b}=\operatorname{sign}(w)=\left\{ \begin{array}{ll}{+1,}{\text { if } w \geq 0} \\ {-1,}{\text { otherwise }}\end{array}\right. wb=sign(w)={+1, if w01, otherwise 这个sign函数在定义域范围内导数都是0。STE就是用来解决sign函数梯度无法反传的问题的。
二值网络训练过程可以是这样:模型中每个参数其实都是一个浮点型的数,每次迭代其实都是在更新这个浮点型数。但是,在前向传播的过程中,先用sign函数对浮点型参数二值化处理然后再参与到运算,而此时并没有把这个浮点型数值抛弃掉,而是暂时在内存中保存起来。前向传播完之后,网络得到一个输出,就可以接着通过反向传播算出二值参数的梯度,再直接用这个梯度来更新对应的浮点型参数。这样,前向反向就跑通了。等训练的差不多了,就最后对模型的这些浮点型参数做一次二值化处理形成最终的二值网络,此时浮点型的参数就完成了任务,可以被抛弃掉了。

Straight-Through (ST) Gumbel-Softmax Estimator

在前向的时候使用argmax离散化yyy,但在梯度反传的时候,使用连续近似∇θz≈∇θy\nabla_\theta z \approx \nabla_\theta yθzθy

参考

ICLR 2017 Categorical Reparameterization with Gumbel-Softmax
Emma Benjaminson blog
二值网络,围绕STE的那些事儿
gumbel-max-trick的数学证明


http://www.ppmy.cn/news/4205.html

相关文章

最全的SpringMVC教程,终于让我找到了

1. 为啥要学 SpringMVC&#xff1f; 1.1 SpringMVC 简介 在学习 SpringMVC 之前我们先看看在使用 Servlet 的时候我们是如何处理用户请求的&#xff1a; 配置web.xml <?xml version"1.0" encoding"UTF-8"?> <web-app xmlns"http://xmln…

【Anime.js】——用Anime.js实现动画效果

目录 目标&#xff1a; ​编辑1、确定思路 2、创建网格 3、设置随机位置 4、创建时间轴动画 完整代码&#xff1a; 目标&#xff1a; 实现自动选点&#xff0c;对该点进行先缩小后放大如何回到比其他点大一点的状态&#xff0c;并以该点从外向内放大 1、确定思路 2、创建网…

SOFA Weekly|Tongsuo 8.3.2 版本发布、C 位大咖说、本周 Contributor QA

SOFA WEEKLY | 每周精选 筛选每周精华问答&#xff0c;同步开源进展欢迎留言互动&#xff5e;SOFAStack&#xff08;Scalable Open Financial Architecture Stack&#xff09;是蚂蚁集团自主研发的金融级云原生架构&#xff0c;包含了构建金融级云原生架构所需的各个组件&#…

古典乐器网页设计成品 大学生音乐网站制作模板 大学生静态音乐HTML网页源码 dreamweaver网页作业 简单网页课程成品

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…

马上跨年了,如何用代码写一个“跨年倒计时”呢?

前言 大家好&#xff0c;我是陈橘又青&#xff0c;再过两周就是新的一年了&#xff0c;作为一名有仪式感的程序员&#xff0c;今天我们就来制作一个简单的跨年倒计时小网页&#xff0c;祝看到的所有人新年快乐&#xff01;&#xff08;附上完整源码&#xff0c;需要的小伙伴自取…

【生信算法】利用HMM纠正测序错误(Viterbi算法的python实现)

利用HMM纠正测序错误&#xff08;Viterbi算法的python实现&#xff09; 问题背景 对两个纯系个体M和Z的二倍体后代进行约~0.05x的低覆盖度测序&#xff0c;以期获得后代个体的基因型&#xff0c;即后代中哪些片段分别来源于M和Z。已知&#xff1a; 后代中基因型为MM、MZ&…

华为机试真题 Java 实现【开放日活动】【2022.11 Q4 新题】

目录 题目 思路 考点 Code 题目 题目描述 某部门开展Family Day开放日活动,其中有个从桶里取球的游戏,游戏规则如下:有N个容量一样的小桶等距排开,且每个小桶都默认装了数量不等的小球, 每个小桶装的小球数量记录在数组 bucketBallNums 中,游戏开始时,要求所有桶的小球…

【10秒在圣诞节做出温馨的圣诞树】

&#x1f935;‍♂️ 个人主页老虎也淘气 个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f44d;&#x1f3fb; 收藏…