#3.6.1 MNIST数据集
MNIST数据集是由0到9的数字图像构成的。训练图像有6万张,测试图像有1万张,这些图像可以用于学习和推理。MNIST数据集的一般使用方法是,先用训练图像进行学习,再用学习到的模型度量能在多大程度上对测试图像进行正确的分类。
可以按下述方式读入MNIST数据。
import sys,os
sys.path.append("D:\学习资料\深度学习")
from mnist import load_mnist
(x_train,t_train),(x_test,t_test)=load_mnist(flatten=True,normalize=False)#第一次调用花费的时间较长
输出各个数据的形状
print(x_train.shape)#(60000, 784)print(t_train.shape)#(60000,)print(x_test.shape)#(10000, 784)print(t_test.shape)#(10000,)
load_mnist函数以“(训练图像 ,训练标签 ),(测试图像,测试标签 )”的形式返回读入的MNIST数据。此外,还可以像load_mnist(normalize=True, flatten=True, one_hot_label=False) 这 样,设 置 3 个 参 数。第 1 个参数normalize设置是否将输入图像正规化为0.0~1.0的值。如果将该参数设置为False,则输入图像的像素会保持原来的0~255。第2个参数flatten设置是否展开输入图像(变成一维数组)。如果将该参数设置为False,则输图
像为1 × 28 × 28的三维数组;若设置为True,则输入图像会保存为由784个元素构成的一维数组。第3个参数one_hot_label设置是否将标签保存为one-hot表示(one-hot representation)。one-hot表示是仅正确解标签为1,其余皆为0的数组,就像[0,0,1,0,0,0,0,0,0,0]这样。当one_hot_label为False时,只是像7、2这样简单保存正确解标签;当one_hot_label为True时,标签则保存为one-hot表示。
现在,我们试着显示MNIST图像,同时也确认一下数据。图像的显示使用PIL(Python Image Library)模块。执行下述代码后,训练图像的第一张就会显示出来。
import sys,os
sys.path.append("D:\学习资料\深度学习")
import numpy as np
from mnist import load_mnist
from PIL import Imagedef img_show(img):pil_img = Image.fromarray(np.uint8(img))pil_img.show()(x_train,t_train),(x_test,t_test)=load_mnist(flatten=True,normalize=False)
img = x_train[0]
label=t_train[0]
print(label)print(img.shape)
img = img.reshape(28,28)
print(img.shape)img_show(img)
这里需要注意的是,flatten=True时读入的图像是以一列(一维)NumPy数组的形式保存的。因此,显示图像时,需要把它变为原来的28像素 × 28像素的形状。可以通过reshape()方法的参数指定期望的形状,更改NumPy数组的形状。此外,还需要把保存为NumPy数组的图像数据转换为PIL用的数据对象,这个转换处理由Image.fromarray()来完成。