shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
# Scene Text Recognition Model Hub
# Copyright 2022 Darwin Bautista
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import PurePath
from typing import Callable, Optional, Sequence
from torch.utils.data import DataLoader
from torchvision import transforms as T
import pytorch_lightning as pl
from IndicPhotoOCR.utils.strhub.data.dataset import LmdbDataset, build_tree_dataset
class SceneTextDataModule(pl.LightningDataModule):
TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80')
TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80')
TEST_NEW = ('ArT', 'COCOv1.4', 'Uber')
TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW))
def __init__(
self,
root_dir: str,
train_dir: str,
img_size: Sequence[int],
max_label_length: int,
charset_train: str,
charset_test: str,
batch_size: int,
num_workers: int,
augment: bool,
remove_whitespace: bool = True,
normalize_unicode: bool = True,
min_image_dim: int = 0,
rotation: int = 0,
collate_fn: Optional[Callable] = None,
):
super().__init__()
self.root_dir = root_dir
self.train_dir = train_dir
self.img_size = tuple(img_size)
self.max_label_length = max_label_length
self.charset_train = charset_train
self.charset_test = charset_test
self.batch_size = batch_size
self.num_workers = num_workers
self.augment = augment
self.remove_whitespace = remove_whitespace
self.normalize_unicode = normalize_unicode
self.min_image_dim = min_image_dim
self.rotation = rotation
self.collate_fn = collate_fn
self._train_dataset = None
self._val_dataset = None
@staticmethod
def get_transform(img_size: tuple[int], augment: bool = False, rotation: int = 0):
transforms = []
if augment:
from .augment import rand_augment_transform
transforms.append(rand_augment_transform())
if rotation:
transforms.append(lambda img: img.rotate(rotation, expand=True))
transforms.extend([
T.Resize(img_size, T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(0.5, 0.5),
])
return T.Compose(transforms)
@property
def train_dataset(self):
if self._train_dataset is None:
transform = self.get_transform(self.img_size, self.augment)
root = PurePath(self.root_dir, 'train', self.train_dir)
self._train_dataset = build_tree_dataset(
root,
self.charset_train,
self.max_label_length,
self.min_image_dim,
self.remove_whitespace,
self.normalize_unicode,
transform=transform,
)
return self._train_dataset
@property
def val_dataset(self):
if self._val_dataset is None:
transform = self.get_transform(self.img_size)
root = PurePath(self.root_dir, 'val')
self._val_dataset = build_tree_dataset(
root,
self.charset_test,
self.max_label_length,
self.min_image_dim,
self.remove_whitespace,
self.normalize_unicode,
transform=transform,
)
return self._val_dataset
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
persistent_workers=self.num_workers > 0,
pin_memory=True,
collate_fn=self.collate_fn,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=self.num_workers > 0,
pin_memory=True,
collate_fn=self.collate_fn,
)
def test_dataloaders(self, subset):
transform = self.get_transform(self.img_size, rotation=self.rotation)
root = PurePath(self.root_dir, 'test')
datasets = {
s: LmdbDataset(
str(root / s),
self.charset_test,
self.max_label_length,
self.min_image_dim,
self.remove_whitespace,
self.normalize_unicode,
transform=transform,
)
for s in subset
}
return {
k: DataLoader(
v, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=self.collate_fn
)
for k, v in datasets.items()
}