Spaces:
Sleeping
Sleeping
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# MASt3R to colmap export functions | |
# -------------------------------------------------------- | |
import os | |
import torch | |
import copy | |
import numpy as np | |
import torchvision | |
import numpy as np | |
from tqdm import tqdm | |
from scipy.cluster.hierarchy import DisjointSet | |
from scipy.spatial.transform import Rotation as R | |
from mast3r.utils.misc import hash_md5 | |
from mast3r.fast_nn import extract_correspondences_nonsym, bruteforce_reciprocal_nns | |
import mast3r.utils.path_to_dust3r # noqa | |
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid, geotrf # noqa | |
def convert_im_matches_pairs(img0, img1, image_to_colmap, im_keypoints, matches_im0, matches_im1, viz): | |
if viz: | |
from matplotlib import pyplot as pl | |
image_mean = torch.as_tensor( | |
[0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) | |
image_std = torch.as_tensor( | |
[0.5, 0.5, 0.5], device='cpu').reshape(1, 3, 1, 1) | |
rgb0 = img0['img'] * image_std + image_mean | |
rgb0 = torchvision.transforms.functional.to_pil_image(rgb0[0]) | |
rgb0 = np.array(rgb0) | |
rgb1 = img1['img'] * image_std + image_mean | |
rgb1 = torchvision.transforms.functional.to_pil_image(rgb1[0]) | |
rgb1 = np.array(rgb1) | |
imgs = [rgb0, rgb1] | |
# visualize a few matches | |
n_viz = 100 | |
num_matches = matches_im0.shape[0] | |
match_idx_to_viz = np.round(np.linspace( | |
0, num_matches - 1, n_viz)).astype(int) | |
viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] | |
H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2] | |
rgb0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), | |
(0, 0), (0, 0)), 'constant', constant_values=0) | |
rgb1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), | |
(0, 0), (0, 0)), 'constant', constant_values=0) | |
img = np.concatenate((rgb0, rgb1), axis=1) | |
pl.figure() | |
pl.imshow(img) | |
cmap = pl.get_cmap('jet') | |
for ii in range(n_viz): | |
(x0, y0), (x1, | |
y1) = viz_matches_im0[ii].T, viz_matches_im1[ii].T | |
pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(ii / | |
(n_viz - 1)), scalex=False, scaley=False) | |
pl.show(block=True) | |
matches = [matches_im0.astype(np.float64), matches_im1.astype(np.float64)] | |
imgs = [img0, img1] | |
imidx0 = img0['idx'] | |
imidx1 = img1['idx'] | |
ravel_matches = [] | |
for j in range(2): | |
H, W = imgs[j]['true_shape'][0] | |
with np.errstate(invalid='ignore'): | |
qx, qy = matches[j].round().astype(np.int32).T | |
ravel_matches_j = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(min=0, max=H - 1, out=qy) | |
ravel_matches.append(ravel_matches_j) | |
imidxj = imgs[j]['idx'] | |
for m in ravel_matches_j: | |
if m not in im_keypoints[imidxj]: | |
im_keypoints[imidxj][m] = 0 | |
im_keypoints[imidxj][m] += 1 | |
imid0 = copy.deepcopy(image_to_colmap[imidx0]['colmap_imid']) | |
imid1 = copy.deepcopy(image_to_colmap[imidx1]['colmap_imid']) | |
if imid0 > imid1: | |
colmap_matches = np.stack([ravel_matches[1], ravel_matches[0]], axis=-1) | |
imid0, imid1 = imid1, imid0 | |
imidx0, imidx1 = imidx1, imidx0 | |
else: | |
colmap_matches = np.stack([ravel_matches[0], ravel_matches[1]], axis=-1) | |
colmap_matches = np.unique(colmap_matches, axis=0) | |
return imidx0, imidx1, colmap_matches | |
def get_im_matches(pred1, pred2, pairs, image_to_colmap, im_keypoints, conf_thr, | |
is_sparse=True, subsample=8, pixel_tol=0, viz=False, device='cuda'): | |
im_matches = {} | |
for i in range(len(pred1['pts3d'])): | |
imidx0 = pairs[i][0]['idx'] | |
imidx1 = pairs[i][1]['idx'] | |
if 'desc' in pred1: # mast3r | |
descs = [pred1['desc'][i], pred2['desc'][i]] | |
confidences = [pred1['desc_conf'][i], pred2['desc_conf'][i]] | |
desc_dim = descs[0].shape[-1] | |
if is_sparse: | |
corres = extract_correspondences_nonsym(descs[0], descs[1], confidences[0], confidences[1], | |
device=device, subsample=subsample, pixel_tol=pixel_tol) | |
conf = corres[2] | |
mask = conf >= conf_thr | |
matches_im0 = corres[0][mask].cpu().numpy() | |
matches_im1 = corres[1][mask].cpu().numpy() | |
else: | |
confidence_masks = [confidences[0] >= | |
conf_thr, confidences[1] >= conf_thr] | |
pts2d_list, desc_list = [], [] | |
for j in range(2): | |
conf_j = confidence_masks[j].cpu().numpy().flatten() | |
true_shape_j = pairs[i][j]['true_shape'][0] | |
pts2d_j = xy_grid( | |
true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j] | |
desc_j = descs[j].detach().cpu( | |
).numpy().reshape(-1, desc_dim)[conf_j] | |
pts2d_list.append(pts2d_j) | |
desc_list.append(desc_j) | |
if len(desc_list[0]) == 0 or len(desc_list[1]) == 0: | |
continue | |
nn0, nn1 = bruteforce_reciprocal_nns(desc_list[0], desc_list[1], | |
device=device, dist='dot', block_size=2**13) | |
reciprocal_in_P0 = (nn1[nn0] == np.arange(len(nn0))) | |
matches_im1 = pts2d_list[1][nn0][reciprocal_in_P0] | |
matches_im0 = pts2d_list[0][reciprocal_in_P0] | |
else: | |
pts3d = [pred1['pts3d'][i], pred2['pts3d_in_other_view'][i]] | |
confidences = [pred1['conf'][i], pred2['conf'][i]] | |
if is_sparse: | |
corres = extract_correspondences_nonsym(pts3d[0], pts3d[1], confidences[0], confidences[1], | |
device=device, subsample=subsample, pixel_tol=pixel_tol, | |
ptmap_key='3d') | |
conf = corres[2] | |
mask = conf >= conf_thr | |
matches_im0 = corres[0][mask].cpu().numpy() | |
matches_im1 = corres[1][mask].cpu().numpy() | |
else: | |
confidence_masks = [confidences[0] >= | |
conf_thr, confidences[1] >= conf_thr] | |
# find 2D-2D matches between the two images | |
pts2d_list, pts3d_list = [], [] | |
for j in range(2): | |
conf_j = confidence_masks[j].cpu().numpy().flatten() | |
true_shape_j = pairs[i][j]['true_shape'][0] | |
pts2d_j = xy_grid(true_shape_j[1], true_shape_j[0]).reshape(-1, 2)[conf_j] | |
pts3d_j = pts3d[j].detach().cpu().numpy().reshape(-1, 3)[conf_j] | |
pts2d_list.append(pts2d_j) | |
pts3d_list.append(pts3d_j) | |
PQ, PM = pts3d_list[0], pts3d_list[1] | |
if len(PQ) == 0 or len(PM) == 0: | |
continue | |
reciprocal_in_PM, nnM_in_PQ, num_matches = find_reciprocal_matches( | |
PQ, PM) | |
matches_im1 = pts2d_list[1][reciprocal_in_PM] | |
matches_im0 = pts2d_list[0][nnM_in_PQ][reciprocal_in_PM] | |
if len(matches_im0) == 0: | |
continue | |
imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1], | |
image_to_colmap, im_keypoints, | |
matches_im0, matches_im1, viz) | |
im_matches[(imidx0, imidx1)] = colmap_matches | |
return im_matches | |
def get_im_matches_from_cache(pairs, cache_path, desc_conf, subsample, | |
image_to_colmap, im_keypoints, conf_thr, | |
viz=False, device='cuda'): | |
im_matches = {} | |
for i in range(len(pairs)): | |
imidx0 = pairs[i][0]['idx'] | |
imidx1 = pairs[i][1]['idx'] | |
corres_idx1 = hash_md5(pairs[i][0]['instance']) | |
corres_idx2 = hash_md5(pairs[i][1]['instance']) | |
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx1}-{corres_idx2}.pth' | |
if os.path.isfile(path_corres): | |
score, (xy1, xy2, confs) = torch.load(path_corres, map_location=device) | |
else: | |
path_corres = cache_path + f'/corres_conf={desc_conf}_{subsample=}/{corres_idx2}-{corres_idx1}.pth' | |
score, (xy2, xy1, confs) = torch.load(path_corres, map_location=device) | |
mask = confs >= conf_thr | |
matches_im0 = xy1[mask].cpu().numpy() | |
matches_im1 = xy2[mask].cpu().numpy() | |
if len(matches_im0) == 0: | |
continue | |
imidx0, imidx1, colmap_matches = convert_im_matches_pairs(pairs[i][0], pairs[i][1], | |
image_to_colmap, im_keypoints, | |
matches_im0, matches_im1, viz) | |
im_matches[(imidx0, imidx1)] = colmap_matches | |
return im_matches | |
def export_images(db, images, image_paths, focals, ga_world_to_cam, camera_model): | |
# add cameras/images to the db | |
# with the output of ga as prior | |
image_to_colmap = {} | |
im_keypoints = {} | |
for idx in range(len(image_paths)): | |
im_keypoints[idx] = {} | |
H, W = images[idx]["orig_shape"] | |
if focals is None: | |
focal_x = focal_y = 1.2 * max(W, H) | |
prior_focal_length = False | |
cx = W / 2.0 | |
cy = H / 2.0 | |
elif isinstance(focals[idx], np.ndarray) and len(focals[idx].shape) == 2: | |
# intrinsics | |
focal_x = focals[idx][0, 0] | |
focal_y = focals[idx][1, 1] | |
cx = focals[idx][0, 2] * images[idx]["to_orig"][0, 0] | |
cy = focals[idx][1, 2] * images[idx]["to_orig"][1, 1] | |
prior_focal_length = True | |
else: | |
focal_x = focal_y = float(focals[idx]) | |
prior_focal_length = True | |
cx = W / 2.0 | |
cy = H / 2.0 | |
focal_x = focal_x * images[idx]["to_orig"][0, 0] | |
focal_y = focal_y * images[idx]["to_orig"][1, 1] | |
if camera_model == "SIMPLE_PINHOLE": | |
model_id = 0 | |
focal = (focal_x + focal_y) / 2.0 | |
params = np.asarray([focal, cx, cy], np.float64) | |
elif camera_model == "PINHOLE": | |
model_id = 1 | |
params = np.asarray([focal_x, focal_y, cx, cy], np.float64) | |
elif camera_model == "SIMPLE_RADIAL": | |
model_id = 2 | |
focal = (focal_x + focal_y) / 2.0 | |
params = np.asarray([focal, cx, cy, 0.0], np.float64) | |
elif camera_model == "OPENCV": | |
model_id = 4 | |
params = np.asarray([focal_x, focal_y, cx, cy, 0.0, 0.0, 0.0, 0.0], np.float64) | |
else: | |
raise ValueError(f"invalid camera model {camera_model}") | |
H, W = int(H), int(W) | |
# OPENCV camera model | |
camid = db.add_camera( | |
model_id, W, H, params, prior_focal_length=prior_focal_length) | |
if ga_world_to_cam is None: | |
prior_t = np.zeros(3) | |
prior_q = np.zeros(4) | |
else: | |
q = R.from_matrix(ga_world_to_cam[idx][:3, :3]).as_quat() | |
prior_t = ga_world_to_cam[idx][:3, 3] | |
prior_q = np.array([q[-1], q[0], q[1], q[2]]) | |
imid = db.add_image( | |
image_paths[idx], camid, prior_q=prior_q, prior_t=prior_t) | |
image_to_colmap[idx] = { | |
'colmap_imid': imid, | |
'colmap_camid': camid | |
} | |
return image_to_colmap, im_keypoints | |
def export_matches(db, images, image_to_colmap, im_keypoints, im_matches, min_len_track, skip_geometric_verification): | |
colmap_image_pairs = [] | |
# 2D-2D are quite dense | |
# we want to remove the very small tracks | |
# and export only kpt for which we have values | |
# build tracks | |
print("building tracks") | |
keypoints_to_track_id = {} | |
track_id_to_kpt_list = [] | |
to_merge = [] | |
for (imidx0, imidx1), colmap_matches in tqdm(im_matches.items()): | |
if imidx0 not in keypoints_to_track_id: | |
keypoints_to_track_id[imidx0] = {} | |
if imidx1 not in keypoints_to_track_id: | |
keypoints_to_track_id[imidx1] = {} | |
for m in colmap_matches: | |
if m[0] not in keypoints_to_track_id[imidx0] and m[1] not in keypoints_to_track_id[imidx1]: | |
# new pair of kpts never seen before | |
track_idx = len(track_id_to_kpt_list) | |
keypoints_to_track_id[imidx0][m[0]] = track_idx | |
keypoints_to_track_id[imidx1][m[1]] = track_idx | |
track_id_to_kpt_list.append( | |
[(imidx0, m[0]), (imidx1, m[1])]) | |
elif m[1] not in keypoints_to_track_id[imidx1]: | |
# 0 has a track, not 1 | |
track_idx = keypoints_to_track_id[imidx0][m[0]] | |
keypoints_to_track_id[imidx1][m[1]] = track_idx | |
track_id_to_kpt_list[track_idx].append((imidx1, m[1])) | |
elif m[0] not in keypoints_to_track_id[imidx0]: | |
# 1 has a track, not 0 | |
track_idx = keypoints_to_track_id[imidx1][m[1]] | |
keypoints_to_track_id[imidx0][m[0]] = track_idx | |
track_id_to_kpt_list[track_idx].append((imidx0, m[0])) | |
else: | |
# both have tracks, merge them | |
track_idx0 = keypoints_to_track_id[imidx0][m[0]] | |
track_idx1 = keypoints_to_track_id[imidx1][m[1]] | |
if track_idx0 != track_idx1: | |
# let's deal with them later | |
to_merge.append((track_idx0, track_idx1)) | |
# regroup merge targets | |
print("merging tracks") | |
unique = np.unique(to_merge) | |
tree = DisjointSet(unique) | |
for track_idx0, track_idx1 in tqdm(to_merge): | |
tree.merge(track_idx0, track_idx1) | |
subsets = tree.subsets() | |
print("applying merge") | |
for setvals in tqdm(subsets): | |
new_trackid = len(track_id_to_kpt_list) | |
kpt_list = [] | |
for track_idx in setvals: | |
kpt_list.extend(track_id_to_kpt_list[track_idx]) | |
for imidx, kpid in track_id_to_kpt_list[track_idx]: | |
keypoints_to_track_id[imidx][kpid] = new_trackid | |
track_id_to_kpt_list.append(kpt_list) | |
# binc = np.bincount([len(v) for v in track_id_to_kpt_list]) | |
# nonzero = np.nonzero(binc) | |
# nonzerobinc = binc[nonzero[0]] | |
# print(nonzero[0].tolist()) | |
# print(nonzerobinc) | |
num_valid_tracks = sum( | |
[1 for v in track_id_to_kpt_list if len(v) >= min_len_track]) | |
keypoints_to_idx = {} | |
print(f"squashing keypoints - {num_valid_tracks} valid tracks") | |
for imidx, keypoints_imid in tqdm(im_keypoints.items()): | |
imid = image_to_colmap[imidx]['colmap_imid'] | |
keypoints_kept = [] | |
keypoints_to_idx[imidx] = {} | |
for kp in keypoints_imid.keys(): | |
if kp not in keypoints_to_track_id[imidx]: | |
continue | |
track_idx = keypoints_to_track_id[imidx][kp] | |
track_length = len(track_id_to_kpt_list[track_idx]) | |
if track_length < min_len_track: | |
continue | |
keypoints_to_idx[imidx][kp] = len(keypoints_kept) | |
keypoints_kept.append(kp) | |
if len(keypoints_kept) == 0: | |
continue | |
keypoints_kept = np.array(keypoints_kept) | |
keypoints_kept = np.unravel_index(keypoints_kept, images[imidx]['true_shape'][0])[ | |
0].base[:, ::-1].copy().astype(np.float32) | |
# rescale coordinates | |
keypoints_kept[:, 0] += 0.5 | |
keypoints_kept[:, 1] += 0.5 | |
keypoints_kept = geotrf(images[imidx]['to_orig'], keypoints_kept, norm=True) | |
H, W = images[imidx]['orig_shape'] | |
keypoints_kept[:, 0] = keypoints_kept[:, 0].clip(min=0, max=W - 0.01) | |
keypoints_kept[:, 1] = keypoints_kept[:, 1].clip(min=0, max=H - 0.01) | |
db.add_keypoints(imid, keypoints_kept) | |
print("exporting im_matches") | |
for (imidx0, imidx1), colmap_matches in im_matches.items(): | |
imid0, imid1 = image_to_colmap[imidx0]['colmap_imid'], image_to_colmap[imidx1]['colmap_imid'] | |
assert imid0 < imid1 | |
final_matches = np.array([[keypoints_to_idx[imidx0][m[0]], keypoints_to_idx[imidx1][m[1]]] | |
for m in colmap_matches | |
if m[0] in keypoints_to_idx[imidx0] and m[1] in keypoints_to_idx[imidx1]]) | |
if len(final_matches) > 0: | |
colmap_image_pairs.append( | |
(images[imidx0]['instance'], images[imidx1]['instance'])) | |
db.add_matches(imid0, imid1, final_matches) | |
if skip_geometric_verification: | |
db.add_two_view_geometry(imid0, imid1, final_matches) | |
return colmap_image_pairs | |