04_0. 图片分类数据集
%matplotlib inline import torch import torchvision from torch.utils import data from torchvision import transforms from d2l import torch as d2l d2l.use_svg_display()#使用svg来显示图片,清晰度高一些
"""我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。""" # 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,并除以255使得所有像素的数值均在0到1之间 trans=transforms.ToTensor()#首先把图片转成pytorch的tensor mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)#train=True表示是训练数据 mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True) print(len(mnist_train)) len(mnist_test) 输出: 60000 10000
print(mnist_train[0][0].shape)#拿到第0个example的第一个图片;之所以是[0][0],因为mnist_train[i]表第i个例子,mnist_train[i][0]表第i个例子的图片 #mnist_train[i][1]表第i个图片的标签label。mnist_train[i]是(image,label)。 # print(mnist_train[0][0])#输出全是值的矩阵;图片的矩阵值; mnist_train[0][1]#输出类别,此模块两个语句只是为了让你看一下。 输出: torch.Size([1, 28, 28]) 9
def get_fashion_mnist_labels(labels): """返回FashionMNIST数据集的文本标签""" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels] def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5): """Plot a list of images""" figsize=(num_cols*scale,num_rows*scale) _,axes=d2l.plt.subplots(num_rows,num_cols,figsize=figsize) axes=axes.flatten() for i,(ax,img) in enumerate(zip(axes,imgs)): if torch.is_tensor(img): #图片张量 ax.imshow(img.numpy()) else: #PIL图片 ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_xaxis().set_visible(False) if titles: ax.set_title(titles[i]) return axes
X,y = next(iter(data.DataLoader(mnist_train,batch_size=18)))#y是shape为torch.Size([18])的torch.LongTensor列表,为对应标签的数值标号, #如下pirnt所示; print(y,y.shape,y.type()) show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y))#把y标签列表送给get_fashion_mnist_labels函数,得到对应的标签, #再返回作为每个图像的title.
输出:
for i in [9, 0, 0, 3, 0, 2, 7, 2, 5, 5, 0, 9, 5, 5, 7, 9, 1, 0]: print(i,end=' ')#打印的只是列表中的数据
batch_size=256 def get_dataloader_workers(): """使用4个进程来读取数据;也就是每一次读一个图片不容易,因为一般来说图片放在了硬盘上,我们可能需要多个进程来进行数据的读取、操作 以及做一些预读取;当然具体数值你可以根据自己的CPU来定。""" return 4 #网友:如果在pycharm中跑的话,必须把这些代码放到main函数中,不然num_workers只能设置为0; train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers()) #上述的shuffle=True,训练集打乱顺序,测试集就不一定了。 timer=d2l.Timer() for X,y in train_iter: continue print(f'{timer.stop():.2f} sec') #此处之所以测一下读取数据的时间,是因为经常会遇到的性能瓶颈是,很有可能你的模型是训练的挺快的,但是你的数据读取很慢/读不过来, #所以通常我们会在训练之前,会来看一下,我们的数据读取有多快,需要读的至少要比训练要快,一般要快一些,快很多肯定是最好的。 #此处benchmark一下我们的数据读取。
输出: --------------------------------------------------------------------------- Empty Traceback (most recent call last) D:\ProgramData\Anaconda3\envs\dlt\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout) 989 try: --> 990 data = self._data_queue.get(timeout=timeout) 991 return (True, data) D:\ProgramData\Anaconda3\envs\dlt\lib\multiprocessing\queues.py in get(self, block, timeout) 107 if not self._poll(timeout): --> 108 raise Empty 109 elif not self._poll(): Empty: The above exception was the direct cause of the following exception: RuntimeError Traceback (most recent call last) ~\AppData\Local\Temp/ipykernel_59332/2931189052.py in9 10 timer=d2l.Timer() ---> 11 for X,y in train_iter: 12 continue 13 print(f'{timer.stop():.2f} sec') D:\ProgramData\Anaconda3\envs\dlt\lib\site-packages\torch\utils\data\dataloader.py in __next__(self) 519 if self._sampler_iter is None: 520 self._reset() --> 521 data = self._next_data() 522 self._num_yielded += 1 523 if self._dataset_kind == _DatasetKind.Iterable and \ D:\ProgramData\Anaconda3\envs\dlt\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self) 1184 1185 assert not self._shutdown and self._tasks_outstanding > 0 -> 1186 idx, data = self._get_data() 1187 self._tasks_outstanding -= 1 1188 if self._dataset_kind == _DatasetKind.Iterable: D:\ProgramData\Anaconda3\envs\dlt\lib\site-packages\torch\utils\data\dataloader.py in _get_data(self) 1150 else: 1151 while True: -> 1152 success, data = self._try_get_data() 1153 if success: 1154 return data D:\ProgramData\Anaconda3\envs\dlt\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout) 1001 if len(failed_workers) > 0: 1002 pids_str = ', '.join(str(w.pid) for w in failed_workers) -> 1003 raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e 1004 if isinstance(e, queue.Empty): 1005 return (False, None) RuntimeError: DataLoader worker (pid(s) 59456, 48440, 58580, 58408) exited unexpectedly
"""整合上述所有组件。现在我们定义load_data_fashion_mnist函数,使得之后能够重用,?于获取和读取Fashion-MNIST数据集。 它返回训练集和验证集的数据迭代器。此外,它还接受?个可选参数,?来将图像?小调整为另?种形状。""" def load_data_fashion_mnist(batch_size, resize=None): #此处的resize是为了方便后续之后的模型可能 #需要更大的输入,而非现在的28*28;那么我们可以再通过在下面的trans中假如resize的操作来进行调整, #可以把图片变得更大一些。 """下载Fashion-MNIST数据集,然后将其加载到内存中。""" trans = [transforms.ToTensor()] if resize: trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True) return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()), data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))