UniK3D-demo / unik3d /datasets /base_dataset.py
Luigi Piccinelli
init demo
1ea89dd
import os
from abc import abstractmethod
from copy import deepcopy
from math import ceil, log
from typing import Any, Dict, Tuple
import numpy as np
import torch
from torch.utils.data import Dataset
import unik3d.datasets.pipelines as pipelines
from unik3d.utils import (eval_3d, eval_depth, identity, is_main_process,
recursive_index, sync_tensor_across_gpus)
from unik3d.utils.constants import (IMAGENET_DATASET_MEAN,
IMAGENET_DATASET_STD, OPENAI_DATASET_MEAN,
OPENAI_DATASET_STD)
class BaseDataset(Dataset):
min_depth = 0.01
max_depth = 1000.0
def __init__(
self,
image_shape: Tuple[int, int],
split_file: str,
test_mode: bool,
normalize: bool,
augmentations_db: Dict[str, Any],
shape_constraints: Dict[str, Any],
resize_method: str,
mini: float,
num_copies: int = 1,
**kwargs,
) -> None:
super().__init__()
assert normalize in [None, "imagenet", "openai"]
self.split_file = split_file
self.test_mode = test_mode
self.data_root = os.environ["DATAROOT"]
self.image_shape = image_shape
self.resize_method = resize_method
self.mini = mini
self.num_frames = 1
self.num_copies = num_copies
self.metrics_store = {}
self.metrics_count = {}
if normalize == "imagenet":
self.normalization_stats = {
"mean": torch.tensor(IMAGENET_DATASET_MEAN),
"std": torch.tensor(IMAGENET_DATASET_STD),
}
elif normalize == "openai":
self.normalization_stats = {
"mean": torch.tensor(OPENAI_DATASET_MEAN),
"std": torch.tensor(OPENAI_DATASET_STD),
}
else:
self.normalization_stats = {
"mean": torch.tensor([0.0, 0.0, 0.0]),
"std": torch.tensor([1.0, 1.0, 1.0]),
}
for k, v in augmentations_db.items():
setattr(self, k, v)
self.shape_constraints = shape_constraints
if not self.test_mode:
self._augmentation_space()
self.masker = pipelines.AnnotationMask(
min_value=0.0,
max_value=self.max_depth if test_mode else None,
custom_fn=identity,
)
self.filler = pipelines.RandomFiller(test_mode=test_mode)
shape_mult = self.shape_constraints["shape_mult"]
self.image_shape = [
ceil(self.image_shape[0] / shape_mult) * shape_mult,
ceil(self.image_shape[1] / shape_mult) * shape_mult,
]
self.resizer = pipelines.ContextCrop(
image_shape=self.image_shape,
train_ctx_range=(1.0 / self.random_scale, 1.0 * self.random_scale),
test_min_ctx=self.test_context,
keep_original=test_mode,
shape_constraints=self.shape_constraints,
)
self.collecter = pipelines.Collect(
keys=["image_fields", "mask_fields", "gt_fields", "camera_fields"]
)
def __len__(self):
return len(self.dataset)
def pack_batch(self, results):
results["paddings"] = [
results[x]["paddings"][0] for x in results["sequence_fields"]
]
for fields_name in [
"image_fields",
"gt_fields",
"mask_fields",
"camera_fields",
]:
fields = results.get(fields_name)
packed = {
field: torch.cat(
[results[seq][field] for seq in results["sequence_fields"]]
)
for field in fields
}
results.update(packed)
return results
def unpack_batch(self, results):
for fields_name in [
"image_fields",
"gt_fields",
"mask_fields",
"camera_fields",
]:
fields = results.get(fields_name)
unpacked = {
field: {
seq: results[field][idx : idx + 1]
for idx, seq in enumerate(results["sequence_fields"])
}
for field in fields
}
results.update(unpacked)
return results
def _augmentation_space(self):
self.augmentations_dict = {
"Flip": pipelines.RandomFlip(prob=self.flip_p),
"Jitter": pipelines.RandomColorJitter(
(-self.random_jitter, self.random_jitter), prob=self.jitter_p
),
"Gamma": pipelines.RandomGamma(
(-self.random_gamma, self.random_gamma), prob=self.gamma_p
),
"Blur": pipelines.GaussianBlur(
kernel_size=13, sigma=(0.1, self.random_blur), prob=self.blur_p
),
"Grayscale": pipelines.RandomGrayscale(prob=self.grayscale_p),
}
def augment(self, results):
for name, aug in self.augmentations_dict.items():
results = aug(results)
return results
def prepare_depth_eval(self, inputs, preds):
new_preds = {}
keyframe_idx = getattr(self, "keyframe_idx", None)
slice_idx = slice(
keyframe_idx, keyframe_idx + 1 if keyframe_idx is not None else None
)
new_gts = inputs["depth"][slice_idx]
new_masks = inputs["depth_mask"][slice_idx].bool()
for key, val in preds.items():
if "depth" in key:
new_preds[key] = val[slice_idx]
return new_gts, new_preds, new_masks
def prepare_points_eval(self, inputs, preds):
new_preds = {}
new_gts = inputs["points"]
new_masks = inputs["depth_mask"].bool()
if "points_mask" in inputs:
new_masks = inputs["points_mask"].bool()
for key, val in preds.items():
if "points" in key:
new_preds[key] = val
return new_gts, new_preds, new_masks
def add_points(self, inputs):
inputs["points"] = inputs.get("camera_original", inputs["camera"]).reconstruct(
inputs["depth"]
)
return inputs
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def accumulate_metrics(
self,
inputs,
preds,
keyframe_idx=None,
metrics=["depth", "points", "flow_fwd", "pairwise"],
):
if "depth" in inputs and "points" not in inputs:
inputs = self.add_points(inputs)
available_metrics = []
for metric in metrics:
metric_in_gt = any((metric in k for k in inputs.keys()))
metric_in_pred = any((metric in k for k in preds.keys()))
if metric_in_gt and metric_in_pred:
available_metrics.append(metric)
if keyframe_idx is not None:
inputs = recursive_index(inputs, slice(keyframe_idx, keyframe_idx + 1))
preds = recursive_index(preds, slice(keyframe_idx, keyframe_idx + 1))
if "depth" in available_metrics:
depth_gt, depth_pred, depth_masks = self.prepare_depth_eval(inputs, preds)
self.accumulate_metrics_depth(depth_gt, depth_pred, depth_masks)
if "points" in available_metrics:
points_gt, points_pred, points_masks = self.prepare_points_eval(
inputs, preds
)
self.accumulate_metrics_3d(points_gt, points_pred, points_masks)
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def accumulate_metrics_depth(self, gts, preds, masks):
for eval_type, pred in preds.items():
log_name = eval_type.replace("depth", "").strip("-").strip("_")
if log_name not in self.metrics_store:
self.metrics_store[log_name] = {}
current_count = self.metrics_count.get(
log_name, torch.tensor([], device=gts.device)
)
new_count = masks.view(gts.shape[0], -1).sum(dim=-1)
self.metrics_count[log_name] = torch.cat([current_count, new_count])
for k, v in eval_depth(gts, pred, masks, max_depth=self.max_depth).items():
current_metric = self.metrics_store[log_name].get(
k, torch.tensor([], device=gts.device)
)
self.metrics_store[log_name][k] = torch.cat([current_metric, v])
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def accumulate_metrics_3d(self, gts, preds, masks):
thresholds = torch.linspace(
log(self.min_depth),
log(self.max_depth / 20),
steps=100,
device=gts.device,
).exp()
for eval_type, pred in preds.items():
log_name = eval_type.replace("points", "").strip("-").strip("_")
if log_name not in self.metrics_store:
self.metrics_store[log_name] = {}
current_count = self.metrics_count.get(
log_name, torch.tensor([], device=gts.device)
)
new_count = masks.view(gts.shape[0], -1).sum(dim=-1)
self.metrics_count[log_name] = torch.cat([current_count, new_count])
for k, v in eval_3d(gts, pred, masks, thresholds=thresholds).items():
current_metric = self.metrics_store[log_name].get(
k, torch.tensor([], device=gts.device)
)
self.metrics_store[log_name][k] = torch.cat([current_metric, v])
def get_evaluation(self, metrics=None):
metric_vals = {}
for eval_type in metrics if metrics is not None else self.metrics_store.keys():
assert self.metrics_store[eval_type]
cnts = sync_tensor_across_gpus(self.metrics_count[eval_type])
for name, val in self.metrics_store[eval_type].items():
# vals_r = (sync_tensor_across_gpus(val) * cnts / cnts.sum()).sum()
vals_r = sync_tensor_across_gpus(val).mean()
metric_vals[f"{eval_type}_{name}".strip("_")] = np.round(
vals_r.cpu().item(), 5
)
self.metrics_store[eval_type] = {}
self.metrics_count = {}
return metric_vals
def replicate(self, results):
for i in range(1, self.num_copies):
results[(0, i)] = {k: deepcopy(v) for k, v in results[(0, 0)].items()}
results["sequence_fields"].append((0, i))
return results
def log_load_dataset(self):
if is_main_process():
info = f"Loaded {self.__class__.__name__} with {len(self)} images."
print(info)
def pre_pipeline(self, results):
results["image_fields"] = results.get("image_fields", set())
results["gt_fields"] = results.get("gt_fields", set())
results["mask_fields"] = results.get("mask_fields", set())
results["sequence_fields"] = results.get("sequence_fields", set())
results["camera_fields"] = results.get("camera_fields", set())
results["dataset_name"] = (
[self.__class__.__name__] * self.num_frames * self.num_copies
)
results["depth_scale"] = [self.depth_scale] * self.num_frames * self.num_copies
results["si"] = [False] * self.num_frames * self.num_copies
results["dense"] = [False] * self.num_frames * self.num_copies
results["synthetic"] = [False] * self.num_frames * self.num_copies
results["quality"] = [0] * self.num_frames * self.num_copies
results["valid_camera"] = [True] * self.num_frames * self.num_copies
results["valid_pose"] = [True] * self.num_frames * self.num_copies
return results
def eval_mask(self, valid_mask):
return valid_mask
def chunk(self, dataset, chunk_dim=1, pct=1.0):
subsampled_datasets = [
x
for i in range(0, len(dataset), int(1 / pct * chunk_dim))
for x in dataset[i : i + chunk_dim]
]
return subsampled_datasets
@abstractmethod
def preprocess(self, results):
raise NotImplementedError
@abstractmethod
def postprocess(self, results):
raise NotImplementedError
@abstractmethod
def get_mapper(self):
raise NotImplementedError
@abstractmethod
def get_intrinsics(self, idx, image_name):
raise NotImplementedError
@abstractmethod
def get_extrinsics(self, idx, image_name):
raise NotImplementedError
@abstractmethod
def load_dataset(self):
raise NotImplementedError
@abstractmethod
def get_single_item(self, idx, sample=None, mapper=None):
raise NotImplementedError
@abstractmethod
def __getitem__(self, idx):
raise NotImplementedError