Spaces:
Running
on
Zero
Running
on
Zero
# Create src directory structure | |
import os | |
import sys | |
print("Starting NAG Video Demo application...") | |
# Add current directory to Python path | |
try: | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
except: | |
current_dir = os.getcwd() | |
sys.path.insert(0, current_dir) | |
print(f"Added {current_dir} to Python path") | |
os.makedirs("src", exist_ok=True) | |
# Install required packages | |
os.system("pip install safetensors") | |
# Create __init__.py | |
with open("src/__init__.py", "w") as f: | |
f.write("") | |
print("Creating NAG transformer module...") | |
# Create transformer_wan_nag.py | |
with open("src/transformer_wan_nag.py", "w") as f: | |
f.write(''' | |
import torch | |
import torch.nn as nn | |
from typing import Optional, Dict, Any | |
import torch.nn.functional as F | |
class NagWanTransformer3DModel(nn.Module): | |
"""NAG-enhanced Transformer for video generation (simplified demo)""" | |
def __init__( | |
self, | |
in_channels: int = 4, | |
out_channels: int = 4, | |
hidden_size: int = 64, | |
num_layers: int = 1, | |
num_heads: int = 4, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.hidden_size = hidden_size | |
self.training = False | |
self._dtype = torch.float32 # Add dtype attribute | |
# Dummy config for compatibility | |
self.config = type('Config', (), { | |
'in_channels': in_channels, | |
'out_channels': out_channels, | |
'hidden_size': hidden_size, | |
'num_attention_heads': num_heads, | |
'attention_head_dim': hidden_size // num_heads, | |
})() | |
# Simple conv layers for demo | |
self.conv_in = nn.Conv3d(in_channels, hidden_size, kernel_size=3, padding=1) | |
self.conv_mid = nn.Conv3d(hidden_size, hidden_size, kernel_size=3, padding=1) | |
self.conv_out = nn.Conv3d(hidden_size, out_channels, kernel_size=3, padding=1) | |
# Time embedding | |
self.time_embed = nn.Sequential( | |
nn.Linear(1, hidden_size), | |
nn.SiLU(), | |
nn.Linear(hidden_size, hidden_size), | |
) | |
@property | |
def dtype(self): | |
"""Return the dtype of the model""" | |
return self._dtype | |
@dtype.setter | |
def dtype(self, value): | |
"""Set the dtype of the model""" | |
self._dtype = value | |
def to(self, *args, **kwargs): | |
"""Override to method to handle dtype""" | |
result = super().to(*args, **kwargs) | |
# Update dtype if moving to a specific dtype | |
for arg in args: | |
if isinstance(arg, torch.dtype): | |
self._dtype = arg | |
if 'dtype' in kwargs: | |
self._dtype = kwargs['dtype'] | |
return result | |
@staticmethod | |
def attn_processors(): | |
return {} | |
@staticmethod | |
def set_attn_processor(processor): | |
pass | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
timestep: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
**kwargs | |
): | |
# Simple forward pass for demo | |
batch_size = hidden_states.shape[0] | |
# Time embedding | |
if timestep is not None: | |
# Ensure timestep is the right shape | |
if timestep.ndim == 0: | |
timestep = timestep.unsqueeze(0) | |
if timestep.shape[0] != batch_size: | |
timestep = timestep.repeat(batch_size) | |
# Normalize timestep to [0, 1] | |
t_emb = timestep.float() / 1000.0 | |
t_emb = t_emb.view(-1, 1) | |
t_emb = self.time_embed(t_emb) | |
# Reshape for broadcasting | |
t_emb = t_emb.view(batch_size, -1, 1, 1, 1) | |
# Simple convolutions | |
h = self.conv_in(hidden_states) | |
# Add time embedding if available | |
if timestep is not None: | |
h = h + t_emb | |
h = F.silu(h) | |
h = self.conv_mid(h) | |
h = F.silu(h) | |
h = self.conv_out(h) | |
# Add residual connection | |
h = h + hidden_states | |
return h | |
''') | |
print("Creating NAG pipeline module...") | |
# Create pipeline_wan_nag.py | |
with open("src/pipeline_wan_nag.py", "w") as f: | |
f.write(''' | |
import torch | |
import torch.nn.functional as F | |
from typing import List, Optional, Union, Tuple, Callable, Dict, Any | |
from diffusers import DiffusionPipeline | |
from diffusers.utils import logging, export_to_video | |
from diffusers.schedulers import KarrasDiffusionSchedulers | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
from transformers import CLIPTextModel, CLIPTokenizer | |
import numpy as np | |
logger = logging.get_logger(__name__) | |
class NAGWanPipeline(DiffusionPipeline): | |
"""NAG-enhanced pipeline for video generation""" | |
def __init__( | |
self, | |
vae, | |
text_encoder, | |
tokenizer, | |
transformer, | |
scheduler, | |
): | |
super().__init__() | |
self.register_modules( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
transformer=transformer, | |
scheduler=scheduler, | |
) | |
# Set vae scale factor | |
if hasattr(self.vae, 'config') and hasattr(self.vae.config, 'block_out_channels'): | |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) | |
else: | |
self.vae_scale_factor = 8 # Default value for most VAEs | |
@classmethod | |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | |
"""Load pipeline from pretrained model""" | |
vae = kwargs.pop("vae", None) | |
transformer = kwargs.pop("transformer", None) | |
torch_dtype = kwargs.pop("torch_dtype", torch.float32) | |
# Load text encoder and tokenizer | |
text_encoder = CLIPTextModel.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
torch_dtype=torch_dtype | |
) | |
tokenizer = CLIPTokenizer.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="tokenizer" | |
) | |
# Load scheduler | |
from diffusers import UniPCMultistepScheduler | |
scheduler = UniPCMultistepScheduler.from_pretrained( | |
pretrained_model_name_or_path, | |
subfolder="scheduler" | |
) | |
return cls( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
transformer=transformer, | |
scheduler=scheduler, | |
) | |
def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt=None): | |
"""Encode text prompt to embeddings""" | |
batch_size = len(prompt) if isinstance(prompt, list) else 1 | |
text_inputs = self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
text_embeddings = self.text_encoder(text_input_ids.to(device))[0] | |
if do_classifier_free_guidance: | |
uncond_tokens = [""] * batch_size if negative_prompt is None else negative_prompt | |
uncond_input = self.tokenizer( | |
uncond_tokens, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
return text_embeddings | |
@torch.no_grad() | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
nag_negative_prompt: Optional[Union[str, List[str]]] = None, | |
nag_scale: float = 0.0, | |
nag_tau: float = 3.5, | |
nag_alpha: float = 0.5, | |
height: Optional[int] = 512, | |
width: Optional[int] = 512, | |
num_frames: int = 16, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 7.5, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
eta: float = 0.0, | |
generator: Optional[torch.Generator] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
callback: Optional[Callable] = None, | |
callback_steps: int = 1, | |
**kwargs, | |
): | |
# Use NAG negative prompt if provided | |
if nag_negative_prompt is not None: | |
negative_prompt = nag_negative_prompt | |
# Setup | |
batch_size = 1 if isinstance(prompt, str) else len(prompt) | |
device = self._execution_device | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
# Encode prompt | |
text_embeddings = self._encode_prompt( | |
prompt, device, do_classifier_free_guidance, negative_prompt | |
) | |
# Prepare latents | |
if hasattr(self.vae.config, 'latent_channels'): | |
num_channels_latents = self.vae.config.latent_channels | |
else: | |
num_channels_latents = 4 # Default for most VAEs | |
shape = ( | |
batch_size, | |
num_channels_latents, | |
num_frames, | |
height // self.vae_scale_factor, | |
width // self.vae_scale_factor, | |
) | |
if latents is None: | |
latents = torch.randn( | |
shape, | |
generator=generator, | |
device=device, | |
dtype=text_embeddings.dtype, | |
) | |
latents = latents * self.scheduler.init_noise_sigma | |
# Set timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps | |
# Denoising loop with NAG | |
for i, t in enumerate(timesteps): | |
# Expand for classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
# Predict noise residual | |
noise_pred = self.transformer( | |
latent_model_input, | |
timestep=t, | |
encoder_hidden_states=text_embeddings, | |
) | |
# Apply NAG | |
if nag_scale > 0: | |
# Compute attention-based guidance | |
b, c, f, h, w = noise_pred.shape | |
noise_flat = noise_pred.view(b, c, -1) | |
# Normalize and compute attention | |
noise_norm = F.normalize(noise_flat, dim=-1) | |
attention = F.softmax(noise_norm * nag_tau, dim=-1) | |
# Apply guidance | |
guidance = attention.mean(dim=-1, keepdim=True) * nag_alpha | |
guidance = guidance.unsqueeze(-1).unsqueeze(-1) | |
noise_pred = noise_pred + nag_scale * guidance * noise_pred | |
# Classifier free guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# Compute previous noisy sample | |
latents = self.scheduler.step(noise_pred, t, latents, eta=eta, generator=generator).prev_sample | |
# Callback | |
if callback is not None and i % callback_steps == 0: | |
callback(i, t, latents) | |
# Decode latents | |
if hasattr(self.vae.config, 'scaling_factor'): | |
latents = 1 / self.vae.config.scaling_factor * latents | |
else: | |
latents = 1 / 0.18215 * latents # Default SD scaling factor | |
video = self.vae.decode(latents).sample | |
video = (video / 2 + 0.5).clamp(0, 1) | |
# Convert to output format | |
video = video.cpu().float().numpy() | |
video = (video * 255).round().astype("uint8") | |
video = video.transpose(0, 2, 3, 4, 1) | |
frames = [] | |
for batch_idx in range(video.shape[0]): | |
batch_frames = [video[batch_idx, i] for i in range(video.shape[1])] | |
frames.append(batch_frames) | |
if not return_dict: | |
return (frames,) | |
return type('PipelineOutput', (), {'frames': frames})() | |
''') | |
print("NAG modules created successfully!") | |
# Ensure files are written and synced | |
import time | |
time.sleep(2) # Give more time for file writes | |
# Verify files exist | |
if not os.path.exists("src/transformer_wan_nag.py"): | |
raise RuntimeError("transformer_wan_nag.py not created") | |
if not os.path.exists("src/pipeline_wan_nag.py"): | |
raise RuntimeError("pipeline_wan_nag.py not created") | |
print("Files verified, importing modules...") | |
# Now import and run the main application | |
import types | |
import random | |
import spaces | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from diffusers import AutoencoderKL, UniPCMultistepScheduler, DDPMScheduler | |
from diffusers.utils import export_to_video | |
import gradio as gr | |
import tempfile | |
from huggingface_hub import hf_hub_download | |
import logging | |
import gc | |
# Ensure src files are created | |
import time | |
time.sleep(1) # Give a moment for file writes to complete | |
try: | |
# Import our custom modules | |
from src.pipeline_wan_nag import NAGWanPipeline | |
from src.transformer_wan_nag import NagWanTransformer3DModel | |
print("Successfully imported NAG modules") | |
except Exception as e: | |
print(f"Error importing NAG modules: {e}") | |
print("Attempting to recreate modules...") | |
# Wait a bit and try again | |
import time | |
time.sleep(3) | |
try: | |
from src.pipeline_wan_nag import NAGWanPipeline | |
from src.transformer_wan_nag import NagWanTransformer3DModel | |
print("Successfully imported NAG modules on second attempt") | |
except: | |
print("Failed to import modules. Please restart the application.") | |
sys.exit(1) | |
# MMAudio imports | |
try: | |
import mmaudio | |
except ImportError: | |
os.system("pip install -e .") | |
import mmaudio | |
# Set environment variables | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' | |
os.environ['HF_HUB_CACHE'] = '/tmp/hub' | |
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, | |
setup_eval_logging) | |
from mmaudio.model.flow_matching import FlowMatching | |
from mmaudio.model.networks import MMAudio, get_my_mmaudio | |
from mmaudio.model.sequence_config import SequenceConfig | |
from mmaudio.model.utils.features_utils import FeaturesUtils | |
# Constants | |
MOD_VALUE = 32 | |
DEFAULT_DURATION_SECONDS = 1 | |
DEFAULT_STEPS = 1 | |
DEFAULT_SEED = 2025 | |
DEFAULT_H_SLIDER_VALUE = 128 | |
DEFAULT_W_SLIDER_VALUE = 128 | |
NEW_FORMULA_MAX_AREA = 128.0 * 128.0 | |
SLIDER_MIN_H, SLIDER_MAX_H = 128, 256 | |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 256 | |
MAX_SEED = np.iinfo(np.int32).max | |
FIXED_FPS = 8 # Reduced FPS for demo | |
MIN_FRAMES_MODEL = 8 | |
MAX_FRAMES_MODEL = 32 # Reduced max frames for demo | |
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details" | |
# Note: Model IDs are kept for reference but not used in demo | |
MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers" | |
SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX" | |
SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors" | |
LORA_REPO_ID = "Kijai/WanVideo_comfy" | |
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" | |
# Initialize models | |
print("Creating demo models...") | |
# Create a simple VAE-like model for demo | |
class DemoVAE(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self._dtype = torch.float32 # Add dtype attribute | |
self.encoder = nn.Sequential( | |
nn.Conv2d(3, 64, 3, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(64, 4, 3, padding=1) | |
) | |
self.decoder = nn.Sequential( | |
nn.Conv2d(4, 64, 3, padding=1), | |
nn.ReLU(), | |
nn.Conv2d(64, 3, 3, padding=1), | |
nn.Tanh() # Output in [-1, 1] | |
) | |
self.config = type('Config', (), { | |
'scaling_factor': 0.18215, | |
'latent_channels': 4, | |
})() | |
def dtype(self): | |
"""Return the dtype of the model""" | |
return self._dtype | |
def dtype(self, value): | |
"""Set the dtype of the model""" | |
self._dtype = value | |
def to(self, *args, **kwargs): | |
"""Override to method to handle dtype""" | |
result = super().to(*args, **kwargs) | |
# Update dtype if moving to a specific dtype | |
for arg in args: | |
if isinstance(arg, torch.dtype): | |
self._dtype = arg | |
if 'dtype' in kwargs: | |
self._dtype = kwargs['dtype'] | |
return result | |
def encode(self, x): | |
# Simple encoding | |
encoded = self.encoder(x) | |
return type('EncoderOutput', (), {'latent_dist': type('LatentDist', (), {'sample': lambda: encoded})()})() | |
def decode(self, z): | |
# Simple decoding | |
# Handle different input shapes | |
if z.dim() == 5: # Video: (B, C, F, H, W) | |
b, c, f, h, w = z.shape | |
z = z.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w) | |
decoded = self.decoder(z) | |
decoded = decoded.reshape(b, f, 3, h * 8, w * 8).permute(0, 2, 1, 3, 4) | |
else: # Image: (B, C, H, W) | |
decoded = self.decoder(z) | |
return type('DecoderOutput', (), {'sample': decoded})() | |
vae = DemoVAE() | |
print("Creating simplified NAG transformer model...") | |
transformer = NagWanTransformer3DModel( | |
in_channels=4, | |
out_channels=4, | |
hidden_size=64, # Reduced from 1280 for demo | |
num_layers=1, # Reduced for demo | |
num_heads=4 # Reduced for demo | |
) | |
print("Creating pipeline...") | |
# Create a minimal pipeline for demo | |
pipe = NAGWanPipeline( | |
vae=vae, | |
text_encoder=None, | |
tokenizer=None, | |
transformer=transformer, | |
scheduler=DDPMScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
prediction_type="epsilon", | |
) | |
) | |
# Move to appropriate device | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Using device: {device}") | |
# Move models to device with explicit dtype | |
vae = vae.to(device).to(torch.float32) | |
transformer = transformer.to(device).to(torch.float32) | |
# Now move pipeline to device (it will handle the components) | |
try: | |
pipe = pipe.to(device) | |
print(f"Pipeline moved to {device}") | |
except Exception as e: | |
print(f"Warning: Could not move pipeline to {device}: {e}") | |
# Manually set device | |
pipe._execution_device = device | |
print("Demo version ready!") | |
# Check if transformer has the required methods | |
if hasattr(transformer, 'attn_processors'): | |
pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors | |
if hasattr(transformer, 'set_attn_processor'): | |
pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor | |
# Audio model setup | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
log = logging.getLogger() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
dtype = torch.bfloat16 | |
# Global audio model variables | |
audio_model = None | |
audio_net = None | |
audio_feature_utils = None | |
audio_seq_cfg = None | |
def load_audio_model(): | |
global audio_model, audio_net, audio_feature_utils, audio_seq_cfg | |
if audio_net is None: | |
audio_model = all_model_cfg['small_16k'] | |
audio_model.download_if_needed() | |
setup_eval_logging() | |
seq_cfg = audio_model.seq_cfg | |
net = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval() | |
net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True)) | |
log.info(f'Loaded weights from {audio_model.model_path}') | |
feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model.vae_path, | |
synchformer_ckpt=audio_model.synchformer_ckpt, | |
enable_conditions=True, | |
mode=audio_model.mode, | |
bigvgan_vocoder_ckpt=audio_model.bigvgan_16k_path, | |
need_vae_encoder=False) | |
feature_utils = feature_utils.to(device, dtype).eval() | |
audio_net = net | |
audio_feature_utils = feature_utils | |
audio_seq_cfg = seq_cfg | |
return audio_net, audio_feature_utils, audio_seq_cfg | |
# Helper functions | |
def cleanup_temp_files(): | |
temp_dir = tempfile.gettempdir() | |
for filename in os.listdir(temp_dir): | |
filepath = os.path.join(temp_dir, filename) | |
try: | |
if filename.endswith(('.mp4', '.flac', '.wav')): | |
os.remove(filepath) | |
except: | |
pass | |
def clear_cache(): | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
gc.collect() | |
# CSS | |
css = """ | |
.container { | |
max-width: 1400px; | |
margin: auto; | |
padding: 20px; | |
} | |
.main-title { | |
text-align: center; | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
font-size: 2.5em; | |
font-weight: bold; | |
margin-bottom: 10px; | |
} | |
.subtitle { | |
text-align: center; | |
color: #6b7280; | |
margin-bottom: 30px; | |
} | |
.prompt-container { | |
background: linear-gradient(135deg, #f3f4f6 0%, #e5e7eb 100%); | |
border-radius: 15px; | |
padding: 20px; | |
margin-bottom: 20px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.generate-btn { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
font-size: 1.2em; | |
font-weight: bold; | |
padding: 15px 30px; | |
border-radius: 10px; | |
border: none; | |
cursor: pointer; | |
transition: all 0.3s ease; | |
width: 100%; | |
margin-top: 20px; | |
} | |
.generate-btn:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 6px 20px rgba(102, 126, 234, 0.4); | |
} | |
.video-output { | |
border-radius: 15px; | |
overflow: hidden; | |
box-shadow: 0 10px 30px rgba(0, 0, 0, 0.2); | |
background: #1a1a1a; | |
padding: 10px; | |
} | |
.settings-panel { | |
background: #f9fafb; | |
border-radius: 15px; | |
padding: 20px; | |
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.05); | |
} | |
.slider-container { | |
background: white; | |
padding: 15px; | |
border-radius: 10px; | |
margin-bottom: 15px; | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); | |
} | |
.info-box { | |
background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); | |
border-radius: 10px; | |
padding: 15px; | |
margin: 10px 0; | |
border-left: 4px solid #667eea; | |
} | |
""" | |
# RIGHT AFTER the css definition, ADD these lines: | |
default_prompt = "A serene beach with waves gently rolling onto the shore" | |
default_audio_prompt = "" | |
default_audio_negative_prompt = "music" | |
def get_duration( | |
prompt, | |
nag_negative_prompt, nag_scale, | |
height, width, duration_seconds, | |
steps, | |
seed, randomize_seed, | |
audio_mode, audio_prompt, audio_negative_prompt, | |
audio_seed, audio_steps, audio_cfg_strength, | |
): | |
# Simplified duration calculation for demo | |
duration = int(duration_seconds) * int(steps) + 10 | |
if audio_mode == "Enable Audio": | |
duration += 30 # Reduced from 60 for demo | |
return min(duration, 60) # Cap at 60 seconds for demo | |
def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt, | |
audio_seed, audio_steps, audio_cfg_strength): | |
net, feature_utils, seq_cfg = load_audio_model() | |
rng = torch.Generator(device=device) | |
if audio_seed >= 0: | |
rng.manual_seed(audio_seed) | |
else: | |
rng.seed() | |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=audio_steps) | |
video_info = load_video(video_path, duration_sec) | |
clip_frames = video_info.clip_frames.unsqueeze(0) | |
sync_frames = video_info.sync_frames.unsqueeze(0) | |
duration = video_info.duration_sec | |
seq_cfg.duration = duration | |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) | |
audios = generate(clip_frames, | |
sync_frames, [audio_prompt], | |
negative_text=[audio_negative_prompt], | |
feature_utils=feature_utils, | |
net=net, | |
fm=fm, | |
rng=rng, | |
cfg_strength=audio_cfg_strength) | |
audio = audios.float().cpu()[0] | |
video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name | |
make_video(video_info, video_with_audio_path, audio, sampling_rate=seq_cfg.sampling_rate) | |
return video_with_audio_path | |
def generate_video( | |
prompt, | |
nag_negative_prompt, nag_scale, | |
height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS, | |
steps=DEFAULT_STEPS, | |
seed=DEFAULT_SEED, randomize_seed=False, | |
audio_mode="Video Only", audio_prompt="", audio_negative_prompt="music", | |
audio_seed=-1, audio_steps=25, audio_cfg_strength=4.5, | |
): | |
try: | |
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) | |
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) | |
num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) | |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
# Ensure transformer is on the right device and dtype | |
if hasattr(pipe, 'transformer'): | |
pipe.transformer = pipe.transformer.to(device).to(torch.float32) | |
if hasattr(pipe, 'vae'): | |
pipe.vae = pipe.vae.to(device).to(torch.float32) | |
print(f"Generating video: {target_w}x{target_h}, {num_frames} frames, seed {current_seed}") | |
with torch.inference_mode(): | |
nag_output_frames_list = pipe( | |
prompt=prompt, | |
nag_negative_prompt=nag_negative_prompt, | |
nag_scale=nag_scale, | |
nag_tau=3.5, | |
nag_alpha=0.5, | |
height=target_h, width=target_w, num_frames=num_frames, | |
guidance_scale=0., | |
num_inference_steps=int(steps), | |
generator=torch.Generator(device=device).manual_seed(current_seed) | |
).frames[0] | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
nag_video_path = tmpfile.name | |
export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS) | |
# Generate audio if enabled | |
video_with_audio_path = None | |
if audio_mode == "Enable Audio": | |
try: | |
video_with_audio_path = add_audio_to_video( | |
nag_video_path, duration_seconds, | |
audio_prompt, audio_negative_prompt, | |
audio_seed, audio_steps, audio_cfg_strength | |
) | |
except Exception as e: | |
print(f"Warning: Could not generate audio: {e}") | |
video_with_audio_path = None | |
clear_cache() | |
cleanup_temp_files() | |
return nag_video_path, video_with_audio_path, current_seed | |
except Exception as e: | |
print(f"Error generating video: {e}") | |
import traceback | |
traceback.print_exc() | |
# Return a simple error video | |
error_frames = [] | |
for i in range(8): # Create 8 frames | |
frame = np.zeros((128, 128, 3), dtype=np.uint8) | |
frame[:, :] = [255, 0, 0] # Red frame | |
# Add error text | |
error_frames.append(frame) | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
error_video_path = tmpfile.name | |
export_to_video(error_frames, error_video_path, fps=FIXED_FPS) | |
return error_video_path, None, 0 | |
def update_audio_visibility(audio_mode): | |
return gr.update(visible=(audio_mode == "Enable Audio")) | |
# Build interface | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
with gr.Column(elem_classes="container"): | |
gr.HTML(""" | |
<h1 class="main-title">๐ฌ NAG Video Demo</h1> | |
<p class="subtitle">Simple Text-to-Video with NAG + Audio Generation</p> | |
""") | |
gr.HTML(""" | |
<div class="info-box"> | |
<p>๐ <strong>Demo Version:</strong> This is a simplified demo that demonstrates NAG concepts without large model downloads</p> | |
<p>๐ <strong>NAG Technology:</strong> Normalized Attention Guidance for enhanced video quality</p> | |
<p>๐ต <strong>Audio:</strong> Optional synchronized audio generation with MMAudio</p> | |
<p>โก <strong>Fast:</strong> Runs without downloading 28GB model files</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(elem_classes="prompt-container"): | |
prompt = gr.Textbox( | |
label="โจ Video Prompt", | |
value=default_prompt, | |
placeholder="Describe your video scene...", | |
lines=2, | |
elem_classes="prompt-input" | |
) | |
with gr.Accordion("๐จ Advanced Prompt Settings", open=False): | |
nag_negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value=DEFAULT_NAG_NEGATIVE_PROMPT, | |
lines=2, | |
) | |
nag_scale = gr.Slider( | |
label="NAG Scale", | |
minimum=0.0, | |
maximum=20.0, | |
step=0.25, | |
value=5.0, | |
info="Higher values = stronger guidance (0 = no NAG)" | |
) | |
audio_mode = gr.Radio( | |
choices=["Video Only", "Enable Audio"], | |
value="Video Only", | |
label="๐ต Audio Mode", | |
info="Enable to add audio to your generated video" | |
) | |
with gr.Column(visible=False) as audio_settings: | |
audio_prompt = gr.Textbox( | |
label="๐ต Audio Prompt", | |
value=default_audio_prompt, | |
placeholder="Describe the audio (e.g., 'waves, seagulls', 'footsteps')", | |
lines=2 | |
) | |
audio_negative_prompt = gr.Textbox( | |
label="โ Audio Negative Prompt", | |
value=default_audio_negative_prompt, | |
lines=2 | |
) | |
with gr.Row(): | |
audio_seed = gr.Number( | |
label="๐ฒ Audio Seed", | |
value=-1, | |
precision=0, | |
minimum=-1 | |
) | |
audio_steps = gr.Slider( | |
minimum=1, | |
maximum=25, | |
step=1, | |
value=10, | |
label="๐ Audio Steps" | |
) | |
audio_cfg_strength = gr.Slider( | |
minimum=1.0, | |
maximum=10.0, | |
step=0.5, | |
value=4.5, | |
label="๐ฏ Audio Guidance" | |
) | |
with gr.Group(elem_classes="settings-panel"): | |
gr.Markdown("### โ๏ธ Video Settings") | |
with gr.Row(): | |
duration_seconds_input = gr.Slider( | |
minimum=1, | |
maximum=2, | |
step=1, | |
value=DEFAULT_DURATION_SECONDS, | |
label="๐ฑ Duration (seconds)", | |
elem_classes="slider-container" | |
) | |
steps_slider = gr.Slider( | |
minimum=1, | |
maximum=2, | |
step=1, | |
value=DEFAULT_STEPS, | |
label="๐ Inference Steps", | |
elem_classes="slider-container" | |
) | |
with gr.Row(): | |
height_input = gr.Slider( | |
minimum=SLIDER_MIN_H, | |
maximum=SLIDER_MAX_H, | |
step=MOD_VALUE, | |
value=DEFAULT_H_SLIDER_VALUE, | |
label=f"๐ Height (ร{MOD_VALUE})", | |
elem_classes="slider-container" | |
) | |
width_input = gr.Slider( | |
minimum=SLIDER_MIN_W, | |
maximum=SLIDER_MAX_W, | |
step=MOD_VALUE, | |
value=DEFAULT_W_SLIDER_VALUE, | |
label=f"๐ Width (ร{MOD_VALUE})", | |
elem_classes="slider-container" | |
) | |
with gr.Row(): | |
seed_input = gr.Slider( | |
label="๐ฑ Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=DEFAULT_SEED, | |
interactive=True | |
) | |
randomize_seed_checkbox = gr.Checkbox( | |
label="๐ฒ Random Seed", | |
value=True, | |
interactive=True | |
) | |
generate_button = gr.Button( | |
"๐ฌ Generate Video", | |
variant="primary", | |
elem_classes="generate-btn" | |
) | |
with gr.Column(scale=1): | |
nag_video_output = gr.Video( | |
label="Generated Video", | |
autoplay=True, | |
interactive=False, | |
elem_classes="video-output" | |
) | |
video_with_audio_output = gr.Video( | |
label="๐ฅ Generated Video with Audio", | |
autoplay=True, | |
interactive=False, | |
visible=False, | |
elem_classes="video-output" | |
) | |
gr.HTML(""" | |
<div style="text-align: center; margin-top: 20px; color: #6b7280;"> | |
<p>๐ก Demo version with simplified model - Real NAG would produce higher quality results</p> | |
<p>๐ก Tip: Try different NAG scales for varied artistic effects!</p> | |
</div> | |
""") | |
gr.Markdown("### ๐ฏ Example Prompts") | |
gr.Examples( | |
examples=[ | |
["A cat playing guitar on stage", DEFAULT_NAG_NEGATIVE_PROMPT, 5, | |
128, 128, 1, | |
1, DEFAULT_SEED, False, | |
"Enable Audio", "guitar music", default_audio_negative_prompt, -1, 10, 4.5], | |
["A red car driving on a cliff road", DEFAULT_NAG_NEGATIVE_PROMPT, 5, | |
128, 128, 1, | |
1, DEFAULT_SEED, False, | |
"Enable Audio", "car engine, wind", default_audio_negative_prompt, -1, 10, 4.5], | |
["Glowing jellyfish floating in the sky", DEFAULT_NAG_NEGATIVE_PROMPT, 5, | |
128, 128, 1, | |
1, DEFAULT_SEED, False, | |
"Video Only", "", default_audio_negative_prompt, -1, 10, 4.5], | |
], | |
fn=generate_video, | |
inputs=[prompt, nag_negative_prompt, nag_scale, | |
height_input, width_input, duration_seconds_input, | |
steps_slider, seed_input, randomize_seed_checkbox, | |
audio_mode, audio_prompt, audio_negative_prompt, | |
audio_seed, audio_steps, audio_cfg_strength], | |
outputs=[nag_video_output, video_with_audio_output, seed_input], | |
cache_examples="lazy" | |
) | |
# Event handlers | |
audio_mode.change( | |
fn=update_audio_visibility, | |
inputs=[audio_mode], | |
outputs=[audio_settings, video_with_audio_output] | |
) | |
ui_inputs = [ | |
prompt, | |
nag_negative_prompt, nag_scale, | |
height_input, width_input, duration_seconds_input, | |
steps_slider, | |
seed_input, randomize_seed_checkbox, | |
audio_mode, audio_prompt, audio_negative_prompt, | |
audio_seed, audio_steps, audio_cfg_strength, | |
] | |
generate_button.click( | |
fn=generate_video, | |
inputs=ui_inputs, | |
outputs=[nag_video_output, video_with_audio_output, seed_input], | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |