pytorch数据集和数据处理部分dataset自定义、继承


https://blog.csdn.net/zhenaoxi1077/article/details/80953227  

一、数据加载

在Pytorch 中,数据加载可以通过自己定义的数据集对象来实现。数据集对象被抽象为Dataset类,实现自己定义的数据集需要继承Dataset,并实现两个Python魔法方法。

__getitem__: 返回一条数据或一个样本。   obj[index]等价于obj.__getitem__(index).

__len__: 返回样本的数量。len(obj)等价于obj.__len__().
import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np

class DogCat(data.Dataset):
    def __init__(self,root):
        imgs=os.listdir(root)
        #所有图片的绝对路径
        #这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
        self.imgs=[os.path.join(root, img) for img in imgs]

    def __getitem__(self, index):
        img_path=self.imgs[index]
        #dog->1, cat->0
        label=1 if 'dog' in img_path.split("/")[-1] else 0
        pil_img=Image.open(img_path)
        array=np.asarray(pil_img)
        data=t.from_numpy(array)
        return data,label

    def __len__(self):
        return len(self.image)

dataset=DogCat('N:/百度网盘/kaggle/DogCat')
img,label=dataset[0]#相当于调用dataset.__getitem__(0)
for img,label in dataset:
    print(img.size(),img.float().mean(),label)

二、数据处理transforms

ytorch提供了torchvision。它是一个视觉工具包,提供了很多视觉图像处理的工具。
其中transforms模块提供了对PIL Image对象和Tensor对象的常用操作。

对PIL Image的常见操作如下:

(1)Scale/Resize: 调整尺寸,长宽比保持不变; #Resize
(2)CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片;
(3)Pad: 填充;
(4)ToTensor: 将PIL Image对象转换成Tensor,会自动将【0,255】归一化至【0,1】。
(5)对Tensor的常见操作如下:Normalize: 标准化,即减均值,除以标准差;ToPILImage:将Tensor转为PIL Image.

如果要对图片进行多个操作,可通过Compose将这些操作拼接起来,类似于nn.Sequential.这些操作定义之后是以对象的形式存在,真正使用时需要调用它的__call__方法,类似于nn.Mudule.

例如:要将图片调整为224*224,首先应构建操作trans=Scale((224,224)),然后调用trans(img).

import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
transforms=T.Compose([
    T.Resize(224),  #缩放图片(Image),保持长宽比不变,最短边为224像素
    T.CenterCrop(224), #从图片中间裁剪出224*224的图片
    T.ToTensor(), #将图片Image转换成Tensor,归一化至【0,1】
    T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])  #标准化至【-1,1】,规定均值和方差
])

class DogCat(data.Dataset):
    def __init__(self,root, transforms=None):
        imgs=os.listdir(root)
        self.imgs=[os.path.join(root, img) for img in imgs]
        self.transforms=transforms

    def __getitem__(self, index):
        img_path=self.imgs[index]
        #dog->1, cat->0
        label=1 if 'dog' in img_path.split("/")[-1] else 0
        data=Image.open(img_path)
        if self.transforms:
            data=self.transforms(data)
        return data,label

    def __len__(self):
        return len(self.imgs)        

dataset=DogCat('N:/百度网盘/kaggle/DogCat/', transforms=transforms)
img,label=dataset[0]#相当于调用dataset.__getitem__(0)
for img,label in dataset:
    print(img.size(),label)

三、ImageFolder

下面介绍一个会经常使用到的Dataset——ImageFolder,它的实现和上述DogCat很相似。

四、DataLoader

DataLoader加载数据

Dateset只负责数据的抽象,一次调用__getitem__只返回一个样本。
在训练神经网络时,是对一个batch的数据进行操作,同时还要进行shuffle和并行加速等。
对此,pytorch提供了DataLoader帮助我们实现这些功能。

dataloader=DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)

dataloader是一个可以迭代的对象

五、sampler采样模块