prithivMLmods commited on
Commit
e01e01c
·
verified ·
1 Parent(s): 4074d29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +404 -602
app.py CHANGED
@@ -1,79 +1,53 @@
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
  from diffusers import AutoencoderKL, TCDScheduler
5
  from diffusers.models.model_loading_utils import load_state_dict
6
- # Removed ImageSlider import
 
7
  from huggingface_hub import hf_hub_download
8
 
9
- # Ensure these custom modules are accessible in the environment
10
- # If running locally, they should be in the same directory or installed
11
- try:
12
- from controlnet_union import ControlNetModel_Union
13
- from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
14
- except ImportError as e:
15
- print(f"Error importing custom modules: {e}")
16
- print("Please ensure 'controlnet_union.py' and 'pipeline_fill_sd_xl.py' are in the working directory or installed.")
17
- # Optionally, try installing if running in a suitable environment
18
- # import os
19
- # os.system("pip install git+https://github.com/UNION-AI-Research/FILL-Context-Aware-Outpainting.git") # Or wherever the package is hosted
20
- # Re-try import might be needed depending on environment setup
21
- exit()
22
-
23
 
24
  from PIL import Image, ImageDraw
25
  import numpy as np
26
- import os # For checking example files
27
-
28
- # --- Model Loading ---
29
- # Use environment variable for model cache if needed
30
- # HUGGINGFACE_HUB_CACHE = os.environ.get("HUGGINGFACE_HUB_CACHE", None)
31
 
32
- try:
33
- config_file = hf_hub_download(
34
- "xinsir/controlnet-union-sdxl-1.0",
35
- filename="config_promax.json",
36
- # cache_dir=HUGGINGFACE_HUB_CACHE
37
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- config = ControlNetModel_Union.load_config(config_file)
40
- controlnet_model = ControlNetModel_Union.from_config(config)
41
- model_file = hf_hub_download(
42
- "xinsir/controlnet-union-sdxl-1.0",
43
- filename="diffusion_pytorch_model_promax.safetensors",
44
- # cache_dir=HUGGINGFACE_HUB_CACHE
45
- )
46
-
47
- sstate_dict = load_state_dict(model_file)
48
- model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
49
- controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
50
- )
51
- model.to(device="cuda", dtype=torch.float16)
52
- print("ControlNet loaded successfully.")
53
-
54
- vae = AutoencoderKL.from_pretrained(
55
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, # cache_dir=HUGGINGFACE_HUB_CACHE
56
- ).to("cuda")
57
- print("VAE loaded successfully.")
58
-
59
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
60
- "SG161222/RealVisXL_V5.0_Lightning",
61
- torch_dtype=torch.float16,
62
- vae=vae,
63
- controlnet=model,
64
- variant="fp16",
65
- # cache_dir=HUGGINGFACE_HUB_CACHE
66
- ).to("cuda")
67
- print("Pipeline loaded successfully.")
68
-
69
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
70
- print("Scheduler configured.")
71
-
72
- except Exception as e:
73
- print(f"Error during model loading: {e}")
74
- raise e
75
-
76
- # --- Helper Functions ---
77
  def can_expand(source_width, source_height, target_width, target_height, alignment):
78
  """Checks if the image can be expanded based on the alignment."""
79
  if alignment in ("Left", "Right") and source_width >= target_width:
@@ -83,308 +57,212 @@ def can_expand(source_width, source_height, target_width, target_height, alignme
83
  return True
84
 
85
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
86
- if image is None:
87
- raise gr.Error("Input image not provided.")
88
- try:
89
- target_size = (width, height)
90
-
91
- # Calculate the scaling factor to fit the image within the target size
92
- scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
93
- new_width = int(image.width * scale_factor)
94
- new_height = int(image.height * scale_factor)
95
-
96
- # Resize the source image to fit within target size
97
- source = image.resize((new_width, new_height), Image.LANCZOS)
98
-
99
- # Apply resize option using percentages
100
- if resize_option == "Full":
101
- resize_percentage = 100
102
- elif resize_option == "50%":
103
- resize_percentage = 50
104
- elif resize_option == "33%":
105
- resize_percentage = 33
106
- elif resize_option == "25%":
107
- resize_percentage = 25
108
- elif resize_option == "Custom":
109
- resize_percentage = custom_resize_percentage
110
- else:
111
- raise ValueError(f"Invalid resize option: {resize_option}")
112
-
113
-
114
- # Calculate new dimensions based on percentage
115
- resize_factor = resize_percentage / 100
116
- new_width = int(source.width * resize_factor)
117
- new_height = int(source.height * resize_factor)
118
-
119
- # Ensure minimum size of 64 pixels
120
- new_width = max(new_width, 64)
121
- new_height = max(new_height, 64)
122
-
123
- # Ensure dimensions fit within target (can happen if original image is tiny and resize % is large)
124
- new_width = min(new_width, target_size[0])
125
- new_height = min(new_height, target_size[1])
126
-
127
- # Resize the image
128
- source = source.resize((new_width, new_height), Image.LANCZOS)
129
-
130
- # Calculate the overlap in pixels based on the percentage
131
- overlap_x = int(new_width * (overlap_percentage / 100))
132
- overlap_y = int(new_height * (overlap_percentage / 100))
133
-
134
- # Ensure minimum overlap of 1 pixel if overlap is enabled, otherwise 0
135
- overlap_x = max(overlap_x, 1) if overlap_left or overlap_right else 0
136
- overlap_y = max(overlap_y, 1) if overlap_top or overlap_bottom else 0
137
-
138
- # Calculate margins based on alignment
139
- if alignment == "Middle":
140
- margin_x = (target_size[0] - new_width) // 2
141
- margin_y = (target_size[1] - new_height) // 2
142
- elif alignment == "Left":
143
- margin_x = 0
144
- margin_y = (target_size[1] - new_height) // 2
145
- elif alignment == "Right":
146
- margin_x = target_size[0] - new_width
147
- margin_y = (target_size[1] - new_height) // 2
148
- elif alignment == "Top":
149
- margin_x = (target_size[0] - new_width) // 2
150
- margin_y = 0
151
- elif alignment == "Bottom":
152
- margin_x = (target_size[0] - new_width) // 2
153
- margin_y = target_size[1] - new_height
154
- else:
155
- raise ValueError(f"Invalid alignment: {alignment}")
156
-
157
-
158
- # Adjust margins to ensure image is fully within bounds (should be redundant with min check above)
159
- margin_x = max(0, min(margin_x, target_size[0] - new_width))
160
- margin_y = max(0, min(margin_y, target_size[1] - new_height))
161
-
162
- # Create a new background image and paste the resized source image
163
- background = Image.new('RGB', target_size, (255, 255, 255)) # White background
164
- background.paste(source, (margin_x, margin_y))
165
-
166
- # Create the mask (initially all black - meaning keep everything)
167
- mask_np = np.zeros(target_size[::-1], dtype=np.uint8) # Use numpy for easier slicing [::-1] for (height, width)
168
-
169
- # Calculate the coordinates of the *source image* area within the target canvas
170
- source_left = margin_x
171
- source_top = margin_y
172
- source_right = margin_x + new_width
173
- source_bottom = margin_y + new_height
174
-
175
- # Calculate the coordinates of the *unmasked* area (area to keep from source)
176
- unmasked_left = source_left + overlap_x if overlap_left else source_left
177
- unmasked_top = source_top + overlap_y if overlap_top else source_top
178
- unmasked_right = source_right - overlap_x if overlap_right else source_right
179
- unmasked_bottom = source_bottom - overlap_y if overlap_bottom else source_bottom
180
-
181
- # Special handling for edge alignments to ensure the edge itself is kept if overlap disabled
182
- if alignment == "Left" and not overlap_left:
183
- unmasked_left = source_left
184
- if alignment == "Right" and not overlap_right:
185
- unmasked_right = source_right
186
- if alignment == "Top" and not overlap_top:
187
- unmasked_top = source_top
188
- if alignment == "Bottom" and not overlap_bottom:
189
- unmasked_bottom = source_bottom
190
-
191
- # Ensure coordinates are valid and clipped to the source image area within the canvas
192
- unmasked_left = max(source_left, min(unmasked_left, source_right))
193
- unmasked_top = max(source_top, min(unmasked_top, source_bottom))
194
- unmasked_right = max(source_left, min(unmasked_right, source_right))
195
- unmasked_bottom = max(source_top, min(unmasked_bottom, source_bottom))
196
-
197
- # Create the final mask: White (255) = Area to inpaint/outpaint, Black (0) = Area to keep
198
- final_mask_np = np.ones(target_size[::-1], dtype=np.uint8) * 255 # Start with all white (change everything)
199
- if unmasked_right > unmasked_left and unmasked_bottom > unmasked_top:
200
- # Set the area to keep (calculated unmasked rectangle) to black (0)
201
- final_mask_np[unmasked_top:unmasked_bottom, unmasked_left:unmasked_right] = 0
202
-
203
- mask = Image.fromarray(final_mask_np)
204
-
205
- return background, mask
206
- except Exception as e:
207
- print(f"Error in prepare_image_and_mask: {e}")
208
- raise gr.Error(f"Failed to prepare image and mask: {e}")
209
-
210
 
211
  def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
212
- if image is None:
213
- return None # Or return a placeholder image/message
214
- try:
215
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
216
 
217
- # Create a preview image showing the mask
218
- preview = background.copy().convert('RGBA')
219
 
220
- # Create a semi-transparent red overlay for the masked (inpainting/outpainting) area
221
- red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 100)) # 100 alpha (~40% opacity)
222
 
