prithivMLmods commited on
Commit
11d7c13
·
verified ·
1 Parent(s): 1b7b670

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -396
app.py CHANGED
@@ -2,447 +2,229 @@ import gradio as gr
2
  import spaces
3
  import torch
4
  from diffusers import AutoencoderKL, TCDScheduler
5
- from diffusers.models.model_loading_utils import load_state_dict
6
- from gradio_imageslider import ImageSlider
7
- from huggingface_hub import hf_hub_download
8
-
9
  from controlnet_union import ControlNetModel_Union
10
  from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
 
 
11
 
12
  from PIL import Image, ImageDraw
13
  import numpy as np
14
 
15
- config_file = hf_hub_download(
 
 
16
  "xinsir/controlnet-union-sdxl-1.0",
17
- filename="config_promax.json",
18
- )
19
-
20
- config = ControlNetModel_Union.load_config(config_file)
21
- controlnet_model = ControlNetModel_Union.from_config(config)
22
- model_file = hf_hub_download(
23
- "xinsir/controlnet-union-sdxl-1.0",
24
- filename="diffusion_pytorch_model_promax.safetensors",
25
- )
26
-
27
- sstate_dict = load_state_dict(model_file)
28
- model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
29
- controlnet_model, sstate_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
30
- )
31
- model.to(device="cuda", dtype=torch.float16)
32
- #----------------------
33
 
34
  vae = AutoencoderKL.from_pretrained(
35
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
 
36
  ).to("cuda")
37
 
38
  pipe = StableDiffusionXLFillPipeline.from_pretrained(
39
  "SG161222/RealVisXL_V5.0_Lightning",
40
  torch_dtype=torch.float16,
41
  vae=vae,
42
- controlnet=model,
43
  variant="fp16",
44
  ).to("cuda")
45
-
46
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
47
 
 
48
  def can_expand(source_width, source_height, target_width, target_height, alignment):
49
- """Checks if the image can be expanded based on the alignment."""
50
  if alignment in ("Left", "Right") and source_width >= target_width:
51
  return False
52
  if alignment in ("Top", "Bottom") and source_height >= target_height:
53
  return False
54
  return True
55
 
