File size: 3,253 Bytes
fa84113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
""" Dataset factory

Updated 2021 Wimlds in Detect Waste in Pomerania
"""
from collections import OrderedDict
from pathlib import Path

from .dataset_config import *
from .parsers import *
from .dataset import DetectionDatset
from .parsers import create_parser

# list of detect-waste datasets
waste_datasets_list = ['taco', 'detectwaste', 'binary', 'multi',
                       'uav', 'mju', 'trashcan', 'wade', 'icra'
                       'drinkwaste']


def create_dataset(name, root, ann, splits=('train', 'val')):
    if isinstance(splits, str):
        splits = (splits,)
    name = name.lower()
    root = Path(root)
    dataset_cls = DetectionDatset
    datasets = OrderedDict()
    if name.startswith('coco'):
        if 'coco2014' in name:
            dataset_cfg = Coco2014Cfg()
        else:
            dataset_cfg = Coco2017Cfg()
        for s in splits:
            if s not in dataset_cfg.splits:
                raise RuntimeError(f'{s} split not found in config')
            split_cfg = dataset_cfg.splits[s]
            ann_file = root / split_cfg['ann_filename']
            parser_cfg = CocoParserCfg(
                ann_filename=ann_file,
                has_labels=split_cfg['has_labels']
            )
            datasets[s] = dataset_cls(
                data_dir=root / Path(split_cfg['img_dir']),
                parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
            )
        datasets = OrderedDict()
    elif name in waste_datasets_list:
        if name.startswith('taco'):
            dataset_cfg = TACOCfg(root=root, ann=ann)
        elif name.startswith('detectwaste'):
            dataset_cfg = DetectwasteCfg(root=root, ann=ann)
        elif name.startswith('binary'):
            dataset_cfg = BinaryCfg(root=root, ann=ann)
        elif name.startswith('multi'):
            dataset_cfg = BinaryMultiCfg(root=root, ann=ann)
        elif name.startswith('uav'):
            dataset_cfg = UAVVasteCfg(root=root, ann=ann)
        elif name.startswith('trashcan'):
            dataset_cfg = TrashCanCfg(root=root, ann=ann)
        elif name.startswith('drinkwaste'):
            dataset_cfg = DrinkWasteCfg(root=root, ann=ann)
        elif name.startswith('mju'):
            dataset_cfg = MJU_WasteCfg(root=root, ann=ann)
        elif name.startswith('wade'):
            dataset_cfg = WadeCfg(root=root, ann=ann)
        elif name.startswith('icra'):
            dataset_cfg = ICRACfg(root=root, ann=ann)
        else:
            assert False, f'Unknown dataset parser ({name})'
        dataset_cfg.add_split()
        for s in splits:
            if s not in dataset_cfg.splits:
                raise RuntimeError(f'{s} split not found in config')
            split_cfg = dataset_cfg.splits[s]
            parser_cfg = CocoParserCfg(
                ann_filename=split_cfg['ann_filename'],
                has_labels=split_cfg['has_labels']
            )
            datasets[s] = dataset_cls(
                data_dir=split_cfg['img_dir'],
                parser=create_parser(dataset_cfg.parser, cfg=parser_cfg),
            )
    else:
        assert False, f'Unknown dataset parser ({name})'

    datasets = list(datasets.values())
    return datasets if len(datasets) > 1 else datasets[0]