KingNish commited on
Commit
8effdde
·
verified ·
1 Parent(s): 627bc27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -38
app.py CHANGED
@@ -4,16 +4,12 @@ import random
4
  import spaces
5
  import torch
6
  import time
7
- import logging
8
  from diffusers import DiffusionPipeline, AutoencoderTiny
9
  # Using AttnProcessor2_0 for potential speedup with PyTorch 2.x
10
  from diffusers.models.attention_processor import AttnProcessor2_0
11
  # Assuming custom_pipeline defines FluxWithCFGPipeline correctly
12
  from custom_pipeline import FluxWithCFGPipeline
13
 
14
- # --- Setup Logging ---
15
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
-
17
  # --- Torch Optimizations ---
18
  torch.backends.cuda.matmul.allow_tf32 = True
19
  torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
@@ -34,50 +30,36 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
34
  pipe = None # Initialize pipe to None
35
 
36
  try:
37
- logging.info("Loading diffusion pipeline...")
38
  pipe = FluxWithCFGPipeline.from_pretrained(
39
  "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
40
  )
41
- logging.info("Loading VAE...")
42
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
43
 
44
- logging.info(f"Moving pipeline to {device}...")
45
  pipe.to(device)
46
 
47
  # Apply optimizations
48
- logging.info("Setting attention processor...")
49
  pipe.unet.set_attn_processor(AttnProcessor2_0())
50
  pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
51
 
52
- logging.info("Loading and fusing LoRA...")
53
  pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
54
  pipe.set_adapters(["better"], adapter_weights=[1.0])
55
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # Fuse for potential speedup
56
  pipe.unload_lora_weights() # Unload after fusing
57
- logging.info("LoRA fused and unloaded.")
58
 
59
  # --- Compilation (Major Speed Optimization) ---
60
- # logging.info("Compiling VAE Decoder...")
61
- # pipe.vae.decoder = torch.compile(pipe.vae.decoder, mode="reduce-overhead", fullgraph=True)
62
- # logging.info("Compiling VAE Encoder...")
63
- # pipe.vae.encoder = torch.compile(pipe.vae.encoder, mode="reduce-overhead", fullgraph=True)
64
- # logging.info("Model compilation finished.")
65
 
66
  # Clear cache after setup
67
  if torch.cuda.is_available():
68
  torch.cuda.empty_cache()
69
- logging.info("CUDA cache cleared after setup.")
70
 
71
  except Exception as e:
72
- logging.error(f"Error during model loading or setup: {e}", exc_info=True)
73
- # Display error in Gradio if UI is already built, otherwise just log and exit.
74
- # For simplicity here, we'll rely on the Gradio UI showing an error if `pipe` is None later.
75
- # If running script directly, consider `sys.exit()`
76
- # raise gr.Error(f"Failed to load models. Check logs for details. Error: {e}")
77
 
78
 
79
  # --- Inference Function ---
80
- @spaces.GPU(duration=30) # Slightly increased duration buffer
81
  def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
82
  """Generates an image using the FLUX pipeline with error handling."""
83
 
@@ -85,10 +67,7 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
85
  raise gr.Error("Diffusion pipeline failed to load. Cannot generate images.")
86
 
87
  if not prompt or prompt.strip() == "":
88
- # Return a blank image or previous result if prompt is empty?
89
- # For now, raise warning and return None.
90
  gr.Warning("Prompt is empty. Please enter a description.")
91
- # Returning None for image, original seed, and error message
92
  return None, seed, "Error: Empty prompt"
93
 
94
  start_time = time.time()
@@ -105,8 +84,6 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
105
  # Clamp steps
106
  steps_to_use = max(MIN_INFERENCE_STEPS, min(steps_to_use, MAX_INFERENCE_STEPS))
107
 
108
- logging.info(f"Generating image with prompt: '{prompt}', seed: {seed}, size: {width}x{height}, steps: {steps_to_use}")
109
-
110
  try:
111
  # Ensure generator is on the correct device
112
  generator = torch.Generator(device=device).manual_seed(int(float(seed)))
@@ -127,18 +104,15 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
127
 
128
  latency = time.time() - start_time
