Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,290 Bytes
a930e1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class FineSubMatching(nn.Module):
"""Fine-level and Sub-pixel matching"""
def __init__(self, config):
super().__init__()
self.temperature = config['fine']['dsmax_temperature']
self.W_f = config['fine_window_size']
self.denser = config['fine']['denser']
self.inference = config['fine']['inference']
dim_f = config['resnet']['block_dims'][0]
self.fine_thr = config['fine']['thr']
self.fine_proj = nn.Linear(dim_f, dim_f, bias=False)
self.subpixel_mlp = nn.Sequential(nn.Linear(2*dim_f, 2*dim_f, bias=False),
nn.ReLU(),
nn.Linear(2*dim_f, 4, bias=False))
def forward(self, feat_f0_unfold, feat_f1_unfold, data):
"""
Args:
feat_f0_unfold (torch.Tensor): [M, WW, C]
feat_f1_unfold (torch.Tensor): [M, WW, C]
data (dict)
Update:
data (dict):{
'expec_f' (torch.Tensor): [M, 3],
'mkpts0_f' (torch.Tensor): [M, 2],
'mkpts1_f' (torch.Tensor): [M, 2]}
"""
feat_f0 = self.fine_proj(feat_f0_unfold)
feat_f1 = self.fine_proj(feat_f1_unfold)
M, WW, C = feat_f0.shape
W_f = self.W_f
# corner case: if no coarse matches found
if M == 0:
assert self.training == False, "M is always >0, when training, see coarse_matching.py"
# logger.warning('No matches found in coarse-level.')
data.update({
'mkpts0_f': data['mkpts0_c'],
'mkpts1_f': data['mkpts1_c'],
'mconf_f': torch.zeros(0, device=feat_f0_unfold.device),
# 'mkpts0_f_train': data['mkpts0_c'],
# 'mkpts1_f_train': data['mkpts1_c'],
# 'conf_matrix_fine': torch.zeros(1, W_f*W_f, W_f*W_f, device=feat_f0.device)
})
return
# normalize
feat_f0, feat_f1 = map(lambda feat: feat / feat.shape[-1]**.5,
[feat_f0, feat_f1])
sim_matrix = torch.einsum("nlc,nsc->nls", feat_f0,
feat_f1) / self.temperature
conf_matrix_fine = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
data.update({'conf_matrix_fine': conf_matrix_fine})
# predict fine-level and sub-pixel matches from conf_matrix
data.update(**self.get_fine_sub_match(conf_matrix_fine, feat_f0_unfold, feat_f1_unfold, data))
def get_fine_sub_match(self, conf_matrix_fine, feat_f0_unfold, feat_f1_unfold, data):
"""
Args:
conf_matrix_fine (torch.Tensor): [M, WW, WW]
feat_f0_unfold (torch.Tensor): [M, WW, C]
feat_f1_unfold (torch.Tensor): [M, WW, C]
data (dict)
Update:
data (dict):{
'm_bids' (torch.Tensor): [M]
'expec_f' (torch.Tensor): [M, 3],
'mkpts0_f' (torch.Tensor): [M, 2],
'mkpts1_f' (torch.Tensor): [M, 2]}
"""
with torch.no_grad():
W_f = self.W_f
# 1. confidence thresholding
mask = conf_matrix_fine > self.fine_thr
if mask.sum() == 0:
mask[0,0,0] = 1
conf_matrix_fine[0,0,0] = 1
if not self.denser:
# match only the highest confidence
mask = mask \
* (conf_matrix_fine == conf_matrix_fine.amax(dim=[1,2], keepdim=True))
else:
# 2. mutual nearest, match all features in fine window
mask = mask \
* (conf_matrix_fine == conf_matrix_fine.max(dim=2, keepdim=True)[0]) \
* (conf_matrix_fine == conf_matrix_fine.max(dim=1, keepdim=True)[0])
# 3. find all valid fine matches
# this only works when at most one `True` in each row
mask_v, all_j_ids = mask.max(dim=2)
b_ids, i_ids = torch.where(mask_v)
j_ids = all_j_ids[b_ids, i_ids]
mconf = conf_matrix_fine[b_ids, i_ids, j_ids]
# 4. update with matches in original image resolution
# indices from coarse matches
b_ids_c, i_ids_c, j_ids_c = data['b_ids'], data['i_ids'], data['j_ids']
# scale (coarse level / fine-level)
scale_f_c = data['hw0_f'][0] // data['hw0_c'][0]
# coarse level matches scaled to fine-level (1/2)
mkpts0_c_scaled_to_f = torch.stack(
[i_ids_c % data['hw0_c'][1], torch.div(i_ids_c, data['hw0_c'][1], rounding_mode='trunc')],
dim=1) * scale_f_c
mkpts1_c_scaled_to_f = torch.stack(
[j_ids_c % data['hw1_c'][1], torch.div(j_ids_c, data['hw1_c'][1], rounding_mode='trunc')],
dim=1) * scale_f_c
# updated b_ids after second thresholding
updated_b_ids = b_ids_c[b_ids]
# scales (image res / fine level)
scale = data['hw0_i'][0] / data['hw0_f'][0]
scale0 = scale * data['scale0'][updated_b_ids] if 'scale0' in data else scale
scale1 = scale * data['scale1'][updated_b_ids] if 'scale1' in data else scale
# fine-level discrete matches on window coordiantes
mkpts0_f_window = torch.stack(
[i_ids % W_f, torch.div(i_ids, W_f, rounding_mode='trunc')],
dim=1)
mkpts1_f_window = torch.stack(
[j_ids % W_f, torch.div(j_ids, W_f, rounding_mode='trunc')],
dim=1)
# sub-pixel refinement
sub_ref = self.subpixel_mlp(torch.cat([feat_f0_unfold[b_ids, i_ids],
feat_f1_unfold[b_ids, j_ids]], dim=-1))
sub_ref0, sub_ref1 = torch.chunk(sub_ref, 2, dim=-1)
sub_ref0 = torch.tanh(sub_ref0) * 0.5
sub_ref1 = torch.tanh(sub_ref1) * 0.5
# final sub-pixel matches by (coarse-level + fine-level windowed + sub-pixel refinement)
mkpts0_f_train = (mkpts0_f_window + mkpts0_c_scaled_to_f[b_ids] - (W_f//2) + sub_ref0) * scale0
mkpts1_f_train = (mkpts1_f_window + mkpts1_c_scaled_to_f[b_ids] - (W_f//2) + sub_ref1) * scale1
mkpts0_f = mkpts0_f_train.clone().detach()
mkpts1_f = mkpts1_f_train.clone().detach()
# These matches is the current prediction (for visualization)
sub_pixel_matches = {
'm_bids': b_ids_c[b_ids[mconf != 0]], # mconf == 0 => gt matches
'mkpts0_f': mkpts0_f[mconf != 0],
'mkpts1_f': mkpts1_f[mconf != 0],
'mconf_f': mconf[mconf != 0]
}
# These matches are used for training
if not self.inference:
sub_pixel_matches.update({
'mkpts0_f_train': mkpts0_f_train[mconf != 0],
'mkpts1_f_train': mkpts1_f_train[mconf != 0],
})
return sub_pixel_matches
|