OMG_Seg / seg /datasets /concat_dataset.py
Haobo Yuan
add omg code
b34d1d6
from abc import ABC
import logging
from typing import Sequence, Union, Optional, Tuple
from mmengine.dataset import ConcatDataset, RepeatDataset, ClassBalancedDataset
from mmengine.logging import print_log
from mmengine.registry import DATASETS
from mmengine.dataset.base_dataset import BaseDataset
from mmdet.structures import TrackDataSample
from seg.models.utils import NO_OBJ
@DATASETS.register_module()
class ConcatOVDataset(ConcatDataset, ABC):
_fully_initialized: bool = False
def __init__(self,
datasets: Sequence[Union[BaseDataset, dict]],
lazy_init: bool = False,
data_tag: Optional[Tuple[str]] = None,
):
for i, dataset in enumerate(datasets):
if isinstance(dataset, dict):
dataset.update(lazy_init=lazy_init)
if 'times' in dataset:
dataset['dataset'].update(lazy_init=lazy_init)
super().__init__(datasets, lazy_init=lazy_init,
ignore_keys=['classes', 'thing_classes', 'stuff_classes', 'palette'])
self.data_tag = data_tag
if self.data_tag is not None:
assert len(self.data_tag) == len(datasets)
cls_names = []
for dataset in self.datasets:
if isinstance(dataset, RepeatDataset) or isinstance(dataset, ClassBalancedDataset):
if hasattr(dataset.dataset, 'dataset_name'):
name = dataset.dataset.dataset_name
else:
name = dataset.dataset.__class__.__name__
else:
if hasattr(dataset, 'dataset_name'):
name = dataset.dataset_name
else:
name = dataset.__class__.__name__
cls_names.append(name)
thing_classes = []
thing_mapper = []
stuff_classes = []
stuff_mapper = []
for idx, dataset in enumerate(self.datasets):
if 'classes' not in dataset.metainfo or (self.data_tag is not None and self.data_tag[idx] in ['sam']):
# class agnostic dataset
_thing_mapper = {}
_stuff_mapper = {}
thing_mapper.append(_thing_mapper)
stuff_mapper.append(_stuff_mapper)
continue
_thing_classes = dataset.metainfo['thing_classes'] \
if 'thing_classes' in dataset.metainfo else dataset.metainfo['classes']
_stuff_classes = dataset.metainfo['stuff_classes'] if 'stuff_classes' in dataset.metainfo else []
_thing_mapper = {}
_stuff_mapper = {}
for idy, cls in enumerate(_thing_classes):
flag = False
cls = cls.replace('_or_', ',')
cls = cls.replace('/', ',')
cls = cls.replace('_', ' ')
cls = cls.lower()
for all_idx, all_cls in enumerate(thing_classes):
if set(cls.split(',')).intersection(set(all_cls.split(','))):
_thing_mapper[idy] = all_idx
flag = True
break
if not flag:
thing_classes.append(cls)
_thing_mapper[idy] = len(thing_classes) - 1
thing_mapper.append(_thing_mapper)
for idy, cls in enumerate(_stuff_classes):
flag = False
cls = cls.replace('_or_', ',')
cls = cls.replace('/', ',')
cls = cls.replace('_', ' ')
cls = cls.lower()
for all_idx, all_cls in enumerate(stuff_classes):
if set(cls.split(',')).intersection(set(all_cls.split(','))):
_stuff_mapper[idy] = all_idx
flag = True
break
if not flag:
stuff_classes.append(cls)
_stuff_mapper[idy] = len(stuff_classes) - 1
stuff_mapper.append(_stuff_mapper)
cls_name = ""
cnt = 0
dataset_idx = 0
classes = [*thing_classes, *stuff_classes]
mapper = []
meta_cls_names = []
for _thing_mapper, _stuff_mapper in zip(thing_mapper, stuff_mapper):
if not _thing_mapper and not _stuff_mapper:
# class agnostic dataset
_mapper = dict()
for idx in range(1000):
_mapper[idx] = -1
else:
_mapper = {**_thing_mapper}
_num_thing = len(_thing_mapper)
for key, value in _stuff_mapper.items():
assert value < len(stuff_classes)
_mapper[key + _num_thing] = _stuff_mapper[key] + len(thing_classes)
assert len(_mapper) == len(_thing_mapper) + len(_stuff_mapper)
cnt += 1
cls_name = cls_name + cls_names[dataset_idx] + "_"
meta_cls_names.append(cls_names[dataset_idx])
_mapper[NO_OBJ] = NO_OBJ
mapper.append(_mapper)
dataset_idx += 1
if cnt > 1:
cls_name = "Concat_" + cls_name
cls_name = cls_name[:-1]
self.dataset_name = cls_name
self._metainfo.update({
'classes': classes,
'thing_classes': thing_classes,
'stuff_classes': stuff_classes,
'mapper': mapper,
'dataset_names': meta_cls_names
})
print_log(
f"------------{self.dataset_name}------------",
logger='current',
level=logging.INFO
)
for idx, dataset in enumerate(self.datasets):
dataset_type = cls_names[idx]
if isinstance(dataset, RepeatDataset):
times = dataset.times
else:
times = 1
print_log(
f"|---dataset#{idx + 1} --> name: {dataset_type}; length: {len(dataset)}; repeat times: {times}",
logger='current',
level=logging.INFO
)
print_log(
f"------num_things : {len(thing_classes)}; num_stuff : {len(stuff_classes)}------",
logger='current',
level=logging.INFO
)
def get_dataset_source(self, idx: int) -> int:
dataset_idx, _ = self._get_ori_dataset_idx(idx)
return dataset_idx
def __getitem__(self, idx):
if not self._fully_initialized:
print_log(
'Please call `full_init` method manually to '
'accelerate the speed.',
logger='current',
level=logging.WARNING)
self.full_init()
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
results = self.datasets[dataset_idx][sample_idx]
_mapper = self.metainfo['mapper'][dataset_idx]
data_samples = results['data_samples']
if isinstance(data_samples, TrackDataSample):
for det_sample in data_samples:
if 'gt_sem_seg' in det_sample:
det_sample.gt_sem_seg.sem_seg.apply_(lambda x: _mapper.__getitem__(x))
if 'gt_instances' in det_sample:
det_sample.gt_instances.labels.apply_(lambda x: _mapper.__getitem__(x))
else:
if 'gt_sem_seg' in data_samples:
data_samples.gt_sem_seg.sem_seg.apply_(lambda x: _mapper.__getitem__(x))
if 'gt_instances' in data_samples:
data_samples.gt_instances.labels.apply_(lambda x: _mapper.__getitem__(x))
if self.data_tag is not None:
data_samples.data_tag = self.data_tag[dataset_idx]
return results