Spaces:
Sleeping
Sleeping
Tony Lian
commited on
Commit
•
61ac46b
1
Parent(s):
7b0a7ad
Add batched single object generation
Browse files- app.py +4 -4
- generation.py +53 -24
- models/models.py +19 -17
- models/pipelines.py +47 -29
- models/sam.py +50 -29
app.py
CHANGED
@@ -109,7 +109,7 @@ def get_ours_image(response, seed, num_inference_steps=20, dpm_scheduler=True, u
|
|
109 |
spec, bg_seed=seed, overall_prompt_override=overall_prompt_override, fg_seed_start=fg_seed_start,
|
110 |
fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio, use_autocast=use_autocast,
|
111 |
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
|
112 |
-
so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt
|
113 |
)
|
114 |
images = [image_np]
|
115 |
if show_so_imgs:
|
@@ -201,7 +201,7 @@ html = f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to
|
|
201 |
<p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
|
202 |
<p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the object boxes bigger).</p>
|
203 |
<p>3. You can also try prompts in Simplified Chinese. If you want to try prompts in another language, translate the first line of last example to your language.</p>
|
204 |
-
<p>4. The diffusion model only runs
|
205 |
<p>5. Duplicate this space and add GPU or clone the space and run locally to skip the queue and run our model faster. (Currently we are using a T4, and you can add a A10G to make it 5x faster) {duplicate_html}</p>
|
206 |
<br/>
|
207 |
<p>Implementation note: In this demo, we replace the attention manipulation in our layout-guided Stable Diffusion described in our paper with GLIGEN due to much faster inference speed (<b>FlashAttention supported, no backprop needed</b> during inference). Compared to vanilla GLIGEN, we have better coherence. Other parts of text-to-image pipeline, including single object generation and SAM, remain the same. The settings and examples in the prompt are simplified in this demo.</p>
|
@@ -237,12 +237,12 @@ with gr.Blocks(
|
|
237 |
with gr.Column(scale=1):
|
238 |
response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
|
239 |
overall_prompt_override = gr.Textbox(lines=2, label="Prompt for overall generation (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
|
|
|
240 |
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
241 |
with gr.Accordion("Advanced options (play around for better generation)", open=False):
|
242 |
frozen_step_ratio = gr.Slider(0, 1, value=0.4, step=0.1, label="Foreground frozen steps ratio (higher: preserve object attributes; lower: higher coherence; set to 0: (almost) equivalent to vanilla GLIGEN except details)")
|
243 |
gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.3, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
|
244 |
-
|
245 |
-
dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend 50 or even more inference steps)", show_label=False, value=True)
|
246 |
use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision", show_label=False, value=True)
|
247 |
fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
|
248 |
fg_blending_ratio = gr.Slider(0, 1, value=0.1, step=0.01, label="Variations added to foreground for single object generation (0: no variation, 1: max variation)")
|
|
|
109 |
spec, bg_seed=seed, overall_prompt_override=overall_prompt_override, fg_seed_start=fg_seed_start,
|
110 |
fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio, use_autocast=use_autocast,
|
111 |
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key,
|
112 |
+
so_negative_prompt=so_negative_prompt, overall_negative_prompt=overall_negative_prompt, so_batch_size=8
|
113 |
)
|
114 |
images = [image_np]
|
115 |
if show_so_imgs:
|
|
|
201 |
<p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
|
202 |
<p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the object boxes bigger).</p>
|
203 |
<p>3. You can also try prompts in Simplified Chinese. If you want to try prompts in another language, translate the first line of last example to your language.</p>
|
204 |
+
<p>4. The diffusion model only runs 50 steps by default in this demo. You can make it run more/fewer steps to get higher quality images or faster generation (or tweak frozen steps/guidance steps for better guidance and coherence).</p>
|
205 |
<p>5. Duplicate this space and add GPU or clone the space and run locally to skip the queue and run our model faster. (Currently we are using a T4, and you can add a A10G to make it 5x faster) {duplicate_html}</p>
|
206 |
<br/>
|
207 |
<p>Implementation note: In this demo, we replace the attention manipulation in our layout-guided Stable Diffusion described in our paper with GLIGEN due to much faster inference speed (<b>FlashAttention supported, no backprop needed</b> during inference). Compared to vanilla GLIGEN, we have better coherence. Other parts of text-to-image pipeline, including single object generation and SAM, remain the same. The settings and examples in the prompt are simplified in this demo.</p>
|
|
|
237 |
with gr.Column(scale=1):
|
238 |
response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
|
239 |
overall_prompt_override = gr.Textbox(lines=2, label="Prompt for overall generation (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
|
240 |
+
num_inference_steps = gr.Slider(1, 250, value=50, step=1, label="Number of denoising steps (set to 20 to trade quality for faster generation)")
|
241 |
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
242 |
with gr.Accordion("Advanced options (play around for better generation)", open=False):
|
243 |
frozen_step_ratio = gr.Slider(0, 1, value=0.4, step=0.1, label="Foreground frozen steps ratio (higher: preserve object attributes; lower: higher coherence; set to 0: (almost) equivalent to vanilla GLIGEN except details)")
|
244 |
gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.3, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
|
245 |
+
dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend >=50 inference steps)", show_label=False, value=True)
|
|
|
246 |
use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision", show_label=False, value=True)
|
247 |
fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
|
248 |
fg_blending_ratio = gr.Slider(0, 1, value=0.1, step=0.01, label="Variations added to foreground for single object generation (0: no variation, 1: max variation)")
|
generation.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
version = "v3.0"
|
2 |
|
3 |
import torch
|
|
|
4 |
import models
|
5 |
import utils
|
6 |
from models import pipelines, sam
|
@@ -21,7 +22,6 @@ H, W = height // 8, width // 8 # size of the latent
|
|
21 |
guidance_scale = 7.5 # Scale for classifier-free guidance
|
22 |
|
23 |
# batch size that is not 1 is not supported
|
24 |
-
so_batch_size = 1
|
25 |
overall_batch_size = 1
|
26 |
|
27 |
# discourage masks with confidence below
|
@@ -33,41 +33,70 @@ discourage_mask_below_coarse_iou = 0.25
|
|
33 |
run_ind = None
|
34 |
|
35 |
|
36 |
-
def
|
37 |
sam_refine_kwargs, num_inference_steps, gligen_scheduled_sampling_beta=0.3,
|
38 |
-
verbose=False, scheduler_key=None, visualize=True):
|
|
|
39 |
|
40 |
-
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
mask_selected_tensor = torch.tensor(mask_selected)
|
51 |
|
52 |
-
|
|
|
|
|
53 |
|
54 |
def get_masked_latents_all_list(so_prompt_phrase_word_box_list, input_latents_list, so_input_embeddings, verbose=False, **kwargs):
|
55 |
-
latents_all_list, mask_tensor_list
|
56 |
|
57 |
if not so_prompt_phrase_word_box_list:
|
58 |
return latents_all_list, mask_tensor_list
|
59 |
|
60 |
-
|
61 |
|
62 |
-
for
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
mask_tensor_list.append(mask_tensor)
|
70 |
-
so_img_list.append(so_img)
|
71 |
|
72 |
return latents_all_list, mask_tensor_list, so_img_list
|
73 |
|
@@ -77,7 +106,7 @@ def get_masked_latents_all_list(so_prompt_phrase_word_box_list, input_latents_li
|
|
77 |
def run(
|
78 |
spec, bg_seed = 1, overall_prompt_override="", fg_seed_start = 20, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta = 0.3, num_inference_steps = 20,
|
79 |
so_center_box = False, fg_blending_ratio = 0.1, scheduler_key='dpm_scheduler', so_negative_prompt = DEFAULT_SO_NEGATIVE_PROMPT, overall_negative_prompt = DEFAULT_OVERALL_NEGATIVE_PROMPT, so_horizontal_center_only = True,
|
80 |
-
align_with_overall_bboxes = False, horizontal_shift_only = True, use_autocast = False
|
81 |
):
|
82 |
"""
|
83 |
so_center_box: using centered box in single object generation
|
@@ -130,7 +159,7 @@ def run(
|
|
130 |
latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
|
131 |
so_prompt_phrase_word_box_list, input_latents_list,
|
132 |
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
133 |
-
sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key, verbose=verbose
|
134 |
)
|
135 |
|
136 |
|
|
|
1 |
version = "v3.0"
|
2 |
|
3 |
import torch
|
4 |
+
import numpy as np
|
5 |
import models
|
6 |
import utils
|
7 |
from models import pipelines, sam
|
|
|
22 |
guidance_scale = 7.5 # Scale for classifier-free guidance
|
23 |
|
24 |
# batch size that is not 1 is not supported
|
|
|
25 |
overall_batch_size = 1
|
26 |
|
27 |
# discourage masks with confidence below
|
|
|
33 |
run_ind = None
|
34 |
|
35 |
|
36 |
+
def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input_latents_list, input_embeddings,
|
37 |
sam_refine_kwargs, num_inference_steps, gligen_scheduled_sampling_beta=0.3,
|
38 |
+
verbose=False, scheduler_key=None, visualize=True, batch_size=None):
|
39 |
+
# batch_size=None: does not limit the batch size (pass all input together)
|
40 |
|
41 |
+
# prompts and words are not used since we don't have cross-attention control in this function
|
42 |
|
43 |
+
input_latents = torch.cat(input_latents_list, dim=0)
|
44 |
+
|
45 |
+
# We need to "unsqueeze" to tell that we have only one box and phrase in each batch item
|
46 |
+
bboxes, phrases = [[item] for item in bboxes], [[item] for item in phrases]
|
47 |
+
|
48 |
+
input_len = len(bboxes)
|
49 |
+
assert len(bboxes) == len(phrases), f"{len(bboxes)} != {len(phrases)}"
|
50 |
+
|
51 |
+
if batch_size is None:
|
52 |
+
batch_size = input_len
|
53 |
+
|
54 |
+
run_times = int(np.ceil(input_len / batch_size))
|
55 |
+
single_object_images, single_object_pil_images_box_ann, latents_all = [], [], []
|
56 |
+
for batch_idx in range(run_times):
|
57 |
+
input_latents_batch, bboxes_batch, phrases_batch = input_latents[batch_idx * batch_size:(batch_idx + 1) * batch_size], \
|
58 |
+
bboxes[batch_idx * batch_size:(batch_idx + 1) * batch_size], phrases[batch_idx * batch_size:(batch_idx + 1) * batch_size]
|
59 |
+
input_embeddings_batch = input_embeddings[0], input_embeddings[1][batch_idx * batch_size:(batch_idx + 1) * batch_size]
|
60 |
+
|
61 |
+
_, single_object_images_batch, single_object_pil_images_box_ann_batch, latents_all_batch = pipelines.generate_gligen(
|
62 |
+
model_dict, input_latents_batch, input_embeddings_batch, num_inference_steps, bboxes_batch, phrases_batch, gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
63 |
+
guidance_scale=guidance_scale, return_saved_cross_attn=False,
|
64 |
+
return_box_vis=True, save_all_latents=True, batched_condition=True, scheduler_key=scheduler_key
|
65 |
+
)
|
66 |
+
|
67 |
+
single_object_images.append(single_object_images_batch)
|
68 |
+
single_object_pil_images_box_ann.append(single_object_pil_images_box_ann_batch)
|
69 |
+
latents_all.append(latents_all_batch)
|
70 |
|
71 |
+
single_object_images, single_object_pil_images_box_ann, latents_all = np.concatenate(single_object_images, axis=0), sum(single_object_pil_images_box_ann, []), torch.cat(latents_all, dim=1)
|
72 |
|
73 |
+
mask_selected, conf_score_selected = sam.sam_refine_boxes(sam_input_images=single_object_images, boxes=bboxes, model_dict=model_dict, verbose=verbose, **sam_refine_kwargs)
|
74 |
+
|
75 |
+
# mask_selected: List[List[Array of shape (64, 64)]]
|
76 |
+
|
77 |
+
mask_selected = np.array(mask_selected)[:, 0]
|
78 |
+
|
79 |
mask_selected_tensor = torch.tensor(mask_selected)
|
80 |
|
81 |
+
latents_all = latents_all.transpose(0,1)[:,:,None,...]
|
82 |
+
|
83 |
+
return latents_all, mask_selected_tensor, single_object_pil_images_box_ann
|
84 |
|
85 |
def get_masked_latents_all_list(so_prompt_phrase_word_box_list, input_latents_list, so_input_embeddings, verbose=False, **kwargs):
|
86 |
+
latents_all_list, mask_tensor_list = [], []
|
87 |
|
88 |
if not so_prompt_phrase_word_box_list:
|
89 |
return latents_all_list, mask_tensor_list
|
90 |
|
91 |
+
prompts, bboxes, phrases, words = [], [], [], []
|
92 |
|
93 |
+
for prompt, phrase, word, box in so_prompt_phrase_word_box_list:
|
94 |
+
prompts.append(prompt)
|
95 |
+
bboxes.append(box)
|
96 |
+
phrases.append(phrase)
|
97 |
+
words.append(word)
|
98 |
+
|
99 |
+
latents_all_list, mask_tensor_list, so_img_list = generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input_latents_list, input_embeddings=so_input_embeddings, verbose=verbose, **kwargs)
|
|
|
|
|
100 |
|
101 |
return latents_all_list, mask_tensor_list, so_img_list
|
102 |
|
|
|
106 |
def run(
|
107 |
spec, bg_seed = 1, overall_prompt_override="", fg_seed_start = 20, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta = 0.3, num_inference_steps = 20,
|
108 |
so_center_box = False, fg_blending_ratio = 0.1, scheduler_key='dpm_scheduler', so_negative_prompt = DEFAULT_SO_NEGATIVE_PROMPT, overall_negative_prompt = DEFAULT_OVERALL_NEGATIVE_PROMPT, so_horizontal_center_only = True,
|
109 |
+
align_with_overall_bboxes = False, horizontal_shift_only = True, use_autocast = False, so_batch_size = None
|
110 |
):
|
111 |
"""
|
112 |
so_center_box: using centered box in single object generation
|
|
|
159 |
latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
|
160 |
so_prompt_phrase_word_box_list, input_latents_list,
|
161 |
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
162 |
+
sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key, verbose=verbose, batch_size=so_batch_size
|
163 |
)
|
164 |
|
165 |
|
models/models.py
CHANGED
@@ -75,20 +75,22 @@ def encode_prompts(tokenizer, text_encoder, prompts, negative_prompt="", return_
|
|
75 |
return text_embeddings
|
76 |
return text_embeddings, uncond_embeddings, cond_embeddings
|
77 |
|
78 |
-
def
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
75 |
return text_embeddings
|
76 |
return text_embeddings, uncond_embeddings, cond_embeddings
|
77 |
|
78 |
+
def process_input_embeddings(input_embeddings):
|
79 |
+
assert isinstance(input_embeddings, (tuple, list))
|
80 |
+
if len(input_embeddings) == 3:
|
81 |
+
# input_embeddings: text_embeddings, uncond_embeddings, cond_embeddings
|
82 |
+
# Assume `uncond_embeddings` is full (has batch size the same as cond_embeddings)
|
83 |
+
_, uncond_embeddings, cond_embeddings = input_embeddings
|
84 |
+
assert uncond_embeddings.shape[0] == cond_embeddings.shape[0], f"{uncond_embeddings.shape[0]} != {cond_embeddings.shape[0]}"
|
85 |
+
return input_embeddings
|
86 |
+
elif len(input_embeddings) == 2:
|
87 |
+
# input_embeddings: uncond_embeddings, cond_embeddings
|
88 |
+
# uncond_embeddings may have only one item
|
89 |
+
uncond_embeddings, cond_embeddings = input_embeddings
|
90 |
+
if uncond_embeddings.shape[0] == 1:
|
91 |
+
uncond_embeddings = uncond_embeddings.expand(cond_embeddings.shape)
|
92 |
+
# We follow the convention: negative (unconditional) prompt comes first
|
93 |
+
text_embeddings = torch.cat((uncond_embeddings, cond_embeddings), dim=0)
|
94 |
+
return text_embeddings, uncond_embeddings, cond_embeddings
|
95 |
+
else:
|
96 |
+
raise ValueError(f"input_embeddings length: {len(input_embeddings)}")
|
models/pipelines.py
CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
|
|
5 |
import gc
|
6 |
import numpy as np
|
7 |
from .attention import GatedSelfAttentionDense
|
8 |
-
from .models import torch_device
|
9 |
|
10 |
@torch.no_grad()
|
11 |
def encode(model_dict, image, generator):
|
@@ -88,17 +88,56 @@ def gligen_enable_fuser(unet, enabled=True):
|
|
88 |
if isinstance(module, GatedSelfAttentionDense):
|
89 |
module.enabled = enabled
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
@torch.no_grad()
|
92 |
def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, num_images_per_prompt=1, gligen_scheduled_sampling_beta: float = 0.3, guidance_scale=7.5,
|
93 |
frozen_steps=20, frozen_mask=None,
|
94 |
return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
|
95 |
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
|
96 |
-
return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler'):
|
97 |
"""
|
98 |
The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
|
99 |
"""
|
100 |
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
|
101 |
-
|
|
|
102 |
|
103 |
if latents.dim() == 5:
|
104 |
# latents_all from the input side, different from the latents_all to be saved
|
@@ -122,33 +161,12 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
122 |
if frozen_mask is not None:
|
123 |
frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
|
124 |
|
125 |
-
batch_size = 1
|
126 |
-
|
127 |
# 5.1 Prepare GLIGEN variables
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
_boxes = bboxes
|
132 |
|
133 |
-
|
134 |
-
boxes = torch.zeros(max_objs, 4, device=torch_device, dtype=dtype)
|
135 |
-
phrase_embeddings = torch.zeros(max_objs, 768, device=torch_device, dtype=dtype)
|
136 |
-
masks = torch.zeros(max_objs, device=torch_device, dtype=dtype)
|
137 |
-
|
138 |
-
if n_objs > 0:
|
139 |
-
boxes[:n_objs] = torch.tensor(_boxes[:n_objs])
|
140 |
-
tokenizer_inputs = tokenizer(phrases, padding=True, return_tensors="pt").to(torch_device)
|
141 |
-
_phrase_embeddings = text_encoder(**tokenizer_inputs).pooler_output
|
142 |
-
phrase_embeddings[:n_objs] = _phrase_embeddings[:n_objs]
|
143 |
-
masks[:n_objs] = 1
|
144 |
-
|
145 |
-
# Classifier-free guidance
|
146 |
-
repeat_batch = batch_size * num_images_per_prompt * 2
|
147 |
-
|
148 |
-
boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
|
149 |
-
phrase_embeddings = phrase_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
|
150 |
-
masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
|
151 |
-
masks[:repeat_batch // 2] = 0
|
152 |
|
153 |
if return_saved_cross_attn:
|
154 |
saved_attns = []
|
@@ -215,7 +233,7 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
215 |
if return_saved_cross_attn:
|
216 |
ret.append(saved_attns)
|
217 |
if return_box_vis:
|
218 |
-
pil_images = [utils.draw_box(Image.fromarray(image),
|
219 |
ret.append(pil_images)
|
220 |
if save_all_latents:
|
221 |
latents_all = torch.stack(latents_all, dim=0)
|
|
|
5 |
import gc
|
6 |
import numpy as np
|
7 |
from .attention import GatedSelfAttentionDense
|
8 |
+
from .models import process_input_embeddings, torch_device
|
9 |
|
10 |
@torch.no_grad()
|
11 |
def encode(model_dict, image, generator):
|
|
|
88 |
if isinstance(module, GatedSelfAttentionDense):
|
89 |
module.enabled = enabled
|
90 |
|
91 |
+
def prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt):
|
92 |
+
batch_size = len(bboxes)
|
93 |
+
|
94 |
+
assert len(phrases) == len(bboxes)
|
95 |
+
max_objs = 30
|
96 |
+
|
97 |
+
n_objs = min(max([len(bboxes_item) for bboxes_item in bboxes]), max_objs)
|
98 |
+
boxes = torch.zeros((batch_size, max_objs, 4), device=torch_device, dtype=dtype)
|
99 |
+
phrase_embeddings = torch.zeros((batch_size, max_objs, 768), device=torch_device, dtype=dtype)
|
100 |
+
# masks is a 1D tensor deciding which of the enteries to be enabled
|
101 |
+
masks = torch.zeros((batch_size, max_objs), device=torch_device, dtype=dtype)
|
102 |
+
|
103 |
+
if n_objs > 0:
|
104 |
+
for idx, (bboxes_item, phrases_item) in enumerate(zip(bboxes, phrases)):
|
105 |
+
# the length of `bboxes_item` could be smaller than `n_objs` because n_objs takes the max of item length
|
106 |
+
bboxes_item = torch.tensor(bboxes_item[:n_objs])
|
107 |
+
boxes[idx, :bboxes_item.shape[0]] = bboxes_item
|
108 |
+
|
109 |
+
tokenizer_inputs = tokenizer(phrases_item[:n_objs], padding=True, return_tensors="pt").to(torch_device)
|
110 |
+
_phrase_embeddings = text_encoder(**tokenizer_inputs).pooler_output
|
111 |
+
phrase_embeddings[idx, :_phrase_embeddings.shape[0]] = _phrase_embeddings
|
112 |
+
assert bboxes_item.shape[0] == _phrase_embeddings.shape[0], f"{bboxes_item.shape[0]} != {_phrase_embeddings.shape[0]}"
|
113 |
+
|
114 |
+
masks[idx, :bboxes_item.shape[0]] = 1
|
115 |
+
|
116 |
+
# Classifier-free guidance
|
117 |
+
repeat_times = num_images_per_prompt * 2
|
118 |
+
condition_len = batch_size * repeat_times
|
119 |
+
|
120 |
+
boxes = boxes.repeat(repeat_times, 1, 1)
|
121 |
+
phrase_embeddings = phrase_embeddings.repeat(repeat_times, 1, 1)
|
122 |
+
masks = masks.repeat(repeat_times, 1)
|
123 |
+
masks[:condition_len // 2] = 0
|
124 |
+
|
125 |
+
# print("shapes:", boxes.shape, phrase_embeddings.shape, masks.shape)
|
126 |
+
|
127 |
+
return boxes, phrase_embeddings, masks, condition_len
|
128 |
+
|
129 |
@torch.no_grad()
|
130 |
def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, num_images_per_prompt=1, gligen_scheduled_sampling_beta: float = 0.3, guidance_scale=7.5,
|
131 |
frozen_steps=20, frozen_mask=None,
|
132 |
return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
|
133 |
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
|
134 |
+
return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False):
|
135 |
"""
|
136 |
The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
|
137 |
"""
|
138 |
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
|
139 |
+
|
140 |
+
text_embeddings, _, cond_embeddings = process_input_embeddings(input_embeddings)
|
141 |
|
142 |
if latents.dim() == 5:
|
143 |
# latents_all from the input side, different from the latents_all to be saved
|
|
|
161 |
if frozen_mask is not None:
|
162 |
frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
|
163 |
|
|
|
|
|
164 |
# 5.1 Prepare GLIGEN variables
|
165 |
+
if not batched_condition:
|
166 |
+
# Add batch dimension to bboxes and phrases
|
167 |
+
bboxes, phrases = [bboxes], [phrases]
|
|
|
168 |
|
169 |
+
boxes, phrase_embeddings, masks, condition_len = prepare_gligen_condition(bboxes, phrases, dtype, tokenizer, text_encoder, num_images_per_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
if return_saved_cross_attn:
|
172 |
saved_attns = []
|
|
|
233 |
if return_saved_cross_attn:
|
234 |
ret.append(saved_attns)
|
235 |
if return_box_vis:
|
236 |
+
pil_images = [utils.draw_box(Image.fromarray(image), bboxes_item, phrases_item) for image, bboxes_item, phrases_item in zip(images, bboxes, phrases)]
|
237 |
ret.append(pil_images)
|
238 |
if save_all_latents:
|
239 |
latents_all = torch.stack(latents_all, dim=0)
|
models/sam.py
CHANGED
@@ -2,6 +2,7 @@ import gc
|
|
2 |
import matplotlib.pyplot as plt
|
3 |
import numpy as np
|
4 |
import torch
|
|
|
5 |
from models import torch_device
|
6 |
from transformers import SamModel, SamProcessor
|
7 |
import utils
|
@@ -20,10 +21,18 @@ def load_sam():
|
|
20 |
|
21 |
# Not fully backward compatible with the previous implementation
|
22 |
# Reference: lmdv2/notebooks/gen_masked_latents_multi_object_ref_ca_loss_modular.ipynb
|
23 |
-
def sam(sam_model_dict, image, input_points=None, input_boxes=None, target_mask_shape=None):
|
24 |
"""target_mask_shape: (h, w)"""
|
25 |
sam_model, sam_processor = sam_model_dict['sam_model'], sam_model_dict['sam_processor']
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
with torch.no_grad():
|
28 |
with torch.autocast(torch_device):
|
29 |
inputs = sam_processor(image, input_points=input_points, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
|
@@ -31,18 +40,17 @@ def sam(sam_model_dict, image, input_points=None, input_boxes=None, target_mask_
|
|
31 |
masks = sam_processor.image_processor.post_process_masks(
|
32 |
outputs.pred_masks.cpu().float(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
|
33 |
)
|
34 |
-
conf_scores = outputs.iou_scores.
|
35 |
del inputs, outputs
|
36 |
-
|
37 |
-
gc.collect()
|
38 |
-
if torch_device == "cuda":
|
39 |
-
torch.cuda.empty_cache()
|
40 |
-
|
41 |
-
masks = masks[0][0].numpy()
|
42 |
|
43 |
-
|
44 |
-
|
45 |
|
|
|
|
|
|
|
|
|
|
|
46 |
return masks, conf_scores
|
47 |
|
48 |
def sam_point_input(sam_model_dict, image, input_points, **kwargs):
|
@@ -154,26 +162,39 @@ def sam_refine_attn(sam_input_image, token_attn_np, model_dict, height, width, H
|
|
154 |
|
155 |
return mask_selected, conf_score_selected
|
156 |
|
157 |
-
def sam_refine_box(sam_input_image, box,
|
|
|
|
|
|
|
|
|
158 |
# (w, h)
|
159 |
-
input_boxes = utils.scale_proportion(box, H=height, W=width)
|
160 |
-
input_boxes = [input_boxes]
|
161 |
|
162 |
-
masks, conf_scores = sam_box_input(model_dict, image=
|
163 |
|
164 |
-
|
165 |
-
if verbose:
|
166 |
-
# Also the box is the input for SAM
|
167 |
-
plt.title("Binary mask from input box (for iou)")
|
168 |
-
plt.imshow(mask_binary)
|
169 |
-
plt.show()
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import matplotlib.pyplot as plt
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
from models import torch_device
|
7 |
from transformers import SamModel, SamProcessor
|
8 |
import utils
|
|
|
21 |
|
22 |
# Not fully backward compatible with the previous implementation
|
23 |
# Reference: lmdv2/notebooks/gen_masked_latents_multi_object_ref_ca_loss_modular.ipynb
|
24 |
+
def sam(sam_model_dict, image, input_points=None, input_boxes=None, target_mask_shape=None, return_numpy=True):
|
25 |
"""target_mask_shape: (h, w)"""
|
26 |
sam_model, sam_processor = sam_model_dict['sam_model'], sam_model_dict['sam_processor']
|
27 |
|
28 |
+
if input_boxes and isinstance(input_boxes[0], tuple):
|
29 |
+
# Convert tuple to list
|
30 |
+
input_boxes = [list(input_box) for input_box in input_boxes]
|
31 |
+
|
32 |
+
if input_boxes and input_boxes[0] and isinstance(input_boxes[0][0], tuple):
|
33 |
+
# Convert tuple to list
|
34 |
+
input_boxes = [[list(input_box) for input_box in input_boxes_item] for input_boxes_item in input_boxes]
|
35 |
+
|
36 |
with torch.no_grad():
|
37 |
with torch.autocast(torch_device):
|
38 |
inputs = sam_processor(image, input_points=input_points, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
|
|
|
40 |
masks = sam_processor.image_processor.post_process_masks(
|
41 |
outputs.pred_masks.cpu().float(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
|
42 |
)
|
43 |
+
conf_scores = outputs.iou_scores.cpu().numpy()[0,0]
|
44 |
del inputs, outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
gc.collect()
|
47 |
+
torch.cuda.empty_cache()
|
48 |
|
49 |
+
if return_numpy:
|
50 |
+
masks = [F.interpolate(masks_item.type(torch.float), target_mask_shape, mode='bilinear').type(torch.bool).numpy() for masks_item in masks]
|
51 |
+
else:
|
52 |
+
masks = [F.interpolate(masks_item.type(torch.float), target_mask_shape, mode='bilinear').type(torch.bool) for masks_item in masks]
|
53 |
+
|
54 |
return masks, conf_scores
|
55 |
|
56 |
def sam_point_input(sam_model_dict, image, input_points, **kwargs):
|
|
|
162 |
|
163 |
return mask_selected, conf_score_selected
|
164 |
|
165 |
+
def sam_refine_box(sam_input_image, box, *args, **kwargs):
|
166 |
+
sam_input_images, boxes = [sam_input_image], [box]
|
167 |
+
return sam_refine_boxes(sam_input_images, boxes, *args, **kwargs)
|
168 |
+
|
169 |
+
def sam_refine_boxes(sam_input_images, boxes, model_dict, height, width, H, W, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
|
170 |
# (w, h)
|
171 |
+
input_boxes = [[utils.scale_proportion(box, H=height, W=width) for box in boxes_item] for boxes_item in boxes]
|
|
|
172 |
|
173 |
+
masks, conf_scores = sam_box_input(model_dict, image=sam_input_images, input_boxes=input_boxes, target_mask_shape=(H, W))
|
174 |
|
175 |
+
mask_selected_batched_list, conf_score_selected_batched_list = [], []
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
+
for boxes_item, masks_item in zip(boxes, masks):
|
178 |
+
mask_selected_list, conf_score_selected_list = [], []
|
179 |
+
for box, three_masks in zip(boxes_item, masks_item):
|
180 |
+
mask_binary = utils.proportion_to_mask(box, H, W, return_np=True)
|
181 |
+
if verbose:
|
182 |
+
# Also the box is the input for SAM
|
183 |
+
plt.title("Binary mask from input box (for iou)")
|
184 |
+
plt.imshow(mask_binary)
|
185 |
+
plt.show()
|
186 |
+
|
187 |
+
coarse_ious = get_iou_with_resize(mask_binary, three_masks, masks_shape=mask_binary.shape)
|
188 |
+
|
189 |
+
mask_selected, conf_score_selected = select_mask(three_masks, conf_scores, coarse_ious=coarse_ious,
|
190 |
+
rule="largest_over_conf",
|
191 |
+
discourage_mask_below_confidence=discourage_mask_below_confidence,
|
192 |
+
discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
|
193 |
+
verbose=True)
|
194 |
+
|
195 |
+
mask_selected_list.append(mask_selected)
|
196 |
+
conf_score_selected_list.append(conf_score_selected)
|
197 |
+
mask_selected_batched_list.append(mask_selected_list)
|
198 |
+
conf_score_selected_batched_list.append(conf_score_selected_list)
|
199 |
+
|
200 |
+
return mask_selected_batched_list, conf_score_selected_batched_list
|