ginipick commited on
Commit
6d34eee
·
verified ·
1 Parent(s): a88fb30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -843
app.py CHANGED
@@ -1,431 +1,23 @@
1
- # Create src directory structure
2
- import os
3
- import sys
4
-
5
- print("Starting NAG Video Demo application...")
6
-
7
- # Add current directory to Python path
8
- try:
9
- current_dir = os.path.dirname(os.path.abspath(__file__))
10
- except:
11
- current_dir = os.getcwd()
12
-
13
- sys.path.insert(0, current_dir)
14
- print(f"Added {current_dir} to Python path")
15
-
16
- os.makedirs("src", exist_ok=True)
17
-
18
- # Install required packages
19
- os.system("pip install safetensors")
20
-
21
- # Create __init__.py
22
- with open("src/__init__.py", "w") as f:
23
- f.write("")
24
-
25
- print("Creating NAG transformer module...")
26
-
27
- # Create transformer_wan_nag.py
28
- with open("src/transformer_wan_nag.py", "w") as f:
29
- f.write('''
30
- import torch
31
- import torch.nn as nn
32
- from typing import Optional, Dict, Any
33
- import torch.nn.functional as F
34
-
35
- class NagWanTransformer3DModel(nn.Module):
36
- """NAG-enhanced Transformer for video generation (simplified demo)"""
37
-
38
- def __init__(
39
- self,
40
- in_channels: int = 4,
41
- out_channels: int = 4,
42
- hidden_size: int = 64,
43
- num_layers: int = 1,
44
- num_heads: int = 4,
45
- ):
46
- super().__init__()
47
- self.in_channels = in_channels
48
- self.out_channels = out_channels
49
- self.hidden_size = hidden_size
50
- self.training = False
51
- self._dtype = torch.float32 # Add dtype attribute
52
-
53
- # Dummy config for compatibility
54
- self.config = type('Config', (), {
55
- 'in_channels': in_channels,
56
- 'out_channels': out_channels,
57
- 'hidden_size': hidden_size,
58
- 'num_attention_heads': num_heads,
59
- 'attention_head_dim': hidden_size // num_heads,
60
- })()
61
-
62
- # Simple conv layers for demo
63
- self.conv_in = nn.Conv3d(in_channels, hidden_size, kernel_size=3, padding=1)
64
- self.conv_mid = nn.Conv3d(hidden_size, hidden_size, kernel_size=3, padding=1)
65
- self.conv_out = nn.Conv3d(hidden_size, out_channels, kernel_size=3, padding=1)
66
-
67
- # Time embedding
68
- self.time_embed = nn.Sequential(
69
- nn.Linear(1, hidden_size),
70
- nn.SiLU(),
71
- nn.Linear(hidden_size, hidden_size),
72
- )
73
-
74
- @property
75
- def dtype(self):
76
- """Return the dtype of the model"""
77
- return self._dtype
78
-
79
- @dtype.setter
80
- def dtype(self, value):
81
- """Set the dtype of the model"""
82
- self._dtype = value
83
-
84
- def to(self, *args, **kwargs):
85
- """Override to method to handle dtype"""
86
- result = super().to(*args, **kwargs)
87
- # Update dtype if moving to a specific dtype
88
- for arg in args:
89
- if isinstance(arg, torch.dtype):
90
- self._dtype = arg
91
- if 'dtype' in kwargs:
92
- self._dtype = kwargs['dtype']
93
- return result
94
-
95
- @staticmethod
96
- def attn_processors():
97
- return {}
98
-
99
- @staticmethod
100
- def set_attn_processor(processor):
101
- pass
102
-
103
- def forward(
104
- self,
105
- hidden_states: torch.Tensor,
106
- timestep: Optional[torch.Tensor] = None,
107
- encoder_hidden_states: Optional[torch.Tensor] = None,
108
- attention_mask: Optional[torch.Tensor] = None,
109
- **kwargs
110
- ):
111
- # Simple forward pass for demo
112
- batch_size = hidden_states.shape[0]
113
-
114
- # Time embedding
115
- if timestep is not None:
116
- # Ensure timestep is the right shape
117
- if timestep.ndim == 0:
118
- timestep = timestep.unsqueeze(0)
119
- if timestep.shape[0] != batch_size:
120
- timestep = timestep.repeat(batch_size)
121
-
122
- # Normalize timestep to [0, 1]
123
- t_emb = timestep.float() / 1000.0
124
- t_emb = t_emb.view(-1, 1)
125
- t_emb = self.time_embed(t_emb)
126
-
127
- # Reshape for broadcasting
128
- t_emb = t_emb.view(batch_size, -1, 1, 1, 1)
129
-
130
- # Simple convolutions
131
- h = self.conv_in(hidden_states)
132
-
133
- # Add time embedding if available
134
- if timestep is not None:
135
- h = h + t_emb
136
-
137
- h = F.silu(h)
138
- h = self.conv_mid(h)
139
- h = F.silu(h)
140
- h = self.conv_out(h)
141
-
142
- # Add residual connection
143
- h = h + hidden_states
144
-
145
- return h
146
- ''')
147
-
148
- print("Creating NAG pipeline module...")
149
-
150
- # Create pipeline_wan_nag.py
151
- with open("src/pipeline_wan_nag.py", "w") as f:
152
- f.write('''
153
- import torch
154
- import torch.nn.functional as F
155
- from typing import List, Optional, Union, Tuple, Callable, Dict, Any
156
- from diffusers import DiffusionPipeline
157
- from diffusers.utils import logging, export_to_video
158
- from diffusers.schedulers import KarrasDiffusionSchedulers
159
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
160
- from transformers import CLIPTextModel, CLIPTokenizer
161
- import numpy as np
162
-
163
- logger = logging.get_logger(__name__)
164
-
165
- class NAGWanPipeline(DiffusionPipeline):
166
- """NAG-enhanced pipeline for video generation"""
167
-
168
- def __init__(
169
- self,
170
- vae,
171
- text_encoder,
172
- tokenizer,
173
- transformer,
174
- scheduler,
175
- ):
176
- super().__init__()
177
- self.register_modules(
178
- vae=vae,
179
- text_encoder=text_encoder,
180
- tokenizer=tokenizer,
181
- transformer=transformer,
182
- scheduler=scheduler,
183
- )
184
- # Set vae scale factor
185
- if hasattr(self.vae, 'config') and hasattr(self.vae.config, 'block_out_channels'):
186
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
187
- else:
188
- self.vae_scale_factor = 8 # Default value for most VAEs
189
-
190
- @classmethod
191
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
192
- """Load pipeline from pretrained model"""
193
- vae = kwargs.pop("vae", None)
194
- transformer = kwargs.pop("transformer", None)
195
- torch_dtype = kwargs.pop("torch_dtype", torch.float32)
196
-
197
- # Load text encoder and tokenizer
198
- text_encoder = CLIPTextModel.from_pretrained(
199
- pretrained_model_name_or_path,
200
- subfolder="text_encoder",
201
- torch_dtype=torch_dtype
202
- )
203
- tokenizer = CLIPTokenizer.from_pretrained(
204
- pretrained_model_name_or_path,
205
- subfolder="tokenizer"
206
- )
207
-
208
- # Load scheduler
209
- from diffusers import UniPCMultistepScheduler
210
- scheduler = UniPCMultistepScheduler.from_pretrained(
211
- pretrained_model_name_or_path,
212
- subfolder="scheduler"
213
- )
214
-
215
- return cls(
216
- vae=vae,
217
- text_encoder=text_encoder,
218
- tokenizer=tokenizer,
219
- transformer=transformer,
220
- scheduler=scheduler,
221
- )
222
-
223
- def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt=None):
224
- """Encode text prompt to embeddings"""
225
- batch_size = len(prompt) if isinstance(prompt, list) else 1
226
-
227
- text_inputs = self.tokenizer(
228
- prompt,
229
- padding="max_length",
230
- max_length=self.tokenizer.model_max_length,
231
- truncation=True,
232
- return_tensors="pt",
233
- )
234
- text_input_ids = text_inputs.input_ids
235
- text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
236
-
237
- if do_classifier_free_guidance:
238
- uncond_tokens = [""] * batch_size if negative_prompt is None else negative_prompt
239
- uncond_input = self.tokenizer(
240
- uncond_tokens,
241
- padding="max_length",
242
- max_length=self.tokenizer.model_max_length,
243
- truncation=True,
244
- return_tensors="pt",
245
- )
246
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
247
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
248
-
249
- return text_embeddings
250
-
251
- @torch.no_grad()
252
- def __call__(
253
- self,
254
- prompt: Union[str, List[str]] = None,
255
- nag_negative_prompt: Optional[Union[str, List[str]]] = None,
256
- nag_scale: float = 0.0,
257
- nag_tau: float = 3.5,
258
- nag_alpha: float = 0.5,
259
- height: Optional[int] = 512,
260
- width: Optional[int] = 512,
261
- num_frames: int = 16,
262
- num_inference_steps: int = 50,
263
- guidance_scale: float = 7.5,
264
- negative_prompt: Optional[Union[str, List[str]]] = None,
265
- eta: float = 0.0,
266
- generator: Optional[torch.Generator] = None,
267
- latents: Optional[torch.FloatTensor] = None,
268
- output_type: Optional[str] = "pil",
269
- return_dict: bool = True,
270
- callback: Optional[Callable] = None,
271
- callback_steps: int = 1,
272
- **kwargs,
273
- ):
274
- # Use NAG negative prompt if provided
275
- if nag_negative_prompt is not None:
276
- negative_prompt = nag_negative_prompt
277
-
278
- # Setup
279
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
280
- device = self._execution_device
281
- do_classifier_free_guidance = guidance_scale > 1.0
282
-
283
- # Encode prompt
284
- text_embeddings = self._encode_prompt(
285
- prompt, device, do_classifier_free_guidance, negative_prompt
286
- )
287
-
288
- # Prepare latents
289
- if hasattr(self.vae.config, 'latent_channels'):
290
- num_channels_latents = self.vae.config.latent_channels
291
- else:
292
- num_channels_latents = 4 # Default for most VAEs
293
- shape = (
294
- batch_size,
295
- num_channels_latents,
296
- num_frames,
297
- height // self.vae_scale_factor,
298
- width // self.vae_scale_factor,
299
- )
300
-
301
- if latents is None:
302
- latents = torch.randn(
303
- shape,
304
- generator=generator,
305
- device=device,
306
- dtype=text_embeddings.dtype,
307
- )
308
- latents = latents * self.scheduler.init_noise_sigma
309
-
310
- # Set timesteps
311
- self.scheduler.set_timesteps(num_inference_steps, device=device)
312
- timesteps = self.scheduler.timesteps
313
-
314
- # Denoising loop with NAG
315
- for i, t in enumerate(timesteps):
316
- # Expand for classifier free guidance
317
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
318
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
319
-
320
- # Predict noise residual
321
- noise_pred = self.transformer(
322
- latent_model_input,
323
- timestep=t,
324
- encoder_hidden_states=text_embeddings,
325
- )
326
-
327
- # Apply NAG
328
- if nag_scale > 0:
329
- # Compute attention-based guidance
330
- b, c, f, h, w = noise_pred.shape
331
- noise_flat = noise_pred.view(b, c, -1)
332
-
333
- # Normalize and compute attention
334
- noise_norm = F.normalize(noise_flat, dim=-1)
335
- attention = F.softmax(noise_norm * nag_tau, dim=-1)
336
-
337
- # Apply guidance
338
- guidance = attention.mean(dim=-1, keepdim=True) * nag_alpha
339
- guidance = guidance.unsqueeze(-1).unsqueeze(-1)
340
- noise_pred = noise_pred + nag_scale * guidance * noise_pred
341
-
342
- # Classifier free guidance
343
- if do_classifier_free_guidance:
344
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
345
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
346
-
347
- # Compute previous noisy sample
348
- latents = self.scheduler.step(noise_pred, t, latents, eta=eta, generator=generator).prev_sample
349
-
350
- # Callback
351
- if callback is not None and i % callback_steps == 0:
352
- callback(i, t, latents)
353
-
354
- # Decode latents
355
- if hasattr(self.vae.config, 'scaling_factor'):
356
- latents = 1 / self.vae.config.scaling_factor * latents
357
- else:
358
- latents = 1 / 0.18215 * latents # Default SD scaling factor
359
- video = self.vae.decode(latents).sample
360
- video = (video / 2 + 0.5).clamp(0, 1)
361
-
362
- # Convert to output format
363
- video = video.cpu().float().numpy()
364
- video = (video * 255).round().astype("uint8")
365
- video = video.transpose(0, 2, 3, 4, 1)
366
-
367
- frames = []
368
- for batch_idx in range(video.shape[0]):
369
- batch_frames = [video[batch_idx, i] for i in range(video.shape[1])]
370
- frames.append(batch_frames)
371
-
372
- if not return_dict:
373
- return (frames,)
374
-
375
- return type('PipelineOutput', (), {'frames': frames})()
376
- ''')
377
-
378
- print("NAG modules created successfully!")
379
-
380
- # Ensure files are written and synced
381
- import time
382
- time.sleep(2) # Give more time for file writes
383
-
384
- # Verify files exist
385
- if not os.path.exists("src/transformer_wan_nag.py"):
386
- raise RuntimeError("transformer_wan_nag.py not created")
387
- if not os.path.exists("src/pipeline_wan_nag.py"):
388
- raise RuntimeError("pipeline_wan_nag.py not created")
389
-
390
- print("Files verified, importing modules...")
391
-
392
- # Now import and run the main application
393
  import types
