Stable-X commited on
Commit
d1dbe71
1 Parent(s): 6f6423c

feat: Add backend for refinement

Browse files
Files changed (3) hide show
  1. app.py +62 -15
  2. backend_utils.py +144 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -15,7 +15,9 @@ from scipy.spatial.transform import Rotation
15
  from transformers import AutoModelForImageSegmentation
16
  from torchvision import transforms
17
  from PIL import Image
18
- import spaces
 
 
19
 
20
  # Default values
21
  DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
@@ -143,7 +145,6 @@ def generate_mask(image: np.ndarray):
143
  mask_np = np.array(mask) / 255.0
144
  return mask_np
145
 
146
- @spaces.GPU
147
  @torch.no_grad()
148
  def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_background=False):
149
  # Extract frames from video
@@ -176,7 +177,7 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
176
  if remove_background:
177
  mask = generate_mask(image)
178
  else:
179
- mask = np.ones_like(conf) # Change this to match conf shape
180
 
181
  images_all.append((image[None, ...] + 1.0)/2.0)
182
  pts_all.append(pts[None, ...])
@@ -192,6 +193,54 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
192
  conf_sig_all = (conf_all-1) / conf_all
193
  combined_mask = (conf_sig_all > conf_thresh) & (mask_all > 0.5)
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  scene = trimesh.Scene()
196
 
197
  if as_pointcloud:
@@ -206,37 +255,35 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
206
  meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], combined_mask[i]))
207
  mesh = trimesh.Trimesh(**cat_meshes(meshes))
208
  scene.add_geometry(mesh)
209
-
210
  rot = np.eye(4)
211
  rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
212
  scene.apply_transform(np.linalg.inv(OPENGL @ rot))
213
-
214
- # Save the scene as GLB
215
  if as_pointcloud:
216
- output_path = tempfile.mktemp(suffix='.ply')
217
  else:
218
  output_path = tempfile.mktemp(suffix='.obj')
219
  scene.export(output_path)
220
-
221
- # Clean up temporary directory
222
- os.system(f"rm -rf {demo_path}")
223
-
224
- return output_path, f"Reconstruction completed. FPS: {fps:.2f}"
225
 
 
226
  iface = gr.Interface(
227
  fn=reconstruct,
228
  inputs=[
229
  gr.Video(label="Input Video"),
230
- gr.Slider(0, 1, value=1e-3, label="Confidence Threshold"),
231
  gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
232
  gr.Checkbox(label="As Pointcloud", value=False),
233
  gr.Checkbox(label="Remove Background", value=False)
234
  ],
235
  outputs=[
236
- gr.Model3D(label="3D Model", display_mode="solid"),
 
237
  gr.Textbox(label="Status")
238
  ],
239
- title="3D Reconstruction with Spatial Memory and Background Removal",
240
  )
241
 
242
  if __name__ == "__main__":
 
15
  from transformers import AutoModelForImageSegmentation
16
  from torchvision import transforms
17
  from PIL import Image
18
+ import open3d as o3d
19
+ from backend_utils import improved_multiway_registration
20
+
21
 
22
  # Default values
23
  DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
 
145
  mask_np = np.array(mask) / 255.0
146
  return mask_np
147
 
 
148
  @torch.no_grad()
149
  def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_background=False):
150
  # Extract frames from video
 
177
  if remove_background:
178
  mask = generate_mask(image)
179
  else:
180
+ mask = np.ones_like(conf)
181
 
182
  images_all.append((image[None, ...] + 1.0)/2.0)
183
  pts_all.append(pts[None, ...])
 
193
  conf_sig_all = (conf_all-1) / conf_all
194
  combined_mask = (conf_sig_all > conf_thresh) & (mask_all > 0.5)
195
 
