最新阿里开源视频生成框架Tora部署

embedded/2024/12/29 10:23:35/

Tora是由阿里团队推出的一种基于轨迹导向的扩散变换器(Diffusion Transformer, DiT)技术的AI视频生成框架。

Tora在生成过程中可以接受多种形式的输入,包括文字描述、图片或物体移动的路线,并据此制作出既真实又流畅的视频。

通过引入轨迹控制机制,Tora能够更精确地控制视频中物体的运动模式,解决了现有模型难以生成具有精确一致运动的问题。

Tora采用两阶段训练过程,首先使用密集光流进行训练,然后使用稀疏轨迹进行微调,以提高模型对各种类型轨迹数据的适应性。

Tora模型支持长达204帧、720p分辨率的视频制作,适用于影视制作、动画创作、虚拟现实(VR)、增强现实(AR)及游戏开发等多个领域。

github项目地址:https://github.com/alibaba/Tora。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 torchaudio==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu118

cd modules/SwissArmyTransformer

pip install -e .

cd ../../sat

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

3、CogVideoX-5b模型下载

git lfs install

git clone https://www.modelscope.cn/AI-ModelScope/CogVideoX-5b.git

4、Tora t2v模型下载

https://cloudbook-public-daily.oss-cn-hangzhou.aliyuncs.com/Tora_t2v/mp_rank_00_model_states.pt

、功能测试

1、运行测试

(1)python代码调用测试

