Stable-X chongjie commited on
Commit
8e3d0ca
1 Parent(s): acc5f47

Update app.py (#1)

Browse files

- Update app.py (699bcd394d0c919b423d19d2b428de41b533eda7)


Co-authored-by: Hugo <[email protected]>

Files changed (1) hide show
  1. app.py +57 -15
app.py CHANGED
@@ -12,10 +12,12 @@ from spann3r.datasets import Demo
12
  from torch.utils.data import DataLoader
13
  import trimesh
14
  from scipy.spatial.transform import Rotation
15
- import spaces
 
 
16
 
17
  # Default values
18
- DEFAULT_CKPT_PATH = './checkpoints/spann3r.pth'
19
  DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
20
  DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
21
 
@@ -106,10 +108,42 @@ def pts3d_to_trimesh(img, pts3d, valid=None):
106
  return dict(vertices=vertices, face_colors=face_colors, faces=faces)
107
 
108
  model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- @spaces.GPU
111
  @torch.no_grad()
112
- def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False):
113
  # Extract frames from video
114
  demo_path = extract_frames(video_path)
115
 
@@ -131,36 +165,43 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False):
131
  print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}')
132
 
133
  # Process results
134
- pts_all, images_all, conf_all = [], [], []
135
  for j, view in enumerate(batch):
136
  image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
137
  pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
138
  conf = preds[j]['conf'][0].cpu().data.numpy()
139
 
 
 
 
 
 
140
  images_all.append((image[None, ...] + 1.0)/2.0)
141
  pts_all.append(pts[None, ...])
142
  conf_all.append(conf[None, ...])
 
143
 
144
  images_all = np.concatenate(images_all, axis=0)
145
  pts_all = np.concatenate(pts_all, axis=0) * 10
146
  conf_all = np.concatenate(conf_all, axis=0)
 
147
 
148
  # Create point cloud or mesh
149
  conf_sig_all = (conf_all-1) / conf_all
150
- mask = conf_sig_all > conf_thresh
151
 
152
  scene = trimesh.Scene()
153
 
154
  if as_pointcloud:
155
  pcd = trimesh.PointCloud(
156
- vertices=pts_all[mask].reshape(-1, 3),
157
- colors=images_all[mask].reshape(-1, 3)
158
  )
159
  scene.add_geometry(pcd)
160
  else:
161
  meshes = []
162
  for i in range(len(images_all)):
163
- meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], mask[i]))
164
  mesh = trimesh.Trimesh(**cat_meshes(meshes))
165
  scene.add_geometry(mesh)
166
 
@@ -168,11 +209,11 @@ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False):
168
  rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
169
  scene.apply_transform(np.linalg.inv(OPENGL @ rot))
170
 
 
171
  if as_pointcloud:
172
  output_path = tempfile.mktemp(suffix='.ply')
173
  else:
174
  output_path = tempfile.mktemp(suffix='.obj')
175
-
176
  scene.export(output_path)
177
 
178
  # Clean up temporary directory
@@ -185,15 +226,16 @@ iface = gr.Interface(
185
  inputs=[
186
  gr.Video(label="Input Video"),
187
  gr.Slider(0, 1, value=1e-3, label="Confidence Threshold"),
188
- gr.Slider(1, 30, step=1, value=1, label="Keyframe Interval"),
189
- gr.Checkbox(label="As Pointcloud", value=False)
 
190
  ],
191
  outputs=[
192
- gr.Model3D(label="3D Model (GLB)", display_mode="solid"),
193
  gr.Textbox(label="Status")
194
  ],
195
- title="3D Reconstruction with Spatial Memory",
196
  )
197
 
198
  if __name__ == "__main__":
199
- iface.launch()
 
12
  from torch.utils.data import DataLoader
13
  import trimesh
14
  from scipy.spatial.transform import Rotation
15
+ from transformers import AutoModelForImageSegmentation
16
+ from torchvision import transforms
17
+ from PIL import Image
18
 
19
  # Default values
20
+ DEFAULT_CKPT_PATH = 'https://huggingface.co/spaces/Stable-X/StableSpann3R/resolve/main/checkpoints/spann3r.pth'
21
  DEFAULT_DUST3R_PATH = 'https://huggingface.co/camenduru/dust3r/resolve/main/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth'
22
  DEFAULT_DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
23
 
 
108
  return dict(vertices=vertices, face_colors=face_colors, faces=faces)
109
 
110
  model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
