AF3 tensor_tree_map函数解读

news/2025/1/10 14:57:29/

AlphaFold3 递归函数 dict_map 和 tree_map,用于对嵌套的数据结构(如字典、列表、元组等)中的每个“叶子节点”应用指定的操作。最后,通过 partial 函数创建了 tensor_tree_map,专门用于对包含 torch.Tensor 的树形结构进行操作。

源代码:

# With tree_map, a poor man's JAX tree_map
def dict_map(fn, dic, leaf_type):new_dict = {}for k, v in dic.items():if type(v) is dict:new_dict[k] = dict_map(fn, v, leaf_type)else:new_dict[k] = tree_map(fn, v, leaf_type)return new_dictdef tree_map(fn, tree, leaf_type):if isinstance(tree, dict):return dict_map(fn, tree, leaf_type)elif isinstance(tree, list):return [tree_map(fn, x, leaf_type) for x in tree]elif isinstance(tree, tuple):return tuple([tree_map(fn, x, leaf_type) for x in tree])elif isinstance(tree, leaf_type):return fn(tree)else:print(type(tree))raise ValueError("Not supported")tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)

代码解读:

核心概念
  • 树形数据结构:可以是嵌套的字典、列表、元组等,叶子节点是具体的数据类型(如 torch.Tensor)。
  • 递归操作:对每个叶子节点递归地应用指定的函数 fn
  • 树遍历:函数通过递归实现对整个树结构的深度优先遍历。
1. dict_map 函数
def dict_map(fn, dic, leaf_type):new_dict = {}for k, v in dic.items():if type(v) is dict:new_dict[k] = dict_map(fn, v, leaf_type)else:new_dict[k] = tree_map(fn, v, leaf_type)return new_dict
功能
  • 递归地对字典 dic 中的每个值 v 进行操作:
    • 如果 v 是另一个字典,则递归调用 dict_map
    • 否则,调用 tree_map 处理 v
场景
  • 当树的某部分是嵌套字典时,dict_map 能递归地处理这些嵌套结构。

2. tree_map 函数

def tree_map(fn, tree, leaf_type):if isinstance(tree, dict):return dict_map(fn, tree, leaf_type)elif isinstance(tree, list):return [tree_map(fn, x, leaf_type) for x in tree]elif isinstance(tree, tuple):return tuple([tree_map(fn, x, leaf_type) for x in tree])elif isinstance(tree, leaf_type):return fn(tree)else:print(type(tree))raise ValueError("Not supported")
功能
  • 递归遍历整个树结构 tree,对其每个节点进行如下分类处理:
    1. 字典:调用 dict_map
    2. 列表:对列表中的每个元素递归调用 tree_map
    3. 元组:对元组中的每个元素递归调用 tree_map,并将结果重新封装为元组。
    4. 叶子节点:如果节点是 leaf_type 类型,应用函数 fn
    5. 其他类型:抛出错误,提示不支持的类型。
场景
  • 当树中包含多种数据结构(如字典、列表、元组等)时,tree_map 能通用地递归处理每个叶子节点。

3. tensor_tree_map 函数

tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
功能
  • 使用 functools.partial 固定 leaf_type 为 torch.Tensor,得到一个专门处理 torch.Tensor 的 tree_map函数。
场景
  • 深度学习模型中,数据往往以 torch.Tensor 形式存在于复杂的嵌套结构(如模型参数、梯度等)中,tensor_tree_map 可以轻松对每个 torch.Tensor 应用特定操作(如转换、归一化等)。

总结

  • dict_map:专门处理嵌套字典。
  • tree_map:通用递归遍历和处理嵌套树形结构(包括字典、列表、元组等)。
  • tensor_tree_map:专用于对树形结构中的 torch.Tensor 应用指定操作。


http://www.ppmy.cn/news/1562001.html

相关文章

使用 Flask 和 PaddleOCR 的车牌识别项目安装与打包教程

文章目录 一、环境安装1. 安装 Python2. 创建虚拟环境3. 安装依赖 二、项目实现代码文件 app.py 三、打包流程1. 安装 pyinstaller2. 生成可执行文件3. 验证可执行文件 四、测试车牌识别服务1. 使用 Postman 或 curl 测试2. 校验车牌 五、提供给运维的打包文件 本教程基于 Flas…

边缘计算应用十大领域

边缘计算解决了互联网的网速问题,作为实现边缘计算的基础,那边缘计算是5G与产业互联网、物联网时代的重要技术支撑,也正迎来广阔的增长空间。那么现在我们生活中有哪些领域正在使用边缘计算呢?今天我们来盘点一下我们身边正在使用…

MySQL派生表合并优化的原理和实现

在MySQL中,派生表(Derived Table)是一个常用的技术,用于在SQL查询中临时创建一个表。派生表通常通过子查询实现。然而,派生表可能会导致性能问题,因为它们在执行过程中可能会创建临时表。在优化SQL查询时&a…

Web网页制作之JavaScript的应用

---------------📡🔍K学啦 更多学习资料📕 免费获取--------------- 实现的功能:1.通过登录界面跳转至主页面,用户名统一为“admin”,密码统一为“admin123”,密码可显示或隐藏,输入…

【AI游戏】使用强化学习玩 Flappy Bird:从零实现 Q-Learning 算法(附完整资源)

1. 引言 Flappy Bird 是一款经典的休闲游戏,玩家需要控制小鸟穿过管道,避免碰撞。虽然游戏规则简单,但实现一个 AI 来自动玩 Flappy Bird 却是一个有趣的挑战。本文将介绍如何使用 Q-Learning 强化学习算法来训练一个 AI,使其能够…

SwiftUI 是如何改变 iOS 开发游戏规则的?

SwiftUI 是 Apple 推出的现代化声明式 UI 框架,适用于 iOS、macOS、watchOS 和 tvOS 开发。 SwiftUI 与传统 UIKit(Swift 和 Objective-C) 的优劣势对比: SwiftUI 的优势 一. 声明式编程 优势: SwiftUI 使用声明式语法&#xff…

线性回归的改进-岭回归

2.10 线性回归的改进-岭回归 学习目标 知道岭回归api的具体使用 1 API sklearn.linear_model.Ridge(alpha1.0, fit_interceptTrue,solver"auto", normalizeFalse) 具有l2正则化的线性回归alpha:正则化力度,也叫 λ λ取值:0~1 1~10solver:会根…

爬取电影数据结合Flask实现数据可视化

网站:Scrape | Movie 本案例(爬虫)所需要的模块 requests (网络请求模块)pandas (数据保存模块)parsel (数据解析模块)lxml (数据解析模块) pyecharts (可视化库)flask(框架) 以上的模块均需要通过 指令 pip install 模块名 安装 Explain: 分析此页面的数据为静态的…