Spaces:
Sleeping
Sleeping
File size: 3,253 Bytes
fa84113 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
""" Dataset factory
Updated 2021 Wimlds in Detect Waste in Pomerania
"""
from collections import OrderedDict
from pathlib import Path
from .dataset_config import *
from .parsers import *
from .dataset import DetectionDatset
from .parsers import create_parser
# list of detect-waste datasets
waste_datasets_list = ['taco', 'detectwaste', 'binary', 'multi',
'uav', 'mju', 'trashcan', 'wade', 'icra'
'drinkwaste']
def create_dataset(name, root, ann, splits=('train', 'val')):
if isinstance(splits, str):
splits = (splits,)
name = name.lower()
root = Path(root)
dataset_cls = DetectionDatset
datasets = OrderedDict()
if name.startswith('coco'):
if 'coco2014' in name:
dataset_cfg = Coco2014Cfg()
else:
dataset_cfg = Coco2017Cfg()
for s in splits:
if s not in dataset_cfg.splits:
raise RuntimeError(f'{s} split not found in config')
split_cfg = dataset_cfg.splits[s]
ann_file = root / split_cfg['ann_filename']
parser_cfg = CocoParserCfg(
ann_filename=ann_file,
has_labels=split_cfg['has_labels']
)
datasets[s] = dataset_cls(
data_dir=root / Path(split_cfg['img_dir']),
parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
)
datasets = OrderedDict()
elif name in waste_datasets_list:
if name.startswith('taco'):
dataset_cfg = TACOCfg(root=root, ann=ann)
elif name.startswith('detectwaste'):
dataset_cfg = DetectwasteCfg(root=root, ann=ann)
elif name.startswith('binary'):
dataset_cfg = BinaryCfg(root=root, ann=ann)
elif name.startswith('multi'):
dataset_cfg = BinaryMultiCfg(root=root, ann=ann)
elif name.startswith('uav'):
dataset_cfg = UAVVasteCfg(root=root, ann=ann)
elif name.startswith('trashcan'):
dataset_cfg = TrashCanCfg(root=root, ann=ann)
elif name.startswith('drinkwaste'):
dataset_cfg = DrinkWasteCfg(root=root, ann=ann)
elif name.startswith('mju'):
dataset_cfg = MJU_WasteCfg(root=root, ann=ann)
elif name.startswith('wade'):
dataset_cfg = WadeCfg(root=root, ann=ann)
elif name.startswith('icra'):
dataset_cfg = ICRACfg(root=root, ann=ann)
else:
assert False, f'Unknown dataset parser ({name})'
dataset_cfg.add_split()
for s in splits:
if s not in dataset_cfg.splits:
raise RuntimeError(f'{s} split not found in config')
split_cfg = dataset_cfg.splits[s]
parser_cfg = CocoParserCfg(
ann_filename=split_cfg['ann_filename'],
has_labels=split_cfg['has_labels']
)
datasets[s] = dataset_cls(
data_dir=split_cfg['img_dir'],
parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
)
else:
assert False, f'Unknown dataset parser ({name})'
datasets = list(datasets.values())
return datasets if len(datasets) > 1 else datasets[0]
|