w~深度学习~合集4

embedded/2025/1/12 20:06:34/

 我自己的原文哦~   https://blog.51cto.com/whaosoft/12998003

#FlashAttention

FlashAttention新升级!斯坦福博士一人重写算法,第二代实现了最高9倍速提升。Transformer上下文长度史诗级提升

继超快且省内存的注意力算法FlashAttention爆火后,升级版的2代来了。

FlashAttention-2是一种从头编写的算法,可以加快注意力并减少其内存占用,且没有任何近似值。

比起第一代,FlashAttention-2速度提升了2倍。

甚至,相较于PyTorch的标准注意力,其运行速度最高可达9倍。

一年前,StanfordAILab博士Tri Dao发布了FlashAttention,让注意力快了2到4倍,如今,FlashAttention已经被许多企业和研究室采用,广泛应用于大多数LLM库。

如今,随着长文档查询、编写故事等新用例的需要,大语言模型的上下文以前比过去变长了许多——GPT-4的上下文长度是32k,MosaicML的MPT上下文长度是65k,Anthropic的Claude上下文长度是100k。

但是,扩大Transformer的上下文长度是一项极大的挑战,因为作为其核心的注意力层的运行时间和内存要求,是输入序列长度的二次方。

Tri Dao一直在研究FlashAttention-2,它比v1快2倍,比标准的注意力快5到9倍,在A100上已经达到了225 TFLOP/s的训练速度!

论文地址:https://tridao.me/publications/flash2/flash2.pdf

项目地址:https://github.com/Dao-AILab/flash-attention

FlashAttention-2:更好的算法、并行性和工作分区

端到端训练GPT模型,速度高达225 TFLOP/s

虽说FlashAttention在发布时就已经比优化的基线快了2-4倍,但还是有相当大的进步空间。

比方说,FlashAttention仍然不如优化矩阵乘法(GEMM)运算快,仅能达到理论最大FLOPs/s的25-40%(例如,在A100 GPU上的速度可达124 TFLOPs/s)。

GEMM如何用于卷积

在过去的几个月里,研究人员一直在开发FlashAttention-2,它的性能指标比第一代更强。

研究人员表示,2代相当于完全从头重写,使用英伟达的CUTLASS 3.x及其核心库CuTe。从速度上看,FlashAttention-2比之前的版本快了2倍,在A100 GPU上的速度可达230 TFLOPs/s。

当使用端到端来训练GPT之类的语言模型时,研究人员的训练速度高达225 TFLOPs/s(模型的FLOP利用率为72%)。

对注意力计算重新排序

我们知道,FlashAttention是一种对注意力计算进行重新排序的算法,利用平铺、重新计算来显著加快计算速度,并将序列长度的内存使用量从二次减少到线性。

研究人员将输入块从HBM(GPU内存)加载到SRAM(快速缓存),并对该模块执行注意,更新HBM中的输出。

由于没有将大型中间注意力矩阵写入HBM,内存的读/写量也跟着减少,进而带来了2-4倍的执行时间加速。

下图是FlashAttention的前向传递图:通过平铺和softmax重新缩放,研究人员人员按模块进行操作,避免从HBM读取或是写入,同时获得正确输出,无需近似。

然而,FlashAttention仍然存在一些低效率的问题,这是由于不同线程块之间的工作划分并不理想,以及GPU上的warp——导致低占用率或不必要的共享内存读写。

更少的non-matmul FLOP(非矩阵乘法浮点计算数)

研究人员通过调整FlashAttention的算法来减少non-matmul FLOP的次数。这非常重要,因为现代GPU有专门的计算单元(比如英伟达GPU上的张量核心),这就使得matmul的速度更快。

例如,A100 GPU FP16/BF16 matmul的最大理论吞吐量为312 TFLOPs/s,但non-matmul FP32的理论吞吐量仅为 19.5 TFLOPs/s。

另外,每个非matmul FLOP比matmul FLOP要贵16倍。

所以为了保持高吞吐量,研究人员希望在matmul FLOP上花尽可能多的时间。

研究人员还重新编写了FlashAttention中使用的在线softmax技巧,以减少重新缩放操作的数量,以及边界检查和因果掩码操作,而无需更改输出。

更好的并行性

FlashAttention v1在批大小和部数量上进行并行化处理。研究人员使用1个线程块来处理一个注意力头,共有 (batch_size * head number) 个线程块。

在前向处理(左图)中,研究者将Worker(线程块)并行化,每个Worker负责处理注意力矩阵的一个行块。在后向处理过程中(右图),每个Worker处理注意力矩阵的一个列块

每个线程块都在流式多处理器 (SM)运行,例如,A100 GPU上有108个这样的处理器。当这个数字很大(比如 ≥80)时,这种调度是有效的,因为在这种情况下,可以有效地使用GPU上几乎所有的计算资源。

在长序列的情况下(通常意味着更小批或更少的头),为了更好地利用GPU上的多处理器,研究人员在序列长度的维度上另外进行了并行化,使得该机制获得了显著加速。

更好的工作分区

即使在每个线程块内,研究人员也必须决定如何在不同的warp(线程束)之间划分工作(一组32个线程一起工作)。研究人员通常在每个线程块使用4或8个warp,分区方案如下图所示。

研究人员在FlashAttention-2中改进了这种分区,减少了不同warp之间的同步和通信量,从而减少共享内存读/写。

对于每个块,FlashAttention将K和V分割到4个warp上,同时保持Q可被所有warp访问。这称为「sliced-K」方案。

然而,这样做的效率并不高,因为所有warp都需要将其中间结果写入共享内存,进行同步,然后再将中间结果相加。

而这些共享内存读/写会减慢FlashAttention中的前向传播速度。

在FlashAttention-2中,研究人员将Q拆分为4个warp,同时保持所有warp都可以访问K和V。

在每个warp执行矩阵乘法得到Q K^T的一个切片后,它们只需与共享的V切片相乘,即可得到相应的输出切片。

这样一来,warp之间就不再需要通信。共享内存读写的减少就可以提高速度。

新功能:头的维度高达256,多查询注意力

FlashAttention仅支持最大128的头的维度,虽说适用于大多数模型,但还是有一些模型被排除在外。

FlashAttention-2现在支持256的头的维度,这意味着GPT-J、CodeGen、CodeGen2以及Stable Diffusion 1.x等模型都可以使用FlashAttention-2来获得加速和节省内存。

v2还支持多查询注意力(MQA)以及分组查询注意力(GQA)。

GQA为每组查询头共享单个key和value的头,在多头和多查询注意之间进行插值

这些都是注意力的变体,其中多个查询头会指向key和value的同一个头,以减少推理过程中KV缓存的大小,并可以显著提高推理的吞吐量。

注意力基准

研究人员人员在A100 80GB SXM4 GPU 上测量不同设置(有无因果掩码、头的维度是64或128)下不同注意力方法的运行时间。

研究人员发现FlashAttention-2比第一代快大约2倍(包括在xformers库和Triton中的其他实现)。与PyTorch中的标准注意力实现相比,FlashAttention-2的速度最高可达其9倍。

A100 GPU上的前向+后向速度只需在H100 GPU上运行相同的实现(不需要使用特殊指令来利用TMA和第四代Tensor Core等新硬件功能),研究人员就可以获得高达335 TFLOPs/s的速度。

H100 GPU上的前向+后向速度

当用于端到端训练GPT类模型时,FlashAttention-2能在A100 GPU上实现高达225TFLOPs/s的速度(模型FLOPs利用率为72%)。

与已经非常优化的FlashAttention模型相比,端到端的加速进一步提高了1.3倍。

未来的工作

速度上快2倍,意味着研究人员可以用与之前训练8k上下文模型相同的成本,来训练16k上下文长度的模型。这些模型可以理解长篇书籍和报告、高分辨率图像、音频和视频。

同时,FlashAttention-2还将加速现有模型的训练、微调和推理。

在不久的将来,研究人员还计划扩大合作,使FlashAttention广泛适用于不同类型的设备(例如H100 GPU、AMD GPU)以及新的数据类型(例如fp8)。

下一步,研究人员计划针对H100 GPU进一步优化FlashAttention-2,以使用新的硬件功能(TMA、第四代Tensor Core、fp8等等)。

将FlashAttention-2中的低级优化与高级算法更改(例如局部、扩张、块稀疏注意力)相结合,可以让研究人员用更长的上下文来训练AI模型。

研究人员也很高兴与编译器研究人员合作,使这些优化技术更好地应用于编程。

参考资料:

​​https://princeton-nlp.github.io/flash-atttention-2/​​

别再「浪费」GPU了,FlashAttention重磅升级,实现长文本推理速度8倍提升

处理小说、法律文件等长文本是大模型的一个重要应用方向,但也面临速度上的挑战。FlashAttention 作者 Tri Dao 等人提出的「Flash-Decoding」通过充分利用 GPU,可以将大模型的长上下文推理速度提高至 8 倍。

最近,像 ChatGPT 或 Llama 这样的大型语言模型(LLM)引起了前所未有的关注。然而,它们的运行成本仍然极高。虽然生成单个响应可能仅需 0.01 美元(在 AWS 上的 8xA100 实例上运行几秒钟),但当扩大规模以满足数十亿用户的需求时,成本会迅速累积。而且,这些用户可能每天与 LLM 进行多次互动。某些用例的成本更高,例如代码自动生成,因为它会随着每次输入新字符而运行。随着 LLM 应用的不断增加,即使在生成时间方面实现细微的效率提升,也将产生巨大的影响。

LLM 推理(或「解码」)是一个迭代的过程:token 逐个生成。生成包含 N 个 token 的完整句子需要通过模型进行 N 次前向传递。幸运的是,我们可以缓存先前计算的 token:这意味着单个生成步骤不依赖于上下文长度,除了一个单独的操作 —— 注意力。这个操作导致上下文长度不能很好地扩展。

在 LLM 的重要新兴用例中,有一些需要利用更长的上下文。只有拥有了更长的上下文窗口,LLM 才能对更长的文档进行推理,无论是总结文档还是回答其中的问题。此外,它们还可以保持更长的对话历史,甚至在编写代码之前处理整个代码库。举个例子,在 2022 年,大多数 LLM 的上下文长度最多为 2k(例如 GPT-3),但现在,有些开源 LLM 已经可以扩展到 32k(比如 Llama-2-32k),甚至有些模型已经达到了 100k(比如 CodeLlama)。在这些情境中,注意力操作在推理过程中占据了相当大的时间比例。

在扩展 batch size 维度时,即使上下文相对较短,注意力也可能成为一个瓶颈。这是因为随着 batch 维度的增加,需要读取的内存量也会增加,而对于模型的其余部分,内存需求只取决于模型的大小。

为了解决上述问题,FlashAttention 的作者 Tri Dao 等人提出了一项名为「Flash-Decoding」的技术,它显著加速了推理过程中的注意力计算,使长序列的处理生成速度提高到了原来的 8 倍。其主要思想是以最快的速度并行加载键和值,然后分别重新缩放和合并结果,以维持正确的注意力输出。

解码时的多头注意力

在解码期间,生成的每个新 token 都需要关注所有先前的 token,以计算:softmax (queries @ keys.transpose) @ values

这个操作已经在训练阶段通过 FlashAttention 进行了优化(包括最近的 v1 和 v2 版本),瓶颈是读写中间结果的内存带宽(如 Q @ K^T)。然而,这些优化并不直接适用于推理情况,因为瓶颈不同。在训练中,FlashAttention 并行处理 batch size 和查询长度两个维度。而在推理过程中,查询长度通常为 1:这意味着,如果 batch size 小于 GPU 上的流多处理器(streaming multiprocessor,SM)数量(例如 A100 有 108 个),该操作只会利用 GPU 的一小部分!特别是在处理长上下文时,情况尤为明显,因为它需要较小的 batch size 以适应 GPU 内存。当 batch size 为 1 时,FlashAttention 将使用不到 1% 的 GPU!

FlashAttention 只在查询块和 batch size 之间并行,并且在解码期间不会设法占用整个 GPU。

使用矩阵乘法基元也能执行注意力计算,这样就不需要使用 FlashAttention 了。在这种情况下,该操作会占用整个 GPU,但会启动许多写入和读取中间结果的内核,因此并不是最优的做法。

更快的注意力解码:Flash-Decoding

新方法 Flash-Decoding 基于 FlashAttention,同时引入了一个新的并行维度:键值序列的长度。它综合了上述两种方法的优点。与 FlashAttention 类似,它在全局内存中存储的额外数据很少。然而,只要上下文足够长,即使 batch size 较小,它也能充分利用 GPU。

