|
import os |
|
import copy |
|
import torch |
|
|
|
from .build import DATASETS_REGISTRY |
|
from uniperceiver.config import configurable |
|
|
|
|
|
|
|
import numpy as np |
|
""" |
|
|
|
only standard dataset support now |
|
class dataset: |
|
def __len__(self): |
|
pass |
|
|
|
def __getitem__(self, index): |
|
pass |
|
|
|
""" |
|
|
|
class UnifiedDataset: |
|
|
|
def __init__(self, cfg, task_cfg, stage, **kwargs): |
|
self.cfg = cfg |
|
self.task_cfg =task_cfg |
|
assert stage == 'train', f'now only training dataset is supported' |
|
|
|
datasets = dict() |
|
for name, new_cfg in self.task_cfg.items(): |
|
datasets[name] = self.build_unit_dataset(new_cfg, |
|
new_cfg.DATASETS.TRAIN, |
|
stage=stage) |
|
self.datasets = datasets |
|
|
|
self.dataset_name = list(self.datasets.keys()) |
|
self.dataset_list = list(self.datasets.values()) |
|
|
|
|
|
self.dataset_length = np.array([len(ds) for ds in self.datasets.values()]) |
|
|
|
self.dataset_scale = np.array([0] + np.cumsum(self.dataset_length).tolist()[:-1]) |
|
|
|
|
|
|
|
pass |
|
|
|
def build_unit_dataset(self, cfg, name, stage): |
|
dataset_mapper = DATASETS_REGISTRY.get(name)(cfg, stage) |
|
return dataset_mapper |
|
|
|
def __len__(self): |
|
return np.cumsum(self.dataset_length).tolist()[-1] |
|
|
|
|
|
def __getitem__(self, index): |
|
dataset_index = (index >= self.dataset_scale).sum() - 1 |
|
offset = self.dataset_scale[dataset_index] |
|
ret = self.dataset_list[dataset_index][index-offset] |
|
ret.update({"task_name": self.dataset_name[dataset_index]}) |
|
return ret |
|
|