Realcat's picture
add: rdd sparse and dense match
1b369eb
# Description: RDD model
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from .utils import NestedTensor, nested_tensor_from_tensor_list, to_pixel_coords, read_config
from .models.detector import build_detector
from .models.descriptor import build_descriptor
from .models.soft_detect import SoftDetect
from .models.interpolator import InterpolateSparse2d
class RDD(nn.Module):
def __init__(self, detector, descriptor, detection_threshold=0.5, top_k=4096, train_detector=False, device='cuda'):
super().__init__()
self.detector = detector
self.descriptor = descriptor
self.interpolator = InterpolateSparse2d('bicubic')
self.detection_threshold = detection_threshold
self.top_k = top_k
self.device = device
if train_detector:
for p in self.detector.parameters():
p.requires_grad = True
for p in self.descriptor.parameters():
p.requires_grad = False
else:
for p in self.detector.parameters():
p.requires_grad = False
for p in self.descriptor.parameters():
p.requires_grad = True
self.softdetect = None
self.stride = descriptor.stride
def train(self, mode=True):
super().train(mode)
self.set_softdetect(top_k=500, scores_th=0.2)
def eval(self):
super().eval()
self.set_softdetect(top_k=self.top_k, scores_th=0.01)
def forward(self, samples: NestedTensor):
if not isinstance(samples, NestedTensor):
samples = nested_tensor_from_tensor_list(samples)
scoremap = self.detector(samples)
feats, matchibility = self.descriptor(samples)
return feats, scoremap, matchibility
def set_softdetect(self, top_k=4096, scores_th=0.01):
self.softdetect = SoftDetect(radius=2, top_k=top_k, scores_th=scores_th)
@torch.inference_mode()
def filter(self, matchibility):
# Filter out keypoints on the border
B, _, H, W = matchibility.shape
frame = torch.zeros(B, H, W, device=matchibility.device)
frame[:, self.stride:-self.stride, self.stride:-self.stride] = 1
matchibility = matchibility * frame
return matchibility
@torch.inference_mode()
def extract(self, x):
if self.softdetect is None:
self.eval()
x, rh1, rw1 = self.preprocess_tensor(x)
x = x.to(self.device).float()
B, _, _H1, _W1 = x.shape
M1, K1, H1 = self.forward(x)
M1 = F.normalize(M1, dim=1)
keypoints, kptscores, scoredispersitys = self.softdetect(K1)
keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)])
kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)])
keypoints = to_pixel_coords(keypoints, _H1, _W1)
feats = self.interpolator(M1, keypoints, H = _H1, W = _W1)
feats = F.normalize(feats, dim=-1)
# Correct kpt scale
keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1)
valid = kptscores > self.detection_threshold
return [
{'keypoints': keypoints[b][valid[b]],
'scores': kptscores[b][valid[b]],
'descriptors': feats[b][valid[b]]} for b in range(B)
]
@torch.inference_mode()
def extract_3rd_party(self, x, model='aliked'):
"""
one image per batch
"""
x, rh1, rw1 = self.preprocess_tensor(x)
B, _, _H1, _W1 = x.shape
if model == 'aliked':
from third_party import extract_aliked_kpts
img = x
mkpts, scores = extract_aliked_kpts(img, self.device)
else:
raise ValueError('Unknown model')
M1, _ = self.descriptor(x)
M1 = F.normalize(M1, dim=1)
if mkpts.shape[1] > self.top_k:
idx = torch.argsort(scores, descending=True)[0][:self.top_k]
mkpts = mkpts[:,idx]
scores = scores[:,idx]
feats = self.interpolator(M1, mkpts, H = _H1, W = _W1)
feats = F.normalize(feats, dim=-1)
mkpts = mkpts * torch.tensor([rw1,rh1], device=mkpts.device).view(1, 1, -1)
return [
{'keypoints': mkpts[b],
'scores': scores[b],
'descriptors': feats[b]} for b in range(B)
]
@torch.inference_mode()
def extract_dense(self, x, n_limit=30000, thr=0.01):
self.set_softdetect(top_k=n_limit, scores_th=-1)
x, rh1, rw1 = self.preprocess_tensor(x)
B, _, _H1, _W1 = x.shape
M1, K1, H1 = self.forward(x)
M1 = F.normalize(M1, dim=1)
keypoints, kptscores, scoredispersitys = self.softdetect(K1)
keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)])
kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)])
keypoints = to_pixel_coords(keypoints, _H1, _W1)
feats = self.interpolator(M1, keypoints, H = _H1, W = _W1)
feats = F.normalize(feats, dim=-1)
H1 = self.filter(H1)
dense_kpts, dense_scores, inds = self.sample_dense_kpts(H1, n_limit=n_limit)
dense_keypoints = to_pixel_coords(dense_kpts, _H1, _W1)
dense_feats = self.interpolator(M1, dense_keypoints, H = _H1, W = _W1)
dense_feats = F.normalize(dense_feats, dim=-1)
keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1)
dense_keypoints = dense_keypoints * torch.tensor([rw1,rh1], device=dense_keypoints.device).view(1, -1)
valid = kptscores > self.detection_threshold
valid_dense = dense_scores > thr
return [
{'keypoints': keypoints[b][valid[b]],
'scores': kptscores[b][valid[b]],
'descriptors': feats[b][valid[b]],
'keypoints_dense': dense_keypoints[b][valid_dense[b]],
'scores_dense': dense_scores[b][valid_dense[b]],
'descriptors_dense': dense_feats[b][valid_dense[b]]} for b in range(B)
]
@torch.inference_mode()
def sample_dense_kpts(self, keypoint_logits, threshold=0.01, n_limit=30000, force_kpts = True):
B, K, H, W = keypoint_logits.shape
if n_limit < 0 or n_limit > H*W:
n_limit = min(H*W - 1, n_limit)
scoremap = keypoint_logits.permute(0,2,3,1)
scoremap = scoremap.reshape(B, H, W)
frame = torch.zeros(B, H, W, device=keypoint_logits.device)
frame[:, 1:-1, 1:-1] = 1
scoremap = scoremap * frame
scoremap = scoremap.reshape(B, H*W)
grid = self.get_grid(B, H, W, device = keypoint_logits.device)
inds = torch.topk(scoremap, n_limit, dim=1).indices
# inds = torch.multinomial(scoremap, top_k, replacement=False)
kpts = torch.gather(grid, 1, inds[..., None].expand(B, n_limit, 2))
scoremap = torch.gather(scoremap, 1, inds)
if force_kpts:
valid = scoremap > threshold
kpts = kpts[valid][None]
scoremap = scoremap[valid][None]
return kpts, scoremap, inds
def preprocess_tensor(self, x):
""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
if isinstance(x, np.ndarray) and len(x.shape) == 3:
x = torch.tensor(x).permute(2,0,1)[None]
x = x.to(self.device).float()
H, W = x.shape[-2:]
_H, _W = (H//32) * 32, (W//32) * 32
rh, rw = H/_H, W/_W
x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
return x, rh, rw
@torch.inference_mode()
def get_grid(self, B, H, W, device = None):
x1_n = torch.meshgrid(
*[
torch.linspace(
-1 + 1 / n, 1 - 1 / n, n, device=device
)
for n in (B, H, W)
]
)
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
return x1_n
def build(config=None, weights=None):
if config is None:
config = read_config('./configs/default.yaml')
if weights is not None:
config['weights'] = weights
device = torch.device(config['device'])
print('config', config)
detector = build_detector(config)
descriptor = build_descriptor(config)
model = RDD(
detector,
descriptor,
detection_threshold=config['detection_threshold'],
top_k=config['top_k'],
train_detector=config['train_detector'],
device=device
)
if 'weights' in config and config['weights'] is not None:
model.load_state_dict(torch.load(config['weights'], map_location='cpu'))
model.to(device)
return model