Flash-Decoding 也在键和值之间并行化,代价是一个小的最终归约(reduction 步骤。

Flash-Decoding 主要有三个工作步骤:

  1. 首先,将键 / 值分成更小的块;
  2. 使用 FlashAttention 并行计算查询与每个这些分块的注意力,为每行和每个分块额外写入一个标量值:注意力值的 log-sum-exp
  3. 最后,通过对所有分块进行归约来计算实际输出,使用 log-sum-exp 来调整每个分块的贡献。

这一切之所以可行,都是因为注意力 /softmax 可以进行迭代计算。在 Flash-Decoding 中,它在两个级别上被使用:在分块内部(类似 FlashAttention),以及跨分块进行最终的归约计算。

实际操作中,步骤(1)不涉及任何 GPU 操作,因为键 / 值块是完整键 / 值张量的视图。然后,有两个独立的核函数,分别用于执行步骤(2)和(3)。

在 CodeLlama 34B 上进行的基准测试

为了验证上述新方法,研究者对 CodeLLaMa-34b 的解码吞吐量进行了基准测试。该模型与 Llama 2 具有相同的架构,一般来说,结果应该适用于许多大型语言模型。研究者在不同序列长度下(从 512 到 64k),以 tok/s 为单位来测量解码速度,并比较了多种计算注意力的方式:

  • Pytorch:使用纯粹的 PyTorch 基元来运行注意力计算(不使用 FlashAttention);
  • FlashAttention v2;
  • FasterTransformer:使用 FasterTransformer 的注意力内核;
  • Flash-Decoding;
  • 以及一个上限值,该值计算了从内存中读取整个模型和 KV-cache 所需的时间

对于非常大的序列,Flash-Decoding 可以将解码速度提高至 8 倍,并且比其他方法的扩展性要好得多。

在 prompt 比较小时,所有方法表现接近。但是当序列长度从 512 增加到 64k 时,除了 Flash-Decoding,其他方法的可扩展性都很差。在 Flash-Decoding 的这种模式下(batch size 为 1),扩展序列长度对生成速度的影响很小。

组件级微基准测试

研究者还在 A100 上对多头注意力进行了微基准测试,输入为 f16,考虑了不同的序列长度和 batch size。他们将 batch size 设置为 1,并且使用 16 个 128 维的查询头,以及 2 个键 / 值头(分组查询注意力),这与在 4 个 GPU 上运行的 CodeLLaMa-34b 使用的维度相匹配。

上述微基准测试展示了多头注意力的运行时间,单位为微秒。Flash-Decoding 在序列长度扩展到高达 64k 时,几乎实现了恒定的运行时间。

之前测量的高达 8 倍的端到端加速是可能的,因为注意力本身的速度比 FlashAttention 快高达 50 倍。在序列长度达到 32k 之前,注意力的时间大致是恒定的,因为 Flash-Decoding 能够完全利用 GPU。

使用 Flash-Decoding

Flash-decoding 可以在以下链接中找到:

  • FlashAttention 包,从 v2.2 开始:https://github.com/Dao-AILab/flash-attention/tree/main
  • xFormers 包(搜索 xformers.ops.memory_efficient_attention),从 0.0.22 开始:调度程序将根据问题的大小自动使用 Flash-Decoding 或 FlashAttention 方法。当这些方法不受支持时,它可以调度到一个高效的 triton 内核,该内核实现了 Flash-Decoding 算法。

一个完整的使用 LLaMa v2 / CodeLLaMa 的解码示例可以在 FlashAttention  repo 和 xFormers  repo 中找到。此外,作者还提供了一个简单的 LLaMa v1/v2 模型的高效解码代码示例,旨在快速、易读、有教育意义和易于修改。

参考链接:https://princeton-nlp.github.io/flash-decoding/

#FlashAttention~2

这是一个简单的学习总结,核心逻辑以及V1 V2差异总结,主要从计算的角度总结FlashAttention怎么做到save memory & perf speedup的,讲一些其他文章提的比较少的点,不提供等效计算变换的公式证明。

摸了一下FlashAttention的CUTLASS实现(https//github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_kernel.h)和Triton实现(https//github.com/openai/triton/blob/main/python/triton/ops/flash_attention.py),之前做过FlashAttention V1和V2算法下,这两种框架最终的Kernel Perf benchmark(A100 & H100)所以对V1和V2的差异很好奇,本文是一个简单的学习总结,主要从计算的角度总结FlashAttention怎么做到save memory & perf speedup的,讲一些其他文章提的比较少的点,不提供等效计算变换的公式证明(这部分知乎其他大佬写的非常详细清晰了)。

理解FlashAttention核心逻辑

本节列举从0搞懂FlashAttention的核心步骤

首先需要理解Naive Attention是怎么计算的:

  1. Google Research的工作重点在减少整个过程的memory footprint;FlashAttention重点在减少memory reads/writes次数。可以说FlashAttention主要是从GPU block/thread并行度的视角对访存进行了优化。
  2. Google Research的工作每个block会产出一份中间结果,所有block执行完毕之后,再将他们的中间结果计算获得一个最终结果;FlashAttention则采用类似滑动窗口的方式,第i个block会将累积的中间结果传递给第 i+1 个block,也就是说最后一个block计算完毕后,可以保证整行的Softmax逻辑计算正确性。锐评:我认为这个点并没有什么独创性,Google Research这么考虑的原因也大概率是因为TPU的计算逻辑粒度适合沿sequence length切并行,切的越小越有利于TPU并行,最后再有一个逻辑来处理中间数据很正常。
  3. Google Research的工作在后向backward的时候做了一些冗余计算,FlashAttention把后向的计算简化了,减少了backward阶段的memory traffic。

FlashAttention V1的公式和推导不细嗦了,其他文章讲得非常好。列一下个人觉得从0开始的最佳学习路线:

首先看文章,公式推导和Tiling细节强烈这篇文章:From online softmax to FlashAttention(https//courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf),写的非常好,由浅入深,公式推导这块看完这篇其他都不用看了。然后辅助看一些知乎文章把不明白的地方搞清楚。

理解From online softmax to FlashAttention需要四个步骤

  1. softmax
  2. safe softmax
  3. online softmax Tiling
  4. FlashAttention Tiling

总之:FlashAttention之所以可以省显存(显存开销随Seq length线性增加),是因为解开了softmax以及后面GEMM的行方向依赖,并且通过辅助数组保存的辅助信息re-scale到正确的数值。

其次,了解一些背景信息,这里附一下其他可能便于理解FlashAttention项目发展的背景信息:

  • FlashAttention V1 在NVIDIA apex fmha基础上实现(最早的FlashAttention Alpha Realease(https//github.com/Dao-AILab/flash-attention/blob/1fcbe6f0d088d807ba585ddb850eb3497cd7b65b/csrc/stream_attn/src/fmha_kernel.h)),V2基于CUTLASS 3.0 & CUTE 重构(CUTE真是个好东西)
  • FlashAttention目前最方便的调库途径主要有两个
  • 最新实现在ops里面(https//github.com/openai/triton/blob/main/python/triton/ops/flash_attention.py)
  • 稳定的实现在tutorial里面(https//github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py)
  • 【更新】最近测了一下,H100上tutorial速度也很快,看上去tutorial的kernel算法也经常有小优化
  • pip install flash-attn,官方的库,编译时间稍长,基于CUTLASS有大量的模板,如果想进一步魔改(比如加bias或者加mask,或者稀疏化等)学习和Debug成本比较大
  • 使用Triton的实现,性能实测非常不错

最后,看代码,跑代码,Profile Kernel with Nsight Compute,改代码...

这里我推荐基于Triton FlashAttention(https//github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py)上手,原因如下:

  1. Tri-Dao的FlashAttention基于CUTLASS 3.0重构,新手小白想要编译跑起来首先要搞定几个环境问题;里面一个头文件几千行,想看懂需要首先搞懂CUTLASS的基本用法,想改代码更需要一些模板猿编程debug技巧,想优化性能的话...你得学学CUTE,再学学各种GPU Features...如果你只是想学习一下FlashAttention,或者只是想基于这个Fusion idea定制你自己的kernel,没必要从CUTLASS开始。CUTLASS适合CUDA熟手并且期望拿到Peak Performance的玩家。
  2. Triton语法很容易上手,方便魔改你自己的Attention Kernel,或者你有其他的想法也很容易实践实验。例子:FlagAttention(https//github.com/FlagOpen/FlagAttention),Sparse Flash Attention(https//github.com/epfml/dynamic-sparse-flash-attention) (所以非常适合发paper啦,至少迭代CUDA kernel速度直接飙升,加快idea的反馈。从实验时间成本来说,用CUDA写半个月的Kernel+调一个月性能 >> 用CUTLASS写一周Kenrel+调两天性能 >> 用Triton写3天Kernel+调1天性能)
  3. Triton FlashAttention在Hopper-Triton PR(https//github.com/openai/triton/commit/f1512bded1934e34f104bf1ac8547e97e24b2fe8)之后,目前main分支已经集成了大量的Hopper相关的优化Pass,相比官方库还没有稳定实现Hopper features来说,在某些problem size下可能有优势。
  4. 关于Triton推荐阅读:杨:谈谈对OpenAI Triton的一些理解(https://zhuanlan.zhihu.com/p/613244988),杨:OpenAI Triton Conference参会随感兼谈Triton Hopper(https://zhuanlan.zhihu.com/p/659348024)

OK,进入正题

FlashAttention V2相比V1有哪些改进

FlashAttention V1 前向后向 Kernel示意草图

FlashAttention V2 前向后向 Kernel示意草图

V2主要从两个方面改进:

算法的改进

fwd和bwd都简化了非matmul计算,这里也是对rescale重新优化了一下;其中bwd不需要m了,只需要logsumexp L 即可。

其实FlashAttention不管V1还是V2都有一个缺点,就是为了rescale方便并行,需要把很多计算逻辑顺序排在后面(尤其是浮点数的乘除),这会改变计算的数值精度稳定性,造成在某些使用到Attention结构的网络中收敛不了的问题。

fwd和bwd都根据casual mask的特性尽可能减少冗余计算和访存:

  1. 右侧上三角block无须计算,直接跳过;
  2. 每行只用对最后一个block设定casual mask的逻辑即可。

FlashAttention V1 & V2 forward

FlashAttention V1 & V2 backward

红框的就是这里算法部分的优化和改动(左图多了mask和droupout的逻辑,忽略即可)

这个优化其实不是critical path,所以提升并不大。fwd做2个GEMM,bwd做5个GEMM,整个Kernel fwd & bwd都是memory bound,此时应该优化的是GEMM相关的memory dependency,做multi-stages,更灵活的异步调度(比如warp specialization),最后可能还需要考虑优化data reuse,优化L2 cache等等,当然一切都需要基于Nsight Compute结果分析,不然都是幻觉。

Sequence Length 并行

非常赞同@方佳瑞(https://www.zhihu.com/people/8c89d6f733cb2b81ce36a2daf0a81a82) 方佳瑞:大模型训练加速之FlashAttention系列:爆款工作背后的产品观(https://zhuanlan.zhihu.com/p/664061672#:~:text=我觉得V2最重要的提升点是参考Phil Tillet的Tirton版本,更改了Tiling循环的顺序,也就是笔者本文图1的效果。) 提到的,V2 能够把特定输入下的一个CUDA Kernel提升2X,这只能说明baseline(V1)选的太好了(笑),总之,就是因为改变了Tiling循环的顺序,把Q循环挪到了最外层,所以刚好就可以把Q循环直接给到Thread Block并行维度来计算了,本来这个方向没有依赖就是可以并行的。话说我最开始的也很纳闷,这个idea其实最早就有了,PyTorch的实现以及NV Apex FMHA的实现都有这个版本的kernel。

考虑一下,为什么K/V上的seq length方向不给到Thread Block做并行?答案是,如果可以在Q seq length上拆block并行了,那么一般来说GPU occupancy已经够了,再多拆K/V的话也不是不行,但是会额外带来通信开销;Flash Decoding其实就是在inference阶段,面对Q的seq length=1的情况,在K/V方向做了block并行,来提高GPU Utilization从而加速的。

FlashAttention V1 - Tile and 2D-Loop

Thread Block Level 并行

交换了Q loop顺序到最外层之后,最大的好处是可以把这一维度的并行度从串行的loop改成并行的thread block。

所以,FlashAttention V2的实现中,fwd除了在Batch和Head上分配Thread Block并行,还在seq length上增加了一维并行度(之前是需要M N方向做loop的,现在只在N方向loop了,横着切),注意:bwd没有改变这里的循环,跟V1一样,但是也在seq length上增加了一维并行度(N方向并行,竖着切)。

造成fwd和bwd区别的主要原因是:

  1. fwd的目的是计算QK GEMM之后沿着行方向online softmax,所以需要沿着行方向loop,不然就需要额外的reduce逻辑了。因此fwd kernel选择一行Tile为一个block。如下左图一行同色块为一个block。

细节1:从V1 (KV外循环,QO内循环) 到V2 (Q外循环,KV内循环,O在 smem初始化,最后只写出一次), memory traffic是否降低了? memory traffic of V1:

细节2:现在确定了fwd kernel要在B, H, Q_N_CTX三个维度Launch Kernel了,有两种选择:grid_dim = [Q_N_CTX, B, H], grid_dim = [B, H, Q_N_CTX],哪种更好?

答案是第一种更好,因为Q_N_CTX放ThreadBlock.X维度的话,对于同一个B和H的Q_N_CTX是连续调度的,也就是说算第一行用到的K/V Tile大概率还在L2上,第二行计算可以直接从L2拿到,这样可以显著提高L2 cache hit rate。这个优化在大seq_length的时候优化很明显。

Warp Level 并行

说完了thread block的并行,再来看一个block内的warp怎么并行的

把V横着画,有一种上下对称的美感

首先看fwd,相比V1,V2改进了Warp Partition:4个warp会从smem的K/V tile load同样的数据做mma计算,但是load 不同Q,把V1 sliced-K sliced-V 改成了v2 sliced-Q,V1的做法是需要warp之间产生同步通信的,因为在计算QK结果乘V的时候,如图所示需要跨warp reduction得到O的结果,而且fwd的目的是沿着行方向计算softmax,行方向信息最后要汇总的,这也需要跨warp不同。V2就不需要了,这样可以减少同步开销。

对于bwd来说,如果按照右图做warp partition:1、QK的结果是P,dV=P x dO,计算dV也是需要cross warp sync的;2、dO x V的结果是dP,跟P是对称的,计算dS = P o(点乘) dP的时候不需要cross warp sync;3、计算dQ是dS x K,不需要cross warp sync;4、计算dK = dS x Q,需要cross warp sync;

如果按照左图来做warp partition,那么:1、计算dV不需要cross warp sync;2、计算dS不需要cross warp sync;3、计算dQ,需要cross warp sync;4、计算dK,不需要cross warp sync;

这里有个疑问,对于bwd来说,左图不需要cross warp sync的场景是更多的,如果Br Bc d这三个reduction的维度差不多,按道理来说bwd kernel更应该采用V1的方式做warp partition,原文:

Similarly for the backward pass, we choose to partition the warps to avoid the "split-K" scheme. However, it still requires some synchronization due to the more complicated dependency between all the different inputs and gradients Q, K, V, O, dO, dQ, dK, dV. Nevertheless, avoiding "split-K" reduces shared memory reads/writes and again yields speedup

但作者并没有细说,也没有实验数据证明这里确实提升了bwd kernel性能,感到困惑。以后有空再测一下这里的策略是不是负优化(笑

其他优化

Sequence Parallel

https//github.com/openai/triton/blob/main/python/triton/ops/flash_attention.py

这是一个没有放在FlashAttentionV2 Release的优化点(应该,反正V2 paper没提到)。之前提到的bwd中,对dQ的计算是需要跨Block做全局Atomic Reduction的,如果Block数太多,就会产生Atomic竞争,效率很低(AtomicAdd指令吞吐和延迟都比正常的LDG慢2X,而且如果dQ的数据类型是fp16或者bf16,进行Atomic操作将是性能灾难);所以如果想避免使用Atomic指令做Reduction,有两种方案:

Loop n

最简单的方式,就是bwd不要seq length方向并行度了,直接串行循环就ok;

Sequence Parallel

把dQ开大一维度,N方向有多少列block就开几个buffer,并且N方向纯block并行,最后开另外一个kernel做一个reduce,这个逻辑体现在Triton代码里就是:

https//github.com/openai/triton/blob/f9b2b822dfc0980df4c39286713414dfcd27cf8e/python/triton/ops/flash_attention.py%23L323C1-L361C1

num_block_n = tl.cdiv(N_CTX, BLOCK_N)if not SEQUENCE_PARALLEL:for start_n in range(0, num_block_n):_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO,  #DQ, DK, DV,  #L,  #D,  #Q_block_ptr, K_block_ptr, V_block_ptr,  #DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,  #stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,  #stride_kz, stride_kh, stride_kn, stride_kk,  #stride_vz, stride_vh, stride_vn, stride_vk,  #Z, H, N_CTX,  #off_h, off_z, off_hz, start_n, num_block_n,  #BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,  #BLOCK_N=BLOCK_N,  #SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,  #CAUSAL=CAUSAL,  #MMA_V3=MMA_V3  #)else:start_n = tl.program_id(1)_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO,  #DQ, DK, DV,  #L,  #D,  #Q_block_ptr, K_block_ptr, V_block_ptr,  #DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr,  #stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk,  #stride_kz, stride_kh, stride_kn, stride_kk,  #stride_vz, stride_vh, stride_vn, stride_vk,  #Z, H, N_CTX,  #off_h, off_z, off_hz, start_n, num_block_n,  #BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,  #BLOCK_N=BLOCK_N,  #SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,  #CAUSAL=CAUSAL,  #MMA_V3=MMA_V3  #)

Atomic和Sequence Parallel肯定是两个极端了,所以根据不同的problem size取个折中肯定会有机会拿到更好的性能,根据具体的计算任务和GPU型号,还有些取巧的方式来避免Atomic,比如把即将冲突的Atomic Tile安排在不同的wave上,这样GPU也不会因为Atomic带来性能损耗。

Flash Decoding优化的核心跟这里的Sequence Parallel很类似:K/V Seq Length 方向切并行度做forward softmax并且最后使用一个reduction kernel对output进行累加和rescale。

指令级别的优化

基于FlashAttentionV2继续优化下去(假设把mma相关指令和memory latency掩盖的很好的话),最后大概率会bound在softmax相关的指令上,这个时候细扣这些浮点数计算指令也是有帮助的:

  1. FMUL+MUFU.EX2指令替换为fastmath指令expf
  2. 把scale相关的数据在cpu提前计算好,省一条GPU指令
  3. 把FMUL和FADDs合成FFMA指令

其实还有很多点可以优化FlashAttention的性能,不过都是些没有profile kernel的拍脑袋幻觉,就不讲了,之后有空做了实验可以再写一篇。

#IMMA~~

搬来自斯坦福的研究者提出了 IMMA, 一种利用隐空间多层图 (multiplex latent graphs) 来表征多种独立的交互类型,并使用一种新型的多层图注意力机制 (multiplex attention mechanism) 来描述个体间交互强度的行为及轨迹预测模型。该方法不仅大幅提升了预测的准确度,同时也具有很强的可解释性 (interpretability) 和泛化能力 (zero-shot generalizability)。

  • 论文链接:https://arxiv.org/abs/2208.10660 
  • 代码链接:https://github.com/fanyun-sun/IMMA

一.研究背景

对多智能体系统 (multi-agent systems) 的建模在很多领域和应用中起到重要作用,包括但不限于自动驾驶,移动机器人导航,以及人机协作。由于个体的行为会受到不同类型社会性交互 (social interactions) 的影响,多智能体系统的动力学建模面临着极高的挑战。

过去已有的方法例如 NRI[1]和 EvolveGraph[2]利用图神经网络 (GNN) 来推测每一对智能体之间的关系类型,但是这样并不能显式地对在多智能体系统中出现的不同层级的交互关系进行建模,从而导致模型的预测效果和可解释性下降。下面 Figure 1 介绍了生活中人与人之间交互的一个实例。

该研究提出了 IMMA(Interaction Modeling with Multiplex Attention)方法,通过利用隐空间中的多层关系图结构 (multiplex latent interaction graphs), 对不同层级中不同类型的交互关系进行推理,同时,该研究还设计了一种新的多层图注意力机制 (multiplex attention mechanism) 来学习每种交互关系的强度。另外,该研究还提出了一种逐层训练 (progressive layer training) 的方法来加强不同层的关系图之间的解耦,从而进一步提升了模型的可解释性 (interpretability) 和泛化能力 (zero-shot generalizability)。本文方法在多种不同领域的问题中都取得了最优效果,包括 social navigation, cooperative task achievement 和 team sports.

二.研究方法

问题描述:假设场景中包含 N 个智能体,模型的输入包含这 N 个智能体的轨迹,任务目标是根据过去一段时间内的轨迹观测来预测未来一段时间内的轨迹,同时要对智能体之间的交互关系进行建模和推理。

核心观点:在已有的方法中 ([1][2]),隐变量 z 中的每个元素表示交互关系图中对应的 edge 属于每一种可能的关系的概率,意味着用于 Decoder 的关系图中每一对智能体的关系只能是其中的一种。然而,复杂的交互系统中可能存在某些智能体之间同时存在多种相互独立的关系的情况,并且每种关系的强度也可能有所区别,仅通过一层关系图不能准确描述具有这类性质的多体系统。因此,该研究提出利用多层关系图拓扑结构来进行更精准的建模,不仅可以提升预测效果,同时也能用模型学到的多层关系图提供对模型预测的解释,一定程度上可以分析各智能体行为之间的因果关系。另外,由于多层关系图也可以提升模型在训练数据中没有包含的场景中取得更好的效果,增强泛化能力和训练样本效率(sample efficiency)。

模型介绍:该研究提出的模型 IMMA 由一个 Encoder 和一个 Decoder 组成。Encoder 的输入是一系列轨迹观测,输出是一个在隐空间内的交互关系图。Decoder 通过对推测出的关系图里面的信息进行传递和整合来生成对每个智能体未来轨迹的预测。下面 Figure 2 提供了模型示意图和进一步解释。模型整体基于 Conditional Variational Autoencoder 框架,并通过以下目标函数来训练模型参数:

三.实验结果和分析

本文的实验试图回答以下四个问题:

1. 本文方法 (IMMA w/ PLT) 是否在各种 social multiagent systems 数据集测试中始终优于已有的基准方法?

2.Multiplex attentional latent graph 的使用是否给模型提供了更多可解释性?

3. 模型中的每个模块对模型效果的提升有多大贡献?

4. 本文方法相比于基准方法是否可以提升 sample efficiency?它是否可以很好地泛化到新环境或新场景?

1. 该研究主要用以下三个数据集进行实验:Social Navigation Environment  (基于 ORCA), PHASE dataset 和 NBA dataset,所有数据集上的结果如 Table 1 所示。该研究发现使用 Multiplex attentional latent graph 和渐进层训练 (MG w/ PLT) ,结果在所有三个数据集上都优于已有的最强基线模型。 

对于 Social Navigation dataset 的可视化结果如下图所示。第一排表示预测轨迹。圆圈越小表示离当前时间点越远。最左边是真实的未来轨迹和交互关系图。第二排表示预测的 latent graph。智能体 i 和 j 之间的 relation 由 heatmap 中的第 i 行第 j 列的元素表示。RFM 错误地预测了智能体之间的关系——用箭头突出显示,绿色 agent 被错误地赋予了比蓝色 agent 更高的权重。因此,预测的轨迹偏离了事实。相反,本文方法准确地预测了交互关系和未来轨迹。

本文方法在 PHASE(左)和 NBA 数据集 (右) 的结果可视化如下图所示。在右侧的 NBA 图中,橙色代表篮球,不同轨迹颜色代表不同球队。    

 2. 以上实验证明本文方法可以更准确地预测运动轨迹,之后,该研究进一步探究了关系推理能力和对轨迹预测的影响。首先,本文对于关系推理更加准确 (见 Table 2),这不仅帮助模型预测运动轨迹,也提供了更好的 disentanglement 和可解释性。如下图所示,IMMA 中改变 agent 的 leader 会显著改变预测的轨迹,以新 leader 为目标,同时保持对其他 agent 的预测不变。而 RFM 生成的轨迹包括不切实际的转弯 (如红色 agent) 并且对其他 agent 的轨迹预测变差。

3. Ablation study 结果如 Table 3 所示,实验结果证明 multiplex attention graph 在模型中起到了至关重要的作用,逐层训练进一步提升了轨迹预测和关系推理准确度。

4. Table 4 显示 IMMA 的 zero-shot 泛化能力比基线方法更好。 

另外,下图显示相比最优基线方法 (RFM), IMMA 需要更少的训练数据就可以得到更好的结果。

四.结论

由于存在潜在的多层社会交互 (social interactions) 关系,多智能体系统 (multi-agent systems) 的动力学 (dynamics) 通常很复杂。智能体 (agent) 的行为可能会受到与其他每个智能体多种独立关系类型的影响,例如在物理系统中通常不存在的复杂属性 (意向性或合作关系)。本文提出了一种包含交互建模的预测方法 (IMMA),该方法使用 multiplex latent graph 作为隐空间表征 (latent representation) 来建模这种多层交互类型可能产生的行为。本文方法在行为和轨迹建模以及关系推理方面均优于其他最先进的方法,并有很强的可解释性 (interpretability) 和泛化能力 (zero-shot generalizability)。

#深度强化学习中SAC算法

数学原理、网络架构及其PyTorch实现

本文详细介绍了深度强化学习中的软演员-评论家算法(SAC),包括其数学原理、网络架构设计以及PyTorch实现。

深度强化学习是人工智能领域最具挑战性的研究方向之一,其设计理念源于生物学习系统从经验中优化决策的机制。在众多深度强化学习算法中,软演员-评论家算法(Soft Actor-Critic, SAC)因其在样本效率、探索效果和训练稳定性等方面的优异表现而备受关注。

传统的深度强化学习算法往往在探索-利用权衡、训练稳定性等方面面临挑战。SAC算法通过引入最大熵强化学习框架,在策略优化过程中自动调节探索程度,有效解决了这些问题。其核心创新在于将熵最大化作为策略优化的额外目标,在保证收敛性的同时维持策略的多样性。

本文将系统阐述SAC算法的技术细节,主要包括:

  1. 基于最大熵框架的SAC算法数学原理
  2. 演员网络与评论家网络的具体架构设计
  3. 基于PyTorch的详细实现方案
  4. 网络训练的关键技术要点

SAC算法采用演员-评论家架构,演员网络负责生成动作策略,评论家网络评估动作价值。通过两个网络的协同优化,实现策略的逐步改进。整个训练过程中,演员网络致力于最大化评论家网络预测的Q值,同时保持适度的策略探索;评论家网络则不断优化其Q值估计的准确性。

接下来,我们将从演员网络的数学原理开始,详细分析SAC算法的各个技术组件:​

演员(策略)网络

演员是由参数φ确定的策略网络,表示为:

图片

这是一个基于状态输出动作的随机策略。它使用神经网络估计均值和对数标准差,从而得到给定状态下动作的分布及其对数概率。对数概率用于熵正则化,即目标函数中包含一个用于最大化概率分布广度(熵)的项,以促进智能体的探索行为。关于熵正则化的具体内容将在后文详述。演员网络的架构如图所示:

图片

均值μ(s)和对数σ(s)用于动作采样:

图片

其中N表示正态分布。但这个操作存在梯度不可微的问题,需要通过重参数化技巧来解决。

图片

这里d表示动作空间维度,每个分量ε_i从标准正态分布(均值0,标准差1)中采样。应用重参数化技巧:

图片

这样就解决了梯度截断问题。接下来通过激活函数将x_t转换为标准化动作:

图片

该转换确保动作被限制在[-1,1]区间内。​

动作对数概率计算

完成动作计算后,就可以计算奖励和预期回报。演员的损失函数中还包含熵正则化项,用于最大化分布的广度。计算采样动作𝑎_t的对数概率Log(π_ϕ)时,从预tanh变换x_t开始分析更为便利。

由于x_t来自均值μ(s)和标准差σ(s)的高斯分布,其概率密度函数(PDF)为:

图片

其中各独立分量x_t,i的分布为:

图片

对两边取对数可简化PDF:

图片

要将其转换为log(π_ϕ),需要考虑x_t到a_t的tanh变换,这可通过微分链式法则实现:

图片

这个关系的推导基于概率守恒原理:两个变量在给定区间内的概率必须相等:

图片

其中a_i = tanh(x_i)。将区间缩小到无穷小的dx和da:

图片

tanh的导数形式为:

图片

代入得到:

图片

最终可得完整表达式:

图片

至此完成了演员部分的推导,这里有动作又有对数概率,就可以进行损失函数的计算。下面是这些数学表达式的PyTorch实现:

import gymnasium as gym    from src.utils.logger import logger    from src.models.callback import PolicyGradientLossCallback    from pydantic import Field, BaseModel, ConfigDict    from typing import Dict, List    import numpy as np    import os    from pathlib import Path    import torch    import torch.nn as nn    import torch.optim as optim    import torch.nn.functional as F    from torch.distributions import Normal   '''演员网络:估计均值和对数标准差用于熵正则化计算'''    class Actor(nn.Module):    def __init__(self,state_dim,action_dim):  super(Actor,self).__init__()  self.net = nn.Sequential(  nn.Linear(state_dim, 100),  nn.ReLU(),  nn.Linear(100,100),  nn.ReLU()  )  self.mean_linear = nn.Linear(100, action_dim)  self.log_std_linear = nn.Linear(100, action_dim)  def forward(self, state):  x = self.net(state)  mean = self.mean_linear(x)  log_std =self.log_std_linear(x)  log_std = torch.clamp(log_std, min=-20, max=2)  return mean, log_std  def sample(self, state):  mean, log_std = self.forward(state)  std = log_std.exp()  normal = Normal(mean, std)  x_t = normal.rsample() # 重参数化技巧  y_t = torch.tanh(x_t)  action = y_t  log_prob = normal.log_prob(x_t)  log_prob -= torch.log(1-y_t.pow(2)+1e-6)  log_prob = log_prob.sum(dim=1, keepdim =True)  return action, log_prob

在讨论损失函数定义和演员网络的训练过程之前,需要先介绍评论家网络的数学原理。​

评论家网络

评论家网络的核心功能是估计状态-动作对的预期回报(Q值)。这些估计值在训练过程中为演员网络提供指导。评论家网络采用双网络结构,分别提供预期回报的两个独立估计,并选取较小值作为最终估计。这种设计可以有效避免过度估计偏差,同时提升训练稳定性。其结构如图所示:

图片

需要说明的是,此时的示意图是简化版本,主要用于理解演员和评论家网络的基本角色,暂不考虑训练稳定性的细节。另外,"智能体"实际上是演员和评论家网络的统称而非独立实体,图中分开表示只是为了清晰展示结构。假设评论家网络暂不需要训练,因为这样可以专注于如何利用评论家网络估计的Q值来训练演员网络。演员网络的损失函数表达式为:

图片

更常见的形式是:

图片

其中ρ_D表示状态分布。损失函数通过对所有动作空间和状态空间的熵项与Q值进行积分得到。但在实际应用中,无法直接获取完整的状态分布,因此ρ_D实际上是基于重放缓冲区样本的经验状态分布,期望其能较好地表征整体状态分布特征。

基于该损失函数可以通过反向传播对演员网络进行训练。以下是评论家网络的PyTorch实现:

'''评论家网络:定义q1和q2'''    class Critic(nn.Module):    def __init__(self, state_dim, action_dim):  super(Critic, self).__init__()  # Q1网络架构  self.q1_net = nn.Sequential(  nn.Linear(state_dim + action_dim, 256),  nn.ReLU(),  nn.Linear(256, 256),  nn.ReLU(),  nn.Linear(256, 1),  )  # Q2网络架构  self.q2_net = nn.Sequential(  nn.Linear(state_dim + action_dim, 256),  nn.ReLU(),  nn.Linear(256, 256),  nn.ReLU(),  nn.Linear(256, 1),  )  def forward(self, state, action):  sa = torch.cat([state, action], dim=1)  q1 = self.q1_net(sa)  q2 = self.q2_net(sa)  return q1, q2

前述内容尚未涉及评论家网络自身的训练机制。从重放缓冲区采样的每个数据点包含[s_t, s_{t+1}, a_t, R]。对于状态-动作对的Q值,我们可以获得两种不同的估计。

第一种方法是直接将a_t和s_t输入评论家网络:

图片

第二种方法是基于贝尔曼方程:

图片

这种方法使用s_t+1、a_t+1以及执行动作a_t获得的奖励来重新估计。这里使用目标网络而非第一种方法中的评论家网络进行估计。采用目标评论家网络的主要目的是解决训练不稳定性问题。如果同一个评论家网络同时用于生成当前状态和下一状态的Q值(用于目标Q值),这种耦合会导致网络更新在目标计算的两端产生不一致的传播,从而引起训练不稳定。因此引入独立的目标网络为下一状态的Q值提供稳定估计。目标网络作为评论家网络的缓慢更新版本,确保目标Q值能够平稳演化。具体结构如图所示:

图片

评论家网络的损失函数定义为:

图片

通过该损失函数可以利用反向传播更新评论家网络,而目标网络则采用软更新机制:

图片

其中ε是一个较小的常数,用于限制目标评论家的更新幅度,从而维持训练稳定性。​

完整流程

以上内容完整阐述了SAC智能体的各个组件。下图展示了完整SAC智能体的结构及其计算流程:

图片

下面是一个综合了前述演员网络、评论家网络及其更新机制的完整SAC智能体实现

'''SAC智能体的实现:整合演员网络和评论家网络'''    class SACAgent:    def __init__(self, state_dim, action_dim, learning_rate, device):  self.device = device  self.actor = Actor(state_dim, action_dim).to(device)  self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)  self.critic = Critic(state_dim, action_dim).to(device)  self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=learning_rate)  # 目标网络初始化  self.critic_target = Critic(state_dim, action_dim).to(device)  self.critic_target.load_state_dict(self.critic.state_dict())  # 熵温度参数  self.target_entropy = -action_dim   self.log_alpha = torch.zeros(1, requires_grad=True, device=device)  self.alpha_optimizer = optim.Adam([self.log_alpha], lr=learning_rate)  def select_action(self, state, evaluate=False):  state = torch.FloatTensor(state).to(self.device).unsqueeze(0)  if evaluate:  with torch.no_grad():  mean, _ = self.actor(state)  action = torch.tanh(mean)  return action.cpu().numpy().flatten()  else:  with torch.no_grad():  action, _ = self.actor.sample(state)  return action.cpu().numpy().flatten()  def update(self, replay_buffer, batch_size=256, gamma=0.99, tau=0.005):  # 从经验回放中采样训练数据  batch = replay_buffer.sample_batch(batch_size)  state = torch.FloatTensor(batch['state']).to(self.device)  action = torch.FloatTensor(batch['action']).to(self.device)  reward = torch.FloatTensor(batch['reward']).to(self.device)  next_state = torch.FloatTensor(batch['next_state']).to(self.device)  done = torch.FloatTensor(batch['done']).to(self.device)  # 评论家网络更新  with torch.no_grad():  next_action, next_log_prob = self.actor.sample(next_state)  q1_next, q2_next = self.critic_target(next_state, next_action)  q_next = torch.min(q1_next, q2_next) - torch.exp(self.log_alpha) * next_log_prob  target_q = reward + (1 - done) * gamma * q_next  q1_current, q2_current = self.critic(state, action)  critic_loss = F.mse_loss(q1_current, target_q) + F.mse_loss(q2_current, target_q)  self.critic_optimizer.zero_grad()  critic_loss.backward()  self.critic_optimizer.step()  # 演员网络更新  action_new, log_prob = self.actor.sample(state)  q1_new, q2_new = self.critic(state, action_new)  q_new = torch.min(q1_new, q2_new)  actor_loss = (torch.exp(self.log_alpha) * log_prob - q_new).mean()  self.actor_optimizer.zero_grad()  actor_loss.backward()  self.actor_optimizer.step()  # 温度参数更新  alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()  self.alpha_optimizer.zero_grad()  alpha_loss.backward()  self.alpha_optimizer.step()  # 目标网络软更新  for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):  target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

总结

本文系统地阐述了SAC算法的数学基础和实现细节。通过对演员网络和评论家网络的深入分析,我们可以看到SAC算法在以下几个方面具有显著优势:

理论框架

  • 基于最大熵强化学习的理论基础保证了算法的收敛性
  • 双Q网络设计有效降低了值函数估计的过度偏差
  • 自适应温度参数实现了探索-利用的动态平衡

实现特点

  • 采用重参数化技巧确保了策略梯度的连续性
  • 软更新机制提升了训练稳定性
  • 基于PyTorch的向量化实现提高了计算效率

实践价值

  • 算法在连续动作空间中表现优异
  • 样本效率高,适合实际应用场景
  • 训练过程稳定,调参难度相对较小

未来研究可以在以下方向继续深化:

  • 探索更高效的策略表达方式
  • 研究多智能体场景下的SAC算法扩展
  • 结合迁移学习提升算法的泛化能力
  • 针对大规模状态空间优化网络架构

强化学习作为人工智能的核心研究方向之一,其理论体系和应用场景都在持续发展。深入理解算法的数学原理和实现细节,将有助于我们在这个快速演进的领域中把握技术本质,开发更有效的解决方案。作者:Najib Sharifi, Ph.D

#SEA~~

抹平One-Stage与Two-Stage目标检测之间的差距 ,重新讨论了单阶段和两阶段的检测器蒸馏任务

在这讨论了单阶段和两阶段的检测器蒸馏任务,并提出了一个简单而有效的语义感知框架来填补它们之间的空白。作者通过设计类别Anchor来生成每个类别的代表性模式,并规范像素级的拓扑距离和类别Anchor之间的拓扑距离,以进一步加强它们的语义联系,从而解决像素级的语义失衡问题。

作者将本文的方法命名为SEA(SEmantic-aware Alignment)蒸馏,因为通过语义依赖提取密集的细粒度信息的本质,以很好地促进蒸馏效果。

SEA很好地适用于两种检测管道,并在单阶段和两阶段检测器上的具有挑战性的COCO目标检测任务上取得了新的最先进的结果。其在实例分割上的优越性能进一步体现了其泛化能力。2x-distilled RetinaNet和使用ResNet50-FPN的FCOS均优于相应的3x ResNet101-FPN教师,分别达到40.64和43.06 AP。

在目标检测中,设计良好的骨干网络为强大的目标检测器提供了强大的支持,以解决具有挑战性的任务。延迟和准确性之间的平衡是目标检测中不可避免的权衡,特别是当在移动设备上部署模型时。与改变原始模型权重和存储位的剪枝和量化不同,知识蒸馏通过将知识从训练有素的教师模型转移到相对较小的学生模型来保持目标模型的完整,这现在成为模型加速的常见做法。

近年来,由于没有专门设计的算子,单阶段检测器有利于移动应用。然而,单阶段检测器的蒸馏性能仍落后于两阶段检测器。其原因是大多数用于目标检测的蒸馏方法主要是为两阶段检测器设计的。例如,DSIG将两阶段检测器提高了2%以上。但当应用于FCOS时,其性能降低了0.8%。

这种不一致性主要来自于单阶段和两阶段检测管道之间的差异(图1)。在两阶段管道中,先前的proposals由Region Proposal Network生成,以提取感兴趣区域(RoI)特征,使预测来自像素集合。而在没有建议的单阶段检测器(也称为密集检测器)中,密集特征映射是通过连续卷积处理的。单个像素被直接用于产生预测。

在单阶段检测器中存在的一个潜在问题是,“感兴趣的区域”(即选定的单个像素)与密集检测器中真正感兴趣的区域(即边界框)不对齐。相比之下,在两阶段设计中采用的区域建议自然适合于实例级的框预测。此外,在单阶段教师和学生模型之间的像素到像素蒸馏过程中,对不重要或错误像素的预测可能会引入不必要的噪声,从而导致基于关系的方法可能错误地利用元素之间的相关性,从而导致性能推断。

为了解决这个问题,作者提出了一个新的视角来理解密集检测器的知识蒸馏。如前所述,密集检测器由多个层的特征图组成,其中像素级训练样本极不平衡,并且不能从密集特征图中利用任何显式实例级关系。此外,密集检测器还有另一个并行分支来处理定位信息,这可以被更好地利用。

解决不平衡问题

不同于将前景(FG)样本与背景(BG)样本的比例正规化的两阶段检测器,只有单阶段的检测器。这导致了FG和BG之间更高水平的不平衡,甚至是不同类别的前景之间的不平衡。Two stage oriented distillers倾向于将FG与BG分离,并以不同的方式调整其损失重量。在单阶段检测器中,FG和BG很难在密集的特征图上进行分割和调谐。如果施加正则化,剧烈的不平衡会使蒸馏的效果降低。为了解决这个问题,作者编码了一种名为类别Anchor的新知识表示,它被设计为每个现有类别的通用模式。它还可以作为场景图像的密集特征图的语义摘要。当涉及到提取类别锚时,类别之间的不平衡(包括BG)得到了缓和。

2、在像素级中挖掘稀疏关系

两阶段风格的实例级关系很难转移到单阶段管道中。uniformly-sampled dense relation distillation甚至降低了检测性能。这说明需要更好地定义每个语义像素内的丰富信息之间的关系。此外,如果能筛选出不相关的密集关系,那将会更好。

为此,作者将空间像素表示映射到一个单位超球体中,其中保留了与附近元素的空间局部相似性,而不失去对那些具有明显特征的元素的识别能力。这里不是测量所有元素之间的关系,而是只量化这些元素和类别Anchor之间的距离。它允许大量毫无意义的关系被丢弃,并自动完成稀疏化。作者将这种策略命名为拓扑距离蒸馏法(_Topological Distance Distillation_)。

解纠缠分支中的不同语义

分类和边界框回归在经验上并不是互补的。因此,它们是由两个独立的分支独立完成的,以解开特征,并使它们专注于优化自己的任务。此外,实验结果(表5)表明,仅仅通过边界框分支中的MSE损失来匹配像素并不利于其性能。

如图4所示,高激活通常聚集在分支中的对象周围,用于回归边界框,而分类分支的激活更有可能从对象区域开始广泛传播。考虑到这种差异,作者没有采用简单的像素-像素匹配,而是开发了定位分布对齐法来通过概率分布匹配学生和教师之间的相对定位回归。这样就避免了提取每个激活图的绝对值,并找到了另一种方法来建模空间域中相对位置的信息。

主要贡献:

解决了上述问题,并提出了很少研究的单阶段检测器蒸馏的最终解决方案。为了在多个尺度和层次中利用密集的特征映射,作者设计了一个合适的知识表示类别Anchor来总结场景中的一般语义模式。在此基础上,引入了拓扑距离来保持像素级样本中的语义联系。最后,将定位蒸馏作为一个分布对齐问题,以有效地传递定位信息的知识。作者将本文的方法命名为语义感知的对齐蒸馏,这涉及到从语义感知的角度利用细粒度知识的本质。

本文的方法在COCO检测任务上的性能大大优于之前所有最先进的目标检测蒸馏器,并显示了其在COCO实例分割任务上的通用性。此外,我们在两阶段检测器上取得了与最先进水平相当的强大性能。

虽然两阶段和一阶段的检测器采用不同的处理方案,但它们都有一个共同点:大部分参数都用于生成高级语义特征图,作为后续任务的基础。在这个交叉点上蒸馏是弥合它们之间差距的关键。在消融研究中分析了每个组件的影响,并测试了超参数的敏感性。

本文的方法是稳定的训练,并不引入额外的参数来进行蒸馏训练。据作者描述,这是第一次尝试实现单阶段目标检测器蒸馏。

相关方法目标检测

基于深度学习的目标检测器分为两种,即两阶段检测器和单阶段检测器。前者生成提取区域特征的prior proposals,而后者在没有先验的情况下完成目标检测。

两阶段的方法主要源于R-CNN,在下一步工作进一步细化R-CNN管道,以生成更精确的proposals并保证实时性。还有许多优秀的方法来进一步改进两阶段检测器。它们都按区域包围提取的特征,并将其视为分类和定位中的实例。

在没有生成proposals的情况下,单阶段检测器设计为低延迟,同时保持高精度。YOLO最初引入Anchor来从特征图中预测类和边界框。为了缓解密集检测中的前景背景不平衡问题,RetinaNet提出了Focal Loss来减轻分类良好的简单示例的权重。

此外,ATSS改进了标签分配算法,GFL提出了优化的Focal Loss和box distribution loss。也有Anchor-free检测器丢弃先前的边界框,例如FCOS,它仅从点进行检测。在本文中对这些检测器进行了实验,包括两阶段、单阶段 Anchor -based和Anchor-free 的检测器。

2.2 知识蒸馏

知识蒸馏将知识从大的教师模型转移到轻量级的学生模型。Hinton等人首先将输出层的soft logits视为需要提炼的知识表示。此外,中间特征表示和注意力映射也被于匹配学生和教师。最近,一些知识蒸馏框架进一步挖掘了这种潜力,如Mutual Learning、Noisy Student 和TAKD。

目标检测从知识蒸馏中获益很多,因为它以更低的推理成本和内存占用实现了相当大的准确性。这里将检测蒸馏方法分为4组:

Region-based

Chen等人首先提出了一个用于两阶段检测器的检测知识蒸馏框架,并专门重新设计了用于检测的logits distillation loss。Wang等人从GT中生成细粒度Mask以提取特征图中选定的近目标区域。Sun等人提出了一种高斯Mask以增加对目标中心附近区域的关注,并采用学习率衰减策略来提高泛化性。最近,Guo等人将这些特征分为前景和背景,处理方式不同。这些方法都首先提取注意力区域,然后对这些选定的像素进行蒸馏。

Relation-based

Chen等和Dai等在检测中利用实例关系将学生与教师相结合。这些方法还使用先前的区域(即建议和GT框)来选择区域特征作为元素,以建立关系,以进行进一步的蒸馏。

Backbone-based

Zhang等人利用了注意力模块,并生成了一个level的注意力图,用于在检测器中进行主干蒸馏。但是,它没有使用以下的检测头,它也包含了足够的信息。

  1. Segmentation-based

Liu等人提出了一种结构化的方法来转移规则像素块之间的关系。请注意,在检测中,这种关系由于背景的比例较高而带来了很大的噪声。Shu等人提取的通道信息-检测器的激活图与分割的不同。

语义感知对齐蒸馏

在本节中介绍了整体蒸馏管道,该管道由类别Anchor蒸馏、拓扑距离蒸馏和定位分布对齐组成。类别Anchor表示从多尺度特征图中挖掘的全局类别信息。它描述了语义原型和单个像素之间的分布,而不会受到类间不平衡的不利影响。因此,在潜在嵌入空间中可以更好地说明拓扑距离。定位分布对齐有效地利用分布匹配损失来提取从教师到学生的回归特征图,其中像素在空间相对位置上表现出语义。

Category Anchor Distillation

为了避免来自密集像素匹配的一系列卷积层的不平衡信息,作者设计了类别Anchor作为一个图像批处理中现有实例的分类摘要(图3)。categorical summary functions类似于两阶段检测器,它提取感兴趣的区域以辅助检测。作者选择了使用注意力像素来建立类别Anchor,而不是像素级模仿,而是对类别区域进行学生和教师模型之间的模仿。

考虑到正方形的边框并不代表真实物体的实际形状,作者将边框分为中心部分和边缘部分。因此,每个类别都拥有两个类别Anchor,它们收集了整个图像批处理中属于分类区域的所有像素。

Multi-level Anchor Matching

不同level的Anchor集中了从多分辨率特征图中收集到的信息。提取学生和教师Anchor在所有level的特征图:

Topological Distance Distillation

在起始点检测框架中,我们将每个像素视为一个单独的样本,其中每个像素都拥有足够的信息来支持以下分类和定位回归任务。同样,样本在嵌入空间中分散,且与Structured knowledge distillation之间存在相关性。这里不是建模所有像素之间的密集相关性,而是测量像素和类别Anchor之间的距离。这些距离构建了一个拓扑结构,用于根据Anchor来校准单个样本。训练有素的教师对距离有更精确的测量方法。因此,它规范了学生的拓扑结构。作者设计了学生和教师之间的拓扑距离蒸馏损失为:

Localization Distribution Alignment 

如图4所示,bbox特征图在图像中具有更被激活的像素,这实际上表明了它自身和GT中心之间的相对位置。受此发现的启发在空间域建立了分布对齐模型,并有效KL散度损失将bbox层提取为  

 实验 消融实验

组件分析

超参数灵敏度 

Loss penalty coefficients

主要实验 

Faster RCNN与Cascade R-CNN 

Mask RCNN与SOLOv2 

局限与总结 局限

一般的限制在于提炼的本质,教师模型不可避免地需要将其知识传递给学生模型。虽然蒸馏主要是针对小的学生模型,但对于大的学生模型很难找到合适的教师模型。

 总结

在本文中提出了用于目标探测器的SEA(SEmantic-Aware Alignment)蒸馏方法。为了弥合单阶段和两阶段检测器蒸馏之间的差距,SEA将每个像素作为实例,设计类别Anchor来总结场景图像中的分类信息,处理密集像素中的剧烈不平衡。在此基础上,对语义关系进行建模,并对其进行稀疏化,使蒸馏更加结构化和完整。此外,还有效地对齐了学生和教师之间的未被充分研究的边界框分支中的定位分布。大量的实验证明了SEA方法在目标检测和实例分割蒸馏任务方面的有效性和鲁棒性。

#PointNet++ 2

这也是第二集了啊

这里通过对模型训练和缩放策略的系统研究重新审视了经典的PointNet++,并提供了两个主要贡献,进而提出PointNeXt,表现SOTA!性能优于PointMLP、Point Transformer等网络

PointNeXt: Revisiting PointNet++ with Improved Training and Scaling Strategies

单位:KAUST, 微软

代码:https://github.com/guochengqian/pointnext

论文:https://arxiv.org/abs/2206.04670

PointNet++ 是用于点云理解的最有影响力的神经架构之一。尽管 PointNet++ 的准确性已被 PointMLP 和 Point Transformer 等最近的网络在很大程度上超越,但我们发现很大一部分性能提升是由于改进了训练策略,即数据增强和优化技术,以及增加了模型大小而不是架构创新。因此,PointNet++ 的全部潜力还有待探索。

在这项工作中,我们通过对模型训练和缩放策略的系统研究重新审视了经典的 PointNet++,并提供了两个主要贡献。

首先,我们提出了一组改进的训练策略,显著提高了 PointNet++ 的性能。例如,我们表明,在不改变架构的情况下,PointNet++ 在 ScanObjectNN 对象分类上的整体准确率(OA)可以从 77.9% 提高到 86.1%,甚至优于最先进的 PointMLP。

其次,我们将倒置残差瓶颈设计和可分离 MLP 引入 PointNet++,以实现高效且有效的模型缩放,并提出 PointNeXt,即下一版本的 PointNets。

PointNeXt 可以灵活扩展,在 3D 分类和分割任务上都优于最先进的方法。

图一 PointNeXt网络结构

算法细节

在这一节,我们展示了通过更先进的训练策略以及模型缩放策略提升PointNet++ 的性能。我们从两个小节分别介绍他们:(1)训练策略现代化;(2)网络架构现代化。

训练策略现代化

本章节中,我们简述我们的研究方法, 具体的训练策略可见后续的消融实验章节。

数据增强

数据增强是提升神经网络性能的最重要的方法之一,而PointNet++ 使用了简单的数据增强组合如随机旋转,缩放,平移,抖动(jitter)并应用于不同的数据集。最新的一些方法使用了更强的数据增强方法。例如, KPConv在训练时随机的失活(drop)部分颜色信息。在这篇工作中,我们收集了近期方法中用到的常见数据增强方法,并通过叠加实验定量地研究每个数据集上每种数据增强方法的效果。针对每一个数据集,我们提出了一组改进的数据增强方法,其可以大幅度提升了PointNet++ 的性能。

优化策略

优化技术主要包含损失函数(loss function),优化器(optimizer),学习率计划器(learning rate schedulers),和超参数(hyperparmeters)。随着机器学习理论的发展,现代化的神经网络可以被理论上更好的优化器(如AdamW)和更好的损失函数(CrossEntropy with label smoothing)训练。Cosine learning rate decay也在近年被大量使用,因为相比 step decay,它的调参更为简单而且效果不会差。在这篇工作中,我们通过叠加实验量化了每种优化策略对PointNet++的影响。同样的,针对每一个数据集,我们提出了一组改进的优化技术可以进一步提高网络性能。

模型架构现代化:小修改 → 大改进

感受野缩放

在点云网络中,使用不同的ball query radius (查询半径)会影响模型的感受野,进而影响性能。我们发现初始半径对于网络性能有很大程度上的影响,并且不同数据集上最佳查询半径不同。此外,我们发现相对坐标 使得网络优化更难,导致性能下降。因此,我们提出利用相对坐标处以查询半径以实现的归一化:

如果没有归一化,相对坐标的值会非常小(小于半径)。这就要求网络能学习到更大的权重应用于 。这使得优化变得困难,特别是考虑到权重衰减的正则化手段限制了网络权重的大小。

模型缩放

PointNet++ 用于分类和分割的模型规模均小于2M。而现在的网络参数普遍在10M以上[3,4]。有趣的是,我们发现无论是使用更多的SA模块还是使用更大的channel size都不会显著提高准确性,却反而导致thoughput显著下降。这主要是梯度消失和过度拟合导致的。在本小节中,我们提出了Inverted Residual MLP (InvResMLP)模块以实现高效实用的模型缩放。该模块建立在SA模块上,如图一的中部所示。InvResMLP和SA模块的不同点有三个:

  • 在模块的输入和输出之间添加了残差连接, 以缓解梯度消失问题
  • 引入了可分离的MLP 以减少计算量,并增强逐点的特征提取
  • 引入inverted bottleneck的设计,以提高特征提取的能力

在PointNet++基础上结合InvResMLP 和图一所示的宏观架构变化,我们提出了PointNeXt。我们将 stem MLP 的channel大小表示为 C,将 InvResMLP 模块的数量表示为 B。我们 PointNeXt 系列的配置总结如下:

  • PointNeXt-S: C = 32, B = 0
  • PointNeXt-B: C = 32, B = (1, 2, 1, 1)
  • PointNeXt-L: C = 32, B = (2, 4, 2, 2)
  • PointNeXt-XL: C = 64, B = (3, 6, 3, 3)

实验

在S3DIS语义分割上,PointNeXt-XL以mIoU/OA/mACC=74.9%/90.3%/83.0%超越了Point Transformer取得SOTA性能且在推理速度上更快。在ScanObjectNN分类上,PointNeXt-S超越目前的SOTA方法PointMLP,且推理速度快十倍。在ShapeNetPart部分分割上,加宽后的模型PointNeXt-S(C=160)达到87.2 Instance mIoU, 超越SOTA CurNet。

#PointNet~3

马上又更新了 这次是PointNet相关的全集

主要对Pointnet、PointNet++和F-PointNet三种模型行全面的解析,包括基本思路、网络结构、模型效果等各个方面

PointNet是由斯坦福大学的Charles R. Qi等人在《PointNet:Deep Learning on Point Sets for 3D Classification and Segmentation》一文中提出的模型,它可以直接对点云进行处理的,对输入点云中的每一个点,学习其对应的空间编码,之后再利用所有点的特征得到一个全局的点云特征。Pointnet提取的全局特征能够很好地完成分类任务,但局部特征提取能力较差,这使得它很难对复杂场景进行分析。

PointNet++是Charles R. Qi团队在PointNet论文基础上改进版本,其核心是提出了多层次特征提取结构,有效提取局部特征提取,和全局特征。

F-PointNet将PointNet的应用拓展到了3D目标检测上,可以使用PointNet或PointNet++进行点云处理。它在进行点云处理之前,先使用图像信息得到一些先验搜索范围,这样既能提高效率,又能增加准确率。

论文地址:https://arxiv.org/abs/1612.00593

开源代码-原论文实现:https://github.com/charlesq34/pointnet

开源代码-Pytorch实现:https://github.com/fxia22/pointnet.pytorch

1.1  PointNet思路流程

1)输入为一帧的全部点云数据的集合,表示为一个nx3的2d tensor,其中n代表点云数量,3对应xyz坐标。

2)输入数据先通过和一个T-Net学习到的转换矩阵相乘来对齐,保证了模型的对特定空间转换的不变性。

3)通过多次mlp对各点云数据进行特征提取后,再用一个T-Net对特征进行对齐。

4)在特征的各个维度上执行maxpooling操作来得到最终的全局特征。

5)对分类任务,将全局特征通过mlp来预测最后的分类分数;对分割任务,将全局特征和之前学习到的各点云的局部特征进行串联,再通过mlp得到每个数据点的分类结果。

1.2 PointNet网络结构

它提取的“全局特征”能够很好地完成分类任务。下面看一下PointNet的框架结构:

下面解释一个网络中各个部件的作用。

1)transform:第一次,T-Net 3x3,对输入点云进行对齐:位姿改变,使改变后的位姿更适合分类/分割;第二次,T-Net 64x64,对64维特征进行对齐。

2)mlp:多层感知机,用于提取点云的特征,这里使用共享权重的卷积。

3)max pooling:汇总所有点云的信息,进行最大池化,得到点云的全局信息。

4)分割部分:局部和全局信息组合结构(concate,语义分割)。

5)分类loss:交叉熵:分割loss:分类+分割+L2(transform,原图的正交变换)。

1.3 T-Net网络结构

将输入的点云数据作为nx3x1单通道图像,接三次卷积和一次池化后,再reshape为1024个节点,然后接两层全连接,网络除最后一层外都使用了ReLU激活函数和批标准化。

1.4 模型效果

ModelNet40 上的分类结果:

 ShapeNet部分数据集上的分割结果:

 不足:缺乏在不同尺度上提取局部信息的能力。

PointNet++  

论文地址:https://arxiv.org/abs/1706.02413

开源代码地址:https://github.com/charlesq34/pointnet2

Pointnet提取的全局特征能够很好地完成分类任务,由于模型基本上都是单点采样,代码底层用的是2Dconv,只有maxpooling整合了整体特征,所以局部特征提取能力较差,这使得它很难对复杂场景进行分析。
PointNet++的核心是提出了多层次特征提取结构,有效提取局部特征提取,和全局特征。

2.1  思路流程

先在输入点集中选择一些点作为中心点,然后围绕每个中心点选择周围的点组成一个区域,之后每个区域作为PointNet的一个输入样本,得到一组特征,这个特征就是这个区域的特征。

之后中心点不变,扩大区域,把上一步得到的那些特征作为输入送入PointNet,以此类推,这个过程就是不断的提取局部特征,然后扩大局部范围,最后得到一组全局的特征,然后进行分类。

2.2 整体网络结构

PointNet++ 在不同尺度提取局部特征,通过多层网络结构得到深层特征。PointNet++按照任务也分为 classification (分类网络)和 segmentation (分割网络)两种,输入和输出分别与PointNet中的两个网络一致。

 PointNet++会先对点云进行采样(sampling)和划分区域(grouping),在各个小区域内用基础的PointNet网络进行特征提取(MSG、MRG),不断迭代。

对于分类问题,直接用PointNet提取全局特征,采用全连接得到每个类别评分。对于分割问题,将高维的点反距离插值得到与低维相同的点数,再特征融合,再使用PointNet提取特征 。

比较PointNet++两个任务网络的区别:

在得到最高层的 feature 之后,分类网络使用了一个小型的 PointNet + FCN 网络提取得到最后的分类 score;

分割网络通过“跳跃连接” 操作不断与底层 “低层特征图”信息融合,最终得到逐点分分类语义分割结果。(“跳跃连接”对应上图的 skip link connection;低层特征图 具有分辨率较大,保留较丰富的信息,虽然整体语义信息较弱。)

2.3 网络结构组件

1)采样层(sampling)

