prithivMLmods commited on
Commit
678ab98
·
verified ·
1 Parent(s): 68cf392

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +386 -320
app.py CHANGED
@@ -1,349 +1,415 @@
1
  import os
2
- import random
3
- import uuid
4
  import json
 
 
 
 
 
 
 
 
 
5
  import time
6
- import asyncio
7
  import re
8
- from threading import Thread
9
 
10
- import gradio as gr
11
- import spaces
12
- import torch
13
- import numpy as np
14
- from PIL import Image
15
- import edge_tts
16
 
17
- from transformers import (
18
- AutoModelForCausalLM,
19
- AutoTokenizer,
20
- TextIteratorStreamer,
21
- Qwen2VLForConditionalGeneration,
22
- AutoProcessor,
23
- )
24
- from transformers.image_utils import load_image
25
- from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
 
27
- MAX_MAX_NEW_TOKENS = 2048
28
- DEFAULT_MAX_NEW_TOKENS = 1024
29
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
 
30
 
31
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
32
 
33
- # -----------------------
34
- # Progress Bar Helper
35
- # -----------------------
36
- def progress_bar_html(label: str) -> str:
37
- """
38
- Returns an HTML snippet for a thin progress bar with a label.
39
- The progress bar is styled as a dark red animated bar.
40
- """
41
- return f'''
42
- <div style="display: flex; align-items: center;">
43
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
44
- <div style="width: 110px; height: 5px; background-color: #DDA0DD; border-radius: 2px; overflow: hidden;">
45
- <div style="width: 100%; height: 100%; background-color: #FF00FF; animation: loading 1.5s linear infinite;"></div>
46
- </div>
47
- </div>
48
- <style>
49
- @keyframes loading {{
50
- 0% {{ transform: translateX(-100%); }}
51
- 100% {{ transform: translateX(100%); }}
52
- }}
53
- </style>
54
- '''
55
 
56
- # -----------------------
57
- # Text Generation Setup
58
- # -----------------------
59
- model_id = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
60
- tokenizer = AutoTokenizer.from_pretrained(model_id)
61
- model = AutoModelForCausalLM.from_pretrained(
62
- model_id,
63
- device_map="auto",
64
- torch_dtype=torch.bfloat16,
65
- )
66
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- TTS_VOICES = [
69
- "en-US-JennyNeural", # @tts1
70
- "en-US-GuyNeural", # @tts2
71
- ]
72
 
73
- # -----------------------
74
- # Multimodal OCR Setup
75
- # -----------------------
76
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
77
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
78
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
79
- MODEL_ID,
80
- trust_remote_code=True,
81
- torch_dtype=torch.float16
82
- ).to("cuda").eval()
 
83
 
84
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
85
- """Convert text to speech using Edge TTS and save as MP3"""
86
- communicate = edge_tts.Communicate(text, voice)
87
- await communicate.save(output_file)
88
- return output_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- def clean_chat_history(chat_history):
91
- """
92
- Filter out any chat entries whose "content" is not a string.
93
- """
94
- cleaned = []
95
- for msg in chat_history:
96
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
97
- cleaned.append(msg)
98
- return cleaned
 
 
 
 
 
 
 
99
 
100
- # -----------------------
101
- # Stable Diffusion Image Generation Setup
102
- # -----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- MAX_SEED = np.iinfo(np.int32).max
105
- USE_TORCH_COMPILE = False
106
- ENABLE_CPU_OFFLOAD = False
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- if torch.cuda.is_available():
109
- pipe = StableDiffusionXLPipeline.from_pretrained(
110
- "SG161222/RealVisXL_V4.0_Lightning",
111
- torch_dtype=torch.float16,
112
- use_safetensors=True,
113
- )
114
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
115
-
116
- # LoRA options with one example for each.
117
- LORA_OPTIONS = {
118
- "Realism": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"),
119
- "Pixar": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"),
120
- "Photoshoot": ("prithivMLmods/Canopus-Photo-Shoot-Mini-LoRA", "Canopus-Photo-Shoot-Mini-LoRA.safetensors", "photo"),
121
- "Clothing": ("prithivMLmods/Canopus-Clothing-Adp-LoRA", "Canopus-Dress-Clothing-LoRA.safetensors", "clth"),
122
- "Interior": ("prithivMLmods/Canopus-Interior-Architecture-0.1", "Canopus-Interior-Architecture-0.1δ.safetensors", "arch"),
123
- "Fashion": ("prithivMLmods/Canopus-Fashion-Product-Dilation", "Canopus-Fashion-Product-Dilation.safetensors", "fashion"),
124
- "Minimalistic": ("prithivMLmods/Pegasi-Minimalist-Image-Style", "Pegasi-Minimalist-Image-Style.safetensors", "minimalist"),
125
- "Modern": ("prithivMLmods/Canopus-Modern-Clothing-Design", "Canopus-Modern-Clothing-Design.safetensors", "mdrnclth"),
126
- "Animaliea": ("prithivMLmods/Canopus-Animaliea-Artism", "Canopus-Animaliea-Artism.safetensors", "Animaliea"),
127
- "Wallpaper": ("prithivMLmods/Canopus-Liquid-Wallpaper-Art", "Canopus-Liquid-Wallpaper-Minimalize-LoRA.safetensors", "liquid"),
128
- "Cars": ("prithivMLmods/Canes-Cars-Model-LoRA", "Canes-Cars-Model-LoRA.safetensors", "car"),
129
- "PencilArt": ("prithivMLmods/Canopus-Pencil-Art-LoRA", "Canopus-Pencil-Art-LoRA.safetensors", "Pencil Art"),
130
- "ArtMinimalistic": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"),
131
- }
 
 
 
 
 
 
 
 
132
 
133
- # Load all LoRA weights
134
- for model_name, weight_name, adapter_name in LORA_OPTIONS.values():
135
- pipe.load_lora_weights(model_name, weight_name=weight_name, adapter_name=adapter_name)
136
- pipe.to("cuda")
137
- else:
138
- pipe = StableDiffusionXLPipeline.from_pretrained(
139
- "SG161222/RealVisXL_V4.0_Lightning",
140
- torch_dtype=torch.float32,
141
- use_safetensors=True,
142
- ).to(device)
143
 
144
- def save_image(img: Image.Image) -> str:
145
- """Save a PIL image with a unique filename and return the path."""
146
- unique_name = str(uuid.uuid4()) + ".png"
147
- img.save(unique_name)
148
- return unique_name
149
 
150
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
151
- if randomize_seed:
152
- seed = random.randint(0, MAX_SEED)
153
- return seed
154
 
155
- @spaces.GPU(duration=180, enable_queue=True)
156
- def generate_image(
157
- prompt: str,
158
- negative_prompt: str = "",
159
- seed: int = 0,
160
- width: int = 1024,
161
- height: int = 1024,
162
- guidance_scale: float = 3.0,
163
- randomize_seed: bool = True,
164
- lora_model: str = "Realism",
165
- progress=gr.Progress(track_tqdm=True),
166
- ):
167
- seed = int(randomize_seed_fn(seed, randomize_seed))
168
- effective_negative_prompt = negative_prompt # Use provided negative prompt if any
169
- model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model]
170
- pipe.set_adapters(adapter_name)
171
- outputs = pipe(
172
- prompt=prompt,
173
- negative_prompt=effective_negative_prompt,
174
- width=width,
175
- height=height,
176
- guidance_scale=guidance_scale,
177
- num_inference_steps=28,
178
- num_images_per_prompt=1,
179
- cross_attention_kwargs={"scale": 0.65},
180
- output_type="pil",
181
- )
182
- images = outputs.images
183
- image_paths = [save_image(img) for img in images]
184
- return image_paths, seed
185
 
