KingNish commited on
Commit
6934cc4
·
verified ·
1 Parent(s): 0a437ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -57
app.py CHANGED
@@ -11,8 +11,8 @@ from diffusers.models.attention_processor import AttnProcessor2_0
11
  from custom_pipeline import FluxWithCFGPipeline
12
 
13
  # --- Torch Optimizations ---
14
- torch.backends.cuda.matmul.allow_tf32 = True
15
- torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
16
 
17
  # --- Constants ---
18
  MAX_SEED = np.iinfo(np.int32).max
@@ -27,39 +27,30 @@ ENHANCE_STEPS = 2 # Fixed steps for the enhance button
27
  # --- Device and Model Setup ---
28
  dtype = torch.float16
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
- pipe = None # Initialize pipe to None
31
 
32
- try:
33
- pipe = FluxWithCFGPipeline.from_pretrained(
34
- "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
35
- )
36
- pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
37
-
38
- pipe.to(device)
39
 
40
- # Apply optimizations
41
- pipe.unet.set_attn_processor(AttnProcessor2_0())
42
- pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
43
 
44
- pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
45
- pipe.set_adapters(["better"], adapter_weights=[1.0])
46
- pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # Fuse for potential speedup
47
- pipe.unload_lora_weights() # Unload after fusing
48
 
49
- # --- Compilation (Major Speed Optimization) ---
50
- pipe.vae.decoder = torch.compile(pipe.vae.decoder, mode="reduce-overhead", fullgraph=True)
51
- pipe.vae.encoder = torch.compile(pipe.vae.encoder, mode="reduce-overhead", fullgraph=True)
 
52
 
53
- # Clear cache after setup
54
- if torch.cuda.is_available():
55
- torch.cuda.empty_cache()
56
-
57
- except Exception as e:
58
- print(e)
59
 
60
 
61
  # --- Inference Function ---
62
- @spaces.GPU() # Slightly increased duration buffer
63
  def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
64
  """Generates an image using the FLUX pipeline with error handling."""
65
 
@@ -119,20 +110,6 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
119
  raise gr.Error(f"An error occurred during generation: {e}")
120
 
121
 
122
- # --- Real-time Generation Wrapper ---
123
- # This function checks the realtime toggle before calling the main generation function.
124
- # It's triggered by changes in prompt or sliders when realtime is enabled.
125
- def handle_realtime_update(realtime_enabled: bool, prompt: str, seed: int, width: int, height: int, randomize_seed: bool, num_inference_steps: int):
126
- if realtime_enabled and pipe is not None:
127
- # Call generate_image directly. Errors within generate_image will be caught and raised as gr.Error.
128
- # We don't set is_enhance=True for realtime updates.
129
- return generate_image(prompt, seed, width, height, randomize_seed, num_inference_steps, is_enhance=False)
130
- else:
131
- # If realtime is disabled or pipe failed, don't update the image, seed, or latency.
132
- # Return gr.update() for each output component to indicate no change.
133
- return gr.update(), gr.update(), gr.update()
134
-
135
-
136
  # --- Example Prompts ---
137
  examples = [
138
  "a tiny astronaut hatching from an egg on the moon",
@@ -195,9 +172,7 @@ with gr.Blocks() as demo:
195
  fn=generate_image,
196
  inputs=[prompt, seed, width, height],
197
  outputs=[result, seed, latency],
198
- show_progress="full",
199
- queue=False,
200
- concurrency_limit=None,
201
  )
202
 
203
  generateBtn.click(
@@ -206,7 +181,6 @@ with gr.Blocks() as demo:
206
  outputs=[result, seed, latency],
207
  show_progress="full",
208
  api_name="RealtimeFlux",
209
- queue=False
210
  )
211
 
212
  def update_ui(realtime_enabled):
@@ -222,21 +196,14 @@ with gr.Blocks() as demo:
222
  realtime.change(
223
  fn=update_ui,
224
  inputs=[realtime],
225
- outputs=[prompt, generateBtn],
226
- queue=False,
227
- concurrency_limit=None
228
  )
229
 
230
- # Removed the intermediate realtime_generation function.
231
- # handle_realtime_update checks the realtime toggle internally.
232
-
233
  prompt.submit(
234
  fn=generate_image,
235
  inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
236
  outputs=[result, seed, latency],
237
- show_progress="full",
238
- queue=False,
239
- concurrency_limit=None
240
  )
241
 
242
  for component in [prompt, width, height, num_inference_steps]:
@@ -245,9 +212,7 @@ with gr.Blocks() as demo:
245
  inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
246
  outputs=[result, seed, latency],
247
  show_progress="hidden",
248
- trigger_mode="always_last",
249
- queue=False,
250
- concurrency_limit=None
251
  )
252
 
253
  # Launch the app
 
11
  from custom_pipeline import FluxWithCFGPipeline
12
 
13
  # --- Torch Optimizations ---
14
+ # torch.backends.cuda.matmul.allow_tf32 = True
15
+ # torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
16
 
17
  # --- Constants ---
18
  MAX_SEED = np.iinfo(np.int32).max
 
27
  # --- Device and Model Setup ---
28
  dtype = torch.float16
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
30
 
31
+ pipe = FluxWithCFGPipeline.from_pretrained(
32
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
33
+ )
34
+ pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
 
 
 
35
 
36
+ pipe.to(device)
 
 
37
 
38
+ # Apply optimizations
39
+ pipe.unet.set_attn_processor(AttnProcessor2_0())
40
+ pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
 
41
 
42
+ pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
43
+ pipe.set_adapters(["better"], adapter_weights=[1.0])
44
+ pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # Fuse for potential speedup
45
+ pipe.unload_lora_weights() # Unload after fusing
46
 
47
+ # --- Compilation (Major Speed Optimization) ---
48
+ pipe.vae.decoder = torch.compile(pipe.vae.decoder, mode="reduce-overhead", fullgraph=True)
49
+ pipe.vae.encoder = torch.compile(pipe.vae.encoder, mode="reduce-overhead", fullgraph=True)
 
 
 
50
 
51
 
52
  # --- Inference Function ---
53
+ @spaces.GPU
54
  def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
55
  """Generates an image using the FLUX pipeline with error handling."""
56
 
 
110
  raise gr.Error(f"An error occurred during generation: {e}")
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # --- Example Prompts ---
114
  examples = [
115
  "a tiny astronaut hatching from an egg on the moon",
 
172
  fn=generate_image,
173
  inputs=[prompt, seed, width, height],
174
  outputs=[result, seed, latency],
175
+ show_progress="full"
 
 
176
  )
177
 
178
  generateBtn.click(
 
181
  outputs=[result, seed, latency],
182
  show_progress="full",
183
  api_name="RealtimeFlux",
 
184
  )
185
 
186
  def update_ui(realtime_enabled):
 
196
  realtime.change(
197
  fn=update_ui,
198
  inputs=[realtime],
199
+ outputs=[prompt, generateBtn]
 
 
200
  )
201
 
 
 
 
202
  prompt.submit(
203
  fn=generate_image,
204
  inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
205
  outputs=[result, seed, latency],
206
+ show_progress="full"
 
 
207
  )
208
 
209
  for component in [prompt, width, height, num_inference_steps]:
 
212
  inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
213
  outputs=[result, seed, latency],
214
  show_progress="hidden",
215
+ trigger_mode="always_last"
 
 
216
  )
217
 
218
  # Launch the app