File size: 1,726 Bytes
32b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import os
import copy
import torch
from .build import DATASETS_REGISTRY
from uniperceiver.config import configurable
import numpy as np
"""
only standard dataset support now
class dataset:
def __len__(self):
pass
def __getitem__(self, index):
pass
"""
class UnifiedDataset:
def __init__(self, cfg, task_cfg, stage, **kwargs):
self.cfg = cfg
self.task_cfg =task_cfg
assert stage == 'train', f'now only training dataset is supported'
datasets = dict()
for name, new_cfg in self.task_cfg.items():
datasets[name] = self.build_unit_dataset(new_cfg,
new_cfg.DATASETS.TRAIN,
stage=stage)
self.datasets = datasets
self.dataset_name = list(self.datasets.keys())
self.dataset_list = list(self.datasets.values())
self.dataset_length = np.array([len(ds) for ds in self.datasets.values()])
self.dataset_scale = np.array([0] + np.cumsum(self.dataset_length).tolist()[:-1])
# [0, 876, 128000, ....]
pass
def build_unit_dataset(self, cfg, name, stage):
dataset_mapper = DATASETS_REGISTRY.get(name)(cfg, stage)
return dataset_mapper
def __len__(self):
return np.cumsum(self.dataset_length).tolist()[-1]
def __getitem__(self, index):
dataset_index = (index >= self.dataset_scale).sum() - 1 # the dataset index
offset = self.dataset_scale[dataset_index]
ret = self.dataset_list[dataset_index][index-offset]
ret.update({"task_name": self.dataset_name[dataset_index]})
return ret
|