在训练 深度学习 模型之前,样本集的制作是非常重要的环节。在pytorch中,提供了一些接口和类,方便我们定义自己的数据集合,按照对应格式自定义数据集,才可以使用DataLoader加载数据,下面是自定义样本集的整个流程。

“三步走”的策略

Pytorch输入数据PipeLine一般遵循“三步走”的策略,一般pytorch 的数据加载到模型的操作顺序是这样的:

① 创建一个 Dataset 对象。必须实现__len__()、__getitem__()这两个方法,这里面会用到transform对数据集进行扩充。

② 创建一个 DataLoader 对象。它是对DataSet对象进行迭代的,一般不需要事先里面的其他方法了。

③ 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练。

代码示例

1、基本格式

from torch.utils import data

class MyDataset(data.Dataset): # 需要继承data.Dataset
    def __init__(self):
        pass

    def __len__(self):
        # 获取数据集的大小
        pass

    def __getitem__(self, index):
        # 通过索引获取数据
        # 如果是图片,可以返回PIL.image、numpy数组或者Tensor
        # 如果是PIL.image,需要使用transform转化成Tensor
        pass

2、初始化和获取数据集长度

def __init__(self, root):
    self.train_data = []
    train_file = os.path.join(root, 'train/data.txt')
    with open(train_file) as f:
        for line in f:
            path, label =  line.strip().split(' ')
            file_path = os.path.join(root, 'train', path)
            self.train_data.append((file_path, label))

def __len__(self):
    # 获取数据集的大小
    return len(self.train_data)

创建文件 data.txt,并编辑文件路径和label,例如 cat/1.png 0,之后导入图片。

3、迭代读取文件
def __getitem__(self, index):
    from PIL import Image
    file_path, label = self.train_data[index]
    # 读取图片,并转化为np数组
    img = Image.open(file_path)
    return np.array(img), label

4、添加transform参数

dataset = MyDataset(file_path, transform=transforms.ToTensor())
loader = data.DataLoader(dataset, batch_size=10)

for l in loader:
    print(l)

# class MyDataset():
def __init__(self, root, transform):
    self.transform = transform

def __getitem__(self, index):
    ......
    if self.transform:
        img = self.transform(img)
    return img, label

参考文档

https://www.jianshu.com/p/2d9927a70594

https://pytorch.org/docs/stable/data.html

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

本文为 陈华 原创,欢迎转载,但请注明出处:http://www.ichenhua.cn/read/236