223
- # The mask is white (255) where outpainting happens. Use this directly.
224
- preview.paste(red_overlay, (0, 0), mask) # Paste red where mask is white
 
225
 
226
- return preview
227
- except Exception as e:
228
- print(f"Error during preview generation: {e}")
229
- # Return the original background or an error placeholder
230
- if 'background' in locals():
231
- return background.convert('RGBA')
232
- else:
233
- return Image.new('RGBA', (width, height), (200, 200, 200, 255)) # Grey placeholder
234
 
 
235
 
236
- @spaces.GPU(duration=60) # Adjusted duration slightly
237
- def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom, progress=gr.Progress(track_tqdm=True)):
238
  if image is None:
239
- raise gr.Error("Please provide an input image.")
 
 
240
 
241
- try:
242
- # --- Preparation ---
243
- progress(0.1, desc="Preparing image and mask...")
244
- original_alignment = alignment
 
245
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
246
 
247
- # --- Alignment Check & Correction ---
248
- # Get dimensions *after* initial placement and resize
249
- pasted_source_img_width = int(image.width * min(width / image.width, height / image.height) * (custom_resize_percentage if resize_option=='Custom' else {'Full':100, '50%':50, '33%':33, '25%':25}[resize_option])/100)
250
- pasted_source_img_height = int(image.height * min(width / image.width, height / image.height) * (custom_resize_percentage if resize_option=='Custom' else {'Full':100, '50%':50, '33%':33, '25%':25}[resize_option])/100)
251
- pasted_source_img_width = max(64, min(pasted_source_img_width, width))
252
- pasted_source_img_height = max(64, min(pasted_source_img_height, height))
253
-
254
- needs_reprepare = False
255
- if alignment in ("Left", "Right") and pasted_source_img_width >= width:
256
- print(f"Warning: Source width ({pasted_source_img_width}) >= target width ({width}) with {alignment} alignment. Forcing Middle alignment.")
257
- alignment = "Middle"
258
- needs_reprepare = True
259
- if alignment in ("Top", "Bottom") and pasted_source_img_height >= height:
260
- print(f"Warning: Source height ({pasted_source_img_height}) >= target height ({height}) with {alignment} alignment. Forcing Middle alignment.")
261
- alignment = "Middle"
262
- needs_reprepare = True
263
-
264
- if needs_reprepare and alignment != original_alignment:
265
- print("Re-preparing mask due to alignment change.")
266
- progress(0.15, desc="Re-preparing mask for Middle alignment...")
267
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
268
-
269
- # ControlNet expects the image with the *original* content visible in the non-masked area
270
- cnet_image = background.copy()
271
- # In some ControlNet inpainting setups, you might mask the control image too,
272
- # but Union ControlNet Fill often works well with the unmasked source pasted onto the background.
273
- # cnet_image.paste(0, mask=ImageOps.invert(mask)) # Optional: Black out masked area in CNet image
274
-
275
- # --- Prompt Encoding ---
276
- progress(0.2, desc="Encoding prompt...")
277
- final_prompt = f"{prompt_input}, high quality, 4k" if prompt_input else "high quality, 4k" # Add default tags if no prompt
278
- negative_prompt = "low quality, blurry, noisy, text, words, letters, watermark, signature, username, artist name, deformed, distorted, disfigured, bad anatomy, extra limbs, missing limbs"
279
-
280
-
281
- # Note: TCD/Lightning pipelines often work better *without* explicit negative prompts encoded
282
- # Try encoding only the positive prompt first
283
- (
284
- prompt_embeds,
285
- _, # negative_prompt_embeds (set to None or handle differently for TCD)
286
- pooled_prompt_embeds,
287
- _, # negative_pooled_prompt_embeds
288
- ) = pipe.encode_prompt(final_prompt, "cuda", False) # do_classifier_free_guidance=False for TCD
289
-
290
-
291
- # --- Inference ---
292
- progress(0.3, desc="Starting diffusion process...")
293
- print(f"Running inference with {num_inference_steps} steps...")
294
- pipeline_output = pipe(
295
- prompt_embeds=prompt_embeds,
296
- negative_prompt_embeds=None, # Pass None for TCD/Lightning
297
- pooled_prompt_embeds=pooled_prompt_embeds,
298
- negative_pooled_prompt_embeds=None, # Pass None for TCD/Lightning
299
- image=background, # Initial state for masked area (background with source)
300
- mask_image=mask, # Mask (white = change)
301
- control_image=cnet_image, # ControlNet input
302
- num_inference_steps=num_inference_steps,
303
- guidance_scale=0.0, # Crucial for TCD/Lightning
304
- controlnet_conditioning_scale=0.8, # Default for FILL pipeline, adjust if needed
305
- output_type="pil" # Ensure PIL output
306
- # Add tqdm=True if supported by the custom pipeline and using gr.Progress without track_tqdm
307
- )
308
-
309
- # --- Process Output ---
310
- progress(0.9, desc="Processing results...")
311
- # Check if the pipeline returned a standard output object or a generator
312
- output_image = None
313
- if hasattr(pipeline_output, 'images'): # Standard diffusers output
314
- print("Pipeline returned a standard output object.")
315
- if len(pipeline_output.images) > 0:
316
- output_image = pipeline_output.images[0]
317
- else:
318
- raise ValueError("Pipeline output contained no images.")
319
- # Check if it's iterable (generator) - less likely with direct call and output_type='pil' but good practice
320
- elif hasattr(pipeline_output, '__iter__') and not isinstance(pipeline_output, dict):
321
- print("Pipeline returned a generator, iterating to get the final image.")
322
- last_item = None
323
- for item in pipeline_output:
324
- last_item = item
325
- # Try to extract image from the last yielded item (structure can vary)
326
- if isinstance(last_item, tuple) and len(last_item) > 0 and isinstance(last_item[0], Image.Image):
327
- output_image = last_item[0]
328
- elif isinstance(last_item, dict) and 'images' in last_item and len(last_item['images']) > 0:
329
- output_image = last_item['images'][0]
330
- elif isinstance(last_item, Image.Image):
331
- output_image = last_item
332
- elif hasattr(last_item, 'images') and len(last_item.images) > 0: # Handle case where object yielded early
333
- output_image = last_item.images[0]
334
-
335
- if output_image is None:
336
- raise ValueError("Pipeline generator did not yield a valid final image structure.")
337
- else:
338
- raise TypeError(f"Unexpected pipeline output type: {type(pipeline_output)}. Cannot extract image.")
339
-
340
- print("Inference complete.")
341
- progress(1.0, desc="Done!")
342
- return output_image
343
-
344
- except Exception as e:
345
- print(f"Error during inference: {e}")
346
- import traceback
347
- traceback.print_exc() # Print full traceback to console/logs
348
- raise gr.Error(f"Inference failed: {e}")
349
-
350
-
351
- def clear_result(*args):
352
- """Clears the result Image and related components."""
353
- updates = {
354
- result: gr.update(value=None),
355
- use_as_input_button: gr.update(visible=False),
356
- }
357
- # If preview image is passed as an arg, clear it too
358
- if len(args) > 0 and isinstance(args[0], gr.Image):
359
- updates[args[0]] = gr.update(value=None) # Assuming preview_image is the first optional arg
360
- return updates
361
-
362
-
363
- # --- UI Helper Functions ---
364
  def preload_presets(target_ratio, ui_width, ui_height):
