prithivMLmods commited on
Commit
4074d29
·
verified ·
1 Parent(s): 03b41ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +580 -365
app.py CHANGED
@@ -3,51 +3,77 @@ import spaces
3
  import torch
4
  from diffusers import AutoencoderKL, TCDScheduler
5
  from diffusers.models.model_loading_utils import load_state_dict
6
- # Remove ImageSlider import
7
- # from gradio_imageslider import ImageSlider
8
  from huggingface_hub import hf_hub_download
9
 
10
- from controlnet_union import ControlNetModel_Union
11
- from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  from PIL import Image, ImageDraw
14
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # --- Model Loading (Unchanged) ---
17
- config_file = hf_hub_download(
18
- "xinsir/controlnet-union-sdxl-1.0",
19
- filename="config_promax.json",
20
- )
21
-
22
- config = ControlNetModel_Union.load_config(config_file)
23
- controlnet_model = ControlNetModel_Union.from_config(config)
24
- model_file = hf_hub_download(
25
- "xinsir/controlnet-union-sdxl-1.0",
26
- filename="diffusion_pytorch_model_promax.safetensors",
27
- )
28
-
29
- sstate_dict = load_state_dict(model_file)
30
- model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
31
- controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
32
- )
33
- model.to(device="cuda", dtype=torch.float16)
34
- #----------------------
35
-
36
- vae = AutoencoderKL.from_pretrained(
37
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
38
- ).to("cuda")
39
-
40
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
41
- "SG161222/RealVisXL_V5.0_Lightning",
42
- torch_dtype=torch.float16,
43
- vae=vae,
44
- controlnet=model,
45
- variant="fp16",
46
- ).to("cuda")
47
-
48
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
49
-
50
- # --- Helper Functions (Mostly Unchanged) ---
51
  def can_expand(source_width, source_height, target_width, target_height, alignment):
52
  """Checks if the image can be expanded based on the alignment."""
53
  if alignment in ("Left", "Right") and source_width >= target_width:
@@ -57,211 +83,305 @@ def can_expand(source_width, source_height, target_width, target_height, alignme
57
  return True
58
 
59
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
60
- target_size = (width, height)
61
-
62
- # Calculate the scaling factor to fit the image within the target size
63
- scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
64
- new_width = int(image.width * scale_factor)
65
- new_height = int(image.height * scale_factor)
66
-
67
- # Resize the source image to fit within target size
68
- source = image.resize((new_width, new_height), Image.LANCZOS)
69
-
70
- # Apply resize option using percentages
71
- if resize_option == "Full":
72
- resize_percentage = 100
73
- elif resize_option == "50%":
74
- resize_percentage = 50
75
- elif resize_option == "33%":
76
- resize_percentage = 33
77
- elif resize_option == "25%":
78
- resize_percentage = 25
79
- else: # Custom
80
- resize_percentage = custom_resize_percentage
81
-
82
- # Calculate new dimensions based on percentage
83
- resize_factor = resize_percentage / 100
84
- new_width = int(source.width * resize_factor)
85
- new_height = int(source.height * resize_factor)
86
-
87
- # Ensure minimum size of 64 pixels
88
- new_width = max(new_width, 64)
89
- new_height = max(new_height, 64)
90
-
91
- # Resize the image
92
- source = source.resize((new_width, new_height), Image.LANCZOS)
93
-
94
- # Calculate the overlap in pixels based on the percentage
95
- overlap_x = int(new_width * (overlap_percentage / 100))
96
- overlap_y = int(new_height * (overlap_percentage / 100))
97
-
98
- # Ensure minimum overlap of 1 pixel
99
- overlap_x = max(overlap_x, 1)
100
- overlap_y = max(overlap_y, 1)
101
-
102
- # Calculate margins based on alignment
103
- if alignment == "Middle":
104
- margin_x = (target_size[0] - new_width) // 2
105
- margin_y = (target_size[1] - new_height) // 2
106
- elif alignment == "Left":
107
- margin_x = 0
108
- margin_y = (target_size[1] - new_height) // 2
109
- elif alignment == "Right":
110
- margin_x = target_size[0] - new_width
111
- margin_y = (target_size[1] - new_height) // 2
112
- elif alignment == "Top":
113
- margin_x = (target_size[0] - new_width) // 2
114
- margin_y = 0
115
- elif alignment == "Bottom":
116
- margin_x = (target_size[0] - new_width) // 2
117
- margin_y = target_size[1] - new_height
118
-
119
- # Adjust margins to eliminate gaps
120
- margin_x = max(0, min(margin_x, target_size[0] - new_width))
121
- margin_y = max(0, min(margin_y, target_size[1] - new_height))
122
-
123
- # Create a new background image and paste the resized source image
124
- background = Image.new('RGB', target_size, (255, 255, 255))
125
- background.paste(source, (margin_x, margin_y))
126
-
127
- # Create the mask
128
- mask = Image.new('L', target_size, 255)
129
- mask_draw = ImageDraw.Draw(mask)
130
-
131
- # Calculate overlap areas
132
- white_gaps_patch = 2 # Pixels to leave unmasked at edges if overlap is disabled for that edge
133
-
134
- left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
135
- right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
136
- top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
137
- bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
138
-
139
- # Adjust overlap boundaries based on alignment when specific overlap directions are *disabled*
140
- # This prevents unmasking the absolute edge of the canvas in alignment modes
141
- if alignment == "Left":
142
- left_overlap = margin_x + overlap_x if overlap_left else margin_x # Keep edge masked if alignment is left
143
- elif alignment == "Right":
144
- right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width # Keep edge masked
145
- elif alignment == "Top":
146
- top_overlap = margin_y + overlap_y if overlap_top else margin_y # Keep edge masked
147
- elif alignment == "Bottom":
148
- bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height # Keep edge masked
149
-
150
- # Ensure coordinates are within bounds
151
- left_overlap = max(0, left_overlap)
152
- top_overlap = max(0, top_overlap)
153
- right_overlap = min(target_size[0], right_overlap)
154
- bottom_overlap = min(target_size[1], bottom_overlap)
155
-
156
- # Draw the mask (black rectangle for the area to keep)
157
- if right_overlap > left_overlap and bottom_overlap > top_overlap:
158
- mask_draw.rectangle([
159
- (left_overlap, top_overlap),
160
- (right_overlap, bottom_overlap)
161
- ], fill=0) # 0 means keep this area (not masked for inpainting)
162
-
163
- # Invert the mask: White areas (255) will be inpainted. Black (0) is kept.
164
- mask = Image.fromarray(255 - np.array(mask))
165
-
166
- return background, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
169
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
170
-
171
- # Create a preview image showing the mask
172
- preview = background.copy().convert('RGBA')
173
-
174
- # Create a semi-transparent red overlay for the masked (inpainting) area
175
- red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 100)) # 100 alpha (~40% opacity)
176
-
177
- # The mask is now white (255) where inpainting happens. Use this directly.
178
- preview.paste(red_overlay, (0, 0), mask)
179
-
180
- return preview
181
-
182
- @spaces.GPU(duration=24)
183
- def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
184
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
185
-
186
- # Ensure alignment allows expansion, default to Middle if not
187
- source_w, source_h = background.size # Use background size after initial resize/placement
188
- target_w, target_h = width, height
189
- if alignment in ("Left", "Right") and source_w >= target_w:
190
- print(f"Warning: Source width ({source_w}) >= target width ({target_w}) with {alignment} alignment. Forcing Middle alignment.")
191
- alignment = "Middle"
192
- # Re-prepare mask/background with corrected alignment if needed (optional, depends if prepare func uses alignment early)
193
- # background, mask = prepare_image_and_mask(...) # If needed
194
- if alignment in ("Top", "Bottom") and source_h >= target_h:
195
- print(f"Warning: Source height ({source_h}) >= target height ({target_h}) with {alignment} alignment. Forcing Middle alignment.")
196
- alignment = "Middle"
197
- # Re-prepare mask/background with corrected alignment if needed
198
- # background, mask = prepare_image_and_mask(...) # If needed
199
-
200
- # Image for ControlNet input (masked original content)
201
- # The pipeline expects the original image content in the non-masked area
202
- cnet_image = background.copy()
203
- # The pipeline's `image` argument is the *initial* content for the *masked* area (often noise, but here we provide the background)
204
- # The `mask_image` tells the pipeline *where* to perform the inpainting/outpainting.
205
- # The controlnet `image` needs the original content visible in the non-masked area.
206
- # ControlNet Union seems to work well by just passing the background with the source image pasted.
207
-
208
- final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
209
-
210
- (
211
- prompt_embeds,
212
- negative_prompt_embeds,
213
- pooled_prompt_embeds,
214
- negative_pooled_prompt_embeds,
215
- ) = pipe.encode_prompt(final_prompt, "cuda", True)
216
-
217
- # The pipeline call
218
- # Note: The pipeline expects `image` (initial state for masked area) and `mask_image`
219
- # The `control_image` is implicitly handled by the ControlNet attached to the pipeline
220
- output_image = pipe(
221
- prompt_embeds=prompt_embeds,
222
- negative_prompt_embeds=negative_prompt_embeds,
223
- pooled_prompt_embeds=pooled_prompt_embeds,
224
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
225
- image=background, # Provide the initial canvas state
226
- mask_image=mask, # Provide the mask (white is area to change)
227
- control_image=cnet_image, # Pass the control image explicitly if needed by pipeline logic
228
- num_inference_steps=num_inference_steps,
229
- output_type="pil" # Ensure PIL output
230
- ).images[0]
231
-
232
- # The pipeline should have already handled the compositing based on the mask
233
- # If not, uncomment the paste operation below:
234
- # final_image = background.copy().convert("RGBA") # Start with original background
235
- # output_image = output_image.convert("RGBA")
236
- # mask_rgba = mask.convert('L').point(lambda p: 255 if p > 128 else 0) # Ensure mask is binary 0/255
237
- # final_image.paste(output_image, (0, 0), mask_rgba) # Paste generated content using the mask
238
-
239
- # Return the single final image
240
- return output_image
241
-
242
-
243
- def clear_result():
244
- """Clears the result Image component."""
245
- return gr.update(value=None)
246
-
247
- # --- UI Helper Functions (Unchanged) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  def preload_presets(target_ratio, ui_width, ui_height):
249
  """Updates the width and height sliders based on the selected aspect ratio."""
 
250
  if target_ratio == "9:16":
251
  changed_width = 720
252
  changed_height = 1280
253
- return changed_width, changed_height, gr.update() # Close accordion
254
  elif target_ratio == "16:9":
255
  changed_width = 1280
256
  changed_height = 720
257
- return changed_width, changed_height, gr.update() # Close accordion
258
  elif target_ratio == "1:1":
259
  changed_width = 1024
260
  changed_height = 1024
261
- return changed_width, changed_height, gr.update() # Close accordion
262
  elif target_ratio == "Custom":
263
- # Don't change sliders, just open accordion
264
- return ui_width, ui_height, gr.update(open=True)
 
 
 
 
 
 
265
 
266
  def select_the_right_preset(user_width, user_height):
267
  """Updates the radio button based on the current slider values."""
@@ -280,172 +400,254 @@ def toggle_custom_resize_slider(resize_option):
280
 
281
  def update_history(new_image, history):
282
  """Updates the history gallery with the new image."""
 
 
 
283
  if history is None:
284
  history = []
285
- # Ensure new_image is a PIL Image before adding
286
- if isinstance(new_image, Image.Image):
287
- history.insert(0, new_image)
 
 
288
  return history
289
 
290
  # --- Gradio UI Definition ---
291
  css = """
292
  .gradio-container {
293
- width: 1200px !important;
294
  margin: auto !important; /* Center the container */
 
295
  }
296
- h1 { text-align: center; }
297
- footer { visibility: hidden; }
 
298
  /* Ensure result image takes reasonable space */
299
  #result-image img {
300
  max-height: 768px; /* Adjust max height as needed */
301
  object-fit: contain;
302
- width: auto;
303
  height: auto;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  }
 
305
  #history-gallery .thumbnail-item { /* Style history items */
306
  height: 100px !important;
 
307
  }
308
  #history-gallery .gallery {
309
  grid-template-rows: repeat(auto-fill, 100px) !important;
 
 
 
 
 
 
310
  }
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  """
313
 
314
- title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>"""
315
 
316
- with gr.Blocks(css=css) as demo:
317
- with gr.Column():
318
- gr.HTML(title)
 
319
 
320
- with gr.Row():
321
- with gr.Column(scale=1): # Left column for inputs
322
- input_image = gr.Image(
323
- type="pil",
324
- label="Input Image",
325
- height=400 # Give input image reasonable height
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  )
327
 
 
328
  with gr.Row():
329
- with gr.Column(scale=2):
330
- prompt_input = gr.Textbox(label="Prompt (Optional)", placeholder="Describe the scene to expand...")
331
- with gr.Column(scale=1):
332
- run_button = gr.Button("Generate", variant="primary") # Make primary
 
 
 
 
 
333
 
334
- with gr.Row():
335
- target_ratio = gr.Radio(
336
- label="Target Ratio",
337
- choices=["9:16", "16:9", "1:1", "Custom"],
338
- value="9:16",
339
- scale=2
340
  )
 
 
 
 
 
 
341
 
342
- alignment_dropdown = gr.Dropdown(
343
- choices=["Middle", "Left", "Right", "Top", "Bottom"],
344
- value="Middle",
345
- label="Align Source Image"
346
- )
347
 
348
- with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
349
- with gr.Row():
350
- width_slider = gr.Slider(
351
- label="Target Width",
352
- minimum=512, # Lowered minimum slightly
353
- maximum=2048, # Increased maximum slightly
354
- step=64, # Use steps of 64 common for SD
355
- value=720,
356
- )
357
- height_slider = gr.Slider(
358
- label="Target Height",
359
- minimum=512,
360
- maximum=2048,
361
- step=64,
362
- value=1280,
363
- )
364
- num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=12, step=1, value=4) # TCD/Lightning allows few steps
365
-
366
- with gr.Group():
367
- overlap_percentage = gr.Slider(
368
- label="Mask overlap (%)",
369
- minimum=1,
370
- maximum=50,
371
- value=12, # Default overlap
372
- step=1
373
- )
374
- with gr.Row():
375
- overlap_top = gr.Checkbox(label="Top", value=True)
376
- overlap_right = gr.Checkbox(label="Right", value=True)
377
- overlap_bottom = gr.Checkbox(label="Bottom", value=True)
378
- overlap_left = gr.Checkbox(label="Left", value=True)
379
-
380
-
381
- with gr.Row():
382
- resize_option = gr.Radio(
383
- label="Resize input within target",
384
- choices=["Full", "50%", "33%", "25%", "Custom"],
385
- value="Full"
386
- )
387
- custom_resize_percentage = gr.Slider(
388
- label="Custom resize (%)",
389
- minimum=1,
390
- maximum=100,
391
- step=1,
392
- value=50,
393
- visible=False # Initially hidden
394
- )
395
-
396
- preview_button = gr.Button("Preview Mask & Alignment")
397
- preview_image = gr.Image(label="Mask Preview (Red = Outpaint Area)", type="pil", interactive=False)
398
 
 
 
399
 
 
400
  gr.Examples(
401
- examples=[
402
- ["./examples/example_1.webp", "A wide landscape view of the mountains", 1280, 720, "Middle"],
403
- ["./examples/example_2.jpg", "Full body shot of the astronaut on the moon", 720, 1280, "Middle"],
404
- ["./examples/example_3.jpg", "Expanding the sky and ground around the subject", 1024, 1024, "Middle"],
405
- ["./examples/example_3.jpg", "Expanding downwards from the subject", 1024, 1024, "Top"], # Align subject Top
406
- ["./examples/example_3.jpg", "Expanding upwards from the subject", 1024, 1024, "Bottom"], # Align subject Bottom
407
- ],
408
  inputs=[input_image, prompt_input, width_slider, height_slider, alignment_dropdown],
409
- label="Examples (Click to load)"
 
410
  )
 
 
411
 
 
412
 
413
- with gr.Column(scale=1): # Right column for output
414
- # Replace ImageSlider with gr.Image
415
- result = gr.Image(label="Generated Image", type="pil", interactive=False, elem_id="result-image")
416
- use_as_input_button = gr.Button("Use Result as Input Image", visible=False) # Initially hidden
417
 
418
- history_gallery = gr.Gallery(
419
- label="History",
420
- columns=6,
421
- object_fit="contain",
422
- interactive=False,
423
- height=110, # Fixed height for the row
424
- elem_id="history-gallery"
425
- )
426
 
 
 
 
 
427
 
428
  # --- Event Handling ---
429
 
430
- def use_output_as_input(output_image):
431
- """Sets the generated output as the new input image."""
432
- # output_image is now the single final image from gr.Image
433
- return gr.update(value=output_image)
 
 
 
434
 
435
  use_as_input_button.click(
436
- fn=use_output_as_input,
437
- inputs=[result], # Input is the result image component
438
- outputs=[input_image] # Output updates the input image component
439
  )
440
 
441
  target_ratio.change(
442
  fn=preload_presets,
443
  inputs=[target_ratio, width_slider, height_slider],
444
- outputs=[width_slider, height_slider, settings_panel], # Also control accordion state
445
  queue=False
446
  )
447
 
448
- # Link sliders back to the ratio selector
449
  width_slider.change(
450
  fn=select_the_right_preset,
451
  inputs=[width_slider, height_slider],
@@ -472,58 +674,71 @@ with gr.Blocks(css=css) as demo:
472
  resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
473
  overlap_left, overlap_right, overlap_top, overlap_bottom
474
  ]
 
475
 
476
- # Chain generation logic
477
- run_button.click(
478
- fn=clear_result,
479
- inputs=None,
480
- outputs=[result], # Clear the single image output
481
- queue=False # Run clearing immediately
482
  ).then(
483
  fn=infer,
484
  inputs=gen_inputs,
485
- outputs=[result], # Output the single image to the result component
486
- ).then(
487
- # Update history with the single result image
 
 
488
  fn=lambda res_img, hist: update_history(res_img, hist),
489
  inputs=[result, history_gallery],
490
  outputs=[history_gallery],
491
- queue=False # Update history immediately after generation
492
  ).then(
493
- # Show the 'Use as Input' button
494
- fn=lambda: gr.update(visible=True),
495
- inputs=None,
496
  outputs=[use_as_input_button],
497
  queue=False # Show button immediately
498
  )
499
 
500
- prompt_input.submit(
 
 
501
  fn=clear_result,
502
- inputs=None,
503
- outputs=[result],
504
  queue=False
505
  ).then(
506
  fn=infer,
507
  inputs=gen_inputs,
508
- outputs=[result],
509
- ).then(
 
 
510
  fn=lambda res_img, hist: update_history(res_img, hist),
511
  inputs=[result, history_gallery],
512
  outputs=[history_gallery],
513
  queue=False
514
  ).then(
515
- fn=lambda: gr.update(visible=True),
516
- inputs=None,
517
  outputs=[use_as_input_button],
518
  queue=False
519
  )
520
 
 
 
 
 
 
 
521
  preview_button.click(
522
  fn=preview_image_and_mask,
523
- inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
524
- overlap_left, overlap_right, overlap_top, overlap_bottom],
525
  outputs=preview_image,
526
- queue=False # Preview should be fast
527
  )
528
 
529
- demo.queue(max_size=10).launch(ssr_mode=False, show_error=True) # Removed share=False for potential Hugging Face Spaces use
 
 
3
  import torch
4
  from diffusers import AutoencoderKL, TCDScheduler
5
  from diffusers.models.model_loading_utils import load_state_dict
6
+ # Removed ImageSlider import
 
7
  from huggingface_hub import hf_hub_download
8
 
9
+ # Ensure these custom modules are accessible in the environment
10
+ # If running locally, they should be in the same directory or installed
11
+ try:
12
+ from controlnet_union import ControlNetModel_Union
13
+ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
14
+ except ImportError as e:
15
+ print(f"Error importing custom modules: {e}")
16
+ print("Please ensure 'controlnet_union.py' and 'pipeline_fill_sd_xl.py' are in the working directory or installed.")
17
+ # Optionally, try installing if running in a suitable environment
18
+ # import os
19
+ # os.system("pip install git+https://github.com/UNION-AI-Research/FILL-Context-Aware-Outpainting.git") # Or wherever the package is hosted
20
+ # Re-try import might be needed depending on environment setup
21
+ exit()
22
+
23
 
24
  from PIL import Image, ImageDraw
25
  import numpy as np
26
+ import os # For checking example files
27
+
28
+ # --- Model Loading ---
29
+ # Use environment variable for model cache if needed
30
+ # HUGGINGFACE_HUB_CACHE = os.environ.get("HUGGINGFACE_HUB_CACHE", None)
31
+
32
+ try:
33
+ config_file = hf_hub_download(
34
+ "xinsir/controlnet-union-sdxl-1.0",
35
+ filename="config_promax.json",
36
+ # cache_dir=HUGGINGFACE_HUB_CACHE
37
+ )
38
+
39
+ config = ControlNetModel_Union.load_config(config_file)
40
+ controlnet_model = ControlNetModel_Union.from_config(config)
41
+ model_file = hf_hub_download(
42
+ "xinsir/controlnet-union-sdxl-1.0",
43
+ filename="diffusion_pytorch_model_promax.safetensors",
44
+ # cache_dir=HUGGINGFACE_HUB_CACHE
45
+ )
46
 
47
+ sstate_dict = load_state_dict(model_file)
48
+ model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
49
+ controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
50
+ )
51
+ model.to(device="cuda", dtype=torch.float16)
52
+ print("ControlNet loaded successfully.")
53
+
54
+ vae = AutoencoderKL.from_pretrained(
55
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, # cache_dir=HUGGINGFACE_HUB_CACHE
56
+ ).to("cuda")
57
+ print("VAE loaded successfully.")
58
+
59
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
60
+ "SG161222/RealVisXL_V5.0_Lightning",
61
+ torch_dtype=torch.float16,
62
+ vae=vae,
63
+ controlnet=model,
64
+ variant="fp16",
65
+ # cache_dir=HUGGINGFACE_HUB_CACHE
66
+ ).to("cuda")
67
+ print("Pipeline loaded successfully.")
68
+
69
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
70
+ print("Scheduler configured.")
71
+
72
+ except Exception as e:
73
+ print(f"Error during model loading: {e}")
74
+ raise e
75
+
76
+ # --- Helper Functions ---
 
 
 
 
 
77
  def can_expand(source_width, source_height, target_width, target_height, alignment):
78
  """Checks if the image can be expanded based on the alignment."""
79
  if alignment in ("Left", "Right") and source_width >= target_width:
 
83
  return True
84
 
85
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
86
+ if image is None:
87
+ raise gr.Error("Input image not provided.")
88
+ try:
89
+ target_size = (width, height)
90
+
91
+ # Calculate the scaling factor to fit the image within the target size
92
+ scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
93
+ new_width = int(image.width * scale_factor)
94
+ new_height = int(image.height * scale_factor)
95
+
96
+ # Resize the source image to fit within target size
97
+ source = image.resize((new_width, new_height), Image.LANCZOS)
98
+
99
+ # Apply resize option using percentages
100
+ if resize_option == "Full":
101
+ resize_percentage = 100
102
+ elif resize_option == "50%":
103
+ resize_percentage = 50
104
+ elif resize_option == "33%":
105
+ resize_percentage = 33
106
+ elif resize_option == "25%":
107
+ resize_percentage = 25
108
+ elif resize_option == "Custom":
109
+ resize_percentage = custom_resize_percentage
110
+ else:
111
+ raise ValueError(f"Invalid resize option: {resize_option}")
112
+
113
+
114
+ # Calculate new dimensions based on percentage
115
+ resize_factor = resize_percentage / 100
116
+ new_width = int(source.width * resize_factor)
117
+ new_height = int(source.height * resize_factor)
118
+
119
+ # Ensure minimum size of 64 pixels
120
+ new_width = max(new_width, 64)
121
+ new_height = max(new_height, 64)
122
+
123
+ # Ensure dimensions fit within target (can happen if original image is tiny and resize % is large)
124
+ new_width = min(new_width, target_size[0])
125
+ new_height = min(new_height, target_size[1])
126
+
127
+ # Resize the image
128
+ source = source.resize((new_width, new_height), Image.LANCZOS)
129
+
130
+ # Calculate the overlap in pixels based on the percentage
131
+ overlap_x = int(new_width * (overlap_percentage / 100))
132
+ overlap_y = int(new_height * (overlap_percentage / 100))
133
+
134
+ # Ensure minimum overlap of 1 pixel if overlap is enabled, otherwise 0
135
+ overlap_x = max(overlap_x, 1) if overlap_left or overlap_right else 0
136
+ overlap_y = max(overlap_y, 1) if overlap_top or overlap_bottom else 0
137
+
138
+ # Calculate margins based on alignment
139
+ if alignment == "Middle":
140
+ margin_x = (target_size[0] - new_width) // 2
141
+ margin_y = (target_size[1] - new_height) // 2
142
+ elif alignment == "Left":
143
+ margin_x = 0
144
+ margin_y = (target_size[1] - new_height) // 2
145
+ elif alignment == "Right":
146
+ margin_x = target_size[0] - new_width
147
+ margin_y = (target_size[1] - new_height) // 2
148
+ elif alignment == "Top":
149
+ margin_x = (target_size[0] - new_width) // 2
150
+ margin_y = 0
151
+ elif alignment == "Bottom":
152
+ margin_x = (target_size[0] - new_width) // 2
153
+ margin_y = target_size[1] - new_height
154
+ else:
155
+ raise ValueError(f"Invalid alignment: {alignment}")
156
+
157
+
158
+ # Adjust margins to ensure image is fully within bounds (should be redundant with min check above)
159
+ margin_x = max(0, min(margin_x, target_size[0] - new_width))
160
+ margin_y = max(0, min(margin_y, target_size[1] - new_height))
161
+
162
+ # Create a new background image and paste the resized source image
163
+ background = Image.new('RGB', target_size, (255, 255, 255)) # White background
164
+ background.paste(source, (margin_x, margin_y))
165
+
166
+ # Create the mask (initially all black - meaning keep everything)
167
+ mask_np = np.zeros(target_size[::-1], dtype=np.uint8) # Use numpy for easier slicing [::-1] for (height, width)
168
+
169
+ # Calculate the coordinates of the *source image* area within the target canvas
170
+ source_left = margin_x
171
+ source_top = margin_y
172
+ source_right = margin_x + new_width
173
+ source_bottom = margin_y + new_height
174
+
175
+ # Calculate the coordinates of the *unmasked* area (area to keep from source)
176
+ unmasked_left = source_left + overlap_x if overlap_left else source_left
177
+ unmasked_top = source_top + overlap_y if overlap_top else source_top
178
+ unmasked_right = source_right - overlap_x if overlap_right else source_right
179
+ unmasked_bottom = source_bottom - overlap_y if overlap_bottom else source_bottom
180
+
181
+ # Special handling for edge alignments to ensure the edge itself is kept if overlap disabled
182
+ if alignment == "Left" and not overlap_left:
183
+ unmasked_left = source_left
184
+ if alignment == "Right" and not overlap_right:
185
+ unmasked_right = source_right
186
+ if alignment == "Top" and not overlap_top:
187
+ unmasked_top = source_top
188
+ if alignment == "Bottom" and not overlap_bottom:
189
+ unmasked_bottom = source_bottom
190
+
191
+ # Ensure coordinates are valid and clipped to the source image area within the canvas
192
+ unmasked_left = max(source_left, min(unmasked_left, source_right))
193
+ unmasked_top = max(source_top, min(unmasked_top, source_bottom))
194
+ unmasked_right = max(source_left, min(unmasked_right, source_right))
195
+ unmasked_bottom = max(source_top, min(unmasked_bottom, source_bottom))
196
+
197
+ # Create the final mask: White (255) = Area to inpaint/outpaint, Black (0) = Area to keep
198
+ final_mask_np = np.ones(target_size[::-1], dtype=np.uint8) * 255 # Start with all white (change everything)
199
+ if unmasked_right > unmasked_left and unmasked_bottom > unmasked_top:
200
+ # Set the area to keep (calculated unmasked rectangle) to black (0)
201
+ final_mask_np[unmasked_top:unmasked_bottom, unmasked_left:unmasked_right] = 0
202
+
203
+ mask = Image.fromarray(final_mask_np)
204
+
205
+ return background, mask
206
+ except Exception as e:
207
+ print(f"Error in prepare_image_and_mask: {e}")
208
+ raise gr.Error(f"Failed to prepare image and mask: {e}")
209
+
210
 
211
  def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
212
+ if image is None:
213
+ return None # Or return a placeholder image/message
214
+ try:
215
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
216
+
217
+ # Create a preview image showing the mask
218
+ preview = background.copy().convert('RGBA')
219
+
220
+ # Create a semi-transparent red overlay for the masked (inpainting/outpainting) area
221
+ red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 100)) # 100 alpha (~40% opacity)
222
+
223
+ # The mask is white (255) where outpainting happens. Use this directly.
224
+ preview.paste(red_overlay, (0, 0), mask) # Paste red where mask is white
225
+
226
+ return preview
227
+ except Exception as e:
228
+ print(f"Error during preview generation: {e}")
229
+ # Return the original background or an error placeholder
230
+ if 'background' in locals():
231
+ return background.convert('RGBA')
232
+ else:
233
+ return Image.new('RGBA', (width, height), (200, 200, 200, 255)) # Grey placeholder
234
+
235
+
236
+ @spaces.GPU(duration=60) # Adjusted duration slightly
237
+ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom, progress=gr.Progress(track_tqdm=True)):
238
+ if image is None:
239
+ raise gr.Error("Please provide an input image.")
240
+
241
+ try:
242
+ # --- Preparation ---
243
+ progress(0.1, desc="Preparing image and mask...")
244
+ original_alignment = alignment
245
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
246
+
247
+ # --- Alignment Check & Correction ---
248
+ # Get dimensions *after* initial placement and resize
249
+ pasted_source_img_width = int(image.width * min(width / image.width, height / image.height) * (custom_resize_percentage if resize_option=='Custom' else {'Full':100, '50%':50, '33%':33, '25%':25}[resize_option])/100)
250
+ pasted_source_img_height = int(image.height * min(width / image.width, height / image.height) * (custom_resize_percentage if resize_option=='Custom' else {'Full':100, '50%':50, '33%':33, '25%':25}[resize_option])/100)
251
+ pasted_source_img_width = max(64, min(pasted_source_img_width, width))
252
+ pasted_source_img_height = max(64, min(pasted_source_img_height, height))
253
+
254
+ needs_reprepare = False
255
+ if alignment in ("Left", "Right") and pasted_source_img_width >= width:
256
+ print(f"Warning: Source width ({pasted_source_img_width}) >= target width ({width}) with {alignment} alignment. Forcing Middle alignment.")
257
+ alignment = "Middle"
258
+ needs_reprepare = True
259
+ if alignment in ("Top", "Bottom") and pasted_source_img_height >= height:
260
+ print(f"Warning: Source height ({pasted_source_img_height}) >= target height ({height}) with {alignment} alignment. Forcing Middle alignment.")
261
+ alignment = "Middle"
262
+ needs_reprepare = True
263
+
264
+ if needs_reprepare and alignment != original_alignment:
265
+ print("Re-preparing mask due to alignment change.")
266
+ progress(0.15, desc="Re-preparing mask for Middle alignment...")
267
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
268
+
269
+ # ControlNet expects the image with the *original* content visible in the non-masked area
270
+ cnet_image = background.copy()
271
+ # In some ControlNet inpainting setups, you might mask the control image too,
272
+ # but Union ControlNet Fill often works well with the unmasked source pasted onto the background.
273
+ # cnet_image.paste(0, mask=ImageOps.invert(mask)) # Optional: Black out masked area in CNet image
274
+
275
+ # --- Prompt Encoding ---
276
+ progress(0.2, desc="Encoding prompt...")
277
+ final_prompt = f"{prompt_input}, high quality, 4k" if prompt_input else "high quality, 4k" # Add default tags if no prompt
278
+ negative_prompt = "low quality, blurry, noisy, text, words, letters, watermark, signature, username, artist name, deformed, distorted, disfigured, bad anatomy, extra limbs, missing limbs"
279
+
280
+
281
+ # Note: TCD/Lightning pipelines often work better *without* explicit negative prompts encoded
282
+ # Try encoding only the positive prompt first
283
+ (
284
+ prompt_embeds,
285
+ _, # negative_prompt_embeds (set to None or handle differently for TCD)
286
+ pooled_prompt_embeds,
287
+ _, # negative_pooled_prompt_embeds
288
+ ) = pipe.encode_prompt(final_prompt, "cuda", False) # do_classifier_free_guidance=False for TCD
289
+
290
+
291
+ # --- Inference ---
292
+ progress(0.3, desc="Starting diffusion process...")
293
+ print(f"Running inference with {num_inference_steps} steps...")
294
+ pipeline_output = pipe(
295
+ prompt_embeds=prompt_embeds,
296
+ negative_prompt_embeds=None, # Pass None for TCD/Lightning
297
+ pooled_prompt_embeds=pooled_prompt_embeds,
298
+ negative_pooled_prompt_embeds=None, # Pass None for TCD/Lightning
299
+ image=background, # Initial state for masked area (background with source)
300
+ mask_image=mask, # Mask (white = change)
301
+ control_image=cnet_image, # ControlNet input
302
+ num_inference_steps=num_inference_steps,
303
+ guidance_scale=0.0, # Crucial for TCD/Lightning
304
+ controlnet_conditioning_scale=0.8, # Default for FILL pipeline, adjust if needed
305
+ output_type="pil" # Ensure PIL output
306
+ # Add tqdm=True if supported by the custom pipeline and using gr.Progress without track_tqdm
307
+ )
308
+
309
+ # --- Process Output ---
310
+ progress(0.9, desc="Processing results...")
311
+ # Check if the pipeline returned a standard output object or a generator
312
+ output_image = None
313
+ if hasattr(pipeline_output, 'images'): # Standard diffusers output
314
+ print("Pipeline returned a standard output object.")
315
+ if len(pipeline_output.images) > 0:
316
+ output_image = pipeline_output.images[0]
317
+ else:
318
+ raise ValueError("Pipeline output contained no images.")
319
+ # Check if it's iterable (generator) - less likely with direct call and output_type='pil' but good practice
320
+ elif hasattr(pipeline_output, '__iter__') and not isinstance(pipeline_output, dict):
321
+ print("Pipeline returned a generator, iterating to get the final image.")
322
+ last_item = None
323
+ for item in pipeline_output:
324
+ last_item = item
325
+ # Try to extract image from the last yielded item (structure can vary)
326
+ if isinstance(last_item, tuple) and len(last_item) > 0 and isinstance(last_item[0], Image.Image):
327
+ output_image = last_item[0]
328
+ elif isinstance(last_item, dict) and 'images' in last_item and len(last_item['images']) > 0:
329
+ output_image = last_item['images'][0]
330
+ elif isinstance(last_item, Image.Image):
331
+ output_image = last_item
332
+ elif hasattr(last_item, 'images') and len(last_item.images) > 0: # Handle case where object yielded early
333
+ output_image = last_item.images[0]
334
+
335
+ if output_image is None:
336
+ raise ValueError("Pipeline generator did not yield a valid final image structure.")
337
+ else:
338
+ raise TypeError(f"Unexpected pipeline output type: {type(pipeline_output)}. Cannot extract image.")
339
+
340
+ print("Inference complete.")
341
+ progress(1.0, desc="Done!")
342
+ return output_image
343
+
344
+ except Exception as e:
345
+ print(f"Error during inference: {e}")
346
+ import traceback
347
+ traceback.print_exc() # Print full traceback to console/logs
348
+ raise gr.Error(f"Inference failed: {e}")
349
+
350
+
351
+ def clear_result(*args):
352
+ """Clears the result Image and related components."""
353
+ updates = {
354
+ result: gr.update(value=None),
355
+ use_as_input_button: gr.update(visible=False),
356
+ }
357
+ # If preview image is passed as an arg, clear it too
358
+ if len(args) > 0 and isinstance(args[0], gr.Image):
359
+ updates[args[0]] = gr.update(value=None) # Assuming preview_image is the first optional arg
360
+ return updates
361
+
362
+
363
+ # --- UI Helper Functions ---
364
  def preload_presets(target_ratio, ui_width, ui_height):
365
  """Updates the width and height sliders based on the selected aspect ratio."""
366
+ settings_update = gr.update() # Default: no change to accordion state
367
  if target_ratio == "9:16":
368
  changed_width = 720
369
  changed_height = 1280
 
370
  elif target_ratio == "16:9":
371
  changed_width = 1280
372
  changed_height = 720
 
373
  elif target_ratio == "1:1":
374
  changed_width = 1024
375
  changed_height = 1024
 
376
  elif target_ratio == "Custom":
377
+ changed_width = ui_width # Keep current slider values
378
+ changed_height = ui_height
379
+ settings_update = gr.update(open=True) # Open accordion for custom
380
+ else: # Should not happen
381
+ changed_width = ui_width
382
+ changed_height = ui_height
383
+
384
+ return changed_width, changed_height, settings_update
385
 
386
  def select_the_right_preset(user_width, user_height):
387
  """Updates the radio button based on the current slider values."""
 
400
 
401
  def update_history(new_image, history):
402
  """Updates the history gallery with the new image."""
403
+ if not isinstance(new_image, Image.Image): # Don't add if generation failed (None)
404
+ return history or [] # Return current or empty list
405
+
406
  if history is None:
407
  history = []
408
+ history.insert(0, new_image)
409
+ # Limit history size (optional)
410
+ max_history = 12
411
+ if len(history) > max_history:
412
+ history = history[:max_history]
413
  return history
414
 
415
  # --- Gradio UI Definition ---
416
  css = """
