hugoycj commited on
Commit
2caa1bd
1 Parent(s): 9059c91

feat: Add mast3r dependencies

Browse files
dust3r/model.py CHANGED
@@ -14,6 +14,7 @@ from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, in
14
  from .heads import head_factory
15
  from dust3r.patch_embed import get_patch_embed
16
 
 
17
  import dust3r.utils.path_to_croco # noqa: F401
18
  from models.croco import CroCoNet # noqa
19
 
@@ -78,7 +79,10 @@ class AsymmetricCroCo3DStereo (
78
 
79
  @classmethod
80
  def from_pretrained(cls, pretrained_model_name_or_path, **kw):
81
- return load_model(pretrained_model_name_or_path, device='cpu', landscape_only=kw['landscape_only'])
 
 
 
82
 
83
  def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
84
  self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)
 
14
  from .heads import head_factory
15
  from dust3r.patch_embed import get_patch_embed
16
 
17
+ import urllib
18
  import dust3r.utils.path_to_croco # noqa: F401
19
  from models.croco import CroCoNet # noqa
20
 
 
79
 
80
  @classmethod
81
  def from_pretrained(cls, pretrained_model_name_or_path, **kw):
82
+ if os.path.isfile(pretrained_model_name_or_path) or urllib.parse.urlparse(pretrained_model_name_or_path).scheme in ('http', 'https'):
83
+ return load_model(pretrained_model_name_or_path, device='cpu', landscape_only=kw['landscape_only'])
84
+ else:
85
+ return super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw)
86
 
87
  def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
88
  self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)
