python处理MNIST数据集

您所在的位置:网站首页 tensorflow处理多张图片 python处理MNIST数据集

python处理MNIST数据集

2023-09-12 16:52| 来源: 网络整理| 查看: 265

1. MNIST数据集 1.1 MNIST数据集获取

MNIST数据集是入门机器学习/模式识别的最经典数据集之一。最早于1998年Yan Lecun在论文:

Gradient-based learning applied to document recognition.

中提出。经典的LeNet-5 CNN网络也是在该论文中提出的。 数据集包含了0-9共10类手写数字图片,每张图片都做了尺寸归一化,都是28x28大小的灰度图。每张图片中像素值大小在0-255之间,其中0是黑色背景,255是白色前景。如下图所示:

MNIST共包含70000张手写数字图片,其中有60000张用作训练集,10000张用作测试集。原始数据集可在MNIST官网下载。

下载之后得到4个压缩文件:

train-images-idx3-ubyte.gz #60000张训练集图片 train-labels-idx1-ubyte.gz #60000张训练集图片对应的标签 t10k-images-idx3-ubyte.gz #10000张测试集图片 t10k-labels-idx1-ubyte.gz #10000张测试集图片对应的标签

将其解压,得到

train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte 1.2 MNIST二进制文件的存储格式

解压得到的四个文件都是二进制格式,我们如何获取其中的信息呢?这得首先了解MNIST二进制文件的存储格式(官网底部有介绍),以训练集图像文件train-images-idx3-ubyte为例:

图像文件的

第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051; 第5-8个byte存的是number of images,即图像数量60000; 第9-12个byte存的是每张图片行数/高度,即28; 第13-16个byte存的是每张图片的列数/宽度,即28。 从第17个byte开始,每个byte存储一张图片中的一个像素点的值。

因为train-images-idx3-ubyte文件总共包含了60000张图片数据,按照以上的存储方式,我们算一下该文件的大小:

一张图片包含28x28=784个像素点,需要784bytes的存储空间; 60000张图片则需要784x60000=47040000 bytes的存储空间; 此外,文件开始处使用了16个bytes用于存储magic number、图像数量、图像高度和图像宽度,因此,训练集图像文件的大小应该是47040000+16=47040016 bytes。

我们查看解压后的train-images-idx3-ubyte文件的属性:

文件实际大小和我们计算的结果一致。

类似地,我们查看训练集标签文件train-labels-idx1-ubyte的存储格式:

和图像文件类似:

第1-4个byte存的是文件的magic number,对应的十进制大小是2049; 第5-8个byte存的是number of items,即label数量60000; 从第9个byte开始,每个byte存一个图片的label信息,即数字0-9中的一个。

计算一下训练集标签文件train-labels-idx1-ubyte的文件大小:

1x60000+8=60008 bytes。

与该文件实际的大小一致:

另外两个文件,即测试集图像文件、测试集标签文件的存储方式和训练图像文件、训练标签文件相似,只是图像数量由60000变为10000。

1.3 使用python访问MNIST数据集文件内容

知道了MNIST二进制文件的存储方式,下面介绍如何使用python访问文件内容。同样以训练集图像文件train-images-idx3-ubyte为例:

首先,使用open()函数打开文件,并使用read()方法将所有的文件数据读入到一个字符串中:

yan@yanubuntu:~/codes/Deep-Learning-21-Examples/chapter_1/MNIST_data$ python Python 2.7.12 (default, Nov 12 2018, 14:36:49) [GCC 5.4.0 20160609] on linux2 Type "help", "copyright", "credits" or "license" for more information. >>> with open('train-images.idx3-ubyte', 'rb') as f: ... file = f.read() ... >>>

file是str类型,其中的每个元素就存储的1个字节的内容。我们现在查看前4个字节,即magic number的内容,看下是否是前面说的2051:

>>> magic_number=file[:4] >>> magic_number '\x00\x00\x08\x03' >>> magic_number.encode('hex') '00000803' >>> int(magic_number.encode('hex'),16) 2051

可以看出前4个byte的值确实是2051,但是不能直接输出magic number的内容,需要将其编码,然后才能转成十进制的int类型(有关字节编码的知识暂时没懂,先用着)。 同样的方式,查看图像数量、图像高度和图像宽度信息:

>>> num_images = int(file[4:8].encode('hex'),16) >>> num_images 60000 >>> h_image = int(file[8:12].encode('hex'),16) >>> h_image 28 >>> w_image = int(file[12:16].encode('hex'),16) >>> w_image 28

现在获取第1张图片的像素信息,然后利用numpy和cv2模块转换其格式,并保存成.jpg格式的图片:

>>> image1 = [int(item.encode('hex'), 16) for item in file[16:16+784]] >>> len(image1) 784 >>> import numpy as np >>> import cv2 >>> image1_np = np.array(image1, dtype=np.uint8).reshape(28,28,1) >>> image1_np.shape (28, 28, 1) >>> cv2.imwrite('image1.jpg', image1_np) True >>>