417
  .gradio-container {
418
+ max-width: 1200px !important; /* Use max-width for responsiveness */
419
  margin: auto !important; /* Center the container */
420
+ padding: 10px; /* Add some padding */
421
  }
422
+ h1 { text-align: center; margin-bottom: 15px;}
423
+ footer { display: none !important; /* More reliable way to hide footer */ }
424
+
425
  /* Ensure result image takes reasonable space */
426
  #result-image img {
427
  max-height: 768px; /* Adjust max height as needed */
428
  object-fit: contain;
429
+ width: 100%; /* Allow image to use column width */
430
  height: auto;
431
+ display: block; /* Prevent extra space below image */
432
+ margin: auto; /* Center image within its container */
433
+ }
434
+ #input-image img {
435
+ max-height: 400px;
436
+ object-fit: contain;
437
+ width: 100%;
438
+ height: auto;
439
+ display: block;
440
+ margin: auto;
441
+ }
442
+ #preview-image img {
443
+ max-height: 250px; /* Smaller preview */
444
+ object-fit: contain;
445
+ width: 100%;
446
+ height: auto;
447
+ display: block;
448
+ margin: auto;
449
  }
450
+
451
  #history-gallery .thumbnail-item { /* Style history items */
452
  height: 100px !important;
453
+ overflow: hidden; /* Hide overflow */
454
  }
