Jie Hu commited on
Commit
665754b
·
1 Parent(s): 43b0caa

init project

Browse files
Files changed (1) hide show
  1. app.py +192 -195
app.py CHANGED
@@ -83,7 +83,7 @@ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world,
83
  if not silent:
84
  print('(exporting 3D scene to', outfile, ')')
85
  # scene.export(file_obj=outfile)
86
- print('ttttt')
87
  return outfile
88
 
89
  # @spaces.GPU(duration=180)
@@ -244,197 +244,197 @@ def slerp_multiple(vectors, t_values):
244
 
245
  return interpolated_vector
246
 
247
- @torch.no_grad
248
- def get_mask_from_img_sam1(mobilesamv2, yolov8, sam1_image, yolov8_image, original_size, input_size, transform):
249
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
250
 
251
- sam_mask=[]
252
- img_area = original_size[0] * original_size[1]
253
 
254
- obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
255
- input_boxes1 = obj_results[0].boxes.xyxy
256
- input_boxes1 = input_boxes1.cpu().numpy()
257
- input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
258
- input_boxes = torch.from_numpy(input_boxes1).to(device)
259
 
260
- # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=512,conf=0.25,iou=0.9,verbose=False)
261
- # input_boxes2 = obj_results[0].boxes.xyxy
262
- # input_boxes2 = input_boxes2.cpu().numpy()
263
- # input_boxes2 = transform.apply_boxes(input_boxes2, original_size)
264
- # input_boxes2 = torch.from_numpy(input_boxes2).to(device)
265
-
266
- # input_boxes = torch.cat((input_boxes1, input_boxes2), dim=0)
267
-
268
- input_image = mobilesamv2.preprocess(sam1_image)
269
- image_embedding = mobilesamv2.image_encoder(input_image)['last_hidden_state']
270
-
271
- image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
272
- prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
273
- prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
274
- for (boxes,) in batch_iterator(320, input_boxes):
275
- with torch.no_grad():
276
- image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
277
- prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
278
- sparse_embeddings, dense_embeddings = mobilesamv2.prompt_encoder(
279
- points=None,
280
- boxes=boxes,
281
- masks=None,)
282
- low_res_masks, _ = mobilesamv2.mask_decoder(
283
- image_embeddings=image_embedding,
284
- image_pe=prompt_embedding,
285
- sparse_prompt_embeddings=sparse_embeddings,
286
- dense_prompt_embeddings=dense_embeddings,
287
- multimask_output=False,
288
- simple_type=True,
289
- )
290
- low_res_masks=mobilesamv2.postprocess_masks(low_res_masks, input_size, original_size)
291
- sam_mask_pre = (low_res_masks > mobilesamv2.mask_threshold)
292
- for mask in sam_mask_pre:
293
- if mask.sum() / img_area > 0.002:
294
- sam_mask.append(mask.squeeze(1))
295
- sam_mask=torch.cat(sam_mask)
296
- sorted_sam_mask = sorted(sam_mask, key=(lambda x: x.sum()), reverse=True)
297
- keep = mask_nms(sorted_sam_mask)
298
- ret_mask = filter(sorted_sam_mask, keep)
299
-
300
- return ret_mask
301
-
302
- @torch.no_grad
303
- def get_cog_feats(images):
304
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
305
- cog_seg_maps = []
306
- rev_cog_seg_maps = []
307
- inference_state = pe3r.sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
308
- mask_num = 0
309
-
310
- sam1_images = images.sam1_images
311
- sam1_images_size = images.sam1_images_size
312
- np_images = images.np_images
313
- np_images_size = images.np_images_size
314
 