129
  latency_str = f"Latency: {latency:.2f} seconds (Steps: {steps_to_use})"
130
- logging.info(f"Image generated successfully. {latency_str}")
131
  return result_img, seed, latency_str
132
 
133
  except torch.cuda.OutOfMemoryError as e:
134
- logging.error(f"CUDA OutOfMemoryError: {e}", exc_info=True)
135
  # Clear cache and suggest reducing size/steps
136
  if torch.cuda.is_available():
137
  torch.cuda.empty_cache()
138
  raise gr.Error("GPU ran out of memory. Try reducing the image width/height or the number of inference steps.")
139
 
140
  except Exception as e:
141
- logging.error(f"Error during image generation: {e}", exc_info=True)
142
  # Clear cache just in case
143
  if torch.cuda.is_available():
144
  torch.cuda.empty_cache()
@@ -150,14 +124,12 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
150
  # It's triggered by changes in prompt or sliders when realtime is enabled.
151
  def handle_realtime_update(realtime_enabled: bool, prompt: str, seed: int, width: int, height: int, randomize_seed: bool, num_inference_steps: int):
152
  if realtime_enabled and pipe is not None:
153
- logging.debug("Realtime update triggered.")
154
  # Call generate_image directly. Errors within generate_image will be caught and raised as gr.Error.
155
  # We don't set is_enhance=True for realtime updates.
156
  return generate_image(prompt, seed, width, height, randomize_seed, num_inference_steps, is_enhance=False)
157
  else:
158
  # If realtime is disabled or pipe failed, don't update the image, seed, or latency.
159
  # Return gr.update() for each output component to indicate no change.
160
- logging.debug("Realtime update skipped (disabled or pipe error).")
161
  return gr.update(), gr.update(), gr.update()
162
 
163
 
@@ -225,7 +197,8 @@ with gr.Blocks() as demo:
225
  outputs=[result, seed, latency],
226
  show_progress="full",
227
  queue=False,
228
- concurrency_limit=None
 
229
  )
230
 
231
  generateBtn.click(
@@ -251,9 +224,8 @@ with gr.Blocks() as demo:
251
  concurrency_limit=None
252
  )
253
 
254
- def realtime_generation(*args):
255
- if args[0]: # If realtime is enabled
256
- return next(generate_image(*args[1:]))
257
 