394
  import random
395
  import spaces
 
 
 
 
 
396
  import torch
397
- import torch.nn as nn
398
  import numpy as np
399
- from diffusers import AutoencoderKL, UniPCMultistepScheduler, DDPMScheduler
 
400
  from diffusers.utils import export_to_video
 
401
  import gradio as gr
402
  import tempfile
403
  from huggingface_hub import hf_hub_download
404
- import logging
405
- import gc
406
-
407
- # Ensure src files are created
408
- import time
409
- time.sleep(1) # Give a moment for file writes to complete
410
 
411
- try:
412
- # Import our custom modules
413
- from src.pipeline_wan_nag import NAGWanPipeline
414
- from src.transformer_wan_nag import NagWanTransformer3DModel
415
- print("Successfully imported NAG modules")
416
- except Exception as e:
417
- print(f"Error importing NAG modules: {e}")
418
- print("Attempting to recreate modules...")
419
- # Wait a bit and try again
420
- import time
421
- time.sleep(3)
422
- try:
423
- from src.pipeline_wan_nag import NAGWanPipeline
424
- from src.transformer_wan_nag import NagWanTransformer3DModel
425
- print("Successfully imported NAG modules on second attempt")
426
- except:
427
- print("Failed to import modules. Please restart the application.")
428
- sys.exit(1)
429
 
