unit_test / uniperceiver /datasets /unified_dataset.py
herrius's picture
Upload 259 files
32b542e
import os
import copy
import torch
from .build import DATASETS_REGISTRY
from uniperceiver.config import configurable
import numpy as np
"""
only standard dataset support now
class dataset:
def __len__(self):
pass
def __getitem__(self, index):
pass
"""
class UnifiedDataset:
def __init__(self, cfg, task_cfg, stage, **kwargs):
self.cfg = cfg
self.task_cfg =task_cfg
assert stage == 'train', f'now only training dataset is supported'
datasets = dict()
for name, new_cfg in self.task_cfg.items():
datasets[name] = self.build_unit_dataset(new_cfg,
new_cfg.DATASETS.TRAIN,
stage=stage)
self.datasets = datasets
self.dataset_name = list(self.datasets.keys())
self.dataset_list = list(self.datasets.values())
self.dataset_length = np.array([len(ds) for ds in self.datasets.values()])
self.dataset_scale = np.array([0] + np.cumsum(self.dataset_length).tolist()[:-1])
# [0, 876, 128000, ....]
pass
def build_unit_dataset(self, cfg, name, stage):
dataset_mapper = DATASETS_REGISTRY.get(name)(cfg, stage)
return dataset_mapper
def __len__(self):
return np.cumsum(self.dataset_length).tolist()[-1]
def __getitem__(self, index):
dataset_index = (index >= self.dataset_scale).sum() - 1 # the dataset index
offset = self.dataset_scale[dataset_index]
ret = self.dataset_list[dataset_index][index-offset]
ret.update({"task_name": self.dataset_name[dataset_index]})
return ret