258
  prompt.submit(
259
  fn=generate_image,
@@ -266,7 +238,7 @@ with gr.Blocks() as demo:
266
 
267
  for component in [prompt, width, height, num_inference_steps]:
268
  component.input(
269
- fn=realtime_generation,
270
  inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
271
  outputs=[result, seed, latency],
272
  show_progress="hidden",
@@ -274,6 +246,17 @@ with gr.Blocks() as demo:
274
  queue=False,
275
  concurrency_limit=None
276
  )
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  # Launch the app
279
- demo.launch()
 
4
  import spaces
5
  import torch
6
  import time
 
7
  from diffusers import DiffusionPipeline, AutoencoderTiny
8
  # Using AttnProcessor2_0 for potential speedup with PyTorch 2.x
9
  from diffusers.models.attention_processor import AttnProcessor2_0
10
  # Assuming custom_pipeline defines FluxWithCFGPipeline correctly
11
  from custom_pipeline import FluxWithCFGPipeline
12
 
 
 
 
13
  # --- Torch Optimizations ---
14
  torch.backends.cuda.matmul.allow_tf32 = True
15
  torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
 
30
  pipe = None # Initialize pipe to None
31
 
32
  try:
 
33
  pipe = FluxWithCFGPipeline.from_pretrained(
34
  "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
35
  )
 
36
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
37
 
 
38
  pipe.to(device)
39
 
40
  # Apply optimizations
 
41
  pipe.unet.set_attn_processor(AttnProcessor2_0())
42
  pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
43
 
 
44
  pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
45
  pipe.set_adapters(["better"], adapter_weights=[1.0])
46
  pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # Fuse for potential speedup
47
  pipe.unload_lora_weights() # Unload after fusing
 
48
 
49
  # --- Compilation (Major Speed Optimization) ---
50
+ pipe.vae.decoder = torch.compile(pipe.vae.decoder, mode="reduce-overhead", fullgraph=True)
51
+ pipe.vae.encoder = torch.compile(pipe.vae.encoder, mode="reduce-overhead", fullgraph=True)
 
 
 
52
 
53
  # Clear cache after setup
54
  if torch.cuda.is_available():
55
  torch.cuda.empty_cache()
 
56
 
57
  except Exception as e:
58
+ print(e)
 
 
 
 
59
 
60
 
61
  # --- Inference Function ---
62
+ @spaces.GPU() # Slightly increased duration buffer
63
  def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
64
  """Generates an image using the FLUX pipeline with error handling."""
65
 
 
67
  raise gr.Error("Diffusion pipeline failed to load. Cannot generate images.")
68
 
69
  if not prompt or prompt.strip() == "":
 
 
70
  gr.Warning("Prompt is empty. Please enter a description.")
 
71
  return None, seed, "Error: Empty prompt"
72
 
73
  start_time = time.time()
 
84
  # Clamp steps
85
  steps_to_use = max(MIN_INFERENCE_STEPS, min(steps_to_use, MAX_INFERENCE_STEPS))
86
 
 
 
87
  try:
88
  # Ensure generator is on the correct device
89
  generator = torch.Generator(device=device).manual_seed(int(float(seed)))
 
104
 
105
  latency = time.time() - start_time
106
  latency_str = f"Latency: {latency:.2f} seconds (Steps: {steps_to_use})"
 
107
  return result_img, seed, latency_str
108
 
109
  except torch.cuda.OutOfMemoryError as e:
 
110
  # Clear cache and suggest reducing size/steps
111
  if torch.cuda.is_available():
112
  torch.cuda.empty_cache()
113
  raise gr.Error("GPU ran out of memory. Try reducing the image width/height or the number of inference steps.")
114
 
115
  except Exception as e:
 
116
  # Clear cache just in case
117
  if torch.cuda.is_available():
118
  torch.cuda.empty_cache()
 
124
  # It's triggered by changes in prompt or sliders when realtime is enabled.
125
  def handle_realtime_update(realtime_enabled: bool, prompt: str, seed: int, width: int, height: int, randomize_seed: bool, num_inference_steps: int):
126
  if realtime_enabled and pipe is not None:
 
127
  # Call generate_image directly. Errors within generate_image will be caught and raised as gr.Error.
128
  # We don't set is_enhance=True for realtime updates.
129
  return generate_image(prompt, seed, width, height, randomize_seed, num_inference_steps, is_enhance=False)
130
  else:
131
  # If realtime is disabled or pipe failed, don't update the image, seed, or latency.
132
  # Return gr.update() for each output component to indicate no change.
 
133
  return gr.update(), gr.update(), gr.update()
134
 
135
 
 
197
  outputs=[result, seed, latency],
198
  show_progress="full",
199
  queue=False,
200
+ concurrency_limit=None,
201
+ fn_kwargs={"is_enhance": True} # Pass the flag to indicate enhance
202
  )
203
 
204
  generateBtn.click(
 
224
  concurrency_limit=None
225
  )
226
 
227
+ # Removed the intermediate realtime_generation function.
228
+ # handle_realtime_update checks the realtime toggle internally.
 
229
 
230
  prompt.submit(
231
  fn=generate_image,
 
238
 
239
  for component in [prompt, width, height, num_inference_steps]:
240
  component.input(
241
+ fn=handle_realtime_update, # Call the wrapper that checks the toggle
242
  inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
243
  outputs=[result, seed, latency],
244
  show_progress="hidden",
 
246
  queue=False,
247
  concurrency_limit=None
248
  )
249
+
250
+ # Also trigger realtime on seed change if randomize is off
251
+ seed.input(
252
+ fn=handle_realtime_update,
253
+ inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
254
+ outputs=[result, seed, latency],
255
+ show_progress="hidden",
256
+ trigger_mode="always_last",
257
+ queue=False,
258
+ concurrency_limit=None
259
+ )
260
 
261
  # Launch the app
262
+ demo.launch()