pytorch数据集和数据处理部分dataset自定义、继承
https://blog.csdn.net/zhenaoxi1077/article/details/80953227
一、数据加载
在Pytorch 中,数据加载可以通过自己定义的数据集对象来实现。数据集对象被抽象为Dataset类,实现自己定义的数据集需要继承Dataset,并实现两个Python魔法方法。
__getitem__: 返回一条数据或一个样本。 obj[index]等价于obj.__getitem__(index). __len__: 返回样本的数量。len(obj)等价于obj.__len__().
import torch as t from torch.utils import data import os from PIL import Image import numpy as np class DogCat(data.Dataset): def __init__(self,root): imgs=os.listdir(root) #所有图片的绝对路径 #这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片 self.imgs=[os.path.join(root, img) for img in imgs] def __getitem__(self, index): img_path=self.imgs[index] #dog->1, cat->0 label=1 if 'dog' in img_path.split("/")[-1] else 0 pil_img=Image.open(img_path) array=np.asarray(pil_img) data=t.from_numpy(array) return data,label def __len__(self): return len(self.image) dataset=DogCat('N:/百度网盘/kaggle/DogCat') img,label=dataset[0]#相当于调用dataset.__getitem__(0) for img,label in dataset: print(img.size(),img.float().mean(),label)
二、数据处理transforms
ytorch提供了torchvision。它是一个视觉工具包,提供了很多视觉图像处理的工具。
其中transforms模块提供了对PIL Image对象和Tensor对象的常用操作。
对PIL Image的常见操作如下:
(1)Scale/Resize: 调整尺寸,长宽比保持不变; #Resize
(2)CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片;
(3)Pad: 填充;
(4)ToTensor: 将PIL Image对象转换成Tensor,会自动将【0,255】归一化至【0,1】。
(5)对Tensor的常见操作如下:Normalize: 标准化,即减均值,除以标准差;ToPILImage:将Tensor转为PIL Image.
如果要对图片进行多个操作,可通过Compose将这些操作拼接起来,类似于nn.Sequential.这些操作定义之后是以对象的形式存在,真正使用时需要调用它的__call__方法,类似于nn.Mudule.
例如:要将图片调整为224*224,首先应构建操作trans=Scale((224,224)),然后调用trans(img).
import os from PIL import Image import numpy as np from torchvision import transforms as T transforms=T.Compose([ T.Resize(224), #缩放图片(Image),保持长宽比不变,最短边为224像素 T.CenterCrop(224), #从图片中间裁剪出224*224的图片 T.ToTensor(), #将图片Image转换成Tensor,归一化至【0,1】 T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至【-1,1】,规定均值和方差 ]) class DogCat(data.Dataset): def __init__(self,root, transforms=None): imgs=os.listdir(root) self.imgs=[os.path.join(root, img) for img in imgs] self.transforms=transforms def __getitem__(self, index): img_path=self.imgs[index] #dog->1, cat->0 label=1 if 'dog' in img_path.split("/")[-1] else 0 data=Image.open(img_path) if self.transforms: data=self.transforms(data) return data,label def __len__(self): return len(self.imgs) dataset=DogCat('N:/百度网盘/kaggle/DogCat/', transforms=transforms) img,label=dataset[0]#相当于调用dataset.__getitem__(0) for img,label in dataset: print(img.size(),label)
三、ImageFolder
下面介绍一个会经常使用到的Dataset——ImageFolder,它的实现和上述DogCat很相似。
四、DataLoader
DataLoader加载数据
Dateset只负责数据的抽象,一次调用__getitem__
只返回一个样本。
在训练神经网络时,是对一个batch的数据进行操作,同时还要进行shuffle和并行加速等。
对此,pytorch
提供了DataLoader
帮助我们实现这些功能。
dataloader=DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataloader是一个可以迭代的对象