Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class FineSubMatching(nn.Module): | |
"""Fine-level and Sub-pixel matching""" | |
def __init__(self, config): | |
super().__init__() | |
self.temperature = config['fine']['dsmax_temperature'] | |
self.W_f = config['fine_window_size'] | |
self.denser = config['fine']['denser'] | |
self.inference = config['fine']['inference'] | |
dim_f = config['resnet']['block_dims'][0] | |
self.fine_thr = config['fine']['thr'] | |
self.fine_proj = nn.Linear(dim_f, dim_f, bias=False) | |
self.subpixel_mlp = nn.Sequential(nn.Linear(2*dim_f, 2*dim_f, bias=False), | |
nn.ReLU(), | |
nn.Linear(2*dim_f, 4, bias=False)) | |
def forward(self, feat_f0_unfold, feat_f1_unfold, data): | |
""" | |
Args: | |
feat_f0_unfold (torch.Tensor): [M, WW, C] | |
feat_f1_unfold (torch.Tensor): [M, WW, C] | |
data (dict) | |
Update: | |
data (dict):{ | |
'expec_f' (torch.Tensor): [M, 3], | |
'mkpts0_f' (torch.Tensor): [M, 2], | |
'mkpts1_f' (torch.Tensor): [M, 2]} | |
""" | |
feat_f0 = self.fine_proj(feat_f0_unfold) | |
feat_f1 = self.fine_proj(feat_f1_unfold) | |
M, WW, C = feat_f0.shape | |
W_f = self.W_f | |
# corner case: if no coarse matches found | |
if M == 0: | |
assert self.training == False, "M is always >0, when training, see coarse_matching.py" | |
# logger.warning('No matches found in coarse-level.') | |
data.update({ | |
'mkpts0_f': data['mkpts0_c'], | |
'mkpts1_f': data['mkpts1_c'], | |
'mconf_f': torch.zeros(0, device=feat_f0_unfold.device), | |
# 'mkpts0_f_train': data['mkpts0_c'], | |
# 'mkpts1_f_train': data['mkpts1_c'], | |
# 'conf_matrix_fine': torch.zeros(1, W_f*W_f, W_f*W_f, device=feat_f0.device) | |
}) | |
return | |
# normalize | |
feat_f0, feat_f1 = map(lambda feat: feat / feat.shape[-1]**.5, | |
[feat_f0, feat_f1]) | |
sim_matrix = torch.einsum("nlc,nsc->nls", feat_f0, | |
feat_f1) / self.temperature | |
conf_matrix_fine = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) | |
data.update({'conf_matrix_fine': conf_matrix_fine}) | |
# predict fine-level and sub-pixel matches from conf_matrix | |
data.update(**self.get_fine_sub_match(conf_matrix_fine, feat_f0_unfold, feat_f1_unfold, data)) | |
def get_fine_sub_match(self, conf_matrix_fine, feat_f0_unfold, feat_f1_unfold, data): | |
""" | |
Args: | |
conf_matrix_fine (torch.Tensor): [M, WW, WW] | |
feat_f0_unfold (torch.Tensor): [M, WW, C] | |
feat_f1_unfold (torch.Tensor): [M, WW, C] | |
data (dict) | |
Update: | |
data (dict):{ | |
'm_bids' (torch.Tensor): [M] | |
'expec_f' (torch.Tensor): [M, 3], | |
'mkpts0_f' (torch.Tensor): [M, 2], | |
'mkpts1_f' (torch.Tensor): [M, 2]} | |
""" | |
with torch.no_grad(): | |
W_f = self.W_f | |
# 1. confidence thresholding | |
mask = conf_matrix_fine > self.fine_thr | |
if mask.sum() == 0: | |
mask[0,0,0] = 1 | |
conf_matrix_fine[0,0,0] = 1 | |
if not self.denser: | |
# match only the highest confidence | |
mask = mask \ | |
* (conf_matrix_fine == conf_matrix_fine.amax(dim=[1,2], keepdim=True)) | |
else: | |
# 2. mutual nearest, match all features in fine window | |
mask = mask \ | |
* (conf_matrix_fine == conf_matrix_fine.max(dim=2, keepdim=True)[0]) \ | |
* (conf_matrix_fine == conf_matrix_fine.max(dim=1, keepdim=True)[0]) | |
# 3. find all valid fine matches | |
# this only works when at most one `True` in each row | |
mask_v, all_j_ids = mask.max(dim=2) | |
b_ids, i_ids = torch.where(mask_v) | |
j_ids = all_j_ids[b_ids, i_ids] | |
mconf = conf_matrix_fine[b_ids, i_ids, j_ids] | |
# 4. update with matches in original image resolution | |
# indices from coarse matches | |
b_ids_c, i_ids_c, j_ids_c = data['b_ids'], data['i_ids'], data['j_ids'] | |
# scale (coarse level / fine-level) | |
scale_f_c = data['hw0_f'][0] // data['hw0_c'][0] | |
# coarse level matches scaled to fine-level (1/2) | |
mkpts0_c_scaled_to_f = torch.stack( | |
[i_ids_c % data['hw0_c'][1], torch.div(i_ids_c, data['hw0_c'][1], rounding_mode='trunc')], | |
dim=1) * scale_f_c | |
mkpts1_c_scaled_to_f = torch.stack( | |
[j_ids_c % data['hw1_c'][1], torch.div(j_ids_c, data['hw1_c'][1], rounding_mode='trunc')], | |
dim=1) * scale_f_c | |
# updated b_ids after second thresholding | |
updated_b_ids = b_ids_c[b_ids] | |
# scales (image res / fine level) | |
scale = data['hw0_i'][0] / data['hw0_f'][0] | |
scale0 = scale * data['scale0'][updated_b_ids] if 'scale0' in data else scale | |
scale1 = scale * data['scale1'][updated_b_ids] if 'scale1' in data else scale | |
# fine-level discrete matches on window coordiantes | |
mkpts0_f_window = torch.stack( | |
[i_ids % W_f, torch.div(i_ids, W_f, rounding_mode='trunc')], | |
dim=1) | |
mkpts1_f_window = torch.stack( | |
[j_ids % W_f, torch.div(j_ids, W_f, rounding_mode='trunc')], | |
dim=1) | |
# sub-pixel refinement | |
sub_ref = self.subpixel_mlp(torch.cat([feat_f0_unfold[b_ids, i_ids], | |
feat_f1_unfold[b_ids, j_ids]], dim=-1)) | |
sub_ref0, sub_ref1 = torch.chunk(sub_ref, 2, dim=-1) | |
sub_ref0 = torch.tanh(sub_ref0) * 0.5 | |
sub_ref1 = torch.tanh(sub_ref1) * 0.5 | |
# final sub-pixel matches by (coarse-level + fine-level windowed + sub-pixel refinement) | |
mkpts0_f_train = (mkpts0_f_window + mkpts0_c_scaled_to_f[b_ids] - (W_f//2) + sub_ref0) * scale0 | |
mkpts1_f_train = (mkpts1_f_window + mkpts1_c_scaled_to_f[b_ids] - (W_f//2) + sub_ref1) * scale1 | |
mkpts0_f = mkpts0_f_train.clone().detach() | |
mkpts1_f = mkpts1_f_train.clone().detach() | |
# These matches is the current prediction (for visualization) | |
sub_pixel_matches = { | |
'm_bids': b_ids_c[b_ids[mconf != 0]], # mconf == 0 => gt matches | |
'mkpts0_f': mkpts0_f[mconf != 0], | |
'mkpts1_f': mkpts1_f[mconf != 0], | |
'mconf_f': mconf[mconf != 0] | |
} | |
# These matches are used for training | |
if not self.inference: | |
sub_pixel_matches.update({ | |
'mkpts0_f_train': mkpts0_f_train[mconf != 0], | |
'mkpts1_f_train': mkpts1_f_train[mconf != 0], | |
}) | |
return sub_pixel_matches | |