prithivMLmods commited on
Commit
a5ce6db
·
verified ·
1 Parent(s): bdf5103

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -212
app.py CHANGED
@@ -1,11 +1,8 @@
1
- # -*- coding: utf-8 -*-
2
  import gradio as gr
3
  import spaces
4
  import torch
5
  from diffusers import AutoencoderKL, TCDScheduler
6
  from diffusers.models.model_loading_utils import load_state_dict
7
- # Remove ImageSlider import as it's no longer needed
8
- # from gradio_imageslider import ImageSlider
9
  from huggingface_hub import hf_hub_download
10
 
11
  from controlnet_union import ControlNetModel_Union
@@ -14,7 +11,6 @@ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
14
  from PIL import Image, ImageDraw
15
  import numpy as np
16
 
17
- # --- Model Loading (Keep as is) ---
18
  config_file = hf_hub_download(
19
  "xinsir/controlnet-union-sdxl-1.0",
20
  filename="config_promax.json",
@@ -26,9 +22,10 @@ model_file = hf_hub_download(
26
  "xinsir/controlnet-union-sdxl-1.0",
27
  filename="diffusion_pytorch_model_promax.safetensors",
28
  )
29
- state_dict = load_state_dict(model_file)
 
30
  model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
31
- controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
32
  )
33
  model.to(device="cuda", dtype=torch.float16)
34
 