mast3r/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
mast3r/catmlp_dpt_head.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # MASt3R heads
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ import mast3r.utils.path_to_dust3r # noqa
11
+ from dust3r.heads.postprocess import reg_dense_depth, reg_dense_conf # noqa
12
+ from dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa
13
+ import dust3r.utils.path_to_croco # noqa
14
+ from models.blocks import Mlp # noqa
15
+
16
+
17
+ def reg_desc(desc, mode):
18
+ if 'norm' in mode:
19
+ desc = desc / desc.norm(dim=-1, keepdim=True)
20
+ else:
21
+ raise ValueError(f"Unknown desc mode {mode}")
22
+ return desc
23
+
24
+
25
+ def postprocess(out, depth_mode, conf_mode, desc_dim=None, desc_mode='norm', two_confs=False, desc_conf_mode=None):
26
+ if desc_conf_mode is None:
27
+ desc_conf_mode = conf_mode
28
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,D
29
+ res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode))
30
+ if conf_mode is not None:
31
+ res['conf'] = reg_dense_conf(fmap[..., 3], mode=conf_mode)
32
+ if desc_dim is not None:
33
+ start = 3 + int(conf_mode is not None)
34
+ res['desc'] = reg_desc(fmap[..., start:start + desc_dim], mode=desc_mode)
35
+ if two_confs:
36
+ res['desc_conf'] = reg_dense_conf(fmap[..., start + desc_dim], mode=desc_conf_mode)
37
+ else:
38
+ res['desc_conf'] = res['conf'].clone()
39
+ return res
40
+
41
+
42
+ class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT):
43
+ """ Mixture between MLP and DPT head that outputs 3d points and local features (with MLP).
44
+ The input for both heads is a concatenation of Encoder and Decoder outputs
45
+ """
46
+
47
+ def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., hooks_idx=None, dim_tokens=None,
48
+ num_channels=1, postprocess=None, feature_dim=256, last_dim=32, depth_mode=None, conf_mode=None, head_type="regression", **kwargs):
49
+ super().__init__(num_channels=num_channels, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=hooks_idx,
50
+ dim_tokens=dim_tokens, depth_mode=depth_mode, postprocess=postprocess, conf_mode=conf_mode, head_type=head_type)
51
+ self.local_feat_dim = local_feat_dim
52
+
53
+ patch_size = net.patch_embed.patch_size
54
+ if isinstance(patch_size, tuple):
55
+ assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance(
56
+ patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints."
57
+ assert patch_size[0] == patch_size[1], "Error, non square patches not managed"
58
+ patch_size = patch_size[0]
59
+ self.patch_size = patch_size
60
+
61
+ self.desc_mode = net.desc_mode
62
+ self.has_conf = has_conf
63
+ self.two_confs = net.two_confs # independent confs for 3D regr and descs
64
+ self.desc_conf_mode = net.desc_conf_mode
65
+ idim = net.enc_embed_dim + net.dec_embed_dim
66
+
67
+ self.head_local_features = Mlp(in_features=idim,
68
+ hidden_features=int(hidden_dim_factor * idim),
69
+ out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2)
70
+
71
+ def forward(self, decout, img_shape):
72
+ # pass through the heads
73
+ pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1]))
74
+
75
+ # recover encoder and decoder outputs
76
+ enc_output, dec_output = decout[0], decout[-1]
77
+ cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate
78
+ H, W = img_shape
79
+ B, S, D = cat_output.shape
80
+
81
+ # extract local_features
82
+ local_features = self.head_local_features(cat_output) # B,S,D
83
+ local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size)
84
+ local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W
85
+
86
+ # post process 3D pts, descriptors and confidences
87
+ out = torch.cat([pts3d, local_features], dim=1)
88
+ if self.postprocess:
89
+ out = self.postprocess(out,
90
+ depth_mode=self.depth_mode,
91
+ conf_mode=self.conf_mode,
92
+ desc_dim=self.local_feat_dim,
93
+ desc_mode=self.desc_mode,
94
+ two_confs=self.two_confs,
95
+ desc_conf_mode=self.desc_conf_mode)
96
+ return out
97
+
98
+
99
+ def mast3r_head_factory(head_type, output_mode, net, has_conf=False):
100
+ """" build a prediction head for the decoder
101
+ """
102
+ if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'):
103
+ local_feat_dim = int(output_mode[10:])
104
+ assert net.dec_depth > 9
105
+ l2 = net.dec_depth
106
+ feature_dim = 256
107
+ last_dim = feature_dim // 2
108
+ out_nchan = 3
109
+ ed = net.enc_embed_dim
110
+ dd = net.dec_embed_dim
111
+ return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
112
+ num_channels=out_nchan + has_conf,
113
+ feature_dim=feature_dim,
114
+ last_dim=last_dim,
115
+ hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
116
+ dim_tokens=[ed, dd, dd, dd],
117
+ postprocess=postprocess,
118
+ depth_mode=net.depth_mode,
119
+ conf_mode=net.conf_mode,
120
+ head_type='regression')
121
+ else:
122
+ raise NotImplementedError(
123
+ f"unexpected {head_type=} and {output_mode=}")
mast3r/cloud_opt/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
mast3r/cloud_opt/sparse_ga.py ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # MASt3R Sparse Global Alignement
6
+ # --------------------------------------------------------
7
+ from tqdm import tqdm
8
+ import roma
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ import os
14
+ from collections import namedtuple
15
+ from functools import lru_cache
16
+ from scipy import sparse as sp
17
+ import copy
18
+
19
+ from mast3r.utils.misc import mkdir_for, hash_md5
20
+ from mast3r.cloud_opt.utils.losses import gamma_loss
21
+ from mast3r.cloud_opt.utils.schedules import linear_schedule, cosine_schedule
22
+ from mast3r.fast_nn import fast_reciprocal_NNs, merge_corres
23
+
24
+ import mast3r.utils.path_to_dust3r # noqa
25
+ from dust3r.utils.geometry import inv, geotrf # noqa
26
+ from dust3r.utils.device import to_cpu, to_numpy, todevice # noqa
27
+ from dust3r.post_process import estimate_focal_knowing_depth # noqa
28
+ from dust3r.optim_factory import adjust_learning_rate_by_lr # noqa
29
+ from dust3r.viz import SceneViz
30
+
31
+
32
+ class SparseGA():
33
+ def __init__(self, img_paths, pairs_in, res_fine, anchors, canonical_paths=None):
34
+ def fetch_img(im):
35
+ def torgb(x): return (x[0].permute(1, 2, 0).numpy() * .5 + .5).clip(min=0., max=1.)
36
+ for im1, im2 in pairs_in:
37
+ if im1['instance'] == im:
38
+ return torgb(im1['img'])
39
+ if im2['instance'] == im:
40
+ return torgb(im2['img'])
41
+ self.canonical_paths = canonical_paths
42
+ self.img_paths = img_paths
43
+ self.imgs = [fetch_img(img) for img in img_paths]
44
+ self.intrinsics = res_fine['intrinsics']
45
+ self.cam2w = res_fine['cam2w']
46
+ self.depthmaps = res_fine['depthmaps']
47
+ self.pts3d = res_fine['pts3d']
48
+ self.pts3d_colors = []
49
+ self.working_device = self.cam2w.device
50
+ for i in range(len(self.imgs)):
51
+ im = self.imgs[i]
52
+ x, y = anchors[i][0][..., :2].detach().cpu().numpy().T
53
+ self.pts3d_colors.append(im[y, x])
54
+ assert self.pts3d_colors[-1].shape == self.pts3d[i].shape
55
+ self.n_imgs = len(self.imgs)
56
+
57
+ def get_focals(self):
58
+ return torch.tensor([ff[0, 0] for ff in self.intrinsics]).to(self.working_device)
59
+
60
+ def get_principal_points(self):
61
+ return torch.stack([ff[:2, -1] for ff in self.intrinsics]).to(self.working_device)
62
+
63
+ def get_im_poses(self):
64
+ return self.cam2w
65
+
66
+ def get_sparse_pts3d(self):
67
+ return self.pts3d
68
+
69
+ def get_dense_pts3d(self, clean_depth=True, subsample=8):
70
+ assert self.canonical_paths, 'cache_path is required for dense 3d points'
71
+ device = self.cam2w.device
72
+ confs = []
73
+ base_focals = []
74
+ anchors = {}
75
+ for i, canon_path in enumerate(self.canonical_paths):
76
+ (canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
77
+ confs.append(conf)
78
+ base_focals.append(focal)
79
+
80
+ H, W = conf.shape
81
+ pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
82
+ idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
83
+ anchors[i] = (pixels, idxs[i], offsets[i])
84
+
85
+ # densify sparse depthmaps
86
+ pts3d, depthmaps = make_pts3d(anchors, self.intrinsics, self.cam2w, [
87
+ d.ravel() for d in self.depthmaps], base_focals=base_focals, ret_depth=True)
88
+
89
+ return pts3d, depthmaps, confs
90
+
91
+ def get_pts3d_colors(self):
92
+ return self.pts3d_colors
93
+
94
+ def get_depthmaps(self):
95
+ return self.depthmaps
96
+
97
+ def get_masks(self):
98
+ return [slice(None, None) for _ in range(len(self.imgs))]
99
+
100
+ def show(self, show_cams=True):
101
+ pts3d, _, confs = self.get_dense_pts3d()
102
+ show_reconstruction(self.imgs, self.intrinsics if show_cams else None, self.cam2w,
103
+ [p.clip(min=-50, max=50) for p in pts3d],
104
+ masks=[c > 1 for c in confs])
105
+
106
+
107
+ def convert_dust3r_pairs_naming(imgs, pairs_in):
108
+ for pair_id in range(len(pairs_in)):
109
+ for i in range(2):
110
+ pairs_in[pair_id][i]['instance'] = imgs[pairs_in[pair_id][i]['idx']]
111
+ return pairs_in
112
+
113
+
114
+ def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
115
+ device='cuda', dtype=torch.float32, shared_intrinsics=False, **kw):
116
+ """ Sparse alignment with MASt3R
117
+ imgs: list of image paths
118
+ cache_path: path where to dump temporary files (str)
119
+
120
+ lr1, niter1: learning rate and #iterations for coarse global alignment (3D matching)
121
+ lr2, niter2: learning rate and #iterations for refinement (2D reproj error)
122
+
123
+ lora_depth: smart dimensionality reduction with depthmaps
124
+ """
125
+ # Convert pair naming convention from dust3r to mast3r
126
+ pairs_in = convert_dust3r_pairs_naming(imgs, pairs_in)
127
+ # forward pass
128
+ pairs, cache_path = forward_mast3r(pairs_in, model,
129
+ cache_path=cache_path, subsample=subsample,
130
+ desc_conf=desc_conf, device=device)
131
+
132
+ # extract canonical pointmaps
133
+ tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \
134
+ prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device)
135
+
136
+ # compute minimal spanning tree
137
+ mst = compute_min_spanning_tree(pairwise_scores)
138
+
139
+ # remove all edges not in the spanning tree?
140
+ # min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]}
141
+ # tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree}
142
+
143
+ # smartly combine all useful data
144
+ imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21 = \
145
+ condense_data(imgs, tmp_pairs, canonical_views, preds_21, dtype)
146
+
147
+ imgs, res_coarse, res_fine = sparse_scene_optimizer(
148
+ imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, mst,
149
+ shared_intrinsics=shared_intrinsics, cache_path=cache_path, device=device, dtype=dtype, **kw)
150
+
151
+ return SparseGA(imgs, pairs_in, res_fine or res_coarse, anchors, canonical_paths)
152
+
153
+
154
+ def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d,
155
+ preds_21, canonical_paths, mst, cache_path,
156
+ lr1=0.2, niter1=500, loss1=gamma_loss(1.1),
157
+ lr2=0.02, niter2=500, loss2=gamma_loss(0.4),
158
+ lossd=gamma_loss(1.1),
159
+ opt_pp=True, opt_depth=True,
160
+ schedule=cosine_schedule, depth_mode='add', exp_depth=False,
161
+ lora_depth=False, # dict(k=96, gamma=15, min_norm=.5),
162
+ shared_intrinsics=False,
163
+ init={}, device='cuda', dtype=torch.float32,
164
+ matching_conf_thr=5., loss_dust3r_w=0.01,
165
+ verbose=True, dbg=()):
166
+ init = copy.deepcopy(init)
167
+ # extrinsic parameters
168
+ vec0001 = torch.tensor((0, 0, 0, 1), dtype=dtype, device=device)
169
+ quats = [nn.Parameter(vec0001.clone()) for _ in range(len(imgs))]
170
+ trans = [nn.Parameter(torch.zeros(3, device=device, dtype=dtype)) for _ in range(len(imgs))]
171
+
172
+ # initialize
173
+ ones = torch.ones((len(imgs), 1), device=device, dtype=dtype)
174
+ median_depths = torch.ones(len(imgs), device=device, dtype=dtype)
175
+ for img in imgs:
176
+ idx = imgs.index(img)
177
+ init_values = init.setdefault(img, {})
178
+ if verbose and init_values:
179
+ print(f' >> initializing img=...{img[-25:]} [{idx}] for {set(init_values)}')
180
+
181
+ K = init_values.get('intrinsics')
182
+ if K is not None:
183
+ K = K.detach()
184
+ focal = K[:2, :2].diag().mean()
185
+ pp = K[:2, 2]
186
+ base_focals[idx] = focal
187
+ pps[idx] = pp
188
+ pps[idx] /= imsizes[idx] # default principal_point would be (0.5, 0.5)
189
+
190
+ depth = init_values.get('depthmap')
191
+ if depth is not None:
192
+ core_depth[idx] = depth.detach()
193
+
194
+ median_depths[idx] = med_depth = core_depth[idx].median()
195
+ core_depth[idx] /= med_depth
196
+
197
+ cam2w = init_values.get('cam2w')
198
+ if cam2w is not None:
199
+ rot = cam2w[:3, :3].detach()
200
+ cam_center = cam2w[:3, 3].detach()
201
+ quats[idx].data[:] = roma.rotmat_to_unitquat(rot)
202
+ trans_offset = med_depth * torch.cat((imsizes[idx] / base_focals[idx] * (0.5 - pps[idx]), ones[:1, 0]))
203
+ trans[idx].data[:] = cam_center + rot @ trans_offset
204
+ del rot
205
+ assert False, 'inverse kinematic chain not yet implemented'
206
+
207
+ # intrinsics parameters
208
+ if shared_intrinsics:
209
+ # Optimize a single set of intrinsics for all cameras. Use averages as init.
210
+ confs = torch.stack([torch.load(pth)[0][2].mean() for pth in canonical_paths]).to(pps)
211
+ weighting = confs / confs.sum()
212
+ pp = nn.Parameter((weighting @ pps).to(dtype))
213
+ pps = [pp for _ in range(len(imgs))]
214
+ focal_m = weighting @ base_focals
215
+ log_focal = nn.Parameter(focal_m.view(1).log().to(dtype))
216
+ log_focals = [log_focal for _ in range(len(imgs))]
217
+ else:
218
+ pps = [nn.Parameter(pp.to(dtype)) for pp in pps]
219
+ log_focals = [nn.Parameter(f.view(1).log().to(dtype)) for f in base_focals]
220
+
221
+ diags = imsizes.float().norm(dim=1)
222
+ min_focals = 0.25 * diags # diag = 1.2~1.4*max(W,H) => beta >= 1/(2*1.2*tan(fov/2)) ~= 0.26
223
+ max_focals = 10 * diags
224
+
225
+ assert len(mst[1]) == len(pps) - 1
226
+
227
+ def make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth):
228
+ # make intrinsics
229
+ focals = torch.cat(log_focals).exp().clip(min=min_focals, max=max_focals)
230
+ pps = torch.stack(pps)
231
+ K = torch.eye(3, dtype=dtype, device=device)[None].expand(len(imgs), 3, 3).clone()
232
+ K[:, 0, 0] = K[:, 1, 1] = focals
233
+ K[:, 0:2, 2] = pps * imsizes
234
+ if trans is None:
235
+ return K
236
+
237
+ # security! optimization is always trying to crush the scale down
238
+ sizes = torch.cat(log_sizes).exp()
239
+ global_scaling = 1 / sizes.min()
240
+
241
+ # compute distance of camera to focal plane
242
+ # tan(fov) = W/2 / focal
243
+ z_cameras = sizes * median_depths * focals / base_focals
244
+
245
+ # make extrinsic
246
+ rel_cam2cam = torch.eye(4, dtype=dtype, device=device)[None].expand(len(imgs), 4, 4).clone()
247
+ rel_cam2cam[:, :3, :3] = roma.unitquat_to_rotmat(F.normalize(torch.stack(quats), dim=1))
248
+ rel_cam2cam[:, :3, 3] = torch.stack(trans)
249
+
250
+ # camera are defined as a kinematic chain
251
+ tmp_cam2w = [None] * len(K)
252
+ tmp_cam2w[mst[0]] = rel_cam2cam[mst[0]]
253
+ for i, j in mst[1]:
254
+ # i is the cam_i_to_world reference, j is the relative pose = cam_j_to_cam_i
255
+ tmp_cam2w[j] = tmp_cam2w[i] @ rel_cam2cam[j]
256
+ tmp_cam2w = torch.stack(tmp_cam2w)
257
+
258
+ # smart reparameterizaton of cameras
259
+ trans_offset = z_cameras.unsqueeze(1) * torch.cat((imsizes / focals.unsqueeze(1) * (0.5 - pps), ones), dim=-1)
260
+ new_trans = global_scaling * (tmp_cam2w[:, :3, 3:4] - tmp_cam2w[:, :3, :3] @ trans_offset.unsqueeze(-1))
261
+ cam2w = torch.cat((torch.cat((tmp_cam2w[:, :3, :3], new_trans), dim=2),
262
+ vec0001.view(1, 1, 4).expand(len(K), 1, 4)), dim=1)
263
+
264
+ depthmaps = []
265
+ for i in range(len(imgs)):
266
+ core_depth_img = core_depth[i]
267
+ if exp_depth:
268
+ core_depth_img = core_depth_img.exp()
269
+ if lora_depth: # compute core_depth as a low-rank decomposition of 3d points
270
+ core_depth_img = lora_depth_proj[i] @ core_depth_img
271
+ if depth_mode == 'add':
272
+ core_depth_img = z_cameras[i] + (core_depth_img - 1) * (median_depths[i] * sizes[i])
273
+ elif depth_mode == 'mul':
274
+ core_depth_img = z_cameras[i] * core_depth_img
275
+ else:
276
+ raise ValueError(f'Bad {depth_mode=}')
277
+ depthmaps.append(global_scaling * core_depth_img)
278
+
279
+ return K, (inv(cam2w), cam2w), depthmaps
280
+
281
+ K = make_K_cam_depth(log_focals, pps, None, None, None, None)
282
+
283
+ if shared_intrinsics:
284
+ print('init focal (shared) = ', to_numpy(K[0, 0, 0]).round(2))
285
+ else:
286
+ print('init focals =', to_numpy(K[:, 0, 0]))
287
+
288
+ # spectral low-rank projection of depthmaps
289
+ if lora_depth:
290
+ core_depth, lora_depth_proj = spectral_projection_of_depthmaps(
291
+ imgs, K, core_depth, subsample, cache_path=cache_path, **lora_depth)
292
+ if exp_depth:
293
+ core_depth = [d.clip(min=1e-4).log() for d in core_depth]
294
+ core_depth = [nn.Parameter(d.ravel().to(dtype)) for d in core_depth]
295
+ log_sizes = [nn.Parameter(torch.zeros(1, dtype=dtype, device=device)) for _ in range(len(imgs))]
296
+
297
+ # Fetch img slices
298
+ _, confs_sum, imgs_slices = corres
299
+
300
+ # Define which pairs are fine to use with matching
301
+ def matching_check(x): return x.max() > matching_conf_thr
302
+ is_matching_ok = {}
303
+ for s in imgs_slices:
304
+ is_matching_ok[s.img1, s.img2] = matching_check(s.confs)
305
+
306
+ # Prepare slices and corres for losses
307
+ dust3r_slices = [s for s in imgs_slices if not is_matching_ok[s.img1, s.img2]]
308
+ loss3d_slices = [s for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
309
+ cleaned_corres2d = []
310
+ for cci, (img1, pix1, confs, confsum, imgs_slices) in enumerate(corres2d):
311
+ cf_sum = 0
312
+ pix1_filtered = []
313
+ confs_filtered = []
314
+ curstep = 0
315
+ cleaned_slices = []
316
+ for img2, slice2 in imgs_slices:
317
+ if is_matching_ok[img1, img2]:
318
+ tslice = slice(curstep, curstep + slice2.stop - slice2.start, slice2.step)
319
+ pix1_filtered.append(pix1[tslice])
320
+ confs_filtered.append(confs[tslice])
321
+ cleaned_slices.append((img2, slice2))
322
+ curstep += slice2.stop - slice2.start
323
+ if pix1_filtered != []:
324
+ pix1_filtered = torch.cat(pix1_filtered)
325
+ confs_filtered = torch.cat(confs_filtered)
326
+ cf_sum = confs_filtered.sum()
327
+ cleaned_corres2d.append((img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices))
328
+
329
+ def loss_dust3r(cam2w, pts3d, pix_loss):
330
+ # In the case no correspondence could be established, fallback to DUSt3R GA regression loss formulation (sparsified)
331
+ loss = 0.
332
+ cf_sum = 0.
333
+ for s in dust3r_slices:
334
+ if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
335
+ continue
336
+ # fallback to dust3r regression
337
+ tgt_pts, tgt_confs = preds_21[imgs[s.img2]][imgs[s.img1]]
338
+ tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
339
+ cf_sum += tgt_confs.sum()
340
+ loss += tgt_confs @ pix_loss(pts3d[s.img1], tgt_pts)
341
+ return loss / cf_sum if cf_sum != 0. else 0.
342
+
343
+ def loss_3d(K, w2cam, pts3d, pix_loss):
344
+ # For each correspondence, we have two 3D points (one for each image of the pair).
345
+ # For each 3D point, we have 2 reproj errors
346
+ if any(v.get('freeze') for v in init.values()):
347
+ pts3d_1 = []
348
+ pts3d_2 = []
349
+ confs = []
350
+ for s in loss3d_slices:
351
+ if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
352
+ continue
353
+ pts3d_1.append(pts3d[s.img1][s.slice1])
354
+ pts3d_2.append(pts3d[s.img2][s.slice2])
355
+ confs.append(s.confs)
356
+ else:
357
+ pts3d_1 = [pts3d[s.img1][s.slice1] for s in loss3d_slices]
358
+ pts3d_2 = [pts3d[s.img2][s.slice2] for s in loss3d_slices]
359
+ confs = [s.confs for s in loss3d_slices]
360
+
361
+ if pts3d_1 != []:
362
+ confs = torch.cat(confs)
363
+ pts3d_1 = torch.cat(pts3d_1)
364
+ pts3d_2 = torch.cat(pts3d_2)
365
+ loss = confs @ pix_loss(pts3d_1, pts3d_2)
366
+ cf_sum = confs.sum()
367
+ else:
368
+ loss = 0.
369
+ cf_sum = 1.
370
+
371
+ return loss / cf_sum
372
+
373
+ def loss_2d(K, w2cam, pts3d, pix_loss):
374
+ # For each correspondence, we have two 3D points (one for each image of the pair).
375
+ # For each 3D point, we have 2 reproj errors
376
+ proj_matrix = K @ w2cam[:, :3]
377
+ loss = npix = 0
378
+ for img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices in cleaned_corres2d:
379
+ if init[imgs[img1]].get('freeze', 0) >= 1:
380
+ continue # no need
381
+ pts3d_in_img1 = [pts3d[img2][slice2] for img2, slice2 in cleaned_slices]
382
+ if pts3d_in_img1 != []:
383
+ pts3d_in_img1 = torch.cat(pts3d_in_img1)
384
+ loss += confs_filtered @ pix_loss(pix1_filtered, reproj2d(proj_matrix[img1], pts3d_in_img1))
385
+ npix += confs_filtered.sum()
386
+
387
+ return loss / npix if npix != 0 else 0.
388
+
389
+ def optimize_loop(loss_func, lr_base, niter, pix_loss, lr_end=0):
390
+ # create optimizer
391
+ params = pps + log_focals + quats + trans + log_sizes + core_depth
392
+ optimizer = torch.optim.Adam(params, lr=1, weight_decay=0, betas=(0.9, 0.9))
393
+ ploss = pix_loss if 'meta' in repr(pix_loss) else (lambda a: pix_loss)
394
+
395
+ with tqdm(total=niter) as bar:
396
+ for iter in range(niter or 1):
397
+ K, (w2cam, cam2w), depthmaps = make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth)
398
+ pts3d = make_pts3d(anchors, K, cam2w, depthmaps, base_focals=base_focals)
399
+ if niter == 0:
400
+ break
401
+
402
+ alpha = (iter / niter)
403
+ lr = schedule(alpha, lr_base, lr_end)
404
+ adjust_learning_rate_by_lr(optimizer, lr)
405
+ pix_loss = ploss(1 - alpha)
406
+ optimizer.zero_grad()
407
+ loss = loss_func(K, w2cam, pts3d, pix_loss) + loss_dust3r_w * loss_dust3r(cam2w, pts3d, lossd)
408
+ loss.backward()
409
+ optimizer.step()
410
+
411
+ # make sure the pose remains well optimizable
412
+ for i in range(len(imgs)):
413
+ quats[i].data[:] /= quats[i].data.norm()
414
+
415
+ loss = float(loss)
416
+ if loss != loss:
417
+ break # NaN loss
418
+ bar.set_postfix_str(f'{lr=:.4f}, {loss=:.3f}')
419
+ bar.update(1)
420
+
421
+ if niter:
422
+ print(f'>> final loss = {loss}')
423
+ return dict(intrinsics=K.detach(), cam2w=cam2w.detach(),
424
+ depthmaps=[d.detach() for d in depthmaps], pts3d=[p.detach() for p in pts3d])
425
+
426
+ # at start, don't optimize 3d points
427
+ for i, img in enumerate(imgs):
428
+ trainable = not (init[img].get('freeze'))
429
+ pps[i].requires_grad_(False)
430
+ log_focals[i].requires_grad_(False)
431
+ quats[i].requires_grad_(trainable)
432
+ trans[i].requires_grad_(trainable)
433
+ log_sizes[i].requires_grad_(trainable)
434
+ core_depth[i].requires_grad_(False)
435
+
436
+ res_coarse = optimize_loop(loss_3d, lr_base=lr1, niter=niter1, pix_loss=loss1)
437
+
438
+ res_fine = None
439
+ if niter2:
440
+ # now we can optimize 3d points
441
+ for i, img in enumerate(imgs):
442
+ if init[img].get('freeze', 0) >= 1:
443
+ continue
444
+ pps[i].requires_grad_(bool(opt_pp))
445
+ log_focals[i].requires_grad_(True)
446
+ core_depth[i].requires_grad_(opt_depth)
447
+
448
+ # refinement with 2d reproj
449
+ res_fine = optimize_loop(loss_2d, lr_base=lr2, niter=niter2, pix_loss=loss2)
450
+
451
+ K = make_K_cam_depth(log_focals, pps, None, None, None, None)
452
+ if shared_intrinsics:
453
+ print('Final focal (shared) = ', to_numpy(K[0, 0, 0]).round(2))
454
+ else:
455
+ print('Final focals =', to_numpy(K[:, 0, 0]))
456
+
457
+ return imgs, res_coarse, res_fine
458
+
459
+
460
+ @lru_cache
461
+ def mask110(device, dtype):
462
+ return torch.tensor((1, 1, 0), device=device, dtype=dtype)
463
+
464
+
465
+ def proj3d(inv_K, pixels, z):
466
+ if pixels.shape[-1] == 2:
467
+ pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1)
468
+ return z.unsqueeze(-1) * (pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype))
469
+
470
+
471
+ def make_pts3d(anchors, K, cam2w, depthmaps, base_focals=None, ret_depth=False):
472
+ focals = K[:, 0, 0]
473
+ invK = inv(K)
474
+ all_pts3d = []
475
+ depth_out = []
476
+
477
+ for img, (pixels, idxs, offsets) in anchors.items():
478
+ # from depthmaps to 3d points
479
+ if base_focals is None:
480
+ pass
481
+ else:
482
+ # compensate for focal
483
+ # depth + depth * (offset - 1) * base_focal / focal
484
+ # = depth * (1 + (offset - 1) * (base_focal / focal))
485
+ offsets = 1 + (offsets - 1) * (base_focals[img] / focals[img])
486
+
487
+ pts3d = proj3d(invK[img], pixels, depthmaps[img][idxs] * offsets)
488
+ if ret_depth:
489
+ depth_out.append(pts3d[..., 2]) # before camera rotation
490
+ # rotate to world coordinate
491
+ pts3d = geotrf(cam2w[img], pts3d)
492
+ all_pts3d.append(pts3d)
493
+
494
+ if ret_depth:
495
+ return all_pts3d, depth_out
496
+ return all_pts3d
497
+
498
+
499
+ def make_dense_pts3d(intrinsics, cam2w, depthmaps, canonical_paths, subsample, device='cuda'):
500
+ base_focals = []
501
+ anchors = {}
502
+ confs = []
503
+ for i, canon_path in enumerate(canonical_paths):
504
+ (canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
505
+ confs.append(conf)
506
+ base_focals.append(focal)
507
+ H, W = conf.shape
508
+ pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
509
+ idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
510
+ anchors[i] = (pixels, idxs[i], offsets[i])
511
+
512
+ # densify sparse depthmaps
513
+ pts3d, depthmaps_out = make_pts3d(anchors, intrinsics, cam2w, [
514
+ d.ravel() for d in depthmaps], base_focals=base_focals, ret_depth=True)
515
+
516
+ return pts3d, depthmaps_out, confs
517
+
518
+
519
+ @torch.no_grad()
520
+ def forward_mast3r(pairs, model, cache_path, desc_conf='desc_conf',
521
+ device='cuda', subsample=8, **matching_kw):
522
+ res_paths = {}
523
+
524
+ for img1, img2 in tqdm(pairs):
525
+ idx1 = hash_md5(img1['instance'])
526
+ idx2 = hash_md5(img2['instance'])
527
+
528
+ path1 = cache_path + f'/forward/{idx1}/{idx2}.pth'
529
+ path2 = cache_path + f'/forward/{idx2}/{idx1}.pth'
530
+ path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx1}-{idx2}.pth'
531
+ path_corres2 = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx2}-{idx1}.pth'
532
+
533
+ if os.path.isfile(path_corres2) and not os.path.isfile(path_corres):
534
+ score, (xy1, xy2, confs) = torch.load(path_corres2)
535
+ torch.save((score, (xy2, xy1, confs)), path_corres)
536
+
537
+ if not all(os.path.isfile(p) for p in (path1, path2, path_corres)):
538
+ if model is None:
539
+ continue
540
+ res = symmetric_inference(model, img1, img2, device=device)
541
+ X11, X21, X22, X12 = [r['pts3d'][0] for r in res]
542
+ C11, C21, C22, C12 = [r['conf'][0] for r in res]
543
+ descs = [r['desc'][0] for r in res]
544
+ qonfs = [r[desc_conf][0] for r in res]
545
+
546
+ # save
547
+ torch.save(to_cpu((X11, C11, X21, C21)), mkdir_for(path1))
548
+ torch.save(to_cpu((X22, C22, X12, C12)), mkdir_for(path2))
549
+
550
+ # perform reciprocal matching
551
+ corres = extract_correspondences(descs, qonfs, device=device, subsample=subsample)
552
+
553
+ conf_score = (C11.mean() * C12.mean() * C21.mean() * C22.mean()).sqrt().sqrt()
554
+ matching_score = (float(conf_score), float(corres[2].sum()), len(corres[2]))
555
+ if cache_path is not None:
556
+ torch.save((matching_score, corres), mkdir_for(path_corres))
557
+
558
+ res_paths[img1['instance'], img2['instance']] = (path1, path2), path_corres
559
+
560
+ del model
561
+ torch.cuda.empty_cache()
562
+
563
+ return res_paths, cache_path
564
+
565
+
566
+ def symmetric_inference(model, img1, img2, device):
567
+ shape1 = torch.from_numpy(img1['true_shape']).to(device, non_blocking=True)
568
+ shape2 = torch.from_numpy(img2['true_shape']).to(device, non_blocking=True)
569
+ img1 = img1['img'].to(device, non_blocking=True)
570
+ img2 = img2['img'].to(device, non_blocking=True)
571
+
572
+ # compute encoder only once
573
+ feat1, feat2, pos1, pos2 = model._encode_image_pairs(img1, img2, shape1, shape2)
574
+
575
+ def decoder(feat1, feat2, pos1, pos2, shape1, shape2):
576
+ dec1, dec2 = model._decoder(feat1, pos1, feat2, pos2)
577
+ with torch.cuda.amp.autocast(enabled=False):
578
+ res1 = model._downstream_head(1, [tok.float() for tok in dec1], shape1)
579
+ res2 = model._downstream_head(2, [tok.float() for tok in dec2], shape2)
580
+ return res1, res2
581
+
582
+ # decoder 1-2
583
+ res11, res21 = decoder(feat1, feat2, pos1, pos2, shape1, shape2)
584
+ # decoder 2-1
585
+ res22, res12 = decoder(feat2, feat1, pos2, pos1, shape2, shape1)
586
+
587
+ return (res11, res21, res22, res12)
588
+
589
+
590
+ def extract_correspondences(feats, qonfs, subsample=8, device=None, ptmap_key='pred_desc'):
591
+ feat11, feat21, feat22, feat12 = feats
592
+ qonf11, qonf21, qonf22, qonf12 = qonfs
593
+ assert feat11.shape[:2] == feat12.shape[:2] == qonf11.shape == qonf12.shape
594
+ assert feat21.shape[:2] == feat22.shape[:2] == qonf21.shape == qonf22.shape
595
+
596
+ if '3d' in ptmap_key:
597
+ opt = dict(device='cpu', workers=32)
598
+ else:
599
+ opt = dict(device=device, dist='dot', block_size=2**13)
600
+
601
+ # matching the two pairs
602
+ idx1 = []
603
+ idx2 = []
604
+ qonf1 = []
605
+ qonf2 = []
606
+ # TODO add non symmetric / pixel_tol options
607
+ for A, B, QA, QB in [(feat11, feat21, qonf11.cpu(), qonf21.cpu()),
608
+ (feat12, feat22, qonf12.cpu(), qonf22.cpu())]:
609
+ nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt)
610
+ nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt)
611
+
612
+ idx1.append(np.r_[nn1to2[0], nn2to1[1]])
613
+ idx2.append(np.r_[nn1to2[1], nn2to1[0]])
614
+ qonf1.append(QA.ravel()[idx1[-1]])
615
+ qonf2.append(QB.ravel()[idx2[-1]])
616
+
617
+ # merge corres from opposite pairs
618
+ H1, W1 = feat11.shape[:2]
619
+ H2, W2 = feat22.shape[:2]
620
+ cat = np.concatenate
621
+
622
+ xy1, xy2, idx = merge_corres(cat(idx1), cat(idx2), (H1, W1), (H2, W2), ret_xy=True, ret_index=True)
623
+ corres = (xy1.copy(), xy2.copy(), np.sqrt(cat(qonf1)[idx] * cat(qonf2)[idx]))
624
+
625
+ return todevice(corres, device)
626
+
627
+
628
+ @torch.no_grad()
629
+ def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_conf_thr=0,
630
+ cache_path=None, device='cuda', **kw):
631
+ canonical_views = {}
632
+ pairwise_scores = torch.zeros((len(imgs), len(imgs)), device=device)
633
+ canonical_paths = []
634
+ preds_21 = {}
635
+
636
+ for img in tqdm(imgs):
637
+ if cache_path:
638
+ cache = os.path.join(cache_path, 'canon_views', hash_md5(img) + f'_{subsample=}_{kw=}.pth')
639
+ canonical_paths.append(cache)
640
+ try:
641
+ (canon, canon2, cconf), focal = torch.load(cache, map_location=device)
642
+ except IOError:
643
+ # cache does not exist yet, we create it!
644
+ canon = focal = None
645
+
646
+ # collect all pred1
647
+ n_pairs = sum((img in pair) for pair in tmp_pairs)
648
+
649
+ ptmaps11 = None
650
+ pixels = {}
651
+ n = 0
652
+ for (img1, img2), ((path1, path2), path_corres) in tmp_pairs.items():
653
+ score = None
654
+ if img == img1:
655
+ X, C, X2, C2 = torch.load(path1, map_location=device)
656
+ score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
657
+ pixels[img2] = xy1, confs
658
+ if img not in preds_21:
659
+ preds_21[img] = {}
660
+ # Subsample preds_21
661
+ preds_21[img][img2] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
662
+
663
+ if img == img2:
664
+ X, C, X2, C2 = torch.load(path2, map_location=device)
665
+ score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
666
+ pixels[img1] = xy2, confs
667
+ if img not in preds_21:
668
+ preds_21[img] = {}
669
+ preds_21[img][img1] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
670
+
671
+ if score is not None:
672
+ i, j = imgs.index(img1), imgs.index(img2)
673
+ # score = score[0]
674
+ # score = np.log1p(score[2])
675
+ score = score[2]
676
+ pairwise_scores[i, j] = score
677
+ pairwise_scores[j, i] = score
678
+
679
+ if canon is not None:
680
+ continue
681
+ if ptmaps11 is None:
682
+ H, W = C.shape
683
+ ptmaps11 = torch.empty((n_pairs, H, W, 3), device=device)
684
+ confs11 = torch.empty((n_pairs, H, W), device=device)
685
+
686
+ ptmaps11[n] = X
687
+ confs11[n] = C
688
+ n += 1
689
+
690
+ if canon is None:
691
+ canon, canon2, cconf = canonical_view(ptmaps11, confs11, subsample, **kw)
692
+ del ptmaps11
693
+ del confs11
694
+
695
+ # compute focals
696
+ H, W = canon.shape[:2]
697
+ pp = torch.tensor([W / 2, H / 2], device=device)
698
+ if focal is None:
699
+ focal = estimate_focal_knowing_depth(canon[None], pp, focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5)
700
+ if cache:
701
+ torch.save(to_cpu(((canon, canon2, cconf), focal)), mkdir_for(cache))
702
+
703
+ # extract depth offsets with correspondences
704
+ core_depth = canon[subsample // 2::subsample, subsample // 2::subsample, 2]
705
+ idxs, offsets = anchor_depth_offsets(canon2, pixels, subsample=subsample)
706
+
707
+ canonical_views[img] = (pp, (H, W), focal.view(1), core_depth, pixels, idxs, offsets)
708
+
709
+ return tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21
710
+
711
+
712
+ def load_corres(path_corres, device, min_conf_thr):
713
+ score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device)
714
+ valid = confs > min_conf_thr if min_conf_thr else slice(None)
715
+ # valid = (xy1 > 0).all(dim=1) & (xy2 > 0).all(dim=1) & (xy1 < 512).all(dim=1) & (xy2 < 512).all(dim=1)
716
+ # print(f'keeping {valid.sum()} / {len(valid)} correspondences')
717
+ return score, (xy1[valid], xy2[valid], confs[valid])
718
+
719
+
720
+ PairOfSlices = namedtuple(
721
+ 'ImgPair', 'img1, slice1, pix1, anchor_idxs1, img2, slice2, pix2, anchor_idxs2, confs, confs_sum')
722
+
723
+
724
+ def condense_data(imgs, tmp_paths, canonical_views, preds_21, dtype=torch.float32):
725
+ # aggregate all data properly
726
+ set_imgs = set(imgs)
727
+
728
+ principal_points = []
729
+ shapes = []
730
+ focals = []
731
+ core_depth = []
732
+ img_anchors = {}
733
+ tmp_pixels = {}
734
+
735
+ for idx1, img1 in enumerate(imgs):
736
+ # load stuff
737
+ pp, shape, focal, anchors, pixels_confs, idxs, offsets = canonical_views[img1]
738
+
739
+ principal_points.append(pp)
740
+ shapes.append(shape)
741
+ focals.append(focal)
742
+ core_depth.append(anchors)
743
+
744
+ img_uv1 = []
745
+ img_idxs = []
746
+ img_offs = []
747
+ cur_n = [0]
748
+
749
+ for img2, (pixels, match_confs) in pixels_confs.items():
750
+ if img2 not in set_imgs:
751
+ continue
752
+ assert len(pixels) == len(idxs[img2]) == len(offsets[img2])
753
+ img_uv1.append(torch.cat((pixels, torch.ones_like(pixels[:, :1])), dim=-1))
754
+ img_idxs.append(idxs[img2])
755
+ img_offs.append(offsets[img2])
756
+ cur_n.append(cur_n[-1] + len(pixels))
757
+ # store the position of 3d points
758
+ tmp_pixels[img1, img2] = pixels.to(dtype), match_confs.to(dtype), slice(*cur_n[-2:])
759
+ img_anchors[idx1] = (torch.cat(img_uv1), torch.cat(img_idxs), torch.cat(img_offs))
760
+
761
+ all_confs = []
762
+ imgs_slices = []
763
+ corres2d = {img: [] for img in range(len(imgs))}
764
+
765
+ for img1, img2 in tmp_paths:
766
+ try:
767
+ pix1, confs1, slice1 = tmp_pixels[img1, img2]
768
+ pix2, confs2, slice2 = tmp_pixels[img2, img1]
769
+ except KeyError:
770
+ continue
771
+ img1 = imgs.index(img1)
772
+ img2 = imgs.index(img2)
773
+ confs = (confs1 * confs2).sqrt()
774
+
775
+ # prepare for loss_3d
776
+ all_confs.append(confs)
777
+ anchor_idxs1 = canonical_views[imgs[img1]][5][imgs[img2]]
778
+ anchor_idxs2 = canonical_views[imgs[img2]][5][imgs[img1]]
779
+ imgs_slices.append(PairOfSlices(img1, slice1, pix1, anchor_idxs1,
780
+ img2, slice2, pix2, anchor_idxs2,
781
+ confs, float(confs.sum())))
782
+
783
+ # prepare for loss_2d
784
+ corres2d[img1].append((pix1, confs, img2, slice2))
785
+ corres2d[img2].append((pix2, confs, img1, slice1))
786
+
787
+ all_confs = torch.cat(all_confs)
788
+ corres = (all_confs, float(all_confs.sum()), imgs_slices)
789
+
790
+ def aggreg_matches(img1, list_matches):
791
+ pix1, confs, img2, slice2 = zip(*list_matches)
792
+ all_pix1 = torch.cat(pix1).to(dtype)
793
+ all_confs = torch.cat(confs).to(dtype)
794
+ return img1, all_pix1, all_confs, float(all_confs.sum()), [(j, sl2) for j, sl2 in zip(img2, slice2)]
795
+ corres2d = [aggreg_matches(img, m) for img, m in corres2d.items()]
796
+
797
+ imsizes = torch.tensor([(W, H) for H, W in shapes], device=pp.device) # (W,H)
798
+ principal_points = torch.stack(principal_points)
799
+ focals = torch.cat(focals)
800
+
801
+ # Subsample preds_21
802
+ subsamp_preds_21 = {}
803
+ for imk, imv in preds_21.items():
804
+ subsamp_preds_21[imk] = {}
805
+ for im2k, (pred, conf) in preds_21[imk].items():
806
+ idxs = img_anchors[imgs.index(im2k)][1]
807
+ subsamp_preds_21[imk][im2k] = (pred[idxs], conf[idxs]) # anchors subsample
808
+
809
+ return imsizes, principal_points, focals, core_depth, img_anchors, corres, corres2d, subsamp_preds_21
810
+
811
+
812
+ def canonical_view(ptmaps11, confs11, subsample, mode='avg-angle'):
813
+ assert len(ptmaps11) == len(confs11) > 0, 'not a single view1 for img={i}'
814
+
815
+ # canonical pointmap is just a weighted average
816
+ confs11 = confs11.unsqueeze(-1) - 0.999
817
+ canon = (confs11 * ptmaps11).sum(0) / confs11.sum(0)
818
+
819
+ canon_depth = ptmaps11[..., 2].unsqueeze(1)
820
+ S = slice(subsample // 2, None, subsample)
821
+ center_depth = canon_depth[:, :, S, S]
822
+ center_depth = torch.clip(center_depth, min=torch.finfo(center_depth.dtype).eps)
823
+
824
+ stacked_depth = F.pixel_unshuffle(canon_depth, subsample)
825
+ stacked_confs = F.pixel_unshuffle(confs11[:, None, :, :, 0], subsample)
826
+
827
+ if mode == 'avg-reldepth':
828
+ rel_depth = stacked_depth / center_depth
829
+ stacked_canon = (stacked_confs * rel_depth).sum(dim=0) / stacked_confs.sum(dim=0)
830
+ canon2 = F.pixel_shuffle(stacked_canon.unsqueeze(0), subsample).squeeze()
831
+
832
+ elif mode == 'avg-angle':
833
+ xy = ptmaps11[..., 0:2].permute(0, 3, 1, 2)
834
+ stacked_xy = F.pixel_unshuffle(xy, subsample)
835
+ B, _, H, W = stacked_xy.shape
836
+ stacked_radius = (stacked_xy.view(B, 2, -1, H, W) - xy[:, :, None, S, S]).norm(dim=1)
837
+ stacked_radius.clip_(min=1e-8)
838
+
839
+ stacked_angle = torch.arctan((stacked_depth - center_depth) / stacked_radius)
840
+ avg_angle = (stacked_confs * stacked_angle).sum(dim=0) / stacked_confs.sum(dim=0)
841
+
842
+ # back to depth
843
+ stacked_depth = stacked_radius.mean(dim=0) * torch.tan(avg_angle)
844
+
845
+ canon2 = F.pixel_shuffle((1 + stacked_depth / canon[S, S, 2]).unsqueeze(0), subsample).squeeze()
846
+ else:
847
+ raise ValueError(f'bad {mode=}')
848
+
849
+ confs = (confs11.square().sum(dim=0) / confs11.sum(dim=0)).squeeze()
850
+ return canon, canon2, confs
851
+
852
+
853
+ def anchor_depth_offsets(canon_depth, pixels, subsample=8):
854
+ device = canon_depth.device
855
+
856
+ # create a 2D grid of anchor 3D points
857
+ H1, W1 = canon_depth.shape
858
+ yx = np.mgrid[subsample // 2:H1:subsample, subsample // 2:W1:subsample]
859
+ H2, W2 = yx.shape[1:]
860
+ cy, cx = yx.reshape(2, -1)
861
+ core_depth = canon_depth[cy, cx]
862
+ assert (core_depth > 0).all()
863
+
864
+ # slave 3d points (attached to core 3d points)
865
+ core_idxs = {} # core_idxs[img2] = {corr_idx:core_idx}
866
+ core_offs = {} # core_offs[img2] = {corr_idx:3d_offset}
867
+
868
+ for img2, (xy1, _confs) in pixels.items():
869
+ px, py = xy1.long().T
870
+
871
+ # find nearest anchor == block quantization
872
+ core_idx = (py // subsample) * W2 + (px // subsample)
873
+ core_idxs[img2] = core_idx.to(device)
874
+
875
+ # compute relative depth offsets w.r.t. anchors
876
+ ref_z = core_depth[core_idx]
877
+ pts_z = canon_depth[py, px]
878
+ offset = pts_z / ref_z
879
+ core_offs[img2] = offset.detach().to(device)
880
+
881
+ return core_idxs, core_offs
882
+
883
+
884
+ def spectral_clustering(graph, k=None, normalized_cuts=False):
885
+ graph.fill_diagonal_(0)
886
+
887
+ # graph laplacian
888
+ degrees = graph.sum(dim=-1)
889
+ laplacian = torch.diag(degrees) - graph
890
+ if normalized_cuts:
891
+ i_inv = torch.diag(degrees.sqrt().reciprocal())
892
+ laplacian = i_inv @ laplacian @ i_inv
893
+
894
+ # compute eigenvectors!
895
+ eigval, eigvec = torch.linalg.eigh(laplacian)
896
+ return eigval[:k], eigvec[:, :k]
897
+
898
+
899
+ def sim_func(p1, p2, gamma):
900
+ diff = (p1 - p2).norm(dim=-1)
901
+ avg_depth = (p1[:, :, 2] + p2[:, :, 2])
902
+ rel_distance = diff / avg_depth
903
+ sim = torch.exp(-gamma * rel_distance.square())
904
+ return sim
905
+
906
+
907
+ def backproj(K, depthmap, subsample):
908
+ H, W = depthmap.shape
909
+ uv = np.mgrid[subsample // 2:subsample * W:subsample, subsample // 2:subsample * H:subsample].T.reshape(H, W, 2)
910
+ xyz = depthmap.unsqueeze(-1) * geotrf(inv(K), todevice(uv, K.device), ncol=3)
911
+ return xyz
912
+
913
+
914
+ def spectral_projection_depth(K, depthmap, subsample, k=64, cache_path='',
915
+ normalized_cuts=True, gamma=7, min_norm=5):
916
+ try:
917
+ if cache_path:
918
+ cache_path = cache_path + f'_{k=}_norm={normalized_cuts}_{gamma=}.pth'
919
+ lora_proj = torch.load(cache_path, map_location=K.device)
920
+
921
+ except IOError:
922
+ # reconstruct 3d points in camera coordinates
923
+ xyz = backproj(K, depthmap, subsample)
924
+
925
+ # compute all distances
926
+ xyz = xyz.reshape(-1, 3)
927
+ graph = sim_func(xyz[:, None], xyz[None, :], gamma=gamma)
928
+ _, lora_proj = spectral_clustering(graph, k, normalized_cuts=normalized_cuts)
929
+
930
+ if cache_path:
931
+ torch.save(lora_proj.cpu(), mkdir_for(cache_path))
932
+
933
+ lora_proj, coeffs = lora_encode_normed(lora_proj, depthmap.ravel(), min_norm=min_norm)
934
+
935
+ # depthmap ~= lora_proj @ coeffs
936
+ return coeffs, lora_proj
937
+
938
+
939
+ def lora_encode_normed(lora_proj, x, min_norm, global_norm=False):
940
+ # encode the pointmap
941
+ coeffs = torch.linalg.pinv(lora_proj) @ x
942
+
943
+ # rectify the norm of basis vector to be ~ equal
944
+ if coeffs.ndim == 1:
945
+ coeffs = coeffs[:, None]
946
+ if global_norm:
947
+ lora_proj *= coeffs[1:].norm() * min_norm / coeffs.shape[1]
948
+ elif min_norm:
949
+ lora_proj *= coeffs.norm(dim=1).clip(min=min_norm)
950
+ # can have rounding errors here!
951
+ coeffs = (torch.linalg.pinv(lora_proj.double()) @ x.double()).float()
952
+
953
+ return lora_proj.detach(), coeffs.detach()
954
+
955
+
956
+ @torch.no_grad()
957
+ def spectral_projection_of_depthmaps(imgs, intrinsics, depthmaps, subsample, cache_path=None, **kw):
958
+ # recover 3d points
959
+ core_depth = []
960
+ lora_proj = []
961
+
962
+ for i, img in enumerate(tqdm(imgs)):
963
+ cache = os.path.join(cache_path, 'lora_depth', hash_md5(img)) if cache_path else None
964
+ depth, proj = spectral_projection_depth(intrinsics[i], depthmaps[i], subsample,
965
+ cache_path=cache, **kw)
966
+ core_depth.append(depth)
967
+ lora_proj.append(proj)
968
+
969
+ return core_depth, lora_proj
970
+
971
+
972
+ def reproj2d(Trf, pts3d):
973
+ res = (pts3d @ Trf[:3, :3].transpose(-1, -2)) + Trf[:3, 3]
974
+ clipped_z = res[:, 2:3].clip(min=1e-3) # make sure we don't have nans!
975
+ uv = res[:, 0:2] / clipped_z
976
+ return uv.clip(min=-1000, max=2000)
977
+
978
+
979
+ def bfs(tree, start_node):
980
+ order, predecessors = sp.csgraph.breadth_first_order(tree, start_node, directed=False)
981
+ ranks = np.arange(len(order))
982
+ ranks[order] = ranks.copy()
983
+ return ranks, predecessors
984
+
985
+
986
+ def compute_min_spanning_tree(pws):
987
+ sparse_graph = sp.dok_array(pws.shape)
988
+ for i, j in pws.nonzero().cpu().tolist():
989
+ sparse_graph[i, j] = -float(pws[i, j])
990
+ msp = sp.csgraph.minimum_spanning_tree(sparse_graph)
991
+
992
+ # now reorder the oriented edges, starting from the central point
993
+ ranks1, _ = bfs(msp, 0)
994
+ ranks2, _ = bfs(msp, ranks1.argmax())
995
+ ranks1, _ = bfs(msp, ranks2.argmax())
996
+ # this is the point farther from any leaf
997
+ root = np.minimum(ranks1, ranks2).argmax()
998
+
999
+ # find the ordered list of edges that describe the tree
1000
+ order, predecessors = sp.csgraph.breadth_first_order(msp, root, directed=False)
1001
+ order = order[1:] # root not do not have a predecessor
1002
+ edges = [(predecessors[i], i) for i in order]
1003
+
1004
+ return root, edges
1005
+
1006
+
1007
+ def show_reconstruction(shapes_or_imgs, K, cam2w, pts3d, gt_cam2w=None, gt_K=None, cam_size=None, masks=None, **kw):
1008
+ viz = SceneViz()
1009
+
1010
+ cc = cam2w[:, :3, 3]
1011
+ cs = cam_size or float(torch.cdist(cc, cc).fill_diagonal_(np.inf).min(dim=0).values.median())
1012
+ colors = 64 + np.random.randint(255 - 64, size=(len(cam2w), 3))
1013
+
1014
+ if isinstance(shapes_or_imgs, np.ndarray) and shapes_or_imgs.ndim == 2:
1015
+ cam_kws = dict(imsizes=shapes_or_imgs[:, ::-1], cam_size=cs)
1016
+ else:
1017
+ imgs = shapes_or_imgs
1018
+ cam_kws = dict(images=imgs, cam_size=cs)
1019
+ if K is not None:
1020
+ viz.add_cameras(to_numpy(cam2w), to_numpy(K), colors=colors, **cam_kws)
1021
+
1022
+ if gt_cam2w is not None:
1023
+ if gt_K is None:
1024
+ gt_K = K
1025
+ viz.add_cameras(to_numpy(gt_cam2w), to_numpy(gt_K), colors=colors, marker='o', **cam_kws)
1026
+
1027
+ if pts3d is not None:
1028
+ for i, p in enumerate(pts3d):
1029
+ if not len(p):
1030
+ continue
1031
+ if masks is None:
1032
+ viz.add_pointcloud(to_numpy(p), color=tuple(colors[i].tolist()))
1033
+ else:
1034
+ viz.add_pointcloud(to_numpy(p), mask=masks[i], color=imgs[i])
1035
+ viz.show(**kw)
mast3r/cloud_opt/triangulation.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Matches Triangulation Utils
6
+ # --------------------------------------------------------
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ # Batched Matches Triangulation
12
+ def batched_triangulate(pts2d, # [B, Ncams, Npts, 2]
13
+ proj_mats): # [B, Ncams, 3, 4] I@E projection matrix
14
+ B, Ncams, Npts, two = pts2d.shape
15
+ assert two==2
16
+ assert proj_mats.shape == (B, Ncams, 3, 4)
17
+ # P - xP
18
+ x = proj_mats[...,0,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,0], proj_mats[...,2,:]) # [B, Ncams, Npts, 4]
19
+ y = proj_mats[...,1,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,1], proj_mats[...,2,:]) # [B, Ncams, Npts, 4]
20
+ eq = torch.cat([x, y], dim=1).transpose(1, 2) # [B, Npts, 2xNcams, 4]
21
+ return torch.linalg.lstsq(eq[...,:3], -eq[...,3]).solution
22
+
23
+ def matches_to_depths(intrinsics, # input camera intrinsics [B, Ncams, 3, 3]
24
+ extrinsics, # input camera extrinsics [B, Ncams, 3, 4]
25
+ matches, # input correspondences [B, Ncams, Npts, 2]
26
+ batchsize=16, # bs for batched processing
27
+ min_num_valids_ratio=.3 # at least this ratio of image pairs need to predict a match for a given pixel of img1
28
+ ):
29
+ B, Nv, H, W, five = matches.shape
30
+ min_num_valids = np.floor(Nv*min_num_valids_ratio)
31
+ out_aggregated_points, out_depths, out_confs = [], [], []
32
+ for b in range(B//batchsize+1): # batched processing
33
+ start, stop = b*batchsize,min(B,(b+1)*batchsize)
34
+ sub_batch=slice(start,stop)
35
+ sub_batchsize = stop-start
36
+ if sub_batchsize==0:continue
37
+ points1, points2, confs = matches[sub_batch, ..., :2], matches[sub_batch, ..., 2:4], matches[sub_batch, ..., -1]
38
+ allpoints = torch.cat([points1.view([sub_batchsize*Nv,1,H*W,2]), points2.view([sub_batchsize*Nv,1,H*W,2])],dim=1) # [BxNv, 2, HxW, 2]
39
+
40
+ allcam_Ps = intrinsics[sub_batch] @ extrinsics[sub_batch,:,:3,:]
41
+ cam_Ps1, cam_Ps2 = allcam_Ps[:,[0]].repeat([1,Nv,1,1]), allcam_Ps[:,1:] # [B, Nv, 3, 4]
42
+ formatted_camPs = torch.cat([cam_Ps1.reshape([sub_batchsize*Nv,1,3,4]), cam_Ps2.reshape([sub_batchsize*Nv,1,3,4])],dim=1) # [BxNv, 2, 3, 4]
43
+
44
+ # Triangulate matches to 3D
45
+ points_3d_world = batched_triangulate(allpoints, formatted_camPs) # [BxNv, HxW, three]
46
+
47
+ # Aggregate pairwise predictions
48
+ points_3d_world = points_3d_world.view([sub_batchsize,Nv,H,W,3])
49
+ valids = points_3d_world.isfinite()
50
+ valids_sum = valids.sum(dim=-1)
51
+ validsuni=valids_sum.unique()
52
+ assert torch.all(torch.logical_or(validsuni == 0 , validsuni == 3)), "Error, can only be nan for none or all XYZ values, not a subset"
53
+ confs[valids_sum==0] = 0.
54
+ points_3d_world = points_3d_world*confs[...,None]
55
+
56
+ # Take care of NaNs
57
+ normalization = confs.sum(dim=1)[:,None].repeat(1,Nv,1,1)
58
+ normalization[normalization <= 1e-5] = 1.
59
+ points_3d_world[valids] /= normalization[valids_sum==3][:,None].repeat(1,3).view(-1)
60
+ points_3d_world[~valids] = 0.
61
+ aggregated_points = points_3d_world.sum(dim=1) # weighted average (by confidence value) ignoring nans
62
+
63
+ # Reset invalid values to nans, with a min visibility threshold
64
+ aggregated_points[valids_sum.sum(dim=1)/3 <= min_num_valids] = torch.nan
65
+
66
+ # From 3D to depths
67
+ refcamE = extrinsics[sub_batch, 0]
68
+ points_3d_camera = (refcamE[:,:3, :3] @ aggregated_points.view(sub_batchsize,-1,3).transpose(-2,-1) + refcamE[:,:3,[3]]).transpose(-2,-1) # [B,HxW,3]
69
+ depths = points_3d_camera.view(sub_batchsize,H,W,3)[..., 2] # [B,H,W]
70
+
71
+ # Cat results
72
+ out_aggregated_points.append(aggregated_points.cpu())
73
+ out_depths.append(depths.cpu())
74
+ out_confs.append(confs.sum(dim=1).cpu())
75
+
76
+ out_aggregated_points = torch.cat(out_aggregated_points,dim=0)
77
+ out_depths = torch.cat(out_depths,dim=0)
78
+ out_confs = torch.cat(out_confs,dim=0)
79
+
80
+ return out_aggregated_points, out_depths, out_confs
mast3r/cloud_opt/tsdf_optimizer.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from matplotlib import pyplot as pl
6
+
7
+ import mast3r.utils.path_to_dust3r # noqa
8
+ from dust3r.utils.geometry import depthmap_to_pts3d, geotrf, inv
9
+
10
+
11
+ class TSDFPostProcess:
12
+ """ Optimizes a signed distance-function to improve depthmaps.
13
+ """
14
+
15
+ def __init__(self, optimizer, subsample=8, TSDF_thresh=0., TSDF_batchsize=int(1e7)):
16
+ self.TSDF_thresh = TSDF_thresh # None -> no TSDF
17
+ self.TSDF_batchsize = TSDF_batchsize
18
+ self.optimizer = optimizer
19
+
20
+ pts3d, depthmaps, confs = optimizer.get_dense_pts3d(clean_depth=False, subsample=subsample)
21
+ pts3d, depthmaps = self._TSDF_postprocess_or_not(pts3d, depthmaps, confs)
22
+ self.pts3d = pts3d
23
+ self.depthmaps = depthmaps
24
+ self.confs = confs
25
+
26
+ def _get_depthmaps(self, TSDF_filtering_thresh=None):
27
+ if TSDF_filtering_thresh:
28
+ self._refine_depths_with_TSDF(self.optimizer, TSDF_filtering_thresh) # compute refined depths if needed
29
+ dms = self.TSDF_im_depthmaps if TSDF_filtering_thresh else self.im_depthmaps
30
+ return [d.exp() for d in dms]
31
+
32
+ @torch.no_grad()
33
+ def _refine_depths_with_TSDF(self, TSDF_filtering_thresh, niter=1, nsamples=1000):
34
+ """
35
+ Leverage TSDF to post-process estimated depths
36
+ for each pixel, find zero level of TSDF along ray (or closest to 0)
37
+ """
38
+ print("Post-Processing Depths with TSDF fusion.")
39
+ self.TSDF_im_depthmaps = []
40
+ alldepths, allposes, allfocals, allpps, allimshapes = self._get_depthmaps(), self.optimizer.get_im_poses(
41
+ ), self.optimizer.get_focals(), self.optimizer.get_principal_points(), self.imshapes
42
+ for vi in tqdm(range(self.optimizer.n_imgs)):
43
+ dm, pose, focal, pp, imshape = alldepths[vi], allposes[vi], allfocals[vi], allpps[vi], allimshapes[vi]
44
+ minvals = torch.full(dm.shape, 1e20)
45
+
46
+ for it in range(niter):
47
+ H, W = dm.shape
48
+ curthresh = (niter - it) * TSDF_filtering_thresh
49
+ dm_offsets = (torch.randn(H, W, nsamples).to(dm) - 1.) * \
50
+ curthresh # decreasing search std along with iterations
51
+ newdm = dm[..., None] + dm_offsets # [H,W,Nsamp]
52
+ curproj = self._backproj_pts3d(in_depths=[newdm], in_im_poses=pose[None], in_focals=focal[None], in_pps=pp[None], in_imshapes=[
53
+ imshape])[0] # [H,W,Nsamp,3]
54
+ # Batched TSDF eval
55
+ curproj = curproj.view(-1, 3)
56
+ tsdf_vals = []
57
+ valids = []
58
+ for batch in range(0, len(curproj), self.TSDF_batchsize):
59
+ values, valid = self._TSDF_query(
60
+ curproj[batch:min(batch + self.TSDF_batchsize, len(curproj))], curthresh)
61
+ tsdf_vals.append(values)
62
+ valids.append(valid)
63
+ tsdf_vals = torch.cat(tsdf_vals, dim=0)
64
+ valids = torch.cat(valids, dim=0)
65
+
66
+ tsdf_vals = tsdf_vals.view([H, W, nsamples])
67
+ valids = valids.view([H, W, nsamples])
68
+
69
+ # keep depth value that got us the closest to 0
70
+ tsdf_vals[~valids] = torch.inf # ignore invalid values
71
+ tsdf_vals = tsdf_vals.abs()
72
+ mins = torch.argmin(tsdf_vals, dim=-1, keepdim=True)
73
+ # when all samples live on a very flat zone, do nothing
74
+ allbad = (tsdf_vals == curthresh).sum(dim=-1) == nsamples
75
+ dm[~allbad] = torch.gather(newdm, -1, mins)[..., 0][~allbad]
76
+
77
+ # Save refined depth map
78
+ self.TSDF_im_depthmaps.append(dm.log())
79
+
80
+ def _TSDF_query(self, qpoints, TSDF_filtering_thresh, weighted=True):
81
+ """
82
+ TSDF query call: returns the weighted TSDF value for each query point [N, 3]
83
+ """
84
+ N, three = qpoints.shape
85
+ assert three == 3
86
+ qpoints = qpoints[None].repeat(self.optimizer.n_imgs, 1, 1) # [B,N,3]
87
+ # get projection coordinates and depths onto images
88
+ coords_and_depth = self._proj_pts3d(pts3d=qpoints, cam2worlds=self.optimizer.get_im_poses(
89
+ ), focals=self.optimizer.get_focals(), pps=self.optimizer.get_principal_points())
90
+ image_coords = coords_and_depth[..., :2].round().to(int) # for now, there's no interpolation...
91
+ proj_depths = coords_and_depth[..., -1]
92
+ # recover depth values after scene optim
93
+ pred_depths, pred_confs, valids = self._get_pixel_depths(image_coords)
94
+ # Gather TSDF scores
95
+ all_SDF_scores = pred_depths - proj_depths # SDF
96
+ unseen = all_SDF_scores < -TSDF_filtering_thresh # handle visibility
97
+ # all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh,TSDF_filtering_thresh) # SDF -> TSDF
98
+ all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh, 1e20) # SDF -> TSDF
99
+ # Gather TSDF confidences and ignore points that are unseen, either OOB during reproj or too far behind seen depth
100
+ all_TSDF_weights = (~unseen).float() * valids.float()
101
+ if weighted:
102
+ all_TSDF_weights = pred_confs.exp() * all_TSDF_weights
103
+ # Aggregate all votes, ignoring zeros
104
+ TSDF_weights = all_TSDF_weights.sum(dim=0)
105
+ valids = TSDF_weights != 0.
106
+ TSDF_wsum = (all_TSDF_weights * all_TSDF_scores).sum(dim=0)
107
+ TSDF_wsum[valids] /= TSDF_weights[valids]
108
+ return TSDF_wsum, valids
109
+
110
+ def _get_pixel_depths(self, image_coords, TSDF_filtering_thresh=None, with_normals_conf=False):
111
+ """ Recover depth value for each input pixel coordinate, along with OOB validity mask
112
+ """
113
+ B, N, two = image_coords.shape
114
+ assert B == self.optimizer.n_imgs and two == 2
115
+ depths = torch.zeros([B, N], device=image_coords.device)
116
+ valids = torch.zeros([B, N], dtype=bool, device=image_coords.device)
117
+ confs = torch.zeros([B, N], device=image_coords.device)
118
+ curconfs = self._get_confs_with_normals() if with_normals_conf else self.im_conf
119
+ for ni, (imc, depth, conf) in enumerate(zip(image_coords, self._get_depthmaps(TSDF_filtering_thresh), curconfs)):
120
+ H, W = depth.shape
121
+ valids[ni] = torch.logical_and(0 <= imc[:, 1], imc[:, 1] <
122
+ H) & torch.logical_and(0 <= imc[:, 0], imc[:, 0] < W)
123
+ imc[~valids[ni]] = 0
124
+ depths[ni] = depth[imc[:, 1], imc[:, 0]]
125
+ confs[ni] = conf.cuda()[imc[:, 1], imc[:, 0]]
126
+ return depths, confs, valids
127
+
128
+ def _get_confs_with_normals(self):
129
+ outconfs = []
130
+ # Confidence basedf on depth gradient
131
+
132
+ class Sobel(nn.Module):
133
+ def __init__(self):
134
+ super().__init__()
135
+ self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False)
136
+ Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]])
137
+ Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]])
138
+ G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
139
+ G = G.unsqueeze(1)
140
+ self.filter.weight = nn.Parameter(G, requires_grad=False)
141
+
142
+ def forward(self, img):
143
+ x = self.filter(img)
144
+ x = torch.mul(x, x)
145
+ x = torch.sum(x, dim=1, keepdim=True)
146
+ x = torch.sqrt(x)
147
+ return x
148
+
149
+ grad_op = Sobel().to(self.im_depthmaps[0].device)
150
+ for conf, depth in zip(self.im_conf, self.im_depthmaps):
151
+ grad_confs = (1. - grad_op(depth[None, None])[0, 0]).clip(0)
152
+ if not 'dbg show':
153
+ pl.imshow(grad_confs.cpu())
154
+ pl.show()
155
+ outconfs.append(conf * grad_confs.to(conf))
156
+ return outconfs
157
+
158
+ def _proj_pts3d(self, pts3d, cam2worlds, focals, pps):
159
+ """
160
+ Projection operation: from 3D points to 2D coordinates + depths
161
+ """
162
+ B = pts3d.shape[0]
163
+ assert pts3d.shape[0] == cam2worlds.shape[0]
164
+ # prepare Extrinsincs
165
+ R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
166
+ Rinv = R.transpose(-2, -1)
167
+ tinv = -Rinv @ t[..., None]
168
+
169
+ # prepare intrinsics
170
+ intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(focals.shape[0], 1, 1)
171
+ if len(focals.shape) == 1:
172
+ focals = torch.stack([focals, focals], dim=-1)
173
+ intrinsics[:, 0, 0] = focals[:, 0]
174
+ intrinsics[:, 1, 1] = focals[:, 1]
175
+ intrinsics[:, :2, -1] = pps
176
+ # Project
177
+ projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
178
+ projpts = projpts.transpose(-2, -1) # [B,N,3]
179
+ projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
180
+ return projpts
181
+
182
+ def _backproj_pts3d(self, in_depths=None, in_im_poses=None,
183
+ in_focals=None, in_pps=None, in_imshapes=None):
184
+ """
185
+ Backprojection operation: from image depths to 3D points
186
+ """
187
+ # Get depths and projection params if not provided
188
+ focals = self.optimizer.get_focals() if in_focals is None else in_focals
189
+ im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
190
+ depth = self._get_depthmaps() if in_depths is None else in_depths
191
+ pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
192
+ imshapes = self.imshapes if in_imshapes is None else in_imshapes
193
+ def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
194
+ dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[[i]]) for i in range(im_poses.shape[0])]
195
+
196
+ def autoprocess(x):
197
+ x = x[0]
198
+ return x.transpose(-2, -1) if len(x.shape) == 4 else x
199
+ return [geotrf(pose, autoprocess(pt)) for pose, pt in zip(im_poses, dm_to_3d)]
200
+
201
+ def _pts3d_to_depth(self, pts3d, cam2worlds, focals, pps):
202
+ """
203
+ Projection operation: from 3D points to 2D coordinates + depths
204
+ """
205
+ B = pts3d.shape[0]
206
+ assert pts3d.shape[0] == cam2worlds.shape[0]
207
+ # prepare Extrinsincs
208
+ R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
209
+ Rinv = R.transpose(-2, -1)
210
+ tinv = -Rinv @ t[..., None]
211
+
212
+ # prepare intrinsics
213
+ intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(self.optimizer.n_imgs, 1, 1)
214
+ if len(focals.shape) == 1:
215
+ focals = torch.stack([focals, focals], dim=-1)
216
+ intrinsics[:, 0, 0] = focals[:, 0]
217
+ intrinsics[:, 1, 1] = focals[:, 1]
218
+ intrinsics[:, :2, -1] = pps
219
+ # Project
220
+ projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
221
+ projpts = projpts.transpose(-2, -1) # [B,N,3]
222
+ projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
223
+ return projpts
224
+
225
+ def _depth_to_pts3d(self, in_depths=None, in_im_poses=None, in_focals=None, in_pps=None, in_imshapes=None):
226
+ """
227
+ Backprojection operation: from image depths to 3D points
228
+ """
229
+ # Get depths and projection params if not provided
230
+ focals = self.optimizer.get_focals() if in_focals is None else in_focals
231
+ im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
232
+ depth = self._get_depthmaps() if in_depths is None else in_depths
233
+ pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
234
+ imshapes = self.imshapes if in_imshapes is None else in_imshapes
235
+
236
+ def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
237
+
238
+ dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i + 1]) for i in range(im_poses.shape[0])]
239
+
240
+ def autoprocess(x):
241
+ x = x[0]
242
+ H, W, three = x.shape[:3]
243
+ return x.transpose(-2, -1) if len(x.shape) == 4 else x
244
+ return [geotrf(pp, autoprocess(pt)) for pp, pt in zip(im_poses, dm_to_3d)]
245
+
246
+ def _get_pts3d(self, TSDF_filtering_thresh=None, **kw):
247
+ """
248
+ return 3D points (possibly filtering depths with TSDF)
249
+ """
250
+ return self._backproj_pts3d(in_depths=self._get_depthmaps(TSDF_filtering_thresh=TSDF_filtering_thresh), **kw)
251
+
252
+ def _TSDF_postprocess_or_not(self, pts3d, depthmaps, confs, niter=1):
253
+ # Setup inner variables
254
+ self.imshapes = [im.shape[:2] for im in self.optimizer.imgs]
255
+ self.im_depthmaps = [dd.log().view(imshape) for dd, imshape in zip(depthmaps, self.imshapes)]
256
+ self.im_conf = confs
257
+
258
+ if self.TSDF_thresh > 0.:
259
+ # Create or update self.TSDF_im_depthmaps that contain logdepths filtered with TSDF
260
+ self._refine_depths_with_TSDF(self.TSDF_thresh, niter=niter)
261
+ depthmaps = [dd.exp() for dd in self.TSDF_im_depthmaps]
262
+ # Turn them into 3D points
263
+ pts3d = self._backproj_pts3d(in_depths=depthmaps)
264
+ depthmaps = [dd.flatten() for dd in depthmaps]
265
+ pts3d = [pp.view(-1, 3) for pp in pts3d]
266
+ return pts3d, depthmaps
267
+
268
+ def get_dense_pts3d(self, clean_depth=True):
269
+ return self.pts3d, self.depthmaps, self.confs
mast3r/cloud_opt/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
mast3r/cloud_opt/utils/losses.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # losses for sparse ga
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import numpy as np
9
+
10
+
11
+ def l05_loss(x, y):
12
+ return torch.linalg.norm(x - y, dim=-1).sqrt()
13
+
14
+
15
+ def l1_loss(x, y):
16
+ return torch.linalg.norm(x - y, dim=-1)
17
+
18
+
19
+ def gamma_loss(gamma, mul=1, offset=None, clip=np.inf):
20
+ if offset is None:
21
+ if gamma == 1:
22
+ return l1_loss
23
+ # d(x**p)/dx = 1 ==> p * x**(p-1) == 1 ==> x = (1/p)**(1/(p-1))
24
+ offset = (1 / gamma)**(1 / (gamma - 1))
25
+
26
+ def loss_func(x, y):
27
+ return (mul * l1_loss(x, y).clip(max=clip) + offset) ** gamma - offset ** gamma
28
+ return loss_func
29
+
30
+
31
+ def meta_gamma_loss():
32
+ return lambda alpha: gamma_loss(alpha)
mast3r/cloud_opt/utils/schedules.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # lr schedules for sparse ga
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+
9
+
10
+ def linear_schedule(alpha, lr_base, lr_end=0):
11
+ lr = (1 - alpha) * lr_base + alpha * lr_end
12
+ return lr
13
+
14
+
15
+ def cosine_schedule(alpha, lr_base, lr_end=0):
16
+ lr = lr_end + (lr_base - lr_end) * (1 + np.cos(alpha * np.pi)) / 2
17
+ return lr
mast3r/colmap/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
mast3r/colmap/database.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # MASt3R to colmap export functions
6
+ # --------------------------------------------------------
7
+ import os
8
+ import torch
9
+ import copy
10
+ import numpy as np
11
+ import torchvision
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+ from scipy.cluster.hierarchy import DisjointSet
15
+ from scipy.spatial.transform import Rotation as R
16
+
17
+ from mast3r.utils.misc import hash_md5
18
+
19
+ from mast3r.fast_nn import extract_correspondences_nonsym, bruteforce_reciprocal_nns
20
+
21
+ import mast3r.utils.path_to_dust3r # noqa
22
+ from dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf # noqa
23
+
24
+
25
+ def convert_im_matches_pairs(img0, img1, image_to_colmap, im_keypoints, matches_im0, matches_im1, viz):
26
+ if viz:
27
+ from matplotlib import pyplot as pl
28
+
29
+ image_mean = torch.as_tensor(
30
+ [0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1)
31
+ image_std = torch.as_tensor(
32
+ [0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1)
33
+ rgb0 = img0['img'] * image_std + image_mean
34
+ rgb0 = torchvision.transforms.functional.to_pil_image(rgb0[0])
35
+ rgb0 = np.array(rgb0)
36
+
37
+ rgb1 = img1['img'] * image_std + image_mean
38
+ rgb1 = torchvision.transforms.functional.to_pil_image(rgb1[0])
39
+ rgb1 = np.array(rgb1)
40
+
41
+ imgs = [rgb0, rgb1]
42
+ # visualize a few matches
43
+ n_viz = 100
44
+ num_matches = matches_im0.shape[0]
45
+ match_idx_to_viz = np.round(np.linspace(
46
+ 0, num_matches - 1, n_viz)).astype(int)
47
+ viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
48
+
49
+ H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
50
+ rgb0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)),
51
+ (0, 0), (0, 0)), 'constant', constant_values=0)
52
+ rgb1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)),
53
+ (0, 0), (0, 0)), 'constant', constant_values=0)
54
+ img = np.concatenate((rgb0, rgb1), axis=1)
55
+ pl.figure()
56
+ pl.imshow(img)
57
+ cmap = pl.get_cmap('jet')
58
+ for ii in range(n_viz):
59
+ (x0, y0), (x1,
60
+ y1) = viz_matches_im0[ii].T, viz_matches_im1[ii].T
61
+ pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(ii /
62
+ (n_viz - 1)), scalex=False, scaley=False)
63
+ pl.show(block=True)
64
+
65
+ matches = [matches_im0.astype(np.float64), matches_im1.astype(np.float64)]
66
+ imgs = [img0, img1]
67
+ imidx0 = img0['idx']
68
+ imidx1 = img1['idx']
69
+ ravel_matches = []
70
+ for j in range(2):
71
+ H, W = imgs[j]['true_shape'][0]
72
+ with np.errstate(invalid='ignore'):
73
+ qx, qy = matches[j].round().astype(np.int32).T
74
+ ravel_matches_j = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
75
+ ravel_matches.append(ravel_matches_j)
76
+ imidxj = imgs[j]['idx']
77
+ for m in ravel_matches_j:
78
+ if m not in im_keypoints[imidxj]:
79
+ im_keypoints[imidxj][m] = 0
80
+ im_keypoints[imidxj][m] += 1
81
+ imid0 = copy.deepcopy(image_to_colmap[imidx0]['colmap_imid'])
82
+ imid1 = copy.deepcopy(image_to_colmap[imidx1]['colmap_imid'])
83
+ if imid0 > imid1:
84
+ colmap_matches = np.stack([ravel_matches[1], ravel_matches[0]], axis=-1)
85
+ imid0, imid1 = imid1, imid0
86
+ imidx0, imidx1 = imidx1, imidx0
87
+ else:
88
+ colmap_matches = np.stack([ravel_matches[0], ravel_matches[1]], axis=-1)
89
+ colmap_matches = np.unique(colmap_matches, axis=0)
90
+ return imidx0, imidx1, colmap_matches
91
+
92
+
93
+ def get_im_matches(pred1, pred2, pairs, image_to_colmap, im_keypoints, conf_thr,
94
+ is_sparse=True, subsample=8, pixel_tol=0, viz=False, device='cuda'):
95
+ im_matches = {}
96
+ for i in range(len(pred1['pts3d'])):
97
+ imidx0 = pairs[i][0]['idx']
98
+ imidx1 = pairs[i][1]['idx']
99
+ if 'desc' in pred1: # mast3r
100
+ descs = [pred1['desc'][i], pred2['desc'][i]]
101
+ confidences = [pred1['desc_conf'][i], pred2['desc_conf'][i]]
102
+ desc_dim = descs[0].shape[-1]
103
+
104
+ if is_sparse:
105
+ corres = extract_correspondences_nonsym(descs[0], descs[1], confidences[0], confidences[1],
106
+ device=device, subsample=subsample, pixel_tol=pixel_tol)
107
+ conf = corres[2]
108
+ mask = conf >= conf_thr
109
+ matches_im0 = corres[0][mask].cpu().numpy()
110
+ matches_im1 = corres[1][mask].cpu().numpy()
111
+ else:
112
+ confidence_masks = [confidences[0] >=
113
+ conf_thr, confidences[1] >= conf_thr]
114
+ pts2d_list, desc_list = [], []
115
+ for j in range(2):
116
+ conf_j = confidence_masks[j].cpu().numpy().flatten()
117
+ true_shape_j = pairs[i][j]['true_shape'][0]
118
+ pts2d_j = xy_grid(
119
+ true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j]
120
+ desc_j = descs[j].detach().cpu(
121
+ ).numpy().reshape(-1, desc_dim)[conf_j]
122
+ pts2d_list.append(pts2d_j)
123
+ desc_list.append(desc_j)
124
+ if len(desc_list[0]) == 0 or len(desc_list[1]) == 0:
125
+ continue
126
+
127
+ nn0, nn1 = bruteforce_reciprocal_nns(desc_list[0], desc_list[1],
128
+ device=device, dist='dot', block_size=2**13)
129
+ reciprocal_in_P0 = (nn1[nn0] == np.arange(len(nn0)))
130
+
131
+ matches_im1 = pts2d_list[1][nn0][reciprocal_in_P0]
132
+ matches_im0 = pts2d_list[0][reciprocal_in_P0]
133
+ else:
134
+ pts3d = [pred1['pts3d'][i], pred2['pts3d_in_other_view'][i]]
135
+ confidences = [pred1['conf'][i], pred2['conf'][i]]
136
+
137
+ if is_sparse:
138
+ corres = extract_correspondences_nonsym(pts3d[0], pts3d[1], confidences[0], confidences[1],
139
+ device=device, subsample=subsample, pixel_tol=pixel_tol,
140
+ ptmap_key='3d')
141
+ conf = corres[2]
142
+ mask = conf >= conf_thr
143
+ matches_im0 = corres[0][mask].cpu().numpy()
144
+ matches_im1 = corres[1][mask].cpu().numpy()
145
+ else:
146
+ confidence_masks = [confidences[0] >=
147
+ conf_thr, confidences[1] >= conf_thr]
148
+ # find 2D-2D matches between the two images
149
+ pts2d_list, pts3d_list = [], []
150
+ for j in range(2):
151
+ conf_j = confidence_masks[j].cpu().numpy().flatten()
152
+ true_shape_j = pairs[i][j]['true_shape'][0]
153
+ pts2d_j = xy_grid(true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j]
154
+ pts3d_j = pts3d[j].detach().cpu().numpy().reshape(-1, 3)[conf_j]
155
+ pts2d_list.append(pts2d_j)
156
+ pts3d_list.append(pts3d_j)
157
+
158
+ PQ, PM = pts3d_list[0], pts3d_list[1]
159
+ if len(PQ) == 0 or len(PM) == 0:
160
+ continue
161
+ reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches(
162
+ PQ, PM)
163
+
164
+ matches_im1 = pts2d_list[1][reciprocal_in_PM]
165
+ matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM]
166
+
167
+ if len(matches_im0) == 0:
168
+ continue
169
+ imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1],
170
+ image_to_colmap, im_keypoints,
171
+ matches_im0, matches_im1, viz)
172
+ im_matches[(imidx0, imidx1)] = colmap_matches
173
+ return im_matches
174
+
175
+
176
+ def get_im_matches_from_cache(pairs, cache_path, desc_conf, subsample,
177
+ image_to_colmap, im_keypoints, conf_thr,
178
+ viz=False, device='cuda'):
179
+ im_matches = {}
180
+ for i in range(len(pairs)):
181
+ imidx0 = pairs[i][0]['idx']
182
+ imidx1 = pairs[i][1]['idx']
183
+
184
+ corres_idx1 = hash_md5(pairs[i][0]['instance'])
185
+ corres_idx2 = hash_md5(pairs[i][1]['instance'])
186
+
187
+ path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx1}-{corres_idx2}.pth'
188
+ if os.path.isfile(path_corres):
189
+ score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device)
190
+ else:
191
+ path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx2}-{corres_idx1}.pth'
192
+ score, (xy2, xy1, confs) = torch.load(path_corres, map_location=device)
193
+ mask = confs >= conf_thr
194
+ matches_im0 = xy1[mask].cpu().numpy()
195
+ matches_im1 = xy2[mask].cpu().numpy()
196
+
197
+ if len(matches_im0) == 0:
198
+ continue
199
+ imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1],
200
+ image_to_colmap, im_keypoints,
201
+ matches_im0, matches_im1, viz)
202
+ im_matches[(imidx0, imidx1)] = colmap_matches
203
+ return im_matches
204
+
205
+
206
+ def export_images(db, images, image_paths, focals, ga_world_to_cam, camera_model):
207
+ # add cameras/images to the db
208
+ # with the output of ga as prior
209
+ image_to_colmap = {}
210
+ im_keypoints = {}
211
+ for idx in range(len(image_paths)):
212
+ im_keypoints[idx] = {}
213
+ H, W = images[idx]["orig_shape"]
214
+ if focals is None:
215
+ focal_x = focal_y = 1.2 * max(W, H)
216
+ prior_focal_length = False
217
+ cx = W / 2.0
218
+ cy = H / 2.0
219
+ elif isinstance(focals[idx], np.ndarray) and len(focals[idx].shape) == 2:
220
+ # intrinsics
221
+ focal_x = focals[idx][0, 0]
222
+ focal_y = focals[idx][1, 1]
223
+ cx = focals[idx][0, 2] * images[idx]["to_orig"][0, 0]
224
+ cy = focals[idx][1, 2] * images[idx]["to_orig"][1, 1]
225
+ prior_focal_length = True
226
+ else:
227
+ focal_x = focal_y = float(focals[idx])
228
+ prior_focal_length = True
229
+ cx = W / 2.0
230
+ cy = H / 2.0
231
+ focal_x = focal_x * images[idx]["to_orig"][0, 0]
232
+ focal_y = focal_y * images[idx]["to_orig"][1, 1]
233
+
234
+ if camera_model == "SIMPLE_PINHOLE":
235
+ model_id = 0
236
+ focal = (focal_x + focal_y) / 2.0
237
+ params = np.asarray([focal, cx, cy], np.float64)
238
+ elif camera_model == "PINHOLE":
239
+ model_id = 1
240
+ params = np.asarray([focal_x, focal_y, cx, cy], np.float64)
241
+ elif camera_model == "SIMPLE_RADIAL":
242
+ model_id = 2
243
+ focal = (focal_x + focal_y) / 2.0
244
+ params = np.asarray([focal, cx, cy, 0.0], np.float64)
245
+ elif camera_model == "OPENCV":
246
+ model_id = 4
247
+ params = np.asarray([focal_x, focal_y, cx, cy, 0.0, 0.0, 0.0, 0.0], np.float64)
248
+ else:
249
+ raise ValueError(f"invalid camera model {camera_model}")
250
+
251
+ H, W = int(H), int(W)
252
+ # OPENCV camera model
253
+ camid = db.add_camera(
254
+ model_id, W, H, params, prior_focal_length=prior_focal_length)
255
+ if ga_world_to_cam is None:
256
+ prior_t = np.zeros(3)
257
+ prior_q = np.zeros(4)
258
+ else:
259
+ q = R.from_matrix(ga_world_to_cam[idx][:3, :3]).as_quat()
260
+ prior_t = ga_world_to_cam[idx][:3, 3]
261
+ prior_q = np.array([q[-1], q[0], q[1], q[2]])
262
+ imid = db.add_image(
263
+ image_paths[idx], camid, prior_q=prior_q, prior_t=prior_t)
264
+ image_to_colmap[idx] = {
265
+ 'colmap_imid': imid,
266
+ 'colmap_camid': camid
267
+ }
268
+ return image_to_colmap, im_keypoints
269
+
270
+
271
+ def export_matches(db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification):
272
+ colmap_image_pairs = []
273
+ # 2D-2D are quite dense
274
+ # we want to remove the very small tracks
275
+ # and export only kpt for which we have values
276
+ # build tracks
277
+ print("building tracks")
278
+ keypoints_to_track_id = {}
279
+ track_id_to_kpt_list = []
280
+ to_merge = []
281
+ for (imidx0, imidx1), colmap_matches in tqdm(im_matches.items()):
282
+ if imidx0 not in keypoints_to_track_id:
283
+ keypoints_to_track_id[imidx0] = {}
284
+ if imidx1 not in keypoints_to_track_id:
285
+ keypoints_to_track_id[imidx1] = {}
286
+
287
+ for m in colmap_matches:
288
+ if m[0] not in keypoints_to_track_id[imidx0] and m[1] not in keypoints_to_track_id[imidx1]:
289
+ # new pair of kpts never seen before
290
+ track_idx = len(track_id_to_kpt_list)
291
+ keypoints_to_track_id[imidx0][m[0]] = track_idx
292
+ keypoints_to_track_id[imidx1][m[1]] = track_idx
293
+ track_id_to_kpt_list.append(
294
+ [(imidx0, m[0]), (imidx1, m[1])])
295
+ elif m[1] not in keypoints_to_track_id[imidx1]:
296
+ # 0 has a track, not 1
297
+ track_idx = keypoints_to_track_id[imidx0][m[0]]
298
+ keypoints_to_track_id[imidx1][m[1]] = track_idx
299
+ track_id_to_kpt_list[track_idx].append((imidx1, m[1]))
300
+ elif m[0] not in keypoints_to_track_id[imidx0]:
301
+ # 1 has a track, not 0
302
+ track_idx = keypoints_to_track_id[imidx1][m[1]]
303
+ keypoints_to_track_id[imidx0][m[0]] = track_idx
304
+ track_id_to_kpt_list[track_idx].append((imidx0, m[0]))
305
+ else:
306
+ # both have tracks, merge them
307
+ track_idx0 = keypoints_to_track_id[imidx0][m[0]]
308
+ track_idx1 = keypoints_to_track_id[imidx1][m[1]]
309
+ if track_idx0 != track_idx1:
310
+ # let's deal with them later
311
+ to_merge.append((track_idx0, track_idx1))
312
+
313
+ # regroup merge targets
314
+ print("merging tracks")
315
+ unique = np.unique(to_merge)
316
+ tree = DisjointSet(unique)
317
+ for track_idx0, track_idx1 in tqdm(to_merge):
318
+ tree.merge(track_idx0, track_idx1)
319
+
320
+ subsets = tree.subsets()
321
+ print("applying merge")
322
+ for setvals in tqdm(subsets):
323
+ new_trackid = len(track_id_to_kpt_list)
324
+ kpt_list = []
325
+ for track_idx in setvals:
326
+ kpt_list.extend(track_id_to_kpt_list[track_idx])
327
+ for imidx, kpid in track_id_to_kpt_list[track_idx]:
328
+ keypoints_to_track_id[imidx][kpid] = new_trackid
329
+ track_id_to_kpt_list.append(kpt_list)
330
+
331
+ # binc = np.bincount([len(v) for v in track_id_to_kpt_list])
332
+ # nonzero = np.nonzero(binc)
333
+ # nonzerobinc = binc[nonzero[0]]
334
+ # print(nonzero[0].tolist())
335
+ # print(nonzerobinc)
336
+ num_valid_tracks = sum(
337
+ [1 for v in track_id_to_kpt_list if len(v) >= min_len_track])
338
+
339
+ keypoints_to_idx = {}
340
+ print(f"squashing keypoints - {num_valid_tracks} valid tracks")
341
+ for imidx, keypoints_imid in tqdm(im_keypoints.items()):
342
+ imid = image_to_colmap[imidx]['colmap_imid']
343
+ keypoints_kept = []
344
+ keypoints_to_idx[imidx] = {}
345
+ for kp in keypoints_imid.keys():
346
+ if kp not in keypoints_to_track_id[imidx]:
347
+ continue
348
+ track_idx = keypoints_to_track_id[imidx][kp]
349
+ track_length = len(track_id_to_kpt_list[track_idx])
350
+ if track_length < min_len_track:
351
+ continue
352
+ keypoints_to_idx[imidx][kp] = len(keypoints_kept)
353
+ keypoints_kept.append(kp)
354
+ if len(keypoints_kept) == 0:
355
+ continue
356
+ keypoints_kept = np.array(keypoints_kept)
357
+ keypoints_kept = np.unravel_index(keypoints_kept, images[imidx]['true_shape'][0])[
358
+ 0].base[:, ::-1].copy().astype(np.float32)
359
+ # rescale coordinates
360
+ keypoints_kept[:, 0] += 0.5
361
+ keypoints_kept[:, 1] += 0.5
362
+ keypoints_kept = geotrf(images[imidx]['to_orig'], keypoints_kept, norm=True)
363
+
364
+ H, W = images[imidx]['orig_shape']
365
+ keypoints_kept[:, 0] = keypoints_kept[:, 0].clip(min=0, max=W - 0.01)
366
+ keypoints_kept[:, 1] = keypoints_kept[:, 1].clip(min=0, max=H - 0.01)
367
+
368
+ db.add_keypoints(imid, keypoints_kept)
369
+
370
+ print("exporting im_matches")
371
+ for (imidx0, imidx1), colmap_matches in im_matches.items():
372
+ imid0, imid1 = image_to_colmap[imidx0]['colmap_imid'], image_to_colmap[imidx1]['colmap_imid']
373
+ assert imid0 < imid1
374
+ final_matches = np.array([[keypoints_to_idx[imidx0][m[0]], keypoints_to_idx[imidx1][m[1]]]
375
+ for m in colmap_matches
376
+ if m[0] in keypoints_to_idx[imidx0] and m[1] in keypoints_to_idx[imidx1]])
377
+ if len(final_matches) > 0:
378
+ colmap_image_pairs.append(
379
+ (images[imidx0]['instance'], images[imidx1]['instance']))
380
+ db.add_matches(imid0, imid1, final_matches)
381
+ if skip_geometric_verification:
382
+ db.add_two_view_geometry(imid0, imid1, final_matches)
383
+ return colmap_image_pairs
mast3r/datasets/__init__.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ from .base.mast3r_base_stereo_view_dataset import MASt3RBaseStereoViewDataset
5
+
6
+ import mast3r.utils.path_to_dust3r # noqa
7
+ from dust3r.datasets.arkitscenes import ARKitScenes as DUSt3R_ARKitScenes # noqa
8
+ from dust3r.datasets.blendedmvs import BlendedMVS as DUSt3R_BlendedMVS # noqa
9
+ from dust3r.datasets.co3d import Co3d as DUSt3R_Co3d # noqa
10
+ from dust3r.datasets.megadepth import MegaDepth as DUSt3R_MegaDepth # noqa
11
+ from dust3r.datasets.scannetpp import ScanNetpp as DUSt3R_ScanNetpp # noqa
12
+ from dust3r.datasets.staticthings3d import StaticThings3D as DUSt3R_StaticThings3D # noqa
13
+ from dust3r.datasets.waymo import Waymo as DUSt3R_Waymo # noqa
14
+ from dust3r.datasets.wildrgbd import WildRGBD as DUSt3R_WildRGBD # noqa
15
+
16
+
17
+ class ARKitScenes(DUSt3R_ARKitScenes, MASt3RBaseStereoViewDataset):
18
+ def __init__(self, *args, split, ROOT, **kwargs):
19
+ super().__init__(*args, split=split, ROOT=ROOT, **kwargs)
20
+ self.is_metric_scale = True
21
+
22
+
23
+ class BlendedMVS(DUSt3R_BlendedMVS, MASt3RBaseStereoViewDataset):
24
+ def __init__(self, *args, ROOT, split=None, **kwargs):
25
+ super().__init__(*args, ROOT=ROOT, split=split, **kwargs)
26
+ self.is_metric_scale = False
27
+
28
+
29
+ class Co3d(DUSt3R_Co3d, MASt3RBaseStereoViewDataset):
30
+ def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
31
+ super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)
32
+ self.is_metric_scale = False
33
+
34
+
35
+ class MegaDepth(DUSt3R_MegaDepth, MASt3RBaseStereoViewDataset):
36
+ def __init__(self, *args, split, ROOT, **kwargs):
37
+ super().__init__(*args, split=split, ROOT=ROOT, **kwargs)
38
+ self.is_metric_scale = False
39
+
40
+
41
+ class ScanNetpp(DUSt3R_ScanNetpp, MASt3RBaseStereoViewDataset):
42
+ def __init__(self, *args, ROOT, **kwargs):
43
+ super().__init__(*args, ROOT=ROOT, **kwargs)
44
+ self.is_metric_scale = True
45
+
46
+
47
+ class StaticThings3D(DUSt3R_StaticThings3D, MASt3RBaseStereoViewDataset):
48
+ def __init__(self, ROOT, *args, mask_bg='rand', **kwargs):
49
+ super().__init__(ROOT, *args, mask_bg=mask_bg, **kwargs)
50
+ self.is_metric_scale = False
51
+
52
+
53
+ class Waymo(DUSt3R_Waymo, MASt3RBaseStereoViewDataset):
54
+ def __init__(self, *args, ROOT, **kwargs):
55
+ super().__init__(*args, ROOT=ROOT, **kwargs)
56
+ self.is_metric_scale = True
57
+
58
+
59
+ class WildRGBD(DUSt3R_WildRGBD, MASt3RBaseStereoViewDataset):
60
+ def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
61
+ super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)
62
+ self.is_metric_scale = True
mast3r/datasets/base/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
mast3r/datasets/base/mast3r_base_stereo_view_dataset.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # base class for implementing datasets
6
+ # --------------------------------------------------------
7
+ import PIL.Image
8
+ import PIL.Image as Image
9
+ import numpy as np
10
+ import torch
11
+ import copy
12
+
13
+ from mast3r.datasets.utils.cropping import (extract_correspondences_from_pts3d,
14
+ gen_random_crops, in2d_rect, crop_to_homography)
15
+
16
+ import mast3r.utils.path_to_dust3r # noqa
17
+ from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset, view_name, is_good_type # noqa
18
+ from dust3r.datasets.utils.transforms import ImgNorm
19
+ from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, geotrf, depthmap_to_camera_coordinates
20
+ import dust3r.datasets.utils.cropping as cropping
21
+
22
+
23
+ class MASt3RBaseStereoViewDataset(BaseStereoViewDataset):
24
+ def __init__(self, *, # only keyword arguments
25
+ split=None,
26
+ resolution=None, # square_size or (width, height) or list of [(width,height), ...]
27
+ transform=ImgNorm,
28
+ aug_crop=False,
29
+ aug_swap=False,
30
+ aug_monocular=False,
31
+ aug_portrait_or_landscape=True, # automatic choice between landscape/portrait when possible
32
+ aug_rot90=False,
33
+ n_corres=0,
34
+ nneg=0,
35
+ n_tentative_crops=4,
36
+ seed=None):
37
+ super().__init__(split=split, resolution=resolution, transform=transform, aug_crop=aug_crop, seed=seed)
38
+ self.is_metric_scale = False # by default a dataset is not metric scale, subclasses can overwrite this
39
+
40
+ self.aug_swap = aug_swap
41
+ self.aug_monocular = aug_monocular
42
+ self.aug_portrait_or_landscape = aug_portrait_or_landscape
43
+ self.aug_rot90 = aug_rot90
44
+
45
+ self.n_corres = n_corres
46
+ self.nneg = nneg
47
+ assert self.n_corres == 'all' or isinstance(self.n_corres, int) or (isinstance(self.n_corres, list) and len(
48
+ self.n_corres) == self.num_views), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}"
49
+ assert self.nneg == 0 or self.n_corres != 'all'
50
+ self.n_tentative_crops = n_tentative_crops
51
+
52
+ def _swap_view_aug(self, views):
53
+ if self._rng.random() < 0.5:
54
+ views.reverse()
55
+
56
+ def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None):
57
+ """ This function:
58
+ - first downsizes the image with LANCZOS inteprolation,
59
+ which is better than bilinear interpolation in
60
+ """
61
+ if not isinstance(image, PIL.Image.Image):
62
+ image = PIL.Image.fromarray(image)
63
+
64
+ # transpose the resolution if necessary
65
+ W, H = image.size # new size
66
+ assert resolution[0] >= resolution[1]
67
+ if H > 1.1 * W:
68
+ # image is portrait mode
69
+ resolution = resolution[::-1]
70
+ elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
71
+ # image is square, so we chose (portrait, landscape) randomly
72
+ if rng.integers(2) and self.aug_portrait_or_landscape:
73
+ resolution = resolution[::-1]
74
+
75
+ # high-quality Lanczos down-scaling
76
+ target_resolution = np.array(resolution)
77
+ image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
78
+
79
+ # actual cropping (if necessary) with bilinear interpolation
80
+ offset_factor = 0.5
81
+ intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
82
+ crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
83
+ image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
84
+
85
+ return image, depthmap, intrinsics2
86
+
87
+ def generate_crops_from_pair(self, view1, view2, resolution, aug_crop_arg, n_crops=4, rng=np.random):
88
+ views = [view1, view2]
89
+
90
+ if aug_crop_arg is False:
91
+ # compatibility
92
+ for i in range(2):
93
+ view = views[i]
94
+ view['img'], view['depthmap'], view['camera_intrinsics'] = self._crop_resize_if_necessary(view['img'],
95
+ view['depthmap'],
96
+ view['camera_intrinsics'],
97
+ resolution,
98
+ rng=rng)
99
+ view['pts3d'], view['valid_mask'] = depthmap_to_absolute_camera_coordinates(view['depthmap'],
100
+ view['camera_intrinsics'],
101
+ view['camera_pose'])
102
+ return
103
+
104
+ # extract correspondences
105
+ corres = extract_correspondences_from_pts3d(*views, target_n_corres=None, rng=rng)
106
+
107
+ # generate 4 random crops in each view
108
+ view_crops = []
109
+ crops_resolution = []
110
+ corres_msks = []
111
+ for i in range(2):
112
+
113
+ if aug_crop_arg == 'auto':
114
+ S = min(views[i]['img'].size)
115
+ R = min(resolution)
116
+ aug_crop = S * (S - R) // R
117
+ aug_crop = max(.1 * S, aug_crop) # for cropping: augment scale of at least 10%, and more if possible
118
+ else:
119
+ aug_crop = aug_crop_arg
120
+
121
+ # tranpose the target resolution if necessary
122
+ assert resolution[0] >= resolution[1]
123
+ W, H = imsize = views[i]['img'].size
124
+ crop_resolution = resolution
125
+ if H > 1.1 * W:
126
+ # image is portrait mode
127
+ crop_resolution = resolution[::-1]
128
+ elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
129
+ # image is square, so we chose (portrait, landscape) randomly
130
+ if rng.integers(2):
131
+ crop_resolution = resolution[::-1]
132
+
133
+ crops = gen_random_crops(imsize, n_crops, crop_resolution, aug_crop=aug_crop, rng=rng)
134
+ view_crops.append(crops)
135
+ crops_resolution.append(crop_resolution)
136
+
137
+ # compute correspondences
138
+ corres_msks.append(in2d_rect(corres[i], crops))
139
+
140
+ # compute IoU for each
141
+ intersection = np.float32(corres_msks[0]).T @ np.float32(corres_msks[1])
142
+ # select best pair of crops
143
+ best = np.unravel_index(intersection.argmax(), (n_crops, n_crops))
144
+ crops = [view_crops[i][c] for i, c in enumerate(best)]
145
+
146
+ # crop with the homography
147
+ for i in range(2):
148
+ view = views[i]
149
+ imsize, K_new, R, H = crop_to_homography(view['camera_intrinsics'], crops[i], crops_resolution[i])
150
+ # imsize, K_new, H = upscale_homography(imsize, resolution, K_new, H)
151
+
152
+ # update camera params
153
+ K_old = view['camera_intrinsics']
154
+ view['camera_intrinsics'] = K_new
155
+ view['camera_pose'] = view['camera_pose'].copy()
156
+ view['camera_pose'][:3, :3] = view['camera_pose'][:3, :3] @ R
157
+
158
+ # apply homography to image and depthmap
159
+ homo8 = (H / H[2, 2]).ravel().tolist()[:8]
160
+ view['img'] = view['img'].transform(imsize, Image.Transform.PERSPECTIVE,
161
+ homo8,
162
+ resample=Image.Resampling.BICUBIC)
163
+
164
+ depthmap2 = depthmap_to_camera_coordinates(view['depthmap'], K_old)[0] @ R[:, 2]
165
+ view['depthmap'] = np.array(Image.fromarray(depthmap2).transform(
166
+ imsize, Image.Transform.PERSPECTIVE, homo8))
167
+
168
+ if 'track_labels' in view:
169
+ # convert from uint64 --> uint32, because PIL.Image cannot handle uint64
170
+ mapping, track_labels = np.unique(view['track_labels'], return_inverse=True)
171
+ track_labels = track_labels.astype(np.uint32).reshape(view['track_labels'].shape)
172
+
173
+ # homography transformation
174
+ res = np.array(Image.fromarray(track_labels).transform(imsize, Image.Transform.PERSPECTIVE, homo8))
175
+ view['track_labels'] = mapping[res] # mapping back to uint64
176
+
177
+ # recompute 3d points from scratch
178
+ view['pts3d'], view['valid_mask'] = depthmap_to_absolute_camera_coordinates(view['depthmap'],
179
+ view['camera_intrinsics'],
180
+ view['camera_pose'])
181
+
182
+ def __getitem__(self, idx):
183
+ if isinstance(idx, tuple):
184
+ # the idx is specifying the aspect-ratio
185
+ idx, ar_idx = idx
186
+ else:
187
+ assert len(self._resolutions) == 1
188
+ ar_idx = 0
189
+
190
+ # set-up the rng
191
+ if self.seed: # reseed for each __getitem__
192
+ self._rng = np.random.default_rng(seed=self.seed + idx)
193
+ elif not hasattr(self, '_rng'):
194
+ seed = torch.initial_seed() # this is different for each dataloader process
195
+ self._rng = np.random.default_rng(seed=seed)
196
+
197
+ # over-loaded code
198
+ resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
199
+ views = self._get_views(idx, resolution, self._rng)
200
+ assert len(views) == self.num_views
201
+
202
+ for v, view in enumerate(views):
203
+ assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
204
+ view['idx'] = (idx, ar_idx, v)
205
+ view['is_metric_scale'] = self.is_metric_scale
206
+
207
+ assert 'camera_intrinsics' in view
208
+ if 'camera_pose' not in view:
209
+ view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
210
+ else:
211
+ assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
212
+ assert 'pts3d' not in view
213
+ assert 'valid_mask' not in view
214
+ assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
215
+
216
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
217
+
218
+ view['pts3d'] = pts3d
219
+ view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
220
+
221
+ self.generate_crops_from_pair(views[0], views[1], resolution=resolution,
222
+ aug_crop_arg=self.aug_crop,
223
+ n_crops=self.n_tentative_crops,
224
+ rng=self._rng)
225
+ for v, view in enumerate(views):
226
+ # encode the image
227
+ width, height = view['img'].size
228
+ view['true_shape'] = np.int32((height, width))
229
+ view['img'] = self.transform(view['img'])
230
+ # Pixels for which depth is fundamentally undefined
231
+ view['sky_mask'] = (view['depthmap'] < 0)
232
+
233
+ if self.aug_swap:
234
+ self._swap_view_aug(views)
235
+
236
+ if self.aug_monocular:
237
+ if self._rng.random() < self.aug_monocular:
238
+ views = [copy.deepcopy(views[0]) for _ in range(len(views))]
239
+
240
+ # automatic extraction of correspondences from pts3d + pose
241
+ if self.n_corres > 0 and ('corres' not in view):
242
+ corres1, corres2, valid = extract_correspondences_from_pts3d(*views, self.n_corres,
243
+ self._rng, nneg=self.nneg)
244
+ views[0]['corres'] = corres1
245
+ views[1]['corres'] = corres2
246
+ views[0]['valid_corres'] = valid
247
+ views[1]['valid_corres'] = valid
248
+
249
+ if self.aug_rot90 is False:
250
+ pass
251
+ elif self.aug_rot90 == 'same':
252
+ rotate_90(views, k=self._rng.choice(4))
253
+ elif self.aug_rot90 == 'diff':
254
+ rotate_90(views[:1], k=self._rng.choice(4))
255
+ rotate_90(views[1:], k=self._rng.choice(4))
256
+ else:
257
+ raise ValueError(f'Bad value for {self.aug_rot90=}')
258
+
259
+ # check data-types metric_scale
260
+ for v, view in enumerate(views):
261
+ if 'corres' not in view:
262
+ view['corres'] = np.full((self.n_corres, 2), np.nan, dtype=np.float32)
263
+
264
+ # check all datatypes
265
+ for key, val in view.items():
266
+ res, err_msg = is_good_type(key, val)
267
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
268
+ K = view['camera_intrinsics']
269
+
270
+ # check shapes
271
+ assert view['depthmap'].shape == view['img'].shape[1:]
272
+ assert view['depthmap'].shape == view['pts3d'].shape[:2]
273
+ assert view['depthmap'].shape == view['valid_mask'].shape
274
+
275
+ # last thing done!
276
+ for view in views:
277
+ # transpose to make sure all views are the same size
278
+ transpose_to_landscape(view)
279
+ # this allows to check whether the RNG is is the same state each time
280
+ view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
281
+
282
+ return views
283
+
284
+
285
+ def transpose_to_landscape(view, revert=False):
286
+ height, width = view['true_shape']
287
+
288
+ if width < height:
289
+ if revert:
290
+ height, width = width, height
291
+
292
+ # rectify portrait to landscape
293
+ assert view['img'].shape == (3, height, width)
294
+ view['img'] = view['img'].swapaxes(1, 2)
295
+
296
+ assert view['valid_mask'].shape == (height, width)
297
+ view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
298
+
299
+ assert view['sky_mask'].shape == (height, width)
300
+ view['sky_mask'] = view['sky_mask'].swapaxes(0, 1)
301
+
302
+ assert view['depthmap'].shape == (height, width)
303
+ view['depthmap'] = view['depthmap'].swapaxes(0, 1)
304
+
305
+ assert view['pts3d'].shape == (height, width, 3)
306
+ view['pts3d'] = view['pts3d'].swapaxes(0, 1)
307
+
308
+ # transpose x and y pixels
309
+ view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]
310
+
311
+ # transpose correspondences x and y
312
+ view['corres'] = view['corres'][:, [1, 0]]
313
+
314
+
315
+ def rotate_90(views, k=1):
316
+ from scipy.spatial.transform import Rotation
317
+ # print('rotation =', k)
318
+
319
+ RT = np.eye(4, dtype=np.float32)
320
+ RT[:3, :3] = Rotation.from_euler('z', 90 * k, degrees=True).as_matrix()
321
+
322
+ for view in views:
323
+ view['img'] = torch.rot90(view['img'], k=k, dims=(-2, -1)) # WARNING!! dims=(-1,-2) != dims=(-2,-1)
324
+ view['depthmap'] = np.rot90(view['depthmap'], k=k).copy()
325
+ view['camera_pose'] = view['camera_pose'] @ RT
326
+
327
+ RT2 = np.eye(3, dtype=np.float32)
328
+ RT2[:2, :2] = RT[:2, :2] * ((1, -1), (-1, 1))
329
+ H, W = view['depthmap'].shape
330
+ if k % 4 == 0:
331
+ pass
332
+ elif k % 4 == 1:
333
+ # top-left (0,0) pixel becomes (0,H-1)
334
+ RT2[:2, 2] = (0, H - 1)
335
+ elif k % 4 == 2:
336
+ # top-left (0,0) pixel becomes (W-1,H-1)
337
+ RT2[:2, 2] = (W - 1, H - 1)
338
+ elif k % 4 == 3:
339
+ # top-left (0,0) pixel becomes (W-1,0)
340
+ RT2[:2, 2] = (W - 1, 0)
341
+ else:
342
+ raise ValueError(f'Bad value for {k=}')
343
+
344
+ view['camera_intrinsics'][:2, 2] = geotrf(RT2, view['camera_intrinsics'][:2, 2])
345
+ if k % 2 == 1:
346
+ K = view['camera_intrinsics']
347
+ np.fill_diagonal(K, K.diagonal()[[1, 0, 2]])
348
+
349
+ pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
350
+ view['pts3d'] = pts3d
351
+ view['valid_mask'] = np.rot90(view['valid_mask'], k=k).copy()
352
+ view['sky_mask'] = np.rot90(view['sky_mask'], k=k).copy()
353
+
354
+ view['corres'] = geotrf(RT2, view['corres']).round().astype(view['corres'].dtype)
355
+ view['true_shape'] = np.int32((H, W))
mast3r/datasets/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
mast3r/datasets/utils/cropping.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # cropping/match extraction
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import mast3r.utils.path_to_dust3r # noqa
9
+ from dust3r.utils.device import to_numpy
10
+ from dust3r.utils.geometry import inv, geotrf
11
+
12
+
13
+ def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
14
+ is_reciprocal1 = (corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)))
15
+ pos1 = is_reciprocal1.nonzero()[0]
16
+ pos2 = corres_1_to_2[pos1]
17
+ if ret_recip:
18
+ return is_reciprocal1, pos1, pos2
19
+ return pos1, pos2
20
+
21
+
22
+ def extract_correspondences_from_pts3d(view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0):
23
+ view1, view2 = to_numpy((view1, view2))
24
+ # project pixels from image1 --> 3d points --> image2 pixels
25
+ shape1, corres1_to_2 = reproject_view(view1['pts3d'], view2)
26
+ shape2, corres2_to_1 = reproject_view(view2['pts3d'], view1)
27
+
28
+ # compute reciprocal correspondences:
29
+ # pos1 == valid pixels (correspondences) in image1
30
+ is_reciprocal1, pos1, pos2 = reciprocal_1d(corres1_to_2, corres2_to_1, ret_recip=True)
31
+ is_reciprocal2 = (corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)))
32
+
33
+ if target_n_corres is None:
34
+ if ret_xy:
35
+ pos1 = unravel_xy(pos1, shape1)
36
+ pos2 = unravel_xy(pos2, shape2)
37
+ return pos1, pos2
38
+
39
+ available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
40
+ target_n_positives = int(target_n_corres * (1 - nneg))
41
+ n_positives = min(len(pos1), target_n_positives)
42
+ n_negatives = min(target_n_corres - n_positives, available_negatives)
43
+
44
+ if n_negatives + n_positives != target_n_corres:
45
+ # should be really rare => when there are not enough negatives
46
+ # in that case, break nneg and add a few more positives ?
47
+ n_positives = target_n_corres - n_negatives
48
+ assert n_positives <= len(pos1)
49
+
50
+ assert n_positives <= len(pos1)
51
+ assert n_positives <= len(pos2)
52
+ assert n_negatives <= (~is_reciprocal1).sum()
53
+ assert n_negatives <= (~is_reciprocal2).sum()
54
+ assert n_positives + n_negatives == target_n_corres
55
+
56
+ valid = np.ones(n_positives, dtype=bool)
57
+ if n_positives < len(pos1):
58
+ # random sub-sampling of valid correspondences
59
+ perm = rng.permutation(len(pos1))[:n_positives]
60
+ pos1 = pos1[perm]
61
+ pos2 = pos2[perm]
62
+
63
+ if n_negatives > 0:
64
+ # add false correspondences if not enough
65
+ def norm(p): return p / p.sum()
66
+ pos1 = np.r_[pos1, rng.choice(shape1[0] * shape1[1], size=n_negatives, replace=False, p=norm(~is_reciprocal1))]
67
+ pos2 = np.r_[pos2, rng.choice(shape2[0] * shape2[1], size=n_negatives, replace=False, p=norm(~is_reciprocal2))]
68
+ valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
69
+
70
+ # convert (x+W*y) back to 2d (x,y) coordinates
71
+ if ret_xy:
72
+ pos1 = unravel_xy(pos1, shape1)
73
+ pos2 = unravel_xy(pos2, shape2)
74
+ return pos1, pos2, valid
75
+
76
+
77
+ def reproject_view(pts3d, view2):
78
+ shape = view2['pts3d'].shape[:2]
79
+ return reproject(pts3d, view2['camera_intrinsics'], inv(view2['camera_pose']), shape)
80
+
81
+
82
+ def reproject(pts3d, K, world2cam, shape):
83
+ H, W, THREE = pts3d.shape
84
+ assert THREE == 3
85
+
86
+ # reproject in camera2 space
87
+ with np.errstate(divide='ignore', invalid='ignore'):
88
+ pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
89
+
90
+ # quantize to pixel positions
91
+ return (H, W), ravel_xy(pos, shape)
92
+
93
+
94
+ def ravel_xy(pos, shape):
95
+ H, W = shape
96
+ with np.errstate(invalid='ignore'):
97
+ qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
98
+ quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
99
+ return quantized_pos
100
+
101
+
102
+ def unravel_xy(pos, shape):
103
+ # convert (x+W*y) back to 2d (x,y) coordinates
104
+ return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
105
+
106
+
107
+ def _rotation_origin_to_pt(target):
108
+ """ Align the origin (0,0,1) with the target point (x,y,1) in projective space.
109
+ Method: rotate z to put target on (x'+,0,1), then rotate on Y to get (0,0,1) and un-rotate z.
110
+ """
111
+ from scipy.spatial.transform import Rotation
112
+ x, y = target
113
+ rot_z = np.arctan2(y, x)
114
+ rot_y = np.arctan(np.linalg.norm(target))
115
+ R = Rotation.from_euler('ZYZ', [rot_z, rot_y, -rot_z]).as_matrix()
116
+ return R
117
+
118
+
119
+ def _dotmv(Trf, pts, ncol=None, norm=False):
120
+ assert Trf.ndim >= 2
121
+ ncol = ncol or pts.shape[-1]
122
+
123
+ # adapt shape if necessary
124
+ output_reshape = pts.shape[:-1]
125
+ if Trf.ndim >= 3:
126
+ n = Trf.ndim - 2
127
+ assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
128
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
129
+
130
+ if pts.ndim > Trf.ndim:
131
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
132
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
133
+ elif pts.ndim == 2:
134
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
135
+ pts = pts[:, None, :]
136
+
137
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
138
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
139
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
140
+
141
+ elif pts.shape[-1] == Trf.shape[-1]:
142
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
143
+ pts = pts @ Trf
144
+ else:
145
+ pts = Trf @ pts.T
146
+ if pts.ndim >= 2:
147
+ pts = pts.swapaxes(-1, -2)
148
+
149
+ if norm:
150
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
151
+ if norm != 1:
152
+ pts *= norm
153
+
154
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
155
+ return res
156
+
157
+
158
+ def crop_to_homography(K, crop, target_size=None):
159
+ """ Given an image and its intrinsics,
160
+ we want to replicate a rectangular crop with an homography,
161
+ so that the principal point of the new 'crop' is centered.
162
+ """
163
+ # build intrinsics for the crop
164
+ crop = np.round(crop)
165
+ crop_size = crop[2:] - crop[:2]
166
+ K2 = K.copy() # same focal
167
+ K2[:2, 2] = crop_size / 2 # new principal point is perfectly centered
168
+
169
+ # find which corner is the most far-away from current principal point
170
+ # so that the final homography does not go over the image borders
171
+ corners = crop.reshape(-1, 2)
172
+ corner_idx = np.abs(corners - K[:2, 2]).argmax(0)
173
+ corner = corners[corner_idx, [0, 1]]
174
+ # align with the corresponding corner from the target view
175
+ corner2 = np.c_[[0, 0], crop_size][[0, 1], corner_idx]
176
+
177
+ old_pt = _dotmv(np.linalg.inv(K), corner, norm=1)
178
+ new_pt = _dotmv(np.linalg.inv(K2), corner2, norm=1)
179
+ R = _rotation_origin_to_pt(old_pt) @ np.linalg.inv(_rotation_origin_to_pt(new_pt))
180
+
181
+ if target_size is not None:
182
+ imsize = target_size
183
+ target_size = np.asarray(target_size)
184
+ scaling = min(target_size / crop_size)
185
+ K2[:2] *= scaling
186
+ K2[:2, 2] = target_size / 2
187
+ else:
188
+ imsize = tuple(np.int32(crop_size).tolist())
189
+
190
+ return imsize, K2, R, K @ R @ np.linalg.inv(K2)
191
+
192
+
193
+ def gen_random_crops(imsize, n_crops, resolution, aug_crop, rng=np.random):
194
+ """ Generate random crops of size=resolution,
195
+ for an input image upscaled to (imsize + randint(0 , aug_crop))
196
+ """
197
+ resolution_crop = np.array(resolution) * min(np.array(imsize) / resolution)
198
+
199
+ # (virtually) upscale the input image
200
+ # scaling = rng.uniform(1, 1+(aug_crop+1)/min(imsize))
201
+ scaling = np.exp(rng.uniform(0, np.log(1 + aug_crop / min(imsize))))
202
+ imsize2 = np.int32(np.array(imsize) * scaling)
203
+
204
+ # generate some random crops
205
+ topleft = rng.random((n_crops, 2)) * (imsize2 - resolution_crop)
206
+ crops = np.c_[topleft, topleft + resolution_crop]
207
+ # print(f"{scaling=}, {topleft=}")
208
+ # reduce the resolution to come back to original size
209
+ crops /= scaling
210
+ return crops
211
+
212
+
213
+ def in2d_rect(corres, crops):
214
+ # corres = (N,2)
215
+ # crops = (M,4)
216
+ # output = (N, M)
217
+ is_sup = (corres[:, None] >= crops[None, :, 0:2])
218
+ is_inf = (corres[:, None] < crops[None, :, 2:4])
219
+ return (is_sup & is_inf).all(axis=-1)
mast3r/demo.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
3
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ #
5
+ # --------------------------------------------------------
6
+ # sparse gradio demo functions
7
+ # --------------------------------------------------------
8
+ import math
9
+ import gradio
10
+ import os
11
+ import numpy as np
12
+ import functools
13
+ import trimesh
14
+ import copy
15
+ from scipy.spatial.transform import Rotation
16
+ import tempfile
17
+ import shutil
18
+
19
+ from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
20
+ from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
21
+
22
+ import mast3r.utils.path_to_dust3r # noqa
23
+ from dust3r.image_pairs import make_pairs
24
+ from dust3r.utils.image import load_images
25
+ from dust3r.utils.device import to_numpy
26
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
27
+
28
+ import matplotlib.pyplot as pl
29
+
30
+
31
+ class SparseGAState():
32
+ def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
33
+ self.sparse_ga = sparse_ga
34
+ self.cache_dir = cache_dir
35
+ self.outfile_name = outfile_name
36
+ self.should_delete = should_delete
37
+
38
+ def __del__(self):
39
+ if not self.should_delete:
40
+ return
41
+ if self.cache_dir is not None and os.path.isdir(self.cache_dir):
42
+ shutil.rmtree(self.cache_dir)
43
+ self.cache_dir = None
44
+ if self.outfile_name is not None and os.path.isfile(self.outfile_name):
45
+ os.remove(self.outfile_name)
46
+ self.outfile_name = None
47
+
48
+
49
+ def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
50
+ cam_color=None, as_pointcloud=False,
51
+ transparent_cams=False, silent=False):
52
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
53
+ pts3d = to_numpy(pts3d)
54
+ imgs = to_numpy(imgs)
55
+ focals = to_numpy(focals)
56
+ cams2world = to_numpy(cams2world)
57
+
58
+ scene = trimesh.Scene()
59
+
60
+ # full pointcloud
61
+ if as_pointcloud:
62
+ pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
63
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
64
+ valid_msk = np.isfinite(pts.sum(axis=1))
65
+ pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
66
+ scene.add_geometry(pct)
67
+ else:
68
+ meshes = []
69
+ for i in range(len(imgs)):
70
+ pts3d_i = pts3d[i].reshape(imgs[i].shape)
71
+ msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
72
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
73
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
74
+ scene.add_geometry(mesh)
75
+
76
+ # add each camera
77
+ for i, pose_c2w in enumerate(cams2world):
78
+ if isinstance(cam_color, list):
79
+ camera_edge_color = cam_color[i]
80
+ else:
81
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
82
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
83
+ None if transparent_cams else imgs[i], focals[i],
84
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
85
+
86
+ rot = np.eye(4)
87
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
88
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
89
+ if not silent:
90
+ print('(exporting 3D scene to', outfile, ')')
91
+ scene.export(file_obj=outfile)
92
+ return outfile
93
+
94
+
95
+ def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
96
+ clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
97
+ """
98
+ extract 3D_model (glb file) from a reconstructed scene
99
+ """
100
+ if scene_state is None:
101
+ return None
102
+ outfile = scene_state.outfile_name
103
+ if outfile is None:
104
+ return None
105
+
106
+ # get optimized values from scene
107
+ scene = scene_state.sparse_ga
108
+ rgbimg = scene.imgs
109
+ focals = scene.get_focals().cpu()
110
+ cams2world = scene.get_im_poses().cpu()
111
+
112
+ # 3D pointcloud from depthmap, poses and intrinsics
113
+ if TSDF_thresh > 0:
114
+ tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
115
+ pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
116
+ else:
117
+ pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
118
+ msk = to_numpy([c > min_conf_thr for c in confs])
119
+ return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
120
+ transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
121
+
122
+
123
+ def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent, image_size, current_scene_state,
124
+ filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
125
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
126
+ win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
127
+ """
128
+ from a list of images, run mast3r inference, sparse global aligner.
129
+ then run get_3D_model_from_scene
130
+ """
131
+ print(image_size, current_scene_state, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
132
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
133
+ win_cyclic, refid, TSDF_thresh, shared_intrinsics)
134
+ # 512 None refine+depth 0.07 500 0.014 200 1.5 5 True False True False 0.2 logwin 6 False 0 0 True
135
+ imgs = load_images(filelist, size=image_size, verbose=not silent)
136
+ if len(imgs) == 1:
137
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
138
+ imgs[1]['idx'] = 1
139
+ filelist = [filelist[0], filelist[0] + '_2']
140
+
141
+ scene_graph_params = [scenegraph_type]
142
+ if scenegraph_type in ["swin", "logwin"]:
143
+ scene_graph_params.append(str(winsize))
144
+ elif scenegraph_type == "oneref":
145
+ scene_graph_params.append(str(refid))
146
+ if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
147
+ scene_graph_params.append('noncyclic')
148
+ scene_graph = '-'.join(scene_graph_params)
149
+ pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
150
+ print(pairs, len(imgs))
151
+ if optim_level == 'coarse':
152
+ niter2 = 0
153
+ # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
154
+ if current_scene_state is not None and \
155
+ not current_scene_state.should_delete and \
156
+ current_scene_state.cache_dir is not None:
157
+ cache_dir = current_scene_state.cache_dir
158
+ elif gradio_delete_cache:
159
+ cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir)
160
+ else:
161
+ cache_dir = os.path.join(outdir, 'cache')
162
+ os.makedirs(cache_dir, exist_ok=True)
163
+ scene = sparse_global_alignment(filelist, pairs, cache_dir,
164
+ model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
165
+ opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
166
+ matching_conf_thr=matching_conf_thr, **kw)
167
+ if current_scene_state is not None and \
168
+ not current_scene_state.should_delete and \
169
+ current_scene_state.outfile_name is not None:
170
+ outfile_name = current_scene_state.outfile_name
171
+ else:
172
+ outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
173
+
174
+ scene_state = SparseGAState(scene, gradio_delete_cache, cache_dir, outfile_name)
175
+ outfile = get_3D_model_from_scene(silent, scene_state, min_conf_thr, as_pointcloud, mask_sky,
176
+ clean_depth, transparent_cams, cam_size, TSDF_thresh)
177
+ print(outfile)
178
+ return scene_state, outfile
179
+
180
+
181
+ def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
182
+ num_files = len(inputfiles) if inputfiles is not None else 1
183
+ show_win_controls = scenegraph_type in ["swin", "logwin"]
184
+ show_winsize = scenegraph_type in ["swin", "logwin"]
185
+ show_cyclic = scenegraph_type in ["swin", "logwin"]
186
+ max_winsize, min_winsize = 1, 1
187
+ if scenegraph_type == "swin":
188
+ if win_cyclic:
189
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
190
+ else:
191
+ max_winsize = num_files - 1
192
+ elif scenegraph_type == "logwin":
193
+ if win_cyclic:
194
+ half_size = math.ceil((num_files - 1) / 2)
195
+ max_winsize = max(1, math.ceil(math.log(half_size, 2)))
196
+ else:
197
+ max_winsize = max(1, math.ceil(math.log(num_files, 2)))
198
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
199
+ minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
200
+ win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
201
+ win_col = gradio.Column(visible=show_win_controls)
202
+ refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
203
+ maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
204
+ return win_col, winsize, win_cyclic, refid
205
+
206
+
207
+ def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False,
208
+ share=False, gradio_delete_cache=False):
209
+ if not silent:
210
+ print('Outputing stuff in', tmpdirname)
211
+
212
+ recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model, device,
213
+ silent, image_size)
214
+ model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
215
+
216
+ def get_context(delete_cache):
217
+ css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
218
+ title = "MASt3R Demo"
219
+ if delete_cache:
220
+ return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
221
+ else:
222
+ return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
223
+
224
+ with get_context(gradio_delete_cache) as demo:
225
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
226
+ scene = gradio.State(None)
227
+ gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
228
+ with gradio.Column():
229
+ inputfiles = gradio.File(file_count="multiple")
230
+ with gradio.Row():
231
+ with gradio.Column():
232
+ with gradio.Row():
233
+ lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
234
+ niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
235
+ label="num_iterations", info="For coarse alignment!")
236
+ lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
237
+ niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
238
+ label="num_iterations", info="For refinement!")
239
+ optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
240
+ value='refine+depth', label="OptLevel",
241
+ info="Optimization level")
242
+ with gradio.Row():
243
+ matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
244
+ minimum=0., maximum=30., step=0.1,
245
+ info="Before Fallback to Regr3D!")
246
+ shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
247
+ info="Only optimize one set of intrinsics for all views")
248
+ scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
249
+ ("swin: sliding window", "swin"),
250
+ ("logwin: sliding window with long range", "logwin"),
251
+ ("oneref: match one image with all", "oneref")],
252
+ value='complete', label="Scenegraph",
253
+ info="Define how to make pairs",
254
+ interactive=True)
255
+ with gradio.Column(visible=False) as win_col:
256
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
257
+ minimum=1, maximum=1, step=1)
258
+ win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
259
+ refid = gradio.Slider(label="Scene Graph: Id", value=0,
260
+ minimum=0, maximum=0, step=1, visible=False)
261
+ run_btn = gradio.Button("Run")
262
+
263
+ with gradio.Row():
264
+ # adjust the confidence threshold
265
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
266
+ # adjust the camera size in the output pointcloud
267
+ cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
268
+ TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01)
269
+ with gradio.Row():
270
+ as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
271
+ # two post process implemented
272
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
273
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
274
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
275
+
276
+ outmodel = gradio.Model3D()
277
+
278
+ # events
279
+ scenegraph_type.change(set_scenegraph_options,
280
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
281
+ outputs=[win_col, winsize, win_cyclic, refid])
282
+ inputfiles.change(set_scenegraph_options,
283
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
284
+ outputs=[win_col, winsize, win_cyclic, refid])
285
+ win_cyclic.change(set_scenegraph_options,
286
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
287
+ outputs=[win_col, winsize, win_cyclic, refid])
288
+ run_btn.click(fn=recon_fun,
289
+ inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
290
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
291
+ scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
292
+ outputs=[scene, outmodel])
293
+ min_conf_thr.release(fn=model_from_scene_fun,
294
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
295
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
296
+ outputs=outmodel)
297
+ cam_size.change(fn=model_from_scene_fun,
298
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
299
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
300
+ outputs=outmodel)
301
+ TSDF_thresh.change(fn=model_from_scene_fun,
302
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
303
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
304
+ outputs=outmodel)
305
+ as_pointcloud.change(fn=model_from_scene_fun,
306
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
307
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
308
+ outputs=outmodel)
309
+ mask_sky.change(fn=model_from_scene_fun,
310
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
311
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
312
+ outputs=outmodel)
313
+ clean_depth.change(fn=model_from_scene_fun,
314
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
315
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
316
+ outputs=outmodel)
317
+ transparent_cams.change(model_from_scene_fun,
318
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
319
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
320
+ outputs=outmodel)
321
+ demo.launch(share=share, server_name=server_name, server_port=server_port)
mast3r/fast_nn.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # MASt3R Fast Nearest Neighbor
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import numpy as np
9
+ import math
10
+ from scipy.spatial import KDTree
11
+
12
+ import mast3r.utils.path_to_dust3r # noqa
13
+ from dust3r.utils.device import to_numpy, todevice # noqa
14
+
15
+
16
+ @torch.no_grad()
17
+ def bruteforce_reciprocal_nns(A, B, device='cuda', block_size=None, dist='l2'):
18
+ if isinstance(A, np.ndarray):
19
+ A = torch.from_numpy(A).to(device)
20
+ if isinstance(B, np.ndarray):
21
+ B = torch.from_numpy(B).to(device)
22
+
23
+ A = A.to(device)
24
+ B = B.to(device)
25
+
26
+ if dist == 'l2':
27
+ dist_func = torch.cdist
28
+ argmin = torch.min
29
+ elif dist == 'dot':
30
+ def dist_func(A, B):
31
+ return A @ B.T
32
+
33
+ def argmin(X, dim):
34
+ sim, nn = torch.max(X, dim=dim)
35
+ return sim.neg_(), nn
36
+ else:
37
+ raise ValueError(f'Unknown {dist=}')
38
+
39
+ if block_size is None or len(A) * len(B) <= block_size**2:
40
+ dists = dist_func(A, B)
41
+ _, nn_A = argmin(dists, dim=1)
42
+ _, nn_B = argmin(dists, dim=0)
43
+ else:
44
+ dis_A = torch.full((A.shape[0],), float('inf'), device=device, dtype=A.dtype)
45
+ dis_B = torch.full((B.shape[0],), float('inf'), device=device, dtype=B.dtype)
46
+ nn_A = torch.full((A.shape[0],), -1, device=device, dtype=torch.int64)
47
+ nn_B = torch.full((B.shape[0],), -1, device=device, dtype=torch.int64)
48
+ number_of_iteration_A = math.ceil(A.shape[0] / block_size)
49
+ number_of_iteration_B = math.ceil(B.shape[0] / block_size)
50
+
51
+ for i in range(number_of_iteration_A):
52
+ A_i = A[i * block_size:(i + 1) * block_size]
53
+ for j in range(number_of_iteration_B):
54
+ B_j = B[j * block_size:(j + 1) * block_size]
55
+ dists_blk = dist_func(A_i, B_j) # A, B, 1
56
+ # dists_blk = dists[i * block_size:(i+1)*block_size, j * block_size:(j+1)*block_size]
57
+ min_A_i, argmin_A_i = argmin(dists_blk, dim=1)
58
+ min_B_j, argmin_B_j = argmin(dists_blk, dim=0)
59
+
60
+ col_mask = min_A_i < dis_A[i * block_size:(i + 1) * block_size]
61
+ line_mask = min_B_j < dis_B[j * block_size:(j + 1) * block_size]
62
+
63
+ dis_A[i * block_size:(i + 1) * block_size][col_mask] = min_A_i[col_mask]
64
+ dis_B[j * block_size:(j + 1) * block_size][line_mask] = min_B_j[line_mask]
65
+
66
+ nn_A[i * block_size:(i + 1) * block_size][col_mask] = argmin_A_i[col_mask] + (j * block_size)
67
+ nn_B[j * block_size:(j + 1) * block_size][line_mask] = argmin_B_j[line_mask] + (i * block_size)
68
+ nn_A = nn_A.cpu().numpy()
69
+ nn_B = nn_B.cpu().numpy()
70
+ return nn_A, nn_B
71
+
72
+
73
+ class cdistMatcher:
74
+ def __init__(self, db_pts, device='cuda'):
75
+ self.db_pts = db_pts.to(device)
76
+ self.device = device
77
+
78
+ def query(self, queries, k=1, **kw):
79
+ assert k == 1
80
+ if queries.numel() == 0:
81
+ return None, []
82
+ nnA, nnB = bruteforce_reciprocal_nns(queries, self.db_pts, device=self.device, **kw)
83
+ dis = None
84
+ return dis, nnA
85
+
86
+
87
+ def merge_corres(idx1, idx2, shape1=None, shape2=None, ret_xy=True, ret_index=False):
88
+ assert idx1.dtype == idx2.dtype == np.int32
89
+
90
+ # unique and sort along idx1
91
+ corres = np.unique(np.c_[idx2, idx1].view(np.int64), return_index=ret_index)
92
+ if ret_index:
93
+ corres, indices = corres
94
+ xy2, xy1 = corres[:, None].view(np.int32).T
95
+
96
+ if ret_xy:
97
+ assert shape1 and shape2
98
+ xy1 = np.unravel_index(xy1, shape1)
99
+ xy2 = np.unravel_index(xy2, shape2)
100
+ if ret_xy != 'y_x':
101
+ xy1 = xy1[0].base[:, ::-1]
102
+ xy2 = xy2[0].base[:, ::-1]
103
+
104
+ if ret_index:
105
+ return xy1, xy2, indices
106
+ return xy1, xy2
107
+
108
+
109
+ def fast_reciprocal_NNs(pts1, pts2, subsample_or_initxy1=8, ret_xy=True, pixel_tol=0, ret_basin=False,
110
+ device='cuda', **matcher_kw):
111
+ H1, W1, DIM1 = pts1.shape
112
+ H2, W2, DIM2 = pts2.shape
113
+ assert DIM1 == DIM2
114
+
115
+ pts1 = pts1.reshape(-1, DIM1)
116
+ pts2 = pts2.reshape(-1, DIM2)
117
+
118
+ if isinstance(subsample_or_initxy1, int) and pixel_tol == 0:
119
+ S = subsample_or_initxy1
120
+ y1, x1 = np.mgrid[S // 2:H1:S, S // 2:W1:S].reshape(2, -1)
121
+ max_iter = 10
122
+ else:
123
+ x1, y1 = subsample_or_initxy1
124
+ if isinstance(x1, torch.Tensor):
125
+ x1 = x1.cpu().numpy()
126
+ if isinstance(y1, torch.Tensor):
127
+ y1 = y1.cpu().numpy()
128
+ max_iter = 1
129
+
130
+ xy1 = np.int32(np.unique(x1 + W1 * y1)) # make sure there's no doublons
131
+ xy2 = np.full_like(xy1, -1)
132
+ old_xy1 = xy1.copy()
133
+ old_xy2 = xy2.copy()
134
+
135
+ if 'dist' in matcher_kw or 'block_size' in matcher_kw \
136
+ or (isinstance(device, str) and device.startswith('cuda')) \
137
+ or (isinstance(device, torch.device) and device.type.startswith('cuda')):
138
+ pts1 = pts1.to(device)
139
+ pts2 = pts2.to(device)
140
+ tree1 = cdistMatcher(pts1, device=device)
141
+ tree2 = cdistMatcher(pts2, device=device)
142
+ else:
143
+ pts1, pts2 = to_numpy((pts1, pts2))
144
+ tree1 = KDTree(pts1)
145
+ tree2 = KDTree(pts2)
146
+
147
+ notyet = np.ones(len(xy1), dtype=bool)
148
+ basin = np.full((H1 * W1 + 1,), -1, dtype=np.int32) if ret_basin else None
149
+
150
+ niter = 0
151
+ # n_notyet = [len(notyet)]
152
+ while notyet.any():
153
+ _, xy2[notyet] = to_numpy(tree2.query(pts1[xy1[notyet]], **matcher_kw))
154
+ if not ret_basin:
155
+ notyet &= (old_xy2 != xy2) # remove points that have converged
156
+
157
+ _, xy1[notyet] = to_numpy(tree1.query(pts2[xy2[notyet]], **matcher_kw))
158
+ if ret_basin:
159
+ basin[old_xy1[notyet]] = xy1[notyet]
160
+ notyet &= (old_xy1 != xy1) # remove points that have converged
161
+
162
+ # n_notyet.append(notyet.sum())
163
+ niter += 1
164
+ if niter >= max_iter:
165
+ break
166
+
167
+ old_xy2[:] = xy2
168
+ old_xy1[:] = xy1
169
+
170
+ # print('notyet_stats:', ' '.join(map(str, (n_notyet+[0]*10)[:max_iter])))
171
+
172
+ if pixel_tol > 0:
173
+ # in case we only want to match some specific points
174
+ # and still have some way of checking reciprocity
175
+ old_yx1 = np.unravel_index(old_xy1, (H1, W1))[0].base
176
+ new_yx1 = np.unravel_index(xy1, (H1, W1))[0].base
177
+ dis = np.linalg.norm(old_yx1 - new_yx1, axis=-1)
178
+ converged = dis < pixel_tol
179
+ if not isinstance(subsample_or_initxy1, int):
180
+ xy1 = old_xy1 # replace new points by old ones
181
+ else:
182
+ converged = ~notyet # converged correspondences
183
+
184
+ # keep only unique correspondences, and sort on xy1
185
+ xy1, xy2 = merge_corres(xy1[converged], xy2[converged], (H1, W1), (H2, W2), ret_xy=ret_xy)
186
+ if ret_basin:
187
+ return xy1, xy2, basin
188
+ return xy1, xy2
189
+
190
+
191
+ def extract_correspondences_nonsym(A, B, confA, confB, subsample=8, device=None, ptmap_key='pred_desc', pixel_tol=0):
192
+ if '3d' in ptmap_key:
193
+ opt = dict(device='cpu', workers=32)
194
+ else:
195
+ opt = dict(device=device, dist='dot', block_size=2**13)
196
+
197
+ # matching the two pairs
198
+ idx1 = []
199
+ idx2 = []
200
+ # merge corres from opposite pairs
201
+ HA, WA = A.shape[:2]
202
+ HB, WB = B.shape[:2]
203
+ if pixel_tol == 0:
204
+ nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt)
205
+ nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt)
206
+ else:
207
+ S = subsample
208
+ yA, xA = np.mgrid[S // 2:HA:S, S // 2:WA:S].reshape(2, -1)
209
+ yB, xB = np.mgrid[S // 2:HB:S, S // 2:WB:S].reshape(2, -1)
210
+
211
+ nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=(xA, yA), ret_xy=False, pixel_tol=pixel_tol, **opt)
212
+ nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=(xB, yB), ret_xy=False, pixel_tol=pixel_tol, **opt)
213
+
214
+ idx1 = np.r_[nn1to2[0], nn2to1[1]]
215
+ idx2 = np.r_[nn1to2[1], nn2to1[0]]
216
+
217
+ c1 = confA.ravel()[idx1]
218
+ c2 = confB.ravel()[idx2]
219
+
220
+ xy1, xy2, idx = merge_corres(idx1, idx2, (HA, WA), (HB, WB), ret_xy=True, ret_index=True)
221
+ conf = np.minimum(c1[idx], c2[idx])
222
+ corres = (xy1.copy(), xy2.copy(), conf)
223
+ return todevice(corres, device)
mast3r/losses.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Implementation of MASt3R training losses
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+ from sklearn.metrics import average_precision_score
11
+
12
+ import mast3r.utils.path_to_dust3r # noqa
13
+ from dust3r.losses import BaseCriterion, Criterion, MultiLoss, Sum, ConfLoss
14
+ from dust3r.losses import Regr3D as Regr3D_dust3r
15
+ from dust3r.utils.geometry import (geotrf, inv, normalize_pointcloud)
16
+ from dust3r.inference import get_pred_pts3d
17
+ from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale
18
+
19
+
20
+ def apply_log_to_norm(xyz):
21
+ d = xyz.norm(dim=-1, keepdim=True)
22
+ xyz = xyz / d.clip(min=1e-8)
23
+ xyz = xyz * torch.log1p(d)
24
+ return xyz
25
+
26
+
27
+ class Regr3D (Regr3D_dust3r):
28
+ def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False, opt_fit_gt=False,
29
+ sky_loss_value=2, max_metric_scale=False, loss_in_log=False):
30
+ self.loss_in_log = loss_in_log
31
+ if norm_mode.startswith('?'):
32
+ # do no norm pts from metric scale datasets
33
+ self.norm_all = False
34
+ self.norm_mode = norm_mode[1:]
35
+ else:
36
+ self.norm_all = True
37
+ self.norm_mode = norm_mode
38
+ super().__init__(criterion, self.norm_mode, gt_scale)
39
+
40
+ self.sky_loss_value = sky_loss_value
41
+ self.max_metric_scale = max_metric_scale
42
+
43
+ def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None):
44
+ # everything is normalized w.r.t. camera of view1
45
+ in_camera1 = inv(gt1['camera_pose'])
46
+ gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3
47
+ gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3
48
+
49
+ valid1 = gt1['valid_mask'].clone()
50
+ valid2 = gt2['valid_mask'].clone()
51
+
52
+ if dist_clip is not None:
53
+ # points that are too far-away == invalid
54
+ dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
55
+ dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
56
+ valid1 = valid1 & (dis1 <= dist_clip)
57
+ valid2 = valid2 & (dis2 <= dist_clip)
58
+
59
+ if self.loss_in_log == 'before':
60
+ # this only make sense when depth_mode == 'linear'
61
+ gt_pts1 = apply_log_to_norm(gt_pts1)
62
+ gt_pts2 = apply_log_to_norm(gt_pts2)
63
+
64
+ pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False).clone()
65
+ pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True).clone()
66
+
67
+ if not self.norm_all:
68
+ if self.max_metric_scale:
69
+ B = valid1.shape[0]
70
+ # valid1: B, H, W
71
+ # torch.linalg.norm(gt_pts1, dim=-1) -> B, H, W
72
+ # dist1_to_cam1 -> reshape to B, H*W
73
+ dist1_to_cam1 = torch.where(valid1, torch.linalg.norm(gt_pts1, dim=-1), 0).view(B, -1)
74
+ dist2_to_cam1 = torch.where(valid2, torch.linalg.norm(gt_pts2, dim=-1), 0).view(B, -1)
75
+
76
+ # is_metric_scale: B
77
+ # dist1_to_cam1.max(dim=-1).values -> B
78
+ gt1['is_metric_scale'] = gt1['is_metric_scale'] \
79
+ & (dist1_to_cam1.max(dim=-1).values < self.max_metric_scale) \
80
+ & (dist2_to_cam1.max(dim=-1).values < self.max_metric_scale)
81
+ gt2['is_metric_scale'] = gt1['is_metric_scale']
82
+
83
+ mask = ~gt1['is_metric_scale']
84
+ else:
85
+ mask = torch.ones_like(gt1['is_metric_scale'])
86
+ # normalize 3d points
87
+ if self.norm_mode and mask.any():
88
+ pr_pts1[mask], pr_pts2[mask] = normalize_pointcloud(pr_pts1[mask], pr_pts2[mask], self.norm_mode,
89
+ valid1[mask], valid2[mask])
90
+
91
+ if self.norm_mode and not self.gt_scale:
92
+ gt_pts1, gt_pts2, norm_factor = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode,
93
+ valid1, valid2, ret_factor=True)
94
+ # apply the same normalization to prediction
95
+ pr_pts1[~mask] = pr_pts1[~mask] / norm_factor[~mask]
96
+ pr_pts2[~mask] = pr_pts2[~mask] / norm_factor[~mask]
97
+
98
+ # return sky segmentation, making sure they don't include any labelled 3d points
99
+ sky1 = gt1['sky_mask'] & (~valid1)
100
+ sky2 = gt2['sky_mask'] & (~valid2)
101
+ return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, sky1, sky2, {}
102
+
103
+ def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
104
+ gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
105
+ self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw)
106
+
107
+ if self.sky_loss_value > 0:
108
+ assert self.criterion.reduction == 'none', 'sky_loss_value should be 0 if no conf loss'
109
+ # add the sky pixel as "valid" pixels...
110
+ mask1 = mask1 | sky1
111
+ mask2 = mask2 | sky2
112
+
113
+ # loss on img1 side
114
+ pred_pts1 = pred_pts1[mask1]
115
+ gt_pts1 = gt_pts1[mask1]
116
+ if self.loss_in_log and self.loss_in_log != 'before':
117
+ # this only make sense when depth_mode == 'exp'
118
+ pred_pts1 = apply_log_to_norm(pred_pts1)
119
+ gt_pts1 = apply_log_to_norm(gt_pts1)
120
+ l1 = self.criterion(pred_pts1, gt_pts1)
121
+
122
+ # loss on gt2 side
123
+ pred_pts2 = pred_pts2[mask2]
124
+ gt_pts2 = gt_pts2[mask2]
125
+ if self.loss_in_log and self.loss_in_log != 'before':
126
+ pred_pts2 = apply_log_to_norm(pred_pts2)
127
+ gt_pts2 = apply_log_to_norm(gt_pts2)
128
+ l2 = self.criterion(pred_pts2, gt_pts2)
129
+
130
+ if self.sky_loss_value > 0:
131
+ assert self.criterion.reduction == 'none', 'sky_loss_value should be 0 if no conf loss'
132
+ # ... but force the loss to be high there
133
+ l1 = torch.where(sky1[mask1], self.sky_loss_value, l1)
134
+ l2 = torch.where(sky2[mask2], self.sky_loss_value, l2)
135
+ self_name = type(self).__name__
136
+ details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())}
137
+ return Sum((l1, mask1), (l2, mask2)), (details | monitoring)
138
+
139
+
140
+ class Regr3D_ShiftInv (Regr3D):
141
+ """ Same than Regr3D but invariant to depth shift.
142
+ """
143
+
144
+ def get_all_pts3d(self, gt1, gt2, pred1, pred2):
145
+ # compute unnormalized points
146
+ gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
147
+ super().get_all_pts3d(gt1, gt2, pred1, pred2)
148
+
149
+ # compute median depth
150
+ gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2]
151
+ pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2]
152
+ gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None]
153
+ pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None]
154
+
155
+ # subtract the median depth
156
+ gt_z1 -= gt_shift_z
157
+ gt_z2 -= gt_shift_z
158
+ pred_z1 -= pred_shift_z
159
+ pred_z2 -= pred_shift_z
160
+
161
+ # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach())
162
+ return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring
163
+
164
+
165
+ class Regr3D_ScaleInv (Regr3D):
166
+ """ Same than Regr3D but invariant to depth scale.
167
+ if gt_scale == True: enforce the prediction to take the same scale than GT
168
+ """
169
+
170
+ def get_all_pts3d(self, gt1, gt2, pred1, pred2):
171
+ # compute depth-normalized points
172
+ gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
173
+ super().get_all_pts3d(gt1, gt2, pred1, pred2)
174
+
175
+ # measure scene scale
176
+ _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2)
177
+ _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2)
178
+
179
+ # prevent predictions to be in a ridiculous range
180
+ pred_scale = pred_scale.clip(min=1e-3, max=1e3)
181
+
182
+ # subtract the median depth
183
+ if self.gt_scale:
184
+ pred_pts1 *= gt_scale / pred_scale
185
+ pred_pts2 *= gt_scale / pred_scale
186
+ # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean())
187
+ else:
188
+ gt_pts1 /= gt_scale
189
+ gt_pts2 /= gt_scale
190
+ pred_pts1 /= pred_scale
191
+ pred_pts2 /= pred_scale
192
+ # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach())
193
+
194
+ return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring
195
+
196
+
197
+ class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv):
198
+ # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
199
+ pass
200
+
201
+
202
+ def get_similarities(desc1, desc2, euc=False):
203
+ if euc: # euclidean distance in same range than similarities
204
+ dists = (desc1[:, :, None] - desc2[:, None]).norm(dim=-1)
205
+ sim = 1 / (1 + dists)
206
+ else:
207
+ # Compute similarities
208
+ sim = desc1 @ desc2.transpose(-2, -1)
209
+ return sim
210
+
211
+
212
+ class MatchingCriterion(BaseCriterion):
213
+ def __init__(self, reduction='mean', fp=torch.float32):
214
+ super().__init__(reduction)
215
+ self.fp = fp
216
+
217
+ def forward(self, a, b, valid_matches=None, euc=False):
218
+ assert a.ndim >= 2 and 1 <= a.shape[-1], f'Bad shape = {a.shape}'
219
+ dist = self.loss(a.to(self.fp), b.to(self.fp), valid_matches, euc=euc)
220
+ # one dimension less or reduction to single value
221
+ assert (valid_matches is None and dist.ndim == a.ndim -
222
+ 1) or self.reduction in ['mean', 'sum', '1-mean', 'none']
223
+ if self.reduction == 'none':
224
+ return dist
225
+ if self.reduction == 'sum':
226
+ return dist.sum()
227
+ if self.reduction == 'mean':
228
+ return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
229
+ if self.reduction == '1-mean':
230
+ return 1. - dist.mean() if dist.numel() > 0 else dist.new_ones(())
231
+ raise ValueError(f'bad {self.reduction=} mode')
232
+
233
+ def loss(self, a, b, valid_matches=None):
234
+ raise NotImplementedError
235
+
236
+
237
+ class InfoNCE(MatchingCriterion):
238
+ def __init__(self, temperature=0.07, eps=1e-8, mode='all', **kwargs):
239
+ super().__init__(**kwargs)
240
+ self.temperature = temperature
241
+ self.eps = eps
242
+ assert mode in ['all', 'proper', 'dual']
243
+ self.mode = mode
244
+
245
+ def loss(self, desc1, desc2, valid_matches=None, euc=False):
246
+ # valid positives are along diagonals
247
+ B, N, D = desc1.shape
248
+ B2, N2, D2 = desc2.shape
249
+ assert B == B2 and D == D2
250
+ if valid_matches is None:
251
+ valid_matches = torch.ones([B, N], dtype=bool)
252
+ # torch.all(valid_matches.sum(dim=-1) > 0) some pairs have no matches????
253
+ assert valid_matches.shape == torch.Size([B, N]) and valid_matches.sum() > 0
254
+
255
+ # Tempered similarities
256
+ sim = get_similarities(desc1, desc2, euc) / self.temperature
257
+ sim[sim.isnan()] = -torch.inf # ignore nans
258
+ # Softmax of positives with temperature
259
+ sim = sim.exp_() # save peak memory
260
+ positives = sim.diagonal(dim1=-2, dim2=-1)
261
+
262
+ # Loss
263
+ if self.mode == 'all': # Previous InfoNCE
264
+ loss = -torch.log((positives / sim.sum(dim=-1).sum(dim=-1, keepdim=True)).clip(self.eps))
265
+ elif self.mode == 'proper': # Proper InfoNCE
266
+ loss = -(torch.log((positives / sim.sum(dim=-2)).clip(self.eps)) +
267
+ torch.log((positives / sim.sum(dim=-1)).clip(self.eps)))
268
+ elif self.mode == 'dual': # Dual Softmax
269
+ loss = -(torch.log((positives**2 / sim.sum(dim=-1) / sim.sum(dim=-2)).clip(self.eps)))
270
+ else:
271
+ raise ValueError("This should not happen...")
272
+ return loss[valid_matches]
273
+
274
+
275
+ class APLoss (MatchingCriterion):
276
+ """ AP loss.
277
+ """
278
+
279
+ def __init__(self, nq='torch', min=0, max=1, euc=False, **kw):
280
+ super().__init__(**kw)
281
+ # Exact/True AP loss (not differentiable)
282
+ if nq == 0:
283
+ nq = 'sklearn' # special case
284
+ try:
285
+ self.compute_AP = eval('self.compute_true_AP_' + nq)
286
+ except:
287
+ raise ValueError("Unknown mode %s for AP loss" % nq)
288
+
289
+ @staticmethod
290
+ def compute_true_AP_sklearn(scores, labels):
291
+ def compute_AP(label, score):
292
+ return average_precision_score(label, score)
293
+
294
+ aps = scores.new_zeros((scores.shape[0], scores.shape[1]))
295
+ label_np = labels.cpu().numpy().astype(bool)
296
+ scores_np = scores.cpu().numpy()
297
+ for bi in range(scores_np.shape[0]):
298
+ for i in range(scores_np.shape[1]):
299
+ labels = label_np[bi, i, :]
300
+ if labels.sum() < 1:
301
+ continue
302
+ aps[bi, i] = compute_AP(labels, scores_np[bi, i, :])
303
+ return aps
304
+
305
+ @staticmethod
306
+ def compute_true_AP_torch(scores, labels):
307
+ assert scores.shape == labels.shape
308
+ B, N, M = labels.shape
309
+ dev = labels.device
310
+ with torch.no_grad():
311
+ # sort scores
312
+ _, order = scores.sort(dim=-1, descending=True)
313
+ # sort labels accordingly
314
+ labels = labels[torch.arange(B, device=dev)[:, None, None].expand(order.shape),
315
+ torch.arange(N, device=dev)[None, :, None].expand(order.shape),
316
+ order]
317
+ # compute number of positives per query
318
+ npos = labels.sum(dim=-1)
319
+ assert torch.all(torch.isclose(npos, npos[0, 0])
320
+ ), "only implemented for constant number of positives per query"
321
+ npos = int(npos[0, 0])
322
+ # compute precision at each recall point
323
+ posrank = labels.nonzero()[:, -1].view(B, N, npos)
324
+ recall = torch.arange(1, 1 + npos, dtype=torch.float32, device=dev)[None, None, :].expand(B, N, npos)
325
+ precision = recall / (1 + posrank).float()
326
+ # average precision values at all recall points
327
+ aps = precision.mean(dim=-1)
328
+
329
+ return aps
330
+
331
+ def loss(self, desc1, desc2, valid_matches=None, euc=False): # if matches is None, positives are the diagonal
332
+ B, N1, D = desc1.shape
333
+ B2, N2, D2 = desc2.shape
334
+ assert B == B2 and D == D2
335
+
336
+ scores = get_similarities(desc1, desc2, euc)
337
+
338
+ labels = torch.zeros([B, N1, N2], dtype=scores.dtype, device=scores.device)
339
+
340
+ # allow all diagonal positives and only mask afterwards
341
+ labels.diagonal(dim1=-2, dim2=-1)[...] = 1.
342
+ apscore = self.compute_AP(scores, labels)
343
+ if valid_matches is not None:
344
+ apscore = apscore[valid_matches]
345
+ return apscore
346
+
347
+
348
+ class MatchingLoss (Criterion, MultiLoss):
349
+ """
350
+ Matching loss per image
351
+ only compare pixels inside an image but not in the whole batch as what would be done usually
352
+ """
353
+
354
+ def __init__(self, criterion, withconf=False, use_pts3d=False, negatives_padding=0, blocksize=4096):
355
+ super().__init__(criterion)
356
+ self.negatives_padding = negatives_padding
357
+ self.use_pts3d = use_pts3d
358
+ self.blocksize = blocksize
359
+ self.withconf = withconf
360
+
361
+ def add_negatives(self, outdesc2, desc2, batchid, x2, y2):
362
+ if self.negatives_padding:
363
+ B, H, W, D = desc2.shape
364
+ negatives = torch.ones([B, H, W], device=desc2.device, dtype=bool)
365
+ negatives[batchid, y2, x2] = False
366
+ sel = negatives & (negatives.view([B, -1]).cumsum(dim=-1).view(B, H, W)
367
+ <= self.negatives_padding) # take the N-first negatives
368
+ outdesc2 = torch.cat([outdesc2, desc2[sel].view([B, -1, D])], dim=1)
369
+ return outdesc2
370
+
371
+ def get_confs(self, pred1, pred2, sel1, sel2):
372
+ if self.withconf:
373
+ if self.use_pts3d:
374
+ outconfs1 = pred1['conf'][sel1]
375
+ outconfs2 = pred2['conf'][sel2]
376
+ else:
377
+ outconfs1 = pred1['desc_conf'][sel1]
378
+ outconfs2 = pred2['desc_conf'][sel2]
379
+ else:
380
+ outconfs1 = outconfs2 = None
381
+ return outconfs1, outconfs2
382
+
383
+ def get_descs(self, pred1, pred2):
384
+ if self.use_pts3d:
385
+ desc1, desc2 = pred1['pts3d'], pred2['pts3d_in_other_view']
386
+ else:
387
+ desc1, desc2 = pred1['desc'], pred2['desc']
388
+ return desc1, desc2
389
+
390
+ def get_matching_descs(self, gt1, gt2, pred1, pred2, **kw):
391
+ outdesc1 = outdesc2 = outconfs1 = outconfs2 = None
392
+ # Recover descs, GT corres and valid mask
393
+ desc1, desc2 = self.get_descs(pred1, pred2)
394
+
395
+ (x1, y1), (x2, y2) = gt1['corres'].unbind(-1), gt2['corres'].unbind(-1)
396
+ valid_matches = gt1['valid_corres']
397
+
398
+ # Select descs that have GT matches
399
+ B, N = x1.shape
400
+ batchid = torch.arange(B)[:, None].repeat(1, N) # B, N
401
+ outdesc1, outdesc2 = desc1[batchid, y1, x1], desc2[batchid, y2, x2] # B, N, D
402
+
403
+ # Padd with unused negatives
404
+ outdesc2 = self.add_negatives(outdesc2, desc2, batchid, x2, y2)
405
+
406
+ # Gather confs if needed
407
+ sel1 = batchid, y1, x1
408
+ sel2 = batchid, y2, x2
409
+ outconfs1, outconfs2 = self.get_confs(pred1, pred2, sel1, sel2)
410
+
411
+ return outdesc1, outdesc2, outconfs1, outconfs2, valid_matches, {'use_euclidean_dist': self.use_pts3d}
412
+
413
+ def blockwise_criterion(self, descs1, descs2, confs1, confs2, valid_matches, euc, rng=np.random, shuffle=True):
414
+ loss = None
415
+ details = {}
416
+ B, N, D = descs1.shape
417
+
418
+ if N <= self.blocksize: # Blocks are larger than provided descs, compute regular loss
419
+ loss = self.criterion(descs1, descs2, valid_matches, euc=euc)
420
+ else: # Compute criterion on the blockdiagonal only, after shuffling
421
+ # Shuffle if necessary
422
+ matches_perm = slice(None)
423
+ if shuffle:
424
+ matches_perm = np.stack([rng.choice(range(N), size=N, replace=False) for _ in range(B)])
425
+ batchid = torch.tile(torch.arange(B), (N, 1)).T
426
+ matches_perm = batchid, matches_perm
427
+
428
+ descs1 = descs1[matches_perm]
429
+ descs2 = descs2[matches_perm]
430
+ valid_matches = valid_matches[matches_perm]
431
+
432
+ assert N % self.blocksize == 0, "Error, can't chunk block-diagonal, please check blocksize"
433
+ n_chunks = N // self.blocksize
434
+ descs1 = descs1.reshape([B * n_chunks, self.blocksize, D]) # [B*(N//blocksize), blocksize, D]
435
+ descs2 = descs2.reshape([B * n_chunks, self.blocksize, D]) # [B*(N//blocksize), blocksize, D]
436
+ valid_matches = valid_matches.view([B * n_chunks, self.blocksize])
437
+ loss = self.criterion(descs1, descs2, valid_matches, euc=euc)
438
+ if self.withconf:
439
+ confs1, confs2 = map(lambda x: x[matches_perm], (confs1, confs2)) # apply perm to confidences if needed
440
+
441
+ if self.withconf:
442
+ # split confidences between positives/negatives for loss computation
443
+ details['conf_pos'] = map(lambda x: x[valid_matches.view(B, -1)], (confs1, confs2))
444
+ details['conf_neg'] = map(lambda x: x[~valid_matches.view(B, -1)], (confs1, confs2))
445
+ details['Conf1_std'] = confs1.std()
446
+ details['Conf2_std'] = confs2.std()
447
+
448
+ return loss, details
449
+
450
+ def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
451
+ # Gather preds and GT
452
+ descs1, descs2, confs1, confs2, valid_matches, monitoring = self.get_matching_descs(
453
+ gt1, gt2, pred1, pred2, **kw)
454
+
455
+ # loss on matches
456
+ loss, details = self.blockwise_criterion(descs1, descs2, confs1, confs2,
457
+ valid_matches, euc=monitoring.pop('use_euclidean_dist', False))
458
+
459
+ details[type(self).__name__] = float(loss.mean())
460
+ return loss, (details | monitoring)
461
+
462
+
463
+ class ConfMatchingLoss(ConfLoss):
464
+ """ Weight matching by learned confidence. Same as ConfLoss but for a matching criterion
465
+ Assuming the input matching_loss is a match-level loss.
466
+ """
467
+
468
+ def __init__(self, pixel_loss, alpha=1., confmode='prod', neg_conf_loss_quantile=False):
469
+ super().__init__(pixel_loss, alpha)
470
+ self.pixel_loss.withconf = True
471
+ self.confmode = confmode
472
+ self.neg_conf_loss_quantile = neg_conf_loss_quantile
473
+
474
+ def aggregate_confs(self, confs1, confs2): # get the confidences resulting from the two view predictions
475
+ if self.confmode == 'prod':
476
+ confs = confs1 * confs2 if confs1 is not None and confs2 is not None else 1.
477
+ elif self.confmode == 'mean':
478
+ confs = .5 * (confs1 + confs2) if confs1 is not None and confs2 is not None else 1.
479
+ else:
480
+ raise ValueError(f"Unknown conf mode {self.confmode}")
481
+ return confs
482
+
483
+ def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
484
+ # compute per-pixel loss
485
+ loss, details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw)
486
+ # Recover confidences for positive and negative samples
487
+ conf1_pos, conf2_pos = details.pop('conf_pos')
488
+ conf1_neg, conf2_neg = details.pop('conf_neg')
489
+ conf_pos = self.aggregate_confs(conf1_pos, conf2_pos)
490
+
491
+ # weight Matching loss by confidence on positives
492
+ conf_pos, log_conf_pos = self.get_conf_log(conf_pos)
493
+ conf_loss = loss * conf_pos - self.alpha * log_conf_pos
494
+ # average + nan protection (in case of no valid pixels at all)
495
+ conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
496
+ # Add negative confs loss to give some supervision signal to confidences for pixels that are not matched in GT
497
+ if self.neg_conf_loss_quantile:
498
+ conf_neg = torch.cat([conf1_neg, conf2_neg])
499
+ conf_neg, log_conf_neg = self.get_conf_log(conf_neg)
500
+
501
+ # recover quantile that will be used for negatives loss value assignment
502
+ neg_loss_value = torch.quantile(loss, self.neg_conf_loss_quantile).detach()
503
+ neg_loss = neg_loss_value * conf_neg - self.alpha * log_conf_neg
504
+
505
+ neg_loss = neg_loss.mean() if neg_loss.numel() > 0 else 0
506
+ conf_loss = conf_loss + neg_loss
507
+
508
+ return conf_loss, dict(matching_conf_loss=float(conf_loss), **details)
mast3r/model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # MASt3R model class
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import os
10
+
11
+ from mast3r.catmlp_dpt_head import mast3r_head_factory
12
+
13
+ import mast3r.utils.path_to_dust3r # noqa
14
+ from dust3r.model import AsymmetricCroCo3DStereo # noqa
15
+ from dust3r.utils.misc import transpose_to_landscape # noqa
16
+
17
+
18
+ inf = float('inf')
19
+
20
+
21
+ def load_model(model_path, device, verbose=True):
22
+ if verbose:
23
+ print('... loading model from', model_path)
24
+ ckpt = torch.load(model_path, map_location='cpu')
25
+ args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
26
+ if 'landscape_only' not in args:
27
+ args = args[:-1] + ', landscape_only=False)'
28
+ else:
29
+ args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
30
+ assert "landscape_only=False" in args
31
+ if verbose:
32
+ print(f"instantiating : {args}")
33
+ net = eval(args)
34
+ s = net.load_state_dict(ckpt['model'], strict=False)
35
+ if verbose:
36
+ print(s)
37
+ return net.to(device)
38
+
39
+
40
+ class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
41
+ def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
42
+ self.desc_mode = desc_mode
43
+ self.two_confs = two_confs
44
+ self.desc_conf_mode = desc_conf_mode
45
+ super().__init__(**kwargs)
46
+
47
+ @classmethod
48
+ def from_pretrained(cls, pretrained_model_name_or_path, **kw):
49
+ if os.path.isfile(pretrained_model_name_or_path):
50
+ return load_model(pretrained_model_name_or_path, device='cpu')
51
+ else:
52
+ return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
53
+
54
+ def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
55
+ assert img_size[0] % patch_size == 0 and img_size[
56
+ 1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
57
+ self.output_mode = output_mode
58
+ self.head_type = head_type
59
+ self.depth_mode = depth_mode
60
+ self.conf_mode = conf_mode
61
+ if self.desc_conf_mode is None:
62
+ self.desc_conf_mode = conf_mode
63
+ # allocate heads
64
+ self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
65
+ self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
66
+ # magic wrapper
67
+ self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
68
+ self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
mast3r/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
mast3r/utils/coarse_to_fine.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # coarse to fine utilities
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+
9
+
10
+ def crop_tag(cell):
11
+ return f'[{cell[1]}:{cell[3]},{cell[0]}:{cell[2]}]'
12
+
13
+
14
+ def crop_slice(cell):
15
+ return slice(cell[1], cell[3]), slice(cell[0], cell[2])
16
+
17
+
18
+ def _start_pos(total_size, win_size, overlap):
19
+ # we must have AT LEAST overlap between segments
20
+ # first segment starts at 0, last segment starts at total_size-win_size
21
+ assert 0 <= overlap < 1
22
+ assert total_size >= win_size
23
+ spacing = win_size * (1 - overlap)
24
+ last_pt = total_size - win_size
25
+ n_windows = 2 + int((last_pt - 1) // spacing)
26
+ return np.linspace(0, last_pt, n_windows).round().astype(int)
27
+
28
+
29
+ def multiple_of_16(x):
30
+ return (x // 16) * 16
31
+
32
+
33
+ def _make_overlapping_grid(H, W, size, overlap):
34
+ H_win = multiple_of_16(H * size // max(H, W))
35
+ W_win = multiple_of_16(W * size // max(H, W))
36
+ x = _start_pos(W, W_win, overlap)
37
+ y = _start_pos(H, H_win, overlap)
38
+ grid = np.stack(np.meshgrid(x, y, indexing='xy'), axis=-1)
39
+ grid = np.concatenate((grid, grid + (W_win, H_win)), axis=-1)
40
+ return grid.reshape(-1, 4)
41
+
42
+
43
+ def _cell_size(cell2):
44
+ width, height = cell2[:, 2] - cell2[:, 0], cell2[:, 3] - cell2[:, 1]
45
+ assert width.min() >= 0
46
+ assert height.min() >= 0
47
+ return width, height
48
+
49
+
50
+ def _norm_windows(cell2, H2, W2, forced_resolution=None):
51
+ # make sure the window aspect ratio is 3/4, or the output resolution is forced_resolution if defined
52
+ outcell = cell2.copy()
53
+ width, height = _cell_size(cell2)
54
+ width2, height2 = width.clip(max=W2), height.clip(max=H2)
55
+ if forced_resolution is None:
56
+ width2[width < height] = (height2[width < height] * 3.01 / 4).clip(max=W2)
57
+ height2[width >= height] = (width2[width >= height] * 3.01 / 4).clip(max=H2)
58
+ else:
59
+ forced_H, forced_W = forced_resolution
60
+ width2[:] = forced_W
61
+ height2[:] = forced_H
62
+
63
+ half = (width2 - width) / 2
64
+ outcell[:, 0] -= half
65
+ outcell[:, 2] += half
66
+ half = (height2 - height) / 2
67
+ outcell[:, 1] -= half
68
+ outcell[:, 3] += half
69
+
70
+ # proj to integers
71
+ outcell = np.floor(outcell).astype(int)
72
+ # Take care of flooring errors
73
+ tmpw, tmph = _cell_size(outcell)
74
+ outcell[:, 0] += tmpw.astype(tmpw.dtype) - width2.astype(tmpw.dtype)
75
+ outcell[:, 1] += tmph.astype(tmpw.dtype) - height2.astype(tmpw.dtype)
76
+
77
+ # make sure 0 <= x < W2 and 0 <= y < H2
78
+ outcell[:, 0::2] -= outcell[:, [0]].clip(max=0)
79
+ outcell[:, 1::2] -= outcell[:, [1]].clip(max=0)
80
+ outcell[:, 0::2] -= outcell[:, [2]].clip(min=W2) - W2
81
+ outcell[:, 1::2] -= outcell[:, [3]].clip(min=H2) - H2
82
+
83
+ width, height = _cell_size(outcell)
84
+ assert np.all(width == width2.astype(width.dtype)) and np.all(
85
+ height == height2.astype(height.dtype)), "Error, output is not of the expected shape."
86
+ assert np.all(width <= W2)
87
+ assert np.all(height <= H2)
88
+ return outcell
89
+
90
+
91
+ def _weight_pixels(cell, pix, assigned, gauss_var=2):
92
+ center = cell.reshape(-1, 2, 2).mean(axis=1)
93
+ width, height = _cell_size(cell)
94
+
95
+ # square distance between each cell center and each point
96
+ dist = (center[:, None] - pix[None]) / np.c_[width, height][:, None]
97
+ dist2 = np.square(dist).sum(axis=-1)
98
+
99
+ assert assigned.shape == dist2.shape
100
+ res = np.where(assigned, np.exp(-gauss_var * dist2), 0)
101
+ return res
102
+
103
+
104
+ def pos2d_in_rect(p1, cell1):
105
+ x, y = p1.T
106
+ l, t, r, b = cell1
107
+ assigned = (l <= x) & (x < r) & (t <= y) & (y < b)
108
+ return assigned
109
+
110
+
111
+ def _score_cell(cell1, H2, W2, p1, p2, min_corres=10, forced_resolution=None):
112
+ assert p1.shape == p2.shape
113
+
114
+ # compute keypoint assignment
115
+ assigned = pos2d_in_rect(p1, cell1[None].T)
116
+ assert assigned.shape == (len(cell1), len(p1))
117
+
118
+ # remove cells without correspondences
119
+ valid_cells = assigned.sum(axis=1) >= min_corres
120
+ cell1 = cell1[valid_cells]
121
+ assigned = assigned[valid_cells]
122
+ if not valid_cells.any():
123
+ return cell1, cell1, assigned
124
+
125
+ # fill-in the assigned points in both image
126
+ assigned_p1 = np.empty((len(cell1), len(p1), 2), dtype=np.float32)
127
+ assigned_p2 = np.empty((len(cell1), len(p2), 2), dtype=np.float32)
128
+ assigned_p1[:] = p1[None]
129
+ assigned_p2[:] = p2[None]
130
+ assigned_p1[~assigned] = np.nan
131
+ assigned_p2[~assigned] = np.nan
132
+
133
+ # find the median center and scale of assigned points in each cell
134
+ # cell_center1 = np.nanmean(assigned_p1, axis=1)
135
+ cell_center2 = np.nanmean(assigned_p2, axis=1)
136
+ im1_q25, im1_q75 = np.nanquantile(assigned_p1, (0.1, 0.9), axis=1)
137
+ im2_q25, im2_q75 = np.nanquantile(assigned_p2, (0.1, 0.9), axis=1)
138
+
139
+ robust_std1 = (im1_q75 - im1_q25).clip(20.)
140
+ robust_std2 = (im2_q75 - im2_q25).clip(20.)
141
+
142
+ cell_size1 = (cell1[:, 2:4] - cell1[:, 0:2])
143
+ cell_size2 = cell_size1 * robust_std2 / robust_std1
144
+ cell2 = np.c_[cell_center2 - cell_size2 / 2, cell_center2 + cell_size2 / 2]
145
+
146
+ # make sure cell bounds are valid
147
+ cell2 = _norm_windows(cell2, H2, W2, forced_resolution=forced_resolution)
148
+
149
+ # compute correspondence weights
150
+ corres_weights = _weight_pixels(cell1, p1, assigned) * _weight_pixels(cell2, p2, assigned)
151
+
152
+ # return a list of window pairs and assigned correspondences
153
+ return cell1, cell2, corres_weights
154
+
155
+
156
+ def greedy_selection(corres_weights, target=0.9):
157
+ # corres_weight = (n_cell_pair, n_corres) matrix.
158
+ # If corres_weight[c,p]>0, means that correspondence p is visible in cell pair p
159
+ assert 0 < target <= 1
160
+ corres_weights = corres_weights.copy()
161
+
162
+ total = corres_weights.max(axis=0).sum()
163
+ target *= total
164
+
165
+ # init = empty
166
+ res = []
167
+ cur = np.zeros(corres_weights.shape[1]) # current selection
168
+
169
+ while cur.sum() < target:
170
+ # pick the nex best cell pair
171
+ best = corres_weights.sum(axis=1).argmax()
172
+ res.append(best)
173
+
174
+ # update current
175
+ cur += corres_weights[best]
176
+ # print('appending', best, 'with score', corres_weights[best].sum(), '-->', cur.sum())
177
+
178
+ # remove from all other views
179
+ corres_weights = (corres_weights - corres_weights[best]).clip(min=0)
180
+
181
+ return res
182
+
183
+
184
+ def select_pairs_of_crops(img_q, img_b, pos2d_in_query, pos2d_in_ref, maxdim=512, overlap=.5, forced_resolution=None):
185
+ # prepare the overlapping cells
186
+ grid_q = _make_overlapping_grid(*img_q.shape[:2], maxdim, overlap)
187
+ grid_b = _make_overlapping_grid(*img_b.shape[:2], maxdim, overlap)
188
+
189
+ assert forced_resolution is None or len(forced_resolution) == 2
190
+ if isinstance(forced_resolution[0], int) or not len(forced_resolution[0]) == 2:
191
+ forced_resolution1 = forced_resolution2 = forced_resolution
192
+ else:
193
+ assert len(forced_resolution[1]) == 2
194
+ forced_resolution1 = forced_resolution[0]
195
+ forced_resolution2 = forced_resolution[1]
196
+
197
+ # Make sure crops respect constraints
198
+ grid_q = _norm_windows(grid_q.astype(float), *img_q.shape[:2], forced_resolution=forced_resolution1)
199
+ grid_b = _norm_windows(grid_b.astype(float), *img_b.shape[:2], forced_resolution=forced_resolution2)
200
+
201
+ # score cells
202
+ pairs_q = _score_cell(grid_q, *img_b.shape[:2], pos2d_in_query, pos2d_in_ref, forced_resolution=forced_resolution2)
203
+ pairs_b = _score_cell(grid_b, *img_q.shape[:2], pos2d_in_ref, pos2d_in_query, forced_resolution=forced_resolution1)
204
+ pairs_b = pairs_b[1], pairs_b[0], pairs_b[2] # cellq, cellb, corres_weights
205
+
206
+ # greedy selection until all correspondences are generated
207
+ cell1, cell2, corres_weights = map(np.concatenate, zip(pairs_q, pairs_b))
208
+ if len(corres_weights) == 0:
209
+ return # tolerated for empty generators
210
+ order = greedy_selection(corres_weights, target=0.9)
211
+
212
+ for i in order:
213
+ def pair_tag(qi, bi): return (str(qi) + crop_tag(cell1[i]), str(bi) + crop_tag(cell2[i]))
214
+ yield cell1[i], cell2[i], pair_tag
mast3r/utils/collate.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Collate extensions
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import collections
10
+ from torch.utils.data._utils.collate import default_collate_fn_map, default_collate_err_msg_format
11
+ from typing import Callable, Dict, Optional, Tuple, Type, Union, List
12
+
13
+
14
+ def cat_collate_tensor_fn(batch, *, collate_fn_map):
15
+ return torch.cat(batch, dim=0)
16
+
17
+
18
+ def cat_collate_list_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
19
+ return [item for bb in batch for item in bb] # concatenate all lists
20
+
21
+
22
+ cat_collate_fn_map = default_collate_fn_map.copy()
23
+ cat_collate_fn_map[torch.Tensor] = cat_collate_tensor_fn
24
+ cat_collate_fn_map[List] = cat_collate_list_fn
25
+ cat_collate_fn_map[type(None)] = lambda _, **kw: None # When some Nones, simply return a single None
26
+
27
+
28
+ def cat_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
29
+ r"""Custom collate function that concatenates stuff instead of stacking them, and handles NoneTypes """
30
+ elem = batch[0]
31
+ elem_type = type(elem)
32
+
33
+ if collate_fn_map is not None:
34
+ if elem_type in collate_fn_map:
35
+ return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
36
+
37
+ for collate_type in collate_fn_map:
38
+ if isinstance(elem, collate_type):
39
+ return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)
40
+
41
+ if isinstance(elem, collections.abc.Mapping):
42
+ try:
43
+ return elem_type({key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
44
+ except TypeError:
45
+ # The mapping type may not support `__init__(iterable)`.
46
+ return {key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
47
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
48
+ return elem_type(*(cat_collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
49
+ elif isinstance(elem, collections.abc.Sequence):
50
+ transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
51
+
52
+ if isinstance(elem, tuple):
53
+ # Backwards compatibility.
54
+ return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
55
+ else:
56
+ try:
57
+ return elem_type([cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
58
+ except TypeError:
59
+ # The sequence type may not support `__init__(iterable)` (e.g., `range`).
60
+ return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
61
+
62
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
mast3r/utils/misc.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions for MASt3R
6
+ # --------------------------------------------------------
7
+ import os
8
+ import hashlib
9
+
10
+
11
+ def mkdir_for(f):
12
+ os.makedirs(os.path.dirname(f), exist_ok=True)
13
+ return f
14
+
15
+
16
+ def hash_md5(s):
17
+ return hashlib.md5(s.encode('utf-8')).hexdigest()
mast3r/utils/path_to_dust3r.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # dust3r submodule import
6
+ # --------------------------------------------------------
7
+
8
+ import sys
9
+ import os.path as path
10
+ HERE_PATH = path.normpath(path.dirname(__file__))
11
+ DUSt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../'))
12
+ DUSt3R_LIB_PATH = path.join(DUSt3R_REPO_PATH, 'dust3r')
13
+ # check the presence of models directory in repo to be sure its cloned
14
+ if path.isdir(DUSt3R_LIB_PATH):
15
+ # workaround for sibling import
16
+ sys.path.insert(0, DUSt3R_REPO_PATH)
17
+ else:
18
+ raise ImportError(f"dust3r is not initialized, could not find: {DUSt3R_LIB_PATH}.\n "
19
+ "Did you forget to run 'git submodule update --init --recursive' ?")