Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Demo for Self-Forcing. | |
""" | |
import os | |
import re | |
import random | |
import time | |
import base64 | |
import argparse | |
import hashlib | |
import subprocess | |
import urllib.request | |
from io import BytesIO | |
from PIL import Image | |
import numpy as np | |
import torch | |
from omegaconf import OmegaConf | |
from flask import Flask, render_template, jsonify | |
from flask_socketio import SocketIO, emit | |
import queue | |
from threading import Thread, Event | |
from pipeline import CausalInferencePipeline | |
from demo_utils.constant import ZERO_VAE_CACHE | |
from demo_utils.vae_block3 import VAEDecoderWrapper | |
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder | |
from demo_utils.utils import generate_timestamp | |
from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation | |
# Parse arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--port', type=int, default=5001) | |
parser.add_argument('--host', type=str, default='0.0.0.0') | |
parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt') | |
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml') | |
parser.add_argument('--trt', action='store_true') | |
args = parser.parse_args() | |
print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB') | |
low_memory = get_cuda_free_memory_gb(gpu) < 40 | |
# Load models | |
config = OmegaConf.load(args.config_path) | |
default_config = OmegaConf.load("configs/default_config.yaml") | |
config = OmegaConf.merge(default_config, config) | |
text_encoder = WanTextEncoder() | |
# Global variables for dynamic model switching | |
current_vae_decoder = None | |
current_use_taehv = False | |
fp8_applied = False | |
torch_compile_applied = False | |
global frame_number | |
frame_number = 0 | |
anim_name = "" | |
frame_rate = 6 | |
def initialize_vae_decoder(use_taehv=False, use_trt=False): | |
"""Initialize VAE decoder based on the selected option""" | |
global current_vae_decoder, current_use_taehv | |
if use_trt: | |
from demo_utils.vae import VAETRTWrapper | |
current_vae_decoder = VAETRTWrapper() | |
return current_vae_decoder | |
if use_taehv: | |
from demo_utils.taehv import TAEHV | |
# Check if taew2_1.pth exists in checkpoints folder, download if missing | |
taehv_checkpoint_path = "checkpoints/taew2_1.pth" | |
if not os.path.exists(taehv_checkpoint_path): | |
print(f"taew2_1.pth not found in checkpoints folder {taehv_checkpoint_path}. Downloading...") | |
os.makedirs("checkpoints", exist_ok=True) | |
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth" | |
try: | |
urllib.request.urlretrieve(download_url, taehv_checkpoint_path) | |
print(f"Successfully downloaded taew2_1.pth to {taehv_checkpoint_path}") | |
except Exception as e: | |
print(f"Failed to download taew2_1.pth: {e}") | |
raise | |
class DotDict(dict): | |
__getattr__ = dict.__getitem__ | |
__setattr__ = dict.__setitem__ | |
class TAEHVDiffusersWrapper(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.dtype = torch.float16 | |
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype) | |
self.config = DotDict(scaling_factor=1.0) | |
def decode(self, latents, return_dict=None): | |
# n, c, t, h, w = latents.shape | |
# low-memory, set parallel=True for faster + higher memory | |
return self.taehv.decode_video(latents, parallel=False).mul_(2).sub_(1) | |
current_vae_decoder = TAEHVDiffusersWrapper() | |
else: | |
current_vae_decoder = VAEDecoderWrapper() | |
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu") | |
decoder_state_dict = {} | |
for key, value in vae_state_dict.items(): | |
if 'decoder.' in key or 'conv2' in key: | |
decoder_state_dict[key] = value | |
current_vae_decoder.load_state_dict(decoder_state_dict) | |
current_vae_decoder.eval() | |
current_vae_decoder.to(dtype=torch.float16) | |
current_vae_decoder.requires_grad_(False) | |
current_vae_decoder.to(gpu) | |
current_use_taehv = use_taehv | |
print(f"✅ VAE decoder initialized with {'TAEHV' if use_taehv else 'default VAE'}") | |
return current_vae_decoder | |
# Initialize with default VAE | |
vae_decoder = initialize_vae_decoder(use_taehv=False, use_trt=args.trt) | |
transformer = WanDiffusionWrapper(is_causal=True) | |
state_dict = torch.load(args.checkpoint_path, map_location="cpu") | |
transformer.load_state_dict(state_dict['generator_ema']) | |
text_encoder.eval() | |
transformer.eval() | |
transformer.to(dtype=torch.float16) | |
text_encoder.to(dtype=torch.bfloat16) | |
text_encoder.requires_grad_(False) | |
transformer.requires_grad_(False) | |
pipeline = CausalInferencePipeline( | |
config, | |
device=gpu, | |
generator=transformer, | |
text_encoder=text_encoder, | |
vae=vae_decoder | |
) | |
if low_memory: | |
DynamicSwapInstaller.install_model(text_encoder, device=gpu) | |
else: | |
text_encoder.to(gpu) | |
transformer.to(gpu) | |
# Flask and SocketIO setup | |
app = Flask(__name__) | |
app.config['SECRET_KEY'] = 'frontend_buffered_demo' | |
socketio = SocketIO(app, cors_allowed_origins="*") | |
generation_active = False | |
stop_event = Event() | |
frame_send_queue = queue.Queue() | |
sender_thread = None | |
models_compiled = False | |
def tensor_to_base64_frame(frame_tensor): | |
"""Convert a single frame tensor to base64 image string.""" | |
global frame_number, anim_name | |
# Clamp and normalize to 0-255 | |
frame = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5 | |
frame = frame.to(torch.uint8).cpu().numpy() | |
# CHW -> HWC | |
if len(frame.shape) == 3: | |
frame = np.transpose(frame, (1, 2, 0)) | |
# Convert to PIL Image | |
if frame.shape[2] == 3: # RGB | |
image = Image.fromarray(frame, 'RGB') | |
else: # Handle other formats | |
image = Image.fromarray(frame) | |
# Convert to base64 | |
buffer = BytesIO() | |
image.save(buffer, format='JPEG', quality=100) | |
if not os.path.exists("./images/%s" % anim_name): | |
os.makedirs("./images/%s" % anim_name) | |
frame_number += 1 | |
image.save("./images/%s/%s_%03d.jpg" % (anim_name, anim_name, frame_number)) | |
img_str = base64.b64encode(buffer.getvalue()).decode() | |
return f"data:image/jpeg;base64,{img_str}" | |
def frame_sender_worker(): | |
"""Background thread that processes frame send queue non-blocking.""" | |
global frame_send_queue, generation_active, stop_event | |
print("📡 Frame sender thread started") | |
while True: | |
frame_data = None | |
try: | |
# Get frame data from queue | |
frame_data = frame_send_queue.get(timeout=1.0) | |
if frame_data is None: # Shutdown signal | |
frame_send_queue.task_done() # Mark shutdown signal as done | |
break | |
frame_tensor, frame_index, block_index, job_id = frame_data | |
# Convert tensor to base64 | |
base64_frame = tensor_to_base64_frame(frame_tensor) | |
# Send via SocketIO | |
try: | |
socketio.emit('frame_ready', { | |
'data': base64_frame, | |
'frame_index': frame_index, | |
'block_index': block_index, | |
'job_id': job_id | |
}) | |
except Exception as e: | |
print(f"⚠️ Failed to send frame {frame_index}: {e}") | |
frame_send_queue.task_done() | |
except queue.Empty: | |
# Check if we should continue running | |
if not generation_active and frame_send_queue.empty(): | |
break | |
except Exception as e: | |
print(f"❌ Frame sender error: {e}") | |
# Make sure to mark task as done even if there's an error | |
if frame_data is not None: | |
try: | |
frame_send_queue.task_done() | |
except Exception as e: | |
print(f"❌ Failed to mark frame task as done: {e}") | |
break | |
print("📡 Frame sender thread stopped") | |
def generate_video_stream(prompt, seed, enable_torch_compile=False, enable_fp8=False, use_taehv=False): | |
"""Generate video and push frames immediately to frontend.""" | |
global generation_active, stop_event, frame_send_queue, sender_thread, models_compiled, torch_compile_applied, fp8_applied, current_vae_decoder, current_use_taehv, frame_rate, anim_name | |
try: | |
generation_active = True | |
stop_event.clear() | |
job_id = generate_timestamp() | |
# Start frame sender thread if not already running | |
if sender_thread is None or not sender_thread.is_alive(): | |
sender_thread = Thread(target=frame_sender_worker, daemon=True) | |
sender_thread.start() | |
# Emit progress updates | |
def emit_progress(message, progress): | |
try: | |
socketio.emit('progress', { | |
'message': message, | |
'progress': progress, | |
'job_id': job_id | |
}) | |
except Exception as e: | |
print(f"❌ Failed to emit progress: {e}") | |
emit_progress('Starting generation...', 0) | |
# Handle VAE decoder switching | |
if use_taehv != current_use_taehv: | |
emit_progress('Switching VAE decoder...', 2) | |
print(f"🔄 Switching VAE decoder to {'TAEHV' if use_taehv else 'default VAE'}") | |
current_vae_decoder = initialize_vae_decoder(use_taehv=use_taehv) | |
# Update pipeline with new VAE decoder | |
pipeline.vae = current_vae_decoder | |
# Handle FP8 quantization | |
if enable_fp8 and not fp8_applied: | |
emit_progress('Applying FP8 quantization...', 3) | |
print("🔧 Applying FP8 quantization to transformer") | |
from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor | |
quantize_(transformer, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())) | |
fp8_applied = True | |
# Text encoding | |
emit_progress('Encoding text prompt...', 8) | |
conditional_dict = text_encoder(text_prompts=[prompt]) | |
for key, value in conditional_dict.items(): | |
conditional_dict[key] = value.to(dtype=torch.float16) | |
if low_memory: | |
gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5 | |
move_model_to_device_with_memory_preservation( | |
text_encoder,target_device=gpu, preserved_memory_gb=gpu_memory_preservation) | |
# Handle torch.compile if enabled | |
torch_compile_applied = enable_torch_compile | |
if enable_torch_compile and not models_compiled: | |
# Compile transformer and decoder | |
transformer.compile(mode="max-autotune-no-cudagraphs") | |
if not current_use_taehv and not low_memory and not args.trt: | |
current_vae_decoder.compile(mode="max-autotune-no-cudagraphs") | |
# Initialize generation | |
emit_progress('Initializing generation...', 12) | |
rnd = torch.Generator(gpu).manual_seed(seed) | |
# all_latents = torch.zeros([1, 21, 16, 60, 104], device=gpu, dtype=torch.bfloat16) | |
pipeline._initialize_kv_cache(batch_size=1, dtype=torch.float16, device=gpu) | |
pipeline._initialize_crossattn_cache(batch_size=1, dtype=torch.float16, device=gpu) | |
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd) | |
# Generation parameters | |
num_blocks = 7 | |
current_start_frame = 0 | |
num_input_frames = 0 | |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks | |
if current_use_taehv: | |
vae_cache = None | |
else: | |
vae_cache = ZERO_VAE_CACHE | |
for i in range(len(vae_cache)): | |
vae_cache[i] = vae_cache[i].to(device=gpu, dtype=torch.float16) | |
total_frames_sent = 0 | |
generation_start_time = time.time() | |
emit_progress('Generating frames... (frontend handles timing)', 15) | |
for idx, current_num_frames in enumerate(all_num_frames): | |
if not generation_active or stop_event.is_set(): | |
break | |
progress = int(((idx + 1) / len(all_num_frames)) * 80) + 15 | |
# Special message for first block with torch.compile | |
if idx == 0 and torch_compile_applied and not models_compiled: | |
emit_progress( | |
f'Processing block 1/{len(all_num_frames)} - Compiling models (may take 5-10 minutes)...', progress) | |
print(f"🔥 Processing block {idx+1}/{len(all_num_frames)}") | |
models_compiled = True | |
else: | |
emit_progress(f'Processing block {idx+1}/{len(all_num_frames)}...', progress) | |
print(f"🔄 Processing block {idx+1}/{len(all_num_frames)}") | |
block_start_time = time.time() | |
noisy_input = noise[:, current_start_frame - | |
num_input_frames:current_start_frame + current_num_frames - num_input_frames] | |
# Denoising loop | |
denoising_start = time.time() | |
for index, current_timestep in enumerate(pipeline.denoising_step_list): | |
if not generation_active or stop_event.is_set(): | |
break | |
timestep = torch.ones([1, current_num_frames], device=noise.device, | |
dtype=torch.int64) * current_timestep | |
if index < len(pipeline.denoising_step_list) - 1: | |
_, denoised_pred = transformer( | |
noisy_image_or_video=noisy_input, | |
conditional_dict=conditional_dict, | |
timestep=timestep, | |
kv_cache=pipeline.kv_cache1, | |
crossattn_cache=pipeline.crossattn_cache, | |
current_start=current_start_frame * pipeline.frame_seq_length | |
) | |
next_timestep = pipeline.denoising_step_list[index + 1] | |
noisy_input = pipeline.scheduler.add_noise( | |
denoised_pred.flatten(0, 1), | |
torch.randn_like(denoised_pred.flatten(0, 1)), | |
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long) | |
).unflatten(0, denoised_pred.shape[:2]) | |
else: | |
_, denoised_pred = transformer( | |
noisy_image_or_video=noisy_input, | |
conditional_dict=conditional_dict, | |
timestep=timestep, | |
kv_cache=pipeline.kv_cache1, | |
crossattn_cache=pipeline.crossattn_cache, | |
current_start=current_start_frame * pipeline.frame_seq_length | |
) | |
if not generation_active or stop_event.is_set(): | |
break | |
denoising_time = time.time() - denoising_start | |
print(f"⚡ Block {idx+1} denoising completed in {denoising_time:.2f}s") | |
# Record output | |
# all_latents[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred | |
# Update KV cache for next block | |
if idx != len(all_num_frames) - 1: | |
transformer( | |
noisy_image_or_video=denoised_pred, | |
conditional_dict=conditional_dict, | |
timestep=torch.zeros_like(timestep), | |
kv_cache=pipeline.kv_cache1, | |
crossattn_cache=pipeline.crossattn_cache, | |
current_start=current_start_frame * pipeline.frame_seq_length, | |
) | |
# Decode to pixels and send frames immediately | |
print(f"🎨 Decoding block {idx+1} to pixels...") | |
decode_start = time.time() | |
if args.trt: | |
all_current_pixels = [] | |
for i in range(denoised_pred.shape[1]): | |
is_first_frame = torch.tensor(1.0).cuda().half() if idx == 0 and i == 0 else \ | |
torch.tensor(0.0).cuda().half() | |
outputs = vae_decoder.forward(denoised_pred[:, i:i + 1, :, :, :].half(), is_first_frame, *vae_cache) | |
# outputs = vae_decoder.forward(denoised_pred.float(), *vae_cache) | |
current_pixels, vae_cache = outputs[0], outputs[1:] | |
print(current_pixels.max(), current_pixels.min()) | |
all_current_pixels.append(current_pixels.clone()) | |
pixels = torch.cat(all_current_pixels, dim=1) | |
if idx == 0: | |
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block | |
else: | |
if current_use_taehv: | |
if vae_cache is None: | |
vae_cache = denoised_pred | |
else: | |
denoised_pred = torch.cat([vae_cache, denoised_pred], dim=1) | |
vae_cache = denoised_pred[:, -3:, :, :, :] | |
pixels = current_vae_decoder.decode(denoised_pred) | |
print(f"denoised_pred shape: {denoised_pred.shape}") | |
print(f"pixels shape: {pixels.shape}") | |
if idx == 0: | |
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block | |
else: | |
pixels = pixels[:, 12:, :, :, :] | |
else: | |
pixels, vae_cache = current_vae_decoder(denoised_pred.half(), *vae_cache) | |
if idx == 0: | |
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block | |
decode_time = time.time() - decode_start | |
print(f"🎨 Block {idx+1} VAE decoding completed in {decode_time:.2f}s") | |
# Queue frames for non-blocking sending | |
block_frames = pixels.shape[1] | |
print(f"📡 Queueing {block_frames} frames from block {idx+1} for sending...") | |
queue_start = time.time() | |
for frame_idx in range(block_frames): | |
if not generation_active or stop_event.is_set(): | |
break | |
frame_tensor = pixels[0, frame_idx].cpu() | |
# Queue frame data in non-blocking way | |
frame_send_queue.put((frame_tensor, total_frames_sent, idx, job_id)) | |
total_frames_sent += 1 | |
queue_time = time.time() - queue_start | |
block_time = time.time() - block_start_time | |
print(f"✅ Block {idx+1} completed in {block_time:.2f}s ({block_frames} frames queued in {queue_time:.3f}s)") | |
current_start_frame += current_num_frames | |
generation_time = time.time() - generation_start_time | |
print(f"🎉 Generation completed in {generation_time:.2f}s! {total_frames_sent} frames queued for sending") | |
# Wait for all frames to be sent before completing | |
emit_progress('Waiting for all frames to be sent...', 97) | |
print("⏳ Waiting for all frames to be sent...") | |
frame_send_queue.join() # Wait for all queued frames to be processed | |
print("✅ All frames sent successfully!") | |
generate_mp4_from_images("./images","./videos/"+anim_name+".mp4", frame_rate ) | |
# Final progress update | |
emit_progress('Generation complete!', 100) | |
try: | |
socketio.emit('generation_complete', { | |
'message': 'Video generation completed!', | |
'total_frames': total_frames_sent, | |
'generation_time': f"{generation_time:.2f}s", | |
'job_id': job_id | |
}) | |
except Exception as e: | |
print(f"❌ Failed to emit generation complete: {e}") | |
except Exception as e: | |
print(f"❌ Generation failed: {e}") | |
try: | |
socketio.emit('error', { | |
'message': f'Generation failed: {str(e)}', | |
'job_id': job_id | |
}) | |
except Exception as e: | |
print(f"❌ Failed to emit error: {e}") | |
finally: | |
generation_active = False | |
stop_event.set() | |
# Clean up sender thread | |
try: | |
frame_send_queue.put(None) | |
except Exception as e: | |
print(f"❌ Failed to put None in frame_send_queue: {e}") | |
def generate_mp4_from_images(image_directory, output_video_path, fps=24): | |
""" | |
Generate an MP4 video from a directory of images ordered alphabetically. | |
:param image_directory: Path to the directory containing images. | |
:param output_video_path: Path where the output MP4 will be saved. | |
:param fps: Frames per second for the output video. | |
""" | |
global anim_name | |
# Construct the ffmpeg command | |
cmd = [ | |
'ffmpeg', | |
'-framerate', str(fps), | |
'-i', os.path.join(image_directory, anim_name+'/'+anim_name+'_%03d.jpg'), # Adjust the pattern if necessary | |
'-c:v', 'libx264', | |
'-pix_fmt', 'yuv420p', | |
output_video_path | |
] | |
try: | |
subprocess.run(cmd, check=True) | |
print(f"Video saved to {output_video_path}") | |
except subprocess.CalledProcessError as e: | |
print(f"An error occurred: {e}") | |
def calculate_sha256(data): | |
# Convert data to bytes if it's not already | |
if isinstance(data, str): | |
data = data.encode() | |
# Calculate SHA-256 hash | |
sha256_hash = hashlib.sha256(data).hexdigest() | |
return sha256_hash | |
# Socket.IO event handlers | |
def handle_connect(): | |
print('Client connected') | |
emit('status', {'message': 'Connected to frontend-buffered demo server'}) | |
def handle_disconnect(): | |
print('Client disconnected') | |
def handle_start_generation(data): | |
global generation_active, frame_number, anim_name, frame_rate | |
frame_number = 0 | |
if generation_active: | |
emit('error', {'message': 'Generation already in progress'}) | |
return | |
prompt = data.get('prompt', '') | |
seed = data.get('seed', -1) | |
if seed==-1: | |
seed = random.randint(0, 2**32) | |
# Extract words up to the first punctuation or newline | |
words_up_to_punctuation = re.split(r'[^\w\s]', prompt)[0].strip() if prompt else '' | |
if not words_up_to_punctuation: | |
words_up_to_punctuation = re.split(r'[\n\r]', prompt)[0].strip() | |
# Calculate SHA-256 hash of the entire prompt | |
sha256_hash = calculate_sha256(prompt) | |
# Create anim_name with the extracted words and first 10 characters of the hash | |
anim_name = f"{words_up_to_punctuation[:20]}_{str(seed)}_{sha256_hash[:10]}" | |
generation_active = True | |
generation_start_time = time.time() | |
enable_torch_compile = data.get('enable_torch_compile', False) | |
enable_fp8 = data.get('enable_fp8', False) | |
use_taehv = data.get('use_taehv', False) | |
frame_rate = data.get('fps', 6) | |
if not prompt: | |
emit('error', {'message': 'Prompt is required'}) | |
return | |
# Start generation in background thread | |
socketio.start_background_task(generate_video_stream, prompt, seed, | |
enable_torch_compile, enable_fp8, use_taehv) | |
emit('status', {'message': 'Generation started - frames will be sent immediately'}) | |
def handle_stop_generation(): | |
global generation_active, stop_event, frame_send_queue | |
generation_active = False | |
stop_event.set() | |
# Signal sender thread to stop (will be processed after current frames) | |
try: | |
frame_send_queue.put(None) | |
except Exception as e: | |
print(f"❌ Failed to put None in frame_send_queue: {e}") | |
emit('status', {'message': 'Generation stopped'}) | |
# Web routes | |
def index(): | |
return render_template('demo.html') | |
def api_status(): | |
return jsonify({ | |
'generation_active': generation_active, | |
'free_vram_gb': get_cuda_free_memory_gb(gpu), | |
'fp8_applied': fp8_applied, | |
'torch_compile_applied': torch_compile_applied, | |
'current_use_taehv': current_use_taehv | |
}) | |
if __name__ == '__main__': | |
print(f"🚀 Starting demo on http://{args.host}:{args.port}") | |
socketio.run(app, host=args.host, port=args.port, debug=False) | |