186
- # -----------------------
187
- # Main Chat/Generation Function
188
- # -----------------------
189
- @spaces.GPU
190
- def generate(
191
- input_dict: dict,
192
- chat_history: list[dict],
193
- max_new_tokens: int = 1024,
194
- temperature: float = 0.6,
195
- top_p: float = 0.9,
196
- top_k: int = 50,
197
- repetition_penalty: float = 1.2,
198
- ):
199
- """
200
- Generates chatbot responses with support for multimodal input, TTS, and image generation.
201
- Special commands:
202
- - "@tts1" or "@tts2": triggers text-to-speech.
203
- - "@<lora_command>": triggers image generation using the new LoRA pipeline.
204
- Available commands (case-insensitive): @realism, @pixar, @photoshoot, @clothing, @interior, @fashion,
205
- @minimalistic, @modern, @animaliea, @wallpaper, @cars, @pencilart, @artminimalistic.
206
- """
207
- text = input_dict["text"]
208
- files = input_dict.get("files", [])
209
-
210
- # Check for image generation command based on LoRA tags.
211
- lora_mapping = { key.lower(): key for key in LORA_OPTIONS }
212
- for key_lower, key in lora_mapping.items():
213
- command_tag = "@" + key_lower
214
- if text.strip().lower().startswith(command_tag):
215
- prompt_text = text.strip()[len(command_tag):].strip()
216
- yield progress_bar_html(f"Processing Image Generation ({key} style)")
217
- image_paths, used_seed = generate_image(
218
- prompt=prompt_text,
219
- negative_prompt="",
220
- seed=1,
221
- width=1024,
222
- height=1024,
223
- guidance_scale=3,
224
- randomize_seed=True,
225
- lora_model=key,
226
- )
227
- yield progress_bar_html("Finalizing Image Generation")
228
- yield gr.Image(image_paths[0])
229
- return
230
-
231
- # Check for TTS command (@tts1 or @tts2)
232
- tts_prefix = "@tts"
233
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
234
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
235
-
236
- if is_tts and voice_index:
237
- voice = TTS_VOICES[voice_index - 1]
238
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
239
- conversation = [{"role": "user", "content": text}]
240
- else:
241
- voice = None
242
- text = text.replace(tts_prefix, "").strip()
243
- conversation = clean_chat_history(chat_history)
244
- conversation.append({"role": "user", "content": text})
245
-
246
- if files:
247
- if len(files) > 1:
248
- images = [load_image(image) for image in files]
249
- elif len(files) == 1:
250
- images = [load_image(files[0])]
251
- else:
252
- images = []
253
- messages = [{
254
- "role": "user",
255
- "content": [
256
- *[{"type": "image", "image": image} for image in images],
257
- {"type": "text", "text": text},
258
- ]
259
- }]
260
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
261
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
262
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
263
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
264
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
265
- thread.start()
266
 
267
- buffer = ""
268
- yield progress_bar_html("Processing with Qwen2VL Ocr")
269
- for new_text in streamer:
270
- buffer += new_text
271
- buffer = buffer.replace("<|im_end|>", "")
272
- time.sleep(0.01)
273
- yield buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  else:
275
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
276
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
277
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
278
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
279
- input_ids = input_ids.to(model.device)
280
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
281
- generation_kwargs = {
282
- "input_ids": input_ids,
283
- "streamer": streamer,
284
- "max_new_tokens": max_new_tokens,
285
- "do_sample": True,
286
- "top_p": top_p,
287
- "top_k": top_k,
288
- "temperature": temperature,
289
- "num_beams": 1,
290
- "repetition_penalty": repetition_penalty,
291
- }
292
- t = Thread(target=model.generate, kwargs=generation_kwargs)
293
- t.start()
294
 
295
- outputs = []
296
- for new_text in streamer:
297
- outputs.append(new_text)
298
- yield "".join(outputs)
299
 