315
- sam1_masks = get_mask_from_img_sam1(pe3r.mobilesamv2, pe3r.yolov8, sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
316
- for mask in sam1_masks:
317
- _, _, _ = pe3r.sam2.add_new_mask(
318
- inference_state=inference_state,
319
- frame_idx=0,
320
- obj_id=mask_num,
321
- mask=mask,
322
- )
323
- mask_num += 1
324
-
325
- video_segments = {} # video_segments contains the per-frame segmentation results
326
- for out_frame_idx, out_obj_ids, out_mask_logits in pe3r.sam2.propagate_in_video(inference_state):
327
- sam2_masks = (out_mask_logits > 0.0).squeeze(1)
328
-
329
- video_segments[out_frame_idx] = {
330
- out_obj_id: sam2_masks[i].cpu().numpy()
331
- for i, out_obj_id in enumerate(out_obj_ids)
332
- }
333
-
334
- if out_frame_idx == 0:
335
- continue
336
-
337
- sam1_masks = get_mask_from_img_sam1(pe3r.mobilesamv2, pe3r.yolov8, sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform)
338
-
339
- for sam1_mask in sam1_masks:
340
- flg = 1
341
- for sam2_mask in sam2_masks:
342
- # print(sam1_mask.shape, sam2_mask.shape)
343
- area1 = sam1_mask.sum()
344
- area2 = sam2_mask.sum()
345
- intersection = (sam1_mask & sam2_mask).sum()
346
- if min(intersection / area1, intersection / area2) > 0.25:
347
- flg = 0
348
- break
349
- if flg:
350
- video_segments[out_frame_idx][mask_num] = sam1_mask.cpu().numpy()
351
- mask_num += 1
352
-
353
- multi_view_clip_feats = torch.zeros((mask_num+1, 1024))
354
- multi_view_clip_feats_map = {}
355
- multi_view_clip_area_map = {}
356
- for now_frame in range(0, len(video_segments), 1):
357
- image = np_images[now_frame]
358
-
359
- seg_img_list = []
360
- out_obj_id_list = []
361
- out_obj_mask_list = []
362
- out_obj_area_list = []
363
- # NOTE: background: -1
364
- rev_seg_map = -np.ones(image.shape[:2], dtype=np.int64)
365
- sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=False)
366
- for out_obj_id, mask in sorted_dict_items:
367
- if mask.sum() == 0:
368
- continue
369
- rev_seg_map[mask] = out_obj_id
370
- rev_cog_seg_maps.append(rev_seg_map)
371
-
372
- seg_map = -np.ones(image.shape[:2], dtype=np.int64)
373
- sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=True)
374
- for out_obj_id, mask in sorted_dict_items:
375
- if mask.sum() == 0:
376
- continue
377
- box = np.int32(box_xyxy_to_xywh(mask_to_box(mask)))
378
 
379
- if box[2] == 0 and box[3] == 0:
380
- continue
381
- # print(box)
382
- seg_img = get_seg_img(mask, box, image)
383
- pad_seg_img = cv2.resize(pad_img(seg_img), (256,256))
384
- seg_img_list.append(pad_seg_img)
385
- seg_map[mask] = out_obj_id
386
- out_obj_id_list.append(out_obj_id)
387
- out_obj_area_list.append(np.count_nonzero(mask))
388
- out_obj_mask_list.append(mask)
389
-
390
- if len(seg_img_list) == 0:
391
- cog_seg_maps.append(seg_map)
392
- continue
393
-
394
- seg_imgs = np.stack(seg_img_list, axis=0) # b,H,W,3
395
- seg_imgs = torch.from_numpy(seg_imgs).permute(0,3,1,2) # / 255.0
396
 
397
- inputs = pe3r.siglip_processor(images=seg_imgs, return_tensors="pt")
398
- inputs = {key: value.to(device) for key, value in inputs.items()}
399
 
400
- image_features = pe3r.siglip.get_image_features(**inputs)
401
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
402
- image_features = image_features.detach().cpu()
403
-
404
- for i in range(len(out_obj_mask_list)):
405
- for j in range(i + 1, len(out_obj_mask_list)):
406
- mask1 = out_obj_mask_list[i]
407
- mask2 = out_obj_mask_list[j]
408
- intersection = np.logical_and(mask1, mask2).sum()
409
- area1 = out_obj_area_list[i]
410
- area2 = out_obj_area_list[j]
411
- if min(intersection / area1, intersection / area2) > 0.025:
412
- conf1 = area1 / (area1 + area2)
413
- # conf2 = area2 / (area1 + area2)
414
- image_features[j] = slerp(image_features[j], image_features[i], conf1)
415
-
416
- for i, clip_feat in enumerate(image_features):
417
- id = out_obj_id_list[i]
418
- if id in multi_view_clip_feats_map.keys():
419
- multi_view_clip_feats_map[id].append(clip_feat)
420
- multi_view_clip_area_map[id].append(out_obj_area_list[i])
421
- else:
422
- multi_view_clip_feats_map[id] = [clip_feat]
423
- multi_view_clip_area_map[id] = [out_obj_area_list[i]]
424
-
425
- cog_seg_maps.append(seg_map)
426
- del image_features
427
 