@@ -46,8 +43,6 @@ pipe = StableDiffusionXLFillPipeline.from_pretrained(
46
 
47
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
48
 
49
- # --- Helper Functions (Keep as is, except infer) ---
50
-
51
  def can_expand(source_width, source_height, target_width, target_height, alignment):
52
  """Checks if the image can be expanded based on the alignment."""
53
  if alignment in ("Left", "Right") and source_width >= target_width:
@@ -63,7 +58,7 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
63
  scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
64
  new_width = int(image.width * scale_factor)
65
  new_height = int(image.height * scale_factor)
66
-
67
  # Resize the source image to fit within target size
68
  source = image.resize((new_width, new_height), Image.LANCZOS)
69
 
@@ -135,7 +130,7 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
135
  right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
136
  top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
137
  bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
138
-
139
  if alignment == "Left":
140
  left_overlap = margin_x + overlap_x if overlap_left else margin_x
141
  elif alignment == "Right":
@@ -145,7 +140,6 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
145
  elif alignment == "Bottom":
146
  bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
147
 
148
-
149
  # Draw the mask
150
  mask_draw.rectangle([
151
  (left_overlap, top_overlap),
@@ -156,47 +150,33 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
156
 
157
  def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
158
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
159
-
160
  # Create a preview image showing the mask
161
  preview = background.copy().convert('RGBA')
162
-
163
  # Create a semi-transparent red overlay
164
  red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64)) # Reduced alpha to 64 (25% opacity)
165
-
166
  # Convert black pixels in the mask to semi-transparent red
167
  red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
168
  red_mask.paste(red_overlay, (0, 0), mask)
169
-
170
  # Overlay the red mask on the background
171
  preview = Image.alpha_composite(preview, red_mask)
172
-
173
  return preview
174
 
175
  @spaces.GPU(duration=24)
176
  def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
177
- if image is None:
178
- raise gr.Error("Please upload an input image.")
179
-
180
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
181
-
182
  if not can_expand(background.width, background.height, width, height, alignment):
183
- # Optionally provide feedback or default to middle
184
- # gr.Warning(f"Cannot expand image with '{alignment}' alignment as source dimension is larger than target. Defaulting to 'Middle'.")
185
  alignment = "Middle"
186
- # Recalculate background and mask if alignment changed due to this check
187
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
188
-
189
 
190
  cnet_image = background.copy()
191
- # Apply mask to create the input for controlnet (black out non-masked area)
192
- # cnet_image.paste(0, (0, 0), mask) # This line seems incorrect for inpainting/outpainting, usually the unmasked area is kept
193
- # The pipeline expects the original image content where mask=0 and potentially noise/latents where mask=1
194
- # Let's keep the original image content in the unmasked area and let the pipeline handle the masked area.
195
- # The `StableDiffusionXLFillPipeline` likely uses the `image` input and `mask` differently than standard inpainting.
196
- # Based on typical diffusers pipelines, `image` is often the *original* content placed on the canvas.
197
- # Let's pass `background` as the image input for the pipeline.
198
 
199
- final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
200
 
201
  (
202
  prompt_embeds,
@@ -205,42 +185,25 @@ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_
205
  negative_pooled_prompt_embeds,
206
  ) = pipe.encode_prompt(final_prompt, "cuda", True)
207
 
208
- # The pipeline expects the `image` and `mask_image` arguments for inpainting/outpainting
209
- # `image` should be the canvas with the original image placed.
210
- # `mask_image` defines the area to be filled (white=fill, black=keep).
211
- # Our mask is inverted (black=keep, white=fill). Invert it.
212
- inverted_mask = Image.fromarray(255 - np.array(mask))
213
-
214
- # Run the pipeline
215
- # Note: The generator inside the pipeline call is not used here as we only need the final result.
216
- # We iterate once to get the final image.
217
- generated_image = None
218
- for img_output in pipe(
219
  prompt_embeds=prompt_embeds,
220
  negative_prompt_embeds=negative_prompt_embeds,
221
  pooled_prompt_embeds=pooled_prompt_embeds,
222
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
223
- image=background, # Pass the background with the source image placed
224
- mask_image=inverted_mask, # Pass the inverted mask (white = area to fill)
225
- control_image=background, # ControlNet Union might need the full image context
226
- num_inference_steps=num_inference_steps,
227
- output_type="pil" # Ensure PIL images are returned
228
  ):
229
- generated_image = img_output[0] # The pipeline returns a list containing the image
230
-
231
- if generated_image is None:
232
- raise gr.Error("Image generation failed.")
233
-
234
- # The pipeline should return the complete image already composited.
235
- # No need to manually paste.
236
- final_image = generated_image.convert("RGB")
237
 
238
- # Yield only the final generated image
239
- yield final_image
240
 
 
241
 
242
  def clear_result():
243
- """Clears the result Image component."""
244
  return gr.update(value=None)
245
 
246
  def preload_presets(target_ratio, ui_width, ui_height):
@@ -248,21 +211,19 @@ def preload_presets(target_ratio, ui_width, ui_height):
248
  if target_ratio == "9:16":
249
  changed_width = 720
250
  changed_height = 1280
251
- return changed_width, changed_height, gr.update(open=False) # Close accordion
252
  elif target_ratio == "16:9":
253
  changed_width = 1280
254
  changed_height = 720
255
- return changed_width, changed_height, gr.update(open=False) # Close accordion
256
  elif target_ratio == "1:1":
257
  changed_width = 1024
258
  changed_height = 1024
259
- return changed_width, changed_height, gr.update(open=False) # Close accordion
260
  elif target_ratio == "Custom":
261
- # Keep current slider values but open the accordion
262
  return ui_width, ui_height, gr.update(open=True)
263
 
264
  def select_the_right_preset(user_width, user_height):
265
- """Selects the preset radio button based on current width/height."""
266
  if user_width == 720 and user_height == 1280:
267
  return "9:16"
268
  elif user_width == 1280 and user_height == 720:
@@ -273,49 +234,24 @@ def select_the_right_preset(user_width, user_height):
273
  return "Custom"
274
 
275
  def toggle_custom_resize_slider(resize_option):
276
- """Shows/hides the custom resize slider."""
277
  return gr.update(visible=(resize_option == "Custom"))
278
 
279
  def update_history(new_image, history):
280
  """Updates the history gallery with the new image."""
281
- if new_image is None: # Don't add None to history
282
- return history
283
  if history is None:
284
  history = []
285
- # Prepend the new image (as PIL) to the history list
286
  history.insert(0, new_image)
287
- # Limit history size if desired (e.g., keep last 12)
288
- max_history = 12
289
- if len(history) > max_history:
290
- history = history[:max_history]
291
  return history
292
 
293
- # --- Gradio UI ---
294
-
295
  css = """
296
  .gradio-container {
297
- max-width: 1200px !important; /* Limit overall width */
298
- margin: auto; /* Center the container */
299
- }
300
- /* Ensure gallery items are reasonably sized */
301
- #history_gallery .thumbnail-item {
302
- height: 100px !important; /* Adjust as needed */
303
- }
304
- #history_gallery .gallery {
305
- grid-template-columns: repeat(auto-fill, minmax(100px, 1fr)) !important; /* Adjust column size */
306
  }
307
-
 
308
  """
309
 
310
- title = """<h1 align="center">Diffusers Image Outpaint</h1>
311
- <div align="center">Drop an image you would like to extend, pick your expected ratio and hit Generate.</div>
312
- <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
313
- <p style="display: flex;gap: 6px;">
314
- <a href="https://huggingface.co/spaces/fffiloni/diffusers-image-outpaint?duplicate=true">
315
- <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate this Space">
316
- </a> to skip the queue and enjoy faster inference on the GPU of your choice
317
- </p>
318
- </div>
319
  """
320
 
321
  with gr.Blocks(css=css) as demo:
@@ -323,7 +259,7 @@ with gr.Blocks(css=css) as demo:
323
  gr.HTML(title)
324
 
325
  with gr.Row():
326
- with gr.Column(scale=1): # Input column
327
  input_image = gr.Image(
328
  type="pil",
329
  label="Input Image"
@@ -331,65 +267,62 @@ with gr.Blocks(css=css) as demo:
331
 
332
  with gr.Row():
333
  with gr.Column(scale=2):
334
- prompt_input = gr.Textbox(label="Prompt (Optional)", placeholder="Describe the desired extended scene...")
335
- with gr.Column(scale=1, min_width=150):
336
- run_button = gr.Button("Generate", variant="primary")
337
 
338
  with gr.Row():
339
  target_ratio = gr.Radio(
340
- label="Target Ratio",
341
  choices=["9:16", "16:9", "1:1", "Custom"],
342
  value="9:16",
343
  scale=2
344
  )
345
-
346
  alignment_dropdown = gr.Dropdown(
347
  choices=["Middle", "Left", "Right", "Top", "Bottom"],
348
  value="Middle",
349
- label="Align Source Image"
350
  )
351
 
352
  with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
353
  with gr.Column():
354
  with gr.Row():
355
  width_slider = gr.Slider(
356
- label="Target Width (px)",
357
- minimum=512, # Lowered min slightly
358
- maximum=2048, # Increased max slightly
359
- step=64, # SDXL optimal step
360
  value=720,
361
  )
362
  height_slider = gr.Slider(
363
- label="Target Height (px)",
364
- minimum=512, # Lowered min slightly
365
- maximum=2048, # Increased max slightly
366
- step=64, # SDXL optimal step
367
  value=1280,
368
  )
369
-
370
- num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=20, step=1, value=8) # Increased max steps slightly
371
  with gr.Group():
372
  overlap_percentage = gr.Slider(
373
  label="Mask overlap (%)",
374
  minimum=1,
375
  maximum=50,
376
  value=10,
377
- step=1,
378
- info="How much the new area overlaps the original image."
379
  )
380
- gr.Markdown("Select sides to overlap (influences mask generation):")
381
  with gr.Row():
382
- overlap_top = gr.Checkbox(label="Top", value=True)
383
- overlap_right = gr.Checkbox(label="Right", value=True)
384
  with gr.Row():
385
- overlap_left = gr.Checkbox(label="Left", value=True)
386
- overlap_bottom = gr.Checkbox(label="Bottom", value=True)
387
  with gr.Row():
388
  resize_option = gr.Radio(
389
- label="Resize input image before placing",
390
  choices=["Full", "50%", "33%", "25%", "Custom"],
391
- value="Full",
392
- info="Scales the source image down before placing it on the target canvas."
393
  )
394
  custom_resize_percentage = gr.Slider(
395
  label="Custom resize (%)",
@@ -399,67 +332,37 @@ with gr.Blocks(css=css) as demo:
399
  value=50,
400
  visible=False
401
  )
402
-
403
  with gr.Column():
404
- preview_button = gr.Button("Preview Alignment & Mask")
405
-
406
-
407
  gr.Examples(
408
  examples=[
409
- ["./examples/example_1.webp", 1280, 720, "Middle", "A wide landscape view of the mountains"],
410
- ["./examples/example_2.jpg", 1440, 810, "Left", "Full body shot of the cat sitting on a rug"],
411
- ["./examples/example_3.jpg", 1024, 1024, "Top", "The cloudy sky above the building"],
412
- ["./examples/example_3.jpg", 1024, 1024, "Bottom", "The street below the building"],
413
  ],
414
- inputs=[input_image, width_slider, height_slider, alignment_dropdown, prompt_input],
415
- label="Examples (Click to load)"
416
- )
417
-
418
- with gr.Column(scale=1): # Output column
419
- # Replace ImageSlider with gr.Image
420
- result_image = gr.Image(
421
- label="Generated Image",
422
- interactive=False,
423
- show_download_button=True,
424
- type="pil" # Ensure output is PIL for history
425
  )
426
- with gr.Row():
427
- use_as_input_button = gr.Button("Use as Input", visible=False)
428
- clear_button = gr.Button("Clear Output") # Added clear button
429
-
430
- preview_mask_image = gr.Image(label="Alignment & Mask Preview", interactive=False)
431
-
432
- history_gallery = gr.Gallery(
433
- label="History",
434
- columns=4, # Adjust columns as needed
435
- object_fit="contain",
436
- interactive=False,
437
- show_label=True,
438
- elem_id="history_gallery",
439
- height=300 # Set a fixed height for the gallery area
440
- )
441
 
 
 
 
442
 
443
- # --- Event Handlers ---
 
444
 
445
- def use_output_as_input(output_pil_image):
446
- """Sets the generated output PIL image as the new input image."""
447
- # output_image comes directly from result_image which is PIL type
448
- return gr.update(value=output_pil_image)
449
 
450
  use_as_input_button.click(
451
  fn=use_output_as_input,
452
- inputs=[result_image], # Input is the single result image
453
  outputs=[input_image]
454
  )
455
-
456
- clear_button.click(
457
- fn=lambda: (gr.update(value=None), gr.update(visible=False), gr.update(value=None)), # Clear image, hide button, clear preview
458
- inputs=None,
459
- outputs=[result_image, use_as_input_button, preview_mask_image],
460
- queue=False
461
- )
462
-
463
  target_ratio.change(
464
  fn=preload_presets,
465
  inputs=[target_ratio, width_slider, height_slider],
@@ -467,18 +370,17 @@ with gr.Blocks(css=css) as demo:
467
  queue=False
468
  )
469
 
470
- # Link sliders back to ratio selector and potentially open accordion
471
  width_slider.change(
472
- fn=lambda w, h: (select_the_right_preset(w, h), gr.update() if select_the_right_preset(w, h) == "Custom" else gr.update()),
473
  inputs=[width_slider, height_slider],
474
- outputs=[target_ratio, settings_panel],
475
  queue=False
476
  )
477
 
478
  height_slider.change(
479
- fn=lambda w, h: (select_the_right_preset(w, h), gr.update() if select_the_right_preset(w, h) == "Custom" else gr.update()),
480
  inputs=[width_slider, height_slider],
481
- outputs=[target_ratio, settings_panel],
482
  queue=False
483
  )
484
 
@@ -488,59 +390,53 @@ with gr.Blocks(css=css) as demo:
488
  outputs=[custom_resize_percentage],
489
  queue=False
490
  )