56
- def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
57
- target_size = (width, height)
58
-
59
- # Calculate the scaling factor to fit the image within the target size
60
- scale_factor = min(target_size[0] / image.width, target_size[1] / image.height)
61
- new_width = int(image.width * scale_factor)
62
- new_height = int(image.height * scale_factor)
63
-
64
- # Resize the source image to fit within target size
65
- source = image.resize((new_width, new_height), Image.LANCZOS)
66
-
67
- # Apply resize option using percentages
68
- if resize_option == "Full":
69
- resize_percentage = 100
70
- elif resize_option == "50%":
71
- resize_percentage = 50
72
- elif resize_option == "33%":
73
- resize_percentage = 33
74
- elif resize_option == "25%":
75
- resize_percentage = 25
76
- else: # Custom
77
- resize_percentage = custom_resize_percentage
78
-
79
- # Calculate new dimensions based on percentage
80
- resize_factor = resize_percentage / 100
81
- new_width = int(source.width * resize_factor)
82
- new_height = int(source.height * resize_factor)
83
-
84
- # Ensure minimum size of 64 pixels
85
- new_width = max(new_width, 64)
86
- new_height = max(new_height, 64)
87
-
88
- # Resize the image
89
- source = source.resize((new_width, new_height), Image.LANCZOS)
90
-
91
- # Calculate the overlap in pixels based on the percentage
92
- overlap_x = int(new_width * (overlap_percentage / 100))
93
- overlap_y = int(new_height * (overlap_percentage / 100))
94
-
95
- # Ensure minimum overlap of 1 pixel
96
- overlap_x = max(overlap_x, 1)
97
- overlap_y = max(overlap_y, 1)
98
-
99
- # Calculate margins based on alignment
100
- if alignment == "Middle":
101
- margin_x = (target_size[0] - new_width) // 2
102
- margin_y = (target_size[1] - new_height) // 2
103
- elif alignment == "Left":
104
- margin_x = 0
105
- margin_y = (target_size[1] - new_height) // 2
106
- elif alignment == "Right":
107
- margin_x = target_size[0] - new_width
108
- margin_y = (target_size[1] - new_height) // 2
109
- elif alignment == "Top":
110
- margin_x = (target_size[0] - new_width) // 2
111
- margin_y = 0
112
- elif alignment == "Bottom":
113
- margin_x = (target_size[0] - new_width) // 2
114
- margin_y = target_size[1] - new_height
115
-
116
- # Adjust margins to eliminate gaps
117
- margin_x = max(0, min(margin_x, target_size[0] - new_width))
118
- margin_y = max(0, min(margin_y, target_size[1] - new_height))
119
-
120
- # Create a new background image and paste the resized source image
121
- background = Image.new('RGB', target_size, (255, 255, 255))
122
- background.paste(source, (margin_x, margin_y))
123
-
124
- # Create the mask
125
- mask = Image.new('L', target_size, 255)
126
- mask_draw = ImageDraw.Draw(mask)
127
-
128
- # Calculate overlap areas
129
- white_gaps_patch = 2
130
-
131
- left_overlap = margin_x + overlap_x if overlap_left else margin_x + white_gaps_patch
132
- right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width - white_gaps_patch
133
- top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
134
- bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height - white_gaps_patch
135
-
136
- if alignment == "Left":
137
- left_overlap = margin_x + overlap_x if overlap_left else margin_x
138
- elif alignment == "Right":
139
- right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
140
- elif alignment == "Top":
141
- top_overlap = margin_y + overlap_y if overlap_top else margin_y
142
- elif alignment == "Bottom":
143
- bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
144
-
145
-
146
- # Draw the mask
147
- mask_draw.rectangle([
148
- (left_overlap, top_overlap),
149
- (right_overlap, bottom_overlap)
150
- ], fill=0)
151
-
152
- return background, mask
153
-
154
- def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
155
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
156
-
157
- # Create a preview image showing the mask
158
- preview = background.copy().convert('RGBA')
159
-
160
- # Create a semi-transparent red overlay
161
- red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64)) # Reduced alpha to 64 (25% opacity)
162
-
163
- # Convert black pixels in the mask to semi-transparent red
164
- red_mask = Image.new('RGBA', background.size, (0, 0, 0, 0))
165
- red_mask.paste(red_overlay, (0, 0), mask)
166
-
167
- # Overlay the red mask on the background
168
- preview = Image.alpha_composite(preview, red_mask)
169
-
170
- return preview
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  @spaces.GPU(duration=24)
173
- def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
174
- background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
175
-
 
 
 
 
 
 
176
  if not can_expand(background.width, background.height, width, height, alignment):
177
  alignment = "Middle"
178
 
179
- cnet_image = background.copy()
180
- cnet_image.paste(0, (0, 0), mask)
181
 
182
  final_prompt = f"{prompt_input} , high quality, 4k"
183
-
184
- (
185
- prompt_embeds,
186
- negative_prompt_embeds,
187
- pooled_prompt_embeds,
188
- negative_pooled_prompt_embeds,
189
- ) = pipe.encode_prompt(final_prompt, "cuda", True)
190
-
191
- for image in pipe(
192
- prompt_embeds=prompt_embeds,
193
- negative_prompt_embeds=negative_prompt_embeds,
194
- pooled_prompt_embeds=pooled_prompt_embeds,
195
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
196
- image=cnet_image,
197
  num_inference_steps=num_inference_steps
198
- ):
199
- yield cnet_image, image
 
 
200
 
201
- image = image.convert("RGBA")
202
- cnet_image.paste(image, (0, 0), mask)
 
 
 
203
 
204
- yield background, cnet_image
205
 
206
  def clear_result():
207
- """Clears the result ImageSlider."""
208
  return gr.update(value=None)
209
 
210
- def preload_presets(target_ratio, ui_width, ui_height):
211
- """Updates the width and height sliders based on the selected aspect ratio."""
212
- if target_ratio == "9:16":
213
- changed_width = 720
214
- changed_height = 1280
215
- return changed_width, changed_height, gr.update()
216
- elif target_ratio == "16:9":
217
- changed_width = 1280
218
- changed_height = 720
219
- return changed_width, changed_height, gr.update()
220
- elif target_ratio == "1:1":
221
- changed_width = 1024
222
- changed_height = 1024
223
- return changed_width, changed_height, gr.update()
224
- elif target_ratio == "Custom":
225
- return ui_width, ui_height, gr.update(open=True)
226
-
227
- def select_the_right_preset(user_width, user_height):
228
- if user_width == 720 and user_height == 1280:
229
- return "9:16"
230
- elif user_width == 1280 and user_height == 720:
231
- return "16:9"
232
- elif user_width == 1024 and user_height == 1024:
233
- return "1:1"
234
- else:
235
- return "Custom"
236
-
237
- def toggle_custom_resize_slider(resize_option):
238
- return gr.update(visible=(resize_option == "Custom"))
239
-
240
- def update_history(new_image, history):
241
- """Updates the history gallery with the new image."""
242
- if history is None:
243
- history = []
244
- history.insert(0, new_image)
245
- return history
246
 
