Spaces:
Sleeping
Sleeping
feat: Add backend for refinement
Browse files- app.py +62 -15
- backend_utils.py +144 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -15,7 +15,9 @@ from scipy.spatial.transform import Rotation
|
|
15 |
from transformers import AutoModelForImageSegmentation
|
16 |
from torchvision import transforms
|
17 |
from PIL import Image
|
18 |
-
import
|
|
|
|
|
19 |
|
20 |
# Default values
|
21 |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
|
@@ -143,7 +145,6 @@ def generate_mask(image: np.ndarray):
|
|
143 |
mask_np = np.array(mask) / 255.0
|
144 |
return mask_np
|
145 |
|
146 |
-
@spaces.GPU
|
147 |
@torch.no_grad()
|
148 |
def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_background=False):
|
149 |
# Extract frames from video
|
@@ -176,7 +177,7 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
|
|
176 |
if remove_background:
|
177 |
mask = generate_mask(image)
|
178 |
else:
|
179 |
-
mask = np.ones_like(conf)
|
180 |
|
181 |
images_all.append((image[None, ...] + 1.0)/2.0)
|
182 |
pts_all.append(pts[None, ...])
|
@@ -192,6 +193,54 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
|
|
192 |
conf_sig_all = (conf_all-1) / conf_all
|
193 |
combined_mask = (conf_sig_all > conf_thresh) & (mask_all > 0.5)
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
scene = trimesh.Scene()
|
196 |
|
197 |
if as_pointcloud:
|
@@ -206,37 +255,35 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_b
|
|
206 |
meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], combined_mask[i]))
|
207 |
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
208 |
scene.add_geometry(mesh)
|
209 |
-
|
210 |
rot = np.eye(4)
|
211 |
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
212 |
scene.apply_transform(np.linalg.inv(OPENGL @ rot))
|
213 |
-
|
214 |
-
|
215 |
if as_pointcloud:
|
216 |
-
output_path = tempfile.mktemp(suffix='.ply')
|
217 |
else:
|
218 |
output_path = tempfile.mktemp(suffix='.obj')
|
219 |
scene.export(output_path)
|
220 |
-
|
221 |
-
# Clean up temporary directory
|
222 |
-
os.system(f"rm -rf {demo_path}")
|
223 |
-
|
224 |
-
return output_path, f"Reconstruction completed. FPS: {fps:.2f}"
|
225 |
|
|
|
226 |
iface = gr.Interface(
|
227 |
fn=reconstruct,
|
228 |
inputs=[
|
229 |
gr.Video(label="Input Video"),
|
230 |
-
gr.Slider(0, 1, value=1e-
|
231 |
gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
|
232 |
gr.Checkbox(label="As Pointcloud", value=False),
|
233 |
gr.Checkbox(label="Remove Background", value=False)
|
234 |
],
|
235 |
outputs=[
|
236 |
-
gr.Model3D(label="3D Model", display_mode="solid"),
|
|
|
237 |
gr.Textbox(label="Status")
|
238 |
],
|
239 |
-
title="3D Reconstruction with Spatial Memory and
|
240 |
)
|
241 |
|
242 |
if __name__ == "__main__":
|
|
|
15 |
from transformers import AutoModelForImageSegmentation
|
16 |
from torchvision import transforms
|
17 |
from PIL import Image
|
18 |
+
import open3d as o3d
|
19 |
+
from backend_utils import improved_multiway_registration
|
20 |
+
|
21 |
|
22 |
# Default values
|
23 |
DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
|
|
|
145 |
mask_np = np.array(mask) / 255.0
|
146 |
return mask_np
|
147 |
|
|
|
148 |
@torch.no_grad()
|
149 |
def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_background=False):
|
150 |
# Extract frames from video
|
|
|
177 |
if remove_background:
|
178 |
mask = generate_mask(image)
|
179 |
else:
|
180 |
+
mask = np.ones_like(conf)
|
181 |
|
182 |
images_all.append((image[None, ...] + 1.0)/2.0)
|
183 |
pts_all.append(pts[None, ...])
|
|
|
193 |
conf_sig_all = (conf_all-1) / conf_all
|
194 |
combined_mask = (conf_sig_all > conf_thresh) & (mask_all > 0.5)
|
195 |
|
196 |
+
# Create coarse result
|
197 |
+
coarse_scene = create_scene(pts_all, images_all, combined_mask, as_pointcloud)
|
198 |
+
coarse_output_path = save_scene(coarse_scene, as_pointcloud)
|
199 |
+
|
200 |
+
yield coarse_output_path, None, f"Reconstruction completed. FPS: {fps:.2f}"
|
201 |
+
|
202 |
+
# Create point clouds for multiway registration
|
203 |
+
pcds = []
|
204 |
+
for j in range(len(pts_all)):
|
205 |
+
pcd = o3d.geometry.PointCloud()
|
206 |
+
mask = combined_mask[j]
|
207 |
+
pcd.points = o3d.utility.Vector3dVector(pts_all[j][mask])
|
208 |
+
pcd.colors = o3d.utility.Vector3dVector(images_all[j][mask])
|
209 |
+
pcds.append(pcd)
|
210 |
+
|
211 |
+
# Perform global optimization
|
212 |
+
print("Performing global registration...")
|
213 |
+
transformed_pcds, pose_graph = improved_multiway_registration(pcds, voxel_size=0.01)
|
214 |
+
|
215 |
+
# Apply transformations from pose_graph to original pts_all
|
216 |
+
transformed_pts_all = np.zeros_like(pts_all)
|
217 |
+
for j in range(len(pts_all)):
|
218 |
+
# Get the transformation matrix from the pose graph
|
219 |
+
transformation = pose_graph.nodes[j].pose
|
220 |
+
|
221 |
+
# Reshape pts_all[j] to (H*W, 3)
|
222 |
+
H, W, _ = pts_all[j].shape
|
223 |
+
pts_reshaped = pts_all[j].reshape(-1, 3)
|
224 |
+
|
225 |
+
# Apply transformation to all points
|
226 |
+
homogeneous_pts = np.hstack((pts_reshaped, np.ones((pts_reshaped.shape[0], 1))))
|
227 |
+
transformed_pts = (transformation @ homogeneous_pts.T).T[:, :3]
|
228 |
+
|
229 |
+
# Reshape back to (H, W, 3) and store
|
230 |
+
transformed_pts_all[j] = transformed_pts.reshape(H, W, 3)
|
231 |
+
|
232 |
+
print(f"Original shape: {pts_all.shape}, Transformed shape: {transformed_pts_all.shape}")
|
233 |
+
|
234 |
+
# Create refined result
|
235 |
+
refined_scene = create_scene(transformed_pts_all, images_all, combined_mask, as_pointcloud)
|
236 |
+
refined_output_path = save_scene(refined_scene, as_pointcloud)
|
237 |
+
|
238 |
+
# Clean up temporary directory
|
239 |
+
os.system(f"rm -rf {demo_path}")
|
240 |
+
|
241 |
+
yield coarse_output_path, refined_output_path, f"Refinement completed. FPS: {fps:.2f}"
|
242 |
+
|
243 |
+
def create_scene(pts_all, images_all, combined_mask, as_pointcloud):
|
244 |
scene = trimesh.Scene()
|
245 |
|
246 |
if as_pointcloud:
|
|
|
255 |
meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], combined_mask[i]))
|
256 |
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
257 |
scene.add_geometry(mesh)
|
258 |
+
|
259 |
rot = np.eye(4)
|
260 |
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
261 |
scene.apply_transform(np.linalg.inv(OPENGL @ rot))
|
262 |
+
return scene
|
263 |
+
def save_scene(scene, as_pointcloud):
|
264 |
if as_pointcloud:
|
265 |
+
output_path = tempfile.mktemp(suffix='.ply')
|
266 |
else:
|
267 |
output_path = tempfile.mktemp(suffix='.obj')
|
268 |
scene.export(output_path)
|
269 |
+
return output_path
|
|
|
|
|
|
|
|
|
270 |
|
271 |
+
# Update the Gradio interface
|
272 |
iface = gr.Interface(
|
273 |
fn=reconstruct,
|
274 |
inputs=[
|
275 |
gr.Video(label="Input Video"),
|
276 |
+
gr.Slider(0, 1, value=1e-6, label="Confidence Threshold"),
|
277 |
gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
|
278 |
gr.Checkbox(label="As Pointcloud", value=False),
|
279 |
gr.Checkbox(label="Remove Background", value=False)
|
280 |
],
|
281 |
outputs=[
|
282 |
+
gr.Model3D(label="Coarse 3D Model", display_mode="solid"),
|
283 |
+
gr.Model3D(label="Refined 3D Model", display_mode="solid"),
|
284 |
gr.Textbox(label="Status")
|
285 |
],
|
286 |
+
title="3D Reconstruction with Spatial Memory, Background Removal, and Global Optimization",
|
287 |
)
|
288 |
|
289 |
if __name__ == "__main__":
|
backend_utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import open3d as o3d
|
3 |
+
|
4 |
+
def improved_multiway_registration(pcds, voxel_size=0.05, max_correspondence_distance_coarse=None, max_correspondence_distance_fine=None, overlap=3, quadratic_overlap=True, use_colored_icp=True):
|
5 |
+
if max_correspondence_distance_coarse is None:
|
6 |
+
max_correspondence_distance_coarse = voxel_size * 15
|
7 |
+
if max_correspondence_distance_fine is None:
|
8 |
+
max_correspondence_distance_fine = voxel_size * 1.5
|
9 |
+
|
10 |
+
def preprocess_point_cloud(pcd, voxel_size):
|
11 |
+
pcd_down = pcd.voxel_down_sample(voxel_size)
|
12 |
+
pcd_down.estimate_normals(
|
13 |
+
o3d.geometry.KDTreeSearchParamHybrid(radius=voxel_size * 2, max_nn=30))
|
14 |
+
# Apply statistical outlier removal
|
15 |
+
cl, ind = pcd_down.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
|
16 |
+
pcd_down = pcd_down.select_by_index(ind)
|
17 |
+
return pcd_down
|
18 |
+
|
19 |
+
def pairwise_registration(source, target, use_colored_icp, voxel_size, max_correspondence_distance_coarse, max_correspondence_distance_fine):
|
20 |
+
current_transformation = np.identity(4) # Start with identity matrix
|
21 |
+
|
22 |
+
if use_colored_icp:
|
23 |
+
print("Apply colored point cloud registration")
|
24 |
+
voxel_radius = [5*voxel_size, 3*voxel_size, voxel_size]
|
25 |
+
max_iter = [60, 35, 20]
|
26 |
+
|
27 |
+
for scale in range(3):
|
28 |
+
iter = max_iter[scale]
|
29 |
+
radius = voxel_radius[scale]
|
30 |
+
|
31 |
+
source_down = source.voxel_down_sample(radius)
|
32 |
+
target_down = target.voxel_down_sample(radius)
|
33 |
+
|
34 |
+
source_down.estimate_normals(
|
35 |
+
o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30))
|
36 |
+
target_down.estimate_normals(
|
37 |
+
o3d.geometry.KDTreeSearchParamHybrid(radius=radius * 2, max_nn=30))
|
38 |
+
|
39 |
+
try:
|
40 |
+
result_icp = o3d.pipelines.registration.registration_colored_icp(
|
41 |
+
source_down, target_down, radius, current_transformation,
|
42 |
+
o3d.pipelines.registration.TransformationEstimationForColoredICP(),
|
43 |
+
o3d.pipelines.registration.ICPConvergenceCriteria(relative_fitness=1e-6,
|
44 |
+
relative_rmse=1e-6,
|
45 |
+
max_iteration=iter))
|
46 |
+
current_transformation = result_icp.transformation
|
47 |
+
except RuntimeError as e:
|
48 |
+
print(f"Colored ICP failed at scale {scale}: {str(e)}")
|
49 |
+
print("Keeping the previous transformation")
|
50 |
+
# We keep the previous transformation, no need to reassign
|
51 |
+
|
52 |
+
transformation_icp = current_transformation
|
53 |
+
else:
|
54 |
+
print("Apply point-to-plane ICP")
|
55 |
+
try:
|
56 |
+
icp_coarse = o3d.pipelines.registration.registration_icp(
|
57 |
+
source, target, max_correspondence_distance_coarse, current_transformation,
|
58 |
+
o3d.pipelines.registration.TransformationEstimationPointToPlane())
|
59 |
+
current_transformation = icp_coarse.transformation
|
60 |
+
|
61 |
+
icp_fine = o3d.pipelines.registration.registration_icp(
|
62 |
+
source, target, max_correspondence_distance_fine,
|
63 |
+
current_transformation,
|
64 |
+
o3d.pipelines.registration.TransformationEstimationPointToPlane())
|
65 |
+
transformation_icp = icp_fine.transformation
|
66 |
+
except RuntimeError as e:
|
67 |
+
print(f"Point-to-plane ICP failed: {str(e)}")
|
68 |
+
print("Keeping the best available transformation")
|
69 |
+
transformation_icp = current_transformation
|
70 |
+
|
71 |
+
try:
|
72 |
+
information_icp = o3d.pipelines.registration.get_information_matrix_from_point_clouds(
|
73 |
+
source, target, max_correspondence_distance_fine,
|
74 |
+
transformation_icp)
|
75 |
+
except RuntimeError as e:
|
76 |
+
print(f"Failed to compute information matrix: {str(e)}")
|
77 |
+
print("Using identity information matrix")
|
78 |
+
information_icp = np.identity(6)
|
79 |
+
|
80 |
+
return transformation_icp, information_icp
|
81 |
+
|
82 |
+
def full_registration(pcds_down):
|
83 |
+
pose_graph = o3d.pipelines.registration.PoseGraph()
|
84 |
+
odometry = np.identity(4)
|
85 |
+
pose_graph.nodes.append(o3d.pipelines.registration.PoseGraphNode(odometry))
|
86 |
+
n_pcds = len(pcds_down)
|
87 |
+
|
88 |
+
pairs = []
|
89 |
+
for i in range(n_pcds - 1):
|
90 |
+
for j in range(i + 1, min(i + overlap + 1, n_pcds)):
|
91 |
+
pairs.append((i, j))
|
92 |
+
if quadratic_overlap:
|
93 |
+
q = 2**(j-i)
|
94 |
+
if q > overlap and i + q < n_pcds:
|
95 |
+
pairs.append((i, i + q))
|
96 |
+
|
97 |
+
for source_id, target_id in pairs:
|
98 |
+
transformation_icp, information_icp = pairwise_registration(
|
99 |
+
pcds_down[source_id], pcds_down[target_id], use_colored_icp,
|
100 |
+
voxel_size, max_correspondence_distance_coarse, max_correspondence_distance_fine)
|
101 |
+
print(f"Build PoseGraph: {source_id} -> {target_id}")
|
102 |
+
|
103 |
+
if target_id == source_id + 1:
|
104 |
+
odometry = np.dot(transformation_icp, odometry)
|
105 |
+
pose_graph.nodes.append(
|
106 |
+
o3d.pipelines.registration.PoseGraphNode(
|
107 |
+
np.linalg.inv(odometry)))
|
108 |
+
|
109 |
+
pose_graph.edges.append(
|
110 |
+
o3d.pipelines.registration.PoseGraphEdge(source_id,
|
111 |
+
target_id,
|
112 |
+
transformation_icp,
|
113 |
+
information_icp,
|
114 |
+
uncertain=False))
|
115 |
+
return pose_graph
|
116 |
+
|
117 |
+
# Preprocess point clouds
|
118 |
+
print("Preprocessing point clouds...")
|
119 |
+
pcds_down = [preprocess_point_cloud(pcd, voxel_size) for pcd in pcds]
|
120 |
+
|
121 |
+
print("Full registration ...")
|
122 |
+
pose_graph = full_registration(pcds_down)
|
123 |
+
|
124 |
+
print("Optimizing PoseGraph ...")
|
125 |
+
option = o3d.pipelines.registration.GlobalOptimizationOption(
|
126 |
+
max_correspondence_distance=max_correspondence_distance_fine,
|
127 |
+
edge_prune_threshold=0.25,
|
128 |
+
reference_node=0)
|
129 |
+
|
130 |
+
with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
|
131 |
+
o3d.pipelines.registration.global_optimization(
|
132 |
+
pose_graph,
|
133 |
+
o3d.pipelines.registration.GlobalOptimizationLevenbergMarquardt(),
|
134 |
+
o3d.pipelines.registration.GlobalOptimizationConvergenceCriteria(),
|
135 |
+
option)
|
136 |
+
|
137 |
+
print("Transform points and combine")
|
138 |
+
pcd_combined = o3d.geometry.PointCloud()
|
139 |
+
for point_id in range(len(pcds)):
|
140 |
+
print(pose_graph.nodes[point_id].pose)
|
141 |
+
pcds[point_id].transform(pose_graph.nodes[point_id].pose)
|
142 |
+
pcd_combined += pcds[point_id]
|
143 |
+
|
144 |
+
return pcd_combined, pose_graph
|
requirements.txt
CHANGED
@@ -16,4 +16,5 @@ gdown
|
|
16 |
imageio[ffmpeg]
|
17 |
transformers
|
18 |
kornia
|
19 |
-
timm
|
|
|
|
16 |
imageio[ffmpeg]
|
17 |
transformers
|
18 |
kornia
|
19 |
+
timm
|
20 |
+
open3d
|