第六章 番外篇:webdataset

news/2024/11/30 9:48:32/

参考教程:
https://github.com/pytorch/pytorch/issues/38419
https://zhuanlan.zhihu.com/p/412772439
https://webdataset.github.io/webdataset/gettingstarted/


文章目录

  • 背景
  • WebDataset
  • webdataset的生成
  • webdataset的加载
  • 示例代码

背景

训练数据通常是以个体的方式存储的,就像我们在第一章下载并处理成png格式后的cifar10数据,它以’xxx.png’的文件形式存放在一个一个独立的空间中。
随着数据集变得越来越大,这样的存放形式就不是那么高效和便捷。在进行模型训练时,也会因为数据的IO瓶颈拖慢训练的速度。
在使用Dataset中的数据时,我们的__getitem__(self, idx)函数会根据数据的index检索数据。在训练时,我们一般都会使用shuffle = True来完成数据的随机读取,这样索引的index也是无效的,当图片数据直接存放在系统上时,对文件的访问需要花费大量的代价。
这个问题可以使用sequential storage formats and sharding来解决。就像tensorflow中使用的TFRecord格式,它将训练集/测试集打包在一起使用,文件里存储的就是序列化的tf.Example。Pytorch是没有这种专属的数据存储格式的。

WebDataset

WebDataset提供了一种序列化存储大规模数据的方法,它将数据保存在tar包中,但是在使用时不需要对tar包进行解压。这种形式提供了高效的I/O,并且不管是在本地还是云端数据上都表现很不错。

webdataset的生成

webdataset是一个tar文件,所以你直接使用tar命令就可以进行文件的生成。

tar --sort=name -cf dataset.tar dataset/

我们也可以使用python调用webdataset的包,来进行文件的写入操作。
以下面的代码为例,下方的代码想要将现有的MNIST数据存放到’mnist.tar’文件中,因此它按照顺序将数据一个一个多写入了文件里。

dataset = torchvision.datasets.MNIST(root="./temp", download=True) # 获得MNIST数据
sink = wds.TarWriter("mnist.tar") # 使用TarWriter,准备将数据写入mnist.tar
for index, (input, output) in enumerate(dataset):if index%1000==0:print(f"{index:6d}", end="\r", flush=True, file=sys.stderr) # 每写入1000个数据,输出一些状态sink.write({"__key__": "sample%06d" % index, # 当前的数据的index"input.pyd": input, # 数据的input"output.pyd": output, # 数据的target})
sink.close() # 关闭当前文件。

这里的sink_write写入了是一个dict,其中’key’这一项决定了你想保存的数据的前缀名,’input.pyd’是你的input的数据的后缀,它同时也决定了你的数据存放的格式。
比如说这里使用的’pyd’,就是我们之前说过的pickle格式,它可以保证数据的完整性,以不压缩的形式存储数据,缺点是不能被其它的语言读取。
在你明确知道数据的类型的情况下,你也可以使用别的格式来存放数据,比如说对于图片,你可以使用‘ppm’,‘png’,'jpg’等格式,对于图片的标签,已知数据标签是整数的形式时,可以使用’cls’格式。

webdataset的加载

对于一个存入tar的webdataset的数据,你可以通过它的url对它进行读取,这个url可以是云端地址,也可以是本地路径。

import webdataset as wds
dataset = wds.WebDataset(url)

我们在讲数据存入tar时,writer根据我们定义的数据格式对数据进行了encode,所以我们直接读取到的数据是还没有decode的数据。
在教程中给了这样一个例子。
在这里插入图片描述
直接获取到的数据格式是bytes的格式。
你可以数据进行一些处理,webdataset提供一种链式的数据处理方法,比如上面的数据,你就可以使用下面的方法处理。

dataset = (wds.WebDataset(url).shuffle(100).decode("rgb").to_tuple("jpg;png", "json")
)

这里的decode传入的’rgb’属于headler,webdataset提供了一些自带的imageheadler。帮助使用者进行数据类型转换。imagespecs = { "l8": ("numpy", "uint8", "l"), "rgb8": ("numpy", "uint8", "rgb"), "rgba8": ("numpy", "uint8", "rgba"), "l": ("numpy", "float", "l"), "rgb": ("numpy", "float", "rgb"), "rgba": ("numpy", "float", "rgba"), "torchl8": ("torch", "uint8", "l"), "torchrgb8": ("torch", "uint8", "rgb"), "torchrgba8": ("torch", "uint8", "rgba"), "torchl": ("torch", "float", "l"), "torchrgb": ("torch", "float", "rgb"), "torch": ("torch", "float", "rgb"), "torchrgba": ("torch", "float", "rgba"), "pill": ("pil", None, "l"), "pil": ("pil", None, "rgb"), "pilrgb": ("pil", None, "rgb"), "pilrgba": ("pil", None, "rgba"), }
webdataset提供了多种数据的decode方式的示例,你也可以自定义decode的方法。具体的源码可以查看https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py。