激光雷达单帧的数据点可以多达100k个,如果对每一个点都提取局部特征,计算量是非常巨大的。因此,作者提出了先对数据点进行采样。作者使用的采样算法是最远点采样(farthest point sampling, FPS),相对于随机采样,这种采样算法能够更好地覆盖整个采样空间。

2)组合层(grouping)

为了提取一个点的局部特征,首先需要定义这个点的“局部”是什么。一个图片像素点的局部是其周围一定曼哈顿距离下的像素点,通常由卷积层的卷积核大小确定。同理,点云数据中的一个点的局部由其周围给定半径划出的球形空间内的其他点构成。组合层的作用就是找出通过采样层后的每一个点的所有构成其局部的点,以方便后续对每个局部提取特征。

3)特征提取层(feature learning)

因为PointNet给出了一个基于点云数据的特征提取网络,因此可以用PointNet对组合层给出的各个局部进行特征提取来得到局部特征。值得注意的是,虽然组合层给出的各个局部可能由不同数量的点构成,但是通过PointNet后都能得到维度一致的特征(由上述K值决定)。

2.4 不均匀点云组合grouping方法

不同于图片数据分布在规则的像素网格上且有均匀的数据密度,点云数据在空间中的分布是不规则且不均匀的。当点云不均匀时,每个子区域中如果在分区的时候使用相同的球半径,会导致部分稀疏区域采样点过小。作者提出多尺度成组 (MSG)和多分辨率成组 (MRG)两种解决办法。

1)多尺度组合MSG:对于选取的一个中心点设置多个半径进行成组,并将经过PointNet对每个区域抽取后的特征进行拼接(concat)来当做该中心点的特征,这种做法会产生很多特征重叠,结果会可以保留和突出(边际叠加)更多局部关键的特征,但是这种方式不同范围内计算的权值却很难共享,计算量会变大很多。

2)多分辨率组合MRG:MRG避免了大量的计算,但仍然保留了根据点的分布特性自适应地聚合信息的能力。对不同特征层上(分辨率)提取的特征再进行concat,以b图为例,最后的concat包含左右两个部分特征,分别来自底层和高层的特征抽取,对于low level点云成组后经过一个pointnet和high level的进行concat,思想是特征的抽取中的跳层连接。

当局部点云区域较稀疏时,上层提取到的特征可靠性可能比底层更差,因此考虑对底层特征提升权重。当然,点云密度较高时能够提取到的特征也会更多。这种方法优化了直接在稀疏点云上进行特征抽取产生的问题,且相对于MSG的效率也较高。

选择哪一种?

当局部区域的密度低时,第一矢量可能不如第二矢量可靠,因为计算第一矢量的子区域包含更稀疏的点并且更多地受到采样不足的影响。在这种情况下,第二个矢量应该加权更高。另一方面,当局部区域的密度高时,第一矢量提供更精细细节的信息,因为它具有以较低水平递归地表达较高分辨率检查的能力。

2.5 模型效果

分类对比:

分割对比:

 小结复杂场景点云一般采用PointNet++进行处理,而简单场景点云则采用PointNet。如果只从点云分类和分割两个任务角度分析,分类任务只需要max pooling操作之后的特征信息就可完成,而分割任务则需要更加详细的local context信息。

F-PointNet  

论文地址:https://arxiv.org/pdf/1711.08488.pdf

开源代码地址:https://github.com/charlesq34/frustum-pointnets

F-PointNet 也是直接处理点云数据的方案,但这种方式面临着挑战,比如:如何有效地在三维空间中定位目标的可能位置,即如何产生 3D 候选框,假如全局搜索将会耗费大量算力与时间。
F-PointNet是在进行点云处理之前,先使用图像信息得到一些先验搜索范围,这样既能提高效率,又能增加准确率。

3.1 基本思路 

首先使用在 RGB 图像上运行的 2D 检测器,其中每个2D边界框定义一个3D锥体区域。然后基于这些视锥区域中的 3D 点云,我们使用 PointNet/PointNet++ 网络实现了 3D实例分割和非模态 3D 边界框估计。总结一下思路,如下:

  1. 基于图像2D目标检测。
  2. 基于图像生成锥体区域。
  3. 在锥体内,使用 PointNet/PointNet++ 网络进行点云实例分割。

它是在进行点云处理之前,先使用图像信息得到一些先验搜索范围,这样既能提高效率,又能增加准确率。先看看下面这张图:

在这张图里,左上角的意思是先把图像和点云信息标定好(这个属于传感器的外参标定,在感知之前进行;获取两个传感器之间旋转矩阵和平移向量,就可以得到相互的位置关系)。

左下角是用目标检测算法检测出物体的边界框(BoundingBox),有了边界框之后,以相机为原点,沿边界框方向延伸过去就会形成一个锥体(上图的右半部分),该论文题目里frustum这个词就是锥体的意思。然后用点云对该物体进行识别的时候,只需要在这个锥体内识别就行了,大大减小了搜索范围。

3.2 模型框架

模型结构如下:(可以点击图片放大查看)

网络共分为三部分,第一部分是使用图像进行目标检测并生成锥体区域,第二部分是在锥体内的点云实例分割,第三部分是点云物体边界框的回归。

3.3 基于图像生成锥体区域

由于检测到的目标不一定在图像的正中心,所以生成的锥体的轴心就不一定和相机的坐标轴重合,如下图中(a)所示。为了使网络具有更好的旋转不变性,我们需要做一次旋转,使相机的Z轴和锥体的轴心重合。如下图中(b)所示。 

 3.4 在锥体内进行点云实例分割

实例分割使用PointNet。一个锥体内只提取一个物体,因为这个锥体是图像中的边界框产生的,一个边界框内也只有一个完整物体。

在生成锥体的时候提到了旋转不变性,此处完成分割这一步之后,还需要考虑平移不变性,因为点云分割之后,分割的物体的原点和相机的原点必不重合,而我们处理的对象是点云,所以应该把原点平移到物体中去,如下图中(c)所示。

3.5 生成精确边界框

生成精确边界框的网络结构:

从这个结构里可以看出,在生成边界框之前,需要经过一个T-Net,这个东西的作用是生成一个平移量,之所以要做这一步,是因为在上一步得到的物体中心并不完全准确,所以为了更精确地估计边界框,在此处对物体的质心做进一步的调整,如下图中(d)所示。 

下面就是边界框回归了,对一个边界框来讲,一共有七个参数,包括: 

最后总的残差就是以上目标检测、T-Net和边界框残差之和,可以据此构建损失函数。 

3.6 PointNet 关键点

(1) F-PointNet使用2D RGB图像

F-PointNet使用2D RGB图像原因是:1.当时基于纯3D点云数据的3D目标检测对小目标检测效果不佳。所以F-PointNet先基于2D RGB做2D的目标检测来定位目标,再基于2d目标检测结果用其对应的点云数据视锥进行bbox回归的方法来实现3D目标检测。2.使用纯3D的点云数据,计算量也会特别大,效率也是这个方法的优点之一。使用成熟的2D CNN目标检测器(Mask RCNN)生成2D检测框,并输出one-hot 分类向量(即基于2D RGB图像的分类)。

(2)锥体框生成        

2D检测框结合深度信息,找到最近和最远的包含检测框的平面来定义3D视锥区域frustum proposal。然后在该frustum proposal里收集所有的3D点来组成视锥点云(frustum point cloud)。

3.7 实验结果

与其他模型对比:

模型效果:

3.8 优点

(1)舍弃了global fusion,提高了检测效率;并且通过2D detector和3D Instance Segmentation PointNet对3D proposal实现了逐维(2D-3D)的精准定位,大大缩短了对点云的搜索时间。下图是通过3d instance segmentation将搜索范围从9m~55m缩减到12m~16m。

(2)相比于在BEV(Bird's Eye view)中进行3D detection,F-PointNet直接处理raw point cloud,没有任何维度的信息损失,使用PointNet能够学习更全面的空间几何信息,特别是在小物体的检测上有很好的表现。下图是来自Hao Su 2018年初的课程,现在的KITTI榜有细微的变动。

(3)利用成熟的2D detector对proposal进行分类(one-hot class vector,打标签),起到了一定的指导作用,能够大大降低PointNet对三维空间物体的学习难度。

3.9 模型代码

开源代码:GitHub - charlesq34/frustum-pointnets: Frustum PointNets for 3D Object Detection from RGB-D Data

作者代码的运行环境:

系统:Ubuntu 14.04 或 Ubuntu 16.04

深度框架:TensorFlow1.2(GPU 版本)或 TensorFlow1.4(GPU 版本)

其他依赖库:cv2、mayavi等。

#DiJiang

“又西三百五十里曰天山,多金玉,有青雄黄,英水出焉,而西南流注于汤谷。有神鸟,其状如黄囊,赤如丹火,六足四翼,浑敦无面目,是识歌舞,实惟帝江也。”——《山海经》,华为诺亚频域LLM「帝江」:仅需1/50训练成本,7B模型媲美LLaMA,推理加速5倍

基于 Transformer 架构的大语言模型在 NLP 领域取得了令人惊艳的效果,然而,Transformer 中自注意力带来的二次复杂度使得大模型的推理成本和内存占用十分巨大,特别是在长序列的场景中。

此前,研究者们提出了线性 Transformer、Mamba、RetNet 等。这些方案可以大幅降低 Transformer 计算成本,并且取得媲美原有模型的精度,但是由于架构更换,模型重训练带来的巨大成本令人望而却步。

本文着眼于大语言模型的训练和使用代价,提出一种从频域角度降低 LLM 的成本的 [帝江] 大语言模型。减少 Transformer 架构使用成本的一种常见方法是基于线性注意力机制 (Linear Attention),但是构建带有线性注意力机制的 Transformer 就需要重新训练整个模型,花费的计算代价太高了,对于巨量参数的 LLM 显然不切实际。

因此,本文提出频域核化 (Frequency Domain Kernelization) 方法:使用离散余弦变换 (DCT) 有效且精准地将 Transformer 的 Query 和 Key 映射到频域。这种映射能够有效地消除 Self-Attention 机制中的 Softmax 操作,使注意力计算复杂度转化为线性。而且,作者从理论上证明,这种频域映射是与原始注意力机制等效近似,允许预训练的原始 Transformer 模型在训练成本很小的情况下,转换为线性复杂度模型。

本文提出的加权准蒙特卡罗方法提供了优越的逼近效率。为了进一步降低计算复杂度,核化方法基于离散余弦变换 (Discrete Cosine Transform, DCT) 操作。本文所提出的方法实现了与原始 Transformer 相当的性能,但大大降低了训练成本 (约 1/10) 和更快的推理速度 (最快约 10 倍) 。

为了解决这一问题,最近的一篇论文提出了一种基于频域的大语言模型架构 — 帝江(源于山海经的一种神话生物,以跑得快而闻名),同时解决了现有大模型的两大痛点:推理成本和训练成本。

  • 论文地址:https://arxiv.org/abs/2403.19928
  • 开源链接:https://github.com/YuchuanTian/DiJiang

该论文基于频域自注意力变换核,寻找到一种原始自注意力的线性逼近,使得原有的 Transformer 模型可以经过少量数据(1/10-1/50)的微调,可以近乎无损地变形为论文提出的帝江模型。具体来说,在 LLaMA2-7B 上仅仅需要使用 40B 左右的训练数据,就可以取得最多 5 倍的推理加速,且在各个评测集上取得相当的精度。

DiJIang-7B 模型和 LLaMA-7B 的精度对比

DiJIang-7B 模型和 LLaMA-7B 的速度对比

研究背景

Transformer 架构自从推出以来,彻底革新了自然语言处理(NLP)领域,并在多种任务中取得了杰出成果。这一成功导致了大型语言模型(LLMs)主导的时代的到来,在这个时代中,Transformer 结构被放大以处理越来越复杂的任务。然而,这种规模的扩大也带来了巨大的计算需求,特别是由于需要每个 token 之间的计算的自注意力机制。

面对更高效 Transformer 模型的迫切需求,研究者们提出了线性 Transformer、Mamba、RetNet 等方案,虽然这些方案可以大幅降低 Transformer 计算成本,并且取得媲美原有模型的精度,但是由于架构更换,模型重训练带来的巨大成本令人望而却步。

然而,大多数现有的优化 Transformers 方法,特别是与优化注意力机制有关的,需要对模型从头重新训练。这一重新训练过程是一个巨大的挑战,特别是对于参数庞大的模型,需要大量的计算资源和时间投入。例如,像 LLaMA-7B 这样的大型模型的训练需要大约 8 万多 GPU hours。尽管有部分研究如 Performer 努力寻找注意力机制的快速近似方法,但这些方法在大型语言模型中还没有得到彻底的验证。

为了解决大型语言模型中快速注意力近似的问题,论文对现有的线性注意力方案和自注意力近似方案进行了彻底的分析。论文发现,这些方法中近似误差的主要来源是基于蒙特卡洛方法的采样。因此,论文提出采用加权拟蒙特卡洛采样来代替蒙特卡洛采样进行映射,论文进一步引入频域离散余弦变换(DCT)来作为拟蒙特卡洛采样的值,从而高效且准确地将 Transformer 的 query 和 key 映射到频域。使得注意力机制中的 softmax 操作可以被去除,达到线性的计算复杂度。论文还从理论上证明了,这种频域映射是与原始注意力机制的一个近似等效,从而使得帝江模型可以不需要从头开始训练,只需要少量数据就可以从 Transformer 的参数中进行微调继承。论文的实验表明,论文的方法达到了与原始 Transformer 相当的性能,但训练成本大大减少(<1/10),同时也受益于更快的推理速度(在不同模型上最高约 10 倍)。

方法介绍

论文首先回顾了 Attention 的计算方式:

论文提供了理论证明,来表明提出的 WPFF 映射核是一种更优的映射方式,具体的证明内容详见论文附录:

帝江模型和传统自注意力计算的区别

上图展示了帝江模型和传统自注意力计算的区别,在 Transformer 的注意力机制中,key 和 value 的计算通过快速离散余弦变换(DCT)高效地映射到频域。这种映射有效地消除了 softmax 操作,从而显著降低了 Transformer 的计算复杂度。

实验结果

不同模型大小的对比

上表展示了提出的帝江模型在不同大小的 scale 上的结果,可以看到,提出的帝江模型可以取得和原始模型基本相同的精度,并且拥有更快的推理速度和更低的训练成本,显著解决了现有 LLM 遇到的训推成本过大的问题。此外,模型在 1B 的模型量级上超越了 1.3B 大小的 Mamba 模型。需要注意的是,尽管传统 Transformer 可以通过 Flash Attention 的方式进行进一步加速,但由于针对帝江模型的加速框架尚未开发,为了公平对比模型本身的速度,推理速度的测试都是在模型都不使用加速框架的前提下进行的。   

与不同 Transformer 改进方案精度对比

论文还展示了帝江和其他 Transformer 模型的改进方案进行了进一步的对比,可以发现,帝江模型具有比其他模型更好的效果,这得益于其通过更好的核映射近似了原始的 Transformer 模型计

论文还同时提供了帝江 - 7B 模型的续写样例展示,可以看到,帝江 - 7B 的续写结果,和 LLaMA2-7B 相比毫不逊色,甚至条理性上要略胜一筹。

总结

论文提出了一种新的 LLM 架构:帝江,在 7B 以下的模型量级,所提出的模型可以大幅降低 LLM 所需的训练和计算成本,为未来 LLM 的高效部署提出了一种新的思路。帝江架构是否会在更大的模型与多模态 VLM 等其他 Transformer 的应用领域中大放光彩,让我们拭目以待。

大语言模型需要极简注意力机制

Transformer 彻底改变了自然语言处理的领域,也带来了大语言模型 (LLM) 所主导的时代。LLM 可以处理很复杂的任务,但同时也带来了大量的计算需求:显著的推理成本和能耗,使得在手机和机器人这类端侧设备的部署显著受阻。

在大量的模型压缩策略中,简化注意力机制 (simplifying the attention mechanism) 是一种极具前景的方法。比如 Linear Transformer[1],Performer[2]。还有很多经典的改进注意力机制复杂度的技术路线,比如:

  • RWKV:RWKV: Reinventing RNNs for the Transformer Era
  • RetNet:Retentive Network: A Successor to Transformer for Large Language Models
  • Mamba:Mamba: Linear-Time Sequence Modeling with Selective State Spaces

但是,大多数现有的优化 Transformer 的方法,通常需要对模型架构进行重大修改,且通常需要从头训练整个的模型以实现最佳性能。这样的重新训练过程对于 LLM 这种参数量巨大的模型而言,的确是个不小的挑战。比如,训练一个 LLaMA-7B[3] 量级的模型需要 82,432 GPU-hours,总功耗约为 36 MWh。对于这种量级的模型,再训练不仅会带来比较可观的经济问题,还会引发不小的环境问题。那么就需要更有效的方法来适应和优化这些大模型。简化的注意力机制建模方法在大语言模型上面还没有得到很完善的验证。

频域的核化注意力机制

这种频域中的核注意力机制不仅提高了 Transformer 的可扩展性,使其能够轻松处理更大的数据集和序列,而且还显着加快了训练和推理阶段。

1.3 不同尺寸的结果测评

作者使用 Pythia[6]来验证本文方法,这是一个具有完全公共数据集和训练过程的模型,从而可以实现公平比较。作者遵循 Pythia 使用的训练策略,包括学习率、优化器和其他超参数,并使用 Pile 数据集。Pile[7]数据集是一个 825 GiB 大小的英文文本语料库,专为训练大语言模型而设计。它由 22 个不同的高质量子集组成,其中许多来自于学术或者专业资源。这个全面多样的数据集是开发和微调大模型的基础。本文的 DiJiang 模型是从预训练的 Pythia 模型微调得到的。

作者在 Pythia 使用的6个公共数据集上评估了本文方法:

  • PIQA[8]
  • WinoGrande
  • WSC[9]
  • ARC-E
  • ARC-C[10]
  • LogiQA[11]

Pythia 的模型来自 HuggingFace[12]。实验结果如下图3所示。本文方法在从 70M 到 2.8B 参数的不同大小的模型中都取得了不错的结果。6个数据集上的平均性能与原始 Pythia 的性能几乎相同,但训练成本只有其约 1/16。而且,DiJiang 模型的推理速度明显快于原始的 Pythia。这些结果证实了本文方法可以在不影响性能的情况下提高大型语言模型的效率。

图3:不同尺寸模型的实验结果。训练时长在 A800 上测得,推理时使用 2048 的 token 长度

1.4 不同模型的结果测评

为了评估在不同模型的有效性,作者进一步将本文方法应用于 OPT-350M[13][14]和 TinyLLaMA-1.1B[15]模型。需要注意的是,由于它们的训练数据不能完全访问,因此作者继续使用 Pile 数据集来做微调。

最后,作者对著名的大语言模型 LLAMA2-7B 进行了实验,并将其微调到 DiJiang-7B 模型。图4为实验结果,可以看到 DiJiang-7B 模型在各种基准测试中实现了与原始 LLAMA2-7B 几乎相同的结果。值得注意的是,DiJiang 模型只需要 40B tokens 的训练数据,远远小于 LLAMA2-7B 的 2T tokens。这也证明了本文的方法可以扩展到 7B 参数量级别的模型中。

有趣的是,尽管使用的数据集十分有限,但本文方法的结果与原始模型相似,而且训练成本显著降低,速度也更快。这一结果进一步证明了本文方法的泛化性和灵活性。而且,在有些原始训练数据集不可用的情况下,也有一定潜在适用性。

图4:不同 Benchmark 上与 LLaMA2-7B 的对比

1.5 与线性 Transformer 的对比

为了比较本文方法与其他线性复杂度自注意力 Transformer 模型,作者验证了 Pythia-400M 在不同模型 (包括 Linformer、Performer、RetNet 和 Cosformer) 上的微调结果。

图5:与线性 Transformer 微调 Pythoia-410M 的结果对比

为了公平比较,作者采用了相同的训练设置和数据。如图5所示为比较结果。虽然现有的方法可以通过重新训练获得良好的结果,但在不重新训练,仅仅微调的情况下,大多数会遭受显著的精度损失。这主要是因为这些方法难以准确地逼近原始注意力机制,导致无法以最小的训练代价恢复原始模型的精度。

作者还可视化了不同方法的训练曲线,如图6所示。本文方法展示出最快的下降率,最终也实现了最低的损失值。这种快速收敛性也说明了本文可以快速达到与原始 Transformer 相似的性能水平,验证了本文方法在逼近注意力机制方面的有效性。这个结果进一步巩固了本文的结论,即本文方案作为 Transformer 线性替代方案的可行性。

图6:不同方法的训练曲线对比

1.6 推理时间对比

作者还评估了本文方法与 Transformer 模型相比的内存使用量和吞吐量。作者选择 Pythia-410M 模型分析。结果如图7所示。随着 tokens 长度的增加,本文模型的内存占用和推理速度不会变化。这一结果可以归结为本文方法的线性复杂度的注意力机制,表明其更有利于长序列的推理。

由于注意力机制呈二次方的计算复杂度,随 tokens 长度的增加,原始 Transformer 模型在推理时间和所需内存方面都持续增加。这个比较结果突出了本文方案的效率和实用性,尤其是当计算资源是瓶颈问题的长序列推理的情况下。

图7:DiJiang 和原始 Transformer 模型的内存使用量和吞吐量对比

1.7 可视化

为了进一步证明本文提出的模型和近似方案对注意力机制近似的有效性,作者绘制了不同方法得到的注意力图的可视化结果,如下图8所示。原始 Transformer 模型的注意力图 (图8(a)) 有丰富的信息,这为其鲁棒能力奠定了基础。相比之下,线性注意力机制 (例如 Performer (图8(b))) 产生的注意力图很难捕捉到 token 之间的关系,导致其映射与原始 Transformer 不同,最终导致模型精度下降。

本文方法 (图8(c)) 通过使用加权准蒙特卡罗方案,非常接近原始注意力机制。这就允许它有效地建模不同 token 之间的关系,实现与原始 Transformer 模型的结果几乎相同。这个比较结果突出了其他线性注意力方法在捕获 token 相互依赖性方面的不足,也展示了本文方法在准确逼近注意力机制的同时提高了计算效率。

图8:不同架构的注意力图可视化结果

#SoftLabel

来从标签平滑和知识蒸馏理解,先探讨一下hard label和soft label之间的关系,然后介绍一下如何用可靠的方法得到蕴含更多信息的soft label,其中主要包含标签平滑和知识蒸馏两种经典方法。

深度学习领域中,通常将数据标注为hard label,但事实上同一个数据包含不同类别的信息,直接标注为hard label会导致大量信息的损失,进而影响最终模型取得的效果。本文首先探讨一下hard label和soft label之间的关系,然后介绍一下如何用可靠的方法得到蕴含更多信息的soft label,其中主要包含Label SmoothingKnowledge Distillation两种经典方法。

Hard Label vs Soft Label

hard label更容易标注,但是会丢失类内、类间的关联,并且引入噪声。

soft label给模型带来更强的泛化能力,携带更多的信息,对噪声更加鲁棒,但是获取难度大。

Label Smoothing

Softmax Cross Entropy不仅可以做分类任务(目标为one-hot label),还可以做回归任务(目标为soft label)。设网络输出的softmax prob为p,soft label为q,那Softmax Cross Entropy定义为: 

InfoNCE可以拆分成两个部分,alignment和uniformity。 

如上图所示,alignment部分只跟positive pair相关,希望positive pair的feature拉近,uniformity部分只跟negative pair相关,希望所有点的feature尽可能均匀分布在unit hypersphere上。从softmax和InfoNCE损失函数上理解,把InoNCE公式的分母想象成soft label的所有位置相加,也就是最大值的那个位置可以看成是positive pair,其他位置都可以看成是negative pair,softmax的损失函数不是跟InfoNCE损失函数一模一样了吗,异曲同工!也就是说hard label可以认为只有positive pair,而soft label仍然保留negative pair。因此,soft label更容易避免退化解问题。 

上图是sigmoid曲线。Softmax Cross Entropy 的loss曲线其实跟sigmoid类似,越靠近1的时候,loss曲线会越平缓,这里以sigmoid曲线图为例。

从softmax的损失函数曲线上理解,hard label监督下,由于softmax的作用,one-hot的最大值位置无限往1进行优化,但是永远不可能等于1,从上图可知优化到达一定程度时,优化效率就会很低,到达饱和区。而soft label可以保证优化过程始终处于优化效率最高的中间区域,避免进入饱和区。

Knowledge Distillation

knowledge distillation相比于label smoothing,最主要的差别在于,知识蒸馏的soft label是通过网络推理得到的,而label smoothing的soft label是人为设置的。

原始训练模型的做法是让模型的softmax分布与真实标签进行匹配,而知识蒸馏方法是让student模型与teacher模型的softmax分布进行匹配。直观来看,后者比前者具有这样一个优势:经过训练后的原模型,其softmax分布包含有一定的知识——真实标签只能告诉我们,某个图像样本是一辆宝马,不是一辆垃圾车,也不是一颗萝卜;而经过训练的softmax可能会告诉我们,它最可能是一辆宝马,不大可能是一辆垃圾车,但绝不可能是一颗萝卜。

知识蒸馏得到的soft label相当于对数据集的有效信息进行了统计,保留了类间的关联信息,剔除部分无效的冗余信息。 相比于label smoothing,模型在数据集上训练得到的soft label更加可靠。

比较短哦 ~ 

#Bayesian Flow Networks(BFN)合集5

主要对 BFN 核心部分的代码实现进行详细的解析以及介绍了工程实现上的经验性问题-分布式训练时随机种子的设置

原本计划上一篇就会 kill 掉整个系列,因为不确定作者是否会开源。于是在这忐忑之期就到北京溜达了一圈,回来后惊喜(恐)地发现作者还真的开源了!“喜”当然是因为我终究能够续上一直以来在文章中贯彻的“不无聊风格”——源码解析;至于“恐”嘛~ 就是我又得费脑和费手指了..

本文是整个系列的终结篇(CW 很认真,不开玩笑!),主要内容是对 BFN 核心部分的代码实现进行解析,主要包括(按顺序):模型输入输出并计算 loss、采样生成样本(生成模型的天职)、BFN 的核心——贝叶斯流(bayesian flow)的实现、模型训练的武功秘籍——损失函数的实现、BFN 建模的关键——输出分布的实现、神经网络(model)本身的实现、数据加载和预处理、整体训练流程 以及 一个工程上实现的问题——分布式训练时随机种子的设置。

在最后一章,CW 难免要吹吹水,于是先简单总结下 BFN 的玩法,然后将其与扩散模型进行比较,最后发自内心谈谈自己对这个方法论的理解与看法。

如果仅关注 BFN 算法本身的代码实现,那么可以只看前五章;否则,如果你连 BFN 本身是什么都不知道(请问您是怎么进来的..),那么就直接跳转到最后一章吧,或许能令你对 BFN 有个浅浅的认识(并体会到不无聊的风格);又或者,你不小心手抖点进来了,也可以看看倒数第二章,即第九章,那是个工程实现上的经验性问题,只要是在 Pytorch 的分布式框架下玩都适用;若这些情况都不是,那么 CW 懂了——你是要看完全文!辛苦了您,感恩~!

附:BFN 官方源码:https://github.com/nnaisense/bayesian-flow-networks

一、Loss 计算

作者将 loss 的计算流程封装在了 BFN 这个类里,同时,还在其中封装了采样生成的过程。所以,要注意在代码实现中,这个类并非代表神经网络(model)本身的实现,而是 BFN work 的整体逻辑:loss 计算对应训练过程、采样生成则对应推理过程。

整体流程

class BFN(nn.Module):def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: Loss):super().__init__()self.net = netself.bayesian_flow = bayesian_flowself.loss = lossdef forward(self, data: Tensor, t: Optional[Tensor] = None, n_steps: Optional[int] = None) -> tuple[Tensor, dict[str, Tensor], Tensor, Tensor]:"""Compute an MC estimate of the continuous (when n_steps=None or 0) or discrete time KL loss.t is sampled randomly if None. If t is not None, expect t.shape == data.shape.使用蒙特卡洛方法估计发送者分布和接收者分布之间的 KL 散度损失:-采样时间变量;-从贝叶斯流分布中采样得到输入分布的参数(后验更新);-将输入分布的参数喂给模型;-模型返回输出分布;-计算连续/离散时间 loss."""t = self.sample_t(data, n_steps) if t is None else t# sample input parameter flow# 从贝叶斯流分布中采样出输入分布的参数(代表已完成后验更新).input_params = self.bayesian_flow(data, t)# 在输入模型前转换为适合于模型输入的形式(如有必要的话)net_inputs = self.bayesian_flow.params_to_net_inputs(input_params)# compute output distribution parameters# 注意, 这里模型输出的通常不是输出分布的参数, 而是某些变量(比如估计的噪声),# 它们经过后处理才最终成为输出分布的参数.output_params: Tensor = self.net(net_inputs, t)# compute KL loss in float32with torch.autocast(device_type=data.device.type if data.device.type != "mps" else "cpu", enabled=False):if n_steps == 0 or n_steps is None:loss = self.loss.cts_time_loss(data, output_params.float(), input_params, t)else:loss = self.loss.discrete_time_loss(data, output_params.float(), input_params, t, n_steps)# loss shape is (batch_size, 1)return loss.mean()

