import torch import torch.nn as nn import torch.nn.functional as F from einops.einops import rearrange INF = 1e9 def mask_border(m, b: int, v): """ Mask borders with value Args: m (torch.Tensor): [N, H0, W0, H1, W1] b (int) v (m.dtype) """ if b <= 0: return m[:, :b] = v m[:, :, :b] = v m[:, :, :, :b] = v m[:, :, :, :, :b] = v m[:, -b:] = v m[:, :, -b:] = v m[:, :, :, -b:] = v m[:, :, :, :, -b:] = v def mask_border_with_padding(m, bd, v, p_m0, p_m1): if bd <= 0: return m[:, :bd] = v m[:, :, :bd] = v m[:, :, :, :bd] = v m[:, :, :, :, :bd] = v h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): m[b_idx, h0 - bd:] = v m[b_idx, :, w0 - bd:] = v m[b_idx, :, :, h1 - bd:] = v m[b_idx, :, :, :, w1 - bd:] = v def compute_max_candidates(p_m0, p_m1): """Compute the max candidates of all pairs within a batch Args: p_m0, p_m1 (torch.Tensor): padded masks """ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] max_cand = torch.sum( torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) return max_cand class CoarseMatching(nn.Module): def __init__(self, config): super().__init__() self.config = config # general config d_model = config['d_model'] self.thr = config['thr'] self.inference = config['inference'] self.border_rm = config['border_rm'] # -- # for trainig fine-level XoFTR self.train_coarse_percent = config['train_coarse_percent'] self.train_pad_num_gt_min = config['train_pad_num_gt_min'] self.final_proj = nn.Linear(d_model, d_model, bias=True) self.temperature = config['dsmax_temperature'] def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): """ Args: feat0 (torch.Tensor): [N, L, C] feat1 (torch.Tensor): [N, S, C] data (dict) mask_c0 (torch.Tensor): [N, L] (optional) mask_c1 (torch.Tensor): [N, S] (optional) Update: data (dict): { 'b_ids' (torch.Tensor): [M'], 'i_ids' (torch.Tensor): [M'], 'j_ids' (torch.Tensor): [M'], 'gt_mask' (torch.Tensor): [M'], 'mkpts0_c' (torch.Tensor): [M, 2], 'mkpts1_c' (torch.Tensor): [M, 2], 'mconf' (torch.Tensor): [M]} NOTE: M' != M during training. """ feat_c0 = self.final_proj(feat_c0) feat_c1 = self.final_proj(feat_c1) # normalize feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, [feat_c0, feat_c1]) sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) / self.temperature if mask_c0 is not None: sim_matrix.masked_fill_( ~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF) if self.inference: # predict coarse matches from conf_matrix data.update(**self.get_coarse_match_inference(sim_matrix, data)) else: conf_matrix_0_to_1 = F.softmax(sim_matrix, 2) conf_matrix_1_to_0 = F.softmax(sim_matrix, 1) data.update({'conf_matrix_0_to_1': conf_matrix_0_to_1, 'conf_matrix_1_to_0': conf_matrix_1_to_0 }) # predict coarse matches from conf_matrix data.update(**self.get_coarse_match_training(conf_matrix_0_to_1, conf_matrix_1_to_0, data)) @torch.no_grad() def get_coarse_match_training(self, conf_matrix_0_to_1, conf_matrix_1_to_0, data): """ Args: conf_matrix_0_to_1 (torch.Tensor): [N, L, S] conf_matrix_1_to_0 (torch.Tensor): [N, L, S] data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] Returns: coarse_matches (dict): { 'b_ids' (torch.Tensor): [M'], 'i_ids' (torch.Tensor): [M'], 'j_ids' (torch.Tensor): [M'], 'gt_mask' (torch.Tensor): [M'], 'm_bids' (torch.Tensor): [M], 'mkpts0_c' (torch.Tensor): [M, 2], 'mkpts1_c' (torch.Tensor): [M, 2], 'mconf' (torch.Tensor): [M]} """ axes_lengths = { 'h0c': data['hw0_c'][0], 'w0c': data['hw0_c'][1], 'h1c': data['hw1_c'][0], 'w1c': data['hw1_c'][1] } _device = conf_matrix_0_to_1.device # confidence thresholding # {(nearest neighbour for 0 to 1) U (nearest neighbour for 1 to 0)} mask = torch.logical_or((conf_matrix_0_to_1 > self.thr) * (conf_matrix_0_to_1 == conf_matrix_0_to_1.max(dim=2, keepdim=True)[0]), (conf_matrix_1_to_0 > self.thr) * (conf_matrix_1_to_0 == conf_matrix_1_to_0.max(dim=1, keepdim=True)[0])) mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', **axes_lengths) if 'mask0' not in data: mask_border(mask, self.border_rm, False) else: mask_border_with_padding(mask, self.border_rm, False, data['mask0'], data['mask1']) mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', **axes_lengths) # find all valid coarse matches b_ids, i_ids, j_ids = mask.nonzero(as_tuple=True) mconf = torch.maximum(conf_matrix_0_to_1[b_ids, i_ids, j_ids], conf_matrix_1_to_0[b_ids, i_ids, j_ids]) # random sampling of training samples for fine-level XoFTR # (optional) pad samples with gt coarse-level matches if self.training: # NOTE: # the sampling is performed across all pairs in a batch without manually balancing # samples for fine-level increases w.r.t. batch_size if 'mask0' not in data: num_candidates_max = mask.size(0) * max( mask.size(1), mask.size(2)) else: num_candidates_max = compute_max_candidates( data['mask0'], data['mask1']) num_matches_train = int(num_candidates_max * self.train_coarse_percent) num_matches_pred = len(b_ids) assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" # pred_indices is to select from prediction if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: pred_indices = torch.arange(num_matches_pred, device=_device) else: pred_indices = torch.randint( num_matches_pred, (num_matches_train - self.train_pad_num_gt_min, ), device=_device) # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) gt_pad_indices = torch.randint( len(data['spv_b_ids']), (max(num_matches_train - num_matches_pred, self.train_pad_num_gt_min), ), device=_device) mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero b_ids, i_ids, j_ids, mconf = map( lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], dim=0), *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) # these matches are selected patches that feed into fine-level network coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} # update with matches in original image resolution scale = data['hw0_i'][0] / data['hw0_c'][0] scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale mkpts0_c = torch.stack( [i_ids % data['hw0_c'][1], torch.div(i_ids, data['hw0_c'][1], rounding_mode='trunc')], dim=1) * scale0 mkpts1_c = torch.stack( [j_ids % data['hw1_c'][1], torch.div(j_ids, data['hw1_c'][1], rounding_mode='trunc')], dim=1) * scale1 # these matches is the current prediction (for visualization) coarse_matches.update({ 'gt_mask': mconf == 0, 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches 'mkpts0_c': mkpts0_c[mconf != 0], 'mkpts1_c': mkpts1_c[mconf != 0], 'mconf': mconf[mconf != 0] }) return coarse_matches @torch.no_grad() def get_coarse_match_inference(self, sim_matrix, data): """ Args: sim_matrix (torch.Tensor): [N, L, S] data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] Returns: coarse_matches (dict): { 'b_ids' (torch.Tensor): [M'], 'i_ids' (torch.Tensor): [M'], 'j_ids' (torch.Tensor): [M'], 'gt_mask' (torch.Tensor): [M'], 'm_bids' (torch.Tensor): [M], 'mkpts0_c' (torch.Tensor): [M, 2], 'mkpts1_c' (torch.Tensor): [M, 2], 'mconf' (torch.Tensor): [M]} """ axes_lengths = { 'h0c': data['hw0_c'][0], 'w0c': data['hw0_c'][1], 'h1c': data['hw1_c'][0], 'w1c': data['hw1_c'][1] } # softmax for 0 to 1 conf_matrix_ = F.softmax(sim_matrix, 2) # confidence thresholding and nearest neighbour for 0 to 1 mask = (conf_matrix_ > self.thr) * (conf_matrix_ == conf_matrix_.max(dim=2, keepdim=True)[0]) # unlike training, reuse the same conf martix to decrease the vram consumption # softmax for 0 to 1 conf_matrix_ = F.softmax(sim_matrix, 1) # update mask {(nearest neighbour for 0 to 1) U (nearest neighbour for 1 to 0)} mask = torch.logical_or(mask, (conf_matrix_ > self.thr) * (conf_matrix_ == conf_matrix_.max(dim=1, keepdim=True)[0])) mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', **axes_lengths) if 'mask0' not in data: mask_border(mask, self.border_rm, False) else: mask_border_with_padding(mask, self.border_rm, False, data['mask0'], data['mask1']) mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', **axes_lengths) # find all valid coarse matches b_ids, i_ids, j_ids = mask.nonzero(as_tuple=True) # mconf = torch.maximum(conf_matrix_0_to_1[b_ids, i_ids, j_ids], conf_matrix_1_to_0[b_ids, i_ids, j_ids]) # these matches are selected patches that feed into fine-level network coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} # update with matches in original image resolution scale = data['hw0_i'][0] / data['hw0_c'][0] scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale mkpts0_c = torch.stack( [i_ids % data['hw0_c'][1], torch.div(i_ids, data['hw0_c'][1], rounding_mode='trunc')], dim=1) * scale0 mkpts1_c = torch.stack( [j_ids % data['hw1_c'][1], torch.div(j_ids, data['hw1_c'][1], rounding_mode='trunc')], dim=1) * scale1 # these matches are the current coarse level predictions coarse_matches.update({ 'm_bids': b_ids, # mconf == 0 => gt matches 'mkpts0_c': mkpts0_c, 'mkpts1_c': mkpts1_c, }) return coarse_matches