|
"""Modified from https://github.com/JaidedAI/EasyOCR/blob/803b907/easyocr/detection.py. |
|
1. Disable DataParallel. |
|
""" |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
from torch.autograd import Variable |
|
from PIL import Image |
|
from collections import OrderedDict |
|
|
|
import cv2 |
|
import numpy as np |
|
from .craft_utils import getDetBoxes, adjustResultCoordinates |
|
from .imgproc import resize_aspect_ratio, normalizeMeanVariance |
|
from .craft import CRAFT |
|
|
|
def copyStateDict(state_dict): |
|
if list(state_dict.keys())[0].startswith("module"): |
|
start_idx = 1 |
|
else: |
|
start_idx = 0 |
|
new_state_dict = OrderedDict() |
|
for k, v in state_dict.items(): |
|
name = ".".join(k.split(".")[start_idx:]) |
|
new_state_dict[name] = v |
|
return new_state_dict |
|
|
|
def test_net(canvas_size, mag_ratio, net, image, text_threshold, link_threshold, low_text, poly, device, estimate_num_chars=False): |
|
if isinstance(image, np.ndarray) and len(image.shape) == 4: |
|
image_arrs = image |
|
else: |
|
image_arrs = [image] |
|
|
|
img_resized_list = [] |
|
|
|
for img in image_arrs: |
|
img_resized, target_ratio, size_heatmap = resize_aspect_ratio(img, canvas_size, |
|
interpolation=cv2.INTER_LINEAR, |
|
mag_ratio=mag_ratio) |
|
img_resized_list.append(img_resized) |
|
ratio_h = ratio_w = 1 / target_ratio |
|
|
|
x = [np.transpose(normalizeMeanVariance(n_img), (2, 0, 1)) |
|
for n_img in img_resized_list] |
|
x = torch.from_numpy(np.array(x)) |
|
x = x.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
y, feature = net(x) |
|
|
|
boxes_list, polys_list = [], [] |
|
for out in y: |
|
|
|
score_text = out[:, :, 0].cpu().data.numpy() |
|
score_link = out[:, :, 1].cpu().data.numpy() |
|
|
|
|
|
boxes, polys, mapper = getDetBoxes( |
|
score_text, score_link, text_threshold, link_threshold, low_text, poly, estimate_num_chars) |
|
|
|
|
|
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) |
|
polys = adjustResultCoordinates(polys, ratio_w, ratio_h) |
|
if estimate_num_chars: |
|
boxes = list(boxes) |
|
polys = list(polys) |
|
for k in range(len(polys)): |
|
if estimate_num_chars: |
|
boxes[k] = (boxes[k], mapper[k]) |
|
if polys[k] is None: |
|
polys[k] = boxes[k] |
|
boxes_list.append(boxes) |
|
polys_list.append(polys) |
|
|
|
return boxes_list, polys_list |
|
|
|
def get_detector(trained_model, device='cpu', quantize=True, cudnn_benchmark=False): |
|
net = CRAFT() |
|
|
|
if device == 'cpu': |
|
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) |
|
if quantize: |
|
try: |
|
torch.quantization.quantize_dynamic(net, dtype=torch.qint8, inplace=True) |
|
except: |
|
pass |
|
else: |
|
net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device))) |
|
|
|
net = net.to(device) |
|
cudnn.benchmark = cudnn_benchmark |
|
|
|
net.eval() |
|
return net |
|
|
|
def get_textbox(detector, image, canvas_size, mag_ratio, text_threshold, link_threshold, low_text, poly, device, optimal_num_chars=None, **kwargs): |
|
result = [] |
|
estimate_num_chars = optimal_num_chars is not None |
|
bboxes_list, polys_list = test_net(canvas_size, mag_ratio, detector, |
|
image, text_threshold, |
|
link_threshold, low_text, poly, |
|
device, estimate_num_chars) |
|
if estimate_num_chars: |
|
polys_list = [[p for p, _ in sorted(polys, key=lambda x: abs(optimal_num_chars - x[1]))] |
|
for polys in polys_list] |
|
|
|
for polys in polys_list: |
|
single_img_result = [] |
|
for i, box in enumerate(polys): |
|
poly = np.array(box).astype(np.int32).reshape((-1)) |
|
single_img_result.append(poly) |
|
result.append(single_img_result) |
|
|
|
return result |
|
|