在Hunyuan Video上实现RF-Inversion

news/2025/2/4 21:44:36/

Diffusion models代码解读:入门与实战

前言:前面的博客《论文和代码解读:RF-Inversion 图像/视频编辑技术》,但是原始代码是基于FLUX和SD3实现的,这篇博客讲解一下如何在Hunyuan Video上实现RF-Inversion 。

目录

原理讲解

第一步:Forward

第二步:Reverse


原理讲解

原理部分大家就看看上一篇博客,主要要搞清楚这两个算法图:

第一步:Forward

第一步是从原始的Video中获取latents,所以称为Forward操作:

            latents, _  = self.invert(latents=latents.to(device),prompt_embeds=prompt_embeds.to(device),prompt_embeds_2=prompt_embeds_2.to(device),guidance_expand=guidance_expand.to(device),freqs_cis=freqs_cis,prompt_mask=prompt_mask,n_tokens=n_tokens,num_inversion_steps = 30,strength = 1.0,gamma = 0.01,)

也就是算法一的操作:

    @torch.no_grad()def invert(self,latents,prompt_embeds,prompt_embeds_2,guidance_expand,freqs_cis,prompt_mask,n_tokens=None,num_inversion_steps: int = 28,strength: float = 1.0,gamma: float = 0.5,height: Optional[int] = None,width: Optional[int] = None,timesteps: List[int] = None,dtype: Optional[torch.dtype] = None,joint_attention_kwargs: Optional[Dict[str, Any]] = None,):r"""Performs Algorithm 1: Controlled Forward ODE from https://arxiv.org/pdf/2410.10792"""dtype = dtype or self.text_encoder.dtypebatch_size = 1# num_channels_latents = self.transformer.config.in_channels // 4# height = height or self.default_sample_size * self.vae_scale_factor# width = width or self.default_sample_size * self.vae_scale_factordevice = self._execution_device# 1. prepare image# image_latents, _ = self.encode_image(image, height=height, width=width, dtype=dtype)# image_latents, latent_image_ids = self.prepare_latents_inversion(#     batch_size, num_channels_latents, height, width, dtype, device, image_latents# )# 2. prepare timestepssigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps)extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(self.scheduler.set_timesteps, {"n_tokens": n_tokens})timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler,num_inversion_steps,device,timesteps,sigmas=None,**extra_set_timesteps_kwargs,)timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength)# Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dtY_t = latentsy_1 = torch.randn_like(Y_t)N = len(sigmas)# forward ODE loopwith self.progress_bar(total=N - 1) as progress_bar:for i in range(N - 1):t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device)timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size)# get the unconditional vector fieldwith torch.autocast(device_type="cuda", dtype=Y_t.dtype):u_t_i = self.transformer(  # For an input image (129, 192, 336) (1, 256, 256)Y_t,  # [2, 16, 33, 24, 42]timestep,  # [2]text_states=prompt_embeds,  # [2, 256, 4096]text_mask=prompt_mask,  # [2, 256]text_states_2=prompt_embeds_2,  # [2, 768]freqs_cos=freqs_cis[0],  # [seqlen, head_dim]freqs_sin=freqs_cis[1],  # [seqlen, head_dim]guidance=guidance_expand,return_dict=True,)["x"]# get the conditional vector fieldu_t_i_cond = (y_1 - Y_t) / (1 - t_i)# controlled vector field# Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dtu_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i)Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1])progress_bar.update()# return the inverted latents (start point for the denoising loop), encoded image & latent image idsreturn Y_t, latents

第二步:Reverse

