AF3 Rotation类的map_tensor_fn 方法解读

server/2025/3/30 4:19:06/

AlphaFold3 rigid_utils 模块Rotation类的 map_tensor_fn方法主要作用是对旋转矩阵或四元数上的最后一维应用一个函数 (fn) ,并返回一个新的 Rotation 对象。

源代码:

    def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation:"""Apply a Tensor -> Tensor function to underlying rotation tensors,mapping over the rotation dimension(s). Can be used e.g. to sum outa one-hot batch dimension.Args:fn:A Tensor -> Tensor function to be mapped over the Rotation Returns:The transformed Rotation object""" if(self._rot_mats is not None):rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1)rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))return Rotation(rot_mats=rot_mats, quats=None)elif(self._quats is not None):quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1)return Rotation(rot_mats=None, quats=quats, normalize_quats=False)else:raise ValueError("Both rotations are None")

代码解读:

方法签名
def map_tensor_fn(self, fn: Callable[torch.Tensor, torch.Tensor]) -> Rotation:
  • fn:接收一个 Tensor,返回一个 Tensor,典型用途是对旋转的某个维度做变换,比如求和、加权平均等。

  • 返回值:一个新的 Rotation 对象,里面装着变换后的旋转矩阵 (rot_mats) 或四元数 (quats)。

处理旋转矩阵 (_rot_mats)

如果 self._rot_mats 存在,就走这条分支:

if self._rot_mats is not None:# 把 (batch_size, ..., 3, 3) reshape 成 (batch_size, ..., 9)rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))

✅ 解释
view() 是为了把 3x3 的旋转矩阵摊平成 9 维向量,方便对最后一维应用函数。

rot_mats = torch.stack(list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
)

✅ 解释

  1. torch.unbind():沿最后一维解开成 9 个独立的张量。

  2. map(fn, ...):对每个解开的张量应用 fn

  3. torch.stack():把变换后的 9 个张量重新堆叠回去。

注: torch.unbind 维度 -1 ,torch.stack 维度 +1, 并且都处理相同的维度(-1)。

rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
return Rotation(rot_mats=rot_mats, quats=None)

✅ 解释
把 9 维向量重新 reshaped 成 (3, 3) 矩阵,并用它创建一个新的 Rotation 对象。

处理四元数 (_quats)

如果矩阵不存在,走四元数分支:

elif self._quats is not None:quats = torch.stack(list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1)return Rotation(rot_mats=None, quats=quats, normalize_quats=False)

✅ 解释

  • 逻辑和矩阵类似,先 unbind() 分解四元数的最后一维,对每个部分应用 fn(),再 stack() 堆叠回来。

  • 创建新 Rotation 对象时加了 normalize_quats=False,说明这一步不需要再归一化。

 防错处理

如果两个旋转表示都没有,抛出异常:

else:raise ValueError("Both rotations are None")

总结

map_tensor_fn() 是一种 高阶函数,它能灵活地对旋转矩阵或四元数的最后一维执行各种操作(比如求和、加权、归一化、剪裁等)。

核心逻辑:

  • 矩阵路径 → reshape(9维) → 分解 → 应用函数 → 堆叠 → 恢复3x3

  • 四元数路径 → 分解 → 应用函数 → 堆叠


http://www.ppmy.cn/server/179564.html

相关文章

学习爬虫的第二天——分页爬取并存入表中

阅读提示:我现在还在尝试爬静态页面 一、分页爬取模式 以豆瓣Top250为例: 基础url:豆瓣电影 Top 250https://movie.douban.com/top250 分页参数:?start0(第一页)、?start25(第二页)等 每页显示25条数…

ESP32驱动BMP280和MQ4传感器

文章目录 前言 一、硬件准备 所需组件 连接方式: 二、软件实现 1.所需库 2.代码实现 效果演示 三、上传Qt端 前言 在物联网和环境监测应用中,传感器是获取环境数据的关键组件。本文将详细介绍如何使用ESP32微控制器同时驱动BMP280大气压力传感器…

【机器学习】imagenet2012 数据预处理数据预处理

【机器学习】数据预处理 1. 下载/解压数据2. 数据预处理3. 加载以及训练代码3.1 使用PIL等加载代码3.2 使用OpenCV的方式来一张张加载代码3.3 h5的方式来加载大文件 最后总结 这个数据大约 140个G,128w的训练集 1. 下载/解压数据 首先需要下载数据: 数据最后处理…

React 组件之间的通信

React 组件通信 对于 React 组件之间的通信,我们首先了解一下 React 组件通信的设计理念。 单向数据流(Unidirectional Data Flow) 数据流向明确: 在 React 中,数据总是从父组件流向子组件(通过 Props 传…

3ds Max 2026 新功能全面解析

一、视口性能与交互体验升级 1. Hydra 2.0 视口渲染引擎 3ds Max 2026 引入了 Hydra 2.0,大幅优化了视口渲染性能,尤其是在处理复杂场景和高质量实时预览时,流畅度提升显著。 支持USD(通用场景描述)格式&#xff0c…

【GPUStack】【dify】【RAGflow】:本地部署GPUStack并集成到dify和RAGflow

目录 Nvidia-Driver CUDA NVIDIA Container Toolkit(新版本的docker不用安装,自带) Docker 部署GPUStack Text Embeddings 部署模型库模型 测试 部署开源模型(modelscope) dify 集成 RAGflow集成 Nvidia-Dri…

WEB安全--SQL注入--利用log写入webshell

一、原理: 这也是对之前文章的补充:WEB安全--SQL注入--INTO OUTFILE-CSDN博客 我们可以通过修改MySQL的log文件,用select关键字写入木马文件放在服务器物理地址中,通过访问物理地址getshell。 二、条件: 用户有写入权限…

手绘的思维导图怎么转成电子版思维导图?分享今年刚测试出来的方法

看到一份思维导图很好看,但是有点看不清了,怎么可以复制出新的思维导图?手绘的思维导图怎么转成电子版思维导图? 这时候可以把图片复制成一份可编辑的思维导图文件,我说下解决图片转换成思维导图的思路。只要原图片比…