365
  """Updates the width and height sliders based on the selected aspect ratio."""
366
- settings_update = gr.update() # Default: no change to accordion state
367
  if target_ratio == "9:16":
368
  changed_width = 720
369
  changed_height = 1280
 
370
  elif target_ratio == "16:9":
371
  changed_width = 1280
372
  changed_height = 720
 
373
  elif target_ratio == "1:1":
374
  changed_width = 1024
375
  changed_height = 1024
 
376
  elif target_ratio == "Custom":
377
- changed_width = ui_width # Keep current slider values
378
- changed_height = ui_height
379
- settings_update = gr.update(open=True) # Open accordion for custom
380
- else: # Should not happen
381
- changed_width = ui_width
382
- changed_height = ui_height
383
-
384
- return changed_width, changed_height, settings_update
385
 
386
  def select_the_right_preset(user_width, user_height):
387
- """Updates the radio button based on the current slider values."""
388
  if user_width == 720 and user_height == 1280:
389
  return "9:16"
390
  elif user_width == 1280 and user_height == 720:
@@ -400,245 +278,186 @@ def toggle_custom_resize_slider(resize_option):
400
 
401
  def update_history(new_image, history):
402
  """Updates the history gallery with the new image."""
403
- if not isinstance(new_image, Image.Image): # Don't add if generation failed (None)
404
- return history or [] # Return current or empty list
405
-
406
  if history is None:
407
  history = []
 
408
  history.insert(0, new_image)
409
- # Limit history size (optional)
410
  max_history = 12
411
  if len(history) > max_history:
412
  history = history[:max_history]
413
  return history
414
 
415
- # --- Gradio UI Definition ---
 
