Kornia:GPU加速Dataload

news/2024/12/2 16:34:32/

会使用多种数据增强提高模型的泛化性。在输入分辨率大的task(如医疗诊断辅助)上,消耗的时间更大。为了提高augment的效率,故使用Kornia进行数据增强。

效果

效果还是比较好的,下面是其他人做的对比实验:

https://blog.csdn.net/OTZ_2333/article/details/118655925

我使用数据集测试了一下提速前后遍历数据集的耗时。测试方法就是将正常训练model的代码去掉前向传播、计算loss、反向传播等操作,只保留数据的加载、预处理、转移到GPU的操作。是用的数据集中总共有1万张图片。

原始的dataload在10epoch下总耗时11904s(下图一),加速后的dataload在10epoch下耗时791s(下图二)。此外,可以看到原始的dataload各个epcoh的耗时很不稳定,短的能有150s,长的能有4000s;而加速后的dataload耗时基本上都在80s左右。

图一:
在这里插入图片描述
图二:
在这里插入图片描述

Code

先定义一个transformer类

import torch
import kornia.augmentation as K
class DataAugmentation(torch.nn.Module):def __init__(self,):super().__init__()self.flip = torch.nn.Sequential(K.RandomHorizontalFlip(p=0.5),K.RandomVerticalFlip(p=0.5),)p=0.8self.transform_geometry = K.ImageSequential(K.RandomAffine(degrees=20, translate=0.1, scale=[0.8,1.2], shear=20, p=p),K.RandomThinPlateSpline(scale=0.25, p=p),random_apply=1, #choose 1)p=0.5self.transform_intensity = K.ImageSequential(K.RandomGamma(gamma=(0.5, 1.5), gain=(0.5, 1.2), p=p),K.RandomContrast(contrast=(0.8,1.2), p=p),K.RandomBrightness(brightness=(0.8,1.2), p=p),random_apply=1, #choose 1)# p=0.5# self.transform_other = K.ImageSequential(#     K.MyRoll(p=0.1), #Mosaic Augmentation using only one image, implemented by using pytorch roll , i.e. cyclic shift#     K.MyCutOut(num_block=5, block_size=[0.1, 0.2], fill='constant', p=0.1),#     random_apply=1, #choose 1# )@torch.no_grad()  # disable gradients for effiencydef forward(self, x):x = self.flip(x)  # BxCxHxWx = self.transform_geometry(x)x = self.transform_intensity(x)# x = self.transform_other(x)return xif __name__=="__main__":input = torch.rand(4,3,255,255)dataaugmentation = DataAugmentation()input = dataaugmentation(input)print(input.shape)

在训练的时候调用

if config.KORNIA:kornia_aug = DataAugmentation()
for batch_idx, data in enumerate(train_progress):X, y_cancer = data[0].to(DEVICE),data[1]optim.zero_grad()# Using mixed precision trainingwith autocast():if config.KORNIA:X = kornia_aug(X)y_cancer_pred, aux_loss = model.forward(X)loss.backward()optim.step()scheduler.step()

注意点

  1. @torch.no_grad()
    用于在数据增强时禁用梯度以提高效率
  2. 加速失败,使用kornia之后反而变慢了
    要先将输入加载到cuda上,再进行数据增强
  3. tensor.cuda的精度冲突
    先用with autocast(),再进行数据增强
  4. 注意要归一化!!!
  5. 可以自己写一些相关数据增强的方式,如我注释掉的K.MyRoll,K.MyCutOut,但是我还不是很懂要咋写,求大佬教
  6. 因为kornia的这种方式输入是B,C,H,W四维的tensor,所以放在__getitem__多半是不行的
  7. 显存开销比较大,有点难蚌┭┮﹏┭┮

http://www.ppmy.cn/news/590447.html

相关文章

Manjaro(kde) 安装nvidia显卡驱动(optimus-manager管理)

1、查看内核版本:系统设置-内核(System Settings->Kernel) 2、安装显卡驱动 sudo pacman -S nvidia 这里会出现很多版本的显卡驱动,选择与你内核版本一致的版本,数字越大代表驱动越新,比如我的就选择&#xff1…

计算机中CPU、内存、缓存的关系

CPU(Central Processing Unit,中央处理器) 内存(Random Access Memory,随机存取存储器) 缓存(Cache) CPU、内存和缓存之间有着密切的关系,它们共同构成了计算机系统的核…

简单写写Puppet的安装配置和使用

Puppet的安装 Puppet是一款开源的配置管理工具,可以自动化管理和部署服务器上的软件和配置。在进行Puppet的安装之前,需要确保系统已安装Ruby和RubyGems。 步骤1:安装Puppet服务器 1.1 在服务器上添加Puppet的软件源 在Ubuntu系统中&…

jquery展开收起(手风琴)

时隔多月又写到手风琴了,不过这次使用jquery的,很简单的三句话。这里就当记录下: 点击一个列表展开,将其他的列表关闭,若展开后再点击则关闭 html结构如下: js如下: 解释:以上jque…

【Jquery手风琴】

让它兄弟元素的子元素收起来 <body><div class"box"><ul class"outerUl"><li><h4>软件教学</h4><ul class"childUl"><li>java</li><li>web前端</li><li>安卓开发</…

html手风琴案例

我们在前端开发的时候经常遇到鼠标经过的时候发生事件&#xff0c;鼠标离开的时候发生事件的效果 下面看看效果图 这下来看看我们的代码吧 html部分 <div id"box"><ul><li><a href"#"><img src"./images/1.jpg" a…

制作手风琴

开发工具与关键技术&#xff1a;DW 隐藏 作者&#xff1a;魏钦 撰写时间&#xff1a;11.28 进入DW界面&#xff0c;ctrln新建HTML骨架&#xff0c;在body标签放入一个div标签&#xff0c;给一个类名&#xff0c;然后div标签里面再嵌套四个div标签作为子元素&#xff0c;四…

手风琴jQuery

静态页面 <script src"https://cdn.bootcdn.net/ajax/libs/jquery/3.6.1/jquery.min.js"></script><style>* {margin: 0;padding: 0;}.box {width: 500px;height: auto;margin: auto;}ul {list-style: none;}.outerUl>li {text-align: center;b…