300
- final_response = "".join(outputs)
301
- yield final_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
- if is_tts and voice:
304
- output_file = asyncio.run(text_to_speech(final_response, voice))
305
- yield gr.Audio(output_file, autoplay=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
- # -----------------------
308
- # Gradio Chat Interface
309
- # -----------------------
310
- demo = gr.ChatInterface(
311
- fn=generate,
312
- additional_inputs=[
313
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
314
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
315
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
316
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
317
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
318
- ],
319
- examples=[
320
- ["@realism Chocolate dripping from a donut, hyper-realistic"],
321
- ["@pixar Young man with wavy hair in an armchair, Pixar style"],
322
- ["@realism Futuristic cityscape with neon lights"],
323
- ["Write the python array rotation program"],
324
- ["@photoshoot Dramatic portrait lighting"],
325
- [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
326
- ["@tts1 Who was Nikola Tesla and how did he die?"],
327
- ["@clothing Fashionable streetwear in the city"],
328
- ["@interior Modern minimalist living room"],
329
- ["@fashion Runway model in haute couture"],
330
- ["@minimalistic Elegant serene landscape"],
331
- ["@modern Abstract geometric art"],
332
- ["@animaliea Cute animal portrait, vibrant colors"],
333
- ["@wallpaper Scenic mountain desktop wallpaper"],
334
- ["@cars Sleek sports car on city streets"],
335
- ["@pencilart Detailed historic building sketch"],
336
- ["@artminimalistic Subtle minimalist artwork"],
337
- ["@tts2 What causes rainbows?"],
338
- ],
339
- cache_examples=False,
340
- type="messages",
341
- description="# **Gen Vision Sdxl** `tts: @tts1 @tts2` \n `image-tags: @realism, @pixar, @photoshoot, @clothing, @interior, @fashion, @minimalistic, @modern, @animaliea, @wallpaper, @cars, @pencilart, @artminimalistic` \n \n `default: chat, image inference`",
342
- fill_height=True,
343
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="text, image-infer, image-generation, tts"),
344
- stop_btn="Stop Generation",
345
- multimodal=True,
346
- )
347
 
348
- if __name__ == "__main__":
349
- demo.queue(max_size=20).launch(share=True)
 
1
  import os
2
+ import gradio as gr
 
3
  import json
4
+ import logging
5
+ import torch
6
+ from PIL import Image
7
+ import spaces
8
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
9
+ from diffusers.utils import load_image
10
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
11
+ import copy
12
+ import random
13
  import time
 
14
  import re
 
15
 
16
+ # Load LoRAs from JSON file
17
+ with open('loras.json', 'r') as f:
18
+ loras = json.load(f)
 
 
 
19
 
20
+ # Initialize the base model for SDXL
21
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ base_model = "stabilityai/stable-diffusion-xl-base-1.0"
 
 
 
 
 
24
 
25
+ # Load SDXL pipelines
26
+ pipe = StableDiffusionXLPipeline.from_pretrained(
27
+ base_model,
28
+ torch_dtype=dtype,
29
+ use_safetensors=True
30
+ ).to(device)
31
 
32
+ pipe_i2i = StableDiffusionXLImg2ImgPipeline.from_pretrained(
33
+ base_model,
34
+ torch_dtype=dtype,
35
+ use_safetensors=True
36
+ ).to(device)
37
 
38
+ MAX_SEED = 2**32 - 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # Custom SDXL generation function for live preview
41
+ @torch.inference_mode()
42
+ def generate_sdxl_images(
43
+ pipe,
44
+ prompt: str,
45
+ height: int = 1024,
46
+ width: int = 1024,
47
+ num_inference_steps: int = 50,
48
+ guidance_scale: float = 7.5,
49
+ generator: Optional[torch.Generator] = None,
50
+ output_type: str = "pil",
51
+ ):
52
+ # Encode prompt
53
+ prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
54
+ prompt=prompt,
55
+ num_images_per_prompt=1,
56
+ do_classifier_free_guidance=True,
57
+ )
58
+ # Prepare latents
59
+ latents = pipe.prepare_latents(
60
+ batch_size=1,
61
+ num_channels_latents=pipe.unet.config.in_channels,
62
+ height=height,
63
+ width=width,
64
+ dtype=prompt_embeds.dtype,
65
+ device=pipe.device,
66
+ generator=generator,
67
+ )
68
+ # Prepare timesteps
69
+ pipe.scheduler.set_timesteps(num_inference_steps, device=pipe.device)
70
+ timesteps = pipe.scheduler.timesteps
71
+ # Prepare guidance
72
+ do_classifier_free_guidance = guidance_scale > 1.0
73
+ if do_classifier_free_guidance:
74
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
75
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
76
+ # Denoising loop
77
+ for i, t in enumerate(timesteps):
78
+ # Expand latents for guidance
79
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
80
+ # Predict noise
81
+ noise_pred = pipe.unet(
82
+ latent_model_input,
83
+ t,
84
+ encoder_hidden_states=prompt_embeds,
85
+ added_cond_kwargs={"text_embeds": pooled_prompt_embeds},
86
+ ).sample
87
+ # Perform guidance
88
+ if do_classifier_free_guidance:
89
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
90
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
91
+ # Step scheduler
92
+ latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
93
+ # Decode latents to image every step
94
+ image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
95
+ yield pipe.image_processor.postprocess(image, output_type=output_type)[0]
96
+ # Final image
97
+ image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
98
+ yield pipe.image_processor.postprocess(image, output_type=output_type)[0]
99
 
