TTA(Test-Time Augmentation) ,即测试时的数据增强
实现步骤如下:
- 将1个batch的数据通过flips, rotation, scale, etc.等操作生成batches
- 将各个batch分别输入网络
- 每个batch的masks/labels反向转换
- 通过mean, max, gmean, etc.合并各个batch预测的结果
- 最后输出最终的masks/labels
Input| # input batch of images / / /|\ \ \ # apply augmentations (flips, rotation, scale, etc.)| | | | | | | # pass augmented batches through model| | | | | | | # reverse transformations for each batch of masks/labels\ \ \ / / / # merge predictions (mean, max, gmean, etc.)| # output batch of masks/labelsOutput
安装
$ pip install ttach
使用方法如下
import ttach as tta
...
model.load_state_dict(torch.load('models/%s/model.pth' %args.name))
model.eval()
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
...
更多有关使用方法,可以看下面的参考链接
reference
https://github.com/qubvel/ttach