Stable-X commited on
Commit
2bca3f5
1 Parent(s): 549d99a

feat: Clean codes

Browse files
Files changed (2) hide show
  1. app.py +0 -91
  2. spann3r/model.py +5 -2
app.py CHANGED
@@ -161,8 +161,6 @@ def load_model(ckpt_path, device):
161
  return model
162
 
163
  model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
164
- mast3r_model = AsymmetricMASt3R.from_pretrained(DEFAULT_MAST3R_PATH).to(DEFAULT_DEVICE)
165
- mast3r_model.eval()
166
 
167
  birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
168
  birefnet.to(DEFAULT_DEVICE)
@@ -386,87 +384,6 @@ def get_keyframes(temp_dir: str, kf_every: int = 10):
386
  raise ValueError(f"Not enough frames found in {temp_dir}. Need at least 2 frames for reconstruction.")
387
 
388
  return keyframe_paths
389
-
390
- from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
391
- from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
392
- from dust3r.utils.image import load_images
393
- from dust3r.image_pairs import make_pairs
394
- from dust3r.utils.device import to_numpy
395
- def invert_matrix(mat):
396
- """Invert a torch or numpy matrix."""
397
- if isinstance(mat, torch.Tensor):
398
- return torch.linalg.inv(mat)
399
- if isinstance(mat, np.ndarray):
400
- return np.linalg.inv(mat)
401
- raise ValueError(f'Unsupported matrix type: {type(mat)}')
402
-
403
- def refine(
404
- video_path: str,
405
- conf_thresh: float = 5.0,
406
- kf_every: int = 30,
407
- remove_background: bool = False,
408
- enable_registration: bool = True,
409
- output_3d_model: bool = True
410
- ) -> dict:
411
- # Extract keyframes from video
412
- temp_dir = extract_frames(video_path)
413
- keyframe_paths = get_keyframes(temp_dir, kf_every*3)
414
-
415
- image_size = 512
416
- images = load_images(keyframe_paths, size=image_size)
417
-
418
- # Create output directory
419
- output_dir = tempfile.mkdtemp()
420
-
421
- # Generate pairs and run inference
422
- pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
423
- cache_dir = os.path.join(output_dir, 'cache')
424
- if os.path.exists(cache_dir):
425
- os.system(f'rm -rf {cache_dir}')
426
- scene = sparse_global_alignment(keyframe_paths, pairs, cache_dir,
427
- mast3r_model, lr1=0.07, niter1=500, lr2=0.014,
428
- niter2=200 if enable_registration else 0, device=DEFAULT_DEVICE,
429
- opt_depth=True if enable_registration else False, shared_intrinsics=True,
430
- matching_conf_thr=5.)
431
-
432
- # Extract scene information
433
- imgs = np.array(scene.imgs)
434
-
435
- tsdf = TSDFPostProcess(scene, TSDF_thresh=0)
436
- pts3d, _, confs = tsdf.get_dense_pts3d(clean_depth=True)
437
- masks = np.array(to_numpy([c > 1.5 for c in confs]))
438
-
439
- pcds = []
440
- for pts, conf_mask, image in zip(pts3d, masks, imgs):
441
- if remove_background:
442
- mask = generate_mask(image)
443
- else:
444
- mask = np.ones_like(conf_mask)
445
- combined_mask = conf_mask & (mask > 0.5)
446
-
447
- pts = pts.reshape(combined_mask.shape[0], combined_mask.shape[1], 3)
448
- pts_normal = pts2normal(pts).cpu().numpy()
449
- pts = pts.cpu().numpy()
450
- pcd = o3d.geometry.PointCloud()
451
- pcd.points = o3d.utility.Vector3dVector(pts[combined_mask] / 5)
452
- pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
453
- pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
454
- pcds.append(pcd)
455
-
456
- pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
457
- o3d_geometry = point2mesh(pcd_combined, depth=9)
458
- o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
459
-
460
- # Create coarse result
461
- coarse_output_path = export_geometry(o3d_geometry_centered)
462
-
463
- if output_3d_model:
464
- gs_output_path = tempfile.mktemp(suffix='.ply')
465
- point2gs(gs_output_path, pcd_combined)
466
- return coarse_output_path, [gs_output_path]
467
- else:
468
- pcd_output_path = export_geometry(pcd_combined, file_format='ply')
469
- return coarse_output_path, [pcd_output_path]
470
 
