当前位置: 首页 > news >正文

digit_eye开发记录(2): Python读取MNIST数据集

在上一篇博客 digit_eye开发记录(1): C++读取MNIST数据集 中解读了 IDX 文件格式,并使用 C++ 语言完成了 MNIST 数据集的解析,第6小节给出的完整代码有146行之多。使用 Python 读取则可以省略70%的代码,只用不到50行代码完成相同功能。

读取 buffer

np.frombuffer(buf, dtype, count, offset)

说明:

  • buf: buffer,从文件读出来的
  • dtype: 从buf读取时,按什么类型读取数据,或者说,读取的基本单位是什么
  • count: 从buf读取时,读取多少个基本单位
  • offset: 从buf读取时,指针首先偏移多少个字节

读取 magic number

magic number 是 mnist 文件的前4个字节。 以二进制形式打开后,读取4字节即可:

import numpy as npwith open(filename, 'rb') as fin:buf = bytearray(fin.read())
magic = np.frombuffer(buf, np.uint8, count=4)
print(magic)

读取维度信息

回忆一下 magic numbers 的构成: 前两个字节是0,第三个字节是类型,第四个字节是维度数量 num_dims。
根据 num_dims 的取值,读取对应数量的字节,得到对应的维度信息。每个维度都是一个 int32 大小。

注意 MSB 到 LSB 的转换,通过 dtype=np.dtype('>u4') 指定, >u4 意思是:以MSB序,读取4个byte.

对于图像数据:

num_dims = magic[3]
dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)

对于label数据:

dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)
num_labels = dims[0]

读取图像像素

很容易想到使用 OOP 方式,定义 DataSet 类,在成员 self.images 中保存图像;于是乎,很“毛躁”的写出如下糟糕代码:

class DataSet:def __init__(self):self.images = []self.labels = []def load_images(self, filename):...for i in range(num_images):self.images.append(...)

存在的问题:

  • self.images 的类型一定是 list 吗?其实可以是 numpy 数组
  • self.images 的每个元素,和其他元素,一定是独立的吗? 可以是同一个内存上连续的分布
  • self.images 的每个元素,内存可以和读取文件得到的 buffer 复用吗?可以!
class DataSet:def __init__(self):self.images = Noneself.labels = Noneself.buf = Nonedef load_images(self, filename):with open(filename, 'rb') as fin:self.buf = bytearray(fin.read())magic = np.frombuffer(self.buf, np.uint8, count=4)num_dims = magic[3]dims = np.frombuffer(self.buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)num_images, rows, cols = dimsself.images = np.frombuffer(self.buf, dtype=np.uint8, offset=4+4*num_dims).reshape(num_images, rows, cols)...train_set = DataSet()
train_set.load_images('data/train-images.idx3-ubyte')
print("Images and buffer share memory:", np.shares_memory(train_set.images, train_set.buf))

解释:self.buf 的类型,如果直接用 fin.read() 则得到 bytes 类型,是不可变的;转为 bytearray 类型后,是可变的,就可以保持和 self.images( ) 共享。

遗憾的是, self.buf = bytearray(fin.read()) 这句本身就发生了内存拷贝。

改进 - 避免内存拷贝

with open(filename, 'rb') as fin:  self.buf = bytearray(fin.read())  # 当前实现,存在两次内存分配  

改为

with open(filename, 'rb') as fin:  self.buf = fin.read()  # 读取为 bytes  self.buf = memoryview(self.buf)  # 直接使用 memoryview  

就可以避免 bytes 对象的中间拷贝过程。

完整代码

import numpy as np
import cv2class DataSet:def __init__(self):self.images = Noneself.labels = Noneself.buf = Nonedef load_images(self, filename):with open(filename, 'rb') as fin:#self.buf = bytearray(fin.read())self.buf = fin.read()self.buf = memoryview(self.buf)magic = np.frombuffer(self.buf, np.uint8, count=4)num_dims = magic[3]dims = np.frombuffer(self.buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)num_images, rows, cols = dimsself.images = np.frombuffer(self.buf, dtype=np.uint8, offset=4+4*num_dims).reshape(num_images, rows, cols)def load_labels(self, filename):with open(filename, 'rb') as fin:buf = fin.read()magic = np.frombuffer(buf, np.uint8, count=4)num_dims = magic[3]dims = np.frombuffer(buf, dtype=np.dtype('>u4'), count=num_dims, offset=4)num_labels = dims[0]assert num_labels == len(self.images)self.labels = np.frombuffer(buf, dtype=np.uint8, offset=4+4*num_dims)def show_image(self, index):cv2.imshow('image', self.images[index])print('label:', self.labels[index])cv2.waitKey(0)cv2.destroyAllWindows()def main():train_set = DataSet()train_set.load_images('data/train-images.idx3-ubyte')train_set.load_labels('data/train-labels.idx1-ubyte')# train_set.show_image(0)# train_set.show_image(2)# train_set.show_image(5)print("Images and buffer share memory:", np.shares_memory(train_set.images, train_set.buf))if __name__ == '__main__':main()

总结

在前一篇,我们解析了MNIST数据集的IDX格式并用C++做了文件读取的实现,在本篇则切换到 Python 语言,在降低70%代码量的情况下实现了相同功能,并且避免了不必要的内存拷贝。这份工程之美,建立在对 IDX 格式有所了解的前提之下,对于 Python 的熟悉也是必不可少的,对于C++的经验也促使了复用内存这一条件的达成。


http://www.mrgr.cn/news/78298.html

相关文章:

  • 基于Springboot的流浪宠物管理系统
  • 高效制作定期Excel报表:自动化与模板化的策略
  • 【设计模式】【行为型模式(Behavioral Patterns)】之责任链模式(Chain of Responsibility Pattern)
  • NUXT3学习日记四(路由中间件、导航守卫)
  • SAP 零售方案 CAR 系统的介绍与研究
  • JVM中TLAB(线程本地分配缓存区)是什么
  • 大语言模型LLM的微调中 QA 转换的小工具 txt2excel.py
  • Java AQS(AbstractQueuedSynchronizer):深入剖析
  • v-for产生 You may have an infinite update loop in a component render function
  • 直言抖音电商环境恶化,叶国富也想指点张一鸣
  • 【拥抱AI】RAG如何提高向量化的质量
  • 关于node全栈项目打包发布linux项目问题总集
  • SQL基础入门—— 简单查询与条件筛选
  • ubuntu 安装docker
  • Linux下的火墙管理及优化
  • C语言蓝桥杯组题目
  • WonderJourney 学习笔记
  • Qt获取文件夹下的文件个数(过滤和不过滤的区别)
  • 第 4 章 Java 并发包中原子操作类原理剖析
  • 【Jenkins】docker 部署 Jenkins 踩坑笔记
  • 类和对象--中--初始化列表(重要)、隐式类型转化(理解)、最后两个默认成员函数
  • Android 布局菜单或按钮图标或Menu/Item设置可见和不可见
  • 《Vue 初印象:快速上手 Vue 基础语法》
  • PostgreSQL详细安装教程
  • 基于SpringBoot共享汽车管理系统【附源码】
  • Docker容器运行CentOS镜像,执行yum命令提示“Failed to set locale, defaulting to C.UTF-8”