Spaces:
Running
on
L40S
Running
on
L40S
import os | |
import time | |
import torch | |
import numpy as np | |
import gradio as gr | |
import urllib.parse | |
import tempfile | |
import subprocess | |
from dust3r.losses import L21 | |
from spann3r.model import Spann3R | |
from spann3r.datasets import Demo | |
from torch.utils.data import DataLoader | |
import trimesh | |
from scipy.spatial.transform import Rotation | |
from transformers import AutoModelForImageSegmentation | |
from torchvision import transforms | |
from PIL import Image | |
import open3d as o3d | |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds | |
from gs_utils import point2gs | |
# Default values | |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth' | |
DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth' | |
DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
OPENGL = np.array([[1, 0, 0, 0], | |
[0, -1, 0, 0], | |
[0, 0, -1, 0], | |
[0, 0, 0, 1]]) | |
def export_geometry(geometry): | |
output_path = tempfile.mktemp(suffix='.obj') | |
# Apply rotation | |
rot = np.eye(4) | |
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() | |
transform = np.linalg.inv(OPENGL @ rot) | |
geometry.transform(transform) | |
o3d.io.write_triangle_mesh(output_path, geometry, write_ascii=False, compressed=True) | |
return output_path | |
def extract_frames(video_path: str, duration: float = 20.0, fps: float = 3.0) -> str: | |
temp_dir = tempfile.mkdtemp() | |
output_path = os.path.join(temp_dir, "%03d.jpg") | |
filter_complex = f"select='if(lt(t,{duration}),1,0)',fps={fps}" | |
command = [ | |
"ffmpeg", | |
"-i", video_path, | |
"-vf", filter_complex, | |
"-vsync", "0", | |
output_path | |
] | |
subprocess.run(command, check=True) | |
return temp_dir | |
def cat_meshes(meshes): | |
vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes]) | |
n_vertices = np.cumsum([0]+[len(v) for v in vertices]) | |
for i in range(len(faces)): | |
faces[i][:] += n_vertices[i] | |
vertices = np.concatenate(vertices) | |
colors = np.concatenate(colors) | |
faces = np.concatenate(faces) | |
return dict(vertices=vertices, face_colors=colors, faces=faces) | |
def load_ckpt(model_path_or_url, verbose=True): | |
if verbose: | |
print('... loading model from', model_path_or_url) | |
is_url = urllib.parse.urlparse(model_path_or_url).scheme in ('http', 'https') | |
if is_url: | |
ckpt = torch.hub.load_state_dict_from_url(model_path_or_url, map_location='cpu', progress=verbose) | |
else: | |
ckpt = torch.load(model_path_or_url, map_location='cpu') | |
return ckpt | |
def load_model(ckpt_path, device): | |
model = Spann3R(dus3r_name=DEFAULT_DUST3R_PATH, | |
use_feat=False).to(device) | |
model.load_state_dict(load_ckpt(ckpt_path)['model']) | |
model.eval() | |
return model | |
def pts3d_to_trimesh(img, pts3d, valid=None): | |
H, W, THREE = img.shape | |
assert THREE == 3 | |
assert img.shape == pts3d.shape | |
vertices = pts3d.reshape(-1, 3) | |
# make squares: each pixel == 2 triangles | |
idx = np.arange(len(vertices)).reshape(H, W) | |
idx1 = idx[:-1, :-1].ravel() # top-left corner | |
idx2 = idx[:-1, +1:].ravel() # right-left corner | |
idx3 = idx[+1:, :-1].ravel() # bottom-left corner | |
idx4 = idx[+1:, +1:].ravel() # bottom-right corner | |
faces = np.concatenate(( | |
np.c_[idx1, idx2, idx3], | |
np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling) | |
np.c_[idx2, idx3, idx4], | |
np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling) | |
), axis=0) | |
# prepare triangle colors | |
face_colors = np.concatenate(( | |
img[:-1, :-1].reshape(-1, 3), | |
img[:-1, :-1].reshape(-1, 3), | |
img[+1:, +1:].reshape(-1, 3), | |
img[+1:, +1:].reshape(-1, 3) | |
), axis=0) | |
# remove invalid faces | |
if valid is not None: | |
assert valid.shape == (H, W) | |
valid_idxs = valid.ravel() | |
valid_faces = valid_idxs[faces].all(axis=-1) | |
faces = faces[valid_faces] | |
face_colors = face_colors[valid_faces] | |
assert len(faces) == len(face_colors) | |
return dict(vertices=vertices, face_colors=face_colors, faces=faces) | |
model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE) | |
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True) | |
birefnet.to(DEFAULT_DEVICE) | |
birefnet.eval() | |
def extract_object(birefnet, image): | |
# Data settings | |
image_size = (1024, 1024) | |
transform_image = transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
input_images = transform_image(image).unsqueeze(0).to(DEFAULT_DEVICE) | |
# Prediction | |
with torch.no_grad(): | |
preds = birefnet(input_images)[-1].sigmoid().cpu() | |
pred = preds[0].squeeze() | |
pred_pil = transforms.ToPILImage()(pred) | |
mask = pred_pil.resize(image.size) | |
return mask | |
def generate_mask(image: np.ndarray): | |
# Convert numpy array to PIL Image | |
pil_image = Image.fromarray((image * 255).astype(np.uint8)) | |
# Extract object and get mask | |
mask = extract_object(birefnet, pil_image) | |
# Convert mask to numpy array | |
mask_np = np.array(mask) / 255.0 | |
return mask_np | |
def reconstruct(video_path, conf_thresh, kf_every, | |
remove_background=False): | |
# Extract frames from video | |
demo_path = extract_frames(video_path) | |
# Load dataset | |
dataset = Demo(ROOT=demo_path, resolution=224, full_video=True, kf_every=kf_every) | |
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) | |
batch = next(iter(dataloader)) | |
for view in batch: | |
view['img'] = view['img'].to(DEFAULT_DEVICE, non_blocking=True) | |
demo_name = os.path.basename(video_path) | |
print(f'Started reconstruction for {demo_name}') | |
start = time.time() | |
preds, preds_all = model.forward(batch) | |
end = time.time() | |
fps = len(batch) / (end - start) | |
print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}') | |
# Process results | |
pcds = [] | |
for j, view in enumerate(batch): | |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0] | |
image = (image + 1) / 2 | |
pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0] | |
pts_normal = pts2normal(preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'][0]).cpu().numpy() | |
conf = preds[j]['conf'][0].cpu().data.numpy() | |
conf_sig = (conf - 1) / conf | |
if remove_background: | |
mask = generate_mask(image) | |
else: | |
mask = np.ones_like(conf) | |
combined_mask = (conf_sig > conf_thresh) & (mask > 0.5) | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(pts[combined_mask]) | |
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask]) | |
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask]) | |
pcds.append(pcd) | |
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001) | |
o3d_geometry = point2mesh(pcd_combined) | |
# Create coarse result | |
coarse_output_path = export_geometry(o3d_geometry) | |
yield coarse_output_path, None | |
# Perform global optimization | |
print("Performing global registration...") | |
transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01) | |
# Create coarse result | |
refined_output_path = tempfile.mktemp(suffix='.ply') | |
point2gs(refined_output_path, transformed_pcds) | |
yield coarse_output_path, refined_output_path | |
# Clean up temporary directory | |
os.system(f"rm -rf {demo_path}") | |
# Update the Gradio interface with improved layout | |
with gr.Blocks( | |
title="StableSpann3r: Making Spann3r stable with Odometry Backend", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
.tabs button.selected { | |
font-size: 20px !important; | |
color: crimson !important; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
.md_feedback li { | |
margin-bottom: 0px !important; | |
} | |
""", | |
head=""" | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag() {dataLayer.push(arguments);} | |
gtag('js', new Date()); | |
gtag('config', 'G-1FWSVCGZTG'); | |
</script> | |
""", | |
) as iface: | |
gr.Markdown( | |
""" | |
# StableSpann3r: Making Spann3r stable with Odometry Backend | |
<p align="center"> | |
<a title="Website" href="https://stable-x.github.io/StableSpann3r/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-website.svg"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/abs/XXXX.XXXXX" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | |
</a> | |
<a title="Github" href="https://github.com/Stable-X/StableSpann3r" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/Stable-X/StableSpann3r?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
</a> | |
<a title="Social" href="https://x.com/ychngji6" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
</a> | |
</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
video_input = gr.Video(label="Input Video") | |
with gr.Row(): | |
conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold") | |
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval") | |
with gr.Row(): | |
remove_background = gr.Checkbox(label="Remove Background", value=False) | |
reconstruct_btn = gr.Button("Reconstruct") | |
with gr.Column(scale=2): | |
with gr.Tab("Coarse Model"): | |
coarse_model = gr.Model3D(label="Coarse 3D Model", display_mode="solid", | |
clear_color=[0.0, 0.0, 0.0, 0.0]) | |
with gr.Tab("Refined Model"): | |
refined_model = gr.Model3D(label="Refined Gaussian Splatting", display_mode="solid", | |
clear_color=[0.0, 0.0, 0.0, 0.0]) | |
reconstruct_btn.click( | |
fn=reconstruct, | |
inputs=[video_input, conf_thresh, kf_every, remove_background], | |
outputs=[coarse_model, refined_model] | |
) | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0") |