bryanbrunetti commited on
Commit
bf4853f
·
1 Parent(s): 6d1ea5f

lora dropdown

Browse files
Files changed (2) hide show
  1. app.py +34 -33
  2. requirements.txt +6 -2
app.py CHANGED
@@ -9,24 +9,36 @@ from huggingface_hub import hf_hub_download
9
  import os
10
 
11
  dtype = torch.bfloat16
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
13
 
14
  MAX_SEED = np.iinfo(np.int32).max
15
  MAX_IMAGE_SIZE = 2048
16
 
 
17
  # Initialize the pipeline globally
18
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
19
 
20
 
21
  @spaces.GPU(duration=300)
22
- def infer(prompt, lora_model, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0,
23
  num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
24
  global pipe
25
 
26
- # Load LoRA if specified
27
- if lora_model:
 
28
  try:
29
- pipe.load_lora_weights(lora_model)
 
30
  except Exception as e:
31
  return None, seed, f"Failed to load LoRA model: {str(e)}"
32
 
@@ -45,7 +57,7 @@ def infer(prompt, lora_model, seed=42, randomize_seed=False, width=1024, height=
45
  ).images[0]
46
 
47
  # Unload LoRA weights after generation
48
- if lora_model:
49
  pipe.unload_lora_weights()
50
 
51
  return image, seed, "Image generated successfully."
@@ -53,12 +65,6 @@ def infer(prompt, lora_model, seed=42, randomize_seed=False, width=1024, height=
53
  return None, seed, f"Error during image generation: {str(e)}"
54
 
55
 
56
- examples = [
57
- ["a tiny astronaut hatching from an egg on the moon", ""],
58
- ["a cat holding a sign that says hello world", ""],
59
- ["an anime illustration of a wiener schnitzel", ""],
60
- ]
61
-
62
  css = """
63
  #col-container {
64
  margin: 0 auto;
@@ -68,11 +74,6 @@ css = """
68
 
69
  with gr.Blocks(css=css) as demo:
70
  with gr.Column(elem_id="col-container"):
71
- gr.Markdown(f"""# FLUX.1 [dev] with LoRA Support
72
- 12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
73
- [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
74
- """)
75
-
76
  with gr.Row():
77
  prompt = gr.Text(
78
  label="Prompt",
@@ -83,13 +84,17 @@ with gr.Blocks(css=css) as demo:
83
  )
84
  run_button = gr.Button("Run", scale=0)
85
 
86
- lora_model = gr.Text(
87
- label="LoRA Model ID (optional)",
88
- placeholder="Enter Hugging Face LoRA model ID",
89
- )
 
 
 
 
 
90
 
91
  result = gr.Image(label="Result", show_label=False)
92
- output_message = gr.Textbox(label="Output Message")
93
 
94
  with gr.Accordion("Advanced Settings", open=False):
95
  seed = gr.Slider(
@@ -106,18 +111,19 @@ with gr.Blocks(css=css) as demo:
106
  minimum=256,
107
  maximum=MAX_IMAGE_SIZE,
108
  step=32,
109
- value=1024,
110
  )
111
  height = gr.Slider(
112
  label="Height",
113
  minimum=256,
114
  maximum=MAX_IMAGE_SIZE,
115
  step=32,
116
- value=1024,
117
  )
118
  with gr.Row():
119
  guidance_scale = gr.Slider(
120
  label="Guidance Scale",
 
121
  minimum=1,
122
  maximum=15,
123
  step=0.1,
@@ -125,24 +131,19 @@ with gr.Blocks(css=css) as demo:
125
  )
126
  num_inference_steps = gr.Slider(
127
  label="Number of inference steps",
 
128
  minimum=1,
129
  maximum=50,
130
  step=1,
131
  value=28,
132
  )
133
 
134
- gr.Examples(
135
- examples=examples,
136
- fn=infer,
137
- inputs=[prompt, lora_model],
138
- outputs=[result, seed, output_message],
139
- cache_examples="lazy"
140
- )
141
-
142
  gr.on(
143
  triggers=[run_button.click, prompt.submit],
144
  fn=infer,
145
- inputs=[prompt, lora_model, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
146
  outputs=[result, seed, output_message]
147
  )
148
 
 
9
  import os
10
 
11
  dtype = torch.bfloat16
12
+
13
+ if torch.cuda.is_available():
14
+ device = "cuda"
15
+ elif torch.backends.mps.is_available():
16
+ if not torch.backends.mps.is_built():
17
+ print("MPS not available because the current PyTorch install was not "
18
+ "built with MPS enabled.")
19
+ device = "mps"
20
+ else:
21
+ device = "cpu"
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 2048
25
 
26
+ print(f"Device is: {device}")
27
  # Initialize the pipeline globally
28
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
29
 
30
 
31
  @spaces.GPU(duration=300)
32
+ def infer(prompt, lora_models, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0,
33
  num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
34
  global pipe
35
 
36
+ # Load LoRAs if specified
37
+ if lora_models:
38
+ print(f"Loading Loras: {lora_models}")
39
  try:
40
+ for lora_model in lora_models:
41
+ pipe.load_lora_weights(lora_model)
42
  except Exception as e:
43
  return None, seed, f"Failed to load LoRA model: {str(e)}"
44
 
 
57
  ).images[0]
58
 
59
  # Unload LoRA weights after generation
60
+ if lora_models:
61
  pipe.unload_lora_weights()
62
 
63
  return image, seed, "Image generated successfully."
 
65
  return None, seed, f"Error during image generation: {str(e)}"
66
 
67
 
 
 
 
 
 
 
68
  css = """
69
  #col-container {
70
  margin: 0 auto;
 
74
 
75
  with gr.Blocks(css=css) as demo:
76
  with gr.Column(elem_id="col-container"):
 
 
 
 
 
77
  with gr.Row():
78
  prompt = gr.Text(
79
  label="Prompt",
 
84
  )
85
  run_button = gr.Button("Run", scale=0)
86
 
87
+ # lora_model = gr.Text(
88
+ # label="LoRA Model ID (optional)",
89
+ # placeholder="Enter Hugging Face LoRA model ID",
90
+ # )
91
+ lora_models = gr.Dropdown([
92
+ ("jerry", "bryanbrunetti/jerryfluxlora"),
93
+ ("sebastian", "bryanbrunetti/sebastianfluxlora"),
94
+ ("sadie", "bryanbrunetti/sadiefluxlora")
95
+ ], multiselect=True, info="load lora (optional) use the name in the prompt", label="Choose People")
96
 
97
  result = gr.Image(label="Result", show_label=False)
 
98
 
99
  with gr.Accordion("Advanced Settings", open=False):
100
  seed = gr.Slider(
 
111
  minimum=256,
112
  maximum=MAX_IMAGE_SIZE,
113
  step=32,
114
+ value=512,
115
  )
116
  height = gr.Slider(
117
  label="Height",
118
  minimum=256,
119
  maximum=MAX_IMAGE_SIZE,
120
  step=32,
121
+ value=512,
122
  )
123
  with gr.Row():
124
  guidance_scale = gr.Slider(
125
  label="Guidance Scale",
126
+ info="How close to follow prompt",
127
  minimum=1,
128
  maximum=15,
129
  step=0.1,
 
131
  )
132
  num_inference_steps = gr.Slider(
133
  label="Number of inference steps",
134
+ info="higher = more details",
135
  minimum=1,
136
  maximum=50,
137
  step=1,
138
  value=28,
139
  )
140
 
141
+ output_message = gr.Textbox(label="Output Message")
142
+
 
 
 
 
 
 
143
  gr.on(
144
  triggers=[run_button.click, prompt.submit],
145
  fn=infer,
146
+ inputs=[prompt, lora_models, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
147
  outputs=[result, seed, output_message]
148
  )
149
 
requirements.txt CHANGED
@@ -1,7 +1,11 @@
1
  accelerate
2
  torch
3
  transformers==4.42.4
4
- xformers
5
  sentencepiece
6
  peft
7
- diffusers
 
 
 
 
 
 
1
  accelerate
2
  torch
3
  transformers==4.42.4
 
4
  sentencepiece
5
  peft
6
+ diffusers~=0.30.0
7
+ gradio~=4.42.0
8
+ spaces~=0.29.3
9
+ protobuf
10
+ numpy~=1.26.4
11
+ huggingface_hub~=0.24.6