decoders = {"txt": lambda data: data.decode("utf-8"),"text": lambda data: data.decode("utf-8"),"transcript": lambda data: data.decode("utf-8"),"cls": lambda data: int(data),"cls2": lambda data: int(data),"index": lambda data: int(data),"inx": lambda data: int(data),"id": lambda data: int(data),"json": lambda data: json.loads(data),"jsn": lambda data: json.loads(data),"pyd": lambda data: pickle.loads(data),"pickle": lambda data: pickle.loads(data),"pth": lambda data: torch_loads(data),"ten": tenbin_loads,"tb": tenbin_loads,"mp": msgpack_loads,"msg": msgpack_loads,"npy": npy_loads,"npz": lambda data: np.load(io.BytesIO(data)),"cbor": cbor_loads,
}

如果是想要自己定义decode的方法,可以使用以下类似的方法。以下的方法中定义了my_decoder方法,这方法会判断dataset中sample的key是否为jpg,如果不是则忽略,是的话才会返回结果。要注意这里直接获得的数据类型都是bytes,你可以使用类似于**imageio.imread(io.BytesIO(value))**处理数据,将它转为图片。

def my_decoder(key, value):if not key.endswith(".jpg"):return Noneassert isinstance(value, bytes)return valuedataset = wds.WebDataset(url).shuffle(1000).decode(my_decoder)

示例代码

最后给出一个简单的webdataset多进程存储的方法,这里使用的dataset中返回sample是dict形式,最后以pickle的形式存放到指定数量的tar中。

import multiprocessing as mp
import webdataset as wds
import pickle
import osdef write_samples(dataset, tar_index, sample_index,save_dir):for t_idx, s_idx in zip(tar_index, sample_index):fname = os.path.join(save_dir,str(t_idx)+'.tar')stream = wds.TarWriter(fname)for idx in s_idx:data = dataset[idx]sample = {}sample['__key__'] = "sample%06d" % idxfor key, value in data.items():sample[key +'.pyd'] = valuestream.write(sample)stream.close()def dataset2tar(dataset, save_dir,num_tars, num_workers):num_len = len(dataset)data_index = [i for i in range(num_len)]samples = [data_index[i::num_tars] for i in range(num_tars)]tar_index = list(range(num_tars))jobs = []for i in range(num_workers):job = mp.Process(target = write_samples,args=(dataset,tar_index[i::num_workers],samples[i::num_workers],save_dir))job.start()jobs.append(job)for job in jobs:job.join()def pyd_decoder(key, data):if not key.endswith(".pyd"):return Noneresult = pickle.loads(data)return result

http://www.ppmy.cn/news/420736.html

相关文章

WWDC2023 Metal swift 头显ARKit支持c c++ 开发

1 今年WWDC,我们看见了苹果的空间计算设备,visionOS也支持了c c API. 这有什么好处呢,不是说能够吸引更多c c开发者加入苹果开发者阵营,而是我们过去的很多软件,可以轻松对接到苹果的头显设备,让我们的软件…

大模型 Transformer介绍-Part1

众所周知,transformer 架构是自然语言处理 (NLP) 领域的一项突破。它克服了 seq-to-seq 模型(如 RNN 等)无法捕获文本中的长期依赖性的局限性。事实证明,transformer 架构是 BERT、GPT 和 T5 及其变体等革命性架构的基石。正如许多…

手机客户端添加设备时需要扫描二维码,如何查找二维码

扫描设备机身上的二维码。如果设备机身上没有二维码,可登录设备扫描: 摄像机网页界面: 配置>网络配置>宇视云 网络视频录像机显示器/网页界面: 网络配置>宇视云 一体机:系统配置>网络配置>宇视云

深入理解Java虚拟机jvm-运行时数据区域(基于OpenJDK12)

运行时数据区域 运行时数据区域程序计数器Java虚拟机栈本地方法栈Java堆方法区运行时常量池直接内存 运行时数据区域 Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域。这些区域 有各自的用途,以及创建和销毁的时间,有的…

Python二维码扫描

模块准备 1.pyzbar pip install pyzbar 2.PIL 注意:PIL只支持Python2,所以我们需要安装Pillow pip install Pillow 代码示例 from PIL import Image import pyzbar.pyzbar as pyzbardef QRcode_message(image):img Image.open(image) # 读取图片…

【01Studio MaixPy AI K210】14.二维码识别

目录 导包: image库 例程: 导包: import sensor,lcd,time image库 #查找 roi 区域内的所有二维码并返回一个 image.qrcode 的对象列表。 image.find_qrcodes([roi])#返回一个矩形元组(x,y,w,h) qrcode.rect()#返…

js打开手机摄像头实现扫描二维码功能

js打开手机摄像头 在js中使用navigator.getUserMedia这个api 可以点击查看api的使用navigator.getUserMedia 这个api是结合https协议使用的,在http协议中摄像头是无法打开的 var video document.querySelector(video); navigator.getUserMedia({audio: false,vid…

Android使用ZXing实现二维码的扫描和创建

一、引用依赖 1、zxing 生成二维码的依赖 implementation com.google.zxing:core:3.3.3 implementation com.journeyapps:zxing-android-embedded:3.6.0 2、zxing 扫码二维码依赖 implementation pub.devrel:easypermissions:1.0.1 implementation cn.bingoogolapple:bga-p…