参考教程:
https://github.com/pytorch/pytorch/issues/38419
https://zhuanlan.zhihu.com/p/412772439
https://webdataset.github.io/webdataset/gettingstarted/
文章目录
- 背景
- WebDataset
- webdataset的生成
- webdataset的加载
- 示例代码
背景
训练数据通常是以个体的方式存储的,就像我们在第一章下载并处理成png格式后的cifar10数据,它以’xxx.png’的文件形式存放在一个一个独立的空间中。
随着数据集变得越来越大,这样的存放形式就不是那么高效和便捷。在进行模型训练时,也会因为数据的IO瓶颈拖慢训练的速度。
在使用Dataset中的数据时,我们的__getitem__(self, idx)函数会根据数据的index检索数据。在训练时,我们一般都会使用shuffle = True来完成数据的随机读取,这样索引的index也是无效的,当图片数据直接存放在系统上时,对文件的访问需要花费大量的代价。
这个问题可以使用sequential storage formats and sharding来解决。就像tensorflow中使用的TFRecord格式,它将训练集/测试集打包在一起使用,文件里存储的就是序列化的tf.Example。Pytorch是没有这种专属的数据存储格式的。
WebDataset
WebDataset提供了一种序列化存储大规模数据的方法,它将数据保存在tar包中,但是在使用时不需要对tar包进行解压。这种形式提供了高效的I/O,并且不管是在本地还是云端数据上都表现很不错。
webdataset的生成
webdataset是一个tar文件,所以你直接使用tar命令就可以进行文件的生成。
tar --sort=name -cf dataset.tar dataset/
我们也可以使用python调用webdataset的包,来进行文件的写入操作。
以下面的代码为例,下方的代码想要将现有的MNIST数据存放到’mnist.tar’文件中,因此它按照顺序将数据一个一个多写入了文件里。
dataset = torchvision.datasets.MNIST(root="./temp", download=True) # 获得MNIST数据
sink = wds.TarWriter("mnist.tar") # 使用TarWriter,准备将数据写入mnist.tar
for index, (input, output) in enumerate(dataset):if index%1000==0:print(f"{index:6d}", end="\r", flush=True, file=sys.stderr) # 每写入1000个数据,输出一些状态sink.write({"__key__": "sample%06d" % index, # 当前的数据的index"input.pyd": input, # 数据的input"output.pyd": output, # 数据的target})
sink.close() # 关闭当前文件。
这里的sink_write写入了是一个dict,其中’key’这一项决定了你想保存的数据的前缀名,’input.pyd’是你的input的数据的后缀,它同时也决定了你的数据存放的格式。
比如说这里使用的’pyd’,就是我们之前说过的pickle格式,它可以保证数据的完整性,以不压缩的形式存储数据,缺点是不能被其它的语言读取。
在你明确知道数据的类型的情况下,你也可以使用别的格式来存放数据,比如说对于图片,你可以使用‘ppm’,‘png’,'jpg’等格式,对于图片的标签,已知数据标签是整数的形式时,可以使用’cls’格式。
webdataset的加载
对于一个存入tar的webdataset的数据,你可以通过它的url对它进行读取,这个url可以是云端地址,也可以是本地路径。
import webdataset as wds
dataset = wds.WebDataset(url)
我们在讲数据存入tar时,writer根据我们定义的数据格式对数据进行了encode,所以我们直接读取到的数据是还没有decode的数据。
在教程中给了这样一个例子。
直接获取到的数据格式是bytes的格式。
你可以数据进行一些处理,webdataset提供一种链式的数据处理方法,比如上面的数据,你就可以使用下面的方法处理。
dataset = (wds.WebDataset(url).shuffle(100).decode("rgb").to_tuple("jpg;png", "json")
)
这里的decode传入的’rgb’属于headler,webdataset提供了一些自带的imageheadler。帮助使用者进行数据类型转换。imagespecs = { "l8": ("numpy", "uint8", "l"), "rgb8": ("numpy", "uint8", "rgb"), "rgba8": ("numpy", "uint8", "rgba"), "l": ("numpy", "float", "l"), "rgb": ("numpy", "float", "rgb"), "rgba": ("numpy", "float", "rgba"), "torchl8": ("torch", "uint8", "l"), "torchrgb8": ("torch", "uint8", "rgb"), "torchrgba8": ("torch", "uint8", "rgba"), "torchl": ("torch", "float", "l"), "torchrgb": ("torch", "float", "rgb"), "torch": ("torch", "float", "rgb"), "torchrgba": ("torch", "float", "rgba"), "pill": ("pil", None, "l"), "pil": ("pil", None, "rgb"), "pilrgb": ("pil", None, "rgb"), "pilrgba": ("pil", None, "rgba"), }
webdataset提供了多种数据的decode方式的示例,你也可以自定义decode的方法。具体的源码可以查看https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py。
decoders = {"txt": lambda data: data.decode("utf-8"),"text": lambda data: data.decode("utf-8"),"transcript": lambda data: data.decode("utf-8"),"cls": lambda data: int(data),"cls2": lambda data: int(data),"index": lambda data: int(data),"inx": lambda data: int(data),"id": lambda data: int(data),"json": lambda data: json.loads(data),"jsn": lambda data: json.loads(data),"pyd": lambda data: pickle.loads(data),"pickle": lambda data: pickle.loads(data),"pth": lambda data: torch_loads(data),"ten": tenbin_loads,"tb": tenbin_loads,"mp": msgpack_loads,"msg": msgpack_loads,"npy": npy_loads,"npz": lambda data: np.load(io.BytesIO(data)),"cbor": cbor_loads,
}
如果是想要自己定义decode的方法,可以使用以下类似的方法。以下的方法中定义了my_decoder方法,这方法会判断dataset中sample的key是否为jpg,如果不是则忽略,是的话才会返回结果。要注意这里直接获得的数据类型都是bytes,你可以使用类似于**imageio.imread(io.BytesIO(value))**处理数据,将它转为图片。
def my_decoder(key, value):if not key.endswith(".jpg"):return Noneassert isinstance(value, bytes)return valuedataset = wds.WebDataset(url).shuffle(1000).decode(my_decoder)
示例代码
最后给出一个简单的webdataset多进程存储的方法,这里使用的dataset中返回sample是dict形式,最后以pickle的形式存放到指定数量的tar中。
import multiprocessing as mp
import webdataset as wds
import pickle
import osdef write_samples(dataset, tar_index, sample_index,save_dir):for t_idx, s_idx in zip(tar_index, sample_index):fname = os.path.join(save_dir,str(t_idx)+'.tar')stream = wds.TarWriter(fname)for idx in s_idx:data = dataset[idx]sample = {}sample['__key__'] = "sample%06d" % idxfor key, value in data.items():sample[key +'.pyd'] = valuestream.write(sample)stream.close()def dataset2tar(dataset, save_dir,num_tars, num_workers):num_len = len(dataset)data_index = [i for i in range(num_len)]samples = [data_index[i::num_tars] for i in range(num_tars)]tar_index = list(range(num_tars))jobs = []for i in range(num_workers):job = mp.Process(target = write_samples,args=(dataset,tar_index[i::num_workers],samples[i::num_workers],save_dir))job.start()jobs.append(job)for job in jobs:job.join()def pyd_decoder(key, data):if not key.endswith(".pyd"):return Noneresult = pickle.loads(data)return result