491
-
492
- # Define common inputs for generation
493
- gen_inputs = [
494
- input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
495
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
496
- overlap_left, overlap_right, overlap_top, overlap_bottom
497
- ]
498
-
499
- # Define common steps after generation
500
- def handle_output(generated_image, current_history):
501
- # generated_image is the single PIL image from infer
502
- new_history = update_history(generated_image, current_history)
503
- button_visibility = gr.update(visible=True) if generated_image else gr.update(visible=False)
504
- return generated_image, new_history, button_visibility
505
-
506
  run_button.click(
507
- fn=lambda: (gr.update(value=None), gr.update(visible=False)), # Clear result and hide button first
508
  inputs=None,
509
- outputs=[result_image, use_as_input_button],
510
- queue=False # Don't queue the clearing part
511
  ).then(
512
- fn=infer, # Run the generation
513
- inputs=gen_inputs,
514
- outputs=result_image, # Output is the single generated image
 
 
 
 
 
 
515
  ).then(
516
- fn=handle_output, # Process output: update history, show button
517
- inputs=[result_image, history_gallery],
518
- outputs=[result_image, history_gallery, use_as_input_button] # Update result again (no change), history, and button visibility
519
  )
520
 
521
  prompt_input.submit(
522
- fn=lambda: (gr.update(value=None), gr.update(visible=False)), # Clear result and hide button first
523
  inputs=None,
524
- outputs=[result_image, use_as_input_button],
525
- queue=False # Don't queue the clearing part
526
  ).then(
527
- fn=infer, # Run the generation
528
- inputs=gen_inputs,
529
- outputs=result_image, # Output is the single generated image
 
 
 
 
 
 
530
  ).then(
531
- fn=handle_output, # Process output: update history, show button
532
- inputs=[result_image, history_gallery],
533
- outputs=[result_image, history_gallery, use_as_input_button] # Update result again (no change), history, and button visibility
534
  )