100
+ class calculateDuration:
101
+ def __init__(self, activity_name=""):
102
+ self.activity_name = activity_name
 
103
 
104
+ def __enter__(self):
105
+ self.start_time = time.time()
106
+ return self
107
+
108
+ def __exit__(self, exc_type, exc_value, traceback):
109
+ self.end_time = time.time()
110
+ self.elapsed_time = self.end_time - self.start_time
111
+ if self.activity_name:
112
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
113
+ else:
114
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
115
 
116
+ def update_selection(evt: gr.SelectData, width, height):
117
+ selected_lora = loras[evt.index]
118
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
119
+ lora_repo = selected_lora["repo"]
120
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
121
+ if "aspect" in selected_lora:
122
+ if selected_lora["aspect"] == "portrait":
123
+ width = 768
124
+ height = 1024
125
+ elif selected_lora["aspect"] == "landscape":
126
+ width = 1024
127
+ height = 768
128
+ else:
129
+ width = 1024
130
+ height = 1024
131
+ return (
132
+ gr.update(placeholder=new_placeholder),
133
+ updated_text,
134
+ evt.index,
135
+ width,
136
+ height,
137
+ )
138
 
139
+ @spaces.GPU(duration=70)
140
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
141
+ pipe.to("cuda")
142
+ generator = torch.Generator(device="cuda").manual_seed(seed)
143
+ with calculateDuration("Generating image"):
144
+ for img in generate_sdxl_images(
145
+ pipe,
146
+ prompt=prompt_mash,
147
+ num_inference_steps=steps,
148
+ guidance_scale=cfg_scale,
149
+ width=width,
150
+ height=height,
151
+ generator=generator,
152
+ output_type="pil",
153
+ ):
154
+ yield img
155
 
156
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
157
+ generator = torch.Generator(device="cuda").manual_seed(seed)
158
+ pipe_i2i.to("cuda")
159
+ image_input = load_image(image_input_path)
160
+ final_image = pipe_i2i(
161
+ prompt=prompt_mash,
162
+ image=image_input,
163
+ strength=image_strength,
164
+ num_inference_steps=steps,
165
+ guidance_scale=cfg_scale,
166
+ width=width,
167
+ height=height,
168
+ generator=generator,
169
+ output_type="pil",
170
+ ).images[0]
171
+ return final_image
172
 
173
+ @spaces.GPU(duration=70)
174
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
175
+ if selected_index is None:
176
+ raise gr.Error("You must select a LoRA before proceeding.")
177
+ selected_lora = loras[selected_index]
178
+ lora_path = selected_lora["repo"]
179
+ trigger_word = selected_lora["trigger_word"]
180
+ if trigger_word:
181
+ if "trigger_position" in selected_lora and selected_lora["trigger_position"] == "prepend":
182
+ prompt_mash = f"{trigger_word} {prompt}"
183
+ else:
184
+ prompt_mash = f"{prompt} {trigger_word}"
185
+ else:
186
+ prompt_mash = prompt
187
 
