Spaces:
Sleeping
Sleeping
feat: Clean codes
Browse files- app.py +0 -91
- spann3r/model.py +5 -2
app.py
CHANGED
@@ -161,8 +161,6 @@ def load_model(ckpt_path, device):
|
|
161 |
return model
|
162 |
|
163 |
model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
|
164 |
-
mast3r_model = AsymmetricMASt3R.from_pretrained(DEFAULT_MAST3R_PATH).to(DEFAULT_DEVICE)
|
165 |
-
mast3r_model.eval()
|
166 |
|
167 |
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
|
168 |
birefnet.to(DEFAULT_DEVICE)
|
@@ -386,87 +384,6 @@ def get_keyframes(temp_dir: str, kf_every: int = 10):
|
|
386 |
raise ValueError(f"Not enough frames found in {temp_dir}. Need at least 2 frames for reconstruction.")
|
387 |
|
388 |
return keyframe_paths
|
389 |
-
|
390 |
-
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
|
391 |
-
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
|
392 |
-
from dust3r.utils.image import load_images
|
393 |
-
from dust3r.image_pairs import make_pairs
|
394 |
-
from dust3r.utils.device import to_numpy
|
395 |
-
def invert_matrix(mat):
|
396 |
-
"""Invert a torch or numpy matrix."""
|
397 |
-
if isinstance(mat, torch.Tensor):
|
398 |
-
return torch.linalg.inv(mat)
|
399 |
-
if isinstance(mat, np.ndarray):
|
400 |
-
return np.linalg.inv(mat)
|
401 |
-
raise ValueError(f'Unsupported matrix type: {type(mat)}')
|
402 |
-
|
403 |
-
def refine(
|
404 |
-
video_path: str,
|
405 |
-
conf_thresh: float = 5.0,
|
406 |
-
kf_every: int = 30,
|
407 |
-
remove_background: bool = False,
|
408 |
-
enable_registration: bool = True,
|
409 |
-
output_3d_model: bool = True
|
410 |
-
) -> dict:
|
411 |
-
# Extract keyframes from video
|
412 |
-
temp_dir = extract_frames(video_path)
|
413 |
-
keyframe_paths = get_keyframes(temp_dir, kf_every*3)
|
414 |
-
|
415 |
-
image_size = 512
|
416 |
-
images = load_images(keyframe_paths, size=image_size)
|
417 |
-
|
418 |
-
# Create output directory
|
419 |
-
output_dir = tempfile.mkdtemp()
|
420 |
-
|
421 |
-
# Generate pairs and run inference
|
422 |
-
pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
|
423 |
-
cache_dir = os.path.join(output_dir, 'cache')
|
424 |
-
if os.path.exists(cache_dir):
|
425 |
-
os.system(f'rm -rf {cache_dir}')
|
426 |
-
scene = sparse_global_alignment(keyframe_paths, pairs, cache_dir,
|
427 |
-
mast3r_model, lr1=0.07, niter1=500, lr2=0.014,
|
428 |
-
niter2=200 if enable_registration else 0, device=DEFAULT_DEVICE,
|
429 |
-
opt_depth=True if enable_registration else False, shared_intrinsics=True,
|
430 |
-
matching_conf_thr=5.)
|
431 |
-
|
432 |
-
# Extract scene information
|
433 |
-
imgs = np.array(scene.imgs)
|
434 |
-
|
435 |
-
tsdf = TSDFPostProcess(scene, TSDF_thresh=0)
|
436 |
-
pts3d, _, confs = tsdf.get_dense_pts3d(clean_depth=True)
|
437 |
-
masks = np.array(to_numpy([c > 1.5 for c in confs]))
|
438 |
-
|
439 |
-
pcds = []
|
440 |
-
for pts, conf_mask, image in zip(pts3d, masks, imgs):
|
441 |
-
if remove_background:
|
442 |
-
mask = generate_mask(image)
|
443 |
-
else:
|
444 |
-
mask = np.ones_like(conf_mask)
|
445 |
-
combined_mask = conf_mask & (mask > 0.5)
|
446 |
-
|
447 |
-
pts = pts.reshape(combined_mask.shape[0], combined_mask.shape[1], 3)
|
448 |
-
pts_normal = pts2normal(pts).cpu().numpy()
|
449 |
-
pts = pts.cpu().numpy()
|
450 |
-
pcd = o3d.geometry.PointCloud()
|
451 |
-
pcd.points = o3d.utility.Vector3dVector(pts[combined_mask] / 5)
|
452 |
-
pcd.colors = o3d.utility.Vector3dVector(image[combined_mask])
|
453 |
-
pcd.normals = o3d.utility.Vector3dVector(pts_normal[combined_mask])
|
454 |
-
pcds.append(pcd)
|
455 |
-
|
456 |
-
pcd_combined = combine_and_clean_point_clouds(pcds, voxel_size=0.001)
|
457 |
-
o3d_geometry = point2mesh(pcd_combined, depth=9)
|
458 |
-
o3d_geometry_centered = center_mesh(o3d_geometry, normalize=True)
|
459 |
-
|
460 |
-
# Create coarse result
|
461 |
-
coarse_output_path = export_geometry(o3d_geometry_centered)
|
462 |
-
|
463 |
-
if output_3d_model:
|
464 |
-
gs_output_path = tempfile.mktemp(suffix='.ply')
|
465 |
-
point2gs(gs_output_path, pcd_combined)
|
466 |
-
return coarse_output_path, [gs_output_path]
|
467 |
-
else:
|
468 |
-
pcd_output_path = export_geometry(pcd_combined, file_format='ply')
|
469 |
-
return coarse_output_path, [pcd_output_path]
|
470 |
|
471 |
@torch.no_grad()
|
472 |
def reconstruct(video_path, conf_thresh, kf_every,
|
@@ -661,7 +578,6 @@ with gr.Blocks(
|
|
661 |
info="Generate Splat (PLY) instead of Point Cloud (PLY)"
|
662 |
)
|
663 |
reconstruct_btn = gr.Button("Start Reconstruction")
|
664 |
-
refine_btn = gr.Button("Start Refinement")
|
665 |
|
666 |
with gr.Column(scale=2):
|
667 |
with gr.Tab("3D Models"):
|
@@ -695,12 +611,5 @@ with gr.Blocks(
|
|
695 |
outputs=[initial_model, output_model]
|
696 |
)
|
697 |
|
698 |
-
refine_btn.click(
|
699 |
-
fn=refine,
|
700 |
-
inputs=[video_input, conf_thresh, kf_every, remove_background, enable_registration, output_3d_model],
|
701 |
-
outputs=[initial_model, output_model]
|
702 |
-
)
|
703 |
-
|
704 |
-
|
705 |
if __name__ == "__main__":
|
706 |
iface.launch(server_name="0.0.0.0")
|
|
|
161 |
return model
|
162 |
|
163 |
model = load_model(DEFAULT_CKPT_PATH, DEFAULT_DEVICE)
|
|
|
|
|
164 |
|
165 |
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)
|
166 |
birefnet.to(DEFAULT_DEVICE)
|
|
|
384 |
raise ValueError(f"Not enough frames found in {temp_dir}. Need at least 2 frames for reconstruction.")
|
385 |
|
386 |
return keyframe_paths
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
@torch.no_grad()
|
389 |
def reconstruct(video_path, conf_thresh, kf_every,
|
|
|
578 |
info="Generate Splat (PLY) instead of Point Cloud (PLY)"
|
579 |
)
|
580 |
reconstruct_btn = gr.Button("Start Reconstruction")
|
|
|
581 |
|
582 |
with gr.Column(scale=2):
|
583 |
with gr.Tab("3D Models"):
|
|
|
611 |
outputs=[initial_model, output_model]
|
612 |
)
|
613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
if __name__ == "__main__":
|
615 |
iface.launch(server_name="0.0.0.0")
|
spann3r/model.py
CHANGED
@@ -201,7 +201,7 @@ class SpatialMemory():
|
|
201 |
|
202 |
print('Memory pruned:', num_mem_b, '->', num_mem_a)
|
203 |
|
204 |
-
|
205 |
class Spann3R(nn.Module):
|
206 |
def __init__(self, dus3r_name="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
|
207 |
use_feat=False, mem_pos_enc=False, memory_dropout=0.15):
|
@@ -211,7 +211,10 @@ class Spann3R(nn.Module):
|
|
211 |
self.mem_pos_enc = mem_pos_enc
|
212 |
|
213 |
# DUSt3R
|
214 |
-
self.dust3r = AsymmetricCroCo3DStereo
|
|
|
|
|
|
|
215 |
|
216 |
# Memory encoder
|
217 |
self.set_memory_encoder(enc_embed_dim=768 if use_feat else 1024, memory_dropout=memory_dropout)
|
|
|
201 |
|
202 |
print('Memory pruned:', num_mem_b, '->', num_mem_a)
|
203 |
|
204 |
+
import math
|
205 |
class Spann3R(nn.Module):
|
206 |
def __init__(self, dus3r_name="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
|
207 |
use_feat=False, mem_pos_enc=False, memory_dropout=0.15):
|
|
|
211 |
self.mem_pos_enc = mem_pos_enc
|
212 |
|
213 |
# DUSt3R
|
214 |
+
self.dust3r = AsymmetricCroCo3DStereo(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768,
|
215 |
+
enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R',
|
216 |
+
img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -math.inf, math.inf),
|
217 |
+
conf_mode=('exp', 1, math.inf), landscape_only=True)
|
218 |
|
219 |
# Memory encoder
|
220 |
self.set_memory_encoder(enc_embed_dim=768 if use_feat else 1024, memory_dropout=memory_dropout)
|