prithivMLmods commited on
Commit
19d58d4
·
verified ·
1 Parent(s): d4884bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -252
app.py CHANGED
@@ -11,56 +11,52 @@ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
 
14
- # --- Configuration and Model Loading ---
15
-
16
- # Load ControlNet Union
17
  config_file = hf_hub_download(
18
  "xinsir/controlnet-union-sdxl-1.0",
19
  filename="config_promax.json",
20
  )
 
21
  config = ControlNetModel_Union.load_config(config_file)
22
  controlnet_model = ControlNetModel_Union.from_config(config)
23
  model_file = hf_hub_download(
24
  "xinsir/controlnet-union-sdxl-1.0",
25
  filename="diffusion_pytorch_model_promax.safetensors",
26
  )
 
27
  sstate_dict = load_state_dict(model_file)
28
  model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
29
  controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
30
  )
31
  model.to(device="cuda", dtype=torch.float16)
32
 
33
- # Load VAE
34
  vae = AutoencoderKL.from_pretrained(
35
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
36
  ).to("cuda")
37
 
38
- # --- Load Multiple Pipelines ---
39
- pipelines = {}
40
-
41
- # Load RealVisXL V5.0 Lightning
42
- pipe_v5 = StableDiffusionXLFillPipeline.from_pretrained(
43
  "SG161222/RealVisXL_V5.0_Lightning",
44
  torch_dtype=torch.float16,
45
  vae=vae,
46
- controlnet=model, # Use the same controlnet
47
- variant="fp16",
48
- ).to("cuda")
49
- pipe_v5.scheduler = TCDScheduler.from_config(pipe_v5.scheduler.config)
50
- pipelines["RealVisXL V5.0 Lightning"] = pipe_v5
51
-
52
- # Load RealVisXL V4.0 Lightning
53
- pipe_v4 = StableDiffusionXLFillPipeline.from_pretrained(
54
- "SG161222/RealVisXL_V4.0_Lightning",
55
- torch_dtype=torch.float16,
56
- vae=vae, # Use the same VAE
57
- controlnet=model, # Use the same controlnet
58
  variant="fp16",
59
  ).to("cuda")
60
- pipe_v4.scheduler = TCDScheduler.from_config(pipe_v4.scheduler.config)
61
- pipelines["RealVisXL V4.0 Lightning"] = pipe_v4
62
 
63
- # --- Helper Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
66
  target_size = (width, height)
@@ -69,7 +65,7 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
69
  scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
70
  new_width = int(image.width * scale_factor)
71
  new_height = int(image.height * scale_factor)
72
-
73
  # Resize the source image to fit within target size
74
  source = image.resize((new_width, new_height), Image.LANCZOS)
75
 
@@ -121,10 +117,6 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
121
  elif alignment == "Bottom":
122
  margin_x = (target_size[0] - new_width) // 2
123
  margin_y = target_size[1] - new_height
124
- else: # Default to Middle if alignment is somehow invalid
125
- margin_x = (target_size[0] - new_width) // 2
126
- margin_y = (target_size[1] - new_height) // 2
127
-
128
 
129
  # Adjust margins to eliminate gaps
130
  margin_x = max(0, min(margin_x, target_size[0] - new_width))
@@ -135,126 +127,66 @@ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_opti
135
  background.paste(source, (margin_x, margin_y))
136
 
137
  # Create the mask
138
- mask = Image.new('L', target_size, 255) # White background (area to be filled)
139
  mask_draw = ImageDraw.Draw(mask)
140
 
141
- # Calculate overlap areas (where the mask should be black = keep original)
142
- white_gaps_patch = 2 # Small value to ensure no tiny gaps at edges if overlap is off
143
 
