Spaces:
Sleeping
Sleeping
File size: 4,309 Bytes
500565b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import pickle
import sys
import unittest
from functools import partial
import torch
from iopath.common.file_io import LazyPath
from detectron2 import model_zoo
from detectron2.config import instantiate
from detectron2.data import (
DatasetFromList,
MapDataset,
ToIterableDataset,
build_batch_data_loader,
build_detection_test_loader,
build_detection_train_loader,
)
from detectron2.data.samplers import InferenceSampler, TrainingSampler
def _a_slow_func(x):
return "path/{}".format(x)
class TestDatasetFromList(unittest.TestCase):
# Failing for py3.6, likely due to pickle
@unittest.skipIf(sys.version_info.minor <= 6, "Not supported in Python 3.6")
def test_using_lazy_path(self):
dataset = []
for i in range(10):
dataset.append({"file_name": LazyPath(partial(_a_slow_func, i))})
dataset = DatasetFromList(dataset)
for i in range(10):
path = dataset[i]["file_name"]
self.assertTrue(isinstance(path, LazyPath))
self.assertEqual(os.fspath(path), _a_slow_func(i))
class TestMapDataset(unittest.TestCase):
@staticmethod
def map_func(x):
if x == 2:
return None
return x * 2
def test_map_style(self):
ds = DatasetFromList([1, 2, 3])
ds = MapDataset(ds, TestMapDataset.map_func)
self.assertEqual(ds[0], 2)
self.assertEqual(ds[2], 6)
self.assertIn(ds[1], [2, 6])
def test_iter_style(self):
class DS(torch.utils.data.IterableDataset):
def __iter__(self):
yield from [1, 2, 3]
ds = DS()
ds = MapDataset(ds, TestMapDataset.map_func)
self.assertIsInstance(ds, torch.utils.data.IterableDataset)
data = list(iter(ds))
self.assertEqual(data, [2, 6])
def test_pickleability(self):
ds = DatasetFromList([1, 2, 3])
ds = MapDataset(ds, lambda x: x * 2)
ds = pickle.loads(pickle.dumps(ds))
self.assertEqual(ds[0], 2)
@unittest.skipIf(os.environ.get("CI"), "Skipped OSS testing due to COCO data requirement.")
class TestDataLoader(unittest.TestCase):
def _get_kwargs(self):
# get kwargs of build_detection_train_loader
cfg = model_zoo.get_config("common/data/coco.py").dataloader.train
cfg.dataset.names = "coco_2017_val_100"
cfg.pop("_target_")
kwargs = {k: instantiate(v) for k, v in cfg.items()}
return kwargs
def test_build_dataloader_train(self):
kwargs = self._get_kwargs()
dl = build_detection_train_loader(**kwargs)
next(iter(dl))
def test_build_iterable_dataloader_train(self):
kwargs = self._get_kwargs()
ds = DatasetFromList(kwargs.pop("dataset"))
ds = ToIterableDataset(ds, TrainingSampler(len(ds)))
dl = build_detection_train_loader(dataset=ds, **kwargs)
next(iter(dl))
def _check_is_range(self, data_loader, N):
# check that data_loader produces range(N)
data = list(iter(data_loader))
data = [x for batch in data for x in batch] # flatten the batches
self.assertEqual(len(data), N)
self.assertEqual(set(data), set(range(N)))
def test_build_batch_dataloader_inference(self):
# Test that build_batch_data_loader can be used for inference
N = 96
ds = DatasetFromList(list(range(N)))
sampler = InferenceSampler(len(ds))
dl = build_batch_data_loader(ds, sampler, 8, num_workers=3)
self._check_is_range(dl, N)
def test_build_dataloader_inference(self):
N = 50
ds = DatasetFromList(list(range(N)))
sampler = InferenceSampler(len(ds))
dl = build_detection_test_loader(
dataset=ds, sampler=sampler, mapper=lambda x: x, num_workers=3
)
self._check_is_range(dl, N)
def test_build_iterable_dataloader_inference(self):
# Test that build_detection_test_loader supports iterable dataset
N = 50
ds = DatasetFromList(list(range(N)))
ds = ToIterableDataset(ds, InferenceSampler(len(ds)))
dl = build_detection_test_loader(dataset=ds, mapper=lambda x: x, num_workers=3)
self._check_is_range(dl, N)
|