File size: 2,478 Bytes
e4bd7f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import logging

import torch.distributed as dist

import bubogpt.common.utils as utils
from bubogpt.common.dist_utils import is_dist_avail_and_initialized, is_main_process
from bubogpt.common.registry import registry
from bubogpt.datasets.builders import load_dataset_config
from bubogpt.processors.base_processor import BaseProcessor


class MultimodalBaseDatasetBuilder():
    train_dataset_cls, eval_dataset_cls = None, None

    def __init__(self, cfg=None):
        super().__init__()

        if cfg is None:
            # help to create datasets from default config.
            self.config = load_dataset_config(self.default_config_path())
        elif isinstance(cfg, str):
            self.config = load_dataset_config(cfg)
        else:
            # when called from task.build_dataset()
            self.config = cfg

        self.data_type = self.config.data_type.split("_")
        # It will be a list like ["audio", "image"], etc.

        # Add "text" manually here.

        self.processors = {modal: {"train": BaseProcessor(), "eval": BaseProcessor()}
                           for modal in [*self.data_type, "text"]}

    def build_datasets(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed

        if is_main_process():
            self._download_data()

        if is_dist_avail_and_initialized():
            dist.barrier()

        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
        logging.info("Building datasets...")
        datasets = self.build()  # dataset['train'/'val'/'test']

        return datasets

    def build_processors(self):
        for modal in [*self.data_type, "text"]:
            proc_cfg = self.config.get("{}_processor".format(modal))
            if proc_cfg is not None:
                train_cfg = proc_cfg.get("train")
                eval_cfg = proc_cfg.get("eval")
                self.processors[modal]["train"] = self._build_proc_from_cfg(train_cfg)
                self.processors[modal]["eval"] = self._build_proc_from_cfg(eval_cfg)


    @staticmethod
    def _build_proc_from_cfg(cfg):
        return (
            registry.get_processor_class(cfg.name).from_config(cfg)
            if cfg is not None
            else None
        )

    @classmethod
    def default_config_path(cls, type="default"):
        return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])

    def _download_data(self):
        pass