04_0. 图片分类数据集


%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()#使用svg来显示图片,清晰度高一些
"""我们可以通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中。"""
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,并除以255使得所有像素的数值均在0到1之间
trans=transforms.ToTensor()#首先把图片转成pytorch的tensor
mnist_train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=trans,download=True)#train=True表示是训练数据
mnist_test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=trans,download=True)
print(len(mnist_train))
len(mnist_test)
输出:
60000
10000
print(mnist_train[0][0].shape)#拿到第0个example的第一个图片;之所以是[0][0],因为mnist_train[i]表第i个例子,mnist_train[i][0]表第i个例子的图片
#mnist_train[i][1]表第i个图片的标签label。mnist_train[i]是(image,label)。
# print(mnist_train[0][0])#输出全是值的矩阵;图片的矩阵值;
mnist_train[0][1]#输出类别,此模块两个语句只是为了让你看一下。
输出:
torch.Size([1, 28, 28])
9
def get_fashion_mnist_labels(labels):
    """返回FashionMNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
def show_images(imgs,num_rows,num_cols,titles=None,scale=1.5):
    """Plot a list of images"""
    figsize=(num_cols*scale,num_rows*scale)
    _,axes=d2l.plt.subplots(num_rows,num_cols,figsize=figsize)
    axes=axes.flatten()
    for i,(ax,img) in enumerate(zip(axes,imgs)):
        if torch.is_tensor(img):
            #图片张量
            ax.imshow(img.numpy())
        else:
            #PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_xaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes
X,y = next(iter(data.DataLoader(mnist_train,batch_size=18)))#y是shape为torch.Size([18])的torch.LongTensor列表,为对应标签的数值标号,
#如下pirnt所示;
print(y,y.shape,y.type())
show_images(X.reshape(18,28,28),2,9,titles=get_fashion_mnist_labels(y))#把y标签列表送给get_fashion_mnist_labels函数,得到对应的标签,
#再返回作为每个图像的title.

输出:

for i in [9, 0, 0, 3, 0, 2, 7, 2, 5, 5, 0, 9, 5, 5, 7, 9, 1, 0]:
    print(i,end=' ')#打印的只是列表中的数据
batch_size=256

def get_dataloader_workers():
    """使用4个进程来读取数据;也就是每一次读一个图片不容易,因为一般来说图片放在了硬盘上,我们可能需要多个进程来进行数据的读取、操作
    以及做一些预读取;当然具体数值你可以根据自己的CPU来定。"""
    return 4 #网友:如果在pycharm中跑的话,必须把这些代码放到main函数中,不然num_workers只能设置为0;
train_iter=data.DataLoader(mnist_train,batch_size,shuffle=True,num_workers=get_dataloader_workers())
#上述的shuffle=True,训练集打乱顺序,测试集就不一定了。

timer=d2l.Timer()
for X,y in train_iter: 
    continue
print(f'{timer.stop():.2f} sec')
#此处之所以测一下读取数据的时间,是因为经常会遇到的性能瓶颈是,很有可能你的模型是训练的挺快的,但是你的数据读取很慢/读不过来,
#所以通常我们会在训练之前,会来看一下,我们的数据读取有多快,需要读的至少要比训练要快,一般要快一些,快很多肯定是最好的。
#此处benchmark一下我们的数据读取。
输出:
---------------------------------------------------------------------------
Empty                                     Traceback (most recent call last)
D:\ProgramData\Anaconda3\envs\dlt\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout)
    989         try:
--> 990             data = self._data_queue.get(timeout=timeout)
    991             return (True, data)

D:\ProgramData\Anaconda3\envs\dlt\lib\multiprocessing\queues.py in get(self, block, timeout)
    107                     if not self._poll(timeout):
--> 108                         raise Empty
    109                 elif not self._poll():

Empty: 

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_59332/2931189052.py in 
      9 
     10 timer=d2l.Timer()
---> 11 for X,y in train_iter:
     12     continue
     13 print(f'{timer.stop():.2f} sec')

D:\ProgramData\Anaconda3\envs\dlt\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

D:\ProgramData\Anaconda3\envs\dlt\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
   1184 
   1185             assert not self._shutdown and self._tasks_outstanding > 0
-> 1186             idx, data = self._get_data()
   1187             self._tasks_outstanding -= 1
   1188             if self._dataset_kind == _DatasetKind.Iterable:

D:\ProgramData\Anaconda3\envs\dlt\lib\site-packages\torch\utils\data\dataloader.py in _get_data(self)
   1150         else:
   1151             while True:
-> 1152                 success, data = self._try_get_data()
   1153                 if success:
   1154                     return data

D:\ProgramData\Anaconda3\envs\dlt\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout)
   1001             if len(failed_workers) > 0:
   1002                 pids_str = ', '.join(str(w.pid) for w in failed_workers)
-> 1003                 raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
   1004             if isinstance(e, queue.Empty):
   1005                 return (False, None)

RuntimeError: DataLoader worker (pid(s) 59456, 48440, 58580, 58408) exited unexpectedly
"""整合上述所有组件。现在我们定义load_data_fashion_mnist函数,使得之后能够重用,?于获取和读取Fashion-MNIST数据集。
它返回训练集和验证集的数据迭代器。此外,它还接受?个可选参数,?来将图像?小调整为另?种形状。"""
def load_data_fashion_mnist(batch_size, resize=None): #此处的resize是为了方便后续之后的模型可能
    #需要更大的输入,而非现在的28*28;那么我们可以再通过在下面的trans中假如resize的操作来进行调整,
    #可以把图片变得更大一些。
    """下载Fashion-MNIST数据集,然后将其加载到内存中。"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))