PaddleOCR 文字检测/文字块检测的模型训练过程
创始人
2025-05-28 21:46:35
0

文章目录

  • 1、环境搭建
  • 2、数据集
  • 3、下载预训练模型
  • 4、配置文件
    • DecodeImage
    • DetLabelEncode
    • IaaAugment
    • EastRandomCropData
    • MakeBorderMap
  • 5、开启训练
  • 6、纯记录,我在我服务器做的事情
  • 7、显示label
  • 8、推导模型的导出与预测
  • 9、转到onnx模型和mnn模型
    • 到onnx
    • 到mnn
  • 10、在线可视化

1、环境搭建

官网环境准备参考:https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/environment.md#1.3
paddle官网:https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html

linux cpu用于调试:

conda create -n paddle python=3.7 -y
conda activate paddle
pip install -r requirements.txt
python -m pip install paddlepaddle==2.1.3 -i https://pypi.tuna.tsinghua.edu.cn/simple

如果是gpu,docker用起来:

nvidia-docker run -v $PWD:/paddle -v /ssd/xiedong/datasets/ICDAR2015/:/ICDAR2015 --shm-size=64G --network=host -it registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda11.2-cudnn8 /bin/bash # 我这里把数据也挂载进去
python -m pip install -i https://pypi.douban.com/simple --upgrade pip && pip config set global.index-url https://pypi.douban.com/simple
cd /paddle
pip install -r requirements.txt

2、数据集

官网的数据集参考:https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/dataset/ocr_datasets.md

以ICDAR 2015数据集来说,可以直接下载PaddleOCR给的标注,也可以自行使用ppocr/utils/gen_label.py进行标签转换:

python gen_label.py --mode="det" --root_path="/ssd/xiedong/datasets/ICDAR2015/ch4_training_images/"  \--input_path="/ssd/xiedong/datasets/ICDAR2015/ch4_training_localization_transcription_gt" \--output_label="/ssd/xiedong/datasets/ICDAR2015/train_icdar2015_label.txt"

3、下载预训练模型

官网参考:https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/detection.md

下载预训练模型:
https://github.com/PaddlePaddle/PaddleClas/blob/release%2F2.0/README_cn.md#resnet%E5%8F%8A%E5%85%B6vd%E7%B3%BB%E5%88%97

我这里下载后放这里:
在这里插入图片描述
修改label文件中所描述的路径:

# 需要重写label文件中的label
import oslb_file = r"C:\Users\dong.xie\Desktop\test_icdar2015_label.txt"
lb_file_dst = lb_file
# 读取文件所有行
with open(lb_file, 'r') as f:lines = f.readlines()res_list = []
for line in lines:pathname, label = line.split('\t')basename = os.path.basename(pathname)pathname_new = r"/ICDAR2015/ch4_test_images/" + basename  # 改为绝对路径res_list.append(pathname_new + '\t' + label)# 重写label文件
with open(lb_file_dst, 'w') as f:f.write(''.join(res_list))

4、配置文件

配置文件修改为:

Global:use_gpu: trueuse_xpu: falseuse_mlu: falseepoch_num: 1200log_smooth_window: 20print_batch_step: 10save_model_dir: ./output/db_mv3/save_epoch_step: 1200# evaluation is run every 2000 iterationseval_batch_step: [ 0, 2000 ]cal_metric_during_train: Falsepretrained_model: ./pretrain_models/MobileNetV3_small_x0_35_ssld_pretrainedcheckpoints:save_inference_dir:use_visualdl: trueinfer_img: doc/imgs_en/img_10.jpgsave_res_path: ./output/det_db/predicts_db.txt  #设置检测模型的结果保存地址Architecture:model_type: detalgorithm: DBTransform:Backbone:name: MobileNetV3scale: 0.35model_name: smallNeck:name: DBFPNout_channels: 256Head:name: DBHeadk: 50Loss:name: DBLossbalance_loss: truemain_loss_type: DiceLossalpha: 5beta: 10ohem_ratio: 3Optimizer:name: Adambeta1: 0.9beta2: 0.999lr:learning_rate: 0.001regularizer:name: 'L2'factor: 0PostProcess:name: DBPostProcessthresh: 0.3box_thresh: 0.6max_candidates: 1000unclip_ratio: 1.5Metric:name: DetMetricmain_indicator: hmeanTrain:dataset:name: SimpleDataSetdata_dir: /ICDAR2015/ch4_training_images/label_file_list:- /ICDAR2015/train_icdar2015_label.txtratio_list: [ 1.0 ]transforms:- DecodeImage: # load imageimg_mode: BGRchannel_first: False- DetLabelEncode: # Class handling label- IaaAugment:augmenter_args:- { 'type': Fliplr, 'args': { 'p': 0.5 } }- { 'type': Affine, 'args': { 'rotate': [ -10, 10 ] } }- { 'type': Resize, 'args': { 'size': [ 0.5, 3 ] } }- EastRandomCropData:size: [ 640, 640 ]max_tries: 50keep_ratio: true- MakeBorderMap:shrink_ratio: 0.4thresh_min: 0.3thresh_max: 0.7- MakeShrinkMap:shrink_ratio: 0.4min_text_size: 8- NormalizeImage:scale: 1./255.mean: [ 0.485, 0.456, 0.406 ]std: [ 0.229, 0.224, 0.225 ]order: 'hwc'- ToCHWImage:- KeepKeys:keep_keys: [ 'image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask' ] # the order of the dataloader listloader:shuffle: Truedrop_last: Falsebatch_size_per_card: 16num_workers: 8use_shared_memory: TrueEval:dataset:name: SimpleDataSetdata_dir: /ICDAR2015/ch4_test_images/label_file_list:- /ICDAR2015/test_icdar2015_label.txttransforms:- DecodeImage: # load imageimg_mode: BGRchannel_first: False- DetLabelEncode: # Class handling label- DetResizeForTest:image_shape: [ 736, 1280 ]- NormalizeImage:scale: 1./255.mean: [ 0.485, 0.456, 0.406 ]std: [ 0.229, 0.224, 0.225 ]order: 'hwc'- ToCHWImage:- KeepKeys:keep_keys: [ 'image', 'shape', 'polys', 'ignore_tags' ]loader:shuffle: Falsedrop_last: Falsebatch_size_per_card: 1 # must be 1num_workers: 8use_shared_memory: True

选择模型的配置是下面,在初始构建模型的代码中会拉起:
Architecture:
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.35
model_name: small
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50

DecodeImage

从文件路径加载图片到内存:

     - DecodeImage: # load imageimg_mode: BGRchannel_first: False

DetLabelEncode

处理label文本,可看出文字标记为"*“或者”###"的,tag都会是True,后续将忽略此标记。

      - DetLabelEncode: # Class handling label

在这里插入图片描述

IaaAugment

这段代码定义了一个名为 IaaAugment 的类,用于进行图像增强操作。该类的实例化方法 init 接受一个参数 augmenter_args 和其他可选参数 **kwargs。如果没有提供 augmenter_args,则默认使用三个图像增强器:水平翻转、仿射变换和尺寸调整,每个增强器都有不同的参数设置。

call 方法接受一个数据字典 data,其中包括一个 image 键,对应于输入的图像。方法首先获取输入图像的形状,然后将图像传递给 augmenter 对象进行增强操作。增强器对象首先被转换为确定性的,然后对输入图像进行增强,返回增强后的图像。此外,方法还调用 may_augment_annotation 方法对标注进行相同的增强操作,并将增强后的数据字典返回。

may_augment_annotation 方法接受 aug 增强器对象、 data 数据字典和 shape 输入图像的形状。该方法使用 may_augment_poly 方法对 polys 键中的多边形进行增强操作,并将增强后的多边形数组存储在 data[‘polys’] 中。

