# Copyright (C) 2021-2024, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. import random from typing import Any, Callable, List, Optional, Tuple, Union from PIL import Image, ImageDraw from doctr.io.image import tensor_from_pil from doctr.utils.fonts import get_font from ..datasets import AbstractDataset def synthesize_text_img( text: str, font_size: int = 32, font_family: Optional[str] = None, background_color: Optional[Tuple[int, int, int]] = None, text_color: Optional[Tuple[int, int, int]] = None, ) -> Image.Image: """Generate a synthetic text image Args: ---- text: the text to render as an image font_size: the size of the font font_family: the font family (has to be installed on your system) background_color: background color of the final image text_color: text color on the final image Returns: ------- PIL image of the text """ background_color = (0, 0, 0) if background_color is None else background_color text_color = (255, 255, 255) if text_color is None else text_color font = get_font(font_family, font_size) left, top, right, bottom = font.getbbox(text) text_w, text_h = right - left, bottom - top h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w)) # If single letter, make the image square, otherwise expand to meet the text size img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w)) img = Image.new("RGB", img_size[::-1], color=background_color) d = ImageDraw.Draw(img) # Offset so that the text is centered text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2))) # Draw the text d.text(text_pos, text, font=font, fill=text_color) return img class _CharacterGenerator(AbstractDataset): def __init__( self, vocab: str, num_samples: int, cache_samples: bool = False, font_family: Optional[Union[str, List[str]]] = None, img_transforms: Optional[Callable[[Any], Any]] = None, sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, ) -> None: self.vocab = vocab self._num_samples = num_samples self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item] # Validate fonts if isinstance(font_family, list): for font in self.font_family: try: _ = get_font(font, 10) except OSError: raise ValueError(f"unable to locate font: {font}") self.img_transforms = img_transforms self.sample_transforms = sample_transforms self._data: List[Image.Image] = [] if cache_samples: self._data = [ (synthesize_text_img(char, font_family=font), idx) # type: ignore[misc] for idx, char in enumerate(self.vocab) for font in self.font_family ] def __len__(self) -> int: return self._num_samples def _read_sample(self, index: int) -> Tuple[Any, int]: # Samples are already cached if len(self._data) > 0: idx = index % len(self._data) pil_img, target = self._data[idx] # type: ignore[misc] else: target = index % len(self.vocab) pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family)) img = tensor_from_pil(pil_img) return img, target class _WordGenerator(AbstractDataset): def __init__( self, vocab: str, min_chars: int, max_chars: int, num_samples: int, cache_samples: bool = False, font_family: Optional[Union[str, List[str]]] = None, img_transforms: Optional[Callable[[Any], Any]] = None, sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, ) -> None: self.vocab = vocab self.wordlen_range = (min_chars, max_chars) self._num_samples = num_samples self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item] # Validate fonts if isinstance(font_family, list): for font in self.font_family: try: _ = get_font(font, 10) except OSError: raise ValueError(f"unable to locate font: {font}") self.img_transforms = img_transforms self.sample_transforms = sample_transforms self._data: List[Image.Image] = [] if cache_samples: _words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)] self._data = [ (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc] for text in _words ] def _generate_string(self, min_chars: int, max_chars: int) -> str: num_chars = random.randint(min_chars, max_chars) return "".join(random.choice(self.vocab) for _ in range(num_chars)) def __len__(self) -> int: return self._num_samples def _read_sample(self, index: int) -> Tuple[Any, str]: # Samples are already cached if len(self._data) > 0: pil_img, target = self._data[index] # type: ignore[misc] else: target = self._generate_string(*self.wordlen_range) pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family)) img = tensor_from_pil(pil_img) return img, target