Spaces:
Sleeping
Sleeping
Tony Lian
commited on
Commit
•
d871568
1
Parent(s):
ee0b9c2
Use fast schedule for per-box generation to speed up
Browse files- app.py +3 -3
- generation.py +7 -6
- models/pipelines.py +8 -2
- utils/latents.py +6 -2
- utils/schedule.py +19 -0
app.py
CHANGED
@@ -238,11 +238,11 @@ with gr.Blocks(
|
|
238 |
with gr.Column(scale=1):
|
239 |
response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
|
240 |
overall_prompt_override = gr.Textbox(lines=2, label="Prompt for overall generation (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
|
241 |
-
num_inference_steps = gr.Slider(1, 250, value=
|
242 |
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
243 |
with gr.Accordion("Advanced options (play around for better generation)", open=False):
|
244 |
-
frozen_step_ratio = gr.Slider(0, 1, value=0.
|
245 |
-
gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.
|
246 |
dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend >=50 inference steps)", show_label=False, value=True)
|
247 |
use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision (faster but with slightly lower quality)", show_label=False, value=True)
|
248 |
fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
|
|
|
238 |
with gr.Column(scale=1):
|
239 |
response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
|
240 |
overall_prompt_override = gr.Textbox(lines=2, label="Prompt for overall generation (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
|
241 |
+
num_inference_steps = gr.Slider(1, 250, value=50, step=1, label="Number of denoising steps (set to >=50 for higher generation quality)")
|
242 |
seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
|
243 |
with gr.Accordion("Advanced options (play around for better generation)", open=False):
|
244 |
+
frozen_step_ratio = gr.Slider(0, 1, value=0.5, step=0.1, label="Foreground frozen steps ratio (higher: preserve object attributes; lower: higher coherence; set to 0: (almost) equivalent to vanilla GLIGEN except details)")
|
245 |
+
gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.4, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
|
246 |
dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend >=50 inference steps)", show_label=False, value=True)
|
247 |
use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision (faster but with slightly lower quality)", show_label=False, value=True)
|
248 |
fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
|
generation.py
CHANGED
@@ -10,6 +10,8 @@ from shared import model_dict, sam_model_dict, DEFAULT_SO_NEGATIVE_PROMPT, DEFAU
|
|
10 |
import gc
|
11 |
|
12 |
verbose = False
|
|
|
|
|
13 |
|
14 |
vae, tokenizer, text_encoder, unet, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.dtype
|
15 |
|
@@ -36,7 +38,7 @@ run_ind = None
|
|
36 |
|
37 |
def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input_latents_list, input_embeddings,
|
38 |
sam_refine_kwargs, num_inference_steps, gligen_scheduled_sampling_beta=0.3,
|
39 |
-
verbose=False, scheduler_key=None, visualize=True, batch_size=None):
|
40 |
# batch_size=None: does not limit the batch size (pass all input together)
|
41 |
|
42 |
# prompts and words are not used since we don't have cross-attention control in this function
|
@@ -62,7 +64,7 @@ def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input
|
|
62 |
_, single_object_images_batch, single_object_pil_images_box_ann_batch, latents_all_batch = pipelines.generate_gligen(
|
63 |
model_dict, input_latents_batch, input_embeddings_batch, num_inference_steps, bboxes_batch, phrases_batch, gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
64 |
guidance_scale=guidance_scale, return_saved_cross_attn=False,
|
65 |
-
return_box_vis=True, save_all_latents=True, batched_condition=True, scheduler_key=scheduler_key
|
66 |
)
|
67 |
|
68 |
gc.collect()
|
@@ -172,16 +174,15 @@ def run(
|
|
172 |
latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
|
173 |
so_prompt_phrase_word_box_list, input_latents_list,
|
174 |
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
175 |
-
sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key, verbose=verbose, batch_size=so_batch_size
|
|
|
176 |
)
|
177 |
|
178 |
-
|
179 |
-
|
180 |
composed_latents, foreground_indices, offset_list = latents.compose_latents_with_alignment(
|
181 |
model_dict, latents_all_list, mask_tensor_list, num_inference_steps,
|
182 |
overall_batch_size, height, width, latents_bg=latents_bg,
|
183 |
align_with_overall_bboxes=align_with_overall_bboxes, overall_bboxes=overall_bboxes,
|
184 |
-
horizontal_shift_only=horizontal_shift_only
|
185 |
)
|
186 |
|
187 |
overall_bboxes_flattened, overall_phrases_flattened = [], []
|
|
|
10 |
import gc
|
11 |
|
12 |
verbose = False
|
13 |
+
# Accelerates per-box generation
|
14 |
+
use_fast_schedule = True
|
15 |
|
16 |
vae, tokenizer, text_encoder, unet, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.dtype
|
17 |
|
|
|
38 |
|
39 |
def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input_latents_list, input_embeddings,
|
40 |
sam_refine_kwargs, num_inference_steps, gligen_scheduled_sampling_beta=0.3,
|
41 |
+
verbose=False, scheduler_key=None, visualize=True, batch_size=None, **kwargs):
|
42 |
# batch_size=None: does not limit the batch size (pass all input together)
|
43 |
|
44 |
# prompts and words are not used since we don't have cross-attention control in this function
|
|
|
64 |
_, single_object_images_batch, single_object_pil_images_box_ann_batch, latents_all_batch = pipelines.generate_gligen(
|
65 |
model_dict, input_latents_batch, input_embeddings_batch, num_inference_steps, bboxes_batch, phrases_batch, gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
66 |
guidance_scale=guidance_scale, return_saved_cross_attn=False,
|
67 |
+
return_box_vis=True, save_all_latents=True, batched_condition=True, scheduler_key=scheduler_key, **kwargs
|
68 |
)
|
69 |
|
70 |
gc.collect()
|
|
|
174 |
latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
|
175 |
so_prompt_phrase_word_box_list, input_latents_list,
|
176 |
gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
|
177 |
+
sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key, verbose=verbose, batch_size=so_batch_size,
|
178 |
+
fast_after_steps=frozen_steps if use_fast_schedule else None, fast_rate=2
|
179 |
)
|
180 |
|
|
|
|
|
181 |
composed_latents, foreground_indices, offset_list = latents.compose_latents_with_alignment(
|
182 |
model_dict, latents_all_list, mask_tensor_list, num_inference_steps,
|
183 |
overall_batch_size, height, width, latents_bg=latents_bg,
|
184 |
align_with_overall_bboxes=align_with_overall_bboxes, overall_bboxes=overall_bboxes,
|
185 |
+
horizontal_shift_only=horizontal_shift_only, use_fast_schedule=use_fast_schedule, fast_after_steps=frozen_steps
|
186 |
)
|
187 |
|
188 |
overall_bboxes_flattened, overall_phrases_flattened = [], []
|
models/pipelines.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
import utils
|
|
|
4 |
from PIL import Image
|
5 |
import gc
|
6 |
import numpy as np
|
@@ -131,7 +132,7 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
131 |
frozen_steps=20, frozen_mask=None,
|
132 |
return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
|
133 |
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
|
134 |
-
return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False):
|
135 |
"""
|
136 |
The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
|
137 |
"""
|
@@ -157,6 +158,8 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
157 |
latents_all = [latents]
|
158 |
|
159 |
scheduler.set_timesteps(num_inference_steps)
|
|
|
|
|
160 |
|
161 |
if frozen_mask is not None:
|
162 |
frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
|
@@ -212,6 +215,9 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
212 |
# perform guidance
|
213 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
214 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
|
|
|
215 |
|
216 |
# compute the previous noisy sample x_t -> x_t-1
|
217 |
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
@@ -219,7 +225,7 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
|
|
219 |
if frozen_mask is not None and index < frozen_steps:
|
220 |
latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
|
221 |
|
222 |
-
if save_all_latents:
|
223 |
if offload_latents_to_cpu:
|
224 |
latents_all.append(latents.cpu())
|
225 |
else:
|
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
import utils
|
4 |
+
from utils import schedule
|
5 |
from PIL import Image
|
6 |
import gc
|
7 |
import numpy as np
|
|
|
132 |
frozen_steps=20, frozen_mask=None,
|
133 |
return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
|
134 |
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
|
135 |
+
return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False, dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2):
|
136 |
"""
|
137 |
The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
|
138 |
"""
|
|
|
158 |
latents_all = [latents]
|
159 |
|
160 |
scheduler.set_timesteps(num_inference_steps)
|
161 |
+
if fast_after_steps is not None:
|
162 |
+
scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
|
163 |
|
164 |
if frozen_mask is not None:
|
165 |
frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
|
|
|
215 |
# perform guidance
|
216 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
217 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
218 |
+
|
219 |
+
if dynamic_num_inference_steps:
|
220 |
+
schedule.dynamically_adjust_inference_steps(scheduler, index, t)
|
221 |
|
222 |
# compute the previous noisy sample x_t -> x_t-1
|
223 |
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
|
|
225 |
if frozen_mask is not None and index < frozen_steps:
|
226 |
latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
|
227 |
|
228 |
+
if save_all_latents and (fast_after_steps is None or index < fast_after_steps):
|
229 |
if offload_latents_to_cpu:
|
230 |
latents_all.append(latents.cpu())
|
231 |
else:
|
utils/latents.py
CHANGED
@@ -35,7 +35,7 @@ def blend_latents(latents_bg, latents_fg, fg_mask, fg_blending_ratio=0.01):
|
|
35 |
return latents
|
36 |
|
37 |
@torch.no_grad()
|
38 |
-
def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width, latents_bg=None, bg_seed=None, compose_box_to_bg=True):
|
39 |
unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
|
40 |
|
41 |
if latents_bg is None:
|
@@ -43,7 +43,11 @@ def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inferenc
|
|
43 |
latents_bg = get_scaled_latents(overall_batch_size, unet.config.in_channels, height, width, generator, dtype, scheduler)
|
44 |
|
45 |
# Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
composed_latents[0] = latents_bg
|
48 |
|
49 |
foreground_indices = torch.zeros(latents_bg.shape[-2:], dtype=torch.long)
|
|
|
35 |
return latents
|
36 |
|
37 |
@torch.no_grad()
|
38 |
+
def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width, latents_bg=None, bg_seed=None, compose_box_to_bg=True, use_fast_schedule=False, fast_after_steps=None):
|
39 |
unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
|
40 |
|
41 |
if latents_bg is None:
|
|
|
43 |
latents_bg = get_scaled_latents(overall_batch_size, unet.config.in_channels, height, width, generator, dtype, scheduler)
|
44 |
|
45 |
# Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
|
46 |
+
if use_fast_schedule:
|
47 |
+
# If we use fast schedule, we only need to compose the frozen steps.
|
48 |
+
composed_latents = torch.zeros((fast_after_steps + 1, *latents_bg.shape), dtype=dtype)
|
49 |
+
else:
|
50 |
+
composed_latents = torch.zeros((num_inference_steps + 1, *latents_bg.shape), dtype=dtype)
|
51 |
composed_latents[0] = latents_bg
|
52 |
|
53 |
foreground_indices = torch.zeros(latents_bg.shape[-2:], dtype=torch.long)
|
utils/schedule.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
def get_fast_schedule(origial_timesteps, fast_after_steps, fast_rate):
|
5 |
+
if fast_after_steps >= len(origial_timesteps) - 1:
|
6 |
+
return origial_timesteps
|
7 |
+
new_timesteps = torch.cat((origial_timesteps[:fast_after_steps], origial_timesteps[fast_after_steps+1::fast_rate]), dim=0)
|
8 |
+
return new_timesteps
|
9 |
+
|
10 |
+
def dynamically_adjust_inference_steps(scheduler, index, t):
|
11 |
+
prev_t = scheduler.timesteps[index+1] if index+1 < len(scheduler.timesteps) else -1
|
12 |
+
scheduler.num_inference_steps = scheduler.config.num_train_timesteps // (t - prev_t)
|
13 |
+
if index+1 < len(scheduler.timesteps):
|
14 |
+
if scheduler.config.num_train_timesteps // scheduler.num_inference_steps != t - prev_t:
|
15 |
+
warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) != ({t} - {prev_t}), so the step sizes may not be accurate")
|
16 |
+
else:
|
17 |
+
# as long as we hit final cumprob, it should be fine.
|
18 |
+
if scheduler.config.num_train_timesteps // scheduler.num_inference_steps > t - prev_t:
|
19 |
+
warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) > ({t} - {prev_t}), so the step sizes may not be accurate")
|