Spaces:
Running
Running
Realcat
commited on
Commit
·
0bc7901
1
Parent(s):
82ee2a0
add: ModelCache
Browse files- common/utils.py +73 -11
- hloc/matchers/omniglue.py +1 -0
- test_app_cli.py +36 -5
common/utils.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1 |
import os
|
2 |
import cv2
|
|
|
3 |
import torch
|
4 |
import random
|
|
|
|
|
5 |
import numpy as np
|
6 |
import gradio as gr
|
7 |
from pathlib import Path
|
@@ -42,6 +45,66 @@ MATCHER_ZOO = None
|
|
42 |
models_already_loaded = {}
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def load_config(config_name: str) -> Dict[str, Any]:
|
46 |
"""
|
47 |
Load a YAML configuration file.
|
@@ -579,6 +642,7 @@ def run_matching(
|
|
579 |
ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
|
580 |
choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
|
581 |
matcher_zoo: Dict[str, Any] = None,
|
|
|
582 |
) -> Tuple[
|
583 |
np.ndarray,
|
584 |
np.ndarray,
|
@@ -639,15 +703,12 @@ def run_matching(
|
|
639 |
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
640 |
t0 = time.time()
|
641 |
cache_key = "{}_{}".format(key, match_conf["model"]["name"])
|
642 |
-
|
643 |
-
|
644 |
matcher.conf["max_keypoints"] = extract_max_keypoints
|
645 |
matcher.conf["match_threshold"] = match_threshold
|
646 |
logger.info(f"Loaded cached model {cache_key}")
|
647 |
-
|
648 |
-
matcher = get_model(match_conf)
|
649 |
-
models_already_loaded[cache_key] = matcher
|
650 |
-
# gr.Info(f"Loading model using: {time.time()-t0:.3f}s")
|
651 |
logger.info(f"Loading model using: {time.time()-t0:.3f}s")
|
652 |
t1 = time.time()
|
653 |
|
@@ -663,14 +724,15 @@ def run_matching(
|
|
663 |
extract_conf["model"]["max_keypoints"] = extract_max_keypoints
|
664 |
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
665 |
cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
|
666 |
-
|
667 |
-
|
|
|
|
|
|
|
668 |
extractor.conf["max_keypoints"] = extract_max_keypoints
|
669 |
extractor.conf["keypoint_threshold"] = keypoint_threshold
|
670 |
logger.info(f"Loaded cached model {cache_key}")
|
671 |
-
|
672 |
-
extractor = get_feature_model(extract_conf)
|
673 |
-
models_already_loaded[cache_key] = extractor
|
674 |
pred0 = extract_features.extract(
|
675 |
extractor, image0, extract_conf["preprocessing"]
|
676 |
)
|
|
|
1 |
import os
|
2 |
import cv2
|
3 |
+
import sys
|
4 |
import torch
|
5 |
import random
|
6 |
+
import psutil
|
7 |
+
import shutil
|
8 |
import numpy as np
|
9 |
import gradio as gr
|
10 |
from pathlib import Path
|
|
|
45 |
models_already_loaded = {}
|
46 |
|
47 |
|
48 |
+
class ModelCache:
|
49 |
+
def __init__(self, max_memory_size: int = 8):
|
50 |
+
self.max_memory_size = max_memory_size
|
51 |
+
self.current_memory_size = 0
|
52 |
+
self.model_dict = {}
|
53 |
+
self.model_timestamps = []
|
54 |
+
|
55 |
+
def cache_model(self, model_key, model_loader_func, model_conf):
|
56 |
+
if model_key in self.model_dict:
|
57 |
+
self.model_timestamps.remove(model_key)
|
58 |
+
self.model_timestamps.append(model_key)
|
59 |
+
logger.info(f"Load cached {model_key}")
|
60 |
+
return self.model_dict[model_key]
|
61 |
+
|
62 |
+
model = self._load_model_from_disk(model_loader_func, model_conf)
|
63 |
+
while self._calculate_model_memory() > self.max_memory_size:
|
64 |
+
if len(self.model_timestamps) == 0:
|
65 |
+
logger.warn(
|
66 |
+
"RAM: {}GB, MAX RAM: {}GB".format(
|
67 |
+
self._calculate_model_memory(), self.max_memory_size
|
68 |
+
)
|
69 |
+
)
|
70 |
+
break
|
71 |
+
oldest_model_key = self.model_timestamps.pop(0)
|
72 |
+
self.current_memory_size = self._calculate_model_memory()
|
73 |
+
logger.info(f"Del cached {oldest_model_key}")
|
74 |
+
del self.model_dict[oldest_model_key]
|
75 |
+
|
76 |
+
self.model_dict[model_key] = model
|
77 |
+
self.model_timestamps.append(model_key)
|
78 |
+
|
79 |
+
self.print_memory_usage()
|
80 |
+
logger.info(f"Total cached {list(self.model_dict.keys())}")
|
81 |
+
|
82 |
+
return model
|
83 |
+
|
84 |
+
def _load_model_from_disk(self, model_loader_func, model_conf):
|
85 |
+
return model_loader_func(model_conf)
|
86 |
+
|
87 |
+
def _calculate_model_memory(self, verbose=False):
|
88 |
+
host_colocation = int(os.environ.get("HOST_COLOCATION", "1"))
|
89 |
+
vm = psutil.virtual_memory()
|
90 |
+
du = shutil.disk_usage(".")
|
91 |
+
vm_ratio = host_colocation * vm.used / vm.total
|
92 |
+
if verbose:
|
93 |
+
logger.info(
|
94 |
+
f"RAM: {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}GB"
|
95 |
+
)
|
96 |
+
# logger.info(
|
97 |
+
# f"DISK: {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}GB"
|
98 |
+
# )
|
99 |
+
return vm.used / 1e9
|
100 |
+
|
101 |
+
def print_memory_usage(self):
|
102 |
+
self._calculate_model_memory(verbose=True)
|
103 |
+
|
104 |
+
|
105 |
+
model_cache = ModelCache()
|
106 |
+
|
107 |
+
|
108 |
def load_config(config_name: str) -> Dict[str, Any]:
|
109 |
"""
|
110 |
Load a YAML configuration file.
|
|
|
642 |
ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
|
643 |
choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
|
644 |
matcher_zoo: Dict[str, Any] = None,
|
645 |
+
use_cached_model: bool = True,
|
646 |
) -> Tuple[
|
647 |
np.ndarray,
|
648 |
np.ndarray,
|
|
|
703 |
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
704 |
t0 = time.time()
|
705 |
cache_key = "{}_{}".format(key, match_conf["model"]["name"])
|
706 |
+
matcher = model_cache.cache_model(cache_key, get_model, match_conf)
|
707 |
+
if use_cached_model:
|
708 |
matcher.conf["max_keypoints"] = extract_max_keypoints
|
709 |
matcher.conf["match_threshold"] = match_threshold
|
710 |
logger.info(f"Loaded cached model {cache_key}")
|
711 |
+
|
|
|
|
|
|
|
712 |
logger.info(f"Loading model using: {time.time()-t0:.3f}s")
|
713 |
t1 = time.time()
|
714 |
|
|
|
724 |
extract_conf["model"]["max_keypoints"] = extract_max_keypoints
|
725 |
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
726 |
cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
|
727 |
+
|
728 |
+
extractor = model_cache.cache_model(
|
729 |
+
cache_key, get_feature_model, extract_conf
|
730 |
+
)
|
731 |
+
if use_cached_model:
|
732 |
extractor.conf["max_keypoints"] = extract_max_keypoints
|
733 |
extractor.conf["keypoint_threshold"] = keypoint_threshold
|
734 |
logger.info(f"Loaded cached model {cache_key}")
|
735 |
+
|
|
|
|
|
736 |
pred0 = extract_features.extract(
|
737 |
extractor, image0, extract_conf["preprocessing"]
|
738 |
)
|
hloc/matchers/omniglue.py
CHANGED
@@ -10,6 +10,7 @@ from ..utils.base_model import BaseModel
|
|
10 |
thirdparty_path = Path(__file__).parent / "../../third_party"
|
11 |
sys.path.append(str(thirdparty_path))
|
12 |
from omniglue.src import omniglue
|
|
|
13 |
omniglue_path = thirdparty_path / "omniglue"
|
14 |
|
15 |
|
|
|
10 |
thirdparty_path = Path(__file__).parent / "../../third_party"
|
11 |
sys.path.append(str(thirdparty_path))
|
12 |
from omniglue.src import omniglue
|
13 |
+
|
14 |
omniglue_path = thirdparty_path / "omniglue"
|
15 |
|
16 |
|
test_app_cli.py
CHANGED
@@ -12,11 +12,11 @@ from common.utils import (
|
|
12 |
from common.api import ImageMatchingAPI
|
13 |
|
14 |
|
15 |
-
def
|
16 |
img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
|
17 |
img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
|
18 |
-
image0 = cv2.imread(str(img_path1))[:, :, ::-1]
|
19 |
-
image1 = cv2.imread(str(img_path2))[:, :, ::-1]
|
20 |
|
21 |
matcher_zoo_restored = get_matcher_zoo(config["matcher_zoo"])
|
22 |
for k, v in matcher_zoo_restored.items():
|
@@ -27,15 +27,46 @@ def test_api(config: dict = None):
|
|
27 |
logger.info(f"Testing {k} ...")
|
28 |
api = ImageMatchingAPI(conf=v, device=device)
|
29 |
api(image0, image1)
|
30 |
-
log_path = ROOT / "
|
31 |
log_path.mkdir(exist_ok=True, parents=True)
|
32 |
api.visualize(log_path=log_path)
|
33 |
else:
|
34 |
logger.info(f"Skipping {k} ...")
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
if __name__ == "__main__":
|
38 |
import argparse
|
39 |
|
40 |
config = load_config(ROOT / "common/config.yaml")
|
41 |
-
|
|
|
|
12 |
from common.api import ImageMatchingAPI
|
13 |
|
14 |
|
15 |
+
def test_all(config: dict = None):
|
16 |
img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
|
17 |
img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
|
18 |
+
image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB
|
19 |
+
image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB
|
20 |
|
21 |
matcher_zoo_restored = get_matcher_zoo(config["matcher_zoo"])
|
22 |
for k, v in matcher_zoo_restored.items():
|
|
|
27 |
logger.info(f"Testing {k} ...")
|
28 |
api = ImageMatchingAPI(conf=v, device=device)
|
29 |
api(image0, image1)
|
30 |
+
log_path = ROOT / "experiments" / "all"
|
31 |
log_path.mkdir(exist_ok=True, parents=True)
|
32 |
api.visualize(log_path=log_path)
|
33 |
else:
|
34 |
logger.info(f"Skipping {k} ...")
|
35 |
|
36 |
|
37 |
+
def test_one():
|
38 |
+
img_path1 = ROOT / "datasets/sacre_coeur/mapping/02928139_3448003521.jpg"
|
39 |
+
img_path2 = ROOT / "datasets/sacre_coeur/mapping/17295357_9106075285.jpg"
|
40 |
+
image0 = cv2.imread(str(img_path1))[:, :, ::-1] # RGB
|
41 |
+
image1 = cv2.imread(str(img_path2))[:, :, ::-1] # RGB
|
42 |
+
|
43 |
+
conf = {
|
44 |
+
"matcher": {
|
45 |
+
"output": "matches-omniglue",
|
46 |
+
"model": {
|
47 |
+
"name": "omniglue",
|
48 |
+
"match_threshold": 0.2,
|
49 |
+
"features": "null",
|
50 |
+
},
|
51 |
+
"preprocessing": {
|
52 |
+
"grayscale": False,
|
53 |
+
"resize_max": 1024,
|
54 |
+
"dfactor": 8,
|
55 |
+
"force_resize": False,
|
56 |
+
},
|
57 |
+
},
|
58 |
+
"dense": True,
|
59 |
+
}
|
60 |
+
api = ImageMatchingAPI(conf=conf, device=device)
|
61 |
+
api(image0, image1)
|
62 |
+
log_path = ROOT / "experiments" / "one"
|
63 |
+
log_path.mkdir(exist_ok=True, parents=True)
|
64 |
+
api.visualize(log_path=log_path)
|
65 |
+
|
66 |
+
|
67 |
if __name__ == "__main__":
|
68 |
import argparse
|
69 |
|
70 |
config = load_config(ROOT / "common/config.yaml")
|
71 |
+
test_one()
|
72 |
+
test_all(config)
|