数据加载
Dataset
from torch.utils.data import Dataset
- 用法(继承Dataset,其中,init、getitem、len需要自定义)
class MyClass(Dataset):
def __init__ (self, root_dir, label_dir):
self.root = root_dir
self.label = label_dir
self.pic = os.listdir(self.root)
def __getitem__ (self, index):
# 注:img的格式为PIL,维度为H W C
# 如需转为ndarray,可以使用np.array(img)转换
# 注意!PNG读到是四通道,要用img = img.convert("RGB")转换为三通道
img = Image.open(self.pic[index])
label = self.label
return img, label
def __len__ (self):
return len(items)
dataset_1 = Myclass(root_dir, label_dir)
img, label = dataset_1[0]
- 补充
相同的dataset实例可以进行“加法”操作,实现数据集的拼接。
DataLoader
from torch.utils.data import DataLoader
# DataLoader可以从自定义Datasets中读取数据,也可以从torchvision.datasets中读取数据
# 下面以CIFAR10为例
test_data = torchvision.datasets.CIFAR10(
root = "./datasets",
train = False,
transform = torchvision.transforms.ToTensor(),
download = False)
test_loader = DataLoader(
dataset = test_data,
batch_size = 4,
shuffle = True,
numworkers = 0,
drop_last = False
)
# datas是一个包,存放batch_size个img和label
for data in test_loader:
imgs, labels = data
print(imgs.shape)
print(labels)