Spaces:
Running
Running
import logging | |
import os | |
import cv2 | |
import numpy as np | |
import importlib.util | |
import sys | |
import subprocess | |
def get_check_global_params(mode): | |
check_params = [ | |
"use_gpu", | |
"max_text_length", | |
"image_shape", | |
"image_shape", | |
"character_type", | |
"loss_type", | |
] | |
if mode == "train_eval": | |
check_params = check_params + [ | |
"train_batch_size_per_card", | |
"test_batch_size_per_card", | |
] | |
elif mode == "test": | |
check_params = check_params + ["test_batch_size_per_card"] | |
return check_params | |
def _check_image_file(path): | |
img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"} | |
return any([path.lower().endswith(e) for e in img_end]) | |
def get_image_file_list(img_file): | |
imgs_lists = [] | |
if img_file is None or not os.path.exists(img_file): | |
raise Exception("not found any img file in {}".format(img_file)) | |
if os.path.isfile(img_file) and _check_image_file(img_file): | |
imgs_lists.append(img_file) | |
elif os.path.isdir(img_file): | |
for single_file in os.listdir(img_file): | |
file_path = os.path.join(img_file, single_file) | |
if os.path.isfile(file_path) and _check_image_file(file_path): | |
imgs_lists.append(file_path) | |
if len(imgs_lists) == 0: | |
raise Exception("not found any img file in {}".format(img_file)) | |
imgs_lists = sorted(imgs_lists) | |
return imgs_lists | |
def binarize_img(img): | |
if len(img.shape) == 3 and img.shape[2] == 3: | |
gray = cv2.cvtColor(img, | |
cv2.COLOR_BGR2GRAY) # conversion to grayscale image | |
# use cv2 threshold binarization | |
_, gray = cv2.threshold(gray, 0, 255, | |
cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) | |
return img | |
def alpha_to_color(img, alpha_color=(255, 255, 255)): | |
if len(img.shape) == 3 and img.shape[2] == 4: | |
B, G, R, A = cv2.split(img) | |
alpha = A / 255 | |
R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8) | |
G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8) | |
B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8) | |
img = cv2.merge((B, G, R)) | |
return img | |
def check_and_read(img_path): | |
if os.path.basename(img_path)[-3:].lower() == "gif": | |
gif = cv2.VideoCapture(img_path) | |
ret, frame = gif.read() | |
if not ret: | |
logger = logging.getLogger("openrec") | |
logger.info("Cannot read {}. This gif image maybe corrupted.") | |
return None, False | |
if len(frame.shape) == 2 or frame.shape[-1] == 1: | |
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) | |
imgvalue = frame[:, :, ::-1] | |
return imgvalue, True, False | |
elif os.path.basename(img_path)[-3:].lower() == "pdf": | |
import fitz | |
from PIL import Image | |
imgs = [] | |
with fitz.open(img_path) as pdf: | |
for pg in range(0, pdf.page_count): | |
page = pdf[pg] | |
mat = fitz.Matrix(2, 2) | |
pm = page.get_pixmap(matrix=mat, alpha=False) | |
# if width or height > 2000 pixels, don't enlarge the image | |
if pm.width > 2000 or pm.height > 2000: | |
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) | |
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) | |
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
imgs.append(img) | |
return imgs, False, True | |
return None, False, False | |
def load_vqa_bio_label_maps(label_map_path): | |
with open(label_map_path, "r", encoding="utf-8") as fin: | |
lines = fin.readlines() | |
old_lines = [line.strip() for line in lines] | |
lines = ["O"] | |
for line in old_lines: | |
# "O" has already been in lines | |
if line.upper() in ["OTHER", "OTHERS", "IGNORE"]: | |
continue | |
lines.append(line) | |
labels = ["O"] | |
for line in lines[1:]: | |
labels.append("B-" + line) | |
labels.append("I-" + line) | |
label2id_map = {label.upper(): idx for idx, label in enumerate(labels)} | |
id2label_map = {idx: label.upper() for idx, label in enumerate(labels)} | |
return label2id_map, id2label_map | |
def check_install(module_name, install_name): | |
spec = importlib.util.find_spec(module_name) | |
if spec is None: | |
print(f"Warnning! The {module_name} module is NOT installed") | |
print( | |
f"Try install {module_name} module automatically. You can also try to install manually by pip install {install_name}." | |
) | |
python = sys.executable | |
try: | |
subprocess.check_call( | |
[python, "-m", "pip", "install", install_name], | |
stdout=subprocess.DEVNULL, ) | |
print(f"The {module_name} module is now installed") | |
except subprocess.CalledProcessError as exc: | |
raise Exception( | |
f"Install {module_name} failed, please install manually") | |
else: | |
print(f"{module_name} has been installed.") | |
class AverageMeter: | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
"""reset""" | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
"""update""" | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |