Spaces:
Configuration error
Configuration error
add initial version of mast3r sfm and glomap/colmap wrapper
Browse files- NOTICE +5 -0
- README.md +18 -2
- demo.py +2 -2
- demo_glomap.py +52 -0
- dust3r +1 -1
- kapture_mast3r_mapping.py +127 -0
- make_pairs.py +96 -0
- mast3r/catmlp_dpt_head.py +116 -0
- mast3r/cloud_opt/sparse_ga.py +47 -9
- mast3r/colmap/mapping.py +195 -0
- mast3r/demo.py +90 -46
- mast3r/demo_glomap.py +338 -0
- mast3r/image_pairs.py +115 -0
- mast3r/losses.py +1 -1
- mast3r/retrieval/graph.py +77 -0
- mast3r/retrieval/model.py +271 -0
- mast3r/retrieval/processor.py +129 -0
NOTICE
CHANGED
@@ -101,3 +101,8 @@ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
|
101 |
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
102 |
POSSIBILITY OF SUCH DAMAGE.
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
101 |
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
102 |
POSSIBILITY OF SUCH DAMAGE.
|
103 |
|
104 |
+
====
|
105 |
+
gtolias/how
|
106 |
+
https://github.com/gtolias/how
|
107 |
+
|
108 |
+
MIT License https://github.com/gtolias/how/blob/master/LICENSE
|
README.md
CHANGED
@@ -78,7 +78,19 @@ pip install -r dust3r/requirements.txt
|
|
78 |
pip install -r dust3r/requirements_optional.txt
|
79 |
```
|
80 |
|
81 |
-
3.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
```bash
|
83 |
# DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.
|
84 |
cd dust3r/croco/models/curope/
|
@@ -86,9 +98,10 @@ python setup.py build_ext --inplace
|
|
86 |
cd ../../../../
|
87 |
```
|
88 |
|
89 |
-
|
90 |
### Checkpoints
|
91 |
|
|
|
|
|
92 |
You can obtain the checkpoints by two ways:
|
93 |
|
94 |
1) You can use our huggingface_hub integration: the models will be downloaded automatically.
|
@@ -123,6 +136,7 @@ demo.py is the updated demo for MASt3R. It uses our new sparse global alignment
|
|
123 |
python3 demo.py --model_name MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric
|
124 |
|
125 |
# Use --weights to load a checkpoint from a local file, eg --weights checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth
|
|
|
126 |
# Use --local_network to make it accessible on the local network, or --server_name to specify the url manually
|
127 |
# Use --server_port to change the port, by default it will search for an available port starting at 7860
|
128 |
# Use --device to use a different device, by default it's "cuda"
|
@@ -133,6 +147,8 @@ see https://github.com/naver/dust3r?tab=readme-ov-file#interactive-demo for deta
|
|
133 |
|
134 |
### Interactive demo with docker
|
135 |
|
|
|
|
|
136 |
To run MASt3R using Docker, including with NVIDIA CUDA support, follow these instructions:
|
137 |
|
138 |
1. **Install Docker**: If not already installed, download and install `docker` and `docker compose` from the [Docker website](https://www.docker.com/get-started).
|
|
|
78 |
pip install -r dust3r/requirements_optional.txt
|
79 |
```
|
80 |
|
81 |
+
3. compile and install ASMK
|
82 |
+
```bash
|
83 |
+
pip install cython
|
84 |
+
|
85 |
+
git clone https://github.com/jenicek/asmk
|
86 |
+
cd asmk/cython/
|
87 |
+
cythonize *.pyx
|
88 |
+
cd ..
|
89 |
+
pip install .
|
90 |
+
cd ..
|
91 |
+
```
|
92 |
+
|
93 |
+
4. Optional, compile the cuda kernels for RoPE (as in CroCo v2).
|
94 |
```bash
|
95 |
# DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.
|
96 |
cd dust3r/croco/models/curope/
|
|
|
98 |
cd ../../../../
|
99 |
```
|
100 |
|
|
|
101 |
### Checkpoints
|
102 |
|
103 |
+
TODO upload retrieval_model somewhere
|
104 |
+
|
105 |
You can obtain the checkpoints by two ways:
|
106 |
|
107 |
1) You can use our huggingface_hub integration: the models will be downloaded automatically.
|
|
|
136 |
python3 demo.py --model_name MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric
|
137 |
|
138 |
# Use --weights to load a checkpoint from a local file, eg --weights checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth
|
139 |
+
# Use --retrieval_model and point to the retrieval checkpoint to enable retrieval as a pairing strategy, asmk must be installed
|
140 |
# Use --local_network to make it accessible on the local network, or --server_name to specify the url manually
|
141 |
# Use --server_port to change the port, by default it will search for an available port starting at 7860
|
142 |
# Use --device to use a different device, by default it's "cuda"
|
|
|
147 |
|
148 |
### Interactive demo with docker
|
149 |
|
150 |
+
TODO update with asmk/retrieval model
|
151 |
+
|
152 |
To run MASt3R using Docker, including with NVIDIA CUDA support, follow these instructions:
|
153 |
|
154 |
1. **Install Docker**: If not already installed, download and install `docker` and `docker compose` from the [Docker website](https://www.docker.com/get-started).
|
demo.py
CHANGED
@@ -47,5 +47,5 @@ if __name__ == '__main__':
|
|
47 |
with get_context(args.tmp_dir) as tmpdirname:
|
48 |
cache_path = os.path.join(tmpdirname, chkpt_tag)
|
49 |
os.makedirs(cache_path, exist_ok=True)
|
50 |
-
main_demo(cache_path, model, args.device, args.image_size, server_name, args.server_port,
|
51 |
-
share=args.share, gradio_delete_cache=args.gradio_delete_cache)
|
|
|
47 |
with get_context(args.tmp_dir) as tmpdirname:
|
48 |
cache_path = os.path.join(tmpdirname, chkpt_tag)
|
49 |
os.makedirs(cache_path, exist_ok=True)
|
50 |
+
main_demo(cache_path, model, args.retrieval_model, args.device, args.image_size, server_name, args.server_port,
|
51 |
+
silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache)
|
demo_glomap.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
4 |
+
#
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# gradio demo executable
|
7 |
+
# --------------------------------------------------------
|
8 |
+
import pycolmap
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
import tempfile
|
12 |
+
from contextlib import nullcontext
|
13 |
+
|
14 |
+
from mast3r.demo_glomap import get_args_parser, main_demo
|
15 |
+
|
16 |
+
from mast3r.model import AsymmetricMASt3R
|
17 |
+
from mast3r.utils.misc import hash_md5
|
18 |
+
|
19 |
+
import mast3r.utils.path_to_dust3r # noqa
|
20 |
+
from dust3r.demo import set_print_with_timestamp
|
21 |
+
|
22 |
+
import matplotlib.pyplot as pl
|
23 |
+
pl.ion()
|
24 |
+
|
25 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
parser = get_args_parser()
|
29 |
+
args = parser.parse_args()
|
30 |
+
set_print_with_timestamp()
|
31 |
+
|
32 |
+
if args.server_name is not None:
|
33 |
+
server_name = args.server_name
|
34 |
+
else:
|
35 |
+
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
|
36 |
+
|
37 |
+
if args.weights is not None:
|
38 |
+
weights_path = args.weights
|
39 |
+
else:
|
40 |
+
weights_path = "naver/" + args.model_name
|
41 |
+
|
42 |
+
model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
|
43 |
+
chkpt_tag = hash_md5(weights_path)
|
44 |
+
|
45 |
+
def get_context(tmp_dir):
|
46 |
+
return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
|
47 |
+
else nullcontext(tmp_dir)
|
48 |
+
with get_context(args.tmp_dir) as tmpdirname:
|
49 |
+
cache_path = os.path.join(tmpdirname, chkpt_tag)
|
50 |
+
os.makedirs(cache_path, exist_ok=True)
|
51 |
+
main_demo(args.glomap_bin, cache_path, model, args.retrieval_model, args.device, args.image_size, server_name,
|
52 |
+
args.server_port, silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache)
|
dust3r
CHANGED
@@ -1 +1 @@
|
|
1 |
-
Subproject commit
|
|
|
1 |
+
Subproject commit c9e9336a6ba7c1f1873f9295852cea6dffaf770d
|
kapture_mast3r_mapping.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
4 |
+
#
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# mast3r exec for running standard sfm
|
7 |
+
# --------------------------------------------------------
|
8 |
+
import pycolmap
|
9 |
+
import os
|
10 |
+
import os.path as path
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
from mast3r.model import AsymmetricMASt3R
|
14 |
+
from mast3r.colmap.mapping import (kapture_import_image_folder_or_list, run_mast3r_matching, pycolmap_run_triangulator,
|
15 |
+
pycolmap_run_mapper, glomap_run_mapper)
|
16 |
+
from kapture.io.csv import kapture_from_dir
|
17 |
+
|
18 |
+
from kapture.converter.colmap.database_extra import kapture_to_colmap, generate_priors_for_reconstruction
|
19 |
+
from kapture_localization.utils.pairsfile import get_pairs_from_file
|
20 |
+
from kapture.io.records import get_image_fullpath
|
21 |
+
from kapture.converter.colmap.database import COLMAPDatabase
|
22 |
+
|
23 |
+
|
24 |
+
def get_argparser():
|
25 |
+
parser = argparse.ArgumentParser(description='point triangulator with mast3r from kapture data')
|
26 |
+
parser_weights = parser.add_mutually_exclusive_group(required=True)
|
27 |
+
parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None)
|
28 |
+
parser_weights.add_argument("--model_name", type=str, help="name of the model weights",
|
29 |
+
choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"])
|
30 |
+
|
31 |
+
parser_input = parser.add_mutually_exclusive_group(required=True)
|
32 |
+
parser_input.add_argument('-i', '--input', default=None, help='kdata')
|
33 |
+
parser_input.add_argument('--dir', default=None, help='image dir (individual intrinsics)')
|
34 |
+
parser_input.add_argument('--dir_same_camera', default=None, help='image dir (shared intrinsics)')
|
35 |
+
|
36 |
+
parser.add_argument('-o', '--output', required=True, help='output path to reconstruction')
|
37 |
+
parser.add_argument('--pairsfile_path', required=True, help='pairsfile')
|
38 |
+
|
39 |
+
parser.add_argument('--glomap_bin', default='glomap', type=str, help='glomap bin')
|
40 |
+
|
41 |
+
parser_mapper = parser.add_mutually_exclusive_group()
|
42 |
+
parser_mapper.add_argument('--ignore_pose', action='store_true', default=False)
|
43 |
+
parser_mapper.add_argument('--use_glomap_mapper', action='store_true', default=False)
|
44 |
+
|
45 |
+
parser_matching = parser.add_mutually_exclusive_group()
|
46 |
+
parser_matching.add_argument('--dense_matching', action='store_true', default=False)
|
47 |
+
parser_matching.add_argument('--pixel_tol', default=0, type=int)
|
48 |
+
parser.add_argument('--device', default='cuda')
|
49 |
+
|
50 |
+
parser.add_argument('--conf_thr', default=1.001, type=float)
|
51 |
+
parser.add_argument('--skip_geometric_verification', action='store_true', default=False)
|
52 |
+
parser.add_argument('--min_len_track', default=5, type=int)
|
53 |
+
|
54 |
+
return parser
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == '__main__':
|
58 |
+
parser = get_argparser()
|
59 |
+
args = parser.parse_args()
|
60 |
+
if args.weights is not None:
|
61 |
+
weights_path = args.weights
|
62 |
+
else:
|
63 |
+
weights_path = "naver/" + args.model_name
|
64 |
+
model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
|
65 |
+
maxdim = max(model.patch_embed.img_size)
|
66 |
+
patch_size = model.patch_embed.patch_size
|
67 |
+
|
68 |
+
if args.input is not None:
|
69 |
+
kdata = kapture_from_dir(args.input)
|
70 |
+
records_data_path = get_image_fullpath(args.input)
|
71 |
+
else:
|
72 |
+
if args.dir_same_camera is not None:
|
73 |
+
use_single_camera = True
|
74 |
+
records_data_path = args.dir_same_camera
|
75 |
+
elif args.dir is not None:
|
76 |
+
use_single_camera = False
|
77 |
+
records_data_path = args.dir
|
78 |
+
else:
|
79 |
+
raise ValueError('all inputs choices are None')
|
80 |
+
kdata = kapture_import_image_folder_or_list(records_data_path, use_single_camera)
|
81 |
+
has_pose = kdata.trajectories is not None
|
82 |
+
image_pairs = get_pairs_from_file(args.pairsfile_path, kdata.records_camera, kdata.records_camera)
|
83 |
+
|
84 |
+
colmap_db_path = path.join(args.output, 'colmap.db')
|
85 |
+
reconstruction_path = path.join(args.output, "reconstruction")
|
86 |
+
priors_txt_path = path.join(args.output, "priors_for_reconstruction")
|
87 |
+
for path_i in [reconstruction_path, priors_txt_path]:
|
88 |
+
os.makedirs(path_i, exist_ok=True)
|
89 |
+
assert not os.path.isfile(colmap_db_path)
|
90 |
+
|
91 |
+
colmap_db = COLMAPDatabase.connect(colmap_db_path)
|
92 |
+
try:
|
93 |
+
kapture_to_colmap(kdata, args.input, tar_handler=None, database=colmap_db,
|
94 |
+
keypoints_type=None, descriptors_type=None, export_two_view_geometry=False)
|
95 |
+
if has_pose:
|
96 |
+
generate_priors_for_reconstruction(kdata, colmap_db, priors_txt_path)
|
97 |
+
|
98 |
+
colmap_image_pairs = run_mast3r_matching(model, maxdim, patch_size, args.device,
|
99 |
+
kdata, records_data_path, image_pairs, colmap_db,
|
100 |
+
args.dense_matching, args.pixel_tol, args.conf_thr,
|
101 |
+
args.skip_geometric_verification, args.min_len_track)
|
102 |
+
colmap_db.close()
|
103 |
+
except Exception as e:
|
104 |
+
print(f'Error {e}')
|
105 |
+
colmap_db.close()
|
106 |
+
exit(1)
|
107 |
+
|
108 |
+
if len(colmap_image_pairs) == 0:
|
109 |
+
raise Exception("no matches were kept")
|
110 |
+
|
111 |
+
# colmap db is now full, run colmap
|
112 |
+
colmap_world_to_cam = {}
|
113 |
+
if not args.skip_geometric_verification:
|
114 |
+
print("verify_matches")
|
115 |
+
f = open(args.output + '/pairs.txt', "w")
|
116 |
+
for image_path1, image_path2 in colmap_image_pairs:
|
117 |
+
f.write("{} {}\n".format(image_path1, image_path2))
|
118 |
+
f.close()
|
119 |
+
pycolmap.verify_matches(colmap_db_path, args.output + '/pairs.txt')
|
120 |
+
|
121 |
+
print("running mapping")
|
122 |
+
if has_pose and not args.ignore_pose and not args.use_glomap_mapper:
|
123 |
+
pycolmap_run_triangulator(colmap_db_path, priors_txt_path, reconstruction_path, records_data_path)
|
124 |
+
elif not args.use_glomap_mapper:
|
125 |
+
pycolmap_run_mapper(colmap_db_path, reconstruction_path, records_data_path)
|
126 |
+
else:
|
127 |
+
glomap_run_mapper(args.glomap_bin, colmap_db_path, reconstruction_path, records_data_path)
|
make_pairs.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
4 |
+
#
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# make pairs using mast3r scene_graph, including retrieval
|
7 |
+
# --------------------------------------------------------
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
import os
|
11 |
+
import os.path as path
|
12 |
+
import PIL
|
13 |
+
from PIL import Image
|
14 |
+
import pathlib
|
15 |
+
from kapture.io.csv import table_to_file
|
16 |
+
|
17 |
+
from mast3r.model import AsymmetricMASt3R
|
18 |
+
from mast3r.retrieval.processor import Retriever
|
19 |
+
from mast3r.image_pairs import make_pairs
|
20 |
+
|
21 |
+
|
22 |
+
def get_argparser():
|
23 |
+
parser = argparse.ArgumentParser(description='point triangulator with mast3r from kapture data')
|
24 |
+
parser.add_argument('--dir', required=True, help='image dir')
|
25 |
+
parser.add_argument('--scene_graph', default='retrieval-20-1-10-1')
|
26 |
+
parser.add_argument('--output', required=True, help='txt file')
|
27 |
+
|
28 |
+
parser_weights = parser.add_mutually_exclusive_group(required=False)
|
29 |
+
parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None)
|
30 |
+
parser_weights.add_argument("--model_name", type=str, help="name of the model weights",
|
31 |
+
choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"])
|
32 |
+
parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded")
|
33 |
+
|
34 |
+
return parser
|
35 |
+
|
36 |
+
|
37 |
+
def get_image_list(images_path):
|
38 |
+
file_list = [path.relpath(path.join(dirpath, filename), images_path)
|
39 |
+
for dirpath, dirs, filenames in os.walk(images_path)
|
40 |
+
for filename in filenames]
|
41 |
+
file_list = sorted(file_list)
|
42 |
+
image_list = []
|
43 |
+
for filename in file_list:
|
44 |
+
# test if file is a valid image
|
45 |
+
try:
|
46 |
+
# lazy load
|
47 |
+
with Image.open(path.join(images_path, filename)) as im:
|
48 |
+
width, height = im.size
|
49 |
+
image_list.append(filename)
|
50 |
+
except (OSError, PIL.UnidentifiedImageError):
|
51 |
+
# It is not a valid image: skip it
|
52 |
+
print(f'Skipping invalid image file {filename}')
|
53 |
+
continue
|
54 |
+
return image_list
|
55 |
+
|
56 |
+
|
57 |
+
def main(dir, scene_graph, output, backbone=None, retrieval_model=None):
|
58 |
+
imgs = get_image_list(dir)
|
59 |
+
|
60 |
+
sim_matrix = None
|
61 |
+
if 'retrieval' in scene_graph:
|
62 |
+
retriever = Retriever(retrieval_model, backbone=backbone)
|
63 |
+
imgs_fp = [path.join(dir, filename) for filename in imgs]
|
64 |
+
with torch.no_grad():
|
65 |
+
sim_matrix = retriever(imgs_fp)
|
66 |
+
|
67 |
+
# Cleanup
|
68 |
+
del retriever
|
69 |
+
torch.cuda.empty_cache()
|
70 |
+
|
71 |
+
pairs = make_pairs(imgs, scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix)
|
72 |
+
|
73 |
+
os.umask(0o002)
|
74 |
+
p = pathlib.Path(output)
|
75 |
+
os.makedirs(str(p.parent.resolve()), exist_ok=True)
|
76 |
+
|
77 |
+
with open(output, 'w') as fid:
|
78 |
+
table_to_file(fid, pairs, header='# query_image, map_image, score')
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == '__main__':
|
82 |
+
parser = get_argparser()
|
83 |
+
args = parser.parse_args()
|
84 |
+
|
85 |
+
if "retrieval" in args.scene_graph:
|
86 |
+
assert args.retrieval_model is not None
|
87 |
+
if args.weights is not None:
|
88 |
+
weights_path = args.weights
|
89 |
+
else:
|
90 |
+
weights_path = "naver/" + args.model_name
|
91 |
+
backbone = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
|
92 |
+
retrieval_model = args.retrieval_model
|
93 |
+
else:
|
94 |
+
backbone = None
|
95 |
+
retrieval_model = None
|
96 |
+
main(args.dir, args.scene_graph, args.output, backbone, retrieval_model)
|
mast3r/catmlp_dpt_head.py
CHANGED
@@ -5,6 +5,7 @@
|
|
5 |
# MASt3R heads
|
6 |
# --------------------------------------------------------
|
7 |
import torch
|
|
|
8 |
import torch.nn.functional as F
|
9 |
|
10 |
import mast3r.utils.path_to_dust3r # noqa
|
@@ -12,6 +13,7 @@ from dust3r.heads.postprocess import reg_dense_depth, reg_dense_conf # noqa
|
|
12 |
from dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa
|
13 |
import dust3r.utils.path_to_croco # noqa
|
14 |
from models.blocks import Mlp # noqa
|
|
|
15 |
|
16 |
|
17 |
def reg_desc(desc, mode):
|
@@ -96,6 +98,113 @@ class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT):
|
|
96 |
return out
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def mast3r_head_factory(head_type, output_mode, net, has_conf=False):
|
100 |
"""" build a prediction head for the decoder
|
101 |
"""
|
@@ -118,6 +227,13 @@ def mast3r_head_factory(head_type, output_mode, net, has_conf=False):
|
|
118 |
depth_mode=net.depth_mode,
|
119 |
conf_mode=net.conf_mode,
|
120 |
head_type='regression')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
else:
|
122 |
raise NotImplementedError(
|
123 |
f"unexpected {head_type=} and {output_mode=}")
|
|
|
5 |
# MASt3R heads
|
6 |
# --------------------------------------------------------
|
7 |
import torch
|
8 |
+
import torch.nn as nn
|
9 |
import torch.nn.functional as F
|
10 |
|
11 |
import mast3r.utils.path_to_dust3r # noqa
|
|
|
13 |
from dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa
|
14 |
import dust3r.utils.path_to_croco # noqa
|
15 |
from models.blocks import Mlp # noqa
|
16 |
+
from models.dpt_block import Interpolate # noqa
|
17 |
|
18 |
|
19 |
def reg_desc(desc, mode):
|
|
|
98 |
return out
|
99 |
|
100 |
|
101 |
+
class MLP_MiniConv_Head(nn.Module):
|
102 |
+
"""
|
103 |
+
A special Convolutional head inspired by DPT architecture
|
104 |
+
A MLP predicts pixelwise feats in lower resolution. Prediction is upsampled to target res and goes through a mini convolutional head
|
105 |
+
|
106 |
+
Input : [B, S, D] # S = (H//p) * (W//p)
|
107 |
+
|
108 |
+
MLP:
|
109 |
+
D -> (mlp_hidden_dim) -> out_mlp_dim * (p/2)*2
|
110 |
+
reshape to [out_mlp_dim, H/2, W/2] (MLP predicts in half-res)
|
111 |
+
|
112 |
+
MiniConv head from DPT:
|
113 |
+
Upsample x2 -> [out_mlp_dim,H,W]
|
114 |
+
Conv 3x3 -> [conv_inner_dim,H,W]
|
115 |
+
ReLU
|
116 |
+
Conv 1x1 -> [odim,H,W]
|
117 |
+
|
118 |
+
"""
|
119 |
+
|
120 |
+
def __init__(self, idim, mlp_hidden_dim, mlp_odim, conv_inner_dim, odim, patch_size, subpatch=2, **kw):
|
121 |
+
super().__init__()
|
122 |
+
self.patch_size = patch_size
|
123 |
+
self.subpatch = subpatch
|
124 |
+
self.sub_patch_size = patch_size // subpatch
|
125 |
+
self.mlp = Mlp(idim, mlp_hidden_dim, mlp_odim * self.sub_patch_size**2, **kw) # D -> mlp_odim*sub_patch_size**2
|
126 |
+
|
127 |
+
# DPT conv head
|
128 |
+
self.head = nn.Sequential(Interpolate(scale_factor=self.subpatch, mode="bilinear", align_corners=True) if self.subpatch != 1 else nn.Identity(),
|
129 |
+
nn.Conv2d(mlp_odim, conv_inner_dim, kernel_size=3, stride=1, padding=1),
|
130 |
+
nn.ReLU(True),
|
131 |
+
nn.Conv2d(conv_inner_dim, odim, kernel_size=1, stride=1, padding=0)
|
132 |
+
)
|
133 |
+
|
134 |
+
def forward(self, decout, img_shape):
|
135 |
+
H, W = img_shape
|
136 |
+
tokens = decout[-1]
|
137 |
+
B, S, D = tokens.shape
|
138 |
+
# extract features
|
139 |
+
feat = self.mlp(tokens) # [B, S, mlp_odim*sub_patch_size**2]
|
140 |
+
feat = feat.transpose(-1, -2).reshape(B, -1, H // self.patch_size, W // self.patch_size)
|
141 |
+
feat = F.pixel_shuffle(feat, self.sub_patch_size) # B,mlp_odim,H/sub,W/sub
|
142 |
+
|
143 |
+
return self.head(feat) # B, odim, H, W
|
144 |
+
|
145 |
+
|
146 |
+
class Cat_MLP_LocalFeatures_MiniConv_Pts3d(nn.Module):
|
147 |
+
""" Mixture between MLP and MLP-Convolutional head that outputs 3d points (with miniconv) and local features (with MLP).
|
148 |
+
simply contains two MLP_MiniConv_Head: one for 3D points and one for features.
|
149 |
+
The input for both heads is a concatenation of Encoder and Decoder outputs
|
150 |
+
"""
|
151 |
+
|
152 |
+
def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., mlp_odim=24, conv_inner_dim=100, subpatch=2, **kw):
|
153 |
+
super().__init__()
|
154 |
+
|
155 |
+
self.local_feat_dim = local_feat_dim
|
156 |
+
patch_size = net.patch_embed.patch_size
|
157 |
+
if isinstance(patch_size, tuple):
|
158 |
+
assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance(
|
159 |
+
patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints."
|
160 |
+
assert patch_size[0] == patch_size[1], "Error, non square patches not managed"
|
161 |
+
patch_size = patch_size[0]
|
162 |
+
self.patch_size = patch_size
|
163 |
+
|
164 |
+
self.depth_mode = net.depth_mode
|
165 |
+
self.conf_mode = net.conf_mode
|
166 |
+
self.desc_mode = net.desc_mode
|
167 |
+
self.desc_conf_mode = net.desc_conf_mode
|
168 |
+
self.has_conf = has_conf
|
169 |
+
self.two_confs = net.two_confs # independent confs for 3D regr and descs
|
170 |
+
idim = net.enc_embed_dim + net.dec_embed_dim
|
171 |
+
self.head_pts3d = MLP_MiniConv_Head(idim=idim,
|
172 |
+
mlp_hidden_dim=int(hidden_dim_factor * idim),
|
173 |
+
mlp_odim=mlp_odim + self.has_conf,
|
174 |
+
conv_inner_dim=conv_inner_dim,
|
175 |
+
odim=3 + self.has_conf,
|
176 |
+
subpatch=subpatch,
|
177 |
+
patch_size=self.patch_size,
|
178 |
+
**kw)
|
179 |
+
|
180 |
+
self.head_local_features = Mlp(in_features=idim,
|
181 |
+
hidden_features=int(hidden_dim_factor * idim),
|
182 |
+
out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2)
|
183 |
+
|
184 |
+
def forward(self, decout, img_shape):
|
185 |
+
enc_output, dec_output = decout[0], decout[-1] # recover encoder and decoder outputs
|
186 |
+
cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate
|
187 |
+
# pass through the heads
|
188 |
+
pts3d = self.head_pts3d([cat_output], img_shape)
|
189 |
+
|
190 |
+
H, W = img_shape
|
191 |
+
B, S, D = cat_output.shape
|
192 |
+
|
193 |
+
# extract 3D points
|
194 |
+
local_features = self.head_local_features(cat_output) # B,S,D
|
195 |
+
local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size)
|
196 |
+
local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W
|
197 |
+
|
198 |
+
# post process 3D pts, descriptors and confidences
|
199 |
+
out = postprocess(torch.cat([pts3d, local_features], dim=1),
|
200 |
+
depth_mode=self.depth_mode,
|
201 |
+
conf_mode=self.conf_mode,
|
202 |
+
desc_dim=self.local_feat_dim,
|
203 |
+
desc_mode=self.desc_mode,
|
204 |
+
two_confs=self.two_confs, desc_conf_mode=self.desc_conf_mode)
|
205 |
+
return out
|
206 |
+
|
207 |
+
|
208 |
def mast3r_head_factory(head_type, output_mode, net, has_conf=False):
|
209 |
"""" build a prediction head for the decoder
|
210 |
"""
|
|
|
227 |
depth_mode=net.depth_mode,
|
228 |
conf_mode=net.conf_mode,
|
229 |
head_type='regression')
|
230 |
+
elif head_type == 'catconv' and output_mode.startswith('pts3d+desc'):
|
231 |
+
local_feat_dim = int(output_mode[10:])
|
232 |
+
# more params (anounced by a ':' and comma separated)
|
233 |
+
kw = {}
|
234 |
+
if ':' in head_type:
|
235 |
+
kw = eval("dict(" + head_type[8:] + ")")
|
236 |
+
return Cat_MLP_LocalFeatures_MiniConv_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf, **kw)
|
237 |
else:
|
238 |
raise NotImplementedError(
|
239 |
f"unexpected {head_type=} and {output_mode=}")
|
mast3r/cloud_opt/sparse_ga.py
CHANGED
@@ -15,6 +15,7 @@ from collections import namedtuple
|
|
15 |
from functools import lru_cache
|
16 |
from scipy import sparse as sp
|
17 |
import copy
|
|
|
18 |
|
19 |
from mast3r.utils.misc import mkdir_for, hash_md5
|
20 |
from mast3r.cloud_opt.utils.losses import gamma_loss
|
@@ -116,7 +117,7 @@ def convert_dust3r_pairs_naming(imgs, pairs_in):
|
|
116 |
|
117 |
|
118 |
def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
|
119 |
-
device='cuda', dtype=torch.float32, shared_intrinsics=False, **kw):
|
120 |
""" Sparse alignment with MASt3R
|
121 |
imgs: list of image paths
|
122 |
cache_path: path where to dump temporary files (str)
|
@@ -137,17 +138,54 @@ def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc
|
|
137 |
tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \
|
138 |
prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device)
|
139 |
|
140 |
-
#
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# remove all edges not in the spanning tree?
|
144 |
# min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]}
|
145 |
# tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree}
|
146 |
|
147 |
-
# smartly combine all useful data
|
148 |
-
imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21 = \
|
149 |
-
condense_data(imgs, tmp_pairs, canonical_views, preds_21, dtype)
|
150 |
-
|
151 |
imgs, res_coarse, res_fine = sparse_scene_optimizer(
|
152 |
imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, mst,
|
153 |
shared_intrinsics=shared_intrinsics, cache_path=cache_path, device=device, dtype=dtype, **kw)
|
@@ -157,8 +195,8 @@ def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc
|
|
157 |
|
158 |
def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d,
|
159 |
preds_21, canonical_paths, mst, cache_path,
|
160 |
-
lr1=0.
|
161 |
-
lr2=0.
|
162 |
lossd=gamma_loss(1.1),
|
163 |
opt_pp=True, opt_depth=True,
|
164 |
schedule=cosine_schedule, depth_mode='add', exp_depth=False,
|
|
|
15 |
from functools import lru_cache
|
16 |
from scipy import sparse as sp
|
17 |
import copy
|
18 |
+
import scipy.cluster.hierarchy as sch
|
19 |
|
20 |
from mast3r.utils.misc import mkdir_for, hash_md5
|
21 |
from mast3r.cloud_opt.utils.losses import gamma_loss
|
|
|
117 |
|
118 |
|
119 |
def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
|
120 |
+
kinematic_mode='hclust-ward', device='cuda', dtype=torch.float32, shared_intrinsics=False, **kw):
|
121 |
""" Sparse alignment with MASt3R
|
122 |
imgs: list of image paths
|
123 |
cache_path: path where to dump temporary files (str)
|
|
|
138 |
tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \
|
139 |
prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device)
|
140 |
|
141 |
+
# smartly combine all useful data
|
142 |
+
imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21 = \
|
143 |
+
condense_data(imgs, tmp_pairs, canonical_views, preds_21, dtype)
|
144 |
+
|
145 |
+
# Build kinematic chain
|
146 |
+
if kinematic_mode == 'mst':
|
147 |
+
# compute minimal spanning tree
|
148 |
+
mst = compute_min_spanning_tree(pairwise_scores)
|
149 |
+
|
150 |
+
elif kinematic_mode.startswith('hclust'):
|
151 |
+
mode, linkage = kinematic_mode.split('-')
|
152 |
+
|
153 |
+
# Convert the affinity matrix to a distance matrix (if needed)
|
154 |
+
n_patches = (imsizes // subsample).prod(dim=1)
|
155 |
+
max_n_corres = 3 * torch.minimum(n_patches[:,None], n_patches[None,:])
|
156 |
+
pws = (pairwise_scores.clone() / max_n_corres).clip(max=1)
|
157 |
+
pws.fill_diagonal_(1)
|
158 |
+
pws = to_numpy(pws)
|
159 |
+
distance_matrix = np.where(pws, 1 - pws, 2)
|
160 |
+
|
161 |
+
# Compute the condensed distance matrix
|
162 |
+
condensed_distance_matrix = sch.distance.squareform(distance_matrix)
|
163 |
+
|
164 |
+
# Perform hierarchical clustering using the linkage method
|
165 |
+
Z = sch.linkage(condensed_distance_matrix, method=linkage)
|
166 |
+
# dendrogram = sch.dendrogram(Z)
|
167 |
+
|
168 |
+
tree = np.eye(len(imgs))
|
169 |
+
new_to_old_nodes = {i:i for i in range(len(imgs))}
|
170 |
+
for i, (a, b) in enumerate(Z[:,:2].astype(int)):
|
171 |
+
# given two nodes to be merged, we choose which one is the best representant
|
172 |
+
a = new_to_old_nodes[a]
|
173 |
+
b = new_to_old_nodes[b]
|
174 |
+
tree[a,b] = tree[b,a] = 1
|
175 |
+
best = a if pws[a].sum() > pws[b].sum() else b
|
176 |
+
new_to_old_nodes[len(imgs)+i] = best
|
177 |
+
pws[best] = np.maximum(pws[a], pws[b]) # update the node
|
178 |
+
|
179 |
+
pairwise_scores = torch.from_numpy(tree) # this output just gives 1s for connected edges and zeros for other, i.e. no scores or priority
|
180 |
+
mst = compute_min_spanning_tree(pairwise_scores)
|
181 |
+
|
182 |
+
else:
|
183 |
+
raise ValueError(f'bad {kinematic_mode=}')
|
184 |
|
185 |
# remove all edges not in the spanning tree?
|
186 |
# min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]}
|
187 |
# tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree}
|
188 |
|
|
|
|
|
|
|
|
|
189 |
imgs, res_coarse, res_fine = sparse_scene_optimizer(
|
190 |
imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, mst,
|
191 |
shared_intrinsics=shared_intrinsics, cache_path=cache_path, device=device, dtype=dtype, **kw)
|
|
|
195 |
|
196 |
def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d,
|
197 |
preds_21, canonical_paths, mst, cache_path,
|
198 |
+
lr1=0.07, niter1=300, loss1=gamma_loss(1.5),
|
199 |
+
lr2=0.01, niter2=300, loss2=gamma_loss(0.5),
|
200 |
lossd=gamma_loss(1.1),
|
201 |
opt_pp=True, opt_depth=True,
|
202 |
schedule=cosine_schedule, depth_mode='add', exp_depth=False,
|
mast3r/colmap/mapping.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# colmap mapper/colmap point_triangulator/glomap mapper from mast3r matches
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import pycolmap
|
8 |
+
import os
|
9 |
+
import os.path as path
|
10 |
+
import kapture.io
|
11 |
+
import kapture.io.csv
|
12 |
+
import subprocess
|
13 |
+
import PIL
|
14 |
+
from tqdm import tqdm
|
15 |
+
import PIL.Image
|
16 |
+
import numpy as np
|
17 |
+
from typing import List, Tuple, Union
|
18 |
+
|
19 |
+
from mast3r.model import AsymmetricMASt3R
|
20 |
+
from mast3r.colmap.database import export_matches, get_im_matches
|
21 |
+
|
22 |
+
import mast3r.utils.path_to_dust3r # noqa
|
23 |
+
from dust3r_visloc.datasets.utils import get_resize_function
|
24 |
+
|
25 |
+
import kapture
|
26 |
+
from kapture.converter.colmap.database_extra import get_colmap_camera_ids_from_db, get_colmap_image_ids_from_db
|
27 |
+
from kapture.utils.paths import path_secure
|
28 |
+
|
29 |
+
from dust3r.datasets.utils.transforms import ImgNorm
|
30 |
+
from dust3r.inference import inference
|
31 |
+
|
32 |
+
|
33 |
+
def scene_prepare_images(root: str, maxdim: int, patch_size: int, image_paths: List[str]):
|
34 |
+
images = []
|
35 |
+
# image loading
|
36 |
+
for idx in tqdm(range(len(image_paths))):
|
37 |
+
rgb_image = PIL.Image.open(os.path.join(root, image_paths[idx])).convert('RGB')
|
38 |
+
|
39 |
+
# resize images
|
40 |
+
W, H = rgb_image.size
|
41 |
+
resize_func, _, to_orig = get_resize_function(maxdim, patch_size, H, W)
|
42 |
+
rgb_tensor = resize_func(ImgNorm(rgb_image))
|
43 |
+
|
44 |
+
# image dictionary
|
45 |
+
images.append({'img': rgb_tensor.unsqueeze(0),
|
46 |
+
'true_shape': np.int32([rgb_tensor.shape[1:]]),
|
47 |
+
'to_orig': to_orig,
|
48 |
+
'idx': idx,
|
49 |
+
'instance': image_paths[idx],
|
50 |
+
'orig_shape': np.int32([H, W])})
|
51 |
+
return images
|
52 |
+
|
53 |
+
|
54 |
+
def remove_duplicates(images, image_pairs):
|
55 |
+
pairs_added = set()
|
56 |
+
pairs = []
|
57 |
+
for (i, _), (j, _) in image_pairs:
|
58 |
+
smallidx, bigidx = min(i, j), max(i, j)
|
59 |
+
if (smallidx, bigidx) in pairs_added:
|
60 |
+
continue
|
61 |
+
pairs_added.add((smallidx, bigidx))
|
62 |
+
pairs.append((images[i], images[j]))
|
63 |
+
return pairs
|
64 |
+
|
65 |
+
|
66 |
+
def run_mast3r_matching(model: AsymmetricMASt3R, maxdim: int, patch_size: int, device,
|
67 |
+
kdata: kapture.Kapture, root_path: str, image_pairs_kapture: List[Tuple[str, str]],
|
68 |
+
colmap_db,
|
69 |
+
dense_matching: bool, pixel_tol: int, conf_thr: float, skip_geometric_verification: bool,
|
70 |
+
min_len_track: int):
|
71 |
+
assert kdata.records_camera is not None
|
72 |
+
image_paths = kdata.records_camera.data_list()
|
73 |
+
image_path_to_idx = {image_path: idx for idx, image_path in enumerate(image_paths)}
|
74 |
+
image_path_to_ts = {kdata.records_camera[ts, camid]: (ts, camid) for ts, camid in kdata.records_camera.key_pairs()}
|
75 |
+
|
76 |
+
images = scene_prepare_images(root_path, maxdim, patch_size, image_paths)
|
77 |
+
image_pairs = [((image_path_to_idx[image_path1], image_path1), (image_path_to_idx[image_path2], image_path2))
|
78 |
+
for image_path1, image_path2 in image_pairs_kapture]
|
79 |
+
matching_pairs = remove_duplicates(images, image_pairs)
|
80 |
+
|
81 |
+
colmap_camera_ids = get_colmap_camera_ids_from_db(colmap_db, kdata.records_camera)
|
82 |
+
colmap_image_ids = get_colmap_image_ids_from_db(colmap_db)
|
83 |
+
im_keypoints = {idx: {} for idx in range(len(image_paths))}
|
84 |
+
|
85 |
+
im_matches = {}
|
86 |
+
image_to_colmap = {}
|
87 |
+
for image_path, idx in image_path_to_idx.items():
|
88 |
+
_, camid = image_path_to_ts[image_path]
|
89 |
+
colmap_camid = colmap_camera_ids[camid]
|
90 |
+
colmap_imid = colmap_image_ids[image_path]
|
91 |
+
image_to_colmap[idx] = {
|
92 |
+
'colmap_imid': colmap_imid,
|
93 |
+
'colmap_camid': colmap_camid
|
94 |
+
}
|
95 |
+
|
96 |
+
# compute 2D-2D matching from dust3r inference
|
97 |
+
for chunk in tqdm(range(0, len(matching_pairs), 4)):
|
98 |
+
pairs_chunk = matching_pairs[chunk:chunk + 4]
|
99 |
+
output = inference(pairs_chunk, model, device, batch_size=1, verbose=False)
|
100 |
+
pred1, pred2 = output['pred1'], output['pred2']
|
101 |
+
# TODO handle caching
|
102 |
+
im_images_chunk = get_im_matches(pred1, pred2, pairs_chunk, image_to_colmap,
|
103 |
+
im_keypoints, conf_thr, not dense_matching, pixel_tol)
|
104 |
+
im_matches.update(im_images_chunk.items())
|
105 |
+
|
106 |
+
# filter matches, convert them and export keypoints and matches to colmap db
|
107 |
+
colmap_image_pairs = export_matches(
|
108 |
+
colmap_db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification)
|
109 |
+
colmap_db.commit()
|
110 |
+
|
111 |
+
return colmap_image_pairs
|
112 |
+
|
113 |
+
|
114 |
+
def pycolmap_run_triangulator(colmap_db_path, prior_recon_path, recon_path, image_root_path):
|
115 |
+
print("running mapping")
|
116 |
+
reconstruction = pycolmap.Reconstruction(prior_recon_path)
|
117 |
+
pycolmap.triangulate_points(
|
118 |
+
reconstruction=reconstruction,
|
119 |
+
database_path=colmap_db_path,
|
120 |
+
image_path=image_root_path,
|
121 |
+
output_path=recon_path,
|
122 |
+
refine_intrinsics=False,
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def pycolmap_run_mapper(colmap_db_path, recon_path, image_root_path):
|
127 |
+
print("running mapping")
|
128 |
+
reconstructions = pycolmap.incremental_mapping(
|
129 |
+
database_path=colmap_db_path,
|
130 |
+
image_path=image_root_path,
|
131 |
+
output_path=recon_path,
|
132 |
+
options=pycolmap.IncrementalPipelineOptions({'multiple_models': False,
|
133 |
+
'extract_colors': True,
|
134 |
+
})
|
135 |
+
)
|
136 |
+
|
137 |
+
|
138 |
+
def glomap_run_mapper(glomap_bin, colmap_db_path, recon_path, image_root_path):
|
139 |
+
print("running mapping")
|
140 |
+
args = [
|
141 |
+
'mapper',
|
142 |
+
'--database_path',
|
143 |
+
colmap_db_path,
|
144 |
+
'--image_path',
|
145 |
+
image_root_path,
|
146 |
+
'--output_path',
|
147 |
+
recon_path
|
148 |
+
]
|
149 |
+
args.insert(0, glomap_bin)
|
150 |
+
glomap_process = subprocess.Popen(args)
|
151 |
+
glomap_process.wait()
|
152 |
+
|
153 |
+
if glomap_process.returncode != 0:
|
154 |
+
raise ValueError(
|
155 |
+
'\nSubprocess Error (Return code:'
|
156 |
+
f' {glomap_process.returncode} )')
|
157 |
+
|
158 |
+
|
159 |
+
def kapture_import_image_folder_or_list(images_path: Union[str, Tuple[str, List[str]]], use_single_camera=False) -> kapture.Kapture:
|
160 |
+
images = kapture.RecordsCamera()
|
161 |
+
|
162 |
+
if isinstance(images_path, str):
|
163 |
+
images_root = images_path
|
164 |
+
file_list = [path.relpath(path.join(dirpath, filename), images_root)
|
165 |
+
for dirpath, dirs, filenames in os.walk(images_root)
|
166 |
+
for filename in filenames]
|
167 |
+
file_list = sorted(file_list)
|
168 |
+
else:
|
169 |
+
images_root, file_list = images_path
|
170 |
+
|
171 |
+
sensors = kapture.Sensors()
|
172 |
+
for n, filename in enumerate(file_list):
|
173 |
+
# test if file is a valid image
|
174 |
+
try:
|
175 |
+
# lazy load
|
176 |
+
with PIL.Image.open(path.join(images_root, filename)) as im:
|
177 |
+
width, height = im.size
|
178 |
+
model_params = [width, height]
|
179 |
+
except (OSError, PIL.UnidentifiedImageError):
|
180 |
+
# It is not a valid image: skip it
|
181 |
+
print(f'Skipping invalid image file {filename}')
|
182 |
+
continue
|
183 |
+
|
184 |
+
camera_id = f'sensor'
|
185 |
+
if use_single_camera and camera_id not in sensors:
|
186 |
+
sensors[camera_id] = kapture.Camera(kapture.CameraType.UNKNOWN_CAMERA, model_params)
|
187 |
+
elif use_single_camera:
|
188 |
+
assert sensors[camera_id].camera_params[0] == width and sensors[camera_id].camera_params[1] == height
|
189 |
+
else:
|
190 |
+
camera_id = camera_id + f'{n}'
|
191 |
+
sensors[camera_id] = kapture.Camera(kapture.CameraType.UNKNOWN_CAMERA, model_params)
|
192 |
+
|
193 |
+
images[(n, camera_id)] = path_secure(filename) # don't forget windows
|
194 |
+
|
195 |
+
return kapture.Kapture(sensors=sensors, records_camera=images)
|
mast3r/demo.py
CHANGED
@@ -15,12 +15,14 @@ import copy
|
|
15 |
from scipy.spatial.transform import Rotation
|
16 |
import tempfile
|
17 |
import shutil
|
|
|
18 |
|
19 |
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
20 |
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
|
|
|
|
|
21 |
|
22 |
import mast3r.utils.path_to_dust3r # noqa
|
23 |
-
from dust3r.image_pairs import make_pairs
|
24 |
from dust3r.utils.image import load_images
|
25 |
from dust3r.utils.device import to_numpy
|
26 |
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
@@ -29,7 +31,7 @@ from dust3r.demo import get_args_parser as dust3r_get_args_parser
|
|
29 |
import matplotlib.pyplot as pl
|
30 |
|
31 |
|
32 |
-
class SparseGAState
|
33 |
def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
|
34 |
self.sparse_ga = sparse_ga
|
35 |
self.cache_dir = cache_dir
|
@@ -52,6 +54,7 @@ def get_args_parser():
|
|
52 |
parser.add_argument('--share', action='store_true')
|
53 |
parser.add_argument('--gradio_delete_cache', default=None, type=int,
|
54 |
help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
|
|
|
55 |
|
56 |
actions = parser._actions
|
57 |
for action in actions:
|
@@ -136,10 +139,10 @@ def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=F
|
|
136 |
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
137 |
|
138 |
|
139 |
-
def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent, image_size,
|
140 |
-
filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr,
|
141 |
-
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
142 |
-
win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
|
143 |
"""
|
144 |
from a list of images, run mast3r inference, sparse global aligner.
|
145 |
then run get_3D_model_from_scene
|
@@ -155,10 +158,26 @@ def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent,
|
|
155 |
scene_graph_params.append(str(winsize))
|
156 |
elif scenegraph_type == "oneref":
|
157 |
scene_graph_params.append(str(refid))
|
|
|
|
|
|
|
|
|
158 |
if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
|
159 |
scene_graph_params.append('noncyclic')
|
160 |
scene_graph = '-'.join(scene_graph_params)
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
if optim_level == 'coarse':
|
163 |
niter2 = 0
|
164 |
# Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
|
@@ -190,39 +209,66 @@ def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent,
|
|
190 |
|
191 |
def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
|
192 |
num_files = len(inputfiles) if inputfiles is not None else 1
|
193 |
-
show_win_controls = scenegraph_type in ["swin", "logwin"]
|
194 |
-
show_winsize = scenegraph_type in ["swin", "logwin"]
|
195 |
-
show_cyclic = scenegraph_type in ["swin", "logwin"]
|
196 |
max_winsize, min_winsize = 1, 1
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
206 |
else:
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
|
|
216 |
|
217 |
-
|
|
|
218 |
share=False, gradio_delete_cache=False):
|
219 |
if not silent:
|
220 |
print('Outputing stuff in', tmpdirname)
|
221 |
|
222 |
-
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model,
|
223 |
-
silent, image_size)
|
224 |
model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
def get_context(delete_cache):
|
227 |
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
228 |
title = "MASt3R Demo"
|
@@ -241,33 +287,31 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
|
|
241 |
with gradio.Column():
|
242 |
with gradio.Row():
|
243 |
lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
|
244 |
-
niter1 = gradio.
|
245 |
-
label="
|
246 |
-
lr2 = gradio.Slider(label="Fine LR", value=0.
|
247 |
-
niter2 = gradio.
|
248 |
-
label="
|
249 |
optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
|
250 |
value='refine', label="OptLevel",
|
251 |
info="Optimization level")
|
252 |
with gradio.Row():
|
253 |
-
matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=
|
254 |
minimum=0., maximum=30., step=0.1,
|
255 |
info="Before Fallback to Regr3D!")
|
256 |
shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
|
257 |
info="Only optimize one set of intrinsics for all views")
|
258 |
-
scenegraph_type = gradio.Dropdown(
|
259 |
-
("swin: sliding window", "swin"),
|
260 |
-
("logwin: sliding window with long range", "logwin"),
|
261 |
-
("oneref: match one image with all", "oneref")],
|
262 |
value='complete', label="Scenegraph",
|
263 |
info="Define how to make pairs",
|
264 |
interactive=True)
|
265 |
-
with gradio.Column(visible=False) as
|
266 |
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
267 |
minimum=1, maximum=1, step=1)
|
268 |
win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
|
269 |
-
|
270 |
-
|
|
|
271 |
run_btn = gradio.Button("Run")
|
272 |
|
273 |
with gradio.Row():
|
@@ -288,13 +332,13 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
|
|
288 |
# events
|
289 |
scenegraph_type.change(set_scenegraph_options,
|
290 |
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
291 |
-
outputs=[
|
292 |
inputfiles.change(set_scenegraph_options,
|
293 |
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
294 |
-
outputs=[
|
295 |
win_cyclic.change(set_scenegraph_options,
|
296 |
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
297 |
-
outputs=[
|
298 |
run_btn.click(fn=recon_fun,
|
299 |
inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
300 |
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
|
|
15 |
from scipy.spatial.transform import Rotation
|
16 |
import tempfile
|
17 |
import shutil
|
18 |
+
import torch
|
19 |
|
20 |
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
21 |
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
|
22 |
+
from mast3r.image_pairs import make_pairs
|
23 |
+
from mast3r.retrieval.processor import Retriever
|
24 |
|
25 |
import mast3r.utils.path_to_dust3r # noqa
|
|
|
26 |
from dust3r.utils.image import load_images
|
27 |
from dust3r.utils.device import to_numpy
|
28 |
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
|
|
31 |
import matplotlib.pyplot as pl
|
32 |
|
33 |
|
34 |
+
class SparseGAState:
|
35 |
def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
|
36 |
self.sparse_ga = sparse_ga
|
37 |
self.cache_dir = cache_dir
|
|
|
54 |
parser.add_argument('--share', action='store_true')
|
55 |
parser.add_argument('--gradio_delete_cache', default=None, type=int,
|
56 |
help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
|
57 |
+
parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded")
|
58 |
|
59 |
actions = parser._actions
|
60 |
for action in actions:
|
|
|
139 |
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
140 |
|
141 |
|
142 |
+
def get_reconstructed_scene(outdir, gradio_delete_cache, model, retrieval_model, device, silent, image_size,
|
143 |
+
current_scene_state, filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr,
|
144 |
+
matching_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
145 |
+
scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
|
146 |
"""
|
147 |
from a list of images, run mast3r inference, sparse global aligner.
|
148 |
then run get_3D_model_from_scene
|
|
|
158 |
scene_graph_params.append(str(winsize))
|
159 |
elif scenegraph_type == "oneref":
|
160 |
scene_graph_params.append(str(refid))
|
161 |
+
elif scenegraph_type == "retrieval":
|
162 |
+
scene_graph_params.append(str(winsize)) # Na
|
163 |
+
scene_graph_params.append(str(refid)) # k
|
164 |
+
|
165 |
if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
|
166 |
scene_graph_params.append('noncyclic')
|
167 |
scene_graph = '-'.join(scene_graph_params)
|
168 |
+
|
169 |
+
sim_matrix = None
|
170 |
+
if 'retrieval' in scenegraph_type:
|
171 |
+
assert retrieval_model is not None
|
172 |
+
retriever = Retriever(retrieval_model, backbone=model, device=device)
|
173 |
+
with torch.no_grad():
|
174 |
+
sim_matrix = retriever(filelist)
|
175 |
+
|
176 |
+
# Cleanup
|
177 |
+
del retriever
|
178 |
+
torch.cuda.empty_cache()
|
179 |
+
|
180 |
+
pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix)
|
181 |
if optim_level == 'coarse':
|
182 |
niter2 = 0
|
183 |
# Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
|
|
|
209 |
|
210 |
def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
|
211 |
num_files = len(inputfiles) if inputfiles is not None else 1
|
|
|
|
|
|
|
212 |
max_winsize, min_winsize = 1, 1
|
213 |
+
|
214 |
+
winsize = gradio.Slider(visible=False)
|
215 |
+
win_cyclic = gradio.Checkbox(visible=False)
|
216 |
+
graph_opt = gradio.Column(visible=False)
|
217 |
+
refid = gradio.Slider(visible=False)
|
218 |
+
|
219 |
+
if scenegraph_type in ["swin", "logwin"]:
|
220 |
+
if scenegraph_type == "swin":
|
221 |
+
if win_cyclic:
|
222 |
+
max_winsize = max(1, math.ceil((num_files - 1) / 2))
|
223 |
+
else:
|
224 |
+
max_winsize = num_files - 1
|
225 |
else:
|
226 |
+
if win_cyclic:
|
227 |
+
half_size = math.ceil((num_files - 1) / 2)
|
228 |
+
max_winsize = max(1, math.ceil(math.log(half_size, 2)))
|
229 |
+
else:
|
230 |
+
max_winsize = max(1, math.ceil(math.log(num_files, 2)))
|
231 |
+
|
232 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
233 |
+
minimum=min_winsize, maximum=max_winsize, step=1, visible=True)
|
234 |
+
win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=True)
|
235 |
+
graph_opt = gradio.Column(visible=True)
|
236 |
+
refid = gradio.Slider(visible=False)
|
237 |
+
|
238 |
+
elif scenegraph_type == "retrieval":
|
239 |
+
graph_opt = gradio.Column(visible=True)
|
240 |
+
winsize = gradio.Slider(label="Retrieval: Num. key images", value=min(20, num_files),
|
241 |
+
minimum=0, maximum=num_files, step=1, visible=True)
|
242 |
+
win_cyclic = gradio.Checkbox(visible=False)
|
243 |
+
refid = gradio.Slider(label="Retrieval: Num neighbors", value=min(num_files - 1, 10), minimum=1,
|
244 |
+
maximum=num_files - 1, step=1, visible=True)
|
245 |
+
|
246 |
+
elif scenegraph_type == "oneref":
|
247 |
+
graph_opt = gradio.Column(visible=True)
|
248 |
+
winsize = gradio.Slider(visible=False)
|
249 |
+
win_cyclic = gradio.Checkbox(visible=False)
|
250 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
251 |
+
maximum=num_files - 1, step=1, visible=True)
|
252 |
|
253 |
+
return graph_opt, winsize, win_cyclic, refid
|
254 |
|
255 |
+
|
256 |
+
def main_demo(tmpdirname, model, retrieval_model, device, image_size, server_name, server_port, silent=False,
|
257 |
share=False, gradio_delete_cache=False):
|
258 |
if not silent:
|
259 |
print('Outputing stuff in', tmpdirname)
|
260 |
|
261 |
+
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model,
|
262 |
+
retrieval_model, device, silent, image_size)
|
263 |
model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
|
264 |
|
265 |
+
available_scenegraph_type = [("complete: all possible image pairs", "complete"),
|
266 |
+
("swin: sliding window", "swin"),
|
267 |
+
("logwin: sliding window with long range", "logwin"),
|
268 |
+
("oneref: match one image with all", "oneref")]
|
269 |
+
if retrieval_model is not None:
|
270 |
+
available_scenegraph_type.insert(1, ("retrieval: connect views based on similarity", "retrieval"))
|
271 |
+
|
272 |
def get_context(delete_cache):
|
273 |
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
274 |
title = "MASt3R Demo"
|
|
|
287 |
with gradio.Column():
|
288 |
with gradio.Row():
|
289 |
lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
|
290 |
+
niter1 = gradio.Slider(value=300, minimum=0, maximum=1000, step=1,
|
291 |
+
label="Iterations", info="For coarse alignment")
|
292 |
+
lr2 = gradio.Slider(label="Fine LR", value=0.01, minimum=0.005, maximum=0.05, step=0.001)
|
293 |
+
niter2 = gradio.Slider(value=300, minimum=0, maximum=1000, step=1,
|
294 |
+
label="Iterations", info="For refinement")
|
295 |
optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
|
296 |
value='refine', label="OptLevel",
|
297 |
info="Optimization level")
|
298 |
with gradio.Row():
|
299 |
+
matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=0.,
|
300 |
minimum=0., maximum=30., step=0.1,
|
301 |
info="Before Fallback to Regr3D!")
|
302 |
shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
|
303 |
info="Only optimize one set of intrinsics for all views")
|
304 |
+
scenegraph_type = gradio.Dropdown(available_scenegraph_type,
|
|
|
|
|
|
|
305 |
value='complete', label="Scenegraph",
|
306 |
info="Define how to make pairs",
|
307 |
interactive=True)
|
308 |
+
with gradio.Column(visible=False) as graph_opt:
|
309 |
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
310 |
minimum=1, maximum=1, step=1)
|
311 |
win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
|
312 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0,
|
313 |
+
minimum=0, maximum=0, step=1, visible=False)
|
314 |
+
|
315 |
run_btn = gradio.Button("Run")
|
316 |
|
317 |
with gradio.Row():
|
|
|
332 |
# events
|
333 |
scenegraph_type.change(set_scenegraph_options,
|
334 |
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
335 |
+
outputs=[graph_opt, winsize, win_cyclic, refid])
|
336 |
inputfiles.change(set_scenegraph_options,
|
337 |
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
338 |
+
outputs=[graph_opt, winsize, win_cyclic, refid])
|
339 |
win_cyclic.change(set_scenegraph_options,
|
340 |
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
341 |
+
outputs=[graph_opt, winsize, win_cyclic, refid])
|
342 |
run_btn.click(fn=recon_fun,
|
343 |
inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
344 |
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
mast3r/demo_glomap.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
4 |
+
#
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# gradio demo functions
|
7 |
+
# --------------------------------------------------------
|
8 |
+
import pycolmap
|
9 |
+
import gradio
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
import functools
|
13 |
+
import trimesh
|
14 |
+
import copy
|
15 |
+
from scipy.spatial.transform import Rotation
|
16 |
+
import tempfile
|
17 |
+
import shutil
|
18 |
+
import PIL.Image
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from kapture.converter.colmap.database_extra import kapture_to_colmap
|
22 |
+
from kapture.converter.colmap.database import COLMAPDatabase
|
23 |
+
|
24 |
+
from mast3r.colmap.mapping import kapture_import_image_folder_or_list, run_mast3r_matching, glomap_run_mapper
|
25 |
+
from mast3r.demo import set_scenegraph_options
|
26 |
+
from mast3r.retrieval.processor import Retriever
|
27 |
+
from mast3r.image_pairs import make_pairs
|
28 |
+
|
29 |
+
import mast3r.utils.path_to_dust3r # noqa
|
30 |
+
from dust3r.utils.image import load_images
|
31 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL
|
32 |
+
from dust3r.demo import get_args_parser as dust3r_get_args_parser
|
33 |
+
|
34 |
+
import matplotlib.pyplot as pl
|
35 |
+
|
36 |
+
|
37 |
+
class GlomapRecon:
|
38 |
+
def __init__(self, world_to_cam, intrinsics, points3d, imgs):
|
39 |
+
self.world_to_cam = world_to_cam
|
40 |
+
self.intrinsics = intrinsics
|
41 |
+
self.points3d = points3d
|
42 |
+
self.imgs = imgs
|
43 |
+
|
44 |
+
|
45 |
+
class GlomapReconState:
|
46 |
+
def __init__(self, glomap_recon, should_delete=False, cache_dir=None, outfile_name=None):
|
47 |
+
self.glomap_recon = glomap_recon
|
48 |
+
self.cache_dir = cache_dir
|
49 |
+
self.outfile_name = outfile_name
|
50 |
+
self.should_delete = should_delete
|
51 |
+
|
52 |
+
def __del__(self):
|
53 |
+
if not self.should_delete:
|
54 |
+
return
|
55 |
+
if self.cache_dir is not None and os.path.isdir(self.cache_dir):
|
56 |
+
shutil.rmtree(self.cache_dir)
|
57 |
+
self.cache_dir = None
|
58 |
+
if self.outfile_name is not None and os.path.isfile(self.outfile_name):
|
59 |
+
os.remove(self.outfile_name)
|
60 |
+
self.outfile_name = None
|
61 |
+
|
62 |
+
|
63 |
+
def get_args_parser():
|
64 |
+
parser = dust3r_get_args_parser()
|
65 |
+
parser.add_argument('--share', action='store_true')
|
66 |
+
parser.add_argument('--gradio_delete_cache', default=None, type=int,
|
67 |
+
help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
|
68 |
+
parser.add_argument('--glomap_bin', default='glomap', type=str, help='glomap bin')
|
69 |
+
parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded")
|
70 |
+
|
71 |
+
actions = parser._actions
|
72 |
+
for action in actions:
|
73 |
+
if action.dest == 'model_name':
|
74 |
+
action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
|
75 |
+
# change defaults
|
76 |
+
parser.prog = 'mast3r demo'
|
77 |
+
return parser
|
78 |
+
|
79 |
+
|
80 |
+
def get_reconstructed_scene(glomap_bin, outdir, gradio_delete_cache, model, retrieval_model, device, silent, image_size,
|
81 |
+
current_scene_state, filelist, transparent_cams, cam_size, scenegraph_type, winsize,
|
82 |
+
win_cyclic, refid, shared_intrinsics, **kw):
|
83 |
+
"""
|
84 |
+
from a list of images, run mast3r inference, sparse global aligner.
|
85 |
+
then run get_3D_model_from_scene
|
86 |
+
"""
|
87 |
+
imgs = load_images(filelist, size=image_size, verbose=not silent)
|
88 |
+
if len(imgs) == 1:
|
89 |
+
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
90 |
+
imgs[1]['idx'] = 1
|
91 |
+
filelist = [filelist[0], filelist[0]]
|
92 |
+
|
93 |
+
scene_graph_params = [scenegraph_type]
|
94 |
+
if scenegraph_type in ["swin", "logwin"]:
|
95 |
+
scene_graph_params.append(str(winsize))
|
96 |
+
elif scenegraph_type == "oneref":
|
97 |
+
scene_graph_params.append(str(refid))
|
98 |
+
elif scenegraph_type == "retrieval":
|
99 |
+
scene_graph_params.append(str(winsize)) # Na
|
100 |
+
scene_graph_params.append(str(refid)) # k
|
101 |
+
|
102 |
+
if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
|
103 |
+
scene_graph_params.append('noncyclic')
|
104 |
+
scene_graph = '-'.join(scene_graph_params)
|
105 |
+
|
106 |
+
sim_matrix = None
|
107 |
+
if 'retrieval' in scenegraph_type:
|
108 |
+
assert retrieval_model is not None
|
109 |
+
retriever = Retriever(retrieval_model, backbone=model, device=device)
|
110 |
+
with torch.no_grad():
|
111 |
+
sim_matrix = retriever(filelist)
|
112 |
+
|
113 |
+
# Cleanup
|
114 |
+
del retriever
|
115 |
+
torch.cuda.empty_cache()
|
116 |
+
|
117 |
+
pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix)
|
118 |
+
|
119 |
+
if current_scene_state is not None and \
|
120 |
+
not current_scene_state.should_delete and \
|
121 |
+
current_scene_state.cache_dir is not None:
|
122 |
+
cache_dir = current_scene_state.cache_dir
|
123 |
+
elif gradio_delete_cache:
|
124 |
+
cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir)
|
125 |
+
else:
|
126 |
+
cache_dir = os.path.join(outdir, 'cache')
|
127 |
+
|
128 |
+
root_path = os.path.commonpath(filelist)
|
129 |
+
filelist_relpath = [
|
130 |
+
os.path.relpath(filename, root_path).replace('\\', '/')
|
131 |
+
for filename in filelist
|
132 |
+
]
|
133 |
+
kdata = kapture_import_image_folder_or_list((root_path, filelist_relpath), shared_intrinsics)
|
134 |
+
image_pairs = [
|
135 |
+
(filelist_relpath[img1['idx']], filelist_relpath[img2['idx']])
|
136 |
+
for img1, img2 in pairs
|
137 |
+
]
|
138 |
+
|
139 |
+
colmap_db_path = os.path.join(cache_dir, 'colmap.db')
|
140 |
+
if os.path.isfile(colmap_db_path):
|
141 |
+
os.remove(colmap_db_path)
|
142 |
+
|
143 |
+
os.makedirs(os.path.dirname(colmap_db_path), exist_ok=True)
|
144 |
+
colmap_db = COLMAPDatabase.connect(colmap_db_path)
|
145 |
+
try:
|
146 |
+
kapture_to_colmap(kdata, root_path, tar_handler=None, database=colmap_db,
|
147 |
+
keypoints_type=None, descriptors_type=None, export_two_view_geometry=False)
|
148 |
+
colmap_image_pairs = run_mast3r_matching(model, image_size, 16, device,
|
149 |
+
kdata, root_path, image_pairs, colmap_db,
|
150 |
+
False, 5, 1.001,
|
151 |
+
False, 3)
|
152 |
+
colmap_db.close()
|
153 |
+
except Exception as e:
|
154 |
+
print(f'Error {e}')
|
155 |
+
colmap_db.close()
|
156 |
+
exit(1)
|
157 |
+
|
158 |
+
if len(colmap_image_pairs) == 0:
|
159 |
+
raise Exception("no matches were kept")
|
160 |
+
|
161 |
+
# colmap db is now full, run colmap
|
162 |
+
colmap_world_to_cam = {}
|
163 |
+
print("verify_matches")
|
164 |
+
f = open(cache_dir + '/pairs.txt', "w")
|
165 |
+
for image_path1, image_path2 in colmap_image_pairs:
|
166 |
+
f.write("{} {}\n".format(image_path1, image_path2))
|
167 |
+
f.close()
|
168 |
+
pycolmap.verify_matches(colmap_db_path, cache_dir + '/pairs.txt')
|
169 |
+
|
170 |
+
reconstruction_path = os.path.join(cache_dir, "reconstruction")
|
171 |
+
if os.path.isdir(reconstruction_path):
|
172 |
+
shutil.rmtree(reconstruction_path)
|
173 |
+
os.makedirs(reconstruction_path, exist_ok=True)
|
174 |
+
glomap_run_mapper(glomap_bin, colmap_db_path, reconstruction_path, root_path)
|
175 |
+
|
176 |
+
if current_scene_state is not None and \
|
177 |
+
not current_scene_state.should_delete and \
|
178 |
+
current_scene_state.outfile_name is not None:
|
179 |
+
outfile_name = current_scene_state.outfile_name
|
180 |
+
else:
|
181 |
+
outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
|
182 |
+
|
183 |
+
ouput_recon = pycolmap.Reconstruction(os.path.join(reconstruction_path, '0'))
|
184 |
+
print(ouput_recon.summary())
|
185 |
+
|
186 |
+
colmap_world_to_cam = {}
|
187 |
+
colmap_intrinsics = {}
|
188 |
+
colmap_image_id_to_name = {}
|
189 |
+
images = {}
|
190 |
+
num_reg_images = ouput_recon.num_reg_images()
|
191 |
+
for idx, (colmap_imgid, colmap_image) in enumerate(ouput_recon.images.items()):
|
192 |
+
colmap_image_id_to_name[colmap_imgid] = colmap_image.name
|
193 |
+
if callable(colmap_image.cam_from_world.matrix):
|
194 |
+
colmap_world_to_cam[colmap_imgid] = colmap_image.cam_from_world.matrix(
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
colmap_world_to_cam[colmap_imgid] = colmap_image.cam_from_world.matrix
|
198 |
+
camera = ouput_recon.cameras[colmap_image.camera_id]
|
199 |
+
K = np.eye(3)
|
200 |
+
K[0, 0] = camera.focal_length_x
|
201 |
+
K[1, 1] = camera.focal_length_y
|
202 |
+
K[0, 2] = camera.principal_point_x
|
203 |
+
K[1, 2] = camera.principal_point_y
|
204 |
+
colmap_intrinsics[colmap_imgid] = K
|
205 |
+
|
206 |
+
with PIL.Image.open(os.path.join(root_path, colmap_image.name)) as im:
|
207 |
+
images[colmap_imgid] = np.asarray(im)
|
208 |
+
|
209 |
+
if idx + 1 == num_reg_images:
|
210 |
+
break # bug with the iterable ?
|
211 |
+
points3D = []
|
212 |
+
num_points3D = ouput_recon.num_points3D()
|
213 |
+
for idx, (pt3d_id, pts3d) in enumerate(ouput_recon.points3D.items()):
|
214 |
+
points3D.append((pts3d.xyz, pts3d.color))
|
215 |
+
if idx + 1 == num_points3D:
|
216 |
+
break # bug with the iterable ?
|
217 |
+
scene = GlomapRecon(colmap_world_to_cam, colmap_intrinsics, points3D, images)
|
218 |
+
scene_state = GlomapReconState(scene, gradio_delete_cache, cache_dir, outfile_name)
|
219 |
+
outfile = get_3D_model_from_scene(silent, scene_state, transparent_cams, cam_size)
|
220 |
+
return scene_state, outfile
|
221 |
+
|
222 |
+
|
223 |
+
def get_3D_model_from_scene(silent, scene_state, transparent_cams=False, cam_size=0.05):
|
224 |
+
"""
|
225 |
+
extract 3D_model (glb file) from a reconstructed scene
|
226 |
+
"""
|
227 |
+
if scene_state is None:
|
228 |
+
return None
|
229 |
+
outfile = scene_state.outfile_name
|
230 |
+
if outfile is None:
|
231 |
+
return None
|
232 |
+
|
233 |
+
recon = scene_state.glomap_recon
|
234 |
+
|
235 |
+
scene = trimesh.Scene()
|
236 |
+
pts = np.stack([p[0] for p in recon.points3d], axis=0)
|
237 |
+
col = np.stack([p[1] for p in recon.points3d], axis=0)
|
238 |
+
pct = trimesh.PointCloud(pts, colors=col)
|
239 |
+
scene.add_geometry(pct)
|
240 |
+
|
241 |
+
# add each camera
|
242 |
+
cams2world = []
|
243 |
+
for i, (id, pose_w2c_3x4) in enumerate(recon.world_to_cam.items()):
|
244 |
+
intrinsics = recon.intrinsics[id]
|
245 |
+
focal = (intrinsics[0, 0] + intrinsics[1, 1]) / 2.0
|
246 |
+
camera_edge_color = CAM_COLORS[i % len(CAM_COLORS)]
|
247 |
+
pose_w2c = np.eye(4)
|
248 |
+
pose_w2c[:3, :] = pose_w2c_3x4
|
249 |
+
pose_c2w = np.linalg.inv(pose_w2c)
|
250 |
+
cams2world.append(pose_c2w)
|
251 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
252 |
+
None if transparent_cams else recon.imgs[id], focal,
|
253 |
+
imsize=recon.imgs[id].shape[1::-1], screen_width=cam_size)
|
254 |
+
|
255 |
+
rot = np.eye(4)
|
256 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
257 |
+
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
258 |
+
if not silent:
|
259 |
+
print('(exporting 3D scene to', outfile, ')')
|
260 |
+
scene.export(file_obj=outfile)
|
261 |
+
|
262 |
+
return outfile
|
263 |
+
|
264 |
+
|
265 |
+
def main_demo(glomap_bin, tmpdirname, model, retrieval_model, device, image_size, server_name, server_port,
|
266 |
+
silent=False, share=False, gradio_delete_cache=False):
|
267 |
+
if not silent:
|
268 |
+
print('Outputing stuff in', tmpdirname)
|
269 |
+
|
270 |
+
recon_fun = functools.partial(get_reconstructed_scene, glomap_bin, tmpdirname, gradio_delete_cache, model,
|
271 |
+
retrieval_model, device, silent, image_size)
|
272 |
+
model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
|
273 |
+
|
274 |
+
available_scenegraph_type = [("complete: all possible image pairs", "complete"),
|
275 |
+
("swin: sliding window", "swin"),
|
276 |
+
("logwin: sliding window with long range", "logwin"),
|
277 |
+
("oneref: match one image with all", "oneref")]
|
278 |
+
if retrieval_model is not None:
|
279 |
+
available_scenegraph_type.insert(1, ("retrieval: connect views based on similarity", "retrieval"))
|
280 |
+
|
281 |
+
def get_context(delete_cache):
|
282 |
+
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
283 |
+
title = "MASt3R Demo"
|
284 |
+
if delete_cache:
|
285 |
+
return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
|
286 |
+
else:
|
287 |
+
return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
|
288 |
+
|
289 |
+
with get_context(gradio_delete_cache) as demo:
|
290 |
+
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
|
291 |
+
scene = gradio.State(None)
|
292 |
+
gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
|
293 |
+
with gradio.Column():
|
294 |
+
inputfiles = gradio.File(file_count="multiple")
|
295 |
+
with gradio.Row():
|
296 |
+
shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
|
297 |
+
info="Only optimize one set of intrinsics for all views")
|
298 |
+
scenegraph_type = gradio.Dropdown(available_scenegraph_type,
|
299 |
+
value='complete', label="Scenegraph",
|
300 |
+
info="Define how to make pairs",
|
301 |
+
interactive=True)
|
302 |
+
with gradio.Column(visible=False) as win_col:
|
303 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
304 |
+
minimum=1, maximum=1, step=1)
|
305 |
+
win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
|
306 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0,
|
307 |
+
minimum=0, maximum=0, step=1, visible=False)
|
308 |
+
run_btn = gradio.Button("Run")
|
309 |
+
|
310 |
+
with gradio.Row():
|
311 |
+
# adjust the camera size in the output pointcloud
|
312 |
+
cam_size = gradio.Slider(label="cam_size", value=0.01, minimum=0.001, maximum=1.0, step=0.001)
|
313 |
+
with gradio.Row():
|
314 |
+
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
|
315 |
+
|
316 |
+
outmodel = gradio.Model3D()
|
317 |
+
|
318 |
+
# events
|
319 |
+
scenegraph_type.change(set_scenegraph_options,
|
320 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
321 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
322 |
+
inputfiles.change(set_scenegraph_options,
|
323 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
324 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
325 |
+
win_cyclic.change(set_scenegraph_options,
|
326 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
327 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
328 |
+
run_btn.click(fn=recon_fun,
|
329 |
+
inputs=[scene, inputfiles, transparent_cams, cam_size,
|
330 |
+
scenegraph_type, winsize, win_cyclic, refid, shared_intrinsics],
|
331 |
+
outputs=[scene, outmodel])
|
332 |
+
cam_size.change(fn=model_from_scene_fun,
|
333 |
+
inputs=[scene, transparent_cams, cam_size],
|
334 |
+
outputs=outmodel)
|
335 |
+
transparent_cams.change(model_from_scene_fun,
|
336 |
+
inputs=[scene, transparent_cams, cam_size],
|
337 |
+
outputs=outmodel)
|
338 |
+
demo.launch(share=share, server_name=server_name, server_port=server_port)
|
mast3r/image_pairs.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utilities needed to load image pairs
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from mast3r.retrieval.graph import make_pairs_fps
|
10 |
+
|
11 |
+
def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True, sim_mat=None):
|
12 |
+
pairs = []
|
13 |
+
if scene_graph == 'complete': # complete graph
|
14 |
+
for i in range(len(imgs)):
|
15 |
+
for j in range(i):
|
16 |
+
pairs.append((imgs[i], imgs[j]))
|
17 |
+
elif scene_graph.startswith('swin'):
|
18 |
+
iscyclic = not scene_graph.endswith('noncyclic')
|
19 |
+
try:
|
20 |
+
winsize = int(scene_graph.split('-')[1])
|
21 |
+
except Exception as e:
|
22 |
+
winsize = 3
|
23 |
+
pairsid = set()
|
24 |
+
for i in range(len(imgs)):
|
25 |
+
for j in range(1, winsize + 1):
|
26 |
+
idx = (i + j)
|
27 |
+
if iscyclic:
|
28 |
+
idx = idx % len(imgs) # explicit loop closure
|
29 |
+
if idx >= len(imgs):
|
30 |
+
continue
|
31 |
+
pairsid.add((i, idx) if i < idx else (idx, i))
|
32 |
+
for i, j in pairsid:
|
33 |
+
pairs.append((imgs[i], imgs[j]))
|
34 |
+
elif scene_graph.startswith('logwin'):
|
35 |
+
iscyclic = not scene_graph.endswith('noncyclic')
|
36 |
+
try:
|
37 |
+
winsize = int(scene_graph.split('-')[1])
|
38 |
+
except Exception as e:
|
39 |
+
winsize = 3
|
40 |
+
offsets = [2**i for i in range(winsize)]
|
41 |
+
pairsid = set()
|
42 |
+
for i in range(len(imgs)):
|
43 |
+
ixs_l = [i - off for off in offsets]
|
44 |
+
ixs_r = [i + off for off in offsets]
|
45 |
+
for j in ixs_l + ixs_r:
|
46 |
+
if iscyclic:
|
47 |
+
j = j % len(imgs) # Explicit loop closure
|
48 |
+
if j < 0 or j >= len(imgs) or j == i:
|
49 |
+
continue
|
50 |
+
pairsid.add((i, j) if i < j else (j, i))
|
51 |
+
for i, j in pairsid:
|
52 |
+
pairs.append((imgs[i], imgs[j]))
|
53 |
+
elif scene_graph.startswith('oneref'):
|
54 |
+
refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0
|
55 |
+
for j in range(len(imgs)):
|
56 |
+
if j != refid:
|
57 |
+
pairs.append((imgs[refid], imgs[j]))
|
58 |
+
elif scene_graph.startswith('retrieval'):
|
59 |
+
mode, Na, k = scene_graph.split('-')
|
60 |
+
assert sim_mat is not None, "sim_mat is required for retrieval mode"
|
61 |
+
|
62 |
+
fps_pairs, anchor_idxs = make_pairs_fps(sim_mat, Na=int(Na), tokK=int(k), dist_thresh=None)
|
63 |
+
|
64 |
+
for i, j in fps_pairs:
|
65 |
+
pairs.append((imgs[i], imgs[j]))
|
66 |
+
else:
|
67 |
+
raise ValueError(f'unrecognized value for {scene_graph=}')
|
68 |
+
|
69 |
+
if symmetrize:
|
70 |
+
pairs += [(img2, img1) for img1, img2 in pairs]
|
71 |
+
|
72 |
+
# now, remove edges
|
73 |
+
if isinstance(prefilter, str) and prefilter.startswith('seq'):
|
74 |
+
pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
|
75 |
+
|
76 |
+
if isinstance(prefilter, str) and prefilter.startswith('cyc'):
|
77 |
+
pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
|
78 |
+
|
79 |
+
return pairs
|
80 |
+
|
81 |
+
|
82 |
+
def sel(x, kept):
|
83 |
+
if isinstance(x, dict):
|
84 |
+
return {k: sel(v, kept) for k, v in x.items()}
|
85 |
+
if isinstance(x, (torch.Tensor, np.ndarray)):
|
86 |
+
return x[kept]
|
87 |
+
if isinstance(x, (tuple, list)):
|
88 |
+
return type(x)([x[k] for k in kept])
|
89 |
+
|
90 |
+
|
91 |
+
def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
|
92 |
+
# number of images
|
93 |
+
n = max(max(e) for e in edges) + 1
|
94 |
+
|
95 |
+
kept = []
|
96 |
+
for e, (i, j) in enumerate(edges):
|
97 |
+
dis = abs(i - j)
|
98 |
+
if cyclic:
|
99 |
+
dis = min(dis, abs(i + n - j), abs(i - n - j))
|
100 |
+
if dis <= seq_dis_thr:
|
101 |
+
kept.append(e)
|
102 |
+
return kept
|
103 |
+
|
104 |
+
|
105 |
+
def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
|
106 |
+
edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs]
|
107 |
+
kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
|
108 |
+
return [pairs[i] for i in kept]
|
109 |
+
|
110 |
+
|
111 |
+
def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):
|
112 |
+
edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
|
113 |
+
kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
|
114 |
+
print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges')
|
115 |
+
return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)
|
mast3r/losses.py
CHANGED
@@ -273,7 +273,7 @@ class InfoNCE(MatchingCriterion):
|
|
273 |
|
274 |
|
275 |
class APLoss (MatchingCriterion):
|
276 |
-
""" AP loss
|
277 |
"""
|
278 |
|
279 |
def __init__(self, nq='torch', min=0, max=1, euc=False, **kw):
|
|
|
273 |
|
274 |
|
275 |
class APLoss (MatchingCriterion):
|
276 |
+
""" AP loss
|
277 |
"""
|
278 |
|
279 |
def __init__(self, nq='torch', min=0, max=1, euc=False, **kw):
|
mast3r/retrieval/graph.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Building the graph based on retrieval results.
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def farthest_point_sampling(dist, N=None, dist_thresh=None):
|
11 |
+
"""Farthest point sampling.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
dist: NxN distance matrix.
|
15 |
+
N: Number of points to sample.
|
16 |
+
dist_thresh: Distance threshold. Point sampling terminates once the
|
17 |
+
maximum distance is below this threshold.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
indices: Indices of the sampled points.
|
21 |
+
"""
|
22 |
+
|
23 |
+
assert N is not None or dist_thresh is not None, "Either N or min_dist must be provided."
|
24 |
+
|
25 |
+
if N is None:
|
26 |
+
N = dist.shape[0]
|
27 |
+
|
28 |
+
indices = []
|
29 |
+
distances = [0]
|
30 |
+
indices.append(np.random.choice(dist.shape[0]))
|
31 |
+
for i in range(1, N):
|
32 |
+
d = dist[indices].min(axis=0)
|
33 |
+
bst = d.argmax()
|
34 |
+
bst_dist = d[bst]
|
35 |
+
if dist_thresh is not None and bst_dist < dist_thresh:
|
36 |
+
break
|
37 |
+
indices.append(bst)
|
38 |
+
distances.append(bst_dist)
|
39 |
+
return np.array(indices), np.array(distances)
|
40 |
+
|
41 |
+
|
42 |
+
def make_pairs_fps(sim_mat, Na=20, tokK=1, dist_thresh=None):
|
43 |
+
dist_mat = 1 - sim_mat
|
44 |
+
|
45 |
+
pairs = set()
|
46 |
+
keyimgs_idx = np.array([])
|
47 |
+
if Na != 0:
|
48 |
+
keyimgs_idx, _ = farthest_point_sampling(dist_mat, N=Na, dist_thresh=dist_thresh)
|
49 |
+
|
50 |
+
# 1. Complete graph between key images
|
51 |
+
for i in range(len(keyimgs_idx)):
|
52 |
+
for j in range(i + 1, len(keyimgs_idx)):
|
53 |
+
idx_i, idx_j = keyimgs_idx[i], keyimgs_idx[j]
|
54 |
+
pairs.add((idx_i, idx_j))
|
55 |
+
|
56 |
+
# 2. Connect non-key images to the earest key image
|
57 |
+
keyimg_dist_mat = dist_mat[:, keyimgs_idx]
|
58 |
+
for i in range(keyimg_dist_mat.shape[0]):
|
59 |
+
if i in keyimgs_idx:
|
60 |
+
continue
|
61 |
+
j = keyimg_dist_mat[i].argmax()
|
62 |
+
i1, i2 = min(i, keyimgs_idx[j]), max(i, keyimgs_idx[j])
|
63 |
+
if i1 != i2 and (i1, i2) not in pairs:
|
64 |
+
pairs.add((i1, i2))
|
65 |
+
|
66 |
+
# 3. Add some local connections (k-NN) for each view
|
67 |
+
if tokK > 0:
|
68 |
+
for i in range(dist_mat.shape[0]):
|
69 |
+
idx = dist_mat[i].argsort()[:tokK]
|
70 |
+
for j in idx:
|
71 |
+
i1, i2 = min(i, j), max(i, j)
|
72 |
+
if i1 != i2 and (i1, i2) not in pairs:
|
73 |
+
pairs.add((i1, i2))
|
74 |
+
|
75 |
+
pairs = list(pairs)
|
76 |
+
|
77 |
+
return pairs, keyimgs_idx
|
mast3r/retrieval/model.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Whitener and RetrievalModel
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
import time
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
import mast3r.utils.path_to_dust3r # noqa
|
15 |
+
from dust3r.utils.image import load_images
|
16 |
+
|
17 |
+
default_device = torch.device('cuda' if torch.cuda.is_available() and torch.cuda.device_count() > 0 else 'cpu')
|
18 |
+
|
19 |
+
|
20 |
+
# from https://github.com/gtolias/how/blob/4d73c88e0ffb55506e2ce6249e2a015ef6ccf79f/how/utils/whitening.py#L20
|
21 |
+
def pcawhitenlearn_shrinkage(X, s=1.0):
|
22 |
+
"""Learn PCA whitening with shrinkage from given descriptors"""
|
23 |
+
N = X.shape[0]
|
24 |
+
|
25 |
+
# Learning PCA w/o annotations
|
26 |
+
m = X.mean(axis=0, keepdims=True)
|
27 |
+
Xc = X - m
|
28 |
+
Xcov = np.dot(Xc.T, Xc)
|
29 |
+
Xcov = (Xcov + Xcov.T) / (2 * N)
|
30 |
+
eigval, eigvec = np.linalg.eig(Xcov)
|
31 |
+
order = eigval.argsort()[::-1]
|
32 |
+
eigval = eigval[order]
|
33 |
+
eigvec = eigvec[:, order]
|
34 |
+
|
35 |
+
eigval = np.clip(eigval, a_min=1e-14, a_max=None)
|
36 |
+
P = np.dot(np.linalg.inv(np.diag(np.power(eigval, 0.5 * s))), eigvec.T)
|
37 |
+
|
38 |
+
return m, P.T
|
39 |
+
|
40 |
+
|
41 |
+
class Dust3rInputFromImageList(torch.utils.data.Dataset):
|
42 |
+
def __init__(self, image_list, imsize=512):
|
43 |
+
super().__init__()
|
44 |
+
self.image_list = image_list
|
45 |
+
assert imsize == 512
|
46 |
+
self.imsize = imsize
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.image_list)
|
50 |
+
|
51 |
+
def __getitem__(self, index):
|
52 |
+
return load_images([self.image_list[index]], size=self.imsize, verbose=False)[0]
|
53 |
+
|
54 |
+
|
55 |
+
class Whitener(nn.Module):
|
56 |
+
def __init__(self, dim, l2norm=None):
|
57 |
+
super().__init__()
|
58 |
+
self.m = torch.nn.Parameter(torch.zeros((1, dim)).double())
|
59 |
+
self.p = torch.nn.Parameter(torch.eye(dim, dim).double())
|
60 |
+
self.l2norm = l2norm # if not None, apply l2 norm along a given dimension
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
with torch.autocast(self.m.device.type, enabled=False):
|
64 |
+
shape = x.size()
|
65 |
+
input_type = x.dtype
|
66 |
+
x_reshaped = x.view(-1, shape[-1]).to(dtype=self.m.dtype)
|
67 |
+
# Center the input data
|
68 |
+
x_centered = x_reshaped - self.m
|
69 |
+
# Apply PCA transformation
|
70 |
+
pca_output = torch.matmul(x_centered, self.p)
|
71 |
+
# reshape back
|
72 |
+
pca_output_shape = shape # list(shape[:-1]) + [shape[-1]]
|
73 |
+
pca_output = pca_output.view(pca_output_shape)
|
74 |
+
if self.l2norm is not None:
|
75 |
+
return torch.nn.functional.normalize(pca_output, dim=self.l2norm).to(dtype=input_type)
|
76 |
+
return pca_output.to(dtype=input_type)
|
77 |
+
|
78 |
+
|
79 |
+
def weighted_spoc(feat, attn):
|
80 |
+
"""
|
81 |
+
feat: BxNxC
|
82 |
+
attn: BxN
|
83 |
+
output: BxC L2-normalization weighted-sum-pooling of features
|
84 |
+
"""
|
85 |
+
return torch.nn.functional.normalize((feat * attn[:, :, None]).sum(dim=1), dim=1)
|
86 |
+
|
87 |
+
|
88 |
+
def how_select_local(feat, attn, nfeat):
|
89 |
+
"""
|
90 |
+
feat: BxNxC
|
91 |
+
attn: BxN
|
92 |
+
nfeat: nfeat to keep
|
93 |
+
"""
|
94 |
+
# get nfeat
|
95 |
+
if nfeat < 0:
|
96 |
+
assert nfeat >= -1.0
|
97 |
+
nfeat = int(-nfeat * feat.size(1))
|
98 |
+
else:
|
99 |
+
nfeat = int(nfeat)
|
100 |
+
# asort
|
101 |
+
topk_attn, topk_indices = torch.topk(attn, min(nfeat, attn.size(1)), dim=1)
|
102 |
+
topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, feat.size(2))
|
103 |
+
topk_features = torch.gather(feat, 1, topk_indices_expanded)
|
104 |
+
return topk_features, topk_attn, topk_indices
|
105 |
+
|
106 |
+
|
107 |
+
class RetrievalModel(nn.Module):
|
108 |
+
def __init__(self, backbone, freeze_backbone=1, prewhiten=None, hdims=[1024], residual=False, postwhiten=None,
|
109 |
+
featweights='l2norm', nfeat=300, pretrained_retrieval=None):
|
110 |
+
super().__init__()
|
111 |
+
self.backbone = backbone
|
112 |
+
self.freeze_backbone = freeze_backbone
|
113 |
+
if freeze_backbone:
|
114 |
+
for p in self.backbone.parameters():
|
115 |
+
p.requires_grad = False
|
116 |
+
self.backbone_dim = backbone.enc_embed_dim
|
117 |
+
self.prewhiten = nn.Identity() if prewhiten is None else Whitener(self.backbone_dim)
|
118 |
+
self.prewhiten_freq = prewhiten
|
119 |
+
if prewhiten is not None and prewhiten != -1:
|
120 |
+
for p in self.prewhiten.parameters():
|
121 |
+
p.requires_grad = False
|
122 |
+
self.residual = residual
|
123 |
+
self.projector = self.build_projector(hdims, residual)
|
124 |
+
self.dim = hdims[-1] if len(hdims) > 0 else self.backbone_dim
|
125 |
+
self.postwhiten_freq = postwhiten
|
126 |
+
self.postwhiten = nn.Identity() if postwhiten is None else Whitener(self.dim)
|
127 |
+
if postwhiten is not None and postwhiten != -1:
|
128 |
+
assert len(hdims) > 0
|
129 |
+
for p in self.postwhiten.parameters():
|
130 |
+
p.requires_grad = False
|
131 |
+
self.featweights = featweights
|
132 |
+
if featweights == 'l2norm':
|
133 |
+
self.attention = lambda x: x.norm(dim=-1)
|
134 |
+
else:
|
135 |
+
raise NotImplementedError(featweights)
|
136 |
+
self.nfeat = nfeat
|
137 |
+
self.pretrained_retrieval = pretrained_retrieval
|
138 |
+
if self.pretrained_retrieval is not None:
|
139 |
+
ckpt = torch.load(pretrained_retrieval, 'cpu')
|
140 |
+
msg = self.load_state_dict(ckpt['model'], strict=False)
|
141 |
+
assert len(msg.unexpected_keys) == 0 and all(k.startswith('backbone')
|
142 |
+
or k.startswith('postwhiten') for k in msg.missing_keys)
|
143 |
+
|
144 |
+
def build_projector(self, hdims, residual):
|
145 |
+
if self.residual:
|
146 |
+
assert hdims[-1] == self.backbone_dim
|
147 |
+
d = self.backbone_dim
|
148 |
+
if len(hdims) == 0:
|
149 |
+
return nn.Identity()
|
150 |
+
layers = []
|
151 |
+
for i in range(len(hdims) - 1):
|
152 |
+
layers.append(nn.Linear(d, hdims[i]))
|
153 |
+
d = hdims[i]
|
154 |
+
layers.append(nn.LayerNorm(d))
|
155 |
+
layers.append(nn.GELU())
|
156 |
+
layers.append(nn.Linear(d, hdims[-1]))
|
157 |
+
return nn.Sequential(*layers)
|
158 |
+
|
159 |
+
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
|
160 |
+
ss = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
161 |
+
if self.freeze_backbone:
|
162 |
+
ss = {k: v for k, v in ss.items() if not k.startswith('backbone')}
|
163 |
+
return ss
|
164 |
+
|
165 |
+
def reinitialize_whitening(self, epoch, train_dataset, nimgs=5000, log_writer=None, max_nfeat_per_image=None, seed=0, device=default_device):
|
166 |
+
do_prewhiten = self.prewhiten_freq is not None and self.pretrained_retrieval is None and \
|
167 |
+
(epoch == 0 or (self.prewhiten_freq > 0 and epoch % self.prewhiten_freq == 0))
|
168 |
+
do_postwhiten = self.postwhiten_freq is not None and ((epoch == 0 and self.postwhiten_freq in [0, -1])
|
169 |
+
or (self.postwhiten_freq > 0 and
|
170 |
+
epoch % self.postwhiten_freq == 0 and epoch > 0))
|
171 |
+
if do_prewhiten or do_postwhiten:
|
172 |
+
self.eval()
|
173 |
+
imdataset = train_dataset.imlist_dataset_n_images(nimgs, seed)
|
174 |
+
loader = torch.utils.data.DataLoader(imdataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
|
175 |
+
if do_prewhiten:
|
176 |
+
print('Re-initialization of pre-whitening')
|
177 |
+
t = time.time()
|
178 |
+
with torch.no_grad():
|
179 |
+
features = []
|
180 |
+
for d in tqdm(loader):
|
181 |
+
feat = self.backbone._encode_image(d['img'][0, ...].to(device),
|
182 |
+
true_shape=d['true_shape'][0, ...])[0]
|
183 |
+
feat = feat.flatten(0, 1)
|
184 |
+
if max_nfeat_per_image is not None and max_nfeat_per_image < feat.size(0):
|
185 |
+
l2norms = torch.linalg.vector_norm(feat, dim=1)
|
186 |
+
feat = feat[torch.argsort(-l2norms)[:max_nfeat_per_image], :]
|
187 |
+
features.append(feat.cpu())
|
188 |
+
features = torch.cat(features, dim=0)
|
189 |
+
features = features.numpy()
|
190 |
+
m, P = pcawhitenlearn_shrinkage(features)
|
191 |
+
self.prewhiten.load_state_dict({'m': torch.from_numpy(m), 'p': torch.from_numpy(P)})
|
192 |
+
prewhiten_time = time.time() - t
|
193 |
+
print(f'Done in {prewhiten_time:.1f} seconds')
|
194 |
+
if log_writer is not None:
|
195 |
+
log_writer.add_scalar('time/prewhiten', prewhiten_time, epoch)
|
196 |
+
if do_postwhiten:
|
197 |
+
print(f'Re-initialization of post-whitening')
|
198 |
+
t = time.time()
|
199 |
+
with torch.no_grad():
|
200 |
+
features = []
|
201 |
+
for d in tqdm(loader):
|
202 |
+
backbone_feat = self.backbone._encode_image(d['img'][0, ...].to(device),
|
203 |
+
true_shape=d['true_shape'][0, ...])[0]
|
204 |
+
backbone_feat_prewhitened = self.prewhiten(backbone_feat)
|
205 |
+
proj_feat = self.projector(backbone_feat_prewhitened) + \
|
206 |
+
(0.0 if not self.residual else backbone_feat_prewhitened)
|
207 |
+
proj_feat = proj_feat.flatten(0, 1)
|
208 |
+
if max_nfeat_per_image is not None and max_nfeat_per_image < proj_feat.size(0):
|
209 |
+
l2norms = torch.linalg.vector_norm(proj_feat, dim=1)
|
210 |
+
proj_feat = proj_feat[torch.argsort(-l2norms)[:max_nfeat_per_image], :]
|
211 |
+
features.append(proj_feat.cpu())
|
212 |
+
features = torch.cat(features, dim=0)
|
213 |
+
features = features.numpy()
|
214 |
+
m, P = pcawhitenlearn_shrinkage(features)
|
215 |
+
self.postwhiten.load_state_dict({'m': torch.from_numpy(m), 'p': torch.from_numpy(P)})
|
216 |
+
postwhiten_time = time.time() - t
|
217 |
+
print(f'Done in {postwhiten_time:.1f} seconds')
|
218 |
+
if log_writer is not None:
|
219 |
+
log_writer.add_scalar('time/postwhiten', postwhiten_time, epoch)
|
220 |
+
|
221 |
+
def extract_features_and_attention(self, x):
|
222 |
+
backbone_feat = self.backbone._encode_image(x['img'], true_shape=x['true_shape'])[0]
|
223 |
+
backbone_feat_prewhitened = self.prewhiten(backbone_feat)
|
224 |
+
proj_feat = self.projector(backbone_feat_prewhitened) + \
|
225 |
+
(0.0 if not self.residual else backbone_feat_prewhitened)
|
226 |
+
attention = self.attention(proj_feat)
|
227 |
+
proj_feat_whitened = self.postwhiten(proj_feat)
|
228 |
+
return proj_feat_whitened, attention
|
229 |
+
|
230 |
+
def forward_local(self, x):
|
231 |
+
feat, attn = self.extract_features_and_attention(x)
|
232 |
+
return how_select_local(feat, attn, self.nfeat)
|
233 |
+
|
234 |
+
def forward_global(self, x):
|
235 |
+
feat, attn = self.extract_features_and_attention(x)
|
236 |
+
return weighted_spoc(feat, attn)
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
return self.forward_global(x)
|
240 |
+
|
241 |
+
|
242 |
+
def identity(x): # to avoid Can't pickle local object 'extract_local_features.<locals>.<lambda>'
|
243 |
+
return x
|
244 |
+
|
245 |
+
|
246 |
+
@torch.no_grad()
|
247 |
+
def extract_local_features(model, images, imsize, seed=0, tocpu=False, max_nfeat_per_image=None,
|
248 |
+
max_nfeat_per_image2=None, device=default_device):
|
249 |
+
model.eval()
|
250 |
+
imdataset = Dust3rInputFromImageList(images, imsize=imsize) if isinstance(images, list) else images
|
251 |
+
loader = torch.utils.data.DataLoader(imdataset, batch_size=1, shuffle=False,
|
252 |
+
num_workers=8, pin_memory=True, collate_fn=identity)
|
253 |
+
with torch.no_grad():
|
254 |
+
features = []
|
255 |
+
imids = []
|
256 |
+
for i, d in enumerate(tqdm(loader)):
|
257 |
+
dd = d[0]
|
258 |
+
dd['img'] = dd['img'].to(device, non_blocking=True)
|
259 |
+
feat, _, _ = model.forward_local(dd)
|
260 |
+
feat = feat.flatten(0, 1)
|
261 |
+
if max_nfeat_per_image is not None and feat.size(0) > max_nfeat_per_image:
|
262 |
+
feat = feat[torch.randperm(feat.size(0))[:max_nfeat_per_image], :]
|
263 |
+
if max_nfeat_per_image2 is not None and feat.size(0) > max_nfeat_per_image2:
|
264 |
+
feat = feat[:max_nfeat_per_image2, :]
|
265 |
+
features.append(feat)
|
266 |
+
if tocpu:
|
267 |
+
features[-1] = features[-1].cpu()
|
268 |
+
imids.append(i * torch.ones_like(features[-1][:, 0]).to(dtype=torch.int64))
|
269 |
+
features = torch.cat(features, dim=0)
|
270 |
+
imids = torch.cat(imids, dim=0)
|
271 |
+
return features, imids
|
mast3r/retrieval/processor.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Main Retriever class
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from mast3r.model import AsymmetricMASt3R
|
13 |
+
from mast3r.retrieval.model import RetrievalModel, extract_local_features
|
14 |
+
|
15 |
+
try:
|
16 |
+
import faiss
|
17 |
+
faiss.StandardGpuResources() # when loading the checkpoint, it will try to instanciate FaissGpuL2Index
|
18 |
+
except AttributeError as e:
|
19 |
+
import asmk.index
|
20 |
+
|
21 |
+
class FaissCpuL2Index(asmk.index.FaissL2Index):
|
22 |
+
def __init__(self, gpu_id):
|
23 |
+
super().__init__()
|
24 |
+
self.gpu_id = gpu_id
|
25 |
+
|
26 |
+
def _faiss_index_flat(self, dim):
|
27 |
+
"""Return initialized faiss.IndexFlatL2"""
|
28 |
+
return faiss.IndexFlatL2(dim)
|
29 |
+
|
30 |
+
asmk.index.FaissGpuL2Index = FaissCpuL2Index
|
31 |
+
|
32 |
+
from asmk import asmk_method # noqa
|
33 |
+
|
34 |
+
|
35 |
+
def get_args_parser():
|
36 |
+
parser = argparse.ArgumentParser('Retrieval scores from a set of retrieval', add_help=False, allow_abbrev=False)
|
37 |
+
parser.add_argument('--model', type=str, required=True,
|
38 |
+
help="shortname of a retrieval model or path to the corresponding .pth")
|
39 |
+
parser.add_argument('--input', type=str, required=True,
|
40 |
+
help="directory containing images or a file containing a list of image paths")
|
41 |
+
parser.add_argument('--outfile', type=str, required=True, help="numpy file where to store the matrix score")
|
42 |
+
return parser
|
43 |
+
|
44 |
+
|
45 |
+
def get_impaths(imlistfile):
|
46 |
+
with open(imlistfile, 'r') as fid:
|
47 |
+
impaths = [f for f in imlistfile.read().splitlines() if not f.startswith('#')
|
48 |
+
and len(f) > 0] # ignore comments and empty lines
|
49 |
+
return impaths
|
50 |
+
|
51 |
+
|
52 |
+
def get_impaths_from_imdir(imdir, extensions=['png', 'jpg', 'PNG', 'JPG']):
|
53 |
+
assert os.path.isdir(imdir)
|
54 |
+
impaths = [os.path.join(imdir, f) for f in sorted(os.listdir(imdir)) if any(f.endswith(ext) for ext in extensions)]
|
55 |
+
return impaths
|
56 |
+
|
57 |
+
|
58 |
+
def get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile):
|
59 |
+
if os.path.isfile(input_imdir_or_imlistfile):
|
60 |
+
return get_impaths(input_imdir_or_imlistfile)
|
61 |
+
else:
|
62 |
+
return get_impaths_from_imdir(input_imdir_or_imlistfile)
|
63 |
+
|
64 |
+
|
65 |
+
class Retriever(object):
|
66 |
+
def __init__(self, modelname, backbone=None, device='cuda'):
|
67 |
+
# load the model
|
68 |
+
assert os.path.isfile(modelname), modelname
|
69 |
+
print(f'Loading retrieval model from {modelname}')
|
70 |
+
ckpt = torch.load(modelname, 'cpu') # TODO from pretrained to download it automatically
|
71 |
+
ckpt_args = ckpt['args']
|
72 |
+
if backbone is None:
|
73 |
+
backbone = AsymmetricMASt3R.from_pretrained(ckpt_args.pretrained)
|
74 |
+
self.model = RetrievalModel(
|
75 |
+
backbone, freeze_backbone=ckpt_args.freeze_backbone, prewhiten=ckpt_args.prewhiten,
|
76 |
+
hdims=list(map(int, ckpt_args.hdims.split('_'))) if len(ckpt_args.hdims) > 0 else "",
|
77 |
+
residual=getattr(ckpt_args, 'residual', False), postwhiten=ckpt_args.postwhiten,
|
78 |
+
featweights=ckpt_args.featweights, nfeat=ckpt_args.nfeat
|
79 |
+
).to(device)
|
80 |
+
self.device = device
|
81 |
+
msg = self.model.load_state_dict(ckpt['model'], strict=False)
|
82 |
+
assert all(k.startswith('backbone') for k in msg.missing_keys)
|
83 |
+
assert len(msg.unexpected_keys) == 0
|
84 |
+
self.imsize = ckpt_args.imsize
|
85 |
+
|
86 |
+
# load the asmk codebook
|
87 |
+
dname, bname = os.path.split(modelname) # TODO they should both be in the same file ?
|
88 |
+
bname_splits = bname.split('_')
|
89 |
+
cache_codebook_fname = os.path.join(dname, '_'.join(bname_splits[:-1]) + '_codebook.pkl')
|
90 |
+
assert os.path.isfile(cache_codebook_fname), cache_codebook_fname
|
91 |
+
asmk_params = {'index': {'gpu_id': 0}, 'train_codebook': {'codebook': {'size': '64k'}},
|
92 |
+
'build_ivf': {'kernel': {'binary': True}, 'ivf': {'use_idf': False},
|
93 |
+
'quantize': {'multiple_assignment': 1}, 'aggregate': {}},
|
94 |
+
'query_ivf': {'quantize': {'multiple_assignment': 5}, 'aggregate': {},
|
95 |
+
'search': {'topk': None},
|
96 |
+
'similarity': {'similarity_threshold': 0.0, 'alpha': 3.0}}}
|
97 |
+
asmk_params['train_codebook']['codebook']['size'] = ckpt_args.nclusters
|
98 |
+
self.asmk = asmk_method.ASMKMethod.initialize_untrained(asmk_params)
|
99 |
+
self.asmk = self.asmk.train_codebook(None, cache_path=cache_codebook_fname)
|
100 |
+
|
101 |
+
def __call__(self, input_imdir_or_imlistfile, outfile=None):
|
102 |
+
# get impaths
|
103 |
+
if isinstance(input_imdir_or_imlistfile, str):
|
104 |
+
impaths = get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile)
|
105 |
+
else:
|
106 |
+
impaths = input_imdir_or_imlistfile # we're assuming a list has been passed
|
107 |
+
print(f'Found {len(impaths)} images')
|
108 |
+
|
109 |
+
# build the database
|
110 |
+
feat, ids = extract_local_features(self.model, impaths, self.imsize, tocpu=True, device=self.device)
|
111 |
+
feat = feat.cpu().numpy()
|
112 |
+
ids = ids.cpu().numpy()
|
113 |
+
asmk_dataset = self.asmk.build_ivf(feat, ids)
|
114 |
+
|
115 |
+
# we actually retrieve the same set of images
|
116 |
+
metadata, query_ids, ranks, ranked_scores = asmk_dataset.query_ivf(feat, ids)
|
117 |
+
|
118 |
+
# well ... scores are actually reordered according to ranks ...
|
119 |
+
# so we redo it the other way around...
|
120 |
+
scores = np.empty_like(ranked_scores)
|
121 |
+
scores[np.arange(ranked_scores.shape[0])[:, None], ranks] = ranked_scores
|
122 |
+
|
123 |
+
# save
|
124 |
+
if outfile is not None:
|
125 |
+
if os.path.isdir(os.path.dirname(outfile)):
|
126 |
+
os.makedirs(os.path.dirname(outfile), exist_ok=True)
|
127 |
+
np.save(outfile, scores)
|
128 |
+
print(f'Scores matrix saved in {outfile}')
|
129 |
+
return scores
|