Stable-X commited on
Commit
1139032
1 Parent(s): 46dd982

feat: Update demo

Browse files
Files changed (1) hide show
  1. app.py +107 -19
app.py CHANGED
@@ -18,6 +18,9 @@ from PIL import Image
18
  import open3d as o3d
19
  from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
20
  from gs_utils import point2gs
 
 
 
21
 
22
  # Default values
23
  DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
@@ -29,6 +32,14 @@ OPENGL = np.array([[1, 0, 0, 0],
29
  [0, 0, -1, 0],
30
  [0, 0, 0, 1]])
31
 
 
 
 
 
 
 
 
 
32
  def export_geometry(geometry):
33
  output_path = tempfile.mktemp(suffix='.obj')
34
 
@@ -163,6 +174,41 @@ def generate_mask(image: np.ndarray):
163
  # Convert mask to numpy array
164
  mask_np = np.array(mask) / 255.0
165
  return mask_np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  @torch.no_grad()
167
  def reconstruct(video_path, conf_thresh, kf_every,
168
  remove_background=False):
@@ -209,6 +255,7 @@ def reconstruct(video_path, conf_thresh, kf_every,
209
  pcds.append(pcd)
210
 
211
  pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
 
212
  o3d_geometry = point2mesh(pcd_combined)
213
 
214
  # Create coarse result
@@ -216,10 +263,9 @@ def reconstruct(video_path, conf_thresh, kf_every,
216
 
217
  yield coarse_output_path, None
218
 
219
- # Perform global optimization
220
- print("Performing global registration...")
221
  transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
222
-
 
223
  # Create coarse result
224
  refined_output_path = tempfile.mktemp(suffix='.ply')
225
  point2gs(refined_output_path, transformed_pcds)
@@ -228,9 +274,11 @@ def reconstruct(video_path, conf_thresh, kf_every,
228
  # Clean up temporary directory
229
  os.system(f"rm -rf {demo_path}")
230
 
 
 
231
  # Update the Gradio interface with improved layout
232
  with gr.Blocks(
233
- title="StableSpann3r: Making Spann3r stable with Odometry Backend",
234
  css="""
235
  #download {
236
  height: 118px;
@@ -276,12 +324,6 @@ with gr.Blocks(
276
  """
277
  # StableSpann3r: Making Spann3r stable with Odometry Backend
278
  <p align="center">
279
- <a title="Website" href="https://stable-x.github.io/StableSpann3r/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
280
- <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
281
- </a>
282
- <a title="arXiv" href="https://arxiv.org/abs/XXXX.XXXXX" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
283
- <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
284
- </a>
285
  <a title="Github" href="https://github.com/Stable-X/StableSpann3r" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
286
  <img src="https://img.shields.io/github/stars/Stable-X/StableSpann3r?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
287
  </a>
@@ -293,26 +335,72 @@ with gr.Blocks(
293
  )
294
  with gr.Row():
295
  with gr.Column(scale=1):
296
- video_input = gr.Video(label="Input Video")
297
  with gr.Row():
298
  conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold")
299
  kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
300
  with gr.Row():
301
  remove_background = gr.Checkbox(label="Remove Background", value=False)
302
- reconstruct_btn = gr.Button("Reconstruct")
303
 
304
  with gr.Column(scale=2):
305
- with gr.Tab("Coarse Model"):
306
- coarse_model = gr.Model3D(label="Coarse 3D Model", display_mode="solid",
307
- clear_color=[0.0, 0.0, 0.0, 0.0])
308
- with gr.Tab("Refined Model"):
309
- refined_model = gr.Model3D(label="Refined Gaussian Splatting", display_mode="solid",
310
- clear_color=[0.0, 0.0, 0.0, 0.0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  reconstruct_btn.click(
313
  fn=reconstruct,
314
  inputs=[video_input, conf_thresh, kf_every, remove_background],
315
- outputs=[coarse_model, refined_model]
316
  )
317
 
318
  if __name__ == "__main__":
 
18
  import open3d as o3d
19
  from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
20
  from gs_utils import point2gs
21
+ from gradio.helpers import Examples as GradioExamples
22
+ from gradio.utils import get_cache_folder
23
+ from pathlib import Path
24
 
25
  # Default values
26
  DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
 
32
  [0, 0, -1, 0],
33
  [0, 0, 0, 1]])
34
 
35
+ class Examples(GradioExamples):
36
+ def __init__(self, *args, directory_name=None, **kwargs):
37
+ super().__init__(*args, **kwargs, _initiated_directly=False)
38
+ if directory_name is not None:
39
+ self.cached_folder = get_cache_folder() / directory_name
40
+ self.cached_file = Path(self.cached_folder) / "log.csv"
41
+ self.create()
42
+
43
  def export_geometry(geometry):
44
  output_path = tempfile.mktemp(suffix='.obj')
45
 
 
174
  # Convert mask to numpy array
175
  mask_np = np.array(mask) / 255.0
176
  return mask_np
177
+
178
+ def center_pcd(pcd: o3d.geometry.PointCloud, normalize=False) -> o3d.geometry.PointCloud:
179
+ # Convert to numpy array
180
+ points = np.asarray(pcd.points)
181
+
182
+ # Compute centroid
183
+ centroid = np.mean(points, axis=0)
184
+
185
+ # Center the point cloud
186
+ centered_points = points - centroid
187
+
188
+ if normalize:
189
+ # Compute the maximum distance from the center
190
+ max_distance = np.max(np.linalg.norm(centered_points, axis=1))
191
+
192
+ # Normalize the point cloud
193
+ normalized_points = centered_points / max_distance
194
+
195
+ # Create a new point cloud with the normalized points
196
+ normalized_pcd = o3d.geometry.PointCloud()
197
+ normalized_pcd.points = o3d.utility.Vector3dVector(normalized_points)
198
+
199
+ # If the original point cloud has colors, normalize them too
200
+ if pcd.has_colors():
201
+ normalized_pcd.colors = pcd.colors
202
+
203
+ # If the original point cloud has normals, copy them
204
+ if pcd.has_normals():
205
+ normalized_pcd.normals = pcd.normals
206
+
207
+ return normalized_pcd
208
+ else:
209
+ pcd.points = o3d.utility.Vector3dVector(centered_points)
210
+ return pcd
211
+
212
  @torch.no_grad()
213
  def reconstruct(video_path, conf_thresh, kf_every,
214
  remove_background=False):
 
255
  pcds.append(pcd)
256
 
257
  pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
258
+ pcd_combined = center_pcd(pcd_combined, normalize=True)
259
  o3d_geometry = point2mesh(pcd_combined)
260
 
261
  # Create coarse result
 
263
 
264
  yield coarse_output_path, None
265
 
 
 
266
  transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
267
+ transformed_pcds = center_pcd(transformed_pcds)
268
+
269
  # Create coarse result
270
  refined_output_path = tempfile.mktemp(suffix='.ply')
271
  point2gs(refined_output_path, transformed_pcds)
 
274
  # Clean up temporary directory
275
  os.system(f"rm -rf {demo_path}")
276
 
277
+ example_videos = [os.path.join('./examples', f) for f in os.listdir('./examples') if f.endswith(('.mp4', '.webm'))]
278
+
279
  # Update the Gradio interface with improved layout
280
  with gr.Blocks(
281
+ title="StableSpann3r: 3D Reconstruction from Video",
282
  css="""
283
  #download {
284
  height: 118px;
 
324
  """
325
  # StableSpann3r: Making Spann3r stable with Odometry Backend
326
  <p align="center">
 
 
 
 
 
 
327
  <a title="Github" href="https://github.com/Stable-X/StableSpann3r" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
328
  <img src="https://img.shields.io/github/stars/Stable-X/StableSpann3r?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
329
  </a>
 
335
  )
336
  with gr.Row():
337
  with gr.Column(scale=1):
338
+ video_input = gr.Video(label="Input Video", sources=["upload"])
339
  with gr.Row():
340
  conf_thresh = gr.Slider(0, 1, value=1e-3, label="Confidence Threshold")
341
  kf_every = gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval")
342
  with gr.Row():
343
  remove_background = gr.Checkbox(label="Remove Background", value=False)
344
+ reconstruct_btn = gr.Button("Start Reconstruction")
345
 
346
  with gr.Column(scale=2):
347
+ with gr.Tab("3D Models"):
348
+ with gr.Group():
349
+ initial_model = gr.Model3D(label="Initial 3D Model", display_mode="solid",
350
+ clear_color=[0.0, 0.0, 0.0, 0.0])
351
+ gr.Markdown(
352
+ """
353
+ <div class="model-description">
354
+ This is the initial 3D model generated from the video. Finish within 10 seconds.
355
+ </div>
356
+ """
357
+ )
358
+
359
+ with gr.Group():
360
+ optimized_model = gr.Model3D(label="Optimized 3D Model", display_mode="solid",
361
+ clear_color=[0.0, 0.0, 0.0, 0.0])
362
+ gr.Markdown(
363
+ """
364
+ <div class="model-description">
365
+ This is the optimized 3D model with improved accuracy and detail using Gaussian Splatting. Finish within 60 seconds.
366
+ </div>
367
+ """
368
+ )
369
+
370
+ with gr.Tab("Help"):
371
+ gr.Markdown(
372
+ """
373
+ ## How to use this tool:
374
+ 1. Upload a video of the object you want to reconstruct.
375
+ 2. Adjust the Confidence Threshold and Keyframe Interval if needed.
376
+ 3. Choose whether to remove the background.
377
+ 4. Click "Start Reconstruction" to begin the process.
378
+ 5. The Initial 3D Model will appear first, giving you a quick preview.
379
+ 6. Once processing is complete, the Optimized 3D Model will show the final result.
380
+
381
+ ### Tips:
382
+ - For best results, ensure your video captures the object from multiple angles.
383
+ - If the model appears noisy, try increasing the Confidence Threshold.
384
+ - Experiment with different Keyframe Intervals to balance speed and accuracy.
385
+ """
386
+ )
387
+
388
+ Examples(
389
+ fn=reconstruct,
390
+ examples=sorted([
391
+ os.path.join("examples", name)
392
+ for name in os.listdir(os.path.join("examples")) if name.endswith('.webm')
393
+ ]),
394
+ inputs=[video_input],
395
+ outputs=[initial_model, optimized_model],
396
+ directory_name="examples_video",
397
+ cache_examples=False,
398
+ )
399
 
400
  reconstruct_btn.click(
401
  fn=reconstruct,
402
  inputs=[video_input, conf_thresh, kf_every, remove_background],
403
+ outputs=[initial_model, optimized_model]
404
  )
405
 
406
  if __name__ == "__main__":