multimodalart HF Staff commited on
Commit
481a175
·
verified ·
1 Parent(s): 6e67c38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -37
app.py CHANGED
@@ -87,8 +87,6 @@ APP_STATE = {
87
  }
88
 
89
  def initialize_vae_decoder(use_taehv=False, use_trt=False):
90
- global APP_STATE
91
-
92
  if use_trt:
93
  from demo_utils.vae import VAETRTWrapper
94
  print("Initializing TensorRT VAE Decoder...")
@@ -138,6 +136,13 @@ def initialize_vae_decoder(use_taehv=False, use_trt=False):
138
  # Initialize with default VAE
139
  initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
140
 
 
 
 
 
 
 
 
141
  # --- Additional Outputs Handler ---
142
  def handle_additional_outputs(status_html_update, video_update, webrtc_output):
143
  return status_html_update, video_update, webrtc_output
@@ -145,41 +150,17 @@ def handle_additional_outputs(status_html_update, video_update, webrtc_output):
145
  # --- FastRTC Video Generation Handler ---
146
  @torch.no_grad()
147
  @spaces.GPU
148
- def video_generation_handler(prompt, seed, enable_torch_compile, enable_fp8, use_taehv, progress=gr.Progress()):
149
  """
150
  Generator function that yields BGR NumPy frames for real-time streaming.
151
  Returns cleanly when done - no infinite loops.
152
  """
153
- global APP_STATE
154
 
155
  if seed == -1:
156
  seed = random.randint(0, 2**32 - 1)
157
 
158
  print(f"🎬 Starting video generation with prompt: '{prompt}' and seed: {seed}")
159
 
160
- # --- Model & Pipeline Configuration ---
161
- if use_taehv != APP_STATE["current_use_taehv"]:
162
- print(f"🔄 Switching VAE to {'TAEHV' if use_taehv else 'Default VAE'}")
163
- initialize_vae_decoder(use_taehv=use_taehv, use_trt=args.trt)
164
-
165
- pipeline = CausalInferencePipeline(
166
- config, device=gpu, generator=transformer, text_encoder=text_encoder,
167
- vae=APP_STATE["current_vae_decoder"]
168
- )
169
-
170
- if enable_fp8 and not APP_STATE["fp8_applied"]:
171
- print("⚡ Applying FP8 Quantization...")
172
- from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8Weight, PerTensor
173
- quantize_(pipeline.generator.model, Float8DynamicActivationFloat8Weight(granularity=PerTensor()))
174
- APP_STATE["fp8_applied"] = True
175
-
176
- if enable_torch_compile and not APP_STATE["torch_compile_applied"]:
177
- print("🔥 Applying torch.compile (this may take a few minutes)...")
178
- pipeline.generator.model = torch.compile(pipeline.generator.model, mode="max-autotune-no-cudagraphs")
179
- if not use_taehv and not LOW_MEMORY and not args.trt:
180
- pipeline.vae.decoder = torch.compile(pipeline.vae.decoder, mode="max-autotune-no-cudagraphs")
181
- APP_STATE["torch_compile_applied"] = True
182
-
183
  print("🔤 Encoding text prompt...")
184
  conditional_dict = text_encoder(text_prompts=[prompt])
185
  for key, value in conditional_dict.items():
