Pinwheel commited on
Commit
8f6e968
Β·
1 Parent(s): baf6fe0

Initial commit

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: SuperGlue Image Matching
3
- emoji: πŸ¦€
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
 
1
  ---
2
  title: SuperGlue Image Matching
3
+ emoji: πŸ§šβ€β™€οΈ
4
  colorFrom: purple
5
  colorTo: indigo
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.cm as cm
2
+ import torch
3
+ import gradio as gr
4
+ from models.matching import Matching
5
+ from models.utils import (make_matching_plot_fast, process_image)
6
+
7
+ torch.set_grad_enabled(False)
8
+
9
+ # Load the SuperPoint and SuperGlue models.
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+
12
+ resize = [640, 640]
13
+ max_keypoints = 1024
14
+ keypoint_threshold = 0.005
15
+ nms_radius = 4
16
+ sinkhorn_iterations = 20
17
+ match_threshold = 0.2
18
+ resize_float = False
19
+
20
+ config_indoor = {
21
+ 'superpoint': {
22
+ 'nms_radius': nms_radius,
23
+ 'keypoint_threshold': keypoint_threshold,
24
+ 'max_keypoints': max_keypoints
25
+ },
26
+ 'superglue': {
27
+ 'weights': "indoor",
28
+ 'sinkhorn_iterations': sinkhorn_iterations,
29
+ 'match_threshold': match_threshold,
30
+ }
31
+ }
32
+
33
+ config_outdoor = {
34
+ 'superpoint': {
35
+ 'nms_radius': nms_radius,
36
+ 'keypoint_threshold': keypoint_threshold,
37
+ 'max_keypoints': max_keypoints
38
+ },
39
+ 'superglue': {
40
+ 'weights': "outdoor",
41
+ 'sinkhorn_iterations': sinkhorn_iterations,
42
+ 'match_threshold': match_threshold,
43
+ }
44
+ }
45
+
46
+ matching_indoor = Matching(config_indoor).eval().to(device)
47
+ matching_outdoor = Matching(config_outdoor).eval().to(device)
48
+
49
+ def run(input0, input1, superglue):
50
+ if superglue == "indoor":
51
+ matching = matching_indoor
52
+ else:
53
+ matching = matching_outdoor
54
+
55
+ name0 = 'image1'
56
+ name1 = 'image2'
57
+
58
+ # If a rotation integer is provided (e.g. from EXIF data), use it:
59
+ rot0, rot1 = 0, 0
60
+
61
+ # Load the image pair.
62
+ image0, inp0, scales0 = process_image(input0, device, resize, rot0, resize_float)
63
+ image1, inp1, scales1 = process_image(input1, device, resize, rot1, resize_float)
64
+
65
+ if image0 is None or image1 is None:
66
+ print('Problem reading image pair')
67
+ return
68
+
69
+ # Perform the matching.
70
+ pred = matching({'image0': inp0, 'image1': inp1})
71
+ pred = {k: v[0].detach().numpy() for k, v in pred.items()}
72
+ kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
73
+ matches, conf = pred['matches0'], pred['matching_scores0']
74
+
75
+ valid = matches > -1
76
+ mkpts0 = kpts0[valid]
77
+ mkpts1 = kpts1[matches[valid]]
78
+ mconf = conf[valid]
79
+
80
+
81
+ # Visualize the matches.
82
+ color = cm.jet(mconf)
83
+ text = [
84
+ 'SuperGlue',
85
+ 'Keypoints: {}:{}'.format(len(kpts0), len(kpts1)),
86
+ '{}'.format(len(mkpts0)),
87
+ ]
88
+
89
+ if rot0 != 0 or rot1 != 0:
90
+ text.append('Rotation: {}:{}'.format(rot0, rot1))
91
+
92
+ # Display extra parameter info.
93
+ k_thresh = matching.superpoint.config['keypoint_threshold']
94
+ m_thresh = matching.superglue.config['match_threshold']
95
+ small_text = [
96
+ 'Keypoint Threshold: {:.4f}'.format(k_thresh),
97
+ 'Match Threshold: {:.2f}'.format(m_thresh),
98
+ 'Image Pair: {}:{}'.format(name0, name1),
99
+ ]
100
+
101
+ output = make_matching_plot_fast(
102
+ image0, image1, kpts0, kpts1, mkpts0, mkpts1, color,
103
+ text, show_keypoints=True, small_text=small_text)
104
+
105
+ print('Source Image - {}, Destination Image - {}, {}, Match Percentage - {}'.format(name0, name1, text[2], len(mkpts0)/len(kpts0)))
106
+ return output, text[2], str((len(mkpts0)/len(kpts0))*100.0) + '%'
107
+
108
+ if __name__ == '__main__':
109
+
110
+ glue = gr.Interface(
111
+ fn=run,
112
+ inputs=[
113
+ gr.Image(label='Input Image'),
114
+ gr.Image(label='Match Image'),
115
+ gr.Radio(choices=["indoor", "outdoor"], value="Indoor", type="value", label="SuperGlueType", interactive=True),
116
+ ],
117
+ outputs=[gr.Image(
118
+ type="pil",
119
+ label="Result"),
120
+ gr.Textbox(label="Keypoints Matched"),
121
+ gr.Textbox(label="Match Percentage")
122
+ ]
123
+ )
124
+ glue.queue()
125
+ glue.launch()
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (149 Bytes). View file
 
models/__pycache__/matching.cpython-38.pyc ADDED
Binary file (1.67 kB). View file
 
models/__pycache__/superglue.cpython-38.pyc ADDED
Binary file (9.8 kB). View file
 
models/__pycache__/superpoint.cpython-38.pyc ADDED
Binary file (5.66 kB). View file
 
models/__pycache__/utils.cpython-38.pyc ADDED
Binary file (16.4 kB). View file
 
models/matching.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %BANNER_BEGIN%
2
+ # ---------------------------------------------------------------------
3
+ # %COPYRIGHT_BEGIN%
4
+ #
5
+ # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
6
+ #
7
+ # Unpublished Copyright (c) 2020
8
+ # Magic Leap, Inc., All Rights Reserved.
9
+ #
10
+ # NOTICE: All information contained herein is, and remains the property
11
+ # of COMPANY. The intellectual and technical concepts contained herein
12
+ # are proprietary to COMPANY and may be covered by U.S. and Foreign
13
+ # Patents, patents in process, and are protected by trade secret or
14
+ # copyright law. Dissemination of this information or reproduction of
15
+ # this material is strictly forbidden unless prior written permission is
16
+ # obtained from COMPANY. Access to the source code contained herein is
17
+ # hereby forbidden to anyone except current COMPANY employees, managers
18
+ # or contractors who have executed Confidentiality and Non-disclosure
19
+ # agreements explicitly covering such access.
20
+ #
21
+ # The copyright notice above does not evidence any actual or intended
22
+ # publication or disclosure of this source code, which includes
23
+ # information that is confidential and/or proprietary, and is a trade
24
+ # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
25
+ # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
26
+ # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
27
+ # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
28
+ # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
29
+ # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
30
+ # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
31
+ # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
32
+ #
33
+ # %COPYRIGHT_END%
34
+ # ----------------------------------------------------------------------
35
+ # %AUTHORS_BEGIN%
36
+ #
37
+ # Originating Authors: Paul-Edouard Sarlin
38
+ #
39
+ # %AUTHORS_END%
40
+ # --------------------------------------------------------------------*/
41
+ # %BANNER_END%
42
+
43
+ import torch
44
+
45
+ from .superpoint import SuperPoint
46
+ from .superglue import SuperGlue
47
+
48
+
49
+ class Matching(torch.nn.Module):
50
+ """ Image Matching Frontend (SuperPoint + SuperGlue) """
51
+ def __init__(self, config={}):
52
+ super().__init__()
53
+ self.superpoint = SuperPoint(config.get('superpoint', {}))
54
+ self.superglue = SuperGlue(config.get('superglue', {}))
55
+
56
+ def forward(self, data):
57
+ """ Run SuperPoint (optionally) and SuperGlue
58
+ SuperPoint is skipped if ['keypoints0', 'keypoints1'] exist in input
59
+ Args:
60
+ data: dictionary with minimal keys: ['image0', 'image1']
61
+ """
62
+ pred = {}
63
+
64
+ # Extract SuperPoint (keypoints, scores, descriptors) if not provided
65
+ if 'keypoints0' not in data:
66
+ pred0 = self.superpoint({'image': data['image0']})
67
+ pred = {**pred, **{k+'0': v for k, v in pred0.items()}}
68
+ if 'keypoints1' not in data:
69
+ pred1 = self.superpoint({'image': data['image1']})
70
+ pred = {**pred, **{k+'1': v for k, v in pred1.items()}}
71
+
72
+ # Batch all features
73
+ # We should either have i) one image per batch, or
74
+ # ii) the same number of local features for all images in the batch.
75
+ data = {**data, **pred}
76
+
77
+ for k in data:
78
+ if isinstance(data[k], (list, tuple)):
79
+ data[k] = torch.stack(data[k])
80
+
81
+ # Perform the matching
82
+ pred = {**pred, **self.superglue(data)}
83
+
84
+ return pred
models/superglue.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %BANNER_BEGIN%
2
+ # ---------------------------------------------------------------------
3
+ # %COPYRIGHT_BEGIN%
4
+ #
5
+ # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
6
+ #
7
+ # Unpublished Copyright (c) 2020
8
+ # Magic Leap, Inc., All Rights Reserved.
9
+ #
10
+ # NOTICE: All information contained herein is, and remains the property
11
+ # of COMPANY. The intellectual and technical concepts contained herein
12
+ # are proprietary to COMPANY and may be covered by U.S. and Foreign
13
+ # Patents, patents in process, and are protected by trade secret or
14
+ # copyright law. Dissemination of this information or reproduction of
15
+ # this material is strictly forbidden unless prior written permission is
16
+ # obtained from COMPANY. Access to the source code contained herein is
17
+ # hereby forbidden to anyone except current COMPANY employees, managers
18
+ # or contractors who have executed Confidentiality and Non-disclosure
19
+ # agreements explicitly covering such access.
20
+ #
21
+ # The copyright notice above does not evidence any actual or intended
22
+ # publication or disclosure of this source code, which includes
23
+ # information that is confidential and/or proprietary, and is a trade
24
+ # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
25
+ # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
26
+ # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
27
+ # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
28
+ # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
29
+ # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
30
+ # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
31
+ # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
32
+ #
33
+ # %COPYRIGHT_END%
34
+ # ----------------------------------------------------------------------
35
+ # %AUTHORS_BEGIN%
36
+ #
37
+ # Originating Authors: Paul-Edouard Sarlin
38
+ #
39
+ # %AUTHORS_END%
40
+ # --------------------------------------------------------------------*/
41
+ # %BANNER_END%
42
+
43
+ from copy import deepcopy
44
+ from pathlib import Path
45
+ from typing import List, Tuple
46
+
47
+ import torch
48
+ from torch import nn
49
+
50
+
51
+ def MLP(channels: List[int], do_bn: bool = True) -> nn.Module:
52
+ """ Multi-layer perceptron """
53
+ n = len(channels)
54
+ layers = []
55
+ for i in range(1, n):
56
+ layers.append(
57
+ nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True))
58
+ if i < (n-1):
59
+ if do_bn:
60
+ layers.append(nn.BatchNorm1d(channels[i]))
61
+ layers.append(nn.ReLU())
62
+ return nn.Sequential(*layers)
63
+
64
+
65
+ def normalize_keypoints(kpts, image_shape):
66
+ """ Normalize keypoints locations based on image image_shape"""
67
+ _, _, height, width = image_shape
68
+ one = kpts.new_tensor(1)
69
+ size = torch.stack([one*width, one*height])[None]
70
+ center = size / 2
71
+ scaling = size.max(1, keepdim=True).values * 0.7
72
+ return (kpts - center[:, None, :]) / scaling[:, None, :]
73
+
74
+
75
+ class KeypointEncoder(nn.Module):
76
+ """ Joint encoding of visual appearance and location using MLPs"""
77
+ def __init__(self, feature_dim: int, layers: List[int]) -> None:
78
+ super().__init__()
79
+ self.encoder = MLP([3] + layers + [feature_dim])
80
+ nn.init.constant_(self.encoder[-1].bias, 0.0)
81
+
82
+ def forward(self, kpts, scores):
83
+ inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)]
84
+ return self.encoder(torch.cat(inputs, dim=1))
85
+
86
+
87
+ def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
88
+ dim = query.shape[1]
89
+ scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
90
+ prob = torch.nn.functional.softmax(scores, dim=-1)
91
+ return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob
92
+
93
+
94
+ class MultiHeadedAttention(nn.Module):
95
+ """ Multi-head attention to increase model expressivitiy """
96
+ def __init__(self, num_heads: int, d_model: int):
97
+ super().__init__()
98
+ assert d_model % num_heads == 0
99
+ self.dim = d_model // num_heads
100
+ self.num_heads = num_heads
101
+ self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
102
+ self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
103
+
104
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
105
+ batch_dim = query.size(0)
106
+ query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
107
+ for l, x in zip(self.proj, (query, key, value))]
108
+ x, _ = attention(query, key, value)
109
+ return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))
110
+
111
+
112
+ class AttentionalPropagation(nn.Module):
113
+ def __init__(self, feature_dim: int, num_heads: int):
114
+ super().__init__()
115
+ self.attn = MultiHeadedAttention(num_heads, feature_dim)
116
+ self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
117
+ nn.init.constant_(self.mlp[-1].bias, 0.0)
118
+
119
+ def forward(self, x: torch.Tensor, source: torch.Tensor) -> torch.Tensor:
120
+ message = self.attn(x, source, source)
121
+ return self.mlp(torch.cat([x, message], dim=1))
122
+
123
+
124
+ class AttentionalGNN(nn.Module):
125
+ def __init__(self, feature_dim: int, layer_names: List[str]) -> None:
126
+ super().__init__()
127
+ self.layers = nn.ModuleList([
128
+ AttentionalPropagation(feature_dim, 4)
129
+ for _ in range(len(layer_names))])
130
+ self.names = layer_names
131
+
132
+ def forward(self, desc0: torch.Tensor, desc1: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
133
+ for layer, name in zip(self.layers, self.names):
134
+ if name == 'cross':
135
+ src0, src1 = desc1, desc0
136
+ else: # if name == 'self':
137
+ src0, src1 = desc0, desc1
138
+ delta0, delta1 = layer(desc0, src0), layer(desc1, src1)
139
+ desc0, desc1 = (desc0 + delta0), (desc1 + delta1)
140
+ return desc0, desc1
141
+
142
+
143
+ def log_sinkhorn_iterations(Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, iters: int) -> torch.Tensor:
144
+ """ Perform Sinkhorn Normalization in Log-space for stability"""
145
+ u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
146
+ for _ in range(iters):
147
+ u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
148
+ v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
149
+ return Z + u.unsqueeze(2) + v.unsqueeze(1)
150
+
151
+
152
+ def log_optimal_transport(scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
153
+ """ Perform Differentiable Optimal Transport in Log-space for stability"""
154
+ b, m, n = scores.shape
155
+ one = scores.new_tensor(1)
156
+ ms, ns = (m*one).to(scores), (n*one).to(scores)
157
+
158
+ bins0 = alpha.expand(b, m, 1)
159
+ bins1 = alpha.expand(b, 1, n)
160
+ alpha = alpha.expand(b, 1, 1)
161
+
162
+ couplings = torch.cat([torch.cat([scores, bins0], -1),
163
+ torch.cat([bins1, alpha], -1)], 1)
164
+
165
+ norm = - (ms + ns).log()
166
+ log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])
167
+ log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
168
+ log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
169
+
170
+ Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
171
+ Z = Z - norm # multiply probabilities by M+N
172
+ return Z
173
+
174
+
175
+ def arange_like(x, dim: int):
176
+ return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
177
+
178
+
179
+ class SuperGlue(nn.Module):
180
+ """SuperGlue feature matching middle-end
181
+
182
+ Given two sets of keypoints and locations, we determine the
183
+ correspondences by:
184
+ 1. Keypoint Encoding (normalization + visual feature and location fusion)
185
+ 2. Graph Neural Network with multiple self and cross-attention layers
186
+ 3. Final projection layer
187
+ 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
188
+ 5. Thresholding matrix based on mutual exclusivity and a match_threshold
189
+
190
+ The correspondence ids use -1 to indicate non-matching points.
191
+
192
+ Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
193
+ Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
194
+ Networks. In CVPR, 2020. https://arxiv.org/abs/1911.11763
195
+
196
+ """
197
+ default_config = {
198
+ 'descriptor_dim': 256,
199
+ 'weights': 'indoor',
200
+ 'keypoint_encoder': [32, 64, 128, 256],
201
+ 'GNN_layers': ['self', 'cross'] * 9,
202
+ 'sinkhorn_iterations': 100,
203
+ 'match_threshold': 0.2,
204
+ }
205
+
206
+ def __init__(self, config):
207
+ super().__init__()
208
+ self.config = {**self.default_config, **config}
209
+
210
+ self.kenc = KeypointEncoder(
211
+ self.config['descriptor_dim'], self.config['keypoint_encoder'])
212
+
213
+ self.gnn = AttentionalGNN(
214
+ feature_dim=self.config['descriptor_dim'], layer_names=self.config['GNN_layers'])
215
+
216
+ self.final_proj = nn.Conv1d(
217
+ self.config['descriptor_dim'], self.config['descriptor_dim'],
218
+ kernel_size=1, bias=True)
219
+
220
+ bin_score = torch.nn.Parameter(torch.tensor(1.))
221
+ self.register_parameter('bin_score', bin_score)
222
+
223
+ assert self.config['weights'] in ['indoor', 'outdoor']
224
+ path = Path(__file__).parent
225
+ path = path / 'weights/superglue_{}.pth'.format(self.config['weights'])
226
+ self.load_state_dict(torch.load(str(path)))
227
+ print('Loaded SuperGlue model (\"{}\" weights)'.format(
228
+ self.config['weights']))
229
+
230
+ def forward(self, data):
231
+ """Run SuperGlue on a pair of keypoints and descriptors"""
232
+ desc0, desc1 = data['descriptors0'], data['descriptors1']
233
+ kpts0, kpts1 = data['keypoints0'], data['keypoints1']
234
+
235
+ if kpts0.shape[1] == 0 or kpts1.shape[1] == 0: # no keypoints
236
+ shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1]
237
+ return {
238
+ 'matches0': kpts0.new_full(shape0, -1, dtype=torch.int),
239
+ 'matches1': kpts1.new_full(shape1, -1, dtype=torch.int),
240
+ 'matching_scores0': kpts0.new_zeros(shape0),
241
+ 'matching_scores1': kpts1.new_zeros(shape1),
242
+ }
243
+
244
+ # Keypoint normalization.
245
+ kpts0 = normalize_keypoints(kpts0, data['image0'].shape)
246
+ kpts1 = normalize_keypoints(kpts1, data['image1'].shape)
247
+
248
+ # Keypoint MLP encoder.
249
+ desc0 = desc0 + self.kenc(kpts0, data['scores0'])
250
+ desc1 = desc1 + self.kenc(kpts1, data['scores1'])
251
+
252
+ # Multi-layer Transformer network.
253
+ desc0, desc1 = self.gnn(desc0, desc1)
254
+
255
+ # Final MLP projection.
256
+ mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
257
+
258
+ # Compute matching descriptor distance.
259
+ scores = torch.einsum('bdn,bdm->bnm', mdesc0, mdesc1)
260
+ scores = scores / self.config['descriptor_dim']**.5
261
+
262
+ # Run the optimal transport.
263
+ scores = log_optimal_transport(
264
+ scores, self.bin_score,
265
+ iters=self.config['sinkhorn_iterations'])
266
+
267
+ # Get the matches with score above "match_threshold".
268
+ max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
269
+ indices0, indices1 = max0.indices, max1.indices
270
+ mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
271
+ mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
272
+ zero = scores.new_tensor(0)
273
+ mscores0 = torch.where(mutual0, max0.values.exp(), zero)
274
+ mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
275
+ valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
276
+ valid1 = mutual1 & valid0.gather(1, indices1)
277
+ indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
278
+ indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
279
+
280
+ return {
281
+ 'matches0': indices0, # use -1 for invalid match
282
+ 'matches1': indices1, # use -1 for invalid match
283
+ 'matching_scores0': mscores0,
284
+ 'matching_scores1': mscores1,
285
+ }
models/superpoint.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %BANNER_BEGIN%
2
+ # ---------------------------------------------------------------------
3
+ # %COPYRIGHT_BEGIN%
4
+ #
5
+ # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
6
+ #
7
+ # Unpublished Copyright (c) 2020
8
+ # Magic Leap, Inc., All Rights Reserved.
9
+ #
10
+ # NOTICE: All information contained herein is, and remains the property
11
+ # of COMPANY. The intellectual and technical concepts contained herein
12
+ # are proprietary to COMPANY and may be covered by U.S. and Foreign
13
+ # Patents, patents in process, and are protected by trade secret or
14
+ # copyright law. Dissemination of this information or reproduction of
15
+ # this material is strictly forbidden unless prior written permission is
16
+ # obtained from COMPANY. Access to the source code contained herein is
17
+ # hereby forbidden to anyone except current COMPANY employees, managers
18
+ # or contractors who have executed Confidentiality and Non-disclosure
19
+ # agreements explicitly covering such access.
20
+ #
21
+ # The copyright notice above does not evidence any actual or intended
22
+ # publication or disclosure of this source code, which includes
23
+ # information that is confidential and/or proprietary, and is a trade
24
+ # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
25
+ # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
26
+ # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
27
+ # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
28
+ # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
29
+ # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
30
+ # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
31
+ # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
32
+ #
33
+ # %COPYRIGHT_END%
34
+ # ----------------------------------------------------------------------
35
+ # %AUTHORS_BEGIN%
36
+ #
37
+ # Originating Authors: Paul-Edouard Sarlin
38
+ #
39
+ # %AUTHORS_END%
40
+ # --------------------------------------------------------------------*/
41
+ # %BANNER_END%
42
+
43
+ from pathlib import Path
44
+ import torch
45
+ from torch import nn
46
+
47
+ def simple_nms(scores, nms_radius: int):
48
+ """ Fast Non-maximum suppression to remove nearby points """
49
+ assert(nms_radius >= 0)
50
+
51
+ def max_pool(x):
52
+ return torch.nn.functional.max_pool2d(
53
+ x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)
54
+
55
+ zeros = torch.zeros_like(scores)
56
+ max_mask = scores == max_pool(scores)
57
+ for _ in range(2):
58
+ supp_mask = max_pool(max_mask.float()) > 0
59
+ supp_scores = torch.where(supp_mask, zeros, scores)
60
+ new_max_mask = supp_scores == max_pool(supp_scores)
61
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
62
+ return torch.where(max_mask, scores, zeros)
63
+
64
+
65
+ def remove_borders(keypoints, scores, border: int, height: int, width: int):
66
+ """ Removes keypoints too close to the border """
67
+ mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
68
+ mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
69
+ mask = mask_h & mask_w
70
+ return keypoints[mask], scores[mask]
71
+
72
+
73
+ def top_k_keypoints(keypoints, scores, k: int):
74
+ if k >= len(keypoints):
75
+ return keypoints, scores
76
+ scores, indices = torch.topk(scores, k, dim=0)
77
+ return keypoints[indices], scores
78
+
79
+
80
+ def sample_descriptors(keypoints, descriptors, s: int = 8):
81
+ """ Interpolate descriptors at keypoint locations """
82
+ b, c, h, w = descriptors.shape
83
+ keypoints = keypoints - s / 2 + 0.5
84
+ keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
85
+ ).to(keypoints)[None]
86
+ keypoints = keypoints*2 - 1 # normalize to (-1, 1)
87
+ args = {'align_corners': True} if torch.__version__ >= '1.3' else {}
88
+ descriptors = torch.nn.functional.grid_sample(
89
+ descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
90
+ descriptors = torch.nn.functional.normalize(
91
+ descriptors.reshape(b, c, -1), p=2, dim=1)
92
+ return descriptors
93
+
94
+
95
+ class SuperPoint(nn.Module):
96
+ """SuperPoint Convolutional Detector and Descriptor
97
+
98
+ SuperPoint: Self-Supervised Interest Point Detection and
99
+ Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
100
+ Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
101
+
102
+ """
103
+ default_config = {
104
+ 'descriptor_dim': 256,
105
+ 'nms_radius': 4,
106
+ 'keypoint_threshold': 0.005,
107
+ 'max_keypoints': -1,
108
+ 'remove_borders': 4,
109
+ }
110
+
111
+ def __init__(self, config):
112
+ super().__init__()
113
+ self.config = {**self.default_config, **config}
114
+
115
+ self.relu = nn.ReLU(inplace=True)
116
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
117
+ c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
118
+
119
+ self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
120
+ self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
121
+ self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
122
+ self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
123
+ self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
124
+ self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
125
+ self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
126
+ self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
127
+
128
+ self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
129
+ self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
130
+
131
+ self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
132
+ self.convDb = nn.Conv2d(
133
+ c5, self.config['descriptor_dim'],
134
+ kernel_size=1, stride=1, padding=0)
135
+
136
+ path = Path(__file__).parent / 'weights/superpoint_v1.pth'
137
+ self.load_state_dict(torch.load(str(path)))
138
+
139
+ mk = self.config['max_keypoints']
140
+ if mk == 0 or mk < -1:
141
+ raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
142
+
143
+ print('Loaded SuperPoint model')
144
+
145
+ def forward(self, data):
146
+ """ Compute keypoints, scores, descriptors for image """
147
+ # Shared Encoder
148
+ x = self.relu(self.conv1a(data['image']))
149
+ x = self.relu(self.conv1b(x))
150
+ x = self.pool(x)
151
+ x = self.relu(self.conv2a(x))
152
+ x = self.relu(self.conv2b(x))
153
+ x = self.pool(x)
154
+ x = self.relu(self.conv3a(x))
155
+ x = self.relu(self.conv3b(x))
156
+ x = self.pool(x)
157
+ x = self.relu(self.conv4a(x))
158
+ x = self.relu(self.conv4b(x))
159
+
160
+ # Compute the dense keypoint scores
161
+ cPa = self.relu(self.convPa(x))
162
+ scores = self.convPb(cPa)
163
+ scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
164
+ b, _, h, w = scores.shape
165
+ scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
166
+ scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
167
+ scores = simple_nms(scores, self.config['nms_radius'])
168
+
169
+ # Extract keypoints
170
+ keypoints = [
171
+ torch.nonzero(s > self.config['keypoint_threshold'])
172
+ for s in scores]
173
+ scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
174
+
175
+ # Discard keypoints near the image borders
176
+ keypoints, scores = list(zip(*[
177
+ remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
178
+ for k, s in zip(keypoints, scores)]))
179
+
180
+ # Keep the k keypoints with highest score
181
+ if self.config['max_keypoints'] >= 0:
182
+ keypoints, scores = list(zip(*[
183
+ top_k_keypoints(k, s, self.config['max_keypoints'])
184
+ for k, s in zip(keypoints, scores)]))
185
+
186
+ # Convert (h, w) to (x, y)
187
+ keypoints = [torch.flip(k, [1]).float() for k in keypoints]
188
+
189
+ # Compute the dense descriptors
190
+ cDa = self.relu(self.convDa(x))
191
+ descriptors = self.convDb(cDa)
192
+ descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
193
+
194
+ # Extract descriptors
195
+ descriptors = [sample_descriptors(k[None], d[None], 8)[0]
196
+ for k, d in zip(keypoints, descriptors)]
197
+
198
+ return {
199
+ 'keypoints': keypoints,
200
+ 'scores': scores,
201
+ 'descriptors': descriptors,
202
+ }
models/utils.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %BANNER_BEGIN%
2
+ # ---------------------------------------------------------------------
3
+ # %COPYRIGHT_BEGIN%
4
+ #
5
+ # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
6
+ #
7
+ # Unpublished Copyright (c) 2020
8
+ # Magic Leap, Inc., All Rights Reserved.
9
+ #
10
+ # NOTICE: All information contained herein is, and remains the property
11
+ # of COMPANY. The intellectual and technical concepts contained herein
12
+ # are proprietary to COMPANY and may be covered by U.S. and Foreign
13
+ # Patents, patents in process, and are protected by trade secret or
14
+ # copyright law. Dissemination of this information or reproduction of
15
+ # this material is strictly forbidden unless prior written permission is
16
+ # obtained from COMPANY. Access to the source code contained herein is
17
+ # hereby forbidden to anyone except current COMPANY employees, managers
18
+ # or contractors who have executed Confidentiality and Non-disclosure
19
+ # agreements explicitly covering such access.
20
+ #
21
+ # The copyright notice above does not evidence any actual or intended
22
+ # publication or disclosure of this source code, which includes
23
+ # information that is confidential and/or proprietary, and is a trade
24
+ # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
25
+ # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
26
+ # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
27
+ # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
28
+ # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
29
+ # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
30
+ # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
31
+ # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
32
+ #
33
+ # %COPYRIGHT_END%
34
+ # ----------------------------------------------------------------------
35
+ # %AUTHORS_BEGIN%
36
+ #
37
+ # Originating Authors: Paul-Edouard Sarlin
38
+ # Daniel DeTone
39
+ # Tomasz Malisiewicz
40
+ #
41
+ # %AUTHORS_END%
42
+ # --------------------------------------------------------------------*/
43
+ # %BANNER_END%
44
+
45
+ from pathlib import Path
46
+ import time
47
+ from collections import OrderedDict
48
+ from threading import Thread
49
+ import numpy as np
50
+ import cv2
51
+ import torch
52
+ import matplotlib.pyplot as plt
53
+ import matplotlib
54
+ matplotlib.use('Agg')
55
+
56
+
57
+ class AverageTimer:
58
+ """ Class to help manage printing simple timing of code execution. """
59
+
60
+ def __init__(self, smoothing=0.3, newline=False):
61
+ self.smoothing = smoothing
62
+ self.newline = newline
63
+ self.times = OrderedDict()
64
+ self.will_print = OrderedDict()
65
+ self.reset()
66
+
67
+ def reset(self):
68
+ now = time.time()
69
+ self.start = now
70
+ self.last_time = now
71
+ for name in self.will_print:
72
+ self.will_print[name] = False
73
+
74
+ def update(self, name='default'):
75
+ now = time.time()
76
+ dt = now - self.last_time
77
+ if name in self.times:
78
+ dt = self.smoothing * dt + (1 - self.smoothing) * self.times[name]
79
+ self.times[name] = dt
80
+ self.will_print[name] = True
81
+ self.last_time = now
82
+
83
+ def print(self, text='Timer'):
84
+ total = 0.
85
+ print('[{}]'.format(text), end=' ')
86
+ for key in self.times:
87
+ val = self.times[key]
88
+ if self.will_print[key]:
89
+ print('%s=%.3f' % (key, val), end=' ')
90
+ total += val
91
+ print('total=%.3f sec {%.1f FPS}' % (total, 1./total), end=' ')
92
+ if self.newline:
93
+ print(flush=True)
94
+ else:
95
+ print(end='\r', flush=True)
96
+ self.reset()
97
+
98
+
99
+ class VideoStreamer:
100
+ """ Class to help process image streams. Four types of possible inputs:"
101
+ 1.) USB Webcam.
102
+ 2.) An IP camera
103
+ 3.) A directory of images (files in directory matching 'image_glob').
104
+ 4.) A video file, such as an .mp4 or .avi file.
105
+ """
106
+ def __init__(self, basedir, resize, skip, image_glob, max_length=1000000):
107
+ self._ip_grabbed = False
108
+ self._ip_running = False
109
+ self._ip_camera = False
110
+ self._ip_image = None
111
+ self._ip_index = 0
112
+ self.cap = []
113
+ self.camera = True
114
+ self.video_file = False
115
+ self.listing = []
116
+ self.resize = resize
117
+ self.interp = cv2.INTER_AREA
118
+ self.i = 0
119
+ self.skip = skip
120
+ self.max_length = max_length
121
+ if isinstance(basedir, int) or basedir.isdigit():
122
+ print('==> Processing USB webcam input: {}'.format(basedir))
123
+ self.cap = cv2.VideoCapture(int(basedir))
124
+ self.listing = range(0, self.max_length)
125
+ elif basedir.startswith(('http', 'rtsp')):
126
+ print('==> Processing IP camera input: {}'.format(basedir))
127
+ self.cap = cv2.VideoCapture(basedir)
128
+ self.start_ip_camera_thread()
129
+ self._ip_camera = True
130
+ self.listing = range(0, self.max_length)
131
+ elif Path(basedir).is_dir():
132
+ print('==> Processing image directory input: {}'.format(basedir))
133
+ self.listing = list(Path(basedir).glob(image_glob[0]))
134
+ for j in range(1, len(image_glob)):
135
+ image_path = list(Path(basedir).glob(image_glob[j]))
136
+ self.listing = self.listing + image_path
137
+ self.listing.sort()
138
+ self.listing = self.listing[::self.skip]
139
+ self.max_length = np.min([self.max_length, len(self.listing)])
140
+ if self.max_length == 0:
141
+ raise IOError('No images found (maybe bad \'image_glob\' ?)')
142
+ self.listing = self.listing[:self.max_length]
143
+ self.camera = False
144
+ elif Path(basedir).exists():
145
+ print('==> Processing video input: {}'.format(basedir))
146
+ self.cap = cv2.VideoCapture(basedir)
147
+ self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
148
+ num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
149
+ self.listing = range(0, num_frames)
150
+ self.listing = self.listing[::self.skip]
151
+ self.video_file = True
152
+ self.max_length = np.min([self.max_length, len(self.listing)])
153
+ self.listing = self.listing[:self.max_length]
154
+ else:
155
+ raise ValueError('VideoStreamer input \"{}\" not recognized.'.format(basedir))
156
+ if self.camera and not self.cap.isOpened():
157
+ raise IOError('Could not read camera')
158
+
159
+ def load_image(self, impath):
160
+ """ Read image as grayscale and resize to img_size.
161
+ Inputs
162
+ impath: Path to input image.
163
+ Returns
164
+ grayim: uint8 numpy array sized H x W.
165
+ """
166
+ grayim = cv2.imread(impath, 0)
167
+ if grayim is None:
168
+ raise Exception('Error reading image %s' % impath)
169
+ w, h = grayim.shape[1], grayim.shape[0]
170
+ w_new, h_new = process_resize(w, h, self.resize)
171
+ grayim = cv2.resize(
172
+ grayim, (w_new, h_new), interpolation=self.interp)
173
+ return grayim
174
+
175
+ def next_frame(self):
176
+ """ Return the next frame, and increment internal counter.
177
+ Returns
178
+ image: Next H x W image.
179
+ status: True or False depending whether image was loaded.
180
+ """
181
+
182
+ if self.i == self.max_length:
183
+ return (None, False)
184
+ if self.camera:
185
+
186
+ if self._ip_camera:
187
+ #Wait for first image, making sure we haven't exited
188
+ while self._ip_grabbed is False and self._ip_exited is False:
189
+ time.sleep(.001)
190
+
191
+ ret, image = self._ip_grabbed, self._ip_image.copy()
192
+ if ret is False:
193
+ self._ip_running = False
194
+ else:
195
+ ret, image = self.cap.read()
196
+ if ret is False:
197
+ print('VideoStreamer: Cannot get image from camera')
198
+ return (None, False)
199
+ w, h = image.shape[1], image.shape[0]
200
+ if self.video_file:
201
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i])
202
+
203
+ w_new, h_new = process_resize(w, h, self.resize)
204
+ image = cv2.resize(image, (w_new, h_new),
205
+ interpolation=self.interp)
206
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
207
+ else:
208
+ image_file = str(self.listing[self.i])
209
+ image = self.load_image(image_file)
210
+ self.i = self.i + 1
211
+ return (image, True)
212
+
213
+ def start_ip_camera_thread(self):
214
+ self._ip_thread = Thread(target=self.update_ip_camera, args=())
215
+ self._ip_running = True
216
+ self._ip_thread.start()
217
+ self._ip_exited = False
218
+ return self
219
+
220
+ def update_ip_camera(self):
221
+ while self._ip_running:
222
+ ret, img = self.cap.read()
223
+ if ret is False:
224
+ self._ip_running = False
225
+ self._ip_exited = True
226
+ self._ip_grabbed = False
227
+ return
228
+
229
+ self._ip_image = img
230
+ self._ip_grabbed = ret
231
+ self._ip_index += 1
232
+ #print('IPCAMERA THREAD got frame {}'.format(self._ip_index))
233
+
234
+
235
+ def cleanup(self):
236
+ self._ip_running = False
237
+
238
+ # --- PREPROCESSING ---
239
+
240
+ def process_resize(w, h, resize):
241
+ assert(len(resize) > 0 and len(resize) <= 2)
242
+ if len(resize) == 1 and resize[0] > -1:
243
+ scale = resize[0] / max(h, w)
244
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
245
+ elif len(resize) == 1 and resize[0] == -1:
246
+ w_new, h_new = w, h
247
+ else: # len(resize) == 2:
248
+ w_new, h_new = resize[0], resize[1]
249
+
250
+ # Issue warning if resolution is too small or too large.
251
+ if max(w_new, h_new) < 160:
252
+ print('Warning: input resolution is very small, results may vary')
253
+ elif max(w_new, h_new) > 2000:
254
+ print('Warning: input resolution is very large, results may vary')
255
+
256
+ return w_new, h_new
257
+
258
+
259
+ def frame2tensor(frame, device):
260
+ return torch.from_numpy(frame/255.).float()[None, None].to(device)
261
+
262
+
263
+ def read_image(path, device, resize, rotation, resize_float):
264
+ image = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
265
+ if image is None:
266
+ return None, None, None
267
+ w, h = image.shape[1], image.shape[0]
268
+ w_new, h_new = process_resize(w, h, resize)
269
+ scales = (float(w) / float(w_new), float(h) / float(h_new))
270
+
271
+ if resize_float:
272
+ image = cv2.resize(image.astype('float32'), (w_new, h_new))
273
+ else:
274
+ image = cv2.resize(image, (w_new, h_new)).astype('float32')
275
+
276
+ if rotation != 0:
277
+ image = np.rot90(image, k=rotation)
278
+ if rotation % 2:
279
+ scales = scales[::-1]
280
+
281
+ inp = frame2tensor(image, device)
282
+ return image, inp, scales
283
+
284
+ def process_image(image, device, resize, rotation, resize_float):
285
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
286
+ if image is None:
287
+ return None, None, None
288
+ w, h = image.shape[1], image.shape[0]
289
+ w_new, h_new = process_resize(w, h, resize)
290
+ scales = (float(w) / float(w_new), float(h) / float(h_new))
291
+
292
+ if resize_float:
293
+ image = cv2.resize(image.astype('float32'), (w_new, h_new))
294
+ else:
295
+ image = cv2.resize(image, (w_new, h_new)).astype('float32')
296
+
297
+ if rotation != 0:
298
+ image = np.rot90(image, k=rotation)
299
+ if rotation % 2:
300
+ scales = scales[::-1]
301
+
302
+ inp = frame2tensor(image, device)
303
+ return image, inp, scales
304
+
305
+ # --- GEOMETRY ---
306
+
307
+
308
+ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
309
+ if len(kpts0) < 5:
310
+ return None
311
+
312
+ f_mean = np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]])
313
+ norm_thresh = thresh / f_mean
314
+
315
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
316
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
317
+
318
+ E, mask = cv2.findEssentialMat(
319
+ kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf,
320
+ method=cv2.RANSAC)
321
+
322
+ assert E is not None
323
+
324
+ best_num_inliers = 0
325
+ ret = None
326
+ for _E in np.split(E, len(E) / 3):
327
+ n, R, t, _ = cv2.recoverPose(
328
+ _E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
329
+ if n > best_num_inliers:
330
+ best_num_inliers = n
331
+ ret = (R, t[:, 0], mask.ravel() > 0)
332
+ return ret
333
+
334
+
335
+ def rotate_intrinsics(K, image_shape, rot):
336
+ """image_shape is the shape of the image after rotation"""
337
+ assert rot <= 3
338
+ h, w = image_shape[:2][::-1 if (rot % 2) else 1]
339
+ fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
340
+ rot = rot % 4
341
+ if rot == 1:
342
+ return np.array([[fy, 0., cy],
343
+ [0., fx, w-1-cx],
344
+ [0., 0., 1.]], dtype=K.dtype)
345
+ elif rot == 2:
346
+ return np.array([[fx, 0., w-1-cx],
347
+ [0., fy, h-1-cy],
348
+ [0., 0., 1.]], dtype=K.dtype)
349
+ else: # if rot == 3:
350
+ return np.array([[fy, 0., h-1-cy],
351
+ [0., fx, cx],
352
+ [0., 0., 1.]], dtype=K.dtype)
353
+
354
+
355
+ def rotate_pose_inplane(i_T_w, rot):
356
+ rotation_matrices = [
357
+ np.array([[np.cos(r), -np.sin(r), 0., 0.],
358
+ [np.sin(r), np.cos(r), 0., 0.],
359
+ [0., 0., 1., 0.],
360
+ [0., 0., 0., 1.]], dtype=np.float32)
361
+ for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
362
+ ]
363
+ return np.dot(rotation_matrices[rot], i_T_w)
364
+
365
+
366
+ def scale_intrinsics(K, scales):
367
+ scales = np.diag([1./scales[0], 1./scales[1], 1.])
368
+ return np.dot(scales, K)
369
+
370
+
371
+ def to_homogeneous(points):
372
+ return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
373
+
374
+
375
+ def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1):
376
+ kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
377
+ kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
378
+ kpts0 = to_homogeneous(kpts0)
379
+ kpts1 = to_homogeneous(kpts1)
380
+
381
+ t0, t1, t2 = T_0to1[:3, 3]
382
+ t_skew = np.array([
383
+ [0, -t2, t1],
384
+ [t2, 0, -t0],
385
+ [-t1, t0, 0]
386
+ ])
387
+ E = t_skew @ T_0to1[:3, :3]
388
+
389
+ Ep0 = kpts0 @ E.T # N x 3
390
+ p1Ep0 = np.sum(kpts1 * Ep0, -1) # N
391
+ Etp1 = kpts1 @ E # N x 3
392
+ d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2)
393
+ + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2))
394
+ return d
395
+
396
+
397
+ def angle_error_mat(R1, R2):
398
+ cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
399
+ cos = np.clip(cos, -1., 1.) # numercial errors can make it out of bounds
400
+ return np.rad2deg(np.abs(np.arccos(cos)))
401
+
402
+
403
+ def angle_error_vec(v1, v2):
404
+ n = np.linalg.norm(v1) * np.linalg.norm(v2)
405
+ return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
406
+
407
+
408
+ def compute_pose_error(T_0to1, R, t):
409
+ R_gt = T_0to1[:3, :3]
410
+ t_gt = T_0to1[:3, 3]
411
+ error_t = angle_error_vec(t, t_gt)
412
+ error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
413
+ error_R = angle_error_mat(R, R_gt)
414
+ return error_t, error_R
415
+
416
+
417
+ def pose_auc(errors, thresholds):
418
+ sort_idx = np.argsort(errors)
419
+ errors = np.array(errors.copy())[sort_idx]
420
+ recall = (np.arange(len(errors)) + 1) / len(errors)
421
+ errors = np.r_[0., errors]
422
+ recall = np.r_[0., recall]
423
+ aucs = []
424
+ for t in thresholds:
425
+ last_index = np.searchsorted(errors, t)
426
+ r = np.r_[recall[:last_index], recall[last_index-1]]
427
+ e = np.r_[errors[:last_index], t]
428
+ aucs.append(np.trapz(r, x=e)/t)
429
+ return aucs
430
+
431
+
432
+ # --- VISUALIZATION ---
433
+
434
+
435
+ def plot_image_pair(imgs, dpi=100, size=6, pad=.5):
436
+ n = len(imgs)
437
+ assert n == 2, 'number of images must be two'
438
+ figsize = (size*n, size*3/4) if size is not None else None
439
+ _, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
440
+ for i in range(n):
441
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap('gray'), vmin=0, vmax=255)
442
+ ax[i].get_yaxis().set_ticks([])
443
+ ax[i].get_xaxis().set_ticks([])
444
+ for spine in ax[i].spines.values(): # remove frame
445
+ spine.set_visible(False)
446
+ plt.tight_layout(pad=pad)
447
+
448
+
449
+ def plot_keypoints(kpts0, kpts1, color='w', ps=2):
450
+ ax = plt.gcf().axes
451
+ ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
452
+ ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
453
+
454
+
455
+ def plot_matches(kpts0, kpts1, color, lw=1.5, ps=4):
456
+ fig = plt.gcf()
457
+ ax = fig.axes
458
+ fig.canvas.draw()
459
+
460
+ transFigure = fig.transFigure.inverted()
461
+ fkpts0 = transFigure.transform(ax[0].transData.transform(kpts0))
462
+ fkpts1 = transFigure.transform(ax[1].transData.transform(kpts1))
463
+
464
+ fig.lines = [matplotlib.lines.Line2D(
465
+ (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), zorder=1,
466
+ transform=fig.transFigure, c=color[i], linewidth=lw)
467
+ for i in range(len(kpts0))]
468
+ ax[0].scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
469
+ ax[1].scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
470
+
471
+
472
+ def make_matching_plot(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
473
+ color, text, path, show_keypoints=False,
474
+ fast_viz=False, opencv_display=False,
475
+ opencv_title='matches', small_text=[]):
476
+
477
+ if fast_viz:
478
+ make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0, mkpts1,
479
+ color, text, path, show_keypoints, 10,
480
+ opencv_display, opencv_title, small_text)
481
+ return
482
+
483
+ plot_image_pair([image0, image1])
484
+ if show_keypoints:
485
+ plot_keypoints(kpts0, kpts1, color='k', ps=4)
486
+ plot_keypoints(kpts0, kpts1, color='w', ps=2)
487
+ plot_matches(mkpts0, mkpts1, color)
488
+
489
+ fig = plt.gcf()
490
+ txt_color = 'k' if image0[:100, :150].mean() > 200 else 'w'
491
+ fig.text(
492
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
493
+ fontsize=15, va='top', ha='left', color=txt_color)
494
+
495
+ txt_color = 'k' if image0[-100:, :150].mean() > 200 else 'w'
496
+ fig.text(
497
+ 0.01, 0.01, '\n'.join(small_text), transform=fig.axes[0].transAxes,
498
+ fontsize=5, va='bottom', ha='left', color=txt_color)
499
+
500
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
501
+ plt.close()
502
+
503
+
504
+ def make_matching_plot_fast(image0, image1, kpts0, kpts1, mkpts0,
505
+ mkpts1, color, text, path=None,
506
+ show_keypoints=False, margin=10,
507
+ opencv_display=False, opencv_title='',
508
+ small_text=[]):
509
+ H0, W0 = image0.shape
510
+ H1, W1 = image1.shape
511
+ H, W = max(H0, H1), W0 + W1 + margin
512
+
513
+ out = 255*np.ones((H, W), np.uint8)
514
+ out[:H0, :W0] = image0
515
+ out[:H1, W0+margin:] = image1
516
+ out = np.stack([out]*3, -1)
517
+
518
+ if show_keypoints:
519
+ kpts0, kpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int)
520
+ white = (255, 255, 255)
521
+ black = (0, 0, 0)
522
+ for x, y in kpts0:
523
+ cv2.circle(out, (x, y), 2, black, -1, lineType=cv2.LINE_AA)
524
+ cv2.circle(out, (x, y), 1, white, -1, lineType=cv2.LINE_AA)
525
+ for x, y in kpts1:
526
+ cv2.circle(out, (x + margin + W0, y), 2, black, -1,
527
+ lineType=cv2.LINE_AA)
528
+ cv2.circle(out, (x + margin + W0, y), 1, white, -1,
529
+ lineType=cv2.LINE_AA)
530
+
531
+ mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
532
+ color = (np.array(color[:, :3])*255).astype(int)[:, ::-1]
533
+ for (x0, y0), (x1, y1), c in zip(mkpts0, mkpts1, color):
534
+ c = c.tolist()
535
+ cv2.line(out, (x0, y0), (x1 + margin + W0, y1),
536
+ color=c, thickness=1, lineType=cv2.LINE_AA)
537
+ # display line end-points as circles
538
+ cv2.circle(out, (x0, y0), 2, c, -1, lineType=cv2.LINE_AA)
539
+ cv2.circle(out, (x1 + margin + W0, y1), 2, c, -1,
540
+ lineType=cv2.LINE_AA)
541
+
542
+ # Scale factor for consistent visualization across scales.
543
+ sc = min(H / 640., 2.0)
544
+
545
+ # Big text.
546
+ Ht = int(30 * sc) # text height
547
+ txt_color_fg = (255, 255, 255)
548
+ txt_color_bg = (0, 0, 0)
549
+ for i, t in enumerate(text):
550
+ cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
551
+ 1.0*sc, txt_color_bg, 2, cv2.LINE_AA)
552
+ cv2.putText(out, t, (int(8*sc), Ht*(i+1)), cv2.FONT_HERSHEY_DUPLEX,
553
+ 1.0*sc, txt_color_fg, 1, cv2.LINE_AA)
554
+
555
+ # Small text.
556
+ Ht = int(18 * sc) # text height
557
+ for i, t in enumerate(reversed(small_text)):
558
+ cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
559
+ 0.5*sc, txt_color_bg, 2, cv2.LINE_AA)
560
+ cv2.putText(out, t, (int(8*sc), int(H-Ht*(i+.6))), cv2.FONT_HERSHEY_DUPLEX,
561
+ 0.5*sc, txt_color_fg, 1, cv2.LINE_AA)
562
+ return out
563
+
564
+
565
+ def error_colormap(x):
566
+ return np.clip(
567
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)], -1), 0, 1)
models/weights/superglue_indoor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e710469be25ebe1e2ccf68edcae8b2945b0617c8e7e68412251d9d47f5052b1
3
+ size 48233807
models/weights/superglue_outdoor.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f5f5e9bb3febf07b69df633c4c3ff7a17f8af26a023aae2b9303d22339195bd
3
+ size 48233807
models/weights/superpoint_v1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52b6708629640ca883673b5d5c097c4ddad37d8048b33f09c8ca0d69db12c40e
3
+ size 5206086
requirements.txt ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.3
2
+ aiosignal==1.2.0
3
+ anyio==3.6.2
4
+ async-timeout==4.0.2
5
+ attrs==22.1.0
6
+ bcrypt==4.0.1
7
+ certifi==2022.9.24
8
+ cffi==1.15.1
9
+ charset-normalizer==2.1.1
10
+ click==8.1.3
11
+ contourpy==1.0.6
12
+ cryptography==38.0.1
13
+ cycler==0.11.0
14
+ fastapi==0.85.1
15
+ ffmpy==0.3.0
16
+ fonttools==4.38.0
17
+ frozenlist==1.3.1
18
+ fsspec==2022.10.0
19
+ gradio==3.8.1
20
+ h11==0.12.0
21
+ httpcore==0.15.0
22
+ httpx==0.23.0
23
+ idna==3.4
24
+ Jinja2==3.1.2
25
+ kiwisolver==1.4.4
26
+ linkify-it-py==1.0.3
27
+ markdown-it-py==2.1.0
28
+ MarkupSafe==2.1.1
29
+ matplotlib==3.6.1
30
+ mdit-py-plugins==0.3.1
31
+ mdurl==0.1.2
32
+ multidict==6.0.2
33
+ numpy==1.23.4
34
+ opencv-python==4.6.0.66
35
+ orjson==3.8.1
36
+ packaging==21.3
37
+ pandas==1.5.1
38
+ paramiko==2.11.0
39
+ Pillow==9.3.0
40
+ pycparser==2.21
41
+ pycryptodome==3.15.0
42
+ pydantic==1.10.2
43
+ pydub==0.25.1
44
+ PyNaCl==1.5.0
45
+ pyparsing==3.0.9
46
+ python-dateutil==2.8.2
47
+ python-multipart==0.0.5
48
+ pytz==2022.5
49
+ PyYAML==6.0
50
+ requests==2.28.1
51
+ rfc3986==1.5.0
52
+ six==1.16.0
53
+ sniffio==1.3.0
54
+ starlette==0.20.4
55
+ torch==1.13.0
56
+ typing_extensions==4.4.0
57
+ uc-micro-py==1.0.1
58
+ urllib3==1.26.12
59
+ uvicorn==0.19.0
60
+ websockets==10.4
61
+ yarl==1.8.1