Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,569 Bytes
1ea89dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import numpy as np
import torch
from einops import rearrange
from unik3d.utils.camera import CameraSampler
from unik3d.utils.coordinate import coords_grid
from unik3d.utils.geometric import iou
try:
from splatting import splatting_function
except Exception as e:
splatting_function = None
print(
f"Splatting not available, please install it from github.com/hperrot/splatting"
)
def fill(self, rgb, mask):
def fill_noise(size, device):
return torch.normal(0, 1.0, size=size, device=device)
def fill_black(size, device):
return -2 * torch.ones(size, device=device, dtype=torch.float32)
def fill_white(size, device):
return 2 * torch.ones(size, device=device, dtype=torch.float32)
def fill_zero(size, device):
return torch.zeros(size, device=device, dtype=torch.float32)
B, C = rgb.shape[:2]
validity_mask = mask.repeat(1, C, 1, 1).bool()
for i in range(B):
filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero])
rgb[i][~validity_mask[i]] = filler_fn(
size=rgb[i][~validity_mask[i]].shape, device=rgb.device
)
return rgb
@torch.autocast(device_type="cuda", enabled=True, dtype=torch.float32)
def augment_camera(self, inputs, camera_sampler):
rgb = inputs["image"]
gt = inputs["depth"].clone()
guidance = inputs[
"depth_guidance"
] # from GT if dense/synthetic or from a model's metric output
validity_mask = inputs["validity_mask"].bool()
dtype, device = gt.dtype, gt.device
B, C, H, W = rgb.shape
augmentable_indices = inputs["valid_camera"] & (
inputs["depth_mask"].reshape(B, -1).float().mean(dim=1) > 0.0
)
augment_indices = torch.rand(B, 1, 1, device=device, dtype=dtype) > 0.9
augment_indices[~augmentable_indices] = False
id_coords = coords_grid(B, H, W, device=device)
# get rescaled depth
augment_indices = augment_indices.reshape(-1)
for i, is_augment in enumerate(augment_indices):
if not is_augment:
continue
pinhole_camera = inputs["camera"][i]
fov = max(pinhole_camera.hfov[0], pinhole_camera.vfov[0]) * 180 / np.pi
ratio = min(70.0 / fov, 1.0) # decrease effect for larger fov
if fov < 40.0: # skips ~5%
augment_indices[i] = False
continue
rgb_i = rgb[i : i + 1]
id_coords_i = id_coords[i : i + 1]
validity_mask_i = validity_mask[i : i + 1]
depth = guidance[i : i + 1]
if (depth < 0.0).any():
augment_indices[i] = False
continue
depth = depth.sqrt() # why sqrt??
depth[~validity_mask_i] = depth.max() * 2.0
fx, fy, cx, cy = pinhole_camera.params[:, :4].unbind(dim=-1)
new_camera = camera_sampler(fx, fy, cx, cy, mult=1.0, ratio=ratio, H=H)
unprojected = pinhole_camera.reconstruct(depth)
projected = new_camera.project(unprojected)
projection_mask = new_camera.projection_mask
overlap_mask = (
new_camera.overlap_mask
if new_camera.overlap_mask is not None
else torch.ones_like(projection_mask)
)
mask = validity_mask_i & overlap_mask
# if it is actually going out, we need to remember the regions
# remember when the tengetial distortion was keeping the validaty_mask border after re-warpingi
# need a better way to define overlap class, in case of vortex style if will mask wrong parts...
# also is_collapse does not take into consideration when we have vortex effect,
# how can we avoid vortex in the first place????
is_collapse = (projected[0, 1, 0, :] >= 0.0).all()
if is_collapse:
projected[~mask.repeat(1, 2, 1, 1)] = id_coords_i[~mask.repeat(1, 2, 1, 1)]
flow = projected - id_coords_i
depth[~mask] = depth.max() * 2.0
if flow.norm(dim=1).median() / max(H, W) > 0.1: # extreme cases
augment_indices[i] = False
continue
# warp via soft splat
depth_image = torch.cat([rgb_i, guidance[i : i + 1], mask], dim=1)
depth_image = splatting_function(
"softmax", depth_image, flow, -torch.log(1 + depth.clip(0.01))
)
rgb_warp = depth_image[:, :3]
validity_mask_i = depth_image[:, -1:] > 0.0
expanding = validity_mask_i.sum() > validity_mask[i : i + 1].sum()
threshold = 0.7 if expanding else 0.25
_iou = iou(validity_mask_i, validity_mask[i : i + 1])
if _iou < threshold: # too strong augmentation, lose most of the image
augment_indices[i] = False
continue
# where it goes out
mask_unwarpable = projection_mask & overlap_mask
inputs["depth_mask"][i] = inputs["depth_mask"][i] & mask_unwarpable.squeeze(0)
# compute new rays, and use the for supervision
rays = new_camera.get_rays(shapes=(1, H, W))
rays = rearrange(rays, "b c h w -> b (h w) c")
inputs["rays"][i] = torch.where(
rays.isnan().any(dim=-1, keepdim=True), 0.0, rays
)[0]
# update image, camera and validity_mask
inputs["camera"][i] = new_camera
inputs["image"][i] = self.fill(rgb_warp, validity_mask_i)[0]
inputs["validity_mask"][i] = inputs["validity_mask"][i] & mask_unwarpable[0]
# needed to reverse the augmentation for loss-computation (i.e. un-warp the prediction)
inputs["grid_sample"][i] = projected[0]
return inputs
|