may_augment_poly 方法接受 aug 增强器对象、img_shape 输入图像的形状和 poly 多边形数组。该方法将多边形的每个点转换为 imgaug.Keypoint 对象,并将它们作为一个 imgaug.KeypointsOnImage 对象传递给增强器的 augment_keypoints 方法。该方法返回一个包含增强后的关键点的 imgaug.KeypointsOnImage 对象,然后将这些点的坐标转换回原始多边形的形式,最后返回增强后的多边形。

      - IaaAugment:augmenter_args: # 水平翻转、仿射变换和尺寸调整- { 'type': Fliplr, 'args': { 'p': 0.5 } }- { 'type': Affine, 'args': { 'rotate': [ -10, 10 ] } }- { 'type': Resize, 'args': { 'size': [ 0.5, 3 ] } }

在这里插入图片描述

EastRandomCropData

这段代码定义了两个类 EastRandomCropData 和 RandomCropImgMask,它们都是用于数据增强(data augmentation)的。

EastRandomCropData 类实现了一种基于 EAST(Efficient and Accurate Scene Text detection)文本检测算法的数据增强方法,主要作用是对原始图片进行随机裁剪,使得裁剪出的区域包含至少一个文本框,并保持裁剪后的图片与原始图片比例一致。具体实现过程如下:

首先,通过调用 crop_area 函数计算出一个合适的裁剪区域,该函数会根据输入的 all_care_polys 参数(即所有不被标记为忽略的文本框)和裁剪区域的最小边长比例 min_crop_side_ratio,随机生成多个裁剪区域并返回其中包含至少一个文本框的那个区域。如果经过 max_tries 次尝试后还是没有找到合适的裁剪区域,则直接返回原始图片。

接着,根据裁剪区域的大小和目标大小(即 self.size 参数),计算出缩放比例 scale,并将裁剪后的图片缩放到目标大小。如果 keep_ratio 参数为 True,则先将目标大小的全零矩阵 padimg 创建出来,将缩放后的图片居中填充到 padimg 中,最后返回 padimg;否则直接将缩放后的图片调整为目标大小并返回。

最后,对原始文本框进行相应的裁剪和缩放操作,以保证其与裁剪后的图片大小相匹配,并返回增强后的数据。

RandomCropImgMask 类实现了一种随机裁剪的数据增强方法,主要作用是对输入的图像和掩码进行随机裁剪。具体实现过程如下:

首先,根据输入图像的大小和目标大小,生成一个随机裁剪区域,并将该区域应用于输入图像和掩码上。具体来说,如果掩码中存在非零像素(即存在文本区域),则保证裁剪区域至少包含一个文本区域;否则随机生成一个裁剪区域。

接着,对输入数据中包含在 self.crop_keys 列表中的数据(如图像、掩码等)进行裁剪操作,并将裁剪后的数据更新到原始数据中。

最后,返回增强后的数据。

      - EastRandomCropData:size: [ 640, 640 ]max_tries: 50keep_ratio: true

在这里插入图片描述

MakeBorderMap

      - MakeBorderMap:shrink_ratio: 0.4thresh_min: 0.3thresh_max: 0.7

5、开启训练

开启训练的指令:

# 单机单卡训练 mv3_db 模型
python3 tools/train.py -c configs/det/det_mv3_db.yml \-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained

det_mv3_db.yml的含义:https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/doc/doc_ch/config.md

6、纯记录,我在我服务器做的事情

在19服务器将docker镜像拉下来:

docker pull kevinchina/deeplearning:paddle2.1.3paddleocr

使用rsync同步文件:

rsync -avz xiedong@10.20.31.16:/ssd/xiedong/workplace/PaddleOCR-release-2.6/ ./PaddleOCR-release-2.6/

启动容器:

nvidia-docker run -v $PWD:/paddle -v /ssd/xiedong/datasets/ICDAR2015/:/ICDAR2015 --shm-size=64G --network=host -it  kevinchina/deeplearning:paddle2.1.3paddleocr /bin/bash

单机多卡训练:

python3 -m paddle.distributed.launch --gpus '0,1,2' tools/train.py -c configs/det/det_mv3_db_para_det_paddle.yml

断点resume训练:
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:

python3 -m paddle.distributed.launch --gpus '0,1,2' tools/train.py -c configs/det/det_mv3_db_para_det_paddle.yml -o Global.checkpoints=./output/db_mv3_para_det_paddle/latest.pdparams

7、显示label

将图片中的文字画框后显示出来:

import osimport PIL.Image as Image
import PIL.ImageDraw as ImageDraw
import cv2src = r"E:\q\ICDAR2015"
labelfilename = r"E:\q\ICDAR2015\train_icdar2015_label.txt"# ch4_test_images/img_61.jpg	[{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {"transcription": "###", "points": [[1197, 126], [1252, 118], [1257, 136], [1203, 144]]}, {"transcription": "###", "points": [[1137, 140], [1177, 132], [1180, 148], [1140, 156]]}, {"transcription": "###", "points": [[1096, 152], [1130, 145], [1133, 158], [1100, 165]]}, {"transcription": "###", "points": [[1061, 161], [1092, 154], [1093, 168], [1062, 175]]}, {"transcription": "###", "points": [[1030, 168], [1055, 162], [1056, 177], [1030, 183]]}, {"transcription": "###", "points": [[1000, 173], [1023, 168], [1025, 184], [1002, 189]]}, {"transcription": "###", "points": [[223, 293], [313, 288], [313, 311], [222, 316]]}]
with open(labelfilename, "r", encoding="utf-8") as f:lines = f.read().splitlines()for line in lines:line = line.split("\t")imgpath = os.path.join(src, line[0])img = Image.open(imgpath)img = img.convert("RGB")draw = ImageDraw.Draw(img)for d in eval(line[1]):points = d["points"]draw.line(points[0] + points[1] + points[2] + points[3] + points[0], fill=(255, 0, 0), width=3)img.show()

8、推导模型的导出与预测

检测模型转inference 模型方式:

python3 tools/export_model.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model="./output/db_mv3_3/best_accuracy" Global.save_inference_dir="./output/db_mv3_3_inference/"

在这里插入图片描述
段落框:

python3 tools/export_model.py -c configs/det/det_mv3_db_para_det_paddle.yml -o Global.pretrained_model="./output/db_mv3_para_det_paddle/best_accuracy" Global.save_inference_dir="./output/db_mv3_para_det_paddle_inference/"

用于预测:

python3 tools/infer/predict_det.py --det_algorithm="DB" --det_model_dir="./output/det_db_inference/" --image_dir="./doc/imgs/" --use_gpu=True

用于预测,段落框:

python3 tools/infer/predict_det.py --det_algorithm="DB" --det_model_dir="./output/db_mv3_para_det_paddle_inference/" --image_dir="./doc/imgs/" --use_gpu=True

9、转到onnx模型和mnn模型

到onnx

安装:pip install paddle2onnx onnx onnx-simplifier onnxruntime-gpu
导出模型:paddle2onnx --model_dir ./output/db_mv3_para_det_paddle_inference/ --model_filename inference.pdmodel --params_filename inference.pdiparams --save_file ./output/db_mv3_para_det_paddle_inference/model.onnx --opset_version 10 --enable_dev_version True --enable_onnx_checker True

参数选项

参数参数说明
–model_dir配置包含Paddle模型的目录路径
–model_filename[可选] 配置位于--model_dir下存储网络结构的文件名
–params_filename[可选] 配置位于--model_dir下存储模型参数的文件名称
–save_file指定转换后的模型保存目录路径
–opset_version[可选] 配置转换为ONNX的OpSet版本,目前支持7~15等多个版本,默认为9
–enable_dev_version[可选] 是否使用新版本Paddle2ONNX(推荐使用),默认为False
–enable_onnx_checker[可选] 配置是否检查导出为ONNX模型的正确性, 建议打开此开关。若指定为True, 默认为False
–enable_auto_update_opset[可选] 是否开启opset version自动升级,当低版本opset无法转换时,自动选择更高版本的opset 默认为True
–input_shape_dict[可选] 配置输入的shape, 默认为空; 此参数即将移除,如需要固定Paddle模型输入Shape,请使用此工具处理
–version[可选] 查看paddle2onnx版本
  • 使用onnxruntime验证转换模型, 请注意安装最新版本(最低要求1.10.0):

如你有ONNX模型优化的需求,推荐使用onnx-simplifier,也可使用如下命令对模型进行优化:

python -m paddle2onnx.optimize --input_model model.onnx --output_model new_model.onnx

如需要修改导出的模型输入形状,如改为静态shape:

python -m paddle2onnx.optimize --input_model model.onnx \--output_model new_model.onnx \--input_shape_dict "{'x':[1,3,224,224]}"

到mnn

MNNConvert -f ONNX --modelFile model.onnx --MNNModel model.mnn --bizCode MNN

10、在线可视化

百度提供的visualdl 对log可视化:

visualdl --logdir="vdl"

下面这图是我resume训练了,看起来loss还在波动。
在这里插入图片描述

也可以将inference.pdmodel模型放入后查看模型结构:
在这里插入图片描述

相关内容

热门资讯

linux入门---制作进度条 了解缓冲区 我们首先来看看下面的操作: 我们首先创建了一个文件并在这个文件里面添加了...
C++ 机房预约系统(六):学... 8、 学生模块 8.1 学生子菜单、登录和注销 实现步骤: 在Student.cpp的...
JAVA多线程知识整理 Java多线程基础 线程的创建和启动 继承Thread类来创建并启动 自定义Thread类的子类&#...
【洛谷 P1090】[NOIP... [NOIP2004 提高组] 合并果子 / [USACO06NOV] Fence Repair G ...
国民技术LPUART介绍 低功耗通用异步接收器(LPUART) 简介 低功耗通用异步收发器...
城乡供水一体化平台-助力乡村振... 城乡供水一体化管理系统建设方案 城乡供水一体化管理系统是运用云计算、大数据等信息化手段࿰...
程序的循环结构和random库...   第三个参数就是步长     引入文件时记得指明字符格式,否则读入不了 ...
中国版ChatGPT在哪些方面... 目录 一、中国巨大的市场需求 二、中国企业加速创新 三、中国的人工智能发展 四、企业愿景的推进 五、...
报名开启 | 共赴一场 Flu... 2023 年 1 月 25 日,Flutter Forward 大会在肯尼亚首都内罗毕...
汇编00-MASM 和 Vis... Qt源码解析 索引 汇编逆向--- MASM 和 Visual Studio入门 前提知识ÿ...
【简陋Web应用3】实现人脸比... 文章目录🍉 前情提要🌷 效果演示🥝 实现过程1. u...
前缀和与对数器与二分法 1. 前缀和 假设有一个数组,我们想大量频繁的去访问L到R这个区间的和,...
windows安装JDK步骤 一、 下载JDK安装包 下载地址:https://www.oracle.com/jav...
分治法实现合并排序(归并排序)... 🎊【数据结构与算法】专题正在持续更新中,各种数据结构的创建原理与运用✨...
在linux上安装配置node... 目录前言1,关于nodejs2,配置环境变量3,总结 前言...
Linux学习之端口、网络协议... 端口:设备与外界通讯交流的出口 网络协议:   网络协议是指计算机通信网...
Linux内核进程管理并发同步... 并发同步并发 是指在某一时间段内能够处理多个任务的能力,而 并行 是指同一时间能够处理...
opencv学习-HOG LO... 目录1. HOG(Histogram of Oriented Gradients,方向梯度直方图)1...
EEG微状态的功能意义 导读大脑的瞬时全局功能状态反映在其电场结构上。聚类分析方法一致地提取了四种头表面脑电场结构ÿ...
【Unity 手写PBR】Bu... 写在前面 前期积累: GAMES101作业7提高-实现微表面模型你需要了解的知识 【技...