247
- css = """
248
- .gradio-container {
249
- width: 1200px !important;
250
- }
251
- h1 { text-align: center; }
252
- footer { visibility: hidden; }
253
- """
254
 
255
- title = """<h1 align="center">Diffusers Image Outpaint Lightning</h1>
256
- """
257
 
258
- with gr.Blocks(css=css) as demo:
259
- with gr.Column():
260
- gr.HTML(title)
261
-
262
- with gr.Row():
263
- with gr.Column():
264
- input_image = gr.Image(
265
- type="pil",
266
- label="Input Image"
267
- )
268
-
269
- with gr.Row():
270
- with gr.Column(scale=2):
271
- prompt_input = gr.Textbox(label="Prompt (Optional)")
272
- with gr.Column(scale=1):
273
- run_button = gr.Button("Generate")
274
-
275
- with gr.Row():
276
- target_ratio = gr.Radio(
277
- label="Expected Ratio",
278
- choices=["9:16", "16:9", "1:1", "Custom"],
279
- value="9:16",
280
- scale=2
281
- )
282
-
283
- alignment_dropdown = gr.Dropdown(
284
- choices=["Middle", "Left", "Right", "Top", "Bottom"],
285
- value="Middle",
286
- label="Alignment"
287
- )
288
-
289
- with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
290
- with gr.Column():
291
- with gr.Row():
292
- width_slider = gr.Slider(
293
- label="Target Width",
294
- minimum=720,
295
- maximum=1536,
296
- step=8,
297
- value=720, # Set a default value
298
- )
299
- height_slider = gr.Slider(
300
- label="Target Height",
301
- minimum=720,
302
- maximum=1536,
303
- step=8,
304
- value=1280, # Set a default value
305
- )
306
-
307
- num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
308
- with gr.Group():
309
- overlap_percentage = gr.Slider(
310
- label="Mask overlap (%)",
311
- minimum=1,
312
- maximum=50,
313
- value=10,
314
- step=1
315
- )
316
- with gr.Row():
317
- overlap_top = gr.Checkbox(label="Overlap Top", value=True)
318
- overlap_right = gr.Checkbox(label="Overlap Right", value=True)
319
- with gr.Row():
320
- overlap_left = gr.Checkbox(label="Overlap Left", value=True)
321
- overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
322
- with gr.Row():
323
- resize_option = gr.Radio(
324
- label="Resize input image",
325
- choices=["Full", "50%", "33%", "25%", "Custom"],
326
- value="Full"
327
- )
328
- custom_resize_percentage = gr.Slider(
329
- label="Custom resize (%)",
330
- minimum=1,
331
- maximum=100,
332
- step=1,
333
- value=50,
334
- visible=False
335
- )
336
-
337
- with gr.Column():
338
- preview_button = gr.Button("Preview alignment and mask")
339
-
340
-
341
- gr.Examples(
342
- examples=[
343
- ["./examples/example_1.webp", 1280, 720, "Middle"],
344
- ["./examples/example_2.jpg", 1440, 810, "Left"],
345
- ["./examples/example_3.jpg", 1024, 1024, "Top"],
346
- ["./examples/example_3.jpg", 1024, 1024, "Bottom"],
347
- ],
348
- inputs=[input_image, width_slider, height_slider, alignment_dropdown],
349
- )
350
-
351
-
352
-
353
- with gr.Column():
354
- result = ImageSlider(label="Generated Image", interactive=False, type="pil", slider_color="pink")
355
- use_as_input_button = gr.Button("Use as Input Image", visible=False)
356
-
357
- history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
358
- preview_image = gr.Image(label="Preview")
359
-
360
-
361
-
362
- def use_output_as_input(output_image):
363
- """Sets the generated output as the new input image."""
364
- return gr.update(value=output_image[1])
365
-
366
- use_as_input_button.click(
367
- fn=use_output_as_input,
368
- inputs=[result],
369
- outputs=[input_image]
370
- )
371
-
372
- target_ratio.change(
373
- fn=preload_presets,
374
- inputs=[target_ratio, width_slider, height_slider],
375
- outputs=[width_slider, height_slider, settings_panel],
376
- queue=False
377
- )
378
-
379
- width_slider.change(
380
- fn=select_the_right_preset,
381
- inputs=[width_slider, height_slider],
382
- outputs=[target_ratio],
383
- queue=False
384
- )
385
-
386
- height_slider.change(
387
- fn=select_the_right_preset,
388
- inputs=[width_slider, height_slider],
389
- outputs=[target_ratio],
390
- queue=False
391
- )
392
 
