【笔记】扩散模型(八):DALL-E 2 (unCLIP) 论文解读与代码实现

server/2024/9/22 16:23:04/

论文链接:Hierarchical Text-Conditional Image Generation with CLIP Latents

非官方实现:lucidrains/DALLE2-pytorch

DALL-E 2 是一个比较经典的文生图模型,虽然和 Stable Diffusion 的架构有些区别,但是也利用了 CLIP 的文本-图像对齐能力实现了用文本作为条件进行图像生成。由于 CLIP 是输入文本和图像获得相应的特征,而 DALL-E 2 是将输入的文本转化为特征再转换为图像,相当于把 CLIP 中的图像编码器反转了过来,所以这个方法也被称为 unCLIP。这个模型主要由三个部分组成:

  • CLIP 模型:负责将条件文本转换到文本-图像的统一特征空间中;
  • prior 模型:将文本特征转换为图像特征,用于后续的图像生成;
  • decoer 模型:将从 prior 获得的图像特征转换为具体的生成图像,相当于反转了 CLIP 中的图像 encoder。

模型的架构图如下图所示,虚线的上方是 CLIP 模型,下方是 prior 和 decoder 模型。

DALL-E 2 模型架构

DALL-E 2 的训练与采样

由于 DALL-E 2 由三个不同的部分组成,这三个模型都需要分别进行训练。

训练的第一步是训练 CLIP 模型,这部分和 CLIP 原本的训练过程是一样的,因此 DALL-E 2 可以直接使用已经训练好的 CLIP 模型。

第二步是训练 prior 模型,这个模型的作用是将 CLIP 的文本特征转换为图像特征,用于后续的生成步骤。作者个人感觉这一步不一定是必须的,因为 CLIP 中的文本特征与图像特征是对齐的,而且在 Stable Diffusion 中实际上也是直接用 CLIP 的文本特征和 latent 做交叉注意力。不过这里还是用 prior 模型做了一步转换,直观上来说可能转换一步之后可以弥补原先在 CLIP 中文本和图像特征没有对齐的那一部分。

这里的 prior 模型有两种可能的选择:

  • 自回归模型(autoregressive prior):将图像的特征转换为一系列离散的序列,用自回归的方式生成。(应该比较类似于用 Transformer 做 next token prediction 的任务)
  • 扩散模型(diffusion prior):相当于用文本特征作为条件,并用扩散模型生成图像特征。

由于两种模型的效果差不多并且扩散模型的效率更高,所以最后使用的是扩散模型。不过这里用的不是普通的基于 UNet 的扩散模型,而是使用了一个 decoder-only 的 Transformer 模型,并且预测的内容也是从预测噪声变成了直接预测 embedding。

由于 prior 模型是要将文本特征转换为图像特征,训练目标也是将输出与 CLIP 原本的图像特征对齐,如图所示:

DALL-E 2 prior 模型的训练

最后一步是训练 decoder 模型,这个模型需要以图像为条件,生成最终的目标图像。decoder 模型使用的是一个改进的 GLIDE(也是 diffusion model),训练流程和 GLIDE 是一致的。

在采样时,首先使用 CLIP 将文本进行编码,然后用 prior 将文本特征转换为图像特征,最后用 decoder 生图。

DALL-E 2 代码解读

因为 OpenAI 官方没有放出 DALL-E 2 的完整代码,这里主要参考的是文章最开始给出的非官方实现。这个模型的层次结构也很清晰:

class DALLE2(nn.Module):def __init__(self,*,prior: DiffusionPrior,decoder: Decoder,prior_num_samples = 2):super().__init__()self.prior = priorself.decoder = decoderself.prior_num_samples = prior_num_samplesself.decoder_need_text_cond = self.decoder.condition_on_text_encodings@torch.no_grad()@eval_decoratordef forward(self,text,cond_scale = 1.,prior_cond_scale = 1.,return_pil_images = False):device = module_device(self)# 预处理文本,将文本进行 tokenizationone_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)if isinstance(text, str) or is_list_str(text):text = [text] if not isinstance(text, (list, tuple)) else texttext = tokenizer.tokenize(text).to(device)# 这里相当于两步合一:CLIP 提取文本特征+生成图像特征image_embed = self.prior.sample(text, num_samples_per_batch=self.prior_num_samples, cond_scale=prior_cond_scale)text_cond = text if self.decoder_need_text_cond else None# 使用 decoder 生成图像,可以看到不仅可以用图像特征进行 condition,# 也可以使用文本特征进行 conditionimages = self.decoder.sample(image_embed=image_embed, text=text_cond, cond_scale=cond_scale)return images

这里的 diffusion prior 大部分都和一般的 diffusion model 一样,不过主要需要关注两个方法。第一个是采样方法,和上述的流程一样,不过有一个上边没有介绍的细节,就是实际上采样了两个图像的 embedding,但是只使用了与文本最匹配的一个:

class DiffusionPrior(nn.Module):...@torch.no_grad()@eval_decoratordef sample(self,text,num_samples_per_batch = 2,cond_scale = 1.,timesteps = None):# 初始化时间步timesteps = default(timesteps, self.sample_timesteps)# 原文的做法是采样两个 image embedding 然后选 CLIP 匹配分数较高的一个text = repeat(text, 'b ... -> (b r) ...', r=num_samples_per_batch)batch_size = text.shape[0]image_embed_dim = self.image_embed_dim# 使用 CLIP 进行 embeddingtext_embed, text_encodings = self.clip.embed_text(text)text_cond = dict(text_embed=text_embed)if self.condition_on_text_encodings:text_cond = {**text_cond, 'text_encodings': text_encodings}# 生成图像 embeddingimage_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond=text_cond, cond_scale=cond_scale, timesteps=timesteps)# 匹配一个比较好的图像 embedding 返回text_embeds = text_cond['text_embed']text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r=num_samples_per_batch)image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r=num_samples_per_batch)text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))top_sim_indices = text_image_sims.topk(k=1).indicestop_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d=image_embed_dim)top_image_embeds = image_embeds.gather(1, top_sim_indices)return rearrange(top_image_embeds, 'b 1 d -> b d')

