jbilcke-hf HF Staff commited on
Commit
b55bb25
·
1 Parent(s): f5f96d3
Files changed (4) hide show
  1. app.py +242 -159
  2. app_last_working.py +0 -460
  3. demo.py +0 -631
  4. utils/wan_wrapper.py +11 -5
app.py CHANGED
@@ -2,11 +2,17 @@ import subprocess
2
  # not sure why it works in the original space but says "pip not found" in mine
3
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
4
 
 
5
  from huggingface_hub import snapshot_download, hf_hub_download
6
 
 
 
 
 
 
7
  snapshot_download(
8
  repo_id="Wan-AI/Wan2.1-T2V-1.3B",
9
- local_dir="wan_models/Wan2.1-T2V-1.3B",
10
  local_dir_use_symlinks=False,
11
  resume_download=True,
12
  repo_type="model"
@@ -15,11 +21,9 @@ snapshot_download(
15
  hf_hub_download(
16
  repo_id="gdhe17/Self-Forcing",
17
  filename="checkpoints/self_forcing_dmd.pt",
18
- local_dir=".",
19
  local_dir_use_symlinks=False
20
  )
21
-
22
- import os
23
  import re
24
  import random
25
  import argparse
@@ -34,6 +38,10 @@ from tqdm import tqdm
34
  import imageio
35
  import av
36
  import uuid
 
 
 
 
37
 
38
  from pipeline import CausalInferencePipeline
39
  from demo_utils.constant import ZERO_VAE_CACHE
@@ -45,11 +53,25 @@ import numpy as np
45
 
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # --- Argument Parsing ---
49
  parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
50
  parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
51
  parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
52
- parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt', help="Path to the model checkpoint.")
53
  parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
54
  parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
55
  parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
@@ -107,6 +129,89 @@ if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
107
  APP_STATE["torch_compile_applied"] = True
108
  print("✅ torch.compile applied to transformer")
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def frames_to_ts_file(frames, filepath, fps = 15):
111
  """
112
  Convert frames directly to .ts file using PyAV.
@@ -193,7 +298,7 @@ def initialize_vae_decoder(use_taehv=False, use_trt=False):
193
  print("Initializing Default VAE Decoder...")
194
  vae_decoder = VAEDecoderWrapper()
195
  try:
196
- vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
197
  decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
198
  vae_decoder.load_state_dict(decoder_state_dict)
199
  except FileNotFoundError:
@@ -222,26 +327,22 @@ pipeline = CausalInferencePipeline(
222
  pipeline.to(dtype=torch.float16).to(gpu)
223
 
224
  @torch.no_grad()
225
- def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5, buffering=2):
226
  """
227
  Generator function that yields .ts video chunks using PyAV for streaming.
228
- Now optimized for block-based processing with smart buffering.
229
  """
230
  if seed == -1:
231
  seed = random.randint(0, 2**32 - 1)
232
 
233
- print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}, duration: {duration}s, buffering: {buffering}s")
234
-
235
- # Show initial buffering status but don't wait - start generating immediately
236
- if buffering > 0:
237
- buffering_status_html = (
238
- f"<div style='padding: 10px; border: 1px solid #ffc107; background: #fff3cd; border-radius: 8px; font-family: sans-serif;'>"
239
- f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>⏳ Buffering...</p>"
240
- f" <p style='margin: 0; color: #856404; font-size: 14px;'>Generating content, will stream when {buffering} seconds of video is ready</p>"
241
- f"</div>"
242
- )
243
- yield None, buffering_status_html
244
-
245
  # Setup
246
  conditional_dict = text_encoder(text_prompts=[prompt])
247
  for key, value in conditional_dict.items():
@@ -260,7 +361,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
260
  # Current setup generates approximately 5 seconds with 7 blocks
261
  # So we scale proportionally
262
  base_duration = 5.0 # seconds
263
- base_blocks = 7
264
  num_blocks = max(1, int(base_blocks * duration / base_duration))
265
 
266
  current_start_frame = 0
@@ -270,13 +371,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
270
 
271
  # Ensure temp directory exists
272
  os.makedirs("gradio_tmp", exist_ok=True)
273
-
274
- # Buffer management - collect chunks before streaming
275
- buffer_chunks = []
276
- buffer_duration = 0.0
277
- frames_per_second = fps
278
- streaming_started = False
279
-
280
  # Generation loop
281
  for idx, current_num_frames in enumerate(all_num_frames):
282
  print(f"📦 Processing block {idx+1}/{num_blocks}")
@@ -375,45 +470,11 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
375
 
376
  frames_to_ts_file(all_frames_from_block, ts_path, fps)
377
 
378
- # Calculate duration of this chunk
379
- chunk_duration = len(all_frames_from_block) / frames_per_second
380
 
381
- # Add to buffer
382
- buffer_chunks.append(ts_path)
383
- buffer_duration += chunk_duration
384
-
385
- # Check if we have enough buffered content to start streaming
386
- if not streaming_started and buffer_duration >= buffering:
387
- print(f"🚀 Buffer filled ({buffer_duration:.2f}s >= {buffering}s), starting stream!")
388
- streaming_started = True
389
-
390
- # Stream all buffered chunks
391
- for buffered_chunk in buffer_chunks:
392
- yield buffered_chunk, gr.update()
393
-
394
- # Clear buffer since we've streamed it
395
- buffer_chunks.clear()
396
- buffer_duration = 0.0
397
-
398
- elif streaming_started:
399
- # Stream immediately if we're already streaming
400
- yield ts_path, gr.update()
401
- elif buffering == 0:
402
- # No buffering requested, stream immediately
403
- yield ts_path, gr.update()
404
- else:
405
- # Still buffering, show progress
406
- buffer_progress = (buffer_duration / buffering) * 100
407
- buffering_progress_html = (
408
- f"<div style='padding: 10px; border: 1px solid #ffc107; background: #fff3cd; border-radius: 8px; font-family: sans-serif;'>"
409
- f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>⏳ Buffering... ({buffer_duration:.1f}s/{buffering}s)</p>"
410
- f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
411
- f" <div style='width: {buffer_progress:.1f}%; height: 20px; background-color: #ffc107; transition: width 0.2s;'></div>"
412
- f" </div>"
413
- f" <p style='margin: 4px 0 0 0; color: #856404; font-size: 14px;'>Generating content for smooth playback...</p>"
414
- f"</div>"
415
- )
416
- yield None, buffering_progress_html
417
 
418
  except Exception as e:
419
  print(f"⚠️ Error encoding block {idx}: {e}")
@@ -422,12 +483,6 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
422
 
423
  current_start_frame += current_num_frames
424
 
425
- # Stream any remaining buffered content
426
- if buffer_chunks:
427
- print(f"🎬 Streaming remaining {len(buffer_chunks)} buffered chunks")
428
- for buffered_chunk in buffer_chunks:
429
- yield buffered_chunk, gr.update()
430
-
431
  # Final completion status
432
  final_status_html = (
433
  f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
@@ -449,104 +504,132 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
449
  print(f"✅ PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
450
 
451
  # --- Gradio UI Layout ---
452
- with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
453
- gr.Markdown("# 🚀 Self-Forcing Video Generation")
454
- gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
455
 
456
- with gr.Row():
457
- with gr.Column(scale=2):
458
- with gr.Group():
459
- prompt = gr.Textbox(
460
- label="Prompt",
461
- placeholder="A stylish woman walks down a Tokyo street...",
462
- lines=4,
463
- value=""
464
- )
465
-
466
- start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
467
 
468
- gr.Markdown("### ⚙️ Settings")
469
  with gr.Row():
470
- seed = gr.Number(
471
- label="Seed",
472
- value=-1,
473
- info="Use -1 for random seed",
474
- precision=0
475
- )
476
- fps = gr.Slider(
477
- label="Playback FPS",
478
- minimum=1,
479
- maximum=30,
480
- value=args.fps,
481
- step=1,
482
- visible=False,
483
- info="Frames per second for playback"
484
- )
485
-
486
  with gr.Row():
487
- duration = gr.Slider(
488
- label="Duration (seconds)",
489
- minimum=1,
490
- maximum=10,
491
- value=5,
492
- step=1,
493
- info="Video duration in seconds"
494
- )
495
- buffering = gr.Slider(
496
- label="Buffering (seconds)",
497
- minimum=0,
498
- maximum=5,
499
- value=2,
500
- step=0.5,
501
- info="Wait time before starting stream"
502
- )
503
-
504
  with gr.Row():
505
- width = gr.Slider(
506
- label="Width",
507
- minimum=320,
508
- maximum=720,
509
- value=400,
510
- step=8,
511
- info="Video width in pixels (8px steps)"
512
- )
513
- height = gr.Slider(
514
- label="Height",
515
- minimum=320,
516
- maximum=720,
517
- value=224,
518
- step=8,
519
- info="Video height in pixels (8px steps)"
520
- )
521
-
522
- with gr.Column(scale=3):
523
- gr.Markdown("### 📺 Video Stream")
524
-
525
- streaming_video = gr.Video(
526
- label="Live Stream",
527
- streaming=True,
528
- loop=True,
529
- height=400,
530
- autoplay=True,
531
- show_label=False
532
- )
533
-
534
- status_display = gr.HTML(
535
- value=(
536
- "<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
537
- "🎬 Ready to start streaming...<br>"
538
- "<small>Configure your prompt and click 'Start Streaming'</small>"
539
- "</div>"
540
- ),
541
- label="Generation Status"
542
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
 
544
  # Connect the generator to the streaming video
545
  start_btn.click(
546
  fn=video_generation_handler_streaming,
547
- inputs=[prompt, seed, fps, width, height, duration, buffering],
548
  outputs=[streaming_video, status_display]
549
  )
 
 
 
 
 
 
 
550
 
551
 
552
  # --- Launch App ---
 
2
  # not sure why it works in the original space but says "pip not found" in mine
3
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
4
 
5
+ import os
6
  from huggingface_hub import snapshot_download, hf_hub_download
7
 
8
+ # Configuration for data paths
9
+ DATA_ROOT = os.path.normpath(os.getenv('DATA_ROOT', '.'))
10
+ WAN_MODELS_PATH = os.path.join(DATA_ROOT, 'wan_models')
11
+ OTHER_MODELS_PATH = os.path.join(DATA_ROOT, 'other_models')
12
+
13
  snapshot_download(
14
  repo_id="Wan-AI/Wan2.1-T2V-1.3B",
15
+ local_dir=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B"),
16
  local_dir_use_symlinks=False,
17
  resume_download=True,
18
  repo_type="model"
 
21
  hf_hub_download(
22
  repo_id="gdhe17/Self-Forcing",
23
  filename="checkpoints/self_forcing_dmd.pt",
24
+ local_dir=OTHER_MODELS_PATH,
25
  local_dir_use_symlinks=False
26
  )
 
 
27
  import re
28
  import random
29
  import argparse
 
38
  import imageio
39
  import av
40
  import uuid
41
+ import tempfile
42
+ import shutil
43
+ from pathlib import Path
44
+ from typing import Dict, Any, List, Optional, Tuple, Union
45
 
46
  from pipeline import CausalInferencePipeline
47
  from demo_utils.constant import ZERO_VAE_CACHE
 
53
 
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
 
56
+ # LoRA Storage Configuration
57
+ STORAGE_PATH = Path(DATA_ROOT) / "storage"
58
+ LORA_PATH = STORAGE_PATH / "loras"
59
+ OUTPUT_PATH = STORAGE_PATH / "output"
60
+
61
+ # Create necessary directories
62
+ STORAGE_PATH.mkdir(parents=True, exist_ok=True)
63
+ LORA_PATH.mkdir(parents=True, exist_ok=True)
64
+ OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
65
+
66
+ # Global variables for LoRA management
67
+ current_lora_id = None
68
+ current_lora_path = None
69
+
70
  # --- Argument Parsing ---
71
  parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
72
  parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
73
  parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
74
+ parser.add_argument("--checkpoint_path", type=str, default=os.path.join(OTHER_MODELS_PATH, 'checkpoints', 'self_forcing_dmd.pt'), help="Path to the model checkpoint.")
75
  parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
76
  parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
77
  parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
 
129
  APP_STATE["torch_compile_applied"] = True
130
  print("✅ torch.compile applied to transformer")
131
 
132
+ def upload_lora_file(file: tempfile._TemporaryFileWrapper) -> Tuple[str, str]:
133
+ """Upload a LoRA file and return a hash-based ID for future reference"""
134
+ if file is None:
135
+ return "", ""
136
+
137
+ try:
138
+ # Calculate SHA256 hash of the file
139
+ sha256_hash = hashlib.sha256()
140
+ with open(file.name, "rb") as f:
141
+ for chunk in iter(lambda: f.read(4096), b""):
142
+ sha256_hash.update(chunk)
143
+ file_hash = sha256_hash.hexdigest()
144
+
145
+ # Create destination path using hash
146
+ dest_path = LORA_PATH / f"{file_hash}.safetensors"
147
+
148
+ # Check if file already exists
149
+ if dest_path.exists():
150
+ print(f"LoRA file already exists!")
151
+ return file_hash, file_hash
152
+
153
+ # Copy the file to the destination
154
+ shutil.copy(file.name, dest_path)
155
+
156
+ print(f"LoRA file uploaded!")
157
+ return file_hash, file_hash
158
+ except Exception as e:
159
+ print(f"Error uploading LoRA file: {e}")
160
+ raise gr.Error(f"Failed to upload LoRA file: {str(e)}")
161
+
162
+ def get_lora_file_path(lora_id: Optional[str]) -> Optional[Path]:
163
+ """Get the path to a LoRA file from its hash-based ID"""
164
+ if not lora_id:
165
+ return None
166
+
167
+ # Check if file exists
168
+ lora_path = LORA_PATH / f"{lora_id}.safetensors"
169
+ if lora_path.exists():
170
+ return lora_path
171
+
172
+ return None
173
+
174
+ def manage_lora_weights(lora_id: Optional[str], lora_weight: float) -> Tuple[bool, Optional[Path]]:
175
+ """Manage LoRA weights for the transformer model"""
176
+ global current_lora_id, current_lora_path
177
+
178
+ # Determine if we should use LoRA
179
+ using_lora = lora_id is not None and lora_id.strip() != "" and lora_weight > 0
180
+
181
+ # If not using LoRA but we have one loaded, clear it
182
+ if not using_lora and current_lora_id is not None:
183
+ print(f"Clearing current LoRA")
184
+ current_lora_id = None
185
+ current_lora_path = None
186
+ return False, None
187
+
188
+ # If using LoRA, check if we need to change weights
189
+ if using_lora:
190
+ lora_path = get_lora_file_path(lora_id)
191
+
192
+ if not lora_path:
193
+ print(f"A LoRA file with this ID was found. Using base model instead.")
194
+
195
+ # If we had a LoRA loaded, clear it
196
+ if current_lora_id is not None:
197
+ print(f"Clearing current LoRA")
198
+ current_lora_id = None
199
+ current_lora_path = None
200
+
201
+ return False, None
202
+
203
+ # If LoRA ID changed, update
204
+ if lora_id != current_lora_id:
205
+ print(f"Loading LoRA..")
206
+ current_lora_id = lora_id
207
+ current_lora_path = lora_path
208
+ else:
209
+ print(f"Using a LoRA!")
210
+
211
+ return True, lora_path
212
+
213
+ return False, None
214
+
215
  def frames_to_ts_file(frames, filepath, fps = 15):
216
  """
217
  Convert frames directly to .ts file using PyAV.
 
298
  print("Initializing Default VAE Decoder...")
299
  vae_decoder = VAEDecoderWrapper()
300
  try:
301
+ vae_state_dict = torch.load(os.path.join(WAN_MODELS_PATH, 'Wan2.1-T2V-1.3B', 'Wan2.1_VAE.pth'), map_location="cpu")
302
  decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
303
  vae_decoder.load_state_dict(decoder_state_dict)
304
  except FileNotFoundError:
 
327
  pipeline.to(dtype=torch.float16).to(gpu)
328
 
329
  @torch.no_grad()
330
+ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5, lora_id=None, lora_weight=0.0):
331
  """
332
  Generator function that yields .ts video chunks using PyAV for streaming.
 
333
  """
334
  if seed == -1:
335
  seed = random.randint(0, 2**32 - 1)
336
 
337
+ # print(f"🎬 Starting PyAV streaming: seed: {seed}, duration: {duration}s")
338
+
339
+ # Handle LoRA weights
340
+ using_lora, lora_path = manage_lora_weights(lora_id, lora_weight)
341
+ if using_lora:
342
+ print(f"🎨 Using LoRA with weight factor {lora_weight}")
343
+ else:
344
+ print("🎨 Using base model (no LoRA)")
345
+
 
 
 
346
  # Setup
347
  conditional_dict = text_encoder(text_prompts=[prompt])
348
  for key, value in conditional_dict.items():
 
361
  # Current setup generates approximately 5 seconds with 7 blocks
362
  # So we scale proportionally
363
  base_duration = 5.0 # seconds
364
+ base_blocks = 8
365
  num_blocks = max(1, int(base_blocks * duration / base_duration))
366
 
367
  current_start_frame = 0
 
371
 
372
  # Ensure temp directory exists
373
  os.makedirs("gradio_tmp", exist_ok=True)
374
+
 
 
 
 
 
 
375
  # Generation loop
376
  for idx, current_num_frames in enumerate(all_num_frames):
377
  print(f"📦 Processing block {idx+1}/{num_blocks}")
 
470
 
471
  frames_to_ts_file(all_frames_from_block, ts_path, fps)
472
 
473
+ # Calculate final progress for this block
474
+ total_progress = (idx + 1) / num_blocks * 100
475
 
476
+ # Yield the actual video chunk
477
+ yield ts_path, gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
 
479
  except Exception as e:
480
  print(f"⚠️ Error encoding block {idx}: {e}")
 
483
 
484
  current_start_frame += current_num_frames
485
 
 
 
 
 
 
 
486
  # Final completion status
487
  final_status_html = (
488
  f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
 
504
  print(f"✅ PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
505
 
506
  # --- Gradio UI Layout ---
507
+ with gr.Blocks(title="Wan2.1 1.3B LoRA Self-Forcing streaming demo") as demo:
508
+ gr.Markdown("# 🚀 Run Any LoRA in near real-time!")
509
+ gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B and LoRA [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
510
 
511
+ with gr.Tabs():
512
+ # LoRA Upload Tab
513
+ with gr.TabItem("1️⃣ Upload LoRA"):
514
+ gr.Markdown("## Upload LoRA Weights")
515
+ gr.Markdown("Upload your custom LoRA weights file to use for generation. The file will be automatically stored and you'll receive a unique hash-based ID.")
 
 
 
 
 
 
516
 
 
517
  with gr.Row():
518
+ lora_file = gr.File(label="LoRA File (safetensors format)")
519
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  with gr.Row():
521
+ lora_id_output = gr.Textbox(label="LoRA Hash ID (use this in the generation tab)", interactive=False)
522
+
523
+ # Video Generation Tab
524
+ with gr.TabItem("2️⃣ Generate Video"):
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  with gr.Row():
526
+ with gr.Column(scale=2):
527
+ with gr.Group():
528
+ prompt = gr.Textbox(
529
+ label="Prompt",
530
+ placeholder="A stylish woman walks down a Tokyo street...",
531
+ lines=4,
532
+ value=""
533
+ )
534
+
535
+ start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
536
+
537
+ gr.Markdown("### ⚙️ Settings")
538
+ with gr.Row():
539
+ seed = gr.Number(
540
+ label="Seed",
541
+ value=-1,
542
+ info="Use -1 for random seed",
543
+ precision=0
544
+ )
545
+ fps = gr.Slider(
546
+ label="Playback FPS",
547
+ minimum=1,
548
+ maximum=30,
549
+ value=args.fps,
550
+ step=1,
551
+ visible=False,
552
+ info="Frames per second for playback"
553
+ )
554
+
555
+ with gr.Row():
556
+ duration = gr.Slider(
557
+ label="Duration (seconds)",
558
+ minimum=1,
559
+ maximum=5,
560
+ value=3,
561
+ step=1,
562
+ info="Video duration in seconds"
563
+ )
564
+
565
+ with gr.Row():
566
+ width = gr.Slider(
567
+ label="Width",
568
+ minimum=224,
569
+ maximum=720,
570
+ value=400,
571
+ step=8,
572
+ info="Video width in pixels (8px steps)"
573
+ )
574
+ height = gr.Slider(
575
+ label="Height",
576
+ minimum=224,
577
+ maximum=720,
578
+ value=224,
579
+ step=8,
580
+ info="Video height in pixels (8px steps)"
581
+ )
582
+
583
+ gr.Markdown("### 🎨 LoRA Settings")
584
+ lora_id = gr.Textbox(
585
+ label="LoRA ID (from upload tab)",
586
+ placeholder="Enter your LoRA ID here...",
587
+ )
588
+
589
+ lora_weight = gr.Slider(
590
+ label="LoRA Weight",
591
+ minimum=0.0,
592
+ maximum=1.0,
593
+ step=0.01,
594
+ value=1.0,
595
+ info="Strength of LoRA influence"
596
+ )
597
+
598
+ with gr.Column(scale=3):
599
+ gr.Markdown("### 📺 Video Stream")
600
+
601
+ streaming_video = gr.Video(
602
+ label="Live Stream",
603
+ streaming=True,
604
+ loop=True,
605
+ height=400,
606
+ autoplay=True,
607
+ show_label=False
608
+ )
609
+
610
+ status_display = gr.HTML(
611
+ value=(
612
+ "<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
613
+ "🎬 Ready to start streaming...<br>"
614
+ "<small>Configure your prompt and click 'Start Streaming'</small>"
615
+ "</div>"
616
+ ),
617
+ label="Generation Status"
618
+ )
619
 
620
  # Connect the generator to the streaming video
621
  start_btn.click(
622
  fn=video_generation_handler_streaming,
623
+ inputs=[prompt, seed, fps, width, height, duration, lora_id, lora_weight],
624
  outputs=[streaming_video, status_display]
625
  )
626
+
627
+ # Connect LoRA upload to both display fields
628
+ lora_file.change(
629
+ fn=upload_lora_file,
630
+ inputs=[lora_file],
631
+ outputs=[lora_id_output, lora_id]
632
+ )
633
 
634
 
635
  # --- Launch App ---
app_last_working.py DELETED
@@ -1,460 +0,0 @@
1
- import subprocess
2
- # not sure why it works in the original space but says "pip not found" in mine
3
- #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
4
-
5
- from huggingface_hub import snapshot_download, hf_hub_download
6
-
7
- snapshot_download(
8
- repo_id="Wan-AI/Wan2.1-T2V-1.3B",
9
- local_dir="wan_models/Wan2.1-T2V-1.3B",
10
- local_dir_use_symlinks=False,
11
- resume_download=True,
12
- repo_type="model"
13
- )
14
-
15
- hf_hub_download(
16
- repo_id="gdhe17/Self-Forcing",
17
- filename="checkpoints/self_forcing_dmd.pt",
18
- local_dir=".",
19
- local_dir_use_symlinks=False
20
- )
21
-
22
- import os
23
- import re
24
- import random
25
- import argparse
26
- import hashlib
27
- import urllib.request
28
- import time
29
- from PIL import Image
30
- import torch
31
- import gradio as gr
32
- from omegaconf import OmegaConf
33
- from tqdm import tqdm
34
- import imageio
35
- import av
36
- import uuid
37
-
38
- from pipeline import CausalInferencePipeline
39
- from demo_utils.constant import ZERO_VAE_CACHE
40
- from demo_utils.vae_block3 import VAEDecoderWrapper
41
- from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
42
-
43
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM #, BitsAndBytesConfig
44
- import numpy as np
45
-
46
- device = "cuda" if torch.cuda.is_available() else "cpu"
47
-
48
- # --- Argument Parsing ---
49
- parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
50
- parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
51
- parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
52
- parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt', help="Path to the model checkpoint.")
53
- parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
54
- parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
55
- parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
56
- parser.add_argument('--fps', type=float, default=15.0, help="Playback FPS for frame streaming.")
57
- args = parser.parse_args()
58
-
59
- gpu = "cuda"
60
-
61
- try:
62
- config = OmegaConf.load(args.config_path)
63
- default_config = OmegaConf.load("configs/default_config.yaml")
64
- config = OmegaConf.merge(default_config, config)
65
- except FileNotFoundError as e:
66
- print(f"Error loading config file: {e}\n. Please ensure config files are in the correct path.")
67
- exit(1)
68
-
69
- # Initialize Models
70
- print("Initializing models...")
71
- text_encoder = WanTextEncoder()
72
- transformer = WanDiffusionWrapper(is_causal=True)
73
-
74
- try:
75
- state_dict = torch.load(args.checkpoint_path, map_location="cpu")
76
- transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator')))
77
- except FileNotFoundError as e:
78
- print(f"Error loading checkpoint: {e}\nPlease ensure the checkpoint '{args.checkpoint_path}' exists.")
79
- exit(1)
80
-
81
- text_encoder.eval().to(dtype=torch.float16).requires_grad_(False)
82
- transformer.eval().to(dtype=torch.float16).requires_grad_(False)
83
-
84
- text_encoder.to(gpu)
85
- transformer.to(gpu)
86
-
87
- APP_STATE = {
88
- "torch_compile_applied": False,
89
- "fp8_applied": False,
90
- "current_use_taehv": False,
91
- "current_vae_decoder": None,
92
- }
93
-
94
- def frames_to_ts_file(frames, filepath, fps = 15):
95
- """
96
- Convert frames directly to .ts file using PyAV.
97
-
98
- Args:
99
- frames: List of numpy arrays (HWC, RGB, uint8)
100
- filepath: Output file path
101
- fps: Frames per second
102
-
103
- Returns:
104
- The filepath of the created file
105
- """
106
- if not frames:
107
- return filepath
108
-
109
- height, width = frames[0].shape[:2]
110
-
111
- # Create container for MPEG-TS format
112
- container = av.open(filepath, mode='w', format='mpegts')
113
-
114
- # Add video stream with optimized settings for streaming
115
- stream = container.add_stream('h264', rate=fps)
116
- stream.width = width
117
- stream.height = height
118
- stream.pix_fmt = 'yuv420p'
119
-
120
- # Optimize for low latency streaming
121
- stream.options = {
122
- 'preset': 'ultrafast',
123
- 'tune': 'zerolatency',
124
- 'crf': '23',
125
- 'profile': 'baseline',
126
- 'level': '3.0'
127
- }
128
-
129
- try:
130
- for frame_np in frames:
131
- frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
132
- frame = frame.reformat(format=stream.pix_fmt)
133
- for packet in stream.encode(frame):
134
- container.mux(packet)
135
-
136
- for packet in stream.encode():
137
- container.mux(packet)
138
-
139
- finally:
140
- container.close()
141
-
142
- return filepath
143
-
144
- def initialize_vae_decoder(use_taehv=False, use_trt=False):
145
- if use_trt:
146
- from demo_utils.vae import VAETRTWrapper
147
- print("Initializing TensorRT VAE Decoder...")
148
- vae_decoder = VAETRTWrapper()
149
- APP_STATE["current_use_taehv"] = False
150
- elif use_taehv:
151
- print("Initializing TAEHV VAE Decoder...")
152
- from demo_utils.taehv import TAEHV
153
- taehv_checkpoint_path = "checkpoints/taew2_1.pth"
154
- if not os.path.exists(taehv_checkpoint_path):
155
- print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...")
156
- os.makedirs("checkpoints", exist_ok=True)
157
- download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
158
- try:
159
- urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
160
- except Exception as e:
161
- raise RuntimeError(f"Failed to download taew2_1.pth: {e}")
162
-
163
- class DotDict(dict): __getattr__ = dict.get
164
-
165
- class TAEHVDiffusersWrapper(torch.nn.Module):
166
- def __init__(self):
167
- super().__init__()
168
- self.dtype = torch.float16
169
- self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
170
- self.config = DotDict(scaling_factor=1.0)
171
- def decode(self, latents, return_dict=None):
172
- return self.taehv.decode_video(latents, parallel=not LOW_MEMORY).mul_(2).sub_(1)
173
-
174
- vae_decoder = TAEHVDiffusersWrapper()
175
- APP_STATE["current_use_taehv"] = True
176
- else:
177
- print("Initializing Default VAE Decoder...")
178
- vae_decoder = VAEDecoderWrapper()
179
- try:
180
- vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
181
- decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
182
- vae_decoder.load_state_dict(decoder_state_dict)
183
- except FileNotFoundError:
184
- print("Warning: Default VAE weights not found.")
185
- APP_STATE["current_use_taehv"] = False
186
-
187
- vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(gpu)
188
- APP_STATE["current_vae_decoder"] = vae_decoder
189
- print(f"✅ VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
190
-
191
- # Initialize with default VAE
192
- initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
193
-
194
- pipeline = CausalInferencePipeline(
195
- config, device=gpu, generator=transformer, text_encoder=text_encoder,
196
- vae=APP_STATE["current_vae_decoder"]
197
- )
198
-
199
- pipeline.to(dtype=torch.float16).to(gpu)
200
-
201
- @torch.no_grad()
202
- def video_generation_handler_streaming(prompt, seed=42, fps=15):
203
- """
204
- Generator function that yields .ts video chunks using PyAV for streaming.
205
- Now optimized for block-based processing.
206
- """
207
- if seed == -1:
208
- seed = random.randint(0, 2**32 - 1)
209
-
210
- print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
211
-
212
- # Setup
213
- conditional_dict = text_encoder(text_prompts=[prompt])
214
- for key, value in conditional_dict.items():
215
- conditional_dict[key] = value.to(dtype=torch.float16)
216
-
217
- rnd = torch.Generator(gpu).manual_seed(int(seed))
218
- pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
219
- pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
220
- noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
221
-
222
- vae_cache, latents_cache = None, None
223
- if not APP_STATE["current_use_taehv"] and not args.trt:
224
- vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
225
-
226
- num_blocks = 7
227
- current_start_frame = 0
228
- all_num_frames = [pipeline.num_frame_per_block] * num_blocks
229
-
230
- total_frames_yielded = 0
231
-
232
- # Ensure temp directory exists
233
- os.makedirs("gradio_tmp", exist_ok=True)
234
-
235
- # Generation loop
236
- for idx, current_num_frames in enumerate(all_num_frames):
237
- print(f"📦 Processing block {idx+1}/{num_blocks}")
238
-
239
- noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames]
240
-
241
- # Denoising steps
242
- for step_idx, current_timestep in enumerate(pipeline.denoising_step_list):
243
- timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep
244
- _, denoised_pred = pipeline.generator(
245
- noisy_image_or_video=noisy_input, conditional_dict=conditional_dict,
246
- timestep=timestep, kv_cache=pipeline.kv_cache1,
247
- crossattn_cache=pipeline.crossattn_cache,
248
- current_start=current_start_frame * pipeline.frame_seq_length
249
- )
250
- if step_idx < len(pipeline.denoising_step_list) - 1:
251
- next_timestep = pipeline.denoising_step_list[step_idx + 1]
252
- noisy_input = pipeline.scheduler.add_noise(
253
- denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)),
254
- next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
255
- ).unflatten(0, denoised_pred.shape[:2])
256
-
257
- if idx < len(all_num_frames) - 1:
258
- pipeline.generator(
259
- noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict,
260
- timestep=torch.zeros_like(timestep), kv_cache=pipeline.kv_cache1,
261
- crossattn_cache=pipeline.crossattn_cache,
262
- current_start=current_start_frame * pipeline.frame_seq_length,
263
- )
264
-
265
- # Decode to pixels
266
- if args.trt:
267
- pixels, vae_cache = pipeline.vae.forward(denoised_pred.half(), *vae_cache)
268
- elif APP_STATE["current_use_taehv"]:
269
- if latents_cache is None:
270
- latents_cache = denoised_pred
271
- else:
272
- denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1)
273
- latents_cache = denoised_pred[:, -3:]
274
- pixels = pipeline.vae.decode(denoised_pred)
275
- else:
276
- pixels, vae_cache = pipeline.vae(denoised_pred.half(), *vae_cache)
277
-
278
- # Handle frame skipping
279
- if idx == 0 and not args.trt:
280
- pixels = pixels[:, 3:]
281
- elif APP_STATE["current_use_taehv"] and idx > 0:
282
- pixels = pixels[:, 12:]
283
-
284
- print(f"🔍 DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
285
-
286
- # Process all frames from this block at once
287
- all_frames_from_block = []
288
- for frame_idx in range(pixels.shape[1]):
289
- frame_tensor = pixels[0, frame_idx]
290
-
291
- # Convert to numpy (HWC, RGB, uint8)
292
- frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
293
- frame_np = frame_np.to(torch.uint8).cpu().numpy()
294
- frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
295
-
296
- all_frames_from_block.append(frame_np)
297
- total_frames_yielded += 1
298
-
299
- # Yield status update for each frame (cute tracking!)
300
- blocks_completed = idx
301
- current_block_progress = (frame_idx + 1) / pixels.shape[1]
302
- total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
303
-
304
- # Cap at 100% to avoid going over
305
- total_progress = min(total_progress, 100.0)
306
-
307
- frame_status_html = (
308
- f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
309
- f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
310
- f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
311
- f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
312
- f" </div>"
313
- f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
314
- f" Block {idx+1}/{num_blocks} | Frame {total_frames_yielded} | {total_progress:.1f}%"
315
- f" </p>"
316
- f"</div>"
317
- )
318
-
319
- # Yield None for video but update status (frame-by-frame tracking)
320
- yield None, frame_status_html
321
-
322
- # Encode entire block as one chunk immediately
323
- if all_frames_from_block:
324
- print(f"📹 Encoding block {idx} with {len(all_frames_from_block)} frames")
325
-
326
- try:
327
- chunk_uuid = str(uuid.uuid4())[:8]
328
- ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
329
- ts_path = os.path.join("gradio_tmp", ts_filename)
330
-
331
- frames_to_ts_file(all_frames_from_block, ts_path, fps)
332
-
333
- # Calculate final progress for this block
334
- total_progress = (idx + 1) / num_blocks * 100
335
-
336
- # Yield the actual video chunk
337
- yield ts_path, gr.update()
338
-
339
- except Exception as e:
340
- print(f"⚠️ Error encoding block {idx}: {e}")
341
- import traceback
342
- traceback.print_exc()
343
-
344
- current_start_frame += current_num_frames
345
-
346
- # Final completion status
347
- final_status_html = (
348
- f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
349
- f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
350
- f" <span style='font-size: 24px; margin-right: 12px;'>🎉</span>"
351
- f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Stream Complete!</h4>"
352
- f" </div>"
353
- f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>"
354
- f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>"
355
- f" 📊 Generated {total_frames_yielded} frames across {num_blocks} blocks"
356
- f" </p>"
357
- f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
358
- f" 🎬 Playback: {fps} FPS • 📁 Format: MPEG-TS/H.264"
359
- f" </p>"
360
- f" </div>"
361
- f"</div>"
362
- )
363
- yield None, final_status_html
364
- print(f"✅ PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
365
-
366
- # --- Gradio UI Layout ---
367
- with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
368
- gr.Markdown("# 🚀 Self-Forcing Video Generation")
369
- gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
370
-
371
- with gr.Row():
372
- with gr.Column(scale=2):
373
- with gr.Group():
374
- prompt = gr.Textbox(
375
- label="Prompt",
376
- placeholder="A stylish woman walks down a Tokyo street...",
377
- lines=4,
378
- value=""
379
- )
380
-
381
- start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
382
-
383
- gr.Markdown("### 🎯 Examples")
384
- gr.Examples(
385
- examples=[
386
- "A close-up shot of a ceramic teacup slowly pouring water into a glass mug.",
387
- "A playful cat is seen playing an electronic guitar, strumming the strings with its front paws. The cat has distinctive black facial markings and a bushy tail. It sits comfortably on a small stool, its body slightly tilted as it focuses intently on the instrument. The setting is a cozy, dimly lit room with vintage posters on the walls, adding a retro vibe. The cat's expressive eyes convey a sense of joy and concentration. Medium close-up shot, focusing on the cat's face and hands interacting with the guitar.",
388
- "A dynamic over-the-shoulder perspective of a chef meticulously plating a dish in a bustling kitchen. The chef, a middle-aged woman, deftly arranges ingredients on a pristine white plate. Her hands move with precision, each gesture deliberate and practiced. The background shows a crowded kitchen with steaming pots, whirring blenders, and the clatter of utensils. Bright lights highlight the scene, casting shadows across the busy workspace. The camera angle captures the chef's detailed work from behind, emphasizing his skill and dedication.",
389
- ],
390
- inputs=[prompt],
391
- )
392
-
393
- gr.Markdown("### ⚙️ Settings")
394
- with gr.Row():
395
- seed = gr.Number(
396
- label="Seed",
397
- value=-1,
398
- info="Use -1 for random seed",
399
- precision=0
400
- )
401
- fps = gr.Slider(
402
- label="Playback FPS",
403
- minimum=1,
404
- maximum=30,
405
- value=args.fps,
406
- step=1,
407
- visible=False,
408
- info="Frames per second for playback"
409
- )
410
-
411
- with gr.Column(scale=3):
412
- gr.Markdown("### 📺 Video Stream")
413
-
414
- streaming_video = gr.Video(
415
- label="Live Stream",
416
- streaming=True,
417
- loop=True,
418
- height=400,
419
- autoplay=True,
420
- show_label=False
421
- )
422
-
423
- status_display = gr.HTML(
424
- value=(
425
- "<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
426
- "🎬 Ready to start streaming...<br>"
427
- "<small>Configure your prompt and click 'Start Streaming'</small>"
428
- "</div>"
429
- ),
430
- label="Generation Status"
431
- )
432
-
433
- # Connect the generator to the streaming video
434
- start_btn.click(
435
- fn=video_generation_handler_streaming,
436
- inputs=[prompt, seed, fps],
437
- outputs=[streaming_video, status_display]
438
- )
439
-
440
-
441
- # --- Launch App ---
442
- if __name__ == "__main__":
443
- if os.path.exists("gradio_tmp"):
444
- import shutil
445
- shutil.rmtree("gradio_tmp")
446
- os.makedirs("gradio_tmp", exist_ok=True)
447
-
448
- print("🚀 Starting Self-Forcing Streaming Demo")
449
- print(f"📁 Temporary files will be stored in: gradio_tmp/")
450
- print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
451
- print(f"⚡ GPU acceleration: {gpu}")
452
-
453
- demo.queue().launch(
454
- server_name=args.host,
455
- server_port=args.port,
456
- share=args.share,
457
- show_error=True,
458
- max_threads=40,
459
- mcp_server=True
460
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.py DELETED
@@ -1,631 +0,0 @@
1
- """
2
- Demo for Self-Forcing.
3
- """
4
-
5
- import os
6
- import re
7
- import random
8
- import time
9
- import base64
10
- import argparse
11
- import hashlib
12
- import subprocess
13
- import urllib.request
14
- from io import BytesIO
15
- from PIL import Image
16
- import numpy as np
17
- import torch
18
- from omegaconf import OmegaConf
19
- from flask import Flask, render_template, jsonify
20
- from flask_socketio import SocketIO, emit
21
- import queue
22
- from threading import Thread, Event
23
-
24
- from pipeline import CausalInferencePipeline
25
- from demo_utils.constant import ZERO_VAE_CACHE
26
- from demo_utils.vae_block3 import VAEDecoderWrapper
27
- from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
28
- from demo_utils.utils import generate_timestamp
29
- from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation
30
-
31
- # Parse arguments
32
- parser = argparse.ArgumentParser()
33
- parser.add_argument('--port', type=int, default=5001)
34
- parser.add_argument('--host', type=str, default='0.0.0.0')
35
- parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt')
36
- parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml')
37
- parser.add_argument('--trt', action='store_true')
38
- args = parser.parse_args()
39
-
40
- print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
41
- low_memory = get_cuda_free_memory_gb(gpu) < 40
42
-
43
- # Load models
44
- config = OmegaConf.load(args.config_path)
45
- default_config = OmegaConf.load("configs/default_config.yaml")
46
- config = OmegaConf.merge(default_config, config)
47
-
48
- text_encoder = WanTextEncoder()
49
-
50
- # Global variables for dynamic model switching
51
- current_vae_decoder = None
52
- current_use_taehv = False
53
- fp8_applied = False
54
- torch_compile_applied = False
55
- global frame_number
56
- frame_number = 0
57
- anim_name = ""
58
- frame_rate = 6
59
-
60
- def initialize_vae_decoder(use_taehv=False, use_trt=False):
61
- """Initialize VAE decoder based on the selected option"""
62
- global current_vae_decoder, current_use_taehv
63
-
64
- if use_trt:
65
- from demo_utils.vae import VAETRTWrapper
66
- current_vae_decoder = VAETRTWrapper()
67
- return current_vae_decoder
68
-
69
- if use_taehv:
70
- from demo_utils.taehv import TAEHV
71
- # Check if taew2_1.pth exists in checkpoints folder, download if missing
72
- taehv_checkpoint_path = "checkpoints/taew2_1.pth"
73
- if not os.path.exists(taehv_checkpoint_path):
74
- print(f"taew2_1.pth not found in checkpoints folder {taehv_checkpoint_path}. Downloading...")
75
- os.makedirs("checkpoints", exist_ok=True)
76
- download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
77
- try:
78
- urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
79
- print(f"Successfully downloaded taew2_1.pth to {taehv_checkpoint_path}")
80
- except Exception as e:
81
- print(f"Failed to download taew2_1.pth: {e}")
82
- raise
83
-
84
- class DotDict(dict):
85
- __getattr__ = dict.__getitem__
86
- __setattr__ = dict.__setitem__
87
-
88
- class TAEHVDiffusersWrapper(torch.nn.Module):
89
- def __init__(self):
90
- super().__init__()
91
- self.dtype = torch.float16
92
- self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
93
- self.config = DotDict(scaling_factor=1.0)
94
-
95
- def decode(self, latents, return_dict=None):
96
- # n, c, t, h, w = latents.shape
97
- # low-memory, set parallel=True for faster + higher memory
98
- return self.taehv.decode_video(latents, parallel=False).mul_(2).sub_(1)
99
-
100
- current_vae_decoder = TAEHVDiffusersWrapper()
101
- else:
102
- current_vae_decoder = VAEDecoderWrapper()
103
- vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
104
- decoder_state_dict = {}
105
- for key, value in vae_state_dict.items():
106
- if 'decoder.' in key or 'conv2' in key:
107
- decoder_state_dict[key] = value
108
- current_vae_decoder.load_state_dict(decoder_state_dict)
109
-
110
- current_vae_decoder.eval()
111
- current_vae_decoder.to(dtype=torch.float16)
112
- current_vae_decoder.requires_grad_(False)
113
- current_vae_decoder.to(gpu)
114
- current_use_taehv = use_taehv
115
-
116
- print(f"✅ VAE decoder initialized with {'TAEHV' if use_taehv else 'default VAE'}")
117
- return current_vae_decoder
118
-
119
-
120
- # Initialize with default VAE
121
- vae_decoder = initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
122
-
123
- transformer = WanDiffusionWrapper(is_causal=True)
124
- state_dict = torch.load(args.checkpoint_path, map_location="cpu")
125
- transformer.load_state_dict(state_dict['generator_ema'])
126
-
127
- text_encoder.eval()
128
- transformer.eval()
129
-
130
- transformer.to(dtype=torch.float16)
131
- text_encoder.to(dtype=torch.bfloat16)
132
-
133
- text_encoder.requires_grad_(False)
134
- transformer.requires_grad_(False)
135
-
136
- pipeline = CausalInferencePipeline(
137
- config,
138
- device=gpu,
139
- generator=transformer,
140
- text_encoder=text_encoder,
141
- vae=vae_decoder
142
- )
143
-
144
- if low_memory:
145
- DynamicSwapInstaller.install_model(text_encoder, device=gpu)
146
- else:
147
- text_encoder.to(gpu)
148
- transformer.to(gpu)
149
-
150
- # Flask and SocketIO setup
151
- app = Flask(__name__)
152
- app.config['SECRET_KEY'] = 'frontend_buffered_demo'
153
- socketio = SocketIO(app, cors_allowed_origins="*")
154
-
155
- generation_active = False
156
- stop_event = Event()
157
- frame_send_queue = queue.Queue()
158
- sender_thread = None
159
- models_compiled = False
160
-
161
-
162
- def tensor_to_base64_frame(frame_tensor):
163
- """Convert a single frame tensor to base64 image string."""
164
- global frame_number, anim_name
165
- # Clamp and normalize to 0-255
166
- frame = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
167
- frame = frame.to(torch.uint8).cpu().numpy()
168
-
169
- # CHW -> HWC
170
- if len(frame.shape) == 3:
171
- frame = np.transpose(frame, (1, 2, 0))
172
-
173
- # Convert to PIL Image
174
- if frame.shape[2] == 3: # RGB
175
- image = Image.fromarray(frame, 'RGB')
176
- else: # Handle other formats
177
- image = Image.fromarray(frame)
178
-
179
- # Convert to base64
180
- buffer = BytesIO()
181
- image.save(buffer, format='JPEG', quality=100)
182
- if not os.path.exists("./images/%s" % anim_name):
183
- os.makedirs("./images/%s" % anim_name)
184
- frame_number += 1
185
- image.save("./images/%s/%s_%03d.jpg" % (anim_name, anim_name, frame_number))
186
- img_str = base64.b64encode(buffer.getvalue()).decode()
187
- return f"data:image/jpeg;base64,{img_str}"
188
-
189
-
190
- def frame_sender_worker():
191
- """Background thread that processes frame send queue non-blocking."""
192
- global frame_send_queue, generation_active, stop_event
193
-
194
- print("📡 Frame sender thread started")
195
-
196
- while True:
197
- frame_data = None
198
- try:
199
- # Get frame data from queue
200
- frame_data = frame_send_queue.get(timeout=1.0)
201
-
202
- if frame_data is None: # Shutdown signal
203
- frame_send_queue.task_done() # Mark shutdown signal as done
204
- break
205
-
206
- frame_tensor, frame_index, block_index, job_id = frame_data
207
-
208
- # Convert tensor to base64
209
- base64_frame = tensor_to_base64_frame(frame_tensor)
210
-
211
- # Send via SocketIO
212
- try:
213
- socketio.emit('frame_ready', {
214
- 'data': base64_frame,
215
- 'frame_index': frame_index,
216
- 'block_index': block_index,
217
- 'job_id': job_id
218
- })
219
- except Exception as e:
220
- print(f"⚠️ Failed to send frame {frame_index}: {e}")
221
-
222
- frame_send_queue.task_done()
223
-
224
- except queue.Empty:
225
- # Check if we should continue running
226
- if not generation_active and frame_send_queue.empty():
227
- break
228
- except Exception as e:
229
- print(f"❌ Frame sender error: {e}")
230
- # Make sure to mark task as done even if there's an error
231
- if frame_data is not None:
232
- try:
233
- frame_send_queue.task_done()
234
- except Exception as e:
235
- print(f"❌ Failed to mark frame task as done: {e}")
236
- break
237
-
238
- print("📡 Frame sender thread stopped")
239
-
240
-
241
- @torch.no_grad()
242
- def generate_video_stream(prompt, seed, enable_torch_compile=False, enable_fp8=False, use_taehv=False):
243
- """Generate video and push frames immediately to frontend."""
244
- global generation_active, stop_event, frame_send_queue, sender_thread, models_compiled, torch_compile_applied, fp8_applied, current_vae_decoder, current_use_taehv, frame_rate, anim_name
245
-
246
- try:
247
- generation_active = True
248
- stop_event.clear()
249
- job_id = generate_timestamp()
250
-
251
- # Start frame sender thread if not already running
252
- if sender_thread is None or not sender_thread.is_alive():
253
- sender_thread = Thread(target=frame_sender_worker, daemon=True)
254
- sender_thread.start()
255
-
256
- # Emit progress updates
257
- def emit_progress(message, progress):
258
- try:
259
- socketio.emit('progress', {
260
- 'message': message,
261
- 'progress': progress,
262
- 'job_id': job_id
263
- })
264
- except Exception as e:
265
- print(f"❌ Failed to emit progress: {e}")
266
-
267
- emit_progress('Starting generation...', 0)
268
-
269
- # Handle VAE decoder switching
270
- if use_taehv != current_use_taehv:
271
- emit_progress('Switching VAE decoder...', 2)
272
- print(f"🔄 Switching VAE decoder to {'TAEHV' if use_taehv else 'default VAE'}")
273
- current_vae_decoder = initialize_vae_decoder(use_taehv=use_taehv)
274
- # Update pipeline with new VAE decoder
275
- pipeline.vae = current_vae_decoder
276
-
277
- # Handle FP8 quantization
278
- if enable_fp8 and not fp8_applied:
279
- emit_progress('Applying FP8 quantization...', 3)
280
- print("🔧 Applying FP8 quantization to transformer")
281
- from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
282
- quantize_(transformer, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
283
- fp8_applied = True
284
-
285
- # Text encoding
286
- emit_progress('Encoding text prompt...', 8)
287
- conditional_dict = text_encoder(text_prompts=[prompt])
288
- for key, value in conditional_dict.items():
289
- conditional_dict[key] = value.to(dtype=torch.float16)
290
- if low_memory:
291
- gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
292
- move_model_to_device_with_memory_preservation(
293
- text_encoder,target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
294
-
295
- # Handle torch.compile if enabled
296
- torch_compile_applied = enable_torch_compile
297
- if enable_torch_compile and not models_compiled:
298
- # Compile transformer and decoder
299
- transformer.compile(mode="max-autotune-no-cudagraphs")
300
- if not current_use_taehv and not low_memory and not args.trt:
301
- current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")
302
-
303
- # Initialize generation
304
- emit_progress('Initializing generation...', 12)
305
-
306
- rnd = torch.Generator(gpu).manual_seed(seed)
307
- # all_latents = torch.zeros([1, 21, 16, 60, 104], device=gpu, dtype=torch.bfloat16)
308
-
309
- pipeline._initialize_kv_cache(batch_size=1, dtype=torch.float16, device=gpu)
310
- pipeline._initialize_crossattn_cache(batch_size=1, dtype=torch.float16, device=gpu)
311
-
312
- noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
313
-
314
- # Generation parameters
315
- num_blocks = 7
316
- current_start_frame = 0
317
- num_input_frames = 0
318
- all_num_frames = [pipeline.num_frame_per_block] * num_blocks
319
- if current_use_taehv:
320
- vae_cache = None
321
- else:
322
- vae_cache = ZERO_VAE_CACHE
323
- for i in range(len(vae_cache)):
324
- vae_cache[i] = vae_cache[i].to(device=gpu, dtype=torch.float16)
325
-
326
- total_frames_sent = 0
327
- generation_start_time = time.time()
328
-
329
- emit_progress('Generating frames... (frontend handles timing)', 15)
330
-
331
- for idx, current_num_frames in enumerate(all_num_frames):
332
- if not generation_active or stop_event.is_set():
333
- break
334
-
335
- progress = int(((idx + 1) / len(all_num_frames)) * 80) + 15
336
-
337
- # Special message for first block with torch.compile
338
- if idx == 0 and torch_compile_applied and not models_compiled:
339
- emit_progress(
340
- f'Processing block 1/{len(all_num_frames)} - Compiling models (may take 5-10 minutes)...', progress)
341
- print(f"🔥 Processing block {idx+1}/{len(all_num_frames)}")
342
- models_compiled = True
343
- else:
344
- emit_progress(f'Processing block {idx+1}/{len(all_num_frames)}...', progress)
345
- print(f"🔄 Processing block {idx+1}/{len(all_num_frames)}")
346
-
347
- block_start_time = time.time()
348
-
349
- noisy_input = noise[:, current_start_frame -
350
- num_input_frames:current_start_frame + current_num_frames - num_input_frames]
351
-
352
- # Denoising loop
353
- denoising_start = time.time()
354
- for index, current_timestep in enumerate(pipeline.denoising_step_list):
355
- if not generation_active or stop_event.is_set():
356
- break
357
-
358
- timestep = torch.ones([1, current_num_frames], device=noise.device,
359
- dtype=torch.int64) * current_timestep
360
-
361
- if index < len(pipeline.denoising_step_list) - 1:
362
- _, denoised_pred = transformer(
363
- noisy_image_or_video=noisy_input,
364
- conditional_dict=conditional_dict,
365
- timestep=timestep,
366
- kv_cache=pipeline.kv_cache1,
367
- crossattn_cache=pipeline.crossattn_cache,
368
- current_start=current_start_frame * pipeline.frame_seq_length
369
- )
370
- next_timestep = pipeline.denoising_step_list[index + 1]
371
- noisy_input = pipeline.scheduler.add_noise(
372
- denoised_pred.flatten(0, 1),
373
- torch.randn_like(denoised_pred.flatten(0, 1)),
374
- next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
375
- ).unflatten(0, denoised_pred.shape[:2])
376
- else:
377
- _, denoised_pred = transformer(
378
- noisy_image_or_video=noisy_input,
379
- conditional_dict=conditional_dict,
380
- timestep=timestep,
381
- kv_cache=pipeline.kv_cache1,
382
- crossattn_cache=pipeline.crossattn_cache,
383
- current_start=current_start_frame * pipeline.frame_seq_length
384
- )
385
-
386
- if not generation_active or stop_event.is_set():
387
- break
388
-
389
- denoising_time = time.time() - denoising_start
390
- print(f"⚡ Block {idx+1} denoising completed in {denoising_time:.2f}s")
391
-
392
- # Record output
393
- # all_latents[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
394
-
395
- # Update KV cache for next block
396
- if idx != len(all_num_frames) - 1:
397
- transformer(
398
- noisy_image_or_video=denoised_pred,
399
- conditional_dict=conditional_dict,
400
- timestep=torch.zeros_like(timestep),
401
- kv_cache=pipeline.kv_cache1,
402
- crossattn_cache=pipeline.crossattn_cache,
403
- current_start=current_start_frame * pipeline.frame_seq_length,
404
- )
405
-
406
- # Decode to pixels and send frames immediately
407
- print(f"🎨 Decoding block {idx+1} to pixels...")
408
- decode_start = time.time()
409
- if args.trt:
410
- all_current_pixels = []
411
- for i in range(denoised_pred.shape[1]):
412
- is_first_frame = torch.tensor(1.0).cuda().half() if idx == 0 and i == 0 else \
413
- torch.tensor(0.0).cuda().half()
414
- outputs = vae_decoder.forward(denoised_pred[:, i:i + 1, :, :, :].half(), is_first_frame, *vae_cache)
415
- # outputs = vae_decoder.forward(denoised_pred.float(), *vae_cache)
416
- current_pixels, vae_cache = outputs[0], outputs[1:]
417
- print(current_pixels.max(), current_pixels.min())
418
- all_current_pixels.append(current_pixels.clone())
419
- pixels = torch.cat(all_current_pixels, dim=1)
420
- if idx == 0:
421
- pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
422
- else:
423
- if current_use_taehv:
424
- if vae_cache is None:
425
- vae_cache = denoised_pred
426
- else:
427
- denoised_pred = torch.cat([vae_cache, denoised_pred], dim=1)
428
- vae_cache = denoised_pred[:, -3:, :, :, :]
429
- pixels = current_vae_decoder.decode(denoised_pred)
430
- print(f"denoised_pred shape: {denoised_pred.shape}")
431
- print(f"pixels shape: {pixels.shape}")
432
- if idx == 0:
433
- pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
434
- else:
435
- pixels = pixels[:, 12:, :, :, :]
436
-
437
- else:
438
- pixels, vae_cache = current_vae_decoder(denoised_pred.half(), *vae_cache)
439
- if idx == 0:
440
- pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
441
-
442
- decode_time = time.time() - decode_start
443
- print(f"🎨 Block {idx+1} VAE decoding completed in {decode_time:.2f}s")
444
-
445
- # Queue frames for non-blocking sending
446
- block_frames = pixels.shape[1]
447
- print(f"📡 Queueing {block_frames} frames from block {idx+1} for sending...")
448
- queue_start = time.time()
449
-
450
- for frame_idx in range(block_frames):
451
- if not generation_active or stop_event.is_set():
452
- break
453
-
454
- frame_tensor = pixels[0, frame_idx].cpu()
455
-
456
- # Queue frame data in non-blocking way
457
- frame_send_queue.put((frame_tensor, total_frames_sent, idx, job_id))
458
- total_frames_sent += 1
459
-
460
- queue_time = time.time() - queue_start
461
- block_time = time.time() - block_start_time
462
- print(f"✅ Block {idx+1} completed in {block_time:.2f}s ({block_frames} frames queued in {queue_time:.3f}s)")
463
-
464
- current_start_frame += current_num_frames
465
-
466
- generation_time = time.time() - generation_start_time
467
- print(f"🎉 Generation completed in {generation_time:.2f}s! {total_frames_sent} frames queued for sending")
468
-
469
- # Wait for all frames to be sent before completing
470
- emit_progress('Waiting for all frames to be sent...', 97)
471
- print("⏳ Waiting for all frames to be sent...")
472
- frame_send_queue.join() # Wait for all queued frames to be processed
473
- print("✅ All frames sent successfully!")
474
-
475
- generate_mp4_from_images("./images","./videos/"+anim_name+".mp4", frame_rate )
476
- # Final progress update
477
- emit_progress('Generation complete!', 100)
478
-
479
- try:
480
- socketio.emit('generation_complete', {
481
- 'message': 'Video generation completed!',
482
- 'total_frames': total_frames_sent,
483
- 'generation_time': f"{generation_time:.2f}s",
484
- 'job_id': job_id
485
- })
486
- except Exception as e:
487
- print(f"❌ Failed to emit generation complete: {e}")
488
-
489
- except Exception as e:
490
- print(f"❌ Generation failed: {e}")
491
- try:
492
- socketio.emit('error', {
493
- 'message': f'Generation failed: {str(e)}',
494
- 'job_id': job_id
495
- })
496
- except Exception as e:
497
- print(f"❌ Failed to emit error: {e}")
498
- finally:
499
- generation_active = False
500
- stop_event.set()
501
-
502
- # Clean up sender thread
503
- try:
504
- frame_send_queue.put(None)
505
- except Exception as e:
506
- print(f"❌ Failed to put None in frame_send_queue: {e}")
507
-
508
-
509
- def generate_mp4_from_images(image_directory, output_video_path, fps=24):
510
- """
511
- Generate an MP4 video from a directory of images ordered alphabetically.
512
-
513
- :param image_directory: Path to the directory containing images.
514
- :param output_video_path: Path where the output MP4 will be saved.
515
- :param fps: Frames per second for the output video.
516
- """
517
- global anim_name
518
- # Construct the ffmpeg command
519
- cmd = [
520
- 'ffmpeg',
521
- '-framerate', str(fps),
522
- '-i', os.path.join(image_directory, anim_name+'/'+anim_name+'_%03d.jpg'), # Adjust the pattern if necessary
523
- '-c:v', 'libx264',
524
- '-pix_fmt', 'yuv420p',
525
- output_video_path
526
- ]
527
- try:
528
- subprocess.run(cmd, check=True)
529
- print(f"Video saved to {output_video_path}")
530
- except subprocess.CalledProcessError as e:
531
- print(f"An error occurred: {e}")
532
-
533
- def calculate_sha256(data):
534
- # Convert data to bytes if it's not already
535
- if isinstance(data, str):
536
- data = data.encode()
537
- # Calculate SHA-256 hash
538
- sha256_hash = hashlib.sha256(data).hexdigest()
539
- return sha256_hash
540
-
541
- # Socket.IO event handlers
542
- @socketio.on('connect')
543
- def handle_connect():
544
- print('Client connected')
545
- emit('status', {'message': 'Connected to frontend-buffered demo server'})
546
-
547
-
548
- @socketio.on('disconnect')
549
- def handle_disconnect():
550
- print('Client disconnected')
551
-
552
-
553
- @socketio.on('start_generation')
554
- def handle_start_generation(data):
555
- global generation_active, frame_number, anim_name, frame_rate
556
-
557
- frame_number = 0
558
- if generation_active:
559
- emit('error', {'message': 'Generation already in progress'})
560
- return
561
-
562
- prompt = data.get('prompt', '')
563
-
564
- seed = data.get('seed', -1)
565
- if seed==-1:
566
- seed = random.randint(0, 2**32)
567
-
568
- # Extract words up to the first punctuation or newline
569
- words_up_to_punctuation = re.split(r'[^\w\s]', prompt)[0].strip() if prompt else ''
570
- if not words_up_to_punctuation:
571
- words_up_to_punctuation = re.split(r'[\n\r]', prompt)[0].strip()
572
-
573
- # Calculate SHA-256 hash of the entire prompt
574
- sha256_hash = calculate_sha256(prompt)
575
-
576
- # Create anim_name with the extracted words and first 10 characters of the hash
577
- anim_name = f"{words_up_to_punctuation[:20]}_{str(seed)}_{sha256_hash[:10]}"
578
-
579
- generation_active = True
580
- generation_start_time = time.time()
581
- enable_torch_compile = data.get('enable_torch_compile', False)
582
- enable_fp8 = data.get('enable_fp8', False)
583
- use_taehv = data.get('use_taehv', False)
584
- frame_rate = data.get('fps', 6)
585
-
586
- if not prompt:
587
- emit('error', {'message': 'Prompt is required'})
588
- return
589
-
590
- # Start generation in background thread
591
- socketio.start_background_task(generate_video_stream, prompt, seed,
592
- enable_torch_compile, enable_fp8, use_taehv)
593
- emit('status', {'message': 'Generation started - frames will be sent immediately'})
594
-
595
-
596
- @socketio.on('stop_generation')
597
- def handle_stop_generation():
598
- global generation_active, stop_event, frame_send_queue
599
- generation_active = False
600
- stop_event.set()
601
-
602
- # Signal sender thread to stop (will be processed after current frames)
603
- try:
604
- frame_send_queue.put(None)
605
- except Exception as e:
606
- print(f"❌ Failed to put None in frame_send_queue: {e}")
607
-
608
- emit('status', {'message': 'Generation stopped'})
609
-
610
- # Web routes
611
-
612
-
613
- @app.route('/')
614
- def index():
615
- return render_template('demo.html')
616
-
617
-
618
- @app.route('/api/status')
619
- def api_status():
620
- return jsonify({
621
- 'generation_active': generation_active,
622
- 'free_vram_gb': get_cuda_free_memory_gb(gpu),
623
- 'fp8_applied': fp8_applied,
624
- 'torch_compile_applied': torch_compile_applied,
625
- 'current_use_taehv': current_use_taehv
626
- })
627
-
628
-
629
- if __name__ == '__main__':
630
- print(f"🚀 Starting demo on http://{args.host}:{args.port}")
631
- socketio.run(app, host=args.host, port=args.port, debug=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/wan_wrapper.py CHANGED
@@ -1,8 +1,14 @@
1
  import types
2
  from typing import List, Optional
 
3
  import torch
4
  from torch import nn
5
 
 
 
 
 
 
6
  from utils.scheduler import SchedulerInterface, FlowMatchScheduler
7
  from wan.modules.tokenizers import HuggingfaceTokenizer
8
  from wan.modules.model import WanModel, RegisterTokens, GanAttentionBlock
@@ -22,12 +28,12 @@ class WanTextEncoder(torch.nn.Module):
22
  device=torch.device('cpu')
23
  ).eval().requires_grad_(False)
24
  self.text_encoder.load_state_dict(
25
- torch.load("wan_models/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
26
  map_location='cpu', weights_only=False)
27
  )
28
 
29
  self.tokenizer = HuggingfaceTokenizer(
30
- name="wan_models/Wan2.1-T2V-1.3B/google/umt5-xxl/", seq_len=512, clean='whitespace')
31
 
32
  @property
33
  def device(self):
@@ -66,7 +72,7 @@ class WanVAEWrapper(torch.nn.Module):
66
 
67
  # init model
68
  self.model = _video_vae(
69
- pretrained_path="wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
70
  z_dim=16,
71
  ).eval().requires_grad_(False)
72
 
@@ -125,9 +131,9 @@ class WanDiffusionWrapper(torch.nn.Module):
125
 
126
  if is_causal:
127
  self.model = CausalWanModel.from_pretrained(
128
- f"wan_models/{model_name}/", local_attn_size=local_attn_size, sink_size=sink_size)
129
  else:
130
- self.model = WanModel.from_pretrained(f"wan_models/{model_name}/")
131
  self.model.eval()
132
 
133
  # For non-causal diffusion, all frames share the same timestep
 
1
  import types
2
  from typing import List, Optional
3
+ import os
4
  import torch
5
  from torch import nn
6
 
7
+ # Configuration for data paths
8
+ DATA_ROOT = os.path.normpath(os.getenv('DATA_ROOT', '.'))
9
+ WAN_MODELS_PATH = os.path.join(DATA_ROOT, 'wan_models')
10
+ OTHER_MODELS_PATH = os.path.join(DATA_ROOT, 'other_models')
11
+
12
  from utils.scheduler import SchedulerInterface, FlowMatchScheduler
13
  from wan.modules.tokenizers import HuggingfaceTokenizer
14
  from wan.modules.model import WanModel, RegisterTokens, GanAttentionBlock
 
28
  device=torch.device('cpu')
29
  ).eval().requires_grad_(False)
30
  self.text_encoder.load_state_dict(
31
+ torch.load(os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B", "models_t5_umt5-xxl-enc-bf16.pth"),
32
  map_location='cpu', weights_only=False)
33
  )
34
 
35
  self.tokenizer = HuggingfaceTokenizer(
36
+ name=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B", "google", "umt5-xxl") + "/", seq_len=512, clean='whitespace')
37
 
38
  @property
39
  def device(self):
 
72
 
73
  # init model
74
  self.model = _video_vae(
75
+ pretrained_path=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B", "Wan2.1_VAE.pth"),
76
  z_dim=16,
77
  ).eval().requires_grad_(False)
78
 
 
131
 
132
  if is_causal:
133
  self.model = CausalWanModel.from_pretrained(
134
+ os.path.join(WAN_MODELS_PATH, model_name) + "/", local_attn_size=local_attn_size, sink_size=sink_size)
135
  else:
136
+ self.model = WanModel.from_pretrained(os.path.join(WAN_MODELS_PATH, model_name) + "/")
137
  self.model.eval()
138
 
139
  # For non-causal diffusion, all frames share the same timestep