196
+ # Create coarse result
197
+ coarse_scene = create_scene(pts_all, images_all, combined_mask, as_pointcloud)
198
+ coarse_output_path = save_scene(coarse_scene, as_pointcloud)
199
+
200
+ yield coarse_output_path, None, f"Reconstruction completed. FPS: {fps:.2f}"
201
+
202
+ # Create point clouds for multiway registration
203
+ pcds = []
204
+ for j in range(len(pts_all)):
205
+ pcd = o3d.geometry.PointCloud()
206
+ mask = combined_mask[j]
207
+ pcd.points = o3d.utility.Vector3dVector(pts_all[j][mask])
208
+ pcd.colors = o3d.utility.Vector3dVector(images_all[j][mask])
209
+ pcds.append(pcd)
210
+
211
+ # Perform global optimization
212
+ print("Performing global registration...")
213
+ transformed_pcds, pose_graph = improved_multiway_registration(pcds, voxel_size=0.01)
214
+
215
+ # Apply transformations from pose_graph to original pts_all
216
+ transformed_pts_all = np.zeros_like(pts_all)
217
+ for j in range(len(pts_all)):
218
+ # Get the transformation matrix from the pose graph
219
+ transformation = pose_graph.nodes[j].pose
220
+
221
+ # Reshape pts_all[j] to (H*W, 3)
222
+ H, W, _ = pts_all[j].shape
223
+ pts_reshaped = pts_all[j].reshape(-1, 3)
224
+
225
+ # Apply transformation to all points
226
+ homogeneous_pts = np.hstack((pts_reshaped, np.ones((pts_reshaped.shape[0], 1))))
227
+ transformed_pts = (transformation @ homogeneous_pts.T).T[:, :3]
228
+
229
+ # Reshape back to (H, W, 3) and store
230
+ transformed_pts_all[j] = transformed_pts.reshape(H, W, 3)
231
+
232
+ print(f"Original shape: {pts_all.shape}, Transformed shape: {transformed_pts_all.shape}")
233
+
234
+ # Create refined result
235
+ refined_scene = create_scene(transformed_pts_all, images_all, combined_mask, as_pointcloud)
236
+ refined_output_path = save_scene(refined_scene, as_pointcloud)
237
+
238
+ # Clean up temporary directory
239
+ os.system(f"rm -rf {demo_path}")
240
+
241
+ yield coarse_output_path, refined_output_path, f"Refinement completed. FPS: {fps:.2f}"
242
+
243
+ def create_scene(pts_all, images_all, combined_mask, as_pointcloud):
244
  scene = trimesh.Scene()
245
 
246
  if as_pointcloud:
 
255
  meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], combined_mask[i]))
256
  mesh = trimesh.Trimesh(**cat_meshes(meshes))
257
  scene.add_geometry(mesh)
258
+
259
  rot = np.eye(4)
260
  rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
261
  scene.apply_transform(np.linalg.inv(OPENGL @ rot))
262
+ return scene
263
+ def save_scene(scene, as_pointcloud):
264
  if as_pointcloud:
265
+ output_path = tempfile.mktemp(suffix='.ply')
266
  else:
267
  output_path = tempfile.mktemp(suffix='.obj')
268
  scene.export(output_path)
269
+ return output_path
 
 
 
 
270
 
271
+ # Update the Gradio interface
272
  iface = gr.Interface(
273
  fn=reconstruct,
274
  inputs=[
275
  gr.Video(label="Input Video"),
276
+ gr.Slider(0, 1, value=1e-6, label="Confidence Threshold"),
277
  gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
278
  gr.Checkbox(label="As Pointcloud", value=False),
279
  gr.Checkbox(label="Remove Background", value=False)
280
  ],
281
  outputs=[
282
+ gr.Model3D(label="Coarse 3D Model", display_mode="solid"),
283
+ gr.Model3D(label="Refined 3D Model", display_mode="solid"),
284
  gr.Textbox(label="Status")
285
  ],
286
+ title="3D Reconstruction with Spatial Memory, Background Removal, and Global Optimization",
287
  )
288
 
289
  if __name__ == "__main__":