loss 计算的整个流程 CW 已在上述注释中写明。在连续时间的情况下是不需要指定总时间步 n_steps 的,因此当 n_steps = 0 或未指定时就使用连续时间的损失函数进行计算;否则,就使用离散时间的损失函数。至于损失函数的实现,后文会详细解析。

以上前向过程 forward() 的第一步就是采样出时间变量,下面来看看这一步的具体实现。

时间变量的采样

@staticmethod@torch.no_grad()def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor:"""采样时间变量 t, 包括连续时间和离散时间两种情况."""# 连续时间情况不需要指定总步数, 从 U(0,1) 连续型均匀分布中采样.if n_steps == 0 or n_steps is None:# (B,1)t = torch.rand(data.size(0), device=data.device).unsqueeze(-1)# 离散时间情况则先从 U{0,n-1} 离散型均匀分布采样出时间步,然后再除总步数 n 计算出对应的时间变量值: t = \frac{i-1}{n}# 注意, 这是每个区间起始时刻的值.else:# (B,1)t = torch.randint(0, n_steps, (data.size(0),), device=data.device).unsqueeze(-1) / n_steps# 扩展至和数据同样的维度, 不同的数据样本的时间变量不一致, 同一个样本内所有维度上所对应的时间变量则相同.t = (torch.ones_like(data).flatten(start_dim=1) * t).reshape_as(data)return t

这个 sample_t() 方法也是封装在 BFN 这个类里的,但从程序设计的逻辑上来看,它并不专属于某个特定的类,而是可以作为通用方法来使用的,因此用 @staticmethod 修饰器使其成为静态方法。

时间变量仅在不同数据样本之间存在差异,而同一个样本在所有维度上都应该拥有相同的时间变量值,于是采样的时间变量个数只需与数据样本的数量相等即可,这个数对应于 batch_size,也就是 data.size(0),采样完成后再将维度扩充至与数据相同。

二、采样生成

以下是采样生成样本的过程,整体可概括为:

  1. 设置先验参数 input_params
  2. 根据当前时间步计算对应的时间变量 t
  3. 将先验和时间变量输入模型令其返回输出分布的参数 output_params
  4. 从输出分布中采样,采样结果当作当前步骤的生成样本 output_sample
  5. 根据当前时间步计算出对应的精度 alpha
  6. 以输出分布的样本和精度为参数,从发送者分布中采样出观测样本 y
  7. 利用观测样本根据贝叶斯更新函数(贝叶斯定理)计算后验,从而对先验进行更新 update_input_params(...)
  8. 不断重复 2~7,待完成至规定的总步数 n_steps 后(那时 t=1)再根据 3 ~ 4 生成最终的样本
@torch.inference_mode()def sample(self, data_shape: tuple, n_steps: int) -> Tensor:device = next(self.parameters()).device# 起始时刻的先验input_params = self.bayesian_flow.get_prior_input_params(data_shape, device)distribution_factory = self.loss.distribution_factoryfor i in range(1, n_steps):# t_{i-1} = \frac{i-1}{n}t = torch.ones(*data_shape, device=device) * (i - 1) / n_steps# 模型接收输入分布的参数并预测,形成输出分布的参数后,再从其中采样作为预测(生成)的数据样本.output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)output_sample = distribution_factory.get_dist(output_params, input_params, t).sample()output_sample = output_sample.reshape(*data_shape)# 计算精度 \alpha_ialpha = self.bayesian_flow.get_alpha(i, n_steps)# 采样观测样本y = self.bayesian_flow.get_sender_dist(output_sample, alpha).sample()# 后验更新input_params = self.bayesian_flow.update_input_params(input_params, y, alpha)# 最后时刻 t=1t = torch.ones(*data_shape, device=device)output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t)# 概率分布的众数(mode)作为样本.output_sample = distribution_factory.get_dist(output_params, input_params, t).modeoutput_sample = output_sample.reshape(*data_shape)return output_sample

以上需要特别注意的是,最终是使用输出分布的 mode:也就是一个概率分布的众数(概率最大处所对应的样本,即最有可能出现的结果)作为生成结果;而在前面迭代的过程中,使用的是输出分布的常规采样结果作为当前步骤生成的样本。

三、贝叶斯流的实现

贝叶斯流的目标是计算后验,从而对先验进行更新。但与基于贝叶斯定理来计算后验的单步更新不同,它能够根据原始数据样本和任意时间变量计算出对应时刻的后验,而不依赖于由起始时刻至今过程中的那些观测样本。

作者实现了一个抽象基类 BayesianFlow,其中定义了贝叶斯流会用到的一些方法(定义为抽象方法 abstractmethod),而建模不同类型数据时所对应的贝叶斯流都要继承这个基类,并且将抽象方法都真正地实现(overwrite)。

class BayesianFlow(nn.Module, ABC):def __init__(self):super().__init__()@abstractmethoddef get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, ...]:"""Returns the initial input params (for a batch) at t=0. Used during sampling.For discrete data, the tuple has length 1 and contains the initial class probabilities.For continuous data, the tuple has length 2 and contains the mean and precision.返回起始时刻的先验参数, 作为模型的输入, 方法用于采样过程的开端."""pass@abstractmethoddef params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor:"""Utility method to convert input distribution params to network inputs if needed.如果有必要的话, 将输入分布的参数转换为适合模型输入的形式.比如在建模离散化数据时, 输入分布的参数代表概率, 取值范围在[0,1], 于是在输入模型前会将其 scale 至[-1,1],从而与其他类型的数据场景兼容, 并且避免让模型永远只接收非负值."""pass@abstractmethoddef get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float:"""Returns the alpha at step i of total n_steps according to the flow schedule. Used:a) during sampling, when i and alpha are the same for all samples in the batch.b) during discrete time loss computation, when i and alpha are different for samples in the batch.计算某个离散时间步所对应的精度: \alpha_i = \beta(t_i) - \beta(t_{i-1}), 用于采样过程或离散时间的损失函数. """pass@abstractmethoddef get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:"""Returns the sender distribution with accuracy alpha obtained by adding appropriate noise to the data x. Used:a) during sampling (same alpha for whole batch) to sample from the output distribution produced by the net.b) during discrete time loss computation when alpha are different for samples in the batch.返回指定精度 \alpha 下的输入分布. """pass@abstractmethoddef update_input_params(self, input_params: tuple[Tensor, ...], y: Tensor, alpha: float) -> tuple[Tensor, ...]:"""Updates the distribution parameters using Bayes' theorem in light of noisy sample y.Used during sampling when alpha is the same for the whole batch.根据贝叶斯定理利用观测样本 y 计算后验, 从而更新先验. """pass@abstractmethoddef forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]:"""Returns a sample from the Bayesian Flow distribution over input parameters at time t conditioned on data.Used during training when t (and thus accuracies) are different for different samples in the batch.For discrete data, the returned tuple has length 1 and contains the class probabilities.For continuous data, the returned tuple has length 2 and contains the mean and precision.从贝叶斯流分布中采样得到后验, 代表对输入分布参数的更新. """pass

由上可以看到,作者将单步更新的贝叶斯更新函数也封装在了 BayesianFlow 这个类里,可能考虑到两者的目标都一致吧(都是计算后验、更新先验)。

建模连续和离散化数据的贝叶斯流被作者实现为 CtsBayesianFlow 类,因为离散化数据就是由连续数据经过离散化操作而得到的,所以两者共用一套逻辑;而建模离散数据的贝叶斯流则实现为 DiscreteBayesianFlow 类。

接下来,我们就分别深入到两者的内部去一探究竟吧~!

建模连续和离散化数据

在建模连续和离散化数据时,贝叶斯流分布为:

class CtsBayesianFlow(BayesianFlow):"""建模连续/离散化数据的贝叶斯流."""def __init__(self,min_variance: float = 1e-6,):super().__init__()self.min_variance = min_variance@torch.no_grad()def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]:"""返回贝叶斯流分布的采样结果, 即经过后验更新的输入分布的均值向量: \mu."""# \omega_1^{2t}post_var = torch.pow(self.min_variance, t)# \gamma(t)alpha_t = 1 - post_var# \gamma(t)(1-\gamma(t))mean_var = alpha_t * post_var# 贝叶斯流分布的均值: \gamma(t)xmean_mean = alpha_t * data# 贝叶斯流分布的标准差: \sqrt{\gamma(t)(1-\gamma(t))}mean_std_dev = mean_var.sqrt()# 标准高斯噪声noise = torch.randn(mean_mean.shape, device=mean_mean.device)# 利用重参数化技术构造贝叶斯流分布的样本mean = mean_mean + (mean_std_dev * noise)# We don't need to compute the variance because it is not needed by the network, so set it to Noneinput_params = (mean, None)return input_params

另外,以上并非直接从贝叶斯流分布中进行采样,而是使用了重参数化技术——先从标准正态分布中采样出高斯噪声,然后再通过 scale & shift 获得目标分布的采样结果,即:

这个过程对应于以下代码中的 update_input_params() 方法。

至于其它就比较琐碎且简单了,各位客官自行看下面代码即可:

def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:# 仅取输入分布的均值向量作为 BFN 的输入# Only the mean is used by the networkreturn params[0]def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, float]:# 起始时刻的先验是标准高斯分布, 均值为0, 方差为1(协方差矩阵是对角元均为1的对角阵)return torch.zeros(*data_shape, device=device), 1.0def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:# 根据 \beta(t_i) - \beta(t_{i-1}) 计算, 其中 t_i = \frac{i}{n}.sigma_1 = math.sqrt(self.min_variance)return (sigma_1 ** (-2 * i / n_steps)) * (1 - sigma_1 ** (2 / n_steps))def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:# 返回输入分布, 精度 \alpha 是方差的倒数.dist = D.Normal(x, 1.0 / alpha**0.5)return distdef update_input_params(self, input_params: tuple[Tensor, float], y: Tensor, alpha: float) -> tuple[Tensor, float]:"""贝叶斯更新函数, 对输入分布的参数进行后验更新."""input_mean, input_precision = input_params# \rho_i = \rho_{i-1} + \alphanew_precision = input_precision + alpha# 根据贝叶斯定理计算: \mu_i = \frac{ \rho_{i-1} \mu_{i-1} + \alpha y }{\rho_i}new_mean = ((input_precision * input_mean) + (alpha * y)) / new_precisionreturn new_mean, new_precision

建模离散数据

与连续和离散化数据的场景不同,当建模离散数据时,贝叶斯流分布则为:

class DiscreteBayesianFlow(BayesianFlow):def __init__(self,n_classes: int,min_sqrt_beta: float = 1e-10,discretize: bool = False,epsilon: float = 1e-6,max_sqrt_beta: float = 1,):super().__init__()# Kself.n_classes = n_classes# 一个极小值, 用于将传入贝叶斯流分布的时间变量最大值限制至 1-epsilon.# 因为贝叶斯流分布是用于最终时刻前的, 所以需要 t < 1.self.epsilon = epsilon# 是否进行离散化操作self.discretize = discretize# \sqrt{\beta} 的下限self.min_sqrt_beta = min_sqrt_beta# \sqrt{\beta(1)}self.max_sqrt_beta = max_sqrt_beta# 均匀分布的期望熵: H = - \sum_{i=1}^K{p(x_i)ln(p(x_i))}, p(x_i)=\frac{1}{K}self.uniform_entropy = math.log(self.n_classes)@torch.no_grad()def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]:"""根据贝叶斯流分布完成后验更新."""if self.discretize:# 若要进行离散化操作, 则将数据以对应的离散化区间索引表示.data = float_to_idx(data, self.n_classes)# \sqrt{\beta(t)}sqrt_beta = self.t_to_sqrt_beta(t.clamp(max=1 - self.epsilon))lo_beta = sqrt_beta < self.min_sqrt_betasqrt_beta = sqrt_beta.clamp(min=self.min_sqrt_beta)# \beta(t)beta = sqrt_beta.square().unsqueeze(-1)# 从精度参数为 \beta(t) 的发送者分布中采样观测样本以作为贝叶斯流分布的 logits.logits = self.count_sample(data, beta)probs = F.softmax(logits, -1)# 将精度太小的部分所对应的后验以均匀先验 \frac{1}{K} 代替.# 这是因为精度太小, 那么对应的观测样本也"不靠谱"——所包含真实数据的信息太少,# 将其作为 logits 就不靠谱, 即以此为根据而实现的后验更新意义不大.probs = torch.where(lo_beta.unsqueeze(-1), torch.ones_like(probs) / self.n_classes, probs)if self.n_classes == 2:# 如果是二分类则只取其中一类的概率即可.probs = probs[..., :1]probs = probs.reshape_as(data)input_params = (probs,)return input_paramsdef t_to_sqrt_beta(self, t):"""计算当前时刻的 accuracy schedule: \beta(t) 的开根:sqrt{\beta(t)} = t \sqrt{\beta(1)}."""return t * self.max_sqrt_betadef count_dist(self, x, beta=None) -> D.Distribution:"""贝叶斯流分布中的期望部分所对应的发送者分布."""# Ke_x - 1mean = (self.n_classes * F.one_hot(x.long(), self.n_classes)) - 1# \sqrt{K}std_dev = math.sqrt(self.n_classes)if beta is not None:# \beta(t)(Ke_x - 1)mean = mean * beta# \sqrt{\beta(t)K}std_dev = std_dev * beta.sqrt()return D.Normal(mean, std_dev, validate_args=False)def count_sample(self, x, beta):"""利用重参数化采样技术(rsample())采样出观测样本作为贝叶斯流分布的 logits 源(下一步将其输入 softmax 以实现后验更新)."""return self.count_dist(x, beta).rsample()

利用贝叶斯流分布更新先验的整个过程即以上的前向过程 forward(),其中代码对应的释义 CW 都已详细注解。在上面的代码实现中,需要注意的细节有几个:

关于以上最后一点,通常使用重参数化采样是因为要使得梯度流能通过要学习的参数,但是以上这部分却没有需要学习的参数,之所以还这样做可能是考虑到在高维空间中从标准高斯分布中采样会相对高效。另外,对数据进行离散操作的 float_to_idx() 方法会在后文“数据加载与预处理”那章进行解析。

与前面一节建模连续和离散化数据时一样,这个 DiscreteBayesianFlow 类还封装了许多有用的方法,比如:

  • 根据贝叶斯定理来计算后验的贝叶斯更新函数

@torch.no_grad()def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor]:"""初始先验: 各类别概率相等的均匀分布 U{1, K}."""# 注意返回的是元组, 这是为了与连续/离散化数据的场景保持一致性.return (torch.ones(*data_shape, self.n_classes, device=device) / self.n_classes,)@torch.no_grad()def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor:params = params[0]if self.n_classes == 2:# 作者使用的 MNIST 数据集是经过二值化处理的, 因此这部分针对 MNIST 操作,# 将模型输入的范围缩放至 [-1,1]params = params * 2 - 1  # We scale-shift here for MNIST instead of in the network like for text# 因为总共只有两个类别, 所以取其中一类所对应的概率即可.params = params[..., :1]return paramsdef get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]:# 计算离散时间步所对应的精度: \alpha_i = \beta(1) \frac{2i-1}{n^2}return ((self.max_sqrt_beta / n_steps) ** 2) * (2 * i - 1)def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution:e_x = F.one_hot(x.long(), self.n_classes)alpha = alpha.unsqueeze(-1) if isinstance(alpha, Tensor) else alphadist = D.Normal(alpha * ((self.n_classes * e_x) - 1), (self.n_classes * alpha) ** 0.5)return distdef update_input_params(self, input_params: tuple[Tensor], y: Tensor, alpha: float) -> tuple[Tensor]:"""贝叶斯更新函数: 利用贝叶斯定理计算后验."""new_input_params = input_params[0] * y.exp()new_input_params /= new_input_params.sum(-1, keepdims=True)# 注意返回的是元组return (new_input_params,)

四、损失函数的实现

作者将损失函数封装在了一个抽象基类 Loss 里,其中包含了三种具体的 loss 计算:连续时间下发送者分布和接收者分布的 KL loss、离散时间下发送者分布和接收者分布的 KL loss 以及 实际不参与训练的重构 loss。

无论是针对 连续、离散化 亦或是 离散数据的损失函数,都要继承这个基类,并实现以上三种 loss 计算的逻辑。

class Loss(nn.Module, ABC):def __init__(self):super().__init__()@abstractmethoddef cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor) -> Tensor:"""Returns the continuous time KL loss (and any other losses) at time t (between 0 and 1).The input params are only used when the network is parameterized to predict the noise for continuous data.连续时间的损失函数. """pass@abstractmethoddef discrete_time_loss(self, data: Tensor,output_params: Tensor, input_params: Tensor,t: Tensor, n_steps: int, n_samples: int = 20) -> Tensor:"""Returns the discrete time KL loss for n_steps total of communication at time t (between 0 and 1) usingn_samples for Monte Carlo estimation of the discrete loss.The input params are only used when the network is parameterized to predict the noise for continuous data.离散时间的损失函数, 当所需计算的 KL 散度没有解析形式时, 使用蒙特卡洛方法来近似估计. """pass@abstractmethoddef reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:"""Returns the reconstruction loss, i.e. the final cost of transmitting clean data.The input params are only used when the network is parameterized to predict the noise for continuous data.重构损失, 不参与训练. """pass

连续和离散化数据的损失函数

连续和离散化数据的 loss 计算共用一套逻辑,被封装在 CtsBayesianFlowLoss 这个类中。

def sandwich(x: Tensor):return x.reshape(x.size(0), -1, x.size(-1))class CtsBayesianFlowLoss(Loss):"""建模连续/离散化数据场景时所用的损失函数, 包括:-离散时间损失函数;-连续时间损失函数;-重构损失"""def __init__(self,bayesian_flow: CtsBayesianFlow,distribution_factory: Union[CtsDistributionFactory, DiscreteDistributionFactory],min_loss_variance: float = -1,noise_pred: bool = True,):super().__init__()self.bayesian_flow = bayesian_flow# 返回输出分布的工厂对象self.distribution_factory = distribution_factory# \sigma_1^{2} 的下限, 以防用作分母时溢出.self.min_loss_variance = min_loss_variance# -ln(\sigma_1)self.C = -0.5 * math.log(bayesian_flow.min_variance)# 是否预测噪声(亦或是直接预测数据)self.noise_pred = noise_predif self.noise_pred:self.distribution_factory.log_dev = False# 在预测噪声的情况下, 将预测的噪声(或噪声分布相关的参数)转换为对应数据分布(输出分布)的参数.self.distribution_factory = PredDistToDataDistFactory(self.distribution_factory, self.bayesian_flow.min_variance)

CW 在本系列的​​第二​​​、​​三篇​​文章中解析过,在建模连续和离散化数据时,模型预测的分别是高斯噪声 和 高斯噪声分布相关的参数:均值  和 对数标准差 ,因此需要对模型的输出进行一些后处理,以便将其转换为目标数据分布(输出分布)所对应的参数。

  • 连续时间的损失函数

def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:# 模型输出# reshape 成3维:(B, -1, D)output_params = sandwich(output_params)t = t.flatten(start_dim=1).float()flat_target = data.flatten(start_dim=1)# \sigma_1^{2t}posterior_var = torch.pow(self.bayesian_flow.min_variance, t)if self.min_loss_variance > 0:# 做最小值截断, 以防其作分母时防止溢出posterior_var = posterior_var.clamp(min=self.min_loss_variance)# 输出分布pred_dist = self.distribution_factory.get_dist(output_params, input_params, t)# 输出分布的均值 E[P(\theta, t)]pred_mean = pred_dist.meanmse_loss = (pred_mean - flat_target).square()# 连续时间的损失函数计算公式: -ln(\sigma_1) \sigma_1{-2t} || x - E[P(\theta, t)] ||^2loss = self.C * mse_loss / posterior_varreturn loss
  • 离散时间的损失函数

在离散时间的条件下,连续(continuous)数据的 KL loss 依然是 mse,只不过 scale 系数相比于连续时间的情况有些不同;而离散化(discretized)数据的就稍微复杂些了,由于没有解析形式,因此需要使用蒙特卡洛方法从发送者分布中进行采样去近似估计 KL 散度:

def discrete_time_loss(self, data: Tensor,output_params: Tensor, input_params: Tensor,t: Tensor, n_steps: int, n_samples=10) -> Tensor:# (B,-1,D)output_params = sandwich(output_params)t = t.flatten(start_dim=1).float()output_dist = self.distribution_factory.get_dist(output_params, input_params, t)# 离散化数据的场景if hasattr(output_dist, "probs"):  # output distribution is discretized normalt = t.flatten(start_dim=1)i = t * n_steps + 1  # since t = (i - 1) / nalpha = self.bayesian_flow.get_alpha(i, n_steps)flat_target = data.flatten(start_dim=1)# 发送者分布sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)# 因为使用蒙特卡洛方法来估计发送者分布与接收者分布之间的 KL 散度,所以要从发送者分布中采样观测样本 y,# 采样的样本数默认为10.y = sender_dist.sample(torch.Size([n_samples]))# 模型输出的分配到各离散化区间的概率值. #(B,D,K)receiver_mix_wts = sandwich(output_dist.probs)# 输出分布是类别分布, 在每个离散化区间都分配一定概率.= D.Categorical(probs=receiver_mix_wts, validate_args=False)# 以各离散化区间的中心为均值构造多个一维高斯分布,其中每个都与发送者分布的形式一致(噪声强度相等, 即方差一致).\receiver_components = D.Normal(output_dist.class_centres, (1.0 / alpha.sqrt()).unsqueeze(-1), validate_args=False)# 接收者分布, 在数据的每个维度上都是混合高斯分布.receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components, validate_args=False)# (B,1)loss = ((sender_dist.log_prob(y) - receiver_dist.log_prob(y))  # 发送者分布和接收者分布的概率密度对数差.mean(0)  # 在蒙特卡洛采样的样本数上做平均.flatten(start_dim=1).mean(1, keepdims=True))# 连续数据的场景else:  # output distribution is normalpred_mean = output_dist.meanflat_target = data.flatten(start_dim=1)mse_loss = (pred_mean - flat_target).square()i = t * n_steps + 1alpha = self.bayesian_flow.get_alpha(i, n_steps)loss = alpha * mse_loss / 2return n_steps * loss

代码咋一看有些复杂,但认真看后实际上你会发现还好,和论文中的公式是能够完美对上的,至于公式推导的细节,请参考本系列​​第二​​​、​​三篇​​文章。

在实现时需要特别注意的是在建模离散化数据时,输出分布在每一维上都是混合高斯分布,以上使用了 torch.distributions.MixtureSameFamily 来实现,其中每个子高斯分布 receiver_components 的权重就对应输出分布在每个离散化区间上所分配的概率,注意要用这批概率去实例化一个类别分布(torch.distributions.Categorical)对象 receiver_mix_dist 并传到 MixtureSameFamily 中,并且子高斯分布 receiver_components 的个数要和类别分布 receiver_mix_dist 的类别数一致。

  • 重构损失

def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:output_params = sandwich(output_params)flat_data = data.flatten(start_dim=1)# 重构损失只发生在最后时刻,于是 t=1.t = torch.ones_like(data).flatten(start_dim=1).float()output_dist = self.distribution_factory.get_dist(output_params, input_params, t)if hasattr(output_dist, "probs"):  # output distribution is discretized normalreconstruction_loss = -output_dist.log_prob(flat_data)else:  # output distribution is normal, but we use discretized normal to make results comparable (see Sec. 7.2)if self.bayesian_flow.min_variance == 1e-3:  # used for 16 bin CIFAR10noise_dev = 0.7 * math.sqrt(self.bayesian_flow.min_variance)num_bins = 16else:noise_dev = math.sqrt(self.bayesian_flow.min_variance)num_bins = 256mean = output_dist.mean.flatten(start_dim=1)final_dist = D.Normal(mean, noise_dev)# 离散化的正态分布final_dist = DiscretizedCtsDistribution(final_dist, num_bins, device=t.device, batch_dims=mean.ndim - 1)reconstruction_loss = -final_dist.log_prob(flat_data)return reconstruction_loss

离散数据的损失函数

针对离散数据的损失函数被封装在 DiscreteBayesianFlowLoss 这个类里,它也是 Loss 的子类,同样需要实现以上提到的三项 loss 的计算逻辑。

class DiscreteBayesianFlowLoss(Loss):def __init__(self,bayesian_flow: DiscreteBayesianFlow,distribution_factory: DiscreteDistributionFactory,):super().__init__()self.bayesian_flow = bayesian_flowself.distribution_factory = distribution_factory# 离散数据的输出分布建模为类别分布,这个变量就代表类别数量.self.K = self.bayesian_flow.n_classes
  • 连续时间的损失函数

建模离散数据时,连续时间的 KL loss 为:

(具体推导过程请见本系列​​第四篇文章​​)

其中作差的两项都是 one-hot 形式。