471
  @torch.no_grad()
472
  def reconstruct(video_path, conf_thresh, kf_every,
@@ -661,7 +578,6 @@ with gr.Blocks(
661
  info="Generate Splat (PLY) instead of Point Cloud (PLY)"
662
  )
663
  reconstruct_btn = gr.Button("Start Reconstruction")
664
- refine_btn = gr.Button("Start Refinement")
665
 
666
  with gr.Column(scale=2):
667
  with gr.Tab("3D Models"):
@@ -695,12 +611,5 @@ with gr.Blocks(
695
  outputs=[initial_model, output_model]
696
  )
697
 
698
- refine_btn.click(
699
- fn=refine,
700
- inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
701
- outputs=[initial_model, output_model]
702
- )
703
-
704
-
705
  if __name__ == "__main__":
706
  iface.launch(server_name="0.0.0.0")
 
161
  return model
162
 
163
  model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
 
 
164
 
165
  birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
166
  birefnet.to(DEFAULT_DEVICE)
 
384
  raise ValueError(f"Not enough frames found in {temp_dir}. Need at least 2 frames for reconstruction.")
385
 
386
  return keyframe_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
  @torch.no_grad()
389
  def reconstruct(video_path, conf_thresh, kf_every,
 
578
  info="Generate Splat (PLY) instead of Point Cloud (PLY)"
579
  )
580
  reconstruct_btn = gr.Button("Start Reconstruction")
 
581
 
582
  with gr.Column(scale=2):
583
  with gr.Tab("3D Models"):
 
611
  outputs=[initial_model, output_model]
612
  )
613
 
 
 
 
 
 
 
 
614
  if __name__ == "__main__":
615
  iface.launch(server_name="0.0.0.0")
spann3r/model.py CHANGED
@@ -201,7 +201,7 @@ class SpatialMemory():
201
 
202
  print('Memory pruned:', num_mem_b, '->', num_mem_a)
203
 
204
-
205
  class Spann3R(nn.Module):
206
  def __init__(self, dus3r_name="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
207
  use_feat=False, mem_pos_enc=False, memory_dropout=0.15):
@@ -211,7 +211,10 @@ class Spann3R(nn.Module):
211
  self.mem_pos_enc = mem_pos_enc
212
 
213
  # DUSt3R
214
- self.dust3r = AsymmetricCroCo3DStereo.from_pretrained(dus3r_name, landscape_only=True)
 
 
 
215
 
216
  # Memory encoder
217
  self.set_memory_encoder(enc_embed_dim=768 if use_feat else 1024, memory_dropout=memory_dropout)
 
201
 
202
  print('Memory pruned:', num_mem_b, '->', num_mem_a)
203
 
204
+ import math
205
  class Spann3R(nn.Module):
206
  def __init__(self, dus3r_name="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
207
  use_feat=False, mem_pos_enc=False, memory_dropout=0.15):
 
211
  self.mem_pos_enc = mem_pos_enc
212
 
213
  # DUSt3R
214
+ self.dust3r = AsymmetricCroCo3DStereo(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768,
215
+ enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R',
216
+ img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -math.inf, math.inf),
217
+ conf_mode=('exp', 1, math.inf), landscape_only=True)
218
 
219
  # Memory encoder
220
  self.set_memory_encoder(enc_embed_dim=768 if use_feat else 1024, memory_dropout=memory_dropout)