Tony Lian commited on
Commit
d871568
1 Parent(s): ee0b9c2

Use fast schedule for per-box generation to speed up

Browse files
Files changed (5) hide show
  1. app.py +3 -3
  2. generation.py +7 -6
  3. models/pipelines.py +8 -2
  4. utils/latents.py +6 -2
  5. utils/schedule.py +19 -0
app.py CHANGED
@@ -238,11 +238,11 @@ with gr.Blocks(
238
  with gr.Column(scale=1):
239
  response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
240
  overall_prompt_override = gr.Textbox(lines=2, label="Prompt for overall generation (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
241
- num_inference_steps = gr.Slider(1, 250, value=20, step=1, label="Number of denoising steps (set to >=50 for higher generation quality)")
242
  seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
243
  with gr.Accordion("Advanced options (play around for better generation)", open=False):
244
- frozen_step_ratio = gr.Slider(0, 1, value=0.4, step=0.1, label="Foreground frozen steps ratio (higher: preserve object attributes; lower: higher coherence; set to 0: (almost) equivalent to vanilla GLIGEN except details)")
245
- gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.3, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
246
  dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend >=50 inference steps)", show_label=False, value=True)
247
  use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision (faster but with slightly lower quality)", show_label=False, value=True)
248
  fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
 
238
  with gr.Column(scale=1):
239
  response = gr.Textbox(lines=8, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
240
  overall_prompt_override = gr.Textbox(lines=2, label="Prompt for overall generation (optional but recommended)", placeholder="You can put your input prompt for layout generation here, helpful if your scene cannot be represented by background prompt and boxes only, e.g., with object interactions. If left empty: background prompt with [objects].", value="")
241
+ num_inference_steps = gr.Slider(1, 250, value=50, step=1, label="Number of denoising steps (set to >=50 for higher generation quality)")
242
  seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
243
  with gr.Accordion("Advanced options (play around for better generation)", open=False):
244
+ frozen_step_ratio = gr.Slider(0, 1, value=0.5, step=0.1, label="Foreground frozen steps ratio (higher: preserve object attributes; lower: higher coherence; set to 0: (almost) equivalent to vanilla GLIGEN except details)")
245
+ gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.4, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
246
  dpm_scheduler = gr.Checkbox(label="Use DPM scheduler (unchecked: DDIM scheduler, may have better coherence, recommend >=50 inference steps)", show_label=False, value=True)
247
  use_autocast = gr.Checkbox(label="Use FP16 Mixed Precision (faster but with slightly lower quality)", show_label=False, value=True)
248
  fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
generation.py CHANGED
@@ -10,6 +10,8 @@ from shared import model_dict, sam_model_dict, DEFAULT_SO_NEGATIVE_PROMPT, DEFAU
10
  import gc
11
 
12
  verbose = False
 
 
13
 
14
  vae, tokenizer, text_encoder, unet, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.dtype
15
 
@@ -36,7 +38,7 @@ run_ind = None
36
 
37
  def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input_latents_list, input_embeddings,
38
  sam_refine_kwargs, num_inference_steps, gligen_scheduled_sampling_beta=0.3,
39
- verbose=False, scheduler_key=None, visualize=True, batch_size=None):
40
  # batch_size=None: does not limit the batch size (pass all input together)
41
 
42
  # prompts and words are not used since we don't have cross-attention control in this function
@@ -62,7 +64,7 @@ def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input
62
  _, single_object_images_batch, single_object_pil_images_box_ann_batch, latents_all_batch = pipelines.generate_gligen(
63
  model_dict, input_latents_batch, input_embeddings_batch, num_inference_steps, bboxes_batch, phrases_batch, gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
64
  guidance_scale=guidance_scale, return_saved_cross_attn=False,
65
- return_box_vis=True, save_all_latents=True, batched_condition=True, scheduler_key=scheduler_key
66
  )
67
 
68
  gc.collect()
@@ -172,16 +174,15 @@ def run(
172
  latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
173
  so_prompt_phrase_word_box_list, input_latents_list,
174
  gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
175
- sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key, verbose=verbose, batch_size=so_batch_size
 
176
  )
177
 
178
-
179
-
180
  composed_latents, foreground_indices, offset_list = latents.compose_latents_with_alignment(
181
  model_dict, latents_all_list, mask_tensor_list, num_inference_steps,
182
  overall_batch_size, height, width, latents_bg=latents_bg,
183
  align_with_overall_bboxes=align_with_overall_bboxes, overall_bboxes=overall_bboxes,
184
- horizontal_shift_only=horizontal_shift_only
185
  )
186
 
187
  overall_bboxes_flattened, overall_phrases_flattened = [], []
 
10
  import gc
11
 
12
  verbose = False
13
+ # Accelerates per-box generation
14
+ use_fast_schedule = True
15
 
16
  vae, tokenizer, text_encoder, unet, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.dtype
17
 
 
38
 
39
  def generate_single_object_with_box_batch(prompts, bboxes, phrases, words, input_latents_list, input_embeddings,
40
  sam_refine_kwargs, num_inference_steps, gligen_scheduled_sampling_beta=0.3,
41
+ verbose=False, scheduler_key=None, visualize=True, batch_size=None, **kwargs):
42
  # batch_size=None: does not limit the batch size (pass all input together)
43
 
44
  # prompts and words are not used since we don't have cross-attention control in this function
 
64
  _, single_object_images_batch, single_object_pil_images_box_ann_batch, latents_all_batch = pipelines.generate_gligen(
65
  model_dict, input_latents_batch, input_embeddings_batch, num_inference_steps, bboxes_batch, phrases_batch, gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
66
  guidance_scale=guidance_scale, return_saved_cross_attn=False,
67
+ return_box_vis=True, save_all_latents=True, batched_condition=True, scheduler_key=scheduler_key, **kwargs
68
  )
69
 
70
  gc.collect()
 
174
  latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
175
  so_prompt_phrase_word_box_list, input_latents_list,
176
  gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
177
+ sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, num_inference_steps=num_inference_steps, scheduler_key=scheduler_key, verbose=verbose, batch_size=so_batch_size,
178
+ fast_after_steps=frozen_steps if use_fast_schedule else None, fast_rate=2
179
  )
180
 
 
 
181
  composed_latents, foreground_indices, offset_list = latents.compose_latents_with_alignment(
182
  model_dict, latents_all_list, mask_tensor_list, num_inference_steps,
183
  overall_batch_size, height, width, latents_bg=latents_bg,
184
  align_with_overall_bboxes=align_with_overall_bboxes, overall_bboxes=overall_bboxes,
185
+ horizontal_shift_only=horizontal_shift_only, use_fast_schedule=use_fast_schedule, fast_after_steps=frozen_steps
186
  )
187
 
188
  overall_bboxes_flattened, overall_phrases_flattened = [], []
models/pipelines.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from tqdm import tqdm
3
  import utils
 
4
  from PIL import Image
5
  import gc
6
  import numpy as np
@@ -131,7 +132,7 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
131
  frozen_steps=20, frozen_mask=None,
132
  return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
133
  offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
134
- return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False):
135
  """
136
  The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
137
  """
@@ -157,6 +158,8 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
157
  latents_all = [latents]
158
 
159
  scheduler.set_timesteps(num_inference_steps)
 
 
160
 
161
  if frozen_mask is not None:
162
  frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
@@ -212,6 +215,9 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
212
  # perform guidance
213
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
214
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
 
215
 
216
  # compute the previous noisy sample x_t -> x_t-1
217
  latents = scheduler.step(noise_pred, t, latents).prev_sample
@@ -219,7 +225,7 @@ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps,
219
  if frozen_mask is not None and index < frozen_steps:
220
  latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
221
 
222
- if save_all_latents:
223
  if offload_latents_to_cpu:
224
  latents_all.append(latents.cpu())
225
  else:
 
1
  import torch
2
  from tqdm import tqdm
3
  import utils
4
+ from utils import schedule
5
  from PIL import Image
6
  import gc
7
  import numpy as np
 
132
  frozen_steps=20, frozen_mask=None,
133
  return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
134
  offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
135
+ return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler', batched_condition=False, dynamic_num_inference_steps=False, fast_after_steps=None, fast_rate=2):
136
  """
137
  The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
138
  """
 
158
  latents_all = [latents]
159
 
160
  scheduler.set_timesteps(num_inference_steps)
161
+ if fast_after_steps is not None:
162
+ scheduler.timesteps = schedule.get_fast_schedule(scheduler.timesteps, fast_after_steps, fast_rate)
163
 
164
  if frozen_mask is not None:
165
  frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
 
215
  # perform guidance
216
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
217
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
218
+
219
+ if dynamic_num_inference_steps:
220
+ schedule.dynamically_adjust_inference_steps(scheduler, index, t)
221
 
222
  # compute the previous noisy sample x_t -> x_t-1
223
  latents = scheduler.step(noise_pred, t, latents).prev_sample
 
225
  if frozen_mask is not None and index < frozen_steps:
226
  latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
227
 
228
+ if save_all_latents and (fast_after_steps is None or index < fast_after_steps):
229
  if offload_latents_to_cpu:
230
  latents_all.append(latents.cpu())
231
  else:
utils/latents.py CHANGED
@@ -35,7 +35,7 @@ def blend_latents(latents_bg, latents_fg, fg_mask, fg_blending_ratio=0.01):
35
  return latents
36
 
37
  @torch.no_grad()
38
- def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width, latents_bg=None, bg_seed=None, compose_box_to_bg=True):
39
  unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
40
 
41
  if latents_bg is None:
@@ -43,7 +43,11 @@ def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inferenc
43
  latents_bg = get_scaled_latents(overall_batch_size, unet.config.in_channels, height, width, generator, dtype, scheduler)
44
 
45
  # Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
46
- composed_latents = torch.zeros((num_inference_steps + 1, *latents_bg.shape), dtype=dtype)
 
 
 
 
47
  composed_latents[0] = latents_bg
48
 
49
  foreground_indices = torch.zeros(latents_bg.shape[-2:], dtype=torch.long)
 
35
  return latents
36
 
37
  @torch.no_grad()
38
+ def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width, latents_bg=None, bg_seed=None, compose_box_to_bg=True, use_fast_schedule=False, fast_after_steps=None):
39
  unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
40
 
41
  if latents_bg is None:
 
43
  latents_bg = get_scaled_latents(overall_batch_size, unet.config.in_channels, height, width, generator, dtype, scheduler)
44
 
45
  # Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
46
+ if use_fast_schedule:
47
+ # If we use fast schedule, we only need to compose the frozen steps.
48
+ composed_latents = torch.zeros((fast_after_steps + 1, *latents_bg.shape), dtype=dtype)
49
+ else:
50
+ composed_latents = torch.zeros((num_inference_steps + 1, *latents_bg.shape), dtype=dtype)
51
  composed_latents[0] = latents_bg
52
 
53
  foreground_indices = torch.zeros(latents_bg.shape[-2:], dtype=torch.long)
utils/schedule.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+
4
+ def get_fast_schedule(origial_timesteps, fast_after_steps, fast_rate):
5
+ if fast_after_steps >= len(origial_timesteps) - 1:
6
+ return origial_timesteps
7
+ new_timesteps = torch.cat((origial_timesteps[:fast_after_steps], origial_timesteps[fast_after_steps+1::fast_rate]), dim=0)
8
+ return new_timesteps
9
+
10
+ def dynamically_adjust_inference_steps(scheduler, index, t):
11
+ prev_t = scheduler.timesteps[index+1] if index+1 < len(scheduler.timesteps) else -1
12
+ scheduler.num_inference_steps = scheduler.config.num_train_timesteps // (t - prev_t)
13
+ if index+1 < len(scheduler.timesteps):
14
+ if scheduler.config.num_train_timesteps // scheduler.num_inference_steps != t - prev_t:
15
+ warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) != ({t} - {prev_t}), so the step sizes may not be accurate")
16
+ else:
17
+ # as long as we hit final cumprob, it should be fine.
18
+ if scheduler.config.num_train_timesteps // scheduler.num_inference_steps > t - prev_t:
19
+ warnings.warn(f"({scheduler.config.num_train_timesteps} // {scheduler.num_inference_steps}) > ({t} - {prev_t}), so the step sizes may not be accurate")