def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor:flat_output = sandwich(output_params)# 输出分布在各类别上分配的概率pred_probs = self.distribution_factory.get_dist(flat_output).probsflat_target = data.flatten(start_dim=1)if self.bayesian_flow.discretize:flat_target = float_to_idx(flat_target, self.K)tgt_mean = torch.nn.functional.one_hot(flat_target.long(), self.K)kl = self.K * ((tgt_mean - pred_probs).square()).sum(-1)t = t.flatten(start_dim=1).float()loss = t * (self.bayesian_flow.max_sqrt_beta**2) * klreturn loss
  • 离散时间的损失函数

(具体推导过程详见本系列​​第四篇文章​​)

def discrete_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10) -> Tensor:flat_target = data.flatten(start_dim=1)if self.bayesian_flow.discretize:flat_target = float_to_idx(flat_target, self.K)# 根据 t = \frac{i-1}{n} 反过来计算 i i = t * n_steps + 1# \alpha_ialpha = self.bayesian_flow.get_alpha(i, n_steps).flatten(start_dim=1)# (B,D,K)flat_output = sandwich(output_params)# 模型预测的在各个类别上的概率.receiver_mix_wts = self.distribution_factory.get_dist(flat_output).probs# 这里之所以要在倒数第2个维度上加一维是因为以下 components 在每个类别上的均值向量都是 K 维 one-hot,# 从而在每个类别上生成的是 K 个相互独立的正态分布. 总共有 K 类, 于是就有 K x K 个分布.# 因此这里增加维度是为了让 categorical 权重 与 components 对齐.receiver_mix_dist = D.Categorical(probs=receiver_mix_wts.unsqueeze(-2))# 增加2个维度是为了对应 batch dim: B 和 data dim: D.classes = torch.arange(self.K, device=flat_target.device).long().unsqueeze(0).unsqueeze(0)receiver_components = self.bayesian_flow.get_sender_dist(classes, alpha.unsqueeze(-1))# 接收者分布, 它是多个混合高斯分布的联合分布, 其中每个数据维度都是混合高斯分布.receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components)sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha)# 从发送者分布中采样, 以蒙特卡洛方法近似估计其与接收者分布之间的 KL lossy = sender_dist.sample(torch.Size([n_samples]))# (B,1)loss = n_steps * (sender_dist.log_prob(y) - receiver_dist.log_prob(y)).mean(0).sum(-1).mean(1, keepdims=True)return loss

在实现时,需要特别注意的是在构造混合高斯分布(也就是接收者分布)时,传到 receiver_mix_dist 中 receiver_mix_wts 的 shape 和 传给 receiver_components 的 classes 的 shape,具体细节 CW 都在以上做了注解,这里就不再重复阐述。

  • 重构损失

def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor:flat_outputs = sandwich(output_params)flat_data = data.flatten(start_dim=1)output_dist = self.distribution_factory.get_dist(flat_outputs)return -output_dist.log_prob(flat_data)

回顾论文,你会发现重构损失是在贝叶斯流分布的期望下计算的:

五、输出分布

在前面的代码中我们经常看到 distribution_factory 这个工厂对象的出现,它的大招就是返回输出分布,不同类型的输出分布会由对应类型的工厂对象返回。经过本系列前面几篇的理论铺垫,我们知道建模不同类型的数据所用的输出分布类型也是不同的,那么这一章就一起来扒扒这些分布的代码实现。

连续型与离散型分布

首先,不论是哪种分布,都可被归类为连续型分布或离散型分布,作者分别用了两个类来表示,它们作为其余分布的基类,里面包含了作为一个分布理应具备的一些属性与方法。

CONST_log_min = 1e-10def safe_log(data: Tensor):return data.clamp(min=CONST_log_min).log()class CtsDistribution:@abstractmethoddef log_prob(self, x):pass@abstractmethoddef sample(self):passclass DiscreteDistribution:@property@abstractmethoddef probs(self):pass@functools.cached_propertydef log_probs(self):return safe_log(self.probs)@functools.cached_propertydef mean(self):pass@functools.cached_propertydef mode(self):pass@abstractmethoddef log_prob(self, x):pass@abstractmethoddef sample(self):pass

离散化分布

另外还有一种比较基本的分布就是离散化(discretized)分布了,它代表将一个连续型分布离散化为离散型分布,于是它最终是离散型表示,从而继承了离散型分布 DiscreteDistribution。

class DiscretizedDistribution(DiscreteDistribution):def __init__(self, num_bins, device):# 离散区间数量: Kself.num_bins = num_bins# 原数据取值范围是[-1,1], 如今划分为 K 个区间, 因此每个区间宽度是 2/K.self.bin_width = 2.0 / num_binsself.half_bin_width = self.bin_width / 2.0self.device = device@functools.cached_propertydef class_centres(self):# 类别中心的取值范围: [-1 + 1/K, 1 - 1/K]return torch.arange(self.half_bin_width - 1, 1, self.bin_width, device=self.device)@functools.cached_propertydef class_boundaries(self):# 各类别之间的边界: [-1 + 2/K, 1 - 2/K], 共 K-1 个.return torch.arange(self.bin_width - 1, 1 - self.half_bin_width, self.bin_width, device=self.device)@functools.cached_propertydef mean(self):# 将各类别中心用它们各自所对应的概率加权求和: \sum_{k=1}^K{p_k * k_c}return (self.probs * self.class_centres).sum(-1)@functools.cached_propertydef mode(self):"""概率分布的 mode, 代表众数, 即概率最高处所对应的样本."""# 因为 class_centres 是1维的, 所以这里需要将索引展平.mode_idx = self.probs.argmax(-1).flatten()return self.class_centres[mode_idx].reshape(self.probs.shape[:-1])

上面那个类是离散化分布的父类,而要对一个连续型分布实现离散化,那么你得将它作为参数接收进来然后进行处理,于是就有了下面这个子类,它继承了上面那家伙。

class DiscretizedCtsDistribution(DiscretizedDistribution):"""将一个连续型分布离散化."""def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, min_prob=1e-5):super().__init__(num_bins, device)# 原来的连续型分布, 要对其进行离散化处理.self.cts_dist = cts_dist# log(2/K)self.log_bin_width = log(self.bin_width)# Bself.batch_dims = batch_dims# 是否要对原来连续型分布的 CDF 做截断.self.clip = clip# 用作概率的极小值self.min_prob = min_prob@functools.cached_propertydef probs(self):"""计算数据位于各离散区间的概率."""# shape: [K-1] + [1] * Bbdry_cdfs = self.cts_dist.cdf(self.class_boundaries.reshape([-1] + ([1] * self.batch_dims)))# shape: [1] + [1] * Bbdry_slice = bdry_cdfs[:1]if self.clip:'''对原来连续型分布的 CDF 做截断: 小于第一个区间的左端概率置0、小于等于最后一个区间右端的概率置1.'''cdf_min = torch.zeros_like(bdry_slice)cdf_max = torch.ones_like(bdry_slice)# shape: [K+1] + [1] * Bbdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)# 利用 CDF(k_r) - CDF(k_l) 得到位于各区间的概率.# shape: [1] * B + [K]return (bdry_cdfs[1:] - bdry_cdfs[:-1]).moveaxis(0, -1)else:'''以条件概率的思想来计算数据位于各区间的概率,其中的条件就是数据位于 [-1,1] 取值范围内.先计算原连续型分布在 1 和 -1 处的 CDF 值,将两者作差从而得到位于 [-1,1] 内的概率,以此作为条件对各区间的概率进行缩放.'''# CDF(-1)cdf_min = self.cts_dist.cdf(torch.zeros_like(bdry_slice) - 1)# CDF(1)cdf_max = self.cts_dist.cdf(torch.ones_like(bdry_slice))# shape: [K+1] + [1] * Bbdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0)# p_{-1 < x <= 1}cdf_range = cdf_max - cdf_mincdf_mask = cdf_range < self.min_prob# 当 cdf_range 小于就以 1 代替, 避免作为分母时造成结果溢出.cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)# shape: [K] + [1] * Bprobs = (bdry_cdfs[1:] - bdry_cdfs[:-1]) / cdf_range# 若整个 cdf_range 太小, 说明各区间的概率差异微不足道, 因此干脆将每个区间的概率都用 1/K 即均等的概率代替.probs = torch.where(cdf_mask, (probs * 0) + (1 / self.num_bins), probs)# shape: [1] * B + [K]return probs.moveaxis(0, -1)def prob(self, x):# 区间索引 k \in [0, K-1]class_idx = float_to_idx(x, self.num_bins)# 区间中心 k_ccentre = idx_to_float(class_idx, self.num_bins)# CDF(k_l), 其中 k_l 代表区间左端点.cdf_lo = self.cts_dist.cdf(centre - self.half_bin_width)# CDF(k_r), 其中 k_r 代表区间右端点.cdf_hi = self.cts_dist.cdf(centre + self.half_bin_width)if self.clip:'''对原来连续型分布的 CDF 做截断, 使得:CDF(k <= 0) = 0;CDF(k >= K-1) = 1'''cdf_lo = torch.where(class_idx <= 0, torch.zeros_like(centre), cdf_lo)cdf_hi = torch.where(class_idx >= (self.num_bins - 1), torch.ones_like(centre), cdf_hi)return cdf_hi - cdf_loelse:'''以条件概率的思想来计算数据位于某个离散区间内的概率,其中的条件就是数据位于 [-1,1] 取值范围内.先计算原连续型分布在 1 和 -1 处的 CDF 值,将两者作差从而得到位于 [-1,1] 内的概率,以此作为条件对区间的概率进行缩放.'''cdf_min = self.cts_dist.cdf(torch.zeros_like(centre) - 1)cdf_max = self.cts_dist.cdf(torch.ones_like(centre))cdf_range = cdf_max - cdf_min# 若 cdf_range 太小,则设置 mask,并将其以1代替,即不对区间的概率进行缩放, 否则会使得计算出来的采样概率非常接近于1.# 两个非常小的值相除, 由于它们都很小、非常接近,因此商接近于1.cdf_mask = cdf_range < self.min_probcdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range)prob = (cdf_hi - cdf_lo) / cdf_range# 若整个 cdf_range 太小, 说明各区间的概率差异微不足道, 因此干脆将区间的概率都用 1/K 即均等的概率代替.return torch.where(cdf_mask, (prob * 0) + (1 / self.num_bins), prob)def log_prob(self, x):prob = self.prob(x)return torch.where(prob < self.min_prob,# 将 x 以对应区间的中点 k_c 表示并计算出其在原来连续分布中的对数概率密度: log(p(k_c)).# 这里加上 log(2/K) 相当于将 k_c 乘以 2/K 再取对数.self.cts_dist.log_prob(quantize(x, self.num_bins)) + self.log_bin_width,safe_log(prob),)def sample(self, sample_shape=torch.Size([])):if self.clip:# 直接从原来的连续型分布中采样, 然后将其量化至对应的离散化区间.# 此处, clip 的意思是:# 若小于第一个区间,则以第一个区间中点表示;# 同理,若大于最后一个区间,则以最后一个区间的中点表示.return quantize(self.cts_dist.sample(sample_shape), self.num_bins)else:# 要求原来连续型分布的 CDF 存在反函数, 即可以根据概率值逆向求出对应的样本.assert hasattr(self.cts_dist, "icdf")# 数据的取值范围是 [-1,1], 先根据原来的连续型分布计算出 CDF(-1) 和 CDF(1),# 然后利用 CDF 的反函数仅在这个 range 内考虑采样.cdf_min = self.cts_dist.cdf(torch.zeros_like(self.cts_dist.mean) - 1)cdf_max = self.cts_dist.cdf(torch.ones_like(cdf_min))# 由于 CDF 是服从均匀分布的, 因此从均匀分布中采样出 CDF 值并利用反函数求出对应样本就等价于从目标分布中采样.u = Uniform(cdf_min, cdf_max, validate_args=False).sample(sample_shape)cts_samp = self.cts_dist.icdf(u)# 最后将样本量化至对应的离散化区间.# 注意, 与前面 clip 的方式不同, 此处在量化前样本已经处于有效的离散化区间内了, 因为采样区间是在[-1,1]内考虑的.return quantize(cts_samp, self.num_bins)

别被吓倒,虽然代码看起来很复杂,但 CW 已经在上面做了详细的注解,结合对数据进行离散化的知识,理解上面的代码应该是 NO problem 的!

(关于离散化操作的知识背景,可参考本系列​​第三篇​​文章)

Discretized Normal Distribution

在建模离散化数据时,输出分布是离散的正态分布:

以下就是这个离散正态分布的实现,它继承了前面展示的 DiscretizedCtsDistributon 类。得益于它的爸爸(帮它将核心功能都实现完了),你可以看到这个类非常躺平(实现得非常简单)~

CONST_exp_range = 10def safe_exp(data: Tensor):return data.clamp(min=-CONST_exp_range, max=CONST_exp_range).exp()class DiscretizedNormal(DiscretizedCtsDistribution):def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):assert params.size(-1) == 2if min_std_dev < 0:min_std_dev = 1.0 / (num_bins * 5)mean, std_dev = params.split(1, -1)[:2]if log_dev:# 若传入的是对数标准差, 那么此处就需要取自然指数进行还原.std_dev = safe_exp(std_dev)std_dev = std_dev.clamp(min=min_std_dev, max=max_std_dev)super().__init__(cts_dist=Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False),num_bins=num_bins,device=params.device,# 注意所谓的 batch dims 并非指数据的 batch size,# 而是除离散化区间数量以外与分布本身关系不大的其它维度.batch_dims=params.ndim - 1,clip=clip,min_prob=min_prob,)

Delta Distribution

建模连续数据时,输出分布为 Delta 分布,它像是个“山寨分布”似的,只有一个单点,实现起来比起前面那位离散的正态分布,有过之而无不及,人家靠的是爸爸才躺平,而它仅靠自己也很躺..

class DeltaDistribution(CtsDistribution):def __init__(self, mean, clip_range=1.0):if clip_range > 0:mean = mean.clip(min=-clip_range, max=clip_range)self.mean = meandef mode(self):return self.meandef mean(self):return self.meandef sample(self, sample_shape=torch.Size([])):return self.mean

既然你这么躺那么我 CW 也小躺一下——注解我就不做了(傲娇脸)。

Bernoulli Distribution

对于二值的离散数据,输出分布为伯努利分布,作者在 MNIST 的实验中就是这么玩的,他对 MNIST 数据集进行了二值化处理(关于二值化的实现 CW 在后文会详细解析),使其变身为动态二值化(dynamically binarized)的 MNIST。

from torch.distributions.bernoulli import Bernoulli as torch_Bernoulliclass Bernoulli(DiscreteDistribution):def __init__(self, logits):self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False)@functools.cached_propertydef probs(self):p = self.bernoulli.probs.unsqueeze(-1)return torch.cat([1 - p, p], -1)@functools.cached_propertydef mode(self):return self.bernoulli.modedef log_prob(self, x):return self.bernoulli.log_prob(x.float())def sample(self, sample_shape=torch.Size([])):return self.bernoulli.sample(sample_shape)

从上面的代码可以看出,这个分布实现起来也非常 easy,主要靠的是 Pytorch 内置的实现,只不过 torch 的实现在返回 probs 时仅返回了单类概率,作者则额外补充了剩下一类的概率并将两者拼接起来返回。

Categorical Distribution

当面对多类别的离散数据时,输出分布理应就是类别分布了,作者对于 text8 数据集的建模就采取了这种玩法。

from torch.distributions.categorical import Categorical as torch_Categoricalclass Categorical(DiscreteDistribution):def __init__(self, logits):self.categorical = torch_Categorical(logits=logits, validate_args=False)self.n_classes = logits.size(-1)@functools.cached_propertydef probs(self):return self.categorical.probs@functools.cached_propertydef mode(self):return self.categorical.modedef log_prob(self, x):return self.categorical.log_prob(x)def sample(self, sample_shape=torch.Size([])):return self.categorical.sample(sample_shape)

这里的实现完全依赖了 Pytorch 内置的实现,只不过额外记录了类别数 n_classes 这一属性。

好家伙,在建模离散数据时居然靠着开源的力量躺平~!

Distribution Factory

前面提到,输出分布是由对应的工厂对象制作出来的,以下就是各种工厂类的实现,它们都分别继承连续型分布的工厂类或离散型分布的工厂类,这两个类都是抽象基类,定义了子类必须实现的方法 get_dist() —— 返回一个指定参数的分布。

class CtsDistributionFactory:@abstractmethoddef get_dist(self, params: torch.Tensor, input_params=None, t=None) -> CtsDistribution:"""Note: input_params and t are not used but kept here to be consistency with DiscreteDistributionFactory."""passclass DiscreteDistributionFactory:@abstractmethoddef get_dist(self, params: torch.Tensor, input_params=None, t=None) -> DiscreteDistribution:"""Note: input_params and t are only required by PredDistToDataDistFactory."""passclass DiscretizedNormalFactory(DiscreteDistributionFactory):def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True):self.num_bins = num_binsself.clip = clipself.min_std_dev = min_std_devself.max_std_dev = max_std_devself.min_prob = min_probself.log_dev = log_devdef get_dist(self, params, input_params=None, t=None):return DiscretizedNormal(params,num_bins=self.num_bins,clip=self.clip,min_std_dev=self.min_std_dev,max_std_dev=self.max_std_dev,min_prob=self.min_prob,log_dev=self.log_dev,)class DeltaFactory(CtsDistributionFactory):def __init__(self, clip_range=1.0):self.clip_range = clip_rangedef get_dist(self, params, input_params=None, t=None):return DeltaDistribution(params.squeeze(-1), self.clip_range)class BernoulliFactory(DiscreteDistributionFactory):def get_dist(self, params, input_params=None, t=None):return Bernoulli(logits=params.squeeze(-1))class CategoricalFactory(DiscreteDistributionFactory):def get_dist(self, params, input_params=None, t=None):return Categorical(logits=params)
  • 将噪声分布转换为数据分布

前面在解析连续和离散化数据的损失函数 CtsBayesianFlowLoss 时,有以下这样一段代码:

class CtsBayesianFlowLoss(Loss):"""建模连续/离散化数据场景时所用的损失函数, 包括:-离散时间损失函数;-连续时间损失函数;-重构损失"""def __init__(self,bayesian_flow: CtsBayesianFlow,distribution_factory: Union[CtsDistributionFactory, DiscreteDistributionFactory],min_loss_variance: float = -1,noise_pred: bool = True,):  ...  # 此处省略一大段# 是否预测噪声(亦或是直接预测数据)self.noise_pred = noise_predif self.noise_pred:self.distribution_factory.log_dev = False# 在预测噪声的情况下, 将预测的噪声(或噪声分布相关的参数)转换为对应数据分布(输出分布)的参数.self.distribution_factory = PredDistToDataDistFactory(self.distribution_factory, self.bayesian_flow.min_variance)

也就是说,当模型输出(预测)的是噪声变量或噪声分布的参数时,需要将其转换为对应生成的目标数据或目标数据分布(输出分布)所对应的参数。 而前面已经说过,目标数据分布都是由对应的工厂类制造的,于是这个转换过程就由工厂类去实现,这个工厂类就是 PredDistToDataDistFactory。

class PredDistToDataDistFactory(DiscreteDistributionFactory):def __init__(self, data_dist_factory, min_variance, min_t=1e-6):self.data_dist_factory = data_dist_factory# 之所以设为 False 是因为在以下 noise_pred_params_to_data_pred_params() 方法中会将对数标准差使用自然指数进行转换,# 而无需原数据分布的工厂自行转换.self.data_dist_factory.log_dev = Falseself.min_variance = min_varianceself.min_t = min_tdef get_dist(self, params, input_params, t):data_pred_params = noise_pred_params_to_data_pred_params(params, input_params[0], t, self.min_variance, self.min_t)return self.data_dist_factory.get_dist(data_pred_params)

可以看到,它将目标数据分布的工厂对象 data_dist_factory 作为属性,并依靠后者来返回目标数据分布,只不过预先调用了一个将噪声分布相关的参数转换为数据分布相关参数的方法 noise_pred_params_to_data_pred_params()。既然如此,我们就顺藤摸瓜地深入到这个方法中去寻找真理叭~

哦,在看代码之前,还是一起先来回顾下在建模连续和离散化数据时噪声(分布)转换为目标数据分布(即输出分布)的公式,以便和接下来的代码进行对照。

在建模连续数据时,模型预测(输出)的是单个噪声变量(其实是噪声分布的均值向量),对应转换成单点的输出分布(Delta 分布):

(对于上面公式的具体推导过程可参考本系列第二、三篇文章)     

def noise_pred_params_to_data_pred_params(noise_pred_params: torch.Tensor, input_mean: torch.Tensor,t: torch.Tensor, min_variance: float, min_t=1e-6
):"""Convert output parameters that predict the noise added to data, to parameters that predict the data.将模型预测的噪声分布的参数转换为数据分布的参数."""# (B,L,D)data_shape = list(noise_pred_params.shape)[:-1]# (B,L*D,NP), NP: num parameters per datanoise_pred_params = sandwich(noise_pred_params)# (B,L*D)input_mean = input_mean.flatten(start_dim=1)if torch.is_tensor(t):t = t.flatten(start_dim=1)else:t = (input_mean * 0) + t# (B,L*D,1)alpha_mask = (t < min_t).unsqueeze(-1)# \sigma_1^{2t}posterior_var = torch.pow(min_variance, t.clamp(min=min_t))# \gamma(t) = 1 - \sigma_1^{2t}gamma = 1 - posterior_var# \frac{\mu}{\gamma(t)}A = (input_mean / gamma).unsqueeze(-1)# \sqrt{\frac{1-\gamma(t)}{\gamma(t)}}B = (posterior_var / gamma).sqrt().unsqueeze(-1)data_pred_params = []# 对应建模连续数据的场景: 模型预测的是噪声向量.if noise_pred_params.size(-1) == 1:noise_pred_mean = noise_pred_params# 对应建模离散化数据的场景: 模型预测的是噪声分布的均值与对数标准差. elif noise_pred_params.size(-1) == 2:noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(2, -1)else:assert noise_pred_params.size(-1) % 3 == 0mix_wt_logits, noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(3, -1)data_pred_params.append(mix_wt_logits)# 连续数据: x = \frac{\mu}{\gamma(t)} - \sqrt{\frac{1-\gamma(t)}{\gamma(t)}} \epsilon# 离散化数据: \mu_{x} = \frac{\mu}{\gamma(t)} - \sqrt{\frac{1-\gamma(t)}{\gamma(t)}} \mu_{\epsilon}data_pred_mean = A - (B * noise_pred_mean)# 时间变量的值过小则被认为是起始时刻, 等同于先验形式, 即标准高斯分布, 于是将预测的均值置0data_pred_mean = torch.where(alpha_mask, 0 * data_pred_mean, data_pred_mean)data_pred_params.append(data_pred_mean)if noise_pred_params.size(-1) >= 2:# 将对数标准差取自然指数复原: exp(ln(\sigma_{\epsilon})) -> \sigma_{\epsilon}noise_pred_dev = safe_exp(noise_pred_log_dev)# 将噪声分布的标准差转换为目标数据分布的标准差: \sqrt{\frac{1-\gamma(t)}{\gamma(t)}} exp(ln(\sigma_{\epsilon})) -> \mu_xdata_pred_dev = B * noise_pred_dev# 时间变量的值过小则被认为是起始时刻, 等同于先验形式, 即标准高斯分布, 于是将预测的标准差置1data_pred_dev = torch.where(alpha_mask, 1 + (0 * data_pred_dev), data_pred_dev)data_pred_params.append(data_pred_dev)# (B,L*D,NP)data_pred_params = torch.cat(data_pred_params, -1)# (B,L,D,NP)data_pred_params = data_pred_params.reshape(data_shape + [-1])return data_pred_params

六、神经网络的实现

BFN 更多的是一种方法论,而与具体的模型架构无关,于是对神经网络的架构是没有限制的,作者一共实现了3种模型:UNet-VDM(变分扩散模型)、UNet 以及 GPT(参考了 Karpathy 大神的 nano-GPT)。CW 比较懒,实在不愿意将3种模型全部写了,所以干脆就挑最新的 VDM 在这里吹吹水吧~

Model