535
 
536
-
537
  preview_button.click(
538
  fn=preview_image_and_mask,
539
  inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
540
  overlap_left, overlap_right, overlap_top, overlap_bottom],
541
- outputs=preview_mask_image, # Output to the preview image component
542
- queue=False # Preview should be fast
543
  )
544
 
545
- # Launch the app
546
- demo.queue(max_size=12).launch(share=False, ssr_mode=False, show_error=True)
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
  from diffusers import AutoencoderKL, TCDScheduler
5
  from diffusers.models.model_loading_utils import load_state_dict
 
 
6
  from huggingface_hub import hf_hub_download
7
 
8
  from controlnet_union import ControlNetModel_Union
 
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
 
 
14
  config_file = hf_hub_download(
15
  "xinsir/controlnet-union-sdxl-1.0",
16
  filename="config_promax.json",
 
22
  "xinsir/controlnet-union-sdxl-1.0",
23
  filename="diffusion_pytorch_model_promax.safetensors",
24
  )
25
+
26
+ sstate_dict = load_state_dict(model_file)
27
  model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
28
+ controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
29
  )
30
  model.to(device="cuda", dtype=torch.float16)
31
 
 
43
 
44
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
 
 
46
  def can_expand(source_width, source_height, target_width, target_height, alignment):
