AlphaFold3 中的MSAModule
类 是一个用于处理多序列比对(MSA)的模块,核心功能是通过 MSAModuleBlock
堆叠和梯度检查点优化,实现对 MSA 表征和配对表征的高效计算。调用该类最终返回更新后的配对表征z,更新后的z含有MSA特征和目的蛋白质序列信息。
源代码:
class MSAModule(nn.Module):def __init__(self,no_blocks: int = 4,c_msa: int = 64,c_token: int = 384,c_z: int = 128,c_hidden: int = 32,no_heads: int = 8,c_hidden_tri_mul: int = 128,c_hidden_pair_attn: int = 32,no_heads_tri_attn: int = 4,transition_n: int = 4,pair_dropout: float = 0.25,fuse_projection_weights: bool = False,clear_cache_between_blocks: bool = False,blocks_per_ckpt: int = 1,inf: float = 1e8):"""Initialize the MSA module.Args:no_blocks:number of MSAModuleBlocksc_msa:MSA representation dimc_token:Single representation dimc_z:pair representation dimc_hidden:hidden representation dimno_heads:number of heads in the pair averagingc_hidden_tri_mul:hidden dimensionality of triangular multiplicative updatesc_hidden_pair_attn:hidden dimensionality of triangular attentionno_heads_tri_attn:number of heads in triangular attentiontransition_n:multiplication factor for the hidden dim during the transitionpair_dropout:dropout rate within the pair stackfuse_projection_weights:whether to use FusedTriangleMultiplicativeUpdate or notblocks_per_ckpt:Number of blocks per checkpoint. If None, no checkpointing is used.clear_cache_between_blocks:Whether to clear CUDA's GPU memory cache between blocks of thestack. Slows down each block but can reduce fragmentation"""super(MSAModule, self).__init__()self.blocks = nn.ModuleList([MSAModuleBlock(c_msa=c_msa,c_z=c_z,c_hidden=c_hidden,no_heads=no_heads,c_hidden_tri_mul=c_hidden_tri_mul,c_hidden_pair_attn=c_hidden_pair_attn,no_heads_tri_attn=no_heads_tri_attn,transition_n=transition_n,pair_dropout=pair_dropout,fuse_projection_weights=fuse_projection_weights,inf=inf)for _ in range(no_blocks)