416
  css = """
417
  .gradio-container {
418
- max-width: 1200px !important; /* Use max-width for responsiveness */
419
- margin: auto !important; /* Center the container */
420
- padding: 10px; /* Add some padding */
421
- }
422
- h1 { text-align: center; margin-bottom: 15px;}
423
- footer { display: none !important; /* More reliable way to hide footer */ }
424
-
425
- /* Ensure result image takes reasonable space */
426
- #result-image img {
427
- max-height: 768px; /* Adjust max height as needed */
428
- object-fit: contain;
429
- width: 100%; /* Allow image to use column width */
430
- height: auto;
431
- display: block; /* Prevent extra space below image */
432
- margin: auto; /* Center image within its container */
433
- }
434
- #input-image img {
435
- max-height: 400px;
436
- object-fit: contain;
437
- width: 100%;
438
- height: auto;
439
- display: block;
440
- margin: auto;
441
- }
442
- #preview-image img {
443
- max-height: 250px; /* Smaller preview */
444
- object-fit: contain;
445
- width: 100%;
446
- height: auto;
447
- display: block;
448
- margin: auto;
449
- }
450
-
451
- #history-gallery .thumbnail-item { /* Style history items */
452
- height: 100px !important;
453
- overflow: hidden; /* Hide overflow */
454
  }
455
- #history-gallery .gallery {
456
- grid-template-rows: repeat(auto-fill, 100px) !important;
457
- gap: 4px !important; /* Add small gap */
458
  }
459
- #history-gallery .thumbnail-item img {
460
- object-fit: contain !important; /* Ensure history previews fit */
461
- height: 100%;
462
- width: 100%;
463
  }
464
 
465
- /* Make Checkboxes smaller and closer */
466
- .gradio-checkboxgroup .wrap {
467
- gap: 0.5rem 1rem !important; /* Adjust spacing */
468
- }
469
- .gradio-checkbox label span {
470
- font-size: 0.9em; /* Slightly smaller label text */
471
- }
472
- .gradio-checkbox input {
473
- transform: scale(0.9); /* Slightly smaller checkbox */
474
- }
475
-
476
- /* Style Accordion */
477
- .gradio-accordion .label-wrap { /* Target the label wrapper */
478
- border: 1px solid #e0e0e0;
479
- border-radius: 5px;
480
- padding: 8px 12px;
481
- background-color: #f9f9f9;
482
- }
483
  """
484
 
485
- title = """<h1 align="center">🖼️ Diffusers Image Outpaint Lightning ⚡</h1>"""
 
 
 
 
 
 
 
 
 
486
 
487
- # --- Example Files Handling ---
488
- # Create examples directory if it doesn't exist
489
- if not os.path.exists("./examples"):
490
- os.makedirs("./examples")
491
 
492
- # Check for example images and provide defaults or placeholders if missing
493
- example_files = {
494
- "ex1": "./examples/example_1.webp",
495
- "ex2": "./examples/example_2.jpg",
496
- "ex3": "./examples/example_3.jpg"
497
- }
498
- default_image_path = None # Will be set to the first available example
499
-
500
- # You might want to download example images if they don't exist
501
- # from huggingface_hub import hf_hub_download
502
- # def download_example(repo_id, filename, local_path):
503
- # if not os.path.exists(local_path):
504
- # try:
505
- # hf_hub_download(repo_id=repo_id, filename=filename, local_dir="./examples", local_dir_use_symlinks=False)
506
- # print(f"Downloaded {filename}")
507
- # except Exception as e:
508
- # print(f"Failed to download example {filename}: {e}")
509
- # return False # Indicate failure
510
- # return os.path.exists(local_path)
511
-
512
- # Example: download_example("path/to/your/example-repo", "example_1.webp", example_files["ex1"])
513
- # For now, we just check existence
514
-
515
- examples_available = {key: os.path.exists(path) for key, path in example_files.items()}
516
-
517
- example_list = []
518
- if examples_available["ex1"]:
519
- example_list.append([example_files["ex1"], "A wide landscape view of the mountains", 1280, 720, "Middle"])
520
- if default_image_path is None: default_image_path = example_files["ex1"]
521
- if examples_available["ex2"]:
522
- example_list.append([example_files["ex2"], "Full body shot of the astronaut on the moon", 720, 1280, "Middle"])
523
- if default_image_path is None: default_image_path = example_files["ex2"]
524
- if examples_available["ex3"]:
525
- example_list.append([example_files["ex3"], "Expanding the sky and ground around the subject", 1024, 1024, "Middle"])
526
- example_list.append([example_files["ex3"], "Expanding downwards from the subject", 1024, 1024, "Top"])
527
- example_list.append([example_files["ex3"], "Expanding upwards from the subject", 1024, 1024, "Bottom"])
528
- if default_image_path is None: default_image_path = example_files["ex3"]
529
-
530
- if not example_list:
531
- print("Warning: No example images found in ./examples/. Examples section will be empty.")
532
- # Optionally create a placeholder image
533
- # placeholder = Image.new('RGB', (512, 512), color = 'grey')
534
- # placeholder_path = "./examples/placeholder.png"
535
- # placeholder.save(placeholder_path)
536
- # example_list.append([placeholder_path, "Placeholder", 1024, 1024, "Middle"])
537
- # default_image_path = placeholder_path
538
-
539
- # --- UI ---
540
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added a theme
541
- gr.HTML(title)
542
-
543
- with gr.Row():
544
- with gr.Column(scale=1): # Left column for inputs
545
- input_image = gr.Image(
546
- value=default_image_path, # Load default example
547
- type="pil",
548
- label="Input Image",
549
- elem_id="input-image"
550
- )
551
-
552
- prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the scene to expand (optional but recommended)...", lines=2)
553
-
554
- with gr.Row():
555
- target_ratio = gr.Radio(
556
- label="Target Aspect Ratio",
557
- choices=["9:16", "16:9", "1:1", "Custom"],
558
- value="9:16",
559
- scale=2
560
- )
561
- alignment_dropdown = gr.Dropdown(
562
- choices=["Middle", "Left", "Right", "Top", "Bottom"],
563
- value="Middle",
564
- label="Align Source Image",
565
- scale=1
566
  )
567
 
568
- with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
569
  with gr.Row():
570
- width_slider = gr.Slider(
571
- label="Target Width", minimum=512, maximum=2048, step=64, value=720
572
- )
573
- height_slider = gr.Slider(
574
- label="Target Height", minimum=512, maximum=2048, step=64, value=1280
575
- )
576
- num_inference_steps = gr.Slider(
577
- label="Steps (TCD/Lightning: 1-8)", minimum=1, maximum=12, step=1, value=4
578
- )
579
-
580
- with gr.Group():
581
- overlap_percentage = gr.Slider(
582
- label="Mask Overlap with Source (%)", minimum=0, maximum=50, value=12, step=1
583
- )
584
- gr.Markdown("Select edges to overlap:", scale=0) # Add context
585
- with gr.Row(elem_classes="gradio-checkboxgroup"): # Apply CSS class
586
- overlap_top = gr.Checkbox(label="Top", value=True, scale=1)
587
- overlap_bottom = gr.Checkbox(label="Bottom", value=True, scale=1)
588
- overlap_left = gr.Checkbox(label="Left", value=True, scale=1)
589
- overlap_right = gr.Checkbox(label="Right", value=True, scale=1)
590
-
591
 
592
  with gr.Row():
593
- resize_option = gr.Radio(
594
- label="Resize source within target",
595
- choices=["Full", "50%", "33%", "25%", "Custom"],
596
- value="Full",
597
- scale=2
598
  )