430
  # MMAudio imports
431
  try:
@@ -434,217 +26,209 @@ except ImportError:
434
  os.system("pip install -e .")
435
  import mmaudio
436
 
437
- # Set environment variables
438
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
439
- os.environ['HF_HUB_CACHE'] = '/tmp/hub'
440
-
441
- from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
442
- setup_eval_logging)
443
  from mmaudio.model.flow_matching import FlowMatching
444
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
445
  from mmaudio.model.sequence_config import SequenceConfig
446
  from mmaudio.model.utils.features_utils import FeaturesUtils
447
 
448
- # Constants
449
  MOD_VALUE = 32
450
- DEFAULT_DURATION_SECONDS = 1
451
- DEFAULT_STEPS = 1
452
  DEFAULT_SEED = 2025
453
- DEFAULT_H_SLIDER_VALUE = 128
454
- DEFAULT_W_SLIDER_VALUE = 128
455
- NEW_FORMULA_MAX_AREA = 128.0 * 128.0
456
 
457
- SLIDER_MIN_H, SLIDER_MAX_H = 128, 256
458
- SLIDER_MIN_W, SLIDER_MAX_W = 128, 256
459
  MAX_SEED = np.iinfo(np.int32).max
460
 
461
- FIXED_FPS = 8 # Reduced FPS for demo
462
  MIN_FRAMES_MODEL = 8
463
- MAX_FRAMES_MODEL = 32 # Reduced max frames for demo
464
 
465
  DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
 
466
 
467
- # Note: Model IDs are kept for reference but not used in demo
468
  MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
469
  SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
470
  SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors"
471
  LORA_REPO_ID = "Kijai/WanVideo_comfy"
472
  LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
473
 
474
- # Initialize models
475
- print("Creating demo models...")
476
-
477
- # Create a simple VAE-like model for demo
478
- class DemoVAE(nn.Module):
479
- def __init__(self):
480
- super().__init__()
481
- self._dtype = torch.float32 # Add dtype attribute
482
- self.encoder = nn.Sequential(
483
- nn.Conv2d(3, 64, 3, padding=1),
484
- nn.ReLU(),
485
- nn.Conv2d(64, 4, 3, padding=1)
486
- )
487
- self.decoder = nn.Sequential(
488
- nn.Conv2d(4, 64, 3, padding=1),
489
- nn.ReLU(),
490
- nn.Conv2d(64, 3, 3, padding=1),
491
- nn.Tanh() # Output in [-1, 1]
492
- )
493
- self.config = type('Config', (), {
494
- 'scaling_factor': 0.18215,
495
- 'latent_channels': 4,
496
- })()
497
-
498
- @property
499
- def dtype(self):
500
- """Return the dtype of the model"""
501
- return self._dtype
502
-
503
- @dtype.setter
504
- def dtype(self, value):
505
- """Set the dtype of the model"""
506
- self._dtype = value
507
-
508
- def to(self, *args, **kwargs):
509
- """Override to method to handle dtype"""
510
- result = super().to(*args, **kwargs)
511
- # Update dtype if moving to a specific dtype
512
- for arg in args:
513
- if isinstance(arg, torch.dtype):
514
- self._dtype = arg
515
- if 'dtype' in kwargs:
516
- self._dtype = kwargs['dtype']
517
- return result
518
-
519
- def encode(self, x):
520
- # Simple encoding
521
- encoded = self.encoder(x)
522
- return type('EncoderOutput', (), {'latent_dist': type('LatentDist', (), {'sample': lambda: encoded})()})()
523
-
524
- def decode(self, z):
525
- # Simple decoding
526
- # Handle different input shapes
527
- if z.dim() == 5: # Video: (B, C, F, H, W)
528
- b, c, f, h, w = z.shape
529
- z = z.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w)
530
- decoded = self.decoder(z)
531
- decoded = decoded.reshape(b, f, 3, h * 8, w * 8).permute(0, 2, 1, 3, 4)
532
- else: # Image: (B, C, H, W)
533
- decoded = self.decoder(z)
534
- return type('DecoderOutput', (), {'sample': decoded})()
535
-
536
- vae = DemoVAE()
537
-
538
- print("Creating simplified NAG transformer model...")
539
- transformer = NagWanTransformer3DModel(
540
- in_channels=4,
541
- out_channels=4,
542
- hidden_size=64, # Reduced from 1280 for demo
543
- num_layers=1, # Reduced for demo
544
- num_heads=4 # Reduced for demo
545
- )
546
-
547
- print("Creating pipeline...")
548
- # Create a minimal pipeline for demo
549
- pipe = NAGWanPipeline(
550
- vae=vae,
551
- text_encoder=None,
552
- tokenizer=None,
553
- transformer=transformer,
554
- scheduler=DDPMScheduler(
555
- num_train_timesteps=1000,
556
- beta_start=0.00085,
557
- beta_end=0.012,
558
- beta_schedule="scaled_linear",
559
- clip_sample=False,
560
- prediction_type="epsilon",
561
- )
562
- )
563
-
564
- # Move to appropriate device
565
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
566
- print(f"Using device: {device}")
567
-
568
- # Move models to device with explicit dtype
569
- vae = vae.to(device).to(torch.float32)
570
- transformer = transformer.to(device).to(torch.float32)
571
-
572
- # Now move pipeline to device (it will handle the components)
573
- try:
574
- pipe = pipe.to(device)
575
- print(f"Pipeline moved to {device}")
576
- except Exception as e:
577
- print(f"Warning: Could not move pipeline to {device}: {e}")
578
- # Manually set device
579
- pipe._execution_device = device
580
-
581
- print("Demo version ready!")
582
-
583
- # Check if transformer has the required methods
584
- if hasattr(transformer, 'attn_processors'):
585
- pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
586
- if hasattr(transformer, 'set_attn_processor'):
587
- pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
588
-
589
- # Audio model setup
590
  torch.backends.cuda.matmul.allow_tf32 = True
591
  torch.backends.cudnn.allow_tf32 = True
592
-
593
  log = logging.getLogger()
594
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
595
  dtype = torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
596
 
597
- # Global audio model variables
598
- audio_model = None
599
- audio_net = None
600
- audio_feature_utils = None
601
- audio_seq_cfg = None
602
 
603
- def load_audio_model():
604
- global audio_model, audio_net, audio_feature_utils, audio_seq_cfg
 
605
 
606
- if audio_net is None:
607
- audio_model = all_model_cfg['small_16k']
608
- audio_model.download_if_needed()
609
- setup_eval_logging()
610
-
611
- seq_cfg = audio_model.seq_cfg
612
- net = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval()
613
- net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True))
614
- log.info(f'Loaded weights from {audio_model.model_path}')
615
-
616
- feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model.vae_path,
617
- synchformer_ckpt=audio_model.synchformer_ckpt,
618
- enable_conditions=True,
619
- mode=audio_model.mode,
620
- bigvgan_vocoder_ckpt=audio_model.bigvgan_16k_path,
621
- need_vae_encoder=False)
622
- feature_utils = feature_utils.to(device, dtype).eval()
623
-
624
- audio_net = net
625
- audio_feature_utils = feature_utils
626
- audio_seq_cfg = seq_cfg
627
 
628
- return audio_net, audio_feature_utils, audio_seq_cfg
629
 
630
- # Helper functions
631
- def cleanup_temp_files():
632
- temp_dir = tempfile.gettempdir()
633
- for filename in os.listdir(temp_dir):
634
- filepath = os.path.join(temp_dir, filename)
635
- try:
636
- if filename.endswith(('.mp4', '.flac', '.wav')):
637
- os.remove(filepath)
638
- except:
639
- pass
640
 
641
- def clear_cache():
642
- if torch.cuda.is_available():
643
- torch.cuda.empty_cache()
644
- torch.cuda.synchronize()
645
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
647
- # CSS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
  css = """
649
  .container {
650
  max-width: 1400px;
@@ -716,237 +300,63 @@ css = """
716
  margin: 10px 0;
717
  border-left: 4px solid #667eea;
718
  }
 
 
 
 
 
 
 
719
  """
720
 
721
- # RIGHT AFTER the css definition, ADD these lines:
722
- default_prompt = "A serene beach with waves gently rolling onto the shore"
723
- default_audio_prompt = ""
724
- default_audio_negative_prompt = "music"
725
-
726
-
727
- def get_duration(
728
- prompt,
729
- nag_negative_prompt, nag_scale,
730
- height, width, duration_seconds,
731
- steps,
732
- seed, randomize_seed,
733
- audio_mode, audio_prompt, audio_negative_prompt,
734
- audio_seed, audio_steps, audio_cfg_strength,
735
- ):
736
- # Simplified duration calculation for demo
737
- duration = int(duration_seconds) * int(steps) + 10
738
- if audio_mode == "Enable Audio":
739
- duration += 30 # Reduced from 60 for demo
740
- return min(duration, 60) # Cap at 60 seconds for demo
741
-
742
- @torch.inference_mode()
743
- def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
744
- audio_seed, audio_steps, audio_cfg_strength):
745
- net, feature_utils, seq_cfg = load_audio_model()
746
-
747
- rng = torch.Generator(device=device)
748
- if audio_seed >= 0:
749
- rng.manual_seed(audio_seed)
750
- else:
751
- rng.seed()
752
-
753
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=audio_steps)
754
-
755
- video_info = load_video(video_path, duration_sec)
756
- clip_frames = video_info.clip_frames.unsqueeze(0)
757
- sync_frames = video_info.sync_frames.unsqueeze(0)
758
- duration = video_info.duration_sec
759
- seq_cfg.duration = duration
760
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
761
-
762
- audios = generate(clip_frames,
763
- sync_frames, [audio_prompt],
764
- negative_text=[audio_negative_prompt],
765
- feature_utils=feature_utils,
766
- net=net,
767
- fm=fm,
768
- rng=rng,
769
- cfg_strength=audio_cfg_strength)
770
- audio = audios.float().cpu()[0]
771
-
772
- video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
773
- make_video(video_info, video_with_audio_path, audio, sampling_rate=seq_cfg.sampling_rate)
774
-
775
- return video_with_audio_path
776
-
777
- @spaces.GPU(duration=get_duration)
778
- def generate_video(
779
- prompt,
780
- nag_negative_prompt, nag_scale,
781
- height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS,
782
- steps=DEFAULT_STEPS,
783
- seed=DEFAULT_SEED, randomize_seed=False,
784
- audio_mode="Video Only", audio_prompt="", audio_negative_prompt="music",
785
- audio_seed=-1, audio_steps=25, audio_cfg_strength=4.5,
786
- ):
787
- try:
788
- target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
789
- target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
790
-
791
- num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
792
-
793
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
794
-
795
- # Ensure transformer is on the right device and dtype
796
- if hasattr(pipe, 'transformer'):
797
- pipe.transformer = pipe.transformer.to(device).to(torch.float32)
798
- if hasattr(pipe, 'vae'):
799
- pipe.vae = pipe.vae.to(device).to(torch.float32)
800
-
801
- print(f"Generating video: {target_w}x{target_h}, {num_frames} frames, seed {current_seed}")
802
-
803
- with torch.inference_mode():
804
- nag_output_frames_list = pipe(
805
- prompt=prompt,
806
- nag_negative_prompt=nag_negative_prompt,
807
- nag_scale=nag_scale,
808
- nag_tau=3.5,
809
- nag_alpha=0.5,
810
- height=target_h, width=target_w, num_frames=num_frames,
811
- guidance_scale=0.,
812
- num_inference_steps=int(steps),
813
- generator=torch.Generator(device=device).manual_seed(current_seed)
814
- ).frames[0]
815
-
816
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
817
- nag_video_path = tmpfile.name
818
- export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS)
819
-
820
- # Generate audio if enabled
821
- video_with_audio_path = None
822
- if audio_mode == "Enable Audio":
823
- try:
824
- video_with_audio_path = add_audio_to_video(
825
- nag_video_path, duration_seconds,
826
- audio_prompt, audio_negative_prompt,
827
- audio_seed, audio_steps, audio_cfg_strength
828
- )
829
- except Exception as e:
830
- print(f"Warning: Could not generate audio: {e}")
831
- video_with_audio_path = None
832
-
833
- clear_cache()
834
- cleanup_temp_files()
835
-
836
- return nag_video_path, video_with_audio_path, current_seed
837
-
838
- except Exception as e:
839
- print(f"Error generating video: {e}")
840
- import traceback
841
- traceback.print_exc()
842
-
843
- # Return a simple error video
844
- error_frames = []
845
- for i in range(8): # Create 8 frames
846
- frame = np.zeros((128, 128, 3), dtype=np.uint8)
847
- frame[:, :] = [255, 0, 0] # Red frame
848
- # Add error text
849
- error_frames.append(frame)
850
-
851
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
852
- error_video_path = tmpfile.name
853
- export_to_video(error_frames, error_video_path, fps=FIXED_FPS)
854
- return error_video_path, None, 0
855
-
856
- def update_audio_visibility(audio_mode):
857
- return gr.update(visible=(audio_mode == "Enable Audio"))
858
-
859
- # Build interface
860
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
861
  with gr.Column(elem_classes="container"):
862
  gr.HTML("""
