Spaces:
Sleeping
Sleeping
hugoycj
commited on
Commit
•
2caa1bd
1
Parent(s):
9059c91
feat: Add mast3r dependencies
Browse files- dust3r/model.py +5 -1
- mast3r/__init__.py +2 -0
- mast3r/catmlp_dpt_head.py +123 -0
- mast3r/cloud_opt/__init__.py +2 -0
- mast3r/cloud_opt/sparse_ga.py +1035 -0
- mast3r/cloud_opt/triangulation.py +80 -0
- mast3r/cloud_opt/tsdf_optimizer.py +269 -0
- mast3r/cloud_opt/utils/__init__.py +2 -0
- mast3r/cloud_opt/utils/losses.py +32 -0
- mast3r/cloud_opt/utils/schedules.py +17 -0
- mast3r/colmap/__init__.py +2 -0
- mast3r/colmap/database.py +383 -0
- mast3r/datasets/__init__.py +62 -0
- mast3r/datasets/base/__init__.py +2 -0
- mast3r/datasets/base/mast3r_base_stereo_view_dataset.py +355 -0
- mast3r/datasets/utils/__init__.py +2 -0
- mast3r/datasets/utils/cropping.py +219 -0
- mast3r/demo.py +321 -0
- mast3r/fast_nn.py +223 -0
- mast3r/losses.py +508 -0
- mast3r/model.py +68 -0
- mast3r/utils/__init__.py +2 -0
- mast3r/utils/coarse_to_fine.py +214 -0
- mast3r/utils/collate.py +62 -0
- mast3r/utils/misc.py +17 -0
- mast3r/utils/path_to_dust3r.py +19 -0
dust3r/model.py
CHANGED
@@ -14,6 +14,7 @@ from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, in
|
|
14 |
from .heads import head_factory
|
15 |
from dust3r.patch_embed import get_patch_embed
|
16 |
|
|
|
17 |
import dust3r.utils.path_to_croco # noqa: F401
|
18 |
from models.croco import CroCoNet # noqa
|
19 |
|
@@ -78,7 +79,10 @@ class AsymmetricCroCo3DStereo (
|
|
78 |
|
79 |
@classmethod
|
80 |
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
|
81 |
-
|
|
|
|
|
|
|
82 |
|
83 |
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
|
84 |
self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)
|
|
|
14 |
from .heads import head_factory
|
15 |
from dust3r.patch_embed import get_patch_embed
|
16 |
|
17 |
+
import urllib
|
18 |
import dust3r.utils.path_to_croco # noqa: F401
|
19 |
from models.croco import CroCoNet # noqa
|
20 |
|
|
|
79 |
|
80 |
@classmethod
|
81 |
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
|
82 |
+
if os.path.isfile(pretrained_model_name_or_path) or urllib.parse.urlparse(pretrained_model_name_or_path).scheme in ('http', 'https'):
|
83 |
+
return load_model(pretrained_model_name_or_path, device='cpu', landscape_only=kw['landscape_only'])
|
84 |
+
else:
|
85 |
+
return super(AsymmetricCroCo3DStereo, cls).from_pretrained(pretrained_model_name_or_path, **kw)
|
86 |
|
87 |
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
|
88 |
self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)
|
mast3r/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
mast3r/catmlp_dpt_head.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# MASt3R heads
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import mast3r.utils.path_to_dust3r # noqa
|
11 |
+
from dust3r.heads.postprocess import reg_dense_depth, reg_dense_conf # noqa
|
12 |
+
from dust3r.heads.dpt_head import PixelwiseTaskWithDPT # noqa
|
13 |
+
import dust3r.utils.path_to_croco # noqa
|
14 |
+
from models.blocks import Mlp # noqa
|
15 |
+
|
16 |
+
|
17 |
+
def reg_desc(desc, mode):
|
18 |
+
if 'norm' in mode:
|
19 |
+
desc = desc / desc.norm(dim=-1, keepdim=True)
|
20 |
+
else:
|
21 |
+
raise ValueError(f"Unknown desc mode {mode}")
|
22 |
+
return desc
|
23 |
+
|
24 |
+
|
25 |
+
def postprocess(out, depth_mode, conf_mode, desc_dim=None, desc_mode='norm', two_confs=False, desc_conf_mode=None):
|
26 |
+
if desc_conf_mode is None:
|
27 |
+
desc_conf_mode = conf_mode
|
28 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,D
|
29 |
+
res = dict(pts3d=reg_dense_depth(fmap[..., 0:3], mode=depth_mode))
|
30 |
+
if conf_mode is not None:
|
31 |
+
res['conf'] = reg_dense_conf(fmap[..., 3], mode=conf_mode)
|
32 |
+
if desc_dim is not None:
|
33 |
+
start = 3 + int(conf_mode is not None)
|
34 |
+
res['desc'] = reg_desc(fmap[..., start:start + desc_dim], mode=desc_mode)
|
35 |
+
if two_confs:
|
36 |
+
res['desc_conf'] = reg_dense_conf(fmap[..., start + desc_dim], mode=desc_conf_mode)
|
37 |
+
else:
|
38 |
+
res['desc_conf'] = res['conf'].clone()
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
class Cat_MLP_LocalFeatures_DPT_Pts3d(PixelwiseTaskWithDPT):
|
43 |
+
""" Mixture between MLP and DPT head that outputs 3d points and local features (with MLP).
|
44 |
+
The input for both heads is a concatenation of Encoder and Decoder outputs
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, net, has_conf=False, local_feat_dim=16, hidden_dim_factor=4., hooks_idx=None, dim_tokens=None,
|
48 |
+
num_channels=1, postprocess=None, feature_dim=256, last_dim=32, depth_mode=None, conf_mode=None, head_type="regression", **kwargs):
|
49 |
+
super().__init__(num_channels=num_channels, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=hooks_idx,
|
50 |
+
dim_tokens=dim_tokens, depth_mode=depth_mode, postprocess=postprocess, conf_mode=conf_mode, head_type=head_type)
|
51 |
+
self.local_feat_dim = local_feat_dim
|
52 |
+
|
53 |
+
patch_size = net.patch_embed.patch_size
|
54 |
+
if isinstance(patch_size, tuple):
|
55 |
+
assert len(patch_size) == 2 and isinstance(patch_size[0], int) and isinstance(
|
56 |
+
patch_size[1], int), "What is your patchsize format? Expected a single int or a tuple of two ints."
|
57 |
+
assert patch_size[0] == patch_size[1], "Error, non square patches not managed"
|
58 |
+
patch_size = patch_size[0]
|
59 |
+
self.patch_size = patch_size
|
60 |
+
|
61 |
+
self.desc_mode = net.desc_mode
|
62 |
+
self.has_conf = has_conf
|
63 |
+
self.two_confs = net.two_confs # independent confs for 3D regr and descs
|
64 |
+
self.desc_conf_mode = net.desc_conf_mode
|
65 |
+
idim = net.enc_embed_dim + net.dec_embed_dim
|
66 |
+
|
67 |
+
self.head_local_features = Mlp(in_features=idim,
|
68 |
+
hidden_features=int(hidden_dim_factor * idim),
|
69 |
+
out_features=(self.local_feat_dim + self.two_confs) * self.patch_size**2)
|
70 |
+
|
71 |
+
def forward(self, decout, img_shape):
|
72 |
+
# pass through the heads
|
73 |
+
pts3d = self.dpt(decout, image_size=(img_shape[0], img_shape[1]))
|
74 |
+
|
75 |
+
# recover encoder and decoder outputs
|
76 |
+
enc_output, dec_output = decout[0], decout[-1]
|
77 |
+
cat_output = torch.cat([enc_output, dec_output], dim=-1) # concatenate
|
78 |
+
H, W = img_shape
|
79 |
+
B, S, D = cat_output.shape
|
80 |
+
|
81 |
+
# extract local_features
|
82 |
+
local_features = self.head_local_features(cat_output) # B,S,D
|
83 |
+
local_features = local_features.transpose(-1, -2).view(B, -1, H // self.patch_size, W // self.patch_size)
|
84 |
+
local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W
|
85 |
+
|
86 |
+
# post process 3D pts, descriptors and confidences
|
87 |
+
out = torch.cat([pts3d, local_features], dim=1)
|
88 |
+
if self.postprocess:
|
89 |
+
out = self.postprocess(out,
|
90 |
+
depth_mode=self.depth_mode,
|
91 |
+
conf_mode=self.conf_mode,
|
92 |
+
desc_dim=self.local_feat_dim,
|
93 |
+
desc_mode=self.desc_mode,
|
94 |
+
two_confs=self.two_confs,
|
95 |
+
desc_conf_mode=self.desc_conf_mode)
|
96 |
+
return out
|
97 |
+
|
98 |
+
|
99 |
+
def mast3r_head_factory(head_type, output_mode, net, has_conf=False):
|
100 |
+
"""" build a prediction head for the decoder
|
101 |
+
"""
|
102 |
+
if head_type == 'catmlp+dpt' and output_mode.startswith('pts3d+desc'):
|
103 |
+
local_feat_dim = int(output_mode[10:])
|
104 |
+
assert net.dec_depth > 9
|
105 |
+
l2 = net.dec_depth
|
106 |
+
feature_dim = 256
|
107 |
+
last_dim = feature_dim // 2
|
108 |
+
out_nchan = 3
|
109 |
+
ed = net.enc_embed_dim
|
110 |
+
dd = net.dec_embed_dim
|
111 |
+
return Cat_MLP_LocalFeatures_DPT_Pts3d(net, local_feat_dim=local_feat_dim, has_conf=has_conf,
|
112 |
+
num_channels=out_nchan + has_conf,
|
113 |
+
feature_dim=feature_dim,
|
114 |
+
last_dim=last_dim,
|
115 |
+
hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2],
|
116 |
+
dim_tokens=[ed, dd, dd, dd],
|
117 |
+
postprocess=postprocess,
|
118 |
+
depth_mode=net.depth_mode,
|
119 |
+
conf_mode=net.conf_mode,
|
120 |
+
head_type='regression')
|
121 |
+
else:
|
122 |
+
raise NotImplementedError(
|
123 |
+
f"unexpected {head_type=} and {output_mode=}")
|
mast3r/cloud_opt/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
mast3r/cloud_opt/sparse_ga.py
ADDED
@@ -0,0 +1,1035 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# MASt3R Sparse Global Alignement
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from tqdm import tqdm
|
8 |
+
import roma
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import numpy as np
|
13 |
+
import os
|
14 |
+
from collections import namedtuple
|
15 |
+
from functools import lru_cache
|
16 |
+
from scipy import sparse as sp
|
17 |
+
import copy
|
18 |
+
|
19 |
+
from mast3r.utils.misc import mkdir_for, hash_md5
|
20 |
+
from mast3r.cloud_opt.utils.losses import gamma_loss
|
21 |
+
from mast3r.cloud_opt.utils.schedules import linear_schedule, cosine_schedule
|
22 |
+
from mast3r.fast_nn import fast_reciprocal_NNs, merge_corres
|
23 |
+
|
24 |
+
import mast3r.utils.path_to_dust3r # noqa
|
25 |
+
from dust3r.utils.geometry import inv, geotrf # noqa
|
26 |
+
from dust3r.utils.device import to_cpu, to_numpy, todevice # noqa
|
27 |
+
from dust3r.post_process import estimate_focal_knowing_depth # noqa
|
28 |
+
from dust3r.optim_factory import adjust_learning_rate_by_lr # noqa
|
29 |
+
from dust3r.viz import SceneViz
|
30 |
+
|
31 |
+
|
32 |
+
class SparseGA():
|
33 |
+
def __init__(self, img_paths, pairs_in, res_fine, anchors, canonical_paths=None):
|
34 |
+
def fetch_img(im):
|
35 |
+
def torgb(x): return (x[0].permute(1, 2, 0).numpy() * .5 + .5).clip(min=0., max=1.)
|
36 |
+
for im1, im2 in pairs_in:
|
37 |
+
if im1['instance'] == im:
|
38 |
+
return torgb(im1['img'])
|
39 |
+
if im2['instance'] == im:
|
40 |
+
return torgb(im2['img'])
|
41 |
+
self.canonical_paths = canonical_paths
|
42 |
+
self.img_paths = img_paths
|
43 |
+
self.imgs = [fetch_img(img) for img in img_paths]
|
44 |
+
self.intrinsics = res_fine['intrinsics']
|
45 |
+
self.cam2w = res_fine['cam2w']
|
46 |
+
self.depthmaps = res_fine['depthmaps']
|
47 |
+
self.pts3d = res_fine['pts3d']
|
48 |
+
self.pts3d_colors = []
|
49 |
+
self.working_device = self.cam2w.device
|
50 |
+
for i in range(len(self.imgs)):
|
51 |
+
im = self.imgs[i]
|
52 |
+
x, y = anchors[i][0][..., :2].detach().cpu().numpy().T
|
53 |
+
self.pts3d_colors.append(im[y, x])
|
54 |
+
assert self.pts3d_colors[-1].shape == self.pts3d[i].shape
|
55 |
+
self.n_imgs = len(self.imgs)
|
56 |
+
|
57 |
+
def get_focals(self):
|
58 |
+
return torch.tensor([ff[0, 0] for ff in self.intrinsics]).to(self.working_device)
|
59 |
+
|
60 |
+
def get_principal_points(self):
|
61 |
+
return torch.stack([ff[:2, -1] for ff in self.intrinsics]).to(self.working_device)
|
62 |
+
|
63 |
+
def get_im_poses(self):
|
64 |
+
return self.cam2w
|
65 |
+
|
66 |
+
def get_sparse_pts3d(self):
|
67 |
+
return self.pts3d
|
68 |
+
|
69 |
+
def get_dense_pts3d(self, clean_depth=True, subsample=8):
|
70 |
+
assert self.canonical_paths, 'cache_path is required for dense 3d points'
|
71 |
+
device = self.cam2w.device
|
72 |
+
confs = []
|
73 |
+
base_focals = []
|
74 |
+
anchors = {}
|
75 |
+
for i, canon_path in enumerate(self.canonical_paths):
|
76 |
+
(canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
|
77 |
+
confs.append(conf)
|
78 |
+
base_focals.append(focal)
|
79 |
+
|
80 |
+
H, W = conf.shape
|
81 |
+
pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
|
82 |
+
idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
|
83 |
+
anchors[i] = (pixels, idxs[i], offsets[i])
|
84 |
+
|
85 |
+
# densify sparse depthmaps
|
86 |
+
pts3d, depthmaps = make_pts3d(anchors, self.intrinsics, self.cam2w, [
|
87 |
+
d.ravel() for d in self.depthmaps], base_focals=base_focals, ret_depth=True)
|
88 |
+
|
89 |
+
return pts3d, depthmaps, confs
|
90 |
+
|
91 |
+
def get_pts3d_colors(self):
|
92 |
+
return self.pts3d_colors
|
93 |
+
|
94 |
+
def get_depthmaps(self):
|
95 |
+
return self.depthmaps
|
96 |
+
|
97 |
+
def get_masks(self):
|
98 |
+
return [slice(None, None) for _ in range(len(self.imgs))]
|
99 |
+
|
100 |
+
def show(self, show_cams=True):
|
101 |
+
pts3d, _, confs = self.get_dense_pts3d()
|
102 |
+
show_reconstruction(self.imgs, self.intrinsics if show_cams else None, self.cam2w,
|
103 |
+
[p.clip(min=-50, max=50) for p in pts3d],
|
104 |
+
masks=[c > 1 for c in confs])
|
105 |
+
|
106 |
+
|
107 |
+
def convert_dust3r_pairs_naming(imgs, pairs_in):
|
108 |
+
for pair_id in range(len(pairs_in)):
|
109 |
+
for i in range(2):
|
110 |
+
pairs_in[pair_id][i]['instance'] = imgs[pairs_in[pair_id][i]['idx']]
|
111 |
+
return pairs_in
|
112 |
+
|
113 |
+
|
114 |
+
def sparse_global_alignment(imgs, pairs_in, cache_path, model, subsample=8, desc_conf='desc_conf',
|
115 |
+
device='cuda', dtype=torch.float32, shared_intrinsics=False, **kw):
|
116 |
+
""" Sparse alignment with MASt3R
|
117 |
+
imgs: list of image paths
|
118 |
+
cache_path: path where to dump temporary files (str)
|
119 |
+
|
120 |
+
lr1, niter1: learning rate and #iterations for coarse global alignment (3D matching)
|
121 |
+
lr2, niter2: learning rate and #iterations for refinement (2D reproj error)
|
122 |
+
|
123 |
+
lora_depth: smart dimensionality reduction with depthmaps
|
124 |
+
"""
|
125 |
+
# Convert pair naming convention from dust3r to mast3r
|
126 |
+
pairs_in = convert_dust3r_pairs_naming(imgs, pairs_in)
|
127 |
+
# forward pass
|
128 |
+
pairs, cache_path = forward_mast3r(pairs_in, model,
|
129 |
+
cache_path=cache_path, subsample=subsample,
|
130 |
+
desc_conf=desc_conf, device=device)
|
131 |
+
|
132 |
+
# extract canonical pointmaps
|
133 |
+
tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21 = \
|
134 |
+
prepare_canonical_data(imgs, pairs, subsample, cache_path=cache_path, mode='avg-angle', device=device)
|
135 |
+
|
136 |
+
# compute minimal spanning tree
|
137 |
+
mst = compute_min_spanning_tree(pairwise_scores)
|
138 |
+
|
139 |
+
# remove all edges not in the spanning tree?
|
140 |
+
# min_spanning_tree = {(imgs[i],imgs[j]) for i,j in mst[1]}
|
141 |
+
# tmp_pairs = {(a,b):v for (a,b),v in tmp_pairs.items() if {(a,b),(b,a)} & min_spanning_tree}
|
142 |
+
|
143 |
+
# smartly combine all useful data
|
144 |
+
imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21 = \
|
145 |
+
condense_data(imgs, tmp_pairs, canonical_views, preds_21, dtype)
|
146 |
+
|
147 |
+
imgs, res_coarse, res_fine = sparse_scene_optimizer(
|
148 |
+
imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d, preds_21, canonical_paths, mst,
|
149 |
+
shared_intrinsics=shared_intrinsics, cache_path=cache_path, device=device, dtype=dtype, **kw)
|
150 |
+
|
151 |
+
return SparseGA(imgs, pairs_in, res_fine or res_coarse, anchors, canonical_paths)
|
152 |
+
|
153 |
+
|
154 |
+
def sparse_scene_optimizer(imgs, subsample, imsizes, pps, base_focals, core_depth, anchors, corres, corres2d,
|
155 |
+
preds_21, canonical_paths, mst, cache_path,
|
156 |
+
lr1=0.2, niter1=500, loss1=gamma_loss(1.1),
|
157 |
+
lr2=0.02, niter2=500, loss2=gamma_loss(0.4),
|
158 |
+
lossd=gamma_loss(1.1),
|
159 |
+
opt_pp=True, opt_depth=True,
|
160 |
+
schedule=cosine_schedule, depth_mode='add', exp_depth=False,
|
161 |
+
lora_depth=False, # dict(k=96, gamma=15, min_norm=.5),
|
162 |
+
shared_intrinsics=False,
|
163 |
+
init={}, device='cuda', dtype=torch.float32,
|
164 |
+
matching_conf_thr=5., loss_dust3r_w=0.01,
|
165 |
+
verbose=True, dbg=()):
|
166 |
+
init = copy.deepcopy(init)
|
167 |
+
# extrinsic parameters
|
168 |
+
vec0001 = torch.tensor((0, 0, 0, 1), dtype=dtype, device=device)
|
169 |
+
quats = [nn.Parameter(vec0001.clone()) for _ in range(len(imgs))]
|
170 |
+
trans = [nn.Parameter(torch.zeros(3, device=device, dtype=dtype)) for _ in range(len(imgs))]
|
171 |
+
|
172 |
+
# initialize
|
173 |
+
ones = torch.ones((len(imgs), 1), device=device, dtype=dtype)
|
174 |
+
median_depths = torch.ones(len(imgs), device=device, dtype=dtype)
|
175 |
+
for img in imgs:
|
176 |
+
idx = imgs.index(img)
|
177 |
+
init_values = init.setdefault(img, {})
|
178 |
+
if verbose and init_values:
|
179 |
+
print(f' >> initializing img=...{img[-25:]} [{idx}] for {set(init_values)}')
|
180 |
+
|
181 |
+
K = init_values.get('intrinsics')
|
182 |
+
if K is not None:
|
183 |
+
K = K.detach()
|
184 |
+
focal = K[:2, :2].diag().mean()
|
185 |
+
pp = K[:2, 2]
|
186 |
+
base_focals[idx] = focal
|
187 |
+
pps[idx] = pp
|
188 |
+
pps[idx] /= imsizes[idx] # default principal_point would be (0.5, 0.5)
|
189 |
+
|
190 |
+
depth = init_values.get('depthmap')
|
191 |
+
if depth is not None:
|
192 |
+
core_depth[idx] = depth.detach()
|
193 |
+
|
194 |
+
median_depths[idx] = med_depth = core_depth[idx].median()
|
195 |
+
core_depth[idx] /= med_depth
|
196 |
+
|
197 |
+
cam2w = init_values.get('cam2w')
|
198 |
+
if cam2w is not None:
|
199 |
+
rot = cam2w[:3, :3].detach()
|
200 |
+
cam_center = cam2w[:3, 3].detach()
|
201 |
+
quats[idx].data[:] = roma.rotmat_to_unitquat(rot)
|
202 |
+
trans_offset = med_depth * torch.cat((imsizes[idx] / base_focals[idx] * (0.5 - pps[idx]), ones[:1, 0]))
|
203 |
+
trans[idx].data[:] = cam_center + rot @ trans_offset
|
204 |
+
del rot
|
205 |
+
assert False, 'inverse kinematic chain not yet implemented'
|
206 |
+
|
207 |
+
# intrinsics parameters
|
208 |
+
if shared_intrinsics:
|
209 |
+
# Optimize a single set of intrinsics for all cameras. Use averages as init.
|
210 |
+
confs = torch.stack([torch.load(pth)[0][2].mean() for pth in canonical_paths]).to(pps)
|
211 |
+
weighting = confs / confs.sum()
|
212 |
+
pp = nn.Parameter((weighting @ pps).to(dtype))
|
213 |
+
pps = [pp for _ in range(len(imgs))]
|
214 |
+
focal_m = weighting @ base_focals
|
215 |
+
log_focal = nn.Parameter(focal_m.view(1).log().to(dtype))
|
216 |
+
log_focals = [log_focal for _ in range(len(imgs))]
|
217 |
+
else:
|
218 |
+
pps = [nn.Parameter(pp.to(dtype)) for pp in pps]
|
219 |
+
log_focals = [nn.Parameter(f.view(1).log().to(dtype)) for f in base_focals]
|
220 |
+
|
221 |
+
diags = imsizes.float().norm(dim=1)
|
222 |
+
min_focals = 0.25 * diags # diag = 1.2~1.4*max(W,H) => beta >= 1/(2*1.2*tan(fov/2)) ~= 0.26
|
223 |
+
max_focals = 10 * diags
|
224 |
+
|
225 |
+
assert len(mst[1]) == len(pps) - 1
|
226 |
+
|
227 |
+
def make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth):
|
228 |
+
# make intrinsics
|
229 |
+
focals = torch.cat(log_focals).exp().clip(min=min_focals, max=max_focals)
|
230 |
+
pps = torch.stack(pps)
|
231 |
+
K = torch.eye(3, dtype=dtype, device=device)[None].expand(len(imgs), 3, 3).clone()
|
232 |
+
K[:, 0, 0] = K[:, 1, 1] = focals
|
233 |
+
K[:, 0:2, 2] = pps * imsizes
|
234 |
+
if trans is None:
|
235 |
+
return K
|
236 |
+
|
237 |
+
# security! optimization is always trying to crush the scale down
|
238 |
+
sizes = torch.cat(log_sizes).exp()
|
239 |
+
global_scaling = 1 / sizes.min()
|
240 |
+
|
241 |
+
# compute distance of camera to focal plane
|
242 |
+
# tan(fov) = W/2 / focal
|
243 |
+
z_cameras = sizes * median_depths * focals / base_focals
|
244 |
+
|
245 |
+
# make extrinsic
|
246 |
+
rel_cam2cam = torch.eye(4, dtype=dtype, device=device)[None].expand(len(imgs), 4, 4).clone()
|
247 |
+
rel_cam2cam[:, :3, :3] = roma.unitquat_to_rotmat(F.normalize(torch.stack(quats), dim=1))
|
248 |
+
rel_cam2cam[:, :3, 3] = torch.stack(trans)
|
249 |
+
|
250 |
+
# camera are defined as a kinematic chain
|
251 |
+
tmp_cam2w = [None] * len(K)
|
252 |
+
tmp_cam2w[mst[0]] = rel_cam2cam[mst[0]]
|
253 |
+
for i, j in mst[1]:
|
254 |
+
# i is the cam_i_to_world reference, j is the relative pose = cam_j_to_cam_i
|
255 |
+
tmp_cam2w[j] = tmp_cam2w[i] @ rel_cam2cam[j]
|
256 |
+
tmp_cam2w = torch.stack(tmp_cam2w)
|
257 |
+
|
258 |
+
# smart reparameterizaton of cameras
|
259 |
+
trans_offset = z_cameras.unsqueeze(1) * torch.cat((imsizes / focals.unsqueeze(1) * (0.5 - pps), ones), dim=-1)
|
260 |
+
new_trans = global_scaling * (tmp_cam2w[:, :3, 3:4] - tmp_cam2w[:, :3, :3] @ trans_offset.unsqueeze(-1))
|
261 |
+
cam2w = torch.cat((torch.cat((tmp_cam2w[:, :3, :3], new_trans), dim=2),
|
262 |
+
vec0001.view(1, 1, 4).expand(len(K), 1, 4)), dim=1)
|
263 |
+
|
264 |
+
depthmaps = []
|
265 |
+
for i in range(len(imgs)):
|
266 |
+
core_depth_img = core_depth[i]
|
267 |
+
if exp_depth:
|
268 |
+
core_depth_img = core_depth_img.exp()
|
269 |
+
if lora_depth: # compute core_depth as a low-rank decomposition of 3d points
|
270 |
+
core_depth_img = lora_depth_proj[i] @ core_depth_img
|
271 |
+
if depth_mode == 'add':
|
272 |
+
core_depth_img = z_cameras[i] + (core_depth_img - 1) * (median_depths[i] * sizes[i])
|
273 |
+
elif depth_mode == 'mul':
|
274 |
+
core_depth_img = z_cameras[i] * core_depth_img
|
275 |
+
else:
|
276 |
+
raise ValueError(f'Bad {depth_mode=}')
|
277 |
+
depthmaps.append(global_scaling * core_depth_img)
|
278 |
+
|
279 |
+
return K, (inv(cam2w), cam2w), depthmaps
|
280 |
+
|
281 |
+
K = make_K_cam_depth(log_focals, pps, None, None, None, None)
|
282 |
+
|
283 |
+
if shared_intrinsics:
|
284 |
+
print('init focal (shared) = ', to_numpy(K[0, 0, 0]).round(2))
|
285 |
+
else:
|
286 |
+
print('init focals =', to_numpy(K[:, 0, 0]))
|
287 |
+
|
288 |
+
# spectral low-rank projection of depthmaps
|
289 |
+
if lora_depth:
|
290 |
+
core_depth, lora_depth_proj = spectral_projection_of_depthmaps(
|
291 |
+
imgs, K, core_depth, subsample, cache_path=cache_path, **lora_depth)
|
292 |
+
if exp_depth:
|
293 |
+
core_depth = [d.clip(min=1e-4).log() for d in core_depth]
|
294 |
+
core_depth = [nn.Parameter(d.ravel().to(dtype)) for d in core_depth]
|
295 |
+
log_sizes = [nn.Parameter(torch.zeros(1, dtype=dtype, device=device)) for _ in range(len(imgs))]
|
296 |
+
|
297 |
+
# Fetch img slices
|
298 |
+
_, confs_sum, imgs_slices = corres
|
299 |
+
|
300 |
+
# Define which pairs are fine to use with matching
|
301 |
+
def matching_check(x): return x.max() > matching_conf_thr
|
302 |
+
is_matching_ok = {}
|
303 |
+
for s in imgs_slices:
|
304 |
+
is_matching_ok[s.img1, s.img2] = matching_check(s.confs)
|
305 |
+
|
306 |
+
# Prepare slices and corres for losses
|
307 |
+
dust3r_slices = [s for s in imgs_slices if not is_matching_ok[s.img1, s.img2]]
|
308 |
+
loss3d_slices = [s for s in imgs_slices if is_matching_ok[s.img1, s.img2]]
|
309 |
+
cleaned_corres2d = []
|
310 |
+
for cci, (img1, pix1, confs, confsum, imgs_slices) in enumerate(corres2d):
|
311 |
+
cf_sum = 0
|
312 |
+
pix1_filtered = []
|
313 |
+
confs_filtered = []
|
314 |
+
curstep = 0
|
315 |
+
cleaned_slices = []
|
316 |
+
for img2, slice2 in imgs_slices:
|
317 |
+
if is_matching_ok[img1, img2]:
|
318 |
+
tslice = slice(curstep, curstep + slice2.stop - slice2.start, slice2.step)
|
319 |
+
pix1_filtered.append(pix1[tslice])
|
320 |
+
confs_filtered.append(confs[tslice])
|
321 |
+
cleaned_slices.append((img2, slice2))
|
322 |
+
curstep += slice2.stop - slice2.start
|
323 |
+
if pix1_filtered != []:
|
324 |
+
pix1_filtered = torch.cat(pix1_filtered)
|
325 |
+
confs_filtered = torch.cat(confs_filtered)
|
326 |
+
cf_sum = confs_filtered.sum()
|
327 |
+
cleaned_corres2d.append((img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices))
|
328 |
+
|
329 |
+
def loss_dust3r(cam2w, pts3d, pix_loss):
|
330 |
+
# In the case no correspondence could be established, fallback to DUSt3R GA regression loss formulation (sparsified)
|
331 |
+
loss = 0.
|
332 |
+
cf_sum = 0.
|
333 |
+
for s in dust3r_slices:
|
334 |
+
if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
|
335 |
+
continue
|
336 |
+
# fallback to dust3r regression
|
337 |
+
tgt_pts, tgt_confs = preds_21[imgs[s.img2]][imgs[s.img1]]
|
338 |
+
tgt_pts = geotrf(cam2w[s.img2], tgt_pts)
|
339 |
+
cf_sum += tgt_confs.sum()
|
340 |
+
loss += tgt_confs @ pix_loss(pts3d[s.img1], tgt_pts)
|
341 |
+
return loss / cf_sum if cf_sum != 0. else 0.
|
342 |
+
|
343 |
+
def loss_3d(K, w2cam, pts3d, pix_loss):
|
344 |
+
# For each correspondence, we have two 3D points (one for each image of the pair).
|
345 |
+
# For each 3D point, we have 2 reproj errors
|
346 |
+
if any(v.get('freeze') for v in init.values()):
|
347 |
+
pts3d_1 = []
|
348 |
+
pts3d_2 = []
|
349 |
+
confs = []
|
350 |
+
for s in loss3d_slices:
|
351 |
+
if init[imgs[s.img1]].get('freeze') and init[imgs[s.img2]].get('freeze'):
|
352 |
+
continue
|
353 |
+
pts3d_1.append(pts3d[s.img1][s.slice1])
|
354 |
+
pts3d_2.append(pts3d[s.img2][s.slice2])
|
355 |
+
confs.append(s.confs)
|
356 |
+
else:
|
357 |
+
pts3d_1 = [pts3d[s.img1][s.slice1] for s in loss3d_slices]
|
358 |
+
pts3d_2 = [pts3d[s.img2][s.slice2] for s in loss3d_slices]
|
359 |
+
confs = [s.confs for s in loss3d_slices]
|
360 |
+
|
361 |
+
if pts3d_1 != []:
|
362 |
+
confs = torch.cat(confs)
|
363 |
+
pts3d_1 = torch.cat(pts3d_1)
|
364 |
+
pts3d_2 = torch.cat(pts3d_2)
|
365 |
+
loss = confs @ pix_loss(pts3d_1, pts3d_2)
|
366 |
+
cf_sum = confs.sum()
|
367 |
+
else:
|
368 |
+
loss = 0.
|
369 |
+
cf_sum = 1.
|
370 |
+
|
371 |
+
return loss / cf_sum
|
372 |
+
|
373 |
+
def loss_2d(K, w2cam, pts3d, pix_loss):
|
374 |
+
# For each correspondence, we have two 3D points (one for each image of the pair).
|
375 |
+
# For each 3D point, we have 2 reproj errors
|
376 |
+
proj_matrix = K @ w2cam[:, :3]
|
377 |
+
loss = npix = 0
|
378 |
+
for img1, pix1_filtered, confs_filtered, cf_sum, cleaned_slices in cleaned_corres2d:
|
379 |
+
if init[imgs[img1]].get('freeze', 0) >= 1:
|
380 |
+
continue # no need
|
381 |
+
pts3d_in_img1 = [pts3d[img2][slice2] for img2, slice2 in cleaned_slices]
|
382 |
+
if pts3d_in_img1 != []:
|
383 |
+
pts3d_in_img1 = torch.cat(pts3d_in_img1)
|
384 |
+
loss += confs_filtered @ pix_loss(pix1_filtered, reproj2d(proj_matrix[img1], pts3d_in_img1))
|
385 |
+
npix += confs_filtered.sum()
|
386 |
+
|
387 |
+
return loss / npix if npix != 0 else 0.
|
388 |
+
|
389 |
+
def optimize_loop(loss_func, lr_base, niter, pix_loss, lr_end=0):
|
390 |
+
# create optimizer
|
391 |
+
params = pps + log_focals + quats + trans + log_sizes + core_depth
|
392 |
+
optimizer = torch.optim.Adam(params, lr=1, weight_decay=0, betas=(0.9, 0.9))
|
393 |
+
ploss = pix_loss if 'meta' in repr(pix_loss) else (lambda a: pix_loss)
|
394 |
+
|
395 |
+
with tqdm(total=niter) as bar:
|
396 |
+
for iter in range(niter or 1):
|
397 |
+
K, (w2cam, cam2w), depthmaps = make_K_cam_depth(log_focals, pps, trans, quats, log_sizes, core_depth)
|
398 |
+
pts3d = make_pts3d(anchors, K, cam2w, depthmaps, base_focals=base_focals)
|
399 |
+
if niter == 0:
|
400 |
+
break
|
401 |
+
|
402 |
+
alpha = (iter / niter)
|
403 |
+
lr = schedule(alpha, lr_base, lr_end)
|
404 |
+
adjust_learning_rate_by_lr(optimizer, lr)
|
405 |
+
pix_loss = ploss(1 - alpha)
|
406 |
+
optimizer.zero_grad()
|
407 |
+
loss = loss_func(K, w2cam, pts3d, pix_loss) + loss_dust3r_w * loss_dust3r(cam2w, pts3d, lossd)
|
408 |
+
loss.backward()
|
409 |
+
optimizer.step()
|
410 |
+
|
411 |
+
# make sure the pose remains well optimizable
|
412 |
+
for i in range(len(imgs)):
|
413 |
+
quats[i].data[:] /= quats[i].data.norm()
|
414 |
+
|
415 |
+
loss = float(loss)
|
416 |
+
if loss != loss:
|
417 |
+
break # NaN loss
|
418 |
+
bar.set_postfix_str(f'{lr=:.4f}, {loss=:.3f}')
|
419 |
+
bar.update(1)
|
420 |
+
|
421 |
+
if niter:
|
422 |
+
print(f'>> final loss = {loss}')
|
423 |
+
return dict(intrinsics=K.detach(), cam2w=cam2w.detach(),
|
424 |
+
depthmaps=[d.detach() for d in depthmaps], pts3d=[p.detach() for p in pts3d])
|
425 |
+
|
426 |
+
# at start, don't optimize 3d points
|
427 |
+
for i, img in enumerate(imgs):
|
428 |
+
trainable = not (init[img].get('freeze'))
|
429 |
+
pps[i].requires_grad_(False)
|
430 |
+
log_focals[i].requires_grad_(False)
|
431 |
+
quats[i].requires_grad_(trainable)
|
432 |
+
trans[i].requires_grad_(trainable)
|
433 |
+
log_sizes[i].requires_grad_(trainable)
|
434 |
+
core_depth[i].requires_grad_(False)
|
435 |
+
|
436 |
+
res_coarse = optimize_loop(loss_3d, lr_base=lr1, niter=niter1, pix_loss=loss1)
|
437 |
+
|
438 |
+
res_fine = None
|
439 |
+
if niter2:
|
440 |
+
# now we can optimize 3d points
|
441 |
+
for i, img in enumerate(imgs):
|
442 |
+
if init[img].get('freeze', 0) >= 1:
|
443 |
+
continue
|
444 |
+
pps[i].requires_grad_(bool(opt_pp))
|
445 |
+
log_focals[i].requires_grad_(True)
|
446 |
+
core_depth[i].requires_grad_(opt_depth)
|
447 |
+
|
448 |
+
# refinement with 2d reproj
|
449 |
+
res_fine = optimize_loop(loss_2d, lr_base=lr2, niter=niter2, pix_loss=loss2)
|
450 |
+
|
451 |
+
K = make_K_cam_depth(log_focals, pps, None, None, None, None)
|
452 |
+
if shared_intrinsics:
|
453 |
+
print('Final focal (shared) = ', to_numpy(K[0, 0, 0]).round(2))
|
454 |
+
else:
|
455 |
+
print('Final focals =', to_numpy(K[:, 0, 0]))
|
456 |
+
|
457 |
+
return imgs, res_coarse, res_fine
|
458 |
+
|
459 |
+
|
460 |
+
@lru_cache
|
461 |
+
def mask110(device, dtype):
|
462 |
+
return torch.tensor((1, 1, 0), device=device, dtype=dtype)
|
463 |
+
|
464 |
+
|
465 |
+
def proj3d(inv_K, pixels, z):
|
466 |
+
if pixels.shape[-1] == 2:
|
467 |
+
pixels = torch.cat((pixels, torch.ones_like(pixels[..., :1])), dim=-1)
|
468 |
+
return z.unsqueeze(-1) * (pixels * inv_K.diag() + inv_K[:, 2] * mask110(z.device, z.dtype))
|
469 |
+
|
470 |
+
|
471 |
+
def make_pts3d(anchors, K, cam2w, depthmaps, base_focals=None, ret_depth=False):
|
472 |
+
focals = K[:, 0, 0]
|
473 |
+
invK = inv(K)
|
474 |
+
all_pts3d = []
|
475 |
+
depth_out = []
|
476 |
+
|
477 |
+
for img, (pixels, idxs, offsets) in anchors.items():
|
478 |
+
# from depthmaps to 3d points
|
479 |
+
if base_focals is None:
|
480 |
+
pass
|
481 |
+
else:
|
482 |
+
# compensate for focal
|
483 |
+
# depth + depth * (offset - 1) * base_focal / focal
|
484 |
+
# = depth * (1 + (offset - 1) * (base_focal / focal))
|
485 |
+
offsets = 1 + (offsets - 1) * (base_focals[img] / focals[img])
|
486 |
+
|
487 |
+
pts3d = proj3d(invK[img], pixels, depthmaps[img][idxs] * offsets)
|
488 |
+
if ret_depth:
|
489 |
+
depth_out.append(pts3d[..., 2]) # before camera rotation
|
490 |
+
# rotate to world coordinate
|
491 |
+
pts3d = geotrf(cam2w[img], pts3d)
|
492 |
+
all_pts3d.append(pts3d)
|
493 |
+
|
494 |
+
if ret_depth:
|
495 |
+
return all_pts3d, depth_out
|
496 |
+
return all_pts3d
|
497 |
+
|
498 |
+
|
499 |
+
def make_dense_pts3d(intrinsics, cam2w, depthmaps, canonical_paths, subsample, device='cuda'):
|
500 |
+
base_focals = []
|
501 |
+
anchors = {}
|
502 |
+
confs = []
|
503 |
+
for i, canon_path in enumerate(canonical_paths):
|
504 |
+
(canon, canon2, conf), focal = torch.load(canon_path, map_location=device)
|
505 |
+
confs.append(conf)
|
506 |
+
base_focals.append(focal)
|
507 |
+
H, W = conf.shape
|
508 |
+
pixels = torch.from_numpy(np.mgrid[:W, :H].T.reshape(-1, 2)).float().to(device)
|
509 |
+
idxs, offsets = anchor_depth_offsets(canon2, {i: (pixels, None)}, subsample=subsample)
|
510 |
+
anchors[i] = (pixels, idxs[i], offsets[i])
|
511 |
+
|
512 |
+
# densify sparse depthmaps
|
513 |
+
pts3d, depthmaps_out = make_pts3d(anchors, intrinsics, cam2w, [
|
514 |
+
d.ravel() for d in depthmaps], base_focals=base_focals, ret_depth=True)
|
515 |
+
|
516 |
+
return pts3d, depthmaps_out, confs
|
517 |
+
|
518 |
+
|
519 |
+
@torch.no_grad()
|
520 |
+
def forward_mast3r(pairs, model, cache_path, desc_conf='desc_conf',
|
521 |
+
device='cuda', subsample=8, **matching_kw):
|
522 |
+
res_paths = {}
|
523 |
+
|
524 |
+
for img1, img2 in tqdm(pairs):
|
525 |
+
idx1 = hash_md5(img1['instance'])
|
526 |
+
idx2 = hash_md5(img2['instance'])
|
527 |
+
|
528 |
+
path1 = cache_path + f'/forward/{idx1}/{idx2}.pth'
|
529 |
+
path2 = cache_path + f'/forward/{idx2}/{idx1}.pth'
|
530 |
+
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx1}-{idx2}.pth'
|
531 |
+
path_corres2 = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{idx2}-{idx1}.pth'
|
532 |
+
|
533 |
+
if os.path.isfile(path_corres2) and not os.path.isfile(path_corres):
|
534 |
+
score, (xy1, xy2, confs) = torch.load(path_corres2)
|
535 |
+
torch.save((score, (xy2, xy1, confs)), path_corres)
|
536 |
+
|
537 |
+
if not all(os.path.isfile(p) for p in (path1, path2, path_corres)):
|
538 |
+
if model is None:
|
539 |
+
continue
|
540 |
+
res = symmetric_inference(model, img1, img2, device=device)
|
541 |
+
X11, X21, X22, X12 = [r['pts3d'][0] for r in res]
|
542 |
+
C11, C21, C22, C12 = [r['conf'][0] for r in res]
|
543 |
+
descs = [r['desc'][0] for r in res]
|
544 |
+
qonfs = [r[desc_conf][0] for r in res]
|
545 |
+
|
546 |
+
# save
|
547 |
+
torch.save(to_cpu((X11, C11, X21, C21)), mkdir_for(path1))
|
548 |
+
torch.save(to_cpu((X22, C22, X12, C12)), mkdir_for(path2))
|
549 |
+
|
550 |
+
# perform reciprocal matching
|
551 |
+
corres = extract_correspondences(descs, qonfs, device=device, subsample=subsample)
|
552 |
+
|
553 |
+
conf_score = (C11.mean() * C12.mean() * C21.mean() * C22.mean()).sqrt().sqrt()
|
554 |
+
matching_score = (float(conf_score), float(corres[2].sum()), len(corres[2]))
|
555 |
+
if cache_path is not None:
|
556 |
+
torch.save((matching_score, corres), mkdir_for(path_corres))
|
557 |
+
|
558 |
+
res_paths[img1['instance'], img2['instance']] = (path1, path2), path_corres
|
559 |
+
|
560 |
+
del model
|
561 |
+
torch.cuda.empty_cache()
|
562 |
+
|
563 |
+
return res_paths, cache_path
|
564 |
+
|
565 |
+
|
566 |
+
def symmetric_inference(model, img1, img2, device):
|
567 |
+
shape1 = torch.from_numpy(img1['true_shape']).to(device, non_blocking=True)
|
568 |
+
shape2 = torch.from_numpy(img2['true_shape']).to(device, non_blocking=True)
|
569 |
+
img1 = img1['img'].to(device, non_blocking=True)
|
570 |
+
img2 = img2['img'].to(device, non_blocking=True)
|
571 |
+
|
572 |
+
# compute encoder only once
|
573 |
+
feat1, feat2, pos1, pos2 = model._encode_image_pairs(img1, img2, shape1, shape2)
|
574 |
+
|
575 |
+
def decoder(feat1, feat2, pos1, pos2, shape1, shape2):
|
576 |
+
dec1, dec2 = model._decoder(feat1, pos1, feat2, pos2)
|
577 |
+
with torch.cuda.amp.autocast(enabled=False):
|
578 |
+
res1 = model._downstream_head(1, [tok.float() for tok in dec1], shape1)
|
579 |
+
res2 = model._downstream_head(2, [tok.float() for tok in dec2], shape2)
|
580 |
+
return res1, res2
|
581 |
+
|
582 |
+
# decoder 1-2
|
583 |
+
res11, res21 = decoder(feat1, feat2, pos1, pos2, shape1, shape2)
|
584 |
+
# decoder 2-1
|
585 |
+
res22, res12 = decoder(feat2, feat1, pos2, pos1, shape2, shape1)
|
586 |
+
|
587 |
+
return (res11, res21, res22, res12)
|
588 |
+
|
589 |
+
|
590 |
+
def extract_correspondences(feats, qonfs, subsample=8, device=None, ptmap_key='pred_desc'):
|
591 |
+
feat11, feat21, feat22, feat12 = feats
|
592 |
+
qonf11, qonf21, qonf22, qonf12 = qonfs
|
593 |
+
assert feat11.shape[:2] == feat12.shape[:2] == qonf11.shape == qonf12.shape
|
594 |
+
assert feat21.shape[:2] == feat22.shape[:2] == qonf21.shape == qonf22.shape
|
595 |
+
|
596 |
+
if '3d' in ptmap_key:
|
597 |
+
opt = dict(device='cpu', workers=32)
|
598 |
+
else:
|
599 |
+
opt = dict(device=device, dist='dot', block_size=2**13)
|
600 |
+
|
601 |
+
# matching the two pairs
|
602 |
+
idx1 = []
|
603 |
+
idx2 = []
|
604 |
+
qonf1 = []
|
605 |
+
qonf2 = []
|
606 |
+
# TODO add non symmetric / pixel_tol options
|
607 |
+
for A, B, QA, QB in [(feat11, feat21, qonf11.cpu(), qonf21.cpu()),
|
608 |
+
(feat12, feat22, qonf12.cpu(), qonf22.cpu())]:
|
609 |
+
nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt)
|
610 |
+
nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt)
|
611 |
+
|
612 |
+
idx1.append(np.r_[nn1to2[0], nn2to1[1]])
|
613 |
+
idx2.append(np.r_[nn1to2[1], nn2to1[0]])
|
614 |
+
qonf1.append(QA.ravel()[idx1[-1]])
|
615 |
+
qonf2.append(QB.ravel()[idx2[-1]])
|
616 |
+
|
617 |
+
# merge corres from opposite pairs
|
618 |
+
H1, W1 = feat11.shape[:2]
|
619 |
+
H2, W2 = feat22.shape[:2]
|
620 |
+
cat = np.concatenate
|
621 |
+
|
622 |
+
xy1, xy2, idx = merge_corres(cat(idx1), cat(idx2), (H1, W1), (H2, W2), ret_xy=True, ret_index=True)
|
623 |
+
corres = (xy1.copy(), xy2.copy(), np.sqrt(cat(qonf1)[idx] * cat(qonf2)[idx]))
|
624 |
+
|
625 |
+
return todevice(corres, device)
|
626 |
+
|
627 |
+
|
628 |
+
@torch.no_grad()
|
629 |
+
def prepare_canonical_data(imgs, tmp_pairs, subsample, order_imgs=False, min_conf_thr=0,
|
630 |
+
cache_path=None, device='cuda', **kw):
|
631 |
+
canonical_views = {}
|
632 |
+
pairwise_scores = torch.zeros((len(imgs), len(imgs)), device=device)
|
633 |
+
canonical_paths = []
|
634 |
+
preds_21 = {}
|
635 |
+
|
636 |
+
for img in tqdm(imgs):
|
637 |
+
if cache_path:
|
638 |
+
cache = os.path.join(cache_path, 'canon_views', hash_md5(img) + f'_{subsample=}_{kw=}.pth')
|
639 |
+
canonical_paths.append(cache)
|
640 |
+
try:
|
641 |
+
(canon, canon2, cconf), focal = torch.load(cache, map_location=device)
|
642 |
+
except IOError:
|
643 |
+
# cache does not exist yet, we create it!
|
644 |
+
canon = focal = None
|
645 |
+
|
646 |
+
# collect all pred1
|
647 |
+
n_pairs = sum((img in pair) for pair in tmp_pairs)
|
648 |
+
|
649 |
+
ptmaps11 = None
|
650 |
+
pixels = {}
|
651 |
+
n = 0
|
652 |
+
for (img1, img2), ((path1, path2), path_corres) in tmp_pairs.items():
|
653 |
+
score = None
|
654 |
+
if img == img1:
|
655 |
+
X, C, X2, C2 = torch.load(path1, map_location=device)
|
656 |
+
score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
|
657 |
+
pixels[img2] = xy1, confs
|
658 |
+
if img not in preds_21:
|
659 |
+
preds_21[img] = {}
|
660 |
+
# Subsample preds_21
|
661 |
+
preds_21[img][img2] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
|
662 |
+
|
663 |
+
if img == img2:
|
664 |
+
X, C, X2, C2 = torch.load(path2, map_location=device)
|
665 |
+
score, (xy1, xy2, confs) = load_corres(path_corres, device, min_conf_thr)
|
666 |
+
pixels[img1] = xy2, confs
|
667 |
+
if img not in preds_21:
|
668 |
+
preds_21[img] = {}
|
669 |
+
preds_21[img][img1] = X2[::subsample, ::subsample].reshape(-1, 3), C2[::subsample, ::subsample].ravel()
|
670 |
+
|
671 |
+
if score is not None:
|
672 |
+
i, j = imgs.index(img1), imgs.index(img2)
|
673 |
+
# score = score[0]
|
674 |
+
# score = np.log1p(score[2])
|
675 |
+
score = score[2]
|
676 |
+
pairwise_scores[i, j] = score
|
677 |
+
pairwise_scores[j, i] = score
|
678 |
+
|
679 |
+
if canon is not None:
|
680 |
+
continue
|
681 |
+
if ptmaps11 is None:
|
682 |
+
H, W = C.shape
|
683 |
+
ptmaps11 = torch.empty((n_pairs, H, W, 3), device=device)
|
684 |
+
confs11 = torch.empty((n_pairs, H, W), device=device)
|
685 |
+
|
686 |
+
ptmaps11[n] = X
|
687 |
+
confs11[n] = C
|
688 |
+
n += 1
|
689 |
+
|
690 |
+
if canon is None:
|
691 |
+
canon, canon2, cconf = canonical_view(ptmaps11, confs11, subsample, **kw)
|
692 |
+
del ptmaps11
|
693 |
+
del confs11
|
694 |
+
|
695 |
+
# compute focals
|
696 |
+
H, W = canon.shape[:2]
|
697 |
+
pp = torch.tensor([W / 2, H / 2], device=device)
|
698 |
+
if focal is None:
|
699 |
+
focal = estimate_focal_knowing_depth(canon[None], pp, focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5)
|
700 |
+
if cache:
|
701 |
+
torch.save(to_cpu(((canon, canon2, cconf), focal)), mkdir_for(cache))
|
702 |
+
|
703 |
+
# extract depth offsets with correspondences
|
704 |
+
core_depth = canon[subsample // 2::subsample, subsample // 2::subsample, 2]
|
705 |
+
idxs, offsets = anchor_depth_offsets(canon2, pixels, subsample=subsample)
|
706 |
+
|
707 |
+
canonical_views[img] = (pp, (H, W), focal.view(1), core_depth, pixels, idxs, offsets)
|
708 |
+
|
709 |
+
return tmp_pairs, pairwise_scores, canonical_views, canonical_paths, preds_21
|
710 |
+
|
711 |
+
|
712 |
+
def load_corres(path_corres, device, min_conf_thr):
|
713 |
+
score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device)
|
714 |
+
valid = confs > min_conf_thr if min_conf_thr else slice(None)
|
715 |
+
# valid = (xy1 > 0).all(dim=1) & (xy2 > 0).all(dim=1) & (xy1 < 512).all(dim=1) & (xy2 < 512).all(dim=1)
|
716 |
+
# print(f'keeping {valid.sum()} / {len(valid)} correspondences')
|
717 |
+
return score, (xy1[valid], xy2[valid], confs[valid])
|
718 |
+
|
719 |
+
|
720 |
+
PairOfSlices = namedtuple(
|
721 |
+
'ImgPair', 'img1, slice1, pix1, anchor_idxs1, img2, slice2, pix2, anchor_idxs2, confs, confs_sum')
|
722 |
+
|
723 |
+
|
724 |
+
def condense_data(imgs, tmp_paths, canonical_views, preds_21, dtype=torch.float32):
|
725 |
+
# aggregate all data properly
|
726 |
+
set_imgs = set(imgs)
|
727 |
+
|
728 |
+
principal_points = []
|
729 |
+
shapes = []
|
730 |
+
focals = []
|
731 |
+
core_depth = []
|
732 |
+
img_anchors = {}
|
733 |
+
tmp_pixels = {}
|
734 |
+
|
735 |
+
for idx1, img1 in enumerate(imgs):
|
736 |
+
# load stuff
|
737 |
+
pp, shape, focal, anchors, pixels_confs, idxs, offsets = canonical_views[img1]
|
738 |
+
|
739 |
+
principal_points.append(pp)
|
740 |
+
shapes.append(shape)
|
741 |
+
focals.append(focal)
|
742 |
+
core_depth.append(anchors)
|
743 |
+
|
744 |
+
img_uv1 = []
|
745 |
+
img_idxs = []
|
746 |
+
img_offs = []
|
747 |
+
cur_n = [0]
|
748 |
+
|
749 |
+
for img2, (pixels, match_confs) in pixels_confs.items():
|
750 |
+
if img2 not in set_imgs:
|
751 |
+
continue
|
752 |
+
assert len(pixels) == len(idxs[img2]) == len(offsets[img2])
|
753 |
+
img_uv1.append(torch.cat((pixels, torch.ones_like(pixels[:, :1])), dim=-1))
|
754 |
+
img_idxs.append(idxs[img2])
|
755 |
+
img_offs.append(offsets[img2])
|
756 |
+
cur_n.append(cur_n[-1] + len(pixels))
|
757 |
+
# store the position of 3d points
|
758 |
+
tmp_pixels[img1, img2] = pixels.to(dtype), match_confs.to(dtype), slice(*cur_n[-2:])
|
759 |
+
img_anchors[idx1] = (torch.cat(img_uv1), torch.cat(img_idxs), torch.cat(img_offs))
|
760 |
+
|
761 |
+
all_confs = []
|
762 |
+
imgs_slices = []
|
763 |
+
corres2d = {img: [] for img in range(len(imgs))}
|
764 |
+
|
765 |
+
for img1, img2 in tmp_paths:
|
766 |
+
try:
|
767 |
+
pix1, confs1, slice1 = tmp_pixels[img1, img2]
|
768 |
+
pix2, confs2, slice2 = tmp_pixels[img2, img1]
|
769 |
+
except KeyError:
|
770 |
+
continue
|
771 |
+
img1 = imgs.index(img1)
|
772 |
+
img2 = imgs.index(img2)
|
773 |
+
confs = (confs1 * confs2).sqrt()
|
774 |
+
|
775 |
+
# prepare for loss_3d
|
776 |
+
all_confs.append(confs)
|
777 |
+
anchor_idxs1 = canonical_views[imgs[img1]][5][imgs[img2]]
|
778 |
+
anchor_idxs2 = canonical_views[imgs[img2]][5][imgs[img1]]
|
779 |
+
imgs_slices.append(PairOfSlices(img1, slice1, pix1, anchor_idxs1,
|
780 |
+
img2, slice2, pix2, anchor_idxs2,
|
781 |
+
confs, float(confs.sum())))
|
782 |
+
|
783 |
+
# prepare for loss_2d
|
784 |
+
corres2d[img1].append((pix1, confs, img2, slice2))
|
785 |
+
corres2d[img2].append((pix2, confs, img1, slice1))
|
786 |
+
|
787 |
+
all_confs = torch.cat(all_confs)
|
788 |
+
corres = (all_confs, float(all_confs.sum()), imgs_slices)
|
789 |
+
|
790 |
+
def aggreg_matches(img1, list_matches):
|
791 |
+
pix1, confs, img2, slice2 = zip(*list_matches)
|
792 |
+
all_pix1 = torch.cat(pix1).to(dtype)
|
793 |
+
all_confs = torch.cat(confs).to(dtype)
|
794 |
+
return img1, all_pix1, all_confs, float(all_confs.sum()), [(j, sl2) for j, sl2 in zip(img2, slice2)]
|
795 |
+
corres2d = [aggreg_matches(img, m) for img, m in corres2d.items()]
|
796 |
+
|
797 |
+
imsizes = torch.tensor([(W, H) for H, W in shapes], device=pp.device) # (W,H)
|
798 |
+
principal_points = torch.stack(principal_points)
|
799 |
+
focals = torch.cat(focals)
|
800 |
+
|
801 |
+
# Subsample preds_21
|
802 |
+
subsamp_preds_21 = {}
|
803 |
+
for imk, imv in preds_21.items():
|
804 |
+
subsamp_preds_21[imk] = {}
|
805 |
+
for im2k, (pred, conf) in preds_21[imk].items():
|
806 |
+
idxs = img_anchors[imgs.index(im2k)][1]
|
807 |
+
subsamp_preds_21[imk][im2k] = (pred[idxs], conf[idxs]) # anchors subsample
|
808 |
+
|
809 |
+
return imsizes, principal_points, focals, core_depth, img_anchors, corres, corres2d, subsamp_preds_21
|
810 |
+
|
811 |
+
|
812 |
+
def canonical_view(ptmaps11, confs11, subsample, mode='avg-angle'):
|
813 |
+
assert len(ptmaps11) == len(confs11) > 0, 'not a single view1 for img={i}'
|
814 |
+
|
815 |
+
# canonical pointmap is just a weighted average
|
816 |
+
confs11 = confs11.unsqueeze(-1) - 0.999
|
817 |
+
canon = (confs11 * ptmaps11).sum(0) / confs11.sum(0)
|
818 |
+
|
819 |
+
canon_depth = ptmaps11[..., 2].unsqueeze(1)
|
820 |
+
S = slice(subsample // 2, None, subsample)
|
821 |
+
center_depth = canon_depth[:, :, S, S]
|
822 |
+
center_depth = torch.clip(center_depth, min=torch.finfo(center_depth.dtype).eps)
|
823 |
+
|
824 |
+
stacked_depth = F.pixel_unshuffle(canon_depth, subsample)
|
825 |
+
stacked_confs = F.pixel_unshuffle(confs11[:, None, :, :, 0], subsample)
|
826 |
+
|
827 |
+
if mode == 'avg-reldepth':
|
828 |
+
rel_depth = stacked_depth / center_depth
|
829 |
+
stacked_canon = (stacked_confs * rel_depth).sum(dim=0) / stacked_confs.sum(dim=0)
|
830 |
+
canon2 = F.pixel_shuffle(stacked_canon.unsqueeze(0), subsample).squeeze()
|
831 |
+
|
832 |
+
elif mode == 'avg-angle':
|
833 |
+
xy = ptmaps11[..., 0:2].permute(0, 3, 1, 2)
|
834 |
+
stacked_xy = F.pixel_unshuffle(xy, subsample)
|
835 |
+
B, _, H, W = stacked_xy.shape
|
836 |
+
stacked_radius = (stacked_xy.view(B, 2, -1, H, W) - xy[:, :, None, S, S]).norm(dim=1)
|
837 |
+
stacked_radius.clip_(min=1e-8)
|
838 |
+
|
839 |
+
stacked_angle = torch.arctan((stacked_depth - center_depth) / stacked_radius)
|
840 |
+
avg_angle = (stacked_confs * stacked_angle).sum(dim=0) / stacked_confs.sum(dim=0)
|
841 |
+
|
842 |
+
# back to depth
|
843 |
+
stacked_depth = stacked_radius.mean(dim=0) * torch.tan(avg_angle)
|
844 |
+
|
845 |
+
canon2 = F.pixel_shuffle((1 + stacked_depth / canon[S, S, 2]).unsqueeze(0), subsample).squeeze()
|
846 |
+
else:
|
847 |
+
raise ValueError(f'bad {mode=}')
|
848 |
+
|
849 |
+
confs = (confs11.square().sum(dim=0) / confs11.sum(dim=0)).squeeze()
|
850 |
+
return canon, canon2, confs
|
851 |
+
|
852 |
+
|
853 |
+
def anchor_depth_offsets(canon_depth, pixels, subsample=8):
|
854 |
+
device = canon_depth.device
|
855 |
+
|
856 |
+
# create a 2D grid of anchor 3D points
|
857 |
+
H1, W1 = canon_depth.shape
|
858 |
+
yx = np.mgrid[subsample // 2:H1:subsample, subsample // 2:W1:subsample]
|
859 |
+
H2, W2 = yx.shape[1:]
|
860 |
+
cy, cx = yx.reshape(2, -1)
|
861 |
+
core_depth = canon_depth[cy, cx]
|
862 |
+
assert (core_depth > 0).all()
|
863 |
+
|
864 |
+
# slave 3d points (attached to core 3d points)
|
865 |
+
core_idxs = {} # core_idxs[img2] = {corr_idx:core_idx}
|
866 |
+
core_offs = {} # core_offs[img2] = {corr_idx:3d_offset}
|
867 |
+
|
868 |
+
for img2, (xy1, _confs) in pixels.items():
|
869 |
+
px, py = xy1.long().T
|
870 |
+
|
871 |
+
# find nearest anchor == block quantization
|
872 |
+
core_idx = (py // subsample) * W2 + (px // subsample)
|
873 |
+
core_idxs[img2] = core_idx.to(device)
|
874 |
+
|
875 |
+
# compute relative depth offsets w.r.t. anchors
|
876 |
+
ref_z = core_depth[core_idx]
|
877 |
+
pts_z = canon_depth[py, px]
|
878 |
+
offset = pts_z / ref_z
|
879 |
+
core_offs[img2] = offset.detach().to(device)
|
880 |
+
|
881 |
+
return core_idxs, core_offs
|
882 |
+
|
883 |
+
|
884 |
+
def spectral_clustering(graph, k=None, normalized_cuts=False):
|
885 |
+
graph.fill_diagonal_(0)
|
886 |
+
|
887 |
+
# graph laplacian
|
888 |
+
degrees = graph.sum(dim=-1)
|
889 |
+
laplacian = torch.diag(degrees) - graph
|
890 |
+
if normalized_cuts:
|
891 |
+
i_inv = torch.diag(degrees.sqrt().reciprocal())
|
892 |
+
laplacian = i_inv @ laplacian @ i_inv
|
893 |
+
|
894 |
+
# compute eigenvectors!
|
895 |
+
eigval, eigvec = torch.linalg.eigh(laplacian)
|
896 |
+
return eigval[:k], eigvec[:, :k]
|
897 |
+
|
898 |
+
|
899 |
+
def sim_func(p1, p2, gamma):
|
900 |
+
diff = (p1 - p2).norm(dim=-1)
|
901 |
+
avg_depth = (p1[:, :, 2] + p2[:, :, 2])
|
902 |
+
rel_distance = diff / avg_depth
|
903 |
+
sim = torch.exp(-gamma * rel_distance.square())
|
904 |
+
return sim
|
905 |
+
|
906 |
+
|
907 |
+
def backproj(K, depthmap, subsample):
|
908 |
+
H, W = depthmap.shape
|
909 |
+
uv = np.mgrid[subsample // 2:subsample * W:subsample, subsample // 2:subsample * H:subsample].T.reshape(H, W, 2)
|
910 |
+
xyz = depthmap.unsqueeze(-1) * geotrf(inv(K), todevice(uv, K.device), ncol=3)
|
911 |
+
return xyz
|
912 |
+
|
913 |
+
|
914 |
+
def spectral_projection_depth(K, depthmap, subsample, k=64, cache_path='',
|
915 |
+
normalized_cuts=True, gamma=7, min_norm=5):
|
916 |
+
try:
|
917 |
+
if cache_path:
|
918 |
+
cache_path = cache_path + f'_{k=}_norm={normalized_cuts}_{gamma=}.pth'
|
919 |
+
lora_proj = torch.load(cache_path, map_location=K.device)
|
920 |
+
|
921 |
+
except IOError:
|
922 |
+
# reconstruct 3d points in camera coordinates
|
923 |
+
xyz = backproj(K, depthmap, subsample)
|
924 |
+
|
925 |
+
# compute all distances
|
926 |
+
xyz = xyz.reshape(-1, 3)
|
927 |
+
graph = sim_func(xyz[:, None], xyz[None, :], gamma=gamma)
|
928 |
+
_, lora_proj = spectral_clustering(graph, k, normalized_cuts=normalized_cuts)
|
929 |
+
|
930 |
+
if cache_path:
|
931 |
+
torch.save(lora_proj.cpu(), mkdir_for(cache_path))
|
932 |
+
|
933 |
+
lora_proj, coeffs = lora_encode_normed(lora_proj, depthmap.ravel(), min_norm=min_norm)
|
934 |
+
|
935 |
+
# depthmap ~= lora_proj @ coeffs
|
936 |
+
return coeffs, lora_proj
|
937 |
+
|
938 |
+
|
939 |
+
def lora_encode_normed(lora_proj, x, min_norm, global_norm=False):
|
940 |
+
# encode the pointmap
|
941 |
+
coeffs = torch.linalg.pinv(lora_proj) @ x
|
942 |
+
|
943 |
+
# rectify the norm of basis vector to be ~ equal
|
944 |
+
if coeffs.ndim == 1:
|
945 |
+
coeffs = coeffs[:, None]
|
946 |
+
if global_norm:
|
947 |
+
lora_proj *= coeffs[1:].norm() * min_norm / coeffs.shape[1]
|
948 |
+
elif min_norm:
|
949 |
+
lora_proj *= coeffs.norm(dim=1).clip(min=min_norm)
|
950 |
+
# can have rounding errors here!
|
951 |
+
coeffs = (torch.linalg.pinv(lora_proj.double()) @ x.double()).float()
|
952 |
+
|
953 |
+
return lora_proj.detach(), coeffs.detach()
|
954 |
+
|
955 |
+
|
956 |
+
@torch.no_grad()
|
957 |
+
def spectral_projection_of_depthmaps(imgs, intrinsics, depthmaps, subsample, cache_path=None, **kw):
|
958 |
+
# recover 3d points
|
959 |
+
core_depth = []
|
960 |
+
lora_proj = []
|
961 |
+
|
962 |
+
for i, img in enumerate(tqdm(imgs)):
|
963 |
+
cache = os.path.join(cache_path, 'lora_depth', hash_md5(img)) if cache_path else None
|
964 |
+
depth, proj = spectral_projection_depth(intrinsics[i], depthmaps[i], subsample,
|
965 |
+
cache_path=cache, **kw)
|
966 |
+
core_depth.append(depth)
|
967 |
+
lora_proj.append(proj)
|
968 |
+
|
969 |
+
return core_depth, lora_proj
|
970 |
+
|
971 |
+
|
972 |
+
def reproj2d(Trf, pts3d):
|
973 |
+
res = (pts3d @ Trf[:3, :3].transpose(-1, -2)) + Trf[:3, 3]
|
974 |
+
clipped_z = res[:, 2:3].clip(min=1e-3) # make sure we don't have nans!
|
975 |
+
uv = res[:, 0:2] / clipped_z
|
976 |
+
return uv.clip(min=-1000, max=2000)
|
977 |
+
|
978 |
+
|
979 |
+
def bfs(tree, start_node):
|
980 |
+
order, predecessors = sp.csgraph.breadth_first_order(tree, start_node, directed=False)
|
981 |
+
ranks = np.arange(len(order))
|
982 |
+
ranks[order] = ranks.copy()
|
983 |
+
return ranks, predecessors
|
984 |
+
|
985 |
+
|
986 |
+
def compute_min_spanning_tree(pws):
|
987 |
+
sparse_graph = sp.dok_array(pws.shape)
|
988 |
+
for i, j in pws.nonzero().cpu().tolist():
|
989 |
+
sparse_graph[i, j] = -float(pws[i, j])
|
990 |
+
msp = sp.csgraph.minimum_spanning_tree(sparse_graph)
|
991 |
+
|
992 |
+
# now reorder the oriented edges, starting from the central point
|
993 |
+
ranks1, _ = bfs(msp, 0)
|
994 |
+
ranks2, _ = bfs(msp, ranks1.argmax())
|
995 |
+
ranks1, _ = bfs(msp, ranks2.argmax())
|
996 |
+
# this is the point farther from any leaf
|
997 |
+
root = np.minimum(ranks1, ranks2).argmax()
|
998 |
+
|
999 |
+
# find the ordered list of edges that describe the tree
|
1000 |
+
order, predecessors = sp.csgraph.breadth_first_order(msp, root, directed=False)
|
1001 |
+
order = order[1:] # root not do not have a predecessor
|
1002 |
+
edges = [(predecessors[i], i) for i in order]
|
1003 |
+
|
1004 |
+
return root, edges
|
1005 |
+
|
1006 |
+
|
1007 |
+
def show_reconstruction(shapes_or_imgs, K, cam2w, pts3d, gt_cam2w=None, gt_K=None, cam_size=None, masks=None, **kw):
|
1008 |
+
viz = SceneViz()
|
1009 |
+
|
1010 |
+
cc = cam2w[:, :3, 3]
|
1011 |
+
cs = cam_size or float(torch.cdist(cc, cc).fill_diagonal_(np.inf).min(dim=0).values.median())
|
1012 |
+
colors = 64 + np.random.randint(255 - 64, size=(len(cam2w), 3))
|
1013 |
+
|
1014 |
+
if isinstance(shapes_or_imgs, np.ndarray) and shapes_or_imgs.ndim == 2:
|
1015 |
+
cam_kws = dict(imsizes=shapes_or_imgs[:, ::-1], cam_size=cs)
|
1016 |
+
else:
|
1017 |
+
imgs = shapes_or_imgs
|
1018 |
+
cam_kws = dict(images=imgs, cam_size=cs)
|
1019 |
+
if K is not None:
|
1020 |
+
viz.add_cameras(to_numpy(cam2w), to_numpy(K), colors=colors, **cam_kws)
|
1021 |
+
|
1022 |
+
if gt_cam2w is not None:
|
1023 |
+
if gt_K is None:
|
1024 |
+
gt_K = K
|
1025 |
+
viz.add_cameras(to_numpy(gt_cam2w), to_numpy(gt_K), colors=colors, marker='o', **cam_kws)
|
1026 |
+
|
1027 |
+
if pts3d is not None:
|
1028 |
+
for i, p in enumerate(pts3d):
|
1029 |
+
if not len(p):
|
1030 |
+
continue
|
1031 |
+
if masks is None:
|
1032 |
+
viz.add_pointcloud(to_numpy(p), color=tuple(colors[i].tolist()))
|
1033 |
+
else:
|
1034 |
+
viz.add_pointcloud(to_numpy(p), mask=masks[i], color=imgs[i])
|
1035 |
+
viz.show(**kw)
|
mast3r/cloud_opt/triangulation.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Matches Triangulation Utils
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
# Batched Matches Triangulation
|
12 |
+
def batched_triangulate(pts2d, # [B, Ncams, Npts, 2]
|
13 |
+
proj_mats): # [B, Ncams, 3, 4] I@E projection matrix
|
14 |
+
B, Ncams, Npts, two = pts2d.shape
|
15 |
+
assert two==2
|
16 |
+
assert proj_mats.shape == (B, Ncams, 3, 4)
|
17 |
+
# P - xP
|
18 |
+
x = proj_mats[...,0,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,0], proj_mats[...,2,:]) # [B, Ncams, Npts, 4]
|
19 |
+
y = proj_mats[...,1,:][...,None,:] - torch.einsum('bij,bik->bijk', pts2d[...,1], proj_mats[...,2,:]) # [B, Ncams, Npts, 4]
|
20 |
+
eq = torch.cat([x, y], dim=1).transpose(1, 2) # [B, Npts, 2xNcams, 4]
|
21 |
+
return torch.linalg.lstsq(eq[...,:3], -eq[...,3]).solution
|
22 |
+
|
23 |
+
def matches_to_depths(intrinsics, # input camera intrinsics [B, Ncams, 3, 3]
|
24 |
+
extrinsics, # input camera extrinsics [B, Ncams, 3, 4]
|
25 |
+
matches, # input correspondences [B, Ncams, Npts, 2]
|
26 |
+
batchsize=16, # bs for batched processing
|
27 |
+
min_num_valids_ratio=.3 # at least this ratio of image pairs need to predict a match for a given pixel of img1
|
28 |
+
):
|
29 |
+
B, Nv, H, W, five = matches.shape
|
30 |
+
min_num_valids = np.floor(Nv*min_num_valids_ratio)
|
31 |
+
out_aggregated_points, out_depths, out_confs = [], [], []
|
32 |
+
for b in range(B//batchsize+1): # batched processing
|
33 |
+
start, stop = b*batchsize,min(B,(b+1)*batchsize)
|
34 |
+
sub_batch=slice(start,stop)
|
35 |
+
sub_batchsize = stop-start
|
36 |
+
if sub_batchsize==0:continue
|
37 |
+
points1, points2, confs = matches[sub_batch, ..., :2], matches[sub_batch, ..., 2:4], matches[sub_batch, ..., -1]
|
38 |
+
allpoints = torch.cat([points1.view([sub_batchsize*Nv,1,H*W,2]), points2.view([sub_batchsize*Nv,1,H*W,2])],dim=1) # [BxNv, 2, HxW, 2]
|
39 |
+
|
40 |
+
allcam_Ps = intrinsics[sub_batch] @ extrinsics[sub_batch,:,:3,:]
|
41 |
+
cam_Ps1, cam_Ps2 = allcam_Ps[:,[0]].repeat([1,Nv,1,1]), allcam_Ps[:,1:] # [B, Nv, 3, 4]
|
42 |
+
formatted_camPs = torch.cat([cam_Ps1.reshape([sub_batchsize*Nv,1,3,4]), cam_Ps2.reshape([sub_batchsize*Nv,1,3,4])],dim=1) # [BxNv, 2, 3, 4]
|
43 |
+
|
44 |
+
# Triangulate matches to 3D
|
45 |
+
points_3d_world = batched_triangulate(allpoints, formatted_camPs) # [BxNv, HxW, three]
|
46 |
+
|
47 |
+
# Aggregate pairwise predictions
|
48 |
+
points_3d_world = points_3d_world.view([sub_batchsize,Nv,H,W,3])
|
49 |
+
valids = points_3d_world.isfinite()
|
50 |
+
valids_sum = valids.sum(dim=-1)
|
51 |
+
validsuni=valids_sum.unique()
|
52 |
+
assert torch.all(torch.logical_or(validsuni == 0 , validsuni == 3)), "Error, can only be nan for none or all XYZ values, not a subset"
|
53 |
+
confs[valids_sum==0] = 0.
|
54 |
+
points_3d_world = points_3d_world*confs[...,None]
|
55 |
+
|
56 |
+
# Take care of NaNs
|
57 |
+
normalization = confs.sum(dim=1)[:,None].repeat(1,Nv,1,1)
|
58 |
+
normalization[normalization <= 1e-5] = 1.
|
59 |
+
points_3d_world[valids] /= normalization[valids_sum==3][:,None].repeat(1,3).view(-1)
|
60 |
+
points_3d_world[~valids] = 0.
|
61 |
+
aggregated_points = points_3d_world.sum(dim=1) # weighted average (by confidence value) ignoring nans
|
62 |
+
|
63 |
+
# Reset invalid values to nans, with a min visibility threshold
|
64 |
+
aggregated_points[valids_sum.sum(dim=1)/3 <= min_num_valids] = torch.nan
|
65 |
+
|
66 |
+
# From 3D to depths
|
67 |
+
refcamE = extrinsics[sub_batch, 0]
|
68 |
+
points_3d_camera = (refcamE[:,:3, :3] @ aggregated_points.view(sub_batchsize,-1,3).transpose(-2,-1) + refcamE[:,:3,[3]]).transpose(-2,-1) # [B,HxW,3]
|
69 |
+
depths = points_3d_camera.view(sub_batchsize,H,W,3)[..., 2] # [B,H,W]
|
70 |
+
|
71 |
+
# Cat results
|
72 |
+
out_aggregated_points.append(aggregated_points.cpu())
|
73 |
+
out_depths.append(depths.cpu())
|
74 |
+
out_confs.append(confs.sum(dim=1).cpu())
|
75 |
+
|
76 |
+
out_aggregated_points = torch.cat(out_aggregated_points,dim=0)
|
77 |
+
out_depths = torch.cat(out_depths,dim=0)
|
78 |
+
out_confs = torch.cat(out_confs,dim=0)
|
79 |
+
|
80 |
+
return out_aggregated_points, out_depths, out_confs
|
mast3r/cloud_opt/tsdf_optimizer.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
from matplotlib import pyplot as pl
|
6 |
+
|
7 |
+
import mast3r.utils.path_to_dust3r # noqa
|
8 |
+
from dust3r.utils.geometry import depthmap_to_pts3d, geotrf, inv
|
9 |
+
|
10 |
+
|
11 |
+
class TSDFPostProcess:
|
12 |
+
""" Optimizes a signed distance-function to improve depthmaps.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, optimizer, subsample=8, TSDF_thresh=0., TSDF_batchsize=int(1e7)):
|
16 |
+
self.TSDF_thresh = TSDF_thresh # None -> no TSDF
|
17 |
+
self.TSDF_batchsize = TSDF_batchsize
|
18 |
+
self.optimizer = optimizer
|
19 |
+
|
20 |
+
pts3d, depthmaps, confs = optimizer.get_dense_pts3d(clean_depth=False, subsample=subsample)
|
21 |
+
pts3d, depthmaps = self._TSDF_postprocess_or_not(pts3d, depthmaps, confs)
|
22 |
+
self.pts3d = pts3d
|
23 |
+
self.depthmaps = depthmaps
|
24 |
+
self.confs = confs
|
25 |
+
|
26 |
+
def _get_depthmaps(self, TSDF_filtering_thresh=None):
|
27 |
+
if TSDF_filtering_thresh:
|
28 |
+
self._refine_depths_with_TSDF(self.optimizer, TSDF_filtering_thresh) # compute refined depths if needed
|
29 |
+
dms = self.TSDF_im_depthmaps if TSDF_filtering_thresh else self.im_depthmaps
|
30 |
+
return [d.exp() for d in dms]
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def _refine_depths_with_TSDF(self, TSDF_filtering_thresh, niter=1, nsamples=1000):
|
34 |
+
"""
|
35 |
+
Leverage TSDF to post-process estimated depths
|
36 |
+
for each pixel, find zero level of TSDF along ray (or closest to 0)
|
37 |
+
"""
|
38 |
+
print("Post-Processing Depths with TSDF fusion.")
|
39 |
+
self.TSDF_im_depthmaps = []
|
40 |
+
alldepths, allposes, allfocals, allpps, allimshapes = self._get_depthmaps(), self.optimizer.get_im_poses(
|
41 |
+
), self.optimizer.get_focals(), self.optimizer.get_principal_points(), self.imshapes
|
42 |
+
for vi in tqdm(range(self.optimizer.n_imgs)):
|
43 |
+
dm, pose, focal, pp, imshape = alldepths[vi], allposes[vi], allfocals[vi], allpps[vi], allimshapes[vi]
|
44 |
+
minvals = torch.full(dm.shape, 1e20)
|
45 |
+
|
46 |
+
for it in range(niter):
|
47 |
+
H, W = dm.shape
|
48 |
+
curthresh = (niter - it) * TSDF_filtering_thresh
|
49 |
+
dm_offsets = (torch.randn(H, W, nsamples).to(dm) - 1.) * \
|
50 |
+
curthresh # decreasing search std along with iterations
|
51 |
+
newdm = dm[..., None] + dm_offsets # [H,W,Nsamp]
|
52 |
+
curproj = self._backproj_pts3d(in_depths=[newdm], in_im_poses=pose[None], in_focals=focal[None], in_pps=pp[None], in_imshapes=[
|
53 |
+
imshape])[0] # [H,W,Nsamp,3]
|
54 |
+
# Batched TSDF eval
|
55 |
+
curproj = curproj.view(-1, 3)
|
56 |
+
tsdf_vals = []
|
57 |
+
valids = []
|
58 |
+
for batch in range(0, len(curproj), self.TSDF_batchsize):
|
59 |
+
values, valid = self._TSDF_query(
|
60 |
+
curproj[batch:min(batch + self.TSDF_batchsize, len(curproj))], curthresh)
|
61 |
+
tsdf_vals.append(values)
|
62 |
+
valids.append(valid)
|
63 |
+
tsdf_vals = torch.cat(tsdf_vals, dim=0)
|
64 |
+
valids = torch.cat(valids, dim=0)
|
65 |
+
|
66 |
+
tsdf_vals = tsdf_vals.view([H, W, nsamples])
|
67 |
+
valids = valids.view([H, W, nsamples])
|
68 |
+
|
69 |
+
# keep depth value that got us the closest to 0
|
70 |
+
tsdf_vals[~valids] = torch.inf # ignore invalid values
|
71 |
+
tsdf_vals = tsdf_vals.abs()
|
72 |
+
mins = torch.argmin(tsdf_vals, dim=-1, keepdim=True)
|
73 |
+
# when all samples live on a very flat zone, do nothing
|
74 |
+
allbad = (tsdf_vals == curthresh).sum(dim=-1) == nsamples
|
75 |
+
dm[~allbad] = torch.gather(newdm, -1, mins)[..., 0][~allbad]
|
76 |
+
|
77 |
+
# Save refined depth map
|
78 |
+
self.TSDF_im_depthmaps.append(dm.log())
|
79 |
+
|
80 |
+
def _TSDF_query(self, qpoints, TSDF_filtering_thresh, weighted=True):
|
81 |
+
"""
|
82 |
+
TSDF query call: returns the weighted TSDF value for each query point [N, 3]
|
83 |
+
"""
|
84 |
+
N, three = qpoints.shape
|
85 |
+
assert three == 3
|
86 |
+
qpoints = qpoints[None].repeat(self.optimizer.n_imgs, 1, 1) # [B,N,3]
|
87 |
+
# get projection coordinates and depths onto images
|
88 |
+
coords_and_depth = self._proj_pts3d(pts3d=qpoints, cam2worlds=self.optimizer.get_im_poses(
|
89 |
+
), focals=self.optimizer.get_focals(), pps=self.optimizer.get_principal_points())
|
90 |
+
image_coords = coords_and_depth[..., :2].round().to(int) # for now, there's no interpolation...
|
91 |
+
proj_depths = coords_and_depth[..., -1]
|
92 |
+
# recover depth values after scene optim
|
93 |
+
pred_depths, pred_confs, valids = self._get_pixel_depths(image_coords)
|
94 |
+
# Gather TSDF scores
|
95 |
+
all_SDF_scores = pred_depths - proj_depths # SDF
|
96 |
+
unseen = all_SDF_scores < -TSDF_filtering_thresh # handle visibility
|
97 |
+
# all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh,TSDF_filtering_thresh) # SDF -> TSDF
|
98 |
+
all_TSDF_scores = all_SDF_scores.clip(-TSDF_filtering_thresh, 1e20) # SDF -> TSDF
|
99 |
+
# Gather TSDF confidences and ignore points that are unseen, either OOB during reproj or too far behind seen depth
|
100 |
+
all_TSDF_weights = (~unseen).float() * valids.float()
|
101 |
+
if weighted:
|
102 |
+
all_TSDF_weights = pred_confs.exp() * all_TSDF_weights
|
103 |
+
# Aggregate all votes, ignoring zeros
|
104 |
+
TSDF_weights = all_TSDF_weights.sum(dim=0)
|
105 |
+
valids = TSDF_weights != 0.
|
106 |
+
TSDF_wsum = (all_TSDF_weights * all_TSDF_scores).sum(dim=0)
|
107 |
+
TSDF_wsum[valids] /= TSDF_weights[valids]
|
108 |
+
return TSDF_wsum, valids
|
109 |
+
|
110 |
+
def _get_pixel_depths(self, image_coords, TSDF_filtering_thresh=None, with_normals_conf=False):
|
111 |
+
""" Recover depth value for each input pixel coordinate, along with OOB validity mask
|
112 |
+
"""
|
113 |
+
B, N, two = image_coords.shape
|
114 |
+
assert B == self.optimizer.n_imgs and two == 2
|
115 |
+
depths = torch.zeros([B, N], device=image_coords.device)
|
116 |
+
valids = torch.zeros([B, N], dtype=bool, device=image_coords.device)
|
117 |
+
confs = torch.zeros([B, N], device=image_coords.device)
|
118 |
+
curconfs = self._get_confs_with_normals() if with_normals_conf else self.im_conf
|
119 |
+
for ni, (imc, depth, conf) in enumerate(zip(image_coords, self._get_depthmaps(TSDF_filtering_thresh), curconfs)):
|
120 |
+
H, W = depth.shape
|
121 |
+
valids[ni] = torch.logical_and(0 <= imc[:, 1], imc[:, 1] <
|
122 |
+
H) & torch.logical_and(0 <= imc[:, 0], imc[:, 0] < W)
|
123 |
+
imc[~valids[ni]] = 0
|
124 |
+
depths[ni] = depth[imc[:, 1], imc[:, 0]]
|
125 |
+
confs[ni] = conf.cuda()[imc[:, 1], imc[:, 0]]
|
126 |
+
return depths, confs, valids
|
127 |
+
|
128 |
+
def _get_confs_with_normals(self):
|
129 |
+
outconfs = []
|
130 |
+
# Confidence basedf on depth gradient
|
131 |
+
|
132 |
+
class Sobel(nn.Module):
|
133 |
+
def __init__(self):
|
134 |
+
super().__init__()
|
135 |
+
self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False)
|
136 |
+
Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]])
|
137 |
+
Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]])
|
138 |
+
G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
|
139 |
+
G = G.unsqueeze(1)
|
140 |
+
self.filter.weight = nn.Parameter(G, requires_grad=False)
|
141 |
+
|
142 |
+
def forward(self, img):
|
143 |
+
x = self.filter(img)
|
144 |
+
x = torch.mul(x, x)
|
145 |
+
x = torch.sum(x, dim=1, keepdim=True)
|
146 |
+
x = torch.sqrt(x)
|
147 |
+
return x
|
148 |
+
|
149 |
+
grad_op = Sobel().to(self.im_depthmaps[0].device)
|
150 |
+
for conf, depth in zip(self.im_conf, self.im_depthmaps):
|
151 |
+
grad_confs = (1. - grad_op(depth[None, None])[0, 0]).clip(0)
|
152 |
+
if not 'dbg show':
|
153 |
+
pl.imshow(grad_confs.cpu())
|
154 |
+
pl.show()
|
155 |
+
outconfs.append(conf * grad_confs.to(conf))
|
156 |
+
return outconfs
|
157 |
+
|
158 |
+
def _proj_pts3d(self, pts3d, cam2worlds, focals, pps):
|
159 |
+
"""
|
160 |
+
Projection operation: from 3D points to 2D coordinates + depths
|
161 |
+
"""
|
162 |
+
B = pts3d.shape[0]
|
163 |
+
assert pts3d.shape[0] == cam2worlds.shape[0]
|
164 |
+
# prepare Extrinsincs
|
165 |
+
R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
|
166 |
+
Rinv = R.transpose(-2, -1)
|
167 |
+
tinv = -Rinv @ t[..., None]
|
168 |
+
|
169 |
+
# prepare intrinsics
|
170 |
+
intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(focals.shape[0], 1, 1)
|
171 |
+
if len(focals.shape) == 1:
|
172 |
+
focals = torch.stack([focals, focals], dim=-1)
|
173 |
+
intrinsics[:, 0, 0] = focals[:, 0]
|
174 |
+
intrinsics[:, 1, 1] = focals[:, 1]
|
175 |
+
intrinsics[:, :2, -1] = pps
|
176 |
+
# Project
|
177 |
+
projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
|
178 |
+
projpts = projpts.transpose(-2, -1) # [B,N,3]
|
179 |
+
projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
|
180 |
+
return projpts
|
181 |
+
|
182 |
+
def _backproj_pts3d(self, in_depths=None, in_im_poses=None,
|
183 |
+
in_focals=None, in_pps=None, in_imshapes=None):
|
184 |
+
"""
|
185 |
+
Backprojection operation: from image depths to 3D points
|
186 |
+
"""
|
187 |
+
# Get depths and projection params if not provided
|
188 |
+
focals = self.optimizer.get_focals() if in_focals is None else in_focals
|
189 |
+
im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
|
190 |
+
depth = self._get_depthmaps() if in_depths is None else in_depths
|
191 |
+
pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
|
192 |
+
imshapes = self.imshapes if in_imshapes is None else in_imshapes
|
193 |
+
def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
|
194 |
+
dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[[i]]) for i in range(im_poses.shape[0])]
|
195 |
+
|
196 |
+
def autoprocess(x):
|
197 |
+
x = x[0]
|
198 |
+
return x.transpose(-2, -1) if len(x.shape) == 4 else x
|
199 |
+
return [geotrf(pose, autoprocess(pt)) for pose, pt in zip(im_poses, dm_to_3d)]
|
200 |
+
|
201 |
+
def _pts3d_to_depth(self, pts3d, cam2worlds, focals, pps):
|
202 |
+
"""
|
203 |
+
Projection operation: from 3D points to 2D coordinates + depths
|
204 |
+
"""
|
205 |
+
B = pts3d.shape[0]
|
206 |
+
assert pts3d.shape[0] == cam2worlds.shape[0]
|
207 |
+
# prepare Extrinsincs
|
208 |
+
R, t = cam2worlds[:, :3, :3], cam2worlds[:, :3, -1]
|
209 |
+
Rinv = R.transpose(-2, -1)
|
210 |
+
tinv = -Rinv @ t[..., None]
|
211 |
+
|
212 |
+
# prepare intrinsics
|
213 |
+
intrinsics = torch.eye(3).to(cam2worlds)[None].repeat(self.optimizer.n_imgs, 1, 1)
|
214 |
+
if len(focals.shape) == 1:
|
215 |
+
focals = torch.stack([focals, focals], dim=-1)
|
216 |
+
intrinsics[:, 0, 0] = focals[:, 0]
|
217 |
+
intrinsics[:, 1, 1] = focals[:, 1]
|
218 |
+
intrinsics[:, :2, -1] = pps
|
219 |
+
# Project
|
220 |
+
projpts = intrinsics @ (Rinv @ pts3d.transpose(-2, -1) + tinv) # I(RX+t) : [B,3,N]
|
221 |
+
projpts = projpts.transpose(-2, -1) # [B,N,3]
|
222 |
+
projpts[..., :2] /= projpts[..., [-1]] # [B,N,3] (X/Z , Y/Z, Z)
|
223 |
+
return projpts
|
224 |
+
|
225 |
+
def _depth_to_pts3d(self, in_depths=None, in_im_poses=None, in_focals=None, in_pps=None, in_imshapes=None):
|
226 |
+
"""
|
227 |
+
Backprojection operation: from image depths to 3D points
|
228 |
+
"""
|
229 |
+
# Get depths and projection params if not provided
|
230 |
+
focals = self.optimizer.get_focals() if in_focals is None else in_focals
|
231 |
+
im_poses = self.optimizer.get_im_poses() if in_im_poses is None else in_im_poses
|
232 |
+
depth = self._get_depthmaps() if in_depths is None else in_depths
|
233 |
+
pp = self.optimizer.get_principal_points() if in_pps is None else in_pps
|
234 |
+
imshapes = self.imshapes if in_imshapes is None else in_imshapes
|
235 |
+
|
236 |
+
def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *imshapes[i])
|
237 |
+
|
238 |
+
dm_to_3d = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i + 1]) for i in range(im_poses.shape[0])]
|
239 |
+
|
240 |
+
def autoprocess(x):
|
241 |
+
x = x[0]
|
242 |
+
H, W, three = x.shape[:3]
|
243 |
+
return x.transpose(-2, -1) if len(x.shape) == 4 else x
|
244 |
+
return [geotrf(pp, autoprocess(pt)) for pp, pt in zip(im_poses, dm_to_3d)]
|
245 |
+
|
246 |
+
def _get_pts3d(self, TSDF_filtering_thresh=None, **kw):
|
247 |
+
"""
|
248 |
+
return 3D points (possibly filtering depths with TSDF)
|
249 |
+
"""
|
250 |
+
return self._backproj_pts3d(in_depths=self._get_depthmaps(TSDF_filtering_thresh=TSDF_filtering_thresh), **kw)
|
251 |
+
|
252 |
+
def _TSDF_postprocess_or_not(self, pts3d, depthmaps, confs, niter=1):
|
253 |
+
# Setup inner variables
|
254 |
+
self.imshapes = [im.shape[:2] for im in self.optimizer.imgs]
|
255 |
+
self.im_depthmaps = [dd.log().view(imshape) for dd, imshape in zip(depthmaps, self.imshapes)]
|
256 |
+
self.im_conf = confs
|
257 |
+
|
258 |
+
if self.TSDF_thresh > 0.:
|
259 |
+
# Create or update self.TSDF_im_depthmaps that contain logdepths filtered with TSDF
|
260 |
+
self._refine_depths_with_TSDF(self.TSDF_thresh, niter=niter)
|
261 |
+
depthmaps = [dd.exp() for dd in self.TSDF_im_depthmaps]
|
262 |
+
# Turn them into 3D points
|
263 |
+
pts3d = self._backproj_pts3d(in_depths=depthmaps)
|
264 |
+
depthmaps = [dd.flatten() for dd in depthmaps]
|
265 |
+
pts3d = [pp.view(-1, 3) for pp in pts3d]
|
266 |
+
return pts3d, depthmaps
|
267 |
+
|
268 |
+
def get_dense_pts3d(self, clean_depth=True):
|
269 |
+
return self.pts3d, self.depthmaps, self.confs
|
mast3r/cloud_opt/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
mast3r/cloud_opt/utils/losses.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# losses for sparse ga
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
def l05_loss(x, y):
|
12 |
+
return torch.linalg.norm(x - y, dim=-1).sqrt()
|
13 |
+
|
14 |
+
|
15 |
+
def l1_loss(x, y):
|
16 |
+
return torch.linalg.norm(x - y, dim=-1)
|
17 |
+
|
18 |
+
|
19 |
+
def gamma_loss(gamma, mul=1, offset=None, clip=np.inf):
|
20 |
+
if offset is None:
|
21 |
+
if gamma == 1:
|
22 |
+
return l1_loss
|
23 |
+
# d(x**p)/dx = 1 ==> p * x**(p-1) == 1 ==> x = (1/p)**(1/(p-1))
|
24 |
+
offset = (1 / gamma)**(1 / (gamma - 1))
|
25 |
+
|
26 |
+
def loss_func(x, y):
|
27 |
+
return (mul * l1_loss(x, y).clip(max=clip) + offset) ** gamma - offset ** gamma
|
28 |
+
return loss_func
|
29 |
+
|
30 |
+
|
31 |
+
def meta_gamma_loss():
|
32 |
+
return lambda alpha: gamma_loss(alpha)
|
mast3r/cloud_opt/utils/schedules.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# lr schedules for sparse ga
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def linear_schedule(alpha, lr_base, lr_end=0):
|
11 |
+
lr = (1 - alpha) * lr_base + alpha * lr_end
|
12 |
+
return lr
|
13 |
+
|
14 |
+
|
15 |
+
def cosine_schedule(alpha, lr_base, lr_end=0):
|
16 |
+
lr = lr_end + (lr_base - lr_end) * (1 + np.cos(alpha * np.pi)) / 2
|
17 |
+
return lr
|
mast3r/colmap/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
mast3r/colmap/database.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# MASt3R to colmap export functions
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import copy
|
10 |
+
import numpy as np
|
11 |
+
import torchvision
|
12 |
+
import numpy as np
|
13 |
+
from tqdm import tqdm
|
14 |
+
from scipy.cluster.hierarchy import DisjointSet
|
15 |
+
from scipy.spatial.transform import Rotation as R
|
16 |
+
|
17 |
+
from mast3r.utils.misc import hash_md5
|
18 |
+
|
19 |
+
from mast3r.fast_nn import extract_correspondences_nonsym, bruteforce_reciprocal_nns
|
20 |
+
|
21 |
+
import mast3r.utils.path_to_dust3r # noqa
|
22 |
+
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf # noqa
|
23 |
+
|
24 |
+
|
25 |
+
def convert_im_matches_pairs(img0, img1, image_to_colmap, im_keypoints, matches_im0, matches_im1, viz):
|
26 |
+
if viz:
|
27 |
+
from matplotlib import pyplot as pl
|
28 |
+
|
29 |
+
image_mean = torch.as_tensor(
|
30 |
+
[0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1)
|
31 |
+
image_std = torch.as_tensor(
|
32 |
+
[0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1)
|
33 |
+
rgb0 = img0['img'] * image_std + image_mean
|
34 |
+
rgb0 = torchvision.transforms.functional.to_pil_image(rgb0[0])
|
35 |
+
rgb0 = np.array(rgb0)
|
36 |
+
|
37 |
+
rgb1 = img1['img'] * image_std + image_mean
|
38 |
+
rgb1 = torchvision.transforms.functional.to_pil_image(rgb1[0])
|
39 |
+
rgb1 = np.array(rgb1)
|
40 |
+
|
41 |
+
imgs = [rgb0, rgb1]
|
42 |
+
# visualize a few matches
|
43 |
+
n_viz = 100
|
44 |
+
num_matches = matches_im0.shape[0]
|
45 |
+
match_idx_to_viz = np.round(np.linspace(
|
46 |
+
0, num_matches - 1, n_viz)).astype(int)
|
47 |
+
viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
|
48 |
+
|
49 |
+
H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
|
50 |
+
rgb0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)),
|
51 |
+
(0, 0), (0, 0)), 'constant', constant_values=0)
|
52 |
+
rgb1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)),
|
53 |
+
(0, 0), (0, 0)), 'constant', constant_values=0)
|
54 |
+
img = np.concatenate((rgb0, rgb1), axis=1)
|
55 |
+
pl.figure()
|
56 |
+
pl.imshow(img)
|
57 |
+
cmap = pl.get_cmap('jet')
|
58 |
+
for ii in range(n_viz):
|
59 |
+
(x0, y0), (x1,
|
60 |
+
y1) = viz_matches_im0[ii].T, viz_matches_im1[ii].T
|
61 |
+
pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(ii /
|
62 |
+
(n_viz - 1)), scalex=False, scaley=False)
|
63 |
+
pl.show(block=True)
|
64 |
+
|
65 |
+
matches = [matches_im0.astype(np.float64), matches_im1.astype(np.float64)]
|
66 |
+
imgs = [img0, img1]
|
67 |
+
imidx0 = img0['idx']
|
68 |
+
imidx1 = img1['idx']
|
69 |
+
ravel_matches = []
|
70 |
+
for j in range(2):
|
71 |
+
H, W = imgs[j]['true_shape'][0]
|
72 |
+
with np.errstate(invalid='ignore'):
|
73 |
+
qx, qy = matches[j].round().astype(np.int32).T
|
74 |
+
ravel_matches_j = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
|
75 |
+
ravel_matches.append(ravel_matches_j)
|
76 |
+
imidxj = imgs[j]['idx']
|
77 |
+
for m in ravel_matches_j:
|
78 |
+
if m not in im_keypoints[imidxj]:
|
79 |
+
im_keypoints[imidxj][m] = 0
|
80 |
+
im_keypoints[imidxj][m] += 1
|
81 |
+
imid0 = copy.deepcopy(image_to_colmap[imidx0]['colmap_imid'])
|
82 |
+
imid1 = copy.deepcopy(image_to_colmap[imidx1]['colmap_imid'])
|
83 |
+
if imid0 > imid1:
|
84 |
+
colmap_matches = np.stack([ravel_matches[1], ravel_matches[0]], axis=-1)
|
85 |
+
imid0, imid1 = imid1, imid0
|
86 |
+
imidx0, imidx1 = imidx1, imidx0
|
87 |
+
else:
|
88 |
+
colmap_matches = np.stack([ravel_matches[0], ravel_matches[1]], axis=-1)
|
89 |
+
colmap_matches = np.unique(colmap_matches, axis=0)
|
90 |
+
return imidx0, imidx1, colmap_matches
|
91 |
+
|
92 |
+
|
93 |
+
def get_im_matches(pred1, pred2, pairs, image_to_colmap, im_keypoints, conf_thr,
|
94 |
+
is_sparse=True, subsample=8, pixel_tol=0, viz=False, device='cuda'):
|
95 |
+
im_matches = {}
|
96 |
+
for i in range(len(pred1['pts3d'])):
|
97 |
+
imidx0 = pairs[i][0]['idx']
|
98 |
+
imidx1 = pairs[i][1]['idx']
|
99 |
+
if 'desc' in pred1: # mast3r
|
100 |
+
descs = [pred1['desc'][i], pred2['desc'][i]]
|
101 |
+
confidences = [pred1['desc_conf'][i], pred2['desc_conf'][i]]
|
102 |
+
desc_dim = descs[0].shape[-1]
|
103 |
+
|
104 |
+
if is_sparse:
|
105 |
+
corres = extract_correspondences_nonsym(descs[0], descs[1], confidences[0], confidences[1],
|
106 |
+
device=device, subsample=subsample, pixel_tol=pixel_tol)
|
107 |
+
conf = corres[2]
|
108 |
+
mask = conf >= conf_thr
|
109 |
+
matches_im0 = corres[0][mask].cpu().numpy()
|
110 |
+
matches_im1 = corres[1][mask].cpu().numpy()
|
111 |
+
else:
|
112 |
+
confidence_masks = [confidences[0] >=
|
113 |
+
conf_thr, confidences[1] >= conf_thr]
|
114 |
+
pts2d_list, desc_list = [], []
|
115 |
+
for j in range(2):
|
116 |
+
conf_j = confidence_masks[j].cpu().numpy().flatten()
|
117 |
+
true_shape_j = pairs[i][j]['true_shape'][0]
|
118 |
+
pts2d_j = xy_grid(
|
119 |
+
true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j]
|
120 |
+
desc_j = descs[j].detach().cpu(
|
121 |
+
).numpy().reshape(-1, desc_dim)[conf_j]
|
122 |
+
pts2d_list.append(pts2d_j)
|
123 |
+
desc_list.append(desc_j)
|
124 |
+
if len(desc_list[0]) == 0 or len(desc_list[1]) == 0:
|
125 |
+
continue
|
126 |
+
|
127 |
+
nn0, nn1 = bruteforce_reciprocal_nns(desc_list[0], desc_list[1],
|
128 |
+
device=device, dist='dot', block_size=2**13)
|
129 |
+
reciprocal_in_P0 = (nn1[nn0] == np.arange(len(nn0)))
|
130 |
+
|
131 |
+
matches_im1 = pts2d_list[1][nn0][reciprocal_in_P0]
|
132 |
+
matches_im0 = pts2d_list[0][reciprocal_in_P0]
|
133 |
+
else:
|
134 |
+
pts3d = [pred1['pts3d'][i], pred2['pts3d_in_other_view'][i]]
|
135 |
+
confidences = [pred1['conf'][i], pred2['conf'][i]]
|
136 |
+
|
137 |
+
if is_sparse:
|
138 |
+
corres = extract_correspondences_nonsym(pts3d[0], pts3d[1], confidences[0], confidences[1],
|
139 |
+
device=device, subsample=subsample, pixel_tol=pixel_tol,
|
140 |
+
ptmap_key='3d')
|
141 |
+
conf = corres[2]
|
142 |
+
mask = conf >= conf_thr
|
143 |
+
matches_im0 = corres[0][mask].cpu().numpy()
|
144 |
+
matches_im1 = corres[1][mask].cpu().numpy()
|
145 |
+
else:
|
146 |
+
confidence_masks = [confidences[0] >=
|
147 |
+
conf_thr, confidences[1] >= conf_thr]
|
148 |
+
# find 2D-2D matches between the two images
|
149 |
+
pts2d_list, pts3d_list = [], []
|
150 |
+
for j in range(2):
|
151 |
+
conf_j = confidence_masks[j].cpu().numpy().flatten()
|
152 |
+
true_shape_j = pairs[i][j]['true_shape'][0]
|
153 |
+
pts2d_j = xy_grid(true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j]
|
154 |
+
pts3d_j = pts3d[j].detach().cpu().numpy().reshape(-1, 3)[conf_j]
|
155 |
+
pts2d_list.append(pts2d_j)
|
156 |
+
pts3d_list.append(pts3d_j)
|
157 |
+
|
158 |
+
PQ, PM = pts3d_list[0], pts3d_list[1]
|
159 |
+
if len(PQ) == 0 or len(PM) == 0:
|
160 |
+
continue
|
161 |
+
reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches(
|
162 |
+
PQ, PM)
|
163 |
+
|
164 |
+
matches_im1 = pts2d_list[1][reciprocal_in_PM]
|
165 |
+
matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM]
|
166 |
+
|
167 |
+
if len(matches_im0) == 0:
|
168 |
+
continue
|
169 |
+
imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1],
|
170 |
+
image_to_colmap, im_keypoints,
|
171 |
+
matches_im0, matches_im1, viz)
|
172 |
+
im_matches[(imidx0, imidx1)] = colmap_matches
|
173 |
+
return im_matches
|
174 |
+
|
175 |
+
|
176 |
+
def get_im_matches_from_cache(pairs, cache_path, desc_conf, subsample,
|
177 |
+
image_to_colmap, im_keypoints, conf_thr,
|
178 |
+
viz=False, device='cuda'):
|
179 |
+
im_matches = {}
|
180 |
+
for i in range(len(pairs)):
|
181 |
+
imidx0 = pairs[i][0]['idx']
|
182 |
+
imidx1 = pairs[i][1]['idx']
|
183 |
+
|
184 |
+
corres_idx1 = hash_md5(pairs[i][0]['instance'])
|
185 |
+
corres_idx2 = hash_md5(pairs[i][1]['instance'])
|
186 |
+
|
187 |
+
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx1}-{corres_idx2}.pth'
|
188 |
+
if os.path.isfile(path_corres):
|
189 |
+
score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device)
|
190 |
+
else:
|
191 |
+
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx2}-{corres_idx1}.pth'
|
192 |
+
score, (xy2, xy1, confs) = torch.load(path_corres, map_location=device)
|
193 |
+
mask = confs >= conf_thr
|
194 |
+
matches_im0 = xy1[mask].cpu().numpy()
|
195 |
+
matches_im1 = xy2[mask].cpu().numpy()
|
196 |
+
|
197 |
+
if len(matches_im0) == 0:
|
198 |
+
continue
|
199 |
+
imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1],
|
200 |
+
image_to_colmap, im_keypoints,
|
201 |
+
matches_im0, matches_im1, viz)
|
202 |
+
im_matches[(imidx0, imidx1)] = colmap_matches
|
203 |
+
return im_matches
|
204 |
+
|
205 |
+
|
206 |
+
def export_images(db, images, image_paths, focals, ga_world_to_cam, camera_model):
|
207 |
+
# add cameras/images to the db
|
208 |
+
# with the output of ga as prior
|
209 |
+
image_to_colmap = {}
|
210 |
+
im_keypoints = {}
|
211 |
+
for idx in range(len(image_paths)):
|
212 |
+
im_keypoints[idx] = {}
|
213 |
+
H, W = images[idx]["orig_shape"]
|
214 |
+
if focals is None:
|
215 |
+
focal_x = focal_y = 1.2 * max(W, H)
|
216 |
+
prior_focal_length = False
|
217 |
+
cx = W / 2.0
|
218 |
+
cy = H / 2.0
|
219 |
+
elif isinstance(focals[idx], np.ndarray) and len(focals[idx].shape) == 2:
|
220 |
+
# intrinsics
|
221 |
+
focal_x = focals[idx][0, 0]
|
222 |
+
focal_y = focals[idx][1, 1]
|
223 |
+
cx = focals[idx][0, 2] * images[idx]["to_orig"][0, 0]
|
224 |
+
cy = focals[idx][1, 2] * images[idx]["to_orig"][1, 1]
|
225 |
+
prior_focal_length = True
|
226 |
+
else:
|
227 |
+
focal_x = focal_y = float(focals[idx])
|
228 |
+
prior_focal_length = True
|
229 |
+
cx = W / 2.0
|
230 |
+
cy = H / 2.0
|
231 |
+
focal_x = focal_x * images[idx]["to_orig"][0, 0]
|
232 |
+
focal_y = focal_y * images[idx]["to_orig"][1, 1]
|
233 |
+
|
234 |
+
if camera_model == "SIMPLE_PINHOLE":
|
235 |
+
model_id = 0
|
236 |
+
focal = (focal_x + focal_y) / 2.0
|
237 |
+
params = np.asarray([focal, cx, cy], np.float64)
|
238 |
+
elif camera_model == "PINHOLE":
|
239 |
+
model_id = 1
|
240 |
+
params = np.asarray([focal_x, focal_y, cx, cy], np.float64)
|
241 |
+
elif camera_model == "SIMPLE_RADIAL":
|
242 |
+
model_id = 2
|
243 |
+
focal = (focal_x + focal_y) / 2.0
|
244 |
+
params = np.asarray([focal, cx, cy, 0.0], np.float64)
|
245 |
+
elif camera_model == "OPENCV":
|
246 |
+
model_id = 4
|
247 |
+
params = np.asarray([focal_x, focal_y, cx, cy, 0.0, 0.0, 0.0, 0.0], np.float64)
|
248 |
+
else:
|
249 |
+
raise ValueError(f"invalid camera model {camera_model}")
|
250 |
+
|
251 |
+
H, W = int(H), int(W)
|
252 |
+
# OPENCV camera model
|
253 |
+
camid = db.add_camera(
|
254 |
+
model_id, W, H, params, prior_focal_length=prior_focal_length)
|
255 |
+
if ga_world_to_cam is None:
|
256 |
+
prior_t = np.zeros(3)
|
257 |
+
prior_q = np.zeros(4)
|
258 |
+
else:
|
259 |
+
q = R.from_matrix(ga_world_to_cam[idx][:3, :3]).as_quat()
|
260 |
+
prior_t = ga_world_to_cam[idx][:3, 3]
|
261 |
+
prior_q = np.array([q[-1], q[0], q[1], q[2]])
|
262 |
+
imid = db.add_image(
|
263 |
+
image_paths[idx], camid, prior_q=prior_q, prior_t=prior_t)
|
264 |
+
image_to_colmap[idx] = {
|
265 |
+
'colmap_imid': imid,
|
266 |
+
'colmap_camid': camid
|
267 |
+
}
|
268 |
+
return image_to_colmap, im_keypoints
|
269 |
+
|
270 |
+
|
271 |
+
def export_matches(db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification):
|
272 |
+
colmap_image_pairs = []
|
273 |
+
# 2D-2D are quite dense
|
274 |
+
# we want to remove the very small tracks
|
275 |
+
# and export only kpt for which we have values
|
276 |
+
# build tracks
|
277 |
+
print("building tracks")
|
278 |
+
keypoints_to_track_id = {}
|
279 |
+
track_id_to_kpt_list = []
|
280 |
+
to_merge = []
|
281 |
+
for (imidx0, imidx1), colmap_matches in tqdm(im_matches.items()):
|
282 |
+
if imidx0 not in keypoints_to_track_id:
|
283 |
+
keypoints_to_track_id[imidx0] = {}
|
284 |
+
if imidx1 not in keypoints_to_track_id:
|
285 |
+
keypoints_to_track_id[imidx1] = {}
|
286 |
+
|
287 |
+
for m in colmap_matches:
|
288 |
+
if m[0] not in keypoints_to_track_id[imidx0] and m[1] not in keypoints_to_track_id[imidx1]:
|
289 |
+
# new pair of kpts never seen before
|
290 |
+
track_idx = len(track_id_to_kpt_list)
|
291 |
+
keypoints_to_track_id[imidx0][m[0]] = track_idx
|
292 |
+
keypoints_to_track_id[imidx1][m[1]] = track_idx
|
293 |
+
track_id_to_kpt_list.append(
|
294 |
+
[(imidx0, m[0]), (imidx1, m[1])])
|
295 |
+
elif m[1] not in keypoints_to_track_id[imidx1]:
|
296 |
+
# 0 has a track, not 1
|
297 |
+
track_idx = keypoints_to_track_id[imidx0][m[0]]
|
298 |
+
keypoints_to_track_id[imidx1][m[1]] = track_idx
|
299 |
+
track_id_to_kpt_list[track_idx].append((imidx1, m[1]))
|
300 |
+
elif m[0] not in keypoints_to_track_id[imidx0]:
|
301 |
+
# 1 has a track, not 0
|
302 |
+
track_idx = keypoints_to_track_id[imidx1][m[1]]
|
303 |
+
keypoints_to_track_id[imidx0][m[0]] = track_idx
|
304 |
+
track_id_to_kpt_list[track_idx].append((imidx0, m[0]))
|
305 |
+
else:
|
306 |
+
# both have tracks, merge them
|
307 |
+
track_idx0 = keypoints_to_track_id[imidx0][m[0]]
|
308 |
+
track_idx1 = keypoints_to_track_id[imidx1][m[1]]
|
309 |
+
if track_idx0 != track_idx1:
|
310 |
+
# let's deal with them later
|
311 |
+
to_merge.append((track_idx0, track_idx1))
|
312 |
+
|
313 |
+
# regroup merge targets
|
314 |
+
print("merging tracks")
|
315 |
+
unique = np.unique(to_merge)
|
316 |
+
tree = DisjointSet(unique)
|
317 |
+
for track_idx0, track_idx1 in tqdm(to_merge):
|
318 |
+
tree.merge(track_idx0, track_idx1)
|
319 |
+
|
320 |
+
subsets = tree.subsets()
|
321 |
+
print("applying merge")
|
322 |
+
for setvals in tqdm(subsets):
|
323 |
+
new_trackid = len(track_id_to_kpt_list)
|
324 |
+
kpt_list = []
|
325 |
+
for track_idx in setvals:
|
326 |
+
kpt_list.extend(track_id_to_kpt_list[track_idx])
|
327 |
+
for imidx, kpid in track_id_to_kpt_list[track_idx]:
|
328 |
+
keypoints_to_track_id[imidx][kpid] = new_trackid
|
329 |
+
track_id_to_kpt_list.append(kpt_list)
|
330 |
+
|
331 |
+
# binc = np.bincount([len(v) for v in track_id_to_kpt_list])
|
332 |
+
# nonzero = np.nonzero(binc)
|
333 |
+
# nonzerobinc = binc[nonzero[0]]
|
334 |
+
# print(nonzero[0].tolist())
|
335 |
+
# print(nonzerobinc)
|
336 |
+
num_valid_tracks = sum(
|
337 |
+
[1 for v in track_id_to_kpt_list if len(v) >= min_len_track])
|
338 |
+
|
339 |
+
keypoints_to_idx = {}
|
340 |
+
print(f"squashing keypoints - {num_valid_tracks} valid tracks")
|
341 |
+
for imidx, keypoints_imid in tqdm(im_keypoints.items()):
|
342 |
+
imid = image_to_colmap[imidx]['colmap_imid']
|
343 |
+
keypoints_kept = []
|
344 |
+
keypoints_to_idx[imidx] = {}
|
345 |
+
for kp in keypoints_imid.keys():
|
346 |
+
if kp not in keypoints_to_track_id[imidx]:
|
347 |
+
continue
|
348 |
+
track_idx = keypoints_to_track_id[imidx][kp]
|
349 |
+
track_length = len(track_id_to_kpt_list[track_idx])
|
350 |
+
if track_length < min_len_track:
|
351 |
+
continue
|
352 |
+
keypoints_to_idx[imidx][kp] = len(keypoints_kept)
|
353 |
+
keypoints_kept.append(kp)
|
354 |
+
if len(keypoints_kept) == 0:
|
355 |
+
continue
|
356 |
+
keypoints_kept = np.array(keypoints_kept)
|
357 |
+
keypoints_kept = np.unravel_index(keypoints_kept, images[imidx]['true_shape'][0])[
|
358 |
+
0].base[:, ::-1].copy().astype(np.float32)
|
359 |
+
# rescale coordinates
|
360 |
+
keypoints_kept[:, 0] += 0.5
|
361 |
+
keypoints_kept[:, 1] += 0.5
|
362 |
+
keypoints_kept = geotrf(images[imidx]['to_orig'], keypoints_kept, norm=True)
|
363 |
+
|
364 |
+
H, W = images[imidx]['orig_shape']
|
365 |
+
keypoints_kept[:, 0] = keypoints_kept[:, 0].clip(min=0, max=W - 0.01)
|
366 |
+
keypoints_kept[:, 1] = keypoints_kept[:, 1].clip(min=0, max=H - 0.01)
|
367 |
+
|
368 |
+
db.add_keypoints(imid, keypoints_kept)
|
369 |
+
|
370 |
+
print("exporting im_matches")
|
371 |
+
for (imidx0, imidx1), colmap_matches in im_matches.items():
|
372 |
+
imid0, imid1 = image_to_colmap[imidx0]['colmap_imid'], image_to_colmap[imidx1]['colmap_imid']
|
373 |
+
assert imid0 < imid1
|
374 |
+
final_matches = np.array([[keypoints_to_idx[imidx0][m[0]], keypoints_to_idx[imidx1][m[1]]]
|
375 |
+
for m in colmap_matches
|
376 |
+
if m[0] in keypoints_to_idx[imidx0] and m[1] in keypoints_to_idx[imidx1]])
|
377 |
+
if len(final_matches) > 0:
|
378 |
+
colmap_image_pairs.append(
|
379 |
+
(images[imidx0]['instance'], images[imidx1]['instance']))
|
380 |
+
db.add_matches(imid0, imid1, final_matches)
|
381 |
+
if skip_geometric_verification:
|
382 |
+
db.add_two_view_geometry(imid0, imid1, final_matches)
|
383 |
+
return colmap_image_pairs
|
mast3r/datasets/__init__.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
from .base.mast3r_base_stereo_view_dataset import MASt3RBaseStereoViewDataset
|
5 |
+
|
6 |
+
import mast3r.utils.path_to_dust3r # noqa
|
7 |
+
from dust3r.datasets.arkitscenes import ARKitScenes as DUSt3R_ARKitScenes # noqa
|
8 |
+
from dust3r.datasets.blendedmvs import BlendedMVS as DUSt3R_BlendedMVS # noqa
|
9 |
+
from dust3r.datasets.co3d import Co3d as DUSt3R_Co3d # noqa
|
10 |
+
from dust3r.datasets.megadepth import MegaDepth as DUSt3R_MegaDepth # noqa
|
11 |
+
from dust3r.datasets.scannetpp import ScanNetpp as DUSt3R_ScanNetpp # noqa
|
12 |
+
from dust3r.datasets.staticthings3d import StaticThings3D as DUSt3R_StaticThings3D # noqa
|
13 |
+
from dust3r.datasets.waymo import Waymo as DUSt3R_Waymo # noqa
|
14 |
+
from dust3r.datasets.wildrgbd import WildRGBD as DUSt3R_WildRGBD # noqa
|
15 |
+
|
16 |
+
|
17 |
+
class ARKitScenes(DUSt3R_ARKitScenes, MASt3RBaseStereoViewDataset):
|
18 |
+
def __init__(self, *args, split, ROOT, **kwargs):
|
19 |
+
super().__init__(*args, split=split, ROOT=ROOT, **kwargs)
|
20 |
+
self.is_metric_scale = True
|
21 |
+
|
22 |
+
|
23 |
+
class BlendedMVS(DUSt3R_BlendedMVS, MASt3RBaseStereoViewDataset):
|
24 |
+
def __init__(self, *args, ROOT, split=None, **kwargs):
|
25 |
+
super().__init__(*args, ROOT=ROOT, split=split, **kwargs)
|
26 |
+
self.is_metric_scale = False
|
27 |
+
|
28 |
+
|
29 |
+
class Co3d(DUSt3R_Co3d, MASt3RBaseStereoViewDataset):
|
30 |
+
def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
|
31 |
+
super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)
|
32 |
+
self.is_metric_scale = False
|
33 |
+
|
34 |
+
|
35 |
+
class MegaDepth(DUSt3R_MegaDepth, MASt3RBaseStereoViewDataset):
|
36 |
+
def __init__(self, *args, split, ROOT, **kwargs):
|
37 |
+
super().__init__(*args, split=split, ROOT=ROOT, **kwargs)
|
38 |
+
self.is_metric_scale = False
|
39 |
+
|
40 |
+
|
41 |
+
class ScanNetpp(DUSt3R_ScanNetpp, MASt3RBaseStereoViewDataset):
|
42 |
+
def __init__(self, *args, ROOT, **kwargs):
|
43 |
+
super().__init__(*args, ROOT=ROOT, **kwargs)
|
44 |
+
self.is_metric_scale = True
|
45 |
+
|
46 |
+
|
47 |
+
class StaticThings3D(DUSt3R_StaticThings3D, MASt3RBaseStereoViewDataset):
|
48 |
+
def __init__(self, ROOT, *args, mask_bg='rand', **kwargs):
|
49 |
+
super().__init__(ROOT, *args, mask_bg=mask_bg, **kwargs)
|
50 |
+
self.is_metric_scale = False
|
51 |
+
|
52 |
+
|
53 |
+
class Waymo(DUSt3R_Waymo, MASt3RBaseStereoViewDataset):
|
54 |
+
def __init__(self, *args, ROOT, **kwargs):
|
55 |
+
super().__init__(*args, ROOT=ROOT, **kwargs)
|
56 |
+
self.is_metric_scale = True
|
57 |
+
|
58 |
+
|
59 |
+
class WildRGBD(DUSt3R_WildRGBD, MASt3RBaseStereoViewDataset):
|
60 |
+
def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
|
61 |
+
super().__init__(mask_bg, *args, ROOT=ROOT, **kwargs)
|
62 |
+
self.is_metric_scale = True
|
mast3r/datasets/base/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
mast3r/datasets/base/mast3r_base_stereo_view_dataset.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# base class for implementing datasets
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import PIL.Image
|
8 |
+
import PIL.Image as Image
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import copy
|
12 |
+
|
13 |
+
from mast3r.datasets.utils.cropping import (extract_correspondences_from_pts3d,
|
14 |
+
gen_random_crops, in2d_rect, crop_to_homography)
|
15 |
+
|
16 |
+
import mast3r.utils.path_to_dust3r # noqa
|
17 |
+
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset, view_name, is_good_type # noqa
|
18 |
+
from dust3r.datasets.utils.transforms import ImgNorm
|
19 |
+
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, geotrf, depthmap_to_camera_coordinates
|
20 |
+
import dust3r.datasets.utils.cropping as cropping
|
21 |
+
|
22 |
+
|
23 |
+
class MASt3RBaseStereoViewDataset(BaseStereoViewDataset):
|
24 |
+
def __init__(self, *, # only keyword arguments
|
25 |
+
split=None,
|
26 |
+
resolution=None, # square_size or (width, height) or list of [(width,height), ...]
|
27 |
+
transform=ImgNorm,
|
28 |
+
aug_crop=False,
|
29 |
+
aug_swap=False,
|
30 |
+
aug_monocular=False,
|
31 |
+
aug_portrait_or_landscape=True, # automatic choice between landscape/portrait when possible
|
32 |
+
aug_rot90=False,
|
33 |
+
n_corres=0,
|
34 |
+
nneg=0,
|
35 |
+
n_tentative_crops=4,
|
36 |
+
seed=None):
|
37 |
+
super().__init__(split=split, resolution=resolution, transform=transform, aug_crop=aug_crop, seed=seed)
|
38 |
+
self.is_metric_scale = False # by default a dataset is not metric scale, subclasses can overwrite this
|
39 |
+
|
40 |
+
self.aug_swap = aug_swap
|
41 |
+
self.aug_monocular = aug_monocular
|
42 |
+
self.aug_portrait_or_landscape = aug_portrait_or_landscape
|
43 |
+
self.aug_rot90 = aug_rot90
|
44 |
+
|
45 |
+
self.n_corres = n_corres
|
46 |
+
self.nneg = nneg
|
47 |
+
assert self.n_corres == 'all' or isinstance(self.n_corres, int) or (isinstance(self.n_corres, list) and len(
|
48 |
+
self.n_corres) == self.num_views), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}"
|
49 |
+
assert self.nneg == 0 or self.n_corres != 'all'
|
50 |
+
self.n_tentative_crops = n_tentative_crops
|
51 |
+
|
52 |
+
def _swap_view_aug(self, views):
|
53 |
+
if self._rng.random() < 0.5:
|
54 |
+
views.reverse()
|
55 |
+
|
56 |
+
def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None):
|
57 |
+
""" This function:
|
58 |
+
- first downsizes the image with LANCZOS inteprolation,
|
59 |
+
which is better than bilinear interpolation in
|
60 |
+
"""
|
61 |
+
if not isinstance(image, PIL.Image.Image):
|
62 |
+
image = PIL.Image.fromarray(image)
|
63 |
+
|
64 |
+
# transpose the resolution if necessary
|
65 |
+
W, H = image.size # new size
|
66 |
+
assert resolution[0] >= resolution[1]
|
67 |
+
if H > 1.1 * W:
|
68 |
+
# image is portrait mode
|
69 |
+
resolution = resolution[::-1]
|
70 |
+
elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
|
71 |
+
# image is square, so we chose (portrait, landscape) randomly
|
72 |
+
if rng.integers(2) and self.aug_portrait_or_landscape:
|
73 |
+
resolution = resolution[::-1]
|
74 |
+
|
75 |
+
# high-quality Lanczos down-scaling
|
76 |
+
target_resolution = np.array(resolution)
|
77 |
+
image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
|
78 |
+
|
79 |
+
# actual cropping (if necessary) with bilinear interpolation
|
80 |
+
offset_factor = 0.5
|
81 |
+
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=offset_factor)
|
82 |
+
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
|
83 |
+
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
|
84 |
+
|
85 |
+
return image, depthmap, intrinsics2
|
86 |
+
|
87 |
+
def generate_crops_from_pair(self, view1, view2, resolution, aug_crop_arg, n_crops=4, rng=np.random):
|
88 |
+
views = [view1, view2]
|
89 |
+
|
90 |
+
if aug_crop_arg is False:
|
91 |
+
# compatibility
|
92 |
+
for i in range(2):
|
93 |
+
view = views[i]
|
94 |
+
view['img'], view['depthmap'], view['camera_intrinsics'] = self._crop_resize_if_necessary(view['img'],
|
95 |
+
view['depthmap'],
|
96 |
+
view['camera_intrinsics'],
|
97 |
+
resolution,
|
98 |
+
rng=rng)
|
99 |
+
view['pts3d'], view['valid_mask'] = depthmap_to_absolute_camera_coordinates(view['depthmap'],
|
100 |
+
view['camera_intrinsics'],
|
101 |
+
view['camera_pose'])
|
102 |
+
return
|
103 |
+
|
104 |
+
# extract correspondences
|
105 |
+
corres = extract_correspondences_from_pts3d(*views, target_n_corres=None, rng=rng)
|
106 |
+
|
107 |
+
# generate 4 random crops in each view
|
108 |
+
view_crops = []
|
109 |
+
crops_resolution = []
|
110 |
+
corres_msks = []
|
111 |
+
for i in range(2):
|
112 |
+
|
113 |
+
if aug_crop_arg == 'auto':
|
114 |
+
S = min(views[i]['img'].size)
|
115 |
+
R = min(resolution)
|
116 |
+
aug_crop = S * (S - R) // R
|
117 |
+
aug_crop = max(.1 * S, aug_crop) # for cropping: augment scale of at least 10%, and more if possible
|
118 |
+
else:
|
119 |
+
aug_crop = aug_crop_arg
|
120 |
+
|
121 |
+
# tranpose the target resolution if necessary
|
122 |
+
assert resolution[0] >= resolution[1]
|
123 |
+
W, H = imsize = views[i]['img'].size
|
124 |
+
crop_resolution = resolution
|
125 |
+
if H > 1.1 * W:
|
126 |
+
# image is portrait mode
|
127 |
+
crop_resolution = resolution[::-1]
|
128 |
+
elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]:
|
129 |
+
# image is square, so we chose (portrait, landscape) randomly
|
130 |
+
if rng.integers(2):
|
131 |
+
crop_resolution = resolution[::-1]
|
132 |
+
|
133 |
+
crops = gen_random_crops(imsize, n_crops, crop_resolution, aug_crop=aug_crop, rng=rng)
|
134 |
+
view_crops.append(crops)
|
135 |
+
crops_resolution.append(crop_resolution)
|
136 |
+
|
137 |
+
# compute correspondences
|
138 |
+
corres_msks.append(in2d_rect(corres[i], crops))
|
139 |
+
|
140 |
+
# compute IoU for each
|
141 |
+
intersection = np.float32(corres_msks[0]).T @ np.float32(corres_msks[1])
|
142 |
+
# select best pair of crops
|
143 |
+
best = np.unravel_index(intersection.argmax(), (n_crops, n_crops))
|
144 |
+
crops = [view_crops[i][c] for i, c in enumerate(best)]
|
145 |
+
|
146 |
+
# crop with the homography
|
147 |
+
for i in range(2):
|
148 |
+
view = views[i]
|
149 |
+
imsize, K_new, R, H = crop_to_homography(view['camera_intrinsics'], crops[i], crops_resolution[i])
|
150 |
+
# imsize, K_new, H = upscale_homography(imsize, resolution, K_new, H)
|
151 |
+
|
152 |
+
# update camera params
|
153 |
+
K_old = view['camera_intrinsics']
|
154 |
+
view['camera_intrinsics'] = K_new
|
155 |
+
view['camera_pose'] = view['camera_pose'].copy()
|
156 |
+
view['camera_pose'][:3, :3] = view['camera_pose'][:3, :3] @ R
|
157 |
+
|
158 |
+
# apply homography to image and depthmap
|
159 |
+
homo8 = (H / H[2, 2]).ravel().tolist()[:8]
|
160 |
+
view['img'] = view['img'].transform(imsize, Image.Transform.PERSPECTIVE,
|
161 |
+
homo8,
|
162 |
+
resample=Image.Resampling.BICUBIC)
|
163 |
+
|
164 |
+
depthmap2 = depthmap_to_camera_coordinates(view['depthmap'], K_old)[0] @ R[:, 2]
|
165 |
+
view['depthmap'] = np.array(Image.fromarray(depthmap2).transform(
|
166 |
+
imsize, Image.Transform.PERSPECTIVE, homo8))
|
167 |
+
|
168 |
+
if 'track_labels' in view:
|
169 |
+
# convert from uint64 --> uint32, because PIL.Image cannot handle uint64
|
170 |
+
mapping, track_labels = np.unique(view['track_labels'], return_inverse=True)
|
171 |
+
track_labels = track_labels.astype(np.uint32).reshape(view['track_labels'].shape)
|
172 |
+
|
173 |
+
# homography transformation
|
174 |
+
res = np.array(Image.fromarray(track_labels).transform(imsize, Image.Transform.PERSPECTIVE, homo8))
|
175 |
+
view['track_labels'] = mapping[res] # mapping back to uint64
|
176 |
+
|
177 |
+
# recompute 3d points from scratch
|
178 |
+
view['pts3d'], view['valid_mask'] = depthmap_to_absolute_camera_coordinates(view['depthmap'],
|
179 |
+
view['camera_intrinsics'],
|
180 |
+
view['camera_pose'])
|
181 |
+
|
182 |
+
def __getitem__(self, idx):
|
183 |
+
if isinstance(idx, tuple):
|
184 |
+
# the idx is specifying the aspect-ratio
|
185 |
+
idx, ar_idx = idx
|
186 |
+
else:
|
187 |
+
assert len(self._resolutions) == 1
|
188 |
+
ar_idx = 0
|
189 |
+
|
190 |
+
# set-up the rng
|
191 |
+
if self.seed: # reseed for each __getitem__
|
192 |
+
self._rng = np.random.default_rng(seed=self.seed + idx)
|
193 |
+
elif not hasattr(self, '_rng'):
|
194 |
+
seed = torch.initial_seed() # this is different for each dataloader process
|
195 |
+
self._rng = np.random.default_rng(seed=seed)
|
196 |
+
|
197 |
+
# over-loaded code
|
198 |
+
resolution = self._resolutions[ar_idx] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
|
199 |
+
views = self._get_views(idx, resolution, self._rng)
|
200 |
+
assert len(views) == self.num_views
|
201 |
+
|
202 |
+
for v, view in enumerate(views):
|
203 |
+
assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
|
204 |
+
view['idx'] = (idx, ar_idx, v)
|
205 |
+
view['is_metric_scale'] = self.is_metric_scale
|
206 |
+
|
207 |
+
assert 'camera_intrinsics' in view
|
208 |
+
if 'camera_pose' not in view:
|
209 |
+
view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
|
210 |
+
else:
|
211 |
+
assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
|
212 |
+
assert 'pts3d' not in view
|
213 |
+
assert 'valid_mask' not in view
|
214 |
+
assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
|
215 |
+
|
216 |
+
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
217 |
+
|
218 |
+
view['pts3d'] = pts3d
|
219 |
+
view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
|
220 |
+
|
221 |
+
self.generate_crops_from_pair(views[0], views[1], resolution=resolution,
|
222 |
+
aug_crop_arg=self.aug_crop,
|
223 |
+
n_crops=self.n_tentative_crops,
|
224 |
+
rng=self._rng)
|
225 |
+
for v, view in enumerate(views):
|
226 |
+
# encode the image
|
227 |
+
width, height = view['img'].size
|
228 |
+
view['true_shape'] = np.int32((height, width))
|
229 |
+
view['img'] = self.transform(view['img'])
|
230 |
+
# Pixels for which depth is fundamentally undefined
|
231 |
+
view['sky_mask'] = (view['depthmap'] < 0)
|
232 |
+
|
233 |
+
if self.aug_swap:
|
234 |
+
self._swap_view_aug(views)
|
235 |
+
|
236 |
+
if self.aug_monocular:
|
237 |
+
if self._rng.random() < self.aug_monocular:
|
238 |
+
views = [copy.deepcopy(views[0]) for _ in range(len(views))]
|
239 |
+
|
240 |
+
# automatic extraction of correspondences from pts3d + pose
|
241 |
+
if self.n_corres > 0 and ('corres' not in view):
|
242 |
+
corres1, corres2, valid = extract_correspondences_from_pts3d(*views, self.n_corres,
|
243 |
+
self._rng, nneg=self.nneg)
|
244 |
+
views[0]['corres'] = corres1
|
245 |
+
views[1]['corres'] = corres2
|
246 |
+
views[0]['valid_corres'] = valid
|
247 |
+
views[1]['valid_corres'] = valid
|
248 |
+
|
249 |
+
if self.aug_rot90 is False:
|
250 |
+
pass
|
251 |
+
elif self.aug_rot90 == 'same':
|
252 |
+
rotate_90(views, k=self._rng.choice(4))
|
253 |
+
elif self.aug_rot90 == 'diff':
|
254 |
+
rotate_90(views[:1], k=self._rng.choice(4))
|
255 |
+
rotate_90(views[1:], k=self._rng.choice(4))
|
256 |
+
else:
|
257 |
+
raise ValueError(f'Bad value for {self.aug_rot90=}')
|
258 |
+
|
259 |
+
# check data-types metric_scale
|
260 |
+
for v, view in enumerate(views):
|
261 |
+
if 'corres' not in view:
|
262 |
+
view['corres'] = np.full((self.n_corres, 2), np.nan, dtype=np.float32)
|
263 |
+
|
264 |
+
# check all datatypes
|
265 |
+
for key, val in view.items():
|
266 |
+
res, err_msg = is_good_type(key, val)
|
267 |
+
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
|
268 |
+
K = view['camera_intrinsics']
|
269 |
+
|
270 |
+
# check shapes
|
271 |
+
assert view['depthmap'].shape == view['img'].shape[1:]
|
272 |
+
assert view['depthmap'].shape == view['pts3d'].shape[:2]
|
273 |
+
assert view['depthmap'].shape == view['valid_mask'].shape
|
274 |
+
|
275 |
+
# last thing done!
|
276 |
+
for view in views:
|
277 |
+
# transpose to make sure all views are the same size
|
278 |
+
transpose_to_landscape(view)
|
279 |
+
# this allows to check whether the RNG is is the same state each time
|
280 |
+
view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
|
281 |
+
|
282 |
+
return views
|
283 |
+
|
284 |
+
|
285 |
+
def transpose_to_landscape(view, revert=False):
|
286 |
+
height, width = view['true_shape']
|
287 |
+
|
288 |
+
if width < height:
|
289 |
+
if revert:
|
290 |
+
height, width = width, height
|
291 |
+
|
292 |
+
# rectify portrait to landscape
|
293 |
+
assert view['img'].shape == (3, height, width)
|
294 |
+
view['img'] = view['img'].swapaxes(1, 2)
|
295 |
+
|
296 |
+
assert view['valid_mask'].shape == (height, width)
|
297 |
+
view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
|
298 |
+
|
299 |
+
assert view['sky_mask'].shape == (height, width)
|
300 |
+
view['sky_mask'] = view['sky_mask'].swapaxes(0, 1)
|
301 |
+
|
302 |
+
assert view['depthmap'].shape == (height, width)
|
303 |
+
view['depthmap'] = view['depthmap'].swapaxes(0, 1)
|
304 |
+
|
305 |
+
assert view['pts3d'].shape == (height, width, 3)
|
306 |
+
view['pts3d'] = view['pts3d'].swapaxes(0, 1)
|
307 |
+
|
308 |
+
# transpose x and y pixels
|
309 |
+
view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]
|
310 |
+
|
311 |
+
# transpose correspondences x and y
|
312 |
+
view['corres'] = view['corres'][:, [1, 0]]
|
313 |
+
|
314 |
+
|
315 |
+
def rotate_90(views, k=1):
|
316 |
+
from scipy.spatial.transform import Rotation
|
317 |
+
# print('rotation =', k)
|
318 |
+
|
319 |
+
RT = np.eye(4, dtype=np.float32)
|
320 |
+
RT[:3, :3] = Rotation.from_euler('z', 90 * k, degrees=True).as_matrix()
|
321 |
+
|
322 |
+
for view in views:
|
323 |
+
view['img'] = torch.rot90(view['img'], k=k, dims=(-2, -1)) # WARNING!! dims=(-1,-2) != dims=(-2,-1)
|
324 |
+
view['depthmap'] = np.rot90(view['depthmap'], k=k).copy()
|
325 |
+
view['camera_pose'] = view['camera_pose'] @ RT
|
326 |
+
|
327 |
+
RT2 = np.eye(3, dtype=np.float32)
|
328 |
+
RT2[:2, :2] = RT[:2, :2] * ((1, -1), (-1, 1))
|
329 |
+
H, W = view['depthmap'].shape
|
330 |
+
if k % 4 == 0:
|
331 |
+
pass
|
332 |
+
elif k % 4 == 1:
|
333 |
+
# top-left (0,0) pixel becomes (0,H-1)
|
334 |
+
RT2[:2, 2] = (0, H - 1)
|
335 |
+
elif k % 4 == 2:
|
336 |
+
# top-left (0,0) pixel becomes (W-1,H-1)
|
337 |
+
RT2[:2, 2] = (W - 1, H - 1)
|
338 |
+
elif k % 4 == 3:
|
339 |
+
# top-left (0,0) pixel becomes (W-1,0)
|
340 |
+
RT2[:2, 2] = (W - 1, 0)
|
341 |
+
else:
|
342 |
+
raise ValueError(f'Bad value for {k=}')
|
343 |
+
|
344 |
+
view['camera_intrinsics'][:2, 2] = geotrf(RT2, view['camera_intrinsics'][:2, 2])
|
345 |
+
if k % 2 == 1:
|
346 |
+
K = view['camera_intrinsics']
|
347 |
+
np.fill_diagonal(K, K.diagonal()[[1, 0, 2]])
|
348 |
+
|
349 |
+
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
|
350 |
+
view['pts3d'] = pts3d
|
351 |
+
view['valid_mask'] = np.rot90(view['valid_mask'], k=k).copy()
|
352 |
+
view['sky_mask'] = np.rot90(view['sky_mask'], k=k).copy()
|
353 |
+
|
354 |
+
view['corres'] = geotrf(RT2, view['corres']).round().astype(view['corres'].dtype)
|
355 |
+
view['true_shape'] = np.int32((H, W))
|
mast3r/datasets/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
mast3r/datasets/utils/cropping.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# cropping/match extraction
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import mast3r.utils.path_to_dust3r # noqa
|
9 |
+
from dust3r.utils.device import to_numpy
|
10 |
+
from dust3r.utils.geometry import inv, geotrf
|
11 |
+
|
12 |
+
|
13 |
+
def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
|
14 |
+
is_reciprocal1 = (corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2)))
|
15 |
+
pos1 = is_reciprocal1.nonzero()[0]
|
16 |
+
pos2 = corres_1_to_2[pos1]
|
17 |
+
if ret_recip:
|
18 |
+
return is_reciprocal1, pos1, pos2
|
19 |
+
return pos1, pos2
|
20 |
+
|
21 |
+
|
22 |
+
def extract_correspondences_from_pts3d(view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0):
|
23 |
+
view1, view2 = to_numpy((view1, view2))
|
24 |
+
# project pixels from image1 --> 3d points --> image2 pixels
|
25 |
+
shape1, corres1_to_2 = reproject_view(view1['pts3d'], view2)
|
26 |
+
shape2, corres2_to_1 = reproject_view(view2['pts3d'], view1)
|
27 |
+
|
28 |
+
# compute reciprocal correspondences:
|
29 |
+
# pos1 == valid pixels (correspondences) in image1
|
30 |
+
is_reciprocal1, pos1, pos2 = reciprocal_1d(corres1_to_2, corres2_to_1, ret_recip=True)
|
31 |
+
is_reciprocal2 = (corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1)))
|
32 |
+
|
33 |
+
if target_n_corres is None:
|
34 |
+
if ret_xy:
|
35 |
+
pos1 = unravel_xy(pos1, shape1)
|
36 |
+
pos2 = unravel_xy(pos2, shape2)
|
37 |
+
return pos1, pos2
|
38 |
+
|
39 |
+
available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
|
40 |
+
target_n_positives = int(target_n_corres * (1 - nneg))
|
41 |
+
n_positives = min(len(pos1), target_n_positives)
|
42 |
+
n_negatives = min(target_n_corres - n_positives, available_negatives)
|
43 |
+
|
44 |
+
if n_negatives + n_positives != target_n_corres:
|
45 |
+
# should be really rare => when there are not enough negatives
|
46 |
+
# in that case, break nneg and add a few more positives ?
|
47 |
+
n_positives = target_n_corres - n_negatives
|
48 |
+
assert n_positives <= len(pos1)
|
49 |
+
|
50 |
+
assert n_positives <= len(pos1)
|
51 |
+
assert n_positives <= len(pos2)
|
52 |
+
assert n_negatives <= (~is_reciprocal1).sum()
|
53 |
+
assert n_negatives <= (~is_reciprocal2).sum()
|
54 |
+
assert n_positives + n_negatives == target_n_corres
|
55 |
+
|
56 |
+
valid = np.ones(n_positives, dtype=bool)
|
57 |
+
if n_positives < len(pos1):
|
58 |
+
# random sub-sampling of valid correspondences
|
59 |
+
perm = rng.permutation(len(pos1))[:n_positives]
|
60 |
+
pos1 = pos1[perm]
|
61 |
+
pos2 = pos2[perm]
|
62 |
+
|
63 |
+
if n_negatives > 0:
|
64 |
+
# add false correspondences if not enough
|
65 |
+
def norm(p): return p / p.sum()
|
66 |
+
pos1 = np.r_[pos1, rng.choice(shape1[0] * shape1[1], size=n_negatives, replace=False, p=norm(~is_reciprocal1))]
|
67 |
+
pos2 = np.r_[pos2, rng.choice(shape2[0] * shape2[1], size=n_negatives, replace=False, p=norm(~is_reciprocal2))]
|
68 |
+
valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
|
69 |
+
|
70 |
+
# convert (x+W*y) back to 2d (x,y) coordinates
|
71 |
+
if ret_xy:
|
72 |
+
pos1 = unravel_xy(pos1, shape1)
|
73 |
+
pos2 = unravel_xy(pos2, shape2)
|
74 |
+
return pos1, pos2, valid
|
75 |
+
|
76 |
+
|
77 |
+
def reproject_view(pts3d, view2):
|
78 |
+
shape = view2['pts3d'].shape[:2]
|
79 |
+
return reproject(pts3d, view2['camera_intrinsics'], inv(view2['camera_pose']), shape)
|
80 |
+
|
81 |
+
|
82 |
+
def reproject(pts3d, K, world2cam, shape):
|
83 |
+
H, W, THREE = pts3d.shape
|
84 |
+
assert THREE == 3
|
85 |
+
|
86 |
+
# reproject in camera2 space
|
87 |
+
with np.errstate(divide='ignore', invalid='ignore'):
|
88 |
+
pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
|
89 |
+
|
90 |
+
# quantize to pixel positions
|
91 |
+
return (H, W), ravel_xy(pos, shape)
|
92 |
+
|
93 |
+
|
94 |
+
def ravel_xy(pos, shape):
|
95 |
+
H, W = shape
|
96 |
+
with np.errstate(invalid='ignore'):
|
97 |
+
qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
|
98 |
+
quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy)
|
99 |
+
return quantized_pos
|
100 |
+
|
101 |
+
|
102 |
+
def unravel_xy(pos, shape):
|
103 |
+
# convert (x+W*y) back to 2d (x,y) coordinates
|
104 |
+
return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
|
105 |
+
|
106 |
+
|
107 |
+
def _rotation_origin_to_pt(target):
|
108 |
+
""" Align the origin (0,0,1) with the target point (x,y,1) in projective space.
|
109 |
+
Method: rotate z to put target on (x'+,0,1), then rotate on Y to get (0,0,1) and un-rotate z.
|
110 |
+
"""
|
111 |
+
from scipy.spatial.transform import Rotation
|
112 |
+
x, y = target
|
113 |
+
rot_z = np.arctan2(y, x)
|
114 |
+
rot_y = np.arctan(np.linalg.norm(target))
|
115 |
+
R = Rotation.from_euler('ZYZ', [rot_z, rot_y, -rot_z]).as_matrix()
|
116 |
+
return R
|
117 |
+
|
118 |
+
|
119 |
+
def _dotmv(Trf, pts, ncol=None, norm=False):
|
120 |
+
assert Trf.ndim >= 2
|
121 |
+
ncol = ncol or pts.shape[-1]
|
122 |
+
|
123 |
+
# adapt shape if necessary
|
124 |
+
output_reshape = pts.shape[:-1]
|
125 |
+
if Trf.ndim >= 3:
|
126 |
+
n = Trf.ndim - 2
|
127 |
+
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
|
128 |
+
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
|
129 |
+
|
130 |
+
if pts.ndim > Trf.ndim:
|
131 |
+
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
|
132 |
+
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
|
133 |
+
elif pts.ndim == 2:
|
134 |
+
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
|
135 |
+
pts = pts[:, None, :]
|
136 |
+
|
137 |
+
if pts.shape[-1] + 1 == Trf.shape[-1]:
|
138 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
139 |
+
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
|
140 |
+
|
141 |
+
elif pts.shape[-1] == Trf.shape[-1]:
|
142 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
143 |
+
pts = pts @ Trf
|
144 |
+
else:
|
145 |
+
pts = Trf @ pts.T
|
146 |
+
if pts.ndim >= 2:
|
147 |
+
pts = pts.swapaxes(-1, -2)
|
148 |
+
|
149 |
+
if norm:
|
150 |
+
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
151 |
+
if norm != 1:
|
152 |
+
pts *= norm
|
153 |
+
|
154 |
+
res = pts[..., :ncol].reshape(*output_reshape, ncol)
|
155 |
+
return res
|
156 |
+
|
157 |
+
|
158 |
+
def crop_to_homography(K, crop, target_size=None):
|
159 |
+
""" Given an image and its intrinsics,
|
160 |
+
we want to replicate a rectangular crop with an homography,
|
161 |
+
so that the principal point of the new 'crop' is centered.
|
162 |
+
"""
|
163 |
+
# build intrinsics for the crop
|
164 |
+
crop = np.round(crop)
|
165 |
+
crop_size = crop[2:] - crop[:2]
|
166 |
+
K2 = K.copy() # same focal
|
167 |
+
K2[:2, 2] = crop_size / 2 # new principal point is perfectly centered
|
168 |
+
|
169 |
+
# find which corner is the most far-away from current principal point
|
170 |
+
# so that the final homography does not go over the image borders
|
171 |
+
corners = crop.reshape(-1, 2)
|
172 |
+
corner_idx = np.abs(corners - K[:2, 2]).argmax(0)
|
173 |
+
corner = corners[corner_idx, [0, 1]]
|
174 |
+
# align with the corresponding corner from the target view
|
175 |
+
corner2 = np.c_[[0, 0], crop_size][[0, 1], corner_idx]
|
176 |
+
|
177 |
+
old_pt = _dotmv(np.linalg.inv(K), corner, norm=1)
|
178 |
+
new_pt = _dotmv(np.linalg.inv(K2), corner2, norm=1)
|
179 |
+
R = _rotation_origin_to_pt(old_pt) @ np.linalg.inv(_rotation_origin_to_pt(new_pt))
|
180 |
+
|
181 |
+
if target_size is not None:
|
182 |
+
imsize = target_size
|
183 |
+
target_size = np.asarray(target_size)
|
184 |
+
scaling = min(target_size / crop_size)
|
185 |
+
K2[:2] *= scaling
|
186 |
+
K2[:2, 2] = target_size / 2
|
187 |
+
else:
|
188 |
+
imsize = tuple(np.int32(crop_size).tolist())
|
189 |
+
|
190 |
+
return imsize, K2, R, K @ R @ np.linalg.inv(K2)
|
191 |
+
|
192 |
+
|
193 |
+
def gen_random_crops(imsize, n_crops, resolution, aug_crop, rng=np.random):
|
194 |
+
""" Generate random crops of size=resolution,
|
195 |
+
for an input image upscaled to (imsize + randint(0 , aug_crop))
|
196 |
+
"""
|
197 |
+
resolution_crop = np.array(resolution) * min(np.array(imsize) / resolution)
|
198 |
+
|
199 |
+
# (virtually) upscale the input image
|
200 |
+
# scaling = rng.uniform(1, 1+(aug_crop+1)/min(imsize))
|
201 |
+
scaling = np.exp(rng.uniform(0, np.log(1 + aug_crop / min(imsize))))
|
202 |
+
imsize2 = np.int32(np.array(imsize) * scaling)
|
203 |
+
|
204 |
+
# generate some random crops
|
205 |
+
topleft = rng.random((n_crops, 2)) * (imsize2 - resolution_crop)
|
206 |
+
crops = np.c_[topleft, topleft + resolution_crop]
|
207 |
+
# print(f"{scaling=}, {topleft=}")
|
208 |
+
# reduce the resolution to come back to original size
|
209 |
+
crops /= scaling
|
210 |
+
return crops
|
211 |
+
|
212 |
+
|
213 |
+
def in2d_rect(corres, crops):
|
214 |
+
# corres = (N,2)
|
215 |
+
# crops = (M,4)
|
216 |
+
# output = (N, M)
|
217 |
+
is_sup = (corres[:, None] >= crops[None, :, 0:2])
|
218 |
+
is_inf = (corres[:, None] < crops[None, :, 2:4])
|
219 |
+
return (is_sup & is_inf).all(axis=-1)
|
mast3r/demo.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
3 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
4 |
+
#
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# sparse gradio demo functions
|
7 |
+
# --------------------------------------------------------
|
8 |
+
import math
|
9 |
+
import gradio
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
import functools
|
13 |
+
import trimesh
|
14 |
+
import copy
|
15 |
+
from scipy.spatial.transform import Rotation
|
16 |
+
import tempfile
|
17 |
+
import shutil
|
18 |
+
|
19 |
+
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
20 |
+
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
|
21 |
+
|
22 |
+
import mast3r.utils.path_to_dust3r # noqa
|
23 |
+
from dust3r.image_pairs import make_pairs
|
24 |
+
from dust3r.utils.image import load_images
|
25 |
+
from dust3r.utils.device import to_numpy
|
26 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
27 |
+
|
28 |
+
import matplotlib.pyplot as pl
|
29 |
+
|
30 |
+
|
31 |
+
class SparseGAState():
|
32 |
+
def __init__(self, sparse_ga, should_delete=False, cache_dir=None, outfile_name=None):
|
33 |
+
self.sparse_ga = sparse_ga
|
34 |
+
self.cache_dir = cache_dir
|
35 |
+
self.outfile_name = outfile_name
|
36 |
+
self.should_delete = should_delete
|
37 |
+
|
38 |
+
def __del__(self):
|
39 |
+
if not self.should_delete:
|
40 |
+
return
|
41 |
+
if self.cache_dir is not None and os.path.isdir(self.cache_dir):
|
42 |
+
shutil.rmtree(self.cache_dir)
|
43 |
+
self.cache_dir = None
|
44 |
+
if self.outfile_name is not None and os.path.isfile(self.outfile_name):
|
45 |
+
os.remove(self.outfile_name)
|
46 |
+
self.outfile_name = None
|
47 |
+
|
48 |
+
|
49 |
+
def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
50 |
+
cam_color=None, as_pointcloud=False,
|
51 |
+
transparent_cams=False, silent=False):
|
52 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
53 |
+
pts3d = to_numpy(pts3d)
|
54 |
+
imgs = to_numpy(imgs)
|
55 |
+
focals = to_numpy(focals)
|
56 |
+
cams2world = to_numpy(cams2world)
|
57 |
+
|
58 |
+
scene = trimesh.Scene()
|
59 |
+
|
60 |
+
# full pointcloud
|
61 |
+
if as_pointcloud:
|
62 |
+
pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
|
63 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
|
64 |
+
valid_msk = np.isfinite(pts.sum(axis=1))
|
65 |
+
pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
|
66 |
+
scene.add_geometry(pct)
|
67 |
+
else:
|
68 |
+
meshes = []
|
69 |
+
for i in range(len(imgs)):
|
70 |
+
pts3d_i = pts3d[i].reshape(imgs[i].shape)
|
71 |
+
msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
|
72 |
+
meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
|
73 |
+
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
74 |
+
scene.add_geometry(mesh)
|
75 |
+
|
76 |
+
# add each camera
|
77 |
+
for i, pose_c2w in enumerate(cams2world):
|
78 |
+
if isinstance(cam_color, list):
|
79 |
+
camera_edge_color = cam_color[i]
|
80 |
+
else:
|
81 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
82 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
83 |
+
None if transparent_cams else imgs[i], focals[i],
|
84 |
+
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
|
85 |
+
|
86 |
+
rot = np.eye(4)
|
87 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
88 |
+
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
89 |
+
if not silent:
|
90 |
+
print('(exporting 3D scene to', outfile, ')')
|
91 |
+
scene.export(file_obj=outfile)
|
92 |
+
return outfile
|
93 |
+
|
94 |
+
|
95 |
+
def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
|
96 |
+
clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
|
97 |
+
"""
|
98 |
+
extract 3D_model (glb file) from a reconstructed scene
|
99 |
+
"""
|
100 |
+
if scene_state is None:
|
101 |
+
return None
|
102 |
+
outfile = scene_state.outfile_name
|
103 |
+
if outfile is None:
|
104 |
+
return None
|
105 |
+
|
106 |
+
# get optimized values from scene
|
107 |
+
scene = scene_state.sparse_ga
|
108 |
+
rgbimg = scene.imgs
|
109 |
+
focals = scene.get_focals().cpu()
|
110 |
+
cams2world = scene.get_im_poses().cpu()
|
111 |
+
|
112 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
113 |
+
if TSDF_thresh > 0:
|
114 |
+
tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
|
115 |
+
pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
|
116 |
+
else:
|
117 |
+
pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
|
118 |
+
msk = to_numpy([c > min_conf_thr for c in confs])
|
119 |
+
return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
120 |
+
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
121 |
+
|
122 |
+
|
123 |
+
def get_reconstructed_scene(outdir, gradio_delete_cache, model, device, silent, image_size, current_scene_state,
|
124 |
+
filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
125 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
|
126 |
+
win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
|
127 |
+
"""
|
128 |
+
from a list of images, run mast3r inference, sparse global aligner.
|
129 |
+
then run get_3D_model_from_scene
|
130 |
+
"""
|
131 |
+
print(image_size, current_scene_state, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
132 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
|
133 |
+
win_cyclic, refid, TSDF_thresh, shared_intrinsics)
|
134 |
+
# 512 None refine+depth 0.07 500 0.014 200 1.5 5 True False True False 0.2 logwin 6 False 0 0 True
|
135 |
+
imgs = load_images(filelist, size=image_size, verbose=not silent)
|
136 |
+
if len(imgs) == 1:
|
137 |
+
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
138 |
+
imgs[1]['idx'] = 1
|
139 |
+
filelist = [filelist[0], filelist[0] + '_2']
|
140 |
+
|
141 |
+
scene_graph_params = [scenegraph_type]
|
142 |
+
if scenegraph_type in ["swin", "logwin"]:
|
143 |
+
scene_graph_params.append(str(winsize))
|
144 |
+
elif scenegraph_type == "oneref":
|
145 |
+
scene_graph_params.append(str(refid))
|
146 |
+
if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
|
147 |
+
scene_graph_params.append('noncyclic')
|
148 |
+
scene_graph = '-'.join(scene_graph_params)
|
149 |
+
pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
|
150 |
+
print(pairs, len(imgs))
|
151 |
+
if optim_level == 'coarse':
|
152 |
+
niter2 = 0
|
153 |
+
# Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
|
154 |
+
if current_scene_state is not None and \
|
155 |
+
not current_scene_state.should_delete and \
|
156 |
+
current_scene_state.cache_dir is not None:
|
157 |
+
cache_dir = current_scene_state.cache_dir
|
158 |
+
elif gradio_delete_cache:
|
159 |
+
cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir)
|
160 |
+
else:
|
161 |
+
cache_dir = os.path.join(outdir, 'cache')
|
162 |
+
os.makedirs(cache_dir, exist_ok=True)
|
163 |
+
scene = sparse_global_alignment(filelist, pairs, cache_dir,
|
164 |
+
model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
|
165 |
+
opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
|
166 |
+
matching_conf_thr=matching_conf_thr, **kw)
|
167 |
+
if current_scene_state is not None and \
|
168 |
+
not current_scene_state.should_delete and \
|
169 |
+
current_scene_state.outfile_name is not None:
|
170 |
+
outfile_name = current_scene_state.outfile_name
|
171 |
+
else:
|
172 |
+
outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
|
173 |
+
|
174 |
+
scene_state = SparseGAState(scene, gradio_delete_cache, cache_dir, outfile_name)
|
175 |
+
outfile = get_3D_model_from_scene(silent, scene_state, min_conf_thr, as_pointcloud, mask_sky,
|
176 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh)
|
177 |
+
print(outfile)
|
178 |
+
return scene_state, outfile
|
179 |
+
|
180 |
+
|
181 |
+
def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
|
182 |
+
num_files = len(inputfiles) if inputfiles is not None else 1
|
183 |
+
show_win_controls = scenegraph_type in ["swin", "logwin"]
|
184 |
+
show_winsize = scenegraph_type in ["swin", "logwin"]
|
185 |
+
show_cyclic = scenegraph_type in ["swin", "logwin"]
|
186 |
+
max_winsize, min_winsize = 1, 1
|
187 |
+
if scenegraph_type == "swin":
|
188 |
+
if win_cyclic:
|
189 |
+
max_winsize = max(1, math.ceil((num_files - 1) / 2))
|
190 |
+
else:
|
191 |
+
max_winsize = num_files - 1
|
192 |
+
elif scenegraph_type == "logwin":
|
193 |
+
if win_cyclic:
|
194 |
+
half_size = math.ceil((num_files - 1) / 2)
|
195 |
+
max_winsize = max(1, math.ceil(math.log(half_size, 2)))
|
196 |
+
else:
|
197 |
+
max_winsize = max(1, math.ceil(math.log(num_files, 2)))
|
198 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
199 |
+
minimum=min_winsize, maximum=max_winsize, step=1, visible=show_winsize)
|
200 |
+
win_cyclic = gradio.Checkbox(value=win_cyclic, label="Cyclic sequence", visible=show_cyclic)
|
201 |
+
win_col = gradio.Column(visible=show_win_controls)
|
202 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
203 |
+
maximum=num_files - 1, step=1, visible=scenegraph_type == 'oneref')
|
204 |
+
return win_col, winsize, win_cyclic, refid
|
205 |
+
|
206 |
+
|
207 |
+
def main_demo(tmpdirname, model, device, image_size, server_name, server_port, silent=False,
|
208 |
+
share=False, gradio_delete_cache=False):
|
209 |
+
if not silent:
|
210 |
+
print('Outputing stuff in', tmpdirname)
|
211 |
+
|
212 |
+
recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, gradio_delete_cache, model, device,
|
213 |
+
silent, image_size)
|
214 |
+
model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
|
215 |
+
|
216 |
+
def get_context(delete_cache):
|
217 |
+
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
218 |
+
title = "MASt3R Demo"
|
219 |
+
if delete_cache:
|
220 |
+
return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
|
221 |
+
else:
|
222 |
+
return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
|
223 |
+
|
224 |
+
with get_context(gradio_delete_cache) as demo:
|
225 |
+
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
|
226 |
+
scene = gradio.State(None)
|
227 |
+
gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
|
228 |
+
with gradio.Column():
|
229 |
+
inputfiles = gradio.File(file_count="multiple")
|
230 |
+
with gradio.Row():
|
231 |
+
with gradio.Column():
|
232 |
+
with gradio.Row():
|
233 |
+
lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
|
234 |
+
niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
|
235 |
+
label="num_iterations", info="For coarse alignment!")
|
236 |
+
lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
|
237 |
+
niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
|
238 |
+
label="num_iterations", info="For refinement!")
|
239 |
+
optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
|
240 |
+
value='refine+depth', label="OptLevel",
|
241 |
+
info="Optimization level")
|
242 |
+
with gradio.Row():
|
243 |
+
matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
|
244 |
+
minimum=0., maximum=30., step=0.1,
|
245 |
+
info="Before Fallback to Regr3D!")
|
246 |
+
shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
|
247 |
+
info="Only optimize one set of intrinsics for all views")
|
248 |
+
scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
|
249 |
+
("swin: sliding window", "swin"),
|
250 |
+
("logwin: sliding window with long range", "logwin"),
|
251 |
+
("oneref: match one image with all", "oneref")],
|
252 |
+
value='complete', label="Scenegraph",
|
253 |
+
info="Define how to make pairs",
|
254 |
+
interactive=True)
|
255 |
+
with gradio.Column(visible=False) as win_col:
|
256 |
+
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
|
257 |
+
minimum=1, maximum=1, step=1)
|
258 |
+
win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
|
259 |
+
refid = gradio.Slider(label="Scene Graph: Id", value=0,
|
260 |
+
minimum=0, maximum=0, step=1, visible=False)
|
261 |
+
run_btn = gradio.Button("Run")
|
262 |
+
|
263 |
+
with gradio.Row():
|
264 |
+
# adjust the confidence threshold
|
265 |
+
min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
|
266 |
+
# adjust the camera size in the output pointcloud
|
267 |
+
cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
|
268 |
+
TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01)
|
269 |
+
with gradio.Row():
|
270 |
+
as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
|
271 |
+
# two post process implemented
|
272 |
+
mask_sky = gradio.Checkbox(value=False, label="Mask sky")
|
273 |
+
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
|
274 |
+
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
|
275 |
+
|
276 |
+
outmodel = gradio.Model3D()
|
277 |
+
|
278 |
+
# events
|
279 |
+
scenegraph_type.change(set_scenegraph_options,
|
280 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
281 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
282 |
+
inputfiles.change(set_scenegraph_options,
|
283 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
284 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
285 |
+
win_cyclic.change(set_scenegraph_options,
|
286 |
+
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
|
287 |
+
outputs=[win_col, winsize, win_cyclic, refid])
|
288 |
+
run_btn.click(fn=recon_fun,
|
289 |
+
inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
290 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
291 |
+
scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
|
292 |
+
outputs=[scene, outmodel])
|
293 |
+
min_conf_thr.release(fn=model_from_scene_fun,
|
294 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
295 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
296 |
+
outputs=outmodel)
|
297 |
+
cam_size.change(fn=model_from_scene_fun,
|
298 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
299 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
300 |
+
outputs=outmodel)
|
301 |
+
TSDF_thresh.change(fn=model_from_scene_fun,
|
302 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
303 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
304 |
+
outputs=outmodel)
|
305 |
+
as_pointcloud.change(fn=model_from_scene_fun,
|
306 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
307 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
308 |
+
outputs=outmodel)
|
309 |
+
mask_sky.change(fn=model_from_scene_fun,
|
310 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
311 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
312 |
+
outputs=outmodel)
|
313 |
+
clean_depth.change(fn=model_from_scene_fun,
|
314 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
315 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
316 |
+
outputs=outmodel)
|
317 |
+
transparent_cams.change(model_from_scene_fun,
|
318 |
+
inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
319 |
+
clean_depth, transparent_cams, cam_size, TSDF_thresh],
|
320 |
+
outputs=outmodel)
|
321 |
+
demo.launch(share=share, server_name=server_name, server_port=server_port)
|
mast3r/fast_nn.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# MASt3R Fast Nearest Neighbor
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import math
|
10 |
+
from scipy.spatial import KDTree
|
11 |
+
|
12 |
+
import mast3r.utils.path_to_dust3r # noqa
|
13 |
+
from dust3r.utils.device import to_numpy, todevice # noqa
|
14 |
+
|
15 |
+
|
16 |
+
@torch.no_grad()
|
17 |
+
def bruteforce_reciprocal_nns(A, B, device='cuda', block_size=None, dist='l2'):
|
18 |
+
if isinstance(A, np.ndarray):
|
19 |
+
A = torch.from_numpy(A).to(device)
|
20 |
+
if isinstance(B, np.ndarray):
|
21 |
+
B = torch.from_numpy(B).to(device)
|
22 |
+
|
23 |
+
A = A.to(device)
|
24 |
+
B = B.to(device)
|
25 |
+
|
26 |
+
if dist == 'l2':
|
27 |
+
dist_func = torch.cdist
|
28 |
+
argmin = torch.min
|
29 |
+
elif dist == 'dot':
|
30 |
+
def dist_func(A, B):
|
31 |
+
return A @ B.T
|
32 |
+
|
33 |
+
def argmin(X, dim):
|
34 |
+
sim, nn = torch.max(X, dim=dim)
|
35 |
+
return sim.neg_(), nn
|
36 |
+
else:
|
37 |
+
raise ValueError(f'Unknown {dist=}')
|
38 |
+
|
39 |
+
if block_size is None or len(A) * len(B) <= block_size**2:
|
40 |
+
dists = dist_func(A, B)
|
41 |
+
_, nn_A = argmin(dists, dim=1)
|
42 |
+
_, nn_B = argmin(dists, dim=0)
|
43 |
+
else:
|
44 |
+
dis_A = torch.full((A.shape[0],), float('inf'), device=device, dtype=A.dtype)
|
45 |
+
dis_B = torch.full((B.shape[0],), float('inf'), device=device, dtype=B.dtype)
|
46 |
+
nn_A = torch.full((A.shape[0],), -1, device=device, dtype=torch.int64)
|
47 |
+
nn_B = torch.full((B.shape[0],), -1, device=device, dtype=torch.int64)
|
48 |
+
number_of_iteration_A = math.ceil(A.shape[0] / block_size)
|
49 |
+
number_of_iteration_B = math.ceil(B.shape[0] / block_size)
|
50 |
+
|
51 |
+
for i in range(number_of_iteration_A):
|
52 |
+
A_i = A[i * block_size:(i + 1) * block_size]
|
53 |
+
for j in range(number_of_iteration_B):
|
54 |
+
B_j = B[j * block_size:(j + 1) * block_size]
|
55 |
+
dists_blk = dist_func(A_i, B_j) # A, B, 1
|
56 |
+
# dists_blk = dists[i * block_size:(i+1)*block_size, j * block_size:(j+1)*block_size]
|
57 |
+
min_A_i, argmin_A_i = argmin(dists_blk, dim=1)
|
58 |
+
min_B_j, argmin_B_j = argmin(dists_blk, dim=0)
|
59 |
+
|
60 |
+
col_mask = min_A_i < dis_A[i * block_size:(i + 1) * block_size]
|
61 |
+
line_mask = min_B_j < dis_B[j * block_size:(j + 1) * block_size]
|
62 |
+
|
63 |
+
dis_A[i * block_size:(i + 1) * block_size][col_mask] = min_A_i[col_mask]
|
64 |
+
dis_B[j * block_size:(j + 1) * block_size][line_mask] = min_B_j[line_mask]
|
65 |
+
|
66 |
+
nn_A[i * block_size:(i + 1) * block_size][col_mask] = argmin_A_i[col_mask] + (j * block_size)
|
67 |
+
nn_B[j * block_size:(j + 1) * block_size][line_mask] = argmin_B_j[line_mask] + (i * block_size)
|
68 |
+
nn_A = nn_A.cpu().numpy()
|
69 |
+
nn_B = nn_B.cpu().numpy()
|
70 |
+
return nn_A, nn_B
|
71 |
+
|
72 |
+
|
73 |
+
class cdistMatcher:
|
74 |
+
def __init__(self, db_pts, device='cuda'):
|
75 |
+
self.db_pts = db_pts.to(device)
|
76 |
+
self.device = device
|
77 |
+
|
78 |
+
def query(self, queries, k=1, **kw):
|
79 |
+
assert k == 1
|
80 |
+
if queries.numel() == 0:
|
81 |
+
return None, []
|
82 |
+
nnA, nnB = bruteforce_reciprocal_nns(queries, self.db_pts, device=self.device, **kw)
|
83 |
+
dis = None
|
84 |
+
return dis, nnA
|
85 |
+
|
86 |
+
|
87 |
+
def merge_corres(idx1, idx2, shape1=None, shape2=None, ret_xy=True, ret_index=False):
|
88 |
+
assert idx1.dtype == idx2.dtype == np.int32
|
89 |
+
|
90 |
+
# unique and sort along idx1
|
91 |
+
corres = np.unique(np.c_[idx2, idx1].view(np.int64), return_index=ret_index)
|
92 |
+
if ret_index:
|
93 |
+
corres, indices = corres
|
94 |
+
xy2, xy1 = corres[:, None].view(np.int32).T
|
95 |
+
|
96 |
+
if ret_xy:
|
97 |
+
assert shape1 and shape2
|
98 |
+
xy1 = np.unravel_index(xy1, shape1)
|
99 |
+
xy2 = np.unravel_index(xy2, shape2)
|
100 |
+
if ret_xy != 'y_x':
|
101 |
+
xy1 = xy1[0].base[:, ::-1]
|
102 |
+
xy2 = xy2[0].base[:, ::-1]
|
103 |
+
|
104 |
+
if ret_index:
|
105 |
+
return xy1, xy2, indices
|
106 |
+
return xy1, xy2
|
107 |
+
|
108 |
+
|
109 |
+
def fast_reciprocal_NNs(pts1, pts2, subsample_or_initxy1=8, ret_xy=True, pixel_tol=0, ret_basin=False,
|
110 |
+
device='cuda', **matcher_kw):
|
111 |
+
H1, W1, DIM1 = pts1.shape
|
112 |
+
H2, W2, DIM2 = pts2.shape
|
113 |
+
assert DIM1 == DIM2
|
114 |
+
|
115 |
+
pts1 = pts1.reshape(-1, DIM1)
|
116 |
+
pts2 = pts2.reshape(-1, DIM2)
|
117 |
+
|
118 |
+
if isinstance(subsample_or_initxy1, int) and pixel_tol == 0:
|
119 |
+
S = subsample_or_initxy1
|
120 |
+
y1, x1 = np.mgrid[S // 2:H1:S, S // 2:W1:S].reshape(2, -1)
|
121 |
+
max_iter = 10
|
122 |
+
else:
|
123 |
+
x1, y1 = subsample_or_initxy1
|
124 |
+
if isinstance(x1, torch.Tensor):
|
125 |
+
x1 = x1.cpu().numpy()
|
126 |
+
if isinstance(y1, torch.Tensor):
|
127 |
+
y1 = y1.cpu().numpy()
|
128 |
+
max_iter = 1
|
129 |
+
|
130 |
+
xy1 = np.int32(np.unique(x1 + W1 * y1)) # make sure there's no doublons
|
131 |
+
xy2 = np.full_like(xy1, -1)
|
132 |
+
old_xy1 = xy1.copy()
|
133 |
+
old_xy2 = xy2.copy()
|
134 |
+
|
135 |
+
if 'dist' in matcher_kw or 'block_size' in matcher_kw \
|
136 |
+
or (isinstance(device, str) and device.startswith('cuda')) \
|
137 |
+
or (isinstance(device, torch.device) and device.type.startswith('cuda')):
|
138 |
+
pts1 = pts1.to(device)
|
139 |
+
pts2 = pts2.to(device)
|
140 |
+
tree1 = cdistMatcher(pts1, device=device)
|
141 |
+
tree2 = cdistMatcher(pts2, device=device)
|
142 |
+
else:
|
143 |
+
pts1, pts2 = to_numpy((pts1, pts2))
|
144 |
+
tree1 = KDTree(pts1)
|
145 |
+
tree2 = KDTree(pts2)
|
146 |
+
|
147 |
+
notyet = np.ones(len(xy1), dtype=bool)
|
148 |
+
basin = np.full((H1 * W1 + 1,), -1, dtype=np.int32) if ret_basin else None
|
149 |
+
|
150 |
+
niter = 0
|
151 |
+
# n_notyet = [len(notyet)]
|
152 |
+
while notyet.any():
|
153 |
+
_, xy2[notyet] = to_numpy(tree2.query(pts1[xy1[notyet]], **matcher_kw))
|
154 |
+
if not ret_basin:
|
155 |
+
notyet &= (old_xy2 != xy2) # remove points that have converged
|
156 |
+
|
157 |
+
_, xy1[notyet] = to_numpy(tree1.query(pts2[xy2[notyet]], **matcher_kw))
|
158 |
+
if ret_basin:
|
159 |
+
basin[old_xy1[notyet]] = xy1[notyet]
|
160 |
+
notyet &= (old_xy1 != xy1) # remove points that have converged
|
161 |
+
|
162 |
+
# n_notyet.append(notyet.sum())
|
163 |
+
niter += 1
|
164 |
+
if niter >= max_iter:
|
165 |
+
break
|
166 |
+
|
167 |
+
old_xy2[:] = xy2
|
168 |
+
old_xy1[:] = xy1
|
169 |
+
|
170 |
+
# print('notyet_stats:', ' '.join(map(str, (n_notyet+[0]*10)[:max_iter])))
|
171 |
+
|
172 |
+
if pixel_tol > 0:
|
173 |
+
# in case we only want to match some specific points
|
174 |
+
# and still have some way of checking reciprocity
|
175 |
+
old_yx1 = np.unravel_index(old_xy1, (H1, W1))[0].base
|
176 |
+
new_yx1 = np.unravel_index(xy1, (H1, W1))[0].base
|
177 |
+
dis = np.linalg.norm(old_yx1 - new_yx1, axis=-1)
|
178 |
+
converged = dis < pixel_tol
|
179 |
+
if not isinstance(subsample_or_initxy1, int):
|
180 |
+
xy1 = old_xy1 # replace new points by old ones
|
181 |
+
else:
|
182 |
+
converged = ~notyet # converged correspondences
|
183 |
+
|
184 |
+
# keep only unique correspondences, and sort on xy1
|
185 |
+
xy1, xy2 = merge_corres(xy1[converged], xy2[converged], (H1, W1), (H2, W2), ret_xy=ret_xy)
|
186 |
+
if ret_basin:
|
187 |
+
return xy1, xy2, basin
|
188 |
+
return xy1, xy2
|
189 |
+
|
190 |
+
|
191 |
+
def extract_correspondences_nonsym(A, B, confA, confB, subsample=8, device=None, ptmap_key='pred_desc', pixel_tol=0):
|
192 |
+
if '3d' in ptmap_key:
|
193 |
+
opt = dict(device='cpu', workers=32)
|
194 |
+
else:
|
195 |
+
opt = dict(device=device, dist='dot', block_size=2**13)
|
196 |
+
|
197 |
+
# matching the two pairs
|
198 |
+
idx1 = []
|
199 |
+
idx2 = []
|
200 |
+
# merge corres from opposite pairs
|
201 |
+
HA, WA = A.shape[:2]
|
202 |
+
HB, WB = B.shape[:2]
|
203 |
+
if pixel_tol == 0:
|
204 |
+
nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=subsample, ret_xy=False, **opt)
|
205 |
+
nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=subsample, ret_xy=False, **opt)
|
206 |
+
else:
|
207 |
+
S = subsample
|
208 |
+
yA, xA = np.mgrid[S // 2:HA:S, S // 2:WA:S].reshape(2, -1)
|
209 |
+
yB, xB = np.mgrid[S // 2:HB:S, S // 2:WB:S].reshape(2, -1)
|
210 |
+
|
211 |
+
nn1to2 = fast_reciprocal_NNs(A, B, subsample_or_initxy1=(xA, yA), ret_xy=False, pixel_tol=pixel_tol, **opt)
|
212 |
+
nn2to1 = fast_reciprocal_NNs(B, A, subsample_or_initxy1=(xB, yB), ret_xy=False, pixel_tol=pixel_tol, **opt)
|
213 |
+
|
214 |
+
idx1 = np.r_[nn1to2[0], nn2to1[1]]
|
215 |
+
idx2 = np.r_[nn1to2[1], nn2to1[0]]
|
216 |
+
|
217 |
+
c1 = confA.ravel()[idx1]
|
218 |
+
c2 = confB.ravel()[idx2]
|
219 |
+
|
220 |
+
xy1, xy2, idx = merge_corres(idx1, idx2, (HA, WA), (HB, WB), ret_xy=True, ret_index=True)
|
221 |
+
conf = np.minimum(c1[idx], c2[idx])
|
222 |
+
corres = (xy1.copy(), xy2.copy(), conf)
|
223 |
+
return todevice(corres, device)
|
mast3r/losses.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Implementation of MASt3R training losses
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import numpy as np
|
10 |
+
from sklearn.metrics import average_precision_score
|
11 |
+
|
12 |
+
import mast3r.utils.path_to_dust3r # noqa
|
13 |
+
from dust3r.losses import BaseCriterion, Criterion, MultiLoss, Sum, ConfLoss
|
14 |
+
from dust3r.losses import Regr3D as Regr3D_dust3r
|
15 |
+
from dust3r.utils.geometry import (geotrf, inv, normalize_pointcloud)
|
16 |
+
from dust3r.inference import get_pred_pts3d
|
17 |
+
from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale
|
18 |
+
|
19 |
+
|
20 |
+
def apply_log_to_norm(xyz):
|
21 |
+
d = xyz.norm(dim=-1, keepdim=True)
|
22 |
+
xyz = xyz / d.clip(min=1e-8)
|
23 |
+
xyz = xyz * torch.log1p(d)
|
24 |
+
return xyz
|
25 |
+
|
26 |
+
|
27 |
+
class Regr3D (Regr3D_dust3r):
|
28 |
+
def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False, opt_fit_gt=False,
|
29 |
+
sky_loss_value=2, max_metric_scale=False, loss_in_log=False):
|
30 |
+
self.loss_in_log = loss_in_log
|
31 |
+
if norm_mode.startswith('?'):
|
32 |
+
# do no norm pts from metric scale datasets
|
33 |
+
self.norm_all = False
|
34 |
+
self.norm_mode = norm_mode[1:]
|
35 |
+
else:
|
36 |
+
self.norm_all = True
|
37 |
+
self.norm_mode = norm_mode
|
38 |
+
super().__init__(criterion, self.norm_mode, gt_scale)
|
39 |
+
|
40 |
+
self.sky_loss_value = sky_loss_value
|
41 |
+
self.max_metric_scale = max_metric_scale
|
42 |
+
|
43 |
+
def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None):
|
44 |
+
# everything is normalized w.r.t. camera of view1
|
45 |
+
in_camera1 = inv(gt1['camera_pose'])
|
46 |
+
gt_pts1 = geotrf(in_camera1, gt1['pts3d']) # B,H,W,3
|
47 |
+
gt_pts2 = geotrf(in_camera1, gt2['pts3d']) # B,H,W,3
|
48 |
+
|
49 |
+
valid1 = gt1['valid_mask'].clone()
|
50 |
+
valid2 = gt2['valid_mask'].clone()
|
51 |
+
|
52 |
+
if dist_clip is not None:
|
53 |
+
# points that are too far-away == invalid
|
54 |
+
dis1 = gt_pts1.norm(dim=-1) # (B, H, W)
|
55 |
+
dis2 = gt_pts2.norm(dim=-1) # (B, H, W)
|
56 |
+
valid1 = valid1 & (dis1 <= dist_clip)
|
57 |
+
valid2 = valid2 & (dis2 <= dist_clip)
|
58 |
+
|
59 |
+
if self.loss_in_log == 'before':
|
60 |
+
# this only make sense when depth_mode == 'linear'
|
61 |
+
gt_pts1 = apply_log_to_norm(gt_pts1)
|
62 |
+
gt_pts2 = apply_log_to_norm(gt_pts2)
|
63 |
+
|
64 |
+
pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False).clone()
|
65 |
+
pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True).clone()
|
66 |
+
|
67 |
+
if not self.norm_all:
|
68 |
+
if self.max_metric_scale:
|
69 |
+
B = valid1.shape[0]
|
70 |
+
# valid1: B, H, W
|
71 |
+
# torch.linalg.norm(gt_pts1, dim=-1) -> B, H, W
|
72 |
+
# dist1_to_cam1 -> reshape to B, H*W
|
73 |
+
dist1_to_cam1 = torch.where(valid1, torch.linalg.norm(gt_pts1, dim=-1), 0).view(B, -1)
|
74 |
+
dist2_to_cam1 = torch.where(valid2, torch.linalg.norm(gt_pts2, dim=-1), 0).view(B, -1)
|
75 |
+
|
76 |
+
# is_metric_scale: B
|
77 |
+
# dist1_to_cam1.max(dim=-1).values -> B
|
78 |
+
gt1['is_metric_scale'] = gt1['is_metric_scale'] \
|
79 |
+
& (dist1_to_cam1.max(dim=-1).values < self.max_metric_scale) \
|
80 |
+
& (dist2_to_cam1.max(dim=-1).values < self.max_metric_scale)
|
81 |
+
gt2['is_metric_scale'] = gt1['is_metric_scale']
|
82 |
+
|
83 |
+
mask = ~gt1['is_metric_scale']
|
84 |
+
else:
|
85 |
+
mask = torch.ones_like(gt1['is_metric_scale'])
|
86 |
+
# normalize 3d points
|
87 |
+
if self.norm_mode and mask.any():
|
88 |
+
pr_pts1[mask], pr_pts2[mask] = normalize_pointcloud(pr_pts1[mask], pr_pts2[mask], self.norm_mode,
|
89 |
+
valid1[mask], valid2[mask])
|
90 |
+
|
91 |
+
if self.norm_mode and not self.gt_scale:
|
92 |
+
gt_pts1, gt_pts2, norm_factor = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode,
|
93 |
+
valid1, valid2, ret_factor=True)
|
94 |
+
# apply the same normalization to prediction
|
95 |
+
pr_pts1[~mask] = pr_pts1[~mask] / norm_factor[~mask]
|
96 |
+
pr_pts2[~mask] = pr_pts2[~mask] / norm_factor[~mask]
|
97 |
+
|
98 |
+
# return sky segmentation, making sure they don't include any labelled 3d points
|
99 |
+
sky1 = gt1['sky_mask'] & (~valid1)
|
100 |
+
sky2 = gt2['sky_mask'] & (~valid2)
|
101 |
+
return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, sky1, sky2, {}
|
102 |
+
|
103 |
+
def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
|
104 |
+
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
|
105 |
+
self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw)
|
106 |
+
|
107 |
+
if self.sky_loss_value > 0:
|
108 |
+
assert self.criterion.reduction == 'none', 'sky_loss_value should be 0 if no conf loss'
|
109 |
+
# add the sky pixel as "valid" pixels...
|
110 |
+
mask1 = mask1 | sky1
|
111 |
+
mask2 = mask2 | sky2
|
112 |
+
|
113 |
+
# loss on img1 side
|
114 |
+
pred_pts1 = pred_pts1[mask1]
|
115 |
+
gt_pts1 = gt_pts1[mask1]
|
116 |
+
if self.loss_in_log and self.loss_in_log != 'before':
|
117 |
+
# this only make sense when depth_mode == 'exp'
|
118 |
+
pred_pts1 = apply_log_to_norm(pred_pts1)
|
119 |
+
gt_pts1 = apply_log_to_norm(gt_pts1)
|
120 |
+
l1 = self.criterion(pred_pts1, gt_pts1)
|
121 |
+
|
122 |
+
# loss on gt2 side
|
123 |
+
pred_pts2 = pred_pts2[mask2]
|
124 |
+
gt_pts2 = gt_pts2[mask2]
|
125 |
+
if self.loss_in_log and self.loss_in_log != 'before':
|
126 |
+
pred_pts2 = apply_log_to_norm(pred_pts2)
|
127 |
+
gt_pts2 = apply_log_to_norm(gt_pts2)
|
128 |
+
l2 = self.criterion(pred_pts2, gt_pts2)
|
129 |
+
|
130 |
+
if self.sky_loss_value > 0:
|
131 |
+
assert self.criterion.reduction == 'none', 'sky_loss_value should be 0 if no conf loss'
|
132 |
+
# ... but force the loss to be high there
|
133 |
+
l1 = torch.where(sky1[mask1], self.sky_loss_value, l1)
|
134 |
+
l2 = torch.where(sky2[mask2], self.sky_loss_value, l2)
|
135 |
+
self_name = type(self).__name__
|
136 |
+
details = {self_name + '_pts3d_1': float(l1.mean()), self_name + '_pts3d_2': float(l2.mean())}
|
137 |
+
return Sum((l1, mask1), (l2, mask2)), (details | monitoring)
|
138 |
+
|
139 |
+
|
140 |
+
class Regr3D_ShiftInv (Regr3D):
|
141 |
+
""" Same than Regr3D but invariant to depth shift.
|
142 |
+
"""
|
143 |
+
|
144 |
+
def get_all_pts3d(self, gt1, gt2, pred1, pred2):
|
145 |
+
# compute unnormalized points
|
146 |
+
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
|
147 |
+
super().get_all_pts3d(gt1, gt2, pred1, pred2)
|
148 |
+
|
149 |
+
# compute median depth
|
150 |
+
gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2]
|
151 |
+
pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2]
|
152 |
+
gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None]
|
153 |
+
pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None]
|
154 |
+
|
155 |
+
# subtract the median depth
|
156 |
+
gt_z1 -= gt_shift_z
|
157 |
+
gt_z2 -= gt_shift_z
|
158 |
+
pred_z1 -= pred_shift_z
|
159 |
+
pred_z2 -= pred_shift_z
|
160 |
+
|
161 |
+
# monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach())
|
162 |
+
return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring
|
163 |
+
|
164 |
+
|
165 |
+
class Regr3D_ScaleInv (Regr3D):
|
166 |
+
""" Same than Regr3D but invariant to depth scale.
|
167 |
+
if gt_scale == True: enforce the prediction to take the same scale than GT
|
168 |
+
"""
|
169 |
+
|
170 |
+
def get_all_pts3d(self, gt1, gt2, pred1, pred2):
|
171 |
+
# compute depth-normalized points
|
172 |
+
gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring = \
|
173 |
+
super().get_all_pts3d(gt1, gt2, pred1, pred2)
|
174 |
+
|
175 |
+
# measure scene scale
|
176 |
+
_, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2)
|
177 |
+
_, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2)
|
178 |
+
|
179 |
+
# prevent predictions to be in a ridiculous range
|
180 |
+
pred_scale = pred_scale.clip(min=1e-3, max=1e3)
|
181 |
+
|
182 |
+
# subtract the median depth
|
183 |
+
if self.gt_scale:
|
184 |
+
pred_pts1 *= gt_scale / pred_scale
|
185 |
+
pred_pts2 *= gt_scale / pred_scale
|
186 |
+
# monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean())
|
187 |
+
else:
|
188 |
+
gt_pts1 /= gt_scale
|
189 |
+
gt_pts2 /= gt_scale
|
190 |
+
pred_pts1 /= pred_scale
|
191 |
+
pred_pts2 /= pred_scale
|
192 |
+
# monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach())
|
193 |
+
|
194 |
+
return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, sky1, sky2, monitoring
|
195 |
+
|
196 |
+
|
197 |
+
class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv):
|
198 |
+
# calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
|
199 |
+
pass
|
200 |
+
|
201 |
+
|
202 |
+
def get_similarities(desc1, desc2, euc=False):
|
203 |
+
if euc: # euclidean distance in same range than similarities
|
204 |
+
dists = (desc1[:, :, None] - desc2[:, None]).norm(dim=-1)
|
205 |
+
sim = 1 / (1 + dists)
|
206 |
+
else:
|
207 |
+
# Compute similarities
|
208 |
+
sim = desc1 @ desc2.transpose(-2, -1)
|
209 |
+
return sim
|
210 |
+
|
211 |
+
|
212 |
+
class MatchingCriterion(BaseCriterion):
|
213 |
+
def __init__(self, reduction='mean', fp=torch.float32):
|
214 |
+
super().__init__(reduction)
|
215 |
+
self.fp = fp
|
216 |
+
|
217 |
+
def forward(self, a, b, valid_matches=None, euc=False):
|
218 |
+
assert a.ndim >= 2 and 1 <= a.shape[-1], f'Bad shape = {a.shape}'
|
219 |
+
dist = self.loss(a.to(self.fp), b.to(self.fp), valid_matches, euc=euc)
|
220 |
+
# one dimension less or reduction to single value
|
221 |
+
assert (valid_matches is None and dist.ndim == a.ndim -
|
222 |
+
1) or self.reduction in ['mean', 'sum', '1-mean', 'none']
|
223 |
+
if self.reduction == 'none':
|
224 |
+
return dist
|
225 |
+
if self.reduction == 'sum':
|
226 |
+
return dist.sum()
|
227 |
+
if self.reduction == 'mean':
|
228 |
+
return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
|
229 |
+
if self.reduction == '1-mean':
|
230 |
+
return 1. - dist.mean() if dist.numel() > 0 else dist.new_ones(())
|
231 |
+
raise ValueError(f'bad {self.reduction=} mode')
|
232 |
+
|
233 |
+
def loss(self, a, b, valid_matches=None):
|
234 |
+
raise NotImplementedError
|
235 |
+
|
236 |
+
|
237 |
+
class InfoNCE(MatchingCriterion):
|
238 |
+
def __init__(self, temperature=0.07, eps=1e-8, mode='all', **kwargs):
|
239 |
+
super().__init__(**kwargs)
|
240 |
+
self.temperature = temperature
|
241 |
+
self.eps = eps
|
242 |
+
assert mode in ['all', 'proper', 'dual']
|
243 |
+
self.mode = mode
|
244 |
+
|
245 |
+
def loss(self, desc1, desc2, valid_matches=None, euc=False):
|
246 |
+
# valid positives are along diagonals
|
247 |
+
B, N, D = desc1.shape
|
248 |
+
B2, N2, D2 = desc2.shape
|
249 |
+
assert B == B2 and D == D2
|
250 |
+
if valid_matches is None:
|
251 |
+
valid_matches = torch.ones([B, N], dtype=bool)
|
252 |
+
# torch.all(valid_matches.sum(dim=-1) > 0) some pairs have no matches????
|
253 |
+
assert valid_matches.shape == torch.Size([B, N]) and valid_matches.sum() > 0
|
254 |
+
|
255 |
+
# Tempered similarities
|
256 |
+
sim = get_similarities(desc1, desc2, euc) / self.temperature
|
257 |
+
sim[sim.isnan()] = -torch.inf # ignore nans
|
258 |
+
# Softmax of positives with temperature
|
259 |
+
sim = sim.exp_() # save peak memory
|
260 |
+
positives = sim.diagonal(dim1=-2, dim2=-1)
|
261 |
+
|
262 |
+
# Loss
|
263 |
+
if self.mode == 'all': # Previous InfoNCE
|
264 |
+
loss = -torch.log((positives / sim.sum(dim=-1).sum(dim=-1, keepdim=True)).clip(self.eps))
|
265 |
+
elif self.mode == 'proper': # Proper InfoNCE
|
266 |
+
loss = -(torch.log((positives / sim.sum(dim=-2)).clip(self.eps)) +
|
267 |
+
torch.log((positives / sim.sum(dim=-1)).clip(self.eps)))
|
268 |
+
elif self.mode == 'dual': # Dual Softmax
|
269 |
+
loss = -(torch.log((positives**2 / sim.sum(dim=-1) / sim.sum(dim=-2)).clip(self.eps)))
|
270 |
+
else:
|
271 |
+
raise ValueError("This should not happen...")
|
272 |
+
return loss[valid_matches]
|
273 |
+
|
274 |
+
|
275 |
+
class APLoss (MatchingCriterion):
|
276 |
+
""" AP loss.
|
277 |
+
"""
|
278 |
+
|
279 |
+
def __init__(self, nq='torch', min=0, max=1, euc=False, **kw):
|
280 |
+
super().__init__(**kw)
|
281 |
+
# Exact/True AP loss (not differentiable)
|
282 |
+
if nq == 0:
|
283 |
+
nq = 'sklearn' # special case
|
284 |
+
try:
|
285 |
+
self.compute_AP = eval('self.compute_true_AP_' + nq)
|
286 |
+
except:
|
287 |
+
raise ValueError("Unknown mode %s for AP loss" % nq)
|
288 |
+
|
289 |
+
@staticmethod
|
290 |
+
def compute_true_AP_sklearn(scores, labels):
|
291 |
+
def compute_AP(label, score):
|
292 |
+
return average_precision_score(label, score)
|
293 |
+
|
294 |
+
aps = scores.new_zeros((scores.shape[0], scores.shape[1]))
|
295 |
+
label_np = labels.cpu().numpy().astype(bool)
|
296 |
+
scores_np = scores.cpu().numpy()
|
297 |
+
for bi in range(scores_np.shape[0]):
|
298 |
+
for i in range(scores_np.shape[1]):
|
299 |
+
labels = label_np[bi, i, :]
|
300 |
+
if labels.sum() < 1:
|
301 |
+
continue
|
302 |
+
aps[bi, i] = compute_AP(labels, scores_np[bi, i, :])
|
303 |
+
return aps
|
304 |
+
|
305 |
+
@staticmethod
|
306 |
+
def compute_true_AP_torch(scores, labels):
|
307 |
+
assert scores.shape == labels.shape
|
308 |
+
B, N, M = labels.shape
|
309 |
+
dev = labels.device
|
310 |
+
with torch.no_grad():
|
311 |
+
# sort scores
|
312 |
+
_, order = scores.sort(dim=-1, descending=True)
|
313 |
+
# sort labels accordingly
|
314 |
+
labels = labels[torch.arange(B, device=dev)[:, None, None].expand(order.shape),
|
315 |
+
torch.arange(N, device=dev)[None, :, None].expand(order.shape),
|
316 |
+
order]
|
317 |
+
# compute number of positives per query
|
318 |
+
npos = labels.sum(dim=-1)
|
319 |
+
assert torch.all(torch.isclose(npos, npos[0, 0])
|
320 |
+
), "only implemented for constant number of positives per query"
|
321 |
+
npos = int(npos[0, 0])
|
322 |
+
# compute precision at each recall point
|
323 |
+
posrank = labels.nonzero()[:, -1].view(B, N, npos)
|
324 |
+
recall = torch.arange(1, 1 + npos, dtype=torch.float32, device=dev)[None, None, :].expand(B, N, npos)
|
325 |
+
precision = recall / (1 + posrank).float()
|
326 |
+
# average precision values at all recall points
|
327 |
+
aps = precision.mean(dim=-1)
|
328 |
+
|
329 |
+
return aps
|
330 |
+
|
331 |
+
def loss(self, desc1, desc2, valid_matches=None, euc=False): # if matches is None, positives are the diagonal
|
332 |
+
B, N1, D = desc1.shape
|
333 |
+
B2, N2, D2 = desc2.shape
|
334 |
+
assert B == B2 and D == D2
|
335 |
+
|
336 |
+
scores = get_similarities(desc1, desc2, euc)
|
337 |
+
|
338 |
+
labels = torch.zeros([B, N1, N2], dtype=scores.dtype, device=scores.device)
|
339 |
+
|
340 |
+
# allow all diagonal positives and only mask afterwards
|
341 |
+
labels.diagonal(dim1=-2, dim2=-1)[...] = 1.
|
342 |
+
apscore = self.compute_AP(scores, labels)
|
343 |
+
if valid_matches is not None:
|
344 |
+
apscore = apscore[valid_matches]
|
345 |
+
return apscore
|
346 |
+
|
347 |
+
|
348 |
+
class MatchingLoss (Criterion, MultiLoss):
|
349 |
+
"""
|
350 |
+
Matching loss per image
|
351 |
+
only compare pixels inside an image but not in the whole batch as what would be done usually
|
352 |
+
"""
|
353 |
+
|
354 |
+
def __init__(self, criterion, withconf=False, use_pts3d=False, negatives_padding=0, blocksize=4096):
|
355 |
+
super().__init__(criterion)
|
356 |
+
self.negatives_padding = negatives_padding
|
357 |
+
self.use_pts3d = use_pts3d
|
358 |
+
self.blocksize = blocksize
|
359 |
+
self.withconf = withconf
|
360 |
+
|
361 |
+
def add_negatives(self, outdesc2, desc2, batchid, x2, y2):
|
362 |
+
if self.negatives_padding:
|
363 |
+
B, H, W, D = desc2.shape
|
364 |
+
negatives = torch.ones([B, H, W], device=desc2.device, dtype=bool)
|
365 |
+
negatives[batchid, y2, x2] = False
|
366 |
+
sel = negatives & (negatives.view([B, -1]).cumsum(dim=-1).view(B, H, W)
|
367 |
+
<= self.negatives_padding) # take the N-first negatives
|
368 |
+
outdesc2 = torch.cat([outdesc2, desc2[sel].view([B, -1, D])], dim=1)
|
369 |
+
return outdesc2
|
370 |
+
|
371 |
+
def get_confs(self, pred1, pred2, sel1, sel2):
|
372 |
+
if self.withconf:
|
373 |
+
if self.use_pts3d:
|
374 |
+
outconfs1 = pred1['conf'][sel1]
|
375 |
+
outconfs2 = pred2['conf'][sel2]
|
376 |
+
else:
|
377 |
+
outconfs1 = pred1['desc_conf'][sel1]
|
378 |
+
outconfs2 = pred2['desc_conf'][sel2]
|
379 |
+
else:
|
380 |
+
outconfs1 = outconfs2 = None
|
381 |
+
return outconfs1, outconfs2
|
382 |
+
|
383 |
+
def get_descs(self, pred1, pred2):
|
384 |
+
if self.use_pts3d:
|
385 |
+
desc1, desc2 = pred1['pts3d'], pred2['pts3d_in_other_view']
|
386 |
+
else:
|
387 |
+
desc1, desc2 = pred1['desc'], pred2['desc']
|
388 |
+
return desc1, desc2
|
389 |
+
|
390 |
+
def get_matching_descs(self, gt1, gt2, pred1, pred2, **kw):
|
391 |
+
outdesc1 = outdesc2 = outconfs1 = outconfs2 = None
|
392 |
+
# Recover descs, GT corres and valid mask
|
393 |
+
desc1, desc2 = self.get_descs(pred1, pred2)
|
394 |
+
|
395 |
+
(x1, y1), (x2, y2) = gt1['corres'].unbind(-1), gt2['corres'].unbind(-1)
|
396 |
+
valid_matches = gt1['valid_corres']
|
397 |
+
|
398 |
+
# Select descs that have GT matches
|
399 |
+
B, N = x1.shape
|
400 |
+
batchid = torch.arange(B)[:, None].repeat(1, N) # B, N
|
401 |
+
outdesc1, outdesc2 = desc1[batchid, y1, x1], desc2[batchid, y2, x2] # B, N, D
|
402 |
+
|
403 |
+
# Padd with unused negatives
|
404 |
+
outdesc2 = self.add_negatives(outdesc2, desc2, batchid, x2, y2)
|
405 |
+
|
406 |
+
# Gather confs if needed
|
407 |
+
sel1 = batchid, y1, x1
|
408 |
+
sel2 = batchid, y2, x2
|
409 |
+
outconfs1, outconfs2 = self.get_confs(pred1, pred2, sel1, sel2)
|
410 |
+
|
411 |
+
return outdesc1, outdesc2, outconfs1, outconfs2, valid_matches, {'use_euclidean_dist': self.use_pts3d}
|
412 |
+
|
413 |
+
def blockwise_criterion(self, descs1, descs2, confs1, confs2, valid_matches, euc, rng=np.random, shuffle=True):
|
414 |
+
loss = None
|
415 |
+
details = {}
|
416 |
+
B, N, D = descs1.shape
|
417 |
+
|
418 |
+
if N <= self.blocksize: # Blocks are larger than provided descs, compute regular loss
|
419 |
+
loss = self.criterion(descs1, descs2, valid_matches, euc=euc)
|
420 |
+
else: # Compute criterion on the blockdiagonal only, after shuffling
|
421 |
+
# Shuffle if necessary
|
422 |
+
matches_perm = slice(None)
|
423 |
+
if shuffle:
|
424 |
+
matches_perm = np.stack([rng.choice(range(N), size=N, replace=False) for _ in range(B)])
|
425 |
+
batchid = torch.tile(torch.arange(B), (N, 1)).T
|
426 |
+
matches_perm = batchid, matches_perm
|
427 |
+
|
428 |
+
descs1 = descs1[matches_perm]
|
429 |
+
descs2 = descs2[matches_perm]
|
430 |
+
valid_matches = valid_matches[matches_perm]
|
431 |
+
|
432 |
+
assert N % self.blocksize == 0, "Error, can't chunk block-diagonal, please check blocksize"
|
433 |
+
n_chunks = N // self.blocksize
|
434 |
+
descs1 = descs1.reshape([B * n_chunks, self.blocksize, D]) # [B*(N//blocksize), blocksize, D]
|
435 |
+
descs2 = descs2.reshape([B * n_chunks, self.blocksize, D]) # [B*(N//blocksize), blocksize, D]
|
436 |
+
valid_matches = valid_matches.view([B * n_chunks, self.blocksize])
|
437 |
+
loss = self.criterion(descs1, descs2, valid_matches, euc=euc)
|
438 |
+
if self.withconf:
|
439 |
+
confs1, confs2 = map(lambda x: x[matches_perm], (confs1, confs2)) # apply perm to confidences if needed
|
440 |
+
|
441 |
+
if self.withconf:
|
442 |
+
# split confidences between positives/negatives for loss computation
|
443 |
+
details['conf_pos'] = map(lambda x: x[valid_matches.view(B, -1)], (confs1, confs2))
|
444 |
+
details['conf_neg'] = map(lambda x: x[~valid_matches.view(B, -1)], (confs1, confs2))
|
445 |
+
details['Conf1_std'] = confs1.std()
|
446 |
+
details['Conf2_std'] = confs2.std()
|
447 |
+
|
448 |
+
return loss, details
|
449 |
+
|
450 |
+
def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
|
451 |
+
# Gather preds and GT
|
452 |
+
descs1, descs2, confs1, confs2, valid_matches, monitoring = self.get_matching_descs(
|
453 |
+
gt1, gt2, pred1, pred2, **kw)
|
454 |
+
|
455 |
+
# loss on matches
|
456 |
+
loss, details = self.blockwise_criterion(descs1, descs2, confs1, confs2,
|
457 |
+
valid_matches, euc=monitoring.pop('use_euclidean_dist', False))
|
458 |
+
|
459 |
+
details[type(self).__name__] = float(loss.mean())
|
460 |
+
return loss, (details | monitoring)
|
461 |
+
|
462 |
+
|
463 |
+
class ConfMatchingLoss(ConfLoss):
|
464 |
+
""" Weight matching by learned confidence. Same as ConfLoss but for a matching criterion
|
465 |
+
Assuming the input matching_loss is a match-level loss.
|
466 |
+
"""
|
467 |
+
|
468 |
+
def __init__(self, pixel_loss, alpha=1., confmode='prod', neg_conf_loss_quantile=False):
|
469 |
+
super().__init__(pixel_loss, alpha)
|
470 |
+
self.pixel_loss.withconf = True
|
471 |
+
self.confmode = confmode
|
472 |
+
self.neg_conf_loss_quantile = neg_conf_loss_quantile
|
473 |
+
|
474 |
+
def aggregate_confs(self, confs1, confs2): # get the confidences resulting from the two view predictions
|
475 |
+
if self.confmode == 'prod':
|
476 |
+
confs = confs1 * confs2 if confs1 is not None and confs2 is not None else 1.
|
477 |
+
elif self.confmode == 'mean':
|
478 |
+
confs = .5 * (confs1 + confs2) if confs1 is not None and confs2 is not None else 1.
|
479 |
+
else:
|
480 |
+
raise ValueError(f"Unknown conf mode {self.confmode}")
|
481 |
+
return confs
|
482 |
+
|
483 |
+
def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
|
484 |
+
# compute per-pixel loss
|
485 |
+
loss, details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw)
|
486 |
+
# Recover confidences for positive and negative samples
|
487 |
+
conf1_pos, conf2_pos = details.pop('conf_pos')
|
488 |
+
conf1_neg, conf2_neg = details.pop('conf_neg')
|
489 |
+
conf_pos = self.aggregate_confs(conf1_pos, conf2_pos)
|
490 |
+
|
491 |
+
# weight Matching loss by confidence on positives
|
492 |
+
conf_pos, log_conf_pos = self.get_conf_log(conf_pos)
|
493 |
+
conf_loss = loss * conf_pos - self.alpha * log_conf_pos
|
494 |
+
# average + nan protection (in case of no valid pixels at all)
|
495 |
+
conf_loss = conf_loss.mean() if conf_loss.numel() > 0 else 0
|
496 |
+
# Add negative confs loss to give some supervision signal to confidences for pixels that are not matched in GT
|
497 |
+
if self.neg_conf_loss_quantile:
|
498 |
+
conf_neg = torch.cat([conf1_neg, conf2_neg])
|
499 |
+
conf_neg, log_conf_neg = self.get_conf_log(conf_neg)
|
500 |
+
|
501 |
+
# recover quantile that will be used for negatives loss value assignment
|
502 |
+
neg_loss_value = torch.quantile(loss, self.neg_conf_loss_quantile).detach()
|
503 |
+
neg_loss = neg_loss_value * conf_neg - self.alpha * log_conf_neg
|
504 |
+
|
505 |
+
neg_loss = neg_loss.mean() if neg_loss.numel() > 0 else 0
|
506 |
+
conf_loss = conf_loss + neg_loss
|
507 |
+
|
508 |
+
return conf_loss, dict(matching_conf_loss=float(conf_loss), **details)
|
mast3r/model.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# MASt3R model class
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import os
|
10 |
+
|
11 |
+
from mast3r.catmlp_dpt_head import mast3r_head_factory
|
12 |
+
|
13 |
+
import mast3r.utils.path_to_dust3r # noqa
|
14 |
+
from dust3r.model import AsymmetricCroCo3DStereo # noqa
|
15 |
+
from dust3r.utils.misc import transpose_to_landscape # noqa
|
16 |
+
|
17 |
+
|
18 |
+
inf = float('inf')
|
19 |
+
|
20 |
+
|
21 |
+
def load_model(model_path, device, verbose=True):
|
22 |
+
if verbose:
|
23 |
+
print('... loading model from', model_path)
|
24 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
25 |
+
args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
|
26 |
+
if 'landscape_only' not in args:
|
27 |
+
args = args[:-1] + ', landscape_only=False)'
|
28 |
+
else:
|
29 |
+
args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
|
30 |
+
assert "landscape_only=False" in args
|
31 |
+
if verbose:
|
32 |
+
print(f"instantiating : {args}")
|
33 |
+
net = eval(args)
|
34 |
+
s = net.load_state_dict(ckpt['model'], strict=False)
|
35 |
+
if verbose:
|
36 |
+
print(s)
|
37 |
+
return net.to(device)
|
38 |
+
|
39 |
+
|
40 |
+
class AsymmetricMASt3R(AsymmetricCroCo3DStereo):
|
41 |
+
def __init__(self, desc_mode=('norm'), two_confs=False, desc_conf_mode=None, **kwargs):
|
42 |
+
self.desc_mode = desc_mode
|
43 |
+
self.two_confs = two_confs
|
44 |
+
self.desc_conf_mode = desc_conf_mode
|
45 |
+
super().__init__(**kwargs)
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
|
49 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
50 |
+
return load_model(pretrained_model_name_or_path, device='cpu')
|
51 |
+
else:
|
52 |
+
return super(AsymmetricMASt3R, cls).from_pretrained(pretrained_model_name_or_path, **kw)
|
53 |
+
|
54 |
+
def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size, **kw):
|
55 |
+
assert img_size[0] % patch_size == 0 and img_size[
|
56 |
+
1] % patch_size == 0, f'{img_size=} must be multiple of {patch_size=}'
|
57 |
+
self.output_mode = output_mode
|
58 |
+
self.head_type = head_type
|
59 |
+
self.depth_mode = depth_mode
|
60 |
+
self.conf_mode = conf_mode
|
61 |
+
if self.desc_conf_mode is None:
|
62 |
+
self.desc_conf_mode = conf_mode
|
63 |
+
# allocate heads
|
64 |
+
self.downstream_head1 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
|
65 |
+
self.downstream_head2 = mast3r_head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
|
66 |
+
# magic wrapper
|
67 |
+
self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
|
68 |
+
self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
|
mast3r/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
mast3r/utils/coarse_to_fine.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# coarse to fine utilities
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def crop_tag(cell):
|
11 |
+
return f'[{cell[1]}:{cell[3]},{cell[0]}:{cell[2]}]'
|
12 |
+
|
13 |
+
|
14 |
+
def crop_slice(cell):
|
15 |
+
return slice(cell[1], cell[3]), slice(cell[0], cell[2])
|
16 |
+
|
17 |
+
|
18 |
+
def _start_pos(total_size, win_size, overlap):
|
19 |
+
# we must have AT LEAST overlap between segments
|
20 |
+
# first segment starts at 0, last segment starts at total_size-win_size
|
21 |
+
assert 0 <= overlap < 1
|
22 |
+
assert total_size >= win_size
|
23 |
+
spacing = win_size * (1 - overlap)
|
24 |
+
last_pt = total_size - win_size
|
25 |
+
n_windows = 2 + int((last_pt - 1) // spacing)
|
26 |
+
return np.linspace(0, last_pt, n_windows).round().astype(int)
|
27 |
+
|
28 |
+
|
29 |
+
def multiple_of_16(x):
|
30 |
+
return (x // 16) * 16
|
31 |
+
|
32 |
+
|
33 |
+
def _make_overlapping_grid(H, W, size, overlap):
|
34 |
+
H_win = multiple_of_16(H * size // max(H, W))
|
35 |
+
W_win = multiple_of_16(W * size // max(H, W))
|
36 |
+
x = _start_pos(W, W_win, overlap)
|
37 |
+
y = _start_pos(H, H_win, overlap)
|
38 |
+
grid = np.stack(np.meshgrid(x, y, indexing='xy'), axis=-1)
|
39 |
+
grid = np.concatenate((grid, grid + (W_win, H_win)), axis=-1)
|
40 |
+
return grid.reshape(-1, 4)
|
41 |
+
|
42 |
+
|
43 |
+
def _cell_size(cell2):
|
44 |
+
width, height = cell2[:, 2] - cell2[:, 0], cell2[:, 3] - cell2[:, 1]
|
45 |
+
assert width.min() >= 0
|
46 |
+
assert height.min() >= 0
|
47 |
+
return width, height
|
48 |
+
|
49 |
+
|
50 |
+
def _norm_windows(cell2, H2, W2, forced_resolution=None):
|
51 |
+
# make sure the window aspect ratio is 3/4, or the output resolution is forced_resolution if defined
|
52 |
+
outcell = cell2.copy()
|
53 |
+
width, height = _cell_size(cell2)
|
54 |
+
width2, height2 = width.clip(max=W2), height.clip(max=H2)
|
55 |
+
if forced_resolution is None:
|
56 |
+
width2[width < height] = (height2[width < height] * 3.01 / 4).clip(max=W2)
|
57 |
+
height2[width >= height] = (width2[width >= height] * 3.01 / 4).clip(max=H2)
|
58 |
+
else:
|
59 |
+
forced_H, forced_W = forced_resolution
|
60 |
+
width2[:] = forced_W
|
61 |
+
height2[:] = forced_H
|
62 |
+
|
63 |
+
half = (width2 - width) / 2
|
64 |
+
outcell[:, 0] -= half
|
65 |
+
outcell[:, 2] += half
|
66 |
+
half = (height2 - height) / 2
|
67 |
+
outcell[:, 1] -= half
|
68 |
+
outcell[:, 3] += half
|
69 |
+
|
70 |
+
# proj to integers
|
71 |
+
outcell = np.floor(outcell).astype(int)
|
72 |
+
# Take care of flooring errors
|
73 |
+
tmpw, tmph = _cell_size(outcell)
|
74 |
+
outcell[:, 0] += tmpw.astype(tmpw.dtype) - width2.astype(tmpw.dtype)
|
75 |
+
outcell[:, 1] += tmph.astype(tmpw.dtype) - height2.astype(tmpw.dtype)
|
76 |
+
|
77 |
+
# make sure 0 <= x < W2 and 0 <= y < H2
|
78 |
+
outcell[:, 0::2] -= outcell[:, [0]].clip(max=0)
|
79 |
+
outcell[:, 1::2] -= outcell[:, [1]].clip(max=0)
|
80 |
+
outcell[:, 0::2] -= outcell[:, [2]].clip(min=W2) - W2
|
81 |
+
outcell[:, 1::2] -= outcell[:, [3]].clip(min=H2) - H2
|
82 |
+
|
83 |
+
width, height = _cell_size(outcell)
|
84 |
+
assert np.all(width == width2.astype(width.dtype)) and np.all(
|
85 |
+
height == height2.astype(height.dtype)), "Error, output is not of the expected shape."
|
86 |
+
assert np.all(width <= W2)
|
87 |
+
assert np.all(height <= H2)
|
88 |
+
return outcell
|
89 |
+
|
90 |
+
|
91 |
+
def _weight_pixels(cell, pix, assigned, gauss_var=2):
|
92 |
+
center = cell.reshape(-1, 2, 2).mean(axis=1)
|
93 |
+
width, height = _cell_size(cell)
|
94 |
+
|
95 |
+
# square distance between each cell center and each point
|
96 |
+
dist = (center[:, None] - pix[None]) / np.c_[width, height][:, None]
|
97 |
+
dist2 = np.square(dist).sum(axis=-1)
|
98 |
+
|
99 |
+
assert assigned.shape == dist2.shape
|
100 |
+
res = np.where(assigned, np.exp(-gauss_var * dist2), 0)
|
101 |
+
return res
|
102 |
+
|
103 |
+
|
104 |
+
def pos2d_in_rect(p1, cell1):
|
105 |
+
x, y = p1.T
|
106 |
+
l, t, r, b = cell1
|
107 |
+
assigned = (l <= x) & (x < r) & (t <= y) & (y < b)
|
108 |
+
return assigned
|
109 |
+
|
110 |
+
|
111 |
+
def _score_cell(cell1, H2, W2, p1, p2, min_corres=10, forced_resolution=None):
|
112 |
+
assert p1.shape == p2.shape
|
113 |
+
|
114 |
+
# compute keypoint assignment
|
115 |
+
assigned = pos2d_in_rect(p1, cell1[None].T)
|
116 |
+
assert assigned.shape == (len(cell1), len(p1))
|
117 |
+
|
118 |
+
# remove cells without correspondences
|
119 |
+
valid_cells = assigned.sum(axis=1) >= min_corres
|
120 |
+
cell1 = cell1[valid_cells]
|
121 |
+
assigned = assigned[valid_cells]
|
122 |
+
if not valid_cells.any():
|
123 |
+
return cell1, cell1, assigned
|
124 |
+
|
125 |
+
# fill-in the assigned points in both image
|
126 |
+
assigned_p1 = np.empty((len(cell1), len(p1), 2), dtype=np.float32)
|
127 |
+
assigned_p2 = np.empty((len(cell1), len(p2), 2), dtype=np.float32)
|
128 |
+
assigned_p1[:] = p1[None]
|
129 |
+
assigned_p2[:] = p2[None]
|
130 |
+
assigned_p1[~assigned] = np.nan
|
131 |
+
assigned_p2[~assigned] = np.nan
|
132 |
+
|
133 |
+
# find the median center and scale of assigned points in each cell
|
134 |
+
# cell_center1 = np.nanmean(assigned_p1, axis=1)
|
135 |
+
cell_center2 = np.nanmean(assigned_p2, axis=1)
|
136 |
+
im1_q25, im1_q75 = np.nanquantile(assigned_p1, (0.1, 0.9), axis=1)
|
137 |
+
im2_q25, im2_q75 = np.nanquantile(assigned_p2, (0.1, 0.9), axis=1)
|
138 |
+
|
139 |
+
robust_std1 = (im1_q75 - im1_q25).clip(20.)
|
140 |
+
robust_std2 = (im2_q75 - im2_q25).clip(20.)
|
141 |
+
|
142 |
+
cell_size1 = (cell1[:, 2:4] - cell1[:, 0:2])
|
143 |
+
cell_size2 = cell_size1 * robust_std2 / robust_std1
|
144 |
+
cell2 = np.c_[cell_center2 - cell_size2 / 2, cell_center2 + cell_size2 / 2]
|
145 |
+
|
146 |
+
# make sure cell bounds are valid
|
147 |
+
cell2 = _norm_windows(cell2, H2, W2, forced_resolution=forced_resolution)
|
148 |
+
|
149 |
+
# compute correspondence weights
|
150 |
+
corres_weights = _weight_pixels(cell1, p1, assigned) * _weight_pixels(cell2, p2, assigned)
|
151 |
+
|
152 |
+
# return a list of window pairs and assigned correspondences
|
153 |
+
return cell1, cell2, corres_weights
|
154 |
+
|
155 |
+
|
156 |
+
def greedy_selection(corres_weights, target=0.9):
|
157 |
+
# corres_weight = (n_cell_pair, n_corres) matrix.
|
158 |
+
# If corres_weight[c,p]>0, means that correspondence p is visible in cell pair p
|
159 |
+
assert 0 < target <= 1
|
160 |
+
corres_weights = corres_weights.copy()
|
161 |
+
|
162 |
+
total = corres_weights.max(axis=0).sum()
|
163 |
+
target *= total
|
164 |
+
|
165 |
+
# init = empty
|
166 |
+
res = []
|
167 |
+
cur = np.zeros(corres_weights.shape[1]) # current selection
|
168 |
+
|
169 |
+
while cur.sum() < target:
|
170 |
+
# pick the nex best cell pair
|
171 |
+
best = corres_weights.sum(axis=1).argmax()
|
172 |
+
res.append(best)
|
173 |
+
|
174 |
+
# update current
|
175 |
+
cur += corres_weights[best]
|
176 |
+
# print('appending', best, 'with score', corres_weights[best].sum(), '-->', cur.sum())
|
177 |
+
|
178 |
+
# remove from all other views
|
179 |
+
corres_weights = (corres_weights - corres_weights[best]).clip(min=0)
|
180 |
+
|
181 |
+
return res
|
182 |
+
|
183 |
+
|
184 |
+
def select_pairs_of_crops(img_q, img_b, pos2d_in_query, pos2d_in_ref, maxdim=512, overlap=.5, forced_resolution=None):
|
185 |
+
# prepare the overlapping cells
|
186 |
+
grid_q = _make_overlapping_grid(*img_q.shape[:2], maxdim, overlap)
|
187 |
+
grid_b = _make_overlapping_grid(*img_b.shape[:2], maxdim, overlap)
|
188 |
+
|
189 |
+
assert forced_resolution is None or len(forced_resolution) == 2
|
190 |
+
if isinstance(forced_resolution[0], int) or not len(forced_resolution[0]) == 2:
|
191 |
+
forced_resolution1 = forced_resolution2 = forced_resolution
|
192 |
+
else:
|
193 |
+
assert len(forced_resolution[1]) == 2
|
194 |
+
forced_resolution1 = forced_resolution[0]
|
195 |
+
forced_resolution2 = forced_resolution[1]
|
196 |
+
|
197 |
+
# Make sure crops respect constraints
|
198 |
+
grid_q = _norm_windows(grid_q.astype(float), *img_q.shape[:2], forced_resolution=forced_resolution1)
|
199 |
+
grid_b = _norm_windows(grid_b.astype(float), *img_b.shape[:2], forced_resolution=forced_resolution2)
|
200 |
+
|
201 |
+
# score cells
|
202 |
+
pairs_q = _score_cell(grid_q, *img_b.shape[:2], pos2d_in_query, pos2d_in_ref, forced_resolution=forced_resolution2)
|
203 |
+
pairs_b = _score_cell(grid_b, *img_q.shape[:2], pos2d_in_ref, pos2d_in_query, forced_resolution=forced_resolution1)
|
204 |
+
pairs_b = pairs_b[1], pairs_b[0], pairs_b[2] # cellq, cellb, corres_weights
|
205 |
+
|
206 |
+
# greedy selection until all correspondences are generated
|
207 |
+
cell1, cell2, corres_weights = map(np.concatenate, zip(pairs_q, pairs_b))
|
208 |
+
if len(corres_weights) == 0:
|
209 |
+
return # tolerated for empty generators
|
210 |
+
order = greedy_selection(corres_weights, target=0.9)
|
211 |
+
|
212 |
+
for i in order:
|
213 |
+
def pair_tag(qi, bi): return (str(qi) + crop_tag(cell1[i]), str(bi) + crop_tag(cell2[i]))
|
214 |
+
yield cell1[i], cell2[i], pair_tag
|
mast3r/utils/collate.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Collate extensions
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import collections
|
10 |
+
from torch.utils.data._utils.collate import default_collate_fn_map, default_collate_err_msg_format
|
11 |
+
from typing import Callable, Dict, Optional, Tuple, Type, Union, List
|
12 |
+
|
13 |
+
|
14 |
+
def cat_collate_tensor_fn(batch, *, collate_fn_map):
|
15 |
+
return torch.cat(batch, dim=0)
|
16 |
+
|
17 |
+
|
18 |
+
def cat_collate_list_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
|
19 |
+
return [item for bb in batch for item in bb] # concatenate all lists
|
20 |
+
|
21 |
+
|
22 |
+
cat_collate_fn_map = default_collate_fn_map.copy()
|
23 |
+
cat_collate_fn_map[torch.Tensor] = cat_collate_tensor_fn
|
24 |
+
cat_collate_fn_map[List] = cat_collate_list_fn
|
25 |
+
cat_collate_fn_map[type(None)] = lambda _, **kw: None # When some Nones, simply return a single None
|
26 |
+
|
27 |
+
|
28 |
+
def cat_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
|
29 |
+
r"""Custom collate function that concatenates stuff instead of stacking them, and handles NoneTypes """
|
30 |
+
elem = batch[0]
|
31 |
+
elem_type = type(elem)
|
32 |
+
|
33 |
+
if collate_fn_map is not None:
|
34 |
+
if elem_type in collate_fn_map:
|
35 |
+
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
|
36 |
+
|
37 |
+
for collate_type in collate_fn_map:
|
38 |
+
if isinstance(elem, collate_type):
|
39 |
+
return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)
|
40 |
+
|
41 |
+
if isinstance(elem, collections.abc.Mapping):
|
42 |
+
try:
|
43 |
+
return elem_type({key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
|
44 |
+
except TypeError:
|
45 |
+
# The mapping type may not support `__init__(iterable)`.
|
46 |
+
return {key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
|
47 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
48 |
+
return elem_type(*(cat_collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
|
49 |
+
elif isinstance(elem, collections.abc.Sequence):
|
50 |
+
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
|
51 |
+
|
52 |
+
if isinstance(elem, tuple):
|
53 |
+
# Backwards compatibility.
|
54 |
+
return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
|
55 |
+
else:
|
56 |
+
try:
|
57 |
+
return elem_type([cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
|
58 |
+
except TypeError:
|
59 |
+
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
|
60 |
+
return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
|
61 |
+
|
62 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
mast3r/utils/misc.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utilitary functions for MASt3R
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import os
|
8 |
+
import hashlib
|
9 |
+
|
10 |
+
|
11 |
+
def mkdir_for(f):
|
12 |
+
os.makedirs(os.path.dirname(f), exist_ok=True)
|
13 |
+
return f
|
14 |
+
|
15 |
+
|
16 |
+
def hash_md5(s):
|
17 |
+
return hashlib.md5(s.encode('utf-8')).hexdigest()
|
mast3r/utils/path_to_dust3r.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# dust3r submodule import
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import sys
|
9 |
+
import os.path as path
|
10 |
+
HERE_PATH = path.normpath(path.dirname(__file__))
|
11 |
+
DUSt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../'))
|
12 |
+
DUSt3R_LIB_PATH = path.join(DUSt3R_REPO_PATH, 'dust3r')
|
13 |
+
# check the presence of models directory in repo to be sure its cloned
|
14 |
+
if path.isdir(DUSt3R_LIB_PATH):
|
15 |
+
# workaround for sibling import
|
16 |
+
sys.path.insert(0, DUSt3R_REPO_PATH)
|
17 |
+
else:
|
18 |
+
raise ImportError(f"dust3r is not initialized, could not find: {DUSt3R_LIB_PATH}.\n "
|
19 |
+
"Did you forget to run 'git submodule update --init --recursive' ?")
|