mantrakp commited on
Commit
5a59c13
β€’
1 Parent(s): c7554bf

Refactor UI structure and import spaces module

Browse files
Files changed (2) hide show
  1. app2.py +721 -0
  2. src/tasks/images/init_sys.py +1 -1
app2.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Testing one file gradio app for zero gpu spaces not working as expected.
2
+ # Check here for the issue:
3
+ import gc
4
+ import json
5
+ import random
6
+ from typing import List, Optional
7
+
8
+ import spaces
9
+ import gradio as gr
10
+ from huggingface_hub import ModelCard
11
+ import torch
12
+ import numpy as np
13
+ from pydantic import BaseModel
14
+ from PIL import Image
15
+ from diffusers import (
16
+ FluxPipeline,
17
+ FluxImg2ImgPipeline,
18
+ FluxInpaintPipeline,
19
+ FluxControlNetPipeline,
20
+ StableDiffusionXLPipeline,
21
+ StableDiffusionXLImg2ImgPipeline,
22
+ StableDiffusionXLInpaintPipeline,
23
+ StableDiffusionXLControlNetPipeline,
24
+ StableDiffusionXLControlNetImg2ImgPipeline,
25
+ StableDiffusionXLControlNetInpaintPipeline,
26
+ AutoPipelineForText2Image,
27
+ AutoPipelineForImage2Image,
28
+ AutoPipelineForInpainting,
29
+ DiffusionPipeline,
30
+ AutoencoderKL,
31
+ FluxControlNetModel,
32
+ FluxMultiControlNetModel,
33
+ ControlNetModel,
34
+ )
35
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
36
+ from huggingface_hub import hf_hub_download
37
+ from transformers import CLIPFeatureExtractor
38
+ from photomaker import FaceAnalysis2
39
+ from diffusers.schedulers import *
40
+ from huggingface_hub import hf_hub_download
41
+ from safetensors.torch import load_file
42
+ from controlnet_aux.processor import Processor
43
+ from photomaker import (
44
+ PhotoMakerStableDiffusionXLPipeline,
45
+ PhotoMakerStableDiffusionXLControlNetPipeline,
46
+ analyze_faces
47
+ )
48
+ from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl, get_weighted_text_embeddings_flux1
49
+
50
+
51
+ # Initialize System
52
+ def load_sd():
53
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+
56
+ # Models
57
+ models = [
58
+ {
59
+ "repo_id": "black-forest-labs/FLUX.1-dev",
60
+ "loader": "flux",
61
+ "compute_type": torch.bfloat16,
62
+ },
63
+ {
64
+ "repo_id": "SG161222/RealVisXL_V4.0",
65
+ "loader": "xl",
66
+ "compute_type": torch.float16,
67
+ }
68
+ ]
69
+
70
+ for model in models:
71
+ try:
72
+ model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
73
+ model['repo_id'],
74
+ torch_dtype = model['compute_type'],
75
+ safety_checker = None,
76
+ variant = "fp16"
77
+ ).to(device)
78
+ model["pipeline"].enable_model_cpu_offload()
79
+ except:
80
+ model["pipeline"] = AutoPipelineForText2Image.from_pretrained(
81
+ model['repo_id'],
82
+ torch_dtype = model['compute_type'],
83
+ safety_checker = None
84
+ ).to(device)
85
+ model["pipeline"].enable_model_cpu_offload()
86
+
87
+
88
+ # VAE n Refiner
89
+ sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device)
90
+ refiner = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=sdxl_vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to(device)
91
+ refiner.enable_model_cpu_offload()
92
+
93
+
94
+ # Safety Checker
95
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
96
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)
97
+
98
+
99
+ # Controlnets
100
+ controlnet_models = [
101
+ {
102
+ "repo_id": "xinsir/controlnet-depth-sdxl-1.0",
103
+ "name": "depth_xl",
104
+ "layers": ["depth"],
105
+ "loader": "xl",
106
+ "compute_type": torch.float16,
107
+ },
108
+ {
109
+ "repo_id": "xinsir/controlnet-canny-sdxl-1.0",
110
+ "name": "canny_xl",
111
+ "layers": ["canny"],
112
+ "loader": "xl",
113
+ "compute_type": torch.float16,
114
+ },
115
+ {
116
+ "repo_id": "xinsir/controlnet-openpose-sdxl-1.0",
117
+ "name": "openpose_xl",
118
+ "layers": ["pose"],
119
+ "loader": "xl",
120
+ "compute_type": torch.float16,
121
+ },
122
+ {
123
+ "repo_id": "xinsir/controlnet-scribble-sdxl-1.0",
124
+ "name": "scribble_xl",
125
+ "layers": ["scribble"],
126
+ "loader": "xl",
127
+ "compute_type": torch.float16,
128
+ },
129
+ {
130
+ "repo_id": "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
131
+ "name": "flux1_union_pro",
132
+ "layers": ["canny_fl", "tile_fl", "depth_fl", "blur_fl", "pose_fl", "gray_fl", "low_quality_fl"],
133
+ "loader": "flux-multi",
134
+ "compute_type": torch.bfloat16,
135
+ }
136
+ ]
137
+
138
+ for controlnet in controlnet_models:
139
+ if controlnet["loader"] == "xl":
140
+ controlnet["controlnet"] = ControlNetModel.from_pretrained(
141
+ controlnet["repo_id"],
142
+ torch_dtype = controlnet['compute_type']
143
+ ).to(device)
144
+ elif controlnet["loader"] == "flux-multi":
145
+ controlnet["controlnet"] = FluxMultiControlNetModel([FluxControlNetModel.from_pretrained(
146
+ controlnet["repo_id"],
147
+ torch_dtype = controlnet['compute_type']
148
+ ).to(device)])
149
+ #TODO: Add support for flux only controlnet
150
+
151
+
152
+ # Face Detection (for PhotoMaker)
153
+ face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition'])
154
+ face_detector.prepare(ctx_id=0, det_size=(640, 640))
155
+
156
+
157
+ # PhotoMaker V2 (for SDXL only)
158
+ photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker-V2", filename="photomaker-v2.bin", repo_type="model")
159
+
160
+ return device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt
161
+
162
+
163
+ device, models, sdxl_vae, refiner, safety_checker, feature_extractor, controlnet_models, face_detector, photomaker_ckpt = load_sd()
164
+
165
+
166
+ # Models
167
+ class ControlNetReq(BaseModel):
168
+ controlnets: List[str] # ["canny", "tile", "depth"]
169
+ control_images: List[Image.Image]
170
+ controlnet_conditioning_scale: List[float]
171
+
172
+ class Config:
173
+ arbitrary_types_allowed=True
174
+
175
+
176
+ class SDReq(BaseModel):
177
+ model: str = ""
178
+ prompt: str = ""
179
+ negative_prompt: Optional[str] = "black-forest-labs/FLUX.1-dev"
180
+ fast_generation: Optional[bool] = True
181
+ loras: Optional[list] = []
182
+ embeddings: Optional[list] = []
183
+ resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
184
+ scheduler: Optional[str] = "euler_fl"
185
+ height: int = 1024
186
+ width: int = 1024
187
+ num_images_per_prompt: int = 1
188
+ num_inference_steps: int = 8
189
+ guidance_scale: float = 3.5
190
+ seed: Optional[int] = 0
191
+ refiner: bool = False
192
+ vae: bool = True
193
+ controlnet_config: Optional[ControlNetReq] = None
194
+ photomaker_images: Optional[List[Image.Image]] = None
195
+
196
+ class Config:
197
+ arbitrary_types_allowed=True
198
+
199
+
200
+ class SDImg2ImgReq(SDReq):
201
+ image: Image.Image
202
+ strength: float = 1.0
203
+
204
+ class Config:
205
+ arbitrary_types_allowed=True
206
+
207
+
208
+ class SDInpaintReq(SDImg2ImgReq):
209
+ mask_image: Image.Image
210
+
211
+ class Config:
212
+ arbitrary_types_allowed=True
213
+
214
+
215
+ # Helper functions
216
+ def get_controlnet(controlnet_config: ControlNetReq):
217
+ control_mode = []
218
+ controlnet = []
219
+
220
+ for m in controlnet_models:
221
+ for c in controlnet_config.controlnets:
222
+ if c in m["layers"]:
223
+ control_mode.append(m["layers"].index(c))
224
+ controlnet.append(m["controlnet"])
225
+
226
+ return controlnet, control_mode
227
+
228
+
229
+ def get_pipe(request: SDReq | SDImg2ImgReq | SDInpaintReq):
230
+ for m in models:
231
+ if m["repo_id"] == request.model:
232
+ pipeline = m['pipeline']
233
+ controlnet, control_mode = get_controlnet(request.controlnet_config) if request.controlnet_config else (None, None)
234
+
235
+ pipe_args = {
236
+ "pipeline": pipeline,
237
+ "control_mode": control_mode,
238
+ }
239
+ if request.controlnet_config:
240
+ pipe_args["controlnet"] = controlnet
241
+
242
+ if not request.photomaker_images:
243
+ if isinstance(request, SDReq):
244
+ pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
245
+ elif isinstance(request, SDImg2ImgReq):
246
+ pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
247
+ elif isinstance(request, SDInpaintReq):
248
+ pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
249
+ else:
250
+ raise ValueError(f"Unknown request type: {type(request)}")
251
+ elif isinstance(request, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])):
252
+ if request.controlnet_config:
253
+ pipe_args['pipeline'] = PhotoMakerStableDiffusionXLControlNetPipeline.from_pipe(**pipe_args)
254
+ else:
255
+ pipe_args['pipeline'] = PhotoMakerStableDiffusionXLPipeline.from_pipe(**pipe_args)
256
+ else:
257
+ raise ValueError(f"Invalid request type: {type(request)}")
258
+
259
+ return pipe_args
260
+
261
+
262
+ def load_scheduler(pipeline, scheduler):
263
+ schedulers = {
264
+ "dpmpp_2m": (DPMSolverMultistepScheduler, {}),
265
+ "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
266
+ "dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++"}),
267
+ "dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True}),
268
+ "dpmpp_sde": (DPMSolverSinglestepScheduler, {}),
269
+ "dpmpp_sde_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
270
+ "dpm2": (KDPM2DiscreteScheduler, {}),
271
+ "dpm2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
272
+ "dpm2_a": (KDPM2AncestralDiscreteScheduler, {}),
273
+ "dpm2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
274
+ "euler": (EulerDiscreteScheduler, {}),
275
+ "euler_a": (EulerAncestralDiscreteScheduler, {}),
276
+ "heun": (HeunDiscreteScheduler, {}),
277
+ "lms": (LMSDiscreteScheduler, {}),
278
+ "lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
279
+ "deis": (DEISMultistepScheduler, {}),
280
+ "unipc": (UniPCMultistepScheduler, {}),
281
+ "fm_euler": (FlowMatchEulerDiscreteScheduler, {}),
282
+ }
283
+ scheduler_class, kwargs = schedulers.get(scheduler, (None, {}))
284
+
285
+ if scheduler_class is not None:
286
+ scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs)
287
+ else:
288
+ raise ValueError(f"Unknown scheduler: {scheduler}")
289
+
290
+ return scheduler
291
+
292
+
293
+ def load_loras(pipeline, loras, fast_generation):
294
+ for i, lora in enumerate(loras):
295
+ pipeline.load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}")
296
+ adapter_names = [f"lora_{i}" for i in range(len(loras))]
297
+ adapter_weights = [lora['weight'] for lora in loras]
298
+
299
+ if fast_generation:
300
+ hyper_lora = hf_hub_download(
301
+ "ByteDance/Hyper-SD",
302
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors" if isinstance(pipeline, FluxPipeline) else "Hyper-SDXL-2steps-lora.safetensors"
303
+ )
304
+ hyper_weight = 0.125 if isinstance(pipeline, FluxPipeline) else 1.0
305
+ pipeline.load_lora_weights(hyper_lora, adapter_name="hyper_lora")
306
+ adapter_names.append("hyper_lora")
307
+ adapter_weights.append(hyper_weight)
308
+
309
+ pipeline.set_adapters(adapter_names, adapter_weights)
310
+
311
+
312
+ def load_xl_embeddings(pipeline, embeddings):
313
+ for embedding in embeddings:
314
+ state_dict = load_file(hf_hub_download(embedding['repo_id']))
315
+ pipeline.load_textual_inversion(state_dict['clip_g'], token=embedding['token'], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
316
+ pipeline.load_textual_inversion(state_dict["clip_l"], token=embedding['token'], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
317
+
318
+
319
+ def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
320
+ for image in images:
321
+ if resize_mode == "resize_only":
322
+ image = image.resize((width, height))
323
+ elif resize_mode == "crop_and_resize":
324
+ image = image.crop((0, 0, width, height))
325
+ elif resize_mode == "resize_and_fill":
326
+ image = image.resize((width, height), Image.Resampling.LANCZOS)
327
+
328
+ return images
329
+
330
+
331
+ def get_controlnet_images(controlnets: List[str], control_images: List[Image.Image], height: int, width: int, resize_mode: str):
332
+ response_images = []
333
+ control_images = resize_images(control_images, height, width, resize_mode)
334
+ for controlnet, image in zip(controlnets, control_images):
335
+ if controlnet == "canny" or controlnet == "canny_xs" or controlnet == "canny_fl":
336
+ processor = Processor('canny')
337
+ elif controlnet == "depth" or controlnet == "depth_xs" or controlnet == "depth_fl":
338
+ processor = Processor('depth_midas')
339
+ elif controlnet == "pose" or controlnet == "pose_fl":
340
+ processor = Processor('openpose_full')
341
+ elif controlnet == "scribble":
342
+ processor = Processor('scribble')
343
+ else:
344
+ raise ValueError(f"Invalid Controlnet: {controlnet}")
345
+
346
+ response_images.append(processor(image, to_pil=True))
347
+
348
+ return response_images
349
+
350
+
351
+ def check_image_safety(images: List[Image.Image]):
352
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
353
+ has_nsfw_concepts = safety_checker(
354
+ images=[images],
355
+ clip_input=safety_checker_input.pixel_values.to("cuda"),
356
+ )
357
+
358
+ return has_nsfw_concepts[1]
359
+
360
+
361
+ def get_prompt_attention(pipeline, prompt, negative_prompt):
362
+ if isinstance(pipeline, (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline)):
363
+ prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt)
364
+ return prompt_embeds, None, pooled_prompt_embeds, None
365
+ elif isinstance(pipeline, StableDiffusionXLPipeline):
366
+ prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt)
367
+ return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
368
+ else:
369
+ raise ValueError(f"Invalid pipeline type: {type(pipeline)}")
370
+
371
+
372
+ def get_photomaker_images(photomaker_images: List[Image.Image], height: int, width: int, resize_mode: str):
373
+ image_input_ids = []
374
+ image_id_embeds = []
375
+ photomaker_images = resize_images(photomaker_images, height, width, resize_mode)
376
+
377
+ for image in photomaker_images:
378
+ image_input_ids.append(img)
379
+ img = np.array(image)[:, :, ::-1]
380
+ faces = analyze_faces(face_detector, image)
381
+ if len(faces) > 0:
382
+ image_id_embeds.append(torch.from_numpy(faces[0]['embeddings']))
383
+ else:
384
+ raise ValueError("No face detected in the image")
385
+
386
+ return image_input_ids, image_id_embeds
387
+
388
+
389
+ def cleanup(pipeline, loras = None, embeddings = None):
390
+ if loras:
391
+ pipeline.disable_lora()
392
+ pipeline.unload_lora_weights()
393
+ if embeddings:
394
+ pipeline.unload_textual_inversion()
395
+ gc.collect()
396
+ torch.cuda.empty_cache()
397
+
398
+
399
+ # Gen function
400
+ def gen_img(
401
+ request: SDReq | SDImg2ImgReq | SDInpaintReq
402
+ ):
403
+ pipeline_args = get_pipe(request)
404
+ pipeline = pipeline_args['pipeline']
405
+ try:
406
+ pipeline.scheduler = load_scheduler(pipeline, request.scheduler)
407
+
408
+ load_loras(pipeline, request.loras, request.fast_generation)
409
+ load_xl_embeddings(pipeline, request.embeddings)
410
+
411
+ control_images = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode) if request.controlnet_config else None
412
+ photomaker_images, photomaker_id_embeds = get_photomaker_images(request.photomaker_images, request.height, request.width) if request.photomaker_images else (None, None)
413
+
414
+ positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
415
+
416
+ # Common args
417
+ args = {
418
+ 'prompt_embeds': positive_prompt_embeds,
419
+ 'pooled_prompt_embeds': positive_prompt_pooled,
420
+ 'height': request.height,
421
+ 'width': request.width,
422
+ 'num_images_per_prompt': request.num_images_per_prompt,
423
+ 'num_inference_steps': request.num_inference_steps,
424
+ 'guidance_scale': request.guidance_scale,
425
+ 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
426
+ }
427
+
428
+ if isinstance(pipeline, any([StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline,
429
+ StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline])):
430
+ args['clip_skip'] = request.clip_skip
431
+ args['negative_prompt_embeds'] = negative_prompt_embeds
432
+ args['negative_pooled_prompt_embeds'] = negative_prompt_pooled
433
+
434
+ if isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config:
435
+ args['control_mode'] = pipeline_args['control_mode']
436
+ args['control_image'] = control_images
437
+ args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale
438
+
439
+ if not isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config:
440
+ args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale
441
+
442
+ if isinstance(request, SDReq):
443
+ args['image'] = control_images
444
+ elif isinstance(request, (SDImg2ImgReq, SDInpaintReq)):
445
+ args['control_image'] = control_images
446
+
447
+ if request.photomaker_images and isinstance(pipeline, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])):
448
+ args['input_id_images'] = photomaker_images
449
+ args['input_id_embeds'] = photomaker_id_embeds
450
+ args['start_merge_step'] = 10
451
+
452
+ if isinstance(request, SDImg2ImgReq):
453
+ args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)
454
+ args['strength'] = request.strength
455
+ elif isinstance(request, SDInpaintReq):
456
+ args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)
457
+ args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)
458
+ args['strength'] = request.strength
459
+
460
+ images = pipeline(**args).images
461
+
462
+ if request.refiner:
463
+ images = refiner(
464
+ prompt=request.prompt,
465
+ num_inference_steps=40,
466
+ denoising_start=0.7,
467
+ image=images.images
468
+ ).images
469
+
470
+ cleanup(pipeline, request.loras, request.embeddings)
471
+
472
+ return images
473
+ except Exception as e:
474
+ cleanup(pipeline, request.loras, request.embeddings)
475
+ raise ValueError(f"Error generating image: {e}") from e
476
+
477
+
478
+ # CSS
479
+ css = """
480
+ @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@300;400;600&display=swap');
481
+ body {
482
+ font-family: 'Poppins', sans-serif !important;
483
+ }
484
+ .center-content {
485
+ text-align: center;
486
+ max-width: 600px;
487
+ margin: 0 auto;
488
+ padding: 20px;
489
+ }
490
+ .center-content h1 {
491
+ font-weight: 600;
492
+ margin-bottom: 1rem;
493
+ }
494
+ .center-content p {
495
+ margin-bottom: 1.5rem;
496
+ }
497
+ """
498
+
499
+
500
+ flux_models = ["black-forest-labs/FLUX.1-dev"]
501
+ with open("data/images/loras/flux.json", "r") as f:
502
+ loras = json.load(f)
503
+
504
+
505
+ # Main Gradio app
506
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
507
+ # Header
508
+ with gr.Column(elem_classes="center-content"):
509
+ gr.Markdown("""
510
+ # πŸš€ AAI: All AI
511
+ Unleash your creativity with our multi-modal AI platform.
512
+ [![Sync code to HF Space](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml/badge.svg)](https://github.com/mantrakp04/aai/actions/workflows/hf-space.yml)
513
+ """)
514
+
515
+ # Tabs
516
+ with gr.Tabs():
517
+ with gr.Tab(label="πŸ–ΌοΈ Image"):
518
+ with gr.Tabs():
519
+ with gr.Tab("Flux"):
520
+ """
521
+ Create the image tab for Generative Image Generation Models
522
+
523
+ Args:
524
+ models: list
525
+ A list containing the models repository paths
526
+ gap_iol, gap_la, gap_le, gap_eio, gap_io: Optional[List[dict]]
527
+ A list of dictionaries containing the title and component for the custom gradio component
528
+ Example:
529
+ def gr_comp():
530
+ gr.Label("Hello World")
531
+
532
+ [
533
+ {
534
+ 'title': "Title",
535
+ 'component': gr_comp()
536
+ }
537
+ ]
538
+ loras: list
539
+ A list of dictionaries containing the image and title for the Loras Gallery
540
+ Generally a loaded json file from the data folder
541
+
542
+ """
543
+ def process_gaps(gaps: List[dict]):
544
+ for gap in gaps:
545
+ with gr.Accordion(gap['title']):
546
+ gap['component']
547
+
548
+
549
+ with gr.Row():
550
+ with gr.Column():
551
+ with gr.Group() as image_options:
552
+ model = gr.Dropdown(label="Models", choices=flux_models, value=flux_models[0], interactive=True)
553
+ prompt = gr.Textbox(lines=5, label="Prompt")
554
+ negative_prompt = gr.Textbox(label="Negative Prompt")
555
+ fast_generation = gr.Checkbox(label="Fast Generation (Hyper-SD) πŸ§ͺ")
556
+
557
+
558
+ with gr.Accordion("Loras", open=True): # Lora Gallery
559
+ lora_gallery = gr.Gallery(
560
+ label="Gallery",
561
+ value=[(lora['image'], lora['title']) for lora in loras],
562
+ allow_preview=False,
563
+ columns=[3],
564
+ type="pil"
565
+ )
566
+
567
+ with gr.Group():
568
+ with gr.Column():
569
+ with gr.Row():
570
+ custom_lora = gr.Textbox(label="Custom Lora", info="Enter a Huggingface repo path")
571
+ selected_lora = gr.Textbox(label="Selected Lora", info="Choose from the gallery or enter a custom LoRA")
572
+
573
+ custom_lora_info = gr.HTML(visible=False)
574
+ add_lora = gr.Button(value="Add LoRA")
575
+
576
+ enabled_loras = gr.State(value=[])
577
+ with gr.Group():
578
+ with gr.Row():
579
+ for i in range(6): # only support max 6 loras due to inference time
580
+ with gr.Column():
581
+ with gr.Column(scale=2):
582
+ globals()[f"lora_slider_{i}"] = gr.Slider(label=f"LoRA {i+1}", minimum=0, maximum=1, step=0.01, value=0.8, visible=False, interactive=True)
583
+ with gr.Column():
584
+ globals()[f"lora_remove_{i}"] = gr.Button(value="Remove LoRA", visible=False)
585
+
586
+
587
+ with gr.Accordion("Embeddings", open=False): # Embeddings
588
+ gr.Label("To be implemented")
589
+
590
+
591
+ with gr.Accordion("Image Options"): # Image Options
592
+ with gr.Tabs():
593
+ image_options = {
594
+ "img2img": "Upload Image",
595
+ "inpaint": "Upload Image",
596
+ "canny": "Upload Image",
597
+ "pose": "Upload Image",
598
+ "depth": "Upload Image",
599
+ }
600
+
601
+ for image_option, label in image_options.items():
602
+ with gr.Tab(image_option):
603
+ if not image_option in ['inpaint', 'scribble']:
604
+ globals()[f"{image_option}_image"] = gr.Image(label=label, type="pil")
605
+ elif image_option in ['inpaint', 'scribble']:
606
+ globals()[f"{image_option}_image"] = gr.ImageEditor(
607
+ label=label,
608
+ image_mode='RGB',
609
+ layers=False,
610
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") if image_option == 'inpaint' else gr.Brush(),
611
+ interactive=True,
612
+ type="pil",
613
+ )
614
+
615
+ # Image Strength (Co-relates to controlnet strength, strength for img2img n inpaint)
616
+ globals()[f"{image_option}_strength"] = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.01, value=1.0, interactive=True)
617
+
618
+ resize_mode = gr.Radio(
619
+ label="Resize Mode",
620
+ choices=["crop and resize", "resize only", "resize and fill"],
621
+ value="resize and fill",
622
+ interactive=True
623
+ )
624
+
625
+
626
+ with gr.Column():
627
+ with gr.Group():
628
+ output_images = gr.Gallery(
629
+ label="Output Images",
630
+ value=[],
631
+ allow_preview=True,
632
+ type="pil",
633
+ interactive=False,
634
+ )
635
+ generate_images = gr.Button(value="Generate Images", variant="primary")
636
+
637
+ with gr.Accordion("Advance Settings", open=True):
638
+ with gr.Row():
639
+ scheduler = gr.Dropdown(
640
+ label="Scheduler",
641
+ choices = [
642
+ "fm_euler"
643
+ ],
644
+ value="fm_euler",
645
+ interactive=True
646
+ )
647
+
648
+ with gr.Row():
649
+ for column in range(2):
650
+ with gr.Column():
651
+ options = [
652
+ ("Height", "image_height", 64, 1024, 64, 1024, True),
653
+ ("Width", "image_width", 64, 1024, 64, 1024, True),
654
+ ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
655
+ ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
656
+ ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, False),
657
+ ("Guidance Scale", "image_guidance_scale", 0, 20, 0.5, 3.5, True),
658
+ ("Seed", "image_seed", 0, 100000, 1, random.randint(0, 100000), True),
659
+ ]
660
+ for label, var_name, min_val, max_val, step, value, visible in options[column::2]:
661
+ globals()[var_name] = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=step, value=value, visible=visible, interactive=True)
662
+
663
+ with gr.Row():
664
+ refiner = gr.Checkbox(
665
+ label="Refiner πŸ§ͺ",
666
+ value=False,
667
+ )
668
+ vae = gr.Checkbox(
669
+ label="VAE",
670
+ value=True,
671
+ )
672
+
673
+
674
+ # Events
675
+ # Base Options
676
+ fast_generation.change(update_fast_generation, [model, fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
677
+
678
+
679
+ # Lora Gallery
680
+ lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
681
+ custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
682
+ add_lora.click(add_to_enabled_loras, [model, selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
683
+ enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore
684
+
685
+ for i in range(6):
686
+ globals()[f"lora_remove_{i}"].click(
687
+ lambda enabled_loras, index=i: remove_from_enabled_loras(enabled_loras, index),
688
+ [enabled_loras],
689
+ [enabled_loras]
690
+ )
691
+
692
+
693
+ # Generate Image
694
+ generate_images.click(
695
+ generate_image, # type: ignore
696
+ [
697
+ model, prompt, negative_prompt, fast_generation, enabled_loras,
698
+ lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
699
+ img2img_image, inpaint_image, canny_image, pose_image, depth_image, # type: ignore
700
+ img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, # type: ignore
701
+ resize_mode,
702
+ scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
703
+ image_num_inference_steps, image_guidance_scale, image_seed, # type: ignore
704
+ refiner, vae
705
+ ],
706
+ [output_images]
707
+ )
708
+ with gr.Tab("SDXL"):
709
+ gr.Label("To be implemented")
710
+ with gr.Tab(label="🎡 Audio"):
711
+ gr.Label("Coming soon!")
712
+ with gr.Tab(label="🎬 Video"):
713
+ gr.Label("Coming soon!")
714
+ with gr.Tab(label="πŸ“„ Text"):
715
+ gr.Label("Coming soon!")
716
+
717
+
718
+ demo.launch(
719
+ share=False,
720
+ debug=True,
721
+ )
src/tasks/images/init_sys.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
  import torch
3
  from diffusers import (
4
  DiffusionPipeline,
 
1
+ import spaces
2
  import torch
3
  from diffusers import (
4
  DiffusionPipeline,