Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch | |
import numpy as np | |
import gradio as gr | |
import urllib.parse | |
import tempfile | |
import subprocess | |
from dust3r.losses import L21 | |
from spann3r.model import Spann3R | |
from mast3r.model import AsymmetricMASt3R | |
from spann3r.datasets import Demo | |
from torch.utils.data import DataLoader | |
import cv2 | |
import json | |
import glob | |
from dust3r.post_process import estimate_focal_knowing_depth | |
from mast3r.demo import get_reconstructed_scene | |
from scipy.spatial.transform import Rotation | |
from transformers import AutoModelForImageSegmentation | |
from torchvision import transforms | |
from PIL import Image | |
import open3d as o3d | |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds | |
from gs_utils import point2gs | |
from pose_utils import solve_cemara | |
from gradio.helpers import Examples as GradioExamples | |
from gradio.utils import get_cache_folder | |
from pathlib import Path | |
import os | |
import shutil | |
import math | |
import zipfile | |
from pathlib import Path | |
# Default values | |
DEFAULT_CKPT_PATH = 'checkpoints/spann3r.pth' | |
DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth' | |
DEFAULT_MAST3R_PATH = 'https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth' | |
DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
OPENGL = np.array([[1, 0, 0, 0], | |
[0, -1, 0, 0], | |
[0, 0, -1, 0], | |
[0, 0, 0, 1]]) | |
class Examples(GradioExamples): | |
def __init__(self, *args, directory_name=None, **kwargs): | |
super().__init__(*args, **kwargs, _initiated_directly=False) | |
if directory_name is not None: | |
self.cached_folder = get_cache_folder() / directory_name | |
self.cached_file = Path(self.cached_folder) / "log.csv" | |
self.create() | |
def export_geometry(geometry, file_format='obj'): | |
""" | |
Export Open3D geometry (triangle mesh or point cloud) to a file. | |
Args: | |
geometry: Open3D geometry object (TriangleMesh or PointCloud) | |
file_format: str, output format ('obj', 'ply', 'pcd') | |
Returns: | |
str: Path to the exported file | |
Raises: | |
ValueError: If geometry type is not supported or file format is invalid | |
""" | |
# Validate geometry type | |
if not isinstance(geometry, (o3d.geometry.TriangleMesh, o3d.geometry.PointCloud)): | |
raise ValueError("Geometry must be either TriangleMesh or PointCloud") | |
# Validate and set file format | |
supported_formats = { | |
'obj': '.obj', | |
'ply': '.ply', | |
'pcd': '.pcd' | |
} | |
if file_format.lower() not in supported_formats: | |
raise ValueError(f"Unsupported file format. Supported formats: {list(supported_formats.keys())}") | |
# Create temporary file with appropriate extension | |
output_path = tempfile.mktemp(suffix=supported_formats[file_format.lower()]) | |
# Create a copy of the geometry to avoid modifying the original | |
geometry_copy = geometry | |
# Apply rotation | |
rot = np.eye(4) | |
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() | |
transform = np.linalg.inv(OPENGL @ rot) | |
# Transform geometry | |
geometry_copy.transform(transform) | |
# Export based on geometry type and format | |
try: | |
if isinstance(geometry_copy, o3d.geometry.TriangleMesh): | |
if file_format.lower() == 'obj': | |
o3d.io.write_triangle_mesh(output_path, geometry_copy, | |
write_ascii=False, compressed=True) | |
elif file_format.lower() == 'ply': | |
o3d.io.write_triangle_mesh(output_path, geometry_copy, | |
write_ascii=False, compressed=True) | |
elif isinstance(geometry_copy, o3d.geometry.PointCloud): | |
if file_format.lower() == 'pcd': | |
o3d.io.write_point_cloud(output_path, geometry_copy, | |
write_ascii=False, compressed=True) | |
elif file_format.lower() == 'ply': | |
o3d.io.write_point_cloud(output_path, geometry_copy, | |
write_ascii=False, compressed=True) | |
else: | |
raise ValueError(f"Format {file_format} not supported for point clouds. Use 'ply' or 'pcd'") | |
return output_path | |
except Exception as e: | |
# Clean up temporary file if export fails | |
if os.path.exists(output_path): | |
os.remove(output_path) | |
raise RuntimeError(f"Failed to export geometry: {str(e)}") | |
def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) -> str: | |
temp_dir = tempfile.mkdtemp() | |
output_path = os.path.join(temp_dir, "%03d.jpg") | |
filter_complex = f"select='if(lt(t,{duration}),1,0)',fps={fps}" | |
command = [ | |
"ffmpeg", | |
"-i", video_path, | |
"-vf", filter_complex, | |
"-vsync", "0", | |
output_path | |
] | |
subprocess.run(command, check=True) | |
return temp_dir | |
def load_ckpt(model_path_or_url, verbose=True): | |
if verbose: | |
print('... loading model from', model_path_or_url) | |
is_url = urllib.parse.urlparse(model_path_or_url).scheme in ('http', 'https') | |
if is_url: | |
ckpt = torch.hub.load_state_dict_from_url(model_path_or_url, map_location='cpu', progress=verbose) | |
else: | |
ckpt = torch.load(model_path_or_url, map_location='cpu') | |
return ckpt | |
def load_model(ckpt_path, device): | |
model = Spann3R(dus3r_name=DEFAULT_DUST3R_PATH, | |
use_feat=False).to(device) | |
model.load_state_dict(load_ckpt(ckpt_path)['model']) | |
model.eval() | |
return model | |
model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE) | |
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True) | |
birefnet.to(DEFAULT_DEVICE) | |
birefnet.eval() | |
def extract_object(birefnet, image): | |
# Data settings | |
image_size = (1024, 1024) | |
transform_image = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
input_images = transform_image(image).unsqueeze(0).to(DEFAULT_DEVICE) | |
# Prediction | |
with torch.no_grad(): | |
preds = birefnet(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image.size) | |
return mask | |
def generate_mask(image: np.ndarray): | |
# Convert numpy array to PIL Image | |
pil_image = Image.fromarray((image * 255).astype(np.uint8)) | |
# Extract object and get mask | |
mask = extract_object(birefnet, pil_image) | |
# Convert mask to numpy array | |
mask_np = np.array(mask) / 255.0 | |
return mask_np | |
def center_pcd(pcd: o3d.geometry.PointCloud, normalize=False) -> o3d.geometry.PointCloud: | |
# Convert to numpy array | |
points = np.asarray(pcd.points) | |
# Compute centroid | |
centroid = np.mean(points, axis=0) | |
# Center the point cloud | |
centered_points = points - centroid | |
if normalize: | |
# Compute the maximum distance from the center | |
max_distance = np.max(np.linalg.norm(centered_points, axis=1)) | |
# Normalize the point cloud | |
normalized_points = centered_points / max_distance | |
# Create a new point cloud with the normalized points | |
normalized_pcd = o3d.geometry.PointCloud() | |
normalized_pcd.points = o3d.utility.Vector3dVector(normalized_points) | |
# If the original point cloud has colors, normalize them too | |
if pcd.has_colors(): | |
normalized_pcd.colors = pcd.colors | |
# If the original point cloud has normals, copy them | |
if pcd.has_normals(): | |
normalized_pcd.normals = pcd.normals | |
return normalized_pcd | |
else: | |
pcd.points = o3d.utility.Vector3dVector(centered_points) | |
return pcd | |
def center_mesh(mesh: o3d.geometry.TriangleMesh, normalize=False) -> o3d.geometry.TriangleMesh: | |
# Convert to numpy array | |
vertices = np.asarray(mesh.vertices) | |
# Compute centroid | |
centroid = np.mean(vertices, axis=0) | |
# Center the mesh | |
centered_vertices = vertices - centroid | |
if normalize: | |
# Compute the maximum distance from the center | |
max_distance = np.max(np.linalg.norm(centered_vertices, axis=1)) | |
# Normalize the mesh | |
normalized_vertices = centered_vertices / max_distance | |
# Create a new mesh with the normalized vertices | |
normalized_mesh = o3d.geometry.TriangleMesh() | |
normalized_mesh.vertices = o3d.utility.Vector3dVector(normalized_vertices) | |
normalized_mesh.triangles = mesh.triangles | |
# If the original mesh has vertex colors, copy them | |
if mesh.has_vertex_colors(): | |
normalized_mesh.vertex_colors = mesh.vertex_colors | |
# If the original mesh has vertex normals, normalize them | |
if mesh.has_vertex_normals(): | |
vertex_normals = np.asarray(mesh.vertex_normals) | |
normalized_vertex_normals = vertex_normals / np.linalg.norm(vertex_normals, axis=1, keepdims=True) | |
normalized_mesh.vertex_normals = o3d.utility.Vector3dVector(normalized_vertex_normals) | |
return normalized_mesh | |
else: | |
# Update the mesh with the centered vertices | |
mesh.vertices = o3d.utility.Vector3dVector(centered_vertices) | |
return mesh | |
def get_transform_json(H, W, focal, poses_all): | |
transform_dict = { | |
'w': W, | |
'h': H, | |
'fl_x': focal.item(), | |
'fl_y': focal.item(), | |
'cx': W/2, | |
'cy': H/2, | |
} | |
frames = [] | |
for i, pose in enumerate(poses_all): | |
# CV2 GL format | |
pose[:3, 1] *= -1 | |
pose[:3, 2] *= -1 | |
frame = { | |
'w': W, | |
'h': H, | |
'fl_x': focal.item(), | |
'fl_y': focal.item(), | |
'cx': W/2, | |
'cy': H/2, | |
'file_path': f"images/{i:04d}.jpg", | |
"mask_path": f"masks/{i:04d}.png", | |
'transform_matrix': pose.tolist() | |
} | |
frames.append(frame) | |
transform_dict['frames'] = frames | |
return transform_dict | |
def organize_and_zip_output(images_all, masks_all, transform_json_path, output_dir=None): | |
""" | |
Organizes reconstruction outputs into a specific directory structure and creates a zip file. | |
Args: | |
images_all: List of numpy arrays containing images | |
masks_all: List of numpy arrays containing masks | |
transform_json_path: Path to the transform.json file | |
output_dir: Optional custom output directory name | |
Returns: | |
str: Path to the created zip file | |
""" | |
try: | |
# Create temporary directory with timestamp | |
timestamp = time.strftime("%Y%m%d_%H%M%S") | |
base_dir = output_dir or f"reconstruction_{timestamp}" | |
os.makedirs(base_dir, exist_ok=True) | |
# Create subdirectories | |
images_dir = os.path.join(base_dir, "images") | |
masks_dir = os.path.join(base_dir, "masks") | |
os.makedirs(images_dir, exist_ok=True) | |
os.makedirs(masks_dir, exist_ok=True) | |
# Save images | |
for i, image in enumerate(images_all): | |
image_path = os.path.join(images_dir, f"{i:04d}.jpg") | |
cv2.imwrite(image_path, (image * 255).astype(np.uint8)[..., ::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 90]) | |
# Save masks | |
for i, mask in enumerate(masks_all): | |
mask_path = os.path.join(masks_dir, f"{i:04d}.png") | |
cv2.imwrite(mask_path, (mask * 255).astype(np.uint8)) | |
# Copy transform.json | |
shutil.copy2(transform_json_path, os.path.join(base_dir, "transforms.json")) | |
# Create zip file | |
zip_path = f"{base_dir}.zip" | |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
for root, _, files in os.walk(base_dir): | |
for file in files: | |
file_path = os.path.join(root, file) | |
arcname = os.path.relpath(file_path, base_dir) | |
zipf.write(file_path, arcname) | |
return zip_path | |
finally: | |
# Clean up temporary directories and files | |
if os.path.exists(base_dir): | |
shutil.rmtree(base_dir) | |
if os.path.exists(transform_json_path): | |
os.remove(transform_json_path) | |
def get_keyframes(temp_dir: str, kf_every: int = 10): | |
""" | |
Select keyframes from a directory of extracted frames at specified intervals | |
Args: | |
temp_dir: Directory containing extracted frames (named as 001.jpg, 002.jpg, etc.) | |
kf_every: Select every Nth frame as a keyframe | |
Returns: | |
List[str]: Sorted list of paths to selected keyframe images | |
""" | |
# Get all jpg files in the directory | |
frame_paths = glob.glob(os.path.join(temp_dir, "*.jpg")) | |
# Sort frames by number to ensure correct order | |
frame_paths.sort(key=lambda x: int(Path(x).stem)) | |
# Select keyframes at specified interval | |
keyframe_paths = frame_paths[::kf_every] | |
# Ensure we have at least 2 frames for reconstruction | |
if len(keyframe_paths) < 2: | |
if len(frame_paths) >= 2: | |
# If we have at least 2 frames, use first and last | |
keyframe_paths = [frame_paths[0], frame_paths[-1]] | |
else: | |
raise ValueError(f"Not enough frames found in {temp_dir}. Need at least 2 frames for reconstruction.") | |
return keyframe_paths | |
def reconstruct(video_path, conf_thresh, kf_every, | |
remove_background=False, enable_registration=True, output_3d_model=True): | |
# Extract frames from video | |
demo_path = extract_frames(video_path) | |
# Load dataset | |
dataset = Demo(ROOT=demo_path, resolution=224, full_video=True, kf_every=kf_every) | |
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) | |
batch = next(iter(dataloader)) | |
for view in batch: | |
view['img'] = view['img'].to(DEFAULT_DEVICE, non_blocking=True) | |
demo_name = os.path.basename(video_path) | |
print(f'Started reconstruction for {demo_name}') | |
start = time.time() | |
preds, preds_all = model.forward(batch) | |
end = time.time() | |
fps = len(batch) / (end - start) | |
print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}') | |
# Process results | |
pcds = [] | |
poses_all = [] | |
cameras_all = [] | |
images_all = [] | |
masks_all = [] | |
last_focal = None | |
##### estimate focal length | |
_, H, W, _ = preds[0]['pts3d'].shape | |
pp = torch.tensor((W/2, H/2)) | |
focal = estimate_focal_knowing_depth(preds[0]['pts3d'].cpu(), pp, focal_mode='weiszfeld') | |
print(f'Estimated focal of first camera: {focal.item()} (224x224)') | |
intrinsic = np.eye(3) | |
intrinsic[0, 0] = focal | |
intrinsic[1, 1] = focal | |
intrinsic[:2, 2] = pp | |
for j, view in enumerate(batch): | |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0] | |
image = (image + 1) / 2 | |
mask = view['valid_mask'].cpu().numpy()[0] | |
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0] | |
pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy() | |
##### Solve PnP-RANSAC | |
u, v = np.meshgrid(np.arange(W), np.arange(H)) | |
points_2d = np.stack((u, v), axis=-1) | |
dist_coeffs = np.zeros(4).astype(np.float32) | |
success, rotation_vector, translation_vector, inliers = cv2.solvePnPRansac( | |
pts.reshape(-1, 3).astype(np.float32), | |
points_2d.reshape(-1, 2).astype(np.float32), | |
intrinsic.astype(np.float32), | |
dist_coeffs) | |
rotation_matrix, _ = cv2.Rodrigues(rotation_vector) | |
# Extrinsic parameters (4x4 matrix) | |
extrinsic_matrix = np.hstack((rotation_matrix, translation_vector.reshape(-1, 1))) | |
extrinsic_matrix = np.vstack((extrinsic_matrix, [0, 0, 0, 1])) | |
poses_all.append(np.linalg.inv(extrinsic_matrix)) | |
conf = preds[j]['conf'][0].cpu().data.numpy() | |
conf_sig = (conf - 1) / conf | |
if remove_background: | |
mask = generate_mask(image) | |
else: | |
mask = np.ones_like(conf) | |
combined_mask = (conf_sig > conf_thresh) & (mask > 0.5) | |
camera, last_focal = solve_cemara(torch.tensor(pts), torch.tensor(conf_sig) > 0.001, | |
"cuda", focal=last_focal) | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(pts[combined_mask]) | |
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask]) | |
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask]) | |
pcds.append(pcd) | |
images_all.append(image) | |
masks_all.append(mask) | |
cameras_all.append(camera) | |
transform_dict = get_transform_json(H, W, focal, poses_all) | |
temp_json_file = tempfile.mktemp(suffix='.json') | |
with open(os.path.join(temp_json_file), 'w') as f: | |
json.dump(transform_dict, f, indent=4) | |
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001) | |
o3d_geometry = point2mesh(pcd_combined) | |
o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True) | |
# Create coarse result | |
coarse_output_path = export_geometry(o3d_geometry_centered) | |
if enable_registration: | |
pcd_combined, _, _ = improved_multiway_registration(pcds, voxel_size=0.01) | |
pcd_combined = center_pcd(pcd_combined) | |
# zip_path = organize_and_zip_output(images_all, masks_all, temp_json_file) | |
if output_3d_model: | |
gs_output_path = tempfile.mktemp(suffix='.ply') | |
point2gs(gs_output_path, pcd_combined) | |
return coarse_output_path, [gs_output_path, temp_json_file] | |
else: | |
pcd_output_path = export_geometry(pcd_combined, file_format='ply') | |
return coarse_output_path, [pcd_output_path, temp_json_file] | |
example_videos = [os.path.join('./examples', f) for f in os.listdir('./examples') if f.endswith(('.mp4', '.webm'))] | |
# Update the Gradio interface with improved layout | |
with gr.Blocks( | |
title="StableRecon: 3D Reconstruction from Video", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
.tabs button.selected { | |
font-size: 20px !important; | |
color: crimson !important; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
.md_feedback li { | |
margin-bottom: 0px !important; | |
} | |
""", | |
head=""" | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag() {dataLayer.push(arguments);} | |
gtag('js', new Date()); | |
gtag('config', 'G-1FWSVCGZTG'); | |
</script> | |
""", | |
) as iface: | |
gr.Markdown( | |
""" | |
# StableRecon: Making Video to 3D easy | |
<p align="center"> | |
<a title="Github" href="https://github.com/Stable-X/StableRecon" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/Stable-X/StableRecon?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
</a> | |
<a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
</a> | |
</p> | |
<div style="background-color: #f0f0f0; padding: 10px; border-radius: 5px; margin-bottom: 20px;"> | |
<strong>📢 About StableRecon:</strong> This is an experimental open-source project building on <a href="https://github.com/naver/dust3r" target="_blank">dust3r</a> and <a href="https://github.com/HengyiWang/spann3r" target="_blank">spann3r</a>. We're exploring video-to-3D conversion, using spann3r for tracking and implementing our own backend and meshing. While it's a work in progress with plenty of room for improvement, we're excited to share it with the community. We welcome your feedback, especially on failure cases, as we continue to develop and refine this tool. | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
video_input = gr.Video(label="Input Video", sources=["upload"]) | |
with gr.Row(): | |
conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold") | |
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval") | |
with gr.Row(): | |
remove_background = gr.Checkbox(label="Remove Background", value=False) | |
enable_registration = gr.Checkbox( | |
label="Enable Refinement", | |
value=False, | |
info="Improves alignment but takes longer" | |
) | |
output_3d_model = gr.Checkbox( | |
label="Output Splat", | |
value=True, | |
info="Generate Splat (PLY) instead of Point Cloud (PLY)" | |
) | |
reconstruct_btn = gr.Button("Start Reconstruction") | |
with gr.Column(scale=2): | |
with gr.Tab("3D Models"): | |
with gr.Group(): | |
initial_model = gr.Model3D( | |
label="Reconstructed Mesh", | |
display_mode="solid", | |
clear_color=[0.0, 0.0, 0.0, 0.0] | |
) | |
with gr.Group(): | |
output_model = gr.File( | |
label="Reconstructed Results", | |
) | |
Examples( | |
fn=reconstruct, | |
examples=sorted([ | |
os.path.join("examples", name) | |
for name in os.listdir(os.path.join("examples")) if name.endswith('.webm') | |
]), | |
inputs=[video_input], | |
outputs=[initial_model, output_model], | |
directory_name="examples_video", | |
cache_examples=False, | |
) | |
reconstruct_btn.click( | |
fn=reconstruct, | |
inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model], | |
outputs=[initial_model, output_model] | |
) | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0") |