Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Author: Luigi Piccinelli | |
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) | |
""" | |
from math import tanh | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
from timm.models.layers import trunc_normal_ | |
from unik3d.layers import (MLP, AttentionBlock, AttentionLayer, GradChoker, | |
PositionEmbeddingSine, ResUpsampleBil) | |
from unik3d.utils.coordinate import coords_grid | |
from unik3d.utils.geometric import flat_interpolate | |
from unik3d.utils.misc import get_params | |
from unik3d.utils.positional_embedding import generate_fourier_features | |
from unik3d.utils.sht import rsh_cart_3 | |
from unik3d.utils.misc import profile_method | |
def orthonormal_init(num_tokens, dims): | |
pe = torch.randn(num_tokens, dims) | |
# Apply Gram-Schmidt process to make the matrix orthonormal | |
# Awful loop.. | |
for i in range(num_tokens): | |
for j in range(i): | |
pe[i] -= torch.dot(pe[i], pe[j]) * pe[j] | |
pe[i] = F.normalize(pe[i], p=2, dim=0) | |
return pe | |
class ListAdapter(nn.Module): | |
def __init__(self, input_dims: list[int], hidden_dim: int): | |
super().__init__() | |
self.input_adapters = nn.ModuleList([]) | |
self.num_chunks = len(input_dims) | |
self.checkpoint = True | |
for input_dim in input_dims: | |
self.input_adapters.append(nn.Linear(input_dim, hidden_dim)) | |
def forward(self, xs: torch.Tensor) -> torch.Tensor: | |
outs = [self.input_adapters[i](x) for i, x in enumerate(xs)] | |
return outs | |
class AngularModule(nn.Module): | |
def __init__( | |
self, | |
hidden_dim: int, | |
num_heads: int = 8, | |
expansion: int = 4, | |
dropout: float = 0.0, | |
layer_scale: float = 1.0, | |
**kwargs, | |
): | |
super().__init__() | |
self.pin_params = 3 | |
self.deg1_params = 3 | |
self.deg2_params = 5 | |
self.deg3_params = 7 | |
self.num_params = ( | |
self.pin_params + self.deg1_params + self.deg2_params + self.deg3_params | |
) | |
self.aggregate1 = AttentionBlock( | |
hidden_dim, | |
num_heads=num_heads, | |
expansion=expansion, | |
dropout=dropout, | |
layer_scale=layer_scale, | |
) | |
self.aggregate2 = AttentionBlock( | |
hidden_dim, | |
num_heads=num_heads, | |
expansion=expansion, | |
dropout=dropout, | |
layer_scale=layer_scale, | |
) | |
self.latents_pos = nn.Parameter( | |
torch.randn(1, self.num_params, hidden_dim), requires_grad=True | |
) | |
self.in_features = nn.Identity() | |
self.project_pin = nn.Linear( | |
hidden_dim, self.pin_params * hidden_dim, bias=False | |
) | |
self.project_deg1 = nn.Linear( | |
hidden_dim, self.deg1_params * hidden_dim, bias=False | |
) | |
self.project_deg2 = nn.Linear( | |
hidden_dim, self.deg2_params * hidden_dim, bias=False | |
) | |
self.project_deg3 = nn.Linear( | |
hidden_dim, self.deg3_params * hidden_dim, bias=False | |
) | |
self.out_pinhole = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=1) | |
self.out_deg1 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3) | |
self.out_deg2 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3) | |
self.out_deg3 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3) | |
def fill_intrinsics(self, x): | |
hfov, cx, cy = x.unbind(dim=-1) | |
hfov = torch.sigmoid(hfov - 1.1) # 1.1 magic number s.t hfov = pi/2 for x=0 | |
ratio = self.shapes[0] / self.shapes[1] | |
vfov = hfov * ratio | |
cx = torch.sigmoid(cx) | |
cy = torch.sigmoid(cy) | |
correction_tensor = torch.tensor( | |
[2 * torch.pi, 2 * torch.pi, self.shapes[1], self.shapes[0]], | |
device=x.device, | |
dtype=x.dtype, | |
) | |
intrinsics = torch.stack([hfov, vfov, cx, cy], dim=1) | |
intrinsics = correction_tensor.unsqueeze(0) * intrinsics | |
return intrinsics | |
def forward(self, cls_tokens) -> torch.Tensor: | |
latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) | |
pin_tokens, deg1_tokens, deg2_tokens, deg3_tokens = cls_tokens.chunk(4, dim=1) | |
pin_tokens = rearrange( | |
self.project_pin(pin_tokens), "b n (h c) -> b (n h) c", h=self.pin_params | |
) | |
deg1_tokens = rearrange( | |
self.project_deg1(deg1_tokens), "b n (h c) -> b (n h) c", h=self.deg1_params | |
) | |
deg2_tokens = rearrange( | |
self.project_deg2(deg2_tokens), "b n (h c) -> b (n h) c", h=self.deg2_params | |
) | |
deg3_tokens = rearrange( | |
self.project_deg3(deg3_tokens), "b n (h c) -> b (n h) c", h=self.deg3_params | |
) | |
tokens = torch.cat([pin_tokens, deg1_tokens, deg2_tokens, deg3_tokens], dim=1) | |
tokens = self.aggregate1(tokens, pos_embed=latents_pos) | |
tokens = self.aggregate2(tokens, pos_embed=latents_pos) | |
tokens_pinhole, tokens_deg1, tokens_deg2, tokens_deg3 = torch.split( | |
tokens, | |
[self.pin_params, self.deg1_params, self.deg2_params, self.deg3_params], | |
dim=1, | |
) | |
x = self.out_pinhole(tokens_pinhole).squeeze(-1) | |
d1 = self.out_deg1(tokens_deg1) | |
d2 = self.out_deg2(tokens_deg2) | |
d3 = self.out_deg3(tokens_deg3) | |
camera_intrinsics = self.fill_intrinsics(x) | |
return camera_intrinsics, torch.cat([d1, d2, d3], dim=1) | |
def set_shapes(self, shapes: tuple[int, int]): | |
self.shapes = shapes | |
class RadialModule(nn.Module): | |
def __init__( | |
self, | |
hidden_dim: int, | |
num_heads: int = 8, | |
expansion: int = 4, | |
depths: int | list[int] = 4, | |
camera_dim: int = 256, | |
dropout: float = 0.0, | |
kernel_size: int = 7, | |
layer_scale: float = 1.0, | |
out_dim: int = 1, | |
num_prompt_blocks: int = 1, | |
use_norm: bool = False, | |
**kwargs, | |
) -> None: | |
super().__init__() | |
self.camera_dim = camera_dim | |
self.out_dim = out_dim | |
self.hidden_dim = hidden_dim | |
self.ups = nn.ModuleList([]) | |
self.depth_mlp = nn.ModuleList([]) | |
self.process_features = nn.ModuleList([]) | |
self.project_features = nn.ModuleList([]) | |
self.out = nn.ModuleList([]) | |
self.prompt_camera = nn.ModuleList([]) | |
mult = 2 | |
self.to_latents = nn.Linear(hidden_dim, hidden_dim) | |
for _ in range(4): | |
self.prompt_camera.append( | |
AttentionLayer( | |
num_blocks=num_prompt_blocks, | |
dim=hidden_dim, | |
num_heads=num_heads, | |
expansion=expansion, | |
dropout=dropout, | |
layer_scale=-1.0, | |
context_dim=hidden_dim, | |
) | |
) | |
for i, depth in enumerate(depths): | |
current_dim = min(hidden_dim, mult * hidden_dim // int(2**i)) | |
next_dim = mult * hidden_dim // int(2 ** (i + 1)) | |
output_dim = max(next_dim, out_dim) | |
self.process_features.append( | |
nn.ConvTranspose2d( | |
hidden_dim, | |
current_dim, | |
kernel_size=max(1, 2 * i), | |
stride=max(1, 2 * i), | |
padding=0, | |
) | |
) | |
self.ups.append( | |
ResUpsampleBil( | |
current_dim, | |
output_dim=output_dim, | |
expansion=expansion, | |
layer_scale=layer_scale, | |
kernel_size=kernel_size, | |
num_layers=depth, | |
use_norm=use_norm, | |
) | |
) | |
depth_mlp = ( | |
nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim)) | |
if i == len(depths) - 1 | |
else nn.Identity() | |
) | |
self.depth_mlp.append(depth_mlp) | |
self.confidence_mlp = nn.Sequential( | |
nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim) | |
) | |
self.to_depth_lr = nn.Conv2d( | |
output_dim, | |
output_dim // 2, | |
kernel_size=3, | |
padding=1, | |
padding_mode="reflect", | |
) | |
self.to_confidence_lr = nn.Conv2d( | |
output_dim, | |
output_dim // 2, | |
kernel_size=3, | |
padding=1, | |
padding_mode="reflect", | |
) | |
self.to_depth_hr = nn.Sequential( | |
nn.Conv2d( | |
output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect" | |
), | |
nn.LeakyReLU(), | |
nn.Conv2d(32, 1, kernel_size=1), | |
) | |
self.to_confidence_hr = nn.Sequential( | |
nn.Conv2d( | |
output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect" | |
), | |
nn.LeakyReLU(), | |
nn.Conv2d(32, 1, kernel_size=1), | |
) | |
def set_original_shapes(self, shapes: tuple[int, int]): | |
self.original_shapes = shapes | |
def set_shapes(self, shapes: tuple[int, int]): | |
self.shapes = shapes | |
def embed_rays(self, rays): | |
rays_embedding = flat_interpolate( | |
rays, old=self.original_shapes, new=self.shapes, antialias=True | |
) | |
rays_embedding = rays_embedding / torch.norm( | |
rays_embedding, dim=-1, keepdim=True | |
).clip(min=1e-4) | |
x, y, z = rays_embedding[..., 0], rays_embedding[..., 1], rays_embedding[..., 2] | |
polar = torch.acos(z) | |
x_clipped = x.abs().clip(min=1e-3) * (2 * (x >= 0).int() - 1) | |
azimuth = torch.atan2(y, x_clipped) | |
rays_embedding = torch.stack([polar, azimuth], dim=-1) | |
rays_embedding = generate_fourier_features( | |
rays_embedding, | |
dim=self.hidden_dim, | |
max_freq=max(self.shapes) // 2, | |
use_log=True, | |
cat_orig=False, | |
) | |
return rays_embedding | |
def condition(self, feat, rays_embeddings): | |
conditioned_features = [ | |
prompter(rearrange(feature, "b h w c -> b (h w) c"), rays_embeddings) | |
for prompter, feature in zip(self.prompt_camera, feat) | |
] | |
return conditioned_features | |
def process(self, features_list, rays_embeddings): | |
conditioned_features = self.condition(features_list, rays_embeddings) | |
init_latents = self.to_latents(conditioned_features[0]) | |
init_latents = rearrange( | |
init_latents, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1] | |
).contiguous() | |
conditioned_features = [ | |
rearrange( | |
x, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1] | |
).contiguous() | |
for x in conditioned_features | |
] | |
latents = init_latents | |
out_features = [] | |
for i, up in enumerate(self.ups): | |
latents = latents + self.process_features[i](conditioned_features[i + 1]) | |
latents = up(latents) | |
out_features.append(latents) | |
return out_features, init_latents | |
def depth_proj(self, out_features): | |
depths = [] | |
h_out, w_out = out_features[-1].shape[-2:] | |
# aggregate output and project to depth | |
for i, (layer, features) in enumerate(zip(self.depth_mlp, out_features)): | |
out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
if i < len(self.depth_mlp) - 1: | |
continue | |
depths.append(out_depth_features) | |
out_depth_features = F.interpolate( | |
out_depth_features, size=(h_out, w_out), mode="bilinear", align_corners=True | |
) | |
logdepth = self.to_depth_lr(out_depth_features) | |
logdepth = F.interpolate( | |
logdepth, size=self.original_shapes, mode="bilinear", align_corners=True | |
) | |
logdepth = self.to_depth_hr(logdepth) | |
return logdepth | |
def confidence_proj(self, out_features): | |
highres_features = out_features[-1].permute(0, 2, 3, 1) | |
confidence = self.confidence_mlp(highres_features).permute(0, 3, 1, 2) | |
confidence = self.to_confidence_lr(confidence) | |
confidence = F.interpolate( | |
confidence, size=self.original_shapes, mode="bilinear", align_corners=True | |
) | |
confidence = self.to_confidence_hr(confidence) | |
return confidence | |
def decode(self, out_features): | |
logdepth = self.depth_proj(out_features) | |
confidence = self.confidence_proj(out_features) | |
return logdepth, confidence | |
def forward( | |
self, | |
features: list[torch.Tensor], | |
rays_hr: torch.Tensor, | |
pos_embed, | |
level_embed, | |
) -> torch.Tensor: | |
rays_embeddings = self.embed_rays(rays_hr) | |
features, lowres_features = self.process(features, rays_embeddings) | |
logdepth, logconf = self.decode(features) | |
return logdepth, logconf, lowres_features | |
class Decoder(nn.Module): | |
def __init__( | |
self, | |
config, | |
): | |
super().__init__() | |
self.build(config) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=0.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.Conv2d): | |
trunc_normal_(m.weight, std=0.02) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
if m.weight is not None: | |
nn.init.constant_(m.weight, 1.0) | |
def run_camera(self, cls_tokens, original_shapes, rays_gt): | |
H, W = original_shapes | |
# camera layer | |
intrinsics, sh_coeffs = self.angular_module(cls_tokens=cls_tokens) | |
B, N = intrinsics.shape | |
device = intrinsics.device | |
dtype = intrinsics.dtype | |
id_coords = coords_grid(B, H, W, device=sh_coeffs.device) | |
# This is fov based | |
longitude = ( | |
(id_coords[:, 0] - intrinsics[:, 2].view(-1, 1, 1)) | |
/ W | |
* intrinsics[:, 0].view(-1, 1, 1) | |
) | |
latitude = ( | |
(id_coords[:, 1] - intrinsics[:, 3].view(-1, 1, 1)) | |
/ H | |
* intrinsics[:, 1].view(-1, 1, 1) | |
) | |
x = torch.cos(latitude) * torch.sin(longitude) | |
z = torch.cos(latitude) * torch.cos(longitude) | |
y = -torch.sin(latitude) | |
unit_sphere = torch.stack([x, y, z], dim=-1) | |
unit_sphere = unit_sphere / torch.norm(unit_sphere, dim=-1, keepdim=True).clip( | |
min=1e-5 | |
) | |
harmonics = rsh_cart_3(unit_sphere)[..., 1:] # remove constant-value harmonic | |
rays_pred = torch.einsum("bhwc,bcd->bhwd", harmonics, sh_coeffs) | |
rays_pred = rays_pred / torch.norm(rays_pred, dim=-1, keepdim=True).clip( | |
min=1e-5 | |
) | |
rays_pred = rays_pred.permute(0, 3, 1, 2) | |
### LEGACY CODE for training | |
# if self.training: | |
# prob = 1 - tanh(self.steps / self.num_steps) | |
# where_use_gt_rays = torch.rand(B, 1, 1, device=device, dtype=dtype) < prob | |
# where_use_gt_rays = where_use_gt_rays.int() | |
# rays = rays_gt * where_use_gt_rays + rays_pred * (1 - where_use_gt_rays) | |
# should clean also nans | |
if self.training: | |
rays = rays_pred | |
else: | |
rays = rays_gt if rays_gt is not None else rays_pred | |
rays = rearrange(rays, "b c h w -> b (h w) c") | |
return intrinsics, rays | |
def forward(self, inputs, image_metas) -> torch.Tensor: | |
B, C, H, W = inputs["image"].shape | |
device = inputs["image"].device | |
rays_gt = inputs.get("rays", None) | |
# get features in b n d format | |
common_shape = inputs["features"][0].shape[1:3] | |
# input shapes repeat shapes for each level, times the amount of the layers: | |
features = self.input_adapter(inputs["features"]) | |
# positional embeddings, spatial and level | |
level_embed = self.level_embeds.repeat( | |
B, common_shape[0] * common_shape[1], 1, 1 | |
) | |
level_embed = rearrange(level_embed, "b n l d -> b (n l) d") | |
dummy_tensor = torch.zeros( | |
B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False | |
) | |
pos_embed = self.pos_embed(dummy_tensor) | |
pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(1, 4, 1) | |
# get cls tokens projections | |
camera_tokens = inputs["tokens"] | |
camera_tokens = [self.choker(x.contiguous()) for x in camera_tokens] | |
camera_tokens = self.camera_token_adapter(camera_tokens) | |
self.angular_module.set_shapes((H, W)) | |
intrinsics, rays = self.run_camera( | |
torch.cat(camera_tokens, dim=1), | |
original_shapes=(H, W), | |
rays_gt=rays_gt, | |
) | |
# run bulk of the model | |
self.radial_module.set_shapes(common_shape) | |
self.radial_module.set_original_shapes((H, W)) | |
logradius, logconfidence, lowres_features = self.radial_module( | |
features=features, | |
rays_hr=rays, | |
pos_embed=pos_embed, | |
level_embed=level_embed, | |
) | |
radius = torch.exp(logradius.clip(min=-8.0, max=8.0) + 2.0) | |
confidence = torch.exp(logconfidence.clip(min=-8.0, max=10.0)) | |
outputs = { | |
"distance": radius, | |
"lowres_features": lowres_features, | |
"confidence": confidence, | |
"K": intrinsics, | |
"rays": rays, | |
} | |
return outputs | |
def no_weight_decay_keywords(self): | |
return {"latents_pos", "level_embeds"} | |
def get_params(self, lr, wd): | |
angles_p, _ = get_params(self.angular_module, lr, wd) | |
radius_p, _ = get_params(self.radial_module, lr, wd) | |
tokens_p, _ = get_params(self.camera_token_adapter, lr, wd) | |
input_p, _ = get_params(self.input_adapter, lr, wd) | |
return [*tokens_p, *angles_p, *input_p, *radius_p] | |
def build(self, config): | |
input_dims = config["model"]["pixel_encoder"]["embed_dims"] | |
hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] | |
expansion = config["model"]["expansion"] | |
num_heads = config["model"]["num_heads"] | |
dropout = config["model"]["pixel_decoder"]["dropout"] | |
layer_scale = config["model"]["layer_scale"] | |
depth = config["model"]["pixel_decoder"]["depths"] | |
depths_encoder = config["model"]["pixel_encoder"]["depths"] | |
out_dim = config["model"]["pixel_decoder"]["out_dim"] | |
kernel_size = config["model"]["pixel_decoder"]["kernel_size"] | |
self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder)) | |
input_dims = [input_dims[d - 1] for d in depths_encoder] | |
self.steps = 0 | |
self.num_steps = config["model"].get("num_steps", 100000) | |
camera_dims = input_dims | |
self.choker = GradChoker(config["model"]["pixel_decoder"]["detach"]) | |
self.input_adapter = ListAdapter(input_dims, hidden_dim) | |
self.camera_token_adapter = ListAdapter(camera_dims, hidden_dim) | |
self.angular_module = AngularModule( | |
hidden_dim=hidden_dim, | |
num_heads=num_heads, | |
expansion=expansion, | |
dropout=dropout, | |
layer_scale=layer_scale, | |
) | |
self.radial_module = RadialModule( | |
hidden_dim=hidden_dim, | |
num_heads=num_heads, | |
expansion=expansion, | |
depths=depth, | |
dropout=dropout, | |
camera_dim=96, | |
layer_scale=layer_scale, | |
out_dim=out_dim, | |
kernel_size=kernel_size, | |
num_prompt_blocks=config["model"]["pixel_decoder"]["num_prompt_blocks"], | |
use_norm=False, | |
) | |
self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) | |
self.level_embeds = nn.Parameter( | |
orthonormal_init(len(input_dims), hidden_dim).reshape( | |
1, 1, len(input_dims), hidden_dim | |
), | |
requires_grad=False, | |
) | |