strexp / strhub /data /module.py
markytools's picture
added strexp
d61b9c7
# Scene Text Recognition Model Hub
# Copyright 2022 Darwin Bautista
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import PurePath
from typing import Optional, Callable, Sequence, Tuple
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import transforms as T
from .dataset import build_tree_dataset, LmdbDataset
class SceneTextDataModule(pl.LightningDataModule):
TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int,
charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool,
remove_whitespace: bool = True, normalize_unicode: bool = True,
min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None):
super().__init__()
self.root_dir = root_dir
self.train_dir = train_dir
self.img_size = tuple(img_size)
self.max_label_length = max_label_length
self.charset_train = charset_train
self.charset_test = charset_test
self.batch_size = batch_size
self.num_workers = num_workers
self.augment = augment
self.remove_whitespace = remove_whitespace
self.normalize_unicode = normalize_unicode
self.min_image_dim = min_image_dim
self.rotation = rotation
self.collate_fn = collate_fn
self._train_dataset = None
self._val_dataset = None
@staticmethod
def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0):
transforms = []
if augment:
from .augment import rand_augment_transform
transforms.append(rand_augment_transform())
if rotation:
transforms.append(lambda img: img.rotate(rotation, expand=True))
transforms.extend([
# T.Resize(img_size, T.InterpolationMode.BICUBIC),
# T.ToTensor(),
T.Normalize(0.5, 0.5)
])
return T.Compose(transforms)
@property
def train_dataset(self):
if self._train_dataset is None:
transform = self.get_transform(self.img_size, self.augment)
root = PurePath(self.root_dir, 'train', self.train_dir)
self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length,
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
transform=transform)
return self._train_dataset
@property
def val_dataset(self):
if self._val_dataset is None:
transform = self.get_transform(self.img_size)
root = PurePath(self.root_dir, 'val')
self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length,
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
transform=transform)
return self._val_dataset
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_workers, persistent_workers=self.num_workers > 0,
pin_memory=True, collate_fn=self.collate_fn)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size,
num_workers=self.num_workers, persistent_workers=self.num_workers > 0,
pin_memory=True, collate_fn=self.collate_fn)
def test_dataloaders(self, subset):
transform = self.get_transform(self.img_size, rotation=self.rotation)
root = PurePath(self.root_dir, 'test')
datasets = {s: LmdbDataset(str(root / s), self.charset_test, self.max_label_length,
self.min_image_dim, self.remove_whitespace, self.normalize_unicode,
transform=transform) for s in subset}
return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers,
pin_memory=True, collate_fn=self.collate_fn)
for k, v in datasets.items()}