mast3r-sfm / mast3r /demo_glomap.py
yocabon's picture
add initial version of mast3r sfm and glomap/colmap wrapper
35e2575
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# gradio demo functions
# --------------------------------------------------------
import pycolmap
import gradio
import os
import numpy as np
import functools
import trimesh
import copy
from scipy.spatial.transform import Rotation
import tempfile
import shutil
import PIL.Image
import torch
from kapture.converter.colmap.database_extra import kapture_to_colmap
from kapture.converter.colmap.database import COLMAPDatabase
from mast3r.colmap.mapping import kapture_import_image_folder_or_list, run_mast3r_matching, glomap_run_mapper
from mast3r.demo import set_scenegraph_options
from mast3r.retrieval.processor import Retriever
from mast3r.image_pairs import make_pairs
import mast3r.utils.path_to_dust3r # noqa
from dust3r.utils.image import load_images
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL
from dust3r.demo import get_args_parser as dust3r_get_args_parser
import matplotlib.pyplot as pl
class GlomapRecon:
def __init__(self, world_to_cam, intrinsics, points3d, imgs):
self.world_to_cam = world_to_cam
self.intrinsics = intrinsics
self.points3d = points3d
self.imgs = imgs
class GlomapReconState:
def __init__(self, glomap_recon, should_delete=False, cache_dir=None, outfile_name=None):
self.glomap_recon = glomap_recon
self.cache_dir = cache_dir
self.outfile_name = outfile_name
self.should_delete = should_delete
def __del__(self):
if not self.should_delete:
return
if self.cache_dir is not None and os.path.isdir(self.cache_dir):
shutil.rmtree(self.cache_dir)
self.cache_dir = None
if self.outfile_name is not None and os.path.isfile(self.outfile_name):
os.remove(self.outfile_name)
self.outfile_name = None
def get_args_parser():
parser = dust3r_get_args_parser()
parser.add_argument('--share', action='store_true')
parser.add_argument('--gradio_delete_cache', default=None, type=int,
help='age/frequency at which gradio removes the file. If >0, matching cache is purged')
parser.add_argument('--glomap_bin', default='glomap', type=str, help='glomap bin')
parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded")
actions = parser._actions
for action in actions:
if action.dest == 'model_name':
action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
# change defaults
parser.prog = 'mast3r demo'
return parser
def get_reconstructed_scene(glomap_bin, outdir, gradio_delete_cache, model, retrieval_model, device, silent, image_size,
current_scene_state, filelist, transparent_cams, cam_size, scenegraph_type, winsize,
win_cyclic, refid, shared_intrinsics, **kw):
"""
from a list of images, run mast3r inference, sparse global aligner.
then run get_3D_model_from_scene
"""
imgs = load_images(filelist, size=image_size, verbose=not silent)
if len(imgs) == 1:
imgs = [imgs[0], copy.deepcopy(imgs[0])]
imgs[1]['idx'] = 1
filelist = [filelist[0], filelist[0]]
scene_graph_params = [scenegraph_type]
if scenegraph_type in ["swin", "logwin"]:
scene_graph_params.append(str(winsize))
elif scenegraph_type == "oneref":
scene_graph_params.append(str(refid))
elif scenegraph_type == "retrieval":
scene_graph_params.append(str(winsize)) # Na
scene_graph_params.append(str(refid)) # k
if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
scene_graph_params.append('noncyclic')
scene_graph = '-'.join(scene_graph_params)
sim_matrix = None
if 'retrieval' in scenegraph_type:
assert retrieval_model is not None
retriever = Retriever(retrieval_model, backbone=model, device=device)
with torch.no_grad():
sim_matrix = retriever(filelist)
# Cleanup
del retriever
torch.cuda.empty_cache()
pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix)
if current_scene_state is not None and \
not current_scene_state.should_delete and \
current_scene_state.cache_dir is not None:
cache_dir = current_scene_state.cache_dir
elif gradio_delete_cache:
cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir)
else:
cache_dir = os.path.join(outdir, 'cache')
root_path = os.path.commonpath(filelist)
filelist_relpath = [
os.path.relpath(filename, root_path).replace('\\', '/')
for filename in filelist
]
kdata = kapture_import_image_folder_or_list((root_path, filelist_relpath), shared_intrinsics)
image_pairs = [
(filelist_relpath[img1['idx']], filelist_relpath[img2['idx']])
for img1, img2 in pairs
]
colmap_db_path = os.path.join(cache_dir, 'colmap.db')
if os.path.isfile(colmap_db_path):
os.remove(colmap_db_path)
os.makedirs(os.path.dirname(colmap_db_path), exist_ok=True)
colmap_db = COLMAPDatabase.connect(colmap_db_path)
try:
kapture_to_colmap(kdata, root_path, tar_handler=None, database=colmap_db,
keypoints_type=None, descriptors_type=None, export_two_view_geometry=False)
colmap_image_pairs = run_mast3r_matching(model, image_size, 16, device,
kdata, root_path, image_pairs, colmap_db,
False, 5, 1.001,
False, 3)
colmap_db.close()
except Exception as e:
print(f'Error {e}')
colmap_db.close()
exit(1)
if len(colmap_image_pairs) == 0:
raise Exception("no matches were kept")
# colmap db is now full, run colmap
colmap_world_to_cam = {}
print("verify_matches")
f = open(cache_dir + '/pairs.txt', "w")
for image_path1, image_path2 in colmap_image_pairs:
f.write("{} {}\n".format(image_path1, image_path2))
f.close()
pycolmap.verify_matches(colmap_db_path, cache_dir + '/pairs.txt')
reconstruction_path = os.path.join(cache_dir, "reconstruction")
if os.path.isdir(reconstruction_path):
shutil.rmtree(reconstruction_path)
os.makedirs(reconstruction_path, exist_ok=True)
glomap_run_mapper(glomap_bin, colmap_db_path, reconstruction_path, root_path)
if current_scene_state is not None and \
not current_scene_state.should_delete and \
current_scene_state.outfile_name is not None:
outfile_name = current_scene_state.outfile_name
else:
outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir)
ouput_recon = pycolmap.Reconstruction(os.path.join(reconstruction_path, '0'))
print(ouput_recon.summary())
colmap_world_to_cam = {}
colmap_intrinsics = {}
colmap_image_id_to_name = {}
images = {}
num_reg_images = ouput_recon.num_reg_images()
for idx, (colmap_imgid, colmap_image) in enumerate(ouput_recon.images.items()):
colmap_image_id_to_name[colmap_imgid] = colmap_image.name
if callable(colmap_image.cam_from_world.matrix):
colmap_world_to_cam[colmap_imgid] = colmap_image.cam_from_world.matrix(
)
else:
colmap_world_to_cam[colmap_imgid] = colmap_image.cam_from_world.matrix
camera = ouput_recon.cameras[colmap_image.camera_id]
K = np.eye(3)
K[0, 0] = camera.focal_length_x
K[1, 1] = camera.focal_length_y
K[0, 2] = camera.principal_point_x
K[1, 2] = camera.principal_point_y
colmap_intrinsics[colmap_imgid] = K
with PIL.Image.open(os.path.join(root_path, colmap_image.name)) as im:
images[colmap_imgid] = np.asarray(im)
if idx + 1 == num_reg_images:
break # bug with the iterable ?
points3D = []
num_points3D = ouput_recon.num_points3D()
for idx, (pt3d_id, pts3d) in enumerate(ouput_recon.points3D.items()):
points3D.append((pts3d.xyz, pts3d.color))
if idx + 1 == num_points3D:
break # bug with the iterable ?
scene = GlomapRecon(colmap_world_to_cam, colmap_intrinsics, points3D, images)
scene_state = GlomapReconState(scene, gradio_delete_cache, cache_dir, outfile_name)
outfile = get_3D_model_from_scene(silent, scene_state, transparent_cams, cam_size)
return scene_state, outfile
def get_3D_model_from_scene(silent, scene_state, transparent_cams=False, cam_size=0.05):
"""
extract 3D_model (glb file) from a reconstructed scene
"""
if scene_state is None:
return None
outfile = scene_state.outfile_name
if outfile is None:
return None
recon = scene_state.glomap_recon
scene = trimesh.Scene()
pts = np.stack([p[0] for p in recon.points3d], axis=0)
col = np.stack([p[1] for p in recon.points3d], axis=0)
pct = trimesh.PointCloud(pts, colors=col)
scene.add_geometry(pct)
# add each camera
cams2world = []
for i, (id, pose_w2c_3x4) in enumerate(recon.world_to_cam.items()):
intrinsics = recon.intrinsics[id]
focal = (intrinsics[0, 0] + intrinsics[1, 1]) / 2.0
camera_edge_color = CAM_COLORS[i % len(CAM_COLORS)]
pose_w2c = np.eye(4)
pose_w2c[:3, :] = pose_w2c_3x4
pose_c2w = np.linalg.inv(pose_w2c)
cams2world.append(pose_c2w)
add_scene_cam(scene, pose_c2w, camera_edge_color,
None if transparent_cams else recon.imgs[id], focal,
imsize=recon.imgs[id].shape[1::-1], screen_width=cam_size)
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
if not silent:
print('(exporting 3D scene to', outfile, ')')
scene.export(file_obj=outfile)
return outfile
def main_demo(glomap_bin, tmpdirname, model, retrieval_model, device, image_size, server_name, server_port,
silent=False, share=False, gradio_delete_cache=False):
if not silent:
print('Outputing stuff in', tmpdirname)
recon_fun = functools.partial(get_reconstructed_scene, glomap_bin, tmpdirname, gradio_delete_cache, model,
retrieval_model, device, silent, image_size)
model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent)
available_scenegraph_type = [("complete: all possible image pairs", "complete"),
("swin: sliding window", "swin"),
("logwin: sliding window with long range", "logwin"),
("oneref: match one image with all", "oneref")]
if retrieval_model is not None:
available_scenegraph_type.insert(1, ("retrieval: connect views based on similarity", "retrieval"))
def get_context(delete_cache):
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
title = "MASt3R Demo"
if delete_cache:
return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
else:
return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
with get_context(gradio_delete_cache) as demo:
# scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
scene = gradio.State(None)
gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
with gradio.Column():
inputfiles = gradio.File(file_count="multiple")
with gradio.Row():
shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
info="Only optimize one set of intrinsics for all views")
scenegraph_type = gradio.Dropdown(available_scenegraph_type,
value='complete', label="Scenegraph",
info="Define how to make pairs",
interactive=True)
with gradio.Column(visible=False) as win_col:
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
minimum=1, maximum=1, step=1)
win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
refid = gradio.Slider(label="Scene Graph: Id", value=0,
minimum=0, maximum=0, step=1, visible=False)
run_btn = gradio.Button("Run")
with gradio.Row():
# adjust the camera size in the output pointcloud
cam_size = gradio.Slider(label="cam_size", value=0.01, minimum=0.001, maximum=1.0, step=0.001)
with gradio.Row():
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
outmodel = gradio.Model3D()
# events
scenegraph_type.change(set_scenegraph_options,
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
outputs=[win_col, winsize, win_cyclic, refid])
inputfiles.change(set_scenegraph_options,
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
outputs=[win_col, winsize, win_cyclic, refid])
win_cyclic.change(set_scenegraph_options,
inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
outputs=[win_col, winsize, win_cyclic, refid])
run_btn.click(fn=recon_fun,
inputs=[scene, inputfiles, transparent_cams, cam_size,
scenegraph_type, winsize, win_cyclic, refid, shared_intrinsics],
outputs=[scene, outmodel])
cam_size.change(fn=model_from_scene_fun,
inputs=[scene, transparent_cams, cam_size],
outputs=outmodel)
transparent_cams.change(model_from_scene_fun,
inputs=[scene, transparent_cams, cam_size],
outputs=outmodel)
demo.launch(share=share, server_name=server_name, server_port=server_port)