Spaces:
Runtime error
Runtime error
File size: 5,769 Bytes
153628e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import random
from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image, ImageDraw
from doctr.io.image import tensor_from_pil
from doctr.utils.fonts import get_font
from ..datasets import AbstractDataset
def synthesize_text_img(
text: str,
font_size: int = 32,
font_family: Optional[str] = None,
background_color: Optional[Tuple[int, int, int]] = None,
text_color: Optional[Tuple[int, int, int]] = None,
) -> Image.Image:
"""Generate a synthetic text image
Args:
----
text: the text to render as an image
font_size: the size of the font
font_family: the font family (has to be installed on your system)
background_color: background color of the final image
text_color: text color on the final image
Returns:
-------
PIL image of the text
"""
background_color = (0, 0, 0) if background_color is None else background_color
text_color = (255, 255, 255) if text_color is None else text_color
font = get_font(font_family, font_size)
left, top, right, bottom = font.getbbox(text)
text_w, text_h = right - left, bottom - top
h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w))
# If single letter, make the image square, otherwise expand to meet the text size
img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w))
img = Image.new("RGB", img_size[::-1], color=background_color)
d = ImageDraw.Draw(img)
# Offset so that the text is centered
text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2)))
# Draw the text
d.text(text_pos, text, font=font, fill=text_color)
return img
class _CharacterGenerator(AbstractDataset):
def __init__(
self,
vocab: str,
num_samples: int,
cache_samples: bool = False,
font_family: Optional[Union[str, List[str]]] = None,
img_transforms: Optional[Callable[[Any], Any]] = None,
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
) -> None:
self.vocab = vocab
self._num_samples = num_samples
self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
# Validate fonts
if isinstance(font_family, list):
for font in self.font_family:
try:
_ = get_font(font, 10)
except OSError:
raise ValueError(f"unable to locate font: {font}")
self.img_transforms = img_transforms
self.sample_transforms = sample_transforms
self._data: List[Image.Image] = []
if cache_samples:
self._data = [
(synthesize_text_img(char, font_family=font), idx) # type: ignore[misc]
for idx, char in enumerate(self.vocab)
for font in self.font_family
]
def __len__(self) -> int:
return self._num_samples
def _read_sample(self, index: int) -> Tuple[Any, int]:
# Samples are already cached
if len(self._data) > 0:
idx = index % len(self._data)
pil_img, target = self._data[idx] # type: ignore[misc]
else:
target = index % len(self.vocab)
pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family))
img = tensor_from_pil(pil_img)
return img, target
class _WordGenerator(AbstractDataset):
def __init__(
self,
vocab: str,
min_chars: int,
max_chars: int,
num_samples: int,
cache_samples: bool = False,
font_family: Optional[Union[str, List[str]]] = None,
img_transforms: Optional[Callable[[Any], Any]] = None,
sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None,
) -> None:
self.vocab = vocab
self.wordlen_range = (min_chars, max_chars)
self._num_samples = num_samples
self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item]
# Validate fonts
if isinstance(font_family, list):
for font in self.font_family:
try:
_ = get_font(font, 10)
except OSError:
raise ValueError(f"unable to locate font: {font}")
self.img_transforms = img_transforms
self.sample_transforms = sample_transforms
self._data: List[Image.Image] = []
if cache_samples:
_words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)]
self._data = [
(synthesize_text_img(text, font_family=random.choice(self.font_family)), text) # type: ignore[misc]
for text in _words
]
def _generate_string(self, min_chars: int, max_chars: int) -> str:
num_chars = random.randint(min_chars, max_chars)
return "".join(random.choice(self.vocab) for _ in range(num_chars))
def __len__(self) -> int:
return self._num_samples
def _read_sample(self, index: int) -> Tuple[Any, str]:
# Samples are already cached
if len(self._data) > 0:
pil_img, target = self._data[index] # type: ignore[misc]
else:
target = self._generate_string(*self.wordlen_range)
pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family))
img = tensor_from_pil(pil_img)
return img, target
|