hujiecpp commited on
Commit
bea11e6
·
1 Parent(s): b59e9c0

init project

Browse files
Files changed (1) hide show
  1. app.py +25 -23
app.py CHANGED
@@ -257,23 +257,10 @@ def slerp_multiple(vectors, t_values):
257
  return interpolated_vector
258
 
259
  # @torch.no_grad
260
- def get_mask_from_img_sam1(sam1_image, yolov8_image, original_size, input_size, transform):
261
 
262
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
263
- SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
264
- mobilesamv2 = sam_model_registry['sam_vit_h'](None)
265
- sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')
266
- image_encoder = sam1.vision_encoder
267
-
268
- prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
269
- mobilesamv2.prompt_encoder = prompt_encoder
270
- mobilesamv2.mask_decoder = mask_decoder
271
- mobilesamv2.image_encoder=image_encoder
272
- mobilesamv2.to(device=device)
273
- mobilesamv2.eval()
274
 
275
- YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
276
- yolov8 = ObjectAwareModel(YOLO8_CKP)
277
 
278
  sam_mask=[]
279
  img_area = original_size[0] * original_size[1]
@@ -327,15 +314,10 @@ def get_mask_from_img_sam1(sam1_image, yolov8_image, original_size, input_size,
327
  return ret_mask
328
 
329
  # @torch.no_grad
330
- def get_cog_feats(images):
331
 
332
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
333
 
334
- sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)
335
-
336
- siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
337
- siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
338
-
339
  cog_seg_maps = []
340
  rev_cog_seg_maps = []
341
  inference_state = sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
@@ -346,7 +328,7 @@ def get_cog_feats(images):
346
  np_images = images.np_images
347
  np_images_size = images.np_images_size
348
 
349
- sam1_masks = get_mask_from_img_sam1(sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
350
  for mask in sam1_masks:
351
  _, _, _ = sam2.add_new_mask(
352
  inference_state=inference_state,
@@ -368,7 +350,7 @@ def get_cog_feats(images):
368
  if out_frame_idx == 0:
369
  continue
370
 
371
- sam1_masks = get_mask_from_img_sam1(sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform)
372
 
373
  for sam1_mask in sam1_masks:
374
  flg = 1
@@ -484,13 +466,33 @@ def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
484
  MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
485
  mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  if len(filelist) < 2:
488
  raise gradio.Error("Please input at least 2 images.")
489
 
490
  images = Images(filelist=filelist, device=device)
491
 
492
  # try:
493
- cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images)
494
  imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
495
  # except Exception as e:
496
  # rev_cog_seg_maps = []
 
257
  return interpolated_vector
258
 
259
  # @torch.no_grad
260
+ def get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_image, yolov8_image, original_size, input_size, transform):
261
 
262
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
 
 
 
 
263
 
 
 
264
 
265
  sam_mask=[]
266
  img_area = original_size[0] * original_size[1]
 
314
  return ret_mask
315
 
316
  # @torch.no_grad
317
+ def get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2):
318
 
319
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
320
 
 
 
 
 
 
321
  cog_seg_maps = []
322
  rev_cog_seg_maps = []
323
  inference_state = sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
 
328
  np_images = images.np_images
329
  np_images_size = images.np_images_size
330
 
331
+ sam1_masks = get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
332
  for mask in sam1_masks:
333
  _, _, _ = sam2.add_new_mask(
334
  inference_state=inference_state,
 
350
  if out_frame_idx == 0:
351
  continue
352
 
353
+ sam1_masks = get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform)
354
 
355
  for sam1_mask in sam1_masks:
356
  flg = 1
 
466
  MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
467
  mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)
468
 
469
+ sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)
470
+
471
+ siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
472
+ siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
473
+
474
+ SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
475
+ mobilesamv2 = sam_model_registry['sam_vit_h'](None)
476
+ sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')
477
+ image_encoder = sam1.vision_encoder
478
+
479
+ prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
480
+ mobilesamv2.prompt_encoder = prompt_encoder
481
+ mobilesamv2.mask_decoder = mask_decoder
482
+ mobilesamv2.image_encoder=image_encoder
483
+ mobilesamv2.to(device=device)
484
+ mobilesamv2.eval()
485
+
486
+ YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
487
+ yolov8 = ObjectAwareModel(YOLO8_CKP)
488
+
489
  if len(filelist) < 2:
490
  raise gradio.Error("Please input at least 2 images.")
491
 
492
  images = Images(filelist=filelist, device=device)
493
 
494
  # try:
495
+ cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2)
496
  imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
497
  # except Exception as e:
498
  # rev_cog_seg_maps = []