第二个需要关注的是训练时的损失,这里预测的对象和普通的 diffusion model 有所不同:

class DiffusionPrior(nn.Module):...def p_losses(self, image_embed, times, text_cond, noise=None):noise = default(noise, lambda: torch.randn_like(image_embed))image_embed_noisy = self.noise_scheduler.q_sample(x_start=image_embed, t=times, noise=noise)self_cond = Noneif self.net.self_cond and random.random() < 0.5:with torch.no_grad():self_cond = self.net(image_embed_noisy, times, **text_cond).detach()# 正常的 diffusion model 这里预测的是噪声,但这里直接预测了 embeddingpred = self.net(image_embed_noisy,times,self_cond = self_cond,text_cond_drop_prob = self.text_cond_drop_prob,image_cond_drop_prob = self.image_cond_drop_prob,**text_cond)if self.predict_x_start and self.training_clamp_l2norm:pred = self.l2norm_clamp_embed(pred)if self.predict_v:target = self.noise_scheduler.calculate_v(image_embed, times, noise)elif self.predict_x_start:target = image_embedelse:target = noise# 计算损失也是直接用 embedding 进行计算loss = self.noise_scheduler.loss_fn(pred, target)return loss

decoder 的采样过程也没有什么特别的地方,就是普通的 diffusion model 采样过程,这里就不展开介绍了。

总结

DALL-E 2 刚出的时候也算非常火,不过这个模型也有 diffusion model 的一些通病,比如会出现不同主体的属性混淆、文本的生成效果比较差等情况。总体来说,个人感觉这个模型不如 Stable Diffusion 优雅,从后续的很多工作也可以看出,基于 Stable Diffusion 继续进行拓展的方法才是主流,基于 DALL-E 2 的方法还是比较少的。

参考资料:

  1. DALL·E 2 解读 | 结合预训练CLIP和扩散模型实现文本-图像生成

本文原文以 CC BY-NC-SA 4.0 许可协议发布于 笔记|扩散模型(八):DALL-E 2 (unCLIP) 理论与实现,转载请注明出处。


http://www.ppmy.cn/server/119087.html

相关文章

Selenium with Python学习笔记整理(网课+网站)

本篇是根据学习网站和网课结合自己做的学习笔记&#xff0c;后续会一边学习一边补齐和整理笔记 官方学习网站在这获取&#xff1a; https://selenium-python.readthedocs.io/getting-started.html#simple-usage WEB UI自动化环境配置 (推荐靠谱的博客文章来进行环境配置,具…

孙怡带你深度学习(1)--神经网络

文章目录 深度学习神经网络1. 感知器2. 多层感知器偏置 3. 神经网络的构造4. 模型训练损失函数 总结 深度学习 深度学习(DL, Deep Learning)是机器学习(ML, Machine Learning)领域中一个新的研究方向。 从上方的内容包含结果&#xff0c;我们可以知道&#xff0c;在学习深度学…

Redis 配置

一、关系型数据库与非关系型数据库 1. 关系型数据库 关系型数据库是一种结构化数据库&#xff0c;基于关系模型&#xff08;二维表格模型&#xff09;&#xff0c;适合记录数据。通过 SQL&#xff08;结构化查询语言&#xff09;进行数据的检索和操作。主流的关系型数据库包括…

Qt 学习第十天:小项目:QListWidget的使用

一、页面布局 二、命名按钮 双击按钮可以修改显示中的文字&#xff08;例如&#xff1a;改成“全选”&#xff09;&#xff0c;objectName是要改成程序员所熟悉的名字&#xff08;英文&#xff0c;符合代码规范&#xff09;方便修改和书写代码&#xff0c;一看就能看懂的 三、…

git rev-parse

git rev-parse 是 Git 中一个非常有用的命令&#xff0c;用于解析并返回与 Git 对象&#xff08;如提交、分支、标签等&#xff09;相关的信息。它可以帮助我们从给定的引用&#xff08;ref&#xff09;中解析出 SHA-1 哈希值、路径信息等。这个命令在编写 Git 脚本时尤其有用&…

深度剖析iOS渲染

iOS App 图形图像渲染的基本流程&#xff1a; 1.CPU&#xff1a;完成对象的创建和销毁、对象属性的调整、布局计算、文本的计算和排版、图片的格式转换和解码、图像的绘制。 2.GPU&#xff1a;GPU拿到CPU计算好的显示内容&#xff0c;完成纹理的渲染&#xff0c; 渲染完成后将渲…

大数据-134 - ClickHouse 集群三节点 安装配置启动

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff08;已更完&#xff09;HDFS&#xff08;已更完&#xff09;MapReduce&#xff08;已更完&am…

淳博叶子:品牌、平台与广告公司,最终都是为消费者创造价值 | SMARTIES CHINA 2024终审报道④

淳博传播董事长兼总裁 叶子 近日&#xff0c;SMARTIES CHINA 2024终审活动在苏州音昱水中天落下帷幕。来自各行业的40位品牌广告主代表&#xff0c;历时两天时间&#xff0c;通过紧张的评审和精彩的讨论&#xff0c;从178个优秀入围案例中评选出了每个类别的金银铜大奖及全场大…