lsxi77777's picture
commit message
a930e1f
import torch
import torch.nn as nn
import torch.nn.functional as F
class FineSubMatching(nn.Module):
"""Fine-level and Sub-pixel matching"""
def __init__(self, config):
super().__init__()
self.temperature = config['fine']['dsmax_temperature']
self.W_f = config['fine_window_size']
self.denser = config['fine']['denser']
self.inference = config['fine']['inference']
dim_f = config['resnet']['block_dims'][0]
self.fine_thr = config['fine']['thr']
self.fine_proj = nn.Linear(dim_f, dim_f, bias=False)
self.subpixel_mlp = nn.Sequential(nn.Linear(2*dim_f, 2*dim_f, bias=False),
nn.ReLU(),
nn.Linear(2*dim_f, 4, bias=False))
def forward(self, feat_f0_unfold, feat_f1_unfold, data):
"""
Args:
feat_f0_unfold (torch.Tensor): [M, WW, C]
feat_f1_unfold (torch.Tensor): [M, WW, C]
data (dict)
Update:
data (dict):{
'expec_f' (torch.Tensor): [M, 3],
'mkpts0_f' (torch.Tensor): [M, 2],
'mkpts1_f' (torch.Tensor): [M, 2]}
"""
feat_f0 = self.fine_proj(feat_f0_unfold)
feat_f1 = self.fine_proj(feat_f1_unfold)
M, WW, C = feat_f0.shape
W_f = self.W_f
# corner case: if no coarse matches found
if M == 0:
assert self.training == False, "M is always >0, when training, see coarse_matching.py"
# logger.warning('No matches found in coarse-level.')
data.update({
'mkpts0_f': data['mkpts0_c'],
'mkpts1_f': data['mkpts1_c'],
'mconf_f': torch.zeros(0, device=feat_f0_unfold.device),
# 'mkpts0_f_train': data['mkpts0_c'],
# 'mkpts1_f_train': data['mkpts1_c'],
# 'conf_matrix_fine': torch.zeros(1, W_f*W_f, W_f*W_f, device=feat_f0.device)
})
return
# normalize
feat_f0, feat_f1 = map(lambda feat: feat / feat.shape[-1]**.5,
[feat_f0, feat_f1])
sim_matrix = torch.einsum("nlc,nsc->nls", feat_f0,
feat_f1) / self.temperature
conf_matrix_fine = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
data.update({'conf_matrix_fine': conf_matrix_fine})
# predict fine-level and sub-pixel matches from conf_matrix
data.update(**self.get_fine_sub_match(conf_matrix_fine, feat_f0_unfold, feat_f1_unfold, data))
def get_fine_sub_match(self, conf_matrix_fine, feat_f0_unfold, feat_f1_unfold, data):
"""
Args:
conf_matrix_fine (torch.Tensor): [M, WW, WW]
feat_f0_unfold (torch.Tensor): [M, WW, C]
feat_f1_unfold (torch.Tensor): [M, WW, C]
data (dict)
Update:
data (dict):{
'm_bids' (torch.Tensor): [M]
'expec_f' (torch.Tensor): [M, 3],
'mkpts0_f' (torch.Tensor): [M, 2],
'mkpts1_f' (torch.Tensor): [M, 2]}
"""
with torch.no_grad():
W_f = self.W_f
# 1. confidence thresholding
mask = conf_matrix_fine > self.fine_thr
if mask.sum() == 0:
mask[0,0,0] = 1
conf_matrix_fine[0,0,0] = 1
if not self.denser:
# match only the highest confidence
mask = mask \
* (conf_matrix_fine == conf_matrix_fine.amax(dim=[1,2], keepdim=True))
else:
# 2. mutual nearest, match all features in fine window
mask = mask \
* (conf_matrix_fine == conf_matrix_fine.max(dim=2, keepdim=True)[0]) \
* (conf_matrix_fine == conf_matrix_fine.max(dim=1, keepdim=True)[0])
# 3. find all valid fine matches
# this only works when at most one `True` in each row
mask_v, all_j_ids = mask.max(dim=2)
b_ids, i_ids = torch.where(mask_v)
j_ids = all_j_ids[b_ids, i_ids]
mconf = conf_matrix_fine[b_ids, i_ids, j_ids]
# 4. update with matches in original image resolution
# indices from coarse matches
b_ids_c, i_ids_c, j_ids_c = data['b_ids'], data['i_ids'], data['j_ids']
# scale (coarse level / fine-level)
scale_f_c = data['hw0_f'][0] // data['hw0_c'][0]
# coarse level matches scaled to fine-level (1/2)
mkpts0_c_scaled_to_f = torch.stack(
[i_ids_c % data['hw0_c'][1], torch.div(i_ids_c, data['hw0_c'][1], rounding_mode='trunc')],
dim=1) * scale_f_c
mkpts1_c_scaled_to_f = torch.stack(
[j_ids_c % data['hw1_c'][1], torch.div(j_ids_c, data['hw1_c'][1], rounding_mode='trunc')],
dim=1) * scale_f_c
# updated b_ids after second thresholding
updated_b_ids = b_ids_c[b_ids]
# scales (image res / fine level)
scale = data['hw0_i'][0] / data['hw0_f'][0]
scale0 = scale * data['scale0'][updated_b_ids] if 'scale0' in data else scale
scale1 = scale * data['scale1'][updated_b_ids] if 'scale1' in data else scale
# fine-level discrete matches on window coordiantes
mkpts0_f_window = torch.stack(
[i_ids % W_f, torch.div(i_ids, W_f, rounding_mode='trunc')],
dim=1)
mkpts1_f_window = torch.stack(
[j_ids % W_f, torch.div(j_ids, W_f, rounding_mode='trunc')],
dim=1)
# sub-pixel refinement
sub_ref = self.subpixel_mlp(torch.cat([feat_f0_unfold[b_ids, i_ids],
feat_f1_unfold[b_ids, j_ids]], dim=-1))
sub_ref0, sub_ref1 = torch.chunk(sub_ref, 2, dim=-1)
sub_ref0 = torch.tanh(sub_ref0) * 0.5
sub_ref1 = torch.tanh(sub_ref1) * 0.5
# final sub-pixel matches by (coarse-level + fine-level windowed + sub-pixel refinement)
mkpts0_f_train = (mkpts0_f_window + mkpts0_c_scaled_to_f[b_ids] - (W_f//2) + sub_ref0) * scale0
mkpts1_f_train = (mkpts1_f_window + mkpts1_c_scaled_to_f[b_ids] - (W_f//2) + sub_ref1) * scale1
mkpts0_f = mkpts0_f_train.clone().detach()
mkpts1_f = mkpts1_f_train.clone().detach()
# These matches is the current prediction (for visualization)
sub_pixel_matches = {
'm_bids': b_ids_c[b_ids[mconf != 0]], # mconf == 0 => gt matches
'mkpts0_f': mkpts0_f[mconf != 0],
'mkpts1_f': mkpts1_f[mconf != 0],
'mconf_f': mconf[mconf != 0]
}
# These matches are used for training
if not self.inference:
sub_pixel_matches.update({
'mkpts0_f_train': mkpts0_f_train[mconf != 0],
'mkpts1_f_train': mkpts1_f_train[mconf != 0],
})
return sub_pixel_matches