Spaces:
Sleeping
Sleeping
File size: 6,240 Bytes
a80d6bb 86d1c07 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb c74a070 a80d6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import os, pdb
from PIL import Image
import numpy as np
import torch
from .tools import common
from .tools.dataloader import norm_RGB
from .nets.patchnet import *
def load_network(model_fn):
checkpoint = torch.load(model_fn)
print("\n>> Creating net = " + checkpoint["net"])
net = eval(checkpoint["net"])
nb_of_weights = common.model_size(net)
print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )")
# initialization
weights = checkpoint["state_dict"]
net.load_state_dict({k.replace("module.", ""): v for k, v in weights.items()})
return net.eval()
class NonMaxSuppression(torch.nn.Module):
def __init__(self, rel_thr=0.7, rep_thr=0.7):
nn.Module.__init__(self)
self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.rel_thr = rel_thr
self.rep_thr = rep_thr
def forward(self, reliability, repeatability, **kw):
assert len(reliability) == len(repeatability) == 1
reliability, repeatability = reliability[0], repeatability[0]
# local maxima
maxima = repeatability == self.max_filter(repeatability)
# remove low peaks
maxima *= repeatability >= self.rep_thr
maxima *= reliability >= self.rel_thr
return maxima.nonzero().t()[2:4]
def extract_multiscale(
net,
img,
detector,
scale_f=2**0.25,
min_scale=0.0,
max_scale=1,
min_size=256,
max_size=1024,
verbose=False,
):
old_bm = torch.backends.cudnn.benchmark
torch.backends.cudnn.benchmark = False # speedup
# extract keypoints at multiple scales
B, three, H, W = img.shape
assert B == 1 and three == 3, "should be a batch with a single RGB image"
assert max_scale <= 1
s = 1.0 # current scale factor
X, Y, S, C, Q, D = [], [], [], [], [], []
while s + 0.001 >= max(min_scale, min_size / max(H, W)):
if s - 0.001 <= min(max_scale, max_size / max(H, W)):
nh, nw = img.shape[2:]
if verbose:
print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}")
# extract descriptors
with torch.no_grad():
res = net(imgs=[img])
# get output and reliability map
descriptors = res["descriptors"][0]
reliability = res["reliability"][0]
repeatability = res["repeatability"][0]
# normalize the reliability for nms
# extract maxima and descs
y, x = detector(**res) # nms
c = reliability[0, 0, y, x]
q = repeatability[0, 0, y, x]
d = descriptors[0, :, y, x].t()
n = d.shape[0]
# accumulate multiple scales
X.append(x.float() * W / nw)
Y.append(y.float() * H / nh)
S.append((32 / s) * torch.ones(n, dtype=torch.float32, device=d.device))
C.append(c)
Q.append(q)
D.append(d)
s /= scale_f
# down-scale the image for next iteration
nh, nw = round(H * s), round(W * s)
img = F.interpolate(img, (nh, nw), mode="bilinear", align_corners=False)
# restore value
torch.backends.cudnn.benchmark = old_bm
Y = torch.cat(Y)
X = torch.cat(X)
S = torch.cat(S) # scale
scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability
XYS = torch.stack([X, Y, S], dim=-1)
D = torch.cat(D)
return XYS, D, scores
def extract_keypoints(args):
iscuda = common.torch_set_gpu(args.gpu)
# load the network...
net = load_network(args.model)
if iscuda:
net = net.cuda()
# create the non-maxima detector
detector = NonMaxSuppression(
rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr
)
while args.images:
img_path = args.images.pop(0)
if img_path.endswith(".txt"):
args.images = open(img_path).read().splitlines() + args.images
continue
print(f"\nExtracting features for {img_path}")
img = Image.open(img_path).convert("RGB")
W, H = img.size
img = norm_RGB(img)[None]
if iscuda:
img = img.cuda()
# extract keypoints/descriptors for a single image
xys, desc, scores = extract_multiscale(
net,
img,
detector,
scale_f=args.scale_f,
min_scale=args.min_scale,
max_scale=args.max_scale,
min_size=args.min_size,
max_size=args.max_size,
verbose=True,
)
xys = xys.cpu().numpy()
desc = desc.cpu().numpy()
scores = scores.cpu().numpy()
idxs = scores.argsort()[-args.top_k or None :]
outpath = img_path + "." + args.tag
print(f"Saving {len(idxs)} keypoints to {outpath}")
np.savez(
open(outpath, "wb"),
imsize=(W, H),
keypoints=xys[idxs],
descriptors=desc[idxs],
scores=scores[idxs],
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Extract keypoints for a given image")
parser.add_argument("--model", type=str, required=True, help="model path")
parser.add_argument(
"--images", type=str, required=True, nargs="+", help="images / list"
)
parser.add_argument("--tag", type=str, default="r2d2", help="output file tag")
parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints")
parser.add_argument("--scale-f", type=float, default=2**0.25)
parser.add_argument("--min-size", type=int, default=256)
parser.add_argument("--max-size", type=int, default=1024)
parser.add_argument("--min-scale", type=float, default=0)
parser.add_argument("--max-scale", type=float, default=1)
parser.add_argument("--reliability-thr", type=float, default=0.7)
parser.add_argument("--repeatability-thr", type=float, default=0.7)
parser.add_argument(
"--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU"
)
args = parser.parse_args()
extract_keypoints(args)
|