Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import os | |
from collections import defaultdict | |
from typing import Any, Dict, List, Optional | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import torch.utils.data.distributed | |
import wandb | |
from PIL import Image | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from unik3d.utils.distributed import barrier, get_world_size, is_main_process | |
from unik3d.utils.misc import remove_leading_dim, remove_padding, ssi_helper | |
from unik3d.utils.visualization import colorize, image_grid | |
def stack_mixedshape_numpy(tensor_list, dim=0): | |
max_rows = max(tensor.shape[0] for tensor in tensor_list) | |
max_columns = max(tensor.shape[1] for tensor in tensor_list) | |
padded_tensors = [] | |
for tensor in tensor_list: | |
rows, columns, *_ = tensor.shape | |
pad_rows = max_rows - rows | |
pad_columns = max_columns - columns | |
padded_tensor = np.pad( | |
tensor, ((0, pad_rows), (0, pad_columns), (0, 0)), mode="constant" | |
) | |
padded_tensors.append(padded_tensor) | |
return np.stack(padded_tensors, axis=dim) | |
def original_image(batch): | |
paddings = [ | |
torch.tensor(pads) | |
for img_meta in batch["img_metas"] | |
for pads in img_meta.get("paddings", [[0] * 4]) | |
] | |
paddings = torch.stack(paddings).to(batch["data"]["image"].device)[ | |
..., [0, 2, 1, 3] | |
] # lrtb | |
T, _, H, W = batch["data"]["depth"].shape | |
batch["data"]["image"] = F.interpolate( | |
batch["data"]["image"], | |
(H + paddings[0][2] + paddings[0][3], W + paddings[0][1] + paddings[0][2]), | |
mode="bilinear", | |
align_corners=False, | |
antialias=True, | |
) | |
batch["data"]["image"] = remove_padding( | |
batch["data"]["image"], paddings.repeat(T, 1) | |
) | |
return batch | |
def original_image_inv(batch, preds=None): | |
paddings = [ | |
torch.tensor(pads) | |
for img_meta in batch["img_metas"] | |
for pads in img_meta.get("padding_size", [[0] * 4]) | |
] | |
T, _, H, W = batch["data"]["depth"].shape | |
batch["data"]["image"] = remove_padding(batch["data"]["image"], paddings * T) | |
batch["data"]["image"] = F.interpolate( | |
batch["data"]["image"], | |
(H, W), | |
mode="bilinear", | |
align_corners=False, | |
antialias=True, | |
) | |
if preds is not None: | |
for key in ["depth"]: | |
if key in preds: | |
preds[key] = remove_padding(preds[key], paddings * T) | |
preds[key] = F.interpolate( | |
preds[key], | |
(H, W), | |
mode="bilinear", | |
align_corners=False, | |
antialias=True, | |
) | |
return batch, preds | |
def aggregate_metrics(metrics_all, exclude_fn=lambda name: False): | |
aggregate_name = "".join( | |
[name_ds[:3] for name_ds in metrics_all.keys() if not exclude_fn(name_ds)] | |
) | |
metrics_aggregate = defaultdict(list) | |
for name_ds, metrics in metrics_all.items(): | |
if exclude_fn(name_ds): | |
continue | |
for metrics_name, metrics_value in metrics.items(): | |
metrics_aggregate[metrics_name].append(metrics_value) | |
return { | |
**{aggregate_name: {k: sum(v) / len(v) for k, v in metrics_aggregate.items()}}, | |
**metrics_all, | |
} | |
GROUPS = { | |
"SFoV": ["KITTI", "NYUv2Depth", "DiodeIndoor", "ETH3D", "IBims"], | |
"SFoVDi": ["DiodeIndoor_F", "ETH3D_F", "IBims_F"], | |
"LFoV": ["ADT", "KITTI360", "ScanNetpp_F"], | |
} | |
def aggregate_metrics_camera(metrics_all): | |
available_groups = { | |
k: v for k, v in GROUPS.items() if any([name in metrics_all for name in v]) | |
} | |
for group_name, group_datasets in available_groups.items(): | |
metrics_aggregate = defaultdict(list) | |
for dataset_name in group_datasets: | |
if dataset_name not in metrics_all: | |
print( | |
f"Dataset {dataset_name} not used for aggregation of {group_name}" | |
) | |
continue | |
for metrics_name, metrics_value in metrics_all[dataset_name].items(): | |
metrics_aggregate[metrics_name].append(metrics_value) | |
metrics_all[group_name] = { | |
k: sum(v) / len(v) for k, v in metrics_aggregate.items() | |
} | |
return metrics_all | |
def log_metrics(metrics_all, step): | |
for name_ds, metrics in metrics_all.items(): | |
for metrics_name, metrics_value in metrics.items(): | |
try: | |
wandb.log( | |
{f"Metrics/{name_ds}/{metrics_name}": metrics_value}, step=step | |
) | |
except: | |
print(f"Metrics/{name_ds}/{metrics_name} {round(metrics_value, 4)}") | |
def log_artifacts(artifacts_all, step, run_id): | |
for ds_name, artifacts in artifacts_all.items(): | |
rgbs, gts = artifacts["rgbs"], artifacts["gts"] | |
logging_imgs = [ | |
*rgbs, | |
*gts, | |
*[ | |
x | |
for k, v in artifacts.items() | |
if ("rgbs" not in k and "gts" not in k) | |
for x in v | |
], | |
] | |
artifacts_grid = image_grid(logging_imgs, len(artifacts), len(rgbs)) | |
try: | |
wandb.log({f"{ds_name}_test": [wandb.Image(artifacts_grid)]}, step=step) | |
except: | |
print(f"Error while saving artifacts at step {step}") | |
def show(vals, dataset, ssi_depth=False): | |
output_artifacts, additionals = {}, {} | |
predictions, gts, errors, images = [], [], [], [] | |
for v in vals: | |
image = v["image"][0].unsqueeze(0) | |
gt = v["depth"][0].unsqueeze(0) | |
prediction = v["depth_pred"][0].unsqueeze(0) | |
# Downsample for memory and viz | |
# if any([x in dataset.__class__.__name__ for x in ["DDAD", "Argoverse", "Waymo", "DrivingStereo"]]): | |
# gt = F.interpolate(gt, scale_factor=0.5, mode="nearest-exact") | |
# # Dilate for a better visualization | |
# gt[gt < 1e-4] = dilate(gt)[gt < 1e-4] | |
H, W = gt.shape[-2:] | |
aspect_ratio = H / W | |
new_W = int((300_000 / aspect_ratio) ** 0.5) | |
new_H = int(aspect_ratio * new_W) | |
gt = F.interpolate(gt, (new_H, new_W), mode="nearest-exact") | |
# Format predictions and errors for every metrics used | |
prediction = F.interpolate( | |
prediction, | |
gt.shape[-2:], | |
mode="bilinear", | |
align_corners=False, | |
antialias=True, | |
) | |
error = torch.zeros_like(prediction) | |
error[gt > dataset.min_depth] = ( | |
4 | |
* dataset.max_depth | |
* torch.abs(gt - prediction)[gt > dataset.min_depth] | |
/ gt[gt > dataset.min_depth] | |
) | |
if ssi_depth: | |
scale, shift = ssi_helper(gt[gt > 0], prediction[gt > 0]) | |
prediction = (prediction * scale + shift).clip(0.0, dataset.max_depth) | |
prediction = colorize( | |
prediction.squeeze().cpu().detach().numpy(), | |
vmin=dataset.min_depth, | |
vmax=dataset.max_depth, | |
cmap="magma_r", | |
) | |
error = error.clip(0.0, dataset.max_depth).cpu().detach().numpy() | |
error = colorize(error.squeeze(), vmin=0.001, vmax=1.0, cmap="coolwarm") | |
errors.append(error) | |
predictions.append(prediction) | |
image = F.interpolate( | |
image, gt.shape[-2:], mode="bilinear", align_corners=False, antialias=True | |
) | |
image = image.cpu().detach() * dataset.normalization_stats["std"].view( | |
1, -1, 1, 1 | |
) + dataset.normalization_stats["mean"].view(1, -1, 1, 1) | |
image = ( | |
(255 * image) | |
.clip(0.0, 255.0) | |
.to(torch.uint8) | |
.permute(0, 2, 3, 1) | |
.numpy() | |
.squeeze() | |
) | |
gt = gt.clip(0.0, dataset.max_depth).cpu().detach().numpy() | |
gt = colorize( | |
gt.squeeze(), vmin=dataset.min_depth, vmax=dataset.max_depth, cmap="magma_r" | |
) | |
gts.append(gt) | |
images.append(image) | |
for name, additional in v.get("infos", {}).items(): | |
if name not in additionals: | |
additionals[name] = [] | |
if additional[0].shape[0] == 3: | |
val = ( | |
(127.5 * (additional[0] + 1)) | |
.clip(0, 255) | |
.to(torch.uint8) | |
.cpu() | |
.detach() | |
.permute(1, 2, 0) | |
.numpy() | |
) | |
else: | |
val = colorize( | |
additional[0].cpu().detach().squeeze().numpy(), | |
0.0, | |
dataset.max_depth, | |
) | |
additionals[name].append(val) | |
output_artifacts.update( | |
{ | |
f"predictions": stack_mixedshape_numpy(predictions), | |
f"errors": stack_mixedshape_numpy(errors), | |
"rgbs": stack_mixedshape_numpy(images), | |
"gts": stack_mixedshape_numpy(gts), | |
**{k: stack_mixedshape_numpy(v) for k, v in additionals.items()}, | |
} | |
) | |
return output_artifacts | |
METRIC_B = "F1" | |
INVERT = True | |
SSI_VISUALIZATION = True | |
def validate( | |
model, | |
test_loaders: Dict[str, DataLoader], | |
step, | |
run_id, | |
context, | |
idxs=(1, 100, 150, 1000), | |
): | |
metrics_all, predictions_select = {}, {} | |
world_size = get_world_size() | |
for name_ds, test_loader in test_loaders.items(): | |
idxs = [idx % len(test_loader.dataset) for idx in idxs] | |
ds_show = [] | |
for i, batch in enumerate(test_loader): | |
with context: | |
batch["data"] = { | |
k: v.to(model.device) for k, v in batch["data"].items() | |
} | |
preds = model(batch["data"], batch["img_metas"]) | |
if batch["data"]["image"].ndim == 5: | |
batch["data"] = remove_leading_dim(batch["data"]) | |
if preds["depth"].ndim == 5: | |
preds = remove_leading_dim(preds) | |
batch = original_image(batch) | |
test_loader.dataset.accumulate_metrics( | |
inputs=batch["data"], | |
preds=preds, | |
keyframe_idx=batch["img_metas"][0].get("keyframe_idx"), | |
) | |
# for prediction images logging | |
if i * world_size in idxs: | |
ii = (len(preds["depth"]) + 1) // 2 - 1 | |
slice_ = slice(ii, ii + 1) | |
batch["data"] = {k: v[slice_] for k, v in batch["data"].items()} | |
preds["depth"] = preds["depth"][slice_] | |
ds_show.append({**batch["data"], **{"depth_pred": preds["depth"]}}) | |
barrier() | |
metrics_all[name_ds] = test_loader.dataset.get_evaluation() | |
predictions_select[name_ds] = show( | |
ds_show, test_loader.dataset, ssi_depth=SSI_VISUALIZATION | |
) | |
barrier() | |
if is_main_process(): | |
log_artifacts(artifacts_all=predictions_select, step=step, run_id=run_id) | |
metrics_all = aggregate_metrics( | |
metrics_all, exclude_fn=lambda name: "mono" in name | |
) | |
metrics_all = aggregate_metrics_camera(metrics_all) | |
log_metrics(metrics_all=metrics_all, step=step) | |
return metrics_all | |