Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
from abc import abstractmethod | |
from typing import Dict, List, Optional | |
from mmengine import mkdir_or_exist | |
class BaseDatasetConfigGenerator: | |
"""Base class for dataset config generator. | |
Args: | |
data_root (str): The root path of the dataset. | |
task (str): The task of the dataset. | |
dataset_name (str): The name of the dataset. | |
overwrite_cfg (bool): Whether to overwrite the dataset config file if | |
it already exists. If False, config generator will not generate new | |
config for datasets whose configs are already in base. | |
train_anns (List[Dict], optional): A list of train annotation files | |
to appear in the base configs. Defaults to None. | |
Each element is typically a dict with the following fields: | |
- ann_file (str): The path to the annotation file relative to | |
data_root. | |
- dataset_postfix (str, optional): Affects the postfix of the | |
resulting variable in the generated config. If specified, the | |
dataset variable will be named in the form of | |
``{dataset_name}_{dataset_postfix}_{task}_{split}``. Defaults to | |
None. | |
val_anns (List[Dict], optional): A list of val annotation files | |
to appear in the base configs, similar to ``train_anns``. Defaults | |
to None. | |
test_anns (List[Dict], optional): A list of test annotation files | |
to appear in the base configs, similar to ``train_anns``. Defaults | |
to None. | |
config_path (str): Path to the configs. Defaults to 'configs/'. | |
""" | |
def __init__( | |
self, | |
data_root: str, | |
task: str, | |
dataset_name: str, | |
overwrite_cfg: bool = False, | |
train_anns: Optional[List[Dict]] = None, | |
val_anns: Optional[List[Dict]] = None, | |
test_anns: Optional[List[Dict]] = None, | |
config_path: str = 'configs/', | |
) -> None: | |
self.config_path = config_path | |
self.data_root = data_root | |
self.task = task | |
self.dataset_name = dataset_name | |
self.overwrite_cfg = overwrite_cfg | |
self._prepare_anns(train_anns, val_anns, test_anns) | |
def _prepare_anns(self, train_anns: Optional[List[Dict]], | |
val_anns: Optional[List[Dict]], | |
test_anns: Optional[List[Dict]]) -> None: | |
"""Preprocess input arguments and stores these information into | |
``self.anns``. | |
``self.anns`` is a dict that maps the name of a dataset config variable | |
to a dict, which contains the following fields: | |
- ann_file (str): The path to the annotation file relative to | |
data_root. | |
- split (str): The split the annotation belongs to. Usually | |
it can be 'train', 'val' and 'test'. | |
- dataset_postfix (str, optional): Affects the postfix of the | |
resulting variable in the generated config. If specified, the | |
dataset variable will be named in the form of | |
``{dataset_name}_{dataset_postfix}_{task}_{split}``. Defaults to | |
None. | |
""" | |
self.anns = {} | |
for split, ann_list in zip(('train', 'val', 'test'), | |
(train_anns, val_anns, test_anns)): | |
if ann_list is None: | |
continue | |
if not isinstance(ann_list, list): | |
raise ValueError(f'{split}_anns must be either a list or' | |
' None!') | |
for ann_dict in ann_list: | |
assert 'ann_file' in ann_dict | |
suffix = ann_dict['ann_file'].split('.')[-1] | |
if suffix == 'json': | |
dataset_type = 'OCRDataset' | |
elif suffix == 'lmdb': | |
assert self.task == 'textrecog', \ | |
'LMDB format only works for textrecog now.' | |
dataset_type = 'RecogLMDBDataset' | |
else: | |
raise NotImplementedError( | |
'ann file only supports JSON file or LMDB file') | |
ann_dict['dataset_type'] = dataset_type | |
if ann_dict.get('dataset_postfix', ''): | |
key = f'{self.dataset_name}_{ann_dict["dataset_postfix"]}_{self.task}_{split}' # noqa | |
else: | |
key = f'{self.dataset_name}_{self.task}_{split}' | |
ann_dict['split'] = split | |
if key in self.anns: | |
raise ValueError( | |
f'Duplicate dataset variable {key} found! ' | |
'Please use different dataset_postfix to avoid ' | |
'conflict.') | |
self.anns[key] = ann_dict | |
def __call__(self) -> None: | |
"""Generates the base dataset config.""" | |
dataset_config = self._gen_dataset_config() | |
cfg_path = osp.join(self.config_path, self.task, '_base_', 'datasets', | |
f'{self.dataset_name}.py') | |
if osp.exists(cfg_path) and not self.overwrite_cfg: | |
print(f'{cfg_path} found, skipping.') | |
return | |
mkdir_or_exist(osp.dirname(cfg_path)) | |
with open(cfg_path, 'w') as f: | |
f.write( | |
f'{self.dataset_name}_{self.task}_data_root = \'{self.data_root}\'\n' # noqa: E501 | |
) | |
f.write(dataset_config) | |
def _gen_dataset_config(self) -> str: | |
"""Generate a full dataset config based on the annotation file | |
dictionary. | |
Returns: | |
str: The generated dataset config. | |
""" | |