hujiecpp commited on
Commit
b00e12c
·
1 Parent(s): 9225e86

init project

Browse files
Files changed (1) hide show
  1. app.py +47 -21
app.py CHANGED
@@ -34,14 +34,26 @@ from typing import Any, Dict, Generator,List
34
  import matplotlib.pyplot as pl
35
 
36
  from modules.mobilesamv2.utils.transforms import ResizeLongestSide
37
- from modules.pe3r.models import Models
38
  import torchvision.transforms as tvf
39
 
40
- from transformers import AutoTokenizer, AutoModel, AutoProcessor
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  silent = False
43
  # device = 'cuda' if torch.cuda.is_available() else 'cpu' #'cpu' #
44
- pe3r = Models('cpu') #
45
  # print(device)
46
 
47
  def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
@@ -248,13 +260,25 @@ def slerp_multiple(vectors, t_values):
248
  def get_mask_from_img_sam1(sam1_image, yolov8_image, original_size, input_size, transform):
249
 
250
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
251
- pe3r.yolov8.to(device)
252
- pe3r.mobilesamv2.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  sam_mask=[]
255
  img_area = original_size[0] * original_size[1]
256
 
257
- obj_results = pe3r.yolov8(yolov8_image,device=device,retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
258
  input_boxes1 = obj_results[0].boxes.xyxy
259
  input_boxes1 = input_boxes1.cpu().numpy()
260
  input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
@@ -268,21 +292,21 @@ def get_mask_from_img_sam1(sam1_image, yolov8_image, original_size, input_size,
268
 
269
  # input_boxes = torch.cat((input_boxes1, input_boxes2), dim=0)
270
 
271
- input_image = pe3r.mobilesamv2.preprocess(sam1_image)
272
- image_embedding = pe3r.mobilesamv2.image_encoder(input_image)['last_hidden_state']
273
 
274
  image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
275
- prompt_embedding=pe3r.mobilesamv2.prompt_encoder.get_dense_pe()
276
  prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
277
  for (boxes,) in batch_iterator(320, input_boxes):
278
  with torch.no_grad():
279
  image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
280
  prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
281
- sparse_embeddings, dense_embeddings = pe3r.mobilesamv2.prompt_encoder(
282
  points=None,
283
  boxes=boxes,
284
  masks=None,)
285
- low_res_masks, _ = pe3r.mobilesamv2.mask_decoder(
286
  image_embeddings=image_embedding,
287
  image_pe=prompt_embedding,
288
  sparse_prompt_embeddings=sparse_embeddings,
@@ -290,8 +314,8 @@ def get_mask_from_img_sam1(sam1_image, yolov8_image, original_size, input_size,
290
  multimask_output=False,
291
  simple_type=True,
292
  )
293
- low_res_masks=pe3r.mobilesamv2.postprocess_masks(low_res_masks, input_size, original_size)
294
- sam_mask_pre = (low_res_masks > pe3r.mobilesamv2.mask_threshold)
295
  for mask in sam_mask_pre:
296
  if mask.sum() / img_area > 0.002:
297
  sam_mask.append(mask.squeeze(1))
@@ -306,13 +330,15 @@ def get_mask_from_img_sam1(sam1_image, yolov8_image, original_size, input_size,
306
  def get_cog_feats(images):
307
 
308
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
309
- pe3r.sam2.to(device)
 
 
310
  siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
311
  siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
312
 
313
  cog_seg_maps = []
314
  rev_cog_seg_maps = []
315
- inference_state = pe3r.sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
316
  mask_num = 0
317
 
318
  sam1_images = images.sam1_images
@@ -322,7 +348,7 @@ def get_cog_feats(images):
322
 
323
  sam1_masks = get_mask_from_img_sam1(sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
324
  for mask in sam1_masks:
325
- _, _, _ = pe3r.sam2.add_new_mask(
326
  inference_state=inference_state,
327
  frame_idx=0,
328
  obj_id=mask_num,
@@ -331,7 +357,7 @@ def get_cog_feats(images):
331
  mask_num += 1
332
 
333
  video_segments = {} # video_segments contains the per-frame segmentation results
334
- for out_frame_idx, out_obj_ids, out_mask_logits in pe3r.sam2.propagate_in_video(inference_state):
335
  sam2_masks = (out_mask_logits > 0.0).squeeze(1)
336
 
337
  video_segments[out_frame_idx] = {
@@ -455,8 +481,8 @@ def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
455
  """
456
 
457
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
458
-
459
- pe3r.mast3r.to(device)
460
 
461
  if len(filelist) < 2:
462
  raise gradio.Error("Please input at least 2 images.")
@@ -485,7 +511,7 @@ def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
485
  scenegraph_type = scenegraph_type + "-" + str(refid)
486
 
487
  pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
488
- output = inference(pairs, pe3r.mast3r, device, batch_size=1, verbose=not silent)
489
  mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
490
  scene_1 = global_aligner(output, cog_seg_maps, rev_cog_seg_maps, cog_feats, device=device, mode=mode, verbose=not silent)
491
  lr = 0.01
@@ -498,7 +524,7 @@ def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
498
  # print(imgs[i]['img'].shape, scene.imgs[i].shape, ImgNorm(scene.imgs[i])[None])
499
  imgs[i]['img'] = ImgNorm(scene_1.imgs[i])[None]
500
  pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
501
- output = inference(pairs, pe3r.mast3r, device, batch_size=1, verbose=not silent)
502
  mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
503
  scene = global_aligner(output, cog_seg_maps, rev_cog_seg_maps, cog_feats, device=device, mode=mode, verbose=not silent)
504
  ori_imgs = scene.ori_imgs
 
34
  import matplotlib.pyplot as pl
35
 
36
  from modules.mobilesamv2.utils.transforms import ResizeLongestSide
37
+ # from modules.pe3r.models import Models
38
  import torchvision.transforms as tvf
39
 
40
+
41
+
42
+ sys.path.append(os.path.abspath('./modules/ultralytics'))
43
+
44
+ from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
45
+ from modules.mast3r.model import AsymmetricMASt3R
46
+
47
+ # from modules.sam2.build_sam import build_sam2_video_predictor
48
+ from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
49
+ from modules.mobilesamv2 import sam_model_registry
50
+
51
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
52
+
53
 
54
  silent = False
55
  # device = 'cuda' if torch.cuda.is_available() else 'cpu' #'cpu' #
56
+ # pe3r = Models('cpu') #
57
  # print(device)
58
 
59
  def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
 
260
  def get_mask_from_img_sam1(sam1_image, yolov8_image, original_size, input_size, transform):
261
 
262
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
263
+ SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
264
+ mobilesamv2 = sam_model_registry['sam_vit_h'](None)
265
+ sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')
266
+ image_encoder = sam1.vision_encoder
267
+
268
+ prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
269
+ mobilesamv2.prompt_encoder = prompt_encoder
270
+ mobilesamv2.mask_decoder = mask_decoder
271
+ mobilesamv2.image_encoder=image_encoder
272
+ mobilesamv2.to(device=device)
273
+ mobilesamv2.eval()
274
+
275
+ YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
276
+ yolov8 = ObjectAwareModel(YOLO8_CKP)
277
 
278
  sam_mask=[]
279
  img_area = original_size[0] * original_size[1]
280
 
281
+ obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
282
  input_boxes1 = obj_results[0].boxes.xyxy
283
  input_boxes1 = input_boxes1.cpu().numpy()
284
  input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
 
292
 
293
  # input_boxes = torch.cat((input_boxes1, input_boxes2), dim=0)
294
 
295
+ input_image = mobilesamv2.preprocess(sam1_image)
296
+ image_embedding = mobilesamv2.image_encoder(input_image)['last_hidden_state']
297
 
298
  image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
299
+ prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
300
  prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
301
  for (boxes,) in batch_iterator(320, input_boxes):
302
  with torch.no_grad():
303
  image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
304
  prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
305
+ sparse_embeddings, dense_embeddings = mobilesamv2.prompt_encoder(
306
  points=None,
307
  boxes=boxes,
308
  masks=None,)
309
+ low_res_masks, _ = mobilesamv2.mask_decoder(
310
  image_embeddings=image_embedding,
311
  image_pe=prompt_embedding,
312
  sparse_prompt_embeddings=sparse_embeddings,
 
314
  multimask_output=False,
315
  simple_type=True,
316
  )
317
+ low_res_masks=mobilesamv2.postprocess_masks(low_res_masks, input_size, original_size)
318
+ sam_mask_pre = (low_res_masks > mobilesamv2.mask_threshold)
319
  for mask in sam_mask_pre:
320
  if mask.sum() / img_area > 0.002:
321
  sam_mask.append(mask.squeeze(1))
 
330
  def get_cog_feats(images):
331
 
332
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
333
+
334
+ sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)
335
+
336
  siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
337
  siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
338
 
339
  cog_seg_maps = []
340
  rev_cog_seg_maps = []
341
+ inference_state = sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
342
  mask_num = 0
343
 
344
  sam1_images = images.sam1_images
 
348
 
349
  sam1_masks = get_mask_from_img_sam1(sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
350
  for mask in sam1_masks:
351
+ _, _, _ = sam2.add_new_mask(
352
  inference_state=inference_state,
353
  frame_idx=0,
354
  obj_id=mask_num,
 
357
  mask_num += 1
358
 
359
  video_segments = {} # video_segments contains the per-frame segmentation results
360
+ for out_frame_idx, out_obj_ids, out_mask_logits in sam2.propagate_in_video(inference_state):
361
  sam2_masks = (out_mask_logits > 0.0).squeeze(1)
362
 
363
  video_segments[out_frame_idx] = {
 
481
  """
482
 
483
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
484
+ MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
485
+ mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)
486
 
487
  if len(filelist) < 2:
488
  raise gradio.Error("Please input at least 2 images.")
 
511
  scenegraph_type = scenegraph_type + "-" + str(refid)
512
 
513
  pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
514
+ output = inference(pairs, mast3r, device, batch_size=1, verbose=not silent)
515
  mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
516
  scene_1 = global_aligner(output, cog_seg_maps, rev_cog_seg_maps, cog_feats, device=device, mode=mode, verbose=not silent)
517
  lr = 0.01
 
524
  # print(imgs[i]['img'].shape, scene.imgs[i].shape, ImgNorm(scene.imgs[i])[None])
525
  imgs[i]['img'] = ImgNorm(scene_1.imgs[i])[None]
526
  pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
527
+ output = inference(pairs, mast3r, device, batch_size=1, verbose=not silent)
528
  mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
529
  scene = global_aligner(output, cog_seg_maps, rev_cog_seg_maps, cog_feats, device=device, mode=mode, verbose=not silent)
530
  ori_imgs = scene.ori_imgs