第二步从Forward的latents中再进行Reverse操作:

        # if is_progress_bar:with self.progress_bar(total=num_inference_steps) as progress_bar:for i, t in enumerate(timesteps):if do_rf_inversion:# ti (current timestep) as annotated in algorithm 2 - i/num_inference_steps.t_i = 1 - t / 1000dt = torch.tensor(1 / (len(timesteps) - 1), device=device)if self.interrupt:continue# expand the latents if we are doing classifier free guidancelatent_model_input = (torch.cat([latents] * 2)if self.do_classifier_free_guidanceelse latents)latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)t_expand = t.repeat(latent_model_input.shape[0])guidance_expand = (torch.tensor([embedded_guidance_scale] * latent_model_input.shape[0],dtype=torch.float32,device=device,).to(target_dtype)* 1000.0if embedded_guidance_scale is not Noneelse None)# predict the noise residualwith torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):noise_pred = self.transformer(  # For an input image (129, 192, 336) (1, 256, 256)latent_model_input,  # [2, 16, 33, 24, 42]t_expand,  # [2]text_states=prompt_embeds,  # [2, 256, 4096]text_mask=prompt_mask,  # [2, 256]text_states_2=prompt_embeds_2,  # [2, 768]freqs_cos=freqs_cis[0],  # [seqlen, head_dim]freqs_sin=freqs_cis[1],  # [seqlen, head_dim]guidance=guidance_expand,return_dict=True,)["x"]# perform guidanceif self.do_classifier_free_guidance:noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdfnoise_pred = rescale_noise_cfg(noise_pred,noise_pred_text,guidance_rescale=self.guidance_rescale,)if do_rf_inversion:v_t = -noise_predv_t_cond = (y_0 - latents) / (1 - t_i)eta_t = eta if start_timestep <= i < stop_timestep else 0.0if decay_eta:eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power  # Decay eta over the loopv_hat_t = v_t + eta_t * (v_t_cond - v_t)# SDE Eq: 17 from https://arxiv.org/pdf/2410.10792latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])else:# compute the previous noisy sample x_t -> x_t-1latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]if callback_on_step_end is not None:callback_kwargs = {}for k in callback_on_step_end_tensor_inputs:callback_kwargs[k] = locals()[k]callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)latents = callback_outputs.pop("latents", latents)prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)# call the callback, if providedif i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):if progress_bar is not None:progress_bar.update()if callback is not None and i % callback_steps == 0:step_idx = i // getattr(self.scheduler, "order", 1)callback(step_idx, t, latents)

http://www.ppmy.cn/news/1569332.html

相关文章

二分/双指针/单调栈队列专题

1.4924. 矩阵 - AcWing题库 一开始打表找规律以为是右上角向左下角递增,但当n很大的时候就不对了,因此我们得去观察 i * i 100000 * (i - j) j * j i * j 这个式子,我们关心的是这个式子的单调性因此我们可以分别将i和j看作常数来对式子进行求导,可以得到 f(i) 2 * i 10…

【JavaEE】Spring(5):Mybatis(上)

一、什么是Mybatis Mybatis是一个持久层的框架&#xff0c;它用来更简单的完成程序和数据库之间的交互&#xff0c;也就是更简单的操作和读取数据库中的数据 在讲解Mybatis之前&#xff0c;先要进行一些准备工作&#xff1a; 1. 为项目添加 Mybatis 相关依赖 2. 创建用户表以…

vue3中el-input无法获得焦点的问题

文章目录 现象两次nextTick()加setTimeout()解决结论 现象 el-input被外层div包裹了&#xff0c;设置autofocus不起作用&#xff1a; <el-dialog v-model"visible" :title"title" :append-to-bodytrue width"50%"><el-form v-model&q…

LabVIEW在电机自动化生产线中的实时数据采集与生产过程监控

在电机自动化生产线中&#xff0c;实时数据采集与生产过程监控是确保生产效率和产品质量的重要环节。LabVIEW作为一种强大的图形化编程平台&#xff0c;可以有效实现数据采集、实时监控和自动化控制。详细探讨如何利用LabVIEW实现这一目标&#xff0c;包括硬件选择、软件架构设…

C# List 列表综合运用实例⁓Hypak原始数据处理编程小结

C# List 列表综合运用实例⁓Hypak原始数据处理编程小结 1、一个数组解决很麻烦引出的问题1.1、RAW 文件尾部数据如下:1.2、自定义标头 ADD 或 DEL 的数据结构如下&#xff1a; 2、程序 C# 源代码的编写和剖析2.1、使用 ref 关键字&#xff0c;通过引用将参数传递&#xff0c;以…

精品PPT | 华为企业数据架构、应用架构及技术架构设计方法

这份PPT详细介绍了华为企业数据架构、应用架构及技术架构的设计方法。它涵盖了数据架构的五大原则&#xff0c;包括数据按对象管理、企业全局视角定义数据架构、遵从企业数据分类管理框架、概念实体结构化数字化以及数据服务化同源共享等&#xff0c;旨在确保数据在企业内的一致…

DOM 操作入门:HTML 元素操作与页面事件处理

DOM 操作入门&#xff1a;HTML 元素操作与页面事件处理 DOM 操作入门&#xff1a;HTML 元素操作与页面事件处理什么是 DOM&#xff1f;1. 如何操作 HTML 元素&#xff1f;1.1 使用 document.getElementById() 获取单个元素1.2 使用 document.querySelector() 和 document.query…

MySQl的日期时间加

MySQL日期相关_mysql 日期加减-CSDN博客MySQL日期相关_mysql 日期加减-CSDN博客 raise notice 查询目标 site:% model:% date:% target:%,t_shipment_date.site,t_shipment_date.model,t_shipment_date.plant_date,v_date_shipment_qty_target;