|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import unittest |
|
from collections import Counter |
|
|
|
import pkg_resources |
|
|
|
import torch |
|
|
|
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset |
|
|
|
NO_BLOBS_KWARGS = { |
|
"dataset_root": "", |
|
"load_images": False, |
|
"load_depths": False, |
|
"load_masks": False, |
|
"load_depth_masks": False, |
|
"box_crop": False, |
|
} |
|
|
|
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset") |
|
sh = logging.StreamHandler() |
|
logger.addHandler(sh) |
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset") |
|
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite") |
|
SET_LIST_FILE = os.path.join(DATASET_ROOT, "set_lists_100.json") |
|
|
|
|
|
class TestSqlDataset(unittest.TestCase): |
|
def test_basic(self, sequence="cat1_seq2", frame_number=4): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 100) |
|
|
|
|
|
past_sequences = set() |
|
last_frame_number = -1 |
|
last_sequence = "" |
|
for i in range(len(dataset)): |
|
item = dataset[i] |
|
|
|
if item.frame_number == 0: |
|
self.assertNotIn(item.sequence_name, past_sequences) |
|
past_sequences.add(item.sequence_name) |
|
last_sequence = item.sequence_name |
|
else: |
|
self.assertEqual(item.sequence_name, last_sequence) |
|
self.assertEqual(item.frame_number, last_frame_number + 1) |
|
|
|
last_frame_number = item.frame_number |
|
|
|
|
|
with self.assertRaises(IndexError): |
|
dataset[len(dataset) + 1] |
|
|
|
|
|
item = dataset[sequence, frame_number] |
|
self.assertEqual(item.sequence_name, sequence) |
|
self.assertEqual(item.frame_number, frame_number) |
|
|
|
with self.assertRaises(IndexError): |
|
dataset[sequence, 13] |
|
|
|
def test_filter_empty_masks(self): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=True, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 78) |
|
|
|
def test_pick_frames_sql_clause(self): |
|
dataset_no_empty_masks = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=True, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0", |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
|
|
self.assertEqual(len(dataset), len(dataset_no_empty_masks)) |
|
for i in range(len(dataset)): |
|
item_nem = dataset_no_empty_masks[i] |
|
item = dataset[i] |
|
self.assertEqual(item_nem.image_path, item.image_path) |
|
|
|
|
|
dataset_ts = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=True, |
|
pick_frames_sql_clause="frame_timestamp < 0.15", |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
self.assertEqual(len(dataset_ts), 19) |
|
|
|
def test_limit_categories(self, category="cat0"): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
pick_categories=[category], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 50) |
|
for i in range(len(dataset)): |
|
self.assertEqual(dataset[i].sequence_category, category) |
|
|
|
def test_limit_sequences(self, num_sequences=3): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
limit_sequences_to=num_sequences, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 10 * num_sequences) |
|
|
|
def delist(sequence_name): |
|
return sequence_name if isinstance(sequence_name, str) else sequence_name[0] |
|
|
|
unique_seqs = {delist(dataset[i].sequence_name) for i in range(len(dataset))} |
|
self.assertEqual(len(unique_seqs), num_sequences) |
|
|
|
def test_pick_exclude_sequencess(self, sequence="cat1_seq2"): |
|
|
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
pick_sequences=[sequence], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 10) |
|
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))} |
|
self.assertCountEqual(unique_seqs, {sequence}) |
|
|
|
item = dataset[sequence, 0] |
|
self.assertEqual(item.sequence_name, sequence) |
|
self.assertEqual(item.frame_number, 0) |
|
|
|
|
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
exclude_sequences=[sequence], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 90) |
|
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))} |
|
self.assertNotIn(sequence, unique_seqs) |
|
|
|
with self.assertRaises(IndexError): |
|
dataset[sequence, 0] |
|
|
|
def test_limit_frames(self, num_frames=13): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
limit_to=num_frames, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), num_frames) |
|
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))} |
|
self.assertEqual(len(unique_seqs), 2) |
|
|
|
|
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
limit_to=1000, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 100) |
|
|
|
def test_limit_frames_per_sequence(self, num_frames=2): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
n_frames_per_sequence=num_frames, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), num_frames * 10) |
|
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset))) |
|
self.assertEqual(len(seq_counts), 10) |
|
self.assertCountEqual( |
|
set(seq_counts.values()), {2} |
|
) |
|
|
|
with self.assertRaises(IndexError): |
|
dataset[next(iter(seq_counts)), num_frames + 1] |
|
|
|
|
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
n_frames_per_sequence=13, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
self.assertEqual(len(dataset), 100) |
|
|
|
def test_limit_sequence_per_category(self, num_sequences=2): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
limit_sequences_per_category_to=num_sequences, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), num_sequences * 10 * 2) |
|
seq_names = list(dataset.sequence_names()) |
|
self.assertEqual(len(seq_names), num_sequences * 2) |
|
|
|
for seq_name in seq_names: |
|
self.assertLess(int(seq_name[-1]), num_sequences) |
|
|
|
|
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
limit_sequences_per_category_to=13, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
self.assertEqual(len(dataset), 100) |
|
|
|
def test_filter_medley(self): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=True, |
|
pick_categories=["cat1"], |
|
exclude_sequences=["cat1_seq0"], |
|
limit_sequences_to=2, |
|
limit_to=14, |
|
n_frames_per_sequence=6, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
|
|
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset))) |
|
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"]) |
|
self.assertEqual(seq_counts["cat1_seq1"], 6) |
|
self.assertEqual(seq_counts["cat1_seq2"], 4) |
|
|
|
def test_subsets_trivial(self): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
subset_lists_file=SET_LIST_FILE, |
|
limit_to=100, |
|
subsets=["train", "test"], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 100) |
|
|
|
|
|
past_sequences = set() |
|
last_frame_number = -1 |
|
last_sequence = "" |
|
for i in range(len(dataset)): |
|
item = dataset[i] |
|
|
|
if item.frame_number == 0: |
|
self.assertNotIn(item.sequence_name, past_sequences) |
|
past_sequences.add(item.sequence_name) |
|
last_sequence = item.sequence_name |
|
else: |
|
self.assertEqual(item.sequence_name, last_sequence) |
|
self.assertEqual(item.frame_number, last_frame_number + 1) |
|
|
|
last_frame_number = item.frame_number |
|
|
|
def test_subsets_filter_empty_masks(self): |
|
|
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=True, |
|
subset_lists_file=SET_LIST_FILE, |
|
subsets=["train", "test"], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 78) |
|
|
|
def test_subsets_pick_frames_sql_clause(self): |
|
dataset_no_empty_masks = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=True, |
|
subset_lists_file=SET_LIST_FILE, |
|
subsets=["train", "test"], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0", |
|
subset_lists_file=SET_LIST_FILE, |
|
subsets=["train", "test"], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
|
|
self.assertEqual(len(dataset), len(dataset_no_empty_masks)) |
|
for i in range(len(dataset)): |
|
item_nem = dataset_no_empty_masks[i] |
|
item = dataset[i] |
|
self.assertEqual(item_nem.image_path, item.image_path) |
|
|
|
|
|
dataset_ts = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=True, |
|
pick_frames_sql_clause="frame_timestamp < 0.15", |
|
subset_lists_file=SET_LIST_FILE, |
|
subsets=["train", "test"], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset_ts), 19) |
|
|
|
def test_single_subset(self): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
subset_lists_file=SET_LIST_FILE, |
|
subsets=["train"], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 50) |
|
|
|
with self.assertRaises(IndexError): |
|
dataset[51] |
|
|
|
|
|
past_sequences = set() |
|
last_frame_number = -1 |
|
last_sequence = "" |
|
for i in range(len(dataset)): |
|
item = dataset[i] |
|
|
|
if item.frame_number < 2: |
|
self.assertNotIn(item.sequence_name, past_sequences) |
|
past_sequences.add(item.sequence_name) |
|
last_sequence = item.sequence_name |
|
else: |
|
self.assertEqual(item.sequence_name, last_sequence) |
|
self.assertEqual(item.frame_number, last_frame_number + 2) |
|
|
|
last_frame_number = item.frame_number |
|
|
|
item = dataset[last_sequence, 0] |
|
self.assertEqual(item.sequence_name, last_sequence) |
|
|
|
with self.assertRaises(IndexError): |
|
dataset[last_sequence, 1] |
|
|
|
def test_subset_with_filters(self): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=True, |
|
subset_lists_file=SET_LIST_FILE, |
|
subsets=["train"], |
|
pick_categories=["cat1"], |
|
exclude_sequences=["cat1_seq0"], |
|
limit_sequences_to=2, |
|
limit_to=7, |
|
n_frames_per_sequence=3, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
|
|
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset))) |
|
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"]) |
|
self.assertEqual(seq_counts["cat1_seq1"], 3) |
|
self.assertEqual(seq_counts["cat1_seq2"], 2) |
|
|
|
def test_visitor(self): |
|
dataset_sorted = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
sequences = dataset_sorted.sequence_names() |
|
i = 0 |
|
for seq in sequences: |
|
last_ts = float("-Inf") |
|
for ts, _, idx in dataset_sorted.sequence_frames_in_order(seq): |
|
self.assertEqual(i, idx) |
|
i += 1 |
|
self.assertGreaterEqual(ts, last_ts) |
|
last_ts = ts |
|
|
|
|
|
old_indices = None |
|
for seq in sequences: |
|
last_ts = float("-Inf") |
|
rows = dataset_sorted._index.index.get_loc(seq) |
|
indices = list(range(rows.start or 0, rows.stop, rows.step or 1)) |
|
fn_ts_list = dataset_sorted.get_frame_numbers_and_timestamps(indices) |
|
self.assertEqual(len(fn_ts_list), len(indices)) |
|
|
|
if old_indices: |
|
|
|
with self.assertRaises(ValueError): |
|
dataset_sorted.get_frame_numbers_and_timestamps( |
|
indices + old_indices |
|
) |
|
|
|
old_indices = indices |
|
|
|
def test_visitor_subsets(self): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
limit_to=100, |
|
subset_lists_file=SET_LIST_FILE, |
|
subsets=["train", "test"], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
sequences = dataset.sequence_names() |
|
i = 0 |
|
for seq in sequences: |
|
last_ts = float("-Inf") |
|
seq_frames = list(dataset.sequence_frames_in_order(seq)) |
|
self.assertEqual(len(seq_frames), 10) |
|
for ts, _, idx in seq_frames: |
|
self.assertEqual(i, idx) |
|
i += 1 |
|
self.assertGreaterEqual(ts, last_ts) |
|
last_ts = ts |
|
|
|
last_ts = float("-Inf") |
|
train_frames = list(dataset.sequence_frames_in_order(seq, "train")) |
|
self.assertEqual(len(train_frames), 5) |
|
for ts, _, _ in train_frames: |
|
self.assertGreaterEqual(ts, last_ts) |
|
last_ts = ts |
|
|
|
def test_category_to_sequence_names(self): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
subset_lists_file=SET_LIST_FILE, |
|
subsets=["train", "test"], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
cat_to_seqs = dataset.category_to_sequence_names() |
|
self.assertEqual(len(cat_to_seqs), 2) |
|
self.assertIn("cat1", cat_to_seqs) |
|
self.assertEqual(len(cat_to_seqs["cat1"]), 5) |
|
|
|
|
|
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names() |
|
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base) |
|
|
|
def test_category_to_sequence_names_filters(self): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=True, |
|
subset_lists_file=SET_LIST_FILE, |
|
exclude_sequences=["cat1_seq0"], |
|
subsets=["train", "test"], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
cat_to_seqs = dataset.category_to_sequence_names() |
|
self.assertEqual(len(cat_to_seqs), 2) |
|
self.assertIn("cat1", cat_to_seqs) |
|
self.assertEqual(len(cat_to_seqs["cat1"]), 4) |
|
|
|
|
|
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names() |
|
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base) |
|
|
|
def test_meta_access(self): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
subset_lists_file=SET_LIST_FILE, |
|
subsets=["train"], |
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS, |
|
) |
|
|
|
self.assertEqual(len(dataset), 50) |
|
|
|
for idx in [10, ("cat0_seq2", 2)]: |
|
example_meta = dataset.meta[idx] |
|
example = dataset[idx] |
|
self.assertEqual(example_meta.sequence_name, example.sequence_name) |
|
self.assertEqual(example_meta.frame_number, example.frame_number) |
|
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp) |
|
self.assertEqual(example_meta.sequence_category, example.sequence_category) |
|
torch.testing.assert_close(example_meta.camera.R, example.camera.R) |
|
torch.testing.assert_close(example_meta.camera.T, example.camera.T) |
|
torch.testing.assert_close( |
|
example_meta.camera.focal_length, example.camera.focal_length |
|
) |
|
torch.testing.assert_close( |
|
example_meta.camera.principal_point, example.camera.principal_point |
|
) |
|
|
|
def test_meta_access_no_blobs(self): |
|
dataset = SqlIndexDataset( |
|
sqlite_metadata_file=METADATA_FILE, |
|
remove_empty_masks=False, |
|
subset_lists_file=SET_LIST_FILE, |
|
subsets=["train"], |
|
frame_data_builder_FrameDataBuilder_args={ |
|
"dataset_root": ".", |
|
"box_crop": False, |
|
}, |
|
) |
|
|
|
self.assertIsNone(dataset.meta[0].image_rgb) |
|
self.assertIsNone(dataset.meta[0].fg_probability) |
|
self.assertIsNone(dataset.meta[0].depth_map) |
|
self.assertIsNone(dataset.meta[0].sequence_point_cloud) |
|
self.assertIsNotNone(dataset.meta[0].camera) |
|
|