Realcat commited on
Commit
b11e1d7
·
1 Parent(s): 5ff702b

add: scores to liftfeat

Browse files
imcui/hloc/extractors/liftfeat.py CHANGED
@@ -12,12 +12,6 @@ sys.path.append(str(fire_path))
12
  from models.liftfeat_wrapper import LiftFeat, MODEL_PATH
13
 
14
 
15
- def select_idx(N, M):
16
- numbers = list(range(0, N))
17
- selected = random.sample(numbers, M)
18
- return selected
19
-
20
-
21
  class Liftfeat(BaseModel):
22
  default_conf = {
23
  "keypoint_threshold": 0.05,
@@ -42,9 +36,9 @@ class Liftfeat(BaseModel):
42
 
43
  keypoints = pred["keypoints"]
44
  descriptors = pred["descriptors"]
45
- scores = torch.ones_like(pred["keypoints"][:, 0])
46
  if self.conf["max_keypoints"] < len(keypoints):
47
- idxs = select_idx(len(keypoints), self.conf["max_keypoints"])
48
  keypoints = keypoints[idxs, :2]
49
  descriptors = descriptors[idxs]
50
  scores = scores[idxs]
 
12
  from models.liftfeat_wrapper import LiftFeat, MODEL_PATH
13
 
14
 
 
 
 
 
 
 
15
  class Liftfeat(BaseModel):
16
  default_conf = {
17
  "keypoint_threshold": 0.05,
 
36
 
37
  keypoints = pred["keypoints"]
38
  descriptors = pred["descriptors"]
39
+ scores = pred["scores"]
40
  if self.conf["max_keypoints"] < len(keypoints):
41
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
42
  keypoints = keypoints[idxs, :2]
43
  descriptors = descriptors[idxs]
44
  scores = scores[idxs]
imcui/third_party/LiftFeat/models/liftfeat_wrapper.py CHANGED
@@ -9,22 +9,21 @@ from models.model import LiftFeatSPModel
9
  from models.interpolator import InterpolateSparse2d
10
  from utils.config import featureboost_config
11
 
12
- device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
13
 
14
- MODEL_PATH=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
15
 
16
 
17
  class NonMaxSuppression(torch.nn.Module):
18
  def __init__(self, rep_thr=0.1, top_k=4096):
19
- super(NonMaxSuppression,self).__init__()
20
  self.max_filter = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
21
  self.rep_thr = rep_thr
22
- self.top_k=top_k
23
-
24
-
25
- def NMS(self, x, threshold = 0.05, kernel_size = 5):
26
  B, _, H, W = x.shape
27
- pad=kernel_size//2
28
  local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
29
  pos = (x == local_max) & (x > threshold)
30
  pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos]
@@ -32,17 +31,18 @@ class NonMaxSuppression(torch.nn.Module):
32
  pad_val = max([len(x) for x in pos_batched])
33
  pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device)
34
 
35
- #Pad kpts and build (B, N, 2) tensor
36
  for b in range(len(pos_batched)):
37
- pos[b, :len(pos_batched[b]), :] = pos_batched[b]
38
 
39
  return pos
40
-
41
  def forward(self, score):
42
- pos = self.NMS(score,self.rep_thr)
43
-
44
  return pos
45
 
 
46
  def load_model(model, weight_path):
47
  pretrained_weights = torch.load(weight_path, map_location="cpu")
48
 
@@ -72,82 +72,82 @@ def load_model(model, weight_path):
72
 
73
 
74
  import torch.nn as nn
 
 
75
  class LiftFeat(nn.Module):
76
- def __init__(self,weight=MODEL_PATH,top_k=4096,detect_threshold=0.1):
77
  super().__init__()
78
- self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
- self.net=LiftFeatSPModel(featureboost_config).to(self.device).eval()
80
- self.top_k=top_k
81
- self.sampler=InterpolateSparse2d('bicubic')
82
- self.net=load_model(self.net,weight)
83
- self.detector=NonMaxSuppression(rep_thr=detect_threshold)
84
- self.net=self.net.to(self.device)
85
- self.detector=self.detector.to(self.device)
86
- self.sampler=self.sampler.to(self.device)
87
-
88
- def image_preprocess(self,image: np.ndarray):
89
- H,W,C=image.shape[0],image.shape[1],image.shape[2]
90
-
91
- _H=math.ceil(H/32)*32
92
- _W=math.ceil(W/32)*32
93
-
94
- pad_h=_H-H
95
- pad_w=_W-W
96
-
97
- image=cv2.copyMakeBorder(image,0,pad_h,0,pad_w,cv2.BORDER_CONSTANT,None,(0, 0, 0))
98
-
99
- pad_info=[0,pad_h,0,pad_w]
100
-
101
- if len(image.shape)==3:
102
- image=image[None,...]
103
-
104
- image=torch.tensor(image).permute(0,3,1,2)/255
105
- image=image.to(device)
106
 
107
  return image, pad_info
108
-
109
  @torch.inference_mode()
110
- def extract(self,image: np.ndarray):
111
- image,pad_info=self.image_preprocess(image)
112
- B,_,_H1,_W1=image.shape
113
-
114
- M1,K1,D1=self.net.forward1(image)
115
- refine_M=self.net.forward2(M1,K1,D1)
116
-
117
- refine_M=refine_M.reshape(M1.shape[0],M1.shape[2],M1.shape[3],-1).permute(0,3,1,2)
118
- refine_M=torch.nn.functional.normalize(refine_M,2,dim=1)
119
-
120
- descs_map=refine_M
121
- # descs_map=M1
122
-
123
- scores=torch.softmax(K1,dim=1)[:,:64]
124
- heatmap=scores.permute(0,2,3,1).reshape(scores.shape[0],scores.shape[2],scores.shape[3],8,8)
125
- heatmap=heatmap.permute(0,1,3,2,4).reshape(scores.shape[0],1,scores.shape[2]*8,scores.shape[3]*8)
126
-
127
- pos=self.detector(heatmap)
128
- kpts=pos.squeeze(0)
129
- mask_w=kpts[...,0]<(_W1-pad_info[-1])
130
- kpts=kpts[mask_w]
131
- mask_h=kpts[..., 1]<(_H1-pad_info[1])
132
- kpts=kpts[mask_h]
133
-
134
- descs=self.sampler(descs_map,kpts.unsqueeze(0),_H1,_W1)
135
- descs=torch.nn.functional.normalize(descs,p=2,dim=1)
136
- descs=descs.squeeze(0)
137
-
138
- return {
139
- 'descriptors':descs,
140
- 'keypoints':kpts
141
- }
142
-
143
  def match_liftfeat(self, img1, img2, min_cossim=-1):
144
  # import pdb;pdb.set_trace()
145
- data1=self.extract(img1)
146
- data2=self.extract(img2)
147
-
148
- kpts1,feats1=data1['keypoints'],data1['descriptors']
149
- kpts2,feats2=data2['keypoints'],data2['descriptors']
150
-
151
  cossim = feats1 @ feats2.t()
152
  cossim_t = feats2 @ feats1.t()
153
 
@@ -165,9 +165,8 @@ class LiftFeat(nn.Module):
165
  else:
166
  idx0 = idx0[mutual]
167
  idx1 = match12[mutual]
168
-
169
- mkpts1,mkpts2=kpts1[idx0],kpts2[idx1]
170
- mkpts1,mkpts2=mkpts1.cpu().numpy(),mkpts2.cpu().numpy()
171
 
172
- return mkpts1, mkpts2
 
173
 
 
 
9
  from models.interpolator import InterpolateSparse2d
10
  from utils.config import featureboost_config
11
 
12
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
13
 
14
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), "../weights/LiftFeat.pth")
15
 
16
 
17
  class NonMaxSuppression(torch.nn.Module):
18
  def __init__(self, rep_thr=0.1, top_k=4096):
19
+ super(NonMaxSuppression, self).__init__()
20
  self.max_filter = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
21
  self.rep_thr = rep_thr
22
+ self.top_k = top_k
23
+
24
+ def NMS(self, x, threshold=0.05, kernel_size=5):
 
25
  B, _, H, W = x.shape
26
+ pad = kernel_size // 2
27
  local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
28
  pos = (x == local_max) & (x > threshold)
29
  pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos]
 
31
  pad_val = max([len(x) for x in pos_batched])
32
  pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device)
33
 
34
+ # Pad kpts and build (B, N, 2) tensor
35
  for b in range(len(pos_batched)):
36
+ pos[b, : len(pos_batched[b]), :] = pos_batched[b]
37
 
38
  return pos
39
+
40
  def forward(self, score):
41
+ pos = self.NMS(score, self.rep_thr)
42
+
43
  return pos
44
 
45
+
46
  def load_model(model, weight_path):
47
  pretrained_weights = torch.load(weight_path, map_location="cpu")
48
 
 
72
 
73
 
74
  import torch.nn as nn
75
+
76
+
77
  class LiftFeat(nn.Module):
78
+ def __init__(self, weight=MODEL_PATH, top_k=4096, detect_threshold=0.1):
79
  super().__init__()
80
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ self.net = LiftFeatSPModel(featureboost_config).to(self.device).eval()
82
+ self.top_k = top_k
83
+ self.sampler = InterpolateSparse2d("bicubic")
84
+ self.net = load_model(self.net, weight)
85
+ self.detector = NonMaxSuppression(rep_thr=detect_threshold)
86
+ self.net = self.net.to(self.device)
87
+ self.detector = self.detector.to(self.device)
88
+ self.sampler = self.sampler.to(self.device)
89
+
90
+ def image_preprocess(self, image: np.ndarray):
91
+ H, W, C = image.shape[0], image.shape[1], image.shape[2]
92
+
93
+ _H = math.ceil(H / 32) * 32
94
+ _W = math.ceil(W / 32) * 32
95
+
96
+ pad_h = _H - H
97
+ pad_w = _W - W
98
+
99
+ image = cv2.copyMakeBorder(image, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, None, (0, 0, 0))
100
+
101
+ pad_info = [0, pad_h, 0, pad_w]
102
+
103
+ if len(image.shape) == 3:
104
+ image = image[None, ...]
105
+
106
+ image = torch.tensor(image).permute(0, 3, 1, 2) / 255
107
+ image = image.to(device)
108
 
109
  return image, pad_info
110
+
111
  @torch.inference_mode()
112
+ def extract(self, image: np.ndarray):
113
+ image, pad_info = self.image_preprocess(image)
114
+ B, _, _H1, _W1 = image.shape
115
+
116
+ M1, K1, D1 = self.net.forward1(image)
117
+ refine_M = self.net.forward2(M1, K1, D1)
118
+
119
+ refine_M = refine_M.reshape(M1.shape[0], M1.shape[2], M1.shape[3], -1).permute(0, 3, 1, 2)
120
+ refine_M = torch.nn.functional.normalize(refine_M, 2, dim=1)
121
+
122
+ descs_map = refine_M
123
+
124
+ scores = torch.softmax(K1, dim=1)[:, :64]
125
+ heatmap = scores.permute(0, 2, 3, 1).reshape(scores.shape[0], scores.shape[2], scores.shape[3], 8, 8)
126
+ heatmap = heatmap.permute(0, 1, 3, 2, 4).reshape(scores.shape[0], 1, scores.shape[2] * 8, scores.shape[3] * 8)
127
+
128
+ pos = self.detector(heatmap)
129
+ kpts = pos.squeeze(0)
130
+ mask_w = kpts[..., 0] < (_W1 - pad_info[-1])
131
+ kpts = kpts[mask_w]
132
+ mask_h = kpts[..., 1] < (_H1 - pad_info[1])
133
+ kpts = kpts[mask_h]
134
+
135
+ scores = self.sampler(heatmap, kpts.unsqueeze(0), _H1, _W1)
136
+ scores = scores.squeeze(0).reshape(-1)
137
+ descs = self.sampler(descs_map, kpts.unsqueeze(0), _H1, _W1)
138
+ descs = torch.nn.functional.normalize(descs, p=2, dim=1)
139
+ descs = descs.squeeze(0)
140
+
141
+ return {"descriptors": descs, "keypoints": kpts, "scores": scores}
142
+
 
 
143
  def match_liftfeat(self, img1, img2, min_cossim=-1):
144
  # import pdb;pdb.set_trace()
145
+ data1 = self.extract(img1)
146
+ data2 = self.extract(img2)
147
+
148
+ kpts1, feats1 = data1["keypoints"], data1["descriptors"]
149
+ kpts2, feats2 = data2["keypoints"], data2["descriptors"]
150
+
151
  cossim = feats1 @ feats2.t()
152
  cossim_t = feats2 @ feats1.t()
153
 
 
165
  else:
166
  idx0 = idx0[mutual]
167
  idx1 = match12[mutual]
 
 
 
168
 
169
+ mkpts1, mkpts2 = kpts1[idx0], kpts2[idx1]
170
+ mkpts1, mkpts2 = mkpts1.cpu().numpy(), mkpts2.cpu().numpy()
171
 
172
+ return mkpts1, mkpts2