|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
import os |
|
import unittest |
|
import unittest.mock |
|
|
|
import torch |
|
from omegaconf import OmegaConf |
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( |
|
SequenceDataLoaderMapProvider, |
|
SimpleDataLoaderMapProvider, |
|
) |
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource |
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset |
|
from pytorch3d.implicitron.tools.config import get_default_args |
|
from tests.common_testing import get_tests_dir |
|
from tests.implicitron.common_resources import get_skateboard_data |
|
|
|
DATA_DIR = get_tests_dir() / "implicitron/data" |
|
DEBUG: bool = False |
|
|
|
|
|
class TestDataSource(unittest.TestCase): |
|
def setUp(self): |
|
self.maxDiff = None |
|
torch.manual_seed(42) |
|
|
|
stack = contextlib.ExitStack() |
|
self.dataset_root, self.path_manager = stack.enter_context( |
|
get_skateboard_data() |
|
) |
|
self.addCleanup(stack.close) |
|
|
|
def _test_omegaconf_generic_failure(self): |
|
|
|
from dataclasses import dataclass |
|
|
|
import torch |
|
|
|
@dataclass |
|
class D(torch.utils.data.Dataset[int]): |
|
a: int = 3 |
|
|
|
OmegaConf.structured(D) |
|
|
|
def _test_omegaconf_ListList(self): |
|
|
|
from dataclasses import dataclass |
|
from typing import Sequence |
|
|
|
@dataclass |
|
class A: |
|
a: Sequence[Sequence[int]] = ((32,),) |
|
|
|
OmegaConf.structured(A) |
|
|
|
def test_JsonIndexDataset_args(self): |
|
|
|
get_default_args(JsonIndexDataset) |
|
|
|
def test_one(self): |
|
cfg = get_default_args(ImplicitronDataSource) |
|
|
|
cfg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = "" |
|
cfg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = "" |
|
|
|
if "dataset_map_provider_SqlIndexDatasetMapProvider_args" in cfg: |
|
del cfg.dataset_map_provider_SqlIndexDatasetMapProvider_args |
|
yaml = OmegaConf.to_yaml(cfg, sort_keys=False) |
|
if DEBUG: |
|
(DATA_DIR / "data_source.yaml").write_text(yaml) |
|
self.assertEqual(yaml, (DATA_DIR / "data_source.yaml").read_text()) |
|
|
|
def test_default(self): |
|
if os.environ.get("INSIDE_RE_WORKER") is not None: |
|
return |
|
args = get_default_args(ImplicitronDataSource) |
|
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider" |
|
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args |
|
dataset_args.category = "skateboard" |
|
dataset_args.test_restrict_sequence_id = 0 |
|
dataset_args.n_frames_per_sequence = -1 |
|
|
|
dataset_args.dataset_root = self.dataset_root |
|
|
|
data_source = ImplicitronDataSource(**args) |
|
self.assertIsInstance( |
|
data_source.data_loader_map_provider, SequenceDataLoaderMapProvider |
|
) |
|
_, data_loaders = data_source.get_datasets_and_dataloaders() |
|
self.assertEqual(len(data_loaders.train), 81) |
|
for i in data_loaders.train: |
|
self.assertEqual(i.frame_type, ["test_known"]) |
|
break |
|
|
|
def test_simple(self): |
|
if os.environ.get("INSIDE_RE_WORKER") is not None: |
|
return |
|
args = get_default_args(ImplicitronDataSource) |
|
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider" |
|
args.data_loader_map_provider_class_type = "SimpleDataLoaderMapProvider" |
|
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args |
|
dataset_args.category = "skateboard" |
|
dataset_args.test_restrict_sequence_id = 0 |
|
dataset_args.n_frames_per_sequence = -1 |
|
|
|
dataset_args.dataset_root = self.dataset_root |
|
|
|
data_source = ImplicitronDataSource(**args) |
|
self.assertIsInstance( |
|
data_source.data_loader_map_provider, SimpleDataLoaderMapProvider |
|
) |
|
_, data_loaders = data_source.get_datasets_and_dataloaders() |
|
|
|
self.assertEqual(len(data_loaders.train), 81) |
|
for i in data_loaders.train: |
|
self.assertEqual(i.frame_type, ["test_known"]) |
|
break |
|
|