bryanbrunetti commited on
Commit
56a7978
·
1 Parent(s): b6b5406

lora fixes

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -3,10 +3,7 @@ import numpy as np
3
  import random
4
  import spaces
5
  import torch
6
- from diffusers import DiffusionPipeline#, FlowMatchEulerDiscreteScheduler
7
- # from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
8
- # from huggingface_hub import hf_hub_download
9
- # import os
10
 
11
  dtype = torch.bfloat16
12
 
@@ -27,6 +24,12 @@ print(f"Device is: {device}")
27
  # Initialize the pipeline globally
28
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
29
 
 
 
 
 
 
 
30
 
31
  @spaces.GPU(duration=120)
32
  def infer(prompt, lora_models, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0,
@@ -35,10 +38,10 @@ def infer(prompt, lora_models, seed=42, randomize_seed=False, width=1024, height
35
 
36
  # Load LoRAs if specified
37
  if lora_models:
38
- print(f"Loading Loras: {lora_models}")
39
  try:
40
  for lora_model in lora_models:
41
- pipe.load_lora_weights(lora_model)
 
42
  except Exception as e:
43
  return None, seed, f"Failed to load LoRA model: {str(e)}"
44
 
@@ -88,11 +91,8 @@ with gr.Blocks(css=css) as demo:
88
  # label="LoRA Model ID (optional)",
89
  # placeholder="Enter Hugging Face LoRA model ID",
90
  # )
91
- lora_models = gr.Dropdown([
92
- ("jerry", "bryanbrunetti/jerryfluxlora"),
93
- ("sebastian", "bryanbrunetti/sebastianfluxlora"),
94
- ("sadie", "bryanbrunetti/sadiefluxlora")
95
- ], multiselect=True, info="load lora (optional) use the name in the prompt", label="Choose People")
96
 
97
  result = gr.Image(label="Result", show_label=False)
98
 
@@ -139,7 +139,7 @@ with gr.Blocks(css=css) as demo:
139
  )
140
 
141
  output_message = gr.Textbox(label="Output Message")
142
-
143
  gr.on(
144
  triggers=[run_button.click, prompt.submit],
145
  fn=infer,
@@ -147,4 +147,4 @@ with gr.Blocks(css=css) as demo:
147
  outputs=[result, seed, output_message]
148
  )
149
 
150
- demo.launch()
 
3
  import random
4
  import spaces
5
  import torch
6
+ from diffusers import DiffusionPipeline
 
 
 
7
 
8
  dtype = torch.bfloat16
9
 
 
24
  # Initialize the pipeline globally
25
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
26
 
27
+ lora_weights = {
28
+ "jerry": {"path": "bryanbrunetti/jerryfluxlora", "weight_name": "jerry_000003000.safetensors"},
29
+ "sebastian": {"path": "bryanbrunetti/sebastianfluxlora", "weight_name": "sadie_flux.safetensors"},
30
+ "sadie": {"path": "bryanbrunetti/sadiefluxlora", "weight_name": "sebastian_flux_000001500.safetensors"},
31
+ }
32
+
33
 
34
  @spaces.GPU(duration=120)
35
  def infer(prompt, lora_models, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0,
 
38
 
39
  # Load LoRAs if specified
40
  if lora_models:
 
41
  try:
42
  for lora_model in lora_models:
43
+ print(f"loading LoRA: {lora_model}")
44
+ pipe.load_lora_weights(lora_weights[lora_model], weight_name=lora_weights[lora_model]["weight_name"])
45
  except Exception as e:
46
  return None, seed, f"Failed to load LoRA model: {str(e)}"
47
 
 
91
  # label="LoRA Model ID (optional)",
92
  # placeholder="Enter Hugging Face LoRA model ID",
93
  # )
94
+ lora_models = gr.Dropdown(list(lora_weights.keys()), multiselect=True,
95
+ info="Load LoRA (optional) use the name in the prompt", label="Choose LoRAs")
 
 
 
96
 
97
  result = gr.Image(label="Result", show_label=False)
98
 
 
139
  )
140
 
141
  output_message = gr.Textbox(label="Output Message")
142
+
143
  gr.on(
144
  triggers=[run_button.click, prompt.submit],
145
  fn=infer,
 
147
  outputs=[result, seed, output_message]
148
  )
149
 
150
+ demo.launch()