Spaces:
Runtime error
Runtime error
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
import os | |
import sys | |
__dir__ = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(__dir__) | |
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) | |
os.environ["FLAGS_allocator_strategy"] = 'auto_growth' | |
import cv2 | |
import json | |
import paddle | |
from ppocr.data import create_operators, transform | |
from ppocr.modeling.architectures import build_model | |
from ppocr.postprocess import build_post_process | |
from ppocr.utils.save_load import load_model | |
from ppocr.utils.visual import draw_ser_results | |
from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps | |
import tools.program as program | |
def to_tensor(data): | |
import numbers | |
from collections import defaultdict | |
data_dict = defaultdict(list) | |
to_tensor_idxs = [] | |
for idx, v in enumerate(data): | |
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): | |
if idx not in to_tensor_idxs: | |
to_tensor_idxs.append(idx) | |
data_dict[idx].append(v) | |
for idx in to_tensor_idxs: | |
data_dict[idx] = paddle.to_tensor(data_dict[idx]) | |
return list(data_dict.values()) | |
class SerPredictor(object): | |
def __init__(self, config): | |
global_config = config['Global'] | |
# build post process | |
self.post_process_class = build_post_process(config['PostProcess'], | |
global_config) | |
# build model | |
self.model = build_model(config['Architecture']) | |
load_model( | |
config, self.model, model_type=config['Architecture']["model_type"]) | |
from paddleocr import PaddleOCR | |
self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False) | |
# create data ops | |
transforms = [] | |
for op in config['Eval']['dataset']['transforms']: | |
op_name = list(op)[0] | |
if 'Label' in op_name: | |
op[op_name]['ocr_engine'] = self.ocr_engine | |
elif op_name == 'KeepKeys': | |
op[op_name]['keep_keys'] = [ | |
'input_ids', 'labels', 'bbox', 'image', 'attention_mask', | |
'token_type_ids', 'segment_offset_id', 'ocr_info', | |
'entities' | |
] | |
transforms.append(op) | |
global_config['infer_mode'] = True | |
self.ops = create_operators(config['Eval']['dataset']['transforms'], | |
global_config) | |
self.model.eval() | |
def __call__(self, img_path): | |
with open(img_path, 'rb') as f: | |
img = f.read() | |
data = {'image': img} | |
batch = transform(data, self.ops) | |
batch = to_tensor(batch) | |
preds = self.model(batch) | |
post_result = self.post_process_class( | |
preds, | |
attention_masks=batch[4], | |
segment_offset_ids=batch[6], | |
ocr_infos=batch[7]) | |
return post_result, batch | |
if __name__ == '__main__': | |
config, device, logger, vdl_writer = program.preprocess() | |
os.makedirs(config['Global']['save_res_path'], exist_ok=True) | |
ser_engine = SerPredictor(config) | |
infer_imgs = get_image_file_list(config['Global']['infer_img']) | |
with open( | |
os.path.join(config['Global']['save_res_path'], | |
"infer_results.txt"), | |
"w", | |
encoding='utf-8') as fout: | |
for idx, img_path in enumerate(infer_imgs): | |
save_img_path = os.path.join( | |
config['Global']['save_res_path'], | |
os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") | |
logger.info("process: [{}/{}], save result to {}".format( | |
idx, len(infer_imgs), save_img_path)) | |
result, _ = ser_engine(img_path) | |
result = result[0] | |
fout.write(img_path + "\t" + json.dumps( | |
{ | |
"ocr_info": result, | |
}, ensure_ascii=False) + "\n") | |
img_res = draw_ser_results(img_path, result) | |
cv2.imwrite(save_img_path, img_res) | |