Create app_broken_lora.py
Browse files- app_broken_lora.py +654 -0
app_broken_lora.py
ADDED
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
# not sure why it works in the original space but says "pip not found" in mine
|
3 |
+
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
4 |
+
|
5 |
+
import os
|
6 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
7 |
+
|
8 |
+
# Configuration for data paths
|
9 |
+
DATA_ROOT = os.path.normpath(os.getenv('DATA_ROOT', '.'))
|
10 |
+
WAN_MODELS_PATH = os.path.join(DATA_ROOT, 'wan_models')
|
11 |
+
OTHER_MODELS_PATH = os.path.join(DATA_ROOT, 'other_models')
|
12 |
+
|
13 |
+
snapshot_download(
|
14 |
+
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
|
15 |
+
local_dir=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B"),
|
16 |
+
local_dir_use_symlinks=False,
|
17 |
+
resume_download=True,
|
18 |
+
repo_type="model"
|
19 |
+
)
|
20 |
+
|
21 |
+
hf_hub_download(
|
22 |
+
repo_id="gdhe17/Self-Forcing",
|
23 |
+
filename="checkpoints/self_forcing_dmd.pt",
|
24 |
+
local_dir=OTHER_MODELS_PATH,
|
25 |
+
local_dir_use_symlinks=False
|
26 |
+
)
|
27 |
+
import re
|
28 |
+
import random
|
29 |
+
import argparse
|
30 |
+
import hashlib
|
31 |
+
import urllib.request
|
32 |
+
import time
|
33 |
+
from PIL import Image
|
34 |
+
import torch
|
35 |
+
import gradio as gr
|
36 |
+
from omegaconf import OmegaConf
|
37 |
+
from tqdm import tqdm
|
38 |
+
import imageio
|
39 |
+
import av
|
40 |
+
import uuid
|
41 |
+
import tempfile
|
42 |
+
import shutil
|
43 |
+
from pathlib import Path
|
44 |
+
from typing import Dict, Any, List, Optional, Tuple, Union
|
45 |
+
|
46 |
+
from pipeline import CausalInferencePipeline
|
47 |
+
from demo_utils.constant import ZERO_VAE_CACHE
|
48 |
+
from demo_utils.vae_block3 import VAEDecoderWrapper
|
49 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
|
50 |
+
|
51 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM #, BitsAndBytesConfig
|
52 |
+
import numpy as np
|
53 |
+
|
54 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
+
|
56 |
+
# LoRA Storage Configuration
|
57 |
+
STORAGE_PATH = Path(DATA_ROOT) / "storage"
|
58 |
+
LORA_PATH = STORAGE_PATH / "loras"
|
59 |
+
OUTPUT_PATH = STORAGE_PATH / "output"
|
60 |
+
|
61 |
+
# Create necessary directories
|
62 |
+
STORAGE_PATH.mkdir(parents=True, exist_ok=True)
|
63 |
+
LORA_PATH.mkdir(parents=True, exist_ok=True)
|
64 |
+
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
|
65 |
+
|
66 |
+
# Global variables for LoRA management
|
67 |
+
current_lora_id = None
|
68 |
+
current_lora_path = None
|
69 |
+
|
70 |
+
# --- Argument Parsing ---
|
71 |
+
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
|
72 |
+
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
|
73 |
+
parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
|
74 |
+
parser.add_argument("--checkpoint_path", type=str, default=os.path.join(OTHER_MODELS_PATH, 'checkpoints', 'self_forcing_dmd.pt'), help="Path to the model checkpoint.")
|
75 |
+
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
|
76 |
+
parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
|
77 |
+
parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
|
78 |
+
parser.add_argument('--fps', type=float, default=15.0, help="Playback FPS for frame streaming.")
|
79 |
+
args = parser.parse_args()
|
80 |
+
|
81 |
+
gpu = "cuda"
|
82 |
+
|
83 |
+
try:
|
84 |
+
config = OmegaConf.load(args.config_path)
|
85 |
+
default_config = OmegaConf.load("configs/default_config.yaml")
|
86 |
+
config = OmegaConf.merge(default_config, config)
|
87 |
+
except FileNotFoundError as e:
|
88 |
+
print(f"Error loading config file: {e}\n. Please ensure config files are in the correct path.")
|
89 |
+
exit(1)
|
90 |
+
|
91 |
+
# Initialize Models
|
92 |
+
print("Initializing models...")
|
93 |
+
text_encoder = WanTextEncoder()
|
94 |
+
transformer = WanDiffusionWrapper(is_causal=True)
|
95 |
+
|
96 |
+
try:
|
97 |
+
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
|
98 |
+
transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator')))
|
99 |
+
except FileNotFoundError as e:
|
100 |
+
print(f"Error loading checkpoint: {e}\nPlease ensure the checkpoint '{args.checkpoint_path}' exists.")
|
101 |
+
exit(1)
|
102 |
+
|
103 |
+
text_encoder.eval().to(dtype=torch.float16).requires_grad_(False)
|
104 |
+
transformer.eval().to(dtype=torch.float16).requires_grad_(False)
|
105 |
+
|
106 |
+
text_encoder.to(gpu)
|
107 |
+
transformer.to(gpu)
|
108 |
+
|
109 |
+
APP_STATE = {
|
110 |
+
"torch_compile_applied": False,
|
111 |
+
"fp8_applied": False,
|
112 |
+
"current_use_taehv": False,
|
113 |
+
"current_vae_decoder": None,
|
114 |
+
}
|
115 |
+
|
116 |
+
# I've tried to enable it, but I didn't notice a significant performance improvement..
|
117 |
+
ENABLE_TORCH_COMPILATION = False
|
118 |
+
|
119 |
+
# βdefaultβ: The default mode, used when no mode parameter is specified. It provides a good balance between performance and overhead.
|
120 |
+
# βreduce-overheadβ: Minimizes Python-related overhead using CUDA graphs. However, it may increase memory usage.
|
121 |
+
# βmax-autotuneβ: Uses Triton or template-based matrix multiplications on supported devices. It takes longer to compile but optimizes for the fastest possible execution. On GPUs it enables CUDA graphs by default.
|
122 |
+
# βmax-autotune-no-cudagraphsβ: Similar to βmax-autotuneβ, but without CUDA graphs.
|
123 |
+
TORCH_COMPILATION_MODE = "default"
|
124 |
+
|
125 |
+
# Apply torch.compile for maximum performance
|
126 |
+
if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
|
127 |
+
print("π Applying torch.compile for speed optimization...")
|
128 |
+
transformer.compile(mode=TORCH_COMPILATION_MODE)
|
129 |
+
APP_STATE["torch_compile_applied"] = True
|
130 |
+
print("β
torch.compile applied to transformer")
|
131 |
+
|
132 |
+
def upload_lora_file(file: tempfile._TemporaryFileWrapper) -> Tuple[str, str]:
|
133 |
+
"""Upload a LoRA file and return a hash-based ID for future reference"""
|
134 |
+
if file is None:
|
135 |
+
return "", ""
|
136 |
+
|
137 |
+
try:
|
138 |
+
# Calculate SHA256 hash of the file
|
139 |
+
sha256_hash = hashlib.sha256()
|
140 |
+
with open(file.name, "rb") as f:
|
141 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
142 |
+
sha256_hash.update(chunk)
|
143 |
+
file_hash = sha256_hash.hexdigest()
|
144 |
+
|
145 |
+
# Create destination path using hash
|
146 |
+
dest_path = LORA_PATH / f"{file_hash}.safetensors"
|
147 |
+
|
148 |
+
# Check if file already exists
|
149 |
+
if dest_path.exists():
|
150 |
+
print(f"LoRA file already exists!")
|
151 |
+
return file_hash, file_hash
|
152 |
+
|
153 |
+
# Copy the file to the destination
|
154 |
+
shutil.copy(file.name, dest_path)
|
155 |
+
|
156 |
+
print(f"LoRA file uploaded!")
|
157 |
+
return file_hash, file_hash
|
158 |
+
except Exception as e:
|
159 |
+
print(f"Error uploading LoRA file: {e}")
|
160 |
+
raise gr.Error(f"Failed to upload LoRA file: {str(e)}")
|
161 |
+
|
162 |
+
def get_lora_file_path(lora_id: Optional[str]) -> Optional[Path]:
|
163 |
+
"""Get the path to a LoRA file from its hash-based ID"""
|
164 |
+
if not lora_id:
|
165 |
+
return None
|
166 |
+
|
167 |
+
# Check if file exists
|
168 |
+
lora_path = LORA_PATH / f"{lora_id}.safetensors"
|
169 |
+
if lora_path.exists():
|
170 |
+
return lora_path
|
171 |
+
|
172 |
+
return None
|
173 |
+
|
174 |
+
def manage_lora_weights(lora_id: Optional[str], lora_weight: float) -> Tuple[bool, Optional[Path]]:
|
175 |
+
"""Manage LoRA weights for the transformer model"""
|
176 |
+
global current_lora_id, current_lora_path
|
177 |
+
|
178 |
+
# Determine if we should use LoRA
|
179 |
+
using_lora = lora_id is not None and lora_id.strip() != "" and lora_weight > 0
|
180 |
+
|
181 |
+
# If not using LoRA but we have one loaded, clear it
|
182 |
+
if not using_lora and current_lora_id is not None:
|
183 |
+
print(f"Clearing current LoRA")
|
184 |
+
current_lora_id = None
|
185 |
+
current_lora_path = None
|
186 |
+
return False, None
|
187 |
+
|
188 |
+
# If using LoRA, check if we need to change weights
|
189 |
+
if using_lora:
|
190 |
+
lora_path = get_lora_file_path(lora_id)
|
191 |
+
|
192 |
+
if not lora_path:
|
193 |
+
print(f"A LoRA file with this ID was found. Using base model instead.")
|
194 |
+
|
195 |
+
# If we had a LoRA loaded, clear it
|
196 |
+
if current_lora_id is not None:
|
197 |
+
print(f"Clearing current LoRA")
|
198 |
+
current_lora_id = None
|
199 |
+
current_lora_path = None
|
200 |
+
|
201 |
+
return False, None
|
202 |
+
|
203 |
+
# If LoRA ID changed, update
|
204 |
+
if lora_id != current_lora_id:
|
205 |
+
print(f"Loading LoRA..")
|
206 |
+
current_lora_id = lora_id
|
207 |
+
current_lora_path = lora_path
|
208 |
+
else:
|
209 |
+
print(f"Using a LoRA!")
|
210 |
+
|
211 |
+
return True, lora_path
|
212 |
+
|
213 |
+
return False, None
|
214 |
+
|
215 |
+
def frames_to_ts_file(frames, filepath, fps = 15):
|
216 |
+
"""
|
217 |
+
Convert frames directly to .ts file using PyAV.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
frames: List of numpy arrays (HWC, RGB, uint8)
|
221 |
+
filepath: Output file path
|
222 |
+
fps: Frames per second
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
The filepath of the created file
|
226 |
+
"""
|
227 |
+
if not frames:
|
228 |
+
return filepath
|
229 |
+
|
230 |
+
height, width = frames[0].shape[:2]
|
231 |
+
|
232 |
+
# Create container for MPEG-TS format
|
233 |
+
container = av.open(filepath, mode='w', format='mpegts')
|
234 |
+
|
235 |
+
# Add video stream with optimized settings for streaming
|
236 |
+
stream = container.add_stream('h264', rate=fps)
|
237 |
+
stream.width = width
|
238 |
+
stream.height = height
|
239 |
+
stream.pix_fmt = 'yuv420p'
|
240 |
+
|
241 |
+
# Optimize for low latency streaming
|
242 |
+
stream.options = {
|
243 |
+
'preset': 'ultrafast',
|
244 |
+
'tune': 'zerolatency',
|
245 |
+
'crf': '23',
|
246 |
+
'profile': 'baseline',
|
247 |
+
'level': '3.0'
|
248 |
+
}
|
249 |
+
|
250 |
+
try:
|
251 |
+
for frame_np in frames:
|
252 |
+
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
253 |
+
frame = frame.reformat(format=stream.pix_fmt)
|
254 |
+
for packet in stream.encode(frame):
|
255 |
+
container.mux(packet)
|
256 |
+
|
257 |
+
for packet in stream.encode():
|
258 |
+
container.mux(packet)
|
259 |
+
|
260 |
+
finally:
|
261 |
+
container.close()
|
262 |
+
|
263 |
+
return filepath
|
264 |
+
|
265 |
+
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
266 |
+
if use_trt:
|
267 |
+
from demo_utils.vae import VAETRTWrapper
|
268 |
+
print("Initializing TensorRT VAE Decoder...")
|
269 |
+
vae_decoder = VAETRTWrapper()
|
270 |
+
APP_STATE["current_use_taehv"] = False
|
271 |
+
elif use_taehv:
|
272 |
+
print("Initializing TAEHV VAE Decoder...")
|
273 |
+
from demo_utils.taehv import TAEHV
|
274 |
+
taehv_checkpoint_path = "checkpoints/taew2_1.pth"
|
275 |
+
if not os.path.exists(taehv_checkpoint_path):
|
276 |
+
print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...")
|
277 |
+
os.makedirs("checkpoints", exist_ok=True)
|
278 |
+
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
|
279 |
+
try:
|
280 |
+
urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
|
281 |
+
except Exception as e:
|
282 |
+
raise RuntimeError(f"Failed to download taew2_1.pth: {e}")
|
283 |
+
|
284 |
+
class DotDict(dict): __getattr__ = dict.get
|
285 |
+
|
286 |
+
class TAEHVDiffusersWrapper(torch.nn.Module):
|
287 |
+
def __init__(self):
|
288 |
+
super().__init__()
|
289 |
+
self.dtype = torch.float16
|
290 |
+
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
|
291 |
+
self.config = DotDict(scaling_factor=1.0)
|
292 |
+
def decode(self, latents, return_dict=None):
|
293 |
+
return self.taehv.decode_video(latents, parallel=not LOW_MEMORY).mul_(2).sub_(1)
|
294 |
+
|
295 |
+
vae_decoder = TAEHVDiffusersWrapper()
|
296 |
+
APP_STATE["current_use_taehv"] = True
|
297 |
+
else:
|
298 |
+
print("Initializing Default VAE Decoder...")
|
299 |
+
vae_decoder = VAEDecoderWrapper()
|
300 |
+
try:
|
301 |
+
vae_state_dict = torch.load(os.path.join(WAN_MODELS_PATH, 'Wan2.1-T2V-1.3B', 'Wan2.1_VAE.pth'), map_location="cpu")
|
302 |
+
decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
|
303 |
+
vae_decoder.load_state_dict(decoder_state_dict)
|
304 |
+
except FileNotFoundError:
|
305 |
+
print("Warning: Default VAE weights not found.")
|
306 |
+
APP_STATE["current_use_taehv"] = False
|
307 |
+
|
308 |
+
vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(gpu)
|
309 |
+
|
310 |
+
# Apply torch.compile to VAE decoder if enabled (following demo.py pattern)
|
311 |
+
if APP_STATE["torch_compile_applied"] and not use_taehv and not use_trt:
|
312 |
+
print("π Applying torch.compile to VAE decoder...")
|
313 |
+
vae_decoder.compile(mode=TORCH_COMPILATION_MODE)
|
314 |
+
print("β
torch.compile applied to VAE decoder")
|
315 |
+
|
316 |
+
APP_STATE["current_vae_decoder"] = vae_decoder
|
317 |
+
print(f"β
VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
|
318 |
+
|
319 |
+
# Initialize with default VAE
|
320 |
+
initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
|
321 |
+
|
322 |
+
pipeline = CausalInferencePipeline(
|
323 |
+
config, device=gpu, generator=transformer, text_encoder=text_encoder,
|
324 |
+
vae=APP_STATE["current_vae_decoder"]
|
325 |
+
)
|
326 |
+
|
327 |
+
pipeline.to(dtype=torch.float16).to(gpu)
|
328 |
+
|
329 |
+
@torch.no_grad()
|
330 |
+
def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5, lora_id=None, lora_weight=0.0):
|
331 |
+
"""
|
332 |
+
Generator function that yields .ts video chunks using PyAV for streaming.
|
333 |
+
"""
|
334 |
+
if seed == -1:
|
335 |
+
seed = random.randint(0, 2**32 - 1)
|
336 |
+
|
337 |
+
# print(f"π¬ Starting PyAV streaming: seed: {seed}, duration: {duration}s")
|
338 |
+
|
339 |
+
# Handle LoRA weights
|
340 |
+
using_lora, lora_path = manage_lora_weights(lora_id, lora_weight)
|
341 |
+
if using_lora:
|
342 |
+
print(f"π¨ Using LoRA with weight factor {lora_weight}")
|
343 |
+
else:
|
344 |
+
print("π¨ Using base model (no LoRA)")
|
345 |
+
|
346 |
+
# Setup
|
347 |
+
conditional_dict = text_encoder(text_prompts=[prompt])
|
348 |
+
for key, value in conditional_dict.items():
|
349 |
+
conditional_dict[key] = value.to(dtype=torch.float16)
|
350 |
+
|
351 |
+
rnd = torch.Generator(gpu).manual_seed(int(seed))
|
352 |
+
pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
|
353 |
+
pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
|
354 |
+
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
|
355 |
+
|
356 |
+
vae_cache, latents_cache = None, None
|
357 |
+
if not APP_STATE["current_use_taehv"] and not args.trt:
|
358 |
+
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
|
359 |
+
|
360 |
+
# Calculate number of blocks based on duration
|
361 |
+
# Current setup generates approximately 5 seconds with 7 blocks
|
362 |
+
# So we scale proportionally
|
363 |
+
base_duration = 5.0 # seconds
|
364 |
+
base_blocks = 8
|
365 |
+
num_blocks = max(1, int(base_blocks * duration / base_duration))
|
366 |
+
|
367 |
+
current_start_frame = 0
|
368 |
+
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
369 |
+
|
370 |
+
total_frames_yielded = 0
|
371 |
+
|
372 |
+
# Ensure temp directory exists
|
373 |
+
os.makedirs("gradio_tmp", exist_ok=True)
|
374 |
+
|
375 |
+
# Generation loop
|
376 |
+
for idx, current_num_frames in enumerate(all_num_frames):
|
377 |
+
print(f"π¦ Processing block {idx+1}/{num_blocks}")
|
378 |
+
|
379 |
+
noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames]
|
380 |
+
|
381 |
+
# Denoising steps
|
382 |
+
for step_idx, current_timestep in enumerate(pipeline.denoising_step_list):
|
383 |
+
timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep
|
384 |
+
_, denoised_pred = pipeline.generator(
|
385 |
+
noisy_image_or_video=noisy_input, conditional_dict=conditional_dict,
|
386 |
+
timestep=timestep, kv_cache=pipeline.kv_cache1,
|
387 |
+
crossattn_cache=pipeline.crossattn_cache,
|
388 |
+
current_start=current_start_frame * pipeline.frame_seq_length
|
389 |
+
)
|
390 |
+
if step_idx < len(pipeline.denoising_step_list) - 1:
|
391 |
+
next_timestep = pipeline.denoising_step_list[step_idx + 1]
|
392 |
+
noisy_input = pipeline.scheduler.add_noise(
|
393 |
+
denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)),
|
394 |
+
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
|
395 |
+
).unflatten(0, denoised_pred.shape[:2])
|
396 |
+
|
397 |
+
if idx < len(all_num_frames) - 1:
|
398 |
+
pipeline.generator(
|
399 |
+
noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict,
|
400 |
+
timestep=torch.zeros_like(timestep), kv_cache=pipeline.kv_cache1,
|
401 |
+
crossattn_cache=pipeline.crossattn_cache,
|
402 |
+
current_start=current_start_frame * pipeline.frame_seq_length,
|
403 |
+
)
|
404 |
+
|
405 |
+
# Decode to pixels
|
406 |
+
if args.trt:
|
407 |
+
pixels, vae_cache = pipeline.vae.forward(denoised_pred.half(), *vae_cache)
|
408 |
+
elif APP_STATE["current_use_taehv"]:
|
409 |
+
if latents_cache is None:
|
410 |
+
latents_cache = denoised_pred
|
411 |
+
else:
|
412 |
+
denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1)
|
413 |
+
latents_cache = denoised_pred[:, -3:]
|
414 |
+
pixels = pipeline.vae.decode(denoised_pred)
|
415 |
+
else:
|
416 |
+
pixels, vae_cache = pipeline.vae(denoised_pred.half(), *vae_cache)
|
417 |
+
|
418 |
+
# Handle frame skipping
|
419 |
+
if idx == 0 and not args.trt:
|
420 |
+
pixels = pixels[:, 3:]
|
421 |
+
elif APP_STATE["current_use_taehv"] and idx > 0:
|
422 |
+
pixels = pixels[:, 12:]
|
423 |
+
|
424 |
+
print(f"π DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
|
425 |
+
|
426 |
+
# Process all frames from this block at once
|
427 |
+
all_frames_from_block = []
|
428 |
+
for frame_idx in range(pixels.shape[1]):
|
429 |
+
frame_tensor = pixels[0, frame_idx]
|
430 |
+
|
431 |
+
# Convert to numpy (HWC, RGB, uint8)
|
432 |
+
frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
|
433 |
+
frame_np = frame_np.to(torch.uint8).cpu().numpy()
|
434 |
+
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
435 |
+
|
436 |
+
all_frames_from_block.append(frame_np)
|
437 |
+
total_frames_yielded += 1
|
438 |
+
|
439 |
+
# Yield status update for each frame (cute tracking!)
|
440 |
+
blocks_completed = idx
|
441 |
+
current_block_progress = (frame_idx + 1) / pixels.shape[1]
|
442 |
+
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
|
443 |
+
|
444 |
+
# Cap at 100% to avoid going over
|
445 |
+
total_progress = min(total_progress, 100.0)
|
446 |
+
|
447 |
+
frame_status_html = (
|
448 |
+
f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
|
449 |
+
f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
|
450 |
+
f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
|
451 |
+
f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
|
452 |
+
f" </div>"
|
453 |
+
f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
|
454 |
+
f" Block {idx+1}/{num_blocks} | Frame {total_frames_yielded} | {total_progress:.1f}%"
|
455 |
+
f" </p>"
|
456 |
+
f"</div>"
|
457 |
+
)
|
458 |
+
|
459 |
+
# Yield None for video but update status (frame-by-frame tracking)
|
460 |
+
yield None, frame_status_html
|
461 |
+
|
462 |
+
# Encode entire block as one chunk
|
463 |
+
if all_frames_from_block:
|
464 |
+
print(f"πΉ Encoding block {idx} with {len(all_frames_from_block)} frames")
|
465 |
+
|
466 |
+
try:
|
467 |
+
chunk_uuid = str(uuid.uuid4())[:8]
|
468 |
+
ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
|
469 |
+
ts_path = os.path.join("gradio_tmp", ts_filename)
|
470 |
+
|
471 |
+
frames_to_ts_file(all_frames_from_block, ts_path, fps)
|
472 |
+
|
473 |
+
# Calculate final progress for this block
|
474 |
+
total_progress = (idx + 1) / num_blocks * 100
|
475 |
+
|
476 |
+
# Yield the actual video chunk
|
477 |
+
yield ts_path, gr.update()
|
478 |
+
|
479 |
+
except Exception as e:
|
480 |
+
print(f"β οΈ Error encoding block {idx}: {e}")
|
481 |
+
import traceback
|
482 |
+
traceback.print_exc()
|
483 |
+
|
484 |
+
current_start_frame += current_num_frames
|
485 |
+
|
486 |
+
# Final completion status
|
487 |
+
final_status_html = (
|
488 |
+
f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
|
489 |
+
f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
|
490 |
+
f" <span style='font-size: 24px; margin-right: 12px;'>π</span>"
|
491 |
+
f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Stream Complete!</h4>"
|
492 |
+
f" </div>"
|
493 |
+
f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>"
|
494 |
+
f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>"
|
495 |
+
f" π Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
496 |
+
f" </p>"
|
497 |
+
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
498 |
+
f" π¬ Playback: {fps} FPS β’ π Format: MPEG-TS/H.264"
|
499 |
+
f" </p>"
|
500 |
+
f" </div>"
|
501 |
+
f"</div>"
|
502 |
+
)
|
503 |
+
yield None, final_status_html
|
504 |
+
print(f"β
PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
505 |
+
|
506 |
+
# --- Gradio UI Layout ---
|
507 |
+
with gr.Blocks(title="Wan2.1 1.3B LoRA Self-Forcing streaming demo") as demo:
|
508 |
+
gr.Markdown("# π Run Any LoRA in near real-time!")
|
509 |
+
gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B and LoRA [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
|
510 |
+
|
511 |
+
with gr.Tabs():
|
512 |
+
# LoRA Upload Tab
|
513 |
+
with gr.TabItem("1οΈβ£ Upload LoRA"):
|
514 |
+
gr.Markdown("## Upload LoRA Weights")
|
515 |
+
gr.Markdown("Upload your custom LoRA weights file to use for generation. The file will be automatically stored and you'll receive a unique hash-based ID.")
|
516 |
+
|
517 |
+
with gr.Row():
|
518 |
+
lora_file = gr.File(label="LoRA File (safetensors format)")
|
519 |
+
|
520 |
+
with gr.Row():
|
521 |
+
lora_id_output = gr.Textbox(label="LoRA Hash ID (use this in the generation tab)", interactive=False)
|
522 |
+
|
523 |
+
# Video Generation Tab
|
524 |
+
with gr.TabItem("2οΈβ£ Generate Video"):
|
525 |
+
with gr.Row():
|
526 |
+
with gr.Column(scale=2):
|
527 |
+
with gr.Group():
|
528 |
+
prompt = gr.Textbox(
|
529 |
+
label="Prompt",
|
530 |
+
placeholder="A stylish woman walks down a Tokyo street...",
|
531 |
+
lines=4,
|
532 |
+
value=""
|
533 |
+
)
|
534 |
+
|
535 |
+
start_btn = gr.Button("π¬ Start Streaming", variant="primary", size="lg")
|
536 |
+
|
537 |
+
gr.Markdown("### βοΈ Settings")
|
538 |
+
with gr.Row():
|
539 |
+
seed = gr.Number(
|
540 |
+
label="Seed",
|
541 |
+
value=-1,
|
542 |
+
info="Use -1 for random seed",
|
543 |
+
precision=0
|
544 |
+
)
|
545 |
+
fps = gr.Slider(
|
546 |
+
label="Playback FPS",
|
547 |
+
minimum=1,
|
548 |
+
maximum=30,
|
549 |
+
value=args.fps,
|
550 |
+
step=1,
|
551 |
+
visible=False,
|
552 |
+
info="Frames per second for playback"
|
553 |
+
)
|
554 |
+
|
555 |
+
with gr.Row():
|
556 |
+
duration = gr.Slider(
|
557 |
+
label="Duration (seconds)",
|
558 |
+
minimum=1,
|
559 |
+
maximum=5,
|
560 |
+
value=3,
|
561 |
+
step=1,
|
562 |
+
info="Video duration in seconds"
|
563 |
+
)
|
564 |
+
|
565 |
+
with gr.Row():
|
566 |
+
width = gr.Slider(
|
567 |
+
label="Width",
|
568 |
+
minimum=224,
|
569 |
+
maximum=720,
|
570 |
+
value=400,
|
571 |
+
step=8,
|
572 |
+
info="Video width in pixels (8px steps)"
|
573 |
+
)
|
574 |
+
height = gr.Slider(
|
575 |
+
label="Height",
|
576 |
+
minimum=224,
|
577 |
+
maximum=720,
|
578 |
+
value=224,
|
579 |
+
step=8,
|
580 |
+
info="Video height in pixels (8px steps)"
|
581 |
+
)
|
582 |
+
|
583 |
+
gr.Markdown("### π¨ LoRA Settings")
|
584 |
+
lora_id = gr.Textbox(
|
585 |
+
label="LoRA ID (from upload tab)",
|
586 |
+
placeholder="Enter your LoRA ID here...",
|
587 |
+
)
|
588 |
+
|
589 |
+
lora_weight = gr.Slider(
|
590 |
+
label="LoRA Weight",
|
591 |
+
minimum=0.0,
|
592 |
+
maximum=1.0,
|
593 |
+
step=0.01,
|
594 |
+
value=1.0,
|
595 |
+
info="Strength of LoRA influence"
|
596 |
+
)
|
597 |
+
|
598 |
+
with gr.Column(scale=3):
|
599 |
+
gr.Markdown("### πΊ Video Stream")
|
600 |
+
|
601 |
+
streaming_video = gr.Video(
|
602 |
+
label="Live Stream",
|
603 |
+
streaming=True,
|
604 |
+
loop=True,
|
605 |
+
height=400,
|
606 |
+
autoplay=True,
|
607 |
+
show_label=False
|
608 |
+
)
|
609 |
+
|
610 |
+
status_display = gr.HTML(
|
611 |
+
value=(
|
612 |
+
"<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
|
613 |
+
"π¬ Ready to start streaming...<br>"
|
614 |
+
"<small>Configure your prompt and click 'Start Streaming'</small>"
|
615 |
+
"</div>"
|
616 |
+
),
|
617 |
+
label="Generation Status"
|
618 |
+
)
|
619 |
+
|
620 |
+
# Connect the generator to the streaming video
|
621 |
+
start_btn.click(
|
622 |
+
fn=video_generation_handler_streaming,
|
623 |
+
inputs=[prompt, seed, fps, width, height, duration, lora_id, lora_weight],
|
624 |
+
outputs=[streaming_video, status_display]
|
625 |
+
)
|
626 |
+
|
627 |
+
# Connect LoRA upload to both display fields
|
628 |
+
lora_file.change(
|
629 |
+
fn=upload_lora_file,
|
630 |
+
inputs=[lora_file],
|
631 |
+
outputs=[lora_id_output, lora_id]
|
632 |
+
)
|
633 |
+
|
634 |
+
|
635 |
+
# --- Launch App ---
|
636 |
+
if __name__ == "__main__":
|
637 |
+
if os.path.exists("gradio_tmp"):
|
638 |
+
import shutil
|
639 |
+
shutil.rmtree("gradio_tmp")
|
640 |
+
os.makedirs("gradio_tmp", exist_ok=True)
|
641 |
+
|
642 |
+
print("π Starting Self-Forcing Streaming Demo")
|
643 |
+
print(f"π Temporary files will be stored in: gradio_tmp/")
|
644 |
+
print(f"π― Chunk encoding: PyAV (MPEG-TS/H.264)")
|
645 |
+
print(f"β‘ GPU acceleration: {gpu}")
|
646 |
+
|
647 |
+
demo.queue().launch(
|
648 |
+
server_name=args.host,
|
649 |
+
server_port=args.port,
|
650 |
+
share=args.share,
|
651 |
+
show_error=True,
|
652 |
+
max_threads=40,
|
653 |
+
mcp_server=True
|
654 |
+
)
|