Spaces:
Runtime error
Runtime error
# Copyright (C) 2021-2024, Mindee. | |
# This program is licensed under the Apache License 2.0. | |
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
import random | |
from typing import Any, Callable, List, Optional, Tuple, Union | |
from PIL import Image, ImageDraw | |
from doctr.io.image import tensor_from_pil | |
from doctr.utils.fonts import get_font | |
from ..datasets import AbstractDataset | |
def synthesize_text_img( | |
text: str, | |
font_size: int = 32, | |
font_family: Optional[str] = None, | |
background_color: Optional[Tuple[int, int, int]] = None, | |
text_color: Optional[Tuple[int, int, int]] = None, | |
) -> Image.Image: | |
"""Generate a synthetic text image | |
Args: | |
---- | |
text: the text to render as an image | |
font_size: the size of the font | |
font_family: the font family (has to be installed on your system) | |
background_color: background color of the final image | |
text_color: text color on the final image | |
Returns: | |
------- | |
PIL image of the text | |
""" | |
background_color = (0, 0, 0) if background_color is None else background_color | |
text_color = (255, 255, 255) if text_color is None else text_color | |
font = get_font(font_family, font_size) | |
left, top, right, bottom = font.getbbox(text) | |
text_w, text_h = right - left, bottom - top | |
h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w)) | |
# If single letter, make the image square, otherwise expand to meet the text size | |
img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w)) | |
img = Image.new("RGB", img_size[::-1], color=background_color) | |
d = ImageDraw.Draw(img) | |
# Offset so that the text is centered | |
text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2))) | |
# Draw the text | |
d.text(text_pos, text, font=font, fill=text_color) | |
return img | |
class _CharacterGenerator(AbstractDataset): | |
def __init__( | |
self, | |
vocab: str, | |
num_samples: int, | |
cache_samples: bool = False, | |
font_family: Optional[Union[str, List[str]]] = None, | |
img_transforms: Optional[Callable[[Any], Any]] = None, | |
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, | |
) -> None: | |
self.vocab = vocab | |
self._num_samples = num_samples | |
self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item] | |
# Validate fonts | |
if isinstance(font_family, list): | |
for font in self.font_family: | |
try: | |
_ = get_font(font, 10) | |
except OSError: | |
raise ValueError(f"unable to locate font: {font}") | |
self.img_transforms = img_transforms | |
self.sample_transforms = sample_transforms | |
self._data: List[Image.Image] = [] | |
if cache_samples: | |
self._data = [ | |
(synthesize_text_img(char, font_family=font), idx) # type: ignore[misc] | |
for idx, char in enumerate(self.vocab) | |
for font in self.font_family | |
] | |
def __len__(self) -> int: | |
return self._num_samples | |
def _read_sample(self, index: int) -> Tuple[Any, int]: | |
# Samples are already cached | |
if len(self._data) > 0: | |
idx = index % len(self._data) | |
pil_img, target = self._data[idx] # type: ignore[misc] | |
else: | |
target = index % len(self.vocab) | |
pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family)) | |
img = tensor_from_pil(pil_img) | |
return img, target | |
class _WordGenerator(AbstractDataset): | |
def __init__( | |
self, | |
vocab: str, | |
min_chars: int, | |
max_chars: int, | |
num_samples: int, | |
cache_samples: bool = False, | |
font_family: Optional[Union[str, List[str]]] = None, | |
img_transforms: Optional[Callable[[Any], Any]] = None, | |
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, | |
) -> None: | |
self.vocab = vocab | |
self.wordlen_range = (min_chars, max_chars) | |
self._num_samples = num_samples | |
self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item] | |
# Validate fonts | |
if isinstance(font_family, list): | |
for font in self.font_family: | |
try: | |
_ = get_font(font, 10) | |
except OSError: | |
raise ValueError(f"unable to locate font: {font}") | |
self.img_transforms = img_transforms | |
self.sample_transforms = sample_transforms | |
self._data: List[Image.Image] = [] | |
if cache_samples: | |
_words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)] | |
self._data = [ | |
(synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc] | |
for text in _words | |
] | |
def _generate_string(self, min_chars: int, max_chars: int) -> str: | |
num_chars = random.randint(min_chars, max_chars) | |
return "".join(random.choice(self.vocab) for _ in range(num_chars)) | |
def __len__(self) -> int: | |
return self._num_samples | |
def _read_sample(self, index: int) -> Tuple[Any, str]: | |
# Samples are already cached | |
if len(self._data) > 0: | |
pil_img, target = self._data[index] # type: ignore[misc] | |
else: | |
target = self._generate_string(*self.wordlen_range) | |
pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family)) | |
img = tensor_from_pil(pil_img) | |
return img, target | |