OpenOCR-Demo / tools /infer_rec.py
topdu's picture
update app
4842f28
raw
history blame
15.3 kB
import os
from pathlib import Path
import sys
import time
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
import numpy as np
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
from tools.engine import Config
from tools.utility import ArgsParser
from tools.utils.ckpt import load_ckpt
from tools.utils.logging import get_logger
from tools.utils.utility import get_image_file_list
from tools.infer_det import replace_batchnorm
logger = get_logger()
root_dir = Path(__file__).resolve().parent
DEFAULT_CFG_PATH_REC_SERVER = str(root_dir /
'../configs/rec/svtrv2/svtrv2_ch.yml')
DEFAULT_CFG_PATH_REC = str(root_dir / '../configs/rec/svtrv2/repsvtr_ch.yml')
DEFAULT_DICT_PATH_REC = str(root_dir / './utils/ppocr_keys_v1.txt')
MODEL_NAME_REC = './openocr_repsvtr_ch.pth' # 模型文件名称
DOWNLOAD_URL_REC = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_repsvtr_ch.pth' # 模型文件 URL
MODEL_NAME_REC_SERVER = './openocr_svtrv2_ch.pth' # 模型文件名称
DOWNLOAD_URL_REC_SERVER = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_svtrv2_ch.pth' # 模型文件 URL
def check_and_download_model(model_name: str, url: str):
"""
检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
Args:
model_name (str): 模型文件的名称,例如 "model.pt"
url (str): 模型文件的下载地址
Returns:
str: 模型文件的完整路径
"""
if os.path.exists(model_name):
return model_name
# 固定缓存路径为用户主目录下的 ".cache/openocr"
cache_dir = Path.home() / '.cache' / 'openocr'
model_path = cache_dir / model_name
# 如果模型文件已存在,直接返回路径
if model_path.exists():
logger.info(f'Model already exists at: {model_path}')
return str(model_path)
# 如果文件不存在,下载模型
logger.info(f'Model not found. Downloading from {url}...')
# 创建缓存目录(如果不存在)
cache_dir.mkdir(parents=True, exist_ok=True)
try:
# 下载文件
import urllib.request
with urllib.request.urlopen(url) as response, open(model_path,
'wb') as out_file:
out_file.write(response.read())
logger.info(f'Model downloaded and saved at: {model_path}')
return str(model_path)
except Exception as e:
logger.error(f'Error downloading the model: {e}')
# 提示用户手动下载
logger.error(
f'Unable to download the model automatically. '
f'Please download the model manually from the following URL:\n{url}\n'
f'and save it to: {model_name} or {model_path}')
raise RuntimeError(
f'Failed to download the model. Please download it manually from {url} '
f'and save it to {model_path}') from e
class RatioRecTVReisze(object):
def __init__(self, cfg):
self.max_ratio = cfg['Eval']['loader'].get('max_ratio', 12)
self.base_shape = cfg['Eval']['dataset'].get(
'base_shape', [[64, 64], [96, 48], [112, 40], [128, 32]])
self.base_h = cfg['Eval']['dataset'].get('base_h', 32)
self.interpolation = T.InterpolationMode.BICUBIC
transforms = []
transforms.extend([
T.ToTensor(),
T.Normalize(0.5, 0.5),
])
self.transforms = T.Compose(transforms)
self.ceil = cfg['Eval']['dataset'].get('ceil', False),
def __call__(self, data):
img = data['image']
imgH = self.base_h
w, h = img.size
if self.ceil:
gen_ratio = int(float(w) / float(h)) + 1
else:
gen_ratio = max(1, round(float(w) / float(h)))
ratio_resize = min(gen_ratio, self.max_ratio)
imgW, imgH = self.base_shape[ratio_resize -
1] if ratio_resize <= 4 else [
self.base_h *
ratio_resize, self.base_h
]
resized_w = imgW
resized_image = F.resize(img, (imgH, resized_w),
interpolation=self.interpolation)
img = self.transforms(resized_image)
data['image'] = img
return data
def build_rec_process(cfg):
transforms = []
ratio_resize_flag = True
for op in cfg['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Resize' in op_name:
ratio_resize_flag = False
if 'Label' in op_name:
continue
elif op_name in ['RecResizeImg']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
if cfg['Architecture']['algorithm'] in ['SAR', 'RobustScanner']:
if 'valid_ratio' in op[op_name]['keep_keys']:
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
else:
op[op_name]['keep_keys'] = ['image']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
return transforms, ratio_resize_flag
def set_device(device, numId=0):
if device == 'gpu' and torch.cuda.is_available():
device = torch.device(f'cuda:{numId}')
else:
logger.info('GPU is not available, using CPU.')
device = torch.device('cpu')
return device
class OpenRecognizer(object):
def __init__(self, config=None, mode='mobile', numId=0):
"""
初始化方法。
Args:
config (dict, optional): 配置信息。默认为None。
mode (str, optional): 模式,'server' 或 'mobile'。默认为'mobile'。
numId (int, optional): 设备编号。默认为0。
Returns:
None
Raises:
"""
if config is None:
if mode == 'server':
config = Config(
DEFAULT_CFG_PATH_REC_SERVER).cfg # server model
if not os.path.exists(config['Global']['pretrained_model']):
model_dir = check_and_download_model(
MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER)
else:
config = Config(DEFAULT_CFG_PATH_REC).cfg # mobile model
if not os.path.exists(config['Global']['pretrained_model']):
model_dir = check_and_download_model(
MODEL_NAME_REC, DOWNLOAD_URL_REC)
config['Global']['pretrained_model'] = model_dir
config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
else:
if config['Architecture']['algorithm'] == 'SVTRv2_mobile':
if not os.path.exists(config['Global']['pretrained_model']):
config['Global'][
'pretrained_model'] = check_and_download_model(
MODEL_NAME_REC, DOWNLOAD_URL_REC)
config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
elif config['Architecture']['algorithm'] == 'SVTRv2_server':
if not os.path.exists(config['Global']['pretrained_model']):
config['Global'][
'pretrained_model'] = check_and_download_model(
MODEL_NAME_REC_SERVER, DOWNLOAD_URL_REC_SERVER)
config['Global']['character_dict_path'] = DEFAULT_DICT_PATH_REC
global_config = config['Global']
self.cfg = config
if global_config['pretrained_model'] is None:
global_config[
'pretrained_model'] = global_config['output_dir'] + '/best.pth'
# build post process
from openrec.modeling import build_model as build_rec_model
from openrec.postprocess import build_post_process
from openrec.preprocess import create_operators, transform
self.transform = transform
self.post_process_class = build_post_process(config['PostProcess'],
global_config)
char_num = self.post_process_class.get_character_num()
config['Architecture']['Decoder']['out_channels'] = char_num
# print(char_num)
self.model = build_rec_model(config['Architecture'])
load_ckpt(self.model, config)
# exit(0)
self.device = set_device(global_config['device'], numId=numId)
self.model.eval()
replace_batchnorm(self.model.encoder)
self.model.to(device=self.device)
transforms, ratio_resize_flag = build_rec_process(self.cfg)
global_config['infer_mode'] = True
self.ops = create_operators(transforms, global_config)
if ratio_resize_flag:
ratio_resize = RatioRecTVReisze(cfg=self.cfg)
self.ops.insert(-1, ratio_resize)
def __call__(self,
img_path=None,
img_numpy_list=None,
img_numpy=None,
batch_num=1):
"""
调用函数,处理输入图像,并返回识别结果。
Args:
img_path (str, optional): 图像文件的路径。默认为 None。
img_numpy_list (list, optional): 包含多个图像 numpy 数组的列表。默认为 None。
img_numpy (numpy.ndarray, optional): 单个图像的 numpy 数组。默认为 None。
batch_num (int, optional): 每次处理的图像数量。默认为 1。
Returns:
list: 包含识别结果的列表,每个元素为一个字典,包含文件路径(如果有的话)、文本、分数和延迟时间。
Raises:
Exception: 如果没有提供图像路径或 numpy 数组,则引发异常。
"""
if img_numpy is not None:
img_numpy_list = [img_numpy]
num_img = 1
elif img_path is not None:
img_path = get_image_file_list(img_path)
num_img = len(img_path)
elif img_numpy_list is not None:
num_img = len(img_numpy_list)
else:
raise Exception('No input image path or numpy array.')
results = []
for start_idx in range(0, num_img, batch_num):
batch_data = []
batch_others = []
batch_file_names = []
max_width, max_height = 0, 0
# Prepare batch data
for img_idx in range(start_idx, min(start_idx + batch_num,
num_img)):
if img_numpy_list is not None:
img = img_numpy_list[img_idx]
data = {'image': img}
elif img_path is not None:
file_name = img_path[img_idx]
with open(file_name, 'rb') as f:
img = f.read()
data = {'image': img}
data = self.transform(data, self.ops[:1])
batch_file_names.append(file_name)
batch = self.transform(data, self.ops[1:])
others = None
if self.cfg['Architecture']['algorithm'] in [
'SAR', 'RobustScanner'
]:
valid_ratio = np.expand_dims(batch[-1], axis=0)
batch_others.append(valid_ratio)
# others = [torch.from_numpy(valid_ratio).to(device=self.device)]
resized_image = batch[0]
h, w = resized_image.shape[-2:]
max_width = max(max_width, w)
max_height = max(max_height, h)
batch_data.append(batch[0])
padded_batch_data = []
for resized_image in batch_data:
padded_image = np.zeros([1, 3, max_height, max_width],
dtype=np.float32)
h, w = resized_image.shape[-2:]
# Apply padding (bottom-right padding)
padded_image[:, :, :h, :
w] = resized_image # 0 is typically used for padding
padded_batch_data.append(padded_image)
if batch_others:
others = np.concatenate(batch_others, axis=0)
else:
others = None
images = np.concatenate(padded_batch_data, axis=0)
images = torch.from_numpy(images).to(device=self.device)
with torch.no_grad():
t_start = time.time()
preds = self.model(images, others)
t_cost = time.time() - t_start
post_results = self.post_process_class(preds)
for i, post_result in enumerate(post_results):
if img_path is not None:
info = {
'file': batch_file_names[i],
'text': post_result[0],
'score': post_result[1],
'elapse': t_cost
}
else:
info = {
'text': post_result[0],
'score': post_result[1],
'elapse': t_cost
}
results.append(info)
return results
def main(cfg):
model = OpenRecognizer(cfg)
save_res_path = cfg['Global']['output_dir']
if not os.path.exists(save_res_path):
os.makedirs(save_res_path)
t_sum = 0
sample_num = 0
max_len = cfg['Global']['max_text_length']
text_len_time = [0 for _ in range(max_len)]
text_len_num = [0 for _ in range(max_len)]
sample_num = 0
with open(save_res_path + '/rec_results.txt', 'wb') as fout:
for file in get_image_file_list(cfg['Global']['infer_img']):
preds_result = model(img_path=file, batch_num=1)[0]
rec_text = preds_result['text']
score = preds_result['score']
t_cost = preds_result['elapse']
info = rec_text + '\t' + str(score)
text_len_num[min(max_len - 1, len(rec_text))] += 1
text_len_time[min(max_len - 1, len(rec_text))] += t_cost
logger.info(
f'{sample_num} {file}\t result: {info}, time cost: {t_cost}')
otstr = file + '\t' + info + '\n'
t_sum += t_cost
fout.write(otstr.encode())
sample_num += 1
print(text_len_num)
w_avg_t_cost = []
for l_t_cost, l_num in zip(text_len_time, text_len_num):
if l_num != 0:
w_avg_t_cost.append(l_t_cost / l_num)
print(w_avg_t_cost)
w_avg_t_cost = sum(w_avg_t_cost) / len(w_avg_t_cost)
logger.info(
f'Sample num: {sample_num}, Weighted Avg time cost: {t_sum/sample_num}, Avg time cost: {w_avg_t_cost}'
)
logger.info('success!')
if __name__ == '__main__':
FLAGS = ArgsParser().parse_args()
cfg = Config(FLAGS.config)
FLAGS = vars(FLAGS)
opt = FLAGS.pop('opt')
cfg.merge_dict(FLAGS)
cfg.merge_dict(opt)
main(cfg.cfg)