@torch.no_grad()
def zero_init(module: nn.Module) -> nn.Module:"""Sets to zero all the parameters of a module, and returns the module."""for p in module.parameters():nn.init.zeros_(p.data)return moduleclass UNetVDM(nn.Module):def __init__(self,data_adapters,embedding_dim: int = 128,n_blocks: int = 32,n_attention_heads: int = 1,dropout_prob: float = 0.1,norm_groups: int = 32,input_channels: int = 3,use_fourier_features: bool = True,attention_everywhere: bool = False,image_size: int = 32,):super().__init__()# 对输入进行前置处理, 比如加入位置编码.self.input_adapter = data_adapters["input_adapter"]# 将输出转换为目标 形式, 通常是将维度数 project 到指定数.self.output_adapter = data_adapters["output_adapter"]attention_params = dict(n_heads=n_attention_heads,n_channels=embedding_dim,norm_groups=norm_groups,)resnet_params = dict(ch_in=embedding_dim,ch_out=embedding_dim,condition_dim=4 * embedding_dim,dropout_prob=dropout_prob,norm_groups=norm_groups,)self.embed_conditioning = nn.Sequential(nn.Linear(embedding_dim, embedding_dim * 4),nn.SiLU(),nn.Linear(embedding_dim * 4, embedding_dim * 4),nn.SiLU(),)total_input_ch = input_channelsif use_fourier_features:self.fourier_features = FourierFeatures()# C = (2F + 1)C, 其中 2F 代表傅里叶特征数(sin & cos 各占 F).# 经过傅里叶特征变换所输出的通道数为 2FC, 而这部分特征会和原特征拼接起来,# 于是通道数总共就为 (2F+1)C.total_input_ch *= 1 + self.fourier_features.num_featuresself.conv_in = nn.Conv2d(total_input_ch, embedding_dim, 3, padding=1)# Down path: n_blocks blocks with a resnet block and maybe attention.self.down_blocks = nn.ModuleList(# 注意, 实际上并没有下采样, 分辨率保持不变.UpDownBlock(resnet_block=ResnetBlock(**resnet_params),attention_block=AttentionBlock(**attention_params) if attention_everywhere else None,)for _ in range(n_blocks))self.mid_resnet_block_1 = ResnetBlock(**resnet_params)self.mid_attn_block = AttentionBlock(**attention_params)self.mid_resnet_block_2 = ResnetBlock(**resnet_params)# Up path: n_blocks+1 blocks with a resnet block and maybe attention.resnet_params["ch_in"] *= 2  # double input channels due to skip connectionsself.up_blocks = nn.ModuleList(# 注意, 实际上并没有上采样, 分辨率保持不变.UpDownBlock(resnet_block=ResnetBlock(**resnet_params),attention_block=AttentionBlock(**attention_params) if attention_everywhere else None,)for _ in range(n_blocks + 1))self.conv_out = nn.Sequential(nn.GroupNorm(num_groups=norm_groups, num_channels=embedding_dim),nn.SiLU(),# 将最后的输出卷积层初始化为全0.zero_init(nn.Conv2d(embedding_dim, embedding_dim, 3, padding=1)),)self.embedding_dim = embedding_dimself.input_channels = input_channelsself.image_size = image_sizeself.use_fourier_features = use_fourier_featuresdef forward(self,data: torch.Tensor,t: torch.Tensor,) -> torch.Tensor:# (B,H*W,C)flat_x = self.input_adapter(data, t)# (B,H,W,C)x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.input_channels)# (B,) 因为同一个数据样本在各维度上所对应的时间变量一致, 所以只需要取同的样本的其中1个维度即可.t = t.float().flatten(start_dim=1)[:, 0]# (B,D) 这里 + 0.001 代表小于 0.001 即看作是起始时刻(因此起始时刻不为0), 与 paper 中的描述一致.t_embedding = get_timestep_embedding(t + 0.001, self.embedding_dim)# We will condition on time embedding.# (B,4D)cond = self.embed_conditioning(t_embedding)# (B,C,H,W)x_perm = x.permute(0, 3, 1, 2).contiguous()# 若设定了要使用傅里叶特征, 则将傅里叶特征拼接过来.# (B,(2F+1)C,H,W), 其中 2FC 是傅里叶特征变换模块输出的通道数.h = self.maybe_concat_fourier(x_perm)# (B,D,H,W)h = self.conv_in(h)hs = [h]for down_block in self.down_blocks:  # n_blocks timesh = down_block(h, cond)hs.append(h)h = self.mid_resnet_block_1(h, cond)h = self.mid_attn_block(h)h = self.mid_resnet_block_2(h, cond)for up_block in self.up_blocks:  # n_blocks+1 timesh = torch.cat([h, hs.pop()], dim=1)h = up_block(h, cond)# (B,H*W,D)# 这个最后的卷积层初始化为全0, 因此在参数更新前这个输出特征不起作用,# 于是以下才将网络的输入也一并拼接在一起再输入到最后的 linear projection.out = sandwich(self.conv_out(h).permute(0, 2, 3, 1).contiguous())# (B,H*W,C+D)out = torch.cat([sandwich(x), out], -1)# (B,H*W,out_channels,out_height)out = self.output_adapter(out)return outdef maybe_concat_fourier(self, z):if self.use_fourier_features:return torch.cat([z, self.fourier_features(z)], dim=1)return z

代码也比较好懂,模型的编解码方式遵循 U-Net 的玩法,但在这里却没有真正地进行上、下采样,分辨率一直是保持不变的。 另外,在编解码的同时加入了时间变量(将其处理为 embeddings),以使得模型对于时间拥有感知能力。至于提取特征的基本组件也是老套路了:resnet blcok & self attention。

与常规 VDM 相比,这里的实现有几点比较特殊:

Input Aadapter

作者实现了两种 input adpater,分别是用于语言的 TextInputAdapter 和用于图像的 FourierImageInputAdapter,两者的实质其实都是制作 position embeddings 和 time embeddings 并且附加(element-wise add or concat)在原输入变量上。但这里的 time embeddings 和后文即将展示的 get_timestep_embedding() 中的概念不同,这里主要是对时间变量进行 scale(从而将其取值范围从 [0,1] 缩放至 [-1,1],与输入数据一致),而并不一定对它再进行额外的 projection。

TextInputAdapter 中的 position embeddings 是我们比较熟悉的方式:可学习的 embeddings 或 正弦位置编码;而 FourierImageInputAdapter 在可学习的 embeddings 之余还可能使用傅里叶位置编码,如其名。

def pe_encode(sequence_length: int, embedding_size: int) -> Tensor:"""Positional encoding as described in original attention is all you need paper"""pos = torch.arange(sequence_length).unsqueeze(1)pe = torch.zeros((sequence_length, embedding_size))    pe[:, 0::2] = torch.sin(pos / torch.pow(1000, torch.arange(0, embedding_size, 2, dtype=torch.float32) / embedding_size))pe[:, 1::2] = torch.cos(pos / torch.pow(1000, torch.arange(1, embedding_size, 2, dtype=torch.float32) / embedding_size))return peclass TextInputAdapter(nn.Module):"""A module to convert sequences of text class tokens to embedding tokens with learned positional embeddings."""def __init__(self,vocab_size: int,seq_len: int,output_size: int = 256,learn_pos_embedding: bool = False,):super().__init__()self.learn_pos_embedding = learn_pos_embeddingif learn_pos_embedding:self.pos_embedding = nn.Embedding(seq_len, output_size)else:self.register_buffer("pos_embedding", pe_encode(seq_len, output_size))self.inp_embedding = nn.Linear(vocab_size, output_size)self.t_embedding = nn.Linear(1, output_size)def forward(self, probs: torch.Tensor, t: torch.Tensor) -> Tensor:# 将概率值从[0,1]范围缩放至[-1,1]inp_emb = self.inp_embedding(2 * probs - 1)if self.learn_pos_embedding:pos_emb = self.pos_embedding(torch.arange(0, probs.size(1)).to(probs.device))else:pos_emb = self.pos_embedding# (B,L,output_size)pos_emb = pos_emb.unsqueeze(0).expand(inp_emb.size(0), -1, -1)# 同样将时间变量的范围从[0,1]缩放至[-1,1]t_emb = self.t_embedding((2 * t - 1).unsqueeze(-1))output = inp_emb + pos_emb + t_embreturn outputdef pe_encode_float(x: Tensor, max_freq: float, embedding_size: int) -> Tensor:pos = (((x + 1) / 2) * max_freq).unsqueeze(-1)pe = torch.zeros(list(x.shape) + [embedding_size], device=x.device)pe[..., 0::2] = torch.sin(pos/ torch.pow(10000, torch.arange(0, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size))pe[..., 1::2] = torch.cos(pos/ torch.pow(10000, torch.arange(1, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size))return peclass FourierImageInputAdapter(nn.Module):"""A module to convert 2D image coordinates into a set of vectors represented as a matrix, with fourier position codes."""def __init__(self,input_channels: int = 3,input_shape: Tuple[int, int] = (224, 224),n_freq_bands: int = 64,output_height: int = 256,value_res: int = -1,mask_res: int = -1,add_pos_feats: bool = True,add_mask: bool = True,learn_pos_feats: bool = False,pos_embed_size: int = 32,init_scale: float = 0.02,):super().__init__()self.input_shape = input_shapeself.n_freq_bands = n_freq_bandsself.value_res = value_resself.mask_res = mask_resself.add_pos_feats = add_pos_featsself.add_mask = add_maskif learn_pos_feats:pos_feats = nn.Parameter(init_scale* torch.randn(1, input_shape[0] * input_shape[1], pos_embed_size))self.register_parameter("pos_feats", pos_feats)else:x = torch.linspace(-1.0, 1.0, steps=input_shape[0])y = torch.linspace(-1.0, 1.0, steps=input_shape[1])# (input_shape[0],input_shape[1])x_pos, y_pos = torch.meshgrid(x, y, indexing="ij")# (input_shape[0],input_shape[1],2)pos = torch.stack((x_pos, y_pos), dim=-1)# (L=H*W,2)pos = pos.reshape(-1, 2)x_bands = torch.linspace(1.0, input_shape[0] / 2, steps=n_freq_bands)y_bands = torch.linspace(1.0, input_shape[1] / 2, steps=n_freq_bands)# (2,n_freq_bands)bands = torch.stack((x_bands, y_bands), dim=0)# (L,2,n_freq_bands)vals = pos[:, :, None] * bands[None, :, :]# (L,2*n_freq_bands)vals = math.pi * vals.reshape(vals.shape[0], -1)# (L,4*n_freq_bands)pos_feats = torch.cat([vals.sin(), vals.cos()], dim=-1)# (L,4*n_freq_bands+2)pos_feats = torch.cat([pos_feats, pos], dim=-1)self.register_buffer("pos_feats", pos_feats)img_feat_height = input_channelspos_feat_height = pos_feats.size(-1)if self.mask_res > 0:mask_feat_height = (n_freq_bands * 2) + 1else:mask_feat_height = 1all_feat_height = img_feat_heightif add_mask:all_feat_height += mask_feat_heightif add_pos_feats:all_feat_height += pos_feat_heightself.output_projection = Noneif output_height != all_feat_height:self.output_projection = nn.Linear(all_feat_height, output_height)def forward(self, img: Tensor, t: Tensor) -> Tensor:# (B,H*W,C)flat_img = sandwich(img)# (B,H*W,C)flat_t = sandwich(t)# [0,1] -> [-1,1]t_feats = (flat_t.float()[..., :1] * 2) - 1if self.mask_res > 0:t_feats = torch.cat([t_feats,pe_encode_float(t_feats, self.mask_res, self.n_freq_bands * 2).flatten(start_dim=2),],-1,)# (B, H*W, )fourier_feats = self.pos_feats.expand(img.size(0), -1, -1)all_feat_list = [flat_img]if self.add_mask:all_feat_list.append(t_feats)if self.add_pos_feats:all_feat_list.append(fourier_feats)all_feats = torch.cat(all_feat_list, dim=-1)if self.output_projection is None:output = all_featselse:output = self.output_projection(all_feats)return output

Output Adapter

output adapter 实质上就是 projection layer(Linear 层),用于将特征的最后一维映射至指定数目,以满足指定的输出分布形式。

class OutputAdapter(nn.Module):def __init__(self, input_height: int, output_channels: int, output_height: int):super().__init__()self.output_channels = output_channelsself.output_height = output_heightself.output_projection = nn.Linear(input_height, output_channels * output_height)def forward(self, inp: torch.Tensor) -> torch.Tensor:output = self.output_projection(inp)return output.reshape(output.size(0), -1, self.output_channels, self.output_height)

Time Embedding

最后,再将序列分别送进正余弦函数后拼接起来从而恢复为原来 embedding 的维度。

def get_timestep_embedding(timesteps,embedding_dim: int,dtype=torch.float32,max_timescale=10_000,min_timescale=1,
):"""正弦位置编码, 相当于将时间变量的值看作是位置."""# Adapted from tensor2tensor and VDM codebase.assert timesteps.ndim == 1assert embedding_dim % 2 == 0num_timescales = embedding_dim // 2# num_timescales 个等比元素, 由 1/min_timescale 到 1/max_timescale(包含).# logspace 的底默认为 10, 其输入的前两个参数代表起始和终止的幂inv_timescales = torch.logspace(  # or exp(-linspace(log(min), log(max), n))-np.log10(min_timescale),-np.log10(max_timescale),num_timescales,device=timesteps.device,)timesteps *= 1000.0  # In DDPM the time step is in [0, 1000], here [0, 1]emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :]  # (T, D/2)# sin(t * \frac{1}{10000^{i/d}}), cos(t * \frac{1}{10000^{i/d}})return torch.cat([emb.sin(), emb.cos()], dim=1)  # (T, D)

Others

剩下的 modules 主要包括:负责提取傅里叶特征的 FourierFeatures 、实现自注意力的 Attention、ResNet 的套路 ResnetBlock 以及将 ResnetBlock & Attention 放在一起玩以模仿 UNet 但实际并未进行上下采样的 UpDownBlock。这里就不再逐一详细解析了,直接看代码就能 get 到对应的意思。

  • FourierFeatures
class FourierFeatures(nn.Module):def __init__(self, first=5.0, last=6.0, step=1.0):super().__init__()self.freqs_exponent = torch.arange(first, last + 1e-8, step)@propertydef num_features(self):return len(self.freqs_exponent) * 2def forward(self, x):assert len(x.shape) >= 2# Compute (2pi * 2^n) for n in freqs.freqs_exponent = self.freqs_exponent.to(dtype=x.dtype, device=x.device)  # (F, )freqs = 2.0**freqs_exponent * 2 * pi  # (F, )freqs = freqs.view(-1, *([1] * (x.dim() - 1)))  # (F, 1, 1, ...)# Compute (2pi * 2^n * x) for n in freqs.features = freqs * x.unsqueeze(1)  # (B, F, X1, X2, ...)features = features.flatten(1, 2)  # (B, F * C, X1, X2, ...)# Output features are cos and sin of above. Shape (B, 2 * F * C, H, W).return torch.cat([features.sin(), features.cos()], dim=1)
  • Attention
def attention_inner_heads(qkv, num_heads):"""Computes attention with heads inside of qkv in the channel dimension.Args:qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where:H = number of heads,C = number of channels per head.num_heads: number of heads.Returns:Attention output of shape (B, H*C, T)."""bs, width, length = qkv.shapech = width // (3 * num_heads)# Split into (q, k, v) of shape (B, H*C, T).q, k, v = qkv.chunk(3, dim=1)# 对 Q, K 各自缩放 1/d^{1/4} 相当于 Q, K 矩阵相乘后的结果缩放了 1/(\sqrt{d})# Rescale q and k. This makes them contiguous in memory.scale = ch ** (-1 / 4)  # scale with 4th root = scaling output by sqrtq = q * scalek = k * scale# Reshape qkv to (B*H, C, T).new_shape = (bs * num_heads, ch, length)q = q.view(*new_shape)k = k.view(*new_shape)v = v.reshape(*new_shape)# Compute attention.weight = einsum("bct,bcs->bts", q, k)  # (B*H, T, T)weight = softmax(weight.float(), dim=-1).to(weight.dtype)  # (B*H, T, T)out = einsum("bts,bcs->bct", weight, v)  # (B*H, C, T)return out.reshape(bs, num_heads * ch, length)  # (B, H*C, T)class Attention(nn.Module):"""Based on https://github.com/openai/guided-diffusion."""def __init__(self, n_heads):super().__init__()self.n_heads = n_headsdef forward(self, qkv):assert qkv.dim() >= 3, qkv.dim()assert qkv.shape[1] % (3 * self.n_heads) == 0spatial_dims = qkv.shape[2:]qkv = qkv.view(*qkv.shape[:2], -1)  # (B, 3*n_heads*C, T)out = attention_inner_heads(qkv, self.n_heads)  # (B, n_heads*C, T)return out.view(*out.shape[:2], *spatial_dims).contiguous()class AttentionBlock(nn.Module):"""Self-attention residual block."""def __init__(self, n_heads, n_channels, norm_groups):super().__init__()assert n_channels % n_heads == 0self.layers = nn.Sequential(nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels),# 之所以将通道数扩展3倍是因为后续要输入到 Attention 模块, 为 Q, K ,V 各分配数量一致的通道数.nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1),  # (B, 3 * C, H, W)Attention(n_heads),# 输出卷积层初始化为全0,因此在参数更新前这部分输出特征相当于不起作用.zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)),)def forward(self, x):return self.layers(x) + x
  • ResnetBlock
class ResnetBlock(nn.Module):def __init__(self,ch_in,ch_out=None,condition_dim=None,dropout_prob=0.0,norm_groups=32,):super().__init__()ch_out = ch_in if ch_out is None else ch_outself.ch_out = ch_outself.condition_dim = condition_dimself.net1 = nn.Sequential(nn.GroupNorm(num_groups=norm_groups, num_channels=ch_in),nn.SiLU(),nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1),)if condition_dim is not None:self.cond_proj = zero_init(nn.Linear(condition_dim, ch_out, bias=False))self.net2 = nn.Sequential(nn.GroupNorm(num_groups=norm_groups, num_channels=ch_out),nn.SiLU(),nn.Dropout(dropout_prob),zero_init(nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)),)if ch_in != ch_out:self.skip_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1)def forward(self, x, condition):h = self.net1(x)if condition is not None:assert condition.shape == (x.shape[0], self.condition_dim)# 这个条件映射层(全连接层)初始化为全0, 因此在参数更新前条件变量不起作用.condition = self.cond_proj(condition)# (B,D,1,1)condition = condition[:, :, None, None]h = h + conditionh = self.net2(h)if x.shape[1] != self.ch_out:x = self.skip_conv(x)assert x.shape == h.shapereturn x + h
  • UpDownBlock
class UpDownBlock(nn.Module):def __init__(self, resnet_block, attention_block=None):super().__init__()self.resnet_block = resnet_blockself.attention_block = attention_blockdef forward(self, x, cond):x = self.resnet_block(x, cond)if self.attention_block is not None:x = self.attention_block(x)return x

七、数据加载与预处理

接下来将对数据集加载及预处理的代码实现进行解析,作者在实验中使用的数据集有3种:CIFAR-10、MNIST 以及 TEXT8,除后者是 NLP 玩家专享之外,前两者都是 CV 玩家们的宝宝。

至于我将重点解析的预处理,则都是针对 CV 数据集的,主要是对于 CIFAR-10(RGB 彩图) 的离散化操作 和 专门为 MNIST(单通道灰度图) 设计的动态二值化操作。

数据集加载

CV 的数据集用的是 torchvision 自带的,由于这里不需要图像标签,因此只返回图像本身,这点可从以下 CIFAR10 和 MNIST 的 __getitem__() 方法里看出。至于 TEXT8 数据集,则是从 URL 下载下来后再做一些字符串处理操作。

import numpy as npimport torch
import torchvisionfrom torchvision import transforms
from torch.utils.data import Dataset, random_splitclass MyLambda(torchvision.transforms.Lambda):def __init__(self, lambd, arg1):super().__init__(lambd)self.arg1 = arg1def __call__(self, x):return self.lambd(x, self.arg1)class CIFAR10(torchvision.datasets.CIFAR10):def __getitem__(self, idx):return super().__getitem__(idx)[0]class MNIST(torchvision.datasets.MNIST):def __getitem__(self, idx):return super().__getitem__(idx)[0]def make_datasets(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]:"""Mandatory keys: dataset (must be cifar10, mnist, bin_mnist, bin_mnist_cts or text8), data_dirOptional for vision: num_bins (default 256), val_frac (default 0.01), horizontal_flip (default: False)Mandatory for text: seq_len"""num_bins = cfg.get("num_bins", 256)if cfg.dataset == "cifar10":train_transform_list = [transforms.ToTensor()]if cfg.get("horizontal_flip", False):train_transform_list.append(transforms.RandomHorizontalFlip())train_transform_list.append(MyLambda(rgb_image_transform, num_bins))train_transform = transforms.Compose(train_transform_list)test_transform = transforms.Compose([transforms.ToTensor(), MyLambda(rgb_image_transform, num_bins)])train_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=train_transform)val_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=test_transform)test_set = CIFAR10(root=cfg.data_dir, train=False, download=True, transform=test_transform)elif cfg.dataset == "mnist":transform = transforms.Compose([transforms.ToTensor(),MyLambda(rgb_image_transform, num_bins),])train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)elif cfg.dataset == "bin_mnist":transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_transform)])train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)elif cfg.dataset == "bin_mnist_cts":transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_cts_transform)])train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform)test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform)elif cfg.dataset == "text8":train_set = Text8Dataset(cfg.data_dir, "train", download=True, seq_len=cfg.seq_len)val_set = Text8Dataset(cfg.data_dir, "val", download=True, seq_len=cfg.seq_len)test_set = Text8Dataset(cfg.data_dir, "test", download=True, seq_len=cfg.seq_len)else:raise NotImplementedError(cfg.dataset)if cfg.dataset != "text8":# For vision datasets we split the train set into train and val# 因为上面划分的 train_set 和 val_set 实际上都是训练集,只不过应用了不同的 transforms,# 所以这里需要按比例从训练集中真正划分出验证集.val_frac = cfg.get("val_frac", 0.01)train_val_split = [1.0 - val_frac, val_frac]# 固定随机种子使得两个 random_split 划分的结果一致, 这样 train_set 和 val_set 就不用有交集.seed = 2147483647train_set = random_split(train_set, train_val_split, generator=torch.Generator().manual_seed(seed))[0]val_set = random_split(val_set, train_val_split, generator=torch.Generator().manual_seed(seed))[1]return train_set, val_set, test_setdef prepare_text8(data_dir: pathlib.Path):data_dir.mkdir(parents=True, exist_ok=True)data_url = "http://mattmahoney.net/dc/text8.zip"with open(data_dir / "text8.zip", "wb") as f:print("Downloading text8")f.write(requests.get(data_url).content)print("Done")with zipfile.ZipFile(data_dir / "text8.zip") as f:f.extractall(data_dir)os.remove(data_dir / "text8.zip")data = (data_dir / "text8").read_text()# get all the unique characters that occur in this textchars = sorted(list(set(data)))vocab_size = len(chars)print("all the unique characters:", "".join(chars))print(f"vocab size: {vocab_size:,}")# create a mapping from characters to integersstoi = {ch: i for i, ch in enumerate(chars)}itos = {i: ch for i, ch in enumerate(chars)}def encode(s):return [stoi[c] for c in s]  # encoder: take a string, output a list of integers# encode both to integersn = len(data)# List[int]train_data = data[: int(n * 0.9)]val_data = data[int(n * 0.9) : int(n * 0.95)]test_data = data[int(n * 0.95) :]train_ids = encode(train_data)val_ids = encode(val_data)test_ids = encode(test_data)print(f"train has {len(train_ids):,} tokens")print(f"val has {len(val_ids):,} tokens")print(f"test has {len(test_ids):,} tokens")# export to bin filestrain_ids = np.array(train_ids, dtype=np.uint16)val_ids = np.array(val_ids, dtype=np.uint16)test_ids = np.array(test_ids, dtype=np.uint16)train_ids.tofile(data_dir / "train.bin")val_ids.tofile(data_dir / "val.bin")test_ids.tofile(data_dir / "test.bin")print(f"Saved to {data_dir / 'train.bin'}, {data_dir / 'val.bin'}, {data_dir / 'test.bin'}")# Save the meta information as well, to help us encode/decode latermeta = {"vocab_size": vocab_size,"itos": itos,"stoi": stoi,}with open(os.path.join(data_dir / "meta.pkl"), "wb") as f:pickle.dump(meta, f)print(f"text8 dataset downloaded and prepared in dir {data_dir}")class Text8Dataset(Dataset):def __init__(self, data_dir: Union[str, pathlib.Path], split: str, download: bool, seq_len: int):"""seq_len should include context length. Example: seq_len=512 for modeling 256 chars with 256 char of context.context is only used for correct preparation of val/test sets."""self.seq_len = seq_lenself.split = splitassert self.split in ["train", "val", "test"]fname = {"train": "train.bin", "val": "val.bin", "test": "test.bin"}[self.split]self.root_dir = pathlib.Path(data_dir)data_dir = self.root_dir / "text8"if not os.path.exists(data_dir):if download:prepare_text8(data_dir)else:raise NotADirectoryError(f"dir {data_dir} does not exist and download is False")# memmap() 将磁盘上的大型二进制文件当作内存中的数组进行处理, shape 若未指定, 则返回的数组将是一维的.# order 参数指定数组内存布局的顺序, 可以是 C(行优先) 或 F(列优先), 默认是行优先, 这个参数仅在数组大于1维时有效.# 还支持 offset 参数, 加载的数组数据从此偏移量开始. 偏移量应该是 dtype 的字节大小的倍数, 默认为 0.self.data = np.memmap(data_dir / fname, np.uint16, "r")def __getitem__(self, index) -> torch.Tensor:seq = torch.from_numpy(self.data[index : index + self.seq_len].astype(np.int64))return seqdef __len__(self):return self.data.size - self.seq_len

离散化操作

以下是针对连续数据的离散化操作,即将其“分配”至对应的离散区间,然后使用区间中点值来表示,本质上属于一种量化的过程,这也是以下 quantize() 方法的命名原因。

刚才说到,量化就是将一个连续的浮点值分配至对应的离散区间,然后再用那个区间的中点值来表示。于是,quantize() 方法就是先调用 float_to_idx() 再调用 idx_to_float()。

但是,在调用 quantize() 方法前,由于数据经过了 torchvision.transforms.ToTensor() 的处理,因此数据值位于 [0,1] 区间,于是要先将其 scale 至 [-1,1] 区间内,如下 rgb_image_transform() 的代码所示。

