Tony Lian commited on
Commit
93de48e
1 Parent(s): 0cbad80

Apply batching to SAM to reduce the memory cost with many objects

Browse files
Files changed (1) hide show
  1. generation.py +13 -7
generation.py CHANGED
@@ -53,7 +53,7 @@ def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input
53
  batch_size = input_len
54
 
55
  run_times = int(np.ceil(input_len / batch_size))
56
- single_object_images, single_object_pil_images_box_ann, latents_all = [], [], []
57
  for batch_idx in range(run_times):
58
  input_latents_batch, bboxes_batch, phrases_batch = input_latents[batch_idx * batch_size:(batch_idx + 1) * batch_size], \
59
  bboxes[batch_idx * batch_size:(batch_idx + 1) * batch_size], phrases[batch_idx * batch_size:(batch_idx + 1) * batch_size]
@@ -68,17 +68,23 @@ def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input
68
  gc.collect()
69
  torch.cuda.empty_cache()
70
 
71
- single_object_images.append(single_object_images_batch)
 
 
 
72
  single_object_pil_images_box_ann.append(single_object_pil_images_box_ann_batch)
73
  latents_all.append(latents_all_batch)
74
 
75
- single_object_images, single_object_pil_images_box_ann, latents_all = np.concatenate(single_object_images, axis=0), sum(single_object_pil_images_box_ann, []), torch.cat(latents_all, dim=1)
76
-
77
- mask_selected, conf_score_selected = sam.sam_refine_boxes(sam_input_images=single_object_images, boxes=bboxes, model_dict=model_dict, verbose=verbose, **sam_refine_kwargs)
 
 
 
78
 
79
- # mask_selected: List[List[Array of shape (64, 64)]]
80
 
81
- mask_selected = np.array(mask_selected)[:, 0]
82
 
83
  mask_selected_tensor = torch.tensor(mask_selected)
84
 
 
53
  batch_size = input_len
54
 
55
  run_times = int(np.ceil(input_len / batch_size))
56
+ mask_selected_list, single_object_pil_images_box_ann, latents_all = [], [], []
57
  for batch_idx in range(run_times):
58
  input_latents_batch, bboxes_batch, phrases_batch = input_latents[batch_idx * batch_size:(batch_idx + 1) * batch_size], \
59
  bboxes[batch_idx * batch_size:(batch_idx + 1) * batch_size], phrases[batch_idx * batch_size:(batch_idx + 1) * batch_size]
 
68
  gc.collect()
69
  torch.cuda.empty_cache()
70
 
71
+ # `sam_refine_boxes` also calls `empty_cache` so we don't need to explicitly empty the cache again.
72
+ mask_selected, _ = sam.sam_refine_boxes(sam_input_images=single_object_images_batch, boxes=bboxes_batch, model_dict=model_dict, verbose=verbose, **sam_refine_kwargs)
73
+
74
+ mask_selected_list.append(np.array(mask_selected)[:, 0])
75
  single_object_pil_images_box_ann.append(single_object_pil_images_box_ann_batch)
76
  latents_all.append(latents_all_batch)
77
 
78
+ single_object_pil_images_box_ann, latents_all = sum(single_object_pil_images_box_ann, []), torch.cat(latents_all, dim=1)
79
+
80
+ # mask_selected_list: List(batch)[List(image)[List(box)[Array of shape (64, 64)]]]
81
+
82
+ mask_selected = np.concatenate(mask_selected_list, axis=0)
83
+ mask_selected = mask_selected.reshape((-1, *mask_selected.shape[-2:]))
84
 
85
+ assert mask_selected.shape[0] == input_latents.shape[0], f"{mask_selected.shape[0]} != {input_latents.shape[0]}"
86
 
87
+ print(mask_selected.shape)
88
 
89
  mask_selected_tensor = torch.tensor(mask_selected)
90