Keltezaa commited on
Commit
414f16f
·
verified ·
1 Parent(s): 44a3094

Upload Do not edit app.py

Browse files
Files changed (1) hide show
  1. Do not edit app.py +586 -0
Do not edit app.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import json
4
+ import logging
5
+ import torch
6
+ from PIL import Image
7
+ import spaces
8
+ from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
9
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
10
+ from diffusers.utils import load_image
11
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
12
+ import copy
13
+ import random
14
+ import time
15
+ import requests
16
+ import pandas as pd
17
+
18
+ #Load prompts for randomization
19
+ df = pd.read_csv('prompts.csv', header=None)
20
+ prompt_values = df.values.flatten()
21
+
22
+ # Load LoRAs from JSON file
23
+ with open('loras.json', 'r') as f:
24
+ loras = json.load(f)
25
+
26
+ # Initialize the base model
27
+ dtype = torch.bfloat16
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ base_model = "black-forest-labs/FLUX.1-dev"
30
+
31
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
32
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
33
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
34
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
35
+ base_model,
36
+ vae=good_vae,
37
+ transformer=pipe.transformer,
38
+ text_encoder=pipe.text_encoder,
39
+ tokenizer=pipe.tokenizer,
40
+ text_encoder_2=pipe.text_encoder_2,
41
+ tokenizer_2=pipe.tokenizer_2,
42
+ torch_dtype=dtype
43
+ )
44
+
45
+ MAX_SEED = 2**32 - 1
46
+
47
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
48
+
49
+ class calculateDuration:
50
+ def __init__(self, activity_name=""):
51
+ self.activity_name = activity_name
52
+
53
+ def __enter__(self):
54
+ self.start_time = time.time()
55
+ return self
56
+
57
+ def __exit__(self, exc_type, exc_value, traceback):
58
+ self.end_time = time.time()
59
+ self.elapsed_time = self.end_time - self.start_time
60
+ if self.activity_name:
61
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
62
+ else:
63
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
64
+
65
+ def download_file(url, directory=None):
66
+ if directory is None:
67
+ directory = os.getcwd() # Use current working directory if not specified
68
+
69
+ # Get the filename from the URL
70
+ filename = url.split('/')[-1]
71
+
72
+ # Full path for the downloaded file
73
+ filepath = os.path.join(directory, filename)
74
+
75
+ # Download the file
76
+ response = requests.get(url)
77
+ response.raise_for_status() # Raise an exception for bad status codes
78
+
79
+ # Write the content to the file
80
+ with open(filepath, 'wb') as file:
81
+ file.write(response.content)
82
+
83
+ return filepath
84
+
85
+ def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
86
+ selected_index = evt.index
87
+ selected_indices = selected_indices or []
88
+ if selected_index in selected_indices:
89
+ selected_indices.remove(selected_index)
90
+ else:
91
+ if len(selected_indices) < 2:
92
+ selected_indices.append(selected_index)
93
+ else:
94
+ gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
95
+ return gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), width, height, gr.update(), gr.update()
96
+
97
+ selected_info_1 = "Select a LoRA 1"
98
+ selected_info_2 = "Select a LoRA 2"
99
+ lora_scale_1 = 1.15
100
+ lora_scale_2 = 1.15
101
+ lora_image_1 = None
102
+ lora_image_2 = None
103
+ if len(selected_indices) >= 1:
104
+ lora1 = loras_state[selected_indices[0]]
105
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
106
+ lora_image_1 = lora1['image']
107
+ if len(selected_indices) >= 2:
108
+ lora2 = loras_state[selected_indices[1]]
109
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
110
+ lora_image_2 = lora2['image']
111
+
112
+ if selected_indices:
113
+ last_selected_lora = loras_state[selected_indices[-1]]
114
+ new_placeholder = f"Type a prompt for {last_selected_lora['title']}"
115
+ else:
116
+ new_placeholder = "Type a prompt after selecting a LoRA"
117
+
118
+ return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2
119
+
120
+ def remove_lora_1(selected_indices, loras_state):
121
+ if len(selected_indices) >= 1:
122
+ selected_indices.pop(0)
123
+ selected_info_1 = "Select a LoRA 1"
124
+ selected_info_2 = "Select a LoRA 2"
125
+ lora_scale_1 = 1.15
126
+ lora_scale_2 = 1.15
127
+ lora_image_1 = None
128
+ lora_image_2 = None
129
+ if len(selected_indices) >= 1:
130
+ lora1 = loras_state[selected_indices[0]]
131
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
132
+ lora_image_1 = lora1['image']
133
+ if len(selected_indices) >= 2:
134
+ lora2 = loras_state[selected_indices[1]]
135
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
136
+ lora_image_2 = lora2['image']
137
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
138
+
139
+ def remove_lora_2(selected_indices, loras_state):
140
+ if len(selected_indices) >= 2:
141
+ selected_indices.pop(1)
142
+ selected_info_1 = "Select LoRA 1"
143
+ selected_info_2 = "Select LoRA 2"
144
+ lora_scale_1 = 1.15
145
+ lora_scale_2 = 1.15
146
+ lora_image_1 = None
147
+ lora_image_2 = None
148
+ if len(selected_indices) >= 1:
149
+ lora1 = loras_state[selected_indices[0]]
150
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
151
+ lora_image_1 = lora1['image']
152
+ if len(selected_indices) >= 2:
153
+ lora2 = loras_state[selected_indices[1]]
154
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
155
+ lora_image_2 = lora2['image']
156
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
157
+
158
+ def randomize_loras(selected_indices, loras_state):
159
+ if len(loras_state) < 2:
160
+ raise gr.Error("Not enough LoRAs to randomize.")
161
+ selected_indices = random.sample(range(len(loras_state)), 2)
162
+ lora1 = loras_state[selected_indices[0]]
163
+ lora2 = loras_state[selected_indices[1]]
164
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
165
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
166
+ lora_scale_1 = 1.15
167
+ lora_scale_2 = 1.15
168
+ lora_image_1 = lora1['image']
169
+ lora_image_2 = lora2['image']
170
+ random_prompt = random.choice(prompt_values)
171
+ return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt
172
+
173
+ def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
174
+ if custom_lora:
175
+ try:
176
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
177
+ print(f"Loaded custom LoRA: {repo}")
178
+ existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
179
+ if existing_item_index is None:
180
+ if repo.endswith(".safetensors") and repo.startswith("http"):
181
+ repo = download_file(repo)
182
+ new_item = {
183
+ "image": image if image else "/home/user/app/custom.png",
184
+ "title": title,
185
+ "repo": repo,
186
+ "weights": path,
187
+ "trigger_word": trigger_word
188
+ }
189
+ print(f"New LoRA: {new_item}")
190
+ existing_item_index = len(current_loras)
191
+ current_loras.append(new_item)
192
+
193
+ # Update gallery
194
+ gallery_items = [(item["image"], item["title"]) for item in current_loras]
195
+ # Update selected_indices if there's room
196
+ if len(selected_indices) < 2:
197
+ selected_indices.append(existing_item_index)
198
+ else:
199
+ gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
200
+
201
+ # Update selected_info and images
202
+ selected_info_1 = "Select a LoRA 1"
203
+ selected_info_2 = "Select a LoRA 2"
204
+ lora_scale_1 = 1.15
205
+ lora_scale_2 = 1.15
206
+ lora_image_1 = None
207
+ lora_image_2 = None
208
+ if len(selected_indices) >= 1:
209
+ lora1 = current_loras[selected_indices[0]]
210
+ selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨"
211
+ lora_image_1 = lora1['image'] if lora1['image'] else None
212
+ if len(selected_indices) >= 2:
213
+ lora2 = current_loras[selected_indices[1]]
214
+ selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨"
215
+ lora_image_2 = lora2['image'] if lora2['image'] else None
216
+ print("Finished adding custom LoRA")
217
+ return (
218
+ current_loras,
219
+ gr.update(value=gallery_items),
220
+ selected_info_1,
221
+ selected_info_2,
222
+ selected_indices,
223
+ lora_scale_1,
224
+ lora_scale_2,
225
+ lora_image_1,
226
+ lora_image_2
227
+ )
228
+ except Exception as e:
229
+ print(e)
230
+ gr.Warning(str(e))
231
+ return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
232
+ else:
233
+ return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
234
+
235
+ def remove_custom_lora(selected_indices, current_loras, gallery):
236
+ if current_loras:
237
+ custom_lora_repo = current_loras[-1]['repo']
238
+ # Remove from loras list
239
+ current_loras = current_loras[:-1]
240
+ # Remove from selected_indices if selected
241
+ custom_lora_index = len(current_loras)
242
+ if custom_lora_index in selected_indices:
243
+ selected_indices.remove(custom_lora_index)
244
+ # Update gallery
245
+ gallery_items = [(item["image"], item["title"]) for item in current_loras]
246
+ # Update selected_info and images
247
+ selected_info_1 = "Select a LoRA 1"
248
+ selected_info_2 = "Select a LoRA 2"
249
+ lora_scale_1 = 1.15
250
+ lora_scale_2 = 1.15
251
+ lora_image_1 = None
252
+ lora_image_2 = None
253
+ if len(selected_indices) >= 1:
254
+ lora1 = current_loras[selected_indices[0]]
255
+ selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
256
+ lora_image_1 = lora1['image']
257
+ if len(selected_indices) >= 2:
258
+ lora2 = current_loras[selected_indices[1]]
259
+ selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
260
+ lora_image_2 = lora2['image']
261
+ return (
262
+ current_loras,
263
+ gr.update(value=gallery_items),
264
+ selected_info_1,
265
+ selected_info_2,
266
+ selected_indices,
267
+ lora_scale_1,
268
+ lora_scale_2,
269
+ lora_image_1,
270
+ lora_image_2
271
+ )
272
+
273
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
274
+ print("Generating image...")
275
+ pipe.to("cuda")
276
+ generator = torch.Generator(device="cuda").manual_seed(seed)
277
+ with calculateDuration("Generating image"):
278
+ # Generate image
279
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
280
+ prompt=prompt_mash,
281
+ num_inference_steps=steps,
282
+ guidance_scale=cfg_scale,
283
+ width=width,
284
+ height=height,
285
+ generator=generator,
286
+ joint_attention_kwargs={"scale": 1.0},
287
+ output_type="pil",
288
+ good_vae=good_vae,
289
+ ):
290
+ yield img
291
+
292
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
293
+ pipe_i2i.to("cuda")
294
+ generator = torch.Generator(device="cuda").manual_seed(seed)
295
+ image_input = load_image(image_input_path)
296
+ final_image = pipe_i2i(
297
+ prompt=prompt_mash,
298
+ image=image_input,
299
+ strength=image_strength,
300
+ num_inference_steps=steps,
301
+ guidance_scale=cfg_scale,
302
+ width=width,
303
+ height=height,
304
+ generator=generator,
305
+ joint_attention_kwargs={"scale": 1.0},
306
+ output_type="pil",
307
+ ).images[0]
308
+ return final_image
309
+
310
+ @spaces.GPU(duration=75)
311
+ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
312
+ if not selected_indices:
313
+ raise gr.Error("You must select at least one LoRA before proceeding.")
314
+
315
+ selected_loras = [loras_state[idx] for idx in selected_indices]
316
+
317
+ # Build the prompt with trigger words
318
+ prepends = []
319
+ appends = []
320
+ for lora in selected_loras:
321
+ trigger_word = lora.get('trigger_word', '')
322
+ if trigger_word:
323
+ if lora.get("trigger_position") == "prepend":
324
+ prepends.append(trigger_word)
325
+ else:
326
+ appends.append(trigger_word)
327
+ prompt_mash = " ".join(prepends + [prompt] + appends)
328
+ print("Prompt Mash: ", prompt_mash)
329
+ # Unload previous LoRA weights
330
+ with calculateDuration("Unloading LoRA"):
331
+ pipe.unload_lora_weights()
332
+ pipe_i2i.unload_lora_weights()
333
+
334
+ print(pipe.get_active_adapters())
335
+ # Load LoRA weights with respective scales
336
+ lora_names = []
337
+ lora_weights = []
338
+ with calculateDuration("Loading LoRA weights"):
339
+ for idx, lora in enumerate(selected_loras):
340
+ lora_name = f"lora_{idx}"
341
+ lora_names.append(lora_name)
342
+ print(f"Lora Name: {lora_name}")
343
+ lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
344
+ lora_path = lora['repo']
345
+ weight_name = lora.get("weights")
346
+ print(f"Lora Path: {lora_path}")
347
+ pipe_to_use = pipe_i2i if image_input is not None else pipe
348
+ pipe_to_use.load_lora_weights(
349
+ lora_path,
350
+ weight_name=weight_name if weight_name else None,
351
+ low_cpu_mem_usage=True,
352
+ adapter_name=lora_name
353
+ )
354
+ # if image_input is not None: pipe_i2i = pipe_to_use
355
+ # else: pipe = pipe_to_use
356
+ print("Loaded LoRAs:", lora_names)
357
+ print("Adapter weights:", lora_weights)
358
+ if image_input is not None:
359
+ pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
360
+ else:
361
+ pipe.set_adapters(lora_names, adapter_weights=lora_weights)
362
+ print(pipe.get_active_adapters())
363
+ # Set random seed for reproducibility
364
+ with calculateDuration("Randomizing seed"):
365
+ if randomize_seed:
366
+ seed = random.randint(0, MAX_SEED)
367
+
368
+ # Generate image
369
+ if image_input is not None:
370
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
371
+ yield final_image, seed, gr.update(visible=False)
372
+ else:
373
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
374
+ # Consume the generator to get the final image
375
+ final_image = None
376
+ step_counter = 0
377
+ for image in image_generator:
378
+ step_counter += 1
379
+ final_image = image
380
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
381
+ yield image, seed, gr.update(value=progress_bar, visible=True)
382
+ yield final_image, seed, gr.update(value=progress_bar, visible=False)
383
+
384
+ run_lora.zerogpu = True
385
+
386
+ def get_huggingface_safetensors(link):
387
+ split_link = link.split("/")
388
+ if len(split_link) == 2:
389
+ model_card = ModelCard.load(link)
390
+ base_model = model_card.data.get("base_model")
391
+ print(f"Base model: {base_model}")
392
+ if base_model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
393
+ raise Exception("Not a FLUX LoRA!")
394
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
395
+ trigger_word = model_card.data.get("instance_prompt", "")
396
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
397
+ fs = HfFileSystem()
398
+ safetensors_name = None
399
+ try:
400
+ list_of_files = fs.ls(link, detail=False)
401
+ for file in list_of_files:
402
+ if file.endswith(".safetensors"):
403
+ safetensors_name = file.split("/")[-1]
404
+ if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
405
+ image_elements = file.split("/")
406
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
407
+ except Exception as e:
408
+ print(e)
409
+ raise gr.Error("Invalid Hugging Face repository with a *.safetensors LoRA")
410
+ if not safetensors_name:
411
+ raise gr.Error("No *.safetensors file found in the repository")
412
+ return split_link[1], link, safetensors_name, trigger_word, image_url
413
+ else:
414
+ raise gr.Error("Invalid Hugging Face repository link")
415
+
416
+ def check_custom_model(link):
417
+ if link.endswith(".safetensors"):
418
+ # Treat as direct link to the LoRA weights
419
+ title = os.path.basename(link)
420
+ repo = link
421
+ path = None # No specific weight name
422
+ trigger_word = ""
423
+ image_url = None
424
+ return title, repo, path, trigger_word, image_url
425
+ elif link.startswith("https://"):
426
+ if "huggingface.co" in link:
427
+ link_split = link.split("huggingface.co/")
428
+ return get_huggingface_safetensors(link_split[1])
429
+ else:
430
+ raise Exception("Unsupported URL")
431
+ else:
432
+ # Assume it's a Hugging Face model path
433
+ return get_huggingface_safetensors(link)
434
+
435
+ def update_history(new_image, history):
436
+ """Updates the history gallery with the new image."""
437
+ if history is None:
438
+ history = []
439
+ history.insert(0, new_image)
440
+ return history
441
+
442
+ css = '''
443
+ #gen_btn{height: 100%}
444
+ #title{text-align: center}
445
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
446
+ #title img{width: 100px; margin-right: 0.25em}
447
+ #gallery .grid-wrap{height: 5vh}
448
+ #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
449
+ .custom_lora_card{margin-bottom: 1em}
450
+ .card_internal{display: flex;height: 100px;margin-top: .5em}
451
+ .card_internal img{margin-right: 1em}
452
+ .styler{--form-gap-width: 0px !important}
453
+ #progress{height:30px}
454
+ #progress .generating{display:none}
455
+ .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
456
+ .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
457
+ #component-8, .button_total{height: 100%; align-self: stretch;}
458
+ #loaded_loras [data-testid="block-info"]{font-size:80%}
459
+ #custom_lora_structure{background: var(--block-background-fill)}
460
+ #custom_lora_btn{margin-top: auto;margin-bottom: 11px}
461
+ #random_btn{font-size: 300%}
462
+ #component-11{align-self: stretch;}
463
+ '''
464
+
465
+ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
466
+ title = gr.HTML(
467
+ """<h1><img src="https://i.imgur.com/wMh2Oek.png" alt="LoRA"> LoRA Lab [beta]</h1><br><span style="
468
+ margin-top: -25px !important;
469
+ display: block;
470
+ margin-left: 37px;
471
+ ">Mix and match any FLUX[dev] LoRAs</span>""",
472
+ elem_id="title",
473
+ )
474
+ loras_state = gr.State(loras)
475
+ selected_indices = gr.State([])
476
+ with gr.Row():
477
+ with gr.Column(scale=3):
478
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
479
+ with gr.Column(scale=1):
480
+ generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"])
481
+ with gr.Row(elem_id="loaded_loras"):
482
+ with gr.Column(scale=1, min_width=25):
483
+ randomize_button = gr.Button("🎲", variant="secondary", scale=1, elem_id="random_btn")
484
+ with gr.Column(scale=8):
485
+ with gr.Row():
486
+ with gr.Column(scale=0, min_width=50):
487
+ lora_image_1 = gr.Image(label="LoRA 1 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
488
+ with gr.Column(scale=3, min_width=100):
489
+ selected_info_1 = gr.Markdown("Select a LoRA 1")
490
+ with gr.Column(scale=5, min_width=50):
491
+ lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
492
+ with gr.Row():
493
+ remove_button_1 = gr.Button("Remove", size="sm")
494
+ with gr.Column(scale=8):
495
+ with gr.Row():
496
+ with gr.Column(scale=0, min_width=50):
497
+ lora_image_2 = gr.Image(label="LoRA 2 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
498
+ with gr.Column(scale=3, min_width=100):
499
+ selected_info_2 = gr.Markdown("Select a LoRA 2")
500
+ with gr.Column(scale=5, min_width=50):
501
+ lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
502
+ with gr.Row():
503
+ remove_button_2 = gr.Button("Remove", size="sm")
504
+ with gr.Row():
505
+ with gr.Column():
506
+ with gr.Group():
507
+ with gr.Row(elem_id="custom_lora_structure"):
508
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path or *.safetensors public URL", placeholder="multimodalart/vintage-ads-flux", scale=3, min_width=150)
509
+ add_custom_lora_button = gr.Button("Add Custom LoRA", elem_id="custom_lora_btn", scale=2, min_width=150)
510
+ remove_custom_lora_button = gr.Button("Remove Custom LoRA", visible=False)
511
+ gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
512
+ gallery = gr.Gallery(
513
+ [(item["image"], item["title"]) for item in loras],
514
+ label="Or pick from the LoRA Explorer gallery",
515
+ allow_preview=False,
516
+ columns=5,
517
+ elem_id="gallery",
518
+ show_share_button=False,
519
+ interactive=False
520
+ )
521
+ with gr.Column():
522
+ progress_bar = gr.Markdown(elem_id="progress", visible=False)
523
+ result = gr.Image(label="Generated Image", interactive=False, show_share_button=False)
524
+ with gr.Accordion("History", open=False):
525
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
526
+
527
+ with gr.Row():
528
+ with gr.Accordion("Advanced Settings", open=False):
529
+ with gr.Row():
530
+ input_image = gr.Image(label="Input image", type="filepath", show_share_button=False)
531
+ image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
532
+ with gr.Column():
533
+ with gr.Row():
534
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
535
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
536
+
537
+ with gr.Row():
538
+ width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
539
+ height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
540
+
541
+ with gr.Row():
542
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
543
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
544
+
545
+ gallery.select(
546
+ update_selection,
547
+ inputs=[selected_indices, loras_state, width, height],
548
+ outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2])
549
+ remove_button_1.click(
550
+ remove_lora_1,
551
+ inputs=[selected_indices, loras_state],
552
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
553
+ )
554
+ remove_button_2.click(
555
+ remove_lora_2,
556
+ inputs=[selected_indices, loras_state],
557
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
558
+ )
559
+ randomize_button.click(
560
+ randomize_loras,
561
+ inputs=[selected_indices, loras_state],
562
+ outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, prompt]
563
+ )
564
+ add_custom_lora_button.click(
565
+ add_custom_lora,
566
+ inputs=[custom_lora, selected_indices, loras_state, gallery],
567
+ outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
568
+ )
569
+ remove_custom_lora_button.click(
570
+ remove_custom_lora,
571
+ inputs=[selected_indices, loras_state, gallery],
572
+ outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
573
+ )
574
+ gr.on(
575
+ triggers=[generate_button.click, prompt.submit],
576
+ fn=run_lora,
577
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
578
+ outputs=[result, seed, progress_bar]
579
+ ).then(
580
+ fn=lambda x, history: update_history(x, history),
581
+ inputs=[result, history_gallery],
582
+ outputs=history_gallery,
583
+ )
584
+
585
+ app.queue()
586
+ app.launch()