Pytorch学习记录(一)数据加载


数据加载

Dataset

  • 导入Dataset
from torch.utils.data import Dataset
  • 用法(继承Dataset,其中,init、getitem、len需要自定义)
class MyClass(Dataset):
    def __init__ (self, root_dir, label_dir):
        self.root = root_dir
        self.label = label_dir
        self.pic = os.listdir(self.root)
    def __getitem__ (self, index):
        # 注:img的格式为PIL,维度为H W C
        # 如需转为ndarray,可以使用np.array(img)转换
        # 注意!PNG读到是四通道,要用img = img.convert("RGB")转换为三通道
        img = Image.open(self.pic[index])
        label = self.label
        return img, label
    def __len__ (self):
        return len(items)
  • 实例化
dataset_1 = Myclass(root_dir, label_dir)
img, label = dataset_1[0]
  • 补充
    相同的dataset实例可以进行“加法”操作,实现数据集的拼接。

DataLoader

  • 导入DataLoader
from torch.utils.data import DataLoader
  • 用法
# DataLoader可以从自定义Datasets中读取数据,也可以从torchvision.datasets中读取数据
# 下面以CIFAR10为例
test_data = torchvision.datasets.CIFAR10(
    root = "./datasets",
    train = False,
    transform = torchvision.transforms.ToTensor(),
    download = False)

test_loader = DataLoader(
    dataset = test_data,
    batch_size = 4,
    shuffle = True,
    numworkers = 0,
    drop_last = False
)
  • DataLoader按包读取data
# datas是一个包,存放batch_size个img和label
for data in test_loader:
    imgs, labels = data
    print(imgs.shape)
    print(labels)