StableRecon / app.py
Stable-X's picture
feat: Add rendering output and refinement flag
2c5f88b
raw
history blame
17.8 kB
import os
import time
import torch
import numpy as np
import gradio as gr
import urllib.parse
import tempfile
import subprocess
from dust3r.losses import L21
from spann3r.model import Spann3R
from spann3r.datasets import Demo
from torch.utils.data import DataLoader
import trimesh
from scipy.spatial.transform import Rotation
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
from PIL import Image
import open3d as o3d
from spann3r.tools.vis import render_frames
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
from gs_utils import point2gs
from pose_utils import solve_cemara
from gradio.helpers import Examples as GradioExamples
from gradio.utils import get_cache_folder
from pathlib import Path
# Default values
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
OPENGL = np.array([[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]])
class Examples(GradioExamples):
def __init__(self, *args, directory_name=None, **kwargs):
super().__init__(*args, **kwargs, _initiated_directly=False)
if directory_name is not None:
self.cached_folder = get_cache_folder() / directory_name
self.cached_file = Path(self.cached_folder) / "log.csv"
self.create()
def export_geometry(geometry):
output_path = tempfile.mktemp(suffix='.obj')
# Apply rotation
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
transform = np.linalg.inv(OPENGL @ rot)
geometry.transform(transform)
o3d.io.write_triangle_mesh(output_path, geometry, write_ascii=False, compressed=True)
return output_path
def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) -> str:
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, "%03d.jpg")
filter_complex = f"select='if(lt(t,{duration}),1,0)',fps={fps}"
command = [
"ffmpeg",
"-i", video_path,
"-vf", filter_complex,
"-vsync", "0",
output_path
]
subprocess.run(command, check=True)
return temp_dir
def cat_meshes(meshes):
vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
n_vertices = np.cumsum([0]+[len(v) for v in vertices])
for i in range(len(faces)):
faces[i][:] += n_vertices[i]
vertices = np.concatenate(vertices)
colors = np.concatenate(colors)
faces = np.concatenate(faces)
return dict(vertices=vertices, face_colors=colors, faces=faces)
def load_ckpt(model_path_or_url, verbose=True):
if verbose:
print('... loading model from', model_path_or_url)
is_url = urllib.parse.urlparse(model_path_or_url).scheme in ('http', 'https')
if is_url:
ckpt = torch.hub.load_state_dict_from_url(model_path_or_url, map_location='cpu', progress=verbose)
else:
ckpt = torch.load(model_path_or_url, map_location='cpu')
return ckpt
def load_model(ckpt_path, device):
model = Spann3R(dus3r_name=DEFAULT_DUST3R_PATH,
use_feat=False).to(device)
model.load_state_dict(load_ckpt(ckpt_path)['model'])
model.eval()
return model
def pts3d_to_trimesh(img, pts3d, valid=None):
H, W, THREE = img.shape
assert THREE == 3
assert img.shape == pts3d.shape
vertices = pts3d.reshape(-1, 3)
# make squares: each pixel == 2 triangles
idx = np.arange(len(vertices)).reshape(H, W)
idx1 = idx[:-1, :-1].ravel() # top-left corner
idx2 = idx[:-1, +1:].ravel() # right-left corner
idx3 = idx[+1:, :-1].ravel() # bottom-left corner
idx4 = idx[+1:, +1:].ravel() # bottom-right corner
faces = np.concatenate((
np.c_[idx1, idx2, idx3],
np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling)
np.c_[idx2, idx3, idx4],
np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling)
), axis=0)
# prepare triangle colors
face_colors = np.concatenate((
img[:-1, :-1].reshape(-1, 3),
img[:-1, :-1].reshape(-1, 3),
img[+1:, +1:].reshape(-1, 3),
img[+1:, +1:].reshape(-1, 3)
), axis=0)
# remove invalid faces
if valid is not None:
assert valid.shape == (H, W)
valid_idxs = valid.ravel()
valid_faces = valid_idxs[faces].all(axis=-1)
faces = faces[valid_faces]
face_colors = face_colors[valid_faces]
assert len(faces) == len(face_colors)
return dict(vertices=vertices, face_colors=face_colors, faces=faces)
model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
birefnet.to(DEFAULT_DEVICE)
birefnet.eval()
def extract_object(birefnet, image):
# Data settings
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(image).unsqueeze(0).to(DEFAULT_DEVICE)
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
return mask
def generate_mask(image: np.ndarray):
# Convert numpy array to PIL Image
pil_image = Image.fromarray((image * 255).astype(np.uint8))
# Extract object and get mask
mask = extract_object(birefnet, pil_image)
# Convert mask to numpy array
mask_np = np.array(mask) / 255.0
return mask_np
def center_pcd(pcd: o3d.geometry.PointCloud, normalize=False) -> o3d.geometry.PointCloud:
# Convert to numpy array
points = np.asarray(pcd.points)
# Compute centroid
centroid = np.mean(points, axis=0)
# Center the point cloud
centered_points = points - centroid
if normalize:
# Compute the maximum distance from the center
max_distance = np.max(np.linalg.norm(centered_points, axis=1))
# Normalize the point cloud
normalized_points = centered_points / max_distance
# Create a new point cloud with the normalized points
normalized_pcd = o3d.geometry.PointCloud()
normalized_pcd.points = o3d.utility.Vector3dVector(normalized_points)
# If the original point cloud has colors, normalize them too
if pcd.has_colors():
normalized_pcd.colors = pcd.colors
# If the original point cloud has normals, copy them
if pcd.has_normals():
normalized_pcd.normals = pcd.normals
return normalized_pcd
else:
pcd.points = o3d.utility.Vector3dVector(centered_points)
return pcd
def center_mesh(mesh: o3d.geometry.TriangleMesh, normalize=False) -> o3d.geometry.TriangleMesh:
# Convert to numpy array
vertices = np.asarray(mesh.vertices)
# Compute centroid
centroid = np.mean(vertices, axis=0)
# Center the mesh
centered_vertices = vertices - centroid
if normalize:
# Compute the maximum distance from the center
max_distance = np.max(np.linalg.norm(centered_vertices, axis=1))
# Normalize the mesh
normalized_vertices = centered_vertices / max_distance
# Create a new mesh with the normalized vertices
normalized_mesh = o3d.geometry.TriangleMesh()
normalized_mesh.vertices = o3d.utility.Vector3dVector(normalized_vertices)
normalized_mesh.triangles = mesh.triangles
# If the original mesh has vertex colors, copy them
if mesh.has_vertex_colors():
normalized_mesh.vertex_colors = mesh.vertex_colors
# If the original mesh has vertex normals, normalize them
if mesh.has_vertex_normals():
vertex_normals = np.asarray(mesh.vertex_normals)
normalized_vertex_normals = vertex_normals / np.linalg.norm(vertex_normals, axis=1, keepdims=True)
normalized_mesh.vertex_normals = o3d.utility.Vector3dVector(normalized_vertex_normals)
return normalized_mesh
else:
# Update the mesh with the centered vertices
mesh.vertices = o3d.utility.Vector3dVector(centered_vertices)
return mesh
@torch.no_grad()
def reconstruct(video_path, conf_thresh, kf_every,
remove_background=False, enable_registration=True, output_3d_model=True):
# Extract frames from video
demo_path = extract_frames(video_path)
# Load dataset
dataset = Demo(ROOT=demo_path, resolution=224, full_video=True, kf_every=kf_every)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
batch = next(iter(dataloader))
for view in batch:
view['img'] = view['img'].to(DEFAULT_DEVICE, non_blocking=True)
demo_name = os.path.basename(video_path)
print(f'Started reconstruction for {demo_name}')
start = time.time()
preds, preds_all = model.forward(batch)
end = time.time()
fps = len(batch) / (end - start)
print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}')
# Process results
pcds = []
cameras_all = []
last_focal = None
for j, view in enumerate(batch):
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
image = (image + 1) / 2
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
conf = preds[j]['conf'][0].cpu().data.numpy()
conf_sig = (conf - 1) / conf
if remove_background:
mask = generate_mask(image)
else:
mask = np.ones_like(conf)
combined_mask = (conf_sig > conf_thresh) & (mask > 0.5)
camera, last_focal = solve_cemara(torch.tensor(pts), torch.tensor(conf_sig) > 0.001,
"cuda", focal=last_focal)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pts[combined_mask])
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
pcds.append(pcd)
cameras_all.append(camera)
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
o3d_geometry = point2mesh(pcd_combined)
o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
# Create coarse result
coarse_output_path = export_geometry(o3d_geometry_centered)
yield coarse_output_path, None
gs_output_path = tempfile.mktemp(suffix='.ply')
if enable_registration:
transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
transformed_pcds = center_pcd(transformed_pcds)
point2gs(gs_output_path, transformed_pcds)
else:
point2gs(gs_output_path, pcd_combined)
if output_3d_model:
# Create 3D model result using gaussian splatting
yield coarse_output_path, gs_output_path
else:
gs_output_path = tempfile.mktemp(suffix='.ply')
render_video_path = render_frames(o3d_geometry, cameras_all, demo_path)
yield coarse_output_path, render_video_path
# Clean up temporary directory
os.system(f"rm -rf {demo_path}")
example_videos = [os.path.join('./examples', f) for f in os.listdir('./examples') if f.endswith(('.mp4', '.webm'))]
# Update the Gradio interface with improved layout
with gr.Blocks(
title="StableRecon: 3D Reconstruction from Video",
css="""
#download {
height: 118px;
}
.slider .inner {
width: 5px;
background: #FFF;
}
.viewport {
aspect-ratio: 4/3;
}
.tabs button.selected {
font-size: 20px !important;
color: crimson !important;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
.md_feedback li {
margin-bottom: 0px !important;
}
""",
head="""
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'G-1FWSVCGZTG');
</script>
""",
) as iface:
gr.Markdown(
"""
# StableRecon: Making Video to 3D easy
<p align="center">
<a title="Github" href="https://github.com/Stable-X/StableRecon" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/Stable-X/StableRecon?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</p>
<div style="background-color: #f0f0f0; padding: 10px; border-radius: 5px; margin-bottom: 20px;">
<strong>📢 About StableRecon:</strong> This is an experimental open-source project building on <a href="https://github.com/naver/dust3r" target="_blank">dust3r</a> and <a href="https://github.com/HengyiWang/spann3r" target="_blank">spann3r</a>. We're exploring video-to-3D conversion, using spann3r for tracking and implementing our own backend and meshing. While it's a work in progress with plenty of room for improvement, we're excited to share it with the community. We welcome your feedback, especially on failure cases, as we continue to develop and refine this tool.
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
video_input = gr.Video(label="Input Video", sources=["upload"])
with gr.Row():
conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold")
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
with gr.Row():
remove_background = gr.Checkbox(label="Remove Background", value=False)
enable_registration = gr.Checkbox(
label="Enable Refinement",
value=False,
info="Improves alignment but takes longer"
)
output_3d_model = gr.Checkbox(
label="Output Splat",
value=True,
info="Generate Splat (PLY) instead of video render"
)
reconstruct_btn = gr.Button("Start Reconstruction")
with gr.Column(scale=2):
with gr.Tab("3D Models"):
with gr.Group():
initial_model = gr.Model3D(
label="Initial 3D Model",
display_mode="solid",
clear_color=[0.0, 0.0, 0.0, 0.0]
)
gr.Markdown(
"""
<div class="model-description">
This is the initial 3D model generated from the video. Finish within 10 seconds.
</div>
"""
)
with gr.Group():
output_model = gr.File(
label="Refined Result (Splat or Video)",
file_types=[".ply", ".mp4"],
file_count="single"
)
gr.Markdown(
"""
<div class="model-description">
Downloads as either:
- PLY file: Gaussin Splat Model (when "Output Splat" is enabled)
- MP4 file: 360° rotating render video (when "Output Splat" is disabled)
<br>Time: ~60 seconds with refinement, ~30 seconds without
</div>
"""
)
Examples(
fn=reconstruct,
examples=sorted([
os.path.join("examples", name)
for name in os.listdir(os.path.join("examples")) if name.endswith('.webm')
]),
inputs=[video_input],
outputs=[initial_model, output_model],
directory_name="examples_video",
cache_examples=False,
)
reconstruct_btn.click(
fn=reconstruct,
inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
outputs=[initial_model, output_model]
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0")