调试代码
此阶段主要测试各阶段代码是否有问题。
快速测试
fast_dev_run项可以配置train/val/test阶段的循环次数,跑完就停止代码,快速查看各流程代码正确性,避免train调试后训练又在val/test阶段出错,白白浪费时间和计算成本。
Trainer(fast_dev_run=7)# 每个阶段只循环7次,也可以设置为True,只循环5次。
使用部分数据测试
功能与上面类似,但是运行指定epochs的周期过程,只是train/val/test流程使用部分数据。
# 使用10%训练集 和 1% 验证集
trainer = Trainer(limit_train_batches=0.1, limit_val_batches=0.01)# use 10 batches of train and 5 batches of val
trainer = Trainer(limit_train_batches=10, limit_val_batches=5)
validation_step()提前检查
设置num_sanity_val_steps,lightning会在开始训练前默认先执行再次validation_step,避免训练后验证阶段出错。
trainer = Trainer(num_sanity_val_steps=2)
打印网络模型信息
调用train.fit()后,lightning会自动打印模型信息,如下:
| Name | Type | Params
----------------------------------
0 | net | Sequential | 132 K
1 | net.0 | Linear | 131 K
2 | net.1 | BatchNorm1d | 1.0 K
也可以利用内置的callback ModelSummary打印子模块的信息。需要配置好callback后传入Trainer。
trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])# 打印深度为所有。
打印各模块输入输出的尺寸
在LightningModule中设置example_input_array属性
class LitModel(LightningModule):def __init__(self, *args, **kwargs):self.example_input_array = torch.Tensor(32, 1, 28, 28)
| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512]