Diffusion models代码解读:入门与实战
前言:前面的博客《论文和代码解读:RF-Inversion 图像/视频编辑技术》,但是原始代码是基于FLUX和SD3实现的,这篇博客讲解一下如何在Hunyuan Video上实现RF-Inversion 。
目录
原理讲解
第一步:Forward
第二步:Reverse
原理讲解
原理部分大家就看看上一篇博客,主要要搞清楚这两个算法图:
第一步:Forward
第一步是从原始的Video中获取latents,所以称为Forward操作:
latents, _ = self.invert(latents=latents.to(device),prompt_embeds=prompt_embeds.to(device),prompt_embeds_2=prompt_embeds_2.to(device),guidance_expand=guidance_expand.to(device),freqs_cis=freqs_cis,prompt_mask=prompt_mask,n_tokens=n_tokens,num_inversion_steps = 30,strength = 1.0,gamma = 0.01,)
也就是算法一的操作:
@torch.no_grad()def invert(self,latents,prompt_embeds,prompt_embeds_2,guidance_expand,freqs_cis,prompt_mask,n_tokens=None,num_inversion_steps: int = 28,strength: float = 1.0,gamma: float = 0.5,height: Optional[int] = None,width: Optional[int] = None,timesteps: List[int] = None,dtype: Optional[torch.dtype] = None,joint_attention_kwargs: Optional[Dict[str, Any]] = None,):r"""Performs Algorithm 1: Controlled Forward ODE from https://arxiv.org/pdf/2410.10792"""dtype = dtype or self.text_encoder.dtypebatch_size = 1# num_channels_latents = self.transformer.config.in_channels // 4# height = height or self.default_sample_size * self.vae_scale_factor# width = width or self.default_sample_size * self.vae_scale_factordevice = self._execution_device# 1. prepare image# image_latents, _ = self.encode_image(image, height=height, width=width, dtype=dtype)# image_latents, latent_image_ids = self.prepare_latents_inversion(# batch_size, num_channels_latents, height, width, dtype, device, image_latents# )# 2. prepare timestepssigmas = np.linspace(1.0, 1 / num_inversion_steps, num_inversion_steps)extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(self.scheduler.set_timesteps, {"n_tokens": n_tokens})timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler,num_inversion_steps,device,timesteps,sigmas=None,**extra_set_timesteps_kwargs,)timesteps, sigmas, num_inversion_steps = self.get_timesteps(num_inversion_steps, strength)# Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dtY_t = latentsy_1 = torch.randn_like(Y_t)N = len(sigmas)# forward ODE loopwith self.progress_bar(total=N - 1) as progress_bar:for i in range(N - 1):t_i = torch.tensor(i / (N), dtype=Y_t.dtype, device=device)timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size)# get the unconditional vector fieldwith torch.autocast(device_type="cuda", dtype=Y_t.dtype):u_t_i = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)Y_t, # [2, 16, 33, 24, 42]timestep, # [2]text_states=prompt_embeds, # [2, 256, 4096]text_mask=prompt_mask, # [2, 256]text_states_2=prompt_embeds_2, # [2, 768]freqs_cos=freqs_cis[0], # [seqlen, head_dim]freqs_sin=freqs_cis[1], # [seqlen, head_dim]guidance=guidance_expand,return_dict=True,)["x"]# get the conditional vector fieldu_t_i_cond = (y_1 - Y_t) / (1 - t_i)# controlled vector field# Eq 8 dY_t = [u_t(Y_t) + γ(u_t(Y_t|y_1) - u_t(Y_t))]dtu_hat_t_i = u_t_i + gamma * (u_t_i_cond - u_t_i)Y_t = Y_t + u_hat_t_i * (sigmas[i] - sigmas[i + 1])progress_bar.update()# return the inverted latents (start point for the denoising loop), encoded image & latent image idsreturn Y_t, latents
第二步:Reverse
第二步从Forward的latents中再进行Reverse操作:
# if is_progress_bar:with self.progress_bar(total=num_inference_steps) as progress_bar:for i, t in enumerate(timesteps):if do_rf_inversion:# ti (current timestep) as annotated in algorithm 2 - i/num_inference_steps.t_i = 1 - t / 1000dt = torch.tensor(1 / (len(timesteps) - 1), device=device)if self.interrupt:continue# expand the latents if we are doing classifier free guidancelatent_model_input = (torch.cat([latents] * 2)if self.do_classifier_free_guidanceelse latents)latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)t_expand = t.repeat(latent_model_input.shape[0])guidance_expand = (torch.tensor([embedded_guidance_scale] * latent_model_input.shape[0],dtype=torch.float32,device=device,).to(target_dtype)* 1000.0if embedded_guidance_scale is not Noneelse None)# predict the noise residualwith torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)latent_model_input, # [2, 16, 33, 24, 42]t_expand, # [2]text_states=prompt_embeds, # [2, 256, 4096]text_mask=prompt_mask, # [2, 256]text_states_2=prompt_embeds_2, # [2, 768]freqs_cos=freqs_cis[0], # [seqlen, head_dim]freqs_sin=freqs_cis[1], # [seqlen, head_dim]guidance=guidance_expand,return_dict=True,)["x"]# perform guidanceif self.do_classifier_free_guidance:noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdfnoise_pred = rescale_noise_cfg(noise_pred,noise_pred_text,guidance_rescale=self.guidance_rescale,)if do_rf_inversion:v_t = -noise_predv_t_cond = (y_0 - latents) / (1 - t_i)eta_t = eta if start_timestep <= i < stop_timestep else 0.0if decay_eta:eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loopv_hat_t = v_t + eta_t * (v_t_cond - v_t)# SDE Eq: 17 from https://arxiv.org/pdf/2410.10792latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1])else:# compute the previous noisy sample x_t -> x_t-1latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]if callback_on_step_end is not None:callback_kwargs = {}for k in callback_on_step_end_tensor_inputs:callback_kwargs[k] = locals()[k]callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)latents = callback_outputs.pop("latents", latents)prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)# call the callback, if providedif i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):if progress_bar is not None:progress_bar.update()if callback is not None and i % callback_steps == 0:step_idx = i // getattr(self.scheduler, "order", 1)callback(step_idx, t, latents)