linoyts HF Staff commited on
Commit
3823e02
·
verified ·
1 Parent(s): 7dbbdc6

add lora gallery

Browse files
Files changed (1) hide show
  1. app.py +344 -81
app.py CHANGED
@@ -1,116 +1,379 @@
1
  import gradio as gr
2
  import numpy as np
3
-
4
  import spaces
5
  import torch
6
  import random
 
 
7
  from PIL import Image
8
-
9
- #from kontext_pipeline import FluxKontextPipeline
10
- from pipeline_flux_kontext import FluxKontextPipeline
11
  from diffusers import FluxTransformer2DModel
12
  from diffusers.utils import load_image
 
 
 
 
13
 
14
- from huggingface_hub import hf_hub_download
15
-
16
-
17
  kontext_path = hf_hub_download(repo_id="diffusers/kontext-v2", filename="dev-opt-2-a-3.safetensors")
18
-
19
  MAX_SEED = np.iinfo(np.int32).max
20
 
21
  transformer = FluxTransformer2DModel.from_single_file(kontext_path, torch_dtype=torch.bfloat16)
22
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @spaces.GPU
25
- def infer(input_image, prompt, seed=42, randomize_seed=False, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
 
 
26
 
27
  if randomize_seed:
28
  seed = random.randint(0, MAX_SEED)
29
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  input_image = input_image.convert("RGB")
31
- # original_width, original_height = input_image.size
32
-
33
- # if original_width >= original_height:
34
- # new_width = 1024
35
- # new_height = int(original_height * (new_width / original_width))
36
- # new_height = round(new_height / 64) * 64
37
- # else:
38
- # new_height = 1024
39
- # new_width = int(original_width * (new_height / original_height))
40
- # new_width = round(new_width / 64) * 64
41
-
42
- #input_image_resized = input_image.resize((new_width, new_height), Image.LANCZOS)
43
- image = pipe(
44
- image=input_image,
45
- prompt=prompt,
46
- guidance_scale=guidance_scale,
47
- # width=new_width,
48
- # height=new_height,
49
- generator=torch.Generator().manual_seed(seed),
50
- ).images[0]
51
- return image, seed, gr.update(visible=True)
52
 
53
- css="""
54
- #col-container {
55
- margin: 0 auto;
56
- max-width: 960px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  }
58
  """
59
 
 
60
  with gr.Blocks(css=css) as demo:
 
61
 
62
- with gr.Column(elem_id="col-container"):
63
- gr.Markdown(f"""# FLUX.1 Kontext [dev]
64
- """)
65
-
66
- with gr.Row():
67
- with gr.Column():
68
- input_image = gr.Image(label="Upload the image for editing", type="pil")
69
- with gr.Row():
70
- prompt = gr.Text(
71
- label="Prompt",
72
- show_label=False,
73
- max_lines=1,
74
- placeholder="Enter your prompt for editing (e.g., 'Remove glasses', 'Add a hat')",
75
- container=False,
76
- )
77
- run_button = gr.Button("Run", scale=0)
78
- with gr.Accordion("Advanced Settings", open=False):
79
-
80
- seed = gr.Slider(
81
- label="Seed",
82
- minimum=0,
83
- maximum=MAX_SEED,
84
- step=1,
85
- value=0,
86
- )
87
-
88
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
89
-
90
- guidance_scale = gr.Slider(
91
- label="Guidance Scale",
92
- minimum=1,
93
- maximum=10,
94
- step=0.1,
95
- value=2.5,
96
- )
97
-
98
- with gr.Column():
99
- result = gr.Image(label="Result", show_label=False, interactive=False)
100
- reuse_button = gr.Button("Reuse this image", visible=False)
101
-
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  gr.on(
105
  triggers=[run_button.click, prompt.submit],
106
- fn = infer,
107
- inputs = [input_image, prompt, seed, randomize_seed, guidance_scale],
108
- outputs = [result, seed, reuse_button]
109
  )
 
110
  reuse_button.click(
111
- fn = lambda image: image,
112
- inputs = [result],
113
- outputs = [input_image]
 
 
 
 
 
 
 
114
  )
115
 
 
116
  demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import spaces
4
  import torch
5
  import random
6
+ import json
7
+ import os
8
  from PIL import Image
9
+ from kontext_pipeline import FluxKontextPipeline
 
 
10
  from diffusers import FluxTransformer2DModel
11
  from diffusers.utils import load_image
12
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
13
+ from safetensors.torch import load_file
14
+ import requests
15
+ import re
16
 
17
+ # Load Kontext model
 
 
18
  kontext_path = hf_hub_download(repo_id="diffusers/kontext-v2", filename="dev-opt-2-a-3.safetensors")
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
 
21
  transformer = FluxTransformer2DModel.from_single_file(kontext_path, torch_dtype=torch.bfloat16)
22
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
23
 
24
+ # Load LoRA data (you'll need to create this JSON file or modify to load your LoRAs)
25
+ try:
26
+ with open("flux_loras.json", "r") as file:
27
+ data = json.load(file)
28
+ flux_loras_raw = [
29
+ {
30
+ "image": item["image"],
31
+ "title": item["title"],
32
+ "repo": item["repo"],
33
+ "trigger_word": item.get("trigger_word", ""),
34
+ "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
35
+ "likes": item.get("likes", 0),
36
+ "downloads": item.get("downloads", 0),
37
+ }
38
+ for item in data
39
+ ]
40
+ except FileNotFoundError:
41
+ # Default LoRAs if JSON file doesn't exist
42
+ flux_loras_raw = [
43
+ {
44
+ "image": "https://via.placeholder.com/300x300?text=LoRA+1",
45
+ "title": "Example LoRA 1",
46
+ "repo": "example/lora1",
47
+ "trigger_word": "style1",
48
+ "weights": "pytorch_lora_weights.safetensors",
49
+ "likes": 100,
50
+ "downloads": 500,
51
+ },
52
+ {
53
+ "image": "https://via.placeholder.com/300x300?text=LoRA+2",
54
+ "title": "Example LoRA 2",
55
+ "repo": "example/lora2",
56
+ "trigger_word": "style2",
57
+ "weights": "pytorch_lora_weights.safetensors",
58
+ "likes": 80,
59
+ "downloads": 300,
60
+ }
61
+ ]
62
+
63
+ # Global variables for LoRA management
64
+ current_lora = None
65
+ lora_cache = {}
66
+
67
+ def load_lora_weights(repo_id, weights_filename):
68
+ """Load LoRA weights from HuggingFace"""
69
+ try:
70
+ if repo_id not in lora_cache:
71
+ lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
72
+ lora_cache[repo_id] = lora_path
73
+ return lora_cache[repo_id]
74
+ except Exception as e:
75
+ print(f"Error loading LoRA from {repo_id}: {e}")
76
+ return None
77
+
78
+ def update_selection(selected_state: gr.SelectData, flux_loras):
79
+ """Update UI when a LoRA is selected"""
80
+ if selected_state.index >= len(flux_loras):
81
+ return "### No LoRA selected", gr.update(), selected_state
82
+
83
+ lora_repo = flux_loras[selected_state.index]["repo"]
84
+ trigger_word = flux_loras[selected_state.index]["trigger_word"]
85
+
86
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
87
+ new_placeholder = f"Enter your editing prompt{f' (use {trigger_word} for best results)' if trigger_word else ''}"
88
+
89
+ return updated_text, gr.update(placeholder=new_placeholder), selected_state
90
+
91
+ def get_huggingface_lora(link):
92
+ """Download LoRA from HuggingFace link"""
93
+ split_link = link.split("/")
94
+ if len(split_link) == 2:
95
+ try:
96
+ model_card = ModelCard.load(link)
97
+ trigger_word = model_card.data.get("instance_prompt", "")
98
+
99
+ fs = HfFileSystem()
100
+ list_of_files = fs.ls(link, detail=False)
101
+ safetensors_file = None
102
+
103
+ for file in list_of_files:
104
+ if file.endswith(".safetensors") and "lora" in file.lower():
105
+ safetensors_file = file.split("/")[-1]
106
+ break
107
+
108
+ if not safetensors_file:
109
+ safetensors_file = "pytorch_lora_weights.safetensors"
110
+
111
+ return split_link[1], safetensors_file, trigger_word
112
+ except Exception as e:
113
+ raise Exception(f"Error loading LoRA: {e}")
114
+ else:
115
+ raise Exception("Invalid HuggingFace repository format")
116
+
117
+ def load_custom_lora(link):
118
+ """Load custom LoRA from user input"""
119
+ if not link:
120
+ return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on a LoRA in the gallery to select it"
121
+
122
+ try:
123
+ repo_name, weights_file, trigger_word = get_huggingface_lora(link)
124
+
125
+ card = f'''
126
+ <div style="border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 10px 0;">
127
+ <span><strong>Loaded custom LoRA:</strong></span>
128
+ <div style="margin-top: 8px;">
129
+ <h4>{repo_name}</h4>
130
+ <small>{"Using: <code><b>"+trigger_word+"</b></code> as trigger word" if trigger_word else "No trigger word found"}</small>
131
+ </div>
132
+ </div>
133
+ '''
134
+
135
+ custom_lora_data = {
136
+ "repo": link,
137
+ "weights": weights_file,
138
+ "trigger_word": trigger_word
139
+ }
140
+
141
+ return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}"
142
+
143
+ except Exception as e:
144
+ return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on a LoRA in the gallery to select it"
145
+
146
+ def remove_custom_lora():
147
+ """Remove custom LoRA"""
148
+ return "", gr.update(visible=False), gr.update(visible=False), None
149
+
150
+ def classify_gallery(flux_loras):
151
+ """Sort gallery by likes"""
152
+ sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
153
+ return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
154
+
155
  @spaces.GPU
