从参考数据集到查询数据集的一致性标签转移是单细胞研究的基础。与传统的注释方法相比,基于深度学习的方法速度更快,自动化程度更高。基于自编码器架构的一系列单细胞分析工具已经开发出来,但这些工具难以在深度和可解释性之间取得平衡。作者提出了TOSICA,它可以使用生物学上可理解的实体(如通路pathways,调控regulons)来解释细胞类型的注释。TOSICA实现了快速准确的一站式注释和批次整合,同时为理解发育和疾病进展期间的细胞行为提供了生物学上可解释的见解。通过将 TOSICA 应用于肿瘤浸润性免疫细胞(tumor-infiltrating immune cells)和 COVID-19 中的 CD14+ 单核细胞(CD14+ monocytes)的 scRNA-seq 数据来展示 TOSICA 的优势,以揭示与疾病进展和严重程度相关的稀有细胞类型、异质性和动态轨迹。
来自:Transformer for one stop interpretable cell type annotation
目录
- 背景概述
- 方法
- TOSICA概述
- TOSICA模型
- 细节
- 其他任务
背景概述
scRNA-seq分析的一个重要步骤是通过聚类来识别细胞群体或类型,细胞类型注释可以帮助我们发现细胞异质性。目前已经提出了许多无监督的scRNA-seq聚类方法,随后便是耗时的注释。传统方法通常包括预处理、降维、聚类、差异分析和基于先验知识的手动注释。当基于一小组标记基因手动注释亚型时,由于差异细微,同一亚型可能会被注释为其他类型。此外,当不能同时获得所有样本时,人们希望对第一批次数据上的细胞类型进行分类,并使用模型对将来获得的批次数据进行注释(不需要再次修改模型)。因此,将细胞类型注释从参考数据集转移到新的具有一致性的查询数据集越来越重要。大多数现有的工具虽然可以处理大型数据集,但是它们涉及信息组合和层之间的非线性激活,使最终学习的特征变得抽象,无法追溯到输入特征。
深度学习单细胞分析方法参考:https://github.com/OmicsML/awesome-deep-learning-single-cell-papers
例如,在整个自编码器的深度处理阶段,尺寸的变化和特征的非线性压缩导致了不可解释的潜在空间,以及特征分辨率的损失。然而,Transformer框架不涉及维度减少,从而使所有注意力层都可追溯到原始输入特征,使模型具有可解释性。因此,作者选择Transformer作为框架,在参考数据集和查询数据集之间开发一种新的标签转移工具,命名为Transformer for One Stop Interpredictable Celltype Annotation(TOSICA)。
TOSICA用于可解释细胞类型注释。通过将注意力与先前的生物学知识联系起来,在没有任何批次信息的情况下,TOSICA以批次不敏感的方式可解释地整合批次,并注释单细胞数据,同时保留生物异质性。当在许多数据集上测试时,TOSICA提供了基于注意力的特征基因和途径pathways,额外的,它还自动消除了批次效应,这可能是细胞类型直接映射到基因的结果。
方法
TOSICA概述
TOSICA是一种基于多头自注意力的自动细胞型注释器。通过监督训练,模型学习了从基因表达到细胞类型的映射函数,同时将高维稀疏的基因表达空间转移到低维和密集的特征空间。
TOSICA由三部分组成:细胞嵌入层、多头自注意力层和细胞类型分类器。TOSICA的第一步是细胞embedding,它将基因变换为tokens,其变换矩阵最初是一个完全连接的权重矩阵。但是,变换矩阵随后被基于专家知识的矩阵mask(例如,基因对通路的隶属度),在mask后的转换矩阵中只保留基因和通路之间的稀疏连接,用于训练和学习。因此,一个token只接收来自特定genes的信息,该信息代表一条pathway。该操作被并行地重复m次,并且所有m个token向量被合并在一起。然后,该token矩阵被附加class token(CLS),然后在接下来的网络层期间提取信息,并用于预测细胞类型。接下来,这个新的合并矩阵成为多头自注意力层的输入,其中查询(Q)、键(K)和值(V)通过线性投影得到。由于生物过程是复杂和相互作用的,通路之间存在微妙的关系,这些关系由Q和K计算,称为注意力得分(A)。
CLS和通路token之间的注意力得分意味着后者对细胞类型的分类和鉴定的重要性。输出矩阵(O)是A和V运算的结果,代表每条通路及其相互作用伙伴的综合得分。此时,O中的CLS已经收集了各种通路的信息,然后转换为细胞类型概率的向量。Transformer在可解释性方面取得了成功,得益于自注意机制,该机制计算token之间的关系(称为注意力)。TOSICA计算细胞类型分类token(CLS)和细胞的签名(例如通路token)之间的注意力。此外,CLS和通路token之间的注意力得分,用作细胞的注意力embedding,可以进行各种下游分析。
- 模型架构。该模型是根据scRNA-seq数据和每个细胞的细胞类型标签进行训练的。基于数据库或专家知识,使用带有mask的可学习embedding来将参考输入数据(n个基因,n HVG)转换为表示每个基因集(GS,gene set)的k个输入token,其中添加了class token(CLS)。在注意力函数中,查询(Q)、键(K)和值(V)矩阵是从这些GS和CLS组合的token线性投影得到的,并且权重(注意力,A)是通过Q与相应K的兼容性函数计算的,然后分配给每个V以计算输出(O)。在每个多头自注意层中,注意力函数并行执行H次。O的CLS被认为是每个细胞的潜在空间,被用作细胞类型分类器的输入。同时,CLS对基因集(GS)token的注意力被称为注意力得分,并用于cell embedding。
TOSICA模型
对于每个细胞,nnn个基因的表达量e∈Rne\in R^{n}e∈Rn首先经过变换矩阵WWW编码为kkk个token t∈Rkt\in R^{k}t∈Rk,变换矩阵在训练期间是可学习的。为了实现每个token代表不同的通路pathway,线性变换的权重矩阵被mask,只有属于该通路的基因,才能保存连接。因此,作者利用专家知识生成一个掩码矩阵MMM, MMM由0和1组成,与WWW具有相同的维数。经过mask的变换矩阵W′W'W′是WWW和MMM对应位置的乘积:W′=W∗Mt=W′⋅eW'=W*M\\t=W'\cdot eW′=W∗Mt=W′⋅e然后并行地重复mmm次嵌入操作,以增加嵌入空间的维数,其中mmm是可手动设置的超参数,默认值为48。然后将所有的数据按列连接起来:T=columnbind(t1,t2,...,tm)∈Rk×mT=columnbind(t_{1},t_{2},...,t_{m})\in R^{k\times m}T=columnbind(t1,t2,...,tm)∈Rk×m其中,TTT代表pathway token matrix。TTT中的每一行,即一个token,代表一条pathway。
接下来,一个可学习的parameter class token(CLS)按行排列到TTT的顶部,并生成输入矩阵III:I=rowbind(CLS,T),CLS∈Rm,I∈R(1+k)×mI=rowbind(CLS,T),CLS\in R^{m},I\in R^{(1+k)\times m}I=rowbind(CLS,T),CLS∈Rm,I∈R(1+k)×m注意力函数可以描述为将query和一组key-value pairs映射到输出。在多头注意力层中,query,key,value矩阵分别从III线性投影,三个投影矩阵为Wq,k,vW_{q,k,v}Wq,k,v:Q,K,V=Wq,k,v⋅IQ,K,V∈R(1+k)×mQ,K,V=W_{q,k,v}\cdot I\\Q,K,V\in R^{(1+k)\times m}Q,K,V=Wq,k,v⋅IQ,K,V∈R(1+k)×m注意力矩阵AAA被QQQ和对应的KKK计算:A=softmax(Q⋅KTdK)A=softmax(\frac{Q\cdot K^{T}}{\sqrt{d_{K}}})A=softmax(dKQ⋅KT)其中,dK=md_{K}=mdK=m,以及:softmax(zi)=exp(zi)∑jexp(zj)softmax(z_{i})=\frac{exp(z_{i})}{\sum_{j}exp(z_{j})}softmax(zi)=∑jexp(zj)exp(zi)然后AAA被分配到VVV输出OOO:O=Attn(Q,K,V)=A⋅VO=Attn(Q,K,V)=A\cdot VO=Attn(Q,K,V)=A⋅V上述操作执行HHH次,再拼接:O=WO⋅columnbind(head1,..,headH),O∈R(1+k)×mheadi=Attn(WiQ⋅I,WiK⋅I,WiV⋅I)O=W^{O}\cdot columnbind(head_{1},..,head_{H}),O\in R^{(1+k)\times m}\\head_{i}=Attn(W_{i}^{Q}\cdot I,W_{i}^{K}\cdot I,W_{i}^{V}\cdot I)O=WO⋅columnbind(head1,..,headH),O∈R(1+k)×mheadi=Attn(WiQ⋅I,WiK⋅I,WiV⋅I)用OOO的CLS作为全连通网络的输入,然后用softmax函数得到细胞类型的概率:p=softmax(Wp⋅CLS)p=softmax(W_{p}\cdot CLS)p=softmax(Wp⋅CLS)此外,CLS对pathway的注意力权重被抽象为cell的低维特征。
细节
本工作使用的mask矩阵是基于GSEA的知识数据集(http://www.gsea-msigdb.org/gsea/downloads.jsp),对于MMM,行代表基因,列代表基因集合(或pathway),如果基因iii属于基因集合jjj,则Mi,j=1M_{i,j}=1Mi,j=1。在重复mmm次时,代表要设计mmm个不同的变换矩阵WWW。
对于比较的方法,作者为它们提供了相同的训练(参考)数据集和测试(查询)数据集。它们使用推荐的默认参数运行。
作者使用KL散度衡量参考数据集和查询数据集的不平衡度:DKL=∑i=1nclog2(qi)pi−∑i=1nclog2(pi)piD_{KL}=\sum_{i=1}^{nc}log_{2}(q_{i})p_{i}-\sum_{i=1}^{nc}log_{2}(p_{i})p_{i}DKL=i=1∑nclog2(qi)pi−i=1∑nclog2(pi)pi其中,ncncnc为细胞类型数,pip_{i}pi为训练集中被标记为细胞类型iii的样本数占训练样本总数的比,qiq_{i}qi为测试集中被标记为细胞类型iii的样本数占测试样本总数的比。
注意力矩阵的处理类似于scRNA-seq的预处理,首先,对注意力矩阵规范化,然后,将注意力矩阵作为输入,进行PCA分析(选最大主成分抽象到1维),以及基于PCA结果构建最近邻图进行UMAP可视化。
对于未知细胞的识别,如果最高预测概率值低于95%,则该样本被标记为Unknown。
其他任务
对于批次整合,利用scIB平台比较各个方法,对于现有的方法,scIB中输入full features,对于TOSICA,将注意力嵌入作为scIB平台的输入。
包括轨迹分析,作者使用注意力矩阵作为输入,我们需要知道,CLS对pathway的注意力权重被表示为cell的低维特征。最终每个样本的embedding维度为(1+k)×1(1+k)\times 1(1+k)×1。