Stable-X commited on
Commit
2c5f88b
·
1 Parent(s): 2cc5b1b

feat: Add rendering output and refinement flag

Browse files
Files changed (1) hide show
  1. app.py +94 -40
app.py CHANGED
@@ -16,12 +16,13 @@ from transformers import AutoModelForImageSegmentation
16
  from torchvision import transforms
17
  from PIL import Image
18
  import open3d as o3d
 
19
  from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
20
  from gs_utils import point2gs
 
21
  from gradio.helpers import Examples as GradioExamples
22
  from gradio.utils import get_cache_folder
23
  from pathlib import Path
24
-
25
  # Default values
26
  DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
27
  DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
@@ -209,9 +210,47 @@ def center_pcd(pcd: o3d.geometry.PointCloud, normalize=False) -> o3d.geometry.Po
209
  pcd.points = o3d.utility.Vector3dVector(centered_points)
210
  return pcd
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  @torch.no_grad()
213
  def reconstruct(video_path, conf_thresh, kf_every,
214
- remove_background=False):
215
  # Extract frames from video
216
  demo_path = extract_frames(video_path)
217
 
@@ -234,6 +273,8 @@ def reconstruct(video_path, conf_thresh, kf_every,
234
 
235
  # Process results
236
  pcds = []
 
 
237
  for j, view in enumerate(batch):
238
  image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
239
  image = (image + 1) / 2
@@ -248,28 +289,40 @@ def reconstruct(video_path, conf_thresh, kf_every,
248
 
249
  combined_mask = (conf_sig > conf_thresh) & (mask > 0.5)
250
 
 
 
 
251
  pcd = o3d.geometry.PointCloud()
252
  pcd.points = o3d.utility.Vector3dVector(pts[combined_mask])
253
  pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
254
  pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
255
  pcds.append(pcd)
 
 
256
 
257
  pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
258
- pcd_combined = center_pcd(pcd_combined, normalize=True)
259
  o3d_geometry = point2mesh(pcd_combined)
260
-
261
- # Create coarse result
262
- coarse_output_path = export_geometry(o3d_geometry)
263
-
264
- yield coarse_output_path, None
265
-
266
- transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
267
- transformed_pcds = center_pcd(transformed_pcds)
268
 
269
  # Create coarse result
270
- refined_output_path = tempfile.mktemp(suffix='.ply')
271
- point2gs(refined_output_path, transformed_pcds)
272
- yield coarse_output_path, refined_output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  # Clean up temporary directory
275
  os.system(f"rm -rf {demo_path}")
@@ -345,13 +398,26 @@ with gr.Blocks(
345
  kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
346
  with gr.Row():
347
  remove_background = gr.Checkbox(label="Remove Background", value=False)
 
 
 
 
 
 
 
 
 
 
348
  reconstruct_btn = gr.Button("Start Reconstruction")
349
 
350
  with gr.Column(scale=2):
351
  with gr.Tab("3D Models"):
352
  with gr.Group():
353
- initial_model = gr.Model3D(label="Initial 3D Model", display_mode="solid",
354
- clear_color=[0.0, 0.0, 0.0, 0.0])
 
 
 
355
  gr.Markdown(
356
  """
357
  <div class="model-description">
@@ -361,33 +427,21 @@ with gr.Blocks(
361
  )
362
 
363
  with gr.Group():
364
- optimized_model = gr.Model3D(label="Optimized 3D Model", display_mode="solid",
365
- clear_color=[0.0, 0.0, 0.0, 0.0])
 
 
 
366
  gr.Markdown(
367
  """
368
  <div class="model-description">
369
- This is the optimized 3D model with improved accuracy and detail using Gaussian Splatting. Finish within 60 seconds.
 
 
 
370
  </div>
371
  """
372
  )
373
-
374
- with gr.Tab("Help"):
375
- gr.Markdown(
376
- """
377
- ## How to use this tool:
378
- 1. Upload a video of the object you want to reconstruct.
379
- 2. Adjust the Confidence Threshold and Keyframe Interval if needed.
380
- 3. Choose whether to remove the background.
381
- 4. Click "Start Reconstruction" to begin the process.
382
- 5. The Initial 3D Model will appear first, giving you a quick preview.
383
- 6. Once processing is complete, the Optimized 3D Model will show the final result.
384
-
385
- ### Tips:
386
- - For best results, ensure your video captures the object from multiple angles.
387
- - If the model appears noisy, try increasing the Confidence Threshold.
388
- - Experiment with different Keyframe Intervals to balance speed and accuracy.
389
- """
390
- )
391
 
392
  Examples(
393
  fn=reconstruct,
@@ -396,15 +450,15 @@ with gr.Blocks(
396
  for name in os.listdir(os.path.join("examples")) if name.endswith('.webm')
397
  ]),
398
  inputs=[video_input],
399
- outputs=[initial_model, optimized_model],
400
  directory_name="examples_video",
401
  cache_examples=False,
402
  )
403
 
404
  reconstruct_btn.click(
405
  fn=reconstruct,
406
- inputs=[video_input, conf_thresh, kf_every, remove_background],
407
- outputs=[initial_model, optimized_model]
408
  )
409
 
410
  if __name__ == "__main__":
 
16
  from torchvision import transforms
17
  from PIL import Image
18
  import open3d as o3d
19
+ from spann3r.tools.vis import render_frames
20
  from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
21
  from gs_utils import point2gs
22
+ from pose_utils import solve_cemara
23
  from gradio.helpers import Examples as GradioExamples
24
  from gradio.utils import get_cache_folder
25
  from pathlib import Path
 
26
  # Default values
27
  DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
28
  DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
 
210
  pcd.points = o3d.utility.Vector3dVector(centered_points)
211
  return pcd
212
 
213
+ def center_mesh(mesh: o3d.geometry.TriangleMesh, normalize=False) -> o3d.geometry.TriangleMesh:
214
+ # Convert to numpy array
215
+ vertices = np.asarray(mesh.vertices)
216
+
217
+ # Compute centroid
218
+ centroid = np.mean(vertices, axis=0)
219
+
220
+ # Center the mesh
221
+ centered_vertices = vertices - centroid
222
+
223
+ if normalize:
224
+ # Compute the maximum distance from the center
225
+ max_distance = np.max(np.linalg.norm(centered_vertices, axis=1))
226
+
227
+ # Normalize the mesh
228
+ normalized_vertices = centered_vertices / max_distance
229
+
230
+ # Create a new mesh with the normalized vertices
231
+ normalized_mesh = o3d.geometry.TriangleMesh()
232
+ normalized_mesh.vertices = o3d.utility.Vector3dVector(normalized_vertices)
233
+ normalized_mesh.triangles = mesh.triangles
234
+
235
+ # If the original mesh has vertex colors, copy them
236
+ if mesh.has_vertex_colors():
237
+ normalized_mesh.vertex_colors = mesh.vertex_colors
238
+
239
+ # If the original mesh has vertex normals, normalize them
240
+ if mesh.has_vertex_normals():
241
+ vertex_normals = np.asarray(mesh.vertex_normals)
242
+ normalized_vertex_normals = vertex_normals / np.linalg.norm(vertex_normals, axis=1, keepdims=True)
243
+ normalized_mesh.vertex_normals = o3d.utility.Vector3dVector(normalized_vertex_normals)
244
+
245
+ return normalized_mesh
246
+ else:
247
+ # Update the mesh with the centered vertices
248
+ mesh.vertices = o3d.utility.Vector3dVector(centered_vertices)
249
+ return mesh
250
+
251
  @torch.no_grad()
252
  def reconstruct(video_path, conf_thresh, kf_every,
253
+ remove_background=False, enable_registration=True, output_3d_model=True):
254
  # Extract frames from video
255
  demo_path = extract_frames(video_path)
256
 
 
273
 
274
  # Process results
275
  pcds = []
276
+ cameras_all = []
277
+ last_focal = None
278
  for j, view in enumerate(batch):
279
  image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
280
  image = (image + 1) / 2
 
289
 
290
  combined_mask = (conf_sig > conf_thresh) & (mask > 0.5)
291
 
292
+ camera, last_focal = solve_cemara(torch.tensor(pts), torch.tensor(conf_sig) > 0.001,
293
+ "cuda", focal=last_focal)
294
+
295
  pcd = o3d.geometry.PointCloud()
296
  pcd.points = o3d.utility.Vector3dVector(pts[combined_mask])
297
  pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
298
  pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
299
  pcds.append(pcd)
300
+ cameras_all.append(camera)
301
+
302
 
303
  pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
 
304
  o3d_geometry = point2mesh(pcd_combined)
305
+ o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
 
 
 
 
 
 
 
306
 
307
  # Create coarse result
308
+ coarse_output_path = export_geometry(o3d_geometry_centered)
309
+ yield coarse_output_path, None
310
+
311
+ gs_output_path = tempfile.mktemp(suffix='.ply')
312
+ if enable_registration:
313
+ transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
314
+ transformed_pcds = center_pcd(transformed_pcds)
315
+ point2gs(gs_output_path, transformed_pcds)
316
+ else:
317
+ point2gs(gs_output_path, pcd_combined)
318
+
319
+ if output_3d_model:
320
+ # Create 3D model result using gaussian splatting
321
+ yield coarse_output_path, gs_output_path
322
+ else:
323
+ gs_output_path = tempfile.mktemp(suffix='.ply')
324
+ render_video_path = render_frames(o3d_geometry, cameras_all, demo_path)
325
+ yield coarse_output_path, render_video_path
326
 
327
  # Clean up temporary directory
328
  os.system(f"rm -rf {demo_path}")
 
398
  kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
399
  with gr.Row():
400
  remove_background = gr.Checkbox(label="Remove Background", value=False)
401
+ enable_registration = gr.Checkbox(
402
+ label="Enable Refinement",
403
+ value=False,
404
+ info="Improves alignment but takes longer"
405
+ )
406
+ output_3d_model = gr.Checkbox(
407
+ label="Output Splat",
408
+ value=True,
409
+ info="Generate Splat (PLY) instead of video render"
410
+ )
411
  reconstruct_btn = gr.Button("Start Reconstruction")
412
 
413
  with gr.Column(scale=2):
414
  with gr.Tab("3D Models"):
415
  with gr.Group():
416
+ initial_model = gr.Model3D(
417
+ label="Initial 3D Model",
418
+ display_mode="solid",
419
+ clear_color=[0.0, 0.0, 0.0, 0.0]
420
+ )
421
  gr.Markdown(
422
  """
423
  <div class="model-description">
 
427
  )
428
 
429
  with gr.Group():
430
+ output_model = gr.File(
431
+ label="Refined Result (Splat or Video)",
432
+ file_types=[".ply", ".mp4"],
433
+ file_count="single"
434
+ )
435
  gr.Markdown(
436
  """
437
  <div class="model-description">
438
+ Downloads as either:
439
+ - PLY file: Gaussin Splat Model (when "Output Splat" is enabled)
440
+ - MP4 file: 360° rotating render video (when "Output Splat" is disabled)
441
+ <br>Time: ~60 seconds with refinement, ~30 seconds without
442
  </div>
443
  """
444
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  Examples(
447
  fn=reconstruct,
 
450
  for name in os.listdir(os.path.join("examples")) if name.endswith('.webm')
451
  ]),
452
  inputs=[video_input],
453
+ outputs=[initial_model, output_model],
454
  directory_name="examples_video",
455
  cache_examples=False,
456
  )
457
 
458
  reconstruct_btn.click(
459
  fn=reconstruct,
460
+ inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
461
+ outputs=[initial_model, output_model]
462
  )
463
 
464
  if __name__ == "__main__":