AlphaFold3 的distogram loss函数用于训练中比较预测的距离分布(由 logits
表示)与真实距离分布之间的差异。在蛋白质结构预测中,distogram 表示每对残基之间距离落在各个区间(bin)的概率分布,损失函数使用交叉熵来衡量预测分布与真实分布(通过计算残基之间的欧氏距离确定)之间的差异。
源代码:
def softmax_cross_entropy(logits, labels):loss = -1 * torch.sum(labels * F.log_softmax(logits, dim=-1),dim=-1,)return lossdef distogram_loss(logits: Tensor, # (bs, n_tokens, n_tokens, n_bins)all_atom_positions, # (bs, n_tokens * 4, 3)token_mask, # (bs, n_tokens)min_bin: float = 0.0,max_bin: float = 32.0,no_bins: int = 64,eps: float = 1e-6,**kwargs,
) -> Tensor: # (bs,)# TODO: this is an inelegant implementation, integrate with the data pipelinebatch_size, n_tokens = token_mask.shape# Compute pseudo beta and maskall_atom_positions = all_atom_positions.reshape(batch_size, n_tokens, 4, 3)ca_pos = residue_constants.atom_order["CA"]pseudo_beta = all_atom_positions[..., ca_pos, :] # (bs, n_tokens, 3)pseudo_beta_mask = token_mask # (bs, n_tokens)boundaries = torch.linspace(min_bin,max_bin,no_bins - 1,device=logits.device,)boundaries = boundaries ** 2dists = torch.sum((pseudo_beta[..., :, None, :] - pseudo_beta[..., None, :, :]) ** 2,dim=-1,keepdim=True,)true_bi