MMOCR环境配置及训练测试(详细)
- 1,环境配置
- 2,MMOCR相关结构介绍
- (1)Config配置
- (2)文本识别数据集
- 3,训练测试(sar)
1,环境配置
查看cuda版本安装对应的pytorch
# 先查看本地的cuda版本111,所以torch版本要对应
pip3 install install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html# 安装mmdet(对应cuda、torch)
pip install `mmdet<3.2.0` -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9/index.html# 安装mmcv(对应cuda、torch)
pip install `mmcv<2.1.0` -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9/index.html # 从 github 上下载最新的 ocr>mmocr 源代码
git clone https://github.com/open-mmlab/ocr>mmocr.git# 进入 ocr>mmocr 主目录
import os
os.chdir('ocr>mmocr')# 安装相应的库
pip install -r requirements.txt# ocr>mmocr要注册
pip install -v -e .# 最终查看安装后结果如下:
pip list
2,MMOCR相关结构介绍
(1)Config配置
Config目录下包含文本识别、文本检测、关键信息提取的各个算法配置文件
以sar为例如下
对应的数据集配置
(2)文本识别数据集
3,训练测试(sar)
这里使用其内部的sar算法进行训练测试
(1)标签转换
这里要对之前用opensv生成的初步的ocr识别数据集进行再一次的格式转换
ocr>mmocrpy_72">(1-1)标签转换脚本labelme2ocr>mmocr.py
import cv2
import os
import csv
import jsoninput_path = "E:/tupian/ocr2/ocr_data"
output_path = "./data"#拆分列表
def split_list(lst, ratios, num_splits):"""将列表按照指定比例和数量拆分成子列表:param lst: 待拆分列表:param ratios: 每个子列表的元素占比,由小数表示的列表:param num_splits: 子列表的数量:return: 拆分后的子列表组成的列表"""if len(ratios) != num_splits:raise ValueError("The length of ratios must equal to num_splits.")total_ratio = sum(ratios)if total_ratio != 1:raise ValueError("The sum of ratios must be equal to 1.")n = len(lst)result = []start = 0for i in range(num_splits):end = start + int(n * ratios[i])result.append(lst[start:end])start = endreturn resultdef read_path(input_path,output_path,ratio):#读取json文件,data_label = {}with open('tests/data/rec_toy_dataset/labels.json', 'r', encoding="utf-8") as f:data_label = json.load(f)print(data_label['data_list'])#遍历该目录下的所有图片文件train_list = []for filename in os.listdir(input_path):if '.jpg' in filename:img_filename = filenameimg_path = input_path +'/' + filenametxt_path = input_path +'/' + filename.replace('.jpg','.txt')img_output_path = output_path + "/imgs/" + img_filenameif not os.path.exists(output_path + "/imgs"):os.makedirs(output_path + "/imgs")print(img_path)print(txt_path)#读取保存图像img = cv2.imread(img_path)cv2.imwrite(img_output_path, img)#读取txt文件并保存到tsv#中间用tab隔开(字符空格隔开,使用space表示空格)label = ''with open(txt_path, "r", encoding='utf-8') as f:# read():读取文件全部内容,以字符串形式返回结果data = f.read()label = data.replace('\n','')#print('label:',label)i_dict = {}text_dict_list = []text_dict = {}text_dict['text'] = labeltext_dict_list.append(text_dict)i_dict['instances'] = text_dict_listi_dict['img_path'] = img_filenameprint(i_dict)train_list.append(i_dict)print(train_list)#按照比例分割开列表ratios = [1-ratio, ratio]print(ratios)num_splits = 2result = split_list(train_list, ratios, num_splits)print(result[0])print(result[1])#暂不分离列表,令其训练集和数据集都一样data_label['data_list'] = result[0]with open(output_path + '/labels_train.json', 'w', encoding="utf-8", newline='') as f:content = json.dumps(data_label)f.write(content)data_label['data_list'] = result[1]with open(output_path + '/labels_val.json', 'w', encoding="utf-8", newline='') as f:content = json.dumps(data_label)f.write(content)#注意*处如果包含家目录(home)不能写成~符号代替
#必须要写成"/home"的格式,否则会报错说找不到对应的目录
#读取的目录
read_path(input_path, output_path, 0.2)
#print(os.getcwd())
ocr>mmocr_180">(1-2)转换为ocr>mmocr数据集
会生成三个文件imgs中为对应图片,labels_train.json训练集标签,labels_val.json测试集标签
(1-3)将生成的数据集拷贝到rec_toy_dataset路径下
ocr>mmocr-main\tests\data\rec_toy_dataset
(1-4)修改相应的数据集配置文件toy_data.py
(2)训练
(2-1)训练脚本C1_TEST.py
from mmengine import Config
from mmengine.runner import Runner
import time
import multiprocessing
if __name__ == '__main__':multiprocessing.freeze_support()#使用sar算法训练,效果不是很好cfg = Config.fromfile('configs/textrecog/sar/sar_resnet31_parallel-decoder_5e_toy.py')#预训练模型parallercfg.load_from = 'sar_resnet31_parallel-decoder_5e_st-sub_mj-sub_sa_real_20220915_171910-04eb4e75.pth'cfg.work_dir = 'work_dirs/sar_resnet31_parallel-decoder_5e_toy/'# #使用crnn算法训练看看# cfg = Config.fromfile('configs/textrecog/crnn/crnn_mini-vgg_5e_toy.py')# # 预训练模型paraller# cfg.load_from = 'crnn_mini-vgg_5e_mj_20220826_224120-8afbedbb.pth'# cfg.work_dir = 'work_dirs/crnn_mini-vgg_5e_toy/'cfg.optim_wrapper.optimizer.lr = 1e-3cfg.train_dataloader.batch_size = 1cfg.train_cfg.max_epochs = 1000 # 训练总轮数cfg.default_hooks.checkpoint.interval = 100 # 多少轮保存一次模型cfg.param_scheduler = Nonecfg.randomness = dict(seed=0)print(cfg.pretty_text)cfg.visualizer.name = f'{time.localtime()}'runner = Runner.from_cfg(cfg)runner.train()
(2-2)预训练模型下载
下载地址可以通过ocr>mmocr自带的sar算法下的readme.md得到:
https://download.openmmlab.com/ocr>mmocr/textrecog/sar/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real/sar_resnet31_sequential-decoder_5e_st-sub_mj-sub_sa_real_20220915_185451-1fd6b1fc.pth
(2-3)修改配置文件sar_resnet31_parallel-decoder_5e_toy
此处时为了保证训练防止报错
(2-4)训练结果如下
训练如下
(3)测试
(3-1)测试脚本C2_TEST.py
from ocr>mmocr.apis import TextRecInferencer
import matplotlib.pyplot as plt
# %matplotlib inlineimport cv2
import os#专用英文的模型
def detect_imgs(input_path,output_path):ckpt_path = "G:/pyproject/MMOCR/ocr>mmocr-main/work_dirs/sar_resnet31_parallel-decoder_5e_toy/epoch_1000.pth"config_path = "configs/textrecog/sar/sar_resnet31_parallel-decoder_5e_toy.py"# ckpt_path = "G:/pyproject/MMOCR/ocr>mmocr-main/work_dirs/crnn_mini-vgg_5e_toy/epoch_1000.pth"# config_path = "configs/textrecog/crnn/crnn_mini-vgg_5e_toy.py"infer = TextRecInferencer(config_path, ckpt_path,device='cuda:0')for filename in os.listdir(input_path):img_fp = ''img_output_path = ''if ('.jpg' in filename) or ('.png' in filename) or ('.bmp' in filename):img_filename = filenameimg_fp = input_path + '/' + filenametxt_path = input_path + '/' + filename.replace('.jpg', '.txt')img_output_path = output_path + "/outimgs/" + img_filenameif not os.path.exists(output_path + "/outimgs"):os.makedirs(output_path + "/outimgs")img_path = img_fpprint(img_path)img_bgr = cv2.imread(img_path)#plt.imshow(img_bgr[:, :, ::-1])#plt.show()result = infer(img_path)predictions = result['predictions'] #这是列表,即一张图像可能有多个识别结果print(predictions)print(type(predictions))text = result['predictions'][0]['text'] #这里默认将其一张图像一个识别结果text = text.replace('<UKN>','-')print(text) # 12<UKN>05BN 这种结果无法保存txtcv2.imwrite(img_output_path, img_bgr)string_txt = img_output_path.replace('.jpg', '') + '__' + text + '.txt'print(string_txt)with open(string_txt, 'w') as tsvfile:tsvfile.write(text)# plt.imshow(result['visualization'][0])# # plt.show()# plt.savefig(img_output_path)input_path = 'G:/pyproject/MMOCR/ocr>mmocr-main/data/imgs'
output_path = 'G:/pyproject/MMOCR/ocr>mmocr-main/data'detect_imgs(input_path,output_path)