读取MNIST数据
Python依赖库:numpy matplotlib
数据集下载地址:http://yann.lecun.com/exdb/mnist/
数据集的数据格式:
源码实现:import struct
import numpy as np import matplotlib.pyplot as plt def unpack_mnist(filepath): # 读取mnist数据 f = open(filepath, 'rb') buf = f.read() # mnist数据类型识别 index = 0 magicNum = struct.unpack_from('>I', buf, index) index += struct.calcsize('>I') # 标签数据解包 if magicNum[0] == 2049: labels = [] (labelNum, ) = struct.unpack_from('>I', buf, index) index += struct.calcsize('>I') for i in range(labelNum): label = struct.unpack_from('>B' ,buf, index) index += struct.calcsize('>B') # 将数据添加到数组存储 labels.append(label[0]) f.close() return labelNum, labels # 图像数据解包 elif magicNum[0] == 2051: imgs = [] imgNum, rows, columns = struct.unpack_from('>III', buf, index) index += struct.calcsize('>III') for i in range(imgNum): img = struct.unpack_from('>784B' ,buf, index) index += struct.calcsize('>784B') # 将数据添加到数组存储 imgs.append(img) f.close() return imgNum, rows, columns, imgs else: print("input file error!") f.close() # 读取训练数据 imgNum, rows, columns, imgs = unpack_mnist("../data/train-images.idx3-ubyte") labelNum, labels = unpack_mnist("../data/train-labels.idx1-ubyte") # 读取测试数据 imgNum, rows, columns, imgs = unpack_mnist("../data/t10k-images.idx3-ubyte") labelNum, labels = unpack_mnist("../data/t10k-labels.idx1-ubyte") # 测试读取的图片索引 index = 5 img = np.array(imgs[index]) img = img.reshape(28,28)
fig = plt.figure() plotwindow = fig.add_subplot(111) plt.title(str(labels[index])) plt.imshow(img ,cmap='gray') plt.show()