428
- for i in range(mask_num):
429
- if i in multi_view_clip_feats_map.keys():
430
- clip_feats = multi_view_clip_feats_map[i]
431
- mask_area = multi_view_clip_area_map[i]
432
- multi_view_clip_feats[i] = slerp_multiple(torch.stack(clip_feats), np.stack(mask_area))
433
- else:
434
- multi_view_clip_feats[i] = torch.zeros((1024))
435
- multi_view_clip_feats[mask_num] = torch.zeros((1024))
436
 
437
- return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
438
 
439
  @spaces.GPU(duration=180)
440
  def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
@@ -452,16 +452,16 @@ def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
452
  images = Images(filelist=filelist, device=device)
453
 
454
  # try:
455
- cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images)
456
- imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
457
  # except Exception as e:
458
- # rev_cog_seg_maps = []
459
- # for tmp_img in images.np_images:
460
- # rev_seg_map = -np.ones(tmp_img.shape[:2], dtype=np.int64)
461
- # rev_cog_seg_maps.append(rev_seg_map)
462
- # cog_seg_maps = rev_cog_seg_maps
463
- # cog_feats = torch.zeros((1, 1024))
464
- # imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
465
 
466
  if len(imgs) == 1:
467
  imgs = [imgs[0], copy.deepcopy(imgs[0])]
@@ -499,11 +499,8 @@ def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
499
  scene.ori_imgs = ori_imgs
500
  print(e)
501
 
502
- print('a')
503
-
504
  outfile = get_3D_model_from_scene(outdir, scene, min_conf_thr, as_pointcloud, mask_sky,
505
  clean_depth, transparent_cams, cam_size)
506
- print('b')
507
  # also return rgb, depth and confidence imgs
508
  # depth is normalized with the max value for all images
509
  # we apply the jet colormap on the confidence maps
 
83
  if not silent:
84
  print('(exporting 3D scene to', outfile, ')')
85
  # scene.export(file_obj=outfile)
86
+
87
  return outfile
88
 
89
  # @spaces.GPU(duration=180)
 
244
 
245
  return interpolated_vector
246
 
247
+ # @torch.no_grad
248
+ # def get_mask_from_img_sam1(mobilesamv2, yolov8, sam1_image, yolov8_image, original_size, input_size, transform):
249
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
250
 
251
+ # sam_mask=[]
252
+ # img_area = original_size[0] * original_size[1]
253
 
254
+ # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
255
+ # input_boxes1 = obj_results[0].boxes.xyxy
256
+ # input_boxes1 = input_boxes1.cpu().numpy()
257
+ # input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
258
+ # input_boxes = torch.from_numpy(input_boxes1).to(device)
259
 
260
+ # # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=512,conf=0.25,iou=0.9,verbose=False)
261
+ # # input_boxes2 = obj_results[0].boxes.xyxy
262
+ # # input_boxes2 = input_boxes2.cpu().numpy()
263
+ # # input_boxes2 = transform.apply_boxes(input_boxes2, original_size)
264
+ # # input_boxes2 = torch.from_numpy(input_boxes2).to(device)
265
+
266
+ # # input_boxes = torch.cat((input_boxes1, input_boxes2), dim=0)
267
+
268
+ # input_image = mobilesamv2.preprocess(sam1_image)
269
+ # image_embedding = mobilesamv2.image_encoder(input_image)['last_hidden_state']
270
+
271
+ # image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
272
+ # prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
273
+ # prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
274
+ # for (boxes,) in batch_iterator(320, input_boxes):
275
+ # with torch.no_grad():
276
+ # image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
277
+ # prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
278
+ # sparse_embeddings, dense_embeddings = mobilesamv2.prompt_encoder(
279
+ # points=None,
280
+ # boxes=boxes,
281
+ # masks=None,)
282
+ # low_res_masks, _ = mobilesamv2.mask_decoder(
283
+ # image_embeddings=image_embedding,
284
+ # image_pe=prompt_embedding,
285
+ # sparse_prompt_embeddings=sparse_embeddings,
286
+ # dense_prompt_embeddings=dense_embeddings,
287
+ # multimask_output=False,
288
+ # simple_type=True,
289
+ # )
290
+ # low_res_masks=mobilesamv2.postprocess_masks(low_res_masks, input_size, original_size)
291
+ # sam_mask_pre = (low_res_masks > mobilesamv2.mask_threshold)
292
+ # for mask in sam_mask_pre:
293
+ # if mask.sum() / img_area > 0.002:
294
+ # sam_mask.append(mask.squeeze(1))
295
+ # sam_mask=torch.cat(sam_mask)
296
+ # sorted_sam_mask = sorted(sam_mask, key=(lambda x: x.sum()), reverse=True)
297
+ # keep = mask_nms(sorted_sam_mask)
298
+ # ret_mask = filter(sorted_sam_mask, keep)
299
+
300
+ # return ret_mask
301
+
302
+ # @torch.no_grad
303
+ # def get_cog_feats(images):
304
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
305
+ # cog_seg_maps = []
306
+ # rev_cog_seg_maps = []
307
+ # inference_state = pe3r.sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
308
+ # mask_num = 0
309
+
310
+ # sam1_images = images.sam1_images
311
+ # sam1_images_size = images.sam1_images_size
312
+ # np_images = images.np_images
313
+ # np_images_size = images.np_images_size
314
 
