jbilcke-hf HF Staff commited on
Commit
54eccd3
Β·
verified Β·
1 Parent(s): b55bb25

Create app_broken_lora.py

Browse files
Files changed (1) hide show
  1. app_broken_lora.py +654 -0
app_broken_lora.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ # not sure why it works in the original space but says "pip not found" in mine
3
+ #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
4
+
5
+ import os
6
+ from huggingface_hub import snapshot_download, hf_hub_download
7
+
8
+ # Configuration for data paths
9
+ DATA_ROOT = os.path.normpath(os.getenv('DATA_ROOT', '.'))
10
+ WAN_MODELS_PATH = os.path.join(DATA_ROOT, 'wan_models')
11
+ OTHER_MODELS_PATH = os.path.join(DATA_ROOT, 'other_models')
12
+
13
+ snapshot_download(
14
+ repo_id="Wan-AI/Wan2.1-T2V-1.3B",
15
+ local_dir=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B"),
16
+ local_dir_use_symlinks=False,
17
+ resume_download=True,
18
+ repo_type="model"
19
+ )
20
+
21
+ hf_hub_download(
22
+ repo_id="gdhe17/Self-Forcing",
23
+ filename="checkpoints/self_forcing_dmd.pt",
24
+ local_dir=OTHER_MODELS_PATH,
25
+ local_dir_use_symlinks=False
26
+ )
27
+ import re
28
+ import random
29
+ import argparse
30
+ import hashlib
31
+ import urllib.request
32
+ import time
33
+ from PIL import Image
34
+ import torch
35
+ import gradio as gr
36
+ from omegaconf import OmegaConf
37
+ from tqdm import tqdm
38
+ import imageio
39
+ import av
40
+ import uuid
41
+ import tempfile
42
+ import shutil
43
+ from pathlib import Path
44
+ from typing import Dict, Any, List, Optional, Tuple, Union
45
+
46
+ from pipeline import CausalInferencePipeline
47
+ from demo_utils.constant import ZERO_VAE_CACHE
48
+ from demo_utils.vae_block3 import VAEDecoderWrapper
49
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
50
+
51
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM #, BitsAndBytesConfig
52
+ import numpy as np
53
+
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+
56
+ # LoRA Storage Configuration
57
+ STORAGE_PATH = Path(DATA_ROOT) / "storage"
58
+ LORA_PATH = STORAGE_PATH / "loras"
59
+ OUTPUT_PATH = STORAGE_PATH / "output"
60
+
61
+ # Create necessary directories
62
+ STORAGE_PATH.mkdir(parents=True, exist_ok=True)
63
+ LORA_PATH.mkdir(parents=True, exist_ok=True)
64
+ OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
65
+
66
+ # Global variables for LoRA management
67
+ current_lora_id = None
68
+ current_lora_path = None
69
+
70
+ # --- Argument Parsing ---
71
+ parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
72
+ parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
73
+ parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
74
+ parser.add_argument("--checkpoint_path", type=str, default=os.path.join(OTHER_MODELS_PATH, 'checkpoints', 'self_forcing_dmd.pt'), help="Path to the model checkpoint.")
75
+ parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
76
+ parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
77
+ parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
78
+ parser.add_argument('--fps', type=float, default=15.0, help="Playback FPS for frame streaming.")
79
+ args = parser.parse_args()
80
+
81
+ gpu = "cuda"
82
+
83
+ try:
84
+ config = OmegaConf.load(args.config_path)
85
+ default_config = OmegaConf.load("configs/default_config.yaml")
86
+ config = OmegaConf.merge(default_config, config)
87
+ except FileNotFoundError as e:
88
+ print(f"Error loading config file: {e}\n. Please ensure config files are in the correct path.")
89
+ exit(1)
90
+
91
+ # Initialize Models
92
+ print("Initializing models...")
93
+ text_encoder = WanTextEncoder()
94
+ transformer = WanDiffusionWrapper(is_causal=True)
95
+
96
+ try:
97
+ state_dict = torch.load(args.checkpoint_path, map_location="cpu")
98
+ transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator')))
99
+ except FileNotFoundError as e:
100
+ print(f"Error loading checkpoint: {e}\nPlease ensure the checkpoint '{args.checkpoint_path}' exists.")
101
+ exit(1)
102
+
103
+ text_encoder.eval().to(dtype=torch.float16).requires_grad_(False)
104
+ transformer.eval().to(dtype=torch.float16).requires_grad_(False)
105
+
106
+ text_encoder.to(gpu)
107
+ transformer.to(gpu)
108
+
109
+ APP_STATE = {
110
+ "torch_compile_applied": False,
111
+ "fp8_applied": False,
112
+ "current_use_taehv": False,
113
+ "current_vae_decoder": None,
114
+ }
115
+
116
+ # I've tried to enable it, but I didn't notice a significant performance improvement..
117
+ ENABLE_TORCH_COMPILATION = False
118
+
119
+ # β€œdefault”: The default mode, used when no mode parameter is specified. It provides a good balance between performance and overhead.
120
+ # β€œreduce-overhead”: Minimizes Python-related overhead using CUDA graphs. However, it may increase memory usage.
121
+ # β€œmax-autotune”: Uses Triton or template-based matrix multiplications on supported devices. It takes longer to compile but optimizes for the fastest possible execution. On GPUs it enables CUDA graphs by default.
122
+ # β€œmax-autotune-no-cudagraphs”: Similar to β€œmax-autotune”, but without CUDA graphs.
123
+ TORCH_COMPILATION_MODE = "default"
124
+
125
+ # Apply torch.compile for maximum performance
126
+ if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
127
+ print("πŸš€ Applying torch.compile for speed optimization...")
128
+ transformer.compile(mode=TORCH_COMPILATION_MODE)
129
+ APP_STATE["torch_compile_applied"] = True
130
+ print("βœ… torch.compile applied to transformer")
131
+
132
+ def upload_lora_file(file: tempfile._TemporaryFileWrapper) -> Tuple[str, str]:
133
+ """Upload a LoRA file and return a hash-based ID for future reference"""
134
+ if file is None:
135
+ return "", ""
136
+
137
+ try:
138
+ # Calculate SHA256 hash of the file
139
+ sha256_hash = hashlib.sha256()
140
+ with open(file.name, "rb") as f:
141
+ for chunk in iter(lambda: f.read(4096), b""):
142
+ sha256_hash.update(chunk)
143
+ file_hash = sha256_hash.hexdigest()
144
+
145
+ # Create destination path using hash
146
+ dest_path = LORA_PATH / f"{file_hash}.safetensors"
147
+
148
+ # Check if file already exists
149
+ if dest_path.exists():
150
+ print(f"LoRA file already exists!")
151
+ return file_hash, file_hash
152
+
153
+ # Copy the file to the destination
154
+ shutil.copy(file.name, dest_path)
155
+
156
+ print(f"LoRA file uploaded!")
157
+ return file_hash, file_hash
158
+ except Exception as e:
159
+ print(f"Error uploading LoRA file: {e}")
160
+ raise gr.Error(f"Failed to upload LoRA file: {str(e)}")
161
+
162
+ def get_lora_file_path(lora_id: Optional[str]) -> Optional[Path]:
163
+ """Get the path to a LoRA file from its hash-based ID"""
164
+ if not lora_id:
165
+ return None
166
+
167
+ # Check if file exists
168
+ lora_path = LORA_PATH / f"{lora_id}.safetensors"
169
+ if lora_path.exists():
170
+ return lora_path
171
+
172
+ return None
173
+
174
+ def manage_lora_weights(lora_id: Optional[str], lora_weight: float) -> Tuple[bool, Optional[Path]]:
175
+ """Manage LoRA weights for the transformer model"""
176
+ global current_lora_id, current_lora_path
177
+
178
+ # Determine if we should use LoRA
179
+ using_lora = lora_id is not None and lora_id.strip() != "" and lora_weight > 0
180
+
181
+ # If not using LoRA but we have one loaded, clear it
182
+ if not using_lora and current_lora_id is not None:
183
+ print(f"Clearing current LoRA")
184
+ current_lora_id = None
185
+ current_lora_path = None
186
+ return False, None
187
+
188
+ # If using LoRA, check if we need to change weights
189
+ if using_lora:
190
+ lora_path = get_lora_file_path(lora_id)
191
+
192
+ if not lora_path:
193
+ print(f"A LoRA file with this ID was found. Using base model instead.")
194
+
195
+ # If we had a LoRA loaded, clear it
196
+ if current_lora_id is not None:
197
+ print(f"Clearing current LoRA")
198
+ current_lora_id = None
199
+ current_lora_path = None
200
+
201
+ return False, None
202
+
203
+ # If LoRA ID changed, update
204
+ if lora_id != current_lora_id:
205
+ print(f"Loading LoRA..")
206
+ current_lora_id = lora_id
207
+ current_lora_path = lora_path
208
+ else:
209
+ print(f"Using a LoRA!")
210
+
211
+ return True, lora_path
212
+
213
+ return False, None
214
+
215
+ def frames_to_ts_file(frames, filepath, fps = 15):
216
+ """
217
+ Convert frames directly to .ts file using PyAV.
218
+
219
+ Args:
220
+ frames: List of numpy arrays (HWC, RGB, uint8)
221
+ filepath: Output file path
222
+ fps: Frames per second
223
+
224
+ Returns:
225
+ The filepath of the created file
226
+ """
227
+ if not frames:
228
+ return filepath
229
+
230
+ height, width = frames[0].shape[:2]
231
+
232
+ # Create container for MPEG-TS format
233
+ container = av.open(filepath, mode='w', format='mpegts')
234
+
235
+ # Add video stream with optimized settings for streaming
236
+ stream = container.add_stream('h264', rate=fps)
237
+ stream.width = width
238
+ stream.height = height
239
+ stream.pix_fmt = 'yuv420p'
240
+
241
+ # Optimize for low latency streaming
242
+ stream.options = {
243
+ 'preset': 'ultrafast',
244
+ 'tune': 'zerolatency',
245
+ 'crf': '23',
246
+ 'profile': 'baseline',
247
+ 'level': '3.0'
248
+ }
249
+
250
+ try:
251
+ for frame_np in frames:
252
+ frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
253
+ frame = frame.reformat(format=stream.pix_fmt)
254
+ for packet in stream.encode(frame):
255
+ container.mux(packet)
256
+
257
+ for packet in stream.encode():
258
+ container.mux(packet)
259
+
260
+ finally:
261
+ container.close()
262
+
263
+ return filepath
264
+
265
+ def initialize_vae_decoder(use_taehv=False, use_trt=False):
266
+ if use_trt:
267
+ from demo_utils.vae import VAETRTWrapper
268
+ print("Initializing TensorRT VAE Decoder...")
269
+ vae_decoder = VAETRTWrapper()
270
+ APP_STATE["current_use_taehv"] = False
271
+ elif use_taehv:
272
+ print("Initializing TAEHV VAE Decoder...")
273
+ from demo_utils.taehv import TAEHV
274
+ taehv_checkpoint_path = "checkpoints/taew2_1.pth"
275
+ if not os.path.exists(taehv_checkpoint_path):
276
+ print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...")
277
+ os.makedirs("checkpoints", exist_ok=True)
278
+ download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
279
+ try:
280
+ urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
281
+ except Exception as e:
282
+ raise RuntimeError(f"Failed to download taew2_1.pth: {e}")
283
+
284
+ class DotDict(dict): __getattr__ = dict.get
285
+
286
+ class TAEHVDiffusersWrapper(torch.nn.Module):
287
+ def __init__(self):
288
+ super().__init__()
289
+ self.dtype = torch.float16
290
+ self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
291
+ self.config = DotDict(scaling_factor=1.0)
292
+ def decode(self, latents, return_dict=None):
293
+ return self.taehv.decode_video(latents, parallel=not LOW_MEMORY).mul_(2).sub_(1)
294
+
295
+ vae_decoder = TAEHVDiffusersWrapper()
296
+ APP_STATE["current_use_taehv"] = True
297
+ else:
298
+ print("Initializing Default VAE Decoder...")
299
+ vae_decoder = VAEDecoderWrapper()
300
+ try:
301
+ vae_state_dict = torch.load(os.path.join(WAN_MODELS_PATH, 'Wan2.1-T2V-1.3B', 'Wan2.1_VAE.pth'), map_location="cpu")
302
+ decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
303
+ vae_decoder.load_state_dict(decoder_state_dict)
304
+ except FileNotFoundError:
305
+ print("Warning: Default VAE weights not found.")
306
+ APP_STATE["current_use_taehv"] = False
307
+
308
+ vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(gpu)
309
+
310
+ # Apply torch.compile to VAE decoder if enabled (following demo.py pattern)
311
+ if APP_STATE["torch_compile_applied"] and not use_taehv and not use_trt:
312
+ print("πŸš€ Applying torch.compile to VAE decoder...")
313
+ vae_decoder.compile(mode=TORCH_COMPILATION_MODE)
314
+ print("βœ… torch.compile applied to VAE decoder")
315
+
316
+ APP_STATE["current_vae_decoder"] = vae_decoder
317
+ print(f"βœ… VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
318
+
319
+ # Initialize with default VAE
320
+ initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
321
+
322
+ pipeline = CausalInferencePipeline(
323
+ config, device=gpu, generator=transformer, text_encoder=text_encoder,
324
+ vae=APP_STATE["current_vae_decoder"]
325
+ )
326
+
327
+ pipeline.to(dtype=torch.float16).to(gpu)
328
+
329
+ @torch.no_grad()
330
+ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5, lora_id=None, lora_weight=0.0):
331
+ """
332
+ Generator function that yields .ts video chunks using PyAV for streaming.
333
+ """
334
+ if seed == -1:
335
+ seed = random.randint(0, 2**32 - 1)
336
+
337
+ # print(f"🎬 Starting PyAV streaming: seed: {seed}, duration: {duration}s")
338
+
339
+ # Handle LoRA weights
340
+ using_lora, lora_path = manage_lora_weights(lora_id, lora_weight)
341
+ if using_lora:
342
+ print(f"🎨 Using LoRA with weight factor {lora_weight}")
343
+ else:
344
+ print("🎨 Using base model (no LoRA)")
345
+
346
+ # Setup
347
+ conditional_dict = text_encoder(text_prompts=[prompt])
348
+ for key, value in conditional_dict.items():
349
+ conditional_dict[key] = value.to(dtype=torch.float16)
350
+
351
+ rnd = torch.Generator(gpu).manual_seed(int(seed))
352
+ pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
353
+ pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
354
+ noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
355
+
356
+ vae_cache, latents_cache = None, None
357
+ if not APP_STATE["current_use_taehv"] and not args.trt:
358
+ vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
359
+
360
+ # Calculate number of blocks based on duration
361
+ # Current setup generates approximately 5 seconds with 7 blocks
362
+ # So we scale proportionally
363
+ base_duration = 5.0 # seconds
364
+ base_blocks = 8
365
+ num_blocks = max(1, int(base_blocks * duration / base_duration))
366
+
367
+ current_start_frame = 0
368
+ all_num_frames = [pipeline.num_frame_per_block] * num_blocks
369
+
370
+ total_frames_yielded = 0
371
+
372
+ # Ensure temp directory exists
373
+ os.makedirs("gradio_tmp", exist_ok=True)
374
+
375
+ # Generation loop
376
+ for idx, current_num_frames in enumerate(all_num_frames):
377
+ print(f"πŸ“¦ Processing block {idx+1}/{num_blocks}")
378
+
379
+ noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames]
380
+
381
+ # Denoising steps
382
+ for step_idx, current_timestep in enumerate(pipeline.denoising_step_list):
383
+ timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep
384
+ _, denoised_pred = pipeline.generator(
385
+ noisy_image_or_video=noisy_input, conditional_dict=conditional_dict,
386
+ timestep=timestep, kv_cache=pipeline.kv_cache1,
387
+ crossattn_cache=pipeline.crossattn_cache,
388
+ current_start=current_start_frame * pipeline.frame_seq_length
389
+ )
390
+ if step_idx < len(pipeline.denoising_step_list) - 1:
391
+ next_timestep = pipeline.denoising_step_list[step_idx + 1]
392
+ noisy_input = pipeline.scheduler.add_noise(
393
+ denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)),
394
+ next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
395
+ ).unflatten(0, denoised_pred.shape[:2])
396
+
397
+ if idx < len(all_num_frames) - 1:
398
+ pipeline.generator(
399
+ noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict,
400
+ timestep=torch.zeros_like(timestep), kv_cache=pipeline.kv_cache1,
401
+ crossattn_cache=pipeline.crossattn_cache,
402
+ current_start=current_start_frame * pipeline.frame_seq_length,
403
+ )
404
+
405
+ # Decode to pixels
406
+ if args.trt:
407
+ pixels, vae_cache = pipeline.vae.forward(denoised_pred.half(), *vae_cache)
408
+ elif APP_STATE["current_use_taehv"]:
409
+ if latents_cache is None:
410
+ latents_cache = denoised_pred
411
+ else:
412
+ denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1)
413
+ latents_cache = denoised_pred[:, -3:]
414
+ pixels = pipeline.vae.decode(denoised_pred)
415
+ else:
416
+ pixels, vae_cache = pipeline.vae(denoised_pred.half(), *vae_cache)
417
+
418
+ # Handle frame skipping
419
+ if idx == 0 and not args.trt:
420
+ pixels = pixels[:, 3:]
421
+ elif APP_STATE["current_use_taehv"] and idx > 0:
422
+ pixels = pixels[:, 12:]
423
+
424
+ print(f"πŸ” DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
425
+
426
+ # Process all frames from this block at once
427
+ all_frames_from_block = []
428
+ for frame_idx in range(pixels.shape[1]):
429
+ frame_tensor = pixels[0, frame_idx]
430
+
431
+ # Convert to numpy (HWC, RGB, uint8)
432
+ frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
433
+ frame_np = frame_np.to(torch.uint8).cpu().numpy()
434
+ frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
435
+
436
+ all_frames_from_block.append(frame_np)
437
+ total_frames_yielded += 1
438
+
439
+ # Yield status update for each frame (cute tracking!)
440
+ blocks_completed = idx
441
+ current_block_progress = (frame_idx + 1) / pixels.shape[1]
442
+ total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
443
+
444
+ # Cap at 100% to avoid going over
445
+ total_progress = min(total_progress, 100.0)
446
+
447
+ frame_status_html = (
448
+ f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
449
+ f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
450
+ f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
451
+ f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
452
+ f" </div>"
453
+ f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
454
+ f" Block {idx+1}/{num_blocks} | Frame {total_frames_yielded} | {total_progress:.1f}%"
455
+ f" </p>"
456
+ f"</div>"
457
+ )
458
+
459
+ # Yield None for video but update status (frame-by-frame tracking)
460
+ yield None, frame_status_html
461
+
462
+ # Encode entire block as one chunk
463
+ if all_frames_from_block:
464
+ print(f"πŸ“Ή Encoding block {idx} with {len(all_frames_from_block)} frames")
465
+
466
+ try:
467
+ chunk_uuid = str(uuid.uuid4())[:8]
468
+ ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
469
+ ts_path = os.path.join("gradio_tmp", ts_filename)
470
+
471
+ frames_to_ts_file(all_frames_from_block, ts_path, fps)
472
+
473
+ # Calculate final progress for this block
474
+ total_progress = (idx + 1) / num_blocks * 100
475
+
476
+ # Yield the actual video chunk
477
+ yield ts_path, gr.update()
478
+
479
+ except Exception as e:
480
+ print(f"⚠️ Error encoding block {idx}: {e}")
481
+ import traceback
482
+ traceback.print_exc()
483
+
484
+ current_start_frame += current_num_frames
485
+
486
+ # Final completion status
487
+ final_status_html = (
488
+ f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
489
+ f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
490
+ f" <span style='font-size: 24px; margin-right: 12px;'>πŸŽ‰</span>"
491
+ f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Stream Complete!</h4>"
492
+ f" </div>"
493
+ f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>"
494
+ f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>"
495
+ f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
496
+ f" </p>"
497
+ f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
498
+ f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: MPEG-TS/H.264"
499
+ f" </p>"
500
+ f" </div>"
501
+ f"</div>"
502
+ )
503
+ yield None, final_status_html
504
+ print(f"βœ… PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
505
+
506
+ # --- Gradio UI Layout ---
507
+ with gr.Blocks(title="Wan2.1 1.3B LoRA Self-Forcing streaming demo") as demo:
508
+ gr.Markdown("# πŸš€ Run Any LoRA in near real-time!")
509
+ gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B and LoRA [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
510
+
511
+ with gr.Tabs():
512
+ # LoRA Upload Tab
513
+ with gr.TabItem("1️⃣ Upload LoRA"):
514
+ gr.Markdown("## Upload LoRA Weights")
515
+ gr.Markdown("Upload your custom LoRA weights file to use for generation. The file will be automatically stored and you'll receive a unique hash-based ID.")
516
+
517
+ with gr.Row():
518
+ lora_file = gr.File(label="LoRA File (safetensors format)")
519
+
520
+ with gr.Row():
521
+ lora_id_output = gr.Textbox(label="LoRA Hash ID (use this in the generation tab)", interactive=False)
522
+
523
+ # Video Generation Tab
524
+ with gr.TabItem("2️⃣ Generate Video"):
525
+ with gr.Row():
526
+ with gr.Column(scale=2):
527
+ with gr.Group():
528
+ prompt = gr.Textbox(
529
+ label="Prompt",
530
+ placeholder="A stylish woman walks down a Tokyo street...",
531
+ lines=4,
532
+ value=""
533
+ )
534
+
535
+ start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
536
+
537
+ gr.Markdown("### βš™οΈ Settings")
538
+ with gr.Row():
539
+ seed = gr.Number(
540
+ label="Seed",
541
+ value=-1,
542
+ info="Use -1 for random seed",
543
+ precision=0
544
+ )
545
+ fps = gr.Slider(
546
+ label="Playback FPS",
547
+ minimum=1,
548
+ maximum=30,
549
+ value=args.fps,
550
+ step=1,
551
+ visible=False,
552
+ info="Frames per second for playback"
553
+ )
554
+
555
+ with gr.Row():
556
+ duration = gr.Slider(
557
+ label="Duration (seconds)",
558
+ minimum=1,
559
+ maximum=5,
560
+ value=3,
561
+ step=1,
562
+ info="Video duration in seconds"
563
+ )
564
+
565
+ with gr.Row():
566
+ width = gr.Slider(
567
+ label="Width",
568
+ minimum=224,
569
+ maximum=720,
570
+ value=400,
571
+ step=8,
572
+ info="Video width in pixels (8px steps)"
573
+ )
574
+ height = gr.Slider(
575
+ label="Height",
576
+ minimum=224,
577
+ maximum=720,
578
+ value=224,
579
+ step=8,
580
+ info="Video height in pixels (8px steps)"
581
+ )
582
+
583
+ gr.Markdown("### 🎨 LoRA Settings")
584
+ lora_id = gr.Textbox(
585
+ label="LoRA ID (from upload tab)",
586
+ placeholder="Enter your LoRA ID here...",
587
+ )
588
+
589
+ lora_weight = gr.Slider(
590
+ label="LoRA Weight",
591
+ minimum=0.0,
592
+ maximum=1.0,
593
+ step=0.01,
594
+ value=1.0,
595
+ info="Strength of LoRA influence"
596
+ )
597
+
598
+ with gr.Column(scale=3):
599
+ gr.Markdown("### πŸ“Ί Video Stream")
600
+
601
+ streaming_video = gr.Video(
602
+ label="Live Stream",
603
+ streaming=True,
604
+ loop=True,
605
+ height=400,
606
+ autoplay=True,
607
+ show_label=False
608
+ )
609
+
610
+ status_display = gr.HTML(
611
+ value=(
612
+ "<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
613
+ "🎬 Ready to start streaming...<br>"
614
+ "<small>Configure your prompt and click 'Start Streaming'</small>"
615
+ "</div>"
616
+ ),
617
+ label="Generation Status"
618
+ )
619
+
620
+ # Connect the generator to the streaming video
621
+ start_btn.click(
622
+ fn=video_generation_handler_streaming,
623
+ inputs=[prompt, seed, fps, width, height, duration, lora_id, lora_weight],
624
+ outputs=[streaming_video, status_display]
625
+ )
626
+
627
+ # Connect LoRA upload to both display fields
628
+ lora_file.change(
629
+ fn=upload_lora_file,
630
+ inputs=[lora_file],
631
+ outputs=[lora_id_output, lora_id]
632
+ )
633
+
634
+
635
+ # --- Launch App ---
636
+ if __name__ == "__main__":
637
+ if os.path.exists("gradio_tmp"):
638
+ import shutil
639
+ shutil.rmtree("gradio_tmp")
640
+ os.makedirs("gradio_tmp", exist_ok=True)
641
+
642
+ print("πŸš€ Starting Self-Forcing Streaming Demo")
643
+ print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
644
+ print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
645
+ print(f"⚑ GPU acceleration: {gpu}")
646
+
647
+ demo.queue().launch(
648
+ server_name=args.host,
649
+ server_port=args.port,
650
+ share=args.share,
651
+ show_error=True,
652
+ max_threads=40,
653
+ mcp_server=True
654
+ )