backend_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import open3d as o3d
3
+
4
+ def improved_multiway_registration(pcds, voxel_size=0.05, max_correspondence_distance_coarse=None, max_correspondence_distance_fine=None, overlap=3, quadratic_overlap=True, use_colored_icp=True):
5
+ if max_correspondence_distance_coarse is None:
6
+ max_correspondence_distance_coarse = voxel_size * 15
7
+ if max_correspondence_distance_fine is None:
8
+ max_correspondence_distance_fine = voxel_size * 1.5
9
+
10
+ def preprocess_point_cloud(pcd, voxel_size):
11
+ pcd_down = pcd.voxel_down_sample(voxel_size)
12
+ pcd_down.estimate_normals(
13
+ o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2, max_nn=30))
14
+ # Apply statistical outlier removal
15
+ cl, ind = pcd_down.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
16
+ pcd_down = pcd_down.select_by_index(ind)
17
+ return pcd_down
18
+
19
+ def pairwise_registration(source, target, use_colored_icp, voxel_size, max_correspondence_distance_coarse, max_correspondence_distance_fine):
20
+ current_transformation = np.identity(4) # Start with identity matrix
21
+
22
+ if use_colored_icp:
23
+ print("Apply colored point cloud registration")
24
+ voxel_radius = [5*voxel_size, 3*voxel_size, voxel_size]
25
+ max_iter = [60, 35, 20]
26
+
27
+ for scale in range(3):
28
+ iter = max_iter[scale]
29
+ radius = voxel_radius[scale]
30
+
31
+ source_down = source.voxel_down_sample(radius)
32
+ target_down = target.voxel_down_sample(radius)
33
+
34
+ source_down.estimate_normals(
35
+ o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30))
36
+ target_down.estimate_normals(
37
+ o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30))
38
+
39
+ try:
40
+ result_icp = o3d.pipelines.registration.registration_colored_icp(
41
+ source_down, target_down, radius, current_transformation,
42
+ o3d.pipelines.registration.TransformationEstimationForColoredICP(),
43
+ o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6,
44
+ relative_rmse=1e-6,
45
+ max_iteration=iter))
46
+ current_transformation = result_icp.transformation
47
+ except RuntimeError as e:
48
+ print(f"Colored ICP failed at scale {scale}: {str(e)}")
49
+ print("Keeping the previous transformation")
50
+ # We keep the previous transformation, no need to reassign
51
+
52
+ transformation_icp = current_transformation
53
+ else:
54
+ print("Apply point-to-plane ICP")
55
+ try:
56
+ icp_coarse = o3d.pipelines.registration.registration_icp(
57
+ source, target, max_correspondence_distance_coarse, current_transformation,
58
+ o3d.pipelines.registration.TransformationEstimationPointToPlane())
59
+ current_transformation = icp_coarse.transformation
60
+
61
+ icp_fine = o3d.pipelines.registration.registration_icp(
62
+ source, target, max_correspondence_distance_fine,
63
+ current_transformation,
64
+ o3d.pipelines.registration.TransformationEstimationPointToPlane())
65
+ transformation_icp = icp_fine.transformation
66
+ except RuntimeError as e:
67
+ print(f"Point-to-plane ICP failed: {str(e)}")
68
+ print("Keeping the best available transformation")
69
+ transformation_icp = current_transformation
70
+
71
+ try:
72
+ information_icp = o3d.pipelines.registration.get_information_matrix_from_point_clouds(
73
+ source, target, max_correspondence_distance_fine,
74
+ transformation_icp)
75
+ except RuntimeError as e:
76
+ print(f"Failed to compute information matrix: {str(e)}")
77
+ print("Using identity information matrix")
78
+ information_icp = np.identity(6)
79
+
80
+ return transformation_icp, information_icp
81
+
82
+ def full_registration(pcds_down):
83
+ pose_graph = o3d.pipelines.registration.PoseGraph()
84
+ odometry = np.identity(4)
85
+ pose_graph.nodes.append(o3d.pipelines.registration.PoseGraphNode(odometry))
86
+ n_pcds = len(pcds_down)
87
+
88
+ pairs = []
89
+ for i in range(n_pcds - 1):
90
+ for j in range(i + 1, min(i + overlap + 1, n_pcds)):
91
+ pairs.append((i, j))
92
+ if quadratic_overlap:
93
+ q = 2**(j-i)
94
+ if q > overlap and i + q < n_pcds:
95
+ pairs.append((i, i + q))
96
+
97
+ for source_id, target_id in pairs:
98
+ transformation_icp, information_icp = pairwise_registration(
99
+ pcds_down[source_id], pcds_down[target_id], use_colored_icp,
100
+ voxel_size, max_correspondence_distance_coarse, max_correspondence_distance_fine)
101
+ print(f"Build PoseGraph: {source_id} -> {target_id}")
102
+
103
+ if target_id == source_id + 1:
104
+ odometry = np.dot(transformation_icp, odometry)
105
+ pose_graph.nodes.append(
106
+ o3d.pipelines.registration.PoseGraphNode(
107
+ np.linalg.inv(odometry)))
108
+
109
+ pose_graph.edges.append(
110
+ o3d.pipelines.registration.PoseGraphEdge(source_id,
111
+ target_id,
112
+ transformation_icp,
113
+ information_icp,
114
+ uncertain=False))
115
+ return pose_graph
116
+
117
+ # Preprocess point clouds
118
+ print("Preprocessing point clouds...")
119
+ pcds_down = [preprocess_point_cloud(pcd, voxel_size) for pcd in pcds]
120
+
121
+ print("Full registration ...")
122
+ pose_graph = full_registration(pcds_down)
123
+
124
+ print("Optimizing PoseGraph ...")
125
+ option = o3d.pipelines.registration.GlobalOptimizationOption(
126
+ max_correspondence_distance=max_correspondence_distance_fine,
127
+ edge_prune_threshold=0.25,
128
+ reference_node=0)
129
+
130
+ with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
131
+ o3d.pipelines.registration.global_optimization(
132
+ pose_graph,
133
+ o3d.pipelines.registration.GlobalOptimizationLevenbergMarquardt(),
134
+ o3d.pipelines.registration.GlobalOptimizationConvergenceCriteria(),
135
+ option)
136
+
137
+ print("Transform points and combine")
138
+ pcd_combined = o3d.geometry.PointCloud()
139
+ for point_id in range(len(pcds)):
140
+ print(pose_graph.nodes[point_id].pose)
141
+ pcds[point_id].transform(pose_graph.nodes[point_id].pose)
142
+ pcd_combined += pcds[point_id]
143
+
144
+ return pcd_combined, pose_graph
requirements.txt CHANGED
@@ -16,4 +16,5 @@ gdown
16
  imageio[ffmpeg]
17
  transformers
18
  kornia
19
- timm
 
 
16
  imageio[ffmpeg]
17
  transformers
18
  kornia
19
+ timm
20
+ open3d