315
+ # sam1_masks = get_mask_from_img_sam1(pe3r.mobilesamv2, pe3r.yolov8, sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
316
+ # for mask in sam1_masks:
317
+ # _, _, _ = pe3r.sam2.add_new_mask(
318
+ # inference_state=inference_state,
319
+ # frame_idx=0,
320
+ # obj_id=mask_num,
321
+ # mask=mask,
322
+ # )
323
+ # mask_num += 1
324
+
325
+ # video_segments = {} # video_segments contains the per-frame segmentation results
326
+ # for out_frame_idx, out_obj_ids, out_mask_logits in pe3r.sam2.propagate_in_video(inference_state):
327
+ # sam2_masks = (out_mask_logits > 0.0).squeeze(1)
328
+
329
+ # video_segments[out_frame_idx] = {
330
+ # out_obj_id: sam2_masks[i].cpu().numpy()
331
+ # for i, out_obj_id in enumerate(out_obj_ids)
332
+ # }
333
+
334
+ # if out_frame_idx == 0:
335
+ # continue
336
+
337
+ # sam1_masks = get_mask_from_img_sam1(pe3r.mobilesamv2, pe3r.yolov8, sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform)
338
+
339
+ # for sam1_mask in sam1_masks:
340
+ # flg = 1
341
+ # for sam2_mask in sam2_masks:
342
+ # # print(sam1_mask.shape, sam2_mask.shape)
343
+ # area1 = sam1_mask.sum()
344
+ # area2 = sam2_mask.sum()
345
+ # intersection = (sam1_mask & sam2_mask).sum()
346
+ # if min(intersection / area1, intersection / area2) > 0.25:
347
+ # flg = 0
348
+ # break
349
+ # if flg:
350
+ # video_segments[out_frame_idx][mask_num] = sam1_mask.cpu().numpy()
351
+ # mask_num += 1
352
+
353
+ # multi_view_clip_feats = torch.zeros((mask_num+1, 1024))
354
+ # multi_view_clip_feats_map = {}
355
+ # multi_view_clip_area_map = {}
356
+ # for now_frame in range(0, len(video_segments), 1):
357
+ # image = np_images[now_frame]
358
+
359
+ # seg_img_list = []
360
+ # out_obj_id_list = []
361
+ # out_obj_mask_list = []
362
+ # out_obj_area_list = []
363
+ # # NOTE: background: -1
364
+ # rev_seg_map = -np.ones(image.shape[:2], dtype=np.int64)
365
+ # sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=False)
366
+ # for out_obj_id, mask in sorted_dict_items:
367
+ # if mask.sum() == 0:
368
+ # continue
369
+ # rev_seg_map[mask] = out_obj_id
370
+ # rev_cog_seg_maps.append(rev_seg_map)
371
+
372
+ # seg_map = -np.ones(image.shape[:2], dtype=np.int64)
373
+ # sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=True)
374
+ # for out_obj_id, mask in sorted_dict_items:
375
+ # if mask.sum() == 0:
376
+ # continue
377
+ # box = np.int32(box_xyxy_to_xywh(mask_to_box(mask)))
378
 