144
- # Determine the coordinates for the black rectangle (the non-masked area)
145
- # Start with the full area covered by the pasted image
146
- left_black = margin_x
147
- top_black = margin_y
148
- right_black = margin_x + new_width
149
- bottom_black = margin_y + new_height
150
-
151
- # Adjust the black area based on overlap checkboxes
152
- if overlap_left:
153
- left_black += overlap_x
154
- else:
155
- # If not overlapping left, ensure the black mask starts exactly at the image edge or slightly inside
156
- left_black += white_gaps_patch if alignment != "Left" else 0
157
-
158
- if overlap_right:
159
- right_black -= overlap_x
160
- else:
161
- # If not overlapping right, ensure the black mask ends exactly at the image edge or slightly inside
162
- right_black -= white_gaps_patch if alignment != "Right" else 0
163
-
164
- if overlap_top:
165
- top_black += overlap_y
166
- else:
167
- # If not overlapping top, ensure the black mask starts exactly at the image edge or slightly inside
168
- top_black += white_gaps_patch if alignment != "Top" else 0
169
 
170
- if overlap_bottom:
171
- bottom_black -= overlap_y
172
- else:
173
- # If not overlapping bottom, ensure the black mask ends exactly at the image edge or slightly inside
174
- bottom_black -= white_gaps_patch if alignment != "Bottom" else 0
175
-
176
- # Ensure coordinates are valid (left < right, top < bottom)
177
- left_black = min(left_black, target_size[0])
178
- top_black = min(top_black, target_size[1])
179
- right_black = max(left_black, right_black) # Ensure right >= left
180
- bottom_black = max(top_black, bottom_black) # Ensure bottom >= top
181
- right_black = min(right_black, target_size[0])
182
- bottom_black = min(bottom_black, target_size[1])
183
-
184
-
185
- # Draw the black rectangle onto the white mask
186
- # The area *inside* this rectangle will be kept (mask value 0)
187
- # The area *outside* this rectangle will be filled (mask value 255)
188
- if right_black > left_black and bottom_black > top_black:
189
- mask_draw.rectangle(
190
- [(left_black, top_black), (right_black, bottom_black)],
191
- fill=0 # Black means keep this area
192
- )
193
 
194
  return background, mask
195
 
196
-
197
  @spaces.GPU(duration=24)
198
- def infer(selected_model_name, image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
199
- if image is None:
200
- raise gr.Error("Please upload an input image.")
201
- try:
202
- # Select the pipeline based on the dropdown choice
203
- pipe = pipelines[selected_model_name]
204
-
205
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
206
-
207
- # Create the controlnet input image (original image pasted on white bg, with masked area blacked out)
208
- cnet_image = background.copy()
209
- # Create a black image of the same size as the mask
210
- black_fill = Image.new('RGB', mask.size, (0, 0, 0))
211
- # Paste the black fill using the mask (where mask is 255/white, paste black)
212
- cnet_image.paste(black_fill, (0, 0), mask)
213
-
214
-
215
- final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
216
-
217
- (
218
- prompt_embeds,
219
- negative_prompt_embeds,
220
- pooled_prompt_embeds,
221
- negative_pooled_prompt_embeds,
222
- ) = pipe.encode_prompt(final_prompt, "cuda", True)
223
-
224
- # Generate the image
225
- generator = torch.Generator(device="cuda").manual_seed(np.random.randint(0, 2**32)) # Add random seed
226
-
227
- # The pipeline expects the 'image' argument to be the background with the original content
228
- # and the 'mask_image' argument to define the area to *inpaint* (white area in our mask)
229
- result_image = pipe(
230
- prompt_embeds=prompt_embeds,
231
- negative_prompt_embeds=negative_prompt_embeds,
232
- pooled_prompt_embeds=pooled_prompt_embeds,
233
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
234
- image=background, # The background containing the original image
235
- mask_image=mask, # The mask (white = fill, black = keep)
236
- control_image=cnet_image, # ControlNet input image
237
- num_inference_steps=num_inference_steps,
238
- generator=generator, # Use generator for reproducibility if needed
239
- output_type="pil" # Ensure PIL output
240
- ).images[0]
241
-
242
- # The pipeline directly returns the final composited image.
243
- # No need for manual pasting like before.
244
-
245
- return result_image
246
- except Exception as e:
247
- print(f"Error during inference: {e}")
248
- import traceback
249
- traceback.print_exc()
250
- # Return the background image or raise a Gradio error for clarity
251
- # raise gr.Error(f"Inference failed: {e}")
252
- # Or return the prepared background/mask for debugging
253
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
254
- # Combine background and mask for visualization
255
- debug_img = Image.blend(background.convert("RGBA"), mask.convert("RGBA"), 0.5)
256
- return debug_img # Return a debug image or None
257
-
258
 
259
  def clear_result():
260
  """Clears the result Image."""
@@ -265,21 +197,17 @@ def preload_presets(target_ratio, ui_width, ui_height):
265
  if target_ratio == "9:16":
266
  changed_width = 720
267
  changed_height = 1280
268
- return changed_width, changed_height, gr.update(open=False) # Close accordion on preset
269
  elif target_ratio == "16:9":
270
  changed_width = 1280
271
  changed_height = 720
272
- return changed_width, changed_height, gr.update(open=False) # Close accordion on preset
273
  elif target_ratio == "1:1":
274
  changed_width = 1024
275
  changed_height = 1024
276
- return changed_width, changed_height, gr.update(open=False) # Close accordion on preset
277
  elif target_ratio == "Custom":
278
- # When switching to Custom, keep current slider values but open accordion
279
  return ui_width, ui_height, gr.update(open=True)
280
- # Should not happen, but return current values if it does
281
- return ui_width, ui_height, gr.update()
282
-
283
 
284
  def select_the_right_preset(user_width, user_height):
285
  if user_width == 720 and user_height == 1280:
@@ -296,71 +224,54 @@ def toggle_custom_resize_slider(resize_option):
296
 
297
  def update_history(new_image, history):
298
  """Updates the history gallery with the new image."""
299
- if new_image is None: # Don't add None to history (e.g., on clear or error)
300
- return history
301
  if history is None:
302
  history = []
303
- # Prepend the new image (as PIL or path depending on Gallery config)
304
  history.insert(0, new_image)
305
- # Limit history size if desired (e.g., keep last 12)
306
- max_history = 12
307
- if len(history) > max_history:
308
- history = history[:max_history]
309
  return history
310
 
311
- # --- CSS and Title ---
312
  css = """
313
  h1 {
314
- text-align: center;
315
- display: block;
316
- }
317
- .gradio-container {
318
- max-width: 1280px !important;
319
- margin: auto !important;
320
  }
321
  """
322
 
323
  title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
324
- <p align="center">Expand images using ControlNet Union and Lightning models. Choose a base model below.</p>
325
  """
326
 
327
- # --- Gradio UI ---
328
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
329
  with gr.Column():
330
  gr.HTML(title)
331
 
332
  with gr.Row():
333
- with gr.Column(scale=2): # Input column
334
  input_image = gr.Image(
335
  type="pil",
336
  label="Input Image"
337
  )
338
-
339
- # --- Model Selector ---
340
- model_selector = gr.Dropdown(
341
- label="Select Model",
342
- choices=list(pipelines.keys()),
343
- value="RealVisXL V5.0 Lightning", # Default model
344
  )
345
-
346
  with gr.Row():
347
  with gr.Column(scale=2):
348
- prompt_input = gr.Textbox(label="Prompt (Describe the desired output)", placeholder="e.g., beautiful landscape, photorealistic")
349
- with gr.Column(scale=1, min_width=120):
350
- run_button = gr.Button("Generate", variant="primary")
351
 
352
  with gr.Row():
353
  target_ratio = gr.Radio(
354
- label="Target Ratio",
355
  choices=["9:16", "16:9", "1:1", "Custom"],
356
- value="9:16", # Default ratio
357
  scale=2
358
  )
359
-
360
  alignment_dropdown = gr.Dropdown(
361
  choices=["Middle", "Left", "Right", "Top", "Bottom"],
362
  value="Middle",
363
- label="Align Input Image"
364
  )
365
 
366
  with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
@@ -368,43 +279,38 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
368
  with gr.Row():
369
  width_slider = gr.Slider(
370
  label="Target Width",
371
- minimum=512, # Lowered minimum slightly
372
  maximum=1536,
373
- step=64, # Steps of 64 common for SDXL
374
- value=720, # Default width
375
  )
376
  height_slider = gr.Slider(
377
  label="Target Height",
378
- minimum=512, # Lowered minimum slightly
379
  maximum=1536,
380
- step=64, # Steps of 64
381
- value=1280, # Default height
382
  )
383
-
384
  num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
385
-
386
  with gr.Group():
387
  overlap_percentage = gr.Slider(
388
  label="Mask overlap (%)",
389
- info="Percentage of the input image edge to keep (reduces seams)",
390
  minimum=1,
391
  maximum=50,
392
- value=10, # Default overlap
393
  step=1
394
  )
395
- gr.Markdown("Select edges to apply overlap:")
396
  with gr.Row():
397
- overlap_top = gr.Checkbox(label="Top", value=True)
398
- overlap_right = gr.Checkbox(label="Right", value=True)
399
- overlap_left = gr.Checkbox(label="Left", value=True)
400
- overlap_bottom = gr.Checkbox(label="Bottom", value=True)
401
-
402
  with gr.Row():
403
  resize_option = gr.Radio(
404
- label="Resize input image before placing",
405
- info="Scale the input image relative to its fitted size",
406
  choices=["Full", "50%", "33%", "25%", "Custom"],
407
- value="Full" # Default resize option
408
  )
409
  custom_resize_percentage = gr.Slider(
410
  label="Custom resize (%)",
@@ -412,48 +318,35 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
412
  maximum=100,
413
  step=1,
414
  value=50,
415
- visible=False # Initially hidden
416
  )
417
-
418
  gr.Examples(
419
  examples=[
420
- ["./examples/example_1.webp", "RealVisXL V5.0 Lightning", 1280, 720, "Middle"],
421
- ["./examples/example_2.jpg", "RealVisXL V4.0 Lightning", 1440, 810, "Left"],
422
- ["./examples/example_3.jpg", "RealVisXL V5.0 Lightning", 1024, 1024, "Top"],
423
- ["./examples/example_3.jpg", "RealVisXL V5.0 Lightning", 1024, 1024, "Bottom"],
424
  ],
425
- inputs=[input_image, model_selector, width_slider, height_slider, alignment_dropdown],
426
- label="Examples (Prompt is optional)"
427
  )
428
 
429
- with gr.Column(scale=3): # Output column
430
  result = gr.Image(
431
  interactive=False,
432
  label="Generated Image",
433
  format="png",
434
  )
435
- history_gallery = gr.Gallery(
436
- label="History",
437
- columns=4, # Adjust columns as needed
438
- object_fit="contain",
439
- interactive=False,
440
- show_label=True,
441
- allow_preview=True,
442
- preview=True
443
- )
444
-
445
-
446
- # --- Event Listeners ---
447
 
448
- # Update sliders and accordion based on ratio selection
 
449
  target_ratio.change(
450
  fn=preload_presets,
451
  inputs=[target_ratio, width_slider, height_slider],
452
  outputs=[width_slider, height_slider, settings_panel],
453
  queue=False
454
  )
455
-
456
- # Update ratio selection based on slider changes
457
  width_slider.change(
458
  fn=select_the_right_preset,
459
  inputs=[width_slider, height_slider],
@@ -466,55 +359,42 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
466
  outputs=[target_ratio],
467
  queue=False
468
  )
469
-
470
- # Show/hide custom resize slider
471
  resize_option.change(
472
  fn=toggle_custom_resize_slider,
473
  inputs=[resize_option],
474
  outputs=[custom_resize_percentage],
475
  queue=False
476
  )
477
-
478
- # Define inputs for the main inference function
479
- infer_inputs = [
480
- model_selector, input_image, width_slider, height_slider, overlap_percentage,
481
- num_inference_steps, resize_option, custom_resize_percentage, prompt_input,
482
- alignment_dropdown, overlap_left, overlap_right, overlap_top, overlap_bottom
483
- ]
484
-
485
- # --- Run Button Click ---
486
  run_button.click(
487
  fn=clear_result,
488
  inputs=None,
489
- outputs=[result], # Clear only the main result image
490
- queue=False # Clearing should be fast
491
  ).then(
492
  fn=infer,
493
- inputs=infer_inputs,
494
- outputs=[result], # Output to the main result image
 
 
495
  ).then(
496
- fn=update_history, # Use the specific update function
497
- inputs=[result, history_gallery], # Pass the result and current history
498
- outputs=[history_gallery], # Update the history gallery
499
  )
500
-
501
- # --- Prompt Submit (Enter Key) ---
502
  prompt_input.submit(
503
- fn=clear_result,
504
  inputs=None,
505
- outputs=[result],
506
- queue=False
507
  ).then(
508
  fn=infer,
509
- inputs=infer_inputs,
510
- outputs=[result],
 
 
511
  ).then(
512
- fn=update_history,
513
  inputs=[result, history_gallery],
514
- outputs=[history_gallery],
515
  )
 
516
 
517
- # --- Launch App ---
518
- # Make sure you have example images at the specified paths or remove/update the gr.Examples section
519
- # Create an 'examples' directory and place images like 'example_1.webp', 'example_2.jpg', 'example_3.jpg' inside it.
520
  demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)
 
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
 
14
+ # Load configuration and models
 
 
15
  config_file = hf_hub_download(
16
  "xinsir/controlnet-union-sdxl-1.0",
17
  filename="config_promax.json",
18
  )
19
+
20
  config = ControlNetModel_Union.load_config(config_file)
21
  controlnet_model = ControlNetModel_Union.from_config(config)
22
  model_file = hf_hub_download(
23
  "xinsir/controlnet-union-sdxl-1.0",
24
  filename="diffusion_pytorch_model_promax.safetensors",
25
  )
26
+
27
  sstate_dict = load_state_dict(model_file)
28
  model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
29
  controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
30
  )
31
  model.to(device="cuda", dtype=torch.float16)
32
 
 
33
  vae = AutoencoderKL.from_pretrained(
34
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
35
  ).to("cuda")
36
 
37
+ # Initially load the default pipeline
38
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
 
 
 
39
  "SG161222/RealVisXL_V5.0_Lightning",
40
  torch_dtype=torch.float16,
41
  vae=vae,
42
+ controlnet=model,
 
 
 
 
 
 
 
 
 
 
 
43
  variant="fp16",
44
  ).to("cuda")
 
 
45
 
46
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
47
+
48
+ def load_model(selected_model):
49
+ global pipe
50
+ model_path = f"SG161222/{selected_model}"
51
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
52
+ model_path,
53
+ torch_dtype=torch.float16,
54
+ vae=vae,
55
+ controlnet=model,
56
+ variant="fp16",
57
+ ).to("cuda")
58
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
59
+ return f"Loaded model: {selected_model}"
60
 
61
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
62
  target_size = (width, height)
 
65
  scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
66
  new_width = int(image.width * scale_factor)
67
  new_height = int(image.height * scale_factor)
68
+
69
  # Resize the source image to fit within target size
70
  source = image.resize((new_width, new_height), Image.LANCZOS)
71
 
 
117
  elif alignment == "Bottom":
118
  margin_x = (target_size[0] - new_width) // 2
119
  margin_y = target_size[1] - new_height
 
 
 
 
120
 
121
  # Adjust margins to eliminate gaps
122
  margin_x = max(0, min(margin_x, target_size[0] - new_width))
 
127
  background.paste(source, (margin_x, margin_y))
128
 
129
  # Create the mask
130
+ mask = Image.new('L', target_size, 255)
131
  mask_draw = ImageDraw.Draw(mask)
132
 
133
+ # Calculate overlap areas
134
+ white_gaps_patch = 2
135
 
136
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
137
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
138
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
139
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
140
+
141
+ if alignment == "Left":
142
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x
143
+ elif alignment == "Right":
144
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
145
+ elif alignment == "Top":
146
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y
147
+ elif alignment == "Bottom":
148
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ # Draw the mask
151
+ mask_draw.rectangle([
152
+ (left_overlap, top_overlap),
153
+ (right_overlap, bottom_overlap)
154
+ ], fill=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  return background, mask
157
 
 
158
  @spaces.GPU(duration=24)
159
+ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
160
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
161
+
162
+ cnet_image = background.copy()
163
+ cnet_image.paste(0, (0, 0), mask)
164
+
165
+ final_prompt = f"{prompt_input} , high quality, 4k"
166
+
167
+ (
168
+ prompt_embeds,
169
+ negative_prompt_embeds,
170
+ pooled_prompt_embeds,
171
+ negative_pooled_prompt_embeds,
172
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
173
+
174
+ # Generate the image
175
+ for image in pipe(
176
+ prompt_embeds=prompt_embeds,
177
+ negative_prompt_embeds=negative_prompt_embeds,
178
+ pooled_prompt_embeds=pooled_prompt_embeds,
179
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
180
+ image=cnet_image,
181
+ num_inference_steps=num_inference_steps
182
+ ):
183
+ pass # Wait for the generation to complete
184
+ generated_image = image # Get the last image
185
+
186
+ generated_image = generated_image.convert("RGBA")
187
+ cnet_image.paste(generated_image, (0, 0), mask)
188
+
189
+ return cnet_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  def clear_result():
192
  """Clears the result Image."""
 
197
  if target_ratio == "9:16":
198
  changed_width = 720
199
  changed_height = 1280
200
+ return changed_width, changed_height, gr.update()
201
  elif target_ratio == "16:9":
202
  changed_width = 1280
203
  changed_height = 720
204
+ return changed_width, changed_height, gr.update()
205
  elif target_ratio == "1:1":
206
  changed_width = 1024
207
  changed_height = 1024
208
+ return changed_width, changed_height, gr.update()
209
  elif target_ratio == "Custom":
 
210
  return ui_width, ui_height, gr.update(open=True)
 
 
 
211
 
212
  def select_the_right_preset(user_width, user_height):
213
  if user_width == 720 and user_height == 1280:
 
224
 
225
  def update_history(new_image, history):
226
  """Updates the history gallery with the new image."""
 
 
227
  if history is None:
228
  history = []
 
229
  history.insert(0, new_image)
 
 
 
 
230
  return history
231
 
232
+ # CSS and Title
233
  css = """
234
  h1 {
235
+ text-align: center;
236
+ display: block;
 
 
 
 
237
  }
238
  """
239
 
240
  title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
 
241
  """
242
 
 
243
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
244
  with gr.Column():
245
  gr.HTML(title)
246
 
247
  with gr.Row():
248
+ with gr.Column():
249
  input_image = gr.Image(
250
  type="pil",
251
  label="Input Image"
252
  )
253
+ model_selection = gr.Dropdown(
254
+ choices=["RealVisXL_V5.0_Lightning", "RealVisXL_V4.0_Lightning"],
255
+ value="RealVisXL_V5.0_Lightning",
256
+ label="Select Model"
 
 
257
  )
 
258
  with gr.Row():
259
  with gr.Column(scale=2):
260
+ prompt_input = gr.Textbox(label="Prompt (Optional)")
261
+ with gr.Column(scale=1):
262
+ run_button = gr.Button("Generate")
263
 
264
  with gr.Row():
265
  target_ratio = gr.Radio(
266
+ label="Expected Ratio",
267
  choices=["9:16", "16:9", "1:1", "Custom"],
268
+ value="9:16",
269
  scale=2
270
  )
 
271
  alignment_dropdown = gr.Dropdown(
272
  choices=["Middle", "Left", "Right", "Top", "Bottom"],
273
  value="Middle",
274
+ label="Alignment"
275
  )
276
 
277
  with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
 
279
  with gr.Row():
280
  width_slider = gr.Slider(
281
  label="Target Width",
282
+ minimum=720,
283
  maximum=1536,
284
+ step=8,
285
+ value=720,
286
  )
287
  height_slider = gr.Slider(
288
  label="Target Height",
289
+ minimum=720,
290
  maximum=1536,
291
+ step=8,
292
+ value=1280,
293
  )
 
294
  num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
 
295
  with gr.Group():
296
  overlap_percentage = gr.Slider(
297
  label="Mask overlap (%)",
 
298
  minimum=1,
299
  maximum=50,
300
+ value=10,
301
  step=1
302
  )
 
303
  with gr.Row():
304
+ overlap_top = gr.Checkbox(label="Overlap Top", value=True)
305
+ overlap_right = gr.Checkbox(label="Overlap Right", value=True)
306
+ with gr.Row():
307
+ overlap_left = gr.Checkbox(label="Overlap Left", value=True)
308
+ overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
309
  with gr.Row():
310
  resize_option = gr.Radio(
311
+ label="Resize input image",
 
312
  choices=["Full", "50%", "33%", "25%", "Custom"],
313
+ value="Full"
314
  )
315
  custom_resize_percentage = gr.Slider(
316
  label="Custom resize (%)",
 
318
  maximum=100,
319
  step=1,
320
  value=50,
321
+ visible=False
322
  )
323
+ status_text = gr.Textbox(label="Status", interactive=False)
324
  gr.Examples(
325
  examples=[
326
+ ["./examples/example_1.webp", 1280, 720, "Middle"],
327
+ ["./examples/example_2.jpg", 1440, 810, "Left"],
328
+ ["./examples/example_3.jpg", 1024, 1024, "Top"],
329
+ ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
330
  ],
331
+ inputs=[input_image, width_slider, height_slider, alignment_dropdown],
 
332
  )
333
 
334
+ with gr.Column():
335
  result = gr.Image(
336
  interactive=False,
337
  label="Generated Image",
338
  format="png",
339
  )
340
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
341
 
342
+ # Event handlers
343
+ model_selection.change(fn=load_model, inputs=model_selection, outputs=status_text)
344
  target_ratio.change(
345
  fn=preload_presets,
346
  inputs=[target_ratio, width_slider, height_slider],
347
  outputs=[width_slider, height_slider, settings_panel],
348
  queue=False
349
  )
 
 
350
  width_slider.change(
351
  fn=select_the_right_preset,
352
  inputs=[width_slider, height_slider],
 
359
  outputs=[target_ratio],
360
  queue=False
361
  )
 
 
362
  resize_option.change(
363
  fn=toggle_custom_resize_slider,
364
  inputs=[resize_option],
365
  outputs=[custom_resize_percentage],
366
  queue=False
367
  )
 
 
 
 
 
 
 
 
 
368
  run_button.click(
369
  fn=clear_result,
370
  inputs=None,
371
+ outputs=result,
 
372
  ).then(
373
  fn=infer,
374
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
375
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
376
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
377
+ outputs=result,
378
  ).then(
379
+ fn=lambda x, history: update_history(x, history),
380
+ inputs=[result, history_gallery],
381
+ outputs=history_gallery,
382
  )
 
 
383
  prompt_input.submit(
384
+ fn=clear_result,
385
  inputs=None,
386
+ outputs=result,
 
387
  ).then(
388
  fn=infer,
389
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
390
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
391
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
392
+ outputs=result,
393
  ).then(
394
+ fn=lambda x, history: update_history(x, history),
395
  inputs=[result, history_gallery],
396
+ outputs=history_gallery,
397
  )
398
+ demo.load(fn=load_model, inputs=model_selection, outputs=status_text)
399
 
 
 
 
400
  demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)