File size: 5,569 Bytes
1ea89dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np
import torch
from einops import rearrange

from unik3d.utils.camera import CameraSampler
from unik3d.utils.coordinate import coords_grid
from unik3d.utils.geometric import iou

try:
    from splatting import splatting_function
except Exception as e:
    splatting_function = None
    print(
        f"Splatting not available, please install it from github.com/hperrot/splatting"
    )


def fill(self, rgb, mask):
    def fill_noise(size, device):
        return torch.normal(0, 1.0, size=size, device=device)

    def fill_black(size, device):
        return -2 * torch.ones(size, device=device, dtype=torch.float32)

    def fill_white(size, device):
        return 2 * torch.ones(size, device=device, dtype=torch.float32)

    def fill_zero(size, device):
        return torch.zeros(size, device=device, dtype=torch.float32)

    B, C = rgb.shape[:2]
    validity_mask = mask.repeat(1, C, 1, 1).bool()
    for i in range(B):
        filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero])
        rgb[i][~validity_mask[i]] = filler_fn(
            size=rgb[i][~validity_mask[i]].shape, device=rgb.device
        )
    return rgb


@torch.autocast(device_type="cuda", enabled=True, dtype=torch.float32)
def augment_camera(self, inputs, camera_sampler):
    rgb = inputs["image"]
    gt = inputs["depth"].clone()
    guidance = inputs[
        "depth_guidance"
    ]  # from GT if dense/synthetic or from a model's metric output
    validity_mask = inputs["validity_mask"].bool()
    dtype, device = gt.dtype, gt.device
    B, C, H, W = rgb.shape
    augmentable_indices = inputs["valid_camera"] & (
        inputs["depth_mask"].reshape(B, -1).float().mean(dim=1) > 0.0
    )

    augment_indices = torch.rand(B, 1, 1, device=device, dtype=dtype) > 0.9
    augment_indices[~augmentable_indices] = False
    id_coords = coords_grid(B, H, W, device=device)
    # get rescaled depth
    augment_indices = augment_indices.reshape(-1)
    for i, is_augment in enumerate(augment_indices):
        if not is_augment:
            continue

        pinhole_camera = inputs["camera"][i]
        fov = max(pinhole_camera.hfov[0], pinhole_camera.vfov[0]) * 180 / np.pi
        ratio = min(70.0 / fov, 1.0)  # decrease effect for larger fov
        if fov < 40.0:  # skips ~5%
            augment_indices[i] = False
            continue

        rgb_i = rgb[i : i + 1]
        id_coords_i = id_coords[i : i + 1]

        validity_mask_i = validity_mask[i : i + 1]
        depth = guidance[i : i + 1]

        if (depth < 0.0).any():
            augment_indices[i] = False
            continue

        depth = depth.sqrt()  # why sqrt??
        depth[~validity_mask_i] = depth.max() * 2.0

        fx, fy, cx, cy = pinhole_camera.params[:, :4].unbind(dim=-1)
        new_camera = camera_sampler(fx, fy, cx, cy, mult=1.0, ratio=ratio, H=H)
        unprojected = pinhole_camera.reconstruct(depth)
        projected = new_camera.project(unprojected)
        projection_mask = new_camera.projection_mask
        overlap_mask = (
            new_camera.overlap_mask
            if new_camera.overlap_mask is not None
            else torch.ones_like(projection_mask)
        )
        mask = validity_mask_i & overlap_mask

        # if it is actually going out, we need to remember the regions
        # remember when the tengetial distortion was keeping the validaty_mask border after re-warpingi
        # need a better way to define overlap class, in case of vortex style if will mask wrong parts...
        # also is_collapse does not take into consideration when we have vortex effect,
        # how can we avoid vortex in the first place????
        is_collapse = (projected[0, 1, 0, :] >= 0.0).all()
        if is_collapse:
            projected[~mask.repeat(1, 2, 1, 1)] = id_coords_i[~mask.repeat(1, 2, 1, 1)]
        flow = projected - id_coords_i
        depth[~mask] = depth.max() * 2.0

        if flow.norm(dim=1).median() / max(H, W) > 0.1:  # extreme cases
            augment_indices[i] = False
            continue

        # warp via soft splat
        depth_image = torch.cat([rgb_i, guidance[i : i + 1], mask], dim=1)
        depth_image = splatting_function(
            "softmax", depth_image, flow, -torch.log(1 + depth.clip(0.01))
        )
        rgb_warp = depth_image[:, :3]
        validity_mask_i = depth_image[:, -1:] > 0.0

        expanding = validity_mask_i.sum() > validity_mask[i : i + 1].sum()
        threshold = 0.7 if expanding else 0.25
        _iou = iou(validity_mask_i, validity_mask[i : i + 1])
        if _iou < threshold:  # too strong augmentation, lose most of the image
            augment_indices[i] = False
            continue

        # where it goes out
        mask_unwarpable = projection_mask & overlap_mask
        inputs["depth_mask"][i] = inputs["depth_mask"][i] & mask_unwarpable.squeeze(0)

        # compute new rays, and use the for supervision
        rays = new_camera.get_rays(shapes=(1, H, W))
        rays = rearrange(rays, "b c h w -> b (h w) c")
        inputs["rays"][i] = torch.where(
            rays.isnan().any(dim=-1, keepdim=True), 0.0, rays
        )[0]

        # update image, camera and validity_mask
        inputs["camera"][i] = new_camera
        inputs["image"][i] = self.fill(rgb_warp, validity_mask_i)[0]
        inputs["validity_mask"][i] = inputs["validity_mask"][i] & mask_unwarpable[0]

        # needed to reverse the augmentation for loss-computation (i.e. un-warp the prediction)
        inputs["grid_sample"][i] = projected[0]

    return inputs