Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -4,16 +4,12 @@ import random
|
|
4 |
import spaces
|
5 |
import torch
|
6 |
import time
|
7 |
-
import logging
|
8 |
from diffusers import DiffusionPipeline, AutoencoderTiny
|
9 |
# Using AttnProcessor2_0 for potential speedup with PyTorch 2.x
|
10 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
11 |
# Assuming custom_pipeline defines FluxWithCFGPipeline correctly
|
12 |
from custom_pipeline import FluxWithCFGPipeline
|
13 |
|
14 |
-
# --- Setup Logging ---
|
15 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
16 |
-
|
17 |
# --- Torch Optimizations ---
|
18 |
torch.backends.cuda.matmul.allow_tf32 = True
|
19 |
torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
|
@@ -34,50 +30,36 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
34 |
pipe = None # Initialize pipe to None
|
35 |
|
36 |
try:
|
37 |
-
logging.info("Loading diffusion pipeline...")
|
38 |
pipe = FluxWithCFGPipeline.from_pretrained(
|
39 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
|
40 |
)
|
41 |
-
logging.info("Loading VAE...")
|
42 |
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
43 |
|
44 |
-
logging.info(f"Moving pipeline to {device}...")
|
45 |
pipe.to(device)
|
46 |
|
47 |
# Apply optimizations
|
48 |
-
logging.info("Setting attention processor...")
|
49 |
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
50 |
pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
|
51 |
|
52 |
-
logging.info("Loading and fusing LoRA...")
|
53 |
pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
|
54 |
pipe.set_adapters(["better"], adapter_weights=[1.0])
|
55 |
pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # Fuse for potential speedup
|
56 |
pipe.unload_lora_weights() # Unload after fusing
|
57 |
-
logging.info("LoRA fused and unloaded.")
|
58 |
|
59 |
# --- Compilation (Major Speed Optimization) ---
|
60 |
-
|
61 |
-
|
62 |
-
# logging.info("Compiling VAE Encoder...")
|
63 |
-
# pipe.vae.encoder = torch.compile(pipe.vae.encoder, mode="reduce-overhead", fullgraph=True)
|
64 |
-
# logging.info("Model compilation finished.")
|
65 |
|
66 |
# Clear cache after setup
|
67 |
if torch.cuda.is_available():
|
68 |
torch.cuda.empty_cache()
|
69 |
-
logging.info("CUDA cache cleared after setup.")
|
70 |
|
71 |
except Exception as e:
|
72 |
-
|
73 |
-
# Display error in Gradio if UI is already built, otherwise just log and exit.
|
74 |
-
# For simplicity here, we'll rely on the Gradio UI showing an error if `pipe` is None later.
|
75 |
-
# If running script directly, consider `sys.exit()`
|
76 |
-
# raise gr.Error(f"Failed to load models. Check logs for details. Error: {e}")
|
77 |
|
78 |
|
79 |
# --- Inference Function ---
|
80 |
-
@spaces.GPU(
|
81 |
def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
|
82 |
"""Generates an image using the FLUX pipeline with error handling."""
|
83 |
|
@@ -85,10 +67,7 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
|
|
85 |
raise gr.Error("Diffusion pipeline failed to load. Cannot generate images.")
|
86 |
|
87 |
if not prompt or prompt.strip() == "":
|
88 |
-
# Return a blank image or previous result if prompt is empty?
|
89 |
-
# For now, raise warning and return None.
|
90 |
gr.Warning("Prompt is empty. Please enter a description.")
|
91 |
-
# Returning None for image, original seed, and error message
|
92 |
return None, seed, "Error: Empty prompt"
|
93 |
|
94 |
start_time = time.time()
|
@@ -105,8 +84,6 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
|
|
105 |
# Clamp steps
|
106 |
steps_to_use = max(MIN_INFERENCE_STEPS, min(steps_to_use, MAX_INFERENCE_STEPS))
|
107 |
|
108 |
-
logging.info(f"Generating image with prompt: '{prompt}', seed: {seed}, size: {width}x{height}, steps: {steps_to_use}")
|
109 |
-
|
110 |
try:
|
111 |
# Ensure generator is on the correct device
|
112 |
generator = torch.Generator(device=device).manual_seed(int(float(seed)))
|
@@ -127,18 +104,15 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
|
|
127 |
|
128 |
latency = time.time() - start_time
|
129 |
latency_str = f"Latency: {latency:.2f} seconds (Steps: {steps_to_use})"
|
130 |
-
logging.info(f"Image generated successfully. {latency_str}")
|
131 |
return result_img, seed, latency_str
|
132 |
|
133 |
except torch.cuda.OutOfMemoryError as e:
|
134 |
-
logging.error(f"CUDA OutOfMemoryError: {e}", exc_info=True)
|
135 |
# Clear cache and suggest reducing size/steps
|
136 |
if torch.cuda.is_available():
|
137 |
torch.cuda.empty_cache()
|
138 |
raise gr.Error("GPU ran out of memory. Try reducing the image width/height or the number of inference steps.")
|
139 |
|
140 |
except Exception as e:
|
141 |
-
logging.error(f"Error during image generation: {e}", exc_info=True)
|
142 |
# Clear cache just in case
|
143 |
if torch.cuda.is_available():
|
144 |
torch.cuda.empty_cache()
|
@@ -150,14 +124,12 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
|
|
150 |
# It's triggered by changes in prompt or sliders when realtime is enabled.
|
151 |
def handle_realtime_update(realtime_enabled: bool, prompt: str, seed: int, width: int, height: int, randomize_seed: bool, num_inference_steps: int):
|
152 |
if realtime_enabled and pipe is not None:
|
153 |
-
logging.debug("Realtime update triggered.")
|
154 |
# Call generate_image directly. Errors within generate_image will be caught and raised as gr.Error.
|
155 |
# We don't set is_enhance=True for realtime updates.
|
156 |
return generate_image(prompt, seed, width, height, randomize_seed, num_inference_steps, is_enhance=False)
|
157 |
else:
|
158 |
# If realtime is disabled or pipe failed, don't update the image, seed, or latency.
|
159 |
# Return gr.update() for each output component to indicate no change.
|
160 |
-
logging.debug("Realtime update skipped (disabled or pipe error).")
|
161 |
return gr.update(), gr.update(), gr.update()
|
162 |
|
163 |
|
@@ -225,7 +197,8 @@ with gr.Blocks() as demo:
|
|
225 |
outputs=[result, seed, latency],
|
226 |
show_progress="full",
|
227 |
queue=False,
|
228 |
-
concurrency_limit=None
|
|
|
229 |
)
|
230 |
|
231 |
generateBtn.click(
|
@@ -251,9 +224,8 @@ with gr.Blocks() as demo:
|
|
251 |
concurrency_limit=None
|
252 |
)
|
253 |
|
254 |
-
|
255 |
-
|
256 |
-
return next(generate_image(*args[1:]))
|
257 |
|
258 |
prompt.submit(
|
259 |
fn=generate_image,
|
@@ -266,7 +238,7 @@ with gr.Blocks() as demo:
|
|
266 |
|
267 |
for component in [prompt, width, height, num_inference_steps]:
|
268 |
component.input(
|
269 |
-
fn=
|
270 |
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
|
271 |
outputs=[result, seed, latency],
|
272 |
show_progress="hidden",
|
@@ -274,6 +246,17 @@ with gr.Blocks() as demo:
|
|
274 |
queue=False,
|
275 |
concurrency_limit=None
|
276 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
# Launch the app
|
279 |
-
demo.launch()
|
|
|
4 |
import spaces
|
5 |
import torch
|
6 |
import time
|
|
|
7 |
from diffusers import DiffusionPipeline, AutoencoderTiny
|
8 |
# Using AttnProcessor2_0 for potential speedup with PyTorch 2.x
|
9 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
10 |
# Assuming custom_pipeline defines FluxWithCFGPipeline correctly
|
11 |
from custom_pipeline import FluxWithCFGPipeline
|
12 |
|
|
|
|
|
|
|
13 |
# --- Torch Optimizations ---
|
14 |
torch.backends.cuda.matmul.allow_tf32 = True
|
15 |
torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
|
|
|
30 |
pipe = None # Initialize pipe to None
|
31 |
|
32 |
try:
|
|
|
33 |
pipe = FluxWithCFGPipeline.from_pretrained(
|
34 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
|
35 |
)
|
|
|
36 |
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
37 |
|
|
|
38 |
pipe.to(device)
|
39 |
|
40 |
# Apply optimizations
|
|
|
41 |
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
42 |
pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
|
43 |
|
|
|
44 |
pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
|
45 |
pipe.set_adapters(["better"], adapter_weights=[1.0])
|
46 |
pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # Fuse for potential speedup
|
47 |
pipe.unload_lora_weights() # Unload after fusing
|
|
|
48 |
|
49 |
# --- Compilation (Major Speed Optimization) ---
|
50 |
+
pipe.vae.decoder = torch.compile(pipe.vae.decoder, mode="reduce-overhead", fullgraph=True)
|
51 |
+
pipe.vae.encoder = torch.compile(pipe.vae.encoder, mode="reduce-overhead", fullgraph=True)
|
|
|
|
|
|
|
52 |
|
53 |
# Clear cache after setup
|
54 |
if torch.cuda.is_available():
|
55 |
torch.cuda.empty_cache()
|
|
|
56 |
|
57 |
except Exception as e:
|
58 |
+
print(e)
|
|
|
|
|
|
|
|
|
59 |
|
60 |
|
61 |
# --- Inference Function ---
|
62 |
+
@spaces.GPU() # Slightly increased duration buffer
|
63 |
def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
|
64 |
"""Generates an image using the FLUX pipeline with error handling."""
|
65 |
|
|
|
67 |
raise gr.Error("Diffusion pipeline failed to load. Cannot generate images.")
|
68 |
|
69 |
if not prompt or prompt.strip() == "":
|
|
|
|
|
70 |
gr.Warning("Prompt is empty. Please enter a description.")
|
|
|
71 |
return None, seed, "Error: Empty prompt"
|
72 |
|
73 |
start_time = time.time()
|
|
|
84 |
# Clamp steps
|
85 |
steps_to_use = max(MIN_INFERENCE_STEPS, min(steps_to_use, MAX_INFERENCE_STEPS))
|
86 |
|
|
|
|
|
87 |
try:
|
88 |
# Ensure generator is on the correct device
|
89 |
generator = torch.Generator(device=device).manual_seed(int(float(seed)))
|
|
|
104 |
|
105 |
latency = time.time() - start_time
|
106 |
latency_str = f"Latency: {latency:.2f} seconds (Steps: {steps_to_use})"
|
|
|
107 |
return result_img, seed, latency_str
|
108 |
|
109 |
except torch.cuda.OutOfMemoryError as e:
|
|
|
110 |
# Clear cache and suggest reducing size/steps
|
111 |
if torch.cuda.is_available():
|
112 |
torch.cuda.empty_cache()
|
113 |
raise gr.Error("GPU ran out of memory. Try reducing the image width/height or the number of inference steps.")
|
114 |
|
115 |
except Exception as e:
|
|
|
116 |
# Clear cache just in case
|
117 |
if torch.cuda.is_available():
|
118 |
torch.cuda.empty_cache()
|
|
|
124 |
# It's triggered by changes in prompt or sliders when realtime is enabled.
|
125 |
def handle_realtime_update(realtime_enabled: bool, prompt: str, seed: int, width: int, height: int, randomize_seed: bool, num_inference_steps: int):
|
126 |
if realtime_enabled and pipe is not None:
|
|
|
127 |
# Call generate_image directly. Errors within generate_image will be caught and raised as gr.Error.
|
128 |
# We don't set is_enhance=True for realtime updates.
|
129 |
return generate_image(prompt, seed, width, height, randomize_seed, num_inference_steps, is_enhance=False)
|
130 |
else:
|
131 |
# If realtime is disabled or pipe failed, don't update the image, seed, or latency.
|
132 |
# Return gr.update() for each output component to indicate no change.
|
|
|
133 |
return gr.update(), gr.update(), gr.update()
|
134 |
|
135 |
|
|
|
197 |
outputs=[result, seed, latency],
|
198 |
show_progress="full",
|
199 |
queue=False,
|
200 |
+
concurrency_limit=None,
|
201 |
+
fn_kwargs={"is_enhance": True} # Pass the flag to indicate enhance
|
202 |
)
|
203 |
|
204 |
generateBtn.click(
|
|
|
224 |
concurrency_limit=None
|
225 |
)
|
226 |
|
227 |
+
# Removed the intermediate realtime_generation function.
|
228 |
+
# handle_realtime_update checks the realtime toggle internally.
|
|
|
229 |
|
230 |
prompt.submit(
|
231 |
fn=generate_image,
|
|
|
238 |
|
239 |
for component in [prompt, width, height, num_inference_steps]:
|
240 |
component.input(
|
241 |
+
fn=handle_realtime_update, # Call the wrapper that checks the toggle
|
242 |
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
|
243 |
outputs=[result, seed, latency],
|
244 |
show_progress="hidden",
|
|
|
246 |
queue=False,
|
247 |
concurrency_limit=None
|
248 |
)
|
249 |
+
|
250 |
+
# Also trigger realtime on seed change if randomize is off
|
251 |
+
seed.input(
|
252 |
+
fn=handle_realtime_update,
|
253 |
+
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
|
254 |
+
outputs=[result, seed, latency],
|
255 |
+
show_progress="hidden",
|
256 |
+
trigger_mode="always_last",
|
257 |
+
queue=False,
|
258 |
+
concurrency_limit=None
|
259 |
+
)
|
260 |
|
261 |
# Launch the app
|
262 |
+
demo.launch()
|