188
+ # Unload previous LoRA weights
189
+ with calculateDuration("Unloading LoRA"):
190
+ pipe.unload_lora_weights()
191
+ pipe_i2i.unload_lora_weights()
192
+
193
+ # Load LoRA weights and set adapter scale
194
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
195
+ weight_name = selected_lora.get("weights", None)
196
+ adapter_name = "lora"
197
+ pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=adapter_name)
198
+ pipe.set_adapters([adapter_name], [lora_scale])
199
+ pipe_i2i.load_lora_weights(lora_path, weight_name=weight_name, adapter_name=adapter_name)
200
+ pipe_i2i.set_adapters([adapter_name], [lora_scale])
201
+
202
+ # Set random seed
203
+ with calculateDuration("Randomizing seed"):
204
+ if randomize_seed:
205
+ seed = random.randint(0, MAX_SEED)
206
+
207
+ if image_input is not None:
208
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
209
+ yield final_image, seed, gr.update(visible=False)
210
+ else:
211
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
212
+ final_image = None
213
+ step_counter = 0
214
+ for image in image_generator:
215
+ step_counter += 1
216
+ final_image = image
217
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
218
+ yield image, seed, gr.update(value=progress_bar, visible=True)
219
+ yield final_image, seed, gr.update(value=progress_bar, visible=False)
220
 
221
+ def get_huggingface_safetensors(link):
222
+ split_link = link.split("/")
223
+ if len(split_link) != 2:
224
+ raise Exception("Invalid Hugging Face repository link format.")
 
 
 
 
 
 
225
 
226
+ # Load model card
227
+ model_card = ModelCard.load(link)
228
+ base_model = model_card.data.get("base_model")
229
+ print(base_model)
 
230
 
231
+ # Validate model type for SDXL
232
+ if base_model != "stabilityai/stable-diffusion-xl-base-1.0":
233
+ raise Exception("Not an SDXL LoRA!")
 
234
 
235
+ # Extract image and trigger word
236
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
237
+ trigger_word = model_card.data.get("instance_prompt", "")
238
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
+ # Initialize Hugging Face file system
241
+ fs = HfFileSystem()
242
+ try:
243
+ list_of_files = fs.ls(link, detail=False)
244
+ safetensors_name = None
245
+ highest_trained_file = None
246
+ highest_steps = -1
247
+ last_safetensors_file = None
248
+ step_pattern = re.compile(r"_0{3,}\d+") # Detects step count `_000...`
249
+
250
+ for file in list_of_files:
251
+ filename = file.split("/")[-1]
252
+ if filename.endswith(".safetensors"):
253
+ last_safetensors_file = filename
254
+ match = step_pattern.search(filename)
255
+ if not match:
256
+ safetensors_name = filename
257
+ break
258
+ else:
259
+ steps = int(match.group().lstrip("_"))
260
+ if steps > highest_steps:
261
+ highest_trained_file = filename
262
+ highest_steps = steps
263
+ if not image_url and filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
264
+ image_url = f"https://huggingface.co/{link}/resolve/main/{filename}"
265
+
266
+ if not safetensors_name:
267
+ safetensors_name = highest_trained_file if highest_trained_file else last_safetensors_file
268
+ if not safetensors_name:
269
+ raise Exception("No valid *.safetensors file found in the repository.")
270
+ except Exception as e:
271
+ print(e)
272
+ raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ return split_link[1], link, safetensors_name, trigger_word, image_url
275
+
276
+ def check_custom_model(link):
277
+ if link.startswith("https://"):
278
+ if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
279
+ link_split = link.split("huggingface.co/")
280
+ return get_huggingface_safetensors(link_split[1])
281
+ else:
282
+ return get_huggingface_safetensors(link)
283
+
284
+ def add_custom_lora(custom_lora):
285
+ global loras
286
+ if custom_lora:
287
+ try:
288
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
289
+ print(f"Loaded custom LoRA: {repo}")
290
+ card = f'''
291
+ <div class="custom_lora_card">
292
+ <span>Loaded custom LoRA:</span>
293
+ <div class="card_internal">
294
+ <img src="{image}" />
295
+ <div>
296
+ <h3>{title}</h3>
297
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
298
+ </div>
299
+ </div>
300
+ </div>
301
+ '''
302
+ existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
303
+ if not existing_item_index:
304
+ new_item = {
305
+ "image": image,
306
+ "title": title,
307
+ "repo": repo,
308
+ "weights": path,
309
+ "trigger_word": trigger_word
310
+ }
311
+ print(new_item)
312
+ existing_item_index = len(loras)
313
+ loras.append(new_item)
314
+ return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
315
+ except Exception as e:
316
+ gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-SDXL LoRA")
317
+ return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-SDXL LoRA"), gr.update(visible=True), gr.update(), "", None, ""
318
  else:
