lsxi77777's picture
commit message
a930e1f
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.einops import rearrange
INF = 1e9
def mask_border(m, b: int, v):
""" Mask borders with value
Args:
m (torch.Tensor): [N, H0, W0, H1, W1]
b (int)
v (m.dtype)
"""
if b <= 0:
return
m[:, :b] = v
m[:, :, :b] = v
m[:, :, :, :b] = v
m[:, :, :, :, :b] = v
m[:, -b:] = v
m[:, :, -b:] = v
m[:, :, :, -b:] = v
m[:, :, :, :, -b:] = v
def mask_border_with_padding(m, bd, v, p_m0, p_m1):
if bd <= 0:
return
m[:, :bd] = v
m[:, :, :bd] = v
m[:, :, :, :bd] = v
m[:, :, :, :, :bd] = v
h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
m[b_idx, h0 - bd:] = v
m[b_idx, :, w0 - bd:] = v
m[b_idx, :, :, h1 - bd:] = v
m[b_idx, :, :, :, w1 - bd:] = v
def compute_max_candidates(p_m0, p_m1):
"""Compute the max candidates of all pairs within a batch
Args:
p_m0, p_m1 (torch.Tensor): padded masks
"""
h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
max_cand = torch.sum(
torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
return max_cand
class CoarseMatching(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# general config
d_model = config['d_model']
self.thr = config['thr']
self.inference = config['inference']
self.border_rm = config['border_rm']
# -- # for trainig fine-level XoFTR
self.train_coarse_percent = config['train_coarse_percent']
self.train_pad_num_gt_min = config['train_pad_num_gt_min']
self.final_proj = nn.Linear(d_model, d_model, bias=True)
self.temperature = config['dsmax_temperature']
def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
"""
Args:
feat0 (torch.Tensor): [N, L, C]
feat1 (torch.Tensor): [N, S, C]
data (dict)
mask_c0 (torch.Tensor): [N, L] (optional)
mask_c1 (torch.Tensor): [N, S] (optional)
Update:
data (dict): {
'b_ids' (torch.Tensor): [M'],
'i_ids' (torch.Tensor): [M'],
'j_ids' (torch.Tensor): [M'],
'gt_mask' (torch.Tensor): [M'],
'mkpts0_c' (torch.Tensor): [M, 2],
'mkpts1_c' (torch.Tensor): [M, 2],
'mconf' (torch.Tensor): [M]}
NOTE: M' != M during training.
"""
feat_c0 = self.final_proj(feat_c0)
feat_c1 = self.final_proj(feat_c1)
# normalize
feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
[feat_c0, feat_c1])
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
feat_c1) / self.temperature
if mask_c0 is not None:
sim_matrix.masked_fill_(
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
-INF)
if self.inference:
# predict coarse matches from conf_matrix
data.update(**self.get_coarse_match_inference(sim_matrix, data))
else:
conf_matrix_0_to_1 = F.softmax(sim_matrix, 2)
conf_matrix_1_to_0 = F.softmax(sim_matrix, 1)
data.update({'conf_matrix_0_to_1': conf_matrix_0_to_1,
'conf_matrix_1_to_0': conf_matrix_1_to_0
})
# predict coarse matches from conf_matrix
data.update(**self.get_coarse_match_training(conf_matrix_0_to_1, conf_matrix_1_to_0, data))
@torch.no_grad()
def get_coarse_match_training(self, conf_matrix_0_to_1, conf_matrix_1_to_0, data):
"""
Args:
conf_matrix_0_to_1 (torch.Tensor): [N, L, S]
conf_matrix_1_to_0 (torch.Tensor): [N, L, S]
data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
Returns:
coarse_matches (dict): {
'b_ids' (torch.Tensor): [M'],
'i_ids' (torch.Tensor): [M'],
'j_ids' (torch.Tensor): [M'],
'gt_mask' (torch.Tensor): [M'],
'm_bids' (torch.Tensor): [M],
'mkpts0_c' (torch.Tensor): [M, 2],
'mkpts1_c' (torch.Tensor): [M, 2],
'mconf' (torch.Tensor): [M]}
"""
axes_lengths = {
'h0c': data['hw0_c'][0],
'w0c': data['hw0_c'][1],
'h1c': data['hw1_c'][0],
'w1c': data['hw1_c'][1]
}
_device = conf_matrix_0_to_1.device
# confidence thresholding
# {(nearest neighbour for 0 to 1) U (nearest neighbour for 1 to 0)}
mask = torch.logical_or((conf_matrix_0_to_1 > self.thr) * (conf_matrix_0_to_1 == conf_matrix_0_to_1.max(dim=2, keepdim=True)[0]),
(conf_matrix_1_to_0 > self.thr) * (conf_matrix_1_to_0 == conf_matrix_1_to_0.max(dim=1, keepdim=True)[0]))
mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
**axes_lengths)
if 'mask0' not in data:
mask_border(mask, self.border_rm, False)
else:
mask_border_with_padding(mask, self.border_rm, False,
data['mask0'], data['mask1'])
mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
**axes_lengths)
# find all valid coarse matches
b_ids, i_ids, j_ids = mask.nonzero(as_tuple=True)
mconf = torch.maximum(conf_matrix_0_to_1[b_ids, i_ids, j_ids], conf_matrix_1_to_0[b_ids, i_ids, j_ids])
# random sampling of training samples for fine-level XoFTR
# (optional) pad samples with gt coarse-level matches
if self.training:
# NOTE:
# the sampling is performed across all pairs in a batch without manually balancing
# samples for fine-level increases w.r.t. batch_size
if 'mask0' not in data:
num_candidates_max = mask.size(0) * max(
mask.size(1), mask.size(2))
else:
num_candidates_max = compute_max_candidates(
data['mask0'], data['mask1'])
num_matches_train = int(num_candidates_max *
self.train_coarse_percent)
num_matches_pred = len(b_ids)
assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
# pred_indices is to select from prediction
if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
pred_indices = torch.arange(num_matches_pred, device=_device)
else:
pred_indices = torch.randint(
num_matches_pred,
(num_matches_train - self.train_pad_num_gt_min, ),
device=_device)
# gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
gt_pad_indices = torch.randint(
len(data['spv_b_ids']),
(max(num_matches_train - num_matches_pred,
self.train_pad_num_gt_min), ),
device=_device)
mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
b_ids, i_ids, j_ids, mconf = map(
lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
dim=0),
*zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
[j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
# these matches are selected patches that feed into fine-level network
coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
# update with matches in original image resolution
scale = data['hw0_i'][0] / data['hw0_c'][0]
scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
mkpts0_c = torch.stack(
[i_ids % data['hw0_c'][1], torch.div(i_ids, data['hw0_c'][1], rounding_mode='trunc')],
dim=1) * scale0
mkpts1_c = torch.stack(
[j_ids % data['hw1_c'][1], torch.div(j_ids, data['hw1_c'][1], rounding_mode='trunc')],
dim=1) * scale1
# these matches is the current prediction (for visualization)
coarse_matches.update({
'gt_mask': mconf == 0,
'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
'mkpts0_c': mkpts0_c[mconf != 0],
'mkpts1_c': mkpts1_c[mconf != 0],
'mconf': mconf[mconf != 0]
})
return coarse_matches
@torch.no_grad()
def get_coarse_match_inference(self, sim_matrix, data):
"""
Args:
sim_matrix (torch.Tensor): [N, L, S]
data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
Returns:
coarse_matches (dict): {
'b_ids' (torch.Tensor): [M'],
'i_ids' (torch.Tensor): [M'],
'j_ids' (torch.Tensor): [M'],
'gt_mask' (torch.Tensor): [M'],
'm_bids' (torch.Tensor): [M],
'mkpts0_c' (torch.Tensor): [M, 2],
'mkpts1_c' (torch.Tensor): [M, 2],
'mconf' (torch.Tensor): [M]}
"""
axes_lengths = {
'h0c': data['hw0_c'][0],
'w0c': data['hw0_c'][1],
'h1c': data['hw1_c'][0],
'w1c': data['hw1_c'][1]
}
# softmax for 0 to 1
conf_matrix_ = F.softmax(sim_matrix, 2)
# confidence thresholding and nearest neighbour for 0 to 1
mask = (conf_matrix_ > self.thr) * (conf_matrix_ == conf_matrix_.max(dim=2, keepdim=True)[0])
# unlike training, reuse the same conf martix to decrease the vram consumption
# softmax for 0 to 1
conf_matrix_ = F.softmax(sim_matrix, 1)
# update mask {(nearest neighbour for 0 to 1) U (nearest neighbour for 1 to 0)}
mask = torch.logical_or(mask,
(conf_matrix_ > self.thr) * (conf_matrix_ == conf_matrix_.max(dim=1, keepdim=True)[0]))
mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
**axes_lengths)
if 'mask0' not in data:
mask_border(mask, self.border_rm, False)
else:
mask_border_with_padding(mask, self.border_rm, False,
data['mask0'], data['mask1'])
mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
**axes_lengths)
# find all valid coarse matches
b_ids, i_ids, j_ids = mask.nonzero(as_tuple=True)
# mconf = torch.maximum(conf_matrix_0_to_1[b_ids, i_ids, j_ids], conf_matrix_1_to_0[b_ids, i_ids, j_ids])
# these matches are selected patches that feed into fine-level network
coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
# update with matches in original image resolution
scale = data['hw0_i'][0] / data['hw0_c'][0]
scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
mkpts0_c = torch.stack(
[i_ids % data['hw0_c'][1], torch.div(i_ids, data['hw0_c'][1], rounding_mode='trunc')],
dim=1) * scale0
mkpts1_c = torch.stack(
[j_ids % data['hw1_c'][1], torch.div(j_ids, data['hw1_c'][1], rounding_mode='trunc')],
dim=1) * scale1
# these matches are the current coarse level predictions
coarse_matches.update({
'm_bids': b_ids, # mconf == 0 => gt matches
'mkpts0_c': mkpts0_c,
'mkpts1_c': mkpts1_c,
})
return coarse_matches