Commit
·
b55bb25
1
Parent(s):
f5f96d3
ok
Browse files- app.py +242 -159
- app_last_working.py +0 -460
- demo.py +0 -631
- utils/wan_wrapper.py +11 -5
app.py
CHANGED
@@ -2,11 +2,17 @@ import subprocess
|
|
2 |
# not sure why it works in the original space but says "pip not found" in mine
|
3 |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
4 |
|
|
|
5 |
from huggingface_hub import snapshot_download, hf_hub_download
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
snapshot_download(
|
8 |
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
|
9 |
-
local_dir="
|
10 |
local_dir_use_symlinks=False,
|
11 |
resume_download=True,
|
12 |
repo_type="model"
|
@@ -15,11 +21,9 @@ snapshot_download(
|
|
15 |
hf_hub_download(
|
16 |
repo_id="gdhe17/Self-Forcing",
|
17 |
filename="checkpoints/self_forcing_dmd.pt",
|
18 |
-
local_dir=
|
19 |
local_dir_use_symlinks=False
|
20 |
)
|
21 |
-
|
22 |
-
import os
|
23 |
import re
|
24 |
import random
|
25 |
import argparse
|
@@ -34,6 +38,10 @@ from tqdm import tqdm
|
|
34 |
import imageio
|
35 |
import av
|
36 |
import uuid
|
|
|
|
|
|
|
|
|
37 |
|
38 |
from pipeline import CausalInferencePipeline
|
39 |
from demo_utils.constant import ZERO_VAE_CACHE
|
@@ -45,11 +53,25 @@ import numpy as np
|
|
45 |
|
46 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
# --- Argument Parsing ---
|
49 |
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
|
50 |
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
|
51 |
parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
|
52 |
-
parser.add_argument("--checkpoint_path", type=str, default='
|
53 |
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
|
54 |
parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
|
55 |
parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
|
@@ -107,6 +129,89 @@ if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
|
|
107 |
APP_STATE["torch_compile_applied"] = True
|
108 |
print("✅ torch.compile applied to transformer")
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
111 |
"""
|
112 |
Convert frames directly to .ts file using PyAV.
|
@@ -193,7 +298,7 @@ def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
|
193 |
print("Initializing Default VAE Decoder...")
|
194 |
vae_decoder = VAEDecoderWrapper()
|
195 |
try:
|
196 |
-
vae_state_dict = torch.load('
|
197 |
decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
|
198 |
vae_decoder.load_state_dict(decoder_state_dict)
|
199 |
except FileNotFoundError:
|
@@ -222,26 +327,22 @@ pipeline = CausalInferencePipeline(
|
|
222 |
pipeline.to(dtype=torch.float16).to(gpu)
|
223 |
|
224 |
@torch.no_grad()
|
225 |
-
def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5,
|
226 |
"""
|
227 |
Generator function that yields .ts video chunks using PyAV for streaming.
|
228 |
-
Now optimized for block-based processing with smart buffering.
|
229 |
"""
|
230 |
if seed == -1:
|
231 |
seed = random.randint(0, 2**32 - 1)
|
232 |
|
233 |
-
print(f"🎬 Starting PyAV streaming:
|
234 |
-
|
235 |
-
#
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
)
|
243 |
-
yield None, buffering_status_html
|
244 |
-
|
245 |
# Setup
|
246 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
247 |
for key, value in conditional_dict.items():
|
@@ -260,7 +361,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
|
|
260 |
# Current setup generates approximately 5 seconds with 7 blocks
|
261 |
# So we scale proportionally
|
262 |
base_duration = 5.0 # seconds
|
263 |
-
base_blocks =
|
264 |
num_blocks = max(1, int(base_blocks * duration / base_duration))
|
265 |
|
266 |
current_start_frame = 0
|
@@ -270,13 +371,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
|
|
270 |
|
271 |
# Ensure temp directory exists
|
272 |
os.makedirs("gradio_tmp", exist_ok=True)
|
273 |
-
|
274 |
-
# Buffer management - collect chunks before streaming
|
275 |
-
buffer_chunks = []
|
276 |
-
buffer_duration = 0.0
|
277 |
-
frames_per_second = fps
|
278 |
-
streaming_started = False
|
279 |
-
|
280 |
# Generation loop
|
281 |
for idx, current_num_frames in enumerate(all_num_frames):
|
282 |
print(f"📦 Processing block {idx+1}/{num_blocks}")
|
@@ -375,45 +470,11 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
|
|
375 |
|
376 |
frames_to_ts_file(all_frames_from_block, ts_path, fps)
|
377 |
|
378 |
-
# Calculate
|
379 |
-
|
380 |
|
381 |
-
#
|
382 |
-
|
383 |
-
buffer_duration += chunk_duration
|
384 |
-
|
385 |
-
# Check if we have enough buffered content to start streaming
|
386 |
-
if not streaming_started and buffer_duration >= buffering:
|
387 |
-
print(f"🚀 Buffer filled ({buffer_duration:.2f}s >= {buffering}s), starting stream!")
|
388 |
-
streaming_started = True
|
389 |
-
|
390 |
-
# Stream all buffered chunks
|
391 |
-
for buffered_chunk in buffer_chunks:
|
392 |
-
yield buffered_chunk, gr.update()
|
393 |
-
|
394 |
-
# Clear buffer since we've streamed it
|
395 |
-
buffer_chunks.clear()
|
396 |
-
buffer_duration = 0.0
|
397 |
-
|
398 |
-
elif streaming_started:
|
399 |
-
# Stream immediately if we're already streaming
|
400 |
-
yield ts_path, gr.update()
|
401 |
-
elif buffering == 0:
|
402 |
-
# No buffering requested, stream immediately
|
403 |
-
yield ts_path, gr.update()
|
404 |
-
else:
|
405 |
-
# Still buffering, show progress
|
406 |
-
buffer_progress = (buffer_duration / buffering) * 100
|
407 |
-
buffering_progress_html = (
|
408 |
-
f"<div style='padding: 10px; border: 1px solid #ffc107; background: #fff3cd; border-radius: 8px; font-family: sans-serif;'>"
|
409 |
-
f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>⏳ Buffering... ({buffer_duration:.1f}s/{buffering}s)</p>"
|
410 |
-
f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
|
411 |
-
f" <div style='width: {buffer_progress:.1f}%; height: 20px; background-color: #ffc107; transition: width 0.2s;'></div>"
|
412 |
-
f" </div>"
|
413 |
-
f" <p style='margin: 4px 0 0 0; color: #856404; font-size: 14px;'>Generating content for smooth playback...</p>"
|
414 |
-
f"</div>"
|
415 |
-
)
|
416 |
-
yield None, buffering_progress_html
|
417 |
|
418 |
except Exception as e:
|
419 |
print(f"⚠️ Error encoding block {idx}: {e}")
|
@@ -422,12 +483,6 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
|
|
422 |
|
423 |
current_start_frame += current_num_frames
|
424 |
|
425 |
-
# Stream any remaining buffered content
|
426 |
-
if buffer_chunks:
|
427 |
-
print(f"🎬 Streaming remaining {len(buffer_chunks)} buffered chunks")
|
428 |
-
for buffered_chunk in buffer_chunks:
|
429 |
-
yield buffered_chunk, gr.update()
|
430 |
-
|
431 |
# Final completion status
|
432 |
final_status_html = (
|
433 |
f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
|
@@ -449,104 +504,132 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
|
|
449 |
print(f"✅ PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
450 |
|
451 |
# --- Gradio UI Layout ---
|
452 |
-
with gr.Blocks(title="Self-Forcing
|
453 |
-
gr.Markdown("# 🚀
|
454 |
-
gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
|
455 |
|
456 |
-
with gr.
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
placeholder="A stylish woman walks down a Tokyo street...",
|
462 |
-
lines=4,
|
463 |
-
value=""
|
464 |
-
)
|
465 |
-
|
466 |
-
start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
|
467 |
|
468 |
-
gr.Markdown("### ⚙️ Settings")
|
469 |
with gr.Row():
|
470 |
-
|
471 |
-
|
472 |
-
value=-1,
|
473 |
-
info="Use -1 for random seed",
|
474 |
-
precision=0
|
475 |
-
)
|
476 |
-
fps = gr.Slider(
|
477 |
-
label="Playback FPS",
|
478 |
-
minimum=1,
|
479 |
-
maximum=30,
|
480 |
-
value=args.fps,
|
481 |
-
step=1,
|
482 |
-
visible=False,
|
483 |
-
info="Frames per second for playback"
|
484 |
-
)
|
485 |
-
|
486 |
with gr.Row():
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
value=5,
|
492 |
-
step=1,
|
493 |
-
info="Video duration in seconds"
|
494 |
-
)
|
495 |
-
buffering = gr.Slider(
|
496 |
-
label="Buffering (seconds)",
|
497 |
-
minimum=0,
|
498 |
-
maximum=5,
|
499 |
-
value=2,
|
500 |
-
step=0.5,
|
501 |
-
info="Wait time before starting stream"
|
502 |
-
)
|
503 |
-
|
504 |
with gr.Row():
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
|
544 |
# Connect the generator to the streaming video
|
545 |
start_btn.click(
|
546 |
fn=video_generation_handler_streaming,
|
547 |
-
inputs=[prompt, seed, fps, width, height, duration,
|
548 |
outputs=[streaming_video, status_display]
|
549 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
|
551 |
|
552 |
# --- Launch App ---
|
|
|
2 |
# not sure why it works in the original space but says "pip not found" in mine
|
3 |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
4 |
|
5 |
+
import os
|
6 |
from huggingface_hub import snapshot_download, hf_hub_download
|
7 |
|
8 |
+
# Configuration for data paths
|
9 |
+
DATA_ROOT = os.path.normpath(os.getenv('DATA_ROOT', '.'))
|
10 |
+
WAN_MODELS_PATH = os.path.join(DATA_ROOT, 'wan_models')
|
11 |
+
OTHER_MODELS_PATH = os.path.join(DATA_ROOT, 'other_models')
|
12 |
+
|
13 |
snapshot_download(
|
14 |
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
|
15 |
+
local_dir=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B"),
|
16 |
local_dir_use_symlinks=False,
|
17 |
resume_download=True,
|
18 |
repo_type="model"
|
|
|
21 |
hf_hub_download(
|
22 |
repo_id="gdhe17/Self-Forcing",
|
23 |
filename="checkpoints/self_forcing_dmd.pt",
|
24 |
+
local_dir=OTHER_MODELS_PATH,
|
25 |
local_dir_use_symlinks=False
|
26 |
)
|
|
|
|
|
27 |
import re
|
28 |
import random
|
29 |
import argparse
|
|
|
38 |
import imageio
|
39 |
import av
|
40 |
import uuid
|
41 |
+
import tempfile
|
42 |
+
import shutil
|
43 |
+
from pathlib import Path
|
44 |
+
from typing import Dict, Any, List, Optional, Tuple, Union
|
45 |
|
46 |
from pipeline import CausalInferencePipeline
|
47 |
from demo_utils.constant import ZERO_VAE_CACHE
|
|
|
53 |
|
54 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
|
56 |
+
# LoRA Storage Configuration
|
57 |
+
STORAGE_PATH = Path(DATA_ROOT) / "storage"
|
58 |
+
LORA_PATH = STORAGE_PATH / "loras"
|
59 |
+
OUTPUT_PATH = STORAGE_PATH / "output"
|
60 |
+
|
61 |
+
# Create necessary directories
|
62 |
+
STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
63 |
+
LORA_PATH.mkdir(parents=True, exist_ok=True)
|
64 |
+
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
|
65 |
+
|
66 |
+
# Global variables for LoRA management
|
67 |
+
current_lora_id = None
|
68 |
+
current_lora_path = None
|
69 |
+
|
70 |
# --- Argument Parsing ---
|
71 |
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
|
72 |
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
|
73 |
parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
|
74 |
+
parser.add_argument("--checkpoint_path", type=str, default=os.path.join(OTHER_MODELS_PATH, 'checkpoints', 'self_forcing_dmd.pt'), help="Path to the model checkpoint.")
|
75 |
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
|
76 |
parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
|
77 |
parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
|
|
|
129 |
APP_STATE["torch_compile_applied"] = True
|
130 |
print("✅ torch.compile applied to transformer")
|
131 |
|
132 |
+
def upload_lora_file(file: tempfile._TemporaryFileWrapper) -> Tuple[str, str]:
|
133 |
+
"""Upload a LoRA file and return a hash-based ID for future reference"""
|
134 |
+
if file is None:
|
135 |
+
return "", ""
|
136 |
+
|
137 |
+
try:
|
138 |
+
# Calculate SHA256 hash of the file
|
139 |
+
sha256_hash = hashlib.sha256()
|
140 |
+
with open(file.name, "rb") as f:
|
141 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
142 |
+
sha256_hash.update(chunk)
|
143 |
+
file_hash = sha256_hash.hexdigest()
|
144 |
+
|
145 |
+
# Create destination path using hash
|
146 |
+
dest_path = LORA_PATH / f"{file_hash}.safetensors"
|
147 |
+
|
148 |
+
# Check if file already exists
|
149 |
+
if dest_path.exists():
|
150 |
+
print(f"LoRA file already exists!")
|
151 |
+
return file_hash, file_hash
|
152 |
+
|
153 |
+
# Copy the file to the destination
|
154 |
+
shutil.copy(file.name, dest_path)
|
155 |
+
|
156 |
+
print(f"LoRA file uploaded!")
|
157 |
+
return file_hash, file_hash
|
158 |
+
except Exception as e:
|
159 |
+
print(f"Error uploading LoRA file: {e}")
|
160 |
+
raise gr.Error(f"Failed to upload LoRA file: {str(e)}")
|
161 |
+
|
162 |
+
def get_lora_file_path(lora_id: Optional[str]) -> Optional[Path]:
|
163 |
+
"""Get the path to a LoRA file from its hash-based ID"""
|
164 |
+
if not lora_id:
|
165 |
+
return None
|
166 |
+
|
167 |
+
# Check if file exists
|
168 |
+
lora_path = LORA_PATH / f"{lora_id}.safetensors"
|
169 |
+
if lora_path.exists():
|
170 |
+
return lora_path
|
171 |
+
|
172 |
+
return None
|
173 |
+
|
174 |
+
def manage_lora_weights(lora_id: Optional[str], lora_weight: float) -> Tuple[bool, Optional[Path]]:
|
175 |
+
"""Manage LoRA weights for the transformer model"""
|
176 |
+
global current_lora_id, current_lora_path
|
177 |
+
|
178 |
+
# Determine if we should use LoRA
|
179 |
+
using_lora = lora_id is not None and lora_id.strip() != "" and lora_weight > 0
|
180 |
+
|
181 |
+
# If not using LoRA but we have one loaded, clear it
|
182 |
+
if not using_lora and current_lora_id is not None:
|
183 |
+
print(f"Clearing current LoRA")
|
184 |
+
current_lora_id = None
|
185 |
+
current_lora_path = None
|
186 |
+
return False, None
|
187 |
+
|
188 |
+
# If using LoRA, check if we need to change weights
|
189 |
+
if using_lora:
|
190 |
+
lora_path = get_lora_file_path(lora_id)
|
191 |
+
|
192 |
+
if not lora_path:
|
193 |
+
print(f"A LoRA file with this ID was found. Using base model instead.")
|
194 |
+
|
195 |
+
# If we had a LoRA loaded, clear it
|
196 |
+
if current_lora_id is not None:
|
197 |
+
print(f"Clearing current LoRA")
|
198 |
+
current_lora_id = None
|
199 |
+
current_lora_path = None
|
200 |
+
|
201 |
+
return False, None
|
202 |
+
|
203 |
+
# If LoRA ID changed, update
|
204 |
+
if lora_id != current_lora_id:
|
205 |
+
print(f"Loading LoRA..")
|
206 |
+
current_lora_id = lora_id
|
207 |
+
current_lora_path = lora_path
|
208 |
+
else:
|
209 |
+
print(f"Using a LoRA!")
|
210 |
+
|
211 |
+
return True, lora_path
|
212 |
+
|
213 |
+
return False, None
|
214 |
+
|
215 |
def frames_to_ts_file(frames, filepath, fps = 15):
|
216 |
"""
|
217 |
Convert frames directly to .ts file using PyAV.
|
|
|
298 |
print("Initializing Default VAE Decoder...")
|
299 |
vae_decoder = VAEDecoderWrapper()
|
300 |
try:
|
301 |
+
vae_state_dict = torch.load(os.path.join(WAN_MODELS_PATH, 'Wan2.1-T2V-1.3B', 'Wan2.1_VAE.pth'), map_location="cpu")
|
302 |
decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
|
303 |
vae_decoder.load_state_dict(decoder_state_dict)
|
304 |
except FileNotFoundError:
|
|
|
327 |
pipeline.to(dtype=torch.float16).to(gpu)
|
328 |
|
329 |
@torch.no_grad()
|
330 |
+
def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5, lora_id=None, lora_weight=0.0):
|
331 |
"""
|
332 |
Generator function that yields .ts video chunks using PyAV for streaming.
|
|
|
333 |
"""
|
334 |
if seed == -1:
|
335 |
seed = random.randint(0, 2**32 - 1)
|
336 |
|
337 |
+
# print(f"🎬 Starting PyAV streaming: seed: {seed}, duration: {duration}s")
|
338 |
+
|
339 |
+
# Handle LoRA weights
|
340 |
+
using_lora, lora_path = manage_lora_weights(lora_id, lora_weight)
|
341 |
+
if using_lora:
|
342 |
+
print(f"🎨 Using LoRA with weight factor {lora_weight}")
|
343 |
+
else:
|
344 |
+
print("🎨 Using base model (no LoRA)")
|
345 |
+
|
|
|
|
|
|
|
346 |
# Setup
|
347 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
348 |
for key, value in conditional_dict.items():
|
|
|
361 |
# Current setup generates approximately 5 seconds with 7 blocks
|
362 |
# So we scale proportionally
|
363 |
base_duration = 5.0 # seconds
|
364 |
+
base_blocks = 8
|
365 |
num_blocks = max(1, int(base_blocks * duration / base_duration))
|
366 |
|
367 |
current_start_frame = 0
|
|
|
371 |
|
372 |
# Ensure temp directory exists
|
373 |
os.makedirs("gradio_tmp", exist_ok=True)
|
374 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
# Generation loop
|
376 |
for idx, current_num_frames in enumerate(all_num_frames):
|
377 |
print(f"📦 Processing block {idx+1}/{num_blocks}")
|
|
|
470 |
|
471 |
frames_to_ts_file(all_frames_from_block, ts_path, fps)
|
472 |
|
473 |
+
# Calculate final progress for this block
|
474 |
+
total_progress = (idx + 1) / num_blocks * 100
|
475 |
|
476 |
+
# Yield the actual video chunk
|
477 |
+
yield ts_path, gr.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
|
479 |
except Exception as e:
|
480 |
print(f"⚠️ Error encoding block {idx}: {e}")
|
|
|
483 |
|
484 |
current_start_frame += current_num_frames
|
485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
# Final completion status
|
487 |
final_status_html = (
|
488 |
f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
|
|
|
504 |
print(f"✅ PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
505 |
|
506 |
# --- Gradio UI Layout ---
|
507 |
+
with gr.Blocks(title="Wan2.1 1.3B LoRA Self-Forcing streaming demo") as demo:
|
508 |
+
gr.Markdown("# 🚀 Run Any LoRA in near real-time!")
|
509 |
+
gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B and LoRA [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
|
510 |
|
511 |
+
with gr.Tabs():
|
512 |
+
# LoRA Upload Tab
|
513 |
+
with gr.TabItem("1️⃣ Upload LoRA"):
|
514 |
+
gr.Markdown("## Upload LoRA Weights")
|
515 |
+
gr.Markdown("Upload your custom LoRA weights file to use for generation. The file will be automatically stored and you'll receive a unique hash-based ID.")
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
|
|
|
517 |
with gr.Row():
|
518 |
+
lora_file = gr.File(label="LoRA File (safetensors format)")
|
519 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
with gr.Row():
|
521 |
+
lora_id_output = gr.Textbox(label="LoRA Hash ID (use this in the generation tab)", interactive=False)
|
522 |
+
|
523 |
+
# Video Generation Tab
|
524 |
+
with gr.TabItem("2️⃣ Generate Video"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
525 |
with gr.Row():
|
526 |
+
with gr.Column(scale=2):
|
527 |
+
with gr.Group():
|
528 |
+
prompt = gr.Textbox(
|
529 |
+
label="Prompt",
|
530 |
+
placeholder="A stylish woman walks down a Tokyo street...",
|
531 |
+
lines=4,
|
532 |
+
value=""
|
533 |
+
)
|
534 |
+
|
535 |
+
start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
|
536 |
+
|
537 |
+
gr.Markdown("### ⚙️ Settings")
|
538 |
+
with gr.Row():
|
539 |
+
seed = gr.Number(
|
540 |
+
label="Seed",
|
541 |
+
value=-1,
|
542 |
+
info="Use -1 for random seed",
|
543 |
+
precision=0
|
544 |
+
)
|
545 |
+
fps = gr.Slider(
|
546 |
+
label="Playback FPS",
|
547 |
+
minimum=1,
|
548 |
+
maximum=30,
|
549 |
+
value=args.fps,
|
550 |
+
step=1,
|
551 |
+
visible=False,
|
552 |
+
info="Frames per second for playback"
|
553 |
+
)
|
554 |
+
|
555 |
+
with gr.Row():
|
556 |
+
duration = gr.Slider(
|
557 |
+
label="Duration (seconds)",
|
558 |
+
minimum=1,
|
559 |
+
maximum=5,
|
560 |
+
value=3,
|
561 |
+
step=1,
|
562 |
+
info="Video duration in seconds"
|
563 |
+
)
|
564 |
+
|
565 |
+
with gr.Row():
|
566 |
+
width = gr.Slider(
|
567 |
+
label="Width",
|
568 |
+
minimum=224,
|
569 |
+
maximum=720,
|
570 |
+
value=400,
|
571 |
+
step=8,
|
572 |
+
info="Video width in pixels (8px steps)"
|
573 |
+
)
|
574 |
+
height = gr.Slider(
|
575 |
+
label="Height",
|
576 |
+
minimum=224,
|
577 |
+
maximum=720,
|
578 |
+
value=224,
|
579 |
+
step=8,
|
580 |
+
info="Video height in pixels (8px steps)"
|
581 |
+
)
|
582 |
+
|
583 |
+
gr.Markdown("### 🎨 LoRA Settings")
|
584 |
+
lora_id = gr.Textbox(
|
585 |
+
label="LoRA ID (from upload tab)",
|
586 |
+
placeholder="Enter your LoRA ID here...",
|
587 |
+
)
|
588 |
+
|
589 |
+
lora_weight = gr.Slider(
|
590 |
+
label="LoRA Weight",
|
591 |
+
minimum=0.0,
|
592 |
+
maximum=1.0,
|
593 |
+
step=0.01,
|
594 |
+
value=1.0,
|
595 |
+
info="Strength of LoRA influence"
|
596 |
+
)
|
597 |
+
|
598 |
+
with gr.Column(scale=3):
|
599 |
+
gr.Markdown("### 📺 Video Stream")
|
600 |
+
|
601 |
+
streaming_video = gr.Video(
|
602 |
+
label="Live Stream",
|
603 |
+
streaming=True,
|
604 |
+
loop=True,
|
605 |
+
height=400,
|
606 |
+
autoplay=True,
|
607 |
+
show_label=False
|
608 |
+
)
|
609 |
+
|
610 |
+
status_display = gr.HTML(
|
611 |
+
value=(
|
612 |
+
"<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
|
613 |
+
"🎬 Ready to start streaming...<br>"
|
614 |
+
"<small>Configure your prompt and click 'Start Streaming'</small>"
|
615 |
+
"</div>"
|
616 |
+
),
|
617 |
+
label="Generation Status"
|
618 |
+
)
|
619 |
|
620 |
# Connect the generator to the streaming video
|
621 |
start_btn.click(
|
622 |
fn=video_generation_handler_streaming,
|
623 |
+
inputs=[prompt, seed, fps, width, height, duration, lora_id, lora_weight],
|
624 |
outputs=[streaming_video, status_display]
|
625 |
)
|
626 |
+
|
627 |
+
# Connect LoRA upload to both display fields
|
628 |
+
lora_file.change(
|
629 |
+
fn=upload_lora_file,
|
630 |
+
inputs=[lora_file],
|
631 |
+
outputs=[lora_id_output, lora_id]
|
632 |
+
)
|
633 |
|
634 |
|
635 |
# --- Launch App ---
|
app_last_working.py
DELETED
@@ -1,460 +0,0 @@
|
|
1 |
-
import subprocess
|
2 |
-
# not sure why it works in the original space but says "pip not found" in mine
|
3 |
-
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
4 |
-
|
5 |
-
from huggingface_hub import snapshot_download, hf_hub_download
|
6 |
-
|
7 |
-
snapshot_download(
|
8 |
-
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
|
9 |
-
local_dir="wan_models/Wan2.1-T2V-1.3B",
|
10 |
-
local_dir_use_symlinks=False,
|
11 |
-
resume_download=True,
|
12 |
-
repo_type="model"
|
13 |
-
)
|
14 |
-
|
15 |
-
hf_hub_download(
|
16 |
-
repo_id="gdhe17/Self-Forcing",
|
17 |
-
filename="checkpoints/self_forcing_dmd.pt",
|
18 |
-
local_dir=".",
|
19 |
-
local_dir_use_symlinks=False
|
20 |
-
)
|
21 |
-
|
22 |
-
import os
|
23 |
-
import re
|
24 |
-
import random
|
25 |
-
import argparse
|
26 |
-
import hashlib
|
27 |
-
import urllib.request
|
28 |
-
import time
|
29 |
-
from PIL import Image
|
30 |
-
import torch
|
31 |
-
import gradio as gr
|
32 |
-
from omegaconf import OmegaConf
|
33 |
-
from tqdm import tqdm
|
34 |
-
import imageio
|
35 |
-
import av
|
36 |
-
import uuid
|
37 |
-
|
38 |
-
from pipeline import CausalInferencePipeline
|
39 |
-
from demo_utils.constant import ZERO_VAE_CACHE
|
40 |
-
from demo_utils.vae_block3 import VAEDecoderWrapper
|
41 |
-
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
|
42 |
-
|
43 |
-
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM #, BitsAndBytesConfig
|
44 |
-
import numpy as np
|
45 |
-
|
46 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
-
|
48 |
-
# --- Argument Parsing ---
|
49 |
-
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
|
50 |
-
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
|
51 |
-
parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
|
52 |
-
parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt', help="Path to the model checkpoint.")
|
53 |
-
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
|
54 |
-
parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
|
55 |
-
parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
|
56 |
-
parser.add_argument('--fps', type=float, default=15.0, help="Playback FPS for frame streaming.")
|
57 |
-
args = parser.parse_args()
|
58 |
-
|
59 |
-
gpu = "cuda"
|
60 |
-
|
61 |
-
try:
|
62 |
-
config = OmegaConf.load(args.config_path)
|
63 |
-
default_config = OmegaConf.load("configs/default_config.yaml")
|
64 |
-
config = OmegaConf.merge(default_config, config)
|
65 |
-
except FileNotFoundError as e:
|
66 |
-
print(f"Error loading config file: {e}\n. Please ensure config files are in the correct path.")
|
67 |
-
exit(1)
|
68 |
-
|
69 |
-
# Initialize Models
|
70 |
-
print("Initializing models...")
|
71 |
-
text_encoder = WanTextEncoder()
|
72 |
-
transformer = WanDiffusionWrapper(is_causal=True)
|
73 |
-
|
74 |
-
try:
|
75 |
-
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
|
76 |
-
transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator')))
|
77 |
-
except FileNotFoundError as e:
|
78 |
-
print(f"Error loading checkpoint: {e}\nPlease ensure the checkpoint '{args.checkpoint_path}' exists.")
|
79 |
-
exit(1)
|
80 |
-
|
81 |
-
text_encoder.eval().to(dtype=torch.float16).requires_grad_(False)
|
82 |
-
transformer.eval().to(dtype=torch.float16).requires_grad_(False)
|
83 |
-
|
84 |
-
text_encoder.to(gpu)
|
85 |
-
transformer.to(gpu)
|
86 |
-
|
87 |
-
APP_STATE = {
|
88 |
-
"torch_compile_applied": False,
|
89 |
-
"fp8_applied": False,
|
90 |
-
"current_use_taehv": False,
|
91 |
-
"current_vae_decoder": None,
|
92 |
-
}
|
93 |
-
|
94 |
-
def frames_to_ts_file(frames, filepath, fps = 15):
|
95 |
-
"""
|
96 |
-
Convert frames directly to .ts file using PyAV.
|
97 |
-
|
98 |
-
Args:
|
99 |
-
frames: List of numpy arrays (HWC, RGB, uint8)
|
100 |
-
filepath: Output file path
|
101 |
-
fps: Frames per second
|
102 |
-
|
103 |
-
Returns:
|
104 |
-
The filepath of the created file
|
105 |
-
"""
|
106 |
-
if not frames:
|
107 |
-
return filepath
|
108 |
-
|
109 |
-
height, width = frames[0].shape[:2]
|
110 |
-
|
111 |
-
# Create container for MPEG-TS format
|
112 |
-
container = av.open(filepath, mode='w', format='mpegts')
|
113 |
-
|
114 |
-
# Add video stream with optimized settings for streaming
|
115 |
-
stream = container.add_stream('h264', rate=fps)
|
116 |
-
stream.width = width
|
117 |
-
stream.height = height
|
118 |
-
stream.pix_fmt = 'yuv420p'
|
119 |
-
|
120 |
-
# Optimize for low latency streaming
|
121 |
-
stream.options = {
|
122 |
-
'preset': 'ultrafast',
|
123 |
-
'tune': 'zerolatency',
|
124 |
-
'crf': '23',
|
125 |
-
'profile': 'baseline',
|
126 |
-
'level': '3.0'
|
127 |
-
}
|
128 |
-
|
129 |
-
try:
|
130 |
-
for frame_np in frames:
|
131 |
-
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
132 |
-
frame = frame.reformat(format=stream.pix_fmt)
|
133 |
-
for packet in stream.encode(frame):
|
134 |
-
container.mux(packet)
|
135 |
-
|
136 |
-
for packet in stream.encode():
|
137 |
-
container.mux(packet)
|
138 |
-
|
139 |
-
finally:
|
140 |
-
container.close()
|
141 |
-
|
142 |
-
return filepath
|
143 |
-
|
144 |
-
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
145 |
-
if use_trt:
|
146 |
-
from demo_utils.vae import VAETRTWrapper
|
147 |
-
print("Initializing TensorRT VAE Decoder...")
|
148 |
-
vae_decoder = VAETRTWrapper()
|
149 |
-
APP_STATE["current_use_taehv"] = False
|
150 |
-
elif use_taehv:
|
151 |
-
print("Initializing TAEHV VAE Decoder...")
|
152 |
-
from demo_utils.taehv import TAEHV
|
153 |
-
taehv_checkpoint_path = "checkpoints/taew2_1.pth"
|
154 |
-
if not os.path.exists(taehv_checkpoint_path):
|
155 |
-
print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...")
|
156 |
-
os.makedirs("checkpoints", exist_ok=True)
|
157 |
-
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
|
158 |
-
try:
|
159 |
-
urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
|
160 |
-
except Exception as e:
|
161 |
-
raise RuntimeError(f"Failed to download taew2_1.pth: {e}")
|
162 |
-
|
163 |
-
class DotDict(dict): __getattr__ = dict.get
|
164 |
-
|
165 |
-
class TAEHVDiffusersWrapper(torch.nn.Module):
|
166 |
-
def __init__(self):
|
167 |
-
super().__init__()
|
168 |
-
self.dtype = torch.float16
|
169 |
-
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
|
170 |
-
self.config = DotDict(scaling_factor=1.0)
|
171 |
-
def decode(self, latents, return_dict=None):
|
172 |
-
return self.taehv.decode_video(latents, parallel=not LOW_MEMORY).mul_(2).sub_(1)
|
173 |
-
|
174 |
-
vae_decoder = TAEHVDiffusersWrapper()
|
175 |
-
APP_STATE["current_use_taehv"] = True
|
176 |
-
else:
|
177 |
-
print("Initializing Default VAE Decoder...")
|
178 |
-
vae_decoder = VAEDecoderWrapper()
|
179 |
-
try:
|
180 |
-
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
|
181 |
-
decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
|
182 |
-
vae_decoder.load_state_dict(decoder_state_dict)
|
183 |
-
except FileNotFoundError:
|
184 |
-
print("Warning: Default VAE weights not found.")
|
185 |
-
APP_STATE["current_use_taehv"] = False
|
186 |
-
|
187 |
-
vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(gpu)
|
188 |
-
APP_STATE["current_vae_decoder"] = vae_decoder
|
189 |
-
print(f"✅ VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
|
190 |
-
|
191 |
-
# Initialize with default VAE
|
192 |
-
initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
|
193 |
-
|
194 |
-
pipeline = CausalInferencePipeline(
|
195 |
-
config, device=gpu, generator=transformer, text_encoder=text_encoder,
|
196 |
-
vae=APP_STATE["current_vae_decoder"]
|
197 |
-
)
|
198 |
-
|
199 |
-
pipeline.to(dtype=torch.float16).to(gpu)
|
200 |
-
|
201 |
-
@torch.no_grad()
|
202 |
-
def video_generation_handler_streaming(prompt, seed=42, fps=15):
|
203 |
-
"""
|
204 |
-
Generator function that yields .ts video chunks using PyAV for streaming.
|
205 |
-
Now optimized for block-based processing.
|
206 |
-
"""
|
207 |
-
if seed == -1:
|
208 |
-
seed = random.randint(0, 2**32 - 1)
|
209 |
-
|
210 |
-
print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
|
211 |
-
|
212 |
-
# Setup
|
213 |
-
conditional_dict = text_encoder(text_prompts=[prompt])
|
214 |
-
for key, value in conditional_dict.items():
|
215 |
-
conditional_dict[key] = value.to(dtype=torch.float16)
|
216 |
-
|
217 |
-
rnd = torch.Generator(gpu).manual_seed(int(seed))
|
218 |
-
pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
|
219 |
-
pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
|
220 |
-
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
|
221 |
-
|
222 |
-
vae_cache, latents_cache = None, None
|
223 |
-
if not APP_STATE["current_use_taehv"] and not args.trt:
|
224 |
-
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
|
225 |
-
|
226 |
-
num_blocks = 7
|
227 |
-
current_start_frame = 0
|
228 |
-
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
229 |
-
|
230 |
-
total_frames_yielded = 0
|
231 |
-
|
232 |
-
# Ensure temp directory exists
|
233 |
-
os.makedirs("gradio_tmp", exist_ok=True)
|
234 |
-
|
235 |
-
# Generation loop
|
236 |
-
for idx, current_num_frames in enumerate(all_num_frames):
|
237 |
-
print(f"📦 Processing block {idx+1}/{num_blocks}")
|
238 |
-
|
239 |
-
noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames]
|
240 |
-
|
241 |
-
# Denoising steps
|
242 |
-
for step_idx, current_timestep in enumerate(pipeline.denoising_step_list):
|
243 |
-
timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep
|
244 |
-
_, denoised_pred = pipeline.generator(
|
245 |
-
noisy_image_or_video=noisy_input, conditional_dict=conditional_dict,
|
246 |
-
timestep=timestep, kv_cache=pipeline.kv_cache1,
|
247 |
-
crossattn_cache=pipeline.crossattn_cache,
|
248 |
-
current_start=current_start_frame * pipeline.frame_seq_length
|
249 |
-
)
|
250 |
-
if step_idx < len(pipeline.denoising_step_list) - 1:
|
251 |
-
next_timestep = pipeline.denoising_step_list[step_idx + 1]
|
252 |
-
noisy_input = pipeline.scheduler.add_noise(
|
253 |
-
denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)),
|
254 |
-
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
|
255 |
-
).unflatten(0, denoised_pred.shape[:2])
|
256 |
-
|
257 |
-
if idx < len(all_num_frames) - 1:
|
258 |
-
pipeline.generator(
|
259 |
-
noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict,
|
260 |
-
timestep=torch.zeros_like(timestep), kv_cache=pipeline.kv_cache1,
|
261 |
-
crossattn_cache=pipeline.crossattn_cache,
|
262 |
-
current_start=current_start_frame * pipeline.frame_seq_length,
|
263 |
-
)
|
264 |
-
|
265 |
-
# Decode to pixels
|
266 |
-
if args.trt:
|
267 |
-
pixels, vae_cache = pipeline.vae.forward(denoised_pred.half(), *vae_cache)
|
268 |
-
elif APP_STATE["current_use_taehv"]:
|
269 |
-
if latents_cache is None:
|
270 |
-
latents_cache = denoised_pred
|
271 |
-
else:
|
272 |
-
denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1)
|
273 |
-
latents_cache = denoised_pred[:, -3:]
|
274 |
-
pixels = pipeline.vae.decode(denoised_pred)
|
275 |
-
else:
|
276 |
-
pixels, vae_cache = pipeline.vae(denoised_pred.half(), *vae_cache)
|
277 |
-
|
278 |
-
# Handle frame skipping
|
279 |
-
if idx == 0 and not args.trt:
|
280 |
-
pixels = pixels[:, 3:]
|
281 |
-
elif APP_STATE["current_use_taehv"] and idx > 0:
|
282 |
-
pixels = pixels[:, 12:]
|
283 |
-
|
284 |
-
print(f"🔍 DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
|
285 |
-
|
286 |
-
# Process all frames from this block at once
|
287 |
-
all_frames_from_block = []
|
288 |
-
for frame_idx in range(pixels.shape[1]):
|
289 |
-
frame_tensor = pixels[0, frame_idx]
|
290 |
-
|
291 |
-
# Convert to numpy (HWC, RGB, uint8)
|
292 |
-
frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
|
293 |
-
frame_np = frame_np.to(torch.uint8).cpu().numpy()
|
294 |
-
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
295 |
-
|
296 |
-
all_frames_from_block.append(frame_np)
|
297 |
-
total_frames_yielded += 1
|
298 |
-
|
299 |
-
# Yield status update for each frame (cute tracking!)
|
300 |
-
blocks_completed = idx
|
301 |
-
current_block_progress = (frame_idx + 1) / pixels.shape[1]
|
302 |
-
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
|
303 |
-
|
304 |
-
# Cap at 100% to avoid going over
|
305 |
-
total_progress = min(total_progress, 100.0)
|
306 |
-
|
307 |
-
frame_status_html = (
|
308 |
-
f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
|
309 |
-
f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
|
310 |
-
f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
|
311 |
-
f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
|
312 |
-
f" </div>"
|
313 |
-
f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
|
314 |
-
f" Block {idx+1}/{num_blocks} | Frame {total_frames_yielded} | {total_progress:.1f}%"
|
315 |
-
f" </p>"
|
316 |
-
f"</div>"
|
317 |
-
)
|
318 |
-
|
319 |
-
# Yield None for video but update status (frame-by-frame tracking)
|
320 |
-
yield None, frame_status_html
|
321 |
-
|
322 |
-
# Encode entire block as one chunk immediately
|
323 |
-
if all_frames_from_block:
|
324 |
-
print(f"📹 Encoding block {idx} with {len(all_frames_from_block)} frames")
|
325 |
-
|
326 |
-
try:
|
327 |
-
chunk_uuid = str(uuid.uuid4())[:8]
|
328 |
-
ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
|
329 |
-
ts_path = os.path.join("gradio_tmp", ts_filename)
|
330 |
-
|
331 |
-
frames_to_ts_file(all_frames_from_block, ts_path, fps)
|
332 |
-
|
333 |
-
# Calculate final progress for this block
|
334 |
-
total_progress = (idx + 1) / num_blocks * 100
|
335 |
-
|
336 |
-
# Yield the actual video chunk
|
337 |
-
yield ts_path, gr.update()
|
338 |
-
|
339 |
-
except Exception as e:
|
340 |
-
print(f"⚠️ Error encoding block {idx}: {e}")
|
341 |
-
import traceback
|
342 |
-
traceback.print_exc()
|
343 |
-
|
344 |
-
current_start_frame += current_num_frames
|
345 |
-
|
346 |
-
# Final completion status
|
347 |
-
final_status_html = (
|
348 |
-
f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
|
349 |
-
f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
|
350 |
-
f" <span style='font-size: 24px; margin-right: 12px;'>🎉</span>"
|
351 |
-
f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Stream Complete!</h4>"
|
352 |
-
f" </div>"
|
353 |
-
f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>"
|
354 |
-
f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>"
|
355 |
-
f" 📊 Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
356 |
-
f" </p>"
|
357 |
-
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
358 |
-
f" 🎬 Playback: {fps} FPS • 📁 Format: MPEG-TS/H.264"
|
359 |
-
f" </p>"
|
360 |
-
f" </div>"
|
361 |
-
f"</div>"
|
362 |
-
)
|
363 |
-
yield None, final_status_html
|
364 |
-
print(f"✅ PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
365 |
-
|
366 |
-
# --- Gradio UI Layout ---
|
367 |
-
with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
|
368 |
-
gr.Markdown("# 🚀 Self-Forcing Video Generation")
|
369 |
-
gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
|
370 |
-
|
371 |
-
with gr.Row():
|
372 |
-
with gr.Column(scale=2):
|
373 |
-
with gr.Group():
|
374 |
-
prompt = gr.Textbox(
|
375 |
-
label="Prompt",
|
376 |
-
placeholder="A stylish woman walks down a Tokyo street...",
|
377 |
-
lines=4,
|
378 |
-
value=""
|
379 |
-
)
|
380 |
-
|
381 |
-
start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
|
382 |
-
|
383 |
-
gr.Markdown("### 🎯 Examples")
|
384 |
-
gr.Examples(
|
385 |
-
examples=[
|
386 |
-
"A close-up shot of a ceramic teacup slowly pouring water into a glass mug.",
|
387 |
-
"A playful cat is seen playing an electronic guitar, strumming the strings with its front paws. The cat has distinctive black facial markings and a bushy tail. It sits comfortably on a small stool, its body slightly tilted as it focuses intently on the instrument. The setting is a cozy, dimly lit room with vintage posters on the walls, adding a retro vibe. The cat's expressive eyes convey a sense of joy and concentration. Medium close-up shot, focusing on the cat's face and hands interacting with the guitar.",
|
388 |
-
"A dynamic over-the-shoulder perspective of a chef meticulously plating a dish in a bustling kitchen. The chef, a middle-aged woman, deftly arranges ingredients on a pristine white plate. Her hands move with precision, each gesture deliberate and practiced. The background shows a crowded kitchen with steaming pots, whirring blenders, and the clatter of utensils. Bright lights highlight the scene, casting shadows across the busy workspace. The camera angle captures the chef's detailed work from behind, emphasizing his skill and dedication.",
|
389 |
-
],
|
390 |
-
inputs=[prompt],
|
391 |
-
)
|
392 |
-
|
393 |
-
gr.Markdown("### ⚙️ Settings")
|
394 |
-
with gr.Row():
|
395 |
-
seed = gr.Number(
|
396 |
-
label="Seed",
|
397 |
-
value=-1,
|
398 |
-
info="Use -1 for random seed",
|
399 |
-
precision=0
|
400 |
-
)
|
401 |
-
fps = gr.Slider(
|
402 |
-
label="Playback FPS",
|
403 |
-
minimum=1,
|
404 |
-
maximum=30,
|
405 |
-
value=args.fps,
|
406 |
-
step=1,
|
407 |
-
visible=False,
|
408 |
-
info="Frames per second for playback"
|
409 |
-
)
|
410 |
-
|
411 |
-
with gr.Column(scale=3):
|
412 |
-
gr.Markdown("### 📺 Video Stream")
|
413 |
-
|
414 |
-
streaming_video = gr.Video(
|
415 |
-
label="Live Stream",
|
416 |
-
streaming=True,
|
417 |
-
loop=True,
|
418 |
-
height=400,
|
419 |
-
autoplay=True,
|
420 |
-
show_label=False
|
421 |
-
)
|
422 |
-
|
423 |
-
status_display = gr.HTML(
|
424 |
-
value=(
|
425 |
-
"<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
|
426 |
-
"🎬 Ready to start streaming...<br>"
|
427 |
-
"<small>Configure your prompt and click 'Start Streaming'</small>"
|
428 |
-
"</div>"
|
429 |
-
),
|
430 |
-
label="Generation Status"
|
431 |
-
)
|
432 |
-
|
433 |
-
# Connect the generator to the streaming video
|
434 |
-
start_btn.click(
|
435 |
-
fn=video_generation_handler_streaming,
|
436 |
-
inputs=[prompt, seed, fps],
|
437 |
-
outputs=[streaming_video, status_display]
|
438 |
-
)
|
439 |
-
|
440 |
-
|
441 |
-
# --- Launch App ---
|
442 |
-
if __name__ == "__main__":
|
443 |
-
if os.path.exists("gradio_tmp"):
|
444 |
-
import shutil
|
445 |
-
shutil.rmtree("gradio_tmp")
|
446 |
-
os.makedirs("gradio_tmp", exist_ok=True)
|
447 |
-
|
448 |
-
print("🚀 Starting Self-Forcing Streaming Demo")
|
449 |
-
print(f"📁 Temporary files will be stored in: gradio_tmp/")
|
450 |
-
print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
|
451 |
-
print(f"⚡ GPU acceleration: {gpu}")
|
452 |
-
|
453 |
-
demo.queue().launch(
|
454 |
-
server_name=args.host,
|
455 |
-
server_port=args.port,
|
456 |
-
share=args.share,
|
457 |
-
show_error=True,
|
458 |
-
max_threads=40,
|
459 |
-
mcp_server=True
|
460 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.py
DELETED
@@ -1,631 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Demo for Self-Forcing.
|
3 |
-
"""
|
4 |
-
|
5 |
-
import os
|
6 |
-
import re
|
7 |
-
import random
|
8 |
-
import time
|
9 |
-
import base64
|
10 |
-
import argparse
|
11 |
-
import hashlib
|
12 |
-
import subprocess
|
13 |
-
import urllib.request
|
14 |
-
from io import BytesIO
|
15 |
-
from PIL import Image
|
16 |
-
import numpy as np
|
17 |
-
import torch
|
18 |
-
from omegaconf import OmegaConf
|
19 |
-
from flask import Flask, render_template, jsonify
|
20 |
-
from flask_socketio import SocketIO, emit
|
21 |
-
import queue
|
22 |
-
from threading import Thread, Event
|
23 |
-
|
24 |
-
from pipeline import CausalInferencePipeline
|
25 |
-
from demo_utils.constant import ZERO_VAE_CACHE
|
26 |
-
from demo_utils.vae_block3 import VAEDecoderWrapper
|
27 |
-
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
|
28 |
-
from demo_utils.utils import generate_timestamp
|
29 |
-
from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation
|
30 |
-
|
31 |
-
# Parse arguments
|
32 |
-
parser = argparse.ArgumentParser()
|
33 |
-
parser.add_argument('--port', type=int, default=5001)
|
34 |
-
parser.add_argument('--host', type=str, default='0.0.0.0')
|
35 |
-
parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt')
|
36 |
-
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml')
|
37 |
-
parser.add_argument('--trt', action='store_true')
|
38 |
-
args = parser.parse_args()
|
39 |
-
|
40 |
-
print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
|
41 |
-
low_memory = get_cuda_free_memory_gb(gpu) < 40
|
42 |
-
|
43 |
-
# Load models
|
44 |
-
config = OmegaConf.load(args.config_path)
|
45 |
-
default_config = OmegaConf.load("configs/default_config.yaml")
|
46 |
-
config = OmegaConf.merge(default_config, config)
|
47 |
-
|
48 |
-
text_encoder = WanTextEncoder()
|
49 |
-
|
50 |
-
# Global variables for dynamic model switching
|
51 |
-
current_vae_decoder = None
|
52 |
-
current_use_taehv = False
|
53 |
-
fp8_applied = False
|
54 |
-
torch_compile_applied = False
|
55 |
-
global frame_number
|
56 |
-
frame_number = 0
|
57 |
-
anim_name = ""
|
58 |
-
frame_rate = 6
|
59 |
-
|
60 |
-
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
61 |
-
"""Initialize VAE decoder based on the selected option"""
|
62 |
-
global current_vae_decoder, current_use_taehv
|
63 |
-
|
64 |
-
if use_trt:
|
65 |
-
from demo_utils.vae import VAETRTWrapper
|
66 |
-
current_vae_decoder = VAETRTWrapper()
|
67 |
-
return current_vae_decoder
|
68 |
-
|
69 |
-
if use_taehv:
|
70 |
-
from demo_utils.taehv import TAEHV
|
71 |
-
# Check if taew2_1.pth exists in checkpoints folder, download if missing
|
72 |
-
taehv_checkpoint_path = "checkpoints/taew2_1.pth"
|
73 |
-
if not os.path.exists(taehv_checkpoint_path):
|
74 |
-
print(f"taew2_1.pth not found in checkpoints folder {taehv_checkpoint_path}. Downloading...")
|
75 |
-
os.makedirs("checkpoints", exist_ok=True)
|
76 |
-
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
|
77 |
-
try:
|
78 |
-
urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
|
79 |
-
print(f"Successfully downloaded taew2_1.pth to {taehv_checkpoint_path}")
|
80 |
-
except Exception as e:
|
81 |
-
print(f"Failed to download taew2_1.pth: {e}")
|
82 |
-
raise
|
83 |
-
|
84 |
-
class DotDict(dict):
|
85 |
-
__getattr__ = dict.__getitem__
|
86 |
-
__setattr__ = dict.__setitem__
|
87 |
-
|
88 |
-
class TAEHVDiffusersWrapper(torch.nn.Module):
|
89 |
-
def __init__(self):
|
90 |
-
super().__init__()
|
91 |
-
self.dtype = torch.float16
|
92 |
-
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
|
93 |
-
self.config = DotDict(scaling_factor=1.0)
|
94 |
-
|
95 |
-
def decode(self, latents, return_dict=None):
|
96 |
-
# n, c, t, h, w = latents.shape
|
97 |
-
# low-memory, set parallel=True for faster + higher memory
|
98 |
-
return self.taehv.decode_video(latents, parallel=False).mul_(2).sub_(1)
|
99 |
-
|
100 |
-
current_vae_decoder = TAEHVDiffusersWrapper()
|
101 |
-
else:
|
102 |
-
current_vae_decoder = VAEDecoderWrapper()
|
103 |
-
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
|
104 |
-
decoder_state_dict = {}
|
105 |
-
for key, value in vae_state_dict.items():
|
106 |
-
if 'decoder.' in key or 'conv2' in key:
|
107 |
-
decoder_state_dict[key] = value
|
108 |
-
current_vae_decoder.load_state_dict(decoder_state_dict)
|
109 |
-
|
110 |
-
current_vae_decoder.eval()
|
111 |
-
current_vae_decoder.to(dtype=torch.float16)
|
112 |
-
current_vae_decoder.requires_grad_(False)
|
113 |
-
current_vae_decoder.to(gpu)
|
114 |
-
current_use_taehv = use_taehv
|
115 |
-
|
116 |
-
print(f"✅ VAE decoder initialized with {'TAEHV' if use_taehv else 'default VAE'}")
|
117 |
-
return current_vae_decoder
|
118 |
-
|
119 |
-
|
120 |
-
# Initialize with default VAE
|
121 |
-
vae_decoder = initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
|
122 |
-
|
123 |
-
transformer = WanDiffusionWrapper(is_causal=True)
|
124 |
-
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
|
125 |
-
transformer.load_state_dict(state_dict['generator_ema'])
|
126 |
-
|
127 |
-
text_encoder.eval()
|
128 |
-
transformer.eval()
|
129 |
-
|
130 |
-
transformer.to(dtype=torch.float16)
|
131 |
-
text_encoder.to(dtype=torch.bfloat16)
|
132 |
-
|
133 |
-
text_encoder.requires_grad_(False)
|
134 |
-
transformer.requires_grad_(False)
|
135 |
-
|
136 |
-
pipeline = CausalInferencePipeline(
|
137 |
-
config,
|
138 |
-
device=gpu,
|
139 |
-
generator=transformer,
|
140 |
-
text_encoder=text_encoder,
|
141 |
-
vae=vae_decoder
|
142 |
-
)
|
143 |
-
|
144 |
-
if low_memory:
|
145 |
-
DynamicSwapInstaller.install_model(text_encoder, device=gpu)
|
146 |
-
else:
|
147 |
-
text_encoder.to(gpu)
|
148 |
-
transformer.to(gpu)
|
149 |
-
|
150 |
-
# Flask and SocketIO setup
|
151 |
-
app = Flask(__name__)
|
152 |
-
app.config['SECRET_KEY'] = 'frontend_buffered_demo'
|
153 |
-
socketio = SocketIO(app, cors_allowed_origins="*")
|
154 |
-
|
155 |
-
generation_active = False
|
156 |
-
stop_event = Event()
|
157 |
-
frame_send_queue = queue.Queue()
|
158 |
-
sender_thread = None
|
159 |
-
models_compiled = False
|
160 |
-
|
161 |
-
|
162 |
-
def tensor_to_base64_frame(frame_tensor):
|
163 |
-
"""Convert a single frame tensor to base64 image string."""
|
164 |
-
global frame_number, anim_name
|
165 |
-
# Clamp and normalize to 0-255
|
166 |
-
frame = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
|
167 |
-
frame = frame.to(torch.uint8).cpu().numpy()
|
168 |
-
|
169 |
-
# CHW -> HWC
|
170 |
-
if len(frame.shape) == 3:
|
171 |
-
frame = np.transpose(frame, (1, 2, 0))
|
172 |
-
|
173 |
-
# Convert to PIL Image
|
174 |
-
if frame.shape[2] == 3: # RGB
|
175 |
-
image = Image.fromarray(frame, 'RGB')
|
176 |
-
else: # Handle other formats
|
177 |
-
image = Image.fromarray(frame)
|
178 |
-
|
179 |
-
# Convert to base64
|
180 |
-
buffer = BytesIO()
|
181 |
-
image.save(buffer, format='JPEG', quality=100)
|
182 |
-
if not os.path.exists("./images/%s" % anim_name):
|
183 |
-
os.makedirs("./images/%s" % anim_name)
|
184 |
-
frame_number += 1
|
185 |
-
image.save("./images/%s/%s_%03d.jpg" % (anim_name, anim_name, frame_number))
|
186 |
-
img_str = base64.b64encode(buffer.getvalue()).decode()
|
187 |
-
return f"data:image/jpeg;base64,{img_str}"
|
188 |
-
|
189 |
-
|
190 |
-
def frame_sender_worker():
|
191 |
-
"""Background thread that processes frame send queue non-blocking."""
|
192 |
-
global frame_send_queue, generation_active, stop_event
|
193 |
-
|
194 |
-
print("📡 Frame sender thread started")
|
195 |
-
|
196 |
-
while True:
|
197 |
-
frame_data = None
|
198 |
-
try:
|
199 |
-
# Get frame data from queue
|
200 |
-
frame_data = frame_send_queue.get(timeout=1.0)
|
201 |
-
|
202 |
-
if frame_data is None: # Shutdown signal
|
203 |
-
frame_send_queue.task_done() # Mark shutdown signal as done
|
204 |
-
break
|
205 |
-
|
206 |
-
frame_tensor, frame_index, block_index, job_id = frame_data
|
207 |
-
|
208 |
-
# Convert tensor to base64
|
209 |
-
base64_frame = tensor_to_base64_frame(frame_tensor)
|
210 |
-
|
211 |
-
# Send via SocketIO
|
212 |
-
try:
|
213 |
-
socketio.emit('frame_ready', {
|
214 |
-
'data': base64_frame,
|
215 |
-
'frame_index': frame_index,
|
216 |
-
'block_index': block_index,
|
217 |
-
'job_id': job_id
|
218 |
-
})
|
219 |
-
except Exception as e:
|
220 |
-
print(f"⚠️ Failed to send frame {frame_index}: {e}")
|
221 |
-
|
222 |
-
frame_send_queue.task_done()
|
223 |
-
|
224 |
-
except queue.Empty:
|
225 |
-
# Check if we should continue running
|
226 |
-
if not generation_active and frame_send_queue.empty():
|
227 |
-
break
|
228 |
-
except Exception as e:
|
229 |
-
print(f"❌ Frame sender error: {e}")
|
230 |
-
# Make sure to mark task as done even if there's an error
|
231 |
-
if frame_data is not None:
|
232 |
-
try:
|
233 |
-
frame_send_queue.task_done()
|
234 |
-
except Exception as e:
|
235 |
-
print(f"❌ Failed to mark frame task as done: {e}")
|
236 |
-
break
|
237 |
-
|
238 |
-
print("📡 Frame sender thread stopped")
|
239 |
-
|
240 |
-
|
241 |
-
@torch.no_grad()
|
242 |
-
def generate_video_stream(prompt, seed, enable_torch_compile=False, enable_fp8=False, use_taehv=False):
|
243 |
-
"""Generate video and push frames immediately to frontend."""
|
244 |
-
global generation_active, stop_event, frame_send_queue, sender_thread, models_compiled, torch_compile_applied, fp8_applied, current_vae_decoder, current_use_taehv, frame_rate, anim_name
|
245 |
-
|
246 |
-
try:
|
247 |
-
generation_active = True
|
248 |
-
stop_event.clear()
|
249 |
-
job_id = generate_timestamp()
|
250 |
-
|
251 |
-
# Start frame sender thread if not already running
|
252 |
-
if sender_thread is None or not sender_thread.is_alive():
|
253 |
-
sender_thread = Thread(target=frame_sender_worker, daemon=True)
|
254 |
-
sender_thread.start()
|
255 |
-
|
256 |
-
# Emit progress updates
|
257 |
-
def emit_progress(message, progress):
|
258 |
-
try:
|
259 |
-
socketio.emit('progress', {
|
260 |
-
'message': message,
|
261 |
-
'progress': progress,
|
262 |
-
'job_id': job_id
|
263 |
-
})
|
264 |
-
except Exception as e:
|
265 |
-
print(f"❌ Failed to emit progress: {e}")
|
266 |
-
|
267 |
-
emit_progress('Starting generation...', 0)
|
268 |
-
|
269 |
-
# Handle VAE decoder switching
|
270 |
-
if use_taehv != current_use_taehv:
|
271 |
-
emit_progress('Switching VAE decoder...', 2)
|
272 |
-
print(f"🔄 Switching VAE decoder to {'TAEHV' if use_taehv else 'default VAE'}")
|
273 |
-
current_vae_decoder = initialize_vae_decoder(use_taehv=use_taehv)
|
274 |
-
# Update pipeline with new VAE decoder
|
275 |
-
pipeline.vae = current_vae_decoder
|
276 |
-
|
277 |
-
# Handle FP8 quantization
|
278 |
-
if enable_fp8 and not fp8_applied:
|
279 |
-
emit_progress('Applying FP8 quantization...', 3)
|
280 |
-
print("🔧 Applying FP8 quantization to transformer")
|
281 |
-
from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
|
282 |
-
quantize_(transformer, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
|
283 |
-
fp8_applied = True
|
284 |
-
|
285 |
-
# Text encoding
|
286 |
-
emit_progress('Encoding text prompt...', 8)
|
287 |
-
conditional_dict = text_encoder(text_prompts=[prompt])
|
288 |
-
for key, value in conditional_dict.items():
|
289 |
-
conditional_dict[key] = value.to(dtype=torch.float16)
|
290 |
-
if low_memory:
|
291 |
-
gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
|
292 |
-
move_model_to_device_with_memory_preservation(
|
293 |
-
text_encoder,target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
|
294 |
-
|
295 |
-
# Handle torch.compile if enabled
|
296 |
-
torch_compile_applied = enable_torch_compile
|
297 |
-
if enable_torch_compile and not models_compiled:
|
298 |
-
# Compile transformer and decoder
|
299 |
-
transformer.compile(mode="max-autotune-no-cudagraphs")
|
300 |
-
if not current_use_taehv and not low_memory and not args.trt:
|
301 |
-
current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")
|
302 |
-
|
303 |
-
# Initialize generation
|
304 |
-
emit_progress('Initializing generation...', 12)
|
305 |
-
|
306 |
-
rnd = torch.Generator(gpu).manual_seed(seed)
|
307 |
-
# all_latents = torch.zeros([1, 21, 16, 60, 104], device=gpu, dtype=torch.bfloat16)
|
308 |
-
|
309 |
-
pipeline._initialize_kv_cache(batch_size=1, dtype=torch.float16, device=gpu)
|
310 |
-
pipeline._initialize_crossattn_cache(batch_size=1, dtype=torch.float16, device=gpu)
|
311 |
-
|
312 |
-
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
|
313 |
-
|
314 |
-
# Generation parameters
|
315 |
-
num_blocks = 7
|
316 |
-
current_start_frame = 0
|
317 |
-
num_input_frames = 0
|
318 |
-
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
319 |
-
if current_use_taehv:
|
320 |
-
vae_cache = None
|
321 |
-
else:
|
322 |
-
vae_cache = ZERO_VAE_CACHE
|
323 |
-
for i in range(len(vae_cache)):
|
324 |
-
vae_cache[i] = vae_cache[i].to(device=gpu, dtype=torch.float16)
|
325 |
-
|
326 |
-
total_frames_sent = 0
|
327 |
-
generation_start_time = time.time()
|
328 |
-
|
329 |
-
emit_progress('Generating frames... (frontend handles timing)', 15)
|
330 |
-
|
331 |
-
for idx, current_num_frames in enumerate(all_num_frames):
|
332 |
-
if not generation_active or stop_event.is_set():
|
333 |
-
break
|
334 |
-
|
335 |
-
progress = int(((idx + 1) / len(all_num_frames)) * 80) + 15
|
336 |
-
|
337 |
-
# Special message for first block with torch.compile
|
338 |
-
if idx == 0 and torch_compile_applied and not models_compiled:
|
339 |
-
emit_progress(
|
340 |
-
f'Processing block 1/{len(all_num_frames)} - Compiling models (may take 5-10 minutes)...', progress)
|
341 |
-
print(f"🔥 Processing block {idx+1}/{len(all_num_frames)}")
|
342 |
-
models_compiled = True
|
343 |
-
else:
|
344 |
-
emit_progress(f'Processing block {idx+1}/{len(all_num_frames)}...', progress)
|
345 |
-
print(f"🔄 Processing block {idx+1}/{len(all_num_frames)}")
|
346 |
-
|
347 |
-
block_start_time = time.time()
|
348 |
-
|
349 |
-
noisy_input = noise[:, current_start_frame -
|
350 |
-
num_input_frames:current_start_frame + current_num_frames - num_input_frames]
|
351 |
-
|
352 |
-
# Denoising loop
|
353 |
-
denoising_start = time.time()
|
354 |
-
for index, current_timestep in enumerate(pipeline.denoising_step_list):
|
355 |
-
if not generation_active or stop_event.is_set():
|
356 |
-
break
|
357 |
-
|
358 |
-
timestep = torch.ones([1, current_num_frames], device=noise.device,
|
359 |
-
dtype=torch.int64) * current_timestep
|
360 |
-
|
361 |
-
if index < len(pipeline.denoising_step_list) - 1:
|
362 |
-
_, denoised_pred = transformer(
|
363 |
-
noisy_image_or_video=noisy_input,
|
364 |
-
conditional_dict=conditional_dict,
|
365 |
-
timestep=timestep,
|
366 |
-
kv_cache=pipeline.kv_cache1,
|
367 |
-
crossattn_cache=pipeline.crossattn_cache,
|
368 |
-
current_start=current_start_frame * pipeline.frame_seq_length
|
369 |
-
)
|
370 |
-
next_timestep = pipeline.denoising_step_list[index + 1]
|
371 |
-
noisy_input = pipeline.scheduler.add_noise(
|
372 |
-
denoised_pred.flatten(0, 1),
|
373 |
-
torch.randn_like(denoised_pred.flatten(0, 1)),
|
374 |
-
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
|
375 |
-
).unflatten(0, denoised_pred.shape[:2])
|
376 |
-
else:
|
377 |
-
_, denoised_pred = transformer(
|
378 |
-
noisy_image_or_video=noisy_input,
|
379 |
-
conditional_dict=conditional_dict,
|
380 |
-
timestep=timestep,
|
381 |
-
kv_cache=pipeline.kv_cache1,
|
382 |
-
crossattn_cache=pipeline.crossattn_cache,
|
383 |
-
current_start=current_start_frame * pipeline.frame_seq_length
|
384 |
-
)
|
385 |
-
|
386 |
-
if not generation_active or stop_event.is_set():
|
387 |
-
break
|
388 |
-
|
389 |
-
denoising_time = time.time() - denoising_start
|
390 |
-
print(f"⚡ Block {idx+1} denoising completed in {denoising_time:.2f}s")
|
391 |
-
|
392 |
-
# Record output
|
393 |
-
# all_latents[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
|
394 |
-
|
395 |
-
# Update KV cache for next block
|
396 |
-
if idx != len(all_num_frames) - 1:
|
397 |
-
transformer(
|
398 |
-
noisy_image_or_video=denoised_pred,
|
399 |
-
conditional_dict=conditional_dict,
|
400 |
-
timestep=torch.zeros_like(timestep),
|
401 |
-
kv_cache=pipeline.kv_cache1,
|
402 |
-
crossattn_cache=pipeline.crossattn_cache,
|
403 |
-
current_start=current_start_frame * pipeline.frame_seq_length,
|
404 |
-
)
|
405 |
-
|
406 |
-
# Decode to pixels and send frames immediately
|
407 |
-
print(f"🎨 Decoding block {idx+1} to pixels...")
|
408 |
-
decode_start = time.time()
|
409 |
-
if args.trt:
|
410 |
-
all_current_pixels = []
|
411 |
-
for i in range(denoised_pred.shape[1]):
|
412 |
-
is_first_frame = torch.tensor(1.0).cuda().half() if idx == 0 and i == 0 else \
|
413 |
-
torch.tensor(0.0).cuda().half()
|
414 |
-
outputs = vae_decoder.forward(denoised_pred[:, i:i + 1, :, :, :].half(), is_first_frame, *vae_cache)
|
415 |
-
# outputs = vae_decoder.forward(denoised_pred.float(), *vae_cache)
|
416 |
-
current_pixels, vae_cache = outputs[0], outputs[1:]
|
417 |
-
print(current_pixels.max(), current_pixels.min())
|
418 |
-
all_current_pixels.append(current_pixels.clone())
|
419 |
-
pixels = torch.cat(all_current_pixels, dim=1)
|
420 |
-
if idx == 0:
|
421 |
-
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
|
422 |
-
else:
|
423 |
-
if current_use_taehv:
|
424 |
-
if vae_cache is None:
|
425 |
-
vae_cache = denoised_pred
|
426 |
-
else:
|
427 |
-
denoised_pred = torch.cat([vae_cache, denoised_pred], dim=1)
|
428 |
-
vae_cache = denoised_pred[:, -3:, :, :, :]
|
429 |
-
pixels = current_vae_decoder.decode(denoised_pred)
|
430 |
-
print(f"denoised_pred shape: {denoised_pred.shape}")
|
431 |
-
print(f"pixels shape: {pixels.shape}")
|
432 |
-
if idx == 0:
|
433 |
-
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
|
434 |
-
else:
|
435 |
-
pixels = pixels[:, 12:, :, :, :]
|
436 |
-
|
437 |
-
else:
|
438 |
-
pixels, vae_cache = current_vae_decoder(denoised_pred.half(), *vae_cache)
|
439 |
-
if idx == 0:
|
440 |
-
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
|
441 |
-
|
442 |
-
decode_time = time.time() - decode_start
|
443 |
-
print(f"🎨 Block {idx+1} VAE decoding completed in {decode_time:.2f}s")
|
444 |
-
|
445 |
-
# Queue frames for non-blocking sending
|
446 |
-
block_frames = pixels.shape[1]
|
447 |
-
print(f"📡 Queueing {block_frames} frames from block {idx+1} for sending...")
|
448 |
-
queue_start = time.time()
|
449 |
-
|
450 |
-
for frame_idx in range(block_frames):
|
451 |
-
if not generation_active or stop_event.is_set():
|
452 |
-
break
|
453 |
-
|
454 |
-
frame_tensor = pixels[0, frame_idx].cpu()
|
455 |
-
|
456 |
-
# Queue frame data in non-blocking way
|
457 |
-
frame_send_queue.put((frame_tensor, total_frames_sent, idx, job_id))
|
458 |
-
total_frames_sent += 1
|
459 |
-
|
460 |
-
queue_time = time.time() - queue_start
|
461 |
-
block_time = time.time() - block_start_time
|
462 |
-
print(f"✅ Block {idx+1} completed in {block_time:.2f}s ({block_frames} frames queued in {queue_time:.3f}s)")
|
463 |
-
|
464 |
-
current_start_frame += current_num_frames
|
465 |
-
|
466 |
-
generation_time = time.time() - generation_start_time
|
467 |
-
print(f"🎉 Generation completed in {generation_time:.2f}s! {total_frames_sent} frames queued for sending")
|
468 |
-
|
469 |
-
# Wait for all frames to be sent before completing
|
470 |
-
emit_progress('Waiting for all frames to be sent...', 97)
|
471 |
-
print("⏳ Waiting for all frames to be sent...")
|
472 |
-
frame_send_queue.join() # Wait for all queued frames to be processed
|
473 |
-
print("✅ All frames sent successfully!")
|
474 |
-
|
475 |
-
generate_mp4_from_images("./images","./videos/"+anim_name+".mp4", frame_rate )
|
476 |
-
# Final progress update
|
477 |
-
emit_progress('Generation complete!', 100)
|
478 |
-
|
479 |
-
try:
|
480 |
-
socketio.emit('generation_complete', {
|
481 |
-
'message': 'Video generation completed!',
|
482 |
-
'total_frames': total_frames_sent,
|
483 |
-
'generation_time': f"{generation_time:.2f}s",
|
484 |
-
'job_id': job_id
|
485 |
-
})
|
486 |
-
except Exception as e:
|
487 |
-
print(f"❌ Failed to emit generation complete: {e}")
|
488 |
-
|
489 |
-
except Exception as e:
|
490 |
-
print(f"❌ Generation failed: {e}")
|
491 |
-
try:
|
492 |
-
socketio.emit('error', {
|
493 |
-
'message': f'Generation failed: {str(e)}',
|
494 |
-
'job_id': job_id
|
495 |
-
})
|
496 |
-
except Exception as e:
|
497 |
-
print(f"❌ Failed to emit error: {e}")
|
498 |
-
finally:
|
499 |
-
generation_active = False
|
500 |
-
stop_event.set()
|
501 |
-
|
502 |
-
# Clean up sender thread
|
503 |
-
try:
|
504 |
-
frame_send_queue.put(None)
|
505 |
-
except Exception as e:
|
506 |
-
print(f"❌ Failed to put None in frame_send_queue: {e}")
|
507 |
-
|
508 |
-
|
509 |
-
def generate_mp4_from_images(image_directory, output_video_path, fps=24):
|
510 |
-
"""
|
511 |
-
Generate an MP4 video from a directory of images ordered alphabetically.
|
512 |
-
|
513 |
-
:param image_directory: Path to the directory containing images.
|
514 |
-
:param output_video_path: Path where the output MP4 will be saved.
|
515 |
-
:param fps: Frames per second for the output video.
|
516 |
-
"""
|
517 |
-
global anim_name
|
518 |
-
# Construct the ffmpeg command
|
519 |
-
cmd = [
|
520 |
-
'ffmpeg',
|
521 |
-
'-framerate', str(fps),
|
522 |
-
'-i', os.path.join(image_directory, anim_name+'/'+anim_name+'_%03d.jpg'), # Adjust the pattern if necessary
|
523 |
-
'-c:v', 'libx264',
|
524 |
-
'-pix_fmt', 'yuv420p',
|
525 |
-
output_video_path
|
526 |
-
]
|
527 |
-
try:
|
528 |
-
subprocess.run(cmd, check=True)
|
529 |
-
print(f"Video saved to {output_video_path}")
|
530 |
-
except subprocess.CalledProcessError as e:
|
531 |
-
print(f"An error occurred: {e}")
|
532 |
-
|
533 |
-
def calculate_sha256(data):
|
534 |
-
# Convert data to bytes if it's not already
|
535 |
-
if isinstance(data, str):
|
536 |
-
data = data.encode()
|
537 |
-
# Calculate SHA-256 hash
|
538 |
-
sha256_hash = hashlib.sha256(data).hexdigest()
|
539 |
-
return sha256_hash
|
540 |
-
|
541 |
-
# Socket.IO event handlers
|
542 |
-
@socketio.on('connect')
|
543 |
-
def handle_connect():
|
544 |
-
print('Client connected')
|
545 |
-
emit('status', {'message': 'Connected to frontend-buffered demo server'})
|
546 |
-
|
547 |
-
|
548 |
-
@socketio.on('disconnect')
|
549 |
-
def handle_disconnect():
|
550 |
-
print('Client disconnected')
|
551 |
-
|
552 |
-
|
553 |
-
@socketio.on('start_generation')
|
554 |
-
def handle_start_generation(data):
|
555 |
-
global generation_active, frame_number, anim_name, frame_rate
|
556 |
-
|
557 |
-
frame_number = 0
|
558 |
-
if generation_active:
|
559 |
-
emit('error', {'message': 'Generation already in progress'})
|
560 |
-
return
|
561 |
-
|
562 |
-
prompt = data.get('prompt', '')
|
563 |
-
|
564 |
-
seed = data.get('seed', -1)
|
565 |
-
if seed==-1:
|
566 |
-
seed = random.randint(0, 2**32)
|
567 |
-
|
568 |
-
# Extract words up to the first punctuation or newline
|
569 |
-
words_up_to_punctuation = re.split(r'[^\w\s]', prompt)[0].strip() if prompt else ''
|
570 |
-
if not words_up_to_punctuation:
|
571 |
-
words_up_to_punctuation = re.split(r'[\n\r]', prompt)[0].strip()
|
572 |
-
|
573 |
-
# Calculate SHA-256 hash of the entire prompt
|
574 |
-
sha256_hash = calculate_sha256(prompt)
|
575 |
-
|
576 |
-
# Create anim_name with the extracted words and first 10 characters of the hash
|
577 |
-
anim_name = f"{words_up_to_punctuation[:20]}_{str(seed)}_{sha256_hash[:10]}"
|
578 |
-
|
579 |
-
generation_active = True
|
580 |
-
generation_start_time = time.time()
|
581 |
-
enable_torch_compile = data.get('enable_torch_compile', False)
|
582 |
-
enable_fp8 = data.get('enable_fp8', False)
|
583 |
-
use_taehv = data.get('use_taehv', False)
|
584 |
-
frame_rate = data.get('fps', 6)
|
585 |
-
|
586 |
-
if not prompt:
|
587 |
-
emit('error', {'message': 'Prompt is required'})
|
588 |
-
return
|
589 |
-
|
590 |
-
# Start generation in background thread
|
591 |
-
socketio.start_background_task(generate_video_stream, prompt, seed,
|
592 |
-
enable_torch_compile, enable_fp8, use_taehv)
|
593 |
-
emit('status', {'message': 'Generation started - frames will be sent immediately'})
|
594 |
-
|
595 |
-
|
596 |
-
@socketio.on('stop_generation')
|
597 |
-
def handle_stop_generation():
|
598 |
-
global generation_active, stop_event, frame_send_queue
|
599 |
-
generation_active = False
|
600 |
-
stop_event.set()
|
601 |
-
|
602 |
-
# Signal sender thread to stop (will be processed after current frames)
|
603 |
-
try:
|
604 |
-
frame_send_queue.put(None)
|
605 |
-
except Exception as e:
|
606 |
-
print(f"❌ Failed to put None in frame_send_queue: {e}")
|
607 |
-
|
608 |
-
emit('status', {'message': 'Generation stopped'})
|
609 |
-
|
610 |
-
# Web routes
|
611 |
-
|
612 |
-
|
613 |
-
@app.route('/')
|
614 |
-
def index():
|
615 |
-
return render_template('demo.html')
|
616 |
-
|
617 |
-
|
618 |
-
@app.route('/api/status')
|
619 |
-
def api_status():
|
620 |
-
return jsonify({
|
621 |
-
'generation_active': generation_active,
|
622 |
-
'free_vram_gb': get_cuda_free_memory_gb(gpu),
|
623 |
-
'fp8_applied': fp8_applied,
|
624 |
-
'torch_compile_applied': torch_compile_applied,
|
625 |
-
'current_use_taehv': current_use_taehv
|
626 |
-
})
|
627 |
-
|
628 |
-
|
629 |
-
if __name__ == '__main__':
|
630 |
-
print(f"🚀 Starting demo on http://{args.host}:{args.port}")
|
631 |
-
socketio.run(app, host=args.host, port=args.port, debug=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/wan_wrapper.py
CHANGED
@@ -1,8 +1,14 @@
|
|
1 |
import types
|
2 |
from typing import List, Optional
|
|
|
3 |
import torch
|
4 |
from torch import nn
|
5 |
|
|
|
|
|
|
|
|
|
|
|
6 |
from utils.scheduler import SchedulerInterface, FlowMatchScheduler
|
7 |
from wan.modules.tokenizers import HuggingfaceTokenizer
|
8 |
from wan.modules.model import WanModel, RegisterTokens, GanAttentionBlock
|
@@ -22,12 +28,12 @@ class WanTextEncoder(torch.nn.Module):
|
|
22 |
device=torch.device('cpu')
|
23 |
).eval().requires_grad_(False)
|
24 |
self.text_encoder.load_state_dict(
|
25 |
-
torch.load("
|
26 |
map_location='cpu', weights_only=False)
|
27 |
)
|
28 |
|
29 |
self.tokenizer = HuggingfaceTokenizer(
|
30 |
-
name="
|
31 |
|
32 |
@property
|
33 |
def device(self):
|
@@ -66,7 +72,7 @@ class WanVAEWrapper(torch.nn.Module):
|
|
66 |
|
67 |
# init model
|
68 |
self.model = _video_vae(
|
69 |
-
pretrained_path="
|
70 |
z_dim=16,
|
71 |
).eval().requires_grad_(False)
|
72 |
|
@@ -125,9 +131,9 @@ class WanDiffusionWrapper(torch.nn.Module):
|
|
125 |
|
126 |
if is_causal:
|
127 |
self.model = CausalWanModel.from_pretrained(
|
128 |
-
|
129 |
else:
|
130 |
-
self.model = WanModel.from_pretrained(
|
131 |
self.model.eval()
|
132 |
|
133 |
# For non-causal diffusion, all frames share the same timestep
|
|
|
1 |
import types
|
2 |
from typing import List, Optional
|
3 |
+
import os
|
4 |
import torch
|
5 |
from torch import nn
|
6 |
|
7 |
+
# Configuration for data paths
|
8 |
+
DATA_ROOT = os.path.normpath(os.getenv('DATA_ROOT', '.'))
|
9 |
+
WAN_MODELS_PATH = os.path.join(DATA_ROOT, 'wan_models')
|
10 |
+
OTHER_MODELS_PATH = os.path.join(DATA_ROOT, 'other_models')
|
11 |
+
|
12 |
from utils.scheduler import SchedulerInterface, FlowMatchScheduler
|
13 |
from wan.modules.tokenizers import HuggingfaceTokenizer
|
14 |
from wan.modules.model import WanModel, RegisterTokens, GanAttentionBlock
|
|
|
28 |
device=torch.device('cpu')
|
29 |
).eval().requires_grad_(False)
|
30 |
self.text_encoder.load_state_dict(
|
31 |
+
torch.load(os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B", "models_t5_umt5-xxl-enc-bf16.pth"),
|
32 |
map_location='cpu', weights_only=False)
|
33 |
)
|
34 |
|
35 |
self.tokenizer = HuggingfaceTokenizer(
|
36 |
+
name=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B", "google", "umt5-xxl") + "/", seq_len=512, clean='whitespace')
|
37 |
|
38 |
@property
|
39 |
def device(self):
|
|
|
72 |
|
73 |
# init model
|
74 |
self.model = _video_vae(
|
75 |
+
pretrained_path=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B", "Wan2.1_VAE.pth"),
|
76 |
z_dim=16,
|
77 |
).eval().requires_grad_(False)
|
78 |
|
|
|
131 |
|
132 |
if is_causal:
|
133 |
self.model = CausalWanModel.from_pretrained(
|
134 |
+
os.path.join(WAN_MODELS_PATH, model_name) + "/", local_attn_size=local_attn_size, sink_size=sink_size)
|
135 |
else:
|
136 |
+
self.model = WanModel.from_pretrained(os.path.join(WAN_MODELS_PATH, model_name) + "/")
|
137 |
self.model.eval()
|
138 |
|
139 |
# For non-causal diffusion, all frames share the same timestep
|