nn.Module类可用操作
1. model.named_parameters()
# 遍历模型的所有参数并打印它们的名称和形状
for name, param in model.named_parameters():print(f"Parameter Name: {name}, Parameter Shape: {param.shape}")
输出示例:
Parameter Name: conv1.weight, Parameter Shape: torch.Size([64, 3, 3, 3])
Parameter Name: conv1.bias, Parameter Shape: torch.Size([64])
Parameter Name: conv2.weight, Parameter Shape: torch.Size([64, 64, 3, 3])
Parameter Name: conv2.bias, Parameter Shape: torch.Size([64])
Parameter Name: fc.weight, Parameter Shape: torch.Size([10, 64])
Parameter Name: fc.bias, Parameter Shape: torch.Size([10])
2. model.named_modules()
# 遍历模型的所有模块并打印它们的名称和类型
for name, module in model.named_modules():print(f"Module Name: {name}, Module Type: {module.__class__.__name__}")
输出示例:
Module Name: , Module Type: MyModel
Module Name: conv1, Module Type: Conv2d
Module Name: relu, Module Type: ReLU
Module Name: conv2, Module Type: Conv2d
Module Name: fc, Module Type: Linear
3. model.get_submodule()
# 遍历模型的所有模块并打印它们的名称和类型
for name, module in model.named_modules():# 通过name获取子模块sub_module = model.get_submodule(name)print(f"Module Name: {name}, Module Type: {sub_module.__class__.__name__}")
输出示例:
Module Name: , Module Type: MyModel
Module Name: conv1, Module Type: Conv2d
Module Name: relu, Module Type: ReLU
Module Name: conv2, Module Type: Conv2d
Module Name: fc, Module Type: Linear