863
- <h1 class="main-title">🎬 NAG Video Demo</h1>
864
- <p class="subtitle">Simple Text-to-Video with NAG + Audio Generation</p>
865
  """)
866
 
867
  gr.HTML("""
868
  <div class="info-box">
869
- <p>📌 <strong>Demo Version:</strong> This is a simplified demo that demonstrates NAG concepts without large model downloads</p>
870
- <p>🚀 <strong>NAG Technology:</strong> Normalized Attention Guidance for enhanced video quality</p>
871
- <p>🎵 <strong>Audio:</strong> Optional synchronized audio generation with MMAudio</p>
872
- <p>⚡ <strong>Fast:</strong> Runs without downloading 28GB model files</p>
873
  </div>
874
  """)
875
-
876
  with gr.Row():
877
  with gr.Column(scale=1):
878
  with gr.Group(elem_classes="prompt-container"):
879
  prompt = gr.Textbox(
880
- label="✨ Video Prompt",
881
- value=default_prompt,
882
- placeholder="Describe your video scene...",
883
- lines=2,
884
  elem_classes="prompt-input"
885
  )
886
 
887
- with gr.Accordion("🎨 Advanced Prompt Settings", open=False):
888
  nag_negative_prompt = gr.Textbox(
889
- label="Negative Prompt",
890
  value=DEFAULT_NAG_NEGATIVE_PROMPT,
891
  lines=2,
892
  )
893
  nag_scale = gr.Slider(
894
  label="NAG Scale",
895
- minimum=0.0,
896
  maximum=20.0,
897
  step=0.25,
898
- value=5.0,
899
- info="Higher values = stronger guidance (0 = no NAG)"
900
  )
901
-
902
- audio_mode = gr.Radio(
903
- choices=["Video Only", "Enable Audio"],
904
- value="Video Only",
905
- label="🎵 Audio Mode",
906
- info="Enable to add audio to your generated video"
907
- )
908
 
909
- with gr.Column(visible=False) as audio_settings:
910
- audio_prompt = gr.Textbox(
911
- label="🎵 Audio Prompt",
912
- value=default_audio_prompt,
913
- placeholder="Describe the audio (e.g., 'waves, seagulls', 'footsteps')",
914
- lines=2
915
- )
916
- audio_negative_prompt = gr.Textbox(
917
- label="❌ Audio Negative Prompt",
918
- value=default_audio_negative_prompt,
919
- lines=2
920
- )
921
- with gr.Row():
922
- audio_seed = gr.Number(
923
- label="🎲 Audio Seed",
924
- value=-1,
925
- precision=0,
926
- minimum=-1
927
- )
928
- audio_steps = gr.Slider(
929
- minimum=1,
930
- maximum=25,
931
- step=1,
932
- value=10,
933
- label="🚀 Audio Steps"
934
- )
935
- audio_cfg_strength = gr.Slider(
936
- minimum=1.0,
937
- maximum=10.0,
938
- step=0.5,
939
- value=4.5,
940
- label="🎯 Audio Guidance"
941
- )
942
-
943
  with gr.Group(elem_classes="settings-panel"):
944
  gr.Markdown("### ⚙️ Video Settings")
945
 
946
  with gr.Row():
947
  duration_seconds_input = gr.Slider(
948
  minimum=1,
949
- maximum=2,
950
  step=1,
951
  value=DEFAULT_DURATION_SECONDS,
952
  label="📱 Duration (seconds)",
@@ -954,7 +364,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
954
  )
955
  steps_slider = gr.Slider(
956
  minimum=1,
957
- maximum=2,
958
  step=1,
959
  value=DEFAULT_STEPS,
960
  label="🔄 Inference Steps",
@@ -993,81 +403,97 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
993
  value=True,
994
  interactive=True
995
  )
996
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
997
  generate_button = gr.Button(
998
- "🎬 Generate Video",
999
  variant="primary",
1000
  elem_classes="generate-btn"
1001
  )
1002
-
1003
  with gr.Column(scale=1):
1004
- nag_video_output = gr.Video(
1005
- label="Generated Video",
1006
  autoplay=True,
1007
  interactive=False,
1008
  elem_classes="video-output"
1009
  )
1010
- video_with_audio_output = gr.Video(
1011
- label="🎥 Generated Video with Audio",
1012
- autoplay=True,
1013
- interactive=False,
1014
- visible=False,
1015
- elem_classes="video-output"
1016
- )
1017
 
1018
  gr.HTML("""
