Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from kornia.geometry.subpix import dsnt | |
from kornia.utils.grid import create_meshgrid | |
from loguru import logger | |
class FineMatching(nn.Module): | |
"""FineMatching with s2d paradigm""" | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.local_regress_temperature = config['match_fine']['local_regress_temperature'] | |
self.local_regress_slicedim = config['match_fine']['local_regress_slicedim'] | |
self.fp16 = config['half'] | |
def forward(self, feat_0, feat_1, data): | |
""" | |
Args: | |
feat0 (torch.Tensor): [M, WW, C] | |
feat1 (torch.Tensor): [M, WW, C] | |
data (dict) | |
Update: | |
data (dict):{ | |
'expec_f' (torch.Tensor): [M, 3], | |
'mkpts0_f' (torch.Tensor): [M, 2], | |
'mkpts1_f' (torch.Tensor): [M, 2]} | |
""" | |
M, WW, C = feat_0.shape | |
W = int(math.sqrt(WW)) | |
scale = data['hw0_i'][0] / data['hw0_f'][0] | |
self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale | |
# corner case: if no coarse matches found | |
if M == 0: | |
assert self.training == False, "M is always > 0 while training, see coarse_matching.py" | |
data.update({ | |
'conf_matrix_f': torch.empty(0, WW, WW, device=feat_0.device), | |
'mkpts0_f': data['mkpts0_c'], | |
'mkpts1_f': data['mkpts1_c'], | |
}) | |
return | |
# compute pixel-level confidence matrix | |
with torch.autocast(enabled=True, device_type='cuda'): | |
feat_f0, feat_f1 = feat_0[...,:-self.local_regress_slicedim], feat_1[...,:-self.local_regress_slicedim] | |
feat_ff0, feat_ff1 = feat_0[...,-self.local_regress_slicedim:], feat_1[...,-self.local_regress_slicedim:] | |
feat_f0, feat_f1 = feat_f0 / C**.5, feat_f1 / C**.5 | |
conf_matrix_f = torch.einsum('mlc,mrc->mlr', feat_f0, feat_f1) | |
conf_matrix_ff = torch.einsum('mlc,mrc->mlr', feat_ff0, feat_ff1 / (self.local_regress_slicedim)**.5) | |
softmax_matrix_f = F.softmax(conf_matrix_f, 1) * F.softmax(conf_matrix_f, 2) | |
softmax_matrix_f = softmax_matrix_f.reshape(M, self.WW, self.W+2, self.W+2) | |
softmax_matrix_f = softmax_matrix_f[...,1:-1,1:-1].reshape(M, self.WW, self.WW) | |
# for fine-level supervision | |
if self.training: | |
data.update({'sim_matrix_ff': conf_matrix_ff}) | |
data.update({'conf_matrix_f': softmax_matrix_f}) | |
# compute pixel-level absolute kpt coords | |
self.get_fine_ds_match(softmax_matrix_f, data) | |
# generate seconde-stage 3x3 grid | |
idx_l, idx_r = data['idx_l'], data['idx_r'] | |
m_ids = torch.arange(M, device=idx_l.device, dtype=torch.long).unsqueeze(-1) | |
m_ids = m_ids[:len(data['mconf'])] | |
idx_r_iids, idx_r_jids = idx_r // W, idx_r % W | |
m_ids, idx_l, idx_r_iids, idx_r_jids = m_ids.reshape(-1), idx_l.reshape(-1), idx_r_iids.reshape(-1), idx_r_jids.reshape(-1) | |
delta = create_meshgrid(3, 3, True, conf_matrix_ff.device).to(torch.long) # [1, 3, 3, 2] | |
m_ids = m_ids[...,None,None].expand(-1, 3, 3) | |
idx_l = idx_l[...,None,None].expand(-1, 3, 3) # [m, k, 3, 3] | |
idx_r_iids = idx_r_iids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 1] | |
idx_r_jids = idx_r_jids[...,None,None].expand(-1, 3, 3) + delta[None, ..., 0] | |
if idx_l.numel() == 0: | |
data.update({ | |
'mkpts0_f': data['mkpts0_c'], | |
'mkpts1_f': data['mkpts1_c'], | |
}) | |
return | |
# compute second-stage heatmap | |
conf_matrix_ff = conf_matrix_ff.reshape(M, self.WW, self.W+2, self.W+2) | |
conf_matrix_ff = conf_matrix_ff[m_ids, idx_l, idx_r_iids, idx_r_jids] | |
conf_matrix_ff = conf_matrix_ff.reshape(-1, 9) | |
conf_matrix_ff = F.softmax(conf_matrix_ff / self.local_regress_temperature, -1) | |
heatmap = conf_matrix_ff.reshape(-1, 3, 3) | |
# compute coordinates from heatmap | |
coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] | |
if data['bs'] == 1: | |
scale1 = scale * data['scale1'] if 'scale0' in data else scale | |
else: | |
scale1 = scale * data['scale1'][data['b_ids']][:len(data['mconf']), ...][:,None,:].expand(-1, -1, 2).reshape(-1, 2) if 'scale0' in data else scale | |
# compute subpixel-level absolute kpt coords | |
self.get_fine_match_local(coords_normalized, data, scale1) | |
def get_fine_match_local(self, coords_normed, data, scale1): | |
W, WW, C, scale = self.W, self.WW, self.C, self.scale | |
mkpts0_c, mkpts1_c = data['mkpts0_c'], data['mkpts1_c'] | |
# mkpts0_f and mkpts1_f | |
mkpts0_f = mkpts0_c | |
mkpts1_f = mkpts1_c + (coords_normed * (3 // 2) * scale1) | |
data.update({ | |
"mkpts0_f": mkpts0_f, | |
"mkpts1_f": mkpts1_f | |
}) | |
def get_fine_ds_match(self, conf_matrix, data): | |
W, WW, C, scale = self.W, self.WW, self.C, self.scale | |
m, _, _ = conf_matrix.shape | |
conf_matrix = conf_matrix.reshape(m, -1)[:len(data['mconf']),...] | |
val, idx = torch.max(conf_matrix, dim = -1) | |
idx = idx[:,None] | |
idx_l, idx_r = idx // WW, idx % WW | |
data.update({'idx_l': idx_l, 'idx_r': idx_r}) | |
if self.fp16: | |
grid = create_meshgrid(W, W, False, conf_matrix.device, dtype=torch.float16) - W // 2 + 0.5 # kornia >= 0.5.1 | |
else: | |
grid = create_meshgrid(W, W, False, conf_matrix.device) - W // 2 + 0.5 | |
grid = grid.reshape(1, -1, 2).expand(m, -1, -1) | |
delta_l = torch.gather(grid, 1, idx_l.unsqueeze(-1).expand(-1, -1, 2)) | |
delta_r = torch.gather(grid, 1, idx_r.unsqueeze(-1).expand(-1, -1, 2)) | |
scale0 = scale * data['scale0'][data['b_ids']] if 'scale0' in data else scale | |
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale | |
if torch.is_tensor(scale0) and scale0.numel() > 1: # scale0 is a tensor | |
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0[:len(data['mconf']),...][:,None,:])).reshape(-1, 2) | |
mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1[:len(data['mconf']),...][:,None,:])).reshape(-1, 2) | |
else: # scale0 is a float | |
mkpts0_f = (data['mkpts0_c'][:,None,:] + (delta_l * scale0)).reshape(-1, 2) | |
mkpts1_f = (data['mkpts1_c'][:,None,:] + (delta_r * scale1)).reshape(-1, 2) | |
data.update({ | |
"mkpts0_c": mkpts0_f, | |
"mkpts1_c": mkpts1_f | |
}) |