379
+ # if box[2] == 0 and box[3] == 0:
380
+ # continue
381
+ # # print(box)
382
+ # seg_img = get_seg_img(mask, box, image)
383
+ # pad_seg_img = cv2.resize(pad_img(seg_img), (256,256))
384
+ # seg_img_list.append(pad_seg_img)
385
+ # seg_map[mask] = out_obj_id
386
+ # out_obj_id_list.append(out_obj_id)
387
+ # out_obj_area_list.append(np.count_nonzero(mask))
388
+ # out_obj_mask_list.append(mask)
389
+
390
+ # if len(seg_img_list) == 0:
391
+ # cog_seg_maps.append(seg_map)
392
+ # continue
393
+
394
+ # seg_imgs = np.stack(seg_img_list, axis=0) # b,H,W,3
395
+ # seg_imgs = torch.from_numpy(seg_imgs).permute(0,3,1,2) # / 255.0
396
 
397
+ # inputs = pe3r.siglip_processor(images=seg_imgs, return_tensors="pt")
398
+ # inputs = {key: value.to(device) for key, value in inputs.items()}
399
 
400
+ # image_features = pe3r.siglip.get_image_features(**inputs)
401
+ # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
402
+ # image_features = image_features.detach().cpu()
403
+
404
+ # for i in range(len(out_obj_mask_list)):
405
+ # for j in range(i + 1, len(out_obj_mask_list)):
406
+ # mask1 = out_obj_mask_list[i]
407
+ # mask2 = out_obj_mask_list[j]
408
+ # intersection = np.logical_and(mask1, mask2).sum()
409
+ # area1 = out_obj_area_list[i]
410
+ # area2 = out_obj_area_list[j]
411
+ # if min(intersection / area1, intersection / area2) > 0.025:
412
+ # conf1 = area1 / (area1 + area2)
413
+ # # conf2 = area2 / (area1 + area2)
414
+ # image_features[j] = slerp(image_features[j], image_features[i], conf1)
415
+
416
+ # for i, clip_feat in enumerate(image_features):
417
+ # id = out_obj_id_list[i]
418
+ # if id in multi_view_clip_feats_map.keys():
419
+ # multi_view_clip_feats_map[id].append(clip_feat)
420
+ # multi_view_clip_area_map[id].append(out_obj_area_list[i])
421
+ # else:
422
+ # multi_view_clip_feats_map[id] = [clip_feat]
423
+ # multi_view_clip_area_map[id] = [out_obj_area_list[i]]
424
+
425
+ # cog_seg_maps.append(seg_map)
426
+ # del image_features
427
 
428
+ # for i in range(mask_num):
429
+ # if i in multi_view_clip_feats_map.keys():
430
+ # clip_feats = multi_view_clip_feats_map[i]
431
+ # mask_area = multi_view_clip_area_map[i]
432
+ # multi_view_clip_feats[i] = slerp_multiple(torch.stack(clip_feats), np.stack(mask_area))
433
+ # else:
434
+ # multi_view_clip_feats[i] = torch.zeros((1024))
435
+ # multi_view_clip_feats[mask_num] = torch.zeros((1024))
436
 
437
+ # return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
438
 
439
  @spaces.GPU(duration=180)
440
  def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
 
452
  images = Images(filelist=filelist, device=device)
453
 
454
  # try:
455
+ # cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images)
456
+ # imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
457
  # except Exception as e:
458
+ rev_cog_seg_maps = []
459
+ for tmp_img in images.np_images:
460
+ rev_seg_map = -np.ones(tmp_img.shape[:2], dtype=np.int64)
461
+ rev_cog_seg_maps.append(rev_seg_map)
462
+ cog_seg_maps = rev_cog_seg_maps
463
+ cog_feats = torch.zeros((1, 1024))
464
+ imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
465
 
466
  if len(imgs) == 1:
467
  imgs = [imgs[0], copy.deepcopy(imgs[0])]
 
499
  scene.ori_imgs = ori_imgs
500
  print(e)
501
 
 
 
502
  outfile = get_3D_model_from_scene(outdir, scene, min_conf_thr, as_pointcloud, mask_sky,
503
  clean_depth, transparent_cams, cam_size)
 
504
  # also return rgb, depth and confidence imgs
505
  # depth is normalized with the max value for all images
506
  # we apply the jet colormap on the confidence maps