Spaces:
Running
Running
# Description: RDD model | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
import numpy as np | |
from .utils import NestedTensor, nested_tensor_from_tensor_list, to_pixel_coords, read_config | |
from .models.detector import build_detector | |
from .models.descriptor import build_descriptor | |
from .models.soft_detect import SoftDetect | |
from .models.interpolator import InterpolateSparse2d | |
class RDD(nn.Module): | |
def __init__(self, detector, descriptor, detection_threshold=0.5, top_k=4096, train_detector=False, device='cuda'): | |
super().__init__() | |
self.detector = detector | |
self.descriptor = descriptor | |
self.interpolator = InterpolateSparse2d('bicubic') | |
self.detection_threshold = detection_threshold | |
self.top_k = top_k | |
self.device = device | |
if train_detector: | |
for p in self.detector.parameters(): | |
p.requires_grad = True | |
for p in self.descriptor.parameters(): | |
p.requires_grad = False | |
else: | |
for p in self.detector.parameters(): | |
p.requires_grad = False | |
for p in self.descriptor.parameters(): | |
p.requires_grad = True | |
self.softdetect = None | |
self.stride = descriptor.stride | |
def train(self, mode=True): | |
super().train(mode) | |
self.set_softdetect(top_k=500, scores_th=0.2) | |
def eval(self): | |
super().eval() | |
self.set_softdetect(top_k=self.top_k, scores_th=0.01) | |
def forward(self, samples: NestedTensor): | |
if not isinstance(samples, NestedTensor): | |
samples = nested_tensor_from_tensor_list(samples) | |
scoremap = self.detector(samples) | |
feats, matchibility = self.descriptor(samples) | |
return feats, scoremap, matchibility | |
def set_softdetect(self, top_k=4096, scores_th=0.01): | |
self.softdetect = SoftDetect(radius=2, top_k=top_k, scores_th=scores_th) | |
def filter(self, matchibility): | |
# Filter out keypoints on the border | |
B, _, H, W = matchibility.shape | |
frame = torch.zeros(B, H, W, device=matchibility.device) | |
frame[:, self.stride:-self.stride, self.stride:-self.stride] = 1 | |
matchibility = matchibility * frame | |
return matchibility | |
def extract(self, x): | |
if self.softdetect is None: | |
self.eval() | |
x, rh1, rw1 = self.preprocess_tensor(x) | |
x = x.to(self.device).float() | |
B, _, _H1, _W1 = x.shape | |
M1, K1, H1 = self.forward(x) | |
M1 = F.normalize(M1, dim=1) | |
keypoints, kptscores, scoredispersitys = self.softdetect(K1) | |
keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)]) | |
kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)]) | |
keypoints = to_pixel_coords(keypoints, _H1, _W1) | |
feats = self.interpolator(M1, keypoints, H = _H1, W = _W1) | |
feats = F.normalize(feats, dim=-1) | |
# Correct kpt scale | |
keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1) | |
valid = kptscores > self.detection_threshold | |
return [ | |
{'keypoints': keypoints[b][valid[b]], | |
'scores': kptscores[b][valid[b]], | |
'descriptors': feats[b][valid[b]]} for b in range(B) | |
] | |
def extract_3rd_party(self, x, model='aliked'): | |
""" | |
one image per batch | |
""" | |
x, rh1, rw1 = self.preprocess_tensor(x) | |
B, _, _H1, _W1 = x.shape | |
if model == 'aliked': | |
from third_party import extract_aliked_kpts | |
img = x | |
mkpts, scores = extract_aliked_kpts(img, self.device) | |
else: | |
raise ValueError('Unknown model') | |
M1, _ = self.descriptor(x) | |
M1 = F.normalize(M1, dim=1) | |
if mkpts.shape[1] > self.top_k: | |
idx = torch.argsort(scores, descending=True)[0][:self.top_k] | |
mkpts = mkpts[:,idx] | |
scores = scores[:,idx] | |
feats = self.interpolator(M1, mkpts, H = _H1, W = _W1) | |
feats = F.normalize(feats, dim=-1) | |
mkpts = mkpts * torch.tensor([rw1,rh1], device=mkpts.device).view(1, 1, -1) | |
return [ | |
{'keypoints': mkpts[b], | |
'scores': scores[b], | |
'descriptors': feats[b]} for b in range(B) | |
] | |
def extract_dense(self, x, n_limit=30000, thr=0.01): | |
self.set_softdetect(top_k=n_limit, scores_th=-1) | |
x, rh1, rw1 = self.preprocess_tensor(x) | |
B, _, _H1, _W1 = x.shape | |
M1, K1, H1 = self.forward(x) | |
M1 = F.normalize(M1, dim=1) | |
keypoints, kptscores, scoredispersitys = self.softdetect(K1) | |
keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)]) | |
kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)]) | |
keypoints = to_pixel_coords(keypoints, _H1, _W1) | |
feats = self.interpolator(M1, keypoints, H = _H1, W = _W1) | |
feats = F.normalize(feats, dim=-1) | |
H1 = self.filter(H1) | |
dense_kpts, dense_scores, inds = self.sample_dense_kpts(H1, n_limit=n_limit) | |
dense_keypoints = to_pixel_coords(dense_kpts, _H1, _W1) | |
dense_feats = self.interpolator(M1, dense_keypoints, H = _H1, W = _W1) | |
dense_feats = F.normalize(dense_feats, dim=-1) | |
keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1) | |
dense_keypoints = dense_keypoints * torch.tensor([rw1,rh1], device=dense_keypoints.device).view(1, -1) | |
valid = kptscores > self.detection_threshold | |
valid_dense = dense_scores > thr | |
return [ | |
{'keypoints': keypoints[b][valid[b]], | |
'scores': kptscores[b][valid[b]], | |
'descriptors': feats[b][valid[b]], | |
'keypoints_dense': dense_keypoints[b][valid_dense[b]], | |
'scores_dense': dense_scores[b][valid_dense[b]], | |
'descriptors_dense': dense_feats[b][valid_dense[b]]} for b in range(B) | |
] | |
def sample_dense_kpts(self, keypoint_logits, threshold=0.01, n_limit=30000, force_kpts = True): | |
B, K, H, W = keypoint_logits.shape | |
if n_limit < 0 or n_limit > H*W: | |
n_limit = min(H*W - 1, n_limit) | |
scoremap = keypoint_logits.permute(0,2,3,1) | |
scoremap = scoremap.reshape(B, H, W) | |
frame = torch.zeros(B, H, W, device=keypoint_logits.device) | |
frame[:, 1:-1, 1:-1] = 1 | |
scoremap = scoremap * frame | |
scoremap = scoremap.reshape(B, H*W) | |
grid = self.get_grid(B, H, W, device = keypoint_logits.device) | |
inds = torch.topk(scoremap, n_limit, dim=1).indices | |
# inds = torch.multinomial(scoremap, top_k, replacement=False) | |
kpts = torch.gather(grid, 1, inds[..., None].expand(B, n_limit, 2)) | |
scoremap = torch.gather(scoremap, 1, inds) | |
if force_kpts: | |
valid = scoremap > threshold | |
kpts = kpts[valid][None] | |
scoremap = scoremap[valid][None] | |
return kpts, scoremap, inds | |
def preprocess_tensor(self, x): | |
""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """ | |
if isinstance(x, np.ndarray) and len(x.shape) == 3: | |
x = torch.tensor(x).permute(2,0,1)[None] | |
x = x.to(self.device).float() | |
H, W = x.shape[-2:] | |
_H, _W = (H//32) * 32, (W//32) * 32 | |
rh, rw = H/_H, W/_W | |
x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False) | |
return x, rh, rw | |
def get_grid(self, B, H, W, device = None): | |
x1_n = torch.meshgrid( | |
*[ | |
torch.linspace( | |
-1 + 1 / n, 1 - 1 / n, n, device=device | |
) | |
for n in (B, H, W) | |
] | |
) | |
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) | |
return x1_n | |
def build(config=None, weights=None): | |
if config is None: | |
config = read_config('./configs/default.yaml') | |
if weights is not None: | |
config['weights'] = weights | |
device = torch.device(config['device']) | |
print('config', config) | |
detector = build_detector(config) | |
descriptor = build_descriptor(config) | |
model = RDD( | |
detector, | |
descriptor, | |
detection_threshold=config['detection_threshold'], | |
top_k=config['top_k'], | |
train_detector=config['train_detector'], | |
device=device | |
) | |
if 'weights' in config and config['weights'] is not None: | |
model.load_state_dict(torch.load(config['weights'], map_location='cpu')) | |
model.to(device) | |
return model |