Spaces:
Running
on
L40S
Running
on
L40S
feat: Add rendering output and refinement flag
Browse files
app.py
CHANGED
@@ -16,12 +16,13 @@ from transformers import AutoModelForImageSegmentation
|
|
16 |
from torchvision import transforms
|
17 |
from PIL import Image
|
18 |
import open3d as o3d
|
|
|
19 |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
|
20 |
from gs_utils import point2gs
|
|
|
21 |
from gradio.helpers import Examples as GradioExamples
|
22 |
from gradio.utils import get_cache_folder
|
23 |
from pathlib import Path
|
24 |
-
|
25 |
# Default values
|
26 |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
|
27 |
DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
|
@@ -209,9 +210,47 @@ def center_pcd(pcd: o3d.geometry.PointCloud, normalize=False) -> o3d.geometry.Po
|
|
209 |
pcd.points = o3d.utility.Vector3dVector(centered_points)
|
210 |
return pcd
|
211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
@torch.no_grad()
|
213 |
def reconstruct(video_path, conf_thresh, kf_every,
|
214 |
-
remove_background=False):
|
215 |
# Extract frames from video
|
216 |
demo_path = extract_frames(video_path)
|
217 |
|
@@ -234,6 +273,8 @@ def reconstruct(video_path, conf_thresh, kf_every,
|
|
234 |
|
235 |
# Process results
|
236 |
pcds = []
|
|
|
|
|
237 |
for j, view in enumerate(batch):
|
238 |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
|
239 |
image = (image + 1) / 2
|
@@ -248,28 +289,40 @@ def reconstruct(video_path, conf_thresh, kf_every,
|
|
248 |
|
249 |
combined_mask = (conf_sig > conf_thresh) & (mask > 0.5)
|
250 |
|
|
|
|
|
|
|
251 |
pcd = o3d.geometry.PointCloud()
|
252 |
pcd.points = o3d.utility.Vector3dVector(pts[combined_mask])
|
253 |
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
|
254 |
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
|
255 |
pcds.append(pcd)
|
|
|
|
|
256 |
|
257 |
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
258 |
-
pcd_combined = center_pcd(pcd_combined, normalize=True)
|
259 |
o3d_geometry = point2mesh(pcd_combined)
|
260 |
-
|
261 |
-
# Create coarse result
|
262 |
-
coarse_output_path = export_geometry(o3d_geometry)
|
263 |
-
|
264 |
-
yield coarse_output_path, None
|
265 |
-
|
266 |
-
transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
|
267 |
-
transformed_pcds = center_pcd(transformed_pcds)
|
268 |
|
269 |
# Create coarse result
|
270 |
-
|
271 |
-
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
|
274 |
# Clean up temporary directory
|
275 |
os.system(f"rm -rf {demo_path}")
|
@@ -345,13 +398,26 @@ with gr.Blocks(
|
|
345 |
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
|
346 |
with gr.Row():
|
347 |
remove_background = gr.Checkbox(label="Remove Background", value=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
reconstruct_btn = gr.Button("Start Reconstruction")
|
349 |
|
350 |
with gr.Column(scale=2):
|
351 |
with gr.Tab("3D Models"):
|
352 |
with gr.Group():
|
353 |
-
initial_model = gr.Model3D(
|
354 |
-
|
|
|
|
|
|
|
355 |
gr.Markdown(
|
356 |
"""
|
357 |
<div class="model-description">
|
@@ -361,33 +427,21 @@ with gr.Blocks(
|
|
361 |
)
|
362 |
|
363 |
with gr.Group():
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
366 |
gr.Markdown(
|
367 |
"""
|
368 |
<div class="model-description">
|
369 |
-
|
|
|
|
|
|
|
370 |
</div>
|
371 |
"""
|
372 |
)
|
373 |
-
|
374 |
-
with gr.Tab("Help"):
|
375 |
-
gr.Markdown(
|
376 |
-
"""
|
377 |
-
## How to use this tool:
|
378 |
-
1. Upload a video of the object you want to reconstruct.
|
379 |
-
2. Adjust the Confidence Threshold and Keyframe Interval if needed.
|
380 |
-
3. Choose whether to remove the background.
|
381 |
-
4. Click "Start Reconstruction" to begin the process.
|
382 |
-
5. The Initial 3D Model will appear first, giving you a quick preview.
|
383 |
-
6. Once processing is complete, the Optimized 3D Model will show the final result.
|
384 |
-
|
385 |
-
### Tips:
|
386 |
-
- For best results, ensure your video captures the object from multiple angles.
|
387 |
-
- If the model appears noisy, try increasing the Confidence Threshold.
|
388 |
-
- Experiment with different Keyframe Intervals to balance speed and accuracy.
|
389 |
-
"""
|
390 |
-
)
|
391 |
|
392 |
Examples(
|
393 |
fn=reconstruct,
|
@@ -396,15 +450,15 @@ with gr.Blocks(
|
|
396 |
for name in os.listdir(os.path.join("examples")) if name.endswith('.webm')
|
397 |
]),
|
398 |
inputs=[video_input],
|
399 |
-
outputs=[initial_model,
|
400 |
directory_name="examples_video",
|
401 |
cache_examples=False,
|
402 |
)
|
403 |
|
404 |
reconstruct_btn.click(
|
405 |
fn=reconstruct,
|
406 |
-
inputs=[video_input, conf_thresh, kf_every, remove_background],
|
407 |
-
outputs=[initial_model,
|
408 |
)
|
409 |
|
410 |
if __name__ == "__main__":
|
|
|
16 |
from torchvision import transforms
|
17 |
from PIL import Image
|
18 |
import open3d as o3d
|
19 |
+
from spann3r.tools.vis import render_frames
|
20 |
from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
|
21 |
from gs_utils import point2gs
|
22 |
+
from pose_utils import solve_cemara
|
23 |
from gradio.helpers import Examples as GradioExamples
|
24 |
from gradio.utils import get_cache_folder
|
25 |
from pathlib import Path
|
|
|
26 |
# Default values
|
27 |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
|
28 |
DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
|
|
|
210 |
pcd.points = o3d.utility.Vector3dVector(centered_points)
|
211 |
return pcd
|
212 |
|
213 |
+
def center_mesh(mesh: o3d.geometry.TriangleMesh, normalize=False) -> o3d.geometry.TriangleMesh:
|
214 |
+
# Convert to numpy array
|
215 |
+
vertices = np.asarray(mesh.vertices)
|
216 |
+
|
217 |
+
# Compute centroid
|
218 |
+
centroid = np.mean(vertices, axis=0)
|
219 |
+
|
220 |
+
# Center the mesh
|
221 |
+
centered_vertices = vertices - centroid
|
222 |
+
|
223 |
+
if normalize:
|
224 |
+
# Compute the maximum distance from the center
|
225 |
+
max_distance = np.max(np.linalg.norm(centered_vertices, axis=1))
|
226 |
+
|
227 |
+
# Normalize the mesh
|
228 |
+
normalized_vertices = centered_vertices / max_distance
|
229 |
+
|
230 |
+
# Create a new mesh with the normalized vertices
|
231 |
+
normalized_mesh = o3d.geometry.TriangleMesh()
|
232 |
+
normalized_mesh.vertices = o3d.utility.Vector3dVector(normalized_vertices)
|
233 |
+
normalized_mesh.triangles = mesh.triangles
|
234 |
+
|
235 |
+
# If the original mesh has vertex colors, copy them
|
236 |
+
if mesh.has_vertex_colors():
|
237 |
+
normalized_mesh.vertex_colors = mesh.vertex_colors
|
238 |
+
|
239 |
+
# If the original mesh has vertex normals, normalize them
|
240 |
+
if mesh.has_vertex_normals():
|
241 |
+
vertex_normals = np.asarray(mesh.vertex_normals)
|
242 |
+
normalized_vertex_normals = vertex_normals / np.linalg.norm(vertex_normals, axis=1, keepdims=True)
|
243 |
+
normalized_mesh.vertex_normals = o3d.utility.Vector3dVector(normalized_vertex_normals)
|
244 |
+
|
245 |
+
return normalized_mesh
|
246 |
+
else:
|
247 |
+
# Update the mesh with the centered vertices
|
248 |
+
mesh.vertices = o3d.utility.Vector3dVector(centered_vertices)
|
249 |
+
return mesh
|
250 |
+
|
251 |
@torch.no_grad()
|
252 |
def reconstruct(video_path, conf_thresh, kf_every,
|
253 |
+
remove_background=False, enable_registration=True, output_3d_model=True):
|
254 |
# Extract frames from video
|
255 |
demo_path = extract_frames(video_path)
|
256 |
|
|
|
273 |
|
274 |
# Process results
|
275 |
pcds = []
|
276 |
+
cameras_all = []
|
277 |
+
last_focal = None
|
278 |
for j, view in enumerate(batch):
|
279 |
image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
|
280 |
image = (image + 1) / 2
|
|
|
289 |
|
290 |
combined_mask = (conf_sig > conf_thresh) & (mask > 0.5)
|
291 |
|
292 |
+
camera, last_focal = solve_cemara(torch.tensor(pts), torch.tensor(conf_sig) > 0.001,
|
293 |
+
"cuda", focal=last_focal)
|
294 |
+
|
295 |
pcd = o3d.geometry.PointCloud()
|
296 |
pcd.points = o3d.utility.Vector3dVector(pts[combined_mask])
|
297 |
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
|
298 |
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
|
299 |
pcds.append(pcd)
|
300 |
+
cameras_all.append(camera)
|
301 |
+
|
302 |
|
303 |
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
|
|
304 |
o3d_geometry = point2mesh(pcd_combined)
|
305 |
+
o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
# Create coarse result
|
308 |
+
coarse_output_path = export_geometry(o3d_geometry_centered)
|
309 |
+
yield coarse_output_path, None
|
310 |
+
|
311 |
+
gs_output_path = tempfile.mktemp(suffix='.ply')
|
312 |
+
if enable_registration:
|
313 |
+
transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
|
314 |
+
transformed_pcds = center_pcd(transformed_pcds)
|
315 |
+
point2gs(gs_output_path, transformed_pcds)
|
316 |
+
else:
|
317 |
+
point2gs(gs_output_path, pcd_combined)
|
318 |
+
|
319 |
+
if output_3d_model:
|
320 |
+
# Create 3D model result using gaussian splatting
|
321 |
+
yield coarse_output_path, gs_output_path
|
322 |
+
else:
|
323 |
+
gs_output_path = tempfile.mktemp(suffix='.ply')
|
324 |
+
render_video_path = render_frames(o3d_geometry, cameras_all, demo_path)
|
325 |
+
yield coarse_output_path, render_video_path
|
326 |
|
327 |
# Clean up temporary directory
|
328 |
os.system(f"rm -rf {demo_path}")
|
|
|
398 |
kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
|
399 |
with gr.Row():
|
400 |
remove_background = gr.Checkbox(label="Remove Background", value=False)
|
401 |
+
enable_registration = gr.Checkbox(
|
402 |
+
label="Enable Refinement",
|
403 |
+
value=False,
|
404 |
+
info="Improves alignment but takes longer"
|
405 |
+
)
|
406 |
+
output_3d_model = gr.Checkbox(
|
407 |
+
label="Output Splat",
|
408 |
+
value=True,
|
409 |
+
info="Generate Splat (PLY) instead of video render"
|
410 |
+
)
|
411 |
reconstruct_btn = gr.Button("Start Reconstruction")
|
412 |
|
413 |
with gr.Column(scale=2):
|
414 |
with gr.Tab("3D Models"):
|
415 |
with gr.Group():
|
416 |
+
initial_model = gr.Model3D(
|
417 |
+
label="Initial 3D Model",
|
418 |
+
display_mode="solid",
|
419 |
+
clear_color=[0.0, 0.0, 0.0, 0.0]
|
420 |
+
)
|
421 |
gr.Markdown(
|
422 |
"""
|
423 |
<div class="model-description">
|
|
|
427 |
)
|
428 |
|
429 |
with gr.Group():
|
430 |
+
output_model = gr.File(
|
431 |
+
label="Refined Result (Splat or Video)",
|
432 |
+
file_types=[".ply", ".mp4"],
|
433 |
+
file_count="single"
|
434 |
+
)
|
435 |
gr.Markdown(
|
436 |
"""
|
437 |
<div class="model-description">
|
438 |
+
Downloads as either:
|
439 |
+
- PLY file: Gaussin Splat Model (when "Output Splat" is enabled)
|
440 |
+
- MP4 file: 360° rotating render video (when "Output Splat" is disabled)
|
441 |
+
<br>Time: ~60 seconds with refinement, ~30 seconds without
|
442 |
</div>
|
443 |
"""
|
444 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
446 |
Examples(
|
447 |
fn=reconstruct,
|
|
|
450 |
for name in os.listdir(os.path.join("examples")) if name.endswith('.webm')
|
451 |
]),
|
452 |
inputs=[video_input],
|
453 |
+
outputs=[initial_model, output_model],
|
454 |
directory_name="examples_video",
|
455 |
cache_examples=False,
|
456 |
)
|
457 |
|
458 |
reconstruct_btn.click(
|
459 |
fn=reconstruct,
|
460 |
+
inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
|
461 |
+
outputs=[initial_model, output_model]
|
462 |
)
|
463 |
|
464 |
if __name__ == "__main__":
|