当前位置: 首页 > news >正文

拆分PPOCRLabel标注的数据集并生成识别数据集

拆分PPOCRLabel标注的数据集并生成识别数据集

说明

  • 首次发表日期:2024-10-31
  • 参考:
    • https://github.com/PFCCLab/PPOCRLabel/blob/main/README_ch.md

关于PPOCRLabel以及本文缘起

PPOCRLabel是OCR领域的标注工具,其本身自带导出识别数据和拆分数据集的功能。其中:

PPOCRLabel本身自带导出识别数据的功能,但是保存检测框图片时会自动旋转图片,具体见其saveRecResult函数实现代码: https://github.com/PFCCLab/PPOCRLabel/blob/81a9c550b7b625bd003a16681fcc7d782184d1f4/PPOCRLabel.py#L3371

    def saveRecResult(self):if {} in [self.PPlabelpath, self.PPlabel, self.fileStatedict]:QMessageBox.information(self, "Information", "Check the image first")returnbase_dir = os.path.dirname(self.PPlabelpath)rec_gt_dir = base_dir + "/rec_gt.txt"crop_img_dir = base_dir + "/crop_img/"ques_img = []if not os.path.exists(crop_img_dir):os.mkdir(crop_img_dir)with open(rec_gt_dir, "w", encoding="utf-8") as f:for key in self.fileStatedict:idx = self.getImglabelidx(key)try:img_path = os.path.dirname(base_dir) + "/" + keyimg = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)for i, label in enumerate(self.PPlabel[idx]):if label["difficult"]:continueimg_crop = get_rotate_crop_image(img, np.array(label["points"], np.float32))img_name = (os.path.splitext(os.path.basename(idx))[0]+ "_crop_"+ str(i)+ ".jpg")cv2.imencode(".jpg", img_crop)[1].tofile(crop_img_dir + img_name)f.write("crop_img/" + img_name + "\t")f.write(label["transcription"] + "\n")except KeyError as e:passexcept Exception as e:ques_img.append(key)traceback.print_exc()if ques_img:QMessageBox.information(self,"Information","The following images can not be saved, please check the image path and labels.\n"+ "".join(str(i) + "\n" for i in ques_img),)QMessageBox.information(self,"Information","Cropped images have been saved in " + str(crop_img_dir),)

其中get_rotate_crop_image函数定义: https://github.com/PFCCLab/PPOCRLabel/blob/81a9c550b7b625bd003a16681fcc7d782184d1f4/libs/utils.py#L137

def get_rotate_crop_image(img, points):# Use Green's theory to judge clockwise or counterclockwise# author: biyanhuad = 0.0for index in range(-1, 3):d += (-0.5* (points[index + 1][1] + points[index][1])* (points[index + 1][0] - points[index][0]))if d < 0:  # counterclockwisetmp = np.array(points)points[1], points[3] = tmp[3], tmp[1]try:img_crop_width = int(max(np.linalg.norm(points[0] - points[1]),np.linalg.norm(points[2] - points[3]),))img_crop_height = int(max(np.linalg.norm(points[0] - points[3]),np.linalg.norm(points[1] - points[2]),))pts_std = np.float32([[0, 0],[img_crop_width, 0],[img_crop_width, img_crop_height],[0, img_crop_height],])M = cv2.getPerspectiveTransform(points, pts_std)dst_img = cv2.warpPerspective(img,M,(img_crop_width, img_crop_height),borderMode=cv2.BORDER_REPLICATE,flags=cv2.INTER_CUBIC,)dst_img_height, dst_img_width = dst_img.shape[0:2]if dst_img_height * 1.0 / dst_img_width >= 1.5:dst_img = np.rot90(dst_img)return dst_imgexcept Exception as e:print(e)

但是,有的场景是不需要在将裁剪的检测框旋转后再保存的。

另外,PPOCRLabel官方自带脚本可以用于拆分数据集:

python gen_ocr_train_val_test.py --trainValTestRatio 9:1:0 --datasetRootPath dataset/handwritten_digits/images --detRootPath ./train_data/det --recRootPath ./train_data/rec

拆分数据集并生成识别数据集

标注文件格式

假设我们有数据集及其标注文件:

data_dir = "data/"
label_file = 'data/Label_det.txt'

PPOCRLabel的标注文件是 PaddleOCR 文字检测数据格式。

PaddleOCR 中的文本检测算法支持的标注文件格式如下,中间用"\t"分隔:

" 图像文件名                    json.dumps编码的图像标注信息"
ch4_test_images/img_61.jpg    [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]

这儿假设data_dir加上标注文件中的图像文件名将构成图片的路径。

读取标注文件

读取标注文件内容,并检查下标注文件中的图片是否都存在:

def get_label_lines(data_dir: str, label_file:str):with open(label_file, 'r') as f:label_lines = f.readlines()for line in label_lines:img_path, img_label = line.split("\t")img_rel_path = os.path.join(data_dir, img_path)if not os.path.exists(img_rel_path):print(f'{img_rel_path} not exists!')return label_lineslabel_lines = get_label_lines(parent_dir, label_file)