455
  #history-gallery .gallery {
456
  grid-template-rows: repeat(auto-fill, 100px) !important;
457
+ gap: 4px !important; /* Add small gap */
458
+ }
459
+ #history-gallery .thumbnail-item img {
460
+ object-fit: contain !important; /* Ensure history previews fit */
461
+ height: 100%;
462
+ width: 100%;
463
  }
464
 
465
+ /* Make Checkboxes smaller and closer */
466
+ .gradio-checkboxgroup .wrap {
467
+ gap: 0.5rem 1rem !important; /* Adjust spacing */
468
+ }
469
+ .gradio-checkbox label span {
470
+ font-size: 0.9em; /* Slightly smaller label text */
471
+ }
472
+ .gradio-checkbox input {
473
+ transform: scale(0.9); /* Slightly smaller checkbox */
474
+ }
475
+
476
+ /* Style Accordion */
477
+ .gradio-accordion .label-wrap { /* Target the label wrapper */
478
+ border: 1px solid #e0e0e0;
479
+ border-radius: 5px;
480
+ padding: 8px 12px;
481
+ background-color: #f9f9f9;
482
+ }
483
  """
484
 
485
+ title = """<h1 align="center">🖼️ Diffusers Image Outpaint Lightning ⚡</h1>"""
486
 
487
+ # --- Example Files Handling ---
488
+ # Create examples directory if it doesn't exist
489
+ if not os.path.exists("./examples"):
490
+ os.makedirs("./examples")
491
 
492
+ # Check for example images and provide defaults or placeholders if missing
493
+ example_files = {
494
+ "ex1": "./examples/example_1.webp",
495
+ "ex2": "./examples/example_2.jpg",
496
+ "ex3": "./examples/example_3.jpg"
497
+ }
498
+ default_image_path = None # Will be set to the first available example
499
+
500
+ # You might want to download example images if they don't exist
501
+ # from huggingface_hub import hf_hub_download
502
+ # def download_example(repo_id, filename, local_path):
503
+ # if not os.path.exists(local_path):
504
+ # try:
505
+ # hf_hub_download(repo_id=repo_id, filename=filename, local_dir="./examples", local_dir_use_symlinks=False)
506
+ # print(f"Downloaded {filename}")
507
+ # except Exception as e:
508
+ # print(f"Failed to download example {filename}: {e}")
509
+ # return False # Indicate failure
510
+ # return os.path.exists(local_path)
511
+
512
+ # Example: download_example("path/to/your/example-repo", "example_1.webp", example_files["ex1"])
513
+ # For now, we just check existence
514
+
515
+ examples_available = {key: os.path.exists(path) for key, path in example_files.items()}
516
+
517
+ example_list = []
518
+ if examples_available["ex1"]:
519
+ example_list.append([example_files["ex1"], "A wide landscape view of the mountains", 1280, 720, "Middle"])
520
+ if default_image_path is None: default_image_path = example_files["ex1"]
521
+ if examples_available["ex2"]:
522
+ example_list.append([example_files["ex2"], "Full body shot of the astronaut on the moon", 720, 1280, "Middle"])
523
+ if default_image_path is None: default_image_path = example_files["ex2"]
524
+ if examples_available["ex3"]:
525
+ example_list.append([example_files["ex3"], "Expanding the sky and ground around the subject", 1024, 1024, "Middle"])
526
+ example_list.append([example_files["ex3"], "Expanding downwards from the subject", 1024, 1024, "Top"])
527
+ example_list.append([example_files["ex3"], "Expanding upwards from the subject", 1024, 1024, "Bottom"])
528
+ if default_image_path is None: default_image_path = example_files["ex3"]
529
+
530
+ if not example_list:
531
+ print("Warning: No example images found in ./examples/. Examples section will be empty.")
532
+ # Optionally create a placeholder image
533
+ # placeholder = Image.new('RGB', (512, 512), color = 'grey')
534
+ # placeholder_path = "./examples/placeholder.png"
535
+ # placeholder.save(placeholder_path)
536
+ # example_list.append([placeholder_path, "Placeholder", 1024, 1024, "Middle"])
537
+ # default_image_path = placeholder_path
538
+
539
+ # --- UI ---
540
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added a theme
541
+ gr.HTML(title)
542
+
543
+ with gr.Row():
544
+ with gr.Column(scale=1): # Left column for inputs
545
+ input_image = gr.Image(
546
+ value=default_image_path, # Load default example
547
+ type="pil",
548
+ label="Input Image",
549
+ elem_id="input-image"
550
+ )
551
+
552
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the scene to expand (optional but recommended)...", lines=2)
553
+
554
+ with gr.Row():
555
+ target_ratio = gr.Radio(
556
+ label="Target Aspect Ratio",
557
+ choices=["9:16", "16:9", "1:1", "Custom"],
558
+ value="9:16",
559
+ scale=2
560
+ )
561
+ alignment_dropdown = gr.Dropdown(
562
+ choices=["Middle", "Left", "Right", "Top", "Bottom"],
563
+ value="Middle",
564
+ label="Align Source Image",
565
+ scale=1
566
  )
567
 
568
+ with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
569
  with gr.Row():
570
+ width_slider = gr.Slider(
571
+ label="Target Width", minimum=512, maximum=2048, step=64, value=720
572
+ )
573
+ height_slider = gr.Slider(
574
+ label="Target Height", minimum=512, maximum=2048, step=64, value=1280
575
+ )
576
+ num_inference_steps = gr.Slider(
577
+ label="Steps (TCD/Lightning: 1-8)", minimum=1, maximum=12, step=1, value=4
578
+ )
579
 
580
+ with gr.Group():
581
+ overlap_percentage = gr.Slider(
582
+ label="Mask Overlap with Source (%)", minimum=0, maximum=50, value=12, step=1
 
 
 
583
  )
584
+ gr.Markdown("Select edges to overlap:", scale=0) # Add context
585
+ with gr.Row(elem_classes="gradio-checkboxgroup"): # Apply CSS class
586
+ overlap_top = gr.Checkbox(label="Top", value=True, scale=1)
587
+ overlap_bottom = gr.Checkbox(label="Bottom", value=True, scale=1)
588
+ overlap_left = gr.Checkbox(label="Left", value=True, scale=1)
589
+ overlap_right = gr.Checkbox(label="Right", value=True, scale=1)
590
 
 
 
 
 
 
591
 
592
+ with gr.Row():
593
+ resize_option = gr.Radio(
594
+ label="Resize source within target",
595
+ choices=["Full", "50%", "33%", "25%", "Custom"],
596
+ value="Full",
597
+ scale=2
598
+ )
599
+ custom_resize_percentage = gr.Slider(
600
+ label="Custom resize (%)", minimum=1, maximum=100, step=1, value=50, visible=False, scale=1
601
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
 
603
+ preview_button = gr.Button("Preview Mask & Alignment")
604
+ preview_image = gr.Image(label="Mask Preview (Red = Outpaint Area)", type="pil", interactive=False, elem_id="preview-image")
605
 
606
+ if example_list:
607
  gr.Examples(
608
+ examples=example_list,
 
 
 
 
 
 
609
  inputs=[input_image, prompt_input, width_slider, height_slider, alignment_dropdown],
610
+ label="Examples (Click to load)",
611
+ examples_per_page=10
612
  )
613
+ else:
614
+ gr.Markdown("_(No example files found in ./examples)_")
615
 
616
+ run_button = gr.Button("Generate", variant="primary")
617
 
 
 
 
 
618
 
619
+ with gr.Column(scale=1): # Right column for output
620
+ result = gr.Image(label="Generated Image", type="pil", interactive=False, elem_id="result-image")
621
+ use_as_input_button = gr.Button("Use Result as Input Image", visible=False)
 
 
 
 
 
622
 
623
+ history_gallery = gr.Gallery(
624
+ label="History", columns=6, object_fit="contain", interactive=False,
625
+ height=110, elem_id="history-gallery"
626
+ )
627
 
628
  # --- Event Handling ---
629
 
630
+ # Function to set result as input and clear result area
631
+ def use_output_as_input_and_clear(output_image):
632
+ return {
633
+ input_image: gr.update(value=output_image),
634
+ result: gr.update(value=None), # Clear result after using it
635
+ use_as_input_button: gr.update(visible=False) # Hide button again
636
+ }
637
 
638
  use_as_input_button.click(
639
+ fn=use_output_as_input_and_clear,
640
+ inputs=[result],
641
+ outputs=[input_image, result, use_as_input_button]
642
  )
643
 
644
  target_ratio.change(
645
  fn=preload_presets,
646
  inputs=[target_ratio, width_slider, height_slider],
647
+ outputs=[width_slider, height_slider, settings_panel],
648
  queue=False
649
  )
650
 
 
651
  width_slider.change(
652
  fn=select_the_right_preset,
653
  inputs=[width_slider, height_slider],
 
674
  resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
675
  overlap_left, overlap_right, overlap_top, overlap_bottom
676
  ]
677
+ gen_outputs = [result] # Single output image
678
 
679
+ # Chain generation logic for Run button
680
+ run_trigger = run_button.click(
681
+ fn=clear_result, # Clear previous result first
682
+ inputs=[], # No inputs needed for clear
683
+ outputs=[result, use_as_input_button], # Components to clear/hide
684
+ queue=False
685
  ).then(
686
  fn=infer,
687
  inputs=gen_inputs,
688
+ outputs=gen_outputs,
689
+ )
690
+
691
+ # After generation finishes (successfully or not), update history and button visibility
692
+ run_trigger.then(
693
  fn=lambda res_img, hist: update_history(res_img, hist),
694
  inputs=[result, history_gallery],
695
  outputs=[history_gallery],
696
+ queue=False # Update history immediately
697
  ).then(
698
+ # Show the 'Use as Input' button only if generation was successful (result is not None)
699
+ fn=lambda res_img: gr.update(visible=isinstance(res_img, Image.Image)),
700
+ inputs=[result],
701
  outputs=[use_as_input_button],
702
  queue=False # Show button immediately
703
  )
704
 
705
+
706
+ # Chain generation logic for Enter key in Prompt textbox
707
+ submit_trigger = prompt_input.submit(
708
  fn=clear_result,
709
+ inputs=[],
710
+ outputs=[result, use_as_input_button],
711
  queue=False
712
  ).then(
713
  fn=infer,
714
  inputs=gen_inputs,
715
+ outputs=gen_outputs,
716
+ )
717
+
718
+ submit_trigger.then(
719
  fn=lambda res_img, hist: update_history(res_img, hist),
720
  inputs=[result, history_gallery],
721
  outputs=[history_gallery],
722
  queue=False
723
  ).then(
724
+ fn=lambda res_img: gr.update(visible=isinstance(res_img, Image.Image)),
725
+ inputs=[result],
726
  outputs=[use_as_input_button],
727
  queue=False
728
  )
729
 
730
+ # Preview button logic
731
+ preview_inputs = [
732
+ input_image, width_slider, height_slider, overlap_percentage, resize_option,
733
+ custom_resize_percentage, alignment_dropdown, overlap_left, overlap_right,
734
+ overlap_top, overlap_bottom
735
+ ]
736
  preview_button.click(
737
  fn=preview_image_and_mask,
738
+ inputs=preview_inputs,
 
739
  outputs=preview_image,
740
+ queue=False
741
  )
742
 
743
+ # Launch the interface
744
+ demo.queue(max_size=10).launch(ssr_mode=False, show_error=True, debug=True) # Add debug=True for more logs