保存下来的图片image1.jpg如下图所示:

该图片的标签是5,我们可以验证一下训练集标签文件train-labels-idx1-ubyte文件的第一个标签是否和图像内容一一对应:

>>> with open('train-labels.idx1-ubyte', 'rb') as f: ... label_file = f.read() ... >>> label1 = int(label_file[8].encode('hex'), 16) >>> label1 5 >>>

训练标签文件的第一张图片标签是第9个byte(索引从0开始,所以第9个byte是label_file[8]),结果没问题。

1.4 将MNIST数据集保存成.jpg图片格式

因为使用上面得到的file和label_file文件是str类型,因此可以使用迭代的方式,将所有训练和测试集的二进制文件格式转成.jpg图片格式。转换脚本mnist2jpg.py如下:

# coding=utf-8 '''将二进制格式的MNIST数据集转成.jpg图片格式并保存,图片标签包含在图片名中''' import numpy as np import cv2 import os def save_mnist_to_jpg(mnist_image_file, mnist_label_file, save_dir): if 'train' in os.path.basename(mnist_image_file): num_file = 60000 prefix = 'train' else: num_file = 10000 prefix = 'test' with open(mnist_image_file, 'rb') as f1: image_file = f1.read() with open(mnist_label_file, 'rb') as f2: label_file = f2.read() image_file = image_file[16:] label_file = label_file[8:] for i in range(num_file): label = int(label_file[i].encode('hex'), 16) image_list = [int(item.encode('hex'), 16) for item in image_file[i*784:i*784+784]] image_np = np.array(image_list, dtype=np.uint8).reshape(28,28,1) save_name = os.path.join(save_dir, '{}_{}_{}.jpg'.format(prefix, i, label)) cv2.imwrite(save_name, image_np) print '{} ==> {}_{}_{}.jpg'.format(i, prefix, i, label) if __name__ == '__main__': train_image_file = './train-images.idx3-ubyte' train_label_file = './train-labels.idx1-ubyte' test_image_file = 't10k-images.idx3-ubyte' test_label_file = './t10k-labels.idx1-ubyte' save_train_dir = './train_images/' save_test_dir ='./test_images/' if not os.path.exists(save_train_dir): os.makedirs(save_train_dir) if not os.path.exists(save_test_dir): os.makedirs(save_test_dir) save_mnist_to_jpg(train_image_file, train_label_file, save_train_dir) save_mnist_to_jpg(test_image_file, test_label_file, save_test_dir) 2. Tensorflow处理MNIST数据集的方式

上面读取MNIST的代码可能效率不高,Tensorflow库中专门有处理MNIST数据集的API接口,源代码涉及到几个python文件,我将其整理到一个read_mnist.py文件中:

# coding=utf-8 """Tensorflow中用于读取MNIST数据集的简化代码""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import collections import gzip import numpy #带名字的tuple,方便使用train/validation/test区分不同数据集 Datasets = collections.namedtuple('Datasets', ['train', 'validation', 'test']) def _read32(bytestream): #numpy.dtype.newbyteorder()函数返回一种指定字节序的dtype #参数('>')表示big endian的字节序,MNIST官网底部有提到MNIST二进制数据使用这种字节序,具体不太懂 dt = numpy.dtype(numpy.uint32).newbyteorder('>') #numpy.frombuffer()一次读取bytestream中的4个byte,返回一个一维数组,所以需要使用索引[0]取其中的元素 return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] #从网上下载的mnist数据集图像文件(.gz)中读取数据 #返回一个4-D、np.uint8类型的ndarray,shape=[num_iamges, h, w, channels] # 比如针对训练集图像文件,返回值shape=[60000, 28, 28, 1] def extract_images(f): """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. Args: f: A file object that can be passed into a gzip reader. Returns: data: A 4D uint8 numpy array [index, y, x, depth]. Raises: ValueError: If the bytestream does not start with 2051. """ print('Extracting', f.name) with gzip.GzipFile(fileobj=f) as bytestream: #图像文件的前4个byte记录magic number,数值是2051 magic = _read32(bytestream) if magic != 2051: raise ValueError('Invalid magic number %d in MNIST image file: %s' % (magic, f.name)) #图像文件的第4-8个byte记录的是图像数量,训练集文件是60000,测试集文件是10000 num_images = _read32(bytestream) #图像文件的第8-12个byte记录的是图像的h rows = _read32(bytestream) #图像文件的第12-16个byte记录的是图像的w cols = _read32(bytestream) #文件剩下的内容记录所有图片中的像素值,将其全部读取到一维数组data中,dtype=np.uint8 buf = bytestream.read(rows * cols * num_images) data = numpy.frombuffer(buf, dtype=numpy.uint8) #reshape成[num_images, h, w, channels] data = data.reshape(num_images, rows, cols, 1) return data #将label转成one-hot向量 def dense_to_one_hot(labels_dense, num_classes): """Convert class labels from scalars to one-hot vectors.""" num_labels = labels_dense.shape[0] index_offset = numpy.arange(num_labels) * num_classes labels_one_hot = numpy.zeros((num_labels, num_classes)) labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 return labels_one_hot #读取从网上下载的mnist数据集标签文件的内容 #如果one_hot=True: 返回一个2-D、shape=[num_images, 10]、dtype=np.uint8的ndarray #如果one_hot=False: 返回一个1-D、shape=[num_images]、dtype=np.uint8的ndarray def extract_labels(f, one_hot=False, num_classes=10): """Extract the labels into a 1D uint8 numpy array [index]. Args: f: A file object that can be passed into a gzip reader. one_hot: Does one hot encoding for the result. num_classes: Number of classes for the one hot encoding. Returns: labels: a 1D uint8 numpy array. Raises: ValueError: If the bystream doesn't start with 2049. """ print('Extracting', f.name) with gzip.GzipFile(fileobj=f) as bytestream: magic = _read32(bytestream) if magic != 2049: raise ValueError('Invalid magic number %d in MNIST label file: %s' % (magic, f.name)) num_items = _read32(bytestream) buf = bytestream.read(num_items) labels = numpy.frombuffer(buf, dtype=numpy.uint8) if one_hot: return dense_to_one_hot(labels, num_classes) return labels class DataSet(object): def __init__(self, images, labels, dtype=numpy.float32, reshape=True): assert images.shape[0] == labels.shape[0], ( 'images.shape: %s labels.shape: %s' % (images.shape, labels.shape)) self._num_examples = images.shape[0] # Convert shape from [num examples, rows, columns, depth] # to [num examples, rows*columns] (assuming depth == 1) if reshape: assert images.shape[3] == 1 images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]) if dtype == numpy.float32: # Convert from [0, 255] -> [0.0, 1.0]. images = images.astype(numpy.float32) images = numpy.multiply(images, 1.0 / 255.0) self._images = images self._labels = labels self._epochs_completed = 0 self._index_in_epoch = 0 @property def images(self): return self._images @property def labels(self): return self._labels @property def num_examples(self): return self._num_examples @property def epochs_completed(self): return self._epochs_completed def next_batch(self, batch_size, shuffle=True): """Return the next `batch_size` examples from this data set.""" start = self._index_in_epoch # Shuffle for the first epoch if self._epochs_completed == 0 and start == 0 and shuffle: perm0 = numpy.arange(self._num_examples) numpy.random.shuffle(perm0) self._images = self.images[perm0] self._labels = self.labels[perm0] # Go to the next epoch if start + batch_size > self._num_examples: # Finished epoch self._epochs_completed += 1 # Get the rest examples in this epoch rest_num_examples = self._num_examples - start images_rest_part = self._images[start:self._num_examples] labels_rest_part = self._labels[start:self._num_examples] # Shuffle the data if shuffle: perm = numpy.arange(self._num_examples) numpy.random.shuffle(perm) self._images = self.images[perm] self._labels = self.labels[perm] # Start next epoch start = 0 self._index_in_epoch = batch_size - rest_num_examples end = self._index_in_epoch images_new_part = self._images[start:end] labels_new_part = self._labels[start:end] return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0) else: self._index_in_epoch += batch_size end = self._index_in_epoch return self._images[start:end], self._labels[start:end] def read_data_sets(mnist_dir, one_hot=False, dtype=numpy.float32, reshape=True, validation_size=5000): '''读取MNIST数据集 Args: mnist_dir: 存放4个MNIST数据集压缩文件的文件夹,数据集文件从网址http://yann.lecun.com/exdb/mnist/下载 one_hot: 如果one_hot=True, 返回的labels是one_hot编码 reshape: 如果reshape=True,返回的images将展开成784维的向量 Return: 一个Datasets对象,是一个namedtuple: Datasets.train包含训练集数据 Datasets.validation包含验证集数据 Datasets.test包含测试集数据 ''' TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TEST_IMAGES = 't10k-images-idx3-ubyte.gz' TEST_LABELS = 't10k-labels-idx1-ubyte.gz' #读取训练集图像,train_images.shape=[60000, 28, 28, 1], dtype=np.uint8 local_file = os.path.join(mnist_dir, TRAIN_IMAGES) with open(local_file, 'rb') as f: train_images = extract_images(f) #读取训练集标签,如果one_hot=False, train_labels.shape=[60000,] #如果one_hot=True, train_labels.shape=[60000,10] local_file = os.path.join(mnist_dir, TRAIN_LABELS) with open(local_file, 'rb') as f: train_labels = extract_labels(f, one_hot=one_hot) local_file = os.path.join(mnist_dir, TEST_IMAGES) with open(local_file, 'rb') as f: test_images = extract_images(f) local_file = os.path.join(mnist_dir, TEST_LABELS) with open(local_file, 'rb') as f: test_labels = extract_labels(f, one_hot=one_hot) if not 0


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3