47
  """Checks if the image can be expanded based on the alignment."""
48
  if alignment in ("Left", "Right") and source_width >= target_width:
 
58
  scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
59
  new_width = int(image.width * scale_factor)
60
  new_height = int(image.height * scale_factor)
61
+
62
  # Resize the source image to fit within target size
63
  source = image.resize((new_width, new_height), Image.LANCZOS)
64
 
 
130
  right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
131
  top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
132
  bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
133
+
134
  if alignment == "Left":
135
  left_overlap = margin_x + overlap_x if overlap_left else margin_x
136
  elif alignment == "Right":
 
140
  elif alignment == "Bottom":
141
  bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
142
 
 
143
  # Draw the mask
144
  mask_draw.rectangle([
145
  (left_overlap, top_overlap),
 
150
 
151
  def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
152
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
153
+
154
  # Create a preview image showing the mask
155
  preview = background.copy().convert('RGBA')
156
+
157
  # Create a semi-transparent red overlay
158
  red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64)) # Reduced alpha to 64 (25% opacity)
159
+
160
  # Convert black pixels in the mask to semi-transparent red
161
  red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
162
  red_mask.paste(red_overlay, (0, 0), mask)
163
+
164
  # Overlay the red mask on the background
165
  preview = Image.alpha_composite(preview, red_mask)
166
+
167
  return preview
168
 
169
  @spaces.GPU(duration=24)
170
  def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
 
 
 
171
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
172
+
173
  if not can_expand(background.width, background.height, width, height, alignment):
 
 
174
  alignment = "Middle"
 
 
 
175
 
176
  cnet_image = background.copy()
177
+ cnet_image.paste(0, (0, 0), mask)
 
 
 
 
 
 
178
 
