Luigi Piccinelli
remove fp16
183b4b6
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import importlib
import warnings
from copy import deepcopy
from math import ceil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.v2.functional as TF
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from unik3d.models.decoder import Decoder
from unik3d.utils.camera import BatchCamera, Camera
from unik3d.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD
from unik3d.utils.distributed import is_main_process
from unik3d.utils.misc import get_params, last_stack, match_gt
def orthonormal_init(num_tokens, dims):
pe = torch.randn(num_tokens, dims)
# use Gram-Schmidt process to make the matrix orthonormal
for i in range(num_tokens):
for j in range(i):
pe[i] -= torch.dot(pe[i], pe[j]) * pe[j]
pe[i] = F.normalize(pe[i], p=2, dim=0)
return pe
def get_paddings(original_shape, aspect_ratio_range):
# Original dimensions
H_ori, W_ori = original_shape
orig_aspect_ratio = W_ori / H_ori
# Determine the closest aspect ratio within the range
min_ratio, max_ratio = aspect_ratio_range
target_aspect_ratio = min(max_ratio, max(min_ratio, orig_aspect_ratio))
if orig_aspect_ratio > target_aspect_ratio: # Too wide
W_new = W_ori
H_new = int(W_ori / target_aspect_ratio)
pad_top = (H_new - H_ori) // 2
pad_bottom = H_new - H_ori - pad_top
pad_left, pad_right = 0, 0
else: # Too tall
H_new = H_ori
W_new = int(H_ori * target_aspect_ratio)
pad_left = (W_new - W_ori) // 2
pad_right = W_new - W_ori - pad_left
pad_top, pad_bottom = 0, 0
return (pad_left, pad_right, pad_top, pad_bottom), (H_new, W_new)
def get_resize_factor(original_shape, pixels_range, shape_multiplier=14):
# Original dimensions
H_ori, W_ori = original_shape
n_pixels_ori = W_ori * H_ori
# Determine the closest number of pixels within the range
min_pixels, max_pixels = pixels_range
target_pixels = min(max_pixels, max(min_pixels, n_pixels_ori))
# Calculate the resize factor
resize_factor = (target_pixels / n_pixels_ori) ** 0.5
new_width = int(W_ori * resize_factor)
new_height = int(H_ori * resize_factor)
new_height = ceil(new_height / shape_multiplier) * shape_multiplier
new_width = ceil(new_width / shape_multiplier) * shape_multiplier
return resize_factor, (new_height, new_width)
def _postprocess(tensor, shapes, paddings, interpolation_mode="bilinear"):
# interpolate to original size
tensor = F.interpolate(
tensor, size=shapes, mode=interpolation_mode, align_corners=False
)
# remove paddings
pad1_l, pad1_r, pad1_t, pad1_b = paddings
tensor = tensor[..., pad1_t : shapes[0] - pad1_b, pad1_l : shapes[1] - pad1_r]
return tensor
class UniK3D(
nn.Module,
PyTorchModelHubMixin,
library_name="UniK3D",
repo_url="https://github.com/lpiccinelli-eth/UniK3D",
tags=["monocular-metric-3D-estimation"],
):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__()
self.eps = eps
self.build(config)
self.build_losses(config)
def pack_sequence(
self,
inputs: dict[str, torch.Tensor],
):
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
inputs[key] = value.reshape(-1, *value.shape[2:])
elif isinstance(value, BatchCamera):
inputs[key] = value.reshape(-1)
return inputs
def unpack_sequence(self, inputs: dict[str, torch.Tensor], B: int, T: int):
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
inputs[key] = value.reshape(B, T, *value.shape[1:])
elif isinstance(value, BatchCamera):
inputs[key] = value.reshape(B, T)
return inputs
def forward_train(self, inputs, image_metas):
losses = {"opt": {}, "stat": {}}
B, T = inputs["image"].shape[:2]
image_metas[0]["B"], image_metas[0]["T"] = B, T
inputs = self.pack_sequence(inputs) # move from B, T, ... -> B*T, ...
inputs, outputs = self.encode_decode(inputs, image_metas)
validity_mask = inputs["validity_mask"]
# be careful on possible NaNs in reconstruced 3D (unprojection out-of-bound)
pts_gt = inputs["camera"].reconstruct(inputs["depth"]) * validity_mask.float()
pts_gt = torch.where(pts_gt.isnan().any(dim=1, keepdim=True), 0.0, pts_gt)
mask_pts_gt_nan = ~pts_gt.isnan().any(dim=1, keepdim=True)
mask = (
inputs["depth_mask"].bool() & validity_mask.bool() & mask_pts_gt_nan.bool()
)
# compute loss!
inputs["distance"] = torch.norm(pts_gt, dim=1, keepdim=True)
inputs["points"] = pts_gt
inputs["depth_mask"] = mask
losses = self.compute_losses(outputs, inputs, image_metas)
outputs = self.unpack_sequence(outputs, B, T)
return (
outputs,
losses,
)
def forward_test(self, inputs, image_metas):
B, T = inputs["image"].shape[:2]
image_metas[0]["B"], image_metas[0]["T"] = B, T
# move from B, T, ... -> B*T, ...
inputs = self.pack_sequence(inputs)
inputs, outputs = self.encode_decode(inputs, image_metas)
# you can add a dummy tensor with the actual output shape
depth_gt = inputs["depth"]
outs = {}
outs["points"] = match_gt(
outputs["points"], depth_gt, padding1=inputs["paddings"], padding2=None
)
outs["confidence"] = match_gt(
outputs["confidence"], depth_gt, padding1=inputs["paddings"], padding2=None
)
outs["distance"] = outs["points"].norm(dim=1, keepdim=True)
outs["depth"] = outs["points"][:, -1:]
outs["rays"] = outs["points"] / torch.norm(
outs["points"], dim=1, keepdim=True
).clip(min=1e-5)
outs = self.unpack_sequence(outs, B, T)
return outs
def forward(self, inputs, image_metas):
if self.training:
return self.forward_train(inputs, image_metas)
else:
return self.forward_test(inputs, image_metas)
def encode_decode(self, inputs, image_metas=[]):
B, _, H, W = inputs["image"].shape
# shortcut eval should avoid errors
if len(image_metas) and "paddings" in image_metas[0]:
# lrtb
inputs["paddings"] = torch.tensor(
[image_meta["paddings"] for image_meta in image_metas],
device=self.device,
)[..., [0, 2, 1, 3]]
inputs["depth_paddings"] = torch.tensor(
[image_meta["depth_paddings"] for image_meta in image_metas],
device=self.device,
)
# at inference we do not have image paddings on top of depth ones (we have not "crop" on gt in ContextCrop)
if self.training:
inputs["depth_paddings"] = inputs["depth_paddings"] + inputs["paddings"]
else:
inputs["paddings"] = inputs["paddings"].squeeze(0)
inputs["depth_paddings"] = inputs["depth_paddings"].squeeze(0)
if inputs.get("camera", None) is not None:
inputs["rays"] = inputs["camera"].get_rays(shapes=(B, H, W))
features, tokens = self.pixel_encoder(inputs["image"])
inputs["features"] = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
inputs["tokens"] = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
outputs = self.pixel_decoder(inputs, image_metas)
outputs["rays"] = rearrange(outputs["rays"], "b (h w) c -> b c h w", h=H, w=W)
pts_3d = outputs["rays"] * outputs["distance"]
outputs.update({"points": pts_3d, "depth": pts_3d[:, -1:]})
return inputs, outputs
def compute_losses(self, outputs, inputs, image_metas):
B, _, H, W = inputs["image"].shape
losses = {"opt": {}, "stat": {}}
losses_to_be_computed = list(self.losses.keys())
# depth loss
si = torch.tensor(
[x.get("si", False) for x in image_metas], device=self.device
).reshape(B)
loss = self.losses["depth"]
depth_losses = loss(
outputs["distance"],
target=inputs["distance"],
mask=inputs["depth_mask"].clone(),
si=si,
)
losses["opt"][loss.name] = loss.weight * depth_losses.mean()
losses_to_be_computed.remove("depth")
loss = self.losses["camera"]
camera_losses = loss(
outputs["rays"], target=inputs["rays"], mask=inputs["validity_mask"].bool()
)
losses["opt"][loss.name] = loss.weight * camera_losses.mean()
losses_to_be_computed.remove("camera")
# remaining losses, we expect no more losses to be computed
loss = self.losses["confidence"]
conf_losses = loss(
outputs["confidence"],
target_gt=inputs["depth"],
target_pred=outputs["depth"],
mask=inputs["depth_mask"].clone(),
)
print(conf_losses, camera_losses, depth_losses)
losses["opt"][loss.name + "_conf"] = loss.weight * conf_losses.mean()
losses_to_be_computed.remove("confidence")
assert (
not losses_to_be_computed
), f"Losses {losses_to_be_computed} not computed, revise `compute_loss` method"
return losses
@torch.no_grad()
def infer(
self,
rgb: torch.Tensor,
camera: torch.Tensor | Camera | None = None,
rays=None,
normalize=True,
):
ratio_bounds = self.shape_constraints["ratio_bounds"]
pixels_bounds = [
self.shape_constraints["pixels_min"],
self.shape_constraints["pixels_max"],
]
if hasattr(self, "resolution_level"):
assert (
self.resolution_level >= 0 and self.resolution_level < 10
), "resolution_level should be in [0, 10)"
pixels_range = pixels_bounds[1] - pixels_bounds[0]
interval = pixels_range / 10
new_lowbound = self.resolution_level * interval + pixels_bounds[0]
new_upbound = (self.resolution_level + 1) * interval + pixels_bounds[0]
pixels_bounds = (new_lowbound, new_upbound)
else:
warnings.warn("!! self.resolution_level not set, using default bounds !!")
# houskeeping on cpu/cuda and batchify
if rgb.ndim == 3:
rgb = rgb.unsqueeze(0)
if camera is not None:
camera = BatchCamera.from_camera(camera)
camera = camera.to(self.device)
B, _, H, W = rgb.shape
rgb = rgb.to(self.device)
# preprocess
paddings, (padded_H, padded_W) = get_paddings((H, W), ratio_bounds)
(pad_left, pad_right, pad_top, pad_bottom) = paddings
resize_factor, (new_H, new_W) = get_resize_factor(
(padded_H, padded_W), pixels_bounds
)
# -> rgb preprocess (input std-ized and resized)
if normalize:
rgb = TF.normalize(
rgb.float() / 255.0,
mean=IMAGENET_DATASET_MEAN,
std=IMAGENET_DATASET_STD,
)
rgb = F.pad(rgb, (pad_left, pad_right, pad_top, pad_bottom), value=0.0)
rgb = F.interpolate(
rgb, size=(new_H, new_W), mode="bilinear", align_corners=False
)
# -> camera preprocess
if camera is not None:
camera = camera.crop(
left=-pad_left, top=-pad_top, right=-pad_right, bottom=-pad_bottom
)
camera = camera.resize(resize_factor)
# prepare inputs
inputs = {"image": rgb}
if camera is not None:
inputs["camera"] = camera
rays = camera.get_rays(shapes=(B, new_H, new_W), noisy=False).reshape(
B, 3, new_H, new_W
)
inputs["rays"] = rays
if rays is not None:
rays = rays.to(self.device)
if rays.ndim == 3:
rays = rays.unsqueeze(0)
rays = F.pad(
rays,
(
max(0, pad_left),
max(0, pad_right),
max(0, pad_top),
max(0, pad_bottom),
),
value=0.0,
)
rays = F.interpolate(
rays, size=(new_H, new_W), mode="bilinear", align_corners=False
)
inputs["rays"] = rays
# run model
_, model_outputs = self.encode_decode(inputs, image_metas={})
# collect outputs
out = {}
out["confidence"] = _postprocess(
model_outputs["confidence"],
(padded_H, padded_W),
paddings=paddings,
interpolation_mode=self.interpolation_mode,
)
points = _postprocess(
model_outputs["points"],
(padded_H, padded_W),
paddings=paddings,
interpolation_mode=self.interpolation_mode,
)
rays = _postprocess(
model_outputs["rays"],
(padded_H, padded_W),
paddings=paddings,
interpolation_mode=self.interpolation_mode,
)
out["distance"] = points.norm(dim=1, keepdim=True)
out["depth"] = points[:, -1:]
out["points"] = points
out["rays"] = rays / torch.norm(rays, dim=1, keepdim=True).clip(min=1e-5)
out["lowres_features"] = model_outputs["lowres_features"]
return out
def load_pretrained(self, model_file):
dict_model = torch.load(model_file, map_location="cpu", weights_only=False)
if "model" in dict_model:
dict_model = dict_model["model"]
info = self.load_state_dict(dict_model, strict=False)
if is_main_process():
print(
f"Loaded from {model_file} for {self.__class__.__name__} results in:",
info,
)
def build(self, config):
mod = importlib.import_module("unik3d.models.encoder")
pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"])
pixel_encoder_config = {
**config["training"],
**config["model"]["pixel_encoder"],
**config["data"],
}
pixel_encoder = pixel_encoder_factory(pixel_encoder_config)
pixel_encoder_embed_dims = (
pixel_encoder.embed_dims
if hasattr(pixel_encoder, "embed_dims")
else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)]
)
config["model"]["pixel_encoder"]["embed_dim"] = getattr(
pixel_encoder, "embed_dim"
)
config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims
config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths
config["model"]["pixel_encoder"]["cls_token_embed_dims"] = getattr(
pixel_encoder, "cls_token_embed_dims", pixel_encoder_embed_dims
)
pixel_decoder = Decoder(config)
self.pixel_encoder = pixel_encoder
self.pixel_decoder = pixel_decoder
self.slices_encoder_range = list(
zip([0, *self.pixel_encoder.depths[:-1]], self.pixel_encoder.depths)
)
self.stacking_fn = last_stack
self.shape_constraints = config["data"]["shape_constraints"]
self.interpolation_mode = "bilinear"
def build_losses(self, config):
self.losses = {}
for loss_name, loss_config in config["training"]["losses"].items():
mod = importlib.import_module("unik3d.ops.losses")
loss_factory = getattr(mod, loss_config["name"])
self.losses[loss_name] = loss_factory.build(loss_config)
def get_params(self, config):
if hasattr(self.pixel_encoder, "get_params"):
encoder_p, _ = self.pixel_encoder.get_params(
config["model"]["pixel_encoder"]["lr"],
config["training"]["wd"],
config["training"]["ld"],
)
else:
encoder_p, _ = get_params(
self.pixel_encoder,
config["model"]["pixel_encoder"]["lr"],
config["training"]["wd"],
)
decoder_p = self.pixel_decoder.get_params(
config["training"]["lr"], config["training"]["wd"]
)
return [*encoder_p, *decoder_p]
def step(self):
self.pixel_decoder.steps += 1
def parameters_grad(self):
for p in self.parameters():
if p.requires_grad:
yield p
@property
def device(self):
return next(self.parameters()).device