OpenOCR-Demo / tools /data /__init__.py
topdu's picture
openocr demo
29f689c
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import copy
from torch.utils.data import DataLoader, DistributedSampler
from tools.data.lmdb_dataset import LMDBDataSet
from tools.data.lmdb_dataset_test import LMDBDataSetTest
from tools.data.multi_scale_sampler import MultiScaleSampler
from tools.data.ratio_dataset import RatioDataSet
from tools.data.ratio_dataset_test import RatioDataSetTest
from tools.data.ratio_dataset_tvresize_test import RatioDataSetTVResizeTest
from tools.data.ratio_dataset_tvresize import RatioDataSetTVResize
from tools.data.ratio_sampler import RatioSampler
from tools.data.simple_dataset import MultiScaleDataSet, SimpleDataSet
from tools.data.strlmdb_dataset import STRLMDBDataSet
__all__ = [
'build_dataloader',
'transform',
'create_operators',
]
def build_dataloader(config, mode, logger, seed=None, epoch=3):
config = copy.deepcopy(config)
support_dict = [
'SimpleDataSet', 'LMDBDataSet', 'MultiScaleDataSet', 'STRLMDBDataSet',
'LMDBDataSetTest', 'RatioDataSet', 'RatioDataSetTest',
'RatioDataSetTVResize', 'RatioDataSetTVResizeTest'
]
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}/{}'.format(support_dict, module_name))
assert mode in ['Train', 'Eval',
'Test'], 'Mode should be Train, Eval or Test.'
dataset = eval(module_name)(config, mode, logger, seed, epoch=epoch)
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last']
shuffle = loader_config['shuffle']
num_workers = loader_config['num_workers']
if 'pin_memory' in loader_config.keys():
pin_memory = loader_config['use_shared_memory']
else:
pin_memory = False
sampler = None
batch_sampler = None
if 'sampler' in config[mode]:
config_sampler = config[mode]['sampler']
sampler_name = config_sampler.pop('name')
batch_sampler = eval(sampler_name)(dataset, **config_sampler)
elif config['Global']['distributed'] and mode == 'Train':
sampler = DistributedSampler(dataset=dataset, shuffle=shuffle)
if 'collate_fn' in loader_config:
from . import collate_fn
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
else:
collate_fn = None
if batch_sampler is None:
data_loader = DataLoader(
dataset=dataset,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn,
batch_size=batch_size,
drop_last=drop_last,
)
else:
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn,
)
if len(data_loader) == 0:
logger.error(
f'No Images in {mode.lower()} dataloader, please ensure\n'
'\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n'
'\t2. The annotation file and path in the configuration file are provided normally.\n'
'\t3. The BatchSize is large than images.')
sys.exit()
return data_loader