Ashoka74 commited on
Commit
c1ad03b
Β·
verified Β·
1 Parent(s): 67198ef

Update gradio_demo.py

Browse files
Files changed (1) hide show
  1. gradio_demo.py +1104 -1103
gradio_demo.py CHANGED
@@ -1,1103 +1,1104 @@
1
- import os
2
- import math
3
- import gradio as gr
4
- import numpy as np
5
- import torch
6
- import safetensors.torch as sf
7
- import db_examples
8
- import datetime
9
- from pathlib import Path
10
- from io import BytesIO
11
-
12
- from PIL import Image
13
- from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
14
- from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
15
- from diffusers.models.attention_processor import AttnProcessor2_0
16
- from transformers import CLIPTextModel, CLIPTokenizer
17
- from briarmbg import BriaRMBG
18
- from enum import Enum
19
- from torch.hub import download_url_to_file
20
-
21
- from torch.hub import download_url_to_file
22
- import cv2
23
-
24
- from typing import Optional
25
-
26
- from Depth.depth_anything_v2.dpt import DepthAnythingV2
27
-
28
-
29
-
30
- # from FLORENCE
31
- import spaces
32
- import supervision as sv
33
- import torch
34
- from PIL import Image
35
-
36
- from utils.sam import load_sam_image_model, run_sam_inference
37
-
38
-
39
- try:
40
- import xformers
41
- import xformers.ops
42
- XFORMERS_AVAILABLE = True
43
- print("xformers is available - Using memory efficient attention")
44
- except ImportError:
45
- XFORMERS_AVAILABLE = False
46
- print("xformers not available - Using default attention")
47
-
48
- # Memory optimizations for RTX 2070
49
- torch.backends.cudnn.benchmark = True
50
- if torch.cuda.is_available():
51
- torch.backends.cuda.matmul.allow_tf32 = True
52
- torch.backends.cudnn.allow_tf32 = True
53
- # Set a smaller attention slice size for RTX 2070
54
- torch.backends.cuda.max_split_size_mb = 512
55
- device = torch.device('cuda')
56
- else:
57
- device = torch.device('cpu')
58
-
59
- # 'stablediffusionapi/realistic-vision-v51'
60
- # 'runwayml/stable-diffusion-v1-5'
61
- sd15_name = 'stablediffusionapi/realistic-vision-v51'
62
- tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
63
- text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
64
- vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
65
- unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
66
- rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
67
-
68
- model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
69
- model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
70
- model = model.to(device)
71
- model.eval()
72
-
73
- # Change UNet
74
-
75
- with torch.no_grad():
76
- new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
77
- new_conv_in.weight.zero_()
78
- new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
79
- new_conv_in.bias = unet.conv_in.bias
80
- unet.conv_in = new_conv_in
81
-
82
-
83
- unet_original_forward = unet.forward
84
-
85
-
86
- def enable_efficient_attention():
87
- if XFORMERS_AVAILABLE:
88
- try:
89
- # RTX 2070 specific settings
90
- unet.set_use_memory_efficient_attention_xformers(True)
91
- vae.set_use_memory_efficient_attention_xformers(True)
92
- print("Enabled xformers memory efficient attention")
93
- except Exception as e:
94
- print(f"Xformers error: {e}")
95
- print("Falling back to sliced attention")
96
- # Use sliced attention for RTX 2070
97
- unet.set_attention_slice_size(4)
98
- vae.set_attention_slice_size(4)
99
- unet.set_attn_processor(AttnProcessor2_0())
100
- vae.set_attn_processor(AttnProcessor2_0())
101
- else:
102
- # Fallback for when xformers is not available
103
- print("Using sliced attention")
104
- unet.set_attention_slice_size(4)
105
- vae.set_attention_slice_size(4)
106
- unet.set_attn_processor(AttnProcessor2_0())
107
- vae.set_attn_processor(AttnProcessor2_0())
108
-
109
- # Add memory clearing function
110
- def clear_memory():
111
- if torch.cuda.is_available():
112
- torch.cuda.empty_cache()
113
- torch.cuda.synchronize()
114
-
115
- # Enable efficient attention
116
- enable_efficient_attention()
117
-
118
-
119
- def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
120
- c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
121
- c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
122
- new_sample = torch.cat([sample, c_concat], dim=1)
123
- kwargs['cross_attention_kwargs'] = {}
124
- return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
125
-
126
-
127
- unet.forward = hooked_unet_forward
128
-
129
- # Load
130
-
131
- model_path = './models/iclight_sd15_fc.safetensors'
132
- # model_path = './models/iclight_sd15_fbc.safetensors'
133
-
134
-
135
- # if not os.path.exists(model_path):
136
- # download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
137
-
138
- sd_offset = sf.load_file(model_path)
139
- sd_origin = unet.state_dict()
140
- keys = sd_origin.keys()
141
- sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
142
- unet.load_state_dict(sd_merged, strict=True)
143
- del sd_offset, sd_origin, sd_merged, keys
144
-
145
- # Device
146
-
147
- # device = torch.device('cuda')
148
- # text_encoder = text_encoder.to(device=device, dtype=torch.float16)
149
- # vae = vae.to(device=device, dtype=torch.bfloat16)
150
- # unet = unet.to(device=device, dtype=torch.float16)
151
- # rmbg = rmbg.to(device=device, dtype=torch.float32)
152
-
153
-
154
- # Device and dtype setup
155
- device = torch.device('cuda')
156
- dtype = torch.float16 # RTX 2070 works well with float16
157
-
158
- # Memory optimizations for RTX 2070
159
- torch.backends.cudnn.benchmark = True
160
- if torch.cuda.is_available():
161
- torch.backends.cuda.matmul.allow_tf32 = True
162
- torch.backends.cudnn.allow_tf32 = True
163
- # Set a very small attention slice size for RTX 2070 to avoid OOM
164
- torch.backends.cuda.max_split_size_mb = 128
165
-
166
- # Move models to device with consistent dtype
167
- text_encoder = text_encoder.to(device=device, dtype=dtype)
168
- vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
169
- unet = unet.to(device=device, dtype=dtype)
170
- rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
171
-
172
-
173
- ddim_scheduler = DDIMScheduler(
174
- num_train_timesteps=1000,
175
- beta_start=0.00085,
176
- beta_end=0.012,
177
- beta_schedule="scaled_linear",
178
- clip_sample=False,
179
- set_alpha_to_one=False,
180
- steps_offset=1,
181
- )
182
-
183
- euler_a_scheduler = EulerAncestralDiscreteScheduler(
184
- num_train_timesteps=1000,
185
- beta_start=0.00085,
186
- beta_end=0.012,
187
- steps_offset=1
188
- )
189
-
190
- dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
191
- num_train_timesteps=1000,
192
- beta_start=0.00085,
193
- beta_end=0.012,
194
- algorithm_type="sde-dpmsolver++",
195
- use_karras_sigmas=True,
196
- steps_offset=1
197
- )
198
-
199
- # Pipelines
200
-
201
- t2i_pipe = StableDiffusionPipeline(
202
- vae=vae,
203
- text_encoder=text_encoder,
204
- tokenizer=tokenizer,
205
- unet=unet,
206
- scheduler=dpmpp_2m_sde_karras_scheduler,
207
- safety_checker=None,
208
- requires_safety_checker=False,
209
- feature_extractor=None,
210
- image_encoder=None
211
- )
212
-
213
- i2i_pipe = StableDiffusionImg2ImgPipeline(
214
- vae=vae,
215
- text_encoder=text_encoder,
216
- tokenizer=tokenizer,
217
- unet=unet,
218
- scheduler=dpmpp_2m_sde_karras_scheduler,
219
- safety_checker=None,
220
- requires_safety_checker=False,
221
- feature_extractor=None,
222
- image_encoder=None
223
- )
224
-
225
-
226
- @torch.inference_mode()
227
- def encode_prompt_inner(txt: str):
228
- max_length = tokenizer.model_max_length
229
- chunk_length = tokenizer.model_max_length - 2
230
- id_start = tokenizer.bos_token_id
231
- id_end = tokenizer.eos_token_id
232
- id_pad = id_end
233
-
234
- def pad(x, p, i):
235
- return x[:i] if len(x) >= i else x + [p] * (i - len(x))
236
-
237
- tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
238
- chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
239
- chunks = [pad(ck, id_pad, max_length) for ck in chunks]
240
-
241
- token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
242
- conds = text_encoder(token_ids).last_hidden_state
243
-
244
- return conds
245
-
246
-
247
- @torch.inference_mode()
248
- def encode_prompt_pair(positive_prompt, negative_prompt):
249
- c = encode_prompt_inner(positive_prompt)
250
- uc = encode_prompt_inner(negative_prompt)
251
-
252
- c_len = float(len(c))
253
- uc_len = float(len(uc))
254
- max_count = max(c_len, uc_len)
255
- c_repeat = int(math.ceil(max_count / c_len))
256
- uc_repeat = int(math.ceil(max_count / uc_len))
257
- max_chunk = max(len(c), len(uc))
258
-
259
- c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
260
- uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
261
-
262
- c = torch.cat([p[None, ...] for p in c], dim=1)
263
- uc = torch.cat([p[None, ...] for p in uc], dim=1)
264
-
265
- return c, uc
266
-
267
-
268
- @torch.inference_mode()
269
- def pytorch2numpy(imgs, quant=True):
270
- results = []
271
- for x in imgs:
272
- y = x.movedim(0, -1)
273
-
274
- if quant:
275
- y = y * 127.5 + 127.5
276
- y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
277
- else:
278
- y = y * 0.5 + 0.5
279
- y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
280
-
281
- results.append(y)
282
- return results
283
-
284
-
285
- @torch.inference_mode()
286
- def numpy2pytorch(imgs):
287
- h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
288
- h = h.movedim(-1, 1)
289
- return h
290
-
291
-
292
- def resize_and_center_crop(image, target_width, target_height):
293
- pil_image = Image.fromarray(image)
294
- original_width, original_height = pil_image.size
295
- scale_factor = max(target_width / original_width, target_height / original_height)
296
- resized_width = int(round(original_width * scale_factor))
297
- resized_height = int(round(original_height * scale_factor))
298
- resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
299
- left = (resized_width - target_width) / 2
300
- top = (resized_height - target_height) / 2
301
- right = (resized_width + target_width) / 2
302
- bottom = (resized_height + target_height) / 2
303
- cropped_image = resized_image.crop((left, top, right, bottom))
304
- return np.array(cropped_image)
305
-
306
-
307
- def resize_without_crop(image, target_width, target_height):
308
- pil_image = Image.fromarray(image)
309
- resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
310
- return np.array(resized_image)
311
-
312
-
313
- @torch.inference_mode()
314
- def run_rmbg(img, sigma=0.0):
315
- # Convert RGBA to RGB if needed
316
- if img.shape[-1] == 4:
317
- # Use white background for alpha composition
318
- alpha = img[..., 3:] / 255.0
319
- rgb = img[..., :3]
320
- white_bg = np.ones_like(rgb) * 255
321
- img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
322
-
323
- H, W, C = img.shape
324
- assert C == 3
325
- k = (256.0 / float(H * W)) ** 0.5
326
- feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
327
- feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
328
- alpha = rmbg(feed)[0][0]
329
- alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
330
- alpha = alpha.movedim(1, -1)[0]
331
- alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
332
-
333
- # Create RGBA image
334
- rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
335
- result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
336
- return result.clip(0, 255).astype(np.uint8), rgba
337
- @torch.inference_mode()
338
- def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
339
- clear_memory()
340
-
341
- # Get input dimensions
342
- input_height, input_width = input_fg.shape[:2]
343
-
344
- bg_source = BGSource(bg_source)
345
-
346
-
347
- if bg_source == BGSource.UPLOAD:
348
- pass
349
- elif bg_source == BGSource.UPLOAD_FLIP:
350
- input_bg = np.fliplr(input_bg)
351
- elif bg_source == BGSource.GREY:
352
- input_bg = np.zeros(shape=(input_height, input_width, 3), dtype=np.uint8) + 64
353
- elif bg_source == BGSource.LEFT:
354
- gradient = np.linspace(255, 0, input_width)
355
- image = np.tile(gradient, (input_height, 1))
356
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
357
- elif bg_source == BGSource.RIGHT:
358
- gradient = np.linspace(0, 255, input_width)
359
- image = np.tile(gradient, (input_height, 1))
360
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
361
- elif bg_source == BGSource.TOP:
362
- gradient = np.linspace(255, 0, input_height)[:, None]
363
- image = np.tile(gradient, (1, input_width))
364
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
365
- elif bg_source == BGSource.BOTTOM:
366
- gradient = np.linspace(0, 255, input_height)[:, None]
367
- image = np.tile(gradient, (1, input_width))
368
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
369
- else:
370
- raise 'Wrong initial latent!'
371
-
372
- rng = torch.Generator(device=device).manual_seed(int(seed))
373
-
374
- # Use input dimensions directly
375
- fg = resize_without_crop(input_fg, input_width, input_height)
376
-
377
- concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
378
- concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
379
-
380
- conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
381
-
382
- if input_bg is None:
383
- latents = t2i_pipe(
384
- prompt_embeds=conds,
385
- negative_prompt_embeds=unconds,
386
- width=input_width,
387
- height=input_height,
388
- num_inference_steps=steps,
389
- num_images_per_prompt=num_samples,
390
- generator=rng,
391
- output_type='latent',
392
- guidance_scale=cfg,
393
- cross_attention_kwargs={'concat_conds': concat_conds},
394
- ).images.to(vae.dtype) / vae.config.scaling_factor
395
- else:
396
- bg = resize_without_crop(input_bg, input_width, input_height)
397
- bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
398
- bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
399
- latents = i2i_pipe(
400
- image=bg_latent,
401
- strength=lowres_denoise,
402
- prompt_embeds=conds,
403
- negative_prompt_embeds=unconds,
404
- width=input_width,
405
- height=input_height,
406
- num_inference_steps=int(round(steps / lowres_denoise)),
407
- num_images_per_prompt=num_samples,
408
- generator=rng,
409
- output_type='latent',
410
- guidance_scale=cfg,
411
- cross_attention_kwargs={'concat_conds': concat_conds},
412
- ).images.to(vae.dtype) / vae.config.scaling_factor
413
-
414
- pixels = vae.decode(latents).sample
415
- pixels = pytorch2numpy(pixels)
416
- pixels = [resize_without_crop(
417
- image=p,
418
- target_width=int(round(input_width * highres_scale / 64.0) * 64),
419
- target_height=int(round(input_height * highres_scale / 64.0) * 64))
420
- for p in pixels]
421
-
422
- pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
423
- latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
424
- latents = latents.to(device=unet.device, dtype=unet.dtype)
425
-
426
- highres_height, highres_width = latents.shape[2] * 8, latents.shape[3] * 8
427
-
428
- fg = resize_without_crop(input_fg, highres_width, highres_height)
429
- concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
430
- concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
431
-
432
- latents = i2i_pipe(
433
- image=latents,
434
- strength=highres_denoise,
435
- prompt_embeds=conds,
436
- negative_prompt_embeds=unconds,
437
- width=highres_width,
438
- height=highres_height,
439
- num_inference_steps=int(round(steps / highres_denoise)),
440
- num_images_per_prompt=num_samples,
441
- generator=rng,
442
- output_type='latent',
443
- guidance_scale=cfg,
444
- cross_attention_kwargs={'concat_conds': concat_conds},
445
- ).images.to(vae.dtype) / vae.config.scaling_factor
446
-
447
- pixels = vae.decode(latents).sample
448
- pixels = pytorch2numpy(pixels)
449
-
450
- # Resize back to input dimensions
451
- pixels = [resize_without_crop(p, input_width, input_height) for p in pixels]
452
- pixels = np.stack(pixels)
453
-
454
- return pixels
455
-
456
- @torch.inference_mode()
457
- def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
458
- clear_memory()
459
- bg_source = BGSource(bg_source)
460
-
461
- if bg_source == BGSource.UPLOAD:
462
- pass
463
- elif bg_source == BGSource.UPLOAD_FLIP:
464
- input_bg = np.fliplr(input_bg)
465
- elif bg_source == BGSource.GREY:
466
- input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
467
- elif bg_source == BGSource.LEFT:
468
- gradient = np.linspace(224, 32, image_width)
469
- image = np.tile(gradient, (image_height, 1))
470
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
471
- elif bg_source == BGSource.RIGHT:
472
- gradient = np.linspace(32, 224, image_width)
473
- image = np.tile(gradient, (image_height, 1))
474
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
475
- elif bg_source == BGSource.TOP:
476
- gradient = np.linspace(224, 32, image_height)[:, None]
477
- image = np.tile(gradient, (1, image_width))
478
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
479
- elif bg_source == BGSource.BOTTOM:
480
- gradient = np.linspace(32, 224, image_height)[:, None]
481
- image = np.tile(gradient, (1, image_width))
482
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
483
- else:
484
- raise 'Wrong background source!'
485
-
486
- rng = torch.Generator(device=device).manual_seed(seed)
487
-
488
- fg = resize_and_center_crop(input_fg, image_width, image_height)
489
- bg = resize_and_center_crop(input_bg, image_width, image_height)
490
- concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
491
- concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
492
- concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
493
-
494
- conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
495
-
496
- latents = t2i_pipe(
497
- prompt_embeds=conds,
498
- negative_prompt_embeds=unconds,
499
- width=image_width,
500
- height=image_height,
501
- num_inference_steps=steps,
502
- num_images_per_prompt=num_samples,
503
- generator=rng,
504
- output_type='latent',
505
- guidance_scale=cfg,
506
- cross_attention_kwargs={'concat_conds': concat_conds},
507
- ).images.to(vae.dtype) / vae.config.scaling_factor
508
-
509
- pixels = vae.decode(latents).sample
510
- pixels = pytorch2numpy(pixels)
511
- pixels = [resize_without_crop(
512
- image=p,
513
- target_width=int(round(image_width * highres_scale / 64.0) * 64),
514
- target_height=int(round(image_height * highres_scale / 64.0) * 64))
515
- for p in pixels]
516
-
517
- pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
518
- latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
519
- latents = latents.to(device=unet.device, dtype=unet.dtype)
520
-
521
- image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
522
- fg = resize_and_center_crop(input_fg, image_width, image_height)
523
- bg = resize_and_center_crop(input_bg, image_width, image_height)
524
- concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
525
- concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
526
- concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
527
-
528
- latents = i2i_pipe(
529
- image=latents,
530
- strength=highres_denoise,
531
- prompt_embeds=conds,
532
- negative_prompt_embeds=unconds,
533
- width=image_width,
534
- height=image_height,
535
- num_inference_steps=int(round(steps / highres_denoise)),
536
- num_images_per_prompt=num_samples,
537
- generator=rng,
538
- output_type='latent',
539
- guidance_scale=cfg,
540
- cross_attention_kwargs={'concat_conds': concat_conds},
541
- ).images.to(vae.dtype) / vae.config.scaling_factor
542
-
543
- pixels = vae.decode(latents).sample
544
- pixels = pytorch2numpy(pixels, quant=False)
545
-
546
- clear_memory()
547
- return pixels, [fg, bg]
548
-
549
-
550
- @torch.inference_mode()
551
- def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
552
- input_fg, matting = run_rmbg(input_fg)
553
- results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
554
- return input_fg, results
555
-
556
-
557
-
558
- @torch.inference_mode()
559
- def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
560
- bg_source = BGSource(bg_source)
561
-
562
- # Convert numerical inputs to appropriate types
563
- image_width = int(image_width)
564
- image_height = int(image_height)
565
- num_samples = int(num_samples)
566
- seed = int(seed)
567
- steps = int(steps)
568
- cfg = float(cfg)
569
- highres_scale = float(highres_scale)
570
- highres_denoise = float(highres_denoise)
571
-
572
- if bg_source == BGSource.UPLOAD:
573
- pass
574
- elif bg_source == BGSource.UPLOAD_FLIP:
575
- input_bg = np.fliplr(input_bg)
576
- elif bg_source == BGSource.GREY:
577
- input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
578
- elif bg_source == BGSource.LEFT:
579
- gradient = np.linspace(224, 32, image_width)
580
- image = np.tile(gradient, (image_height, 1))
581
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
582
- elif bg_source == BGSource.RIGHT:
583
- gradient = np.linspace(32, 224, image_width)
584
- image = np.tile(gradient, (image_height, 1))
585
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
586
- elif bg_source == BGSource.TOP:
587
- gradient = np.linspace(224, 32, image_height)[:, None]
588
- image = np.tile(gradient, (1, image_width))
589
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
590
- elif bg_source == BGSource.BOTTOM:
591
- gradient = np.linspace(32, 224, image_height)[:, None]
592
- image = np.tile(gradient, (1, image_width))
593
- input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
594
- else:
595
- raise ValueError('Wrong background source!')
596
-
597
- input_fg, matting = run_rmbg(input_fg)
598
- results, extra_images = process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
599
- results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
600
- final_results = results + extra_images
601
-
602
- # Save the generated images
603
- save_images(results, prefix="relight")
604
-
605
- return results
606
-
607
-
608
- quick_prompts = [
609
- 'sunshine from window',
610
- 'neon light, city',
611
- 'sunset over sea',
612
- 'golden time',
613
- 'sci-fi RGB glowing, cyberpunk',
614
- 'natural lighting',
615
- 'warm atmosphere, at home, bedroom',
616
- 'magic lit',
617
- 'evil, gothic, Yharnam',
618
- 'light and shadow',
619
- 'shadow from window',
620
- 'soft studio lighting',
621
- 'home atmosphere, cozy bedroom illumination',
622
- 'neon, Wong Kar-wai, warm'
623
- ]
624
- quick_prompts = [[x] for x in quick_prompts]
625
-
626
-
627
- quick_subjects = [
628
- 'modern sofa, high quality leather',
629
- 'elegant dining table, polished wood',
630
- 'luxurious bed, premium mattress',
631
- 'minimalist office desk, clean design',
632
- 'vintage wooden cabinet, antique finish',
633
- ]
634
- quick_subjects = [[x] for x in quick_subjects]
635
-
636
-
637
- class BGSource(Enum):
638
- UPLOAD = "Use Background Image"
639
- UPLOAD_FLIP = "Use Flipped Background Image"
640
- LEFT = "Left Light"
641
- RIGHT = "Right Light"
642
- TOP = "Top Light"
643
- BOTTOM = "Bottom Light"
644
- GREY = "Ambient"
645
-
646
- # Add save function
647
- def save_images(images, prefix="relight"):
648
- # Create output directory if it doesn't exist
649
- output_dir = Path("outputs")
650
- output_dir.mkdir(exist_ok=True)
651
-
652
- # Create timestamp for unique filenames
653
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
654
-
655
- saved_paths = []
656
- for i, img in enumerate(images):
657
- if isinstance(img, np.ndarray):
658
- # Convert to PIL Image if numpy array
659
- img = Image.fromarray(img)
660
-
661
- # Create filename with timestamp
662
- filename = f"{prefix}_{timestamp}_{i+1}.png"
663
- filepath = output_dir / filename
664
-
665
- # Save image
666
- img.save(filepath)
667
-
668
-
669
- # print(f"Saved {len(saved_paths)} images to {output_dir}")
670
- return saved_paths
671
-
672
-
673
- class MaskMover:
674
- def __init__(self):
675
- self.extracted_fg = None
676
- self.original_fg = None # Store original foreground
677
-
678
- def set_extracted_fg(self, fg_image):
679
- """Store the extracted foreground with alpha channel"""
680
- if isinstance(fg_image, np.ndarray):
681
- self.extracted_fg = fg_image.copy()
682
- self.original_fg = fg_image.copy()
683
- else:
684
- self.extracted_fg = np.array(fg_image)
685
- self.original_fg = np.array(fg_image)
686
- return self.extracted_fg
687
-
688
- def create_composite(self, background, x_pos, y_pos, scale=1.0):
689
- """Create composite with foreground at specified position"""
690
- if self.original_fg is None or background is None:
691
- return background
692
-
693
- # Convert inputs to PIL Images
694
- if isinstance(background, np.ndarray):
695
- bg = Image.fromarray(background).convert('RGBA')
696
- else:
697
- bg = background.convert('RGBA')
698
-
699
- if isinstance(self.original_fg, np.ndarray):
700
- fg = Image.fromarray(self.original_fg).convert('RGBA')
701
- else:
702
- fg = self.original_fg.convert('RGBA')
703
-
704
- # Scale the foreground size
705
- new_width = int(fg.width * scale)
706
- new_height = int(fg.height * scale)
707
- fg = fg.resize((new_width, new_height), Image.LANCZOS)
708
-
709
- # Center the scaled foreground at the position
710
- x = int(x_pos - new_width / 2)
711
- y = int(y_pos - new_height / 2)
712
-
713
- # Create composite
714
- result = bg.copy()
715
- result.paste(fg, (x, y), fg) # Use fg as the mask (requires fg to be in 'RGBA' mode)
716
-
717
- return np.array(result.convert('RGB')) # Convert back to 'RGB' if needed
718
-
719
- def get_depth(image):
720
- if image is None:
721
- return None
722
- # Convert from PIL/gradio format to cv2
723
- raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
724
- # Get depth map
725
- depth = model.infer_image(raw_img) # HxW raw depth map
726
- # Normalize depth for visualization
727
- depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
728
- # Convert to RGB for display
729
- depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
730
- depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
731
- return Image.fromarray(depth_colored)
732
-
733
-
734
- from PIL import Image
735
-
736
- def compress_image(image):
737
- # Convert Gradio image (numpy array) to PIL Image
738
- img = Image.fromarray(image)
739
-
740
- # Resize image if dimensions are too large
741
- max_size = 1024 # Maximum dimension size
742
- if img.width > max_size or img.height > max_size:
743
- ratio = min(max_size/img.width, max_size/img.height)
744
- new_size = (int(img.width * ratio), int(img.height * ratio))
745
- img = img.resize(new_size, Image.Resampling.LANCZOS)
746
-
747
- quality = 95 # Start with high quality
748
- img.save("compressed_image.jpg", "JPEG", quality=quality) # Initial save
749
-
750
- # Check file size and adjust quality if necessary
751
- while os.path.getsize("compressed_image.jpg") > 100 * 1024: # 100KB limit
752
- quality -= 5 # Decrease quality
753
- img.save("compressed_image.jpg", "JPEG", quality=quality)
754
- if quality < 20: # Prevent quality from going too low
755
- break
756
-
757
- # Convert back to numpy array for Gradio
758
- compressed_img = np.array(Image.open("compressed_image.jpg"))
759
- return compressed_img
760
-
761
-
762
- block = gr.Blocks().queue()
763
- with block:
764
- with gr.Tab("Text"):
765
- with gr.Row():
766
- gr.Markdown("## Product Placement from Text")
767
- with gr.Row():
768
- with gr.Column():
769
- with gr.Row():
770
- input_fg = gr.Image(type="numpy", label="Image", height=480)
771
- output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
772
- with gr.Group():
773
- prompt = gr.Textbox(label="Prompt")
774
- bg_source = gr.Radio(choices=[e.value for e in BGSource],
775
- value=BGSource.GREY.value,
776
- label="Lighting Preference (Initial Latent)", type='value')
777
- example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
778
- example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
779
- relight_button = gr.Button(value="Relight")
780
-
781
- with gr.Group():
782
- with gr.Row():
783
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
784
- seed = gr.Number(label="Seed", value=12345, precision=0)
785
-
786
- with gr.Row():
787
- image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
788
- image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
789
-
790
- with gr.Accordion("Advanced options", open=False):
791
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=15, step=1)
792
- cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01)
793
- lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
794
- highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
795
- highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
796
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
797
- n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
798
- with gr.Column():
799
- result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
800
- with gr.Row():
801
- dummy_image_for_outputs = gr.Image(visible=False, label='Result')
802
- # gr.Examples(
803
- # fn=lambda *args: ([args[-1]], None),
804
- # examples=db_examples.foreground_conditioned_examples,
805
- # inputs=[
806
- # input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
807
- # ],
808
- # outputs=[result_gallery, output_bg],
809
- # run_on_click=True, examples_per_page=1024
810
- # )
811
- ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
812
- relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
813
- example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
814
- example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
815
-
816
- with gr.Tab("Background", visible=False):
817
- mask_mover = MaskMover()
818
-
819
-
820
- with gr.Row():
821
- gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
822
- gr.Markdown("πŸ’Ύ Generated images are automatically saved to 'outputs' folder")
823
-
824
- with gr.Row():
825
- with gr.Column():
826
- # Step 1: Input and Extract
827
- with gr.Row():
828
- with gr.Group():
829
- gr.Markdown("### Step 1: Extract Foreground")
830
- input_image = gr.Image(type="numpy", label="Input Image", height=480)
831
- # find_objects_button = gr.Button(value="Find Objects")
832
- extract_button = gr.Button(value="Remove Background")
833
- extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
834
-
835
- with gr.Row():
836
- # Step 2: Background and Position
837
- with gr.Group():
838
- gr.Markdown("### Step 2: Position on Background")
839
- input_bg = gr.Image(type="numpy", label="Background Image", height=480)
840
-
841
- with gr.Row():
842
- x_slider = gr.Slider(
843
- minimum=0,
844
- maximum=1000,
845
- label="X Position",
846
- value=500,
847
- visible=False
848
- )
849
- y_slider = gr.Slider(
850
- minimum=0,
851
- maximum=1000,
852
- label="Y Position",
853
- value=500,
854
- visible=False
855
- )
856
- fg_scale_slider = gr.Slider(
857
- label="Foreground Scale",
858
- minimum=0.01,
859
- maximum=3.0,
860
- value=1.0,
861
- step=0.01
862
- )
863
-
864
- editor = gr.ImageEditor(
865
- type="numpy",
866
- label="Position Foreground",
867
- height=480,
868
- visible=False
869
- )
870
- get_depth_button = gr.Button(value="Get Depth")
871
- depth_image = gr.Image(type="numpy", label="Depth Image", height=480)
872
-
873
- # Step 3: Relighting Options
874
- with gr.Group():
875
- gr.Markdown("### Step 3: Relighting Settings")
876
- prompt = gr.Textbox(label="Prompt")
877
- bg_source = gr.Radio(
878
- choices=[e.value for e in BGSource],
879
- value=BGSource.UPLOAD.value,
880
- label="Background Source",
881
- type='value'
882
- )
883
-
884
- example_prompts = gr.Dataset(
885
- samples=quick_prompts,
886
- label='Prompt Quick List',
887
- components=[prompt]
888
- )
889
- # bg_gallery = gr.Gallery(
890
- # height=450,
891
- # label='Background Quick List',
892
- # value=db_examples.bg_samples,
893
- # columns=5,
894
- # allow_preview=False
895
- # )
896
- relight_button_bg = gr.Button(value="Relight")
897
-
898
- # Additional settings
899
- with gr.Group():
900
- with gr.Row():
901
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
902
- seed = gr.Number(label="Seed", value=12345, precision=0)
903
- with gr.Row():
904
- image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
905
- image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
906
-
907
- with gr.Accordion("Advanced options", open=False):
908
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
909
- cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
910
- highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=2.0, value=1.2, step=0.01)
911
- highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
912
- a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
913
- n_prompt = gr.Textbox(
914
- label="Negative Prompt",
915
- value='lowres, bad anatomy, bad hands, cropped, worst quality'
916
- )
917
-
918
- with gr.Column():
919
- result_gallery = gr.Image(height=832, label='Outputs')
920
-
921
- def extract_foreground(image):
922
- if image is None:
923
- return None, gr.update(visible=True), gr.update(visible=True)
924
- result, rgba = run_rmbg(image)
925
- mask_mover.set_extracted_fg(rgba)
926
-
927
- return result, gr.update(visible=True), gr.update(visible=True)
928
-
929
-
930
- original_bg = None
931
-
932
- extract_button.click(
933
- fn=extract_foreground,
934
- inputs=[input_image],
935
- outputs=[extracted_fg, x_slider, y_slider]
936
- )
937
-
938
- # find_objects_button.click(
939
- # fn=find_objects,
940
- # inputs=[input_image],
941
- # outputs=[extracted_fg]
942
- # )
943
-
944
- get_depth_button.click(
945
- fn=get_depth,
946
- inputs=[input_bg],
947
- outputs=[depth_image]
948
- )
949
-
950
- # def update_position(background, x_pos, y_pos, scale):
951
- # """Update composite when position changes"""
952
- # global original_bg
953
- # if background is None:
954
- # return None
955
-
956
- # if original_bg is None:
957
- # original_bg = background.copy()
958
-
959
- # # Convert string values to float
960
- # x_pos = float(x_pos)
961
- # y_pos = float(y_pos)
962
- # scale = float(scale)
963
-
964
- # return mask_mover.create_composite(original_bg, x_pos, y_pos, scale)
965
-
966
- class BackgroundManager:
967
- def __init__(self):
968
- self.original_bg = None
969
-
970
- def update_position(self, background, x_pos, y_pos, scale):
971
- """Update composite when position changes"""
972
- if background is None:
973
- return None
974
-
975
- if self.original_bg is None:
976
- self.original_bg = background.copy()
977
-
978
- # Convert string values to float
979
- x_pos = float(x_pos)
980
- y_pos = float(y_pos)
981
- scale = float(scale)
982
-
983
- return mask_mover.create_composite(self.original_bg, x_pos, y_pos, scale)
984
-
985
- # Create an instance of BackgroundManager
986
- bg_manager = BackgroundManager()
987
-
988
-
989
- x_slider.change(
990
- fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
991
- inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
992
- outputs=[input_bg]
993
- )
994
-
995
- y_slider.change(
996
- fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
997
- inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
998
- outputs=[input_bg]
999
- )
1000
-
1001
- fg_scale_slider.change(
1002
- fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1003
- inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1004
- outputs=[input_bg]
1005
- )
1006
-
1007
- # Update inputs list to include fg_scale_slider
1008
-
1009
- def process_relight_with_position(*args):
1010
- if mask_mover.extracted_fg is None:
1011
- gr.Warning("Please extract foreground first")
1012
- return None
1013
-
1014
- background = args[1] # Get background image
1015
- x_pos = float(args[-3]) # x_slider value
1016
- y_pos = float(args[-2]) # y_slider value
1017
- scale = float(args[-1]) # fg_scale_slider value
1018
-
1019
- # Get original foreground size after scaling
1020
- fg = Image.fromarray(mask_mover.original_fg)
1021
- new_width = int(fg.width * scale)
1022
- new_height = int(fg.height * scale)
1023
-
1024
- # Calculate crop region around foreground position
1025
- crop_x = int(x_pos - new_width/2)
1026
- crop_y = int(y_pos - new_height/2)
1027
- crop_width = new_width
1028
- crop_height = new_height
1029
-
1030
- # Add padding for context (20% extra on each side)
1031
- padding = 0.2
1032
- crop_x = int(crop_x - crop_width * padding)
1033
- crop_y = int(crop_y - crop_height * padding)
1034
- crop_width = int(crop_width * (1 + 2 * padding))
1035
- crop_height = int(crop_height * (1 + 2 * padding))
1036
-
1037
- # Ensure crop dimensions are multiples of 8
1038
- crop_width = ((crop_width + 7) // 8) * 8
1039
- crop_height = ((crop_height + 7) // 8) * 8
1040
-
1041
- # Ensure crop region is within image bounds
1042
- bg_height, bg_width = background.shape[:2]
1043
- crop_x = max(0, min(crop_x, bg_width - crop_width))
1044
- crop_y = max(0, min(crop_y, bg_height - crop_height))
1045
-
1046
- # Get actual crop dimensions after boundary check
1047
- crop_width = min(crop_width, bg_width - crop_x)
1048
- crop_height = min(crop_height, bg_height - crop_y)
1049
-
1050
- # Ensure dimensions are multiples of 8 again
1051
- crop_width = (crop_width // 8) * 8
1052
- crop_height = (crop_height // 8) * 8
1053
-
1054
- # Crop region from background
1055
- crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width]
1056
-
1057
- # Create composite in cropped region
1058
- fg_local_x = int(new_width/2 + crop_width*padding)
1059
- fg_local_y = int(new_height/2 + crop_height*padding)
1060
- cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale)
1061
-
1062
- # Process the cropped region
1063
- crop_args = list(args)
1064
- crop_args[0] = cropped_composite
1065
- crop_args[1] = crop_region
1066
- crop_args[3] = crop_width
1067
- crop_args[4] = crop_height
1068
- crop_args = crop_args[:-3] # Remove position and scale arguments
1069
-
1070
- # Get relit result
1071
- relit_crop = process_relight_bg(*crop_args)[0]
1072
-
1073
- # Resize relit result to match crop dimensions if needed
1074
- if relit_crop.shape[:2] != (crop_height, crop_width):
1075
- relit_crop = resize_without_crop(relit_crop, crop_width, crop_height)
1076
-
1077
- # Place relit crop back into original background
1078
- result = background.copy()
1079
- result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = relit_crop
1080
-
1081
- return result
1082
-
1083
- ips_bg = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
1084
-
1085
- # Update button click events with new inputs list
1086
- relight_button_bg.click(
1087
- fn=process_relight_with_position,
1088
- inputs=ips_bg,
1089
- outputs=[result_gallery]
1090
- )
1091
-
1092
-
1093
- example_prompts.click(
1094
- fn=lambda x: x[0],
1095
- inputs=example_prompts,
1096
- outputs=prompt,
1097
- show_progress=False,
1098
- queue=False
1099
- )
1100
-
1101
-
1102
-
1103
- block.launch(server_name='0.0.0.0', share=True)
 
 
1
+ import spaces
2
+ import os
3
+ import math
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import safetensors.torch as sf
8
+ import db_examples
9
+ import datetime
10
+ from pathlib import Path
11
+ from io import BytesIO
12
+
13
+ from PIL import Image
14
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
15
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
16
+ from diffusers.models.attention_processor import AttnProcessor2_0
17
+ from transformers import CLIPTextModel, CLIPTokenizer
18
+ from briarmbg import BriaRMBG
19
+ from enum import Enum
20
+ from torch.hub import download_url_to_file
21
+
22
+ from torch.hub import download_url_to_file
23
+ import cv2
24
+
25
+ from typing import Optional
26
+
27
+ from Depth.depth_anything_v2.dpt import DepthAnythingV2
28
+
29
+
30
+
31
+ # from FLORENCE
32
+
33
+ import supervision as sv
34
+ import torch
35
+ from PIL import Image
36
+
37
+ from utils.sam import load_sam_image_model, run_sam_inference
38
+
39
+
40
+ try:
41
+ import xformers
42
+ import xformers.ops
43
+ XFORMERS_AVAILABLE = True
44
+ print("xformers is available - Using memory efficient attention")
45
+ except ImportError:
46
+ XFORMERS_AVAILABLE = False
47
+ print("xformers not available - Using default attention")
48
+
49
+ # Memory optimizations for RTX 2070
50
+ torch.backends.cudnn.benchmark = True
51
+ if torch.cuda.is_available():
52
+ torch.backends.cuda.matmul.allow_tf32 = True
53
+ torch.backends.cudnn.allow_tf32 = True
54
+ # Set a smaller attention slice size for RTX 2070
55
+ torch.backends.cuda.max_split_size_mb = 512
56
+ device = torch.device('cuda')
57
+ else:
58
+ device = torch.device('cpu')
59
+
60
+ # 'stablediffusionapi/realistic-vision-v51'
61
+ # 'runwayml/stable-diffusion-v1-5'
62
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
63
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
64
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
65
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
66
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
67
+ rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
68
+
69
+ model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
70
+ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
71
+ model = model.to(device)
72
+ model.eval()
73
+
74
+ # Change UNet
75
+
76
+ with torch.no_grad():
77
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
78
+ new_conv_in.weight.zero_()
79
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
80
+ new_conv_in.bias = unet.conv_in.bias
81
+ unet.conv_in = new_conv_in
82
+
83
+
84
+ unet_original_forward = unet.forward
85
+
86
+
87
+ def enable_efficient_attention():
88
+ if XFORMERS_AVAILABLE:
89
+ try:
90
+ # RTX 2070 specific settings
91
+ unet.set_use_memory_efficient_attention_xformers(True)
92
+ vae.set_use_memory_efficient_attention_xformers(True)
93
+ print("Enabled xformers memory efficient attention")
94
+ except Exception as e:
95
+ print(f"Xformers error: {e}")
96
+ print("Falling back to sliced attention")
97
+ # Use sliced attention for RTX 2070
98
+ unet.set_attention_slice_size(4)
99
+ vae.set_attention_slice_size(4)
100
+ unet.set_attn_processor(AttnProcessor2_0())
101
+ vae.set_attn_processor(AttnProcessor2_0())
102
+ else:
103
+ # Fallback for when xformers is not available
104
+ print("Using sliced attention")
105
+ unet.set_attention_slice_size(4)
106
+ vae.set_attention_slice_size(4)
107
+ unet.set_attn_processor(AttnProcessor2_0())
108
+ vae.set_attn_processor(AttnProcessor2_0())
109
+
110
+ # Add memory clearing function
111
+ def clear_memory():
112
+ if torch.cuda.is_available():
113
+ torch.cuda.empty_cache()
114
+ torch.cuda.synchronize()
115
+
116
+ # Enable efficient attention
117
+ enable_efficient_attention()
118
+
119
+
120
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
121
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
122
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
123
+ new_sample = torch.cat([sample, c_concat], dim=1)
124
+ kwargs['cross_attention_kwargs'] = {}
125
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
126
+
127
+
128
+ unet.forward = hooked_unet_forward
129
+
130
+ # Load
131
+
132
+ model_path = './models/iclight_sd15_fc.safetensors'
133
+ # model_path = './models/iclight_sd15_fbc.safetensors'
134
+
135
+
136
+ # if not os.path.exists(model_path):
137
+ # download_url_to_file(url='https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fc.safetensors', dst=model_path)
138
+
139
+ sd_offset = sf.load_file(model_path)
140
+ sd_origin = unet.state_dict()
141
+ keys = sd_origin.keys()
142
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
143
+ unet.load_state_dict(sd_merged, strict=True)
144
+ del sd_offset, sd_origin, sd_merged, keys
145
+
146
+ # Device
147
+
148
+ # device = torch.device('cuda')
149
+ # text_encoder = text_encoder.to(device=device, dtype=torch.float16)
150
+ # vae = vae.to(device=device, dtype=torch.bfloat16)
151
+ # unet = unet.to(device=device, dtype=torch.float16)
152
+ # rmbg = rmbg.to(device=device, dtype=torch.float32)
153
+
154
+
155
+ # Device and dtype setup
156
+ device = torch.device('cuda')
157
+ dtype = torch.float16 # RTX 2070 works well with float16
158
+
159
+ # Memory optimizations for RTX 2070
160
+ torch.backends.cudnn.benchmark = True
161
+ if torch.cuda.is_available():
162
+ torch.backends.cuda.matmul.allow_tf32 = True
163
+ torch.backends.cudnn.allow_tf32 = True
164
+ # Set a very small attention slice size for RTX 2070 to avoid OOM
165
+ torch.backends.cuda.max_split_size_mb = 128
166
+
167
+ # Move models to device with consistent dtype
168
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
169
+ vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
170
+ unet = unet.to(device=device, dtype=dtype)
171
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
172
+
173
+
174
+ ddim_scheduler = DDIMScheduler(
175
+ num_train_timesteps=1000,
176
+ beta_start=0.00085,
177
+ beta_end=0.012,
178
+ beta_schedule="scaled_linear",
179
+ clip_sample=False,
180
+ set_alpha_to_one=False,
181
+ steps_offset=1,
182
+ )
183
+
184
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
185
+ num_train_timesteps=1000,
186
+ beta_start=0.00085,
187
+ beta_end=0.012,
188
+ steps_offset=1
189
+ )
190
+
191
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
192
+ num_train_timesteps=1000,
193
+ beta_start=0.00085,
194
+ beta_end=0.012,
195
+ algorithm_type="sde-dpmsolver++",
196
+ use_karras_sigmas=True,
197
+ steps_offset=1
198
+ )
199
+
200
+ # Pipelines
201
+
202
+ t2i_pipe = StableDiffusionPipeline(
203
+ vae=vae,
204
+ text_encoder=text_encoder,
205
+ tokenizer=tokenizer,
206
+ unet=unet,
207
+ scheduler=dpmpp_2m_sde_karras_scheduler,
208
+ safety_checker=None,
209
+ requires_safety_checker=False,
210
+ feature_extractor=None,
211
+ image_encoder=None
212
+ )
213
+
214
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
215
+ vae=vae,
216
+ text_encoder=text_encoder,
217
+ tokenizer=tokenizer,
218
+ unet=unet,
219
+ scheduler=dpmpp_2m_sde_karras_scheduler,
220
+ safety_checker=None,
221
+ requires_safety_checker=False,
222
+ feature_extractor=None,
223
+ image_encoder=None
224
+ )
225
+
226
+
227
+ @torch.inference_mode()
228
+ def encode_prompt_inner(txt: str):
229
+ max_length = tokenizer.model_max_length
230
+ chunk_length = tokenizer.model_max_length - 2
231
+ id_start = tokenizer.bos_token_id
232
+ id_end = tokenizer.eos_token_id
233
+ id_pad = id_end
234
+
235
+ def pad(x, p, i):
236
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
237
+
238
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
239
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
240
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
241
+
242
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
243
+ conds = text_encoder(token_ids).last_hidden_state
244
+
245
+ return conds
246
+
247
+
248
+ @torch.inference_mode()
249
+ def encode_prompt_pair(positive_prompt, negative_prompt):
250
+ c = encode_prompt_inner(positive_prompt)
251
+ uc = encode_prompt_inner(negative_prompt)
252
+
253
+ c_len = float(len(c))
254
+ uc_len = float(len(uc))
255
+ max_count = max(c_len, uc_len)
256
+ c_repeat = int(math.ceil(max_count / c_len))
257
+ uc_repeat = int(math.ceil(max_count / uc_len))
258
+ max_chunk = max(len(c), len(uc))
259
+
260
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
261
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
262
+
263
+ c = torch.cat([p[None, ...] for p in c], dim=1)
264
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
265
+
266
+ return c, uc
267
+
268
+
269
+ @torch.inference_mode()
270
+ def pytorch2numpy(imgs, quant=True):
271
+ results = []
272
+ for x in imgs:
273
+ y = x.movedim(0, -1)
274
+
275
+ if quant:
276
+ y = y * 127.5 + 127.5
277
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
278
+ else:
279
+ y = y * 0.5 + 0.5
280
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
281
+
282
+ results.append(y)
283
+ return results
284
+
285
+
286
+ @torch.inference_mode()
287
+ def numpy2pytorch(imgs):
288
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
289
+ h = h.movedim(-1, 1)
290
+ return h
291
+
292
+
293
+ def resize_and_center_crop(image, target_width, target_height):
294
+ pil_image = Image.fromarray(image)
295
+ original_width, original_height = pil_image.size
296
+ scale_factor = max(target_width / original_width, target_height / original_height)
297
+ resized_width = int(round(original_width * scale_factor))
298
+ resized_height = int(round(original_height * scale_factor))
299
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
300
+ left = (resized_width - target_width) / 2
301
+ top = (resized_height - target_height) / 2
302
+ right = (resized_width + target_width) / 2
303
+ bottom = (resized_height + target_height) / 2
304
+ cropped_image = resized_image.crop((left, top, right, bottom))
305
+ return np.array(cropped_image)
306
+
307
+
308
+ def resize_without_crop(image, target_width, target_height):
309
+ pil_image = Image.fromarray(image)
310
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
311
+ return np.array(resized_image)
312
+
313
+
314
+ @torch.inference_mode()
315
+ def run_rmbg(img, sigma=0.0):
316
+ # Convert RGBA to RGB if needed
317
+ if img.shape[-1] == 4:
318
+ # Use white background for alpha composition
319
+ alpha = img[..., 3:] / 255.0
320
+ rgb = img[..., :3]
321
+ white_bg = np.ones_like(rgb) * 255
322
+ img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
323
+
324
+ H, W, C = img.shape
325
+ assert C == 3
326
+ k = (256.0 / float(H * W)) ** 0.5
327
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
328
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
329
+ alpha = rmbg(feed)[0][0]
330
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
331
+ alpha = alpha.movedim(1, -1)[0]
332
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
333
+
334
+ # Create RGBA image
335
+ rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
336
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
337
+ return result.clip(0, 255).astype(np.uint8), rgba
338
+ @torch.inference_mode()
339
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
340
+ clear_memory()
341
+
342
+ # Get input dimensions
343
+ input_height, input_width = input_fg.shape[:2]
344
+
345
+ bg_source = BGSource(bg_source)
346
+
347
+
348
+ if bg_source == BGSource.UPLOAD:
349
+ pass
350
+ elif bg_source == BGSource.UPLOAD_FLIP:
351
+ input_bg = np.fliplr(input_bg)
352
+ elif bg_source == BGSource.GREY:
353
+ input_bg = np.zeros(shape=(input_height, input_width, 3), dtype=np.uint8) + 64
354
+ elif bg_source == BGSource.LEFT:
355
+ gradient = np.linspace(255, 0, input_width)
356
+ image = np.tile(gradient, (input_height, 1))
357
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
358
+ elif bg_source == BGSource.RIGHT:
359
+ gradient = np.linspace(0, 255, input_width)
360
+ image = np.tile(gradient, (input_height, 1))
361
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
362
+ elif bg_source == BGSource.TOP:
363
+ gradient = np.linspace(255, 0, input_height)[:, None]
364
+ image = np.tile(gradient, (1, input_width))
365
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
366
+ elif bg_source == BGSource.BOTTOM:
367
+ gradient = np.linspace(0, 255, input_height)[:, None]
368
+ image = np.tile(gradient, (1, input_width))
369
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
370
+ else:
371
+ raise 'Wrong initial latent!'
372
+
373
+ rng = torch.Generator(device=device).manual_seed(int(seed))
374
+
375
+ # Use input dimensions directly
376
+ fg = resize_without_crop(input_fg, input_width, input_height)
377
+
378
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
379
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
380
+
381
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
382
+
383
+ if input_bg is None:
384
+ latents = t2i_pipe(
385
+ prompt_embeds=conds,
386
+ negative_prompt_embeds=unconds,
387
+ width=input_width,
388
+ height=input_height,
389
+ num_inference_steps=steps,
390
+ num_images_per_prompt=num_samples,
391
+ generator=rng,
392
+ output_type='latent',
393
+ guidance_scale=cfg,
394
+ cross_attention_kwargs={'concat_conds': concat_conds},
395
+ ).images.to(vae.dtype) / vae.config.scaling_factor
396
+ else:
397
+ bg = resize_without_crop(input_bg, input_width, input_height)
398
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
399
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
400
+ latents = i2i_pipe(
401
+ image=bg_latent,
402
+ strength=lowres_denoise,
403
+ prompt_embeds=conds,
404
+ negative_prompt_embeds=unconds,
405
+ width=input_width,
406
+ height=input_height,
407
+ num_inference_steps=int(round(steps / lowres_denoise)),
408
+ num_images_per_prompt=num_samples,
409
+ generator=rng,
410
+ output_type='latent',
411
+ guidance_scale=cfg,
412
+ cross_attention_kwargs={'concat_conds': concat_conds},
413
+ ).images.to(vae.dtype) / vae.config.scaling_factor
414
+
415
+ pixels = vae.decode(latents).sample
416
+ pixels = pytorch2numpy(pixels)
417
+ pixels = [resize_without_crop(
418
+ image=p,
419
+ target_width=int(round(input_width * highres_scale / 64.0) * 64),
420
+ target_height=int(round(input_height * highres_scale / 64.0) * 64))
421
+ for p in pixels]
422
+
423
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
424
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
425
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
426
+
427
+ highres_height, highres_width = latents.shape[2] * 8, latents.shape[3] * 8
428
+
429
+ fg = resize_without_crop(input_fg, highres_width, highres_height)
430
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
431
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
432
+
433
+ latents = i2i_pipe(
434
+ image=latents,
435
+ strength=highres_denoise,
436
+ prompt_embeds=conds,
437
+ negative_prompt_embeds=unconds,
438
+ width=highres_width,
439
+ height=highres_height,
440
+ num_inference_steps=int(round(steps / highres_denoise)),
441
+ num_images_per_prompt=num_samples,
442
+ generator=rng,
443
+ output_type='latent',
444
+ guidance_scale=cfg,
445
+ cross_attention_kwargs={'concat_conds': concat_conds},
446
+ ).images.to(vae.dtype) / vae.config.scaling_factor
447
+
448
+ pixels = vae.decode(latents).sample
449
+ pixels = pytorch2numpy(pixels)
450
+
451
+ # Resize back to input dimensions
452
+ pixels = [resize_without_crop(p, input_width, input_height) for p in pixels]
453
+ pixels = np.stack(pixels)
454
+
455
+ return pixels
456
+
457
+ @torch.inference_mode()
458
+ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
459
+ clear_memory()
460
+ bg_source = BGSource(bg_source)
461
+
462
+ if bg_source == BGSource.UPLOAD:
463
+ pass
464
+ elif bg_source == BGSource.UPLOAD_FLIP:
465
+ input_bg = np.fliplr(input_bg)
466
+ elif bg_source == BGSource.GREY:
467
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
468
+ elif bg_source == BGSource.LEFT:
469
+ gradient = np.linspace(224, 32, image_width)
470
+ image = np.tile(gradient, (image_height, 1))
471
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
472
+ elif bg_source == BGSource.RIGHT:
473
+ gradient = np.linspace(32, 224, image_width)
474
+ image = np.tile(gradient, (image_height, 1))
475
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
476
+ elif bg_source == BGSource.TOP:
477
+ gradient = np.linspace(224, 32, image_height)[:, None]
478
+ image = np.tile(gradient, (1, image_width))
479
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
480
+ elif bg_source == BGSource.BOTTOM:
481
+ gradient = np.linspace(32, 224, image_height)[:, None]
482
+ image = np.tile(gradient, (1, image_width))
483
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
484
+ else:
485
+ raise 'Wrong background source!'
486
+
487
+ rng = torch.Generator(device=device).manual_seed(seed)
488
+
489
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
490
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
491
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
492
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
493
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
494
+
495
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
496
+
497
+ latents = t2i_pipe(
498
+ prompt_embeds=conds,
499
+ negative_prompt_embeds=unconds,
500
+ width=image_width,
501
+ height=image_height,
502
+ num_inference_steps=steps,
503
+ num_images_per_prompt=num_samples,
504
+ generator=rng,
505
+ output_type='latent',
506
+ guidance_scale=cfg,
507
+ cross_attention_kwargs={'concat_conds': concat_conds},
508
+ ).images.to(vae.dtype) / vae.config.scaling_factor
509
+
510
+ pixels = vae.decode(latents).sample
511
+ pixels = pytorch2numpy(pixels)
512
+ pixels = [resize_without_crop(
513
+ image=p,
514
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
515
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
516
+ for p in pixels]
517
+
518
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
519
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
520
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
521
+
522
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
523
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
524
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
525
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
526
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
527
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
528
+
529
+ latents = i2i_pipe(
530
+ image=latents,
531
+ strength=highres_denoise,
532
+ prompt_embeds=conds,
533
+ negative_prompt_embeds=unconds,
534
+ width=image_width,
535
+ height=image_height,
536
+ num_inference_steps=int(round(steps / highres_denoise)),
537
+ num_images_per_prompt=num_samples,
538
+ generator=rng,
539
+ output_type='latent',
540
+ guidance_scale=cfg,
541
+ cross_attention_kwargs={'concat_conds': concat_conds},
542
+ ).images.to(vae.dtype) / vae.config.scaling_factor
543
+
544
+ pixels = vae.decode(latents).sample
545
+ pixels = pytorch2numpy(pixels, quant=False)
546
+
547
+ clear_memory()
548
+ return pixels, [fg, bg]
549
+
550
+
551
+ @torch.inference_mode()
552
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
553
+ input_fg, matting = run_rmbg(input_fg)
554
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
555
+ return input_fg, results
556
+
557
+
558
+
559
+ @torch.inference_mode()
560
+ def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
561
+ bg_source = BGSource(bg_source)
562
+
563
+ # Convert numerical inputs to appropriate types
564
+ image_width = int(image_width)
565
+ image_height = int(image_height)
566
+ num_samples = int(num_samples)
567
+ seed = int(seed)
568
+ steps = int(steps)
569
+ cfg = float(cfg)
570
+ highres_scale = float(highres_scale)
571
+ highres_denoise = float(highres_denoise)
572
+
573
+ if bg_source == BGSource.UPLOAD:
574
+ pass
575
+ elif bg_source == BGSource.UPLOAD_FLIP:
576
+ input_bg = np.fliplr(input_bg)
577
+ elif bg_source == BGSource.GREY:
578
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
579
+ elif bg_source == BGSource.LEFT:
580
+ gradient = np.linspace(224, 32, image_width)
581
+ image = np.tile(gradient, (image_height, 1))
582
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
583
+ elif bg_source == BGSource.RIGHT:
584
+ gradient = np.linspace(32, 224, image_width)
585
+ image = np.tile(gradient, (image_height, 1))
586
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
587
+ elif bg_source == BGSource.TOP:
588
+ gradient = np.linspace(224, 32, image_height)[:, None]
589
+ image = np.tile(gradient, (1, image_width))
590
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
591
+ elif bg_source == BGSource.BOTTOM:
592
+ gradient = np.linspace(32, 224, image_height)[:, None]
593
+ image = np.tile(gradient, (1, image_width))
594
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
595
+ else:
596
+ raise ValueError('Wrong background source!')
597
+
598
+ input_fg, matting = run_rmbg(input_fg)
599
+ results, extra_images = process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
600
+ results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
601
+ final_results = results + extra_images
602
+
603
+ # Save the generated images
604
+ save_images(results, prefix="relight")
605
+
606
+ return results
607
+
608
+
609
+ quick_prompts = [
610
+ 'sunshine from window',
611
+ 'neon light, city',
612
+ 'sunset over sea',
613
+ 'golden time',
614
+ 'sci-fi RGB glowing, cyberpunk',
615
+ 'natural lighting',
616
+ 'warm atmosphere, at home, bedroom',
617
+ 'magic lit',
618
+ 'evil, gothic, Yharnam',
619
+ 'light and shadow',
620
+ 'shadow from window',
621
+ 'soft studio lighting',
622
+ 'home atmosphere, cozy bedroom illumination',
623
+ 'neon, Wong Kar-wai, warm'
624
+ ]
625
+ quick_prompts = [[x] for x in quick_prompts]
626
+
627
+
628
+ quick_subjects = [
629
+ 'modern sofa, high quality leather',
630
+ 'elegant dining table, polished wood',
631
+ 'luxurious bed, premium mattress',
632
+ 'minimalist office desk, clean design',
633
+ 'vintage wooden cabinet, antique finish',
634
+ ]
635
+ quick_subjects = [[x] for x in quick_subjects]
636
+
637
+
638
+ class BGSource(Enum):
639
+ UPLOAD = "Use Background Image"
640
+ UPLOAD_FLIP = "Use Flipped Background Image"
641
+ LEFT = "Left Light"
642
+ RIGHT = "Right Light"
643
+ TOP = "Top Light"
644
+ BOTTOM = "Bottom Light"
645
+ GREY = "Ambient"
646
+
647
+ # Add save function
648
+ def save_images(images, prefix="relight"):
649
+ # Create output directory if it doesn't exist
650
+ output_dir = Path("outputs")
651
+ output_dir.mkdir(exist_ok=True)
652
+
653
+ # Create timestamp for unique filenames
654
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
655
+
656
+ saved_paths = []
657
+ for i, img in enumerate(images):
658
+ if isinstance(img, np.ndarray):
659
+ # Convert to PIL Image if numpy array
660
+ img = Image.fromarray(img)
661
+
662
+ # Create filename with timestamp
663
+ filename = f"{prefix}_{timestamp}_{i+1}.png"
664
+ filepath = output_dir / filename
665
+
666
+ # Save image
667
+ img.save(filepath)
668
+
669
+
670
+ # print(f"Saved {len(saved_paths)} images to {output_dir}")
671
+ return saved_paths
672
+
673
+
674
+ class MaskMover:
675
+ def __init__(self):
676
+ self.extracted_fg = None
677
+ self.original_fg = None # Store original foreground
678
+
679
+ def set_extracted_fg(self, fg_image):
680
+ """Store the extracted foreground with alpha channel"""
681
+ if isinstance(fg_image, np.ndarray):
682
+ self.extracted_fg = fg_image.copy()
683
+ self.original_fg = fg_image.copy()
684
+ else:
685
+ self.extracted_fg = np.array(fg_image)
686
+ self.original_fg = np.array(fg_image)
687
+ return self.extracted_fg
688
+
689
+ def create_composite(self, background, x_pos, y_pos, scale=1.0):
690
+ """Create composite with foreground at specified position"""
691
+ if self.original_fg is None or background is None:
692
+ return background
693
+
694
+ # Convert inputs to PIL Images
695
+ if isinstance(background, np.ndarray):
696
+ bg = Image.fromarray(background).convert('RGBA')
697
+ else:
698
+ bg = background.convert('RGBA')
699
+
700
+ if isinstance(self.original_fg, np.ndarray):
701
+ fg = Image.fromarray(self.original_fg).convert('RGBA')
702
+ else:
703
+ fg = self.original_fg.convert('RGBA')
704
+
705
+ # Scale the foreground size
706
+ new_width = int(fg.width * scale)
707
+ new_height = int(fg.height * scale)
708
+ fg = fg.resize((new_width, new_height), Image.LANCZOS)
709
+
710
+ # Center the scaled foreground at the position
711
+ x = int(x_pos - new_width / 2)
712
+ y = int(y_pos - new_height / 2)
713
+
714
+ # Create composite
715
+ result = bg.copy()
716
+ result.paste(fg, (x, y), fg) # Use fg as the mask (requires fg to be in 'RGBA' mode)
717
+
718
+ return np.array(result.convert('RGB')) # Convert back to 'RGB' if needed
719
+
720
+ def get_depth(image):
721
+ if image is None:
722
+ return None
723
+ # Convert from PIL/gradio format to cv2
724
+ raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
725
+ # Get depth map
726
+ depth = model.infer_image(raw_img) # HxW raw depth map
727
+ # Normalize depth for visualization
728
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
729
+ # Convert to RGB for display
730
+ depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
731
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
732
+ return Image.fromarray(depth_colored)
733
+
734
+
735
+ from PIL import Image
736
+
737
+ def compress_image(image):
738
+ # Convert Gradio image (numpy array) to PIL Image
739
+ img = Image.fromarray(image)
740
+
741
+ # Resize image if dimensions are too large
742
+ max_size = 1024 # Maximum dimension size
743
+ if img.width > max_size or img.height > max_size:
744
+ ratio = min(max_size/img.width, max_size/img.height)
745
+ new_size = (int(img.width * ratio), int(img.height * ratio))
746
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
747
+
748
+ quality = 95 # Start with high quality
749
+ img.save("compressed_image.jpg", "JPEG", quality=quality) # Initial save
750
+
751
+ # Check file size and adjust quality if necessary
752
+ while os.path.getsize("compressed_image.jpg") > 100 * 1024: # 100KB limit
753
+ quality -= 5 # Decrease quality
754
+ img.save("compressed_image.jpg", "JPEG", quality=quality)
755
+ if quality < 20: # Prevent quality from going too low
756
+ break
757
+
758
+ # Convert back to numpy array for Gradio
759
+ compressed_img = np.array(Image.open("compressed_image.jpg"))
760
+ return compressed_img
761
+
762
+
763
+ block = gr.Blocks().queue()
764
+ with block:
765
+ with gr.Tab("Text"):
766
+ with gr.Row():
767
+ gr.Markdown("## Product Placement from Text")
768
+ with gr.Row():
769
+ with gr.Column():
770
+ with gr.Row():
771
+ input_fg = gr.Image(type="numpy", label="Image", height=480)
772
+ output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
773
+ with gr.Group():
774
+ prompt = gr.Textbox(label="Prompt")
775
+ bg_source = gr.Radio(choices=[e.value for e in BGSource],
776
+ value=BGSource.GREY.value,
777
+ label="Lighting Preference (Initial Latent)", type='value')
778
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
779
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
780
+ relight_button = gr.Button(value="Relight")
781
+
782
+ with gr.Group():
783
+ with gr.Row():
784
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
785
+ seed = gr.Number(label="Seed", value=12345, precision=0)
786
+
787
+ with gr.Row():
788
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
789
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
790
+
791
+ with gr.Accordion("Advanced options", open=False):
792
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=15, step=1)
793
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01)
794
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
795
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
796
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
797
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
798
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
799
+ with gr.Column():
800
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
801
+ with gr.Row():
802
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
803
+ # gr.Examples(
804
+ # fn=lambda *args: ([args[-1]], None),
805
+ # examples=db_examples.foreground_conditioned_examples,
806
+ # inputs=[
807
+ # input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
808
+ # ],
809
+ # outputs=[result_gallery, output_bg],
810
+ # run_on_click=True, examples_per_page=1024
811
+ # )
812
+ ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
813
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[output_bg, result_gallery])
814
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
815
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
816
+
817
+ with gr.Tab("Background", visible=False):
818
+ mask_mover = MaskMover()
819
+
820
+
821
+ with gr.Row():
822
+ gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
823
+ gr.Markdown("πŸ’Ύ Generated images are automatically saved to 'outputs' folder")
824
+
825
+ with gr.Row():
826
+ with gr.Column():
827
+ # Step 1: Input and Extract
828
+ with gr.Row():
829
+ with gr.Group():
830
+ gr.Markdown("### Step 1: Extract Foreground")
831
+ input_image = gr.Image(type="numpy", label="Input Image", height=480)
832
+ # find_objects_button = gr.Button(value="Find Objects")
833
+ extract_button = gr.Button(value="Remove Background")
834
+ extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
835
+
836
+ with gr.Row():
837
+ # Step 2: Background and Position
838
+ with gr.Group():
839
+ gr.Markdown("### Step 2: Position on Background")
840
+ input_bg = gr.Image(type="numpy", label="Background Image", height=480)
841
+
842
+ with gr.Row():
843
+ x_slider = gr.Slider(
844
+ minimum=0,
845
+ maximum=1000,
846
+ label="X Position",
847
+ value=500,
848
+ visible=False
849
+ )
850
+ y_slider = gr.Slider(
851
+ minimum=0,
852
+ maximum=1000,
853
+ label="Y Position",
854
+ value=500,
855
+ visible=False
856
+ )
857
+ fg_scale_slider = gr.Slider(
858
+ label="Foreground Scale",
859
+ minimum=0.01,
860
+ maximum=3.0,
861
+ value=1.0,
862
+ step=0.01
863
+ )
864
+
865
+ editor = gr.ImageEditor(
866
+ type="numpy",
867
+ label="Position Foreground",
868
+ height=480,
869
+ visible=False
870
+ )
871
+ get_depth_button = gr.Button(value="Get Depth")
872
+ depth_image = gr.Image(type="numpy", label="Depth Image", height=480)
873
+
874
+ # Step 3: Relighting Options
875
+ with gr.Group():
876
+ gr.Markdown("### Step 3: Relighting Settings")
877
+ prompt = gr.Textbox(label="Prompt")
878
+ bg_source = gr.Radio(
879
+ choices=[e.value for e in BGSource],
880
+ value=BGSource.UPLOAD.value,
881
+ label="Background Source",
882
+ type='value'
883
+ )
884
+
885
+ example_prompts = gr.Dataset(
886
+ samples=quick_prompts,
887
+ label='Prompt Quick List',
888
+ components=[prompt]
889
+ )
890
+ # bg_gallery = gr.Gallery(
891
+ # height=450,
892
+ # label='Background Quick List',
893
+ # value=db_examples.bg_samples,
894
+ # columns=5,
895
+ # allow_preview=False
896
+ # )
897
+ relight_button_bg = gr.Button(value="Relight")
898
+
899
+ # Additional settings
900
+ with gr.Group():
901
+ with gr.Row():
902
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
903
+ seed = gr.Number(label="Seed", value=12345, precision=0)
904
+ with gr.Row():
905
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
906
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
907
+
908
+ with gr.Accordion("Advanced options", open=False):
909
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
910
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
911
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=2.0, value=1.2, step=0.01)
912
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
913
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
914
+ n_prompt = gr.Textbox(
915
+ label="Negative Prompt",
916
+ value='lowres, bad anatomy, bad hands, cropped, worst quality'
917
+ )
918
+
919
+ with gr.Column():
920
+ result_gallery = gr.Image(height=832, label='Outputs')
921
+
922
+ def extract_foreground(image):
923
+ if image is None:
924
+ return None, gr.update(visible=True), gr.update(visible=True)
925
+ result, rgba = run_rmbg(image)
926
+ mask_mover.set_extracted_fg(rgba)
927
+
928
+ return result, gr.update(visible=True), gr.update(visible=True)
929
+
930
+
931
+ original_bg = None
932
+
933
+ extract_button.click(
934
+ fn=extract_foreground,
935
+ inputs=[input_image],
936
+ outputs=[extracted_fg, x_slider, y_slider]
937
+ )
938
+
939
+ # find_objects_button.click(
940
+ # fn=find_objects,
941
+ # inputs=[input_image],
942
+ # outputs=[extracted_fg]
943
+ # )
944
+
945
+ get_depth_button.click(
946
+ fn=get_depth,
947
+ inputs=[input_bg],
948
+ outputs=[depth_image]
949
+ )
950
+
951
+ # def update_position(background, x_pos, y_pos, scale):
952
+ # """Update composite when position changes"""
953
+ # global original_bg
954
+ # if background is None:
955
+ # return None
956
+
957
+ # if original_bg is None:
958
+ # original_bg = background.copy()
959
+
960
+ # # Convert string values to float
961
+ # x_pos = float(x_pos)
962
+ # y_pos = float(y_pos)
963
+ # scale = float(scale)
964
+
965
+ # return mask_mover.create_composite(original_bg, x_pos, y_pos, scale)
966
+
967
+ class BackgroundManager:
968
+ def __init__(self):
969
+ self.original_bg = None
970
+
971
+ def update_position(self, background, x_pos, y_pos, scale):
972
+ """Update composite when position changes"""
973
+ if background is None:
974
+ return None
975
+
976
+ if self.original_bg is None:
977
+ self.original_bg = background.copy()
978
+
979
+ # Convert string values to float
980
+ x_pos = float(x_pos)
981
+ y_pos = float(y_pos)
982
+ scale = float(scale)
983
+
984
+ return mask_mover.create_composite(self.original_bg, x_pos, y_pos, scale)
985
+
986
+ # Create an instance of BackgroundManager
987
+ bg_manager = BackgroundManager()
988
+
989
+
990
+ x_slider.change(
991
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
992
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
993
+ outputs=[input_bg]
994
+ )
995
+
996
+ y_slider.change(
997
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
998
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
999
+ outputs=[input_bg]
1000
+ )
1001
+
1002
+ fg_scale_slider.change(
1003
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1004
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1005
+ outputs=[input_bg]
1006
+ )
1007
+
1008
+ # Update inputs list to include fg_scale_slider
1009
+
1010
+ def process_relight_with_position(*args):
1011
+ if mask_mover.extracted_fg is None:
1012
+ gr.Warning("Please extract foreground first")
1013
+ return None
1014
+
1015
+ background = args[1] # Get background image
1016
+ x_pos = float(args[-3]) # x_slider value
1017
+ y_pos = float(args[-2]) # y_slider value
1018
+ scale = float(args[-1]) # fg_scale_slider value
1019
+
1020
+ # Get original foreground size after scaling
1021
+ fg = Image.fromarray(mask_mover.original_fg)
1022
+ new_width = int(fg.width * scale)
1023
+ new_height = int(fg.height * scale)
1024
+
1025
+ # Calculate crop region around foreground position
1026
+ crop_x = int(x_pos - new_width/2)
1027
+ crop_y = int(y_pos - new_height/2)
1028
+ crop_width = new_width
1029
+ crop_height = new_height
1030
+
1031
+ # Add padding for context (20% extra on each side)
1032
+ padding = 0.2
1033
+ crop_x = int(crop_x - crop_width * padding)
1034
+ crop_y = int(crop_y - crop_height * padding)
1035
+ crop_width = int(crop_width * (1 + 2 * padding))
1036
+ crop_height = int(crop_height * (1 + 2 * padding))
1037
+
1038
+ # Ensure crop dimensions are multiples of 8
1039
+ crop_width = ((crop_width + 7) // 8) * 8
1040
+ crop_height = ((crop_height + 7) // 8) * 8
1041
+
1042
+ # Ensure crop region is within image bounds
1043
+ bg_height, bg_width = background.shape[:2]
1044
+ crop_x = max(0, min(crop_x, bg_width - crop_width))
1045
+ crop_y = max(0, min(crop_y, bg_height - crop_height))
1046
+
1047
+ # Get actual crop dimensions after boundary check
1048
+ crop_width = min(crop_width, bg_width - crop_x)
1049
+ crop_height = min(crop_height, bg_height - crop_y)
1050
+
1051
+ # Ensure dimensions are multiples of 8 again
1052
+ crop_width = (crop_width // 8) * 8
1053
+ crop_height = (crop_height // 8) * 8
1054
+
1055
+ # Crop region from background
1056
+ crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width]
1057
+
1058
+ # Create composite in cropped region
1059
+ fg_local_x = int(new_width/2 + crop_width*padding)
1060
+ fg_local_y = int(new_height/2 + crop_height*padding)
1061
+ cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale)
1062
+
1063
+ # Process the cropped region
1064
+ crop_args = list(args)
1065
+ crop_args[0] = cropped_composite
1066
+ crop_args[1] = crop_region
1067
+ crop_args[3] = crop_width
1068
+ crop_args[4] = crop_height
1069
+ crop_args = crop_args[:-3] # Remove position and scale arguments
1070
+
1071
+ # Get relit result
1072
+ relit_crop = process_relight_bg(*crop_args)[0]
1073
+
1074
+ # Resize relit result to match crop dimensions if needed
1075
+ if relit_crop.shape[:2] != (crop_height, crop_width):
1076
+ relit_crop = resize_without_crop(relit_crop, crop_width, crop_height)
1077
+
1078
+ # Place relit crop back into original background
1079
+ result = background.copy()
1080
+ result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = relit_crop
1081
+
1082
+ return result
1083
+
1084
+ ips_bg = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
1085
+
1086
+ # Update button click events with new inputs list
1087
+ relight_button_bg.click(
1088
+ fn=process_relight_with_position,
1089
+ inputs=ips_bg,
1090
+ outputs=[result_gallery]
1091
+ )
1092
+
1093
+
1094
+ example_prompts.click(
1095
+ fn=lambda x: x[0],
1096
+ inputs=example_prompts,
1097
+ outputs=prompt,
1098
+ show_progress=False,
1099
+ queue=False
1100
+ )
1101
+
1102
+
1103
+
1104
+ block.launch(server_name='0.0.0.0', share=True)