599
- custom_resize_percentage = gr.Slider(
600
- label="Custom resize (%)", minimum=1, maximum=100, step=1, value=50, visible=False, scale=1
 
 
 
601
  )
602
 
603
- preview_button = gr.Button("Preview Mask & Alignment")
604
- preview_image = gr.Image(label="Mask Preview (Red = Outpaint Area)", type="pil", interactive=False, elem_id="preview-image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
 
606
- if example_list:
607
  gr.Examples(
608
- examples=example_list,
609
- inputs=[input_image, prompt_input, width_slider, height_slider, alignment_dropdown],
610
- label="Examples (Click to load)",
611
- examples_per_page=10
 
 
 
 
612
  )
613
- else:
614
- gr.Markdown("_(No example files found in ./examples)_")
615
 
616
- run_button = gr.Button("Generate", variant="primary")
617
-
618
-
619
- with gr.Column(scale=1): # Right column for output
620
- result = gr.Image(label="Generated Image", type="pil", interactive=False, elem_id="result-image")
621
- use_as_input_button = gr.Button("Use Result as Input Image", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
 
623
- history_gallery = gr.Gallery(
624
- label="History", columns=6, object_fit="contain", interactive=False,
625
- height=110, elem_id="history-gallery"
626
- )
627
 
628
- # --- Event Handling ---
629
 
630
- # Function to set result as input and clear result area
631
- def use_output_as_input_and_clear(output_image):
632
- return {
633
- input_image: gr.update(value=output_image),
634
- result: gr.update(value=None), # Clear result after using it
635
- use_as_input_button: gr.update(visible=False) # Hide button again
636
- }
637
 
638
  use_as_input_button.click(
639
- fn=use_output_as_input_and_clear,
640
- inputs=[result],
641
- outputs=[input_image, result, use_as_input_button]
 
 
 
 
 
 
 
642
  )
643
 
644
  target_ratio.change(
@@ -648,16 +467,18 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added a theme
648
  queue=False
649
  )
650
 
 
651
  width_slider.change(
652
- fn=select_the_right_preset,
653
  inputs=[width_slider, height_slider],
654
- outputs=[target_ratio],
655
  queue=False
656
  )
 
657
  height_slider.change(
658
- fn=select_the_right_preset,
659
  inputs=[width_slider, height_slider],
660
- outputs=[target_ratio],
661
  queue=False
662
  )
663
 
@@ -668,77 +489,58 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Added a theme
668
  queue=False
669
  )
670
 
671
- # Consolidate common inputs for generation
672
  gen_inputs = [
673
  input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
674
  resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
675
  overlap_left, overlap_right, overlap_top, overlap_bottom
676
  ]
677
- gen_outputs = [result] # Single output image
678
 
679
- # Chain generation logic for Run button
680
- run_trigger = run_button.click(
681
- fn=clear_result, # Clear previous result first
682
- inputs=[], # No inputs needed for clear
683
- outputs=[result, use_as_input_button], # Components to clear/hide
684
- queue=False
 
 
 
 
 
 
685
  ).then(
686
- fn=infer,
687
  inputs=gen_inputs,
688
- outputs=gen_outputs,
689
- )
690
-
691
- # After generation finishes (successfully or not), update history and button visibility
692
- run_trigger.then(
693
- fn=lambda res_img, hist: update_history(res_img, hist),
694
- inputs=[result, history_gallery],
695
- outputs=[history_gallery],
696
- queue=False # Update history immediately
697
  ).then(
698
- # Show the 'Use as Input' button only if generation was successful (result is not None)
699
- fn=lambda res_img: gr.update(visible=isinstance(res_img, Image.Image)),
700
- inputs=[result],
701
- outputs=[use_as_input_button],
702
- queue=False # Show button immediately
703
  )
704
 
705
-
706
- # Chain generation logic for Enter key in Prompt textbox
707
- submit_trigger = prompt_input.submit(
708
- fn=clear_result,
709
- inputs=[],
710
- outputs=[result, use_as_input_button],
711
- queue=False
712
  ).then(
713
- fn=infer,
714
  inputs=gen_inputs,
715
- outputs=gen_outputs,
716
- )
717
-
718
- submit_trigger.then(
719
- fn=lambda res_img, hist: update_history(res_img, hist),
720
- inputs=[result, history_gallery],
721
- outputs=[history_gallery],
722
- queue=False
723
  ).then(
724
- fn=lambda res_img: gr.update(visible=isinstance(res_img, Image.Image)),
725
- inputs=[result],
726
- outputs=[use_as_input_button],
727
- queue=False
728
  )
729
 
730
- # Preview button logic
731
- preview_inputs = [
732
- input_image, width_slider, height_slider, overlap_percentage, resize_option,
733
- custom_resize_percentage, alignment_dropdown, overlap_left, overlap_right,
734
- overlap_top, overlap_bottom
735
- ]
736
  preview_button.click(
737
  fn=preview_image_and_mask,
738
- inputs=preview_inputs,
739
- outputs=preview_image,
740
- queue=False
 
741
  )
742
 
743
- # Launch the interface
744
- demo.queue(max_size=10).launch(ssr_mode=False, show_error=True, debug=True) # Add debug=True for more logs
 
1
+ # -*- coding: utf-8 -*-
2
  import gradio as gr
3
  import spaces
4
  import torch
5
  from diffusers import AutoencoderKL, TCDScheduler
6
  from diffusers.models.model_loading_utils import load_state_dict
7
+ # Remove ImageSlider import as it's no longer needed
8
+ # from gradio_imageslider import ImageSlider
9
  from huggingface_hub import hf_hub_download
10
 
11
+ from controlnet_union import ControlNetModel_Union
12
+ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  from PIL import Image, ImageDraw
15
  import numpy as np
 
 
 
 
 
16
 
17
+ # --- Model Loading (Keep as is) ---
18
+ config_file = hf_hub_download(
19
+ "xinsir/controlnet-union-sdxl-1.0",
20
+ filename="config_promax.json",
21
+ )
22
+
23
+ config = ControlNetModel_Union.load_config(config_file)
24
+ controlnet_model = ControlNetModel_Union.from_config(config)
25
+ model_file = hf_hub_download(
26
+ "xinsir/controlnet-union-sdxl-1.0",
27
+ filename="diffusion_pytorch_model_promax.safetensors",
28
+ )
29
+ state_dict = load_state_dict(model_file)
30
+ model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
31
+ controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
32
+ )
33
+ model.to(device="cuda", dtype=torch.float16)
34
+
35
+ vae = AutoencoderKL.from_pretrained(
36
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
37
+ ).to("cuda")
38
+
39
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
40
+ "SG161222/RealVisXL_V5.0_Lightning",
41
+ torch_dtype=torch.float16,
42
+ vae=vae,
43
+ controlnet=model,
44
+ variant="fp16",
45
+ ).to("cuda")
46
+
47
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
48
+
49
+ # --- Helper Functions (Keep as is, except infer) ---
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def can_expand(source_width, source_height, target_width, target_height, alignment):
52
  """Checks if the image can be expanded based on the alignment."""