import argparse
import gc
import json
import math
import os
import pickle
from pathlib import Path
from typing import List, Unionimport cv2
import imageio
import numpy as np
import torch
import torchvision.transforms as TT
from arguments import get_args
from diffusion_video import SATVideoDiffusionEngine
from einops import rearrange, repeat
from omegaconf import ListConfig
from torchvision.io import write_video
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
from torchvision.utils import flow_to_image
from tqdm import tqdm
from utils.flow_utils import process_traj
from utils.misc import vis_tensorfrom sat import mpu
from sat.arguments import set_random_seed
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpointdef read_from_cli():cnt = 0try:while True:x = input("Please input English text (Ctrl-D quit): ")yield x.strip(), cntcnt += 1except EOFError as e:passdef read_from_file(p, rank=0, world_size=1):with open(p, "r") as fin:cnt = -1for l in fin:cnt += 1if cnt % world_size != rank:continueyield l.strip(), cntdef get_unique_embedder_keys_from_conditioner(conditioner):return list(set([x.input_key for x in conditioner.embedders]))def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):batch = {}batch_uc = {}for key in keys:if key == "txt":batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()else:batch[key] = value_dict[key]if T is not None:batch["num_video_frames"] = Tfor key in batch.keys():if key not in batch_uc and isinstance(batch[key], torch.Tensor):batch_uc[key] = torch.clone(batch[key])return batch, batch_ucdef draw_points(video, points):"""Draw points onto video frames.Parameters:video (torch.tensor): Video tensor with shape [T, H, W, C], where T is the number of frames,H is the height, W is the width, and C is the number of channels.points (list): Positions of points to be drawn as a tensor with shape [N, T, 2],each point contains x and y coordinates.Returns:torch.tensor: The video tensor after drawing points, maintaining the same shape [T, H, W, C]."""T = video.shape[0]N = len(points)device = video.devicedtype = video.dtypevideo = video.cpu().numpy().copy()traj = np.zeros(video.shape[-3:], dtype=np.uint8)  # [H, W, C]for n in range(N):for t in range(1, T):cv2.line(traj, tuple(points[n][t - 1]), tuple(points[n][t]), (255, 1, 1), 2)for t in range(T):mask = traj[..., -1] > 0mask = repeat(mask, "h w -> h w c", c=3)alpha = 0.7video[t][mask] = video[t][mask] * (1 - alpha) + traj[mask] * alphafor n in range(N):cv2.circle(video[t], tuple(points[n][t]), 3, (160, 230, 100), -1)video = torch.from_numpy(video).to(device, dtype)return videodef save_video_as_grid_and_mp4(video_batch: torch.Tensor,save_path: str,name: str,fps: int = 5,args=None,key=None,traj_points=None,prompt="",
):os.makedirs(save_path, exist_ok=True)p = Path(save_path)for i, vid in enumerate(video_batch):x = rearrange(vid, "t c h w -> t h w c")x = x.mul(255).add(0.5).clamp(0, 255).to("cpu", torch.uint8)  # [T H W C]os.makedirs(p / "video", exist_ok=True)os.makedirs(p / "prompt", exist_ok=True)if traj_points is not None:os.makedirs(p / "traj", exist_ok=True)os.makedirs(p / "traj_video", exist_ok=True)write_video(p / "video" / f"{name}_{i:06d}.mp4",x,fps=fps,video_codec="libx264",options={"crf": "18"},)with open(p / "traj" / f"{name}_{i:06d}.pkl", "wb") as f:pickle.dump(traj_points, f)x = draw_points(x, traj_points)write_video(p / "traj_video" / f"{name}_{i:06d}.mp4",x,fps=fps,video_codec="libx264",options={"crf": "18"},)else:write_video(p / "video" / f"{name}_{i:06d}.mp4",x,fps=fps,video_codec="libx264",options={"crf": "18"},)with open(p / "prompt" / f"{name}_{i:06d}.txt", "w") as f:f.write(prompt)def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:arr = resize(arr,size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],interpolation=InterpolationMode.BICUBIC,)else:arr = resize(arr,size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],interpolation=InterpolationMode.BICUBIC,)h, w = arr.shape[2], arr.shape[3]arr = arr.squeeze(0)delta_h = h - image_size[0]delta_w = w - image_size[1]if reshape_mode == "random" or reshape_mode == "none":top = np.random.randint(0, delta_h + 1)left = np.random.randint(0, delta_w + 1)elif reshape_mode == "center":top, left = delta_h // 2, delta_w // 2else:raise NotImplementedErrorarr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])return arrdef sampling_main(args, model_cls):if isinstance(model_cls, type):model = get_model(args, model_cls)else:model = model_clsload_checkpoint(model, args)model.eval()if args.input_type == "cli":data_iter = read_from_cli()elif args.input_type == "txt":rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()print("rank and world_size", rank, world_size)data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)else:raise NotImplementedErrorimage_size = [480, 720]sample_func = model.sampleT, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8num_samples = [1]force_uc_zero_embeddings = ["txt"]device = model.devicewith torch.no_grad():for text, cnt in tqdm(data_iter):set_random_seed(args.seed)if args.flow_from_prompt:text, flow_files = text.split("\t")total_num_frames = (T - 1) * 4 + 1  # T is the video latent size, 13 * 4 = 52if args.no_flow_injection:video_flow = Noneelif args.flow_from_prompt:assert args.flow_path is not None, "Flow path must be provided if flow_from_prompt is True"p = os.path.join(args.flow_path, flow_files)print(f"Flow path: {p}")video_flow = (torch.load(p, map_location="cpu", weights_only=True)[:total_num_frames].unsqueeze_(0).cuda())elif args.flow_path:print(f"Flow path: {args.flow_path}")video_flow = torch.load(args.flow_path, map_location=device, weights_only=True)[:total_num_frames].unsqueeze_(0)elif args.point_path:if type(args.point_path) == str:args.point_path = json.loads(args.point_path)print(f"Point path: {args.point_path}")video_flow, points = process_traj(args.point_path, total_num_frames, image_size, device=device)video_flow = video_flow.unsqueeze_(0)else:print("No flow injection")video_flow = Noneif video_flow is not None:model.to("cpu")  # move model to cpu, run vae on gpu only.tmp = rearrange(video_flow[0], "T H W C -> T C H W")video_flow = flow_to_image(tmp).unsqueeze_(0).to("cuda")  # [1 T C H W]if args.vis_traj_features:os.makedirs("samples/flow", exist_ok=True)vis_tensor(tmp, *tmp.shape[-2:], "samples/flow/flow1_vis.gif")imageio.mimwrite("samples/flow/flow2_vis.gif",rearrange(video_flow[0], "T C H W -> T H W C").cpu(),fps=8,loop=0,)del tmpvideo_flow = (rearrange(video_flow / 255.0 * 2 - 1, "B T C H W -> B C T H W").contiguous().to(torch.bfloat16))torch.cuda.empty_cache()video_flow = video_flow.repeat(2, 1, 1, 1, 1).contiguous()  # for unconditionmodel.first_stage_model.to(device)video_flow = model.encode_first_stage(video_flow, None)video_flow = video_flow.permute(0, 2, 1, 3, 4).contiguous()model.to(device)print("rank:", rank, "start to process", text, cnt)# TODO: broadcast image2videovalue_dict = {"prompt": text,"negative_prompt": "","num_frames": torch.tensor(T).unsqueeze(0),}batch, batch_uc = get_batch(get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples)for key in batch:if isinstance(batch[key], torch.Tensor):print(key, batch[key].shape)elif isinstance(batch[key], list):print(key, [len(l) for l in batch[key]])else:print(key, batch[key])c, uc = model.conditioner.get_unconditional_conditioning(batch,batch_uc=batch_uc,force_uc_zero_embeddings=force_uc_zero_embeddings,)for k in c:if not k == "crossattn":c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))for index in range(args.batch_size):# reload model on GPUmodel.to(device)samples_z = sample_func(c,uc=uc,batch_size=1,shape=(T, C, H // F, W // F),video_flow=video_flow,)samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()# Unload the model from GPU to save GPU memorymodel.to("cpu")torch.cuda.empty_cache()first_stage_model = model.first_stage_modelfirst_stage_model = first_stage_model.to(device)latent = 1.0 / model.scale_factor * samples_z# Decode latent serial to save GPU memoryrecons = []loop_num = (T - 1) // 2for i in range(loop_num):if i == 0:start_frame, end_frame = 0, 3else:start_frame, end_frame = i * 2 + 1, i * 2 + 3if i == loop_num - 1:clear_fake_cp_cache = Trueelse:clear_fake_cp_cache = Falsewith torch.no_grad():recon = first_stage_model.decode(latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache)recons.append(recon)recon = torch.cat(recons, dim=2).to(torch.float32)samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()save_path = args.output_dirname = str(cnt) + "_" + text.replace(" ", "_").replace("/", "")[:60] + f"_{index}_seed{args.seed}"if args.flow_from_prompt:name = Path(flow_files).stemif mpu.get_model_parallel_rank() == 0:save_video_as_grid_and_mp4(samples,save_path,name,fps=args.sampling_fps,traj_points=locals().get("points", None),prompt=text,)del samples_z, samples_x, samples, video_flow, latent, recon, recons, c, uc, batch, batch_ucgc.collect()if __name__ == "__main__":if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]py_parser = argparse.ArgumentParser(add_help=False)known, args_list = py_parser.parse_known_args()args = get_args(args_list)args = argparse.Namespace(**vars(args), **vars(known))del args.deepspeed_configargs.model_config.first_stage_config.params.cp_size = 1args.model_config.network_config.params.transformer_args.model_parallel_size = 1args.model_config.network_config.params.transformer_args.checkpoint_activations = Falseargs.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = Falseargs.model_config.en_and_decode_n_samples_a_time = 1sampling_main(args, model_cls=SATVideoDiffusionEngine)

 未完......

更多详细的欢迎关注:杰哥新技术


http://www.ppmy.cn/embedded/134263.html

相关文章

总分441数一149专137东南大学820信号数电考研经验电子信息与通信工程电路原920专业基础综合,真题,大纲,参考书。

一. 写在前面的话 本人是23年考生,本科就读于西电电子信息工程,以441分总分(数学一149,英语83,专业课820(原920信号和数电专业基础综合)137,政治73)考上东南信院电路与系…

基于SSM+小程序的旅游社交登录管理系统(旅游4)

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1、项目介绍 ​ 本旅游社交小程序功能有管理员和用户。管理员有个人中心,用户管理,每日签到管理,景点推荐管理,景点分类管理,防疫查询管理&a…

随身wifi三年实测总结,随身wifi隐藏套路大揭秘!随身携带wifi品牌推荐!

面对随身wifi市场上琳琅满目的产品,你是否也曾陷入选择困境,甚至不慎踩坑?经过三年的不懈探索,小编亲身体验并评测了20余款随身WiFi,从品牌信誉到网络速度,再到售后服务,都进行了全面考量。今天…

PMP–一、二、三模–分类–11.风险管理–技巧–风险登记册

文章目录 技巧一模11.风险管理--7.监督风险--风险登记册--记录已识别单个项目风险、风险责任人、商定的风险应对策略,以及具体的应对措施。11.风险管理--关键在于区分风险和问题--风险代表对将来问题的预判,问题代表对过去问题事件的跟踪;两者…

Rust 力扣 - 189. 轮转数组

文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 我们观察数组的性质,可以通过翻转原数组,然后在翻转前k个元素,最后翻转k个之后的元素,最终就转换成了原数组的轮转数组 题解代码 impl Solution {pub fn rotate(…

SQL-lab靶场less1-4

说明:部分内容来源于网络,如有侵权联系删除 前情提要:搭建sql-lab本地靶场的时候发现一些致命的报错: 这个程序只能在php 5.x上运行,在php 7及更高版本上,函数“mysql_query”和一些相关函数被删除&#xf…

链栈的引用

链栈&#xff0c;自己实现一遍&#xff0c;但是节点存储不是整数&#xff0c;存储学生信息&#xff08;年龄&#xff0c;分数&#xff0c;姓名&#xff09;三级引用。 1、建立学生信息结构体&#xff0c;将data改为学生信息结构体类型。 2、循环入栈和出栈。 #include<m…

1.机器人抓取与操作介绍-深蓝学院

介绍 操作任务 操作 • Insertion • Pushing and sliding • 其它操作任务 抓取 • 两指&#xff08;平行夹爪&#xff09;抓取 • 灵巧手抓取 7轴 Franka 对应人的手臂 6轴 UR构型去掉一个自由度 课程大纲 Robotic Manipulation 操作 • Robotic manipulation refers…