Spaces:
Sleeping
Sleeping
import random | |
from typing import Tuple | |
import cv2 | |
import numpy as np | |
from dataloader_iam import Batch | |
class Preprocessor: | |
def __init__(self, | |
img_size: Tuple[int, int], | |
padding: int = 0, | |
dynamic_width: bool = False, | |
data_augmentation: bool = False, | |
line_mode: bool = False) -> None: | |
# dynamic width only supported when no data augmentation happens | |
assert not (dynamic_width and data_augmentation) | |
# when padding is on, we need dynamic width enabled | |
assert not (padding > 0 and not dynamic_width) | |
self.img_size = img_size | |
self.padding = padding | |
self.dynamic_width = dynamic_width | |
self.data_augmentation = data_augmentation | |
self.line_mode = line_mode | |
def _truncate_label(text: str, max_text_len: int) -> str: | |
""" | |
Function ctc_loss can't compute loss if it cannot find a mapping between text label and input | |
labels. Repeat letters cost double because of the blank symbol needing to be inserted. | |
If a too-long label is provided, ctc_loss returns an infinite gradient. | |
""" | |
cost = 0 | |
for i in range(len(text)): | |
if i != 0 and text[i] == text[i - 1]: | |
cost += 2 | |
else: | |
cost += 1 | |
if cost > max_text_len: | |
return text[:i] | |
return text | |
def _simulate_text_line(self, batch: Batch) -> Batch: | |
"""Create image of a text line by pasting multiple word images into an image.""" | |
default_word_sep = 30 | |
default_num_words = 5 | |
# go over all batch elements | |
res_imgs = [] | |
res_gt_texts = [] | |
for i in range(batch.batch_size): | |
# number of words to put into current line | |
num_words = random.randint(1, 8) if self.data_augmentation else default_num_words | |
# concat ground truth texts | |
curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)]) | |
res_gt_texts.append(curr_gt) | |
# put selected word images into list, compute target image size | |
sel_imgs = [] | |
word_seps = [0] | |
h = 0 | |
w = 0 | |
for j in range(num_words): | |
curr_sel_img = batch.imgs[(i + j) % batch.batch_size] | |
curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep | |
h = max(h, curr_sel_img.shape[0]) | |
w += curr_sel_img.shape[1] | |
sel_imgs.append(curr_sel_img) | |
if j + 1 < num_words: | |
w += curr_word_sep | |
word_seps.append(curr_word_sep) | |
# put all selected word images into target image | |
target = np.ones([h, w], np.uint8) * 255 | |
x = 0 | |
for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps): | |
x += curr_word_sep | |
y = (h - curr_sel_img.shape[0]) // 2 | |
target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img | |
x += curr_sel_img.shape[1] | |
# put image of line into result | |
res_imgs.append(target) | |
return Batch(res_imgs, res_gt_texts, batch.batch_size) | |
def process_img(self, img: np.ndarray) -> np.ndarray: | |
"""Resize to target size, apply data augmentation.""" | |
# there are damaged files in IAM dataset - just use black image instead | |
if img is None: | |
img = np.zeros(self.img_size[::-1]) | |
# data augmentation | |
img = img.astype(float) | |
if self.data_augmentation: | |
# photometric data augmentation | |
if random.random() < 0.25: | |
def rand_odd(): | |
return random.randint(1, 3) * 2 + 1 | |
img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0) | |
if random.random() < 0.25: | |
img = cv2.dilate(img, np.ones((3, 3))) | |
if random.random() < 0.25: | |
img = cv2.erode(img, np.ones((3, 3))) | |
# geometric data augmentation | |
wt, ht = self.img_size | |
h, w = img.shape | |
f = min(wt / w, ht / h) | |
fx = f * np.random.uniform(0.75, 1.05) | |
fy = f * np.random.uniform(0.75, 1.05) | |
# random position around center | |
txc = (wt - w * fx) / 2 | |
tyc = (ht - h * fy) / 2 | |
freedom_x = max((wt - fx * w) / 2, 0) | |
freedom_y = max((ht - fy * h) / 2, 0) | |
tx = txc + np.random.uniform(-freedom_x, freedom_x) | |
ty = tyc + np.random.uniform(-freedom_y, freedom_y) | |
# map image into target image | |
M = np.float32([[fx, 0, tx], [0, fy, ty]]) | |
target = np.ones(self.img_size[::-1]) * 255 | |
img = cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT) | |
# photometric data augmentation | |
if random.random() < 0.5: | |
img = img * (0.25 + random.random() * 0.75) | |
if random.random() < 0.25: | |
img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255) | |
if random.random() < 0.1: | |
img = 255 - img | |
# no data augmentation | |
else: | |
if self.dynamic_width: | |
ht = self.img_size[1] | |
h, w = img.shape | |
f = ht / h | |
wt = int(f * w + self.padding) | |
wt = wt + (4 - wt) % 4 | |
tx = (wt - w * f) / 2 | |
ty = 0 | |
else: | |
wt, ht = self.img_size | |
h, w = img.shape | |
f = min(wt / w, ht / h) | |
tx = (wt - w * f) / 2 | |
ty = (ht - h * f) / 2 | |
# map image into target image | |
M = np.float32([[f, 0, tx], [0, f, ty]]) | |
target = np.ones([ht, wt]) * 255 | |
img = cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT) | |
# transpose for TF | |
img = cv2.transpose(img) | |
# convert to range [-1, 1] | |
img = img / 255 - 0.5 | |
return img | |
def process_batch(self, batch: Batch) -> Batch: | |
if self.line_mode: | |
batch = self._simulate_text_line(batch) | |
res_imgs = [self.process_img(img) for img in batch.imgs] | |
max_text_len = res_imgs[0].shape[0] // 4 | |
res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts] | |
return Batch(res_imgs, res_gt_texts, batch.batch_size) | |
def main(): | |
import matplotlib.pyplot as plt | |
img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE) | |
img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img) | |
plt.subplot(121) | |
plt.imshow(img, cmap='gray') | |
plt.subplot(122) | |
plt.imshow(cv2.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1) | |
plt.show() | |
if __name__ == '__main__': | |
main() | |