Spaces:
Sleeping
Sleeping
feat: Update demo
Browse files
app.py
CHANGED
@@ -18,6 +18,9 @@ from PIL import Image
|
|
18 |
import open3d as o3d
|
19 |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
|
20 |
from gs_utils import point2gs
|
|
|
|
|
|
|
21 |
|
22 |
# Default values
|
23 |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
|
@@ -29,6 +32,14 @@ OPENGL = np.array([[1, 0, 0, 0],
|
|
29 |
[0, 0, -1, 0],
|
30 |
[0, 0, 0, 1]])
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
def export_geometry(geometry):
|
33 |
output_path = tempfile.mktemp(suffix='.obj')
|
34 |
|
@@ -163,6 +174,41 @@ def generate_mask(image: np.ndarray):
|
|
163 |
# Convert mask to numpy array
|
164 |
mask_np = np.array(mask) / 255.0
|
165 |
return mask_np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
@torch.no_grad()
|
167 |
def reconstruct(video_path, conf_thresh, kf_every,
|
168 |
remove_background=False):
|
@@ -209,6 +255,7 @@ def reconstruct(video_path, conf_thresh, kf_every,
|
|
209 |
pcds.append(pcd)
|
210 |
|
211 |
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
|
|
212 |
o3d_geometry = point2mesh(pcd_combined)
|
213 |
|
214 |
# Create coarse result
|
@@ -216,10 +263,9 @@ def reconstruct(video_path, conf_thresh, kf_every,
|
|
216 |
|
217 |
yield coarse_output_path, None
|
218 |
|
219 |
-
# Perform global optimization
|
220 |
-
print("Performing global registration...")
|
221 |
transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
|
222 |
-
|
|
|
223 |
# Create coarse result
|
224 |
refined_output_path = tempfile.mktemp(suffix='.ply')
|
225 |
point2gs(refined_output_path, transformed_pcds)
|
@@ -228,9 +274,11 @@ def reconstruct(video_path, conf_thresh, kf_every,
|
|
228 |
# Clean up temporary directory
|
229 |
os.system(f"rm -rf {demo_path}")
|
230 |
|
|
|
|
|
231 |
# Update the Gradio interface with improved layout
|
232 |
with gr.Blocks(
|
233 |
-
title="StableSpann3r:
|
234 |
css="""
|
235 |
#download {
|
236 |
height: 118px;
|
@@ -276,12 +324,6 @@ with gr.Blocks(
|
|
276 |
"""
|
277 |
# StableSpann3r: Making Spann3r stable with Odometry Backend
|
278 |
<p align="center">
|
279 |
-
<a title="Website" href="https://stable-x.github.io/StableSpann3r/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
280 |
-
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
|
281 |
-
</a>
|
282 |
-
<a title="arXiv" href="https://arxiv.org/abs/XXXX.XXXXX" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
283 |
-
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
|
284 |
-
</a>
|
285 |
<a title="Github" href="https://github.com/Stable-X/StableSpann3r" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
286 |
<img src="https://img.shields.io/github/stars/Stable-X/StableSpann3r?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
|
287 |
</a>
|
@@ -293,26 +335,72 @@ with gr.Blocks(
|
|
293 |
)
|
294 |
with gr.Row():
|
295 |
with gr.Column(scale=1):
|
296 |
-
video_input = gr.Video(label="Input Video")
|
297 |
with gr.Row():
|
298 |
conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold")
|
299 |
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
|
300 |
with gr.Row():
|
301 |
remove_background = gr.Checkbox(label="Remove Background", value=False)
|
302 |
-
reconstruct_btn = gr.Button("
|
303 |
|
304 |
with gr.Column(scale=2):
|
305 |
-
with gr.Tab("
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
reconstruct_btn.click(
|
313 |
fn=reconstruct,
|
314 |
inputs=[video_input, conf_thresh, kf_every, remove_background],
|
315 |
-
outputs=[
|
316 |
)
|
317 |
|
318 |
if __name__ == "__main__":
|
|
|
18 |
import open3d as o3d
|
19 |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
|
20 |
from gs_utils import point2gs
|
21 |
+
from gradio.helpers import Examples as GradioExamples
|
22 |
+
from gradio.utils import get_cache_folder
|
23 |
+
from pathlib import Path
|
24 |
|
25 |
# Default values
|
26 |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
|
|
|
32 |
[0, 0, -1, 0],
|
33 |
[0, 0, 0, 1]])
|
34 |
|
35 |
+
class Examples(GradioExamples):
|
36 |
+
def __init__(self, *args, directory_name=None, **kwargs):
|
37 |
+
super().__init__(*args, **kwargs, _initiated_directly=False)
|
38 |
+
if directory_name is not None:
|
39 |
+
self.cached_folder = get_cache_folder() / directory_name
|
40 |
+
self.cached_file = Path(self.cached_folder) / "log.csv"
|
41 |
+
self.create()
|
42 |
+
|
43 |
def export_geometry(geometry):
|
44 |
output_path = tempfile.mktemp(suffix='.obj')
|
45 |
|
|
|
174 |
# Convert mask to numpy array
|
175 |
mask_np = np.array(mask) / 255.0
|
176 |
return mask_np
|
177 |
+
|
178 |
+
def center_pcd(pcd: o3d.geometry.PointCloud, normalize=False) -> o3d.geometry.PointCloud:
|
179 |
+
# Convert to numpy array
|
180 |
+
points = np.asarray(pcd.points)
|
181 |
+
|
182 |
+
# Compute centroid
|
183 |
+
centroid = np.mean(points, axis=0)
|
184 |
+
|
185 |
+
# Center the point cloud
|
186 |
+
centered_points = points - centroid
|
187 |
+
|
188 |
+
if normalize:
|
189 |
+
# Compute the maximum distance from the center
|
190 |
+
max_distance = np.max(np.linalg.norm(centered_points, axis=1))
|
191 |
+
|
192 |
+
# Normalize the point cloud
|
193 |
+
normalized_points = centered_points / max_distance
|
194 |
+
|
195 |
+
# Create a new point cloud with the normalized points
|
196 |
+
normalized_pcd = o3d.geometry.PointCloud()
|
197 |
+
normalized_pcd.points = o3d.utility.Vector3dVector(normalized_points)
|
198 |
+
|
199 |
+
# If the original point cloud has colors, normalize them too
|
200 |
+
if pcd.has_colors():
|
201 |
+
normalized_pcd.colors = pcd.colors
|
202 |
+
|
203 |
+
# If the original point cloud has normals, copy them
|
204 |
+
if pcd.has_normals():
|
205 |
+
normalized_pcd.normals = pcd.normals
|
206 |
+
|
207 |
+
return normalized_pcd
|
208 |
+
else:
|
209 |
+
pcd.points = o3d.utility.Vector3dVector(centered_points)
|
210 |
+
return pcd
|
211 |
+
|
212 |
@torch.no_grad()
|
213 |
def reconstruct(video_path, conf_thresh, kf_every,
|
214 |
remove_background=False):
|
|
|
255 |
pcds.append(pcd)
|
256 |
|
257 |
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
258 |
+
pcd_combined = center_pcd(pcd_combined, normalize=True)
|
259 |
o3d_geometry = point2mesh(pcd_combined)
|
260 |
|
261 |
# Create coarse result
|
|
|
263 |
|
264 |
yield coarse_output_path, None
|
265 |
|
|
|
|
|
266 |
transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
|
267 |
+
transformed_pcds = center_pcd(transformed_pcds)
|
268 |
+
|
269 |
# Create coarse result
|
270 |
refined_output_path = tempfile.mktemp(suffix='.ply')
|
271 |
point2gs(refined_output_path, transformed_pcds)
|
|
|
274 |
# Clean up temporary directory
|
275 |
os.system(f"rm -rf {demo_path}")
|
276 |
|
277 |
+
example_videos = [os.path.join('./examples', f) for f in os.listdir('./examples') if f.endswith(('.mp4', '.webm'))]
|
278 |
+
|
279 |
# Update the Gradio interface with improved layout
|
280 |
with gr.Blocks(
|
281 |
+
title="StableSpann3r: 3D Reconstruction from Video",
|
282 |
css="""
|
283 |
#download {
|
284 |
height: 118px;
|
|
|
324 |
"""
|
325 |
# StableSpann3r: Making Spann3r stable with Odometry Backend
|
326 |
<p align="center">
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
<a title="Github" href="https://github.com/Stable-X/StableSpann3r" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
|
328 |
<img src="https://img.shields.io/github/stars/Stable-X/StableSpann3r?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
|
329 |
</a>
|
|
|
335 |
)
|
336 |
with gr.Row():
|
337 |
with gr.Column(scale=1):
|
338 |
+
video_input = gr.Video(label="Input Video", sources=["upload"])
|
339 |
with gr.Row():
|
340 |
conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold")
|
341 |
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
|
342 |
with gr.Row():
|
343 |
remove_background = gr.Checkbox(label="Remove Background", value=False)
|
344 |
+
reconstruct_btn = gr.Button("Start Reconstruction")
|
345 |
|
346 |
with gr.Column(scale=2):
|
347 |
+
with gr.Tab("3D Models"):
|
348 |
+
with gr.Group():
|
349 |
+
initial_model = gr.Model3D(label="Initial 3D Model", display_mode="solid",
|
350 |
+
clear_color=[0.0, 0.0, 0.0, 0.0])
|
351 |
+
gr.Markdown(
|
352 |
+
"""
|
353 |
+
<div class="model-description">
|
354 |
+
This is the initial 3D model generated from the video. Finish within 10 seconds.
|
355 |
+
</div>
|
356 |
+
"""
|
357 |
+
)
|
358 |
+
|
359 |
+
with gr.Group():
|
360 |
+
optimized_model = gr.Model3D(label="Optimized 3D Model", display_mode="solid",
|
361 |
+
clear_color=[0.0, 0.0, 0.0, 0.0])
|
362 |
+
gr.Markdown(
|
363 |
+
"""
|
364 |
+
<div class="model-description">
|
365 |
+
This is the optimized 3D model with improved accuracy and detail using Gaussian Splatting. Finish within 60 seconds.
|
366 |
+
</div>
|
367 |
+
"""
|
368 |
+
)
|
369 |
+
|
370 |
+
with gr.Tab("Help"):
|
371 |
+
gr.Markdown(
|
372 |
+
"""
|
373 |
+
## How to use this tool:
|
374 |
+
1. Upload a video of the object you want to reconstruct.
|
375 |
+
2. Adjust the Confidence Threshold and Keyframe Interval if needed.
|
376 |
+
3. Choose whether to remove the background.
|
377 |
+
4. Click "Start Reconstruction" to begin the process.
|
378 |
+
5. The Initial 3D Model will appear first, giving you a quick preview.
|
379 |
+
6. Once processing is complete, the Optimized 3D Model will show the final result.
|
380 |
+
|
381 |
+
### Tips:
|
382 |
+
- For best results, ensure your video captures the object from multiple angles.
|
383 |
+
- If the model appears noisy, try increasing the Confidence Threshold.
|
384 |
+
- Experiment with different Keyframe Intervals to balance speed and accuracy.
|
385 |
+
"""
|
386 |
+
)
|
387 |
+
|
388 |
+
Examples(
|
389 |
+
fn=reconstruct,
|
390 |
+
examples=sorted([
|
391 |
+
os.path.join("examples", name)
|
392 |
+
for name in os.listdir(os.path.join("examples")) if name.endswith('.webm')
|
393 |
+
]),
|
394 |
+
inputs=[video_input],
|
395 |
+
outputs=[initial_model, optimized_model],
|
396 |
+
directory_name="examples_video",
|
397 |
+
cache_examples=False,
|
398 |
+
)
|
399 |
|
400 |
reconstruct_btn.click(
|
401 |
fn=reconstruct,
|
402 |
inputs=[video_input, conf_thresh, kf_every, remove_background],
|
403 |
+
outputs=[initial_model, optimized_model]
|
404 |
)
|
405 |
|
406 |
if __name__ == "__main__":
|