Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from abc import abstractmethod | |
from copy import deepcopy | |
from math import ceil, log | |
from typing import Any, Dict, Tuple | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
import unik3d.datasets.pipelines as pipelines | |
from unik3d.utils import (eval_3d, eval_depth, identity, is_main_process, | |
recursive_index, sync_tensor_across_gpus) | |
from unik3d.utils.constants import (IMAGENET_DATASET_MEAN, | |
IMAGENET_DATASET_STD, OPENAI_DATASET_MEAN, | |
OPENAI_DATASET_STD) | |
class BaseDataset(Dataset): | |
min_depth = 0.01 | |
max_depth = 1000.0 | |
def __init__( | |
self, | |
image_shape: Tuple[int, int], | |
split_file: str, | |
test_mode: bool, | |
normalize: bool, | |
augmentations_db: Dict[str, Any], | |
shape_constraints: Dict[str, Any], | |
resize_method: str, | |
mini: float, | |
num_copies: int = 1, | |
**kwargs, | |
) -> None: | |
super().__init__() | |
assert normalize in [None, "imagenet", "openai"] | |
self.split_file = split_file | |
self.test_mode = test_mode | |
self.data_root = os.environ["DATAROOT"] | |
self.image_shape = image_shape | |
self.resize_method = resize_method | |
self.mini = mini | |
self.num_frames = 1 | |
self.num_copies = num_copies | |
self.metrics_store = {} | |
self.metrics_count = {} | |
if normalize == "imagenet": | |
self.normalization_stats = { | |
"mean": torch.tensor(IMAGENET_DATASET_MEAN), | |
"std": torch.tensor(IMAGENET_DATASET_STD), | |
} | |
elif normalize == "openai": | |
self.normalization_stats = { | |
"mean": torch.tensor(OPENAI_DATASET_MEAN), | |
"std": torch.tensor(OPENAI_DATASET_STD), | |
} | |
else: | |
self.normalization_stats = { | |
"mean": torch.tensor([0.0, 0.0, 0.0]), | |
"std": torch.tensor([1.0, 1.0, 1.0]), | |
} | |
for k, v in augmentations_db.items(): | |
setattr(self, k, v) | |
self.shape_constraints = shape_constraints | |
if not self.test_mode: | |
self._augmentation_space() | |
self.masker = pipelines.AnnotationMask( | |
min_value=0.0, | |
max_value=self.max_depth if test_mode else None, | |
custom_fn=identity, | |
) | |
self.filler = pipelines.RandomFiller(test_mode=test_mode) | |
shape_mult = self.shape_constraints["shape_mult"] | |
self.image_shape = [ | |
ceil(self.image_shape[0] / shape_mult) * shape_mult, | |
ceil(self.image_shape[1] / shape_mult) * shape_mult, | |
] | |
self.resizer = pipelines.ContextCrop( | |
image_shape=self.image_shape, | |
train_ctx_range=(1.0 / self.random_scale, 1.0 * self.random_scale), | |
test_min_ctx=self.test_context, | |
keep_original=test_mode, | |
shape_constraints=self.shape_constraints, | |
) | |
self.collecter = pipelines.Collect( | |
keys=["image_fields", "mask_fields", "gt_fields", "camera_fields"] | |
) | |
def __len__(self): | |
return len(self.dataset) | |
def pack_batch(self, results): | |
results["paddings"] = [ | |
results[x]["paddings"][0] for x in results["sequence_fields"] | |
] | |
for fields_name in [ | |
"image_fields", | |
"gt_fields", | |
"mask_fields", | |
"camera_fields", | |
]: | |
fields = results.get(fields_name) | |
packed = { | |
field: torch.cat( | |
[results[seq][field] for seq in results["sequence_fields"]] | |
) | |
for field in fields | |
} | |
results.update(packed) | |
return results | |
def unpack_batch(self, results): | |
for fields_name in [ | |
"image_fields", | |
"gt_fields", | |
"mask_fields", | |
"camera_fields", | |
]: | |
fields = results.get(fields_name) | |
unpacked = { | |
field: { | |
seq: results[field][idx : idx + 1] | |
for idx, seq in enumerate(results["sequence_fields"]) | |
} | |
for field in fields | |
} | |
results.update(unpacked) | |
return results | |
def _augmentation_space(self): | |
self.augmentations_dict = { | |
"Flip": pipelines.RandomFlip(prob=self.flip_p), | |
"Jitter": pipelines.RandomColorJitter( | |
(-self.random_jitter, self.random_jitter), prob=self.jitter_p | |
), | |
"Gamma": pipelines.RandomGamma( | |
(-self.random_gamma, self.random_gamma), prob=self.gamma_p | |
), | |
"Blur": pipelines.GaussianBlur( | |
kernel_size=13, sigma=(0.1, self.random_blur), prob=self.blur_p | |
), | |
"Grayscale": pipelines.RandomGrayscale(prob=self.grayscale_p), | |
} | |
def augment(self, results): | |
for name, aug in self.augmentations_dict.items(): | |
results = aug(results) | |
return results | |
def prepare_depth_eval(self, inputs, preds): | |
new_preds = {} | |
keyframe_idx = getattr(self, "keyframe_idx", None) | |
slice_idx = slice( | |
keyframe_idx, keyframe_idx + 1 if keyframe_idx is not None else None | |
) | |
new_gts = inputs["depth"][slice_idx] | |
new_masks = inputs["depth_mask"][slice_idx].bool() | |
for key, val in preds.items(): | |
if "depth" in key: | |
new_preds[key] = val[slice_idx] | |
return new_gts, new_preds, new_masks | |
def prepare_points_eval(self, inputs, preds): | |
new_preds = {} | |
new_gts = inputs["points"] | |
new_masks = inputs["depth_mask"].bool() | |
if "points_mask" in inputs: | |
new_masks = inputs["points_mask"].bool() | |
for key, val in preds.items(): | |
if "points" in key: | |
new_preds[key] = val | |
return new_gts, new_preds, new_masks | |
def add_points(self, inputs): | |
inputs["points"] = inputs.get("camera_original", inputs["camera"]).reconstruct( | |
inputs["depth"] | |
) | |
return inputs | |
def accumulate_metrics( | |
self, | |
inputs, | |
preds, | |
keyframe_idx=None, | |
metrics=["depth", "points", "flow_fwd", "pairwise"], | |
): | |
if "depth" in inputs and "points" not in inputs: | |
inputs = self.add_points(inputs) | |
available_metrics = [] | |
for metric in metrics: | |
metric_in_gt = any((metric in k for k in inputs.keys())) | |
metric_in_pred = any((metric in k for k in preds.keys())) | |
if metric_in_gt and metric_in_pred: | |
available_metrics.append(metric) | |
if keyframe_idx is not None: | |
inputs = recursive_index(inputs, slice(keyframe_idx, keyframe_idx + 1)) | |
preds = recursive_index(preds, slice(keyframe_idx, keyframe_idx + 1)) | |
if "depth" in available_metrics: | |
depth_gt, depth_pred, depth_masks = self.prepare_depth_eval(inputs, preds) | |
self.accumulate_metrics_depth(depth_gt, depth_pred, depth_masks) | |
if "points" in available_metrics: | |
points_gt, points_pred, points_masks = self.prepare_points_eval( | |
inputs, preds | |
) | |
self.accumulate_metrics_3d(points_gt, points_pred, points_masks) | |
def accumulate_metrics_depth(self, gts, preds, masks): | |
for eval_type, pred in preds.items(): | |
log_name = eval_type.replace("depth", "").strip("-").strip("_") | |
if log_name not in self.metrics_store: | |
self.metrics_store[log_name] = {} | |
current_count = self.metrics_count.get( | |
log_name, torch.tensor([], device=gts.device) | |
) | |
new_count = masks.view(gts.shape[0], -1).sum(dim=-1) | |
self.metrics_count[log_name] = torch.cat([current_count, new_count]) | |
for k, v in eval_depth(gts, pred, masks, max_depth=self.max_depth).items(): | |
current_metric = self.metrics_store[log_name].get( | |
k, torch.tensor([], device=gts.device) | |
) | |
self.metrics_store[log_name][k] = torch.cat([current_metric, v]) | |
def accumulate_metrics_3d(self, gts, preds, masks): | |
thresholds = torch.linspace( | |
log(self.min_depth), | |
log(self.max_depth / 20), | |
steps=100, | |
device=gts.device, | |
).exp() | |
for eval_type, pred in preds.items(): | |
log_name = eval_type.replace("points", "").strip("-").strip("_") | |
if log_name not in self.metrics_store: | |
self.metrics_store[log_name] = {} | |
current_count = self.metrics_count.get( | |
log_name, torch.tensor([], device=gts.device) | |
) | |
new_count = masks.view(gts.shape[0], -1).sum(dim=-1) | |
self.metrics_count[log_name] = torch.cat([current_count, new_count]) | |
for k, v in eval_3d(gts, pred, masks, thresholds=thresholds).items(): | |
current_metric = self.metrics_store[log_name].get( | |
k, torch.tensor([], device=gts.device) | |
) | |
self.metrics_store[log_name][k] = torch.cat([current_metric, v]) | |
def get_evaluation(self, metrics=None): | |
metric_vals = {} | |
for eval_type in metrics if metrics is not None else self.metrics_store.keys(): | |
assert self.metrics_store[eval_type] | |
cnts = sync_tensor_across_gpus(self.metrics_count[eval_type]) | |
for name, val in self.metrics_store[eval_type].items(): | |
# vals_r = (sync_tensor_across_gpus(val) * cnts / cnts.sum()).sum() | |
vals_r = sync_tensor_across_gpus(val).mean() | |
metric_vals[f"{eval_type}_{name}".strip("_")] = np.round( | |
vals_r.cpu().item(), 5 | |
) | |
self.metrics_store[eval_type] = {} | |
self.metrics_count = {} | |
return metric_vals | |
def replicate(self, results): | |
for i in range(1, self.num_copies): | |
results[(0, i)] = {k: deepcopy(v) for k, v in results[(0, 0)].items()} | |
results["sequence_fields"].append((0, i)) | |
return results | |
def log_load_dataset(self): | |
if is_main_process(): | |
info = f"Loaded {self.__class__.__name__} with {len(self)} images." | |
print(info) | |
def pre_pipeline(self, results): | |
results["image_fields"] = results.get("image_fields", set()) | |
results["gt_fields"] = results.get("gt_fields", set()) | |
results["mask_fields"] = results.get("mask_fields", set()) | |
results["sequence_fields"] = results.get("sequence_fields", set()) | |
results["camera_fields"] = results.get("camera_fields", set()) | |
results["dataset_name"] = ( | |
[self.__class__.__name__] * self.num_frames * self.num_copies | |
) | |
results["depth_scale"] = [self.depth_scale] * self.num_frames * self.num_copies | |
results["si"] = [False] * self.num_frames * self.num_copies | |
results["dense"] = [False] * self.num_frames * self.num_copies | |
results["synthetic"] = [False] * self.num_frames * self.num_copies | |
results["quality"] = [0] * self.num_frames * self.num_copies | |
results["valid_camera"] = [True] * self.num_frames * self.num_copies | |
results["valid_pose"] = [True] * self.num_frames * self.num_copies | |
return results | |
def eval_mask(self, valid_mask): | |
return valid_mask | |
def chunk(self, dataset, chunk_dim=1, pct=1.0): | |
subsampled_datasets = [ | |
x | |
for i in range(0, len(dataset), int(1 / pct * chunk_dim)) | |
for x in dataset[i : i + chunk_dim] | |
] | |
return subsampled_datasets | |
def preprocess(self, results): | |
raise NotImplementedError | |
def postprocess(self, results): | |
raise NotImplementedError | |
def get_mapper(self): | |
raise NotImplementedError | |
def get_intrinsics(self, idx, image_name): | |
raise NotImplementedError | |
def get_extrinsics(self, idx, image_name): | |
raise NotImplementedError | |
def load_dataset(self): | |
raise NotImplementedError | |
def get_single_item(self, idx, sample=None, mapper=None): | |
raise NotImplementedError | |
def __getitem__(self, idx): | |
raise NotImplementedError | |