针对UNet2DConditionModel模型
查看Unet的源码,得知Unet的down,mid,up blocks的类型分别是:
down_block_types: Tuple[str] = ("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","DownBlock2D",),mid_block_type: str = "UNetMidBlock2DCrossAttn",up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
查看一下down 下采样的get_down_block方法:
def get_down_block(down_block_type,num_layers,in_channels,out_channels,temb_channels,add_downsample,resnet_eps,resnet_act_fn,attn_num_head_channels,resnet_groups=None,cross_attention_dim=None,downsample_padding=None,dual_cross_attention=False,use_linear_projection=False,only_cross_attention=False,upcast_attention=False,resnet_time_scale_shift="default",
):down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_typeif down_block_type == "DownBlock2D":return DownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "ResnetDownsampleBlock2D":return ResnetDownsampleBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnDownBlock2D":return AttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "CrossAttnDownBlock2D":if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")return CrossAttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,cross_attention_dim=cross_attention_dim,attn_num_head_channels=attn_num_head_channels,dual_cross_attention=dual_cross_attention,use_linear_projection=use_linear_projection,only_cross_attention=only_cross_attention,upcast_attention=upcast_attention,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "SimpleCrossAttnDownBlock2D":if cross_attention_dim is None:raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")return SimpleCrossAttnDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,cross_attention_dim=cross_attention_dim,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "SkipDownBlock2D":return SkipDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnSkipDownBlock2D":return AttnSkipDownBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,temb_channels=temb_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "DownEncoderBlock2D":return DownEncoderBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,resnet_time_scale_shift=resnet_time_scale_shift,)elif down_block_type == "AttnDownEncoderBlock2D":return AttnDownEncoderBlock2D(num_layers=num_layers,in_channels=in_channels,out_channels=out_channels,add_downsample=add_downsample,resnet_eps=resnet_eps,resnet_act_fn=resnet_act_fn,resnet_groups=resnet_groups,downsample_padding=downsample_padding,attn_num_head_channels=attn_num_head_channels,resnet_time_scale_shift=resnet_time_scale_shift,)raise ValueError(f"{down_block_type} does not exist.")
我们看一下该Unet的forward函数:
def forward(self,sample: torch.FloatTensor,timestep: Union[torch.Tensor, float, int],encoder_hidden_states: torch.Tensor,class_labels: Optional[torch.Tensor] = None,attention_mask: Optional[torch.Tensor] = None,cross_attention_kwargs: Optional[Dict[str, Any]] = None,return_dict: bool = True,) -> Union[UNet2DConditionOutput, Tuple]:r"""Args:sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensortimestep (`torch.FloatTensor` or `float` or `int`): (batch) timestepsencoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden statesreturn_dict (`bool`, *optional*, defaults to `True`):Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.Returns:[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. Whenreturning a tuple, the first element is the sample tensor."""# By default samples have to be AT least a multiple of the overall upsampling factor.# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).# However, the upsampling interpolation output size can be forced to fit any upsampling size# on the fly if necessary.default_overall_up_factor = 2**self.num_upsamplers# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`forward_upsample_size = Falseupsample_size = Noneif any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):logger.info("Forward upsample size to force interpolation output size.")forward_upsample_size = True# prepare attention_maskif attention_mask is not None:attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0attention_mask = attention_mask.unsqueeze(1)# 0. center input if necessaryif self.config.center_input_sample:sample = 2 * sample - 1.0# 1. timetimesteps = timestepif not torch.is_tensor(timesteps):# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can# This would be a good case for the `match` statement (Python 3.10+)is_mps = sample.device.type == "mps"if isinstance(timestep, float):dtype = torch.float32 if is_mps else torch.float64else:dtype = torch.int32 if is_mps else torch.int64timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)elif len(timesteps.shape) == 0:timesteps = timesteps[None].to(sample.device)# broadcast to batch dimension in a way that's compatible with ONNX/Core MLtimesteps = timesteps.expand(sample.shape[0])t_emb = self.time_proj(timesteps)# timesteps does not contain any weights and will always return f32 tensors# but time_embedding might actually be running in fp16. so we need to cast here.# there might be better ways to encapsulate this.t_emb = t_emb.to(dtype=self.dtype)emb = self.time_embedding(t_emb)if self.class_embedding is not None:if class_labels is None:raise ValueError("class_labels should be provided when num_class_embeds > 0")if self.config.class_embed_type == "timestep":class_labels = self.time_proj(class_labels)class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)emb = emb + class_emb# 2. pre-processsample = self.conv_in(sample)# 3. downdown_block_res_samples = (sample,)for downsample_block in self.down_blocks:if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:sample, res_samples = downsample_block(hidden_states=sample,temb=emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,)else:sample, res_samples = downsample_block(hidden_states=sample, temb=emb)down_block_res_samples += res_samples# 4. midsample = self.mid_block(sample,emb,encoder_hidden_states=encoder_hidden_states,attention_mask=attention_mask,cross_attention_kwargs=cross_attention_kwargs,)# 5. upfor i, upsample_block in enumerate(self.up_blocks):is_final_block = i == len(self.up_blocks) - 1res_samples = down_block_res_samples[-len(upsample_block.resnets) :]down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]# if we have not reached the final block and need to forward the# upsample size, we do it hereif not is_final_block and forward_upsample_size:upsample_size = down_block_res_samples[-1].shape[2:]if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:sample = upsample_block(hidden_states=sample,temb=emb,res_hidden_states_tuple=res_samples,encoder_hidden_states=encoder_hidden_states,cross_attention_kwargs=cross_attention_kwargs,upsample_size=upsample_size,attention_mask=attention_mask,)else:sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size)# 6. post-processsample = self.conv_norm_out(sample)sample = self.conv_act(sample)sample = self.conv_out(sample)if not return_dict:return (sample,)return UNet2DConditionOutput(sample=sample)
也就是说在:down,mid和up Block时候都有传入text_embedding的信息encoder_hidden_states和cross attention的控制:cross_attention_kwargs.
具体每一个Block的实现看源码