Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
import unittest | |
import torch | |
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa | |
SequenceDataLoaderMapProvider, | |
SimpleDataLoaderMapProvider, | |
) | |
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource | |
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset # noqa | |
from pytorch3d.implicitron.dataset.sql_dataset_provider import ( # noqa | |
SqlIndexDatasetMapProvider, | |
) | |
from pytorch3d.implicitron.dataset.train_eval_data_loader_provider import ( | |
TrainEvalDataLoaderMapProvider, | |
) | |
from pytorch3d.implicitron.tools.config import get_default_args | |
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset") | |
sh = logging.StreamHandler() | |
logger.addHandler(sh) | |
logger.setLevel(logging.DEBUG) | |
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "") | |
class TestCo3dSqlDataSource(unittest.TestCase): | |
def test_no_subsets(self): | |
args = get_default_args(ImplicitronDataSource) | |
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" | |
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" | |
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args | |
provider_args.ignore_subsets = True | |
dataset_args = provider_args.dataset_SqlIndexDataset_args | |
dataset_args.pick_categories = ["skateboard"] | |
dataset_args.limit_sequences_to = 1 | |
data_source = ImplicitronDataSource(**args) | |
self.assertIsInstance( | |
data_source.data_loader_map_provider, TrainEvalDataLoaderMapProvider | |
) | |
_, data_loaders = data_source.get_datasets_and_dataloaders() | |
self.assertEqual(len(data_loaders.train), 202) | |
for frame in data_loaders.train: | |
self.assertIsNone(frame.frame_type) | |
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs | |
break | |
def test_subsets(self): | |
args = get_default_args(ImplicitronDataSource) | |
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" | |
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args | |
provider_args.subset_lists_path = ( | |
"skateboard/set_lists/set_lists_manyview_dev_0.json" | |
) | |
# this will naturally limit to one sequence (no need to limit by cat/sequence) | |
dataset_args = provider_args.dataset_SqlIndexDataset_args | |
dataset_args.remove_empty_masks = True | |
for sampler_type in [ | |
"SimpleDataLoaderMapProvider", | |
"SequenceDataLoaderMapProvider", | |
"TrainEvalDataLoaderMapProvider", | |
]: | |
args.data_loader_map_provider_class_type = sampler_type | |
data_source = ImplicitronDataSource(**args) | |
_, data_loaders = data_source.get_datasets_and_dataloaders() | |
self.assertEqual(len(data_loaders.train), 102) | |
self.assertEqual(len(data_loaders.val), 100) | |
self.assertEqual(len(data_loaders.test), 100) | |
for split in ["train", "val", "test"]: | |
for frame in data_loaders[split]: | |
self.assertEqual(frame.frame_type, [split]) | |
# check loading blobs | |
self.assertEqual(frame.image_rgb.shape[-1], 800) | |
break | |
def test_sql_subsets(self): | |
args = get_default_args(ImplicitronDataSource) | |
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" | |
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args | |
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite" | |
dataset_args = provider_args.dataset_SqlIndexDataset_args | |
dataset_args.remove_empty_masks = True | |
dataset_args.pick_categories = ["skateboard"] | |
for sampler_type in [ | |
"SimpleDataLoaderMapProvider", | |
"SequenceDataLoaderMapProvider", | |
"TrainEvalDataLoaderMapProvider", | |
]: | |
args.data_loader_map_provider_class_type = sampler_type | |
data_source = ImplicitronDataSource(**args) | |
_, data_loaders = data_source.get_datasets_and_dataloaders() | |
self.assertEqual(len(data_loaders.train), 102) | |
self.assertEqual(len(data_loaders.val), 100) | |
self.assertEqual(len(data_loaders.test), 100) | |
for split in ["train", "val", "test"]: | |
for frame in data_loaders[split]: | |
self.assertEqual(frame.frame_type, [split]) | |
self.assertEqual( | |
frame.image_rgb.shape[-1], 800 | |
) # check loading blobs | |
break | |
def test_huge_subsets(self): | |
args = get_default_args(ImplicitronDataSource) | |
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" | |
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" | |
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args | |
provider_args.subset_lists_path = "set_lists/set_lists_fewview_dev.sqlite" | |
dataset_args = provider_args.dataset_SqlIndexDataset_args | |
dataset_args.remove_empty_masks = True | |
data_source = ImplicitronDataSource(**args) | |
_, data_loaders = data_source.get_datasets_and_dataloaders() | |
self.assertEqual(len(data_loaders.train), 3158974) | |
self.assertEqual(len(data_loaders.val), 518417) | |
self.assertEqual(len(data_loaders.test), 518417) | |
for split in ["train", "val", "test"]: | |
for frame in data_loaders[split]: | |
self.assertEqual(frame.frame_type, [split]) | |
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs | |
break | |
def test_broken_subsets(self): | |
args = get_default_args(ImplicitronDataSource) | |
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" | |
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" | |
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args | |
provider_args.subset_lists_path = "et_non_est" | |
provider_args.dataset_SqlIndexDataset_args.pick_categories = ["skateboard"] | |
with self.assertRaises(FileNotFoundError) as err: | |
ImplicitronDataSource(**args) | |
# check the hint text | |
self.assertIn("Subset lists path given but not found", str(err.exception)) | |
def test_eval_batches(self): | |
args = get_default_args(ImplicitronDataSource) | |
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" | |
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" | |
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args | |
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite" | |
provider_args.eval_batches_path = ( | |
"skateboard/eval_batches/eval_batches_manyview_dev_0.json" | |
) | |
dataset_args = provider_args.dataset_SqlIndexDataset_args | |
dataset_args.remove_empty_masks = True | |
dataset_args.pick_categories = ["skateboard"] | |
data_source = ImplicitronDataSource(**args) | |
_, data_loaders = data_source.get_datasets_and_dataloaders() | |
self.assertEqual(len(data_loaders.train), 102) | |
self.assertEqual(len(data_loaders.val), 100) | |
self.assertEqual(len(data_loaders.test), 50) | |
for split in ["train", "val", "test"]: | |
for frame in data_loaders[split]: | |
self.assertEqual(frame.frame_type, [split]) | |
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs | |
break | |
def test_eval_batches_from_subset_list_name(self): | |
args = get_default_args(ImplicitronDataSource) | |
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" | |
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" | |
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args | |
provider_args.subset_list_name = "manyview_dev_0" | |
provider_args.category = "skateboard" | |
dataset_args = provider_args.dataset_SqlIndexDataset_args | |
dataset_args.remove_empty_masks = True | |
data_source = ImplicitronDataSource(**args) | |
dataset, data_loaders = data_source.get_datasets_and_dataloaders() | |
self.assertListEqual(list(dataset.train.pick_categories), ["skateboard"]) | |
self.assertEqual(len(data_loaders.train), 102) | |
self.assertEqual(len(data_loaders.val), 100) | |
self.assertEqual(len(data_loaders.test), 50) | |
for split in ["train", "val", "test"]: | |
for frame in data_loaders[split]: | |
self.assertEqual(frame.frame_type, [split]) | |
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs | |
break | |
def test_frame_access(self): | |
args = get_default_args(ImplicitronDataSource) | |
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider" | |
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider" | |
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args | |
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite" | |
dataset_args = provider_args.dataset_SqlIndexDataset_args | |
dataset_args.remove_empty_masks = True | |
dataset_args.pick_categories = ["skateboard"] | |
frame_builder_args = dataset_args.frame_data_builder_FrameDataBuilder_args | |
frame_builder_args.load_point_clouds = True | |
frame_builder_args.box_crop = False # required for .meta | |
data_source = ImplicitronDataSource(**args) | |
dataset_map, _ = data_source.get_datasets_and_dataloaders() | |
dataset = dataset_map["train"] | |
for idx in [10, ("245_26182_52130", 22)]: | |
example_meta = dataset.meta[idx] | |
example = dataset[idx] | |
self.assertIsNone(example_meta.image_rgb) | |
self.assertIsNone(example_meta.fg_probability) | |
self.assertIsNone(example_meta.depth_map) | |
self.assertIsNone(example_meta.sequence_point_cloud) | |
self.assertIsNotNone(example_meta.camera) | |
self.assertIsNotNone(example.image_rgb) | |
self.assertIsNotNone(example.fg_probability) | |
self.assertIsNotNone(example.depth_map) | |
self.assertIsNotNone(example.sequence_point_cloud) | |
self.assertIsNotNone(example.camera) | |
self.assertEqual(example_meta.sequence_name, example.sequence_name) | |
self.assertEqual(example_meta.frame_number, example.frame_number) | |
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp) | |
self.assertEqual(example_meta.sequence_category, example.sequence_category) | |
torch.testing.assert_close(example_meta.camera.R, example.camera.R) | |
torch.testing.assert_close(example_meta.camera.T, example.camera.T) | |
torch.testing.assert_close( | |
example_meta.camera.focal_length, example.camera.focal_length | |
) | |
torch.testing.assert_close( | |
example_meta.camera.principal_point, example.camera.principal_point | |
) | |