""" "LiftFeat: 3D Geometry-Aware Local Feature Matching" MegaDepth data handling was adapted from LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py """ import torch from kornia.utils import create_meshgrid import matplotlib.pyplot as plt import pdb import cv2 @torch.no_grad() def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): """ Warp kpts0 from I0 to I1 with depth, K and Rt Also check covisibility and depth consistency. Depth is consistent if relative error < 0.2 (hard-coded). Args: kpts0 (torch.Tensor): [N, L, 2] - , depth0 (torch.Tensor): [N, H, W], depth1 (torch.Tensor): [N, H, W], T_0to1 (torch.Tensor): [N, 3, 4], K0 (torch.Tensor): [N, 3, 3], K1 (torch.Tensor): [N, 3, 3], Returns: calculable_mask (torch.Tensor): [N, L] warped_keypoints0 (torch.Tensor): [N, L, 2] """ kpts0_long = kpts0.round().long().clip(0, 2000-1) depth0[:, 0, :] = 0 ; depth1[:, 0, :] = 0 depth0[:, :, 0] = 0 ; depth1[:, :, 0] = 0 # Sample depth, get calculable_mask on depth != 0 kpts0_depth = torch.stack( [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 ) # (N, L) nonzero_mask = kpts0_depth > 0 # Draw cross marks on the image for each keypoint # for b in range(len(kpts0)): # fig, ax = plt.subplots(1,2) # depth_np = depth0.numpy()[b] # depth_np_plot = depth_np.copy() # for x, y in kpts0_long[b, nonzero_mask[b], :].numpy(): # cv2.drawMarker(depth_np_plot, (x, y), (255), cv2.MARKER_CROSS, markerSize=10, thickness=2) # ax[0].imshow(depth_np) # ax[1].imshow(depth_np_plot) # Unproject kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) # Rigid Transform w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] # Project w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-5) # (N, L, 2), +1e-4 to avoid zero depth # Covisible Check # h, w = depth1.shape[1:3] # covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ # (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) # w_kpts0_long = w_kpts0.long() # w_kpts0_long[~covisible_mask, :] = 0 # w_kpts0_depth = torch.stack( # [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 # ) # (N, L) # consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 valid_mask = nonzero_mask #* consistent_mask* covisible_mask return valid_mask, w_kpts0 @torch.no_grad() def spvs_coarse(data, scale = 8): """ Supervise corresp with dense depth & camera poses """ # 1. misc device = data['image0'].device N, _, H0, W0 = data['image0'].shape _, _, H1, W1 = data['image1'].shape #scale = 8 scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) # 2. warp grids # create kpts in meshgrid and resize them to image resolution grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) # [N, hw, 2] grid_pt1_i = scale1 * grid_pt1_c # warp kpts bi-directionally and check reproj error nonzero_m1, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) nonzero_m2, w_pt1_og = warp_kpts( w_pt1_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) dist = torch.linalg.norm( grid_pt1_i - w_pt1_og, dim=-1) mask_mutual = (dist < 1.5) & nonzero_m1 & nonzero_m2 #_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) batched_corrs = [ torch.cat([w_pt1_i[i, mask_mutual[i]] / data['scale0'][i], grid_pt1_i[i, mask_mutual[i]] / data['scale1'][i]],dim=-1) for i in range(len(mask_mutual))] #Remove repeated correspondences - this is important for network convergence corrs = [] for pts in batched_corrs: lut_mat12 = torch.ones((h1, w1, 4), device = device, dtype = torch.float32) * -1 lut_mat21 = torch.clone(lut_mat12) src_pts = pts[:, :2] / scale tgt_pts = pts[:, 2:] / scale try: lut_mat12[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) mask_valid12 = torch.all(lut_mat12 >= 0, dim=-1) points = lut_mat12[mask_valid12] #Target-src check src_pts, tgt_pts = points[:, :2], points[:, 2:] lut_mat21[tgt_pts[:,1].long(), tgt_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1) mask_valid21 = torch.all(lut_mat21 >= 0, dim=-1) points = lut_mat21[mask_valid21] corrs.append(points) except: pdb.set_trace() print('..') #Plot for debug purposes # for i in range(len(corrs)): # plot_corrs(data['image0'][i], data['image1'][i], corrs[i][:, :2]*8, corrs[i][:, 2:]*8) return corrs @torch.no_grad() def get_correspondences(pts2, data, idx): device = data['image0'].device N, _, H0, W0 = data['image0'].shape _, _, H1, W1 = data['image1'].shape pts2 = pts2[None, ...] scale0 = data['scale0'][idx, None][None, ...] if 'scale0' in data else 1 scale1 = data['scale1'][idx, None][None, ...] if 'scale1' in data else 1 pts2 = scale1 * pts2 * 8 # warp kpts bi-directionally and check reproj error nonzero_m1, pts1 = warp_kpts(pts2, data['depth1'][idx][None, ...], data['depth0'][idx][None, ...], data['T_1to0'][idx][None, ...], data['K1'][idx][None, ...], data['K0'][idx][None, ...]) corrs = torch.cat([pts1[0, :] / data['scale0'][idx], pts2[0, :] / data['scale1'][idx]],dim=-1) #plot_corrs(data['image0'][idx], data['image1'][idx], corrs[:, :2], corrs[:, 2:]) return corrs