1019
  <div style="text-align: center; margin-top: 20px; color: #6b7280;">
1020
- <p>💡 Demo version with simplified model - Real NAG would produce higher quality results</p>
1021
- <p>💡 Tip: Try different NAG scales for varied artistic effects!</p>
1022
  </div>
1023
  """)
1024
-
1025
  gr.Markdown("### 🎯 Example Prompts")
1026
  gr.Examples(
1027
- examples=[
1028
- ["A cat playing guitar on stage", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
1029
- 128, 128, 1,
1030
- 1, DEFAULT_SEED, False,
1031
- "Enable Audio", "guitar music", default_audio_negative_prompt, -1, 10, 4.5],
1032
- ["A red car driving on a cliff road", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
1033
- 128, 128, 1,
1034
- 1, DEFAULT_SEED, False,
1035
- "Enable Audio", "car engine, wind", default_audio_negative_prompt, -1, 10, 4.5],
1036
- ["Glowing jellyfish floating in the sky", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
1037
- 128, 128, 1,
1038
- 1, DEFAULT_SEED, False,
1039
- "Video Only", "", default_audio_negative_prompt, -1, 10, 4.5],
1040
- ],
1041
- fn=generate_video,
1042
- inputs=[prompt, nag_negative_prompt, nag_scale,
1043
  height_input, width_input, duration_seconds_input,
1044
- steps_slider, seed_input, randomize_seed_checkbox,
1045
- audio_mode, audio_prompt, audio_negative_prompt,
1046
- audio_seed, audio_steps, audio_cfg_strength],
1047
- outputs=[nag_video_output, video_with_audio_output, seed_input],
1048
  cache_examples="lazy"
1049
  )
1050
-
1051
- # Event handlers
1052
- audio_mode.change(
1053
- fn=update_audio_visibility,
1054
- inputs=[audio_mode],
1055
- outputs=[audio_settings, video_with_audio_output]
1056
- )
1057
-
1058
  ui_inputs = [
1059
  prompt,
1060
  nag_negative_prompt, nag_scale,
1061
  height_input, width_input, duration_seconds_input,
1062
  steps_slider,
1063
  seed_input, randomize_seed_checkbox,
1064
- audio_mode, audio_prompt, audio_negative_prompt,
1065
- audio_seed, audio_steps, audio_cfg_strength,
1066
  ]
 
1067
  generate_button.click(
1068
- fn=generate_video,
1069
  inputs=ui_inputs,
1070
- outputs=[nag_video_output, video_with_audio_output, seed_input],
1071
  )
1072
 
1073
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import types
2
  import random
3
  import spaces
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ from datetime import datetime
8
+
9
  import torch
 
10
  import numpy as np
11
+ import torchaudio
12
+ from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
13
  from diffusers.utils import export_to_video
14
+ from diffusers import AutoModel
15
  import gradio as gr
16
  import tempfile
17
  from huggingface_hub import hf_hub_download
 
 
 
 
 
 
18
 
19
+ from src.pipeline_wan_nag import NAGWanPipeline
20
+ from src.transformer_wan_nag import NagWanTransformer3DModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # MMAudio imports
23
  try:
 
26
  os.system("pip install -e .")
27
  import mmaudio
28
 
29
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate as mmaudio_generate,
30
+ load_video, make_video, setup_eval_logging)
 
 
 
 
31
  from mmaudio.model.flow_matching import FlowMatching
32
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
33
  from mmaudio.model.sequence_config import SequenceConfig
34
  from mmaudio.model.utils.features_utils import FeaturesUtils
35
 
36
+ # NAG Video Settings
37
  MOD_VALUE = 32
38
+ DEFAULT_DURATION_SECONDS = 4
39
+ DEFAULT_STEPS = 4
40
  DEFAULT_SEED = 2025
41
+ DEFAULT_H_SLIDER_VALUE = 480
42
+ DEFAULT_W_SLIDER_VALUE = 832
43
+ NEW_FORMULA_MAX_AREA = 480.0 * 832.0
44
 
45
+ SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
46
+ SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
47
  MAX_SEED = np.iinfo(np.int32).max
48
 
49
+ FIXED_FPS = 16
50
  MIN_FRAMES_MODEL = 8
51
+ MAX_FRAMES_MODEL = 129
52
 
53
  DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
54
+ DEFAULT_AUDIO_NEGATIVE_PROMPT = "music"
55
 
56
+ # NAG Model Settings
57
  MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
58
  SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
59
  SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors"
60
  LORA_REPO_ID = "Kijai/WanVideo_comfy"
61
  LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
62
 
63
+ # MMAudio Settings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  torch.backends.cuda.matmul.allow_tf32 = True
65
  torch.backends.cudnn.allow_tf32 = True
 
66
  log = logging.getLogger()
67
+ device = 'cuda'
68
  dtype = torch.bfloat16
69
+ audio_model_config: ModelConfig = all_model_cfg['large_44k_v2']
70
+ audio_model_config.download_if_needed()
71
+ setup_eval_logging()
72
+
73
+ # Initialize NAG Video Model
74
+ vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
75
+ wan_path = hf_hub_download(repo_id=SUB_MODEL_ID, filename=SUB_MODEL_FILENAME)
76
+ transformer = NagWanTransformer3DModel.from_single_file(wan_path, torch_dtype=torch.bfloat16)
77
+ pipe = NAGWanPipeline.from_pretrained(
78
+ MODEL_ID, vae=vae, transformer=transformer, torch_dtype=torch.bfloat16
79
+ )
80
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
81
+ pipe.to("cuda")
82
 
83
+ pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
84
+ pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
85
+ pipe.transformer.__class__.forward = NagWanTransformer3DModel.forward
 
 
86
 
87
+ # Initialize MMAudio Model
88
+ def get_mmaudio_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
89
+ seq_cfg = audio_model_config.seq_cfg
90
 
91
+ net: MMAudio = get_my_mmaudio(audio_model_config.model_name).to(device, dtype).eval()
92
+ net.load_weights(torch.load(audio_model_config.model_path, map_location=device, weights_only=True))
93
+ log.info(f'Loaded MMAudio weights from {audio_model_config.model_path}')
94
+
95
+ feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model_config.vae_path,
96
+ synchformer_ckpt=audio_model_config.synchformer_ckpt,
97
+ enable_conditions=True,
98
+ mode=audio_model_config.mode,
99
+ bigvgan_vocoder_ckpt=audio_model_config.bigvgan_16k_path,
100
+ need_vae_encoder=False)
101
+ feature_utils = feature_utils.to(device, dtype).eval()
 
 
 
 
 
 
 
 
 
 
102
 
103
+ return net, feature_utils, seq_cfg
104
 
105
+ audio_net, audio_feature_utils, audio_seq_cfg = get_mmaudio_model()
 
 
 
 
 
 
 
 
 
106
 
107
+ # Audio generation function
108
+ @torch.inference_mode()
109
+ def add_audio_to_video(video_path, prompt, audio_negative_prompt, audio_steps, audio_cfg_strength, duration):
110
+ """Generate and add audio to video using MMAudio"""
111
+ rng = torch.Generator(device=device)
112
+ rng.seed() # Random seed for audio
113
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=audio_steps)
114
+
115
+ video_info = load_video(video_path, duration)
116
+ clip_frames = video_info.clip_frames
117
+ sync_frames = video_info.sync_frames
118
+ duration = video_info.duration_sec
119
+ clip_frames = clip_frames.unsqueeze(0)
120
+ sync_frames = sync_frames.unsqueeze(0)
121
+ audio_seq_cfg.duration = duration
122
+ audio_net.update_seq_lengths(audio_seq_cfg.latent_seq_len, audio_seq_cfg.clip_seq_len, audio_seq_cfg.sync_seq_len)
123
+
124
+ audios = mmaudio_generate(clip_frames,
125
+ sync_frames, [prompt],
126
+ negative_text=[audio_negative_prompt],
127
+ feature_utils=audio_feature_utils,
128
+ net=audio_net,
129
+ fm=fm,
130
+ rng=rng,
131
+ cfg_strength=audio_cfg_strength)
132
+ audio = audios.float().cpu()[0]
133
+
134
+ # Create video with audio
135
+ video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
136
+ make_video(video_info, video_with_audio_path, audio, sampling_rate=audio_seq_cfg.sampling_rate)
137
+
138
+ return video_with_audio_path
139
+
140
+ # Combined generation function
141
+ def get_duration(prompt, nag_negative_prompt, nag_scale, height, width, duration_seconds,
142
+ steps, seed, randomize_seed, enable_audio, audio_negative_prompt,
143
+ audio_steps, audio_cfg_strength):
144
+ # Calculate total duration including audio processing if enabled
145
+ video_duration = int(duration_seconds) * int(steps) * 2.25 + 5
146
+ audio_duration = 30 if enable_audio else 0 # Additional time for audio processing
147
+ return video_duration + audio_duration
148
 
149
+ @spaces.GPU(duration=get_duration)
150
+ def generate_video_with_audio(
151
+ prompt,
152
+ nag_negative_prompt, nag_scale,
153
+ height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE, duration_seconds=DEFAULT_DURATION_SECONDS,
154
+ steps=DEFAULT_STEPS,
155
+ seed=DEFAULT_SEED, randomize_seed=False,
156
+ enable_audio=True, audio_negative_prompt=DEFAULT_AUDIO_NEGATIVE_PROMPT,
157
+ audio_steps=25, audio_cfg_strength=4.5,
158
+ ):
159
+ # Generate video first
160
+ target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
161
+ target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
162
+
163
+ num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
164
+
165
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
166
+
167
+ with torch.inference_mode():
168
+ nag_output_frames_list = pipe(
169
+ prompt=prompt,
170
+ nag_negative_prompt=nag_negative_prompt,
171
+ nag_scale=nag_scale,
172
+ nag_tau=3.5,
173
+ nag_alpha=0.5,
174
+ height=target_h, width=target_w, num_frames=num_frames,
175
+ guidance_scale=0.,
176
+ num_inference_steps=int(steps),
177
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
178
+ ).frames[0]
179
+
180
+ # Save initial video without audio
181
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
182
+ temp_video_path = tmpfile.name
183
+ export_to_video(nag_output_frames_list, temp_video_path, fps=FIXED_FPS)
184
+
185
+ # Add audio if enabled
186
+ if enable_audio:
187
+ try:
188
+ final_video_path = add_audio_to_video(
189
+ temp_video_path,
190
+ prompt, # Use the same prompt for audio generation
191
+ audio_negative_prompt,
192
+ audio_steps,
193
+ audio_cfg_strength,
194
+ duration_seconds
195
+ )
196
+ # Clean up temp video
197
+ if os.path.exists(temp_video_path):
198
+ os.remove(temp_video_path)
199
+ except Exception as e:
200
+ log.error(f"Audio generation failed: {e}")
201
+ final_video_path = temp_video_path
202
+ else:
203
+ final_video_path = temp_video_path
204
+
205
+ return final_video_path, current_seed
206
+
207
+ # Example generation function
208
+ def generate_with_example(prompt, nag_negative_prompt, nag_scale):
209
+ video_path, seed = generate_video_with_audio(
210
+ prompt=prompt,
211
+ nag_negative_prompt=nag_negative_prompt, nag_scale=nag_scale,
212
+ height=DEFAULT_H_SLIDER_VALUE, width=DEFAULT_W_SLIDER_VALUE,
213
+ duration_seconds=DEFAULT_DURATION_SECONDS,
214
+ steps=DEFAULT_STEPS,
215
+ seed=DEFAULT_SEED, randomize_seed=False,
216
+ enable_audio=True, audio_negative_prompt=DEFAULT_AUDIO_NEGATIVE_PROMPT,
217
+ audio_steps=25, audio_cfg_strength=4.5,
218
+ )
219
+ return video_path, \
220
+ DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, \
221
+ DEFAULT_DURATION_SECONDS, DEFAULT_STEPS, seed, \
222
+ True, DEFAULT_AUDIO_NEGATIVE_PROMPT, 25, 4.5
223
+
224
+ # Examples with audio descriptions
225
+ examples = [
226
+ ["A ginger cat passionately plays electric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights cast dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
227
+ ["A red vintage Porsche convertible flying over a rugged coastal cliff. Monstrous waves violently crashing against the rocks below. A lighthouse stands tall atop the cliff.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
228
+ ["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape. Magical and dreamlike, captured in a wide shot. Surreal realism style with detailed textures.", DEFAULT_NAG_NEGATIVE_PROMPT, 11],
229
+ ]
230
+
231
+ # CSS styling
232
  css = """