111
+ birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
112
+ birefnet.to(DEFAULT_DEVICE)
113
+ birefnet.eval()
114
+
115
+ def extract_object(birefnet, image):
116
+ # Data settings
117
+ image_size = (1024, 1024)
118
+ transform_image = transforms.Compose([
119
+ transforms.Resize(image_size),
120
+ transforms.ToTensor(),
121
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
122
+ ])
123
+
124
+ input_images = transform_image(image).unsqueeze(0).to(DEFAULT_DEVICE)
125
+
126
+ # Prediction
127
+ with torch.no_grad():
128
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
129
+ pred = preds[0].squeeze()
130
+ pred_pil = transforms.ToPILImage()(pred)
131
+ mask = pred_pil.resize(image.size)
132
+ return mask
133
+
134
+ def generate_mask(image: np.ndarray):
135
+ # Convert numpy array to PIL Image
136
+ pil_image = Image.fromarray((image * 255).astype(np.uint8))
137
+
138
+ # Extract object and get mask
139
+ mask = extract_object(birefnet, pil_image)
140
+
141
+ # Convert mask to numpy array
142
+ mask_np = np.array(mask) / 255.0
143
+ return mask_np
144
 
 
145
  @torch.no_grad()
146
+ def reconstruct(video_path, conf_thresh, kf_every, as_pointcloud=False, remove_background=False):
147
  # Extract frames from video
148
  demo_path = extract_frames(video_path)
149
 
 
165
  print(f'Finished reconstruction for {demo_name}, FPS: {fps:.2f}')
166
 
167
  # Process results
168
+ pts_all, images_all, conf_all, mask_all = [], [], [], []
169
  for j, view in enumerate(batch):
170
  image = view['img'].permute(0, 2, 3, 1).cpu().numpy()[0]
171
  pts = preds[j]['pts3d' if j==0 else 'pts3d_in_other_view'].detach().cpu().numpy()[0]
172
  conf = preds[j]['conf'][0].cpu().data.numpy()
173
 
174
+ if remove_background:
175
+ mask = generate_mask(image)
176
+ else:
177
+ mask = np.ones_like(conf) # Change this to match conf shape
178
+
179
  images_all.append((image[None, ...] + 1.0)/2.0)
180
  pts_all.append(pts[None, ...])
181
  conf_all.append(conf[None, ...])
182
+ mask_all.append(mask[None, ...])
183
 
184
  images_all = np.concatenate(images_all, axis=0)
185
  pts_all = np.concatenate(pts_all, axis=0) * 10
186
  conf_all = np.concatenate(conf_all, axis=0)
187
+ mask_all = np.concatenate(mask_all, axis=0)
188
 
189
  # Create point cloud or mesh
190
  conf_sig_all = (conf_all-1) / conf_all
191
+ combined_mask = (conf_sig_all > conf_thresh) & (mask_all > 0.5)
192
 
193
  scene = trimesh.Scene()
194
 
195
  if as_pointcloud:
196
  pcd = trimesh.PointCloud(
197
+ vertices=pts_all[combined_mask].reshape(-1, 3),
198
+ colors=images_all[combined_mask].reshape(-1, 3)
199
  )
200
  scene.add_geometry(pcd)
201
  else:
202
  meshes = []
203
  for i in range(len(images_all)):
204
+ meshes.append(pts3d_to_trimesh(images_all[i], pts_all[i], combined_mask[i]))
205
  mesh = trimesh.Trimesh(**cat_meshes(meshes))
206
  scene.add_geometry(mesh)
207
 
 
209
  rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
210
  scene.apply_transform(np.linalg.inv(OPENGL @ rot))
211
 
212
+ # Save the scene as GLB
213
  if as_pointcloud:
214
  output_path = tempfile.mktemp(suffix='.ply')
215
  else:
216
  output_path = tempfile.mktemp(suffix='.obj')
 
217
  scene.export(output_path)
218
 
219
  # Clean up temporary directory
 
226
  inputs=[
227
  gr.Video(label="Input Video"),
228
  gr.Slider(0, 1, value=1e-3, label="Confidence Threshold"),
229
+ gr.Slider(1, 30, step=1, value=5, label="Keyframe Interval"),
230
+ gr.Checkbox(label="As Pointcloud", value=False),
231
+ gr.Checkbox(label="Remove Background", value=False)
232
  ],
233
  outputs=[
234
+ gr.Model3D(label="3D Model", display_mode="solid"),
235
  gr.Textbox(label="Status")
236
  ],
237
+ title="3D Reconstruction with Spatial Memory and Background Removal",
238
  )
239
 
240
  if __name__ == "__main__":
241
+ iface.launch(server_name="0.0.0.0",)