393
- resize_option.change(
394
- fn=toggle_custom_resize_slider,
395
- inputs=[resize_option],
396
- outputs=[custom_resize_percentage],
397
- queue=False
398
- )
399
-
400
- run_button.click( # Clear the result
401
- fn=clear_result,
402
- inputs=None,
403
- outputs=result,
404
- ).then( # Generate the new image
405
- fn=infer,
406
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
407
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
408
- overlap_left, overlap_right, overlap_top, overlap_bottom],
409
- outputs=result,
410
- ).then( # Update the history gallery
411
- fn=lambda x, history: update_history(x[1], history),
412
- inputs=[result, history_gallery],
413
- outputs=history_gallery,
414
- ).then( # Show the "Use as Input Image" button
415
- fn=lambda: gr.update(visible=True),
416
- inputs=None,
417
- outputs=use_as_input_button,
418
- )
419
 
420
- prompt_input.submit( # Clear the result
421
- fn=clear_result,
422
- inputs=None,
423
- outputs=result,
424
- ).then( # Generate the new image
425
- fn=infer,
426
- inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
427
- resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
428
- overlap_left, overlap_right, overlap_top, overlap_bottom],
429
- outputs=result,
430
- ).then( # Update the history gallery
431
- fn=lambda x, history: update_history(x[1], history),
432
- inputs=[result, history_gallery],
433
- outputs=history_gallery,
434
- ).then( # Show the "Use as Input Image" button
435
- fn=lambda: gr.update(visible=True),
436
- inputs=None,
437
- outputs=use_as_input_button,
438
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
- preview_button.click(
441
- fn=preview_image_and_mask,
442
- inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
443
- overlap_left, overlap_right, overlap_top, overlap_bottom],
444
- outputs=preview_image,
445
- queue=False
446
- )
447
 
448
  demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)
 
2
  import spaces
3
  import torch
4
  from diffusers import AutoencoderKL, TCDScheduler
5
+ # (Assume ControlNet manual load or from_pretrained is already working)
 
 
 
6
  from controlnet_union import ControlNetModel_Union
7
  from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
8
+ from gradio_imageslider import ImageSlider
9
+ from huggingface_hub import hf_hub_download
10
 
11
  from PIL import Image, ImageDraw
12
  import numpy as np
13
 