233
  .container {
234
  max-width: 1400px;
 
300
  margin: 10px 0;
301
  border-left: 4px solid #667eea;
302
  }
303
+ .audio-settings {
304
+ background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%);
305
+ border-radius: 10px;
306
+ padding: 15px;
307
+ margin-top: 10px;
308
+ border-left: 4px solid #f59e0b;
309
+ }
310
  """
311
 
312
+ # Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
314
  with gr.Column(elem_classes="container"):
315
  gr.HTML("""
316
+ <h1 class="main-title">🎬 NAG Video Generator with Auto Audio</h1>
317
+ <p class="subtitle">Fast 4-step Wan2.1-T2V-14B with NAG + Automatic Audio Generation</p>
318
  """)
319
 
320
  gr.HTML("""
321
  <div class="info-box">
322
+ <p>🚀 <strong>Powered by:</strong> NAG + CausVid LoRA for video + MMAudio for automatic audio synthesis</p>
323
+ <p>⚡ <strong>Speed:</strong> Generate videos with synchronized audio in one click!</p>
324
+ <p>🎵 <strong>Audio:</strong> Automatically generates matching audio based on your video prompt</p>
 
325
  </div>
326
  """)
327
+
328
  with gr.Row():
329
  with gr.Column(scale=1):
330
  with gr.Group(elem_classes="prompt-container"):
331
  prompt = gr.Textbox(
332
+ label="✨ Video Prompt (also used for audio generation)",
333
+ placeholder="Describe your video scene in detail...",
334
+ lines=3,
 
335
  elem_classes="prompt-input"
336
  )
337
 
338
+ with gr.Accordion("🎨 Advanced Video Settings", open=False):
339
  nag_negative_prompt = gr.Textbox(
340
+ label="Video Negative Prompt",
341
  value=DEFAULT_NAG_NEGATIVE_PROMPT,
342
  lines=2,
343
  )
344
  nag_scale = gr.Slider(
345
  label="NAG Scale",
346
+ minimum=1.0,
347
  maximum=20.0,
348
  step=0.25,
349
+ value=11.0,
350
+ info="Higher values = stronger guidance"
351
  )
 
 
 
 
 
 
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  with gr.Group(elem_classes="settings-panel"):
354
  gr.Markdown("### ⚙️ Video Settings")
355
 
356
  with gr.Row():
357
  duration_seconds_input = gr.Slider(
358
  minimum=1,
359
+ maximum=8,
360
  step=1,
361
  value=DEFAULT_DURATION_SECONDS,
362
  label="📱 Duration (seconds)",
 
364
  )
365
  steps_slider = gr.Slider(
366
  minimum=1,
367
+ maximum=8,
368
  step=1,
369
  value=DEFAULT_STEPS,
370
  label="🔄 Inference Steps",
 
403
  value=True,
404
  interactive=True
405
  )
406
+
407
+ with gr.Group(elem_classes="audio-settings"):
408
+ gr.Markdown("### 🎵 Audio Generation Settings")
409
+
410
+ enable_audio = gr.Checkbox(
411
+ label="🔊 Enable Automatic Audio Generation",
412
+ value=True,
413
+ interactive=True
414
+ )
415
+
416
+ with gr.Column(visible=True) as audio_settings_group:
417
+ audio_negative_prompt = gr.Textbox(
418
+ label="Audio Negative Prompt",
419
+ value=DEFAULT_AUDIO_NEGATIVE_PROMPT,
420
+ placeholder="Elements to avoid in audio (e.g., music, speech)",
421
+ )
422
+
423
+ with gr.Row():
424
+ audio_steps = gr.Slider(
425
+ minimum=10,
426
+ maximum=50,
427
+ step=5,
428
+ value=25,
429
+ label="🎚️ Audio Steps",
430
+ info="More steps = better quality"
431
+ )
432
+ audio_cfg_strength = gr.Slider(
433
+ minimum=1.0,
434
+ maximum=10.0,
435
+ step=0.5,
436
+ value=4.5,
437
+ label="🎛️ Audio Guidance",
438
+ info="Strength of prompt guidance"
439
+ )
440
+
441
+ # Toggle audio settings visibility
442
+ enable_audio.change(
443
+ fn=lambda x: gr.update(visible=x),
444
+ inputs=[enable_audio],
445
+ outputs=[audio_settings_group]
446
+ )
447
+
448
  generate_button = gr.Button(
449
+ "🎬 Generate Video with Audio",
450
  variant="primary",
451
  elem_classes="generate-btn"
452
  )
453
+
454
  with gr.Column(scale=1):
455
+ video_output = gr.Video(
456
+ label="Generated Video with Audio",
457
  autoplay=True,
458
  interactive=False,
459
  elem_classes="video-output"
460
  )
 
 
 
 
 
 
 
461
 
462
  gr.HTML("""
