AttentionPairBias
是 AlphaFold3 的一个注意力机制模块,设计用于实现全自注意力(Full Self-Attention)并结合成对表示的偏置(Pair Bias)。它在 AlphaFold3 的架构中发挥重要作用,特别是在处理蛋白质序列和空间对称性相关的任务时。
源代码:
class AttentionPairBias(nn.Module):"""Full self-attention with pair bias."""def __init__(self,dim: int,c_pair: int = 16,no_heads: int = 8,dropout: float = 0.0,input_gating: bool = True,residual: bool = True,inf: float = 1e8,):"""Initialize the AttentionPairBias module.Args:dim:Total dimension of the model.c_pair:The number of channels for the pair representation. Defaults to 16.no_heads:Number of parallel attention heads. Note that c_atom will be split across no_heads(i.e. each head will have dimension c_atom // no_heads).dropout:Dropout probability on attn_output_weights. Default: 0.0 (no dropout).residual:Whether the module is used as a residual block. Default: True. This affects the initializationof the final projection layer of the MHA attention.input_gating:Whether the single representation should be gated with another single-like representation usingadaptive layer normalization. Default: True."""super().__init__()self.dim = dimself.c_pair = c_pairself.num_heads = no_headsself.dropout = dropoutself.input_gating = input_gatingself.inf = inf# Perform check for dimensionalityassert dim % no_heads == 0, f"the model dimensionality ({dim}) should be divisible by the " \f"number of heads ({no_heads}) "# Projectionsself.input_proj = Noneself.output_proj_linear = Noneif input_gating:self.input_proj = AdaLN(dim)# Output projection from AdaLNself.output_proj_linear = Linear(dim, dim, init='gating')self.output_proj_linear.bias = nn.Parameter(torch.ones(dim) * -2.0) # gate values will be ~0.11else:self.input_proj = LayerNorm(dim)# Attentionself.attention = Attention(c_q=dim,c_k=dim,c_v=dim,c_hidden=dim // no_heads,no_heads=no_heads,gating=True,residual=residual,proj_q_w_bias=True,)# Pair biasself.proj_pair_bias = nn.Sequential(LayerNorm(self.c_pair),LinearNoBias(self.c_pair, self.num_heads, init='normal'))def _prep_biases(self,single_repr: torch.Tensor, # (*, S, N, c_s)pair_repr: torch.Tensor, # (*, N, N, c_z)mask: Optional[torch.Tensor] = None, # (*, N)):"""Prepares the mask and pair biases in the shapes expected by the DS4Science attention.Expected shapes for the DS4Science kernel:# Q, K, V: [Batch, N_seq, N_res, Head, Dim]# res_mask: [Batch, N_seq, 1, 1, N_res]# pair_bias: [Batch, 1, Head, N_res, N_res]"""# Compute the single maskn_seq, n_res, _ = single_repr.shape[-3:]if mask is None:# [*, N_seq, N_res]mask = single_repr.new_ones(single_repr.shape[:-3] + (n_seq, n_res),)else:# Expand mask by N_seq (or samples per trunk)new_shape = (mask.shape[:-1] + (n_seq, n_res)) # (*, N_seq, N_res)mask = mask.unsqueeze(-2).expand(new_shape)mask = mask.to(single_repr.dtype)# [*, N_seq, 1, 1, N_res]mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]# Project pair biases per head from pair representationpair_bias = self.proj_pair_bias(pair_repr) # (bs, n_tokens, n_tokens, n_heads)pair_bias = rearrange(pair_bias, 'b i j h -> b h i j') # # (bs, h, n, n)pair_bias = pair_bias.unsqueeze(-4)return mask_bias, pair_biasdef forward(self,singl