Tony Lian commited on
Commit
2a53583
1 Parent(s): 1d6e0a9

Add caching and allow scale boxes

Browse files
Files changed (1) hide show
  1. app.py +21 -11
app.py CHANGED
@@ -61,7 +61,7 @@ layout_placeholder = """Caption: A realistic photo of a gray cat and an orange d
61
  Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
62
  Background prompt: A realistic photo of a grassy area."""
63
 
64
- def get_lmd_prompt(prompt, template=""):
65
  if prompt == "":
66
  prompt = prompt_placeholder
67
  if template == "":
@@ -88,7 +88,7 @@ def get_layout_image(response):
88
  def get_layout_image_gallery(response):
89
  return [get_layout_image(response)]
90
 
91
- def get_ours_image(response, seed, num_inference_steps, dpm_scheduler, fg_seed_start, fg_blending_ratio=0.1, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta=0.3, so_negative_prompt="", overall_negative_prompt="", show_so_imgs=False, scale_boxes=False):
92
  if response == "":
93
  response = layout_placeholder
94
  gen_boxes, bg_prompt = parse_input(response)
@@ -116,7 +116,7 @@ def get_ours_image(response, seed, num_inference_steps, dpm_scheduler, fg_seed_s
116
  images.extend([np.asarray(so_img) for so_img in so_img_list])
117
  return images
118
 
119
- def get_baseline_image(prompt, seed):
120
  if prompt == "":
121
  prompt = prompt_placeholder
122
 
@@ -222,8 +222,11 @@ with gr.Blocks(
222
  generate_btn.click(fn=get_lmd_prompt, inputs=[prompt, template], outputs=output, api_name="get_lmd_prompt")
223
 
224
  gr.Examples(
225
- stage1_examples,
226
- [prompt]
 
 
 
227
  )
228
 
229
  # with gr.Tab("(Optional) Visualize ChatGPT-generated Layout"):
@@ -251,17 +254,21 @@ with gr.Blocks(
251
  gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.3, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
252
  so_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for single object generation", value=DEFAULT_SO_NEGATIVE_PROMPT)
253
  overall_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for overall generation", value=DEFAULT_OVERALL_NEGATIVE_PROMPT)
254
- show_so_imgs = gr.Checkbox(label="Show annotated single object generations", show_label=False)
 
255
  with gr.Column(scale=1):
256
  gallery = gr.Gallery(
257
  label="Generated image", show_label=False, elem_id="gallery"
258
  ).style(columns=[1], rows=[1], object_fit="contain", preview=True)
259
  visualize_btn.click(fn=get_layout_image_gallery, inputs=response, outputs=gallery, api_name="visualize-layout")
260
- generate_btn.click(fn=get_ours_image, inputs=[response, seed, num_inference_steps, dpm_scheduler, fg_seed_start, fg_blending_ratio, frozen_step_ratio, gligen_scheduled_sampling_beta, so_negative_prompt, overall_negative_prompt, show_so_imgs], outputs=gallery, api_name="layout-to-image")
261
 
262
  gr.Examples(
263
- stage2_examples,
264
- [response, seed]
 
 
 
265
  )
266
 
267
  with gr.Tab("Baseline: Stable Diffusion"):
@@ -280,8 +287,11 @@ with gr.Blocks(
280
  generate_btn.click(fn=get_baseline_image, inputs=[sd_prompt, seed], outputs=gallery, api_name="baseline")
281
 
282
  gr.Examples(
283
- stage1_examples,
284
- [sd_prompt]
 
 
 
285
  )
286
 
287
  g.launch()
 
61
  Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
62
  Background prompt: A realistic photo of a grassy area."""
63
 
64
+ def get_lmd_prompt(prompt, template=default_template):
65
  if prompt == "":
66
  prompt = prompt_placeholder
67
  if template == "":
 
88
  def get_layout_image_gallery(response):
89
  return [get_layout_image(response)]
90
 
91
+ def get_ours_image(response, seed, num_inference_steps=20, dpm_scheduler=True, fg_seed_start=20, fg_blending_ratio=0.1, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta=0.3, so_negative_prompt=DEFAULT_SO_NEGATIVE_PROMPT, overall_negative_prompt=DEFAULT_OVERALL_NEGATIVE_PROMPT, show_so_imgs=False, scale_boxes=False):
92
  if response == "":
93
  response = layout_placeholder
94
  gen_boxes, bg_prompt = parse_input(response)
 
116
  images.extend([np.asarray(so_img) for so_img in so_img_list])
117
  return images
118
 
119
+ def get_baseline_image(prompt, seed=0):
120
  if prompt == "":
121
  prompt = prompt_placeholder
122
 
 
222
  generate_btn.click(fn=get_lmd_prompt, inputs=[prompt, template], outputs=output, api_name="get_lmd_prompt")
223
 
224
  gr.Examples(
225
+ examples=stage1_examples,
226
+ inputs=[prompt],
227
+ outputs=[output],
228
+ fn=get_lmd_prompt,
229
+ cache_examples=True
230
  )
231
 
232
  # with gr.Tab("(Optional) Visualize ChatGPT-generated Layout"):
 
254
  gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.3, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
255
  so_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for single object generation", value=DEFAULT_SO_NEGATIVE_PROMPT)
256
  overall_negative_prompt = gr.Textbox(lines=1, label="Negative prompt for overall generation", value=DEFAULT_OVERALL_NEGATIVE_PROMPT)
257
+ show_so_imgs = gr.Checkbox(label="Show annotated single object generations", show_label=False, value=False)
258
+ scale_boxes = gr.Checkbox(label="Scale bounding boxes to just fit the scene", show_label=False, value=False)
259
  with gr.Column(scale=1):
260
  gallery = gr.Gallery(
261
  label="Generated image", show_label=False, elem_id="gallery"
262
  ).style(columns=[1], rows=[1], object_fit="contain", preview=True)
263
  visualize_btn.click(fn=get_layout_image_gallery, inputs=response, outputs=gallery, api_name="visualize-layout")
264
+ generate_btn.click(fn=get_ours_image, inputs=[response, seed, num_inference_steps, dpm_scheduler, fg_seed_start, fg_blending_ratio, frozen_step_ratio, gligen_scheduled_sampling_beta, so_negative_prompt, overall_negative_prompt, show_so_imgs, scale_boxes], outputs=gallery, api_name="layout-to-image")
265
 
266
  gr.Examples(
267
+ examples=stage2_examples,
268
+ inputs=[response, seed],
269
+ outputs=[gallery],
270
+ fn=get_ours_image,
271
+ cache_examples=True
272
  )
273
 
274
  with gr.Tab("Baseline: Stable Diffusion"):
 
287
  generate_btn.click(fn=get_baseline_image, inputs=[sd_prompt, seed], outputs=gallery, api_name="baseline")
288
 
289
  gr.Examples(
290
+ examples=stage1_examples,
291
+ inputs=[sd_prompt],
292
+ outputs=[gallery],
293
+ fn=get_baseline_image,
294
+ cache_examples=True
295
  )
296
 
297
  g.launch()