Stable-X commited on
Commit
727fb54
1 Parent(s): fe87d83

fix: Replace render output with point cloud

Browse files
Files changed (2) hide show
  1. app.py +13 -35
  2. vis_utils.py +0 -171
app.py CHANGED
@@ -16,7 +16,6 @@ from transformers import AutoModelForImageSegmentation
16
  from torchvision import transforms
17
  from PIL import Image
18
  import open3d as o3d
19
- from vis_utils import render_frames
20
  from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
21
  from gs_utils import point2gs
22
  from pose_utils import solve_cemara
@@ -306,23 +305,19 @@ def reconstruct(video_path, conf_thresh, kf_every,
306
 
307
  # Create coarse result
308
  coarse_output_path = export_geometry(o3d_geometry_centered)
309
- yield coarse_output_path, None
310
 
311
- gs_output_path = tempfile.mktemp(suffix='.ply')
312
  if enable_registration:
313
- transformed_pcds, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
314
- transformed_pcds = center_pcd(transformed_pcds)
315
- point2gs(gs_output_path, transformed_pcds)
316
- else:
317
- point2gs(gs_output_path, pcd_combined)
318
 
319
  if output_3d_model:
 
 
320
  # Create 3D model result using gaussian splatting
321
- yield coarse_output_path, gs_output_path
322
  else:
323
- gs_output_path = tempfile.mktemp(suffix='.ply')
324
- render_video_path = render_frames(o3d_geometry, cameras_all, demo_path)
325
- yield coarse_output_path, render_video_path
326
 
327
  # Clean up temporary directory
328
  os.system(f"rm -rf {demo_path}")
@@ -406,7 +401,7 @@ with gr.Blocks(
406
  output_3d_model = gr.Checkbox(
407
  label="Output Splat",
408
  value=True,
409
- info="Generate Splat (PLY) instead of video render"
410
  )
411
  reconstruct_btn = gr.Button("Start Reconstruction")
412
 
@@ -414,33 +409,16 @@ with gr.Blocks(
414
  with gr.Tab("3D Models"):
415
  with gr.Group():
416
  initial_model = gr.Model3D(
417
- label="Initial 3D Model",
418
  display_mode="solid",
419
  clear_color=[0.0, 0.0, 0.0, 0.0]
420
  )
421
- gr.Markdown(
422
- """
423
- <div class="model-description">
424
- This is the initial 3D model generated from the video. Finish within 10 seconds.
425
- </div>
426
- """
427
- )
428
 
429
  with gr.Group():
430
- output_model = gr.File(
431
- label="Refined Result (Splat or Video)",
432
- file_types=[".ply", ".mp4"],
433
- file_count="single"
434
- )
435
- gr.Markdown(
436
- """
437
- <div class="model-description">
438
- Downloads as either:
439
- - PLY file: Gaussin Splat Model (when "Output Splat" is enabled)
440
- - MP4 file: 360° rotating render video (when "Output Splat" is disabled)
441
- <br>Time: ~60 seconds with refinement, ~30 seconds without
442
- </div>
443
- """
444
  )
445
 
446
  Examples(
 
16
  from torchvision import transforms
17
  from PIL import Image
18
  import open3d as o3d
 
19
  from backend_utils import improved_multiway_registration, pts2normal, point2mesh, combine_and_clean_point_clouds
20
  from gs_utils import point2gs
21
  from pose_utils import solve_cemara
 
305
 
306
  # Create coarse result
307
  coarse_output_path = export_geometry(o3d_geometry_centered)
 
308
 
 
309
  if enable_registration:
310
+ pcd_combined, _, _ = improved_multiway_registration(pcds, voxel_size=0.01)
311
+ pcd_combined = center_pcd(pcd_combined)
 
 
 
312
 
313
  if output_3d_model:
314
+ gs_output_path = tempfile.mktemp(suffix='.ply')
315
+ point2gs(gs_output_path, pcd_combined)
316
  # Create 3D model result using gaussian splatting
317
+ return coarse_output_path, gs_output_path
318
  else:
319
+ pcd_output_path = export_geometry(pcd_combined)
320
+ return coarse_output_path, pcd_output_path
 
321
 
322
  # Clean up temporary directory
323
  os.system(f"rm -rf {demo_path}")
 
401
  output_3d_model = gr.Checkbox(
402
  label="Output Splat",
403
  value=True,
404
+ info="Generate Splat (PLY) instead of Point Cloud (PLY)"
405
  )
406
  reconstruct_btn = gr.Button("Start Reconstruction")
407
 
 
409
  with gr.Tab("3D Models"):
410
  with gr.Group():
411
  initial_model = gr.Model3D(
412
+ label="Reconstructed Mesh",
413
  display_mode="solid",
414
  clear_color=[0.0, 0.0, 0.0, 0.0]
415
  )
 
 
 
 
 
 
 
416
 
417
  with gr.Group():
418
+ initial_model = gr.Model3D(
419
+ label="Reconstructed PointCloud or Splat",
420
+ display_mode="solid",
421
+ clear_color=[0.0, 0.0, 0.0, 0.0]
 
 
 
 
 
 
 
 
 
 
422
  )
423
 
424
  Examples(
vis_utils.py DELETED
@@ -1,171 +0,0 @@
1
- import os
2
- import cv2
3
- import imageio
4
- import numpy as np
5
- import open3d as o3d
6
- import os.path as osp
7
- import matplotlib.pyplot as plt
8
- import matplotlib.colors as mcolors
9
-
10
- def render_frames(o3d_geometry, camera_all, output_dir, save_video=True, save_camera=True):
11
- # Create off-screen renderer
12
- render = o3d.visualization.rendering.OffscreenRenderer(
13
- width=camera_all[0].intrinsic.width,
14
- height=camera_all[0].intrinsic.height
15
- )
16
-
17
- render_frame_path = os.path.join(output_dir, 'render_frames')
18
- render_camera_path = os.path.join(output_dir, 'render_cameras')
19
- os.makedirs(render_frame_path, exist_ok=True)
20
- os.makedirs(render_camera_path, exist_ok=True)
21
-
22
- video_path = os.path.join(output_dir, 'render_frame.mp4')
23
- if save_video:
24
- writer = imageio.get_writer(video_path, fps=10)
25
-
26
- material = o3d.visualization.rendering.MaterialRecord()
27
- material.shader = 'defaultUnlit' # Use unlit shader for point clouds
28
- material.point_size = 1.0 # Match original point size
29
- material.base_color = [1.0, 1.0, 1.0, 1.0]
30
-
31
- for i, camera_params in enumerate(camera_all):
32
- if camera_params is None:
33
- continue
34
-
35
- # Set camera view
36
- render.setup_camera(
37
- camera_params.intrinsic.intrinsic_matrix,
38
- camera_params.extrinsic,
39
- camera_params.intrinsic.width,
40
- camera_params.intrinsic.height
41
- )
42
-
43
- if save_camera:
44
- o3d.io.write_pinhole_camera_parameters(
45
- os.path.join(render_camera_path, f'camera_{i:03d}.json'),
46
- camera_params
47
- )
48
-
49
- # Render
50
- render.scene.add_geometry("points", o3d_geometry, material)
51
- img = render.render_to_image()
52
- render.scene.remove_geometry("points")
53
-
54
- # Save frame
55
- image_uint8 = (np.asarray(img) * 255).astype(np.uint8)
56
- frame_filename = f'frame_{i:03d}.png'
57
- imageio.imwrite(osp.join(render_frame_path, frame_filename), image_uint8)
58
-
59
- if save_video:
60
- writer.append_data(image_uint8)
61
-
62
- if save_video:
63
- writer.close()
64
-
65
- return video_path
66
-
67
- def find_render_cam(pcd, width=1920, height=1080):
68
- # For headless servers, we'll need to pre-define camera parameters
69
- # This creates a default viewing angle looking at the center of the point cloud
70
-
71
- # Calculate point cloud center and scale
72
- center = pcd.get_center()
73
- scale = np.max(pcd.get_max_bound() - pcd.get_min_bound())
74
-
75
- # Create default camera parameters
76
- camera_params = o3d.camera.PinholeCameraParameters()
77
-
78
- # Set intrinsic parameters
79
- intrinsic = o3d.camera.PinholeCameraIntrinsic()
80
- intrinsic.set_intrinsics(
81
- width=width,
82
- height=height,
83
- fx=width,
84
- fy=width,
85
- cx=width/2,
86
- cy=height/2
87
- )
88
- camera_params.intrinsic = intrinsic
89
-
90
- # Set extrinsic parameters (looking at center from a 45-degree angle)
91
- camera_params.extrinsic = np.array([
92
- [1, 0, 0, 0],
93
- [0, np.cos(np.pi/4), -np.sin(np.pi/4), 0],
94
- [0, np.sin(np.pi/4), np.cos(np.pi/4), 2*scale],
95
- [0, 0, 0, 1]
96
- ])
97
-
98
- return camera_params
99
-
100
- def vis_pred_and_imgs(pts_all, save_path, images_all=None, conf_all=None, save_video=True):
101
- # Set matplotlib backend to non-interactive
102
- plt.switch_backend('Agg')
103
-
104
- # Normalization
105
- min_val = pts_all.min(axis=(0, 1, 2), keepdims=True)
106
- max_val = pts_all.max(axis=(0, 1, 2), keepdims=True)
107
- pts_all = (pts_all - min_val) / (max_val - min_val)
108
-
109
- pts_save_path = osp.join(save_path, 'pts')
110
- os.makedirs(pts_save_path, exist_ok=True)
111
-
112
- if images_all is not None:
113
- images_save_path = osp.join(save_path, 'imgs')
114
- os.makedirs(images_save_path, exist_ok=True)
115
-
116
- if conf_all is not None:
117
- conf_save_path = osp.join(save_path, 'confs')
118
- os.makedirs(conf_save_path, exist_ok=True)
119
-
120
- if save_video:
121
- pts_video_path = osp.join(save_path, 'pts.mp4')
122
- pts_writer = imageio.get_writer(pts_video_path, fps=10)
123
-
124
- if images_all is not None:
125
- imgs_video_path = osp.join(save_path, 'imgs.mp4')
126
- imgs_writer = imageio.get_writer(imgs_video_path, fps=10)
127
-
128
- if conf_all is not None:
129
- conf_video_path = osp.join(save_path, 'confs.mp4')
130
- conf_writer = imageio.get_writer(conf_video_path, fps=10)
131
-
132
- for frame_id in range(pts_all.shape[0]):
133
- # Points visualization
134
- pt_vis = pts_all[frame_id].astype(np.float32)
135
- pt_vis_rgb = mcolors.hsv_to_rgb(1-pt_vis)
136
- pt_vis_rgb_uint8 = (pt_vis_rgb * 255).astype(np.uint8)
137
-
138
- # Use matplotlib in non-interactive mode
139
- fig, ax = plt.subplots()
140
- ax.imshow(pt_vis_rgb_uint8)
141
- plt.savefig(osp.join(pts_save_path, f'pts_{frame_id:04d}.png'))
142
- plt.close(fig)
143
-
144
- if save_video:
145
- pts_writer.append_data(pt_vis_rgb_uint8)
146
-
147
- if images_all is not None:
148
- image = images_all[frame_id]
149
- image_uint8 = (image * 255).astype(np.uint8)
150
- imageio.imwrite(osp.join(images_save_path, f'img_{frame_id:04d}.png'), image_uint8)
151
-
152
- if save_video:
153
- imgs_writer.append_data(image_uint8)
154
-
155
- if conf_all is not None:
156
- fig, ax = plt.subplots()
157
- conf_image = plt.cm.jet(conf_all[frame_id])
158
- ax.imshow(conf_image)
159
- plt.savefig(osp.join(conf_save_path, f'conf_{frame_id:04d}.png'))
160
- plt.close(fig)
161
-
162
- conf_image_uint8 = (conf_image * 255).astype(np.uint8)
163
- if save_video:
164
- conf_writer.append_data(conf_image_uint8)
165
-
166
- if save_video:
167
- pts_writer.close()
168
- if images_all is not None:
169
- imgs_writer.close()
170
- if conf_all is not None:
171
- conf_writer.close()