File size: 5,693 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from abc import abstractmethod
from typing import Dict, List, Optional

from mmengine import mkdir_or_exist


class BaseDatasetConfigGenerator:
    """Base class for dataset config generator.

    Args:
        data_root (str): The root path of the dataset.
        task (str): The task of the dataset.
        dataset_name (str): The name of the dataset.
        overwrite_cfg (bool): Whether to overwrite the dataset config file if
            it already exists. If False, config generator will not generate new
            config for datasets whose configs are already in base.
        train_anns (List[Dict], optional): A list of train annotation files
            to appear in the base configs. Defaults to None.
            Each element is typically a dict with the following fields:
            - ann_file (str): The path to the annotation file relative to
              data_root.
            - dataset_postfix (str, optional): Affects the postfix of the
              resulting variable in the generated config. If specified, the
              dataset variable will be named in the form of
              ``{dataset_name}_{dataset_postfix}_{task}_{split}``. Defaults to
              None.
        val_anns (List[Dict], optional): A list of val annotation files
            to appear in the base configs, similar to ``train_anns``. Defaults
            to None.
        test_anns (List[Dict], optional): A list of test annotation files
            to appear in the base configs, similar to ``train_anns``. Defaults
            to None.
        config_path (str): Path to the configs. Defaults to 'configs/'.
    """

    def __init__(
        self,
        data_root: str,
        task: str,
        dataset_name: str,
        overwrite_cfg: bool = False,
        train_anns: Optional[List[Dict]] = None,
        val_anns: Optional[List[Dict]] = None,
        test_anns: Optional[List[Dict]] = None,
        config_path: str = 'configs/',
    ) -> None:
        self.config_path = config_path
        self.data_root = data_root
        self.task = task
        self.dataset_name = dataset_name
        self.overwrite_cfg = overwrite_cfg
        self._prepare_anns(train_anns, val_anns, test_anns)

    def _prepare_anns(self, train_anns: Optional[List[Dict]],
                      val_anns: Optional[List[Dict]],
                      test_anns: Optional[List[Dict]]) -> None:
        """Preprocess input arguments and stores these information into
        ``self.anns``.

        ``self.anns`` is a dict that maps the name of a dataset config variable
        to a dict, which contains the following fields:
        - ann_file (str): The path to the annotation file relative to
          data_root.
        - split (str): The split the annotation belongs to. Usually
          it can be 'train', 'val' and 'test'.
        - dataset_postfix (str, optional): Affects the postfix of the
          resulting variable in the generated config. If specified, the
          dataset variable will be named in the form of
          ``{dataset_name}_{dataset_postfix}_{task}_{split}``. Defaults to
          None.
        """
        self.anns = {}
        for split, ann_list in zip(('train', 'val', 'test'),
                                   (train_anns, val_anns, test_anns)):
            if ann_list is None:
                continue
            if not isinstance(ann_list, list):
                raise ValueError(f'{split}_anns must be either a list or'
                                 ' None!')
            for ann_dict in ann_list:
                assert 'ann_file' in ann_dict
                suffix = ann_dict['ann_file'].split('.')[-1]
                if suffix == 'json':
                    dataset_type = 'OCRDataset'
                elif suffix == 'lmdb':
                    assert self.task == 'textrecog', \
                        'LMDB format only works for textrecog now.'
                    dataset_type = 'RecogLMDBDataset'
                else:
                    raise NotImplementedError(
                        'ann file only supports JSON file or LMDB file')
                ann_dict['dataset_type'] = dataset_type
                if ann_dict.get('dataset_postfix', ''):
                    key = f'{self.dataset_name}_{ann_dict["dataset_postfix"]}_{self.task}_{split}'  # noqa
                else:
                    key = f'{self.dataset_name}_{self.task}_{split}'
                ann_dict['split'] = split
                if key in self.anns:
                    raise ValueError(
                        f'Duplicate dataset variable {key} found! '
                        'Please use different dataset_postfix to avoid '
                        'conflict.')
                self.anns[key] = ann_dict

    def __call__(self) -> None:
        """Generates the base dataset config."""

        dataset_config = self._gen_dataset_config()

        cfg_path = osp.join(self.config_path, self.task, '_base_', 'datasets',
                            f'{self.dataset_name}.py')
        if osp.exists(cfg_path) and not self.overwrite_cfg:
            print(f'{cfg_path} found, skipping.')
            return
        mkdir_or_exist(osp.dirname(cfg_path))
        with open(cfg_path, 'w') as f:
            f.write(
                f'{self.dataset_name}_{self.task}_data_root = \'{self.data_root}\'\n'  # noqa: E501
            )
            f.write(dataset_config)

    @abstractmethod
    def _gen_dataset_config(self) -> str:
        """Generate a full dataset config based on the annotation file
        dictionary.

        Returns:
            str: The generated dataset config.
        """