Spaces:
Sleeping
Sleeping
Tony Lian
commited on
Commit
•
93de48e
1
Parent(s):
0cbad80
Apply batching to SAM to reduce the memory cost with many objects
Browse files- generation.py +13 -7
generation.py
CHANGED
@@ -53,7 +53,7 @@ def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input
|
|
53 |
batch_size = input_len
|
54 |
|
55 |
run_times = int(np.ceil(input_len / batch_size))
|
56 |
-
|
57 |
for batch_idx in range(run_times):
|
58 |
input_latents_batch, bboxes_batch, phrases_batch = input_latents[batch_idx * batch_size:(batch_idx + 1) * batch_size], \
|
59 |
bboxes[batch_idx * batch_size:(batch_idx + 1) * batch_size], phrases[batch_idx * batch_size:(batch_idx + 1) * batch_size]
|
@@ -68,17 +68,23 @@ def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input
|
|
68 |
gc.collect()
|
69 |
torch.cuda.empty_cache()
|
70 |
|
71 |
-
|
|
|
|
|
|
|
72 |
single_object_pil_images_box_ann.append(single_object_pil_images_box_ann_batch)
|
73 |
latents_all.append(latents_all_batch)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
78 |
|
79 |
-
|
80 |
|
81 |
-
mask_selected
|
82 |
|
83 |
mask_selected_tensor = torch.tensor(mask_selected)
|
84 |
|
|
|
53 |
batch_size = input_len
|
54 |
|
55 |
run_times = int(np.ceil(input_len / batch_size))
|
56 |
+
mask_selected_list, single_object_pil_images_box_ann, latents_all = [], [], []
|
57 |
for batch_idx in range(run_times):
|
58 |
input_latents_batch, bboxes_batch, phrases_batch = input_latents[batch_idx * batch_size:(batch_idx + 1) * batch_size], \
|
59 |
bboxes[batch_idx * batch_size:(batch_idx + 1) * batch_size], phrases[batch_idx * batch_size:(batch_idx + 1) * batch_size]
|
|
|
68 |
gc.collect()
|
69 |
torch.cuda.empty_cache()
|
70 |
|
71 |
+
# `sam_refine_boxes` also calls `empty_cache` so we don't need to explicitly empty the cache again.
|
72 |
+
mask_selected, _ = sam.sam_refine_boxes(sam_input_images=single_object_images_batch, boxes=bboxes_batch, model_dict=model_dict, verbose=verbose, **sam_refine_kwargs)
|
73 |
+
|
74 |
+
mask_selected_list.append(np.array(mask_selected)[:, 0])
|
75 |
single_object_pil_images_box_ann.append(single_object_pil_images_box_ann_batch)
|
76 |
latents_all.append(latents_all_batch)
|
77 |
|
78 |
+
single_object_pil_images_box_ann, latents_all = sum(single_object_pil_images_box_ann, []), torch.cat(latents_all, dim=1)
|
79 |
+
|
80 |
+
# mask_selected_list: List(batch)[List(image)[List(box)[Array of shape (64, 64)]]]
|
81 |
+
|
82 |
+
mask_selected = np.concatenate(mask_selected_list, axis=0)
|
83 |
+
mask_selected = mask_selected.reshape((-1, *mask_selected.shape[-2:]))
|
84 |
|
85 |
+
assert mask_selected.shape[0] == input_latents.shape[0], f"{mask_selected.shape[0]} != {input_latents.shape[0]}"
|
86 |
|
87 |
+
print(mask_selected.shape)
|
88 |
|
89 |
mask_selected_tensor = torch.tensor(mask_selected)
|
90 |
|