Spaces:
Sleeping
Sleeping
File size: 2,366 Bytes
2673dcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
import torch.nn as nn
import kornia
from types import SimpleNamespace
from .utils import ImagePreprocessor
class DISK(nn.Module):
default_conf = {
'weights': 'depth',
'max_num_keypoints': None,
'desc_dim': 128,
'nms_window_size': 5,
'detection_threshold': 0.0,
'pad_if_not_divisible': True,
}
preprocess_conf = {
**ImagePreprocessor.default_conf,
'resize': 1024,
'grayscale': False,
}
required_data_keys = ['image']
def __init__(self, **conf) -> None:
super().__init__()
self.conf = {**self.default_conf, **conf}
self.conf = SimpleNamespace(**self.conf)
self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
def forward(self, data: dict) -> dict:
""" Compute keypoints, scores, descriptors for image """
for key in self.required_data_keys:
assert key in data, f'Missing key {key} in data'
image = data['image']
features = self.model(
image,
n=self.conf.max_num_keypoints,
window_size=self.conf.nms_window_size,
score_threshold=self.conf.detection_threshold,
pad_if_not_divisible=self.conf.pad_if_not_divisible
)
keypoints = [f.keypoints for f in features]
scores = [f.detection_scores for f in features]
descriptors = [f.descriptors for f in features]
del features
keypoints = torch.stack(keypoints, 0)
scores = torch.stack(scores, 0)
descriptors = torch.stack(descriptors, 0)
return {
'keypoints': keypoints.to(image),
'keypoint_scores': scores.to(image),
'descriptors': descriptors.to(image),
}
def extract(self, img: torch.Tensor, **conf) -> dict:
""" Perform extraction with online resizing"""
if img.dim() == 3:
img = img[None] # add batch dim
assert img.dim() == 4 and img.shape[0] == 1
shape = img.shape[-2:][::-1]
img, scales = ImagePreprocessor(
**{**self.preprocess_conf, **conf})(img)
feats = self.forward({'image': img})
feats['image_size'] = torch.tensor(shape)[None].to(img).float()
feats['keypoints'] = (feats['keypoints'] + .5) / scales[None] - .5
return feats
|