File size: 4,475 Bytes
fe3e74d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from kornia import create_meshgrid


def project_and_normalize(ref_grid, src_proj, length):
    """

    @param ref_grid: b 3 n
    @param src_proj: b 4 4
    @param length:   int
    @return:  b, n, 2
    """
    src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n
    div_val = src_grid[:, -1:]
    div_val[div_val<1e-4] = 1e-4
    src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n)
    src_grid[:, 0] = src_grid[:, 0]/((length - 1) / 2) - 1 # scale to -1~1
    src_grid[:, 1] = src_grid[:, 1]/((length - 1) / 2) - 1 # scale to -1~1
    src_grid = src_grid.permute(0, 2, 1) # (b, n, 2)
    return src_grid


def construct_project_matrix(x_ratio, y_ratio, Ks, poses):
    """
    @param x_ratio: float
    @param y_ratio: float
    @param Ks:      b,3,3
    @param poses:   b,3,4
    @return:
    """
    rfn = Ks.shape[0]
    scale_m = torch.tensor([x_ratio, y_ratio, 1.0], dtype=torch.float32, device=Ks.device)
    scale_m = torch.diag(scale_m)
    ref_prj = scale_m[None, :, :] @ Ks @ poses  # rfn,3,4
    pad_vals = torch.zeros([rfn, 1, 4], dtype=torch.float32, device=ref_prj.device)
    pad_vals[:, :, 3] = 1.0
    ref_prj = torch.cat([ref_prj, pad_vals], 1)  # rfn,4,4
    return ref_prj

def get_warp_coordinates(volume_xyz, warp_size, input_size, Ks, warp_pose):
    B, _, D, H, W = volume_xyz.shape
    ratio = warp_size / input_size
    warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4
    warp_coords = project_and_normalize(volume_xyz.view(B,3,D*H*W), warp_proj, warp_size).view(B, D, H, W, 2)
    return warp_coords


def create_target_volume(depth_size, volume_size, input_image_size, pose_target, K, near=None, far=None):
    device, dtype = pose_target.device, pose_target.dtype

    # compute a depth range on the unit sphere
    H, W, D, B = volume_size, volume_size, depth_size, pose_target.shape[0]
    if near is not None and far is not None :
        # near, far b,1,h,w
        depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
        depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1
        depth_values = depth_values * (far - near) + near # b d h w
        depth_values = depth_values.view(B, 1, D, H * W)
    else:
        near, far = near_far_from_unit_sphere_using_camera_poses(pose_target) # b 1
        depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
        depth_values = depth_values[None,:,None] * (far[:,None,:] - near[:,None,:]) + near[:,None,:] # b d 1
        depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H*W)

    ratio = volume_size / input_image_size

    # creat a grid on the target (reference) view
    # H, W, D, B = volume_size, volume_size, depth_values.shape[1], depth_values.shape[0]

    # creat mesh grid: note reference also means target
    ref_grid = create_meshgrid(H, W, normalized_coordinates=False)  # (1, H, W, 2)
    ref_grid = ref_grid.to(device).to(dtype)
    ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W)
    ref_grid = ref_grid.reshape(1, 2, H*W) # (1, 2, H*W)
    ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
    ref_grid = torch.cat((ref_grid, torch.ones(B, 1, H*W, dtype=ref_grid.dtype, device=ref_grid.device)), dim=1) # (B, 3, H*W)
    ref_grid = ref_grid.unsqueeze(2) * depth_values  # (B, 3, D, H*W)

    # unproject to space and transfer to world coordinates.
    Ks = K
    ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4
    ref_proj_inv = torch.inverse(ref_proj) # B,4,4
    ref_grid = ref_proj_inv[:,:3,:3] @ ref_grid.view(B,3,D*H*W) + ref_proj_inv[:,:3,3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW
    return ref_grid.reshape(B,3,D,H,W), depth_values.view(B,1,D,H,W)

def near_far_from_unit_sphere_using_camera_poses(camera_poses):
    """
    @param camera_poses: b 3 4
    @return:
        near: b,1
        far: b,1
    """
    R_w2c = camera_poses[..., :3, :3] # b 3 3
    t_w2c = camera_poses[..., :3, 3:] # b 3 1
    camera_origin = -R_w2c.permute(0,2,1) @ t_w2c   # b 3 1
    # R_w2c.T @ (0,0,1) = z_dir
    camera_orient = R_w2c.permute(0,2,1)[...,:3,2:3] # b 3 1
    camera_origin, camera_orient = camera_origin[...,0], camera_orient[..., 0] # b 3
    a = torch.sum(camera_orient ** 2, dim=-1, keepdim=True) # b 1
    b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1
    mid = b / a # b 1
    near, far = mid - 1.0, mid + 1.0
    return near, far