1.简介
MNIST 数据集来自美国国家标准与技术研究所, 是NIST(National Institute of Standards and Technology)的缩小版,训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员,测试集(test set) 也是同样比例的手写数字数据.
MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 图片是以字节的形式进行存储,它包含了四个部分:
- Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
- Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
- Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
- Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
此数据集中,训练样本:共60000个,其中55000个用于训练,另外5000个用于验证。测试样本:共10000个,验证数据比例相同。
数据集中像素值
a)使用python读取二进制文件方法读取mnist数据集,则读进来的图像像素值为0-255之间;标签是0-9的数值。
b)采用TensorFlow的封装的函数读取mnist,则读进来的图像像素值为0-1之间;标签是0-1值组成的大小为1*10的行向量。
2.读取mnist到numpy
load_mnist 函数返回两个数组, 第一个是一个 n x m 维的 NumPy array(images), 这里的 n 是样本数(行数), m 是特征数(列数). 训练数据集包含 60,000 个样本, 测试数据集包含 10,000 样本.
在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示. 在这里, 我们将 28 x 28 的像素展开为一个一维的行向量, 这些行向量就是图片数组里的行(每行 784 个值, 或者说每行就是代表了一张图片).
load_mnist 函数返回的第二个数组(labels) 包含了相应的目标变量, 也就是手写数字的类标签(整数 0-9).
1 | mport os |
1 | 1. |
3.查看tensorflow集成的mnist
1 | from tensorflow.examples.tutorials.mnist import input_data |
4.可视化
4.1 plt的方法
从 feature matrix 中将 784-像素值 的向量 reshape 为之前的 28*28 的形状, 然后通过 matplotlib 的 imshow 函数进行绘制,不能进行one-hot编码:
- 读单个图片
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29import matplotlib.pyplot as plt
#from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
#mnist = read_data_sets('MNIST_data', one_hot=False)
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./MNIST_data",one_hot=False)
x, y = mnist.test.next_batch(1)
x = x.reshape([28, 28])
fig = plt.figure()
# Method1
ax1 = fig.add_subplot(221)
ax1.imshow(x, cmap=plt.cm.gray)
# Method2: 反转色
ax2 = fig.add_subplot(222)
ax2.imshow(x, cmap=plt.cm.gray_r) # r表示reverse
# Method3(等价于Method1)
ax3 = fig.add_subplot(223)
ax3.imshow(x, cmap='gray')
# Method4(等价于Method2)
ax4 = fig.add_subplot(224)
ax4.imshow(x, cmap='gray_r')
plt.show() - 读多个图片
1 | import matplotlib.pyplot as plt |
4.2 torchvision&scipy方法
其实数据集里的图片就是一个带有像素值的二维数组,可以画出这个数组的库有很多。包括机器学习库torch,其中的torchvision也可以。具体方法如下:
1 | import torchvision |
结果输出如下: