yocabon commited on
Commit
35e2575
·
1 Parent(s): b1b5578

add initial version of mast3r sfm and glomap/colmap wrapper

Browse files
NOTICE CHANGED
@@ -101,3 +101,8 @@ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
101
  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
102
  POSSIBILITY OF SUCH DAMAGE.
103
 
 
 
 
 
 
 
101
  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
102
  POSSIBILITY OF SUCH DAMAGE.
103
 
104
+ ====
105
+ gtolias/how
106
+ https://github.com/gtolias/how
107
+
108
+ MIT License https://github.com/gtolias/how/blob/master/LICENSE
README.md CHANGED
@@ -78,7 +78,19 @@ pip install -r dust3r/requirements.txt
78
  pip install -r dust3r/requirements_optional.txt
79
  ```
80
 
81
- 3. Optional, compile the cuda kernels for RoPE (as in CroCo v2).
 
 
 
 
 
 
 
 
 
 
 
 
82
  ```bash
83
  # DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.
84
  cd dust3r/croco/models/curope/
@@ -86,9 +98,10 @@ python setup.py build_ext --inplace
86
  cd ../../../../
87
  ```
88
 
89
-
90
  ### Checkpoints
91
 
 
 
92
  You can obtain the checkpoints by two ways:
93
 
94
  1) You can use our huggingface_hub integration: the models will be downloaded automatically.
@@ -123,6 +136,7 @@ demo.py is the updated demo for MASt3R. It uses our new sparse global alignment
123
  python3 demo.py --model_name MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric
124
 
125
  # Use --weights to load a checkpoint from a local file, eg --weights checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth
 
126
  # Use --local_network to make it accessible on the local network, or --server_name to specify the url manually
127
  # Use --server_port to change the port, by default it will search for an available port starting at 7860
128
  # Use --device to use a different device, by default it's "cuda"
@@ -133,6 +147,8 @@ see https://github.com/naver/dust3r?tab=readme-ov-file#interactive-demo for deta
133
 
134
  ### Interactive demo with docker
135
 
 
 
136
  To run MASt3R using Docker, including with NVIDIA CUDA support, follow these instructions:
137
 
138
  1. **Install Docker**: If not already installed, download and install `docker` and `docker compose` from the [Docker website](https://www.docker.com/get-started).
 
78
  pip install -r dust3r/requirements_optional.txt
79
  ```
80
 
81
+ 3. compile and install ASMK
82
+ ```bash
83
+ pip install cython
84
+
85
+ git clone https://github.com/jenicek/asmk
86
+ cd asmk/cython/
87
+ cythonize *.pyx
88
+ cd ..
89
+ pip install .
90
+ cd ..
91
+ ```
92
+
93
+ 4. Optional, compile the cuda kernels for RoPE (as in CroCo v2).
94
  ```bash
95
  # DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.
96
  cd dust3r/croco/models/curope/
 
98
  cd ../../../../
99
  ```
100
 
 
101
  ### Checkpoints
102
 
103
+ TODO upload retrieval_model somewhere
104
+
105
  You can obtain the checkpoints by two ways:
106
 
107
  1) You can use our huggingface_hub integration: the models will be downloaded automatically.
 
136
  python3 demo.py --model_name MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric
137
 
138
  # Use --weights to load a checkpoint from a local file, eg --weights checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth
139
+ # Use --retrieval_model and point to the retrieval checkpoint to enable retrieval as a pairing strategy, asmk must be installed
140
  # Use --local_network to make it accessible on the local network, or --server_name to specify the url manually
141
  # Use --server_port to change the port, by default it will search for an available port starting at 7860
142
  # Use --device to use a different device, by default it's "cuda"
 
147
 
148
  ### Interactive demo with docker
149
 
150
+ TODO update with asmk/retrieval model
151
+
152
  To run MASt3R using Docker, including with NVIDIA CUDA support, follow these instructions:
153
 
154
  1. **Install Docker**: If not already installed, download and install `docker` and `docker compose` from the [Docker website](https://www.docker.com/get-started).
demo.py CHANGED
@@ -47,5 +47,5 @@ if __name__ == '__main__':
47
  with get_context(args.tmp_dir) as tmpdirname:
48
  cache_path = os.path.join(tmpdirname, chkpt_tag)
49
  os.makedirs(cache_path, exist_ok=True)
50
- main_demo(cache_path, model, args.device, args.image_size, server_name, args.server_port, silent=args.silent,
51
- share=args.share, gradio_delete_cache=args.gradio_delete_cache)
 
47
  with get_context(args.tmp_dir) as tmpdirname:
48
  cache_path = os.path.join(tmpdirname, chkpt_tag)
49
  os.makedirs(cache_path, exist_ok=True)
50
+ main_demo(cache_path, model, args.retrieval_model, args.device, args.image_size, server_name, args.server_port,
51
+ silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache)
demo_glomap.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # gradio demo executable
7
+ # --------------------------------------------------------
8
+ import pycolmap
9
+ import os
10
+ import torch
11
+ import tempfile
12
+ from contextlib import nullcontext
13
+
14
+ from mast3r.demo_glomap import get_args_parser, main_demo
15
+
16
+ from mast3r.model import AsymmetricMASt3R
17
+ from mast3r.utils.misc import hash_md5
18
+
19
+ import mast3r.utils.path_to_dust3r # noqa
20
+ from dust3r.demo import set_print_with_timestamp
21
+
22
+ import matplotlib.pyplot as pl
23
+ pl.ion()
24
+
25
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
26
+
27
+ if __name__ == '__main__':
28
+ parser = get_args_parser()
29
+ args = parser.parse_args()
30
+ set_print_with_timestamp()
31
+
32
+ if args.server_name is not None:
33
+ server_name = args.server_name
34
+ else:
35
+ server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
36
+
37
+ if args.weights is not None:
38
+ weights_path = args.weights
39
+ else:
40
+ weights_path = "naver/" + args.model_name
41
+
42
+ model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
43
+ chkpt_tag = hash_md5(weights_path)
44
+
45
+ def get_context(tmp_dir):
46
+ return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
47
+ else nullcontext(tmp_dir)
48
+ with get_context(args.tmp_dir) as tmpdirname:
49
+ cache_path = os.path.join(tmpdirname, chkpt_tag)
50
+ os.makedirs(cache_path, exist_ok=True)
51
+ main_demo(args.glomap_bin, cache_path, model, args.retrieval_model, args.device, args.image_size, server_name,
52
+ args.server_port, silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache)
dust3r CHANGED
@@ -1 +1 @@
1
- Subproject commit 9869e71f9165aa53c53ec0979cea1122a569ade4
 
1
+ Subproject commit c9e9336a6ba7c1f1873f9295852cea6dffaf770d
kapture_mast3r_mapping.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # mast3r exec for running standard sfm
7
+ # --------------------------------------------------------
8
+ import pycolmap
9
+ import os
10
+ import os.path as path
11
+ import argparse
12
+
13
+ from mast3r.model import AsymmetricMASt3R
14
+ from mast3r.colmap.mapping import (kapture_import_image_folder_or_list, run_mast3r_matching, pycolmap_run_triangulator,
15
+ pycolmap_run_mapper, glomap_run_mapper)
16
+ from kapture.io.csv import kapture_from_dir
17
+
18
+ from kapture.converter.colmap.database_extra import kapture_to_colmap, generate_priors_for_reconstruction
19
+ from kapture_localization.utils.pairsfile import get_pairs_from_file
20
+ from kapture.io.records import get_image_fullpath
21
+ from kapture.converter.colmap.database import COLMAPDatabase
22
+
23
+
24
+ def get_argparser():
25
+ parser = argparse.ArgumentParser(description='point triangulator with mast3r from kapture data')
26
+ parser_weights = parser.add_mutually_exclusive_group(required=True)
27
+ parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None)
28
+ parser_weights.add_argument("--model_name", type=str, help="name of the model weights",
29
+ choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"])
30
+
31
+ parser_input = parser.add_mutually_exclusive_group(required=True)
32
+ parser_input.add_argument('-i', '--input', default=None, help='kdata')
33
+ parser_input.add_argument('--dir', default=None, help='image dir (individual intrinsics)')
34
+ parser_input.add_argument('--dir_same_camera', default=None, help='image dir (shared intrinsics)')
35
+
36
+ parser.add_argument('-o', '--output', required=True, help='output path to reconstruction')
37
+ parser.add_argument('--pairsfile_path', required=True, help='pairsfile')
38
+
39
+ parser.add_argument('--glomap_bin', default='glomap', type=str, help='glomap bin')
40
+
41
+ parser_mapper = parser.add_mutually_exclusive_group()
42
+ parser_mapper.add_argument('--ignore_pose', action='store_true', default=False)
43
+ parser_mapper.add_argument('--use_glomap_mapper', action='store_true', default=False)
44
+
45
+ parser_matching = parser.add_mutually_exclusive_group()
46
+ parser_matching.add_argument('--dense_matching', action='store_true', default=False)
47
+ parser_matching.add_argument('--pixel_tol', default=0, type=int)
48
+ parser.add_argument('--device', default='cuda')
49
+
50
+ parser.add_argument('--conf_thr', default=1.001, type=float)
51
+ parser.add_argument('--skip_geometric_verification', action='store_true', default=False)
52
+ parser.add_argument('--min_len_track', default=5, type=int)
53
+
54
+ return parser
55
+
56
+
57
+ if __name__ == '__main__':
58
+ parser = get_argparser()
59
+ args = parser.parse_args()
60
+ if args.weights is not None:
61
+ weights_path = args.weights
62
+ else:
63
+ weights_path = "naver/" + args.model_name
64
+ model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
65
+ maxdim = max(model.patch_embed.img_size)
66
+ patch_size = model.patch_embed.patch_size
67
+
68
+ if args.input is not None:
69
+ kdata = kapture_from_dir(args.input)
70
+ records_data_path = get_image_fullpath(args.input)
71
+ else:
72
+ if args.dir_same_camera is not None:
73
+ use_single_camera = True
74
+ records_data_path = args.dir_same_camera
75
+ elif args.dir is not None:
76
+ use_single_camera = False
77
+ records_data_path = args.dir
78
+ else:
79
+ raise ValueError('all inputs choices are None')
80
+ kdata = kapture_import_image_folder_or_list(records_data_path, use_single_camera)
81
+ has_pose = kdata.trajectories is not None
82
+ image_pairs = get_pairs_from_file(args.pairsfile_path, kdata.records_camera, kdata.records_camera)
83
+
84
+ colmap_db_path = path.join(args.output, 'colmap.db')
85
+ reconstruction_path = path.join(args.output, "reconstruction")
86
+ priors_txt_path = path.join(args.output, "priors_for_reconstruction")
87
+ for path_i in [reconstruction_path, priors_txt_path]:
88
+ os.makedirs(path_i, exist_ok=True)
89
+ assert not os.path.isfile(colmap_db_path)
90
+
91
+ colmap_db = COLMAPDatabase.connect(colmap_db_path)
92
+ try:
93
+ kapture_to_colmap(kdata, args.input, tar_handler=None, database=colmap_db,
94
+ keypoints_type=None, descriptors_type=None, export_two_view_geometry=False)
95
+ if has_pose:
96
+ generate_priors_for_reconstruction(kdata, colmap_db, priors_txt_path)
97
+
98
+ colmap_image_pairs = run_mast3r_matching(model, maxdim, patch_size, args.device,
99
+ kdata, records_data_path, image_pairs, colmap_db,
100
+ args.dense_matching, args.pixel_tol, args.conf_thr,
101
+ args.skip_geometric_verification, args.min_len_track)
102
+ colmap_db.close()
103
+ except Exception as e:
104
+ print(f'Error {e}')
105
+ colmap_db.close()
106
+ exit(1)
107
+
108
+ if len(colmap_image_pairs) == 0:
109
+ raise Exception("no matches were kept")
110
+
111
+ # colmap db is now full, run colmap
112
+ colmap_world_to_cam = {}
113
+ if not args.skip_geometric_verification:
114
+ print("verify_matches")
115
+ f = open(args.output + '/pairs.txt', "w")
116
+ for image_path1, image_path2 in colmap_image_pairs:
117
+ f.write("{} {}\n".format(image_path1, image_path2))
118
+ f.close()
119
+ pycolmap.verify_matches(colmap_db_path, args.output + '/pairs.txt')
120
+
121
+ print("running mapping")
122
+ if has_pose and not args.ignore_pose and not args.use_glomap_mapper:
123
+ pycolmap_run_triangulator(colmap_db_path, priors_txt_path, reconstruction_path, records_data_path)
124
+ elif not args.use_glomap_mapper:
125
+ pycolmap_run_mapper(colmap_db_path, reconstruction_path, records_data_path)
126
+ else:
127
+ glomap_run_mapper(args.glomap_bin, colmap_db_path, reconstruction_path, records_data_path)
make_pairs.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # make pairs using mast3r scene_graph, including retrieval
7
+ # --------------------------------------------------------
8
+ import argparse
9
+ import torch
10
+ import os
11
+ import os.path as path
12
+ import PIL
13
+ from PIL import Image
14
+ import pathlib
15
+ from kapture.io.csv import table_to_file
16
+
17
+ from mast3r.model import AsymmetricMASt3R
18
+ from mast3r.retrieval.processor import Retriever
19
+ from mast3r.image_pairs import make_pairs
20
+
21
+
22
+ def get_argparser():
23
+ parser = argparse.ArgumentParser(description='point triangulator with mast3r from kapture data')
24
+ parser.add_argument('--dir', required=True, help='image dir')
25
+ parser.add_argument('--scene_graph', default='retrieval-20-1-10-1')
26
+ parser.add_argument('--output', required=True, help='txt file')
27
+
28
+ parser_weights = parser.add_mutually_exclusive_group(required=False)
29
+ parser_weights.add_argument("--weights", type=str, help="path to the model weights", default=None)
30
+ parser_weights.add_argument("--model_name", type=str, help="name of the model weights",
31
+ choices=["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"])
32
+ parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded")
33
+
34
+ return parser
35
+
36
+
37
+ def get_image_list(images_path):
38
+ file_list = [path.relpath(path.join(dirpath, filename), images_path)
39
+ for dirpath, dirs, filenames in os.walk(images_path)
40
+ for filename in filenames]
41
+ file_list = sorted(file_list)
42
+ image_list = []
43
+ for filename in file_list:
44
+ # test if file is a valid image
45
+ try:
46
+ # lazy load
47
+ with Image.open(path.join(images_path, filename)) as im:
48
+ width, height = im.size
49
+ image_list.append(filename)
50
+ except (OSError, PIL.UnidentifiedImageError):
51
+ # It is not a valid image: skip it
52
+ print(f'Skipping invalid image file {filename}')
53
+ continue
54
+ return image_list
55
+
56
+
57
+ def main(dir, scene_graph, output, backbone=None, retrieval_model=None):
58
+ imgs = get_image_list(dir)
59
+
60
+ sim_matrix = None
61
+ if 'retrieval' in scene_graph:
62
+ retriever = Retriever(retrieval_model, backbone=backbone)
63
+ imgs_fp = [path.join(dir, filename) for filename in imgs]
64
+ with torch.no_grad():
65
+ sim_matrix = retriever(imgs_fp)
66
+
67
+ # Cleanup
68
+ del retriever
69
+ torch.cuda.empty_cache()
70
+
71
+ pairs = make_pairs(imgs, scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix)
72
+
73
+ os.umask(0o002)
74
+ p = pathlib.Path(output)
75
+ os.makedirs(str(p.parent.resolve()), exist_ok=True)
76
+
77
+ with open(output, 'w') as fid:
78
+ table_to_file(fid, pairs, header='# query_image, map_image, score')
79
+
80
+
81
+ if __name__ == '__main__':
82
+ parser = get_argparser()
83
+ args = parser.parse_args()
84
+
85
+ if "retrieval" in args.scene_graph:
86
+ assert args.retrieval_model is not None
87
+ if args.weights is not None:
88
+ weights_path = args.weights
89
+ else:
90
+ weights_path = "naver/" + args.model_name
91
+ backbone = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
92
+ retrieval_model = args.retrieval_model
93
+ else:
94
+ backbone = None
95
+ retrieval_model = None
96
+ main(args.dir, args.scene_graph, args.output, backbone, retrieval_model)
mast3r/catmlp_dpt_head.py CHANGED
@@ -5,6 +5,7 @@
5
  # MASt3R heads
6
  # --------------------------------------------------------
7
  import torch
 
8
  import torch.nn.functional as F
9
 
10
  import mast3r.utils.path_to_dust3r # noqa
@@ -12,6 +13,7 @@ from dust3r.heads.postprocess import reg_dense_depth, reg_dense_conf # noqa
12
  from dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa
13
  import dust3r.utils.path_to_croco # noqa
14
  from models.blocks import Mlp # noqa
 
15
 
16
 
17
  def reg_desc(desc, mode):
@@ -96,6 +98,113 @@ class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT):
96
  return out
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def mast3r_head_factory(head_type, output_mode, net, has_conf=False):
100
  """" build a prediction head for the decoder
101
  """
@@ -118,6 +227,13 @@ def mast3r_head_factory(head_type, output_mode, net, has_conf=False):
118
  depth_mode=net.depth_mode,
119
  conf_mode=net.conf_mode,
120
  head_type='regression')
 
 
 
 
 
 
 
121
  else:
122
  raise NotImplementedError(
123
  f"unexpected {head_type=} and {output_mode=}")
 
5
  # MASt3R heads
6
  # --------------------------------------------------------
7
  import torch
8
+ import torch.nn as nn
9
  import torch.nn.functional as F
10
 
11
  import mast3r.utils.path_to_dust3r # noqa
 
13
  from dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa
14
  import dust3r.utils.path_to_croco # noqa
15
  from models.blocks import Mlp # noqa
16
+ from models.dpt_block import Interpolate # noqa
17
 
18
 
19
  def reg_desc(desc, mode):
 
98
  return out
99
 
100
 
101
+ class MLP_MiniConv_Head(nn.Module):
102
+ """
103
+ A special Convolutional head inspired by DPT architecture
104
+ A MLP predicts pixelwise feats in lower resolution. Prediction is upsampled to target res and goes through a mini convolutional head
105
+
106
+ Input : [B, S, D] # S = (H//p) * (W//p)
107
+
108
+ MLP:
109
+ D -> (mlp_hidden_dim) -> out_mlp_dim * (p/2)*2
110
+ reshape to [out_mlp_dim, H/2, W/2] (MLP predicts in half-res)
111
+
112
+ MiniConv head from DPT:
113
+ Upsample x2 -> [out_mlp_dim,H,W]
114
+ Conv 3x3 -> [conv_inner_dim,H,W]
115
+ ReLU
116
+ Conv 1x1 -> [odim,H,W]
117
+
118
+ """
119
+
120
+ def __init__(self, idim, mlp_hidden_dim, mlp_odim, conv_inner_dim, odim, patch_size, subpatch=2, **kw):
121
+ super().__init__()
122
+ self.patch_size = patch_size
123
+ self.subpatch = subpatch
124
+ self.sub_patch_size = patch_size // subpatch
125
+ self.mlp = Mlp(idim, mlp_hidden_dim, mlp_odim * self.sub_patch_size**2, **kw) # D -> mlp_odim*sub_patch_size**2
126
+
127
+ # DPT conv head
128
+ self.head = nn.Sequential(Interpolate(scale_factor=self.subpatch, mode="bilinear", align_corners=True) if self.subpatch != 1 else nn.Identity(),
129
+ nn.Conv2d(mlp_odim, conv_inner_dim, kernel_size=3, stride=1, padding=1),
130
+ nn.ReLU(True),
131
+ nn.Conv2d(conv_inner_dim, odim, kernel_size=1, stride=1, padding=0)
132
+ )
133
+
134
+ def forward(self, decout, img_shape):
135
+ H, W = img_shape
136
+ tokens = decout[-1]
137
+ B, S, D = tokens.shape
138
+ # extract features
139
+ feat = self.mlp(tokens) # [B, S, mlp_odim*sub_patch_size**2]
140
+ feat = feat.transpose(-1, -2).reshape(B, -1, H // self.patch_size, W // self.patch_size)
141
+ feat = F.pixel_shuffle(feat, self.sub_patch_size) # B,mlp_odim,H/sub,W/sub
142
+
143
+ return self.head(feat) # B, odim, H, W
144
+
145
+
146
+ class Cat_MLP_LocalFeatures_MiniConv_Pts3d(nn.Module):
147
+ """ Mixture between MLP and MLP-Convolutional head that outputs 3d points (with miniconv) and local features (with MLP).
148
+ simply contains two MLP_MiniConv_Head: one for 3D points and one for features.
149
+ The input for both heads is a concatenation of Encoder and Decoder outputs
150
+ """
151
+
152
+ def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., mlp_odim=24, conv_inner_dim=100, subpatch=2, **kw):
153
+ super().__init__()
154
+
155
+ self.local_feat_dim = local_feat_dim
156
+ patch_size = net.patch_embed.patch_size
157
+ if isinstance(patch_size, tuple):
158
+ assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance(
159
+ patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints."
160
+ assert patch_size[0] == patch_size[1], "Error, non square patches not managed"
161
+ patch_size = patch_size[0]
162
+ self.patch_size = patch_size
163
+
164
+ self.depth_mode = net.depth_mode
165
+ self.conf_mode = net.conf_mode
166
+ self.desc_mode = net.desc_mode
167
+ self.desc_conf_mode = net.desc_conf_mode
168
+ self.has_conf = has_conf
169
+ self.two_confs = net.two_confs # independent confs for 3D regr and descs
170
+ idim = net.enc_embed_dim + net.dec_embed_dim
171
+ self.head_pts3d = MLP_MiniConv_Head(idim=idim,
172
+ mlp_hidden_dim=int(hidden_dim_factor * idim),
173
+ mlp_odim=mlp_odim + self.has_conf,
174
+ conv_inner_dim=conv_inner_dim,
175
+ odim=3 + self.has_conf,
176
+ subpatch=subpatch,
177
+ patch_size=self.patch_size,
178
+ **kw)
179
+
180
+ self.head_local_features = Mlp(in_features=idim,
181
+ hidden_features=int(hidden_dim_factor * idim),
182
+ out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2)
183
+
184
+ def forward(self, decout, img_shape):
185
+ enc_output, dec_output = decout[0], decout[-1] # recover encoder and decoder outputs
186
+ cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate
187
+ # pass through the heads
188
+ pts3d = self.head_pts3d([cat_output], img_shape)
189
+
190
+ H, W = img_shape
191
+ B, S, D = cat_output.shape
192
+
193
+ # extract 3D points
194
+ local_features = self.head_local_features(cat_output) # B,S,D
195
+ local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size)
196
+ local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W
197
+
198
+ # post process 3D pts, descriptors and confidences
199
+ out = postprocess(torch.cat([pts3d, local_features], dim=1),
200
+ depth_mode=self.depth_mode,
201
+ conf_mode=self.conf_mode,
202
+ desc_dim=self.local_feat_dim,
203
+ desc_mode=self.desc_mode,
204
+ two_confs=self.two_confs, desc_conf_mode=self.desc_conf_mode)
205
+ return out
206
+
207
+
208
  def mast3r_head_factory(head_type, output_mode, net, has_conf=False):
209
  """" build a prediction head for the decoder
210
  """
 
227
  depth_mode=net.depth_mode,
228
  conf_mode=net.conf_mode,
229
  head_type='regression')
230
+ elif head_type == 'catconv' and output_mode.startswith('pts3d+desc'):
231
+ local_feat_dim = int(output_mode[10:])
232
+ # more params (anounced by a ':' and comma separated)
233
+ kw = {}
234
+ if ':' in head_type:
235
+ kw = eval("dict(" + head_type[8:] + ")")
236
+ return Cat_MLP_LocalFeatures_MiniConv_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf, **kw)
237
  else:
238
  raise NotImplementedError(
239
  f"unexpected {head_type=} and {output_mode=}")
mast3r/cloud_opt/sparse_ga.py CHANGED
@@ -15,6 +15,7 @@ from collections import namedtuple
15
  from functools import lru_cache
16
  from scipy import sparse as sp
17
  import copy
 
18
 
19
  from mast3r.utils.misc import mkdir_for, hash_md5
20
  from mast3r.cloud_opt.utils.losses import gamma_loss
@@ -116,7 +117,7 @@ def convert_dust3r_pairs_naming(imgs, pairs_in):
116
 
117
 
118
  def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
119
- device='cuda', dtype=torch.float32, shared_intrinsics=False, **kw):
120
  """ Sparse alignment with MASt3R
121
  imgs: list of image paths
122
  cache_path: path where to dump temporary files (str)
@@ -137,17 +138,54 @@ def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc
137
  tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \
138
  prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device)
139
 
140
- # compute minimal spanning tree
141
- mst = compute_min_spanning_tree(pairwise_scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # remove all edges not in the spanning tree?
144
  # min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]}
145
  # tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree}
146
 
147
- # smartly combine all useful data
148
- imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21 = \
149
- condense_data(imgs, tmp_pairs, canonical_views, preds_21, dtype)
150
-
151
  imgs, res_coarse, res_fine = sparse_scene_optimizer(
152
  imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, mst,
153
  shared_intrinsics=shared_intrinsics, cache_path=cache_path, device=device, dtype=dtype, **kw)
@@ -157,8 +195,8 @@ def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc
157
 
158
  def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d,
159
  preds_21, canonical_paths, mst, cache_path,
160
- lr1=0.2, niter1=500, loss1=gamma_loss(1.1),
161
- lr2=0.02, niter2=500, loss2=gamma_loss(0.4),
162
  lossd=gamma_loss(1.1),
163
  opt_pp=True, opt_depth=True,
164
  schedule=cosine_schedule, depth_mode='add', exp_depth=False,
 
15
  from functools import lru_cache
16
  from scipy import sparse as sp
17
  import copy
18
+ import scipy.cluster.hierarchy as sch
19
 
20
  from mast3r.utils.misc import mkdir_for, hash_md5
21
  from mast3r.cloud_opt.utils.losses import gamma_loss
 
117
 
118
 
119
  def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
120
+ kinematic_mode='hclust-ward', device='cuda', dtype=torch.float32, shared_intrinsics=False, **kw):
121
  """ Sparse alignment with MASt3R
122
  imgs: list of image paths
123
  cache_path: path where to dump temporary files (str)
 
138
  tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \
139
  prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device)
140
 
141
+ # smartly combine all useful data
142
+ imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21 = \
143
+ condense_data(imgs, tmp_pairs, canonical_views, preds_21, dtype)
144
+
145
+ # Build kinematic chain
146
+ if kinematic_mode == 'mst':
147
+ # compute minimal spanning tree
148
+ mst = compute_min_spanning_tree(pairwise_scores)
149
+
150
+ elif kinematic_mode.startswith('hclust'):
151
+ mode, linkage = kinematic_mode.split('-')
152
+
153
+ # Convert the affinity matrix to a distance matrix (if needed)
154
+ n_patches = (imsizes // subsample).prod(dim=1)
155
+ max_n_corres = 3 * torch.minimum(n_patches[:,None], n_patches[None,:])
156
+ pws = (pairwise_scores.clone() / max_n_corres).clip(max=1)
157
+ pws.fill_diagonal_(1)
158
+ pws = to_numpy(pws)
159
+ distance_matrix = np.where(pws, 1 - pws, 2)
160
+
161
+ # Compute the condensed distance matrix
162
+ condensed_distance_matrix = sch.distance.squareform(distance_matrix)
163
+
164
+ # Perform hierarchical clustering using the linkage method
165
+ Z = sch.linkage(condensed_distance_matrix, method=linkage)
166
+ # dendrogram = sch.dendrogram(Z)
167
+
168
+ tree = np.eye(len(imgs))
169
+ new_to_old_nodes = {i:i for i in range(len(imgs))}
170
+ for i, (a, b) in enumerate(Z[:,:2].astype(int)):
171
+ # given two nodes to be merged, we choose which one is the best representant
172
+ a = new_to_old_nodes[a]
173
+ b = new_to_old_nodes[b]
174
+ tree[a,b] = tree[b,a] = 1
175
+ best = a if pws[a].sum() > pws[b].sum() else b
176
+ new_to_old_nodes[len(imgs)+i] = best
177
+ pws[best] = np.maximum(pws[a], pws[b]) # update the node
178
+
179
+ pairwise_scores = torch.from_numpy(tree) # this output just gives 1s for connected edges and zeros for other, i.e. no scores or priority
180
+ mst = compute_min_spanning_tree(pairwise_scores)
181
+
182
+ else:
183
+ raise ValueError(f'bad {kinematic_mode=}')
184
 
185
  # remove all edges not in the spanning tree?
186
  # min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]}
187
  # tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree}
188
 
 
 
 
 
189
  imgs, res_coarse, res_fine = sparse_scene_optimizer(
190
  imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, mst,
191
  shared_intrinsics=shared_intrinsics, cache_path=cache_path, device=device, dtype=dtype, **kw)
 
195
 
196
  def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d,
197
  preds_21, canonical_paths, mst, cache_path,
198
+ lr1=0.07, niter1=300, loss1=gamma_loss(1.5),
199
+ lr2=0.01, niter2=300, loss2=gamma_loss(0.5),
200
  lossd=gamma_loss(1.1),
201
  opt_pp=True, opt_depth=True,
202
  schedule=cosine_schedule, depth_mode='add', exp_depth=False,
mast3r/colmap/mapping.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # colmap mapper/colmap point_triangulator/glomap mapper from mast3r matches
6
+ # --------------------------------------------------------
7
+ import pycolmap
8
+ import os
9
+ import os.path as path
10
+ import kapture.io
11
+ import kapture.io.csv
12
+ import subprocess
13
+ import PIL
14
+ from tqdm import tqdm
15
+ import PIL.Image
16
+ import numpy as np
17
+ from typing import List, Tuple, Union
18
+
19
+ from mast3r.model import AsymmetricMASt3R
20
+ from mast3r.colmap.database import export_matches, get_im_matches
21
+
22
+ import mast3r.utils.path_to_dust3r # noqa
23
+ from dust3r_visloc.datasets.utils import get_resize_function
24
+
25
+ import kapture
26
+ from kapture.converter.colmap.database_extra import get_colmap_camera_ids_from_db, get_colmap_image_ids_from_db
27
+ from kapture.utils.paths import path_secure
28
+
29
+ from dust3r.datasets.utils.transforms import ImgNorm
30
+ from dust3r.inference import inference
31
+
32
+
33
+ def scene_prepare_images(root: str, maxdim: int, patch_size: int, image_paths: List[str]):
34
+ images = []
35
+ # image loading
36
+ for idx in tqdm(range(len(image_paths))):
37
+ rgb_image = PIL.Image.open(os.path.join(root, image_paths[idx])).convert('RGB')
38
+
39
+ # resize images
40
+ W, H = rgb_image.size
41
+ resize_func, _, to_orig = get_resize_function(maxdim, patch_size, H, W)
42
+ rgb_tensor = resize_func(ImgNorm(rgb_image))
43
+
44
+ # image dictionary
45
+ images.append({'img': rgb_tensor.unsqueeze(0),
46
+ 'true_shape': np.int32([rgb_tensor.shape[1:]]),
47
+ 'to_orig': to_orig,
48
+ 'idx': idx,
49
+ 'instance': image_paths[idx],
50
+ 'orig_shape': np.int32([H, W])})
51
+ return images
52
+
53
+
54
+ def remove_duplicates(images, image_pairs):
55
+ pairs_added = set()
56
+ pairs = []
57
+ for (i, _), (j, _) in image_pairs:
58
+ smallidx, bigidx = min(i, j), max(i, j)
59
+ if (smallidx, bigidx) in pairs_added:
60
+ continue
61
+ pairs_added.add((smallidx, bigidx))
62
+ pairs.append((images[i], images[j]))
63
+ return pairs
64
+
65
+
66
+ def run_mast3r_matching(model: AsymmetricMASt3R, maxdim: int, patch_size: int, device,
67
+ kdata: kapture.Kapture, root_path: str, image_pairs_kapture: List[Tuple[str, str]],
68
+ colmap_db,
69
+ dense_matching: bool, pixel_tol: int, conf_thr: float, skip_geometric_verification: bool,
70
+ min_len_track: int):
71
+ assert kdata.records_camera is not None
72
+ image_paths = kdata.records_camera.data_list()
73
+ image_path_to_idx = {image_path: idx for idx, image_path in enumerate(image_paths)}
74
+ image_path_to_ts = {kdata.records_camera[ts, camid]: (ts, camid) for ts, camid in kdata.records_camera.key_pairs()}
75
+
76
+ images = scene_prepare_images(root_path, maxdim, patch_size, image_paths)
77
+ image_pairs = [((image_path_to_idx[image_path1], image_path1), (image_path_to_idx[image_path2], image_path2))
78
+ for image_path1, image_path2 in image_pairs_kapture]
79
+ matching_pairs = remove_duplicates(images, image_pairs)
80
+
81
+ colmap_camera_ids = get_colmap_camera_ids_from_db(colmap_db, kdata.records_camera)
82
+ colmap_image_ids = get_colmap_image_ids_from_db(colmap_db)
83
+ im_keypoints = {idx: {} for idx in range(len(image_paths))}
84
+
85
+ im_matches = {}
86
+ image_to_colmap = {}
87
+ for image_path, idx in image_path_to_idx.items():
88
+ _, camid = image_path_to_ts[image_path]
89
+ colmap_camid = colmap_camera_ids[camid]
90
+ colmap_imid = colmap_image_ids[image_path]
91
+ image_to_colmap[idx] = {
92
+ 'colmap_imid': colmap_imid,
93
+ 'colmap_camid': colmap_camid
94
+ }
95
+
96
+ # compute 2D-2D matching from dust3r inference
97
+ for chunk in tqdm(range(0, len(matching_pairs), 4)):
98
+ pairs_chunk = matching_pairs[chunk:chunk + 4]
99
+ output = inference(pairs_chunk, model, device, batch_size=1, verbose=False)
100
+ pred1, pred2 = output['pred1'], output['pred2']
101
+ # TODO handle caching
102
+ im_images_chunk = get_im_matches(pred1, pred2, pairs_chunk, image_to_colmap,
103
+ im_keypoints, conf_thr, not dense_matching, pixel_tol)
104
+ im_matches.update(im_images_chunk.items())
105
+
106
+ # filter matches, convert them and export keypoints and matches to colmap db
107
+ colmap_image_pairs = export_matches(
108
+ colmap_db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification)
109
+ colmap_db.commit()
110
+
111
+ return colmap_image_pairs
112
+
113
+
114
+ def pycolmap_run_triangulator(colmap_db_path, prior_recon_path, recon_path, image_root_path):
115
+ print("running mapping")
116
+ reconstruction = pycolmap.Reconstruction(prior_recon_path)
117
+ pycolmap.triangulate_points(
118
+ reconstruction=reconstruction,
119
+ database_path=colmap_db_path,
120
+ image_path=image_root_path,
121
+ output_path=recon_path,
122
+ refine_intrinsics=False,
123
+ )
124
+
125
+
126
+ def pycolmap_run_mapper(colmap_db_path, recon_path, image_root_path):
127
+ print("running mapping")
128
+ reconstructions = pycolmap.incremental_mapping(
129
+ database_path=colmap_db_path,
130
+ image_path=image_root_path,
131
+ output_path=recon_path,
132
+ options=pycolmap.IncrementalPipelineOptions({'multiple_models': False,
133
+ 'extract_colors': True,
134
+ })
135
+ )
136
+
137
+
138
+ def glomap_run_mapper(glomap_bin, colmap_db_path, recon_path, image_root_path):
139
+ print("running mapping")
140
+ args = [
141
+ 'mapper',
142
+ '--database_path',
143
+ colmap_db_path,
144
+ '--image_path',
145
+ image_root_path,
146
+ '--output_path',
147
+ recon_path
148
+ ]
149
+ args.insert(0, glomap_bin)
150
+ glomap_process = subprocess.Popen(args)
151
+ glomap_process.wait()
152
+
153
+ if glomap_process.returncode != 0:
154
+ raise ValueError(
155
+ '\nSubprocess Error (Return code:'
156
+ f' {glomap_process.returncode} )')
157
+
158
+
159
+ def kapture_import_image_folder_or_list(images_path: Union[str, Tuple[str, List[str]]], use_single_camera=False) -> kapture.Kapture:
160
+ images = kapture.RecordsCamera()
161
+
162
+ if isinstance(images_path, str):
163
+ images_root = images_path
164
+ file_list = [path.relpath(path.join(dirpath, filename), images_root)
165
+ for dirpath, dirs, filenames in os.walk(images_root)
166
+ for filename in filenames]
167
+ file_list = sorted(file_list)
168
+ else:
169
+ images_root, file_list = images_path
170
+
171
+ sensors = kapture.Sensors()
172
+ for n, filename in enumerate(file_list):
173
+ # test if file is a valid image
174
+ try:
175
+ # lazy load
176
+ with PIL.Image.open(path.join(images_root, filename)) as im:
177
+ width, height = im.size
178
+ model_params = [width, height]
179
+ except (OSError, PIL.UnidentifiedImageError):
180
+ # It is not a valid image: skip it
181
+ print(f'Skipping invalid image file {filename}')
182
+ continue
183
+
184
+ camera_id = f'sensor'
185
+ if use_single_camera and camera_id not in sensors:
186
+ sensors[camera_id] = kapture.Camera(kapture.CameraType.UNKNOWN_CAMERA, model_params)
187
+ elif use_single_camera:
188
+ assert sensors[camera_id].camera_params[0] == width and sensors[camera_id].camera_params[1] == height
189
+ else:
190
+ camera_id = camera_id + f'{n}'
191
+ sensors[camera_id] = kapture.Camera(kapture.CameraType.UNKNOWN_CAMERA, model_params)
192
+
193
+ images[(n, camera_id)] = path_secure(filename) # don't forget windows
194
+
195
+ return kapture.Kapture(sensors=sensors, records_camera=images)
mast3r/demo.py CHANGED
@@ -15,12 +15,14 @@ import copy
15
  from scipy.spatial.transform import Rotation
16
  import tempfile
17
  import shutil
 
18
 
19
  from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
20
  from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
 
 
21
 
22
  import mast3r.utils.path_to_dust3r # noqa
23
- from dust3r.image_pairs import make_pairs
24
  from dust3r.utils.image import load_images
25
  from dust3r.utils.device import to_numpy
26
  from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
@@ -29,7 +31,7 @@ from dust3r.demo import get_args_parser as dust3r_get_args_parser
29
  import matplotlib.pyplot as pl
30
 
31
 
32
- class SparseGAState():
33
  def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
34
  self.sparse_ga = sparse_ga
35
  self.cache_dir = cache_dir
@@ -52,6 +54,7 @@ def get_args_parser():
52
  parser.add_argument('--share', action='store_true')
53
  parser.add_argument('--gradio_delete_cache', default=None, type=int,
54
  help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
 
55
 
56
  actions = parser._actions
57
  for action in actions:
@@ -136,10 +139,10 @@ def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=F
136
  transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
137
 
138
 
139
- def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent, image_size, current_scene_state,
140
- filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
141
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
142
- win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
143
  """
144
  from a list of images, run mast3r inference, sparse global aligner.
145
  then run get_3D_model_from_scene
@@ -155,10 +158,26 @@ def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent,
155
  scene_graph_params.append(str(winsize))
156
  elif scenegraph_type == "oneref":
157
  scene_graph_params.append(str(refid))
 
 
 
 
158
  if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
159
  scene_graph_params.append('noncyclic')
160
  scene_graph = '-'.join(scene_graph_params)
161
- pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
 
 
 
 
 
 
 
 
 
 
 
 
162
  if optim_level == 'coarse':
163
  niter2 = 0
164
  # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
@@ -190,39 +209,66 @@ def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent,
190
 
191
  def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
192
  num_files = len(inputfiles) if inputfiles is not None else 1
193
- show_win_controls = scenegraph_type in ["swin", "logwin"]
194
- show_winsize = scenegraph_type in ["swin", "logwin"]
195
- show_cyclic = scenegraph_type in ["swin", "logwin"]
196
  max_winsize, min_winsize = 1, 1
197
- if scenegraph_type == "swin":
198
- if win_cyclic:
199
- max_winsize = max(1, math.ceil((num_files - 1) / 2))
200
- else:
201
- max_winsize = num_files - 1
202
- elif scenegraph_type == "logwin":
203
- if win_cyclic:
204
- half_size = math.ceil((num_files - 1) / 2)
205
- max_winsize = max(1, math.ceil(math.log(half_size, 2)))
 
 
 
206
  else:
207
- max_winsize = max(1, math.ceil(math.log(num_files, 2)))
208
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
209
- minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
210
- win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
211
- win_col = gradio.Column(visible=show_win_controls)
212
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
213
- maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
214
- return win_col, winsize, win_cyclic, refid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
 
216
 
217
- def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False,
 
218
  share=False, gradio_delete_cache=False):
219
  if not silent:
220
  print('Outputing stuff in', tmpdirname)
221
 
222
- recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model, device,
223
- silent, image_size)
224
  model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
225
 
 
 
 
 
 
 
 
226
  def get_context(delete_cache):
227
  css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
228
  title = "MASt3R Demo"
@@ -241,33 +287,31 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
241
  with gradio.Column():
242
  with gradio.Row():
243
  lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
244
- niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
245
- label="num_iterations", info="For coarse alignment!")
246
- lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
247
- niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
248
- label="num_iterations", info="For refinement!")
249
  optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
250
  value='refine', label="OptLevel",
251
  info="Optimization level")
252
  with gradio.Row():
253
- matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
254
  minimum=0., maximum=30., step=0.1,
255
  info="Before Fallback to Regr3D!")
256
  shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
257
  info="Only optimize one set of intrinsics for all views")
258
- scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
259
- ("swin: sliding window", "swin"),
260
- ("logwin: sliding window with long range", "logwin"),
261
- ("oneref: match one image with all", "oneref")],
262
  value='complete', label="Scenegraph",
263
  info="Define how to make pairs",
264
  interactive=True)
265
- with gradio.Column(visible=False) as win_col:
266
  winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
267
  minimum=1, maximum=1, step=1)
268
  win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
269
- refid = gradio.Slider(label="Scene Graph: Id", value=0,
270
- minimum=0, maximum=0, step=1, visible=False)
 
271
  run_btn = gradio.Button("Run")
272
 
273
  with gradio.Row():
@@ -288,13 +332,13 @@ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, s
288
  # events
289
  scenegraph_type.change(set_scenegraph_options,
290
  inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
291
- outputs=[win_col, winsize, win_cyclic, refid])
292
  inputfiles.change(set_scenegraph_options,
293
  inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
294
- outputs=[win_col, winsize, win_cyclic, refid])
295
  win_cyclic.change(set_scenegraph_options,
296
  inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
297
- outputs=[win_col, winsize, win_cyclic, refid])
298
  run_btn.click(fn=recon_fun,
299
  inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
300
  as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
 
15
  from scipy.spatial.transform import Rotation
16
  import tempfile
17
  import shutil
18
+ import torch
19
 
20
  from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
21
  from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
22
+ from mast3r.image_pairs import make_pairs
23
+ from mast3r.retrieval.processor import Retriever
24
 
25
  import mast3r.utils.path_to_dust3r # noqa
 
26
  from dust3r.utils.image import load_images
27
  from dust3r.utils.device import to_numpy
28
  from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
 
31
  import matplotlib.pyplot as pl
32
 
33
 
34
+ class SparseGAState:
35
  def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
36
  self.sparse_ga = sparse_ga
37
  self.cache_dir = cache_dir
 
54
  parser.add_argument('--share', action='store_true')
55
  parser.add_argument('--gradio_delete_cache', default=None, type=int,
56
  help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
57
+ parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded")
58
 
59
  actions = parser._actions
60
  for action in actions:
 
139
  transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
140
 
141
 
142
+ def get_reconstructed_scene(outdir, gradio_delete_cache, model, retrieval_model, device, silent, image_size,
143
+ current_scene_state, filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr,
144
+ matching_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
145
+ scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
146
  """
147
  from a list of images, run mast3r inference, sparse global aligner.
148
  then run get_3D_model_from_scene
 
158
  scene_graph_params.append(str(winsize))
159
  elif scenegraph_type == "oneref":
160
  scene_graph_params.append(str(refid))
161
+ elif scenegraph_type == "retrieval":
162
+ scene_graph_params.append(str(winsize)) # Na
163
+ scene_graph_params.append(str(refid)) # k
164
+
165
  if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
166
  scene_graph_params.append('noncyclic')
167
  scene_graph = '-'.join(scene_graph_params)
168
+
169
+ sim_matrix = None
170
+ if 'retrieval' in scenegraph_type:
171
+ assert retrieval_model is not None
172
+ retriever = Retriever(retrieval_model, backbone=model, device=device)
173
+ with torch.no_grad():
174
+ sim_matrix = retriever(filelist)
175
+
176
+ # Cleanup
177
+ del retriever
178
+ torch.cuda.empty_cache()
179
+
180
+ pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix)
181
  if optim_level == 'coarse':
182
  niter2 = 0
183
  # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
 
209
 
210
  def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
211
  num_files = len(inputfiles) if inputfiles is not None else 1
 
 
 
212
  max_winsize, min_winsize = 1, 1
213
+
214
+ winsize = gradio.Slider(visible=False)
215
+ win_cyclic = gradio.Checkbox(visible=False)
216
+ graph_opt = gradio.Column(visible=False)
217
+ refid = gradio.Slider(visible=False)
218
+
219
+ if scenegraph_type in ["swin", "logwin"]:
220
+ if scenegraph_type == "swin":
221
+ if win_cyclic:
222
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
223
+ else:
224
+ max_winsize = num_files - 1
225
  else:
226
+ if win_cyclic:
227
+ half_size = math.ceil((num_files - 1) / 2)
228
+ max_winsize = max(1, math.ceil(math.log(half_size, 2)))
229
+ else:
230
+ max_winsize = max(1, math.ceil(math.log(num_files, 2)))
231
+
232
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
233
+ minimum=min_winsize, maximum=max_winsize, step=1, visible=True)
234
+ win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=True)
235
+ graph_opt = gradio.Column(visible=True)
236
+ refid = gradio.Slider(visible=False)
237
+
238
+ elif scenegraph_type == "retrieval":
239
+ graph_opt = gradio.Column(visible=True)
240
+ winsize = gradio.Slider(label="Retrieval: Num. key images", value=min(20, num_files),
241
+ minimum=0, maximum=num_files, step=1, visible=True)
242
+ win_cyclic = gradio.Checkbox(visible=False)
243
+ refid = gradio.Slider(label="Retrieval: Num neighbors", value=min(num_files - 1, 10), minimum=1,
244
+ maximum=num_files - 1, step=1, visible=True)
245
+
246
+ elif scenegraph_type == "oneref":
247
+ graph_opt = gradio.Column(visible=True)
248
+ winsize = gradio.Slider(visible=False)
249
+ win_cyclic = gradio.Checkbox(visible=False)
250
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
251
+ maximum=num_files - 1, step=1, visible=True)
252
 
253
+ return graph_opt, winsize, win_cyclic, refid
254
 
255
+
256
+ def main_demo(tmpdirname, model, retrieval_model, device, image_size, server_name, server_port, silent=False,
257
  share=False, gradio_delete_cache=False):
258
  if not silent:
259
  print('Outputing stuff in', tmpdirname)
260
 
261
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model,
262
+ retrieval_model, device, silent, image_size)
263
  model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
264
 
265
+ available_scenegraph_type = [("complete: all possible image pairs", "complete"),
266
+ ("swin: sliding window", "swin"),
267
+ ("logwin: sliding window with long range", "logwin"),
268
+ ("oneref: match one image with all", "oneref")]
269
+ if retrieval_model is not None:
270
+ available_scenegraph_type.insert(1, ("retrieval: connect views based on similarity", "retrieval"))
271
+
272
  def get_context(delete_cache):
273
  css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
274
  title = "MASt3R Demo"
 
287
  with gradio.Column():
288
  with gradio.Row():
289
  lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
290
+ niter1 = gradio.Slider(value=300, minimum=0, maximum=1000, step=1,
291
+ label="Iterations", info="For coarse alignment")
292
+ lr2 = gradio.Slider(label="Fine LR", value=0.01, minimum=0.005, maximum=0.05, step=0.001)
293
+ niter2 = gradio.Slider(value=300, minimum=0, maximum=1000, step=1,
294
+ label="Iterations", info="For refinement")
295
  optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
296
  value='refine', label="OptLevel",
297
  info="Optimization level")
298
  with gradio.Row():
299
+ matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=0.,
300
  minimum=0., maximum=30., step=0.1,
301
  info="Before Fallback to Regr3D!")
302
  shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
303
  info="Only optimize one set of intrinsics for all views")
304
+ scenegraph_type = gradio.Dropdown(available_scenegraph_type,
 
 
 
305
  value='complete', label="Scenegraph",
306
  info="Define how to make pairs",
307
  interactive=True)
308
+ with gradio.Column(visible=False) as graph_opt:
309
  winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
310
  minimum=1, maximum=1, step=1)
311
  win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
312
+ refid = gradio.Slider(label="Scene Graph: Id", value=0,
313
+ minimum=0, maximum=0, step=1, visible=False)
314
+
315
  run_btn = gradio.Button("Run")
316
 
317
  with gradio.Row():
 
332
  # events
333
  scenegraph_type.change(set_scenegraph_options,
334
  inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
335
+ outputs=[graph_opt, winsize, win_cyclic, refid])
336
  inputfiles.change(set_scenegraph_options,
337
  inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
338
+ outputs=[graph_opt, winsize, win_cyclic, refid])
339
  win_cyclic.change(set_scenegraph_options,
340
  inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
341
+ outputs=[graph_opt, winsize, win_cyclic, refid])
342
  run_btn.click(fn=recon_fun,
343
  inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
344
  as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
mast3r/demo_glomap.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # gradio demo functions
7
+ # --------------------------------------------------------
8
+ import pycolmap
9
+ import gradio
10
+ import os
11
+ import numpy as np
12
+ import functools
13
+ import trimesh
14
+ import copy
15
+ from scipy.spatial.transform import Rotation
16
+ import tempfile
17
+ import shutil
18
+ import PIL.Image
19
+ import torch
20
+
21
+ from kapture.converter.colmap.database_extra import kapture_to_colmap
22
+ from kapture.converter.colmap.database import COLMAPDatabase
23
+
24
+ from mast3r.colmap.mapping import kapture_import_image_folder_or_list, run_mast3r_matching, glomap_run_mapper
25
+ from mast3r.demo import set_scenegraph_options
26
+ from mast3r.retrieval.processor import Retriever
27
+ from mast3r.image_pairs import make_pairs
28
+
29
+ import mast3r.utils.path_to_dust3r # noqa
30
+ from dust3r.utils.image import load_images
31
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL
32
+ from dust3r.demo import get_args_parser as dust3r_get_args_parser
33
+
34
+ import matplotlib.pyplot as pl
35
+
36
+
37
+ class GlomapRecon:
38
+ def __init__(self, world_to_cam, intrinsics, points3d, imgs):
39
+ self.world_to_cam = world_to_cam
40
+ self.intrinsics = intrinsics
41
+ self.points3d = points3d
42
+ self.imgs = imgs
43
+
44
+
45
+ class GlomapReconState:
46
+ def __init__(self, glomap_recon, should_delete=False, cache_dir=None, outfile_name=None):
47
+ self.glomap_recon = glomap_recon
48
+ self.cache_dir = cache_dir
49
+ self.outfile_name = outfile_name
50
+ self.should_delete = should_delete
51
+
52
+ def __del__(self):
53
+ if not self.should_delete:
54
+ return
55
+ if self.cache_dir is not None and os.path.isdir(self.cache_dir):
56
+ shutil.rmtree(self.cache_dir)
57
+ self.cache_dir = None
58
+ if self.outfile_name is not None and os.path.isfile(self.outfile_name):
59
+ os.remove(self.outfile_name)
60
+ self.outfile_name = None
61
+
62
+
63
+ def get_args_parser():
64
+ parser = dust3r_get_args_parser()
65
+ parser.add_argument('--share', action='store_true')
66
+ parser.add_argument('--gradio_delete_cache', default=None, type=int,
67
+ help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
68
+ parser.add_argument('--glomap_bin', default='glomap', type=str, help='glomap bin')
69
+ parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded")
70
+
71
+ actions = parser._actions
72
+ for action in actions:
73
+ if action.dest == 'model_name':
74
+ action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
75
+ # change defaults
76
+ parser.prog = 'mast3r demo'
77
+ return parser
78
+
79
+
80
+ def get_reconstructed_scene(glomap_bin, outdir, gradio_delete_cache, model, retrieval_model, device, silent, image_size,
81
+ current_scene_state, filelist, transparent_cams, cam_size, scenegraph_type, winsize,
82
+ win_cyclic, refid, shared_intrinsics, **kw):
83
+ """
84
+ from a list of images, run mast3r inference, sparse global aligner.
85
+ then run get_3D_model_from_scene
86
+ """
87
+ imgs = load_images(filelist, size=image_size, verbose=not silent)
88
+ if len(imgs) == 1:
89
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
90
+ imgs[1]['idx'] = 1
91
+ filelist = [filelist[0], filelist[0]]
92
+
93
+ scene_graph_params = [scenegraph_type]
94
+ if scenegraph_type in ["swin", "logwin"]:
95
+ scene_graph_params.append(str(winsize))
96
+ elif scenegraph_type == "oneref":
97
+ scene_graph_params.append(str(refid))
98
+ elif scenegraph_type == "retrieval":
99
+ scene_graph_params.append(str(winsize)) # Na
100
+ scene_graph_params.append(str(refid)) # k
101
+
102
+ if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
103
+ scene_graph_params.append('noncyclic')
104
+ scene_graph = '-'.join(scene_graph_params)
105
+
106
+ sim_matrix = None
107
+ if 'retrieval' in scenegraph_type:
108
+ assert retrieval_model is not None
109
+ retriever = Retriever(retrieval_model, backbone=model, device=device)
110
+ with torch.no_grad():
111
+ sim_matrix = retriever(filelist)
112
+
113
+ # Cleanup
114
+ del retriever
115
+ torch.cuda.empty_cache()
116
+
117
+ pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix)
118
+
119
+ if current_scene_state is not None and \
120
+ not current_scene_state.should_delete and \
121
+ current_scene_state.cache_dir is not None:
122
+ cache_dir = current_scene_state.cache_dir
123
+ elif gradio_delete_cache:
124
+ cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir)
125
+ else:
126
+ cache_dir = os.path.join(outdir, 'cache')
127
+
128
+ root_path = os.path.commonpath(filelist)
129
+ filelist_relpath = [
130
+ os.path.relpath(filename, root_path).replace('\\', '/')
131
+ for filename in filelist
132
+ ]
133
+ kdata = kapture_import_image_folder_or_list((root_path, filelist_relpath), shared_intrinsics)
134
+ image_pairs = [
135
+ (filelist_relpath[img1['idx']], filelist_relpath[img2['idx']])
136
+ for img1, img2 in pairs
137
+ ]
138
+
139
+ colmap_db_path = os.path.join(cache_dir, 'colmap.db')
140
+ if os.path.isfile(colmap_db_path):
141
+ os.remove(colmap_db_path)
142
+
143
+ os.makedirs(os.path.dirname(colmap_db_path), exist_ok=True)
144
+ colmap_db = COLMAPDatabase.connect(colmap_db_path)
145
+ try:
146
+ kapture_to_colmap(kdata, root_path, tar_handler=None, database=colmap_db,
147
+ keypoints_type=None, descriptors_type=None, export_two_view_geometry=False)
148
+ colmap_image_pairs = run_mast3r_matching(model, image_size, 16, device,
149
+ kdata, root_path, image_pairs, colmap_db,
150
+ False, 5, 1.001,
151
+ False, 3)
152
+ colmap_db.close()
153
+ except Exception as e:
154
+ print(f'Error {e}')
155
+ colmap_db.close()
156
+ exit(1)
157
+
158
+ if len(colmap_image_pairs) == 0:
159
+ raise Exception("no matches were kept")
160
+
161
+ # colmap db is now full, run colmap
162
+ colmap_world_to_cam = {}
163
+ print("verify_matches")
164
+ f = open(cache_dir + '/pairs.txt', "w")
165
+ for image_path1, image_path2 in colmap_image_pairs:
166
+ f.write("{} {}\n".format(image_path1, image_path2))
167
+ f.close()
168
+ pycolmap.verify_matches(colmap_db_path, cache_dir + '/pairs.txt')
169
+
170
+ reconstruction_path = os.path.join(cache_dir, "reconstruction")
171
+ if os.path.isdir(reconstruction_path):
172
+ shutil.rmtree(reconstruction_path)
173
+ os.makedirs(reconstruction_path, exist_ok=True)
174
+ glomap_run_mapper(glomap_bin, colmap_db_path, reconstruction_path, root_path)
175
+
176
+ if current_scene_state is not None and \
177
+ not current_scene_state.should_delete and \
178
+ current_scene_state.outfile_name is not None:
179
+ outfile_name = current_scene_state.outfile_name
180
+ else:
181
+ outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
182
+
183
+ ouput_recon = pycolmap.Reconstruction(os.path.join(reconstruction_path, '0'))
184
+ print(ouput_recon.summary())
185
+
186
+ colmap_world_to_cam = {}
187
+ colmap_intrinsics = {}
188
+ colmap_image_id_to_name = {}
189
+ images = {}
190
+ num_reg_images = ouput_recon.num_reg_images()
191
+ for idx, (colmap_imgid, colmap_image) in enumerate(ouput_recon.images.items()):
192
+ colmap_image_id_to_name[colmap_imgid] = colmap_image.name
193
+ if callable(colmap_image.cam_from_world.matrix):
194
+ colmap_world_to_cam[colmap_imgid] = colmap_image.cam_from_world.matrix(
195
+ )
196
+ else:
197
+ colmap_world_to_cam[colmap_imgid] = colmap_image.cam_from_world.matrix
198
+ camera = ouput_recon.cameras[colmap_image.camera_id]
199
+ K = np.eye(3)
200
+ K[0, 0] = camera.focal_length_x
201
+ K[1, 1] = camera.focal_length_y
202
+ K[0, 2] = camera.principal_point_x
203
+ K[1, 2] = camera.principal_point_y
204
+ colmap_intrinsics[colmap_imgid] = K
205
+
206
+ with PIL.Image.open(os.path.join(root_path, colmap_image.name)) as im:
207
+ images[colmap_imgid] = np.asarray(im)
208
+
209
+ if idx + 1 == num_reg_images:
210
+ break # bug with the iterable ?
211
+ points3D = []
212
+ num_points3D = ouput_recon.num_points3D()
213
+ for idx, (pt3d_id, pts3d) in enumerate(ouput_recon.points3D.items()):
214
+ points3D.append((pts3d.xyz, pts3d.color))
215
+ if idx + 1 == num_points3D:
216
+ break # bug with the iterable ?
217
+ scene = GlomapRecon(colmap_world_to_cam, colmap_intrinsics, points3D, images)
218
+ scene_state = GlomapReconState(scene, gradio_delete_cache, cache_dir, outfile_name)
219
+ outfile = get_3D_model_from_scene(silent, scene_state, transparent_cams, cam_size)
220
+ return scene_state, outfile
221
+
222
+
223
+ def get_3D_model_from_scene(silent, scene_state, transparent_cams=False, cam_size=0.05):
224
+ """
225
+ extract 3D_model (glb file) from a reconstructed scene
226
+ """
227
+ if scene_state is None:
228
+ return None
229
+ outfile = scene_state.outfile_name
230
+ if outfile is None:
231
+ return None
232
+
233
+ recon = scene_state.glomap_recon
234
+
235
+ scene = trimesh.Scene()
236
+ pts = np.stack([p[0] for p in recon.points3d], axis=0)
237
+ col = np.stack([p[1] for p in recon.points3d], axis=0)
238
+ pct = trimesh.PointCloud(pts, colors=col)
239
+ scene.add_geometry(pct)
240
+
241
+ # add each camera
242
+ cams2world = []
243
+ for i, (id, pose_w2c_3x4) in enumerate(recon.world_to_cam.items()):
244
+ intrinsics = recon.intrinsics[id]
245
+ focal = (intrinsics[0, 0] + intrinsics[1, 1]) / 2.0
246
+ camera_edge_color = CAM_COLORS[i % len(CAM_COLORS)]
247
+ pose_w2c = np.eye(4)
248
+ pose_w2c[:3, :] = pose_w2c_3x4
249
+ pose_c2w = np.linalg.inv(pose_w2c)
250
+ cams2world.append(pose_c2w)
251
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
252
+ None if transparent_cams else recon.imgs[id], focal,
253
+ imsize=recon.imgs[id].shape[1::-1], screen_width=cam_size)
254
+
255
+ rot = np.eye(4)
256
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
257
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
258
+ if not silent:
259
+ print('(exporting 3D scene to', outfile, ')')
260
+ scene.export(file_obj=outfile)
261
+
262
+ return outfile
263
+
264
+
265
+ def main_demo(glomap_bin, tmpdirname, model, retrieval_model, device, image_size, server_name, server_port,
266
+ silent=False, share=False, gradio_delete_cache=False):
267
+ if not silent:
268
+ print('Outputing stuff in', tmpdirname)
269
+
270
+ recon_fun = functools.partial(get_reconstructed_scene, glomap_bin, tmpdirname, gradio_delete_cache, model,
271
+ retrieval_model, device, silent, image_size)
272
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
273
+
274
+ available_scenegraph_type = [("complete: all possible image pairs", "complete"),
275
+ ("swin: sliding window", "swin"),
276
+ ("logwin: sliding window with long range", "logwin"),
277
+ ("oneref: match one image with all", "oneref")]
278
+ if retrieval_model is not None:
279
+ available_scenegraph_type.insert(1, ("retrieval: connect views based on similarity", "retrieval"))
280
+
281
+ def get_context(delete_cache):
282
+ css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
283
+ title = "MASt3R Demo"
284
+ if delete_cache:
285
+ return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
286
+ else:
287
+ return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
288
+
289
+ with get_context(gradio_delete_cache) as demo:
290
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
291
+ scene = gradio.State(None)
292
+ gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
293
+ with gradio.Column():
294
+ inputfiles = gradio.File(file_count="multiple")
295
+ with gradio.Row():
296
+ shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
297
+ info="Only optimize one set of intrinsics for all views")
298
+ scenegraph_type = gradio.Dropdown(available_scenegraph_type,
299
+ value='complete', label="Scenegraph",
300
+ info="Define how to make pairs",
301
+ interactive=True)
302
+ with gradio.Column(visible=False) as win_col:
303
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
304
+ minimum=1, maximum=1, step=1)
305
+ win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
306
+ refid = gradio.Slider(label="Scene Graph: Id", value=0,
307
+ minimum=0, maximum=0, step=1, visible=False)
308
+ run_btn = gradio.Button("Run")
309
+
310
+ with gradio.Row():
311
+ # adjust the camera size in the output pointcloud
312
+ cam_size = gradio.Slider(label="cam_size", value=0.01, minimum=0.001, maximum=1.0, step=0.001)
313
+ with gradio.Row():
314
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
315
+
316
+ outmodel = gradio.Model3D()
317
+
318
+ # events
319
+ scenegraph_type.change(set_scenegraph_options,
320
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
321
+ outputs=[win_col, winsize, win_cyclic, refid])
322
+ inputfiles.change(set_scenegraph_options,
323
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
324
+ outputs=[win_col, winsize, win_cyclic, refid])
325
+ win_cyclic.change(set_scenegraph_options,
326
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
327
+ outputs=[win_col, winsize, win_cyclic, refid])
328
+ run_btn.click(fn=recon_fun,
329
+ inputs=[scene, inputfiles, transparent_cams, cam_size,
330
+ scenegraph_type, winsize, win_cyclic, refid, shared_intrinsics],
331
+ outputs=[scene, outmodel])
332
+ cam_size.change(fn=model_from_scene_fun,
333
+ inputs=[scene, transparent_cams, cam_size],
334
+ outputs=outmodel)
335
+ transparent_cams.change(model_from_scene_fun,
336
+ inputs=[scene, transparent_cams, cam_size],
337
+ outputs=outmodel)
338
+ demo.launch(share=share, server_name=server_name, server_port=server_port)
mast3r/image_pairs.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilities needed to load image pairs
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ from mast3r.retrieval.graph import make_pairs_fps
10
+
11
+ def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True, sim_mat=None):
12
+ pairs = []
13
+ if scene_graph == 'complete': # complete graph
14
+ for i in range(len(imgs)):
15
+ for j in range(i):
16
+ pairs.append((imgs[i], imgs[j]))
17
+ elif scene_graph.startswith('swin'):
18
+ iscyclic = not scene_graph.endswith('noncyclic')
19
+ try:
20
+ winsize = int(scene_graph.split('-')[1])
21
+ except Exception as e:
22
+ winsize = 3
23
+ pairsid = set()
24
+ for i in range(len(imgs)):
25
+ for j in range(1, winsize + 1):
26
+ idx = (i + j)
27
+ if iscyclic:
28
+ idx = idx % len(imgs) # explicit loop closure
29
+ if idx >= len(imgs):
30
+ continue
31
+ pairsid.add((i, idx) if i < idx else (idx, i))
32
+ for i, j in pairsid:
33
+ pairs.append((imgs[i], imgs[j]))
34
+ elif scene_graph.startswith('logwin'):
35
+ iscyclic = not scene_graph.endswith('noncyclic')
36
+ try:
37
+ winsize = int(scene_graph.split('-')[1])
38
+ except Exception as e:
39
+ winsize = 3
40
+ offsets = [2**i for i in range(winsize)]
41
+ pairsid = set()
42
+ for i in range(len(imgs)):
43
+ ixs_l = [i - off for off in offsets]
44
+ ixs_r = [i + off for off in offsets]
45
+ for j in ixs_l + ixs_r:
46
+ if iscyclic:
47
+ j = j % len(imgs) # Explicit loop closure
48
+ if j < 0 or j >= len(imgs) or j == i:
49
+ continue
50
+ pairsid.add((i, j) if i < j else (j, i))
51
+ for i, j in pairsid:
52
+ pairs.append((imgs[i], imgs[j]))
53
+ elif scene_graph.startswith('oneref'):
54
+ refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0
55
+ for j in range(len(imgs)):
56
+ if j != refid:
57
+ pairs.append((imgs[refid], imgs[j]))
58
+ elif scene_graph.startswith('retrieval'):
59
+ mode, Na, k = scene_graph.split('-')
60
+ assert sim_mat is not None, "sim_mat is required for retrieval mode"
61
+
62
+ fps_pairs, anchor_idxs = make_pairs_fps(sim_mat, Na=int(Na), tokK=int(k), dist_thresh=None)
63
+
64
+ for i, j in fps_pairs:
65
+ pairs.append((imgs[i], imgs[j]))
66
+ else:
67
+ raise ValueError(f'unrecognized value for {scene_graph=}')
68
+
69
+ if symmetrize:
70
+ pairs += [(img2, img1) for img1, img2 in pairs]
71
+
72
+ # now, remove edges
73
+ if isinstance(prefilter, str) and prefilter.startswith('seq'):
74
+ pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
75
+
76
+ if isinstance(prefilter, str) and prefilter.startswith('cyc'):
77
+ pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
78
+
79
+ return pairs
80
+
81
+
82
+ def sel(x, kept):
83
+ if isinstance(x, dict):
84
+ return {k: sel(v, kept) for k, v in x.items()}
85
+ if isinstance(x, (torch.Tensor, np.ndarray)):
86
+ return x[kept]
87
+ if isinstance(x, (tuple, list)):
88
+ return type(x)([x[k] for k in kept])
89
+
90
+
91
+ def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
92
+ # number of images
93
+ n = max(max(e) for e in edges) + 1
94
+
95
+ kept = []
96
+ for e, (i, j) in enumerate(edges):
97
+ dis = abs(i - j)
98
+ if cyclic:
99
+ dis = min(dis, abs(i + n - j), abs(i - n - j))
100
+ if dis <= seq_dis_thr:
101
+ kept.append(e)
102
+ return kept
103
+
104
+
105
+ def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
106
+ edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs]
107
+ kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
108
+ return [pairs[i] for i in kept]
109
+
110
+
111
+ def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):
112
+ edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
113
+ kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
114
+ print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges')
115
+ return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)
mast3r/losses.py CHANGED
@@ -273,7 +273,7 @@ class InfoNCE(MatchingCriterion):
273
 
274
 
275
  class APLoss (MatchingCriterion):
276
- """ AP loss.
277
  """
278
 
279
  def __init__(self, nq='torch', min=0, max=1, euc=False, **kw):
 
273
 
274
 
275
  class APLoss (MatchingCriterion):
276
+ """ AP loss
277
  """
278
 
279
  def __init__(self, nq='torch', min=0, max=1, euc=False, **kw):
mast3r/retrieval/graph.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Building the graph based on retrieval results.
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+
9
+
10
+ def farthest_point_sampling(dist, N=None, dist_thresh=None):
11
+ """Farthest point sampling.
12
+
13
+ Args:
14
+ dist: NxN distance matrix.
15
+ N: Number of points to sample.
16
+ dist_thresh: Distance threshold. Point sampling terminates once the
17
+ maximum distance is below this threshold.
18
+
19
+ Returns:
20
+ indices: Indices of the sampled points.
21
+ """
22
+
23
+ assert N is not None or dist_thresh is not None, "Either N or min_dist must be provided."
24
+
25
+ if N is None:
26
+ N = dist.shape[0]
27
+
28
+ indices = []
29
+ distances = [0]
30
+ indices.append(np.random.choice(dist.shape[0]))
31
+ for i in range(1, N):
32
+ d = dist[indices].min(axis=0)
33
+ bst = d.argmax()
34
+ bst_dist = d[bst]
35
+ if dist_thresh is not None and bst_dist < dist_thresh:
36
+ break
37
+ indices.append(bst)
38
+ distances.append(bst_dist)
39
+ return np.array(indices), np.array(distances)
40
+
41
+
42
+ def make_pairs_fps(sim_mat, Na=20, tokK=1, dist_thresh=None):
43
+ dist_mat = 1 - sim_mat
44
+
45
+ pairs = set()
46
+ keyimgs_idx = np.array([])
47
+ if Na != 0:
48
+ keyimgs_idx, _ = farthest_point_sampling(dist_mat, N=Na, dist_thresh=dist_thresh)
49
+
50
+ # 1. Complete graph between key images
51
+ for i in range(len(keyimgs_idx)):
52
+ for j in range(i + 1, len(keyimgs_idx)):
53
+ idx_i, idx_j = keyimgs_idx[i], keyimgs_idx[j]
54
+ pairs.add((idx_i, idx_j))
55
+
56
+ # 2. Connect non-key images to the earest key image
57
+ keyimg_dist_mat = dist_mat[:, keyimgs_idx]
58
+ for i in range(keyimg_dist_mat.shape[0]):
59
+ if i in keyimgs_idx:
60
+ continue
61
+ j = keyimg_dist_mat[i].argmax()
62
+ i1, i2 = min(i, keyimgs_idx[j]), max(i, keyimgs_idx[j])
63
+ if i1 != i2 and (i1, i2) not in pairs:
64
+ pairs.add((i1, i2))
65
+
66
+ # 3. Add some local connections (k-NN) for each view
67
+ if tokK > 0:
68
+ for i in range(dist_mat.shape[0]):
69
+ idx = dist_mat[i].argsort()[:tokK]
70
+ for j in idx:
71
+ i1, i2 = min(i, j), max(i, j)
72
+ if i1 != i2 and (i1, i2) not in pairs:
73
+ pairs.add((i1, i2))
74
+
75
+ pairs = list(pairs)
76
+
77
+ return pairs, keyimgs_idx
mast3r/retrieval/model.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Whitener and RetrievalModel
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import time
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+ import mast3r.utils.path_to_dust3r # noqa
15
+ from dust3r.utils.image import load_images
16
+
17
+ default_device = torch.device('cuda' if torch.cuda.is_available() and torch.cuda.device_count() > 0 else 'cpu')
18
+
19
+
20
+ # from https://github.com/gtolias/how/blob/4d73c88e0ffb55506e2ce6249e2a015ef6ccf79f/how/utils/whitening.py#L20
21
+ def pcawhitenlearn_shrinkage(X, s=1.0):
22
+ """Learn PCA whitening with shrinkage from given descriptors"""
23
+ N = X.shape[0]
24
+
25
+ # Learning PCA w/o annotations
26
+ m = X.mean(axis=0, keepdims=True)
27
+ Xc = X - m
28
+ Xcov = np.dot(Xc.T, Xc)
29
+ Xcov = (Xcov + Xcov.T) / (2 * N)
30
+ eigval, eigvec = np.linalg.eig(Xcov)
31
+ order = eigval.argsort()[::-1]
32
+ eigval = eigval[order]
33
+ eigvec = eigvec[:, order]
34
+
35
+ eigval = np.clip(eigval, a_min=1e-14, a_max=None)
36
+ P = np.dot(np.linalg.inv(np.diag(np.power(eigval, 0.5 * s))), eigvec.T)
37
+
38
+ return m, P.T
39
+
40
+
41
+ class Dust3rInputFromImageList(torch.utils.data.Dataset):
42
+ def __init__(self, image_list, imsize=512):
43
+ super().__init__()
44
+ self.image_list = image_list
45
+ assert imsize == 512
46
+ self.imsize = imsize
47
+
48
+ def __len__(self):
49
+ return len(self.image_list)
50
+
51
+ def __getitem__(self, index):
52
+ return load_images([self.image_list[index]], size=self.imsize, verbose=False)[0]
53
+
54
+
55
+ class Whitener(nn.Module):
56
+ def __init__(self, dim, l2norm=None):
57
+ super().__init__()
58
+ self.m = torch.nn.Parameter(torch.zeros((1, dim)).double())
59
+ self.p = torch.nn.Parameter(torch.eye(dim, dim).double())
60
+ self.l2norm = l2norm # if not None, apply l2 norm along a given dimension
61
+
62
+ def forward(self, x):
63
+ with torch.autocast(self.m.device.type, enabled=False):
64
+ shape = x.size()
65
+ input_type = x.dtype
66
+ x_reshaped = x.view(-1, shape[-1]).to(dtype=self.m.dtype)
67
+ # Center the input data
68
+ x_centered = x_reshaped - self.m
69
+ # Apply PCA transformation
70
+ pca_output = torch.matmul(x_centered, self.p)
71
+ # reshape back
72
+ pca_output_shape = shape # list(shape[:-1]) + [shape[-1]]
73
+ pca_output = pca_output.view(pca_output_shape)
74
+ if self.l2norm is not None:
75
+ return torch.nn.functional.normalize(pca_output, dim=self.l2norm).to(dtype=input_type)
76
+ return pca_output.to(dtype=input_type)
77
+
78
+
79
+ def weighted_spoc(feat, attn):
80
+ """
81
+ feat: BxNxC
82
+ attn: BxN
83
+ output: BxC L2-normalization weighted-sum-pooling of features
84
+ """
85
+ return torch.nn.functional.normalize((feat * attn[:, :, None]).sum(dim=1), dim=1)
86
+
87
+
88
+ def how_select_local(feat, attn, nfeat):
89
+ """
90
+ feat: BxNxC
91
+ attn: BxN
92
+ nfeat: nfeat to keep
93
+ """
94
+ # get nfeat
95
+ if nfeat < 0:
96
+ assert nfeat >= -1.0
97
+ nfeat = int(-nfeat * feat.size(1))
98
+ else:
99
+ nfeat = int(nfeat)
100
+ # asort
101
+ topk_attn, topk_indices = torch.topk(attn, min(nfeat, attn.size(1)), dim=1)
102
+ topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, feat.size(2))
103
+ topk_features = torch.gather(feat, 1, topk_indices_expanded)
104
+ return topk_features, topk_attn, topk_indices
105
+
106
+
107
+ class RetrievalModel(nn.Module):
108
+ def __init__(self, backbone, freeze_backbone=1, prewhiten=None, hdims=[1024], residual=False, postwhiten=None,
109
+ featweights='l2norm', nfeat=300, pretrained_retrieval=None):
110
+ super().__init__()
111
+ self.backbone = backbone
112
+ self.freeze_backbone = freeze_backbone
113
+ if freeze_backbone:
114
+ for p in self.backbone.parameters():
115
+ p.requires_grad = False
116
+ self.backbone_dim = backbone.enc_embed_dim
117
+ self.prewhiten = nn.Identity() if prewhiten is None else Whitener(self.backbone_dim)
118
+ self.prewhiten_freq = prewhiten
119
+ if prewhiten is not None and prewhiten != -1:
120
+ for p in self.prewhiten.parameters():
121
+ p.requires_grad = False
122
+ self.residual = residual
123
+ self.projector = self.build_projector(hdims, residual)
124
+ self.dim = hdims[-1] if len(hdims) > 0 else self.backbone_dim
125
+ self.postwhiten_freq = postwhiten
126
+ self.postwhiten = nn.Identity() if postwhiten is None else Whitener(self.dim)
127
+ if postwhiten is not None and postwhiten != -1:
128
+ assert len(hdims) > 0
129
+ for p in self.postwhiten.parameters():
130
+ p.requires_grad = False
131
+ self.featweights = featweights
132
+ if featweights == 'l2norm':
133
+ self.attention = lambda x: x.norm(dim=-1)
134
+ else:
135
+ raise NotImplementedError(featweights)
136
+ self.nfeat = nfeat
137
+ self.pretrained_retrieval = pretrained_retrieval
138
+ if self.pretrained_retrieval is not None:
139
+ ckpt = torch.load(pretrained_retrieval, 'cpu')
140
+ msg = self.load_state_dict(ckpt['model'], strict=False)
141
+ assert len(msg.unexpected_keys) == 0 and all(k.startswith('backbone')
142
+ or k.startswith('postwhiten') for k in msg.missing_keys)
143
+
144
+ def build_projector(self, hdims, residual):
145
+ if self.residual:
146
+ assert hdims[-1] == self.backbone_dim
147
+ d = self.backbone_dim
148
+ if len(hdims) == 0:
149
+ return nn.Identity()
150
+ layers = []
151
+ for i in range(len(hdims) - 1):
152
+ layers.append(nn.Linear(d, hdims[i]))
153
+ d = hdims[i]
154
+ layers.append(nn.LayerNorm(d))
155
+ layers.append(nn.GELU())
156
+ layers.append(nn.Linear(d, hdims[-1]))
157
+ return nn.Sequential(*layers)
158
+
159
+ def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
160
+ ss = super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
161
+ if self.freeze_backbone:
162
+ ss = {k: v for k, v in ss.items() if not k.startswith('backbone')}
163
+ return ss
164
+
165
+ def reinitialize_whitening(self, epoch, train_dataset, nimgs=5000, log_writer=None, max_nfeat_per_image=None, seed=0, device=default_device):
166
+ do_prewhiten = self.prewhiten_freq is not None and self.pretrained_retrieval is None and \
167
+ (epoch == 0 or (self.prewhiten_freq > 0 and epoch % self.prewhiten_freq == 0))
168
+ do_postwhiten = self.postwhiten_freq is not None and ((epoch == 0 and self.postwhiten_freq in [0, -1])
169
+ or (self.postwhiten_freq > 0 and
170
+ epoch % self.postwhiten_freq == 0 and epoch > 0))
171
+ if do_prewhiten or do_postwhiten:
172
+ self.eval()
173
+ imdataset = train_dataset.imlist_dataset_n_images(nimgs, seed)
174
+ loader = torch.utils.data.DataLoader(imdataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True)
175
+ if do_prewhiten:
176
+ print('Re-initialization of pre-whitening')
177
+ t = time.time()
178
+ with torch.no_grad():
179
+ features = []
180
+ for d in tqdm(loader):
181
+ feat = self.backbone._encode_image(d['img'][0, ...].to(device),
182
+ true_shape=d['true_shape'][0, ...])[0]
183
+ feat = feat.flatten(0, 1)
184
+ if max_nfeat_per_image is not None and max_nfeat_per_image < feat.size(0):
185
+ l2norms = torch.linalg.vector_norm(feat, dim=1)
186
+ feat = feat[torch.argsort(-l2norms)[:max_nfeat_per_image], :]
187
+ features.append(feat.cpu())
188
+ features = torch.cat(features, dim=0)
189
+ features = features.numpy()
190
+ m, P = pcawhitenlearn_shrinkage(features)
191
+ self.prewhiten.load_state_dict({'m': torch.from_numpy(m), 'p': torch.from_numpy(P)})
192
+ prewhiten_time = time.time() - t
193
+ print(f'Done in {prewhiten_time:.1f} seconds')
194
+ if log_writer is not None:
195
+ log_writer.add_scalar('time/prewhiten', prewhiten_time, epoch)
196
+ if do_postwhiten:
197
+ print(f'Re-initialization of post-whitening')
198
+ t = time.time()
199
+ with torch.no_grad():
200
+ features = []
201
+ for d in tqdm(loader):
202
+ backbone_feat = self.backbone._encode_image(d['img'][0, ...].to(device),
203
+ true_shape=d['true_shape'][0, ...])[0]
204
+ backbone_feat_prewhitened = self.prewhiten(backbone_feat)
205
+ proj_feat = self.projector(backbone_feat_prewhitened) + \
206
+ (0.0 if not self.residual else backbone_feat_prewhitened)
207
+ proj_feat = proj_feat.flatten(0, 1)
208
+ if max_nfeat_per_image is not None and max_nfeat_per_image < proj_feat.size(0):
209
+ l2norms = torch.linalg.vector_norm(proj_feat, dim=1)
210
+ proj_feat = proj_feat[torch.argsort(-l2norms)[:max_nfeat_per_image], :]
211
+ features.append(proj_feat.cpu())
212
+ features = torch.cat(features, dim=0)
213
+ features = features.numpy()
214
+ m, P = pcawhitenlearn_shrinkage(features)
215
+ self.postwhiten.load_state_dict({'m': torch.from_numpy(m), 'p': torch.from_numpy(P)})
216
+ postwhiten_time = time.time() - t
217
+ print(f'Done in {postwhiten_time:.1f} seconds')
218
+ if log_writer is not None:
219
+ log_writer.add_scalar('time/postwhiten', postwhiten_time, epoch)
220
+
221
+ def extract_features_and_attention(self, x):
222
+ backbone_feat = self.backbone._encode_image(x['img'], true_shape=x['true_shape'])[0]
223
+ backbone_feat_prewhitened = self.prewhiten(backbone_feat)
224
+ proj_feat = self.projector(backbone_feat_prewhitened) + \
225
+ (0.0 if not self.residual else backbone_feat_prewhitened)
226
+ attention = self.attention(proj_feat)
227
+ proj_feat_whitened = self.postwhiten(proj_feat)
228
+ return proj_feat_whitened, attention
229
+
230
+ def forward_local(self, x):
231
+ feat, attn = self.extract_features_and_attention(x)
232
+ return how_select_local(feat, attn, self.nfeat)
233
+
234
+ def forward_global(self, x):
235
+ feat, attn = self.extract_features_and_attention(x)
236
+ return weighted_spoc(feat, attn)
237
+
238
+ def forward(self, x):
239
+ return self.forward_global(x)
240
+
241
+
242
+ def identity(x): # to avoid Can't pickle local object 'extract_local_features.<locals>.<lambda>'
243
+ return x
244
+
245
+
246
+ @torch.no_grad()
247
+ def extract_local_features(model, images, imsize, seed=0, tocpu=False, max_nfeat_per_image=None,
248
+ max_nfeat_per_image2=None, device=default_device):
249
+ model.eval()
250
+ imdataset = Dust3rInputFromImageList(images, imsize=imsize) if isinstance(images, list) else images
251
+ loader = torch.utils.data.DataLoader(imdataset, batch_size=1, shuffle=False,
252
+ num_workers=8, pin_memory=True, collate_fn=identity)
253
+ with torch.no_grad():
254
+ features = []
255
+ imids = []
256
+ for i, d in enumerate(tqdm(loader)):
257
+ dd = d[0]
258
+ dd['img'] = dd['img'].to(device, non_blocking=True)
259
+ feat, _, _ = model.forward_local(dd)
260
+ feat = feat.flatten(0, 1)
261
+ if max_nfeat_per_image is not None and feat.size(0) > max_nfeat_per_image:
262
+ feat = feat[torch.randperm(feat.size(0))[:max_nfeat_per_image], :]
263
+ if max_nfeat_per_image2 is not None and feat.size(0) > max_nfeat_per_image2:
264
+ feat = feat[:max_nfeat_per_image2, :]
265
+ features.append(feat)
266
+ if tocpu:
267
+ features[-1] = features[-1].cpu()
268
+ imids.append(i * torch.ones_like(features[-1][:, 0]).to(dtype=torch.int64))
269
+ features = torch.cat(features, dim=0)
270
+ imids = torch.cat(imids, dim=0)
271
+ return features, imids
mast3r/retrieval/processor.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Main Retriever class
6
+ # --------------------------------------------------------
7
+ import os
8
+ import argparse
9
+ import numpy as np
10
+ import torch
11
+
12
+ from mast3r.model import AsymmetricMASt3R
13
+ from mast3r.retrieval.model import RetrievalModel, extract_local_features
14
+
15
+ try:
16
+ import faiss
17
+ faiss.StandardGpuResources() # when loading the checkpoint, it will try to instanciate FaissGpuL2Index
18
+ except AttributeError as e:
19
+ import asmk.index
20
+
21
+ class FaissCpuL2Index(asmk.index.FaissL2Index):
22
+ def __init__(self, gpu_id):
23
+ super().__init__()
24
+ self.gpu_id = gpu_id
25
+
26
+ def _faiss_index_flat(self, dim):
27
+ """Return initialized faiss.IndexFlatL2"""
28
+ return faiss.IndexFlatL2(dim)
29
+
30
+ asmk.index.FaissGpuL2Index = FaissCpuL2Index
31
+
32
+ from asmk import asmk_method # noqa
33
+
34
+
35
+ def get_args_parser():
36
+ parser = argparse.ArgumentParser('Retrieval scores from a set of retrieval', add_help=False, allow_abbrev=False)
37
+ parser.add_argument('--model', type=str, required=True,
38
+ help="shortname of a retrieval model or path to the corresponding .pth")
39
+ parser.add_argument('--input', type=str, required=True,
40
+ help="directory containing images or a file containing a list of image paths")
41
+ parser.add_argument('--outfile', type=str, required=True, help="numpy file where to store the matrix score")
42
+ return parser
43
+
44
+
45
+ def get_impaths(imlistfile):
46
+ with open(imlistfile, 'r') as fid:
47
+ impaths = [f for f in imlistfile.read().splitlines() if not f.startswith('#')
48
+ and len(f) > 0] # ignore comments and empty lines
49
+ return impaths
50
+
51
+
52
+ def get_impaths_from_imdir(imdir, extensions=['png', 'jpg', 'PNG', 'JPG']):
53
+ assert os.path.isdir(imdir)
54
+ impaths = [os.path.join(imdir, f) for f in sorted(os.listdir(imdir)) if any(f.endswith(ext) for ext in extensions)]
55
+ return impaths
56
+
57
+
58
+ def get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile):
59
+ if os.path.isfile(input_imdir_or_imlistfile):
60
+ return get_impaths(input_imdir_or_imlistfile)
61
+ else:
62
+ return get_impaths_from_imdir(input_imdir_or_imlistfile)
63
+
64
+
65
+ class Retriever(object):
66
+ def __init__(self, modelname, backbone=None, device='cuda'):
67
+ # load the model
68
+ assert os.path.isfile(modelname), modelname
69
+ print(f'Loading retrieval model from {modelname}')
70
+ ckpt = torch.load(modelname, 'cpu') # TODO from pretrained to download it automatically
71
+ ckpt_args = ckpt['args']
72
+ if backbone is None:
73
+ backbone = AsymmetricMASt3R.from_pretrained(ckpt_args.pretrained)
74
+ self.model = RetrievalModel(
75
+ backbone, freeze_backbone=ckpt_args.freeze_backbone, prewhiten=ckpt_args.prewhiten,
76
+ hdims=list(map(int, ckpt_args.hdims.split('_'))) if len(ckpt_args.hdims) > 0 else "",
77
+ residual=getattr(ckpt_args, 'residual', False), postwhiten=ckpt_args.postwhiten,
78
+ featweights=ckpt_args.featweights, nfeat=ckpt_args.nfeat
79
+ ).to(device)
80
+ self.device = device
81
+ msg = self.model.load_state_dict(ckpt['model'], strict=False)
82
+ assert all(k.startswith('backbone') for k in msg.missing_keys)
83
+ assert len(msg.unexpected_keys) == 0
84
+ self.imsize = ckpt_args.imsize
85
+
86
+ # load the asmk codebook
87
+ dname, bname = os.path.split(modelname) # TODO they should both be in the same file ?
88
+ bname_splits = bname.split('_')
89
+ cache_codebook_fname = os.path.join(dname, '_'.join(bname_splits[:-1]) + '_codebook.pkl')
90
+ assert os.path.isfile(cache_codebook_fname), cache_codebook_fname
91
+ asmk_params = {'index': {'gpu_id': 0}, 'train_codebook': {'codebook': {'size': '64k'}},
92
+ 'build_ivf': {'kernel': {'binary': True}, 'ivf': {'use_idf': False},
93
+ 'quantize': {'multiple_assignment': 1}, 'aggregate': {}},
94
+ 'query_ivf': {'quantize': {'multiple_assignment': 5}, 'aggregate': {},
95
+ 'search': {'topk': None},
96
+ 'similarity': {'similarity_threshold': 0.0, 'alpha': 3.0}}}
97
+ asmk_params['train_codebook']['codebook']['size'] = ckpt_args.nclusters
98
+ self.asmk = asmk_method.ASMKMethod.initialize_untrained(asmk_params)
99
+ self.asmk = self.asmk.train_codebook(None, cache_path=cache_codebook_fname)
100
+
101
+ def __call__(self, input_imdir_or_imlistfile, outfile=None):
102
+ # get impaths
103
+ if isinstance(input_imdir_or_imlistfile, str):
104
+ impaths = get_impaths_from_imdir_or_imlistfile(input_imdir_or_imlistfile)
105
+ else:
106
+ impaths = input_imdir_or_imlistfile # we're assuming a list has been passed
107
+ print(f'Found {len(impaths)} images')
108
+
109
+ # build the database
110
+ feat, ids = extract_local_features(self.model, impaths, self.imsize, tocpu=True, device=self.device)
111
+ feat = feat.cpu().numpy()
112
+ ids = ids.cpu().numpy()
113
+ asmk_dataset = self.asmk.build_ivf(feat, ids)
114
+
115
+ # we actually retrieve the same set of images
116
+ metadata, query_ids, ranks, ranked_scores = asmk_dataset.query_ivf(feat, ids)
117
+
118
+ # well ... scores are actually reordered according to ranks ...
119
+ # so we redo it the other way around...
120
+ scores = np.empty_like(ranked_scores)
121
+ scores[np.arange(ranked_scores.shape[0])[:, None], ranks] = ranked_scores
122
+
123
+ # save
124
+ if outfile is not None:
125
+ if os.path.isdir(os.path.dirname(outfile)):
126
+ os.makedirs(os.path.dirname(outfile), exist_ok=True)
127
+ np.save(outfile, scores)
128
+ print(f'Scores matrix saved in {outfile}')
129
+ return scores