def idx_to_float(idx: np.ndarray, num_bins: int):"""将离散化区间索引 k 转换为对应的区间中心值 k_c.注意, 此处 k 的取值范围与论文中的不同, 论文中 k 的取值范围是 1~K, 而这里:k_c = \frac{2k+1}{K} - 1, where k \in [0, K-1]."""flt_zero_one = (idx + 0.5) / num_binsreturn (2.0 * flt_zero_one) - 1.0def float_to_idx(flt: np.ndarray, num_bins: int):"""根据离散化值 k_c 计算出对应的区间索引 k, 是 float_to_idx() 的逆向操作."""flt_zero_one = (flt / 2.0) + 0.5return torch.clamp(torch.floor(flt_zero_one * num_bins), min=0, max=num_bins - 1).long()def quantize(flt, num_bins: int):"""将浮点值量化以对应的离散化区间中点 k_c 表示, 因此看作是一个量化的过程."""return idx_to_float(float_to_idx(flt, num_bins), num_bins)def rgb_image_transform(x, num_bins=256):"""将 RGB 图像进行离散化, 其中 x \in [0,1]"""return quantize((x * 2) - 1, num_bins).permute(1, 2, 0).contiguous()

MNIST 的动态二值化

def bin_mnist_transform(x):return torch.bernoulli(x.permute(1, 2, 0).contiguous()).int()def bin_mnist_cts_transform(x):return torch.bernoulli(x.permute(1, 2, 0).contiguous()) - 0.5

以上还有个二值化后变为浮点数(-0.5 or 0.5) 的版本,如 bin_mnist_cts_transform() 所示。

八、训练流程

前面讲的是算法实现和数据处理,现在是时候解析将它们串起来的整个训练流程(https://github.com/nnaisense/bayesian-flow-networks/blob/main/train.py%23L178)了。

import copy
import logging
import mathfrom collections import defaultdict
from pathlib import Path
from typing import Optional, Tupleimport torch
import neptunefrom accelerate import Accelerator
from accelerate.logging import get_loggerfrom omegaconf import OmegaConffrom rich.logging import RichHandler
from rich.progress import Progressfrom torch import nn, optim
from torch.utils.data import DataLoaderfrom model import BFN
from utils_train import (seed_everything, log_cfg,checkpoint_training_state,init_checkpointing,log,update_ema,ddict,make_infinite,make_progress_bar, make_config, make_dataloaders, make_bfn,
)torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = Truelogging.basicConfig(level=logging.INFO,format="%(message)s",datefmt="[%X]",handlers=[RichHandler(rich_tracebacks=True, show_time=False)],
)logger = get_logger(__name__)def ddict():"""Infinite default dict to fake neptune run on non-main processes"""return defaultdict(ddict)def main(cfg):acc = Accelerator(gradient_accumulation_steps=cfg.training.accumulate)cfg.training.seed = seed_everything(cfg.training.seed)logger.info(f"Seeded everything with seed {cfg.training.seed}", main_process_only=True)with acc.main_process_first():model, dataloaders, optimizer = setup(cfg)ema = copy.deepcopy(model) if acc.is_main_process and cfg.training.ema_decay > 0 else None  # EMA on main proc onlymodel, optimizer, dataloaders["train"] = acc.prepare(model, optimizer, dataloaders["train"])# 这个 ddict() 对象是一个无限嵌套的 defaultdict,将其视作假的 neptune run 对象,# 用于主进程之外的其它进程,类似一种 placeholder 的角色,而主进程会重新对 run 变量进行赋值,使其成为真正的neptune run 对象。run = ddict()if acc.is_main_process:ema.to(acc.device)try:if cfg.meta.neptune:import neptunerun = neptune.init_run(project=cfg.meta.neptune, mode="debug" if cfg.meta.debug else None)run["accelerate"] = dict(amp=acc.mixed_precision, nproc=acc.num_processes)log_cfg(cfg, run)except ImportError:logger.info("Did not find neptune installed. Logging will be disabled.")train(cfg.training, acc, model, ema, dataloaders, optimizer, run)if __name__ == "__main__":cfg_file = OmegaConf.from_cli()['config_file']main(make_config(cfg_file))

作者使用了 OmegaConf(https://github.com/omry/omegaconf) 这个配置管理系统,它支持 YAML 格式的文件来定义配置项,并且可以将多个来源(系统环境变量、命令行参数、文件等)的配置项进行合并。

对于实验数据的管理,作者则使用了 neptune(https://neptune.ai/),它可以对实验数据进行追踪、过滤、分组、排序、可视化等,还支持将这些实验结果分享给多人以便合作。

训练是支持分布式的,依赖于大名鼎鼎的 accelerate(https://github.com/huggingface/accelerate),这年头相信大多数人对它已经很熟悉了(不熟悉我也懒得说了~)。

现在来理一理以上 main() 函数的整个流程:首先对分布式相关的东西进行初始化(实例化 Accelerator);然后设置随机种子(seed_everything());接着设置 model, dataloader, optimizer 老三样(setup())并且将它们用 Accelerator 对象(acc) wrap 起来,以便支持分布式训练;哦,这里还对模型做了指数平均移动 EMA(Exponential Moving Average),相当于额外对模型做动量更新,EMA 主要用在评估(validate)阶段;下一步就是对 neptune 做初始化然后将实验配置项记录在其中,注意仅在主进程(main process)上进行即可;最后就是调用 train() 函数开启真正的训练过程了。

接下来先看看 setup() 函数是如何实例化 model, dataloader 以及 optimizer 对象的。

def setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]:"""Create the model, dataloader and optimizer"""dataloaders = make_dataloaders(cfg)model = make_bfn(cfg.model)    if "weight_decay" in cfg.optimizer.keys() and hasattr(model.net, "get_optim_groups"):# 区分了 decay 与不 decay 的参数.params = model.net.get_optim_groups(cfg.optimizer.weight_decay)else:params = model.net.parameters()# Instantiate the optimizer using the hyper-parameters in the configoptimizer = optim.AdamW(params=params, **cfg.optimizer)return model, dataloaders, optimizer

可以看到,dataloader 和 model 都是通过调用其它函数来完成实例化的,optimizer 则直接使用 Pytorch 内置的 AdamW。此外,还支持对模型参数是否要进行 weight decay 做了区分(但在作者的实现中并非所有模型都支持 get_optim_groups() 这个方法,只有其实现的 GPT 才支持该方法)。

下面进一步来看看 make_bfn() 方法,它被定义在另外的文件 utils_train.py(https://github.com/nnaisense/bayesian-flow-networks/blob/main/utils_train.py#L153) 里。

import model
import networks
import probabilityfrom networks import adaptersdef make_from_cfg(module, cfg, **parameters):return getattr(module, cfg.class_name)(**cfg.parameters, **parameters) if cfg is not None else Nonedef make_bfn(cfg: DictConfig):data_adapters = {"input_adapter": make_from_cfg(adapters, cfg.input_adapter),"output_adapter": make_from_cfg(adapters, cfg.output_adapter),}net = make_from_cfg(networks, cfg.net, data_adapters=data_adapters)bayesian_flow = make_from_cfg(model, cfg.bayesian_flow)distribution_factory = make_from_cfg(probability, cfg.distribution_factory)loss = make_from_cfg(model, cfg.loss, bayesian_flow=bayesian_flow, distribution_factory=distribution_factory)bfn = model.BFN(net=net, bayesian_flow=bayesian_flow, loss=loss)return bfn

哦!原来是这样的招数——需要实例化哪个类,就从定义它的文件里将其取出然后再传入对应参数,所以一开始需要先导入包含各个类定义的模块(文件)。

你以为我下一步要给你看 make_dataloaders() 长什么样?不好意思,你误会了。CW 打算先将其晾一晾,搞个熟成,待风味足够时再好好拿出来分享~

现在先回到刚刚训练流程的文件里,看看真正的训练过程 train() 函数是怎么玩的。

def train(cfg,accelerator: Accelerator,model: BFN,ema_model: Optional[nn.Module],dataloaders: dict,optimizer: optim.Optimizer,run: "neptune.Run",# run: neptune.Run
):is_main = accelerator.is_main_processpbar = make_progress_bar(is_main)run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch()train_id = pbar.add_task(f"Training {run_id}", start=cfg.start_step, total=cfg.n_training_steps, loss=math.nan)checkpoint_root_dir = init_checkpointing(cfg.checkpoint_dir, run_id) if is_main else Nonebest_val_loss = math.inftrain_iter = make_infinite(dataloaders["train"])model.train()with pbar:for step in range(cfg.start_step, cfg.n_training_steps + 1):step_loss = 0.0for _ in range(cfg.accumulate):with accelerator.accumulate(model):train_batch = next(train_iter)loss = model(train_batch)accelerator.backward(loss)if accelerator.sync_gradients and cfg.grad_clip_norm > 0:accelerator.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm)optimizer.step()optimizer.zero_grad(set_to_none=True)step_loss += loss.item()update_ema(ema_model, model, cfg.ema_decay)if is_main and (step % cfg.checkpoint_interval == 0):checkpoint_training_state(checkpoint_root_dir / "last", accelerator, ema_model, step, run_id)run["checkpoints/last"].track_files(str(checkpoint_root_dir / "last"))log(run["metrics"]["train"]["loss"], step_loss / cfg.accumulate, step, cond=is_main and step % cfg.log_interval == 0)log(run["metrics"]["epoch"], step // len(dataloaders["train"]), step, cond=is_main)if is_main and (step % cfg.val_interval == 0) and "val" in dataloaders:val_loss = validate(cfg=cfg,model=model,ema_model=ema_model,val_dataloader=dataloaders["val"],step=step,run=run,pbar=pbar,best_val_loss=best_val_loss,checkpoint_root_dir=checkpoint_root_dir,accelerator=accelerator,)best_val_loss = min(val_loss, best_val_loss)# advance=1 代表任务完成度+1pbar.update(train_id, advance=1, loss=loss.item())

在这里,首先作者使用了 rich 库的 Progress(https://rich.readthedocs.io/en/stable/reference/progress.html) 对象来用作进度条的显示(可能是他嫌弃 tqdm 太 low 叭~),而 make_progress() 方法就是实例化这个对象并将其返回;然后设置了 checkpoint 的目录(init_checkpointing()),以便记录训练期间的模型权重,免得白搞一场;接着,他将 dataloder 变成一个无限迭代的生成器(make_infinite()),待到达指定步数 n_training_steps 后,就停止训练;最后就是常规的训练迭代了,包括:dataloder 吐数据、模型吃数据进行预测并计算 loss、反向传播更新权重(可能有梯度累积和裁剪)、更新 EMA(update_ema())、每隔一定周期记录 checkpoint、记录 loss 与当前进度、周期性地对模型效果进行评估(若设置了 EMA 则是拿它来做 validation)。

OK,我知道你们可能好奇 make_infinite() 和 update_ema() 具体是怎么做的,没问题,我现在就 show 出来,它们被定义在另外的 utils_train.py(https://github.com/nnaisense/bayesian-flow-networks/blob/main/utils_train.py%23L115) 文件里。

@torch.no_grad()
def update_ema(ema_model, model, ema_decay):if ema_model is not None and ema_decay > 0:for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):# ema_i = ema_decay * ema_{i-1} + (1-ema_decay) * model_paramema_param.sub_((1 - ema_decay) * (ema_param - model_param))def make_infinite(dataloader: DataLoader) -> Generator[dict, None, None]:while True:for data in dataloader:yield data

make_infinite() 实际就是在 dataloader 循环取数据的外面包了一层 while True 无限循环,而 update_ema() 实质上就是将 EMA 模型与当前 model 的权重做加权求和。

现在来看看对模型效果进行评估的过程。

@torch.no_grad()
def validate(cfg,model: BFN,ema_model: nn.Module,val_dataloader: DataLoader,step: int,run: "neptune.Run",pbar: Optional[Progress],best_val_loss: float,checkpoint_root_dir: Optional[Path],accelerator: Accelerator,
) -> float:"""Evaluate model on validation data and save checkpoint if loss improves."""dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[accelerator.mixed_precision]model_to_eval = ema_model if ema_model is not None else modelmodel_to_eval.eval()pbar = pbar or Progress()max_steps = cfg.max_val_batches if cfg.max_val_batches > 0 else len(val_dataloader)val_id = pbar.add_task("Validating", visible=True, total=cfg.val_repeats * max_steps, transient=True, loss=math.nan)loss, count = 0.0, 0for _ in range(cfg.val_repeats):for idx, eval_batch in enumerate(val_dataloader):enabled = True if dtype in [torch.float16, torch.bfloat16] else Falsewith torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype, enabled=enabled):loss += model_to_eval(eval_batch.to(accelerator.device)).item()count += 1pbar.update(val_id, advance=1, loss=loss / count)if (idx + 1) >= max_steps:breakloss /= countpbar.remove_task(val_id)log(run["metrics"]["val"]["loss"], loss, step)if checkpoint_root_dir is not None and (loss < best_val_loss or math.isinf(best_val_loss)):logger.info(f"loss improved: new value is {loss}")step_checkpoint_path = checkpoint_root_dir / "best"run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch()checkpoint_training_state(step_checkpoint_path, accelerator, ema_model, step, run_id)run["metrics/best/loss/metric"] = lossrun["metrics/best/loss/step"] = stepmodel.train()return loss

其实就像是个简化版的训练过程,主要做的事情就是取数据输入到模型中完成预测,然后计算 loss,最后看 loss 比起之前有无好转(会预先记录历史最优 loss),有的话就保存下 checkpoint 和其它重要信息(checkpoint_training_state())。

import jsondef checkpoint_training_state(checkpoint_dir, accelerator, ema_model, step: int, run_id: str):if checkpoint_dir is None:returnlogger.info(f"Checkpointing training state to {checkpoint_dir} at step {step}")accelerator.save_state(checkpoint_dir)with open(checkpoint_dir / "info.json", "w") as f:json.dump({"step": step, "run_id": run_id}, f)if ema_model is not None:ema_checkpoint_path = checkpoint_dir / "ema_model.pt"torch.save(ema_model.state_dict(), ema_checkpoint_path)

九、分布式训练的随机种子

OK,现在来填一填前面埋下的坑 —— make_dataloaders()。之所以先将其腌制下放到最后来享用,是因为 CW 想将其与随机种子的设置放在一起来好好吹下水,先看代码:

def seed_everything(seed: Optional[int]) -> int:if seed is None:seed = random.randrange(np.iinfo(np.int32).max)random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)return seeddef worker_init_function(worker_id: int) -> None:"""https://pytorch.org/docs/stable/notes/randomness.html#dataloader"""worker_seed = torch.initial_seed() % 2**32np.random.seed(worker_seed)random.seed(worker_seed)def get_generator(seed: int):g = torch.Generator()g.manual_seed(seed)return gdef make_dataloaders(cfg: DictConfig):train_set, val_set, _ = make_datasets(cfg.data)dataloaders = {"train": DataLoader(dataset=train_set,worker_init_fn=worker_init_function,generator=get_generator(cfg.training.seed),**cfg.train_loader,),"val": DataLoader(dataset=val_set,worker_init_fn=worker_init_function,generator=get_generator(cfg.training.seed),**cfg.val_loader,),}return dataloaders

看起来一切都挺正常,在 seed_everything() 中进行全局的随机种子设定,包括:Python, Numpy 和 Pytroch。然后,在 dataloader 里通过 worker_init_fn 对加载数据的 workers 也设置了额外的 worker seed(每个 worker 由独立的 worker_id 识别,会开启额外的 worker 进程),这是为了让每个 worker 拥有不同的随机性,当存在类似数据增强这种操作时能够使得增强后的数据拥有多样性(即各 worker 对应 augment 后的数据呈现不一样)。

但是!在分布式(多 GPUs)训练的情况下,以上实现并不能真正达成 "每个 worker 拥有不同随机性" 这种效果,而是会使得不同 gpu 上拥有相同 worker_id(取值范围通常是 0 ~ num_workers - 1) 的 worker 都有完全一致的随机种子,从而丧失了真正意义上的随机性。

造成这个 bug 的原因是 worker_init_function() 里的 torch.initial_seed() 取决于 get_generator() 里 Generator 对象的 seed + worker_id,Generator 对象的 seed 又由固定的配置项 cfg.training.seed 指定,于是所有 gpu 上的这个值都一样,从而造成不同 gpu 上相同 worker_id 的 worker 最终得到相同的 worker seed。

既然分析出原因,那么解决办法也很简单——让每个 gpu 里 Generator 对象的 seed 不一样即可,比如像这样:

import distdef get_generator(seed: int):import torch.distributed as distrank = dist.get_rank() if dist.is_initialized() else 0seed += rankg = torch.Generator()g.manual_seed(seed)return g

这么做之后,你甚至可以不用定义 worker_init_fn,这点在后面较新版的 Pytorch 中已经支持。

对于这个问题,CW 也向作者提了 issue(顺便刷刷存在感),但作者的解决方法是将每个 gpu 的全局随机种子都设得不一样,如下所示:

def seed_everything(seed: Optional[int]):assert seed is not Noneseed += torch.distributed.get_rank() if torch.distributed.is_initialized() else 0random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)

但这种做法会导致的一个现象就是所有 gpu 在一开始随机初始化模型参数时,会得到不同的随机参数值。 不过如果是使用 DDP(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) 进行分布式训练,那么就不需要担心。因为在 DDP 的机制下,一开始各个 gpu 的模型参数都会由主卡分配以同步。但是,如果用的不是 DDP 或者希望所有 gpu 在全局都拥有一致的随机性时,这个方法就不适用了。

十、吹水

既然走到了本系列的尾端,出于不舍之情当然得好好吹个水!

再谈与扩散模型的比较

BFN 生成样本是一种迭代过程:首先由一个简单的先验分布开始,将这个分布的参数输入到模型中从而输出数据分布;然后从该分布中采样并加噪,将所得的噪声样本作为观测样本来计算后验以更新先验参数;最后将更新的先验参数再次输入到模型中以输出新的数据分布。就这样不断迭代地更新先验分布和模型的输出分布,待一定步骤后,再从输出分布中采样作为最终生成的样本

可以看出,它与扩散模型类似,也是迭代生成的过程并且使用了噪声。所不同的是,扩散模型在训练时是正向的扩散过程、在推理即采样生成时是反向的去噪过程。而我们的主角 BFN 无论是在训练还是采样生成时都是同一种过程,并且没有显式地进行所谓的去噪。在这里,噪声的角色是作为贝叶斯推理(bayesian inference)的 bridge,制作噪声样本是为了计算后验从而更新先验,而这又恰好达到了去噪的效果,因为先验变得更接近真实数据分布了,可谓是“隐式去噪”

另外,大家也知道,扩散模型在理论上有个强硬的限制——正向过程的扩散步数几乎要达到无穷才能变为纯高斯噪声,从而才能与反向的去噪过程的开头(即纯高斯噪声)完美地对应起來。而 BFN 在根本上则没有这种限制,它在采样生成时的初始先验与训练时初始化的输入分布是一致的(完美对应)。

细说 BFN 的噪声设置

那么,更进一步:在 BFN 的玩法中,为何噪声强度要随着时间增加而减小,即:使用了多个噪声强度,如果仅使用一个噪声强度会怎样?毕竟这样可以省去了设置噪声方案这项麻烦事。

在 BFN 的玩法里,并非直接去定义噪声强度而是定义所谓的精度,作为衡量噪声样本(观测样本)与真实数据的接近程度,代表噪声样本所含的有效数据信息量(密度),于是间接控制了噪声强度,所以我们得从源头——精度出发去分析这个事情。

如本系列第二篇文章所述,精度的设置是出于“有效数据信息能够以恒定速率注入到输入分布中”这一宗旨,从而输入分布中所包含的有效信息量越来越多。这样,模型接收的输入(即输入分布的参数)变得越来靠谱,其对应的输出也会更靠谱。而要贯彻这个宗旨,本质上就是要输入分布的期望熵随时间线性递减。当时在文章中就是基于这个出发点在数学上进行分析,最终推导出精度会随时间递增,从而噪声强度就会递减,这也就是为何要用多个噪声强度的原因。

其实,对于这个问题,就算跳脱数学分析,我们也能看出一些苗头。如果噪声强度不变,即精度不变,那么输入分布中所含的有效数据信息就永远是那么丁点儿,从而模型接收的输入(即输入分布的参数)“含金量”就不高,进而就会导致其输出的质量也就不高了。将模型内部的过程看作是去噪(方才说了,相当于隐式去噪),那么即使它完美地去噪了,还原出来的有效数据也就那么点,避免不了成为“劣质品”。

特点与挑战

CW 认为 BFN 最大也是最亮眼(最容易被看到)的特点就是模型的输入不是数据样本本身,而是数据分布的参数!正是这个优秀的基因,导致其天然地能够在连续型(continuous)输入的基础上愉快地玩转离散型(discrete)数据,而无需施加额外的约束。并且,能够在统一的方法框架下适配图像和语言数据的生成,无需专门针对二者做架构上的修改。另外,由于模型输出的是数据分布,因此能够直接计算似然。

但是,BFN 也面临着诸多挑战与不确定性。在最大层面上来讲,其收敛性、稳定性 以及 泛化性还有待检验。它的计算资源也比起一般的模型更多,因为单次前向过程的同时还需要额外进行贝叶斯推理。另外,输入分布很关键,模型对其依赖性不小。若在面临复杂多样的数据时将先验设置地过于简单,可能会导致最终效果不好。

再细节一些,BFN 的精度(噪声)设置也是件棘手的事情。作者在建模离散数据的实验中,就是发现 accuracy schedule 次优导致效果不佳。最后,CW 一直没挖出来 BFN 在生成多样性方面有什么可取之处(比起其他生成模型),或者,我去骚扰下作者看看叭~

完结撒花

水已吹干,真的该结束了。BFN 的建模方法对于许多朋友来说可能比较难懂,由其是当中涉及的数学推导比较多,我看到外网很多人也表示看不懂 paper,绝望地呼出 "not interesting" 的惨叫.. 也正是因为这样,CW 才决定肝出这个系列,毕竟 BFN 确实属于不无聊的风格。特别是当今满街都是扩散模型像行尸走肉般大肆虐杀,能够有只不一样的东西蹦出来难道不觉得很有意思吗!?


http://www.ppmy.cn/embedded/153372.html

相关文章

RS-232串口和普通串口介绍

RS-232串口和普通串口的区别主要体现在标准和信号电平的不同,虽然“串口”通常指的是基于串行通信的接口,但不同的串口标准在硬件实现和使用场景上有些不同。 RS-232串口 vs 普通串口的区别 RS-232 是一种具体的串行通信协议标准,而“普通串口”这个词通常是指没有明确标准定…

C语言gdb调试

目录 1.gdb介绍 2.设置断点 2.1.测试代码 2.2.设置函数断点 2.3.设置文件行号断点 2.4.设置条件断点 2.5.多线程调试 3.删除断点 3.1.删除指定断点 3.2.删除全部断点 4.查看变量信息 4.1.p命令 4.2.display命令 4.3.watch命令 5.coredump日志 6.总结 1.gdb介绍…

服务器与机顶盒

在PCDN&#xff08;P2PCDN&#xff0c;即点对点内容分发网络&#xff09;中&#xff0c;服务器相比盒子具有更高的风险&#xff0c;这主要是由于它们在性能、资源利用、应用场景以及运营方式上的差异所导致的。以下是对这一问题的详细分析&#xff1a; 一、性能与资源利用差异…

牛客网刷题 ——C语言初阶(6指针)——BC106 上三角矩阵判定

1. 题目描述——BC106 上三角矩阵判定 牛客网OJ题链接 描述 KiKi想知道一个n阶方矩是否为上三角矩阵&#xff0c;请帮他编程判定。上三角矩阵即主对角线以下的元素都为0的矩阵&#xff0c;主对角线为从矩阵的左上角至右下角的连线。 示例 输入&#xff1a; 3 1 2 3 0 4 5 0 0…

Linux 内核中的 netif_start_queue 函数:启动网络接口发送队列的关键

在 Linux 内核的网络子系统中,netif_start_queue 函数扮演着至关重要的角色。这个函数的主要功能是启动(或启用)网络接口的发送队列,标志着网络接口已经准备好开始发送数据包。本文将深入探讨 netif_start_queue 函数的用途、工作原理以及在实际网络驱动代码中的应用。 函…

使用 Python 的 pyttsx3 库进行文本转语音

1. 什么是 pyttsx3&#xff1f; 1.1 pyttsx3 是一个 Python 库&#xff0c;它可以将文本转换为语音。与其他文本转语音库&#xff08;如 gTTS&#xff09;不同&#xff0c;pyttsx3 不依赖于网络服务&#xff0c;它使用本地的 TTS&#xff08;Text-to-Speech&#xff09;引擎&a…

STM32使用ITM调试_通过仿真器实现串口打印

IDE&#xff1a;CLion MCU: STM32F407VET6 工具&#xff1a;OpenOCD Telnet 一、简介 调试单片机时&#xff0c;如果要打印数据往往需要另接一根线通过USB转TTL接到电脑上。但这样做往往并不方便&#xff0c;尤其是身边没有USB转TTL工具时。这时可以使用单片机自带的ITM单元…

Java聊天小程序

拟设计一个基于 Java 技术的局域网在线聊天系统,实现客户端与服务器之间的实时通信。系统分为客户端和服务器端两类,客户端用于发送和接收消息,服务器端负责接收客户端请求并处理消息。客户端通过图形界面提供用户友好的操作界面,服务器端监听多个客户端的连接并管理消息通…