@@ -187,14 +168,14 @@ def video_generation_handler(prompt, seed, enable_torch_compile, enable_fp8, use
187
 
188
  # --- Generation Loop ---
189
  rnd = torch.Generator(gpu).manual_seed(int(seed))
190
- pipeline._initialize_kv_cache(1, torch.float16, gpu)
191
- pipeline._initialize_crossattn_cache(1, torch.float16, gpu)
192
  noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
193
 
194
  vae_cache, latents_cache = None, None
195
  if not APP_STATE["current_use_taehv"] and not args.trt:
196
  vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
197
-
198
  num_blocks = 7
199
  current_start_frame = 0
200
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
@@ -303,7 +284,6 @@ def video_generation_handler(prompt, seed, enable_torch_compile, enable_fp8, use
303
  status_html = (
304
  f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
305
  f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
306
- # Correctly implemented progress bar
307
  f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
308
  f" <div style='width: {frame_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
309
  f" </div>"
@@ -352,11 +332,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Self-Forcing FastRTC Demo") as dem
352
 
353
  with gr.Accordion("⚙️ Performance Options", open=False):
354
  gr.Markdown("*These optimizations are applied once per session*")
355
- with gr.Row():
356
- torch_compile_toggle = gr.Checkbox(label="🔥 torch.compile", value=False)
357
- fp8_toggle = gr.Checkbox(label="⚡ FP8 Quantization", value=False, visible=not args.trt)
358
- taehv_toggle = gr.Checkbox(label="⚡ TAEHV VAE", value=False, visible=not args.trt)
359
-
360
  start_btn = gr.Button("🎬 Start Generation", variant="primary", size="lg")
361
 
362
  with gr.Column(scale=3):
@@ -385,7 +361,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Self-Forcing FastRTC Demo") as dem
385
  # Connect the generator to the WebRTC stream
386
  webrtc_output.stream(
387
  fn=video_generation_handler,
388
- inputs=[prompt, seed, torch_compile_toggle, fp8_toggle, taehv_toggle],
389
  outputs=[webrtc_output],
390
  time_limit=300, # 5 minutes max
391
  trigger=start_btn.click,
 
87
  }
88
 
89
  def initialize_vae_decoder(use_taehv=False, use_trt=False):
 
 
90
  if use_trt:
91
  from demo_utils.vae import VAETRTWrapper
92
  print("Initializing TensorRT VAE Decoder...")
 
136
  # Initialize with default VAE
137
  initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
138
 
139
+ pipeline = CausalInferencePipeline(
140
+ config, device=gpu, generator=transformer, text_encoder=text_encoder,
141
+ vae=APP_STATE["current_vae_decoder"]
142
+ )
143
+
144
+ pipeline.to(gpu)
145
+
146
  # --- Additional Outputs Handler ---
147
  def handle_additional_outputs(status_html_update, video_update, webrtc_output):
148
  return status_html_update, video_update, webrtc_output
 
150
  # --- FastRTC Video Generation Handler ---
151
  @torch.no_grad()
152
  @spaces.GPU
153
+ def video_generation_handler(prompt, seed, progress=gr.Progress()):
154
  """
155
  Generator function that yields BGR NumPy frames for real-time streaming.
156
  Returns cleanly when done - no infinite loops.
157
  """
 
158
 
159
  if seed == -1:
160
  seed = random.randint(0, 2**32 - 1)
161
 
162
  print(f"🎬 Starting video generation with prompt: '{prompt}' and seed: {seed}")
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  print("🔤 Encoding text prompt...")
165
  conditional_dict = text_encoder(text_prompts=[prompt])
166
  for key, value in conditional_dict.items():
 
168
 
169
  # --- Generation Loop ---
170
  rnd = torch.Generator(gpu).manual_seed(int(seed))
171
+ pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
172
+ pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
173
  noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
174
 
175
  vae_cache, latents_cache = None, None
176
  if not APP_STATE["current_use_taehv"] and not args.trt:
177
  vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
178
+
179
  num_blocks = 7
180
  current_start_frame = 0
181
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
 
284
  status_html = (
285
  f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
286
  f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
 
287
  f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
288
  f" <div style='width: {frame_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
289
  f" </div>"
 
332
 
333
  with gr.Accordion("⚙️ Performance Options", open=False):
334
  gr.Markdown("*These optimizations are applied once per session*")
335
+
 
 
 
 
336
  start_btn = gr.Button("🎬 Start Generation", variant="primary", size="lg")
337
 
338
  with gr.Column(scale=3):
 
361
  # Connect the generator to the WebRTC stream
362
  webrtc_output.stream(
363
  fn=video_generation_handler,
364
+ inputs=[prompt, seed],
365
  outputs=[webrtc_output],
366
  time_limit=300, # 5 minutes max
367
  trigger=start_btn.click,