463
  <div style="text-align: center; margin-top: 20px; color: #6b7280;">
464
+ <p>💡 Tip: The same prompt is used for both video and audio generation!</p>
465
+ <p>🎧 Audio is automatically matched to the visual content</p>
466
  </div>
467
  """)
468
+
469
  gr.Markdown("### 🎯 Example Prompts")
470
  gr.Examples(
471
+ examples=examples,
472
+ fn=generate_with_example,
473
+ inputs=[prompt, nag_negative_prompt, nag_scale],
474
+ outputs=[
475
+ video_output,
 
 
 
 
 
 
 
 
 
 
 
476
  height_input, width_input, duration_seconds_input,
477
+ steps_slider, seed_input,
478
+ enable_audio, audio_negative_prompt, audio_steps, audio_cfg_strength
479
+ ],
 
480
  cache_examples="lazy"
481
  )
482
+
483
+ # Connect UI elements
 
 
 
 
 
 
484
  ui_inputs = [
485
  prompt,
486
  nag_negative_prompt, nag_scale,
487
  height_input, width_input, duration_seconds_input,
488
  steps_slider,
489
  seed_input, randomize_seed_checkbox,
490
+ enable_audio, audio_negative_prompt, audio_steps, audio_cfg_strength,
 
491
  ]
492
+
493
  generate_button.click(
494
+ fn=generate_video_with_audio,
495
  inputs=ui_inputs,
496
+ outputs=[video_output, seed_input],
497
  )
498
 
499
  if __name__ == "__main__":