1.PyTorch中的DataSet类
因为在up主这里讲了DataSet类的基本使用,那我就想去把DataSet具体学习一下在会过来看up住的内容。
在开发Pytorch项目的时候,项目代码会被分为数据处理模块、模型构建模块、训练模块。其中的数据处理模块的主要任务是构建数据集,而DataSet类就是为了我们方便构建数据集的。而我们的数据集就是一组{x:y}的集合,也就是训练数据和标签,那么我们在使用DataSet类的时候只需要定义好这里的训练数据和标签即可。
我们在编写自己的DataSet类的时候,需要继承DataSet类,并且继承其中的__getitem__() 和 __len__()方法。其中:
- __getitem__():获取样本对,就是获得对应的数据和标签。
- __len()__():得到数据集的长度。
后面用up主的代码来说明:
from torch.utils.data import Dataset
from PIL import Image
import osclass MyData(Dataset): # 继承def __init__(self, root_dir, label_dir): # 初始化self.root_dir = root_dirself.label_dir = label_dirself.path = os.path.join(self.root_dir, self.label_dir)self.img_path = os.listdir(self.path)def __getitem__(self, idx): # 通过索引获取图片img_name = self.img_path[idx]img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)img = Image.open(img_item_path)label = self.label_dirreturn img, labeldef __len__(self):return len(self.img_path)root_dir = "dataset/train"
ants_label_dir = "ants"
ants_dataset=MyData(root_dir,ants_lable_dir)
img,lable=ants_dataset[0]
img.show()
- 在__init__()函数中,我们进行了一些初始化,其中root_dir表示的是数据集的路径,label_dir表示的是标签集的路径,这里的img_path是一个列表,包含了self.path中的条目的名称。
- 在__getitem__()函数是在数据加载的时候被调用的,根据传入的索引idx来获取需要加载的数据,这里返回了数据和标签。这里使用了os库里的path中的join方法,这是让两个路径合并在一起,这样避免了系统复制的路径出现问题(win中需要加一个‘\‘)。
- listdir:listdir函数会返回一个列表,其中包含由path指定的目录中的条目的名称。listdir传递的参数必须是绝对路径。