multimodalart HF Staff commited on
Commit
b720739
·
verified ·
1 Parent(s): e37adb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +427 -0
app.py CHANGED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download, hf_hub_download
2
+
3
+ snapshot_download(
4
+ repo_id="Wan-AI/Wan2.1-T2V-1.3B",
5
+ local_dir="wan_models/Wan2.1-T2V-1.3B",
6
+ local_dir_use_symlinks=False,
7
+ resume_download=True,
8
+ repo_type="model"
9
+ )
10
+
11
+ hf_hub_download(
12
+ repo_id="gdhe17/Self-Forcing",
13
+ filename="checkpoints/self_forcing_dmd.pt",
14
+ local_dir=".",
15
+ local_dir_use_symlinks=False
16
+ )
17
+
18
+ import os
19
+ import re
20
+ import random
21
+ import argparse
22
+ import hashlib
23
+ import urllib.request
24
+ from PIL import Image
25
+ import spaces
26
+ import numpy as np
27
+ import torch
28
+ import gradio as gr
29
+ from omegaconf import OmegaConf
30
+ from tqdm import tqdm
31
+ import imageio # Added for final video rendering
32
+
33
+ # FastRTC imports
34
+ from fastrtc import WebRTC, get_turn_credentials
35
+ from fastrtc.utils import AdditionalOutputs, CloseStream
36
+
37
+ # Original project imports
38
+ from pipeline import CausalInferencePipeline
39
+ from demo_utils.constant import ZERO_VAE_CACHE
40
+ from demo_utils.vae_block3 import VAEDecoderWrapper
41
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
42
+ from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller
43
+
44
+ # --- Argument Parsing ---
45
+ parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with FastRTC")
46
+ parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
47
+ parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
48
+ parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt', help="Path to the model checkpoint.")
49
+ parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
50
+ parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
51
+ parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
52
+ args = parser.parse_args()
53
+
54
+ # --- Global Setup & Model Loading ---
55
+ print(f"CUDA device: {gpu}")
56
+ print(f'Initial Free VRAM: {get_cuda_free_memory_gb(gpu):.2f} GB')
57
+ LOW_MEMORY = get_cuda_free_memory_gb(gpu) < 40
58
+
59
+ # Load configs
60
+ try:
61
+ config = OmegaConf.load(args.config_path)
62
+ default_config = OmegaConf.load("configs/default_config.yaml")
63
+ config = OmegaConf.merge(default_config, config)
64
+ except FileNotFoundError as e:
65
+ print(f"Error loading config file: {e}\n. Please ensure config files are in the correct path.")
66
+ exit(1)
67
+
68
+ # Initialize Models
69
+ print("Initializing models...")
70
+ text_encoder = WanTextEncoder()
71
+ transformer = WanDiffusionWrapper(is_causal=True)
72
+
73
+ try:
74
+ state_dict = torch.load(args.checkpoint_path, map_location="cpu")
75
+ transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator')))
76
+ except FileNotFoundError as e:
77
+ print(f"Error loading checkpoint: {e}\nPlease ensure the checkpoint '{args.checkpoint_path}' exists.")
78
+ exit(1)
79
+
80
+ # Prepare models for inference
81
+ text_encoder.eval().to(dtype=torch.bfloat16).requires_grad_(False)
82
+ transformer.eval().to(dtype=torch.float16).requires_grad_(False)
83
+
84
+ if LOW_MEMORY:
85
+ print("Low memory mode enabled. Using dynamic model swapping.")
86
+ DynamicSwapInstaller.install_model(text_encoder, device=gpu)
87
+ else:
88
+ text_encoder.to(gpu)
89
+ transformer.to(gpu)
90
+
91
+ # --- VAE Decoder Management ---
92
+ APP_STATE = {
93
+ "torch_compile_applied": False,
94
+ "fp8_applied": False,
95
+ "current_use_taehv": False,
96
+ "current_vae_decoder": None,
97
+ }
98
+
99
+ def initialize_vae_decoder(use_taehv=False, use_trt=False):
100
+ global APP_STATE
101
+
102
+ if use_trt:
103
+ from demo_utils.vae import VAETRTWrapper
104
+ print("Initializing TensorRT VAE Decoder...")
105
+ vae_decoder = VAETRTWrapper()
106
+ APP_STATE["current_use_taehv"] = False
107
+ elif use_taehv:
108
+ print("Initializing TAEHV VAE Decoder...")
109
+ from demo_utils.taehv import TAEHV
110
+ taehv_checkpoint_path = "checkpoints/taew2_1.pth"
111
+ if not os.path.exists(taehv_checkpoint_path):
112
+ print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...")
113
+ os.makedirs("checkpoints", exist_ok=True)
114
+ download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
115
+ try:
116
+ urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
117
+ except Exception as e:
118
+ raise RuntimeError(f"Failed to download taew2_1.pth: {e}")
119
+
120
+ class DotDict(dict): __getattr__ = dict.get
121
+
122
+ class TAEHVDiffusersWrapper(torch.nn.Module):
123
+ def __init__(self):
124
+ super().__init__()
125
+ self.dtype = torch.float16
126
+ self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
127
+ self.config = DotDict(scaling_factor=1.0)
128
+ def decode(self, latents, return_dict=None):
129
+ return self.taehv.decode_video(latents, parallel=not LOW_MEMORY).mul_(2).sub_(1)
130
+
131
+ vae_decoder = TAEHVDiffusersWrapper()
132
+ APP_STATE["current_use_taehv"] = True
133
+ else:
134
+ print("Initializing Default VAE Decoder...")
135
+ vae_decoder = VAEDecoderWrapper()
136
+ try:
137
+ vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
138
+ decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
139
+ vae_decoder.load_state_dict(decoder_state_dict)
140
+ except FileNotFoundError:
141
+ print("Warning: Default VAE weights not found.")
142
+ APP_STATE["current_use_taehv"] = False
143
+
144
+ vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(gpu)
145
+ APP_STATE["current_vae_decoder"] = vae_decoder
146
+ print(f"✅ VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
147
+
148
+ # Initialize with default VAE
149
+ initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
150
+
151
+ # --- Additional Outputs Handler ---
152
+ def handle_additional_outputs(status_html_update, video_update, webrtc_output):
153
+ return status_html_update, video_update, webrtc_output
154
+
155
+ # --- FastRTC Video Generation Handler ---
156
+ @torch.no_grad()
157
+ @spaces.GPU
158
+ def video_generation_handler(prompt, seed, enable_torch_compile, enable_fp8, use_taehv, progress=gr.Progress()):
159
+ """
160
+ Generator function that yields BGR NumPy frames for real-time streaming.
161
+ Returns cleanly when done - no infinite loops.
162
+ """
163
+ global APP_STATE
164
+
165
+ if seed == -1:
166
+ seed = random.randint(0, 2**32 - 1)
167
+
168
+ print(f"🎬 Starting video generation with prompt: '{prompt}' and seed: {seed}")
169
+
170
+ # --- Model & Pipeline Configuration ---
171
+ if use_taehv != APP_STATE["current_use_taehv"]:
172
+ print(f"🔄 Switching VAE to {'TAEHV' if use_taehv else 'Default VAE'}")
173
+ initialize_vae_decoder(use_taehv=use_taehv, use_trt=args.trt)
174
+
175
+ pipeline = CausalInferencePipeline(
176
+ config, device=gpu, generator=transformer, text_encoder=text_encoder,
177
+ vae=APP_STATE["current_vae_decoder"]
178
+ )
179
+
180
+ if enable_fp8 and not APP_STATE["fp8_applied"]:
181
+ print("⚡ Applying FP8 Quantization...")
182
+ from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8Weight, PerTensor
183
+ quantize_(pipeline.generator.model, Float8DynamicActivationFloat8Weight(granularity=PerTensor()))
184
+ APP_STATE["fp8_applied"] = True
185
+
186
+ if enable_torch_compile and not APP_STATE["torch_compile_applied"]:
187
+ print("🔥 Applying torch.compile (this may take a few minutes)...")
188
+ pipeline.generator.model = torch.compile(pipeline.generator.model, mode="max-autotune-no-cudagraphs")
189
+ if not use_taehv and not LOW_MEMORY and not args.trt:
190
+ pipeline.vae.decoder = torch.compile(pipeline.vae.decoder, mode="max-autotune-no-cudagraphs")
191
+ APP_STATE["torch_compile_applied"] = True
192
+
193
+ print("🔤 Encoding text prompt...")
194
+ conditional_dict = text_encoder(text_prompts=[prompt])
195
+ for key, value in conditional_dict.items():
196
+ conditional_dict[key] = value.to(dtype=torch.float16)
197
+
198
+ # --- Generation Loop ---
199
+ rnd = torch.Generator(gpu).manual_seed(int(seed))
200
+ pipeline._initialize_kv_cache(1, torch.float16, gpu)
201
+ pipeline._initialize_crossattn_cache(1, torch.float16, gpu)
202
+ noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
203
+
204
+ vae_cache, latents_cache = None, None
205
+ if not APP_STATE["current_use_taehv"] and not args.trt:
206
+ vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
207
+
208
+ num_blocks = 7
209
+ current_start_frame = 0
210
+ all_num_frames = [pipeline.num_frame_per_block] * num_blocks
211
+
212
+ total_frames_yielded = 0
213
+ all_frames_for_video = [] # To collect frames for final video
214
+
215
+ for idx, current_num_frames in enumerate(all_num_frames):
216
+ print(f"📦 Processing block {idx+1}/{num_blocks} with {current_num_frames} frames")
217
+
218
+ noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames]
219
+
220
+ for step_idx, current_timestep in enumerate(pipeline.denoising_step_list):
221
+ timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep
222
+ _, denoised_pred = pipeline.generator(
223
+ noisy_image_or_video=noisy_input, conditional_dict=conditional_dict,
224
+ timestep=timestep, kv_cache=pipeline.kv_cache1,
225
+ crossattn_cache=pipeline.crossattn_cache,
226
+ current_start=current_start_frame * pipeline.frame_seq_length
227
+ )
228
+ if step_idx < len(pipeline.denoising_step_list) - 1:
229
+ next_timestep = pipeline.denoising_step_list[step_idx + 1]
230
+ noisy_input = pipeline.scheduler.add_noise(
231
+ denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)),
232
+ next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
233
+ ).unflatten(0, denoised_pred.shape[:2])
234
+
235
+ if idx < len(all_num_frames) - 1:
236
+ pipeline.generator(
237
+ noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict,
238
+ timestep=torch.zeros_like(timestep), kv_cache=pipeline.kv_cache1,
239
+ crossattn_cache=pipeline.crossattn_cache,
240
+ current_start=current_start_frame * pipeline.frame_seq_length,
241
+ )
242
+
243
+ # Decode to pixels
244
+ if args.trt:
245
+ pixels, vae_cache = pipeline.vae.forward(denoised_pred.half(), *vae_cache)
246
+ elif APP_STATE["current_use_taehv"]:
247
+ if latents_cache is None:
248
+ latents_cache = denoised_pred
249
+ else:
250
+ denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1)
251
+ latents_cache = denoised_pred[:, -3:]
252
+ pixels = pipeline.vae.decode(denoised_pred)
253
+ else:
254
+ pixels, vae_cache = pipeline.vae(denoised_pred.half(), *vae_cache)
255
+
256
+ # Handle frame skipping for first block
257
+ if idx == 0 and not args.trt:
258
+ pixels = pixels[:, 3:]
259
+ elif APP_STATE["current_use_taehv"] and idx > 0:
260
+ pixels = pixels[:, 12:]
261
+
262
+ print(f"📹 Decoded pixels shape: {pixels.shape}")
263
+
264
+ # Yield individual frames WITH status updates
265
+ for frame_idx in range(pixels.shape[1]):
266
+ frame_tensor = pixels[0, frame_idx] # Get single frame [C, H, W]
267
+
268
+ # Normalize from [-1, 1] to [0, 255]
269
+ frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
270
+ frame_np = frame_np.to(torch.uint8).cpu().numpy()
271
+
272
+ # Convert from CHW to HWC format
273
+ frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
274
+
275
+ all_frames_for_video.append(frame_np)
276
+
277
+ # Convert RGB to BGR for FastRTC (OpenCV format)
278
+ frame_bgr = frame_np[:, :, ::-1] # RGB -> BGR
279
+
280
+ total_frames_yielded += 1
281
+ print(f"📺 Yielding frame {total_frames_yielded}: shape {frame_bgr.shape}, dtype {frame_bgr.dtype}")
282
+
283
+ # Calculate progress
284
+ total_expected_frames = num_blocks * pipeline.num_frame_per_block
285
+ current_frame_count = (idx * pipeline.num_frame_per_block) + frame_idx + 1
286
+ frame_progress = 100 * (current_frame_count / total_expected_frames)
287
+
288
+ # --- REVISED HTML START ---
289
+ if frame_idx == pixels.shape[1] - 1 and idx + 1 == num_blocks: # last frame
290
+ status_html = (
291
+ f"<div style='padding: 16px; border: 1px solid #198754; background-color: #d1e7dd; border-radius: 8px; font-family: sans-serif; text-align: center;'>"
292
+ f" <h4 style='margin: 0 0 8px 0; color: #0f5132; font-size: 18px;'>🎉 Generation Complete!</h4>"
293
+ f" <p style='margin: 0; color: #0f5132;'>"
294
+ f" Total frames: {total_frames_yielded}. The final video is now available."
295
+ f" </p>"
296
+ f"</div>"
297
+ )
298
+
299
+ print("💾 Saving final rendered video...")
300
+ video_update = gr.update() # Default to no-op
301
+ try:
302
+ video_path = f"gradio_tmp/{seed}_{hashlib.md5(prompt.encode()).hexdigest()}.mp4"
303
+ imageio.mimwrite(video_path, all_frames_for_video, fps=15, quality=8)
304
+ print(f"✅ Video saved to {video_path}")
305
+ video_update = gr.update(value=video_path, visible=True)
306
+ except Exception as e:
307
+ print(f"⚠️ Could not save final video: {e}")
308
+
309
+ yield frame_bgr, AdditionalOutputs(status_html, video_update, gr.update(visible=False))
310
+ yield CloseStream("🎉 Video generation completed successfully!")
311
+ return
312
+ else: # Regular frames - simpler status
313
+ status_html = (
314
+ f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
315
+ f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
316
+ # Correctly implemented progress bar
317
+ f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
318
+ f" <div style='width: {frame_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
319
+ f" </div>"
320
+ f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
321
+ f" Block {idx+1}/{num_blocks}   |   Frame {total_frames_yielded}   |   {frame_progress:.1f}%"
322
+ f" </p>"
323
+ f"</div>"
324
+ )
325
+ # --- REVISED HTML END ---
326
+
327
+ yield frame_bgr, AdditionalOutputs(status_html, gr.update(visible=False), gr.update(visible=True))
328
+
329
+ current_start_frame += current_num_frames
330
+
331
+ print(f"✅ Video generation completed! Total frames yielded: {total_frames_yielded}")
332
+
333
+ # Signal completion
334
+ yield CloseStream("🎉 Video generation completed successfully!")
335
+
336
+ # --- Gradio UI Layout ---
337
+ with gr.Blocks(theme=gr.themes.Soft(), title="Self-Forcing FastRTC Demo") as demo:
338
+ gr.Markdown("# 🚀 Self-Forcing Video Generation with FastRTC Streaming")
339
+ gr.Markdown("*Real-time video generation streaming via WebRTC*")
340
+
341
+ with gr.Row():
342
+ with gr.Column(scale=2):
343
+ gr.Markdown("### 📝 Configure Generation")
344
+ with gr.Group():
345
+ prompt = gr.Textbox(
346
+ label="Prompt",
347
+ placeholder="A stylish woman walks down a Tokyo street...",
348
+ lines=4,
349
+ value="A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage."
350
+ )
351
+ gr.Examples(
352
+ examples=[
353
+ "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse.",
354
+ "A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something. Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves.",
355
+ "A drone shot of a surfer riding a wave on a sunny day. The camera follows the surfer as they carve through the water.",
356
+ ],
357
+ inputs=[prompt]
358
+ )
359
+
360
+ with gr.Row():
361
+ seed = gr.Number(label="Seed", value=-1, info="Use -1 for a random seed.")
362
+
363
+ with gr.Accordion("⚙️ Performance Options", open=False):
364
+ gr.Markdown("*These optimizations are applied once per session*")
365
+ with gr.Row():
366
+ torch_compile_toggle = gr.Checkbox(label="🔥 torch.compile", value=False)
367
+ fp8_toggle = gr.Checkbox(label="⚡ FP8 Quantization", value=False, visible=not args.trt)
368
+ taehv_toggle = gr.Checkbox(label="⚡ TAEHV VAE", value=False, visible=not args.trt)
369
+
370
+ start_btn = gr.Button("🎬 Start Generation", variant="primary", size="lg")
371
+
372
+ with gr.Column(scale=3):
373
+ gr.Markdown("### 📺 Live Video Stream")
374
+ gr.Markdown("*Click 'Start Generation' to begin streaming*")
375
+
376
+ try:
377
+ rtc_config = get_turn_credentials()
378
+ except Exception as e:
379
+ print(f"Warning: Could not get TURN credentials: {e}")
380
+ rtc_config = None
381
+
382
+ webrtc_output = WebRTC(
383
+ label="Generated Video Stream",
384
+ modality="video",
385
+ mode="receive", # Server sends video to client
386
+ height=480,
387
+ width=832,
388
+ rtc_configuration=rtc_config,
389
+ elem_id="video_stream"
390
+ )
391
+
392
+ final_video = gr.Video(label="Final Rendered Video", visible=False, interactive=False)
393
+
394
+ status_html = gr.HTML(
395
+ value="<div style='text-align: center; padding: 20px; color: #666;'>Ready to start generation...</div>",
396
+ label="Generation Status"
397
+ )
398
+
399
+
400
+
401
+ # Connect the generator to the WebRTC stream
402
+ webrtc_output.stream(
403
+ fn=video_generation_handler,
404
+ inputs=[prompt, seed, torch_compile_toggle, fp8_toggle, taehv_toggle],
405
+ outputs=[webrtc_output],
406
+ time_limit=300, # 5 minutes max
407
+ trigger=start_btn.click,
408
+ )
409
+ # MODIFIED: Handle additional outputs (status updates AND final video)
410
+ webrtc_output.on_additional_outputs(
411
+ fn=handle_additional_outputs,
412
+ outputs=[status_html, final_video, webrtc_output]
413
+ )
414
+
415
+ # --- Launch App ---
416
+ if __name__ == "__main__":
417
+ if os.path.exists("gradio_tmp"):
418
+ import shutil
419
+ shutil.rmtree("gradio_tmp")
420
+ os.makedirs("gradio_tmp", exist_ok=True)
421
+
422
+ demo.queue().launch(
423
+ server_name=args.host,
424
+ server_port=args.port,
425
+ share=args.share,
426
+ show_error=True
427
+ )