Spaces:
Sleeping
Sleeping
File size: 5,693 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from abc import abstractmethod
from typing import Dict, List, Optional
from mmengine import mkdir_or_exist
class BaseDatasetConfigGenerator:
"""Base class for dataset config generator.
Args:
data_root (str): The root path of the dataset.
task (str): The task of the dataset.
dataset_name (str): The name of the dataset.
overwrite_cfg (bool): Whether to overwrite the dataset config file if
it already exists. If False, config generator will not generate new
config for datasets whose configs are already in base.
train_anns (List[Dict], optional): A list of train annotation files
to appear in the base configs. Defaults to None.
Each element is typically a dict with the following fields:
- ann_file (str): The path to the annotation file relative to
data_root.
- dataset_postfix (str, optional): Affects the postfix of the
resulting variable in the generated config. If specified, the
dataset variable will be named in the form of
``{dataset_name}_{dataset_postfix}_{task}_{split}``. Defaults to
None.
val_anns (List[Dict], optional): A list of val annotation files
to appear in the base configs, similar to ``train_anns``. Defaults
to None.
test_anns (List[Dict], optional): A list of test annotation files
to appear in the base configs, similar to ``train_anns``. Defaults
to None.
config_path (str): Path to the configs. Defaults to 'configs/'.
"""
def __init__(
self,
data_root: str,
task: str,
dataset_name: str,
overwrite_cfg: bool = False,
train_anns: Optional[List[Dict]] = None,
val_anns: Optional[List[Dict]] = None,
test_anns: Optional[List[Dict]] = None,
config_path: str = 'configs/',
) -> None:
self.config_path = config_path
self.data_root = data_root
self.task = task
self.dataset_name = dataset_name
self.overwrite_cfg = overwrite_cfg
self._prepare_anns(train_anns, val_anns, test_anns)
def _prepare_anns(self, train_anns: Optional[List[Dict]],
val_anns: Optional[List[Dict]],
test_anns: Optional[List[Dict]]) -> None:
"""Preprocess input arguments and stores these information into
``self.anns``.
``self.anns`` is a dict that maps the name of a dataset config variable
to a dict, which contains the following fields:
- ann_file (str): The path to the annotation file relative to
data_root.
- split (str): The split the annotation belongs to. Usually
it can be 'train', 'val' and 'test'.
- dataset_postfix (str, optional): Affects the postfix of the
resulting variable in the generated config. If specified, the
dataset variable will be named in the form of
``{dataset_name}_{dataset_postfix}_{task}_{split}``. Defaults to
None.
"""
self.anns = {}
for split, ann_list in zip(('train', 'val', 'test'),
(train_anns, val_anns, test_anns)):
if ann_list is None:
continue
if not isinstance(ann_list, list):
raise ValueError(f'{split}_anns must be either a list or'
' None!')
for ann_dict in ann_list:
assert 'ann_file' in ann_dict
suffix = ann_dict['ann_file'].split('.')[-1]
if suffix == 'json':
dataset_type = 'OCRDataset'
elif suffix == 'lmdb':
assert self.task == 'textrecog', \
'LMDB format only works for textrecog now.'
dataset_type = 'RecogLMDBDataset'
else:
raise NotImplementedError(
'ann file only supports JSON file or LMDB file')
ann_dict['dataset_type'] = dataset_type
if ann_dict.get('dataset_postfix', ''):
key = f'{self.dataset_name}_{ann_dict["dataset_postfix"]}_{self.task}_{split}' # noqa
else:
key = f'{self.dataset_name}_{self.task}_{split}'
ann_dict['split'] = split
if key in self.anns:
raise ValueError(
f'Duplicate dataset variable {key} found! '
'Please use different dataset_postfix to avoid '
'conflict.')
self.anns[key] = ann_dict
def __call__(self) -> None:
"""Generates the base dataset config."""
dataset_config = self._gen_dataset_config()
cfg_path = osp.join(self.config_path, self.task, '_base_', 'datasets',
f'{self.dataset_name}.py')
if osp.exists(cfg_path) and not self.overwrite_cfg:
print(f'{cfg_path} found, skipping.')
return
mkdir_or_exist(osp.dirname(cfg_path))
with open(cfg_path, 'w') as f:
f.write(
f'{self.dataset_name}_{self.task}_data_root = \'{self.data_root}\'\n' # noqa: E501
)
f.write(dataset_config)
@abstractmethod
def _gen_dataset_config(self) -> str:
"""Generate a full dataset config based on the annotation file
dictionary.
Returns:
str: The generated dataset config.
"""
|