topdu's picture
openocr demo
29f689c
raw
history blame
5.5 kB
import io
import cv2
import numpy as np
from PIL import Image
from .abinet_label_encode import ABINetLabelEncode
from .ar_label_encode import ARLabelEncode
from .ce_label_encode import CELabelEncode
from .char_label_encode import CharLabelEncode
from .cppd_label_encode import CPPDLabelEncode
from .ctc_label_encode import CTCLabelEncode
from .ep_label_encode import EPLabelEncode
from .igtr_label_encode import IGTRLabelEncode
from .mgp_label_encode import MGPLabelEncode
from .rec_aug import ABINetAug
from .rec_aug import BaseDataAugmentation as BDA
from .rec_aug import PARSeqAug, PARSeqAugPIL, SVTRAug
from .resize import (ABINetResize, CDistNetResize, LongResize, RecTVResize,
RobustScannerRecResizeImg, SliceResize, SliceTVResize,
SRNRecResizeImg, SVTRResize, VisionLANResize,
RecDynamicResize)
from .smtr_label_encode import SMTRLabelEncode
from .srn_label_encode import SRNLabelEncode
from .visionlan_label_encode import VisionLANLabelEncode
from .cam_label_encode import CAMLabelEncode
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
def __call__(self, data):
data_list = []
for key in self.keep_keys:
data_list.append(data[key])
return data_list
def transform(data, ops=None):
"""transform."""
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
class Fasttext(object):
def __init__(self, path='None', **kwargs):
# pip install fasttext==0.9.1
import fasttext
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
label = data['label']
fast_label = self.fast_model[label]
data['fast_label'] = fast_label
return data
class DecodeImage(object):
"""decode image."""
def __init__(self,
img_mode='RGB',
channel_first=False,
ignore_orientation=False,
**kwargs):
self.img_mode = img_mode
self.channel_first = channel_first
self.ignore_orientation = ignore_orientation
def __call__(self, data):
img = data['image']
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8')
if self.ignore_orientation:
img = cv2.imdecode(
img, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)
else:
img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB':
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
data['image'] = img
return data
class DecodeImagePIL(object):
"""decode image."""
def __init__(self, img_mode='RGB', **kwargs):
self.img_mode = img_mode
def __call__(self, data):
img = data['image']
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
img = data['image']
buf = io.BytesIO(img)
img = Image.open(buf).convert('RGB')
if self.img_mode == 'Gray':
img = img.convert('L')
elif self.img_mode == 'BGR':
img = np.array(img)[:, :, ::-1] # 将图片转为numpy格式,并将最后一维通道倒序
img = Image.fromarray(np.uint8(img))
data['image'] = img
return data
def create_operators(op_param_list, global_config=None):
"""create operators based on the config.
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(op_param_list, list), 'operator config should be a list'
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, 'yaml format error'
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
return ops
class GTCLabelEncode():
"""Convert between text-label and text-index."""
def __init__(self,
gtc_label_encode,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
self.gtc_label_encode = eval(gtc_label_encode['name'])(
max_text_length=max_text_length,
character_dict_path=character_dict_path,
use_space_char=use_space_char,
**gtc_label_encode)
self.ctc_label_encode = CTCLabelEncode(max_text_length,
character_dict_path,
use_space_char)
def __call__(self, data):
data_ctc = self.ctc_label_encode({'label': data['label']})
data = self.gtc_label_encode(data)
if data_ctc is None or data is None:
return None
data['ctc_label'] = data_ctc['label']
data['ctc_length'] = data_ctc['length']
return data