AlphaFold3 递归函数 dict_map
和 tree_map
,用于对嵌套的数据结构(如字典、列表、元组等)中的每个“叶子节点”应用指定的操作。最后,通过 partial
函数创建了 tensor_tree_map
,专门用于对包含 torch.Tensor
的树形结构进行操作。
源代码:
# With tree_map, a poor man's JAX tree_map
def dict_map(fn, dic, leaf_type):new_dict = {}for k, v in dic.items():if type(v) is dict:new_dict[k] = dict_map(fn, v, leaf_type)else:new_dict[k] = tree_map(fn, v, leaf_type)return new_dictdef tree_map(fn, tree, leaf_type):if isinstance(tree, dict):return dict_map(fn, tree, leaf_type)elif isinstance(tree, list):return [tree_map(fn, x, leaf_type) for x in tree]elif isinstance(tree, tuple):return tuple([tree_map(fn, x, leaf_type) for x in tree])elif isinstance(tree, leaf_type):return fn(tree)else:print(type(tree))raise ValueError("Not supported")tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
代码解读:
核心概念
- 树形数据结构:可以是嵌套的字典、列表、元组等,叶子节点是具体的数据类型(如
torch.Tensor
)。 - 递归操作:对每个叶子节点递归地应用指定的函数
fn
。 - 树遍历:函数通过递归实现对整个树结构的深度优先遍历。
1. dict_map
函数
def dict_map(fn, dic, leaf_type):new_dict = {}for k, v in dic.items():if type(v) is dict:new_dict[k] = dict_map(fn, v, leaf_type)else:new_dict[k] = tree_map(fn, v, leaf_type)return new_dict
功能
- 递归地对字典
dic
中的每个值v
进行操作:- 如果
v
是另一个字典,则递归调用dict_map
。 - 否则,调用
tree_map
处理v
。
- 如果
场景
- 当树的某部分是嵌套字典时,
dict_map
能递归地处理这些嵌套结构。
2. tree_map
函数
def tree_map(fn, tree, leaf_type):if isinstance(tree, dict):return dict_map(fn, tree, leaf_type)elif isinstance(tree, list):return [tree_map(fn, x, leaf_type) for x in tree]elif isinstance(tree, tuple):return tuple([tree_map(fn, x, leaf_type) for x in tree])elif isinstance(tree, leaf_type):return fn(tree)else:print(type(tree))raise ValueError("Not supported")
功能
- 递归遍历整个树结构
tree
,对其每个节点进行如下分类处理:- 字典:调用
dict_map
。 - 列表:对列表中的每个元素递归调用
tree_map
。 - 元组:对元组中的每个元素递归调用
tree_map
,并将结果重新封装为元组。 - 叶子节点:如果节点是
leaf_type
类型,应用函数fn
。 - 其他类型:抛出错误,提示不支持的类型。
- 字典:调用
场景
- 当树中包含多种数据结构(如字典、列表、元组等)时,
tree_map
能通用地递归处理每个叶子节点。
3. tensor_tree_map
函数
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
功能
- 使用
functools.partial
固定leaf_type
为torch.Tensor
,得到一个专门处理torch.Tensor
的tree_map
函数。
场景
- 在深度学习模型中,数据往往以
torch.Tensor
形式存在于复杂的嵌套结构(如模型参数、梯度等)中,tensor_tree_map
可以轻松对每个torch.Tensor
应用特定操作(如转换、归一化等)。
总结
dict_map
:专门处理嵌套字典。tree_map
:通用递归遍历和处理嵌套树形结构(包括字典、列表、元组等)。tensor_tree_map
:专用于对树形结构中的torch.Tensor
应用指定操作。