|
import math |
|
import os.path as osp |
|
|
|
import pytest |
|
from torch.utils.data import (DistributedSampler, RandomSampler, |
|
SequentialSampler) |
|
|
|
from mmseg.datasets import (DATASETS, ConcatDataset, build_dataloader, |
|
build_dataset) |
|
|
|
|
|
@DATASETS.register_module() |
|
class ToyDataset(object): |
|
|
|
def __init__(self, cnt=0): |
|
self.cnt = cnt |
|
|
|
def __item__(self, idx): |
|
return idx |
|
|
|
def __len__(self): |
|
return 100 |
|
|
|
|
|
def test_build_dataset(): |
|
cfg = dict(type='ToyDataset') |
|
dataset = build_dataset(cfg) |
|
assert isinstance(dataset, ToyDataset) |
|
assert dataset.cnt == 0 |
|
dataset = build_dataset(cfg, default_args=dict(cnt=1)) |
|
assert isinstance(dataset, ToyDataset) |
|
assert dataset.cnt == 1 |
|
|
|
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset') |
|
img_dir = 'imgs/' |
|
ann_dir = 'gts/' |
|
|
|
|
|
|
|
cfg = dict( |
|
type='CustomDataset', |
|
pipeline=[], |
|
data_root=data_root, |
|
img_dir=[img_dir, img_dir], |
|
ann_dir=[ann_dir, ann_dir]) |
|
dataset = build_dataset(cfg) |
|
assert isinstance(dataset, ConcatDataset) |
|
assert len(dataset) == 10 |
|
|
|
|
|
cfg = dict( |
|
type='CustomDataset', |
|
pipeline=[], |
|
data_root=data_root, |
|
img_dir=img_dir, |
|
ann_dir=ann_dir, |
|
split=['splits/train.txt', 'splits/val.txt']) |
|
dataset = build_dataset(cfg) |
|
assert isinstance(dataset, ConcatDataset) |
|
assert len(dataset) == 5 |
|
|
|
|
|
cfg = dict( |
|
type='CustomDataset', |
|
pipeline=[], |
|
data_root=data_root, |
|
img_dir=img_dir, |
|
ann_dir=[ann_dir, ann_dir], |
|
split=['splits/train.txt', 'splits/val.txt']) |
|
dataset = build_dataset(cfg) |
|
assert isinstance(dataset, ConcatDataset) |
|
assert len(dataset) == 5 |
|
|
|
|
|
cfg = dict( |
|
type='CustomDataset', |
|
pipeline=[], |
|
data_root=data_root, |
|
img_dir=[img_dir, img_dir], |
|
test_mode=True) |
|
dataset = build_dataset(cfg) |
|
assert isinstance(dataset, ConcatDataset) |
|
assert len(dataset) == 10 |
|
|
|
|
|
cfg = dict( |
|
type='CustomDataset', |
|
pipeline=[], |
|
data_root=data_root, |
|
img_dir=[img_dir, img_dir], |
|
split=['splits/val.txt', 'splits/val.txt'], |
|
test_mode=True) |
|
dataset = build_dataset(cfg) |
|
assert isinstance(dataset, ConcatDataset) |
|
assert len(dataset) == 2 |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
cfg = dict( |
|
type='CustomDataset', |
|
pipeline=[], |
|
data_root=data_root, |
|
img_dir=[img_dir, img_dir], |
|
ann_dir=[ann_dir, ann_dir, ann_dir]) |
|
build_dataset(cfg) |
|
|
|
|
|
with pytest.raises(AssertionError): |
|
cfg = dict( |
|
type='CustomDataset', |
|
pipeline=[], |
|
data_root=data_root, |
|
img_dir=[img_dir, img_dir], |
|
split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) |
|
build_dataset(cfg) |
|
|
|
|
|
|
|
with pytest.raises(AssertionError): |
|
cfg = dict( |
|
type='CustomDataset', |
|
pipeline=[], |
|
data_root=data_root, |
|
img_dir=img_dir, |
|
ann_dir=[ann_dir, ann_dir], |
|
split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt']) |
|
build_dataset(cfg) |
|
|
|
|
|
def test_build_dataloader(): |
|
dataset = ToyDataset() |
|
samples_per_gpu = 3 |
|
|
|
dataloader = build_dataloader( |
|
dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2) |
|
assert dataloader.batch_size == samples_per_gpu |
|
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) |
|
assert isinstance(dataloader.sampler, DistributedSampler) |
|
assert dataloader.sampler.shuffle |
|
|
|
|
|
dataloader = build_dataloader( |
|
dataset, |
|
samples_per_gpu=samples_per_gpu, |
|
workers_per_gpu=2, |
|
shuffle=False) |
|
assert dataloader.batch_size == samples_per_gpu |
|
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) |
|
assert isinstance(dataloader.sampler, DistributedSampler) |
|
assert not dataloader.sampler.shuffle |
|
|
|
|
|
dataloader = build_dataloader( |
|
dataset, |
|
samples_per_gpu=samples_per_gpu, |
|
workers_per_gpu=2, |
|
num_gpus=8) |
|
assert dataloader.batch_size == samples_per_gpu |
|
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) |
|
assert dataloader.num_workers == 2 |
|
|
|
|
|
dataloader = build_dataloader( |
|
dataset, |
|
samples_per_gpu=samples_per_gpu, |
|
workers_per_gpu=2, |
|
dist=False) |
|
assert dataloader.batch_size == samples_per_gpu |
|
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) |
|
assert isinstance(dataloader.sampler, RandomSampler) |
|
assert dataloader.num_workers == 2 |
|
|
|
|
|
dataloader = build_dataloader( |
|
dataset, |
|
samples_per_gpu=3, |
|
workers_per_gpu=2, |
|
shuffle=False, |
|
dist=False) |
|
assert dataloader.batch_size == samples_per_gpu |
|
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu)) |
|
assert isinstance(dataloader.sampler, SequentialSampler) |
|
assert dataloader.num_workers == 2 |
|
|
|
|
|
dataloader = build_dataloader( |
|
dataset, samples_per_gpu=3, workers_per_gpu=2, num_gpus=8, dist=False) |
|
assert dataloader.batch_size == samples_per_gpu * 8 |
|
assert len(dataloader) == int( |
|
math.ceil(len(dataset) / samples_per_gpu / 8)) |
|
assert isinstance(dataloader.sampler, RandomSampler) |
|
assert dataloader.num_workers == 16 |
|
|