QueryEncoding
类用于在输入张量 x
上添加一种查询序列的特殊编码。这里的查询编码将第一个序列标记为查询序列,并将其与其他序列区分开。以下是代码中的细节和每一步的作用。
源码:
class QueryEncoding(nn.Module):def __init__(self, d_model):super(QueryEncoding, self).__init__()self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)def forward(self, x):B, N, L, K = x.shapeidx = torch.ones((B, N, L), device=x.device).long()idx[:,0,:] = 0 # first sequence is the queryx = x + self.pe(idx)return x
代码解读:
class QueryEncoding(nn.Module):def __init__(self, d_model):super(QueryEncoding, self).__init__()self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)def forward(self, x):B, N, L, K = x.shapeidx = torch.ones((B, N, L), device=x.device).long()idx[:,0,:] = 0 # first sequence is the queryx = x + self.pe(idx)return x
参数说明
d_model
:每个嵌入向量的维度,表示在编码中要用的特征数量。x
:输入张量,维度为(B, N, L, K)
,其中:B
表示 batch 大小,N
表示多序列对齐(MSA)中的序列数(其中第一个序列通常是查询序列),L
表示序列长度,K
表示特征维度(应与d_model
相同)。
初始化部分
self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)
nn.Embedding(2, d_model)
创建了一个大小为(2, d_model)
的嵌入矩阵。该矩阵有两行,每一行是d_model
维度的向量:- 行 0:为查询序列准备的编码向量,
- 行 1:为其他序列准备的编码向量。
- 该嵌入矩阵可以根据输入中的序列类型(查询或非查询)来生成不同的嵌入。
forward
方法解析
1. idx
索引张量的创建
B, N, L, K = x.shape
idx = torch.ones((B, N, L), device=x.device).long()
idx[:,0,:] = 0 # first sequence is the query
idx
是一个张量,形状为(B, N, L)
,用于标记序列类型:- 初始时,
idx
中的值全为1
,即所有位置都默认被标记为非查询序列。 idx[:, 0, :] = 0
将第一个序列的所有位置标记为查询序列,所有 batch 中的第一个序列都是查询序列。
- 初始时,
2. 生成嵌入并相加
x = x + self.pe(idx)
self.pe(idx)
会为idx
中的每个值生成对应的嵌入向量。- 如果
idx
的某个位置是0
,则生成查询编码(第 0 行的嵌入向量), - 如果
idx
的某个位置是1
,则生成非查询编码(第 1 行的嵌入向量)。
- 如果
self.pe(idx)
的形状与x
相同,为(B, N, L, d_model)
。x + self.pe(idx)
:将生成的查询或非查询嵌入向量添加到输入张量x
上,增加了序列的查询信息。
3. 返回带查询编码的信息
最终的 x
(维度 (B, N, L, d_model)
)包含了查询序列和其他序列的特殊编码信息。这种编码有助于模型区分查询序列和其他对齐序列,从而根据查询序列的特征信息做出预测或聚合。