Uhhy commited on
Commit
e87b198
·
verified ·
1 Parent(s): 9116275

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -9
app.py CHANGED
@@ -8,6 +8,8 @@ import gradio as gr
8
  import os
9
  from spaces import GPU
10
  from dotenv import load_dotenv
 
 
11
 
12
  load_dotenv()
13
 
@@ -88,7 +90,14 @@ def remove_duplicates(text):
88
  seen_lines.add(line)
89
  return '\n'.join(unique_lines)
90
 
91
- @GPU(duration=1)
 
 
 
 
 
 
 
92
  def generate_model_response(model, inputs):
93
  try:
94
  response = model(inputs)
@@ -97,6 +106,22 @@ def generate_model_response(model, inputs):
97
  print(f"Error generating model response: {e}")
98
  return ""
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def remove_repetitive_responses(responses):
101
  unique_responses = {}
102
  for response in responses:
@@ -124,15 +149,88 @@ async def process_message(message):
124
  """
125
  return formatted_response, curl_command
126
 
 
 
 
 
 
127
 
128
- iface = gr.Interface(
129
- fn=process_message,
130
- inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."),
131
- outputs=[gr.Markdown(), gr.Textbox(label="cURL command")],
132
- title="Multi-Model LLM API",
133
- description="Enter a message and get responses from multiple LLMs.",
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  if __name__ == "__main__":
137
  port = int(os.environ.get("PORT", 7860))
138
- iface.launch(server_port=port)
 
8
  import os
9
  from spaces import GPU
10
  from dotenv import load_dotenv
11
+ import torch
12
+ from diffusers import DiffusionPipeline
13
 
14
  load_dotenv()
15
 
 
90
  seen_lines.add(line)
91
  return '\n'.join(unique_lines)
92
 
93
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
94
+ device = "cuda" if torch.cuda.is_available() else "cpu"
95
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
96
+ MAX_SEED = np.iinfo(np.int32).max
97
+ MAX_IMAGE_SIZE = 2048
98
+
99
+ @spaces.GPU()
100
+ @GPU(duration=1)
101
  def generate_model_response(model, inputs):
102
  try:
103
  response = model(inputs)
 
106
  print(f"Error generating model response: {e}")
107
  return ""
108
 
109
+ @spaces.GPU()
110
+ @GPU(duration=1)
111
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4):
112
+ if randomize_seed:
113
+ seed = random.randint(0, MAX_SEED)
114
+ generator = torch.Generator(device=device).manual_seed(seed)
115
+ image = pipe(
116
+ prompt=prompt,
117
+ width=width,
118
+ height=height,
119
+ num_inference_steps=num_inference_steps,
120
+ generator=generator,
121
+ guidance_scale=0.0
122
+ ).images[0]
123
+ return image, seed
124
+
125
  def remove_repetitive_responses(responses):
126
  unique_responses = {}
127
  for response in responses:
 
149
  """
150
  return formatted_response, curl_command
151
 
152
+ examples = [
153
+ "a tiny astronaut hatching from an egg on the moon",
154
+ "a cat holding a sign that says hello world",
155
+ "an anime illustration of a wiener schnitzel",
156
+ ]
157
 
158
+ css="""
159
+ #col-container {
160
+ margin: 0 auto;
161
+ max-width: 520px;
162
+ }
163
+ """
164
+
165
+ with gr.Blocks(css=css) as demo:
166
+ with gr.Column(elem_id="col-container"):
167
+ gr.Markdown(f"""# FLUX.1 [schnell]
168
+ 12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
169
+ [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
170
+ """)
171
+
172
+ with gr.Row():
173
+ prompt = gr.Text(
174
+ label="Prompt",
175
+ show_label=False,
176
+ max_lines=1,
177
+ placeholder="Enter your prompt",
178
+ container=False,
179
+ )
180
+ run_button = gr.Button("Run", scale=0)
181
+
182
+ result = gr.Image(label="Result", show_label=False)
183
+
184
+ with gr.Accordion("Advanced Settings", open=False):
185
+ seed = gr.Slider(
186
+ label="Seed",
187
+ minimum=0,
188
+ maximum=MAX_SEED,
189
+ step=1,
190
+ value=0,
191
+ )
192
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
193
+
194
+ with gr.Row():
195
+ width = gr.Slider(
196
+ label="Width",
197
+ minimum=256,
198
+ maximum=MAX_IMAGE_SIZE,
199
+ step=32,
200
+ value=1024,
201
+ )
202
+ height = gr.Slider(
203
+ label="Height",
204
+ minimum=256,
205
+ maximum=MAX_IMAGE_SIZE,
206
+ step=32,
207
+ value=1024,
208
+ )
209
+
210
+ with gr.Row():
211
+ num_inference_steps = gr.Slider(
212
+ label="Number of inference steps",
213
+ minimum=1,
214
+ maximum=50,
215
+ step=1,
216
+ value=4,
217
+ )
218
+
219
+ gr.Examples(
220
+ examples=examples,
221
+ fn=infer,
222
+ inputs=[prompt],
223
+ outputs=[result, seed],
224
+ cache_examples="lazy"
225
+ )
226
+
227
+ gr.on(
228
+ triggers=[run_button.click, prompt.submit],
229
+ fn=infer,
230
+ inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
231
+ outputs=[result, seed]
232
+ )
233
 
234
  if __name__ == "__main__":
235
  port = int(os.environ.get("PORT", 7860))
236
+ demo.launch(server_port=port)