前言
以前使用keras的时候有一个很方便的提前终止类,而pytorch每次都要自己写一次,因此我整理了一个简单通用的代码,需要提前终止功能时,只需cv一下,避免了每次重复写的麻烦。
代码
class EarlyStopping(object):def __init__(self, monitor: str = 'val_loss', mode: str = 'min', patience: int = 1):""":param monitor: 要监测的指标,只有传入指标字典才会生效:param mode: 监测指标的模式,min 或 max:param patience: 最大容忍次数example:```python# Initializeearlystopping = EarlyStopping(mode='max', patience=5)# callif earlystopping(val_accuracy):return;# save checkpointstate = {'model': model,'earlystopping': earlystopping.state_dict(),'optimizer': optimizer}torch.save(state, 'checkpoint.pth')checkpoint = torch.load('checkpoint.pth')earlystopping.load_state_dict(checkpoint['earlystopping'])```"""self.monitor = monitorself.mode = modeself.patience = patienceself.__value = -math.inf if mode == 'max' else math.infself.__times = 0def state_dict(self) -> dict:""":保存状态,以便下次加载恢复torch.save(state_dict, path)"""return {'monitor': self.monitor,'mode': self.mode,'patience': self.patience,'value': self.__value,'times': self.__times}def load_state_dict(self, state_dict: dict):""":加载状态:param state_dict: 保存的状态"""self.monitor = state_dict['monitor']self.mode = state_dict['mode']self.patience = state_dict['patience']self.__value = state_dict['value']self.__times = state_dict['times']def reset(self):""":重置次数"""self.__times = 0def __call__(self, metrics) -> bool:""":param metrics: 指标字典或数值标量:return: 返回bool标量,True表示触发终止条件"""if isinstance(metrics, dict):metrics = metrics[self.monitor]if (self.mode == 'min' and metrics <= self.__value) or (self.mode == 'max' and metrics >= self.__value):self.__value = metricsself.__times = 0else:self.__times += 1if self.__times >= self.patience:return Truereturn False
使用方法
# 初始化,监测模式为最大,最多容忍5次
early_stop = EarlyStopping(mode='max', patience=5)# 整体结构如下:
for epoch in range(1, 21):train_loss, train_acc = train_one_epoch(...)val_loss, val_acc = validate(...)# 如果触发终止条件,就结束训练if early_stop(val_acc):return
- 保存
def save_checkpoint(..., early_stop, ...):state = {... # 其他保存的东西... 'early_stop': early_stop.state_dict(),...}torch.save(state, 'checkpoint.pth')def load_checkpoint(..., early_stop, ...):checkpoint = torch.load('checkpoint.pth')...early_stop.load_state_dict(checkpoint['early_stop'])...# 使用early_stop = EarlyStopping(mode='max', patience=5)
load_checkpoint(..., early_stop, ...)
# 如果需要的话,重置次数
early_stop.reset()