hujiecpp commited on
Commit
c62842b
·
1 Parent(s): 6befd7a

init project

Browse files
Files changed (1) hide show
  1. modules/pe3r/demo.py +6 -6
modules/pe3r/demo.py CHANGED
@@ -236,17 +236,17 @@ def slerp_multiple(vectors, t_values):
236
  return interpolated_vector
237
 
238
  @torch.no_grad
239
- def get_mask_from_img_sam1(mobilesamv2, yolov8, sam1_image, yolov8_image, original_size, input_size, transform):
240
  sam_mask=[]
241
  img_area = original_size[0] * original_size[1]
242
 
243
- obj_results = yolov8(yolov8_image,device='cuda',retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
244
  input_boxes1 = obj_results[0].boxes.xyxy
245
  input_boxes1 = input_boxes1.cpu().numpy()
246
  input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
247
  input_boxes = torch.from_numpy(input_boxes1).cuda()
248
 
249
- # obj_results = yolov8(yolov8_image,device='cuda',retina_masks=False,imgsz=512,conf=0.25,iou=0.9,verbose=False)
250
  # input_boxes2 = obj_results[0].boxes.xyxy
251
  # input_boxes2 = input_boxes2.cpu().numpy()
252
  # input_boxes2 = transform.apply_boxes(input_boxes2, original_size)
@@ -289,7 +289,7 @@ def get_mask_from_img_sam1(mobilesamv2, yolov8, sam1_image, yolov8_image, origin
289
  return ret_mask
290
 
291
  @torch.no_grad
292
- def get_cog_feats(images, pe3r):
293
  cog_seg_maps = []
294
  rev_cog_seg_maps = []
295
  inference_state = pe3r.sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
@@ -300,7 +300,7 @@ def get_cog_feats(images, pe3r):
300
  np_images = images.np_images
301
  np_images_size = images.np_images_size
302
 
303
- sam1_masks = get_mask_from_img_sam1(pe3r.mobilesamv2, pe3r.yolov8, sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
304
  for mask in sam1_masks:
305
  _, _, _ = pe3r.sam2.add_new_mask(
306
  inference_state=inference_state,
@@ -438,7 +438,7 @@ def get_reconstructed_scene(outdir, pe3r, device, silent, filelist, schedule, ni
438
  images = Images(filelist=filelist, device=device)
439
 
440
  try:
441
- cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images, pe3r)
442
  imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
443
  except Exception as e:
444
  rev_cog_seg_maps = []
 
236
  return interpolated_vector
237
 
238
  @torch.no_grad
239
+ def get_mask_from_img_sam1(mobilesamv2, yolov8, sam1_image, yolov8_image, original_size, input_size, transform, device):
240
  sam_mask=[]
241
  img_area = original_size[0] * original_size[1]
242
 
243
+ obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
244
  input_boxes1 = obj_results[0].boxes.xyxy
245
  input_boxes1 = input_boxes1.cpu().numpy()
246
  input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
247
  input_boxes = torch.from_numpy(input_boxes1).cuda()
248
 
249
+ # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=512,conf=0.25,iou=0.9,verbose=False)
250
  # input_boxes2 = obj_results[0].boxes.xyxy
251
  # input_boxes2 = input_boxes2.cpu().numpy()
252
  # input_boxes2 = transform.apply_boxes(input_boxes2, original_size)
 
289
  return ret_mask
290
 
291
  @torch.no_grad
292
+ def get_cog_feats(images, pe3r, device):
293
  cog_seg_maps = []
294
  rev_cog_seg_maps = []
295
  inference_state = pe3r.sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
 
300
  np_images = images.np_images
301
  np_images_size = images.np_images_size
302
 
303
+ sam1_masks = get_mask_from_img_sam1(pe3r.mobilesamv2, pe3r.yolov8, sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform, device)
304
  for mask in sam1_masks:
305
  _, _, _ = pe3r.sam2.add_new_mask(
306
  inference_state=inference_state,
 
438
  images = Images(filelist=filelist, device=device)
439
 
440
  try:
441
+ cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images, pe3r, device)
442
  imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
443
  except Exception as e:
444
  rev_cog_seg_maps = []