Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -11,8 +11,8 @@ from diffusers.models.attention_processor import AttnProcessor2_0
|
|
11 |
from custom_pipeline import FluxWithCFGPipeline
|
12 |
|
13 |
# --- Torch Optimizations ---
|
14 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
15 |
-
torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
|
16 |
|
17 |
# --- Constants ---
|
18 |
MAX_SEED = np.iinfo(np.int32).max
|
@@ -27,39 +27,30 @@ ENHANCE_STEPS = 2 # Fixed steps for the enhance button
|
|
27 |
# --- Device and Model Setup ---
|
28 |
dtype = torch.float16
|
29 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
-
pipe = None # Initialize pipe to None
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
37 |
-
|
38 |
-
pipe.to(device)
|
39 |
|
40 |
-
|
41 |
-
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
42 |
-
pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
pipe.unload_lora_weights() # Unload after fusing
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
except Exception as e:
|
58 |
-
print(e)
|
59 |
|
60 |
|
61 |
# --- Inference Function ---
|
62 |
-
@spaces.GPU
|
63 |
def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
|
64 |
"""Generates an image using the FLUX pipeline with error handling."""
|
65 |
|
@@ -119,20 +110,6 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
|
|
119 |
raise gr.Error(f"An error occurred during generation: {e}")
|
120 |
|
121 |
|
122 |
-
# --- Real-time Generation Wrapper ---
|
123 |
-
# This function checks the realtime toggle before calling the main generation function.
|
124 |
-
# It's triggered by changes in prompt or sliders when realtime is enabled.
|
125 |
-
def handle_realtime_update(realtime_enabled: bool, prompt: str, seed: int, width: int, height: int, randomize_seed: bool, num_inference_steps: int):
|
126 |
-
if realtime_enabled and pipe is not None:
|
127 |
-
# Call generate_image directly. Errors within generate_image will be caught and raised as gr.Error.
|
128 |
-
# We don't set is_enhance=True for realtime updates.
|
129 |
-
return generate_image(prompt, seed, width, height, randomize_seed, num_inference_steps, is_enhance=False)
|
130 |
-
else:
|
131 |
-
# If realtime is disabled or pipe failed, don't update the image, seed, or latency.
|
132 |
-
# Return gr.update() for each output component to indicate no change.
|
133 |
-
return gr.update(), gr.update(), gr.update()
|
134 |
-
|
135 |
-
|
136 |
# --- Example Prompts ---
|
137 |
examples = [
|
138 |
"a tiny astronaut hatching from an egg on the moon",
|
@@ -195,9 +172,7 @@ with gr.Blocks() as demo:
|
|
195 |
fn=generate_image,
|
196 |
inputs=[prompt, seed, width, height],
|
197 |
outputs=[result, seed, latency],
|
198 |
-
show_progress="full"
|
199 |
-
queue=False,
|
200 |
-
concurrency_limit=None,
|
201 |
)
|
202 |
|
203 |
generateBtn.click(
|
@@ -206,7 +181,6 @@ with gr.Blocks() as demo:
|
|
206 |
outputs=[result, seed, latency],
|
207 |
show_progress="full",
|
208 |
api_name="RealtimeFlux",
|
209 |
-
queue=False
|
210 |
)
|
211 |
|
212 |
def update_ui(realtime_enabled):
|
@@ -222,21 +196,14 @@ with gr.Blocks() as demo:
|
|
222 |
realtime.change(
|
223 |
fn=update_ui,
|
224 |
inputs=[realtime],
|
225 |
-
outputs=[prompt, generateBtn]
|
226 |
-
queue=False,
|
227 |
-
concurrency_limit=None
|
228 |
)
|
229 |
|
230 |
-
# Removed the intermediate realtime_generation function.
|
231 |
-
# handle_realtime_update checks the realtime toggle internally.
|
232 |
-
|
233 |
prompt.submit(
|
234 |
fn=generate_image,
|
235 |
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
|
236 |
outputs=[result, seed, latency],
|
237 |
-
show_progress="full"
|
238 |
-
queue=False,
|
239 |
-
concurrency_limit=None
|
240 |
)
|
241 |
|
242 |
for component in [prompt, width, height, num_inference_steps]:
|
@@ -245,9 +212,7 @@ with gr.Blocks() as demo:
|
|
245 |
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
|
246 |
outputs=[result, seed, latency],
|
247 |
show_progress="hidden",
|
248 |
-
trigger_mode="always_last"
|
249 |
-
queue=False,
|
250 |
-
concurrency_limit=None
|
251 |
)
|
252 |
|
253 |
# Launch the app
|
|
|
11 |
from custom_pipeline import FluxWithCFGPipeline
|
12 |
|
13 |
# --- Torch Optimizations ---
|
14 |
+
# torch.backends.cuda.matmul.allow_tf32 = True
|
15 |
+
# torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
|
16 |
|
17 |
# --- Constants ---
|
18 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
27 |
# --- Device and Model Setup ---
|
28 |
dtype = torch.float16
|
29 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
30 |
|
31 |
+
pipe = FluxWithCFGPipeline.from_pretrained(
|
32 |
+
"black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
|
33 |
+
)
|
34 |
+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
|
|
|
|
|
|
35 |
|
36 |
+
pipe.to(device)
|
|
|
|
|
37 |
|
38 |
+
# Apply optimizations
|
39 |
+
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
40 |
+
pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
|
|
|
41 |
|
42 |
+
pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
|
43 |
+
pipe.set_adapters(["better"], adapter_weights=[1.0])
|
44 |
+
pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # Fuse for potential speedup
|
45 |
+
pipe.unload_lora_weights() # Unload after fusing
|
46 |
|
47 |
+
# --- Compilation (Major Speed Optimization) ---
|
48 |
+
pipe.vae.decoder = torch.compile(pipe.vae.decoder, mode="reduce-overhead", fullgraph=True)
|
49 |
+
pipe.vae.encoder = torch.compile(pipe.vae.encoder, mode="reduce-overhead", fullgraph=True)
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
# --- Inference Function ---
|
53 |
+
@spaces.GPU
|
54 |
def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
|
55 |
"""Generates an image using the FLUX pipeline with error handling."""
|
56 |
|
|
|
110 |
raise gr.Error(f"An error occurred during generation: {e}")
|
111 |
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
# --- Example Prompts ---
|
114 |
examples = [
|
115 |
"a tiny astronaut hatching from an egg on the moon",
|
|
|
172 |
fn=generate_image,
|
173 |
inputs=[prompt, seed, width, height],
|
174 |
outputs=[result, seed, latency],
|
175 |
+
show_progress="full"
|
|
|
|
|
176 |
)
|
177 |
|
178 |
generateBtn.click(
|
|
|
181 |
outputs=[result, seed, latency],
|
182 |
show_progress="full",
|
183 |
api_name="RealtimeFlux",
|
|
|
184 |
)
|
185 |
|
186 |
def update_ui(realtime_enabled):
|
|
|
196 |
realtime.change(
|
197 |
fn=update_ui,
|
198 |
inputs=[realtime],
|
199 |
+
outputs=[prompt, generateBtn]
|
|
|
|
|
200 |
)
|
201 |
|
|
|
|
|
|
|
202 |
prompt.submit(
|
203 |
fn=generate_image,
|
204 |
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
|
205 |
outputs=[result, seed, latency],
|
206 |
+
show_progress="full"
|
|
|
|
|
207 |
)
|
208 |
|
209 |
for component in [prompt, width, height, num_inference_steps]:
|
|
|
212 |
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
|
213 |
outputs=[result, seed, latency],
|
214 |
show_progress="hidden",
|
215 |
+
trigger_mode="always_last"
|
|
|
|
|
216 |
)
|
217 |
|
218 |
# Launch the app
|