179
+ final_prompt = f"{prompt_input} , high quality, 4k"
180
 
181
  (
182
  prompt_embeds,
 
185
  negative_pooled_prompt_embeds,
186
  ) = pipe.encode_prompt(final_prompt, "cuda", True)
187
 
188
+ # Generate the image
189
+ for image in pipe(
 
 
 
 
 
 
 
 
 
190
  prompt_embeds=prompt_embeds,
191
  negative_prompt_embeds=negative_prompt_embeds,
192
  pooled_prompt_embeds=pooled_prompt_embeds,
193
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
194
+ image=cnet_image,
195
+ num_inference_steps=num_inference_steps
 
 
 
196
  ):
197
+ pass # Wait for the generation to complete
198
+ generated_image = image # Get the last image
 
 
 
 
 
 
199
 
200
+ generated_image = generated_image.convert("RGBA")
201
+ cnet_image.paste(generated_image, (0, 0), mask)
202
 
203
+ return cnet_image
204
 
205
  def clear_result():
206
+ """Clears the result Image."""
207
  return gr.update(value=None)
208
 
209
  def preload_presets(target_ratio, ui_width, ui_height):
 
211
  if target_ratio == "9:16":
212
  changed_width = 720
213
  changed_height = 1280
214
+ return changed_width, changed_height, gr.update()
215
  elif target_ratio == "16:9":
216
  changed_width = 1280
217
  changed_height = 720
218
+ return changed_width, changed_height, gr.update()
219
  elif target_ratio == "1:1":
220
  changed_width = 1024
221
  changed_height = 1024
222
+ return changed_width, changed_height, gr.update()
223
  elif target_ratio == "Custom":
 
224
  return ui_width, ui_height, gr.update(open=True)
225
 
226
  def select_the_right_preset(user_width, user_height):
 
227
  if user_width == 720 and user_height == 1280:
228
  return "9:16"
229
  elif user_width == 1280 and user_height == 720:
 
234
  return "Custom"
235
 
236
  def toggle_custom_resize_slider(resize_option):
 
237
  return gr.update(visible=(resize_option == "Custom"))
238
 
239
  def update_history(new_image, history):
240
  """Updates the history gallery with the new image."""
 
 
241
  if history is None:
242
  history = []
 
243
  history.insert(0, new_image)
 
 
 
 
244
  return history
245
 
 
 
246
  css = """
247
  .gradio-container {
248
+ width: 1200px !important;
 
 
 
 
 
 
 
 
249
  }
250
+ h1 { text-align: center; }
251
+ footer { visibility: hidden; }
252
  """
253
 