53
  if alignment in ("Left", "Right") and source_width >= target_width:
 
57
  return True
58
 
59
  def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
60
+ target_size = (width, height)
61
+
62
+ # Calculate the scaling factor to fit the image within the target size
63
+ scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
64
+ new_width = int(image.width * scale_factor)
65
+ new_height = int(image.height * scale_factor)
66
+
67
+ # Resize the source image to fit within target size
68
+ source = image.resize((new_width, new_height), Image.LANCZOS)
69
+
70
+ # Apply resize option using percentages
71
+ if resize_option == "Full":
72
+ resize_percentage = 100
73
+ elif resize_option == "50%":
74
+ resize_percentage = 50
75
+ elif resize_option == "33%":
76
+ resize_percentage = 33
77
+ elif resize_option == "25%":
78
+ resize_percentage = 25
79
+ else: # Custom
80
+ resize_percentage = custom_resize_percentage
81
+
82
+ # Calculate new dimensions based on percentage
83
+ resize_factor = resize_percentage / 100
84
+ new_width = int(source.width * resize_factor)
85
+ new_height = int(source.height * resize_factor)
86
+
87
+ # Ensure minimum size of 64 pixels
88
+ new_width = max(new_width, 64)
89
+ new_height = max(new_height, 64)
90
+
91
+ # Resize the image
92
+ source = source.resize((new_width, new_height), Image.LANCZOS)
93
+
94
+ # Calculate the overlap in pixels based on the percentage
95
+ overlap_x = int(new_width * (overlap_percentage / 100))
96
+ overlap_y = int(new_height * (overlap_percentage / 100))
97
+
98
+ # Ensure minimum overlap of 1 pixel
99
+ overlap_x = max(overlap_x, 1)
100
+ overlap_y = max(overlap_y, 1)
101
+
102
+ # Calculate margins based on alignment
103
+ if alignment == "Middle":
104
+ margin_x = (target_size[0] - new_width) // 2
105
+ margin_y = (target_size[1] - new_height) // 2
106
+ elif alignment == "Left":
107
+ margin_x = 0
108
+ margin_y = (target_size[1] - new_height) // 2
109
+ elif alignment == "Right":
110
+ margin_x = target_size[0] - new_width
111
+ margin_y = (target_size[1] - new_height) // 2
112
+ elif alignment == "Top":
113
+ margin_x = (target_size[0] - new_width) // 2
114
+ margin_y = 0
115
+ elif alignment == "Bottom":
116
+ margin_x = (target_size[0] - new_width) // 2
117
+ margin_y = target_size[1] - new_height
118
+
119
+ # Adjust margins to eliminate gaps
120
+ margin_x = max(0, min(margin_x, target_size[0] - new_width))
121
+ margin_y = max(0, min(margin_y, target_size[1] - new_height))
122
+
123
+ # Create a new background image and paste the resized source image
124
+ background = Image.new('RGB', target_size, (255, 255, 255))
125
+ background.paste(source, (margin_x, margin_y))
126
+
127
+ # Create the mask
128
+ mask = Image.new('L', target_size, 255)
129
+ mask_draw = ImageDraw.Draw(mask)
130
+
131
+ # Calculate overlap areas
132
+ white_gaps_patch = 2
133
+
134
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
135
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
136
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
137
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
138
+
139
+ if alignment == "Left":
140
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x
141
+ elif alignment == "Right":
142
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
143
+ elif alignment == "Top":
144
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y
145
+ elif alignment == "Bottom":
146
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
147
+
148
+
149
+ # Draw the mask
150
+ mask_draw.rectangle([
151
+ (left_overlap, top_overlap),
152
+ (right_overlap, bottom_overlap)
153
+ ], fill=0)
154
+
155
+ return background, mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
158
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
 
 
 
159
 
160
+ # Create a preview image showing the mask
161
+ preview = background.copy().convert('RGBA')
162
 
163
+ # Create a semi-transparent red overlay
164
+ red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64)) # Reduced alpha to 64 (25% opacity)
165
 
166
+ # Convert black pixels in the mask to semi-transparent red
167
+ red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
168
+ red_mask.paste(red_overlay, (0, 0), mask)
169
 
170
+ # Overlay the red mask on the background
171
+ preview = Image.alpha_composite(preview, red_mask)
 
 
 
 
 
 
172
 
173
+ return preview
174
 
175
+ @spaces.GPU(duration=24)
176
+ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
177
  if image is None:
178
+ raise gr.Error("Please upload an input image.")
179
+
180
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
181
 
182
+ if not can_expand(background.width, background.height, width, height, alignment):
183
+ # Optionally provide feedback or default to middle
184
+ # gr.Warning(f"Cannot expand image with '{alignment}' alignment as source dimension is larger than target. Defaulting to 'Middle'.")
185
+ alignment = "Middle"
186
+ # Recalculate background and mask if alignment changed due to this check
187
  background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
188
 
