从零开始大模型开发与微调:基于PyTorch与ChatGLM
上QQ阅读APP看书,第一时间看更新

2.4.2 MNIST数据集的特征和标签介绍

对于数据库的获取,前面介绍了两种不同的MNIST数据集的获取方式,本小节推荐使用本书配套源码包中的MNIST数据集进行数据的读取,代码如下:

     import numpy as np
     x_train = np.load("./dataset/mnist/x_train.npy")
     y_train_label = np.load("./dataset/mnist/y_train_label.npy")

这里numpy库函数会根据输入的地址对数据进行处理,并自动将其分解成训练集和验证集。打印训练集的维度如下:

     (60000, 28, 28)
     (60000, )

这是进行数据处理的第一步,有兴趣的读者可以进一步完成数据的训练集和测试集的划分。

回到MNIST数据集,每个MNIST实例数据单元也是由两部分构成的,分别是一幅包含手写数字的图片和一个与其相对应的标签。可以将其中的标签特征设置成y,而图片特征矩阵以x来代替,所有的训练集和测试集中都包含x和y。

图2-30用更为一般化的形式解释了MNIST数据实例的展开形式。在这里,图片数据被展开成矩阵的形式,矩阵的大小为28×28。至于如何处理这个矩阵,常用的方法是将其展开,而展开的方式和顺序并不重要,只需要将其按同样的方式展开即可。

图2-30 图片转换为向量模式

下面回到对数据的读取,前面已经介绍了,MNIST数据集实际上就是一个包含着60 000幅图片的60 000×28×28大小的矩阵张量[60000,28,28],如图2-31所示。

图2-31 MNIST数据集的矩阵表示

矩阵中行数指的是图片的索引,用以对图片进行提取,而后面的28×28个向量用以对图片特征进行标注。实际上,这些特征向量就是图片中的像素点,每幅手写图片是[28,28]的大小,每个像素转化为一个0~1的浮点数,构成矩阵。