File size: 2,335 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from abc import ABCMeta
from typing import Optional, Union

import numpy as np

from .base_dataset import BaseDataset
from .builder import DATASETS


@DATASETS.register_module()
class MeshDataset(BaseDataset, metaclass=ABCMeta):
    """Mesh Dataset. This dataset only contains smpl data.

    Args:
        data_prefix (str): the prefix of data path.
        pipeline (list): a list of dict, where each element represents
            a operation defined in `detrsmpl.datasets.pipelines`.
        dataset_name (str | None): the name of dataset. It is used to
            identify the type of evaluation metric. Default: None.
        ann_file (str | None, optional): the annotation file. When ann_file
            is str, the subclass is expected to read from the ann_file. When
            ann_file is None, the subclass is expected to read according
            to data_prefix.
        test_mode (bool, optional): in train mode or test mode. Default: False.
    """
    def __init__(self,
                 data_prefix: str,
                 pipeline: list,
                 dataset_name: str,
                 ann_file: Optional[Union[str, None]] = None,
                 test_mode: Optional[bool] = False):
        self.dataset_name = dataset_name
        super(MeshDataset, self).__init__(data_prefix=data_prefix,
                                          pipeline=pipeline,
                                          ann_file=ann_file,
                                          test_mode=test_mode)

    def get_annotation_file(self):
        ann_prefix = os.path.join(self.data_prefix, 'preprocessed_datasets')
        self.ann_file = os.path.join(ann_prefix, self.ann_file)

    def load_annotations(self):

        self.get_annotation_file()
        data = np.load(self.ann_file, allow_pickle=True)

        self.smpl = data['smpl'].item()
        num_data = self.smpl['global_orient'].shape[0]
        if 'transl' not in self.smpl:
            self.smpl['transl'] = np.zeros((num_data, 3))
        self.has_smpl = np.ones((num_data))

        data_infos = []

        for idx in range(num_data):
            info = {}
            for k, v in self.smpl.items():
                info['smpl_' + k] = v[idx]

            data_infos.append(info)
        self.num_data = len(data_infos)
        self.data_infos = data_infos