jbilcke-hf HF Staff commited on
Commit
3a23852
·
1 Parent(s): b2c19b1
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -94,10 +94,16 @@ APP_STATE = {
94
  # I've tried to enable it, but I didn't notice a significant performance improvement..
95
  ENABLE_TORCH_COMPILATION = False
96
 
 
 
 
 
 
 
97
  # Apply torch.compile for maximum performance
98
  if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
99
  print("🚀 Applying torch.compile for speed optimization...")
100
- transformer.compile(mode="max-autotune-no-cudagraphs")
101
  APP_STATE["torch_compile_applied"] = True
102
  print("✅ torch.compile applied to transformer")
103
 
@@ -199,7 +205,7 @@ def initialize_vae_decoder(use_taehv=False, use_trt=False):
199
  # Apply torch.compile to VAE decoder if enabled (following demo.py pattern)
200
  if APP_STATE["torch_compile_applied"] and not use_taehv and not use_trt:
201
  print("🚀 Applying torch.compile to VAE decoder...")
202
- vae_decoder.compile(mode="max-autotune-no-cudagraphs")
203
  print("✅ torch.compile applied to VAE decoder")
204
 
205
  APP_STATE["current_vae_decoder"] = vae_decoder
 
94
  # I've tried to enable it, but I didn't notice a significant performance improvement..
95
  ENABLE_TORCH_COMPILATION = False
96
 
97
+ # “default”: The default mode, used when no mode parameter is specified. It provides a good balance between performance and overhead.
98
+ # “reduce-overhead”: Minimizes Python-related overhead using CUDA graphs. However, it may increase memory usage.
99
+ # “max-autotune”: Uses Triton or template-based matrix multiplications on supported devices. It takes longer to compile but optimizes for the fastest possible execution. On GPUs it enables CUDA graphs by default.
100
+ # “max-autotune-no-cudagraphs”: Similar to “max-autotune”, but without CUDA graphs.
101
+ TORCH_COMPILATION_MODE = "default"
102
+
103
  # Apply torch.compile for maximum performance
104
  if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
105
  print("🚀 Applying torch.compile for speed optimization...")
106
+ transformer.compile(mode=TORCH_COMPILATION_MODE)
107
  APP_STATE["torch_compile_applied"] = True
108
  print("✅ torch.compile applied to transformer")
109
 
 
205
  # Apply torch.compile to VAE decoder if enabled (following demo.py pattern)
206
  if APP_STATE["torch_compile_applied"] and not use_taehv and not use_trt:
207
  print("🚀 Applying torch.compile to VAE decoder...")
208
+ vae_decoder.compile(mode=TORCH_COMPILATION_MODE)
209
  print("✅ torch.compile applied to VAE decoder")
210
 
211
  APP_STATE["current_vae_decoder"] = vae_decoder