File size: 2,496 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import copy
from abc import ABCMeta, abstractmethod
from typing import Optional, Union

from torch.utils.data import Dataset

from .pipelines import Compose


class BaseDataset(Dataset, metaclass=ABCMeta):
    """Base dataset.

    Args:
        data_prefix (str): the prefix of data path.
        pipeline (list): a list of dict, where each element represents
            a operation defined in `mmhuman3d.datasets.pipelines`.
        ann_file (str | None, optional): the annotation file. When ann_file is
            str, the subclass is expected to read from the ann_file. When
            ann_file is None, the subclass is expected to read according
            to data_prefix.
        test_mode (bool): in train mode or test mode. Default: None.
        dataset_name (str | None, optional): the name of dataset. It is used
            to identify the type of evaluation metric. Default: None.
    """
    # metric
    ALLOWED_METRICS = {
        'mpjpe', 'pa-mpjpe', 'pve', '3dpck', 'pa-3dpck', '3dauc', 'pa-3dauc',
        '3DRMSE', 'pa-pve'
    }

    def __init__(self,
                 data_prefix: str,
                 pipeline: list,
                 ann_file: Optional[Union[str, None]] = None,
                 test_mode: Optional[bool] = False,
                 dataset_name: Optional[Union[str, None]] = None):
        super(BaseDataset, self).__init__()

        self.ann_file = ann_file
        self.data_prefix = data_prefix
        self.test_mode = test_mode
        self.pipeline = Compose(pipeline)
        if dataset_name is not None:
            self.dataset_name = dataset_name

        self.load_annotations()

    @abstractmethod
    def load_annotations(self):
        """Load annotations from ``ann_file``"""
        pass

    def prepare_data(self, idx: int):
        """"Prepare raw data for the f'{idx'}-th data."""
        results = copy.deepcopy(self.data_infos[idx])
        results['dataset_name'] = self.dataset_name
        results['sample_idx'] = idx
        return self.pipeline(results)

    def __len__(self):
        """Return the length of current dataset."""
        return self.num_data

    def __getitem__(self, idx: int):
        """Prepare data for the ``idx``-th data.

        As for video dataset, we can first parse raw data for each frame. Then
        we combine annotations from all frames. This interface is used to
        simplify the logic of video dataset and other special datasets.
        """
        return self.prepare_data(idx)