File size: 5,726 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import enum
from functools import reduce
from typing import Dict, List, Tuple
import numpy as np
import copy
from utils.common.log import logger
from ..datasets.ab_dataset import ABDataset
from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
from data import get_dataset, MergedDataset, Scenario as DAScenario
class _ABDatasetMetaInfo:
def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type, ignore_classes, idx_map):
self.name = name
self.classes = classes
self.class_aliases = class_aliases
self.shift_type = shift_type
self.task_type = task_type
self.object_type = object_type
self.ignore_classes = ignore_classes
self.idx_map = idx_map
def __repr__(self) -> str:
return f'({self.name}, {self.classes}, {self.idx_map})'
class Scenario:
def __init__(self, config, target_datasets_info: List[_ABDatasetMetaInfo], num_classes: int, num_source_classes: int, data_dirs):
self.config = config
self.target_datasets_info = target_datasets_info
self.num_classes = num_classes
self.cur_task_index = 0
self.num_source_classes = num_source_classes
self.cur_class_offset = num_source_classes
self.data_dirs = data_dirs
self.target_tasks_order = [i.name for i in self.target_datasets_info]
self.num_tasks_to_be_learn = sum([len(i.classes) for i in target_datasets_info])
logger.info(f'[scenario build] # classes: {num_classes}, # tasks to be learnt: {len(target_datasets_info)}, '
f'# classes per task: {config["num_classes_per_task"]}')
def to_json(self):
config = copy.deepcopy(self.config)
config['da_scenario'] = config['da_scenario'].to_json()
target_datasets_info = [str(i) for i in self.target_datasets_info]
return dict(
config=config, target_datasets_info=target_datasets_info,
num_classes=self.num_classes
)
def __str__(self):
return f'Scenario({self.to_json()})'
def get_cur_class_offset(self):
return self.cur_class_offset
def get_cur_num_class(self):
return len(self.target_datasets_info[self.cur_task_index].classes)
def get_nc_per_task(self):
return len(self.target_datasets_info[0].classes)
def next_task(self):
self.cur_class_offset += len(self.target_datasets_info[self.cur_task_index].classes)
self.cur_task_index += 1
print(f'now, cur task: {self.cur_task_index}, cur_class_offset: {self.cur_class_offset}')
def get_cur_task_datasets(self):
dataset_info = self.target_datasets_info[self.cur_task_index]
dataset_name = dataset_info.name.split('|')[0]
# print()
# source_datasets_info = []
res ={ **{split: get_dataset(dataset_name=dataset_name,
root_dir=self.data_dirs[dataset_name],
split=split,
transform=None,
ignore_classes=dataset_info.ignore_classes,
idx_map=dataset_info.idx_map) for split in ['train']},
**{split: MergedDataset([get_dataset(dataset_name=dataset_name,
root_dir=self.data_dirs[dataset_name],
split=split,
transform=None,
ignore_classes=di.ignore_classes,
idx_map=di.idx_map) for di in self.target_datasets_info[0: self.cur_task_index + 1]])
for split in ['val', 'test']}
}
# if len(res['train']) < 200 or len(res['val']) < 200 or len(res['test']) < 200:
# return None
if len(res['train']) < 1000:
res['train'] = MergedDataset([res['train']] * 5)
logger.info('aug train dataset')
if len(res['val']) < 1000:
res['val'] = MergedDataset(res['val'].datasets * 5)
logger.info('aug val dataset')
if len(res['test']) < 1000:
res['test'] = MergedDataset(res['test'].datasets * 5)
logger.info('aug test dataset')
# da_scenario: DAScenario = self.config['da_scenario']
# offline_datasets = da_scenario.get_offline_datasets()
for k, v in res.items():
logger.info(f'{k} dataset: {len(v)}')
# new_val_datasets = [
# *[d['val'] for d in offline_datasets.values()],
# res['val']
# ]
# res['val'] = MergedDataset(new_val_datasets)
# new_test_datasets = [
# *[d['test'] for d in offline_datasets.values()],
# res['test']
# ]
# res['test'] = MergedDataset(new_test_datasets)
return res
def get_cur_task_train_datasets(self):
dataset_info = self.target_datasets_info[self.cur_task_index]
dataset_name = dataset_info.name.split('|')[0]
# print()
# source_datasets_info = []
res = get_dataset(dataset_name=dataset_name,
root_dir=self.data_dirs[dataset_name],
split='train',
transform=None,
ignore_classes=dataset_info.ignore_classes,
idx_map=dataset_info.idx_map)
return res
def get_online_cur_task_samples_for_training(self, num_samples):
dataset = self.get_cur_task_datasets()
dataset = dataset['train']
return next(iter(build_dataloader(dataset, num_samples, 0, True, None)))[0] |