File size: 5,726 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import enum
from functools import reduce
from typing import Dict, List, Tuple
import numpy as np
import copy
from utils.common.log import logger
from ..datasets.ab_dataset import ABDataset
from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader
from data import get_dataset, MergedDataset, Scenario as DAScenario


class _ABDatasetMetaInfo:
    def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type, ignore_classes, idx_map):
        self.name = name
        self.classes = classes
        self.class_aliases = class_aliases
        self.shift_type = shift_type
        self.task_type = task_type
        self.object_type = object_type
        
        self.ignore_classes = ignore_classes
        self.idx_map = idx_map
        
    def __repr__(self) -> str:
        return f'({self.name}, {self.classes}, {self.idx_map})'


class Scenario:
    def __init__(self, config, target_datasets_info: List[_ABDatasetMetaInfo], num_classes: int, num_source_classes: int, data_dirs):
        self.config = config 
        self.target_datasets_info = target_datasets_info
        self.num_classes = num_classes
        self.cur_task_index = 0
        self.num_source_classes = num_source_classes
        self.cur_class_offset = num_source_classes
        self.data_dirs = data_dirs
        
        self.target_tasks_order = [i.name for i in self.target_datasets_info]
        self.num_tasks_to_be_learn = sum([len(i.classes) for i in target_datasets_info])

        logger.info(f'[scenario build] # classes: {num_classes}, # tasks to be learnt: {len(target_datasets_info)}, '
                    f'# classes per task: {config["num_classes_per_task"]}')

    def to_json(self):
        config = copy.deepcopy(self.config)
        config['da_scenario'] = config['da_scenario'].to_json()
        target_datasets_info = [str(i) for i in self.target_datasets_info]
        return dict(
            config=config, target_datasets_info=target_datasets_info,
            num_classes=self.num_classes
        )
        
    def __str__(self):
        return f'Scenario({self.to_json()})'
    
    def get_cur_class_offset(self):
        return self.cur_class_offset
    
    def get_cur_num_class(self):
        return len(self.target_datasets_info[self.cur_task_index].classes)
    
    def get_nc_per_task(self):
        return len(self.target_datasets_info[0].classes)
    
    def next_task(self):
        self.cur_class_offset += len(self.target_datasets_info[self.cur_task_index].classes)
        self.cur_task_index += 1
        
        print(f'now, cur task: {self.cur_task_index}, cur_class_offset: {self.cur_class_offset}')
        
    def get_cur_task_datasets(self):
        dataset_info = self.target_datasets_info[self.cur_task_index]
        dataset_name = dataset_info.name.split('|')[0]
        # print()
        
        # source_datasets_info = []
        
        res ={ **{split: get_dataset(dataset_name=dataset_name, 
                    root_dir=self.data_dirs[dataset_name], 
                    split=split, 
                    transform=None, 
                    ignore_classes=dataset_info.ignore_classes, 
                    idx_map=dataset_info.idx_map) for split in ['train']}, 
              
              **{split: MergedDataset([get_dataset(dataset_name=dataset_name, 
                    root_dir=self.data_dirs[dataset_name], 
                    split=split, 
                    transform=None, 
                    ignore_classes=di.ignore_classes, 
                    idx_map=di.idx_map) for di in self.target_datasets_info[0: self.cur_task_index + 1]]) 
                 for split in ['val', 'test']}
        }
        
        # if len(res['train']) < 200 or len(res['val']) < 200 or len(res['test']) < 200:
        #     return None
            
        
        if len(res['train']) < 1000:
            res['train'] = MergedDataset([res['train']] * 5)
            logger.info('aug train dataset')
        if len(res['val']) < 1000:
            res['val'] = MergedDataset(res['val'].datasets * 5)
            logger.info('aug val dataset')
        if len(res['test']) < 1000:
            res['test'] = MergedDataset(res['test'].datasets * 5)
            logger.info('aug test dataset')
        # da_scenario: DAScenario = self.config['da_scenario']
        # offline_datasets = da_scenario.get_offline_datasets()
        
        for k, v in res.items():
            logger.info(f'{k} dataset: {len(v)}')
        
        # new_val_datasets = [
        #     *[d['val'] for d in offline_datasets.values()],
        #     res['val']
        # ]
        # res['val'] = MergedDataset(new_val_datasets)
        
        # new_test_datasets = [
        #     *[d['test'] for d in offline_datasets.values()],
        #     res['test']
        # ]
        # res['test'] = MergedDataset(new_test_datasets)
        
        return res
    
    def get_cur_task_train_datasets(self):
        dataset_info = self.target_datasets_info[self.cur_task_index]
        dataset_name = dataset_info.name.split('|')[0]
        # print()
        
        # source_datasets_info = []
        
        res = get_dataset(dataset_name=dataset_name, 
                    root_dir=self.data_dirs[dataset_name], 
                    split='train', 
                    transform=None, 
                    ignore_classes=dataset_info.ignore_classes, 
                    idx_map=dataset_info.idx_map)
        
        return res
    
    def get_online_cur_task_samples_for_training(self, num_samples):
        dataset = self.get_cur_task_datasets()
        dataset = dataset['train']
        return next(iter(build_dataloader(dataset, num_samples, 0, True, None)))[0]