StableRecon / app.py
Stable-X's picture
fix: Update title
1164fc6
raw
history blame
15.8 kB
import os
import time
import torch
import numpy as np
import gradio as gr
import urllib.parse
import tempfile
import subprocess
from dust3r.losses import L21
from spann3r.model import Spann3R
from spann3r.datasets import Demo
from torch.utils.data import DataLoader
import trimesh
from scipy.spatial.transform import Rotation
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
from PIL import Image
import open3d as o3d
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
from gs_utils import point2gs
from gradio.helpers import Examples as GradioExamples
from gradio.utils import get_cache_folder
from pathlib import Path
# Default values
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
OPENGL = np.array([[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]])
class Examples(GradioExamples):
def __init__(self, *args, directory_name=None, **kwargs):
super().__init__(*args, **kwargs, _initiated_directly=False)
if directory_name is not None:
self.cached_folder = get_cache_folder() / directory_name
self.cached_file = Path(self.cached_folder) / "log.csv"
self.create()
def export_geometry(geometry):
output_path = tempfile.mktemp(suffix='.obj')
# Apply rotation
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
transform = np.linalg.inv(OPENGL @ rot)
geometry.transform(transform)
o3d.io.write_triangle_mesh(output_path, geometry, write_ascii=False, compressed=True)
return output_path
def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) -> str:
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, "%03d.jpg")
filter_complex = f"select='if(lt(t,{duration}),1,0)',fps={fps}"
command = [
"ffmpeg",
"-i", video_path,
"-vf", filter_complex,
"-vsync", "0",
output_path
]
subprocess.run(command, check=True)
return temp_dir
def cat_meshes(meshes):
vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
n_vertices = np.cumsum([0]+[len(v) for v in vertices])
for i in range(len(faces)):
faces[i][:] += n_vertices[i]
vertices = np.concatenate(vertices)
colors = np.concatenate(colors)
faces = np.concatenate(faces)
return dict(vertices=vertices, face_colors=colors, faces=faces)
def load_ckpt(model_path_or_url, verbose=True):
if verbose:
print('... loading model from', model_path_or_url)
is_url = urllib.parse.urlparse(model_path_or_url).scheme in ('http', 'https')
if is_url:
ckpt = torch.hub.load_state_dict_from_url(model_path_or_url, map_location='cpu', progress=verbose)
else:
ckpt = torch.load(model_path_or_url, map_location='cpu')
return ckpt
def load_model(ckpt_path, device):
model = Spann3R(dus3r_name=DEFAULT_DUST3R_PATH,
use_feat=False).to(device)
model.load_state_dict(load_ckpt(ckpt_path)['model'])
model.eval()
return model
def pts3d_to_trimesh(img, pts3d, valid=None):
H, W, THREE = img.shape
assert THREE == 3
assert img.shape == pts3d.shape
vertices = pts3d.reshape(-1, 3)
# make squares: each pixel == 2 triangles
idx = np.arange(len(vertices)).reshape(H, W)
idx1 = idx[:-1, :-1].ravel() # top-left corner
idx2 = idx[:-1, +1:].ravel() # right-left corner
idx3 = idx[+1:, :-1].ravel() # bottom-left corner
idx4 = idx[+1:, +1:].ravel() # bottom-right corner
faces = np.concatenate((
np.c_[idx1, idx2, idx3],
np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling)
np.c_[idx2, idx3, idx4],
np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling)
), axis=0)
# prepare triangle colors
face_colors = np.concatenate((
img[:-1, :-1].reshape(-1, 3),
img[:-1, :-1].reshape(-1, 3),
img[+1:, +1:].reshape(-1, 3),
img[+1:, +1:].reshape(-1, 3)
), axis=0)
# remove invalid faces
if valid is not None:
assert valid.shape == (H, W)
valid_idxs = valid.ravel()
valid_faces = valid_idxs[faces].all(axis=-1)
faces = faces[valid_faces]
face_colors = face_colors[valid_faces]
assert len(faces) == len(face_colors)
return dict(vertices=vertices, face_colors=face_colors, faces=faces)
model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
birefnet.to(DEFAULT_DEVICE)
birefnet.eval()
def extract_object(birefnet, image):
# Data settings
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(image).unsqueeze(0).to(DEFAULT_DEVICE)
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
return mask
def generate_mask(image: np.ndarray):
# Convert numpy array to PIL Image
pil_image = Image.fromarray((image * 255).astype(np.uint8))
# Extract object and get mask
mask = extract_object(birefnet, pil_image)
# Convert mask to numpy array
mask_np = np.array(mask) / 255.0
return mask_np
def center_pcd(pcd: o3d.geometry.PointCloud, normalize=False) -> o3d.geometry.PointCloud:
# Convert to numpy array
points = np.asarray(pcd.points)
# Compute centroid
centroid = np.mean(points, axis=0)
# Center the point cloud
centered_points = points - centroid
if normalize:
# Compute the maximum distance from the center
max_distance = np.max(np.linalg.norm(centered_points, axis=1))
# Normalize the point cloud
normalized_points = centered_points / max_distance
# Create a new point cloud with the normalized points
normalized_pcd = o3d.geometry.PointCloud()
normalized_pcd.points = o3d.utility.Vector3dVector(normalized_points)
# If the original point cloud has colors, normalize them too
if pcd.has_colors():
normalized_pcd.colors = pcd.colors
# If the original point cloud has normals, copy them
if pcd.has_normals():
normalized_pcd.normals = pcd.normals
return normalized_pcd
else:
pcd.points = o3d.utility.Vector3dVector(centered_points)
return pcd
@torch.no_grad()
def reconstruct(video_path, conf_thresh, kf_every,
remove_background=False):
# Extract frames from video
demo_path = extract_frames(video_path)
# Load dataset
dataset = Demo(ROOT=demo_path, resolution=224, full_video=True, kf_every=kf_every)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
batch = next(iter(dataloader))
for view in batch:
view['img'] = view['img'].to(DEFAULT_DEVICE, non_blocking=True)
demo_name = os.path.basename(video_path)
print(f'Started reconstruction for {demo_name}')
start = time.time()
preds, preds_all = model.forward(batch)
end = time.time()
fps = len(batch) / (end - start)
print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}')
# Process results
pcds = []
for j, view in enumerate(batch):
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
image = (image + 1) / 2
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy()
conf = preds[j]['conf'][0].cpu().data.numpy()
conf_sig = (conf - 1) / conf
if remove_background:
mask = generate_mask(image)
else:
mask = np.ones_like(conf)
combined_mask = (conf_sig > conf_thresh) & (mask > 0.5)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pts[combined_mask])
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
pcds.append(pcd)
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
pcd_combined = center_pcd(pcd_combined, normalize=True)
o3d_geometry = point2mesh(pcd_combined)
# Create coarse result
coarse_output_path = export_geometry(o3d_geometry)
yield coarse_output_path, None
transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
transformed_pcds = center_pcd(transformed_pcds)
# Create coarse result
refined_output_path = tempfile.mktemp(suffix='.ply')
point2gs(refined_output_path, transformed_pcds)
yield coarse_output_path, refined_output_path
# Clean up temporary directory
os.system(f"rm -rf {demo_path}")
example_videos = [os.path.join('./examples', f) for f in os.listdir('./examples') if f.endswith(('.mp4', '.webm'))]
# Update the Gradio interface with improved layout
with gr.Blocks(
title="StableRecon: 3D Reconstruction from Video",
css="""
#download {
height: 118px;
}
.slider .inner {
width: 5px;
background: #FFF;
}
.viewport {
aspect-ratio: 4/3;
}
.tabs button.selected {
font-size: 20px !important;
color: crimson !important;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
.md_feedback li {
margin-bottom: 0px !important;
}
""",
head="""
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag() {dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'G-1FWSVCGZTG');
</script>
""",
) as iface:
gr.Markdown(
"""
# StableRecon: Making Video to 3D easy
<p align="center">
<a title="Github" href="https://github.com/Stable-X/StableRecon" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/Stable-X/StableRecon?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</p>
<div style="background-color: #f0f0f0; padding: 10px; border-radius: 5px; margin-bottom: 20px;">
<strong>📢 About StableRecon:</strong> This is an experimental open-source project building on <a href="https://github.com/naver/dust3r" target="_blank">dust3r</a> and <a href="https://github.com/HengyiWang/spann3r" target="_blank">spann3r</a>. We're exploring video-to-3D conversion, using spann3r for tracking and implementing our own backend and meshing. While it's a work in progress with plenty of room for improvement, we're excited to share it with the community. We welcome your feedback, especially on failure cases, as we continue to develop and refine this tool.
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
video_input = gr.Video(label="Input Video", sources=["upload"])
with gr.Row():
conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold")
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
with gr.Row():
remove_background = gr.Checkbox(label="Remove Background", value=False)
reconstruct_btn = gr.Button("Start Reconstruction")
with gr.Column(scale=2):
with gr.Tab("3D Models"):
with gr.Group():
initial_model = gr.Model3D(label="Initial 3D Model", display_mode="solid",
clear_color=[0.0, 0.0, 0.0, 0.0])
gr.Markdown(
"""
<div class="model-description">
This is the initial 3D model generated from the video. Finish within 10 seconds.
</div>
"""
)
with gr.Group():
optimized_model = gr.Model3D(label="Optimized 3D Model", display_mode="solid",
clear_color=[0.0, 0.0, 0.0, 0.0])
gr.Markdown(
"""
<div class="model-description">
This is the optimized 3D model with improved accuracy and detail using Gaussian Splatting. Finish within 60 seconds.
</div>
"""
)
with gr.Tab("Help"):
gr.Markdown(
"""
## How to use this tool:
1. Upload a video of the object you want to reconstruct.
2. Adjust the Confidence Threshold and Keyframe Interval if needed.
3. Choose whether to remove the background.
4. Click "Start Reconstruction" to begin the process.
5. The Initial 3D Model will appear first, giving you a quick preview.
6. Once processing is complete, the Optimized 3D Model will show the final result.
### Tips:
- For best results, ensure your video captures the object from multiple angles.
- If the model appears noisy, try increasing the Confidence Threshold.
- Experiment with different Keyframe Intervals to balance speed and accuracy.
"""
)
Examples(
fn=reconstruct,
examples=sorted([
os.path.join("examples", name)
for name in os.listdir(os.path.join("examples")) if name.endswith('.webm')
]),
inputs=[video_input],
outputs=[initial_model, optimized_model],
directory_name="examples_video",
cache_examples=False,
)
reconstruct_btn.click(
fn=reconstruct,
inputs=[video_input, conf_thresh, kf_every, remove_background],
outputs=[initial_model, optimized_model]
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0")