156
+ def infer_with_lora(input_image, prompt, selected_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
157
+ """Generate image with selected LoRA"""
158
+ global current_lora, pipe
159
 
160
  if randomize_seed:
161
  seed = random.randint(0, MAX_SEED)
162
+
163
+ # Determine which LoRA to use
164
+ lora_to_use = None
165
+ if custom_lora:
166
+ lora_to_use = custom_lora
167
+ elif selected_state and flux_loras:
168
+ selected_index = selected_state.index if hasattr(selected_state, 'index') else None
169
+ if selected_index is not None and selected_index < len(flux_loras):
170
+ lora_to_use = flux_loras[selected_index]
171
+
172
+ # Load LoRA if needed
173
+ if lora_to_use and lora_to_use != current_lora:
174
+ try:
175
+ # Unload current LoRA
176
+ if current_lora:
177
+ pipe.unload_lora_weights()
178
+
179
+ # Load new LoRA
180
+ lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
181
+ if lora_path:
182
+ pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
183
+ current_lora = lora_to_use
184
+
185
+ # Add trigger word to prompt if available
186
+ trigger_word = lora_to_use.get("trigger_word", "")
187
+ if trigger_word and trigger_word not in prompt:
188
+ prompt = f"{trigger_word} {prompt}"
189
+
190
+ except Exception as e:
191
+ print(f"Error loading LoRA: {e}")
192
+ # Continue without LoRA
193
+
194
+ # Set LoRA scale if LoRA is loaded
195
+ if current_lora and hasattr(pipe, 'set_adapters'):
196
+ try:
197
+ pipe.set_adapters("selected_lora", adapter_weights=[lora_scale])
198
+ except:
199
+ # Fallback for older diffusers versions
200
+ pass
201
+
202
  input_image = input_image.convert("RGB")
203
+
204
+ try:
205
+ image = pipe(
206
+ image=input_image,
207
+ prompt=prompt,
208
+ guidance_scale=guidance_scale,
209
+ generator=torch.Generator().manual_seed(seed),
210
+ ).images[0]
211
+
212
+ return image, seed, gr.update(visible=True)
213
+
214
+ except Exception as e:
215
+ print(f"Error during inference: {e}")
216
+ return None, seed, gr.update(visible=False)
 
 
 
 
 
 
 
217
 
218
+ # CSS styling
219
+ css = """
220
+ #main_app {
221
+ display: flex;
222
+ gap: 20px;
223
+ }
224
+ #box_column {
225
+ min-width: 400px;
226
+ }
227
+ #gallery_box {
228
+ border: 1px solid #ddd;
229
+ border-radius: 8px;
230
+ padding: 15px;
231
+ }
232
+ #gallery {
233
+ height: 400px;
234
+ }
235
+ #selected_lora {
236
+ color: #2563eb;
237
+ font-weight: bold;
238
+ }
239
+ #prompt {
240
+ flex-grow: 1;
241
+ }
242
+ #run_button {
243
+ background: linear-gradient(45deg, #2563eb, #3b82f6);
244
+ color: white;
245
+ border: none;
246
+ padding: 8px 16px;
247
+ border-radius: 6px;
248
+ font-weight: bold;
249
+ }
250
+ .custom_lora_card {
251
+ background: #f8fafc;
252
+ border: 1px solid #e2e8f0;
253
+ border-radius: 8px;
254
+ padding: 12px;
255
+ margin: 8px 0;
256
  }
257
  """
258
 
259
+ # Create Gradio interface
260
  with gr.Blocks(css=css) as demo:
261
+ gr_flux_loras = gr.State(value=flux_loras_raw)
262
 
263
+ title = gr.HTML(
264
+ """<h1> FLUX.1 Kontext Portrait 👩🏻‍🎤
265
+ <br><small style="font-size: 13px; opacity: 0.75;"></small></h1>""",
266
+ )
267
+
268
+ selected_state = gr.State()
269
+ custom_loaded_lora = gr.State()
270
+
271
+ with gr.Row(elem_id="main_app"):
272
+ with gr.Column(scale=4, elem_id="box_column"):
273
+ with gr.Group(elem_id="gallery_box"):
274
+ input_image = gr.Image(label="Upload image for editing", type="pil", height=250)
275
+
276
+ gallery = gr.Gallery(
277
+ label="Pick a LoRA style from the gallery",
278
+ allow_preview=False,
279
+ columns=3,
280
+ elem_id="gallery",
281
+ show_share_button=False,
282
+ height=400
283
+ )
284
+
285
+ custom_model = gr.Textbox(
286
+ label="Or enter a custom HuggingFace FLUX LoRA",
287
+ placeholder="e.g., username/lora-name",
288
+ visible=False
289
+ )
290
+ custom_model_card = gr.HTML(visible=False)
291
+ custom_model_button = gr.Button("Remove custom LoRA", visible=False)
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ with gr.Column(scale=5):
294
+ with gr.Row():
295
+ prompt = gr.Textbox(
296
+ label="Editing Prompt",
297
+ show_label=False,
298
+ lines=1,
299
+ max_lines=1,
300
+ placeholder="Enter your editing prompt (e.g., 'Remove glasses', 'Add a hat')",
301
+ elem_id="prompt"
302
+ )
303
+ run_button = gr.Button("Generate", elem_id="run_button")
304
+
305
+ result = gr.Image(label="Generated Image", interactive=False)
306
+ reuse_button = gr.Button("Reuse this image", visible=False)
307
+
308
+ with gr.Accordion("Advanced Settings", open=False):
309
+ lora_scale = gr.Slider(
310
+ label="LoRA Scale",
311
+ minimum=0,
312
+ maximum=2,
313
+ step=0.1,
314
+ value=1.0,
315
+ info="Controls the strength of the LoRA effect"
316
+ )
317
+ seed = gr.Slider(
318
+ label="Seed",
319
+ minimum=0,
320
+ maximum=MAX_SEED,
321
+ step=1,
322
+ value=0,
323
+ )
324
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
325
+ guidance_scale = gr.Slider(
326
+ label="Guidance Scale",
327
+ minimum=1,
328
+ maximum=10,
329
+ step=0.1,
330
+ value=2.5,
331
+ )
332
+
333
+ prompt_title = gr.Markdown(
334
+ value="### Click on a LoRA in the gallery to select it",
335
+ visible=True,
336
+ elem_id="selected_lora",
337
+ )
338
 
339
+ # Event handlers
340
+ custom_model.input(
341
+ fn=load_custom_lora,
342
+ inputs=[custom_model],
343
+ outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title],
344
+ )
345
+
346
+ custom_model_button.click(
347
+ fn=remove_custom_lora,
348
+ outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora]
349
+ )
350
+
351
+ gallery.select(
352
+ fn=update_selection,
353
+ inputs=[gr_flux_loras],
354
+ outputs=[prompt_title, prompt, selected_state],
355
+ show_progress=False
356
+ )
357
+
358
  gr.on(
359
  triggers=[run_button.click, prompt.submit],
360
+ fn=infer_with_lora,
361
+ inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras],
362
+ outputs=[result, seed, reuse_button]
363
  )
364
+
365
  reuse_button.click(
366
+ fn=lambda image: image,
367
+ inputs=[result],
368
+ outputs=[input_image]
369
+ )
370
+
371
+ # Initialize gallery
372
+ demo.load(
373
+ fn=classify_gallery,
374
+ inputs=[gr_flux_loras],
375
+ outputs=[gallery, gr_flux_loras]
376
  )
377
 
378
+ demo.queue(default_concurrency_limit=None)
379
  demo.launch()