读取MNIST数据


Python依赖库:numpy matplotlib

数据集下载地址:http://yann.lecun.com/exdb/mnist/

数据集的数据格式:

源码实现:import struct

import numpy as np
import matplotlib.pyplot as plt

def unpack_mnist(filepath):
    # 读取mnist数据
    f = open(filepath, 'rb')
    buf = f.read()
    # mnist数据类型识别
    index = 0
    magicNum = struct.unpack_from('>I', buf, index)
    index += struct.calcsize('>I')
    # 标签数据解包
    if magicNum[0] == 2049:
        labels = []
        (labelNum, ) = struct.unpack_from('>I', buf, index)
        index += struct.calcsize('>I')
        for i in range(labelNum):
            label = struct.unpack_from('>B' ,buf, index)
            index += struct.calcsize('>B')
            # 将数据添加到数组存储
            labels.append(label[0])
        f.close()
        return labelNum, labels
    # 图像数据解包
    elif magicNum[0] == 2051:
        imgs = []
        imgNum, rows, columns = struct.unpack_from('>III', buf, index)
        index += struct.calcsize('>III')
        for i in range(imgNum):
            img = struct.unpack_from('>784B' ,buf, index)
            index += struct.calcsize('>784B')
            # 将数据添加到数组存储
            imgs.append(img)
        f.close()
        return imgNum, rows, columns, imgs
    else:
        print("input file error!")
    f.close()


# 读取训练数据
imgNum, rows, columns, imgs = unpack_mnist("../data/train-images.idx3-ubyte")
labelNum, labels = unpack_mnist("../data/train-labels.idx1-ubyte")

# 读取测试数据
imgNum, rows, columns, imgs = unpack_mnist("../data/t10k-images.idx3-ubyte")
labelNum, labels = unpack_mnist("../data/t10k-labels.idx1-ubyte")

# 测试读取的图片索引
index = 5

img = np.array(imgs[index])
img = img.reshape(28,28)

fig
= plt.figure() plotwindow = fig.add_subplot(111) plt.title(str(labels[index])) plt.imshow(img ,cmap='gray') plt.show()