Open-Sora代码详细解读(2):时空3D VAE

news/2024/9/18 11:12:47/ 标签: AIGC, 扩散模型, 视频生成, sora

Diffusion Models视频生成

前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。

目录

3D VAE原理

代码剖析

2D VAE

时间VAE

因果3D卷积


3D VAE原理

之前绝大多数都是2D VAE,特别是SDXL的VAE相当好用,很多人都拿来直接用了。但是在DiT-based的模型中,时间序列上如果再不做压缩的话,就已经很难训得动了。因此非常有必要在时间序列上进行压缩,3D VAE应运而生。

Open-Sora的方案是在2D VAE的基础上,再添加一个时间VAE,相比于EasyAnimate 和 CogVideoX的方案的Full Attention 存在劣势,但是可以充分利用到2D VAE的权重,成本更低。

代码剖析

2D VAE

来自华为pixart sdxl vae:

    vae_2d = dict(type="VideoAutoencoderKL",from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",subfolder="vae",micro_batch_size=micro_batch_size,local_files_only=local_files_only,)

时间VAE

    vae_temporal = dict(type="VAE_Temporal_SD",from_pretrained=None,)
@MODELS.register_module()
class VAE_Temporal(nn.Module):def __init__(self,in_out_channels=4,latent_embed_dim=4,embed_dim=4,filters=128,num_res_blocks=4,channel_multipliers=(1, 2, 2, 4),temporal_downsample=(True, True, False),num_groups=32,  # for nn.GroupNormactivation_fn="swish",):super().__init__()self.time_downsample_factor = 2 ** sum(temporal_downsample)# self.time_padding = self.time_downsample_factor - 1self.patch_size = (self.time_downsample_factor, 1, 1)self.out_channels = in_out_channels# NOTE: following MAGVIT, conv in bias=False in encoder first convself.encoder = Encoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim * 2,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)self.decoder = Decoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)def get_latent_size(self, input_size):latent_size = []for i in range(3):if input_size[i] is None:lsize = Noneelif i == 0:time_padding = (0if (input_size[i] % self.time_downsample_factor == 0)else self.time_downsample_factor - input_size[i] % self.time_downsample_factor)lsize = (input_size[i] + time_padding) // self.patch_size[i]else:lsize = input_size[i] // self.patch_size[i]latent_size.append(lsize)return latent_sizedef encode(self, x):time_padding = (0if (x.shape[2] % self.time_downsample_factor == 0)else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor)x = pad_at_dim(x, (time_padding, 0), dim=2)encoded_feature = self.encoder(x)moments = self.quant_conv(encoded_feature).to(x.dtype)posterior = DiagonalGaussianDistribution(moments)return posteriordef decode(self, z, num_frames=None):time_padding = (0if (num_frames % self.time_downsample_factor == 0)else self.time_downsample_factor - num_frames % self.time_downsample_factor)z = self.post_quant_conv(z)x = self.decoder(z)x = x[:, :, time_padding:]return xdef forward(self, x, sample_posterior=True):posterior = self.encode(x)if sample_posterior:z = posterior.sample()else:z = posterior.mode()recon_video = self.decode(z, num_frames=x.shape[2])return recon_video, posterior, z

因果3D卷积

class CausalConv3d(nn.Module):def __init__(self,chan_in,chan_out,kernel_size: Union[int, Tuple[int, int, int]],pad_mode="constant",strides=None,  # allow custom stride**kwargs,):super().__init__()kernel_size = cast_tuple(kernel_size, 3)time_kernel_size, height_kernel_size, width_kernel_size = kernel_sizeassert is_odd(height_kernel_size) and is_odd(width_kernel_size)dilation = kwargs.pop("dilation", 1)stride = strides[0] if strides is not None else kwargs.pop("stride", 1)self.pad_mode = pad_modetime_pad = dilation * (time_kernel_size - 1) + (1 - stride)height_pad = height_kernel_size // 2width_pad = width_kernel_size // 2self.time_pad = time_padself.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)stride = strides if strides is not None else (stride, 1, 1)dilation = (dilation, 1, 1)self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)def forward(self, x):x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)x = self.conv(x)return x


http://www.ppmy.cn/news/1525570.html

相关文章

解锁数字信任之门:SSL证书的安全之旅

在当今这个数字化时代,互联网已成为我们生活、工作、学习不可或缺的一部分。然而,随着网络活动的日益频繁,信息安全问题也日益凸显。如何确保在线数据传输的安全性、完整性和私密性,成为了每一个网络用户和企业必须面对的重要课题…

掌握ChatGPT:高效利用AI助手

2023 年 3 月 15 日,ChatGPT-4 的诞生标志着人类进入了一个全新的 人机协作时代。这个时代就像一个混沌初开的新世界,而 ChatGPT 则是这个新世界里诞生的一个新物种。 这个新物种的心智如同一个四五岁的小孩,在与它频繁互动中,人…

基于 TDMQ for Apache Pulsar 的跨地域复制实践

导语 自2024年9月6日起,TDMQ Pulsar 版专业集群支持消息、元数据两级跨地域复制功能,消息级复制解决用户全球地域的数据统一归档问题,元数据级复制提供解决用户核心业务跨地域容灾的场景。 用户在跨地域场景遇到的疑问和挑战 在跨地域相关…

中国电商三十年,阿里的时代结束了吗?

9月12日,淘宝正式开放微信支付。 这既是阿里三年整改期结束以来的第一个大动作,更是中国电商格局迎来重塑的标志性事件。淘宝与微信互联,一方面代表着阿里与腾讯从“水火不容”走向互联互通,另一方面也正式宣告了中国电商从阿里京…

Ton的编译过程(上)

系列文章目录 FunC编写初始准备 文章目录 系列文章目录预先准备第一个FunC合约深入compileFunc的内部compileFunc初探艾丽卡的疑惑package.json 初览index.js 预先准备 首先请大家跟着艾丽卡一步一步的完成FunC编写初始准备 这里面环境的搭建。 接下来,请做好下面…

不用禁用 iptables 来解决 UFW 和 Docker 的安全问题

UFW 是 Ubuntu 上很流行的一个 iptables 前端,可以非常方便的管理防火墙的规则。但是当安装了 Docker,UFW 无法管理 Docker 发布出来的端口了。 解决 UFW 和 Docker 的问题 目前新的解决方案只需要修改一个 UFW 配置文件即可,Docker 的所有…

堆叠沙漏网络(stacked hourglass network)学习

定义 Stacked Hourglass Networks是2016年密歇根大学提出的经典网络架构。是曾经最具代表性的姿态识别SOTA之一。 hourglass network hourglass network 本身其实可以理解成是一个encoder-decoder的结构,encoder最大程度的提取图像在每一个scale的特征以及空间信…

漫谈设计模式 [21]:备忘录模式

引导性开场 菜鸟:老鸟,我最近在一个项目中遇到了一个问题。我需要实现一个功能,能够让用户在修改数据后撤销或恢复到之前的状态。你有什么好的建议吗? 老鸟:这听起来像是一个很经典的问题。你有没有听说过设计模式中…

个性化、持续性阅读 学生英语词汇量自然超越标准

2024年秋季新学年,根据2022版《义务教育英语课程标准》全新修订的英语新版教材开始投入使用,标志着我国英语教育迈入了一个以应用为导向、注重综合素养培养的新阶段。 新版教材的变革不仅仅是一次词汇量的简单增加,更是一场从应试到应用的深…

Windows Python 指令补全方法

网络上搜集的补全代码 # python startup file import sys import readline import rlcompleter import atexit import os# tab completion readline.parse_and_bind(tab: complete) # history file histfile os.path.join(os.environ[HOMEPATH], .pythonhistory) try:readline…

数学分析原理答案——第三章 习题18

【第三章 习题18】 把习题16中的递推公式换成 x n 1 p − 1 p x n α p x n − p 1 x_{n 1} \frac{p - 1}{p}x_{n} \frac{\alpha}{p}x_{n}^{- p 1} xn1​pp−1​xn​pα​xn−p1​ 这里 p p p是固定的正整数,描述该序列的性质 【解】 若 x 1 > x p x…

Linux命令分享 三 (ubuntu 16.04)

1、‘>’ >>输出重定向 用法:命令 参数 > 文件 ls > a.txt ‘>’ 将一个命令的结果不输出到屏幕上,输出到文件中,如果文件不存在就创建文件,如果存在就覆盖文件。 ls >> a.txt ‘>>’ 如果文件不存…

注册安全分析报告:熊猫频道

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大,造成亏损无底洞…

JS设计模式之装饰者模式:优雅的给对象增添“魔法”

引言 在前端开发中,我们经常会遇到需要在不修改已有代码的基础上给对象添加新的行为或功能的情况。而传统的继承方式并不适合这种需求,因为继承会导致类的数量急剧增加,且每一个子类都会固定地实现一种特定的功能扩展。 装饰者模式则提供了…

使用Let’s Encrypt 配置 SSL 证书去除浏览器不安全告警

Let’s Encrypt是什么 https://letsencrypt.org/zh-cn/about/如何操作进行配置实现ssl认证 使用 certbot 获取 Let’s Encrypt 的免费 SSL 证书 更新系统软件包 sudo yum update -y安装 EPEL 仓库(Certbot 通常位于 EPEL 仓库中): sudo yum

使用Pandas高效读取和处理Excel数据

目录 引言 安装必要的库 示例代码 注 引言 在数据科学和数据分析领域,Excel文件是一种常见的数据存储格式。由于其易于编辑和分享的特点,Excel成为了许多企业和组织中数据记录的标准工具。然而,在进行大规模的数据分析时,手动处理…

栈OJ题——用栈实现队列

文章目录 一、题目链接二、解题思路三、解题代码 一、题目链接 用栈实现队列 二、解题思路 三、解题代码 class MyQueue {public Stack<Integer> stack1 ;public Stack<Integer> stack2;public MyQueue() {stack1 new Stack<>();stack2 new Stack<&g…

网络药理学:2、文章基本思路、各个数据库汇总与比对、其他相关资料(推荐复现的文章、推荐学习视频、论文基本框架、文献基本知识及知网检索入门)

一、文章基本思路&#xff08;待更&#xff09; 一篇不含分子对接和实验的纯网络药理学文章思路如下&#xff1a; 即如下&#xff1a; 二、 各个数据库&#xff08;待更&#xff09; 三、其他相关资料 1.推荐复现的文章 纯网络药理学分子对接&#xff1a;知网&#xff1…

[项目] - Calc计算器

前言 各位师傅大家好&#xff0c;我是qmx_07&#xff0c;今天来尝试模拟windows 下的clac计算器 绘制计算器 拖动工具箱的Edit Control输入框、Button按钮 制作计算器界面需要将Edit Control输入框 拉长&#xff0c;将多行、只读 设置为True整体计算机的控件ID&#xff1a;I…

算法题:找出一个数组中,出现次数为奇数次的数。详细解析,加个人理解。

数组中&#xff0c;出现奇数次的数 从一个数组中找出出现了奇数次的数字&#xff0c;要求&#xff1a; 时间复杂度&#xff1a;O(n)空间复杂度&#xff1a;O(1) 题目 从数组中找出一个出现奇数次的数字&#xff0c;如&#xff1a; {1, 2, 2, 1, 3, 1, 1, 3, 3}&#xff0c;结…