189
+
190
+ cnet_image = background.copy()
191
+ # Apply mask to create the input for controlnet (black out non-masked area)
192
+ # cnet_image.paste(0, (0, 0), mask) # This line seems incorrect for inpainting/outpainting, usually the unmasked area is kept
193
+ # The pipeline expects the original image content where mask=0 and potentially noise/latents where mask=1
194
+ # Let's keep the original image content in the unmasked area and let the pipeline handle the masked area.
195
+ # The `StableDiffusionXLFillPipeline` likely uses the `image` input and `mask` differently than standard inpainting.
196
+ # Based on typical diffusers pipelines, `image` is often the *original* content placed on the canvas.
197
+ # Let's pass `background` as the image input for the pipeline.
198
+
199
+ final_prompt = f"{prompt_input} , high quality, 4k" if prompt_input else "high quality, 4k"
200
+
201
+ (
202
+ prompt_embeds,
203
+ negative_prompt_embeds,
204
+ pooled_prompt_embeds,
205
+ negative_pooled_prompt_embeds,
206
+ ) = pipe.encode_prompt(final_prompt, "cuda", True, negative_prompt="") # Add default negative prompt
207
+
208
+ # The pipeline expects the `image` and `mask_image` arguments for inpainting/outpainting
209
+ # `image` should be the canvas with the original image placed.
210
+ # `mask_image` defines the area to be filled (white=fill, black=keep).
211
+ # Our mask is inverted (black=keep, white=fill). Invert it.
212
+ inverted_mask = Image.fromarray(255 - np.array(mask))
213
+
214
+ # Run the pipeline
215
+ # Note: The generator inside the pipeline call is not used here as we only need the final result.
216
+ # We iterate once to get the final image.
217
+ generated_image = None
218
+ for img_output in pipe(
219
+ prompt_embeds=prompt_embeds,
220
+ negative_prompt_embeds=negative_prompt_embeds,
221
+ pooled_prompt_embeds=pooled_prompt_embeds,
222
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
223
+ image=background, # Pass the background with the source image placed
224
+ mask_image=inverted_mask, # Pass the inverted mask (white = area to fill)
225
+ control_image=background, # ControlNet Union might need the full image context
226
+ num_inference_steps=num_inference_steps,
227
+ output_type="pil" # Ensure PIL images are returned
228
+ ):
229
+ generated_image = img_output[0] # The pipeline returns a list containing the image
230
+
231
+ if generated_image is None:
232
+ raise gr.Error("Image generation failed.")
233
+
234
+ # The pipeline should return the complete image already composited.
235
+ # No need to manually paste.
236
+ final_image = generated_image.convert("RGB")
237
+
238
+ # Yield only the final generated image
239
+ yield final_image
240
+
241
+
242
+ def clear_result():
243
+ """Clears the result Image component."""
244
+ return gr.update(value=None)
245
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  def preload_presets(target_ratio, ui_width, ui_height):
247
  """Updates the width and height sliders based on the selected aspect ratio."""
 
248
  if target_ratio == "9:16":
249
  changed_width = 720
250
  changed_height = 1280
251
+ return changed_width, changed_height, gr.update(open=False) # Close accordion
252
  elif target_ratio == "16:9":
253
  changed_width = 1280
254
  changed_height = 720
255
+ return changed_width, changed_height, gr.update(open=False) # Close accordion
256
  elif target_ratio == "1:1":
257
  changed_width = 1024
258
  changed_height = 1024
259
+ return changed_width, changed_height, gr.update(open=False) # Close accordion
260
  elif target_ratio == "Custom":
261
+ # Keep current slider values but open the accordion
262
+ return ui_width, ui_height, gr.update(open=True)
 
 
 
 
 
 
263
 
264
  def select_the_right_preset(user_width, user_height):
265
+ """Selects the preset radio button based on current width/height."""
266
  if user_width == 720 and user_height == 1280:
267
  return "9:16"
268
  elif user_width == 1280 and user_height == 720:
 
278
 
279
  def update_history(new_image, history):
280
  """Updates the history gallery with the new image."""
281
+ if new_image is None: # Don't add None to history
282
+ return history
 
283
  if history is None:
284
  history = []
285
+ # Prepend the new image (as PIL) to the history list
286
  history.insert(0, new_image)
287
+ # Limit history size if desired (e.g., keep last 12)
288
  max_history = 12
289
  if len(history) > max_history:
290
  history = history[:max_history]
291
  return history
292
 
293
+ # --- Gradio UI ---
294
+
295
  css = """
296
  .gradio-container {
297
+ max-width: 1200px !important; /* Limit overall width */
298
+ margin: auto; /* Center the container */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  }
300
+ /* Ensure gallery items are reasonably sized */
301
+ #history_gallery .thumbnail-item {
302
+ height: 100px !important; /* Adjust as needed */
303
  }
304
+ #history_gallery .gallery {
305
+ grid-template-columns: repeat(auto-fill, minmax(100px, 1fr)) !important; /* Adjust column size */
 
 
306
  }
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  """
309
 
310
+ title = """<h1 align="center">Diffusers Image Outpaint</h1>
311
+ <div align="center">Drop an image you would like to extend, pick your expected ratio and hit Generate.</div>
312
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
313
+ <p style="display: flex;gap: 6px;">
314
+ <a href="https://huggingface.co/spaces/fffiloni/diffusers-image-outpaint?duplicate=true">
315
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-md.svg" alt="Duplicate this Space">
316
+ </a> to skip the queue and enjoy faster inference on the GPU of your choice
317
+ </p>
318
+ </div>
319
+ """
320
 
321
+ with gr.Blocks(css=css) as demo:
322
+ with gr.Column():
323
+ gr.HTML(title)
 
324
 
325
+ with gr.Row():
326
+ with gr.Column(scale=1): # Input column
327
+ input_image = gr.Image(
328
+ type="pil",
329
+ label="Input Image"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  )
331
 
 
332
  with gr.Row():
333
+ with gr.Column(scale=2):
334
+ prompt_input = gr.Textbox(label="Prompt (Optional)", placeholder="Describe the desired extended scene...")
335
+ with gr.Column(scale=1, min_width=150):
336
+ run_button = gr.Button("Generate", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  with gr.Row():
339
+ target_ratio = gr.Radio(
340
+ label="Target Ratio",
341
+ choices=["9:16", "16:9", "1:1", "Custom"],
342
+ value="9:16",
343
+ scale=2
344
  )
345
+
346
+ alignment_dropdown = gr.Dropdown(
347
+ choices=["Middle", "Left", "Right", "Top", "Bottom"],
348
+ value="Middle",
349
+ label="Align Source Image"
350
  )
351
 
352
+ with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
353
+ with gr.Column():
354
+ with gr.Row():
355
+ width_slider = gr.Slider(
356
+ label="Target Width (px)",
357
+ minimum=512, # Lowered min slightly
358
+ maximum=2048, # Increased max slightly
359
+ step=64, # SDXL optimal step
360
+ value=720,
361
+ )
362
+ height_slider = gr.Slider(
363
+ label="Target Height (px)",
364
+ minimum=512, # Lowered min slightly
365
+ maximum=2048, # Increased max slightly
366
+ step=64, # SDXL optimal step
367
+ value=1280,
368
+ )
369
+
370
+ num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=20, step=1, value=8) # Increased max steps slightly
371
+ with gr.Group():
372
+ overlap_percentage = gr.Slider(
373
+ label="Mask overlap (%)",
374
+ minimum=1,
375
+ maximum=50,
376
+ value=10,
377
+ step=1,
378
+ info="How much the new area overlaps the original image."
379
+ )
380
+ gr.Markdown("Select sides to overlap (influences mask generation):")
381
+ with gr.Row():
382
+ overlap_top = gr.Checkbox(label="Top", value=True)
383
+ overlap_right = gr.Checkbox(label="Right", value=True)
384
+ with gr.Row():
385
+ overlap_left = gr.Checkbox(label="Left", value=True)
386
+ overlap_bottom = gr.Checkbox(label="Bottom", value=True)
387
+ with gr.Row():
388
+ resize_option = gr.Radio(
389
+ label="Resize input image before placing",
390
+ choices=["Full", "50%", "33%", "25%", "Custom"],
391
+ value="Full",
392
+ info="Scales the source image down before placing it on the target canvas."
393
+ )
394
+ custom_resize_percentage = gr.Slider(
395
+ label="Custom resize (%)",
396
+ minimum=1,
397
+ maximum=100,
398
+ step=1,
399
+ value=50,
400
+ visible=False
401
+ )
402
+
403
+ with gr.Column():
404
+ preview_button = gr.Button("Preview Alignment & Mask")
405
+
406
 
 
407
  gr.Examples(
408
+ examples=[
409
+ ["./examples/example_1.webp", 1280, 720, "Middle", "A wide landscape view of the mountains"],
410
+ ["./examples/example_2.jpg", 1440, 810, "Left", "Full body shot of the cat sitting on a rug"],
411
+ ["./examples/example_3.jpg", 1024, 1024, "Top", "The cloudy sky above the building"],
412
+ ["./examples/example_3.jpg", 1024, 1024, "Bottom", "The street below the building"],
413
+ ],
414
+ inputs=[input_image, width_slider, height_slider, alignment_dropdown, prompt_input],
415
+ label="Examples (Click to load)"
416
  )
 
 
417
 
418
+ with gr.Column(scale=1): # Output column
419
+ # Replace ImageSlider with gr.Image
420
+ result_image = gr.Image(
421
+ label="Generated Image",
422
+ interactive=False,
423
+ show_download_button=True,
424
+ type="pil" # Ensure output is PIL for history
425
+ )
426
+ with gr.Row():
427
+ use_as_input_button = gr.Button("Use as Input", visible=False)
428
+ clear_button = gr.Button("Clear Output") # Added clear button
429
+
430
+ preview_mask_image = gr.Image(label="Alignment & Mask Preview", interactive=False)
431
+
432
+ history_gallery = gr.Gallery(
433
+ label="History",
434
+ columns=4, # Adjust columns as needed
435
+ object_fit="contain",
436
+ interactive=False,
437
+ show_label=True,
438
+ elem_id="history_gallery",
439
+ height=300 # Set a fixed height for the gallery area
440
+ )
441
 
 
 
 
 
442
 
443
+ # --- Event Handlers ---
444
 
445
+ def use_output_as_input(output_pil_image):
446
+ """Sets the generated output PIL image as the new input image."""
447
+ # output_image comes directly from result_image which is PIL type
448
+ return gr.update(value=output_pil_image)
 
 
 
449
 
450
  use_as_input_button.click(
451
+ fn=use_output_as_input,
452
+ inputs=[result_image], # Input is the single result image
453
+ outputs=[input_image]
454
+ )
455
+
456
+ clear_button.click(
457
+ fn=lambda: (gr.update(value=None), gr.update(visible=False), gr.update(value=None)), # Clear image, hide button, clear preview
458
+ inputs=None,
459
+ outputs=[result_image, use_as_input_button, preview_mask_image],
460
+ queue=False
461
  )
462
 
463
  target_ratio.change(
 
467
  queue=False
468
  )
469
 
470
+ # Link sliders back to ratio selector and potentially open accordion
471
  width_slider.change(
472
+ fn=lambda w, h: (select_the_right_preset(w, h), gr.update() if select_the_right_preset(w, h) == "Custom" else gr.update()),
473
  inputs=[width_slider, height_slider],
474
+ outputs=[target_ratio, settings_panel],
475
  queue=False
476
  )
477
+
478
  height_slider.change(
479
+ fn=lambda w, h: (select_the_right_preset(w, h), gr.update() if select_the_right_preset(w, h) == "Custom" else gr.update()),
480
  inputs=[width_slider, height_slider],
481
+ outputs=[target_ratio, settings_panel],
482
  queue=False
483
  )
484
 
 
489
  queue=False
490
  )
491
 
492
+ # Define common inputs for generation
493
  gen_inputs = [
494
  input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
495
  resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
496
  overlap_left, overlap_right, overlap_top, overlap_bottom
497
  ]
 
498
 
499
+ # Define common steps after generation
500
+ def handle_output(generated_image, current_history):
501
+ # generated_image is the single PIL image from infer
502
+ new_history = update_history(generated_image, current_history)
503
+ button_visibility = gr.update(visible=True) if generated_image else gr.update(visible=False)
504
+ return generated_image, new_history, button_visibility
505
+
506
+ run_button.click(
507
+ fn=lambda: (gr.update(value=None), gr.update(visible=False)), # Clear result and hide button first
508
+ inputs=None,
509
+ outputs=[result_image, use_as_input_button],
510
+ queue=False # Don't queue the clearing part
511
  ).then(
512
+ fn=infer, # Run the generation
513
  inputs=gen_inputs,
514
+ outputs=result_image, # Output is the single generated image
 
 
 
 
 
 
 
 
515
  ).then(
516
+ fn=handle_output, # Process output: update history, show button
517
+ inputs=[result_image, history_gallery],
518
+ outputs=[result_image, history_gallery, use_as_input_button] # Update result again (no change), history, and button visibility
 
 
519
  )
520
 
521
+ prompt_input.submit(
522
+ fn=lambda: (gr.update(value=None), gr.update(visible=False)), # Clear result and hide button first
523
+ inputs=None,
524
+ outputs=[result_image, use_as_input_button],
525
+ queue=False # Don't queue the clearing part
 
 
526
  ).then(
527
+ fn=infer, # Run the generation
528
  inputs=gen_inputs,
529
+ outputs=result_image, # Output is the single generated image
 
 
 
 
 
 
 
530
  ).then(
531
+ fn=handle_output, # Process output: update history, show button
532
+ inputs=[result_image, history_gallery],
533
+ outputs=[result_image, history_gallery, use_as_input_button] # Update result again (no change), history, and button visibility
 
534
  )
535
 
536
+
 
 
 
 
 
537
  preview_button.click(
538
  fn=preview_image_and_mask,
539
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
540
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
541
+ outputs=preview_mask_image, # Output to the preview image component
542
+ queue=False # Preview should be fast
543
  )
544
 
545
+ # Launch the app
546
+ demo.queue(max_size=12).launch(share=False, ssr_mode=False, show_error=True)