拆分标注数据并保存

拆分label_lines:

from sklearn.model_selection import train_test_splittrain_set_label_lines, test_set_label_lines = train_test_split(label_lines, test_size = 0.2, random_state = 42)

保存为具体的数据集(图片和标注文件):

def save_split_data(split_label_lines,data_dir,dest_dir = "dataset",split_name = "train",
):new_label_lines = []first_img_path = split_label_lines[0].split("\t")[0]parent_dir_name = os.path.split(os.path.dirname(os.path.join(dest_dir, first_img_path)))[-1]rel_dest_img_path = "_".join([parent_dir_name, split_name])dest_dir = os.path.join(dest_dir, rel_dest_img_path)os.makedirs(dest_dir, exist_ok=True)for line in split_label_lines:img_path, label_text = line.split("\t")label_text = label_text.replace("\n", "")assert parent_dir_name == os.path.split(os.path.dirname(os.path.join(dest_dir, img_path)))[-1]new_label_lines.append("\t".join([os.path.join(rel_dest_img_path, os.path.basename(img_path)), label_text]))shutil.copy2(os.path.join(data_dir, img_path), os.path.join(dest_dir, os.path.basename(img_path)))label_file_path = os.path.join(dest_dir, "_".join(["Label", parent_dir_name, split_name]) + ".txt")with open(label_file_path, "w") as f:f.write("\n".join(new_label_lines))return dest_dir, label_file_path
train_img_dir, train_det_label_file = save_split_data(train_set_label_lines,data_dir = data_dir,dest_dir = "dataset",split_name = "train",
)test_img_dir, test_det_label_file = save_split_data(test_set_label_lines,data_dir = data_dir,dest_dir = "dataset",split_name = "test",
)

生成识别图片和标签

def generate_rec_img_label(label_file_path, parent_dir, do_crop = True):with open(label_file_path, 'r') as f:label_lines = f.readlines()rec_label_lines = []for line in label_lines:img_path, label_text = line.split("\t")label_text = label_text.replace("\n", "")label_list = json.loads(label_text)img = cv2.imread(os.path.join(parent_dir, img_path))parent_img_dir = os.path.split(os.path.dirname(img_path))[-1]dest_img_dir = dest_img_path = os.path.join(parent_dir, parent_img_dir, "crop_img")os.makedirs(dest_img_dir, exist_ok=True)for idx, label in enumerate(label_list):crop_img_name = os.path.splitext(os.path.basename(img_path))[0] + "_crop_" + str(idx) + ".jpg"rec_label_lines.append("\t".join([os.path.join(parent_img_dir, "crop_img", crop_img_name), label["transcription"]]))dest_img_path = os.path.join(dest_img_dir, crop_img_name)if do_crop:pt0, pt1, pt2, pt3 = label["points"]crop_img = img[pt0[1]:pt2[1], pt0[0]:pt2[0]]cv2.imwrite(dest_img_path, crop_img)else:shutil.copy2(os.path.join(parent_dir, img_path), dest_img_path)with open(os.path.join(parent_dir, "_".join([os.path.splitext(os.path.basename(label_file_path))[0], "rec"]) + ".txt"), 'w') as f:f.write("\n".join(rec_label_lines))
generate_rec_img_label(train_det_label_file, "dataset")
generate_rec_img_label(test_det_label_file, "dataset")

http://www.mrgr.cn/news/62622.html

相关文章:

  • C++标准模板库--栈和队列
  • Ubuntu20.04安装VM tools并实现主机和虚拟机之间文件夹共享
  • 硬盘的管理
  • 【荔枝网】请求头参数逆向分析 异步回调层层嵌套 | wasm实例化 |
  • 图书管理系统汇报
  • GDB(GNU Debugger)的使用教程
  • 动态规划-回文串问题——647.回文子串
  • Python使用 try-except 捕获与处理异常
  • 从安装到实战:Spring Boot与RabbitMQ的终极整合指南
  • Go 语言解析 yaml 文件的方法
  • ES聚合(仅供自己参考)
  • 【安全性分析】BAN逻辑 (BAN Logic)之详细介绍
  • 天润融通邀您参加AI破局·聚力增长行业论坛
  • 去人声留伴奏免费软件,这四款软件可别错过
  • 智能码二维码zhinengma.cn如何赋能工业产品质量安全追溯
  • 【深度学习】实验 — 动手实现 GPT【二】:注意力机制、注意力掩码、多头注意力机制
  • ABAP RFC SQL 模糊查询和多个区间条件
  • 一些老程序员不愿透露的工作小技巧…
  • 【HDRP下实现视差效果_CubeMap和九宫格ArrayMap形式】
  • 2024年“炫转青春”山东省飞盘联赛盛大开赛——临沭县青少年飞盘运动迅速升温
  • 隐私保护下的数据提取策略
  • USC H5S支持大华ICC平台对接
  • QT:QThread:重写run函数
  • python函数连续
  • ARM base instruction -- adc
  • 2181、合并零之间的节点