adirathor07's picture
added doctr folder
153628e
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import random
from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image, ImageDraw
from doctr.io.image import tensor_from_pil
from doctr.utils.fonts import get_font
from ..datasets import AbstractDataset
def synthesize_text_img(
text: str,
font_size: int = 32,
font_family: Optional[str] = None,
background_color: Optional[Tuple[int, int, int]] = None,
text_color: Optional[Tuple[int, int, int]] = None,
) -> Image.Image:
"""Generate a synthetic text image
Args:
----
text: the text to render as an image
font_size: the size of the font
font_family: the font family (has to be installed on your system)
background_color: background color of the final image
text_color: text color on the final image
Returns:
-------
PIL image of the text
"""
background_color = (0, 0, 0) if background_color is None else background_color
text_color = (255, 255, 255) if text_color is None else text_color
font = get_font(font_family, font_size)
left, top, right, bottom = font.getbbox(text)
text_w, text_h = right - left, bottom - top
h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w))
# If single letter, make the image square, otherwise expand to meet the text size
img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w))
img = Image.new("RGB", img_size[::-1], color=background_color)
d = ImageDraw.Draw(img)
# Offset so that the text is centered
text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2)))
# Draw the text
d.text(text_pos, text, font=font, fill=text_color)
return img
class _CharacterGenerator(AbstractDataset):
def __init__(
self,
vocab: str,
num_samples: int,
cache_samples: bool = False,
font_family: Optional[Union[str, List[str]]] = None,
img_transforms: Optional[Callable[[Any], Any]] = None,
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
) -> None:
self.vocab = vocab
self._num_samples = num_samples
self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
# Validate fonts
if isinstance(font_family, list):
for font in self.font_family:
try:
_ = get_font(font, 10)
except OSError:
raise ValueError(f"unable to locate font: {font}")
self.img_transforms = img_transforms
self.sample_transforms = sample_transforms
self._data: List[Image.Image] = []
if cache_samples:
self._data = [
(synthesize_text_img(char, font_family=font), idx) # type: ignore[misc]
for idx, char in enumerate(self.vocab)
for font in self.font_family
]
def __len__(self) -> int:
return self._num_samples
def _read_sample(self, index: int) -> Tuple[Any, int]:
# Samples are already cached
if len(self._data) > 0:
idx = index % len(self._data)
pil_img, target = self._data[idx] # type: ignore[misc]
else:
target = index % len(self.vocab)
pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family))
img = tensor_from_pil(pil_img)
return img, target
class _WordGenerator(AbstractDataset):
def __init__(
self,
vocab: str,
min_chars: int,
max_chars: int,
num_samples: int,
cache_samples: bool = False,
font_family: Optional[Union[str, List[str]]] = None,
img_transforms: Optional[Callable[[Any], Any]] = None,
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
) -> None:
self.vocab = vocab
self.wordlen_range = (min_chars, max_chars)
self._num_samples = num_samples
self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
# Validate fonts
if isinstance(font_family, list):
for font in self.font_family:
try:
_ = get_font(font, 10)
except OSError:
raise ValueError(f"unable to locate font: {font}")
self.img_transforms = img_transforms
self.sample_transforms = sample_transforms
self._data: List[Image.Image] = []
if cache_samples:
_words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
self._data = [
(synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc]
for text in _words
]
def _generate_string(self, min_chars: int, max_chars: int) -> str:
num_chars = random.randint(min_chars, max_chars)
return "".join(random.choice(self.vocab) for _ in range(num_chars))
def __len__(self) -> int:
return self._num_samples
def _read_sample(self, index: int) -> Tuple[Any, str]:
# Samples are already cached
if len(self._data) > 0:
pil_img, target = self._data[index] # type: ignore[misc]
else:
target = self._generate_string(*self.wordlen_range)
pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family))
img = tensor_from_pil(pil_img)
return img, target