Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Author: Luigi Piccinelli | |
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) | |
""" | |
import importlib | |
import warnings | |
from copy import deepcopy | |
from math import ceil | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms.v2.functional as TF | |
from einops import rearrange | |
from huggingface_hub import PyTorchModelHubMixin | |
from unik3d.models.decoder import Decoder | |
from unik3d.utils.camera import BatchCamera, Camera | |
from unik3d.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD | |
from unik3d.utils.distributed import is_main_process | |
from unik3d.utils.misc import get_params, last_stack, match_gt | |
def orthonormal_init(num_tokens, dims): | |
pe = torch.randn(num_tokens, dims) | |
# use Gram-Schmidt process to make the matrix orthonormal | |
for i in range(num_tokens): | |
for j in range(i): | |
pe[i] -= torch.dot(pe[i], pe[j]) * pe[j] | |
pe[i] = F.normalize(pe[i], p=2, dim=0) | |
return pe | |
def get_paddings(original_shape, aspect_ratio_range): | |
# Original dimensions | |
H_ori, W_ori = original_shape | |
orig_aspect_ratio = W_ori / H_ori | |
# Determine the closest aspect ratio within the range | |
min_ratio, max_ratio = aspect_ratio_range | |
target_aspect_ratio = min(max_ratio, max(min_ratio, orig_aspect_ratio)) | |
if orig_aspect_ratio > target_aspect_ratio: # Too wide | |
W_new = W_ori | |
H_new = int(W_ori / target_aspect_ratio) | |
pad_top = (H_new - H_ori) // 2 | |
pad_bottom = H_new - H_ori - pad_top | |
pad_left, pad_right = 0, 0 | |
else: # Too tall | |
H_new = H_ori | |
W_new = int(H_ori * target_aspect_ratio) | |
pad_left = (W_new - W_ori) // 2 | |
pad_right = W_new - W_ori - pad_left | |
pad_top, pad_bottom = 0, 0 | |
return (pad_left, pad_right, pad_top, pad_bottom), (H_new, W_new) | |
def get_resize_factor(original_shape, pixels_range, shape_multiplier=14): | |
# Original dimensions | |
H_ori, W_ori = original_shape | |
n_pixels_ori = W_ori * H_ori | |
# Determine the closest number of pixels within the range | |
min_pixels, max_pixels = pixels_range | |
target_pixels = min(max_pixels, max(min_pixels, n_pixels_ori)) | |
# Calculate the resize factor | |
resize_factor = (target_pixels / n_pixels_ori) ** 0.5 | |
new_width = int(W_ori * resize_factor) | |
new_height = int(H_ori * resize_factor) | |
new_height = ceil(new_height / shape_multiplier) * shape_multiplier | |
new_width = ceil(new_width / shape_multiplier) * shape_multiplier | |
return resize_factor, (new_height, new_width) | |
def _postprocess(tensor, shapes, paddings, interpolation_mode="bilinear"): | |
# interpolate to original size | |
tensor = F.interpolate( | |
tensor, size=shapes, mode=interpolation_mode, align_corners=False | |
) | |
# remove paddings | |
pad1_l, pad1_r, pad1_t, pad1_b = paddings | |
tensor = tensor[..., pad1_t : shapes[0] - pad1_b, pad1_l : shapes[1] - pad1_r] | |
return tensor | |
class UniK3D( | |
nn.Module, | |
PyTorchModelHubMixin, | |
library_name="UniK3D", | |
repo_url="https://github.com/lpiccinelli-eth/UniK3D", | |
tags=["monocular-metric-3D-estimation"], | |
): | |
def __init__( | |
self, | |
config, | |
eps: float = 1e-6, | |
**kwargs, | |
): | |
super().__init__() | |
self.eps = eps | |
self.build(config) | |
self.build_losses(config) | |
def pack_sequence( | |
self, | |
inputs: dict[str, torch.Tensor], | |
): | |
for key, value in inputs.items(): | |
if isinstance(value, torch.Tensor): | |
inputs[key] = value.reshape(-1, *value.shape[2:]) | |
elif isinstance(value, BatchCamera): | |
inputs[key] = value.reshape(-1) | |
return inputs | |
def unpack_sequence(self, inputs: dict[str, torch.Tensor], B: int, T: int): | |
for key, value in inputs.items(): | |
if isinstance(value, torch.Tensor): | |
inputs[key] = value.reshape(B, T, *value.shape[1:]) | |
elif isinstance(value, BatchCamera): | |
inputs[key] = value.reshape(B, T) | |
return inputs | |
def forward_train(self, inputs, image_metas): | |
losses = {"opt": {}, "stat": {}} | |
B, T = inputs["image"].shape[:2] | |
image_metas[0]["B"], image_metas[0]["T"] = B, T | |
inputs = self.pack_sequence(inputs) # move from B, T, ... -> B*T, ... | |
inputs, outputs = self.encode_decode(inputs, image_metas) | |
validity_mask = inputs["validity_mask"] | |
# be careful on possible NaNs in reconstruced 3D (unprojection out-of-bound) | |
pts_gt = inputs["camera"].reconstruct(inputs["depth"]) * validity_mask.float() | |
pts_gt = torch.where(pts_gt.isnan().any(dim=1, keepdim=True), 0.0, pts_gt) | |
mask_pts_gt_nan = ~pts_gt.isnan().any(dim=1, keepdim=True) | |
mask = ( | |
inputs["depth_mask"].bool() & validity_mask.bool() & mask_pts_gt_nan.bool() | |
) | |
# compute loss! | |
inputs["distance"] = torch.norm(pts_gt, dim=1, keepdim=True) | |
inputs["points"] = pts_gt | |
inputs["depth_mask"] = mask | |
losses = self.compute_losses(outputs, inputs, image_metas) | |
outputs = self.unpack_sequence(outputs, B, T) | |
return ( | |
outputs, | |
losses, | |
) | |
def forward_test(self, inputs, image_metas): | |
B, T = inputs["image"].shape[:2] | |
image_metas[0]["B"], image_metas[0]["T"] = B, T | |
# move from B, T, ... -> B*T, ... | |
inputs = self.pack_sequence(inputs) | |
inputs, outputs = self.encode_decode(inputs, image_metas) | |
# you can add a dummy tensor with the actual output shape | |
depth_gt = inputs["depth"] | |
outs = {} | |
outs["points"] = match_gt( | |
outputs["points"], depth_gt, padding1=inputs["paddings"], padding2=None | |
) | |
outs["confidence"] = match_gt( | |
outputs["confidence"], depth_gt, padding1=inputs["paddings"], padding2=None | |
) | |
outs["distance"] = outs["points"].norm(dim=1, keepdim=True) | |
outs["depth"] = outs["points"][:, -1:] | |
outs["rays"] = outs["points"] / torch.norm( | |
outs["points"], dim=1, keepdim=True | |
).clip(min=1e-5) | |
outs = self.unpack_sequence(outs, B, T) | |
return outs | |
def forward(self, inputs, image_metas): | |
if self.training: | |
return self.forward_train(inputs, image_metas) | |
else: | |
return self.forward_test(inputs, image_metas) | |
def encode_decode(self, inputs, image_metas=[]): | |
B, _, H, W = inputs["image"].shape | |
# shortcut eval should avoid errors | |
if len(image_metas) and "paddings" in image_metas[0]: | |
# lrtb | |
inputs["paddings"] = torch.tensor( | |
[image_meta["paddings"] for image_meta in image_metas], | |
device=self.device, | |
)[..., [0, 2, 1, 3]] | |
inputs["depth_paddings"] = torch.tensor( | |
[image_meta["depth_paddings"] for image_meta in image_metas], | |
device=self.device, | |
) | |
# at inference we do not have image paddings on top of depth ones (we have not "crop" on gt in ContextCrop) | |
if self.training: | |
inputs["depth_paddings"] = inputs["depth_paddings"] + inputs["paddings"] | |
else: | |
inputs["paddings"] = inputs["paddings"].squeeze(0) | |
inputs["depth_paddings"] = inputs["depth_paddings"].squeeze(0) | |
if inputs.get("camera", None) is not None: | |
inputs["rays"] = inputs["camera"].get_rays(shapes=(B, H, W)) | |
features, tokens = self.pixel_encoder(inputs["image"]) | |
inputs["features"] = [ | |
self.stacking_fn(features[i:j]).contiguous() | |
for i, j in self.slices_encoder_range | |
] | |
inputs["tokens"] = [ | |
self.stacking_fn(tokens[i:j]).contiguous() | |
for i, j in self.slices_encoder_range | |
] | |
outputs = self.pixel_decoder(inputs, image_metas) | |
outputs["rays"] = rearrange(outputs["rays"], "b (h w) c -> b c h w", h=H, w=W) | |
pts_3d = outputs["rays"] * outputs["distance"] | |
outputs.update({"points": pts_3d, "depth": pts_3d[:, -1:]}) | |
return inputs, outputs | |
def compute_losses(self, outputs, inputs, image_metas): | |
B, _, H, W = inputs["image"].shape | |
losses = {"opt": {}, "stat": {}} | |
losses_to_be_computed = list(self.losses.keys()) | |
# depth loss | |
si = torch.tensor( | |
[x.get("si", False) for x in image_metas], device=self.device | |
).reshape(B) | |
loss = self.losses["depth"] | |
depth_losses = loss( | |
outputs["distance"], | |
target=inputs["distance"], | |
mask=inputs["depth_mask"].clone(), | |
si=si, | |
) | |
losses["opt"][loss.name] = loss.weight * depth_losses.mean() | |
losses_to_be_computed.remove("depth") | |
loss = self.losses["camera"] | |
camera_losses = loss( | |
outputs["rays"], target=inputs["rays"], mask=inputs["validity_mask"].bool() | |
) | |
losses["opt"][loss.name] = loss.weight * camera_losses.mean() | |
losses_to_be_computed.remove("camera") | |
# remaining losses, we expect no more losses to be computed | |
loss = self.losses["confidence"] | |
conf_losses = loss( | |
outputs["confidence"], | |
target_gt=inputs["depth"], | |
target_pred=outputs["depth"], | |
mask=inputs["depth_mask"].clone(), | |
) | |
print(conf_losses, camera_losses, depth_losses) | |
losses["opt"][loss.name + "_conf"] = loss.weight * conf_losses.mean() | |
losses_to_be_computed.remove("confidence") | |
assert ( | |
not losses_to_be_computed | |
), f"Losses {losses_to_be_computed} not computed, revise `compute_loss` method" | |
return losses | |
def infer( | |
self, | |
rgb: torch.Tensor, | |
camera: torch.Tensor | Camera | None = None, | |
rays=None, | |
normalize=True, | |
): | |
ratio_bounds = self.shape_constraints["ratio_bounds"] | |
pixels_bounds = [ | |
self.shape_constraints["pixels_min"], | |
self.shape_constraints["pixels_max"], | |
] | |
if hasattr(self, "resolution_level"): | |
assert ( | |
self.resolution_level >= 0 and self.resolution_level < 10 | |
), "resolution_level should be in [0, 10)" | |
pixels_range = pixels_bounds[1] - pixels_bounds[0] | |
interval = pixels_range / 10 | |
new_lowbound = self.resolution_level * interval + pixels_bounds[0] | |
new_upbound = (self.resolution_level + 1) * interval + pixels_bounds[0] | |
pixels_bounds = (new_lowbound, new_upbound) | |
else: | |
warnings.warn("!! self.resolution_level not set, using default bounds !!") | |
# houskeeping on cpu/cuda and batchify | |
if rgb.ndim == 3: | |
rgb = rgb.unsqueeze(0) | |
if camera is not None: | |
camera = BatchCamera.from_camera(camera) | |
camera = camera.to(self.device) | |
B, _, H, W = rgb.shape | |
rgb = rgb.to(self.device) | |
# preprocess | |
paddings, (padded_H, padded_W) = get_paddings((H, W), ratio_bounds) | |
(pad_left, pad_right, pad_top, pad_bottom) = paddings | |
resize_factor, (new_H, new_W) = get_resize_factor( | |
(padded_H, padded_W), pixels_bounds | |
) | |
# -> rgb preprocess (input std-ized and resized) | |
if normalize: | |
rgb = TF.normalize( | |
rgb.float() / 255.0, | |
mean=IMAGENET_DATASET_MEAN, | |
std=IMAGENET_DATASET_STD, | |
) | |
rgb = F.pad(rgb, (pad_left, pad_right, pad_top, pad_bottom), value=0.0) | |
rgb = F.interpolate( | |
rgb, size=(new_H, new_W), mode="bilinear", align_corners=False | |
) | |
# -> camera preprocess | |
if camera is not None: | |
camera = camera.crop( | |
left=-pad_left, top=-pad_top, right=-pad_right, bottom=-pad_bottom | |
) | |
camera = camera.resize(resize_factor) | |
# prepare inputs | |
inputs = {"image": rgb} | |
if camera is not None: | |
inputs["camera"] = camera | |
rays = camera.get_rays(shapes=(B, new_H, new_W), noisy=False).reshape( | |
B, 3, new_H, new_W | |
) | |
inputs["rays"] = rays | |
if rays is not None: | |
rays = rays.to(self.device) | |
if rays.ndim == 3: | |
rays = rays.unsqueeze(0) | |
rays = F.pad( | |
rays, | |
( | |
max(0, pad_left), | |
max(0, pad_right), | |
max(0, pad_top), | |
max(0, pad_bottom), | |
), | |
value=0.0, | |
) | |
rays = F.interpolate( | |
rays, size=(new_H, new_W), mode="bilinear", align_corners=False | |
) | |
inputs["rays"] = rays | |
# run model | |
_, model_outputs = self.encode_decode(inputs, image_metas={}) | |
# collect outputs | |
out = {} | |
out["confidence"] = _postprocess( | |
model_outputs["confidence"], | |
(padded_H, padded_W), | |
paddings=paddings, | |
interpolation_mode=self.interpolation_mode, | |
) | |
points = _postprocess( | |
model_outputs["points"], | |
(padded_H, padded_W), | |
paddings=paddings, | |
interpolation_mode=self.interpolation_mode, | |
) | |
rays = _postprocess( | |
model_outputs["rays"], | |
(padded_H, padded_W), | |
paddings=paddings, | |
interpolation_mode=self.interpolation_mode, | |
) | |
out["distance"] = points.norm(dim=1, keepdim=True) | |
out["depth"] = points[:, -1:] | |
out["points"] = points | |
out["rays"] = rays / torch.norm(rays, dim=1, keepdim=True).clip(min=1e-5) | |
out["lowres_features"] = model_outputs["lowres_features"] | |
return out | |
def load_pretrained(self, model_file): | |
dict_model = torch.load(model_file, map_location="cpu", weights_only=False) | |
if "model" in dict_model: | |
dict_model = dict_model["model"] | |
info = self.load_state_dict(dict_model, strict=False) | |
if is_main_process(): | |
print( | |
f"Loaded from {model_file} for {self.__class__.__name__} results in:", | |
info, | |
) | |
def build(self, config): | |
mod = importlib.import_module("unik3d.models.encoder") | |
pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) | |
pixel_encoder_config = { | |
**config["training"], | |
**config["model"]["pixel_encoder"], | |
**config["data"], | |
} | |
pixel_encoder = pixel_encoder_factory(pixel_encoder_config) | |
pixel_encoder_embed_dims = ( | |
pixel_encoder.embed_dims | |
if hasattr(pixel_encoder, "embed_dims") | |
else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] | |
) | |
config["model"]["pixel_encoder"]["embed_dim"] = getattr( | |
pixel_encoder, "embed_dim" | |
) | |
config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims | |
config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths | |
config["model"]["pixel_encoder"]["cls_token_embed_dims"] = getattr( | |
pixel_encoder, "cls_token_embed_dims", pixel_encoder_embed_dims | |
) | |
pixel_decoder = Decoder(config) | |
self.pixel_encoder = pixel_encoder | |
self.pixel_decoder = pixel_decoder | |
self.slices_encoder_range = list( | |
zip([0, *self.pixel_encoder.depths[:-1]], self.pixel_encoder.depths) | |
) | |
self.stacking_fn = last_stack | |
self.shape_constraints = config["data"]["shape_constraints"] | |
self.interpolation_mode = "bilinear" | |
def build_losses(self, config): | |
self.losses = {} | |
for loss_name, loss_config in config["training"]["losses"].items(): | |
mod = importlib.import_module("unik3d.ops.losses") | |
loss_factory = getattr(mod, loss_config["name"]) | |
self.losses[loss_name] = loss_factory.build(loss_config) | |
def get_params(self, config): | |
if hasattr(self.pixel_encoder, "get_params"): | |
encoder_p, _ = self.pixel_encoder.get_params( | |
config["model"]["pixel_encoder"]["lr"], | |
config["training"]["wd"], | |
config["training"]["ld"], | |
) | |
else: | |
encoder_p, _ = get_params( | |
self.pixel_encoder, | |
config["model"]["pixel_encoder"]["lr"], | |
config["training"]["wd"], | |
) | |
decoder_p = self.pixel_decoder.get_params( | |
config["training"]["lr"], config["training"]["wd"] | |
) | |
return [*encoder_p, *decoder_p] | |
def step(self): | |
self.pixel_decoder.steps += 1 | |
def parameters_grad(self): | |
for p in self.parameters(): | |
if p.requires_grad: | |
yield p | |
def device(self): | |
return next(self.parameters()).device | |