319
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
320
+
321
+ def remove_custom_lora():
322
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ run_lora.zerogpu = True
 
 
 
325
 
326
+ css = '''
327
+ #gen_btn{height: 100%}
328
+ #gen_column{align-self: stretch}
329
+ #title{text-align: center}
330
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
331
+ #title img{width: 100px; margin-right: 0.5em}
332
+ #gallery .grid-wrap{height: 10vh}
333
+ #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
334
+ .card_internal{display: flex;height: 100px;margin-top: .5em}
335
+ .card_internal img{margin-right: 1em}
336
+ .styler{--form-gap-width: 0px !important}
337
+ #progress{height:30px}
338
+ #progress .generating{display:none}
339
+ .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
340
+ .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
341
+ '''
342
+ font = [gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
343
+ with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(60, 60)) as app:
344
+ title = gr.HTML(
345
+ """<h1>SDXL LoRA DLC</h1>""",
346
+ elem_id="title",
347
+ )
348
+ selected_index = gr.State(None)
349
+ with gr.Row():
350
+ with gr.Column(scale=3):
351
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
352
+ with gr.Column(scale=1, elem_id="gen_column"):
353
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
354
+ with gr.Row():
355
+ with gr.Column():
356
+ selected_info = gr.Markdown("")
357
+ gallery = gr.Gallery(
358
+ [(item["image"], item["title"]) for item in loras],
359
+ label="LoRA Gallery",
360
+ allow_preview=False,
361
+ columns=3,
362
+ elem_id="gallery",
363
+ show_share_button=False
364
+ )
365
+ with gr.Group():
366
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="username/sdxl-lora-model")
367
+ gr.Markdown("[Check the list of SDXL LoRAs](https://huggingface.co/models?other=base_model:stabilityai/stable-diffusion-xl-base-1.0)", elem_id="lora_list")
368
+ custom_lora_info = gr.HTML(visible=False)
369
+ custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
370
+ with gr.Column():
371
+ progress_bar = gr.Markdown(elem_id="progress", visible=False)
372
+ result = gr.Image(label="Generated Image")
373
 
374
+ with gr.Row():
375
+ with gr.Accordion("Advanced Settings", open=False):
376
+ with gr.Row():
377
+ input_image = gr.Image(label="Input image", type="filepath")
378
+ image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
379
+ with gr.Column():
380
+ with gr.Row():
381
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=7.5)
382
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=30)
383
+
384
+ with gr.Row():
385
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
386
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
387
+
388
+ with gr.Row():
389
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
390
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
391
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=1.0)
392
 
393
+ gallery.select(
394
+ update_selection,
395
+ inputs=[width, height],
396
+ outputs=[prompt, selected_info, selected_index, width, height]
397
+ )
398
+ custom_lora.input(
399
+ add_custom_lora,
400
+ inputs=[custom_lora],
401
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
402
+ )
403
+ custom_lora_button.click(
404
+ remove_custom_lora,
405
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
406
+ )
407
+ gr.on(
408
+ triggers=[generate_button.click, prompt.submit],
409
+ fn=run_lora,
410
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
411
+ outputs=[result, seed, progress_bar]
412
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
+ app.queue()
415
+ app.launch()