254
+ title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
 
 
 
 
 
 
 
 
255
  """
256
 
257
  with gr.Blocks(css=css) as demo:
 
259
  gr.HTML(title)
260
 
261
  with gr.Row():
262
+ with gr.Column():
263
  input_image = gr.Image(
264
  type="pil",
265
  label="Input Image"
 
267
 
268
  with gr.Row():
269
  with gr.Column(scale=2):
270
+ prompt_input = gr.Textbox(label="Prompt (Optional)")
271
+ with gr.Column(scale=1):
272
+ run_button = gr.Button("Generate")
273
 
274
  with gr.Row():
275
  target_ratio = gr.Radio(
276
+ label="Expected Ratio",
277
  choices=["9:16", "16:9", "1:1", "Custom"],
278
  value="9:16",
279
  scale=2
280
  )
281
+
282
  alignment_dropdown = gr.Dropdown(
283
  choices=["Middle", "Left", "Right", "Top", "Bottom"],
284
  value="Middle",
285
+ label="Alignment"
286
  )
287
 
288
  with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
289
  with gr.Column():
290
  with gr.Row():
291
  width_slider = gr.Slider(
292
+ label="Target Width",
293
+ minimum=720,
294
+ maximum=1536,
295
+ step=8,
296
  value=720,
297
  )
298
  height_slider = gr.Slider(
299
+ label="Target Height",
300
+ minimum=720,
301
+ maximum=1536,
302
+ step=8,
303
  value=1280,
304
  )
305
+
306
+ num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
307
  with gr.Group():
308
  overlap_percentage = gr.Slider(
309
  label="Mask overlap (%)",
310
  minimum=1,
311
  maximum=50,
312
  value=10,
313
+ step=1
 
314
  )
 
315
  with gr.Row():
316
+ overlap_top = gr.Checkbox(label="Overlap Top", value=True)
317
+ overlap_right = gr.Checkbox(label="Overlap Right", value=True)
318
  with gr.Row():
319
+ overlap_left = gr.Checkbox(label="Overlap Left", value=True)
320
+ overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
321
  with gr.Row():
322
  resize_option = gr.Radio(
323
+ label="Resize input image",
324
  choices=["Full", "50%", "33%", "25%", "Custom"],
325
+ value="Full"
 
326
  )
327
  custom_resize_percentage = gr.Slider(
328
  label="Custom resize (%)",
 
332
  value=50,
333
  visible=False
334
  )
335
+
336
  with gr.Column():
337
+ preview_button = gr.Button("Preview alignment and mask")
338
+
 
339
  gr.Examples(
340
  examples=[
341
+ ["./examples/example_1.webp", 1280, 720, "Middle"],
342
+ ["./examples/example_2.jpg", 1440, 810, "Left"],
343
+ ["./examples/example_3.jpg", 1024, 1024, "Top"],
344
+ ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
345
  ],
346
+ inputs=[input_image, width_slider, height_slider, alignment_dropdown],
 
 
 
 
 
 
 
 
 
 
347
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
+ with gr.Column():
350
+ result = gr.Image(label="Generated Image", type="pil")
351
+ use_as_input_button = gr.Button("Use as Input Image", visible=False)
352
 
353
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
354
+ preview_image = gr.Image(label="Preview")
355
 
356
+ def use_output_as_input(output_image):
357
+ """Sets the generated output as the new input image."""
358
+ return gr.update(value=output_image)
 
359
 
360
  use_as_input_button.click(
361
  fn=use_output_as_input,
362
+ inputs=[result],
363
  outputs=[input_image]
364
  )
365
+
 
 
 
 
 
 
 
366
  target_ratio.change(
367
  fn=preload_presets,
368
  inputs=[target_ratio, width_slider, height_slider],
 
370
  queue=False
371
  )
372
 
 
373
  width_slider.change(
374
+ fn=select_the_right_preset,
375
  inputs=[width_slider, height_slider],
376
+ outputs=[target_ratio],
377
  queue=False
378
  )
379
 
380
  height_slider.change(
381
+ fn=select_the_right_preset,
382
  inputs=[width_slider, height_slider],
383
+ outputs=[target_ratio],
384
  queue=False
385
  )
386
 
 
390
  outputs=[custom_resize_percentage],
391
  queue=False
392
  )
393
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  run_button.click(
395
+ fn=clear_result,
396
  inputs=None,
397
+ outputs=result,
 
398
  ).then(
399
+ fn=infer,
400
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
401
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
402
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
403
+ outputs=result,
404
+ ).then(
405
+ fn=lambda x, history: update_history(x, history),
406
+ inputs=[result, history_gallery],
407
+ outputs=history_gallery,
408
  ).then(
409
+ fn=lambda: gr.update(visible=True),
410
+ inputs=None,
411
+ outputs=use_as_input_button,
412
  )
413
 
414
  prompt_input.submit(
415
+ fn=clear_result,
416
  inputs=None,
417
+ outputs=result,
 
418
  ).then(
419
+ fn=infer,
420
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
421
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
422
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
423
+ outputs=result,
424
+ ).then(
425
+ fn=lambda x, history: update_history(x, history),
426
+ inputs=[result, history_gallery],
427
+ outputs=history_gallery,
428
  ).then(
429
+ fn=lambda: gr.update(visible=True),
430
+ inputs=None,
431
+ outputs=use_as_input_button,
432
  )
433
 
 
434
  preview_button.click(
435
  fn=preview_image_and_mask,
436
  inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
437
  overlap_left, overlap_right, overlap_top, overlap_bottom],
438
+ outputs=preview_image,
439
+ queue=False
440
  )
441
 
442
+ demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)