Spaces:
Running
Running
Realcat
commited on
Commit
·
4a7fc02
1
Parent(s):
260ecba
update: rerun ransac
Browse files- common/app_class.py +25 -2
- common/utils.py +59 -10
- common/viz.py +35 -1
common/app_class.py
CHANGED
@@ -9,6 +9,7 @@ from common.utils import (
|
|
9 |
load_config,
|
10 |
get_matcher_zoo,
|
11 |
run_matching,
|
|
|
12 |
gen_examples,
|
13 |
GRADIO_VERSION,
|
14 |
)
|
@@ -159,7 +160,9 @@ class ImageMatchingApp:
|
|
159 |
label="Ransac Iterations",
|
160 |
value=self.cfg["defaults"]["ransac_max_iter"],
|
161 |
)
|
162 |
-
|
|
|
|
|
163 |
with gr.Accordion("Geometry Setting", open=False):
|
164 |
with gr.Row(equal_height=False):
|
165 |
choice_geometry_type = gr.Radio(
|
@@ -171,6 +174,7 @@ class ImageMatchingApp:
|
|
171 |
)
|
172 |
|
173 |
# collect inputs
|
|
|
174 |
inputs = [
|
175 |
input_image0,
|
176 |
input_image1,
|
@@ -184,6 +188,7 @@ class ImageMatchingApp:
|
|
184 |
ransac_max_iter,
|
185 |
choice_geometry_type,
|
186 |
gr.State(self.matcher_zoo),
|
|
|
187 |
]
|
188 |
|
189 |
# Add some examples
|
@@ -207,7 +212,8 @@ class ImageMatchingApp:
|
|
207 |
with gr.Column():
|
208 |
output_keypoints = gr.Image(label="Keypoints", type="numpy")
|
209 |
output_matches_raw = gr.Image(
|
210 |
-
label="Raw Matches",
|
|
|
211 |
)
|
212 |
output_matches_ransac = gr.Image(
|
213 |
label="Ransac Matches", type="numpy"
|
@@ -254,6 +260,7 @@ class ImageMatchingApp:
|
|
254 |
matcher_info,
|
255 |
geometry_result,
|
256 |
output_wrapped,
|
|
|
257 |
]
|
258 |
# button callbacks
|
259 |
button_run.click(
|
@@ -288,6 +295,22 @@ class ImageMatchingApp:
|
|
288 |
fn=self.ui_reset_state, inputs=None, outputs=reset_outputs
|
289 |
)
|
290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
# estimate geo
|
292 |
choice_geometry_type.change(
|
293 |
fn=generate_warp_images,
|
|
|
9 |
load_config,
|
10 |
get_matcher_zoo,
|
11 |
run_matching,
|
12 |
+
run_ransac,
|
13 |
gen_examples,
|
14 |
GRADIO_VERSION,
|
15 |
)
|
|
|
160 |
label="Ransac Iterations",
|
161 |
value=self.cfg["defaults"]["ransac_max_iter"],
|
162 |
)
|
163 |
+
button_ransac = gr.Button(
|
164 |
+
value="Rerun RANSAC", variant="primary"
|
165 |
+
)
|
166 |
with gr.Accordion("Geometry Setting", open=False):
|
167 |
with gr.Row(equal_height=False):
|
168 |
choice_geometry_type = gr.Radio(
|
|
|
174 |
)
|
175 |
|
176 |
# collect inputs
|
177 |
+
state_cache = gr.State({})
|
178 |
inputs = [
|
179 |
input_image0,
|
180 |
input_image1,
|
|
|
188 |
ransac_max_iter,
|
189 |
choice_geometry_type,
|
190 |
gr.State(self.matcher_zoo),
|
191 |
+
# state_cache,
|
192 |
]
|
193 |
|
194 |
# Add some examples
|
|
|
212 |
with gr.Column():
|
213 |
output_keypoints = gr.Image(label="Keypoints", type="numpy")
|
214 |
output_matches_raw = gr.Image(
|
215 |
+
label="Raw Matches",
|
216 |
+
type="numpy",
|
217 |
)
|
218 |
output_matches_ransac = gr.Image(
|
219 |
label="Ransac Matches", type="numpy"
|
|
|
260 |
matcher_info,
|
261 |
geometry_result,
|
262 |
output_wrapped,
|
263 |
+
state_cache,
|
264 |
]
|
265 |
# button callbacks
|
266 |
button_run.click(
|
|
|
295 |
fn=self.ui_reset_state, inputs=None, outputs=reset_outputs
|
296 |
)
|
297 |
|
298 |
+
# run ransac button action
|
299 |
+
button_ransac.click(
|
300 |
+
fn=run_ransac,
|
301 |
+
inputs=[
|
302 |
+
ransac_method,
|
303 |
+
ransac_reproj_threshold,
|
304 |
+
ransac_confidence,
|
305 |
+
ransac_max_iter,
|
306 |
+
state_cache,
|
307 |
+
],
|
308 |
+
outputs=[
|
309 |
+
output_matches_ransac,
|
310 |
+
matches_result_info,
|
311 |
+
],
|
312 |
+
)
|
313 |
+
|
314 |
# estimate geo
|
315 |
choice_geometry_type.change(
|
316 |
fn=generate_warp_images,
|
common/utils.py
CHANGED
@@ -265,12 +265,13 @@ def filter_matches(
|
|
265 |
mask = np.array(mask.ravel().astype("bool"), dtype="bool")
|
266 |
if H is not None:
|
267 |
if feature_type == "KEYPOINT":
|
268 |
-
pred["
|
269 |
-
pred["
|
270 |
-
pred["
|
271 |
elif feature_type == "LINE":
|
272 |
-
pred["
|
273 |
-
pred["
|
|
|
274 |
return pred
|
275 |
|
276 |
|
@@ -440,6 +441,50 @@ def generate_warp_images(
|
|
440 |
return None, None
|
441 |
|
442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
def run_matching(
|
444 |
image0: np.ndarray,
|
445 |
image1: np.ndarray,
|
@@ -496,7 +541,7 @@ def run_matching(
|
|
496 |
output_matches_ransac = None
|
497 |
|
498 |
# super slow!
|
499 |
-
if "roma" in key.lower():
|
500 |
gr.Info(
|
501 |
f"Success! Please be patient and allow for about 2-3 minutes."
|
502 |
f" Due to CPU inference, {key} is quiet slow."
|
@@ -592,7 +637,7 @@ def run_matching(
|
|
592 |
"Image 1 - Ransac matched keypoints",
|
593 |
]
|
594 |
output_matches_ransac, num_matches_ransac = display_matches(
|
595 |
-
pred, titles=titles
|
596 |
)
|
597 |
gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
|
598 |
logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
|
@@ -607,17 +652,20 @@ def run_matching(
|
|
607 |
choice_geometry_type,
|
608 |
)
|
609 |
plt.close("all")
|
610 |
-
del pred
|
611 |
logger.info(f"TOTAL time: {time.time()-t0:.3f}s")
|
612 |
gr.Info(f"In summary, total time: {time.time()-t0:.3f}s")
|
613 |
|
|
|
|
|
|
|
614 |
return (
|
615 |
output_keypoints,
|
616 |
output_matches_raw,
|
617 |
output_matches_ransac,
|
618 |
{
|
619 |
-
"
|
620 |
-
"
|
621 |
},
|
622 |
{
|
623 |
"match_conf": match_conf,
|
@@ -627,6 +675,7 @@ def run_matching(
|
|
627 |
"geom_info": geom_info,
|
628 |
},
|
629 |
output_wrapped,
|
|
|
630 |
)
|
631 |
|
632 |
|
|
|
265 |
mask = np.array(mask.ravel().astype("bool"), dtype="bool")
|
266 |
if H is not None:
|
267 |
if feature_type == "KEYPOINT":
|
268 |
+
pred["mkeypoints0_orig"] = mkpts0[mask]
|
269 |
+
pred["mkeypoints1_orig"] = mkpts1[mask]
|
270 |
+
pred["mmconf"] = pred["mconf"][mask]
|
271 |
elif feature_type == "LINE":
|
272 |
+
pred["mline_keypoints0_orig"] = mkpts0[mask]
|
273 |
+
pred["mline_keypoints1_orig"] = mkpts1[mask]
|
274 |
+
pred["H"] = H
|
275 |
return pred
|
276 |
|
277 |
|
|
|
441 |
return None, None
|
442 |
|
443 |
|
444 |
+
def run_ransac(
|
445 |
+
ransac_method: str = DEFAULT_RANSAC_METHOD,
|
446 |
+
ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
|
447 |
+
ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
|
448 |
+
ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
|
449 |
+
state_cache: Dict[str, Any] = None,
|
450 |
+
):
|
451 |
+
t1 = time.time()
|
452 |
+
logger.info(
|
453 |
+
f"Run RANSAC matches using: {ransac_method} with threshold: {ransac_reproj_threshold}"
|
454 |
+
)
|
455 |
+
logger.info(
|
456 |
+
f"Run RANSAC matches using: {ransac_confidence} with iter: {ransac_max_iter}"
|
457 |
+
)
|
458 |
+
# if enable_ransac:
|
459 |
+
filter_matches(
|
460 |
+
state_cache,
|
461 |
+
ransac_method=ransac_method,
|
462 |
+
ransac_reproj_threshold=ransac_reproj_threshold,
|
463 |
+
ransac_confidence=ransac_confidence,
|
464 |
+
ransac_max_iter=ransac_max_iter,
|
465 |
+
)
|
466 |
+
gr.Info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
|
467 |
+
logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
|
468 |
+
t1 = time.time()
|
469 |
+
|
470 |
+
# plot images with ransac matches
|
471 |
+
titles = [
|
472 |
+
"Image 0 - Ransac matched keypoints",
|
473 |
+
"Image 1 - Ransac matched keypoints",
|
474 |
+
]
|
475 |
+
output_matches_ransac, num_matches_ransac = display_matches(
|
476 |
+
state_cache, titles=titles, tag="KPTS_RANSAC"
|
477 |
+
)
|
478 |
+
gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
|
479 |
+
logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
|
480 |
+
t1 = time.time()
|
481 |
+
num_matches_raw = state_cache["num_matches_raw"]
|
482 |
+
return output_matches_ransac, {
|
483 |
+
"num_matches_raw": num_matches_raw,
|
484 |
+
"num_matches_ransac": num_matches_ransac,
|
485 |
+
}
|
486 |
+
|
487 |
+
|
488 |
def run_matching(
|
489 |
image0: np.ndarray,
|
490 |
image1: np.ndarray,
|
|
|
541 |
output_matches_ransac = None
|
542 |
|
543 |
# super slow!
|
544 |
+
if "roma" in key.lower() and device == "cpu":
|
545 |
gr.Info(
|
546 |
f"Success! Please be patient and allow for about 2-3 minutes."
|
547 |
f" Due to CPU inference, {key} is quiet slow."
|
|
|
637 |
"Image 1 - Ransac matched keypoints",
|
638 |
]
|
639 |
output_matches_ransac, num_matches_ransac = display_matches(
|
640 |
+
pred, titles=titles, tag="KPTS_RANSAC"
|
641 |
)
|
642 |
gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
|
643 |
logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
|
|
|
652 |
choice_geometry_type,
|
653 |
)
|
654 |
plt.close("all")
|
655 |
+
# del pred
|
656 |
logger.info(f"TOTAL time: {time.time()-t0:.3f}s")
|
657 |
gr.Info(f"In summary, total time: {time.time()-t0:.3f}s")
|
658 |
|
659 |
+
state_cache = pred
|
660 |
+
state_cache["num_matches_raw"] = num_matches_raw
|
661 |
+
state_cache["num_matches_ransac"] = num_matches_ransac
|
662 |
return (
|
663 |
output_keypoints,
|
664 |
output_matches_raw,
|
665 |
output_matches_ransac,
|
666 |
{
|
667 |
+
"num_raw_matches": num_matches_raw,
|
668 |
+
"num_ransac_matches": num_matches_ransac,
|
669 |
},
|
670 |
{
|
671 |
"match_conf": match_conf,
|
|
|
675 |
"geom_info": geom_info,
|
676 |
},
|
677 |
output_wrapped,
|
678 |
+
state_cache,
|
679 |
)
|
680 |
|
681 |
|
common/viz.py
CHANGED
@@ -156,7 +156,11 @@ def make_matching_figure(
|
|
156 |
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
|
157 |
|
158 |
# draw matches
|
159 |
-
if
|
|
|
|
|
|
|
|
|
160 |
fig.canvas.draw()
|
161 |
transFigure = fig.transFigure.inverted()
|
162 |
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
@@ -377,6 +381,7 @@ def display_matches(
|
|
377 |
titles: List[str] = [],
|
378 |
texts: List[str] = [],
|
379 |
dpi: int = 300,
|
|
|
380 |
) -> Tuple[np.ndarray, int]:
|
381 |
"""
|
382 |
Displays the matches between two images.
|
@@ -393,11 +398,13 @@ def display_matches(
|
|
393 |
img1 = pred["image1_orig"]
|
394 |
|
395 |
num_inliers = 0
|
|
|
396 |
if (
|
397 |
"keypoints0_orig" in pred
|
398 |
and "keypoints1_orig" in pred
|
399 |
and pred["keypoints0_orig"] is not None
|
400 |
and pred["keypoints1_orig"] is not None
|
|
|
401 |
):
|
402 |
mkpts0 = pred["keypoints0_orig"]
|
403 |
mkpts1 = pred["keypoints1_orig"]
|
@@ -417,11 +424,38 @@ def display_matches(
|
|
417 |
texts=texts,
|
418 |
)
|
419 |
fig = fig_mkpts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
if (
|
421 |
"line0_orig" in pred
|
422 |
and "line1_orig" in pred
|
423 |
and pred["line0_orig"] is not None
|
424 |
and pred["line1_orig"] is not None
|
|
|
425 |
):
|
426 |
# lines
|
427 |
mtlines0 = pred["line0_orig"]
|
|
|
156 |
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
|
157 |
|
158 |
# draw matches
|
159 |
+
if (
|
160 |
+
mkpts0.shape[0] != 0
|
161 |
+
and mkpts1.shape[0] != 0
|
162 |
+
and mkpts0.shape == mkpts1.shape
|
163 |
+
):
|
164 |
fig.canvas.draw()
|
165 |
transFigure = fig.transFigure.inverted()
|
166 |
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
|
|
381 |
titles: List[str] = [],
|
382 |
texts: List[str] = [],
|
383 |
dpi: int = 300,
|
384 |
+
tag: str = "KPTS_RAW", # KPTS_RAW, KPTS_RANSAC, LINES_RAW, LINES_RANSAC,
|
385 |
) -> Tuple[np.ndarray, int]:
|
386 |
"""
|
387 |
Displays the matches between two images.
|
|
|
398 |
img1 = pred["image1_orig"]
|
399 |
|
400 |
num_inliers = 0
|
401 |
+
# draw raw matches
|
402 |
if (
|
403 |
"keypoints0_orig" in pred
|
404 |
and "keypoints1_orig" in pred
|
405 |
and pred["keypoints0_orig"] is not None
|
406 |
and pred["keypoints1_orig"] is not None
|
407 |
+
and tag == "KPTS_RAW"
|
408 |
):
|
409 |
mkpts0 = pred["keypoints0_orig"]
|
410 |
mkpts1 = pred["keypoints1_orig"]
|
|
|
424 |
texts=texts,
|
425 |
)
|
426 |
fig = fig_mkpts
|
427 |
+
elif (
|
428 |
+
"mkeypoints0_orig" in pred
|
429 |
+
and "mkeypoints1_orig" in pred
|
430 |
+
and pred["mkeypoints0_orig"] is not None
|
431 |
+
and pred["mkeypoints1_orig"] is not None
|
432 |
+
and tag == "KPTS_RANSAC"
|
433 |
+
): # draw ransac matches
|
434 |
+
mkpts0 = pred["mkeypoints0_orig"]
|
435 |
+
mkpts1 = pred["mkeypoints1_orig"]
|
436 |
+
num_inliers = len(mkpts0)
|
437 |
+
if "mmconf" in pred:
|
438 |
+
mmconf = pred["mmconf"]
|
439 |
+
else:
|
440 |
+
mmconf = np.ones(len(mkpts0))
|
441 |
+
fig_mkpts = draw_matches_core(
|
442 |
+
mkpts0,
|
443 |
+
mkpts1,
|
444 |
+
img0,
|
445 |
+
img1,
|
446 |
+
mmconf,
|
447 |
+
dpi=dpi,
|
448 |
+
titles=titles,
|
449 |
+
texts=texts,
|
450 |
+
)
|
451 |
+
fig = fig_mkpts
|
452 |
+
# TODO: draw lines
|
453 |
if (
|
454 |
"line0_orig" in pred
|
455 |
and "line1_orig" in pred
|
456 |
and pred["line0_orig"] is not None
|
457 |
and pred["line1_orig"] is not None
|
458 |
+
# and (tag == "LINES_RAW" or tag == "LINES_RANSAC")
|
459 |
):
|
460 |
# lines
|
461 |
mtlines0 = pred["line0_orig"]
|