14
+ # --- Load ControlNet and SDXL Fill Pipeline ---
15
+ # (Either manual download or via from_pretrained)
16
+ controlnet_model = ControlNetModel_Union.from_pretrained(
17
  "xinsir/controlnet-union-sdxl-1.0",
18
+ torch_dtype=torch.float16,
19
+ variant="fp16"
20
+ ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  vae = AutoencoderKL.from_pretrained(
23
+ "madebyollin/sdxl-vae-fp16-fix",
24
+ torch_dtype=torch.float16
25
  ).to("cuda")
26
 
27
  pipe = StableDiffusionXLFillPipeline.from_pretrained(
28
  "SG161222/RealVisXL_V5.0_Lightning",
29
  torch_dtype=torch.float16,
30
  vae=vae,
31
+ controlnet=controlnet_model,
32
  variant="fp16",
33
  ).to("cuda")
 
34
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
35
 
36
+ # --- Utility functions ---
37
  def can_expand(source_width, source_height, target_width, target_height, alignment):
 
38
  if alignment in ("Left", "Right") and source_width >= target_width:
39
  return False
40
  if alignment in ("Top", "Bottom") and source_height >= target_height:
41
  return False
42
  return True
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def prepare_image_and_mask(image, width, height, overlap_percentage,
46
+ resize_option, custom_resize_percentage,
47
+ alignment, overlap_left, overlap_right,
48
+ overlap_top, overlap_bottom):
49
+ target = (width, height)
50
+ scale = min(target[0] / image.width, target[1] / image.height)
51
+ w, h = int(image.width * scale), int(image.height * scale)
52
+ src = image.resize((w, h), Image.LANCZOS)
53
+
54
+ # Resize percentage
55
+ if resize_option == "Full": pct = 100
56
+ elif resize_option == "50%": pct = 50
57
+ elif resize_option == "33%": pct = 33
58
+ elif resize_option == "25%": pct = 25
59
+ else: pct = custom_resize_percentage
60
+
61
+ rw, rh = max(int(src.width * pct / 100), 64), max(int(src.height * pct / 100), 64)
62
+ src = src.resize((rw, rh), Image.LANCZOS)
63
+
64
+ ox = max(int(rw * overlap_percentage / 100), 1)
65
+ oy = max(int(rh * overlap_percentage / 100), 1)
66
+
67
+ # Margins
68
+ if alignment == "Middle": mx, my = (width - rw)//2, (height - rh)//2
69
+ elif alignment == "Left": mx, my = 0, (height - rh)//2
70
+ elif alignment == "Right": mx, my = width - rw, (height - rh)//2
71
+ elif alignment == "Top": mx, my = (width - rw)//2, 0
72
+ else: mx, my = (width - rw)//2, height - rh
73
+
74
+ mx, my = max(0, min(mx, width - rw)), max(0, min(my, height - rh))
75
+
76
+ bg = Image.new("RGB", target, (255,255,255))
77
+ bg.paste(src, (mx, my))
78
+
79
+ mask = Image.new("L", target, 255)
80
+ d = ImageDraw.Draw(mask)
81
+
82
+ lx = mx + (ox if overlap_left else 2)
83
+ rx = mx + rw - (ox if overlap_right else 2)
84
+ ty = my + (oy if overlap_top else 2)
85
+ by = my + rh - (oy if overlap_bottom else 2)
86
+
87
+ # Edge adjustments
88
+ if alignment == "Left": lx = mx + (ox if overlap_left else 0)
89
+ if alignment == "Right": rx = mx + rw - (ox if overlap_right else 0)
90
+ if alignment == "Top": ty = my + (oy if overlap_top else 0)
91
+ if alignment == "Bottom": by = my + rh - (oy if overlap_bottom else 0)
92
+
93
+ d.rectangle([(lx, ty), (rx, by)], fill=0)
94
+ return bg, mask
95
+
96
+
97
+ def preview_image_and_mask(*args):
98
+ bg, mask = prepare_image_and_mask(*args)
99
+ vis = bg.copy().convert("RGBA")
100
+ red = Image.new("RGBA", bg.size, (255,0,0,64))
101
+ overlay = Image.new("RGBA", bg.size, (0,0,0,0))
102
+ overlay.paste(red, (0,0), mask)
103
+ return Image.alpha_composite(vis, overlay)
104
+
105
+ # --- Fixed infer: return list for slider ---
106
  @spaces.GPU(duration=24)
107
+ def infer(image, width, height, overlap_percentage, num_inference_steps,
108
+ resize_option, custom_resize_percentage, prompt_input,
109
+ alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
110
+ background, mask = prepare_image_and_mask(
111
+ image, width, height, overlap_percentage,
112
+ resize_option, custom_resize_percentage,
113
+ alignment, overlap_left, overlap_right,
114
+ overlap_top, overlap_bottom
115
+ )
116
  if not can_expand(background.width, background.height, width, height, alignment):
117
  alignment = "Middle"
118
 
119
+ hole = background.copy()
120
+ hole.paste(0, (0,0), mask)
121
 
122
  final_prompt = f"{prompt_input} , high quality, 4k"
123
+ embeds = pipe.encode_prompt(final_prompt, "cuda", True)
124
+
125
+ # Run pipeline and grab last frame
126
+ gen = pipe(
127
+ prompt_embeds=embeds[0],
128
+ negative_prompt_embeds=embeds[1],
129
+ pooled_prompt_embeds=embeds[2],
130
+ negative_pooled_prompt_embeds=embeds[3],
131
+ image=hole,
 
 
 
 
 
132
  num_inference_steps=num_inference_steps
133
+ )
134
+ last = None
135
+ for img in gen:
136
+ last = img
137
 
138
+ out = last.convert("RGBA")
139
+ hole.paste(out, (0,0), mask)
140
+
141
+ # Return a list: [input_hole_image, final_output]
142
+ return [background, hole]
143
 
 
144
 
145
  def clear_result():
 
146
  return gr.update(value=None)
147
 
148
+ def preload_presets(ratio, w, h):
149
+ if ratio == "9:16": return 720, 1280, gr.update()
150
+ if ratio == "16:9": return 1280, 720, gr.update()
151
+ if ratio == "1:1": return 1024, 1024, gr.update()
152
+ return w, h, gr.update(open=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ def select_the_right_preset(w, h):
155
+ if (w,h) == (720,1280): return "9:16"
156
+ if (w,h) == (1280,720): return "16:9"
157
+ if (w,h) == (1024,1024): return "1:1"
158
+ return "Custom"
 
 
159
 
160
+ def toggle_custom_resize_slider(opt):
161
+ return gr.update(visible=(opt=="Custom"))
162
 
163
+ def update_history(img, history):
164
+ history = history or []
165
+ history.insert(0, img)
166
+ return history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ css = ".gradio-container { width: 1200px !important; }"
169
+ title = "<h1 align='center'>Diffusers Image Outpaint Lightning</h1>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ with gr.Blocks(css=css) as demo:
172
+ gr.HTML(title)
173
+ with gr.Row():
174
+ with gr.Column():
175
+ input_image = gr.Image(type="pil", label="Input Image")
176
+ prompt_input = gr.Textbox(label="Prompt (Optional)")
177
+ run_button = gr.Button("Generate")
178
+
179
+ target_ratio = gr.Radio(["9:16","16:9","1:1","Custom"], value="9:16", label="Expected Ratio")
180
+ alignment_dropdown = gr.Dropdown(["Middle","Left","Right","Top","Bottom"], value="Middle", label="Alignment")
181
+
182
+ with gr.Accordion("Advanced settings", open=False) as adv:
183
+ width_slider = gr.Slider(720,1536,step=8, value=720, label="Target Width")
184
+ height_slider = gr.Slider(720,1536,step=8, value=1280, label="Target Height")
185
+ num_steps = gr.Slider(4,12,step=1, value=8, label="Steps")
186
+ overlap_pct = gr.Slider(1,50,step=1, value=10, label="Mask overlap (%)")
187
+ overlap_top = gr.Checkbox(label="Overlap Top", value=True)
188
+ overlap_right = gr.Checkbox(label="Overlap Right", value=True)
189
+ overlap_left = gr.Checkbox(label="Overlap Left", value=True)
190
+ overlap_bottom= gr.Checkbox(label="Overlap Bottom", value=True)
191
+ resize_opt = gr.Radio(["Full","50%","33%","25%","Custom"], value="Full", label="Resize input image")
192
+ custom_resize = gr.Slider(1,100,step=1, value=50, visible=False, label="Custom resize (%)")
193
+ preview_btn = gr.Button("Preview alignment and mask")
194
+
195
+ gr.Examples(
196
+ examples=[
197
+ ["./examples/example_1.webp",1280,720,"Middle"],
198
+ ["./examples/example_2.jpg",1440,810,"Left"],
199
+ ["./examples/example_3.jpg",1024,1024,"Top"],
200
+ ["./examples/example_3.jpg",1024,1024,"Bottom"]
201
+ ],
202
+ inputs=[input_image,width_slider,height_slider,alignment_dropdown]
203
+ )
204
+
205
+ with gr.Column():
206
+ result = ImageSlider(label="Comparison", interactive=False, type="pil", slider_color="pink")
207
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain")
208
+ preview_image = gr.Image(label="Preview")
209
+
210
+ # Callbacks
211
+ run_button.click(clear_result, None, result)
212
+ run_button.click(
213
+ infer,
214
+ inputs=[ input_image, width_slider, height_slider, overlap_pct, num_steps,
215
+ resize_opt, custom_resize, prompt_input, alignment_dropdown,
216
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
217
+ outputs=result
218
+ ).then(update_history, inputs=[result, history_gallery], outputs=history_gallery)
219
+
220
+ target_ratio.change(preload_presets, [target_ratio, width_slider, height_slider], [width_slider, height_slider, adv])
221
+ width_slider.change(select_the_right_preset, [width_slider, height_slider], target_ratio)
222
+ height_slider.change(select_the_right_preset, [width_slider, height_slider], target_ratio)
223
+ resize_opt.change(toggle_custom_resize_slider, resize_opt, custom_resize)
224
+ preview_btn.click(preview_image_and_mask,
225
+ [input_image, width_slider, height_slider, overlap_pct, resize_opt, custom_resize, alignment_dropdown,
226
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
227
+ preview_image)
228
 
 
 
 
 
 
 
 
229
 
230
  demo.queue(max_size=20).launch(share=False, ssr_mode=False, show_error=True)