1、module:
_modules
可以直接获取特定名称的子模块对象,但无法直接获取名称以及遍历所有子模块,所以需要递归,是一个字典,所以其好处就是可以直接用键值索引,坏处是无法直接迭代需要用items()
,其中键是子模块的名称,值是对应的子模块对象。
_moudules.items()
是迭代字典的基本用法,返回可迭代对象,基操如下:
for key, value in my_dict.items():print(key, value)
如果是:
for key in my_dict:print(key)
则迭代一个字典的键。
所以看下面这段代码:
for name, module in model._modules.items():if hasattr(module, "_modules"):.....
name: layer1 (key)
module: Sequential() (value)
2、
children()
方法返回一个迭代器,用于遍历模型的直接子模块。它只返回子模块的对象,不包含子模块的名称。
model.named_children()
返回一个迭代器,该迭代器依次生成由子模块名称和子模块对象组成的元组。
每个元组的结构是 (name, child),其中 name 是子模块的名称,child 是子模块对象本身。
for name, child in model.named_children():
这和上面的效果是一样的。
而且:model._modules[name]
和child
是一样的
3、
modules(): 返回一个迭代器,提供模型及其所有子模块(包括子模块的子模块,这样就不用递归了)
named_modules(): 返回一个迭代器,提供模型及其所有子模块的名称和对象
4、
懂了上面之后我们可以写一个hook
函数
Version1:用直接子模块迭代器
def add_IF_hook(model):for name, child in model.named_children():if isinstance(child, ScaledNeuron):child.register_forward_hook(save_spikes_number)else:add_IF_hook(child) # 需要递归children对应的是直接子模块
Version2:用属性_modules,items()转成迭代器
def add_IF_hook(model):for name, module in model._modules.items():if isinstance(module, ScaledNeuron):model._modules[name],register_forward_hook(save_spikes_number)#或者 module.register_forward_hook(save_spikes_number)else:add_IF_hook(module) #需要递归
上面两个输出的形式是这样的:pipeline是先从layer1—sequential下去,然后到sequential里面找到ScaledNeuron:layer1.2
===> name: 2 child: ScaledNeuron(
(neuron): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False
(surrogate_function): Sigmoid(alpha=1.0, spiking=True)
)
)
Version3:用全部子模块,最简形式:
def add_IF_hook(model):for name, module in model.named_modules():if isinstance(module, ScaledNeuron):print(f"===> name: {name} child: {module}")module.register_forward_hook(save_spikes_number)
最后输出的形式是这样的:
===> name: layer1.2 child: ScaledNeuron(
(neuron): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False
(surrogate_function): Sigmoid(alpha=1.0, spiking=True)
)
)
===> name: layer1.6 child: ScaledNeuron(
(neuron): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False
(surrogate_function): Sigmoid(alpha=1.0, spiking=True)
)
)