ConditionedTransitionBlock
类的核心功能是对输入特征进行非线性变换,并通过条件输入(s
)自适应地调整特征的表示。
主要模块包括:
- 自适应层归一化(
ada_ln
):用于动态调整特征分布。 - 两组线性变换(
hidden_gating_linear
和hidden_linear
):用于构造特征空间中的非线性变化。 - 输出线性变换(
output_linear
和output_gating_linear
):控制最终输出特征,并加入门控机制(Gating)。
源代码:
class ConditionedTransitionBlock(nn.Module):"""SwiGLU transition block with adaptive layer norm."""def __init__(self,input_dim: int,n: int = 2):"""Args:input_dim:Channels of the input tensorn:channel expansion factor for hidden dimensions"""super(ConditionedTransitionBlock, self).__init__()self.ada_ln = AdaLN(input_dim)self.hidden_gating_linear = LinearNoBias(input_dim, n * input_dim, init='relu')self.hidden_linear = LinearNoBias(input_dim, n * input_dim, init='default')self.output_linear = Linear(input_dim * n, input_dim, init='default')self.output_gating_linear = Linear(input_dim, input_dim, init='gating')self.output_gating_linear.bias = nn.Parameter(torch.ones(input_dim) * -2.0) # gate values will be ~0.11def forward(self, a, s):a = self.ada_ln(a, s)b = F.silu(self.hidden_gating_linear(a)) * self.hidden_linear(a)# Output projection (from adaLN-Zero)a = F.sigmoid(self.output_gating_linear(s)) * self.output_linear(b)return a