OpenOCR-Demo / tools /infer_det.py
topdu's picture
update app
695a4a4
raw
history blame
17.4 kB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from pathlib import Path
import time
import numpy as np
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
os.environ['FLAGS_allocator_strategy'] = 'auto_growth'
import cv2
import json
import torch
from tools.engine import Config
from tools.utility import ArgsParser
from tools.utils.ckpt import load_ckpt
from tools.utils.logging import get_logger
from tools.utils.utility import get_image_file_list
logger = get_logger()
root_dir = Path(__file__).resolve().parent
DEFAULT_CFG_PATH_DET = str(root_dir / '../configs/det/dbnet/repvit_db.yml')
MODEL_NAME_DET = './openocr_det_repvit_ch.pth' # 模型文件名称
DOWNLOAD_URL_DET = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/openocr_det_repvit_ch.pth' # 模型文件 URL
def check_and_download_model(model_name: str, url: str):
"""
检查预训练模型是否存在,若不存在则从指定 URL 下载到固定缓存目录。
Args:
model_name (str): 模型文件的名称,例如 "model.pt"
url (str): 模型文件的下载地址
Returns:
str: 模型文件的完整路径
"""
if os.path.exists(model_name):
return model_name
# 固定缓存路径为用户主目录下的 ".cache/openocr"
cache_dir = Path.home() / '.cache' / 'openocr'
model_path = cache_dir / model_name
# 如果模型文件已存在,直接返回路径
if model_path.exists():
logger.info(f'Model already exists at: {model_path}')
return str(model_path)
# 如果文件不存在,下载模型
logger.info(f'Model not found. Downloading from {url}...')
# 创建缓存目录(如果不存在)
cache_dir.mkdir(parents=True, exist_ok=True)
try:
# 下载文件
import urllib.request
with urllib.request.urlopen(url) as response, open(model_path,
'wb') as out_file:
out_file.write(response.read())
logger.info(f'Model downloaded and saved at: {model_path}')
return str(model_path)
except Exception as e:
logger.error(f'Error downloading the model: {e}')
# 提示用户手动下载
logger.error(
f'Unable to download the model automatically. '
f'Please download the model manually from the following URL:\n{url}\n'
f'and save it to: {model_name} or {model_path}')
raise RuntimeError(
f'Failed to download the model. Please download it manually from {url} '
f'and save it to {model_path}') from e
def replace_batchnorm(net):
for child_name, child in net.named_children():
if hasattr(child, 'fuse'):
fused = child.fuse()
setattr(net, child_name, fused)
replace_batchnorm(fused)
elif isinstance(child, torch.nn.BatchNorm2d):
setattr(net, child_name, torch.nn.Identity())
else:
replace_batchnorm(child)
def padding_image(img, size=(640, 640)):
"""
Padding an image using OpenCV:
- If the image is smaller than the target size, pad it to 640x640.
- If the image is larger than the target size, split it into multiple 640x640 images and record positions.
:param image_path: Path to the input image.
:param output_dir: Directory to save the output images.
:param size: The target size for padding or splitting (default 640x640).
:return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image.
"""
img_height, img_width = img.shape[:2]
target_width, target_height = size
# If image is smaller than target size, pad the image to 640x640
# Calculate padding amounts (top, bottom, left, right)
pad_top = 0
pad_bottom = target_height - img_height
pad_left = 0
pad_right = target_width - img_width
# Pad the image (white padding, border type: constant)
padded_img = cv2.copyMakeBorder(img,
pad_top,
pad_bottom,
pad_left,
pad_right,
cv2.BORDER_CONSTANT,
value=[0, 0, 0])
# Return the padded area positions (top-left and bottom-right coordinates of the original image)
return padded_img
def resize_image(img, size=(640, 640), over_lap=64):
"""
Resize an image using OpenCV:
- If the image is smaller than the target size, pad it to 640x640.
- If the image is larger than the target size, split it into multiple 640x640 images and record positions.
:param image_path: Path to the input image.
:param output_dir: Directory to save the output images.
:param size: The target size for padding or splitting (default 640x640).
:return: List of tuples containing the coordinates of the top-left corner of each cropped 640x640 image.
"""
img_height, img_width = img.shape[:2]
target_width, target_height = size
# If image is smaller than target size, pad the image to 640x640
if img_width <= target_width and img_height <= target_height:
# Calculate padding amounts (top, bottom, left, right)
if img_width == target_width and img_height == target_height:
return [img], [[0, 0, img_width, img_height]]
padded_img = padding_image(img, size)
# Return the padded area positions (top-left and bottom-right coordinates of the original image)
return [padded_img], [[0, 0, img_width, img_height]]
img_height, img_width = img.shape[:2]
# If image is larger than or equal to target size, crop it into 640x640 tiles
crop_positions = []
count = 0
cropped_img_list = []
for top in range(0, img_height - over_lap, target_height - over_lap):
for left in range(0, img_width - over_lap, target_width - over_lap):
# Calculate the bottom and right boundaries for the crop
right = min(left + target_width, img_width)
bottom = min(top + target_height, img_height)
if right >= img_width:
right = img_width
left = max(0, right - target_width)
if bottom >= img_height:
bottom = img_height
top = max(0, bottom - target_height)
# Crop the image
cropped_img = img[top:bottom, left:right]
if bottom - top < target_height or right - left < target_width:
cropped_img = padding_image(cropped_img, size)
count += 1
cropped_img_list.append(cropped_img)
# Record the position of the cropped image
crop_positions.append([left, top, right, bottom])
return cropped_img_list, crop_positions
def restore_preds(preds, crop_positions, original_size):
restored_pred = torch.zeros((1, 1, original_size[0], original_size[1]),
dtype=preds.dtype,
device=preds.device)
count = 0
for cropped_pred, (left, top, right, bottom) in zip(preds, crop_positions):
crop_height = bottom - top
crop_width = right - left
corp_vis_img = cropped_pred[:, :crop_height, :crop_width]
mask = corp_vis_img > 0.3
count += 1
restored_pred[:, :, top:top + crop_height, left:left +
crop_width] += mask[:, :crop_height, :crop_width].to(
preds.dtype)
return restored_pred
def draw_det_res(dt_boxes, img, img_name, save_path):
src_im = img
for box in dt_boxes:
box = np.array(box).astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
if not os.path.exists(save_path):
os.makedirs(save_path)
save_path = os.path.join(save_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
def set_device(device, numId=0):
if device == 'gpu' and torch.cuda.is_available():
device = torch.device(f'cuda:{numId}')
else:
logger.info('GPU is not available, using CPU.')
device = torch.device('cpu')
return device
class OpenDetector(object):
def __init__(self, config=None, numId=0):
"""
初始化函数。
Args:
config (dict, optional): 配置文件,默认为None。如果为None,则使用默认配置文件。
numId (int, optional): 设备编号,默认为0。
Returns:
None
Raises:
"""
if config is None:
config = Config(DEFAULT_CFG_PATH_DET).cfg
if not os.path.exists(config['Global']['pretrained_model']):
config['Global']['pretrained_model'] = check_and_download_model(
MODEL_NAME_DET, DOWNLOAD_URL_DET)
from opendet.modeling import build_model as build_det_model
from opendet.postprocess import build_post_process
from opendet.preprocess import create_operators, transform
self.transform = transform
global_config = config['Global']
# build model
self.model = build_det_model(config['Architecture'])
self.model.eval()
load_ckpt(self.model, config)
replace_batchnorm(self.model.backbone)
self.device = set_device(config['Global']['device'], numId=numId)
self.model.to(device=self.device)
# create data ops
transforms = []
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
continue
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
self.ops = create_operators(transforms, global_config)
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
global_config)
def crop_infer(
self,
img_path=None,
img_numpy_list=None,
img_numpy=None,
):
if img_numpy is not None:
img_numpy_list = [img_numpy]
num_img = 1
elif img_path is not None:
num_img = len(img_path)
elif img_numpy_list is not None:
num_img = len(img_numpy_list)
else:
raise Exception('No input image path or numpy array.')
results = []
for img_idx in range(num_img):
if img_numpy_list is not None:
img = img_numpy_list[img_idx]
data = {'image': img}
elif img_path is not None:
with open(img_path[img_idx], 'rb') as f:
img = f.read()
data = {'image': img}
data = self.transform(data, self.ops[:1])
src_img_ori = data['image']
img_height, img_width = src_img_ori.shape[:2]
target_size = 640
over_lap = 64
if img_height > img_width:
r_h = target_size * 2 - over_lap
r_w = img_width * (target_size * 2 - over_lap) // img_height
else:
r_w = target_size * 2 - over_lap
r_h = img_height * (target_size * 2 - over_lap) // img_width
src_img = cv2.resize(src_img_ori, (r_w, r_h))
shape_list_ori = np.array([[
img_height, img_width,
float(r_h) / img_height,
float(r_w) / img_width
]])
img_height, img_width = src_img.shape[:2]
cropped_img_list, crop_positions = resize_image(src_img,
size=(target_size,
target_size),
over_lap=over_lap)
image_list = []
shape_list = []
for img in cropped_img_list:
batch_i = self.transform({'image': img}, self.ops[-3:-1])
image_list.append(batch_i['image'])
shape_list.append([640, 640, 1, 1])
images = np.array(image_list)
shape_list = np.array(shape_list)
images = torch.from_numpy(images).to(device=self.device)
with torch.no_grad():
t_start = time.time()
preds = self.model(images)
t_cost = time.time() - t_start
preds['maps'] = restore_preds(preds['maps'], crop_positions,
(img_height, img_width))
post_result = self.post_process_class(preds, shape_list_ori)
info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
results.append(info)
return results
def __call__(self,
img_path=None,
img_numpy_list=None,
img_numpy=None,
return_mask=False,
**kwargs):
"""
对输入图像进行处理,并返回处理结果。
Args:
img_path (str, optional): 图像文件路径。默认为 None。
img_numpy_list (list, optional): 图像数据列表,每个元素为 numpy 数组。默认为 None。
img_numpy (numpy.ndarray, optional): 图像数据,numpy 数组格式。默认为 None。
Returns:
list: 包含处理结果的列表。每个元素为一个字典,包含 'boxes' 和 'elapse' 两个键。
'boxes' 的值为检测到的目标框点集,'elapse' 的值为处理时间。
Raises:
Exception: 若没有提供图像路径或 numpy 数组,则抛出异常。
"""
if img_numpy is not None:
img_numpy_list = [img_numpy]
num_img = 1
elif img_path is not None:
img_path = get_image_file_list(img_path)
num_img = len(img_path)
elif img_numpy_list is not None:
num_img = len(img_numpy_list)
else:
raise Exception('No input image path or numpy array.')
results = []
for img_idx in range(num_img):
if img_numpy_list is not None:
img = img_numpy_list[img_idx]
data = {'image': img}
elif img_path is not None:
with open(img_path[img_idx], 'rb') as f:
img = f.read()
data = {'image': img}
data = self.transform(data, self.ops[:1])
batch = self.transform(data, self.ops[1:])
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
images = torch.from_numpy(images).to(device=self.device)
with torch.no_grad():
t_start = time.time()
preds = self.model(images)
t_cost = time.time() - t_start
post_result = self.post_process_class(preds, shape_list, **kwargs)
info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
if return_mask:
if isinstance(preds['maps'], torch.Tensor):
mask = preds['maps'].detach().cpu().numpy()
else:
mask = preds['maps']
info['mask'] = mask
results.append(info)
return results
@torch.no_grad()
def main(cfg):
is_visualize = cfg['Global'].get('is_visualize', False)
model = OpenDetector(cfg)
save_res_path = './det_results/'
if not os.path.exists(save_res_path):
os.makedirs(save_res_path)
sample_num = 0
with open(save_res_path + '/det_results.txt', 'wb') as fout:
for file in get_image_file_list(cfg['Global']['infer_img']):
preds_result = model(img_path=file)[0]
logger.info('{} infer_img: {}, time cost: {}'.format(
sample_num, file, preds_result['elapse']))
boxes = preds_result['boxes']
dt_boxes_json = []
for box in boxes:
tmp_json = {}
tmp_json['points'] = np.array(box).tolist()
dt_boxes_json.append(tmp_json)
if is_visualize:
src_img = cv2.imread(file)
draw_det_res(boxes, src_img, file, save_res_path)
logger.info('The detected Image saved in {}'.format(
os.path.join(save_res_path, os.path.basename(file))))
otstr = file + '\t' + json.dumps(dt_boxes_json) + '\n'
logger.info('results: {}'.format(json.dumps(dt_boxes_json)))
fout.write(otstr.encode())
sample_num += 1
logger.info(
f"Results saved to {os.path.join(save_res_path, 'det_results.txt')}.)"
)
logger.info('success!')
if __name__ == '__main__':
FLAGS = ArgsParser().parse_args()
cfg = Config(FLAGS.config)
FLAGS = vars(FLAGS)
opt = FLAGS.pop('opt')
cfg.merge_dict(FLAGS)
cfg.merge_dict(opt)
main(cfg.cfg)