ginipick commited on
Commit
d945073
·
verified ·
1 Parent(s): f08e5c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -147
app.py CHANGED
@@ -9,6 +9,7 @@ except:
9
  current_dir = os.getcwd()
10
 
11
  sys.path.insert(0, current_dir)
 
12
 
13
  os.makedirs("src", exist_ok=True)
14
 
@@ -30,14 +31,14 @@ from typing import Optional, Dict, Any
30
  import torch.nn.functional as F
31
 
32
  class NagWanTransformer3DModel(nn.Module):
33
- """NAG-enhanced Transformer for video generation"""
34
 
35
  def __init__(
36
  self,
37
  in_channels: int = 4,
38
  out_channels: int = 4,
39
- hidden_size: int = 768,
40
- num_layers: int = 4,
41
  num_heads: int = 8,
42
  ):
43
  super().__init__()
@@ -50,46 +51,22 @@ class NagWanTransformer3DModel(nn.Module):
50
  self.config = type('Config', (), {
51
  'in_channels': in_channels,
52
  'out_channels': out_channels,
53
- 'hidden_size': hidden_size
 
 
54
  })()
55
 
56
- # For this demo, we'll use a simple noise-to-noise model
57
- # instead of loading the full 28GB model
58
- self.conv_in = nn.Conv3d(in_channels, 320, kernel_size=3, padding=1)
 
 
 
59
  self.time_embed = nn.Sequential(
60
- nn.Linear(320, 1280),
61
  nn.SiLU(),
62
- nn.Linear(1280, 1280),
63
- )
64
- self.down_blocks = nn.ModuleList([
65
- nn.Conv3d(320, 320, kernel_size=3, stride=2, padding=1),
66
- nn.Conv3d(320, 640, kernel_size=3, stride=2, padding=1),
67
- nn.Conv3d(640, 1280, kernel_size=3, stride=2, padding=1),
68
- ])
69
- self.mid_block = nn.Conv3d(1280, 1280, kernel_size=3, padding=1)
70
- self.up_blocks = nn.ModuleList([
71
- nn.ConvTranspose3d(1280, 640, kernel_size=3, stride=2, padding=1, output_padding=1),
72
- nn.ConvTranspose3d(640, 320, kernel_size=3, stride=2, padding=1, output_padding=1),
73
- nn.ConvTranspose3d(320, 320, kernel_size=3, stride=2, padding=1, output_padding=1),
74
- ])
75
- self.conv_out = nn.Conv3d(320, out_channels, kernel_size=3, padding=1)
76
-
77
- @classmethod
78
- def from_single_file(cls, model_path, **kwargs):
79
- """Load model from single file"""
80
- print(f"Note: Loading simplified NAG model instead of {model_path}")
81
- print("This is a demo version that doesn't require 28GB of weights")
82
-
83
- # Create a simplified model
84
- model = cls(
85
- in_channels=4,
86
- out_channels=4,
87
- hidden_size=768,
88
- num_layers=4,
89
- num_heads=8
90
  )
91
-
92
- return model.to(kwargs.get('torch_dtype', torch.float32))
93
 
94
  @staticmethod
95
  def attn_processors():
@@ -98,14 +75,6 @@ class NagWanTransformer3DModel(nn.Module):
98
  @staticmethod
99
  def set_attn_processor(processor):
100
  pass
101
-
102
- def time_proj(self, timesteps, dim=320):
103
- half_dim = dim // 2
104
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
105
- emb = torch.exp(-emb * torch.arange(half_dim, device=timesteps.device))
106
- emb = timesteps[:, None] * emb[None, :]
107
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
108
- return emb
109
 
110
  def forward(
111
  self,
@@ -115,33 +84,40 @@ class NagWanTransformer3DModel(nn.Module):
115
  attention_mask: Optional[torch.Tensor] = None,
116
  **kwargs
117
  ):
118
- # Get timestep embeddings
 
 
 
119
  if timestep is not None:
120
- t_emb = self.time_proj(timestep)
 
 
 
 
 
 
 
 
121
  t_emb = self.time_embed(t_emb)
 
 
 
122
 
123
- # Initial conv
124
  h = self.conv_in(hidden_states)
125
 
126
- # Down blocks
127
- down_block_res_samples = []
128
- for down_block in self.down_blocks:
129
- down_block_res_samples.append(h)
130
- h = down_block(h)
131
-
132
- # Mid block
133
- h = self.mid_block(h)
134
-
135
- # Up blocks
136
- for i, up_block in enumerate(self.up_blocks):
137
- h = up_block(h)
138
- # Add skip connections
139
- if i < len(down_block_res_samples):
140
- h = h + down_block_res_samples[-(i+1)]
141
 
142
- # Final conv
 
 
143
  h = self.conv_out(h)
144
 
 
 
 
145
  return h
146
  ''')
147
 
@@ -394,8 +370,9 @@ import types
394
  import random
395
  import spaces
396
  import torch
 
397
  import numpy as np
398
- from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
399
  from diffusers.utils import export_to_video
400
  import gradio as gr
401
  import tempfile
@@ -414,7 +391,17 @@ try:
414
  print("Successfully imported NAG modules")
415
  except Exception as e:
416
  print(f"Error importing NAG modules: {e}")
417
- raise
 
 
 
 
 
 
 
 
 
 
418
 
419
  # MMAudio imports
420
  try:
@@ -436,12 +423,12 @@ from mmaudio.model.utils.features_utils import FeaturesUtils
436
 
437
  # Constants
438
  MOD_VALUE = 32
439
- DEFAULT_DURATION_SECONDS = 4
440
- DEFAULT_STEPS = 4
441
  DEFAULT_SEED = 2025
442
- DEFAULT_H_SLIDER_VALUE = 256
443
- DEFAULT_W_SLIDER_VALUE = 256
444
- NEW_FORMULA_MAX_AREA = 480.0 * 832.0
445
 
446
  SLIDER_MIN_H, SLIDER_MAX_H = 128, 512
447
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 512
@@ -453,6 +440,7 @@ MAX_FRAMES_MODEL = 129
453
 
454
  DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
455
 
 
456
  MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
457
  SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
458
  SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors"
@@ -460,44 +448,97 @@ LORA_REPO_ID = "Kijai/WanVideo_comfy"
460
  LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
461
 
462
  # Initialize models
463
- print("Loading VAE...")
464
- vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
465
 
466
- # Skip downloading the large model file
467
- print("Creating simplified NAG transformer model...")
468
- # wan_path = hf_hub_download(repo_id=SUB_MODEL_ID, filename=SUB_MODEL_FILENAME)
469
- wan_path = "dummy_path" # We'll use a simplified model instead
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
- print("Creating transformer model...")
472
- transformer = NagWanTransformer3DModel.from_single_file(wan_path, torch_dtype=torch.bfloat16)
 
 
 
 
 
 
473
 
474
  print("Creating pipeline...")
475
- pipe = NAGWanPipeline.from_pretrained(
476
- MODEL_ID, vae=vae, transformer=transformer, torch_dtype=torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  )
478
- pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
479
 
480
  # Move to appropriate device
481
- if torch.cuda.is_available():
 
 
 
 
 
 
 
482
  pipe.to("cuda")
483
- print("Using CUDA device")
484
  else:
485
  pipe.to("cpu")
486
  print("Warning: CUDA not available, using CPU (will be slow)")
487
 
488
- # Load LoRA weights for faster generation
489
- try:
490
- print("Loading LoRA weights...")
491
- causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
492
- pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
493
- pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
494
- pipe.fuse_lora()
495
- print("LoRA weights loaded successfully")
496
- except Exception as e:
497
- print(f"Warning: Could not load LoRA weights: {e}")
498
 
499
- pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
500
- pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
 
 
 
501
 
502
  # Audio model setup
503
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -643,10 +684,11 @@ def get_duration(
643
  audio_mode, audio_prompt, audio_negative_prompt,
644
  audio_seed, audio_steps, audio_cfg_strength,
645
  ):
646
- duration = int(duration_seconds) * int(steps) * 2.25 + 5
 
647
  if audio_mode == "Enable Audio":
648
- duration += 60
649
- return duration
650
 
651
  @torch.inference_mode()
652
  def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
@@ -693,43 +735,64 @@ def generate_video(
693
  audio_mode="Video Only", audio_prompt="", audio_negative_prompt="music",
694
  audio_seed=-1, audio_steps=25, audio_cfg_strength=4.5,
695
  ):
696
- target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
697
- target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
698
-
699
- num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
700
-
701
- current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
702
-
703
- with torch.inference_mode():
704
- nag_output_frames_list = pipe(
705
- prompt=prompt,
706
- nag_negative_prompt=nag_negative_prompt,
707
- nag_scale=nag_scale,
708
- nag_tau=3.5,
709
- nag_alpha=0.5,
710
- height=target_h, width=target_w, num_frames=num_frames,
711
- guidance_scale=0.,
712
- num_inference_steps=int(steps),
713
- generator=torch.Generator(device=device).manual_seed(current_seed)
714
- ).frames[0]
715
-
716
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
717
- nag_video_path = tmpfile.name
718
- export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS)
719
-
720
- # Generate audio if enabled
721
- video_with_audio_path = None
722
- if audio_mode == "Enable Audio":
723
- video_with_audio_path = add_audio_to_video(
724
- nag_video_path, duration_seconds,
725
- audio_prompt, audio_negative_prompt,
726
- audio_seed, audio_steps, audio_cfg_strength
727
- )
728
-
729
- clear_cache()
730
- cleanup_temp_files()
 
 
 
 
 
 
 
 
 
 
 
731
 
732
- return nag_video_path, video_with_audio_path, current_seed
 
 
 
 
 
 
 
 
 
 
733
 
734
  def update_audio_visibility(audio_mode):
735
  return gr.update(visible=(audio_mode == "Enable Audio"))
@@ -738,15 +801,16 @@ def update_audio_visibility(audio_mode):
738
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
739
  with gr.Column(elem_classes="container"):
740
  gr.HTML("""
741
- <h1 class="main-title">🎬 NAG Video Generator with Audio (Demo)</h1>
742
- <p class="subtitle">Simplified NAG T2V with MMAudio Integration</p>
743
  """)
744
 
745
  gr.HTML("""
746
  <div class="info-box">
747
- <p>⚠️ <strong>Demo Version:</strong> This uses a simplified model to avoid downloading 28GB of weights</p>
748
  <p>🚀 <strong>NAG Technology:</strong> Normalized Attention Guidance for enhanced video quality</p>
749
  <p>🎵 <strong>Audio:</strong> Optional synchronized audio generation with MMAudio</p>
 
750
  </div>
751
  """)
752
 
@@ -822,7 +886,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
822
  with gr.Row():
823
  duration_seconds_input = gr.Slider(
824
  minimum=1,
825
- maximum=8,
826
  step=1,
827
  value=DEFAULT_DURATION_SECONDS,
828
  label="📱 Duration (seconds)",
@@ -830,7 +894,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
830
  )
831
  steps_slider = gr.Slider(
832
  minimum=1,
833
- maximum=8,
834
  step=1,
835
  value=DEFAULT_STEPS,
836
  label="🔄 Inference Steps",
@@ -893,6 +957,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
893
 
894
  gr.HTML("""
895
  <div style="text-align: center; margin-top: 20px; color: #6b7280;">
 
896
  <p>💡 Tip: Try different NAG scales for varied artistic effects!</p>
897
  </div>
898
  """)
@@ -901,16 +966,16 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
901
  gr.Examples(
902
  examples=[
903
  ["A ginger cat passionately plays electric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights cast dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
904
- DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, DEFAULT_DURATION_SECONDS,
905
- DEFAULT_STEPS, DEFAULT_SEED, False,
906
  "Enable Audio", "electric guitar riffs, cat meowing", default_audio_negative_prompt, -1, 25, 4.5],
907
  ["A red vintage Porsche convertible flying over a rugged coastal cliff. Monstrous waves violently crashing against the rocks below. A lighthouse stands tall atop the cliff.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
908
- DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, DEFAULT_DURATION_SECONDS,
909
- DEFAULT_STEPS, DEFAULT_SEED, False,
910
  "Enable Audio", "car engine roaring, ocean waves crashing, wind", default_audio_negative_prompt, -1, 25, 4.5],
911
  ["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
912
- DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, DEFAULT_DURATION_SECONDS,
913
- DEFAULT_STEPS, DEFAULT_SEED, False,
914
  "Video Only", "", default_audio_negative_prompt, -1, 25, 4.5],
915
  ],
916
  fn=generate_video,
 
9
  current_dir = os.getcwd()
10
 
11
  sys.path.insert(0, current_dir)
12
+ print(f"Added {current_dir} to Python path")
13
 
14
  os.makedirs("src", exist_ok=True)
15
 
 
31
  import torch.nn.functional as F
32
 
33
  class NagWanTransformer3DModel(nn.Module):
34
+ """NAG-enhanced Transformer for video generation (simplified demo)"""
35
 
36
  def __init__(
37
  self,
38
  in_channels: int = 4,
39
  out_channels: int = 4,
40
+ hidden_size: int = 1280,
41
+ num_layers: int = 2,
42
  num_heads: int = 8,
43
  ):
44
  super().__init__()
 
51
  self.config = type('Config', (), {
52
  'in_channels': in_channels,
53
  'out_channels': out_channels,
54
+ 'hidden_size': hidden_size,
55
+ 'num_attention_heads': num_heads,
56
+ 'attention_head_dim': hidden_size // num_heads,
57
  })()
58
 
59
+ # Simple conv layers for demo
60
+ self.conv_in = nn.Conv3d(in_channels, 64, kernel_size=3, padding=1)
61
+ self.conv_mid = nn.Conv3d(64, 64, kernel_size=3, padding=1)
62
+ self.conv_out = nn.Conv3d(64, out_channels, kernel_size=3, padding=1)
63
+
64
+ # Time embedding
65
  self.time_embed = nn.Sequential(
66
+ nn.Linear(1, hidden_size),
67
  nn.SiLU(),
68
+ nn.Linear(hidden_size, 64),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  )
 
 
70
 
71
  @staticmethod
72
  def attn_processors():
 
75
  @staticmethod
76
  def set_attn_processor(processor):
77
  pass
 
 
 
 
 
 
 
 
78
 
79
  def forward(
80
  self,
 
84
  attention_mask: Optional[torch.Tensor] = None,
85
  **kwargs
86
  ):
87
+ # Simple forward pass for demo
88
+ batch_size = hidden_states.shape[0]
89
+
90
+ # Time embedding
91
  if timestep is not None:
92
+ # Ensure timestep is the right shape
93
+ if timestep.ndim == 0:
94
+ timestep = timestep.unsqueeze(0)
95
+ if timestep.shape[0] != batch_size:
96
+ timestep = timestep.repeat(batch_size)
97
+
98
+ # Normalize timestep to [0, 1]
99
+ t_emb = timestep.float() / 1000.0
100
+ t_emb = t_emb.view(-1, 1)
101
  t_emb = self.time_embed(t_emb)
102
+
103
+ # Reshape for broadcasting
104
+ t_emb = t_emb.view(batch_size, -1, 1, 1, 1)
105
 
106
+ # Simple convolutions
107
  h = self.conv_in(hidden_states)
108
 
109
+ # Add time embedding if available
110
+ if timestep is not None:
111
+ h = h + t_emb
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ h = F.silu(h)
114
+ h = self.conv_mid(h)
115
+ h = F.silu(h)
116
  h = self.conv_out(h)
117
 
118
+ # Add residual connection
119
+ h = h + hidden_states
120
+
121
  return h
122
  ''')
123
 
 
370
  import random
371
  import spaces
372
  import torch
373
+ import torch.nn as nn
374
  import numpy as np
375
+ from diffusers import AutoencoderKL, UniPCMultistepScheduler, DDPMScheduler
376
  from diffusers.utils import export_to_video
377
  import gradio as gr
378
  import tempfile
 
391
  print("Successfully imported NAG modules")
392
  except Exception as e:
393
  print(f"Error importing NAG modules: {e}")
394
+ print("Attempting to recreate modules...")
395
+ # Wait a bit and try again
396
+ import time
397
+ time.sleep(3)
398
+ try:
399
+ from src.pipeline_wan_nag import NAGWanPipeline
400
+ from src.transformer_wan_nag import NagWanTransformer3DModel
401
+ print("Successfully imported NAG modules on second attempt")
402
+ except:
403
+ print("Failed to import modules. Please restart the application.")
404
+ sys.exit(1)
405
 
406
  # MMAudio imports
407
  try:
 
423
 
424
  # Constants
425
  MOD_VALUE = 32
426
+ DEFAULT_DURATION_SECONDS = 2
427
+ DEFAULT_STEPS = 2
428
  DEFAULT_SEED = 2025
429
+ DEFAULT_H_SLIDER_VALUE = 128
430
+ DEFAULT_W_SLIDER_VALUE = 128
431
+ NEW_FORMULA_MAX_AREA = 256.0 * 256.0
432
 
433
  SLIDER_MIN_H, SLIDER_MAX_H = 128, 512
434
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 512
 
440
 
441
  DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
442
 
443
+ # Note: Model IDs are kept for reference but not used in demo
444
  MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
445
  SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
446
  SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors"
 
448
  LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
449
 
450
  # Initialize models
451
+ print("Creating demo models...")
 
452
 
453
+ # Create a simple VAE-like model for demo
454
+ class DemoVAE(nn.Module):
455
+ def __init__(self):
456
+ super().__init__()
457
+ self.encoder = nn.Sequential(
458
+ nn.Conv2d(3, 64, 3, padding=1),
459
+ nn.ReLU(),
460
+ nn.Conv2d(64, 4, 3, padding=1)
461
+ )
462
+ self.decoder = nn.Sequential(
463
+ nn.Conv2d(4, 64, 3, padding=1),
464
+ nn.ReLU(),
465
+ nn.Conv2d(64, 3, 3, padding=1),
466
+ nn.Tanh() # Output in [-1, 1]
467
+ )
468
+ self.config = type('Config', (), {
469
+ 'scaling_factor': 0.18215,
470
+ 'latent_channels': 4,
471
+ })()
472
+
473
+ def encode(self, x):
474
+ # Simple encoding
475
+ encoded = self.encoder(x)
476
+ return type('EncoderOutput', (), {'latent_dist': type('LatentDist', (), {'sample': lambda: encoded})()})()
477
+
478
+ def decode(self, z):
479
+ # Simple decoding
480
+ # Handle different input shapes
481
+ if z.dim() == 5: # Video: (B, C, F, H, W)
482
+ b, c, f, h, w = z.shape
483
+ z = z.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w)
484
+ decoded = self.decoder(z)
485
+ decoded = decoded.reshape(b, f, 3, h * 8, w * 8).permute(0, 2, 1, 3, 4)
486
+ else: # Image: (B, C, H, W)
487
+ decoded = self.decoder(z)
488
+ return type('DecoderOutput', (), {'sample': decoded})()
489
+
490
+ vae = DemoVAE()
491
 
492
+ print("Creating simplified NAG transformer model...")
493
+ transformer = NagWanTransformer3DModel(
494
+ in_channels=4,
495
+ out_channels=4,
496
+ hidden_size=1280,
497
+ num_layers=2, # Reduced for demo
498
+ num_heads=8
499
+ )
500
 
501
  print("Creating pipeline...")
502
+ # Create a minimal pipeline for demo
503
+ pipe = NAGWanPipeline(
504
+ vae=vae,
505
+ text_encoder=None,
506
+ tokenizer=None,
507
+ transformer=transformer,
508
+ scheduler=DDPMScheduler(
509
+ num_train_timesteps=1000,
510
+ beta_start=0.00085,
511
+ beta_end=0.012,
512
+ beta_schedule="scaled_linear",
513
+ clip_sample=False,
514
+ set_alpha_to_one=False,
515
+ steps_offset=1,
516
+ )
517
  )
 
518
 
519
  # Move to appropriate device
520
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
521
+ print(f"Using device: {device}")
522
+
523
+ # Move models to device
524
+ vae = vae.to(device)
525
+ transformer = transformer.to(device)
526
+
527
+ if device == 'cuda':
528
  pipe.to("cuda")
529
+ print("Pipeline moved to CUDA")
530
  else:
531
  pipe.to("cpu")
532
  print("Warning: CUDA not available, using CPU (will be slow)")
533
 
534
+ # Skip LoRA for demo version
535
+ print("Demo version ready!")
 
 
 
 
 
 
 
 
536
 
537
+ # Check if transformer has the required methods
538
+ if hasattr(transformer, 'attn_processors'):
539
+ pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
540
+ if hasattr(transformer, 'set_attn_processor'):
541
+ pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
542
 
543
  # Audio model setup
544
  torch.backends.cuda.matmul.allow_tf32 = True
 
684
  audio_mode, audio_prompt, audio_negative_prompt,
685
  audio_seed, audio_steps, audio_cfg_strength,
686
  ):
687
+ # Simplified duration calculation for demo
688
+ duration = int(duration_seconds) * int(steps) + 10
689
  if audio_mode == "Enable Audio":
690
+ duration += 30 # Reduced from 60 for demo
691
+ return min(duration, 60) # Cap at 60 seconds for demo
692
 
693
  @torch.inference_mode()
694
  def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
 
735
  audio_mode="Video Only", audio_prompt="", audio_negative_prompt="music",
736
  audio_seed=-1, audio_steps=25, audio_cfg_strength=4.5,
737
  ):
738
+ try:
739
+ target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
740
+ target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
741
+
742
+ num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
743
+
744
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
745
+
746
+ # Ensure transformer is on the right device and dtype
747
+ if hasattr(pipe, 'transformer'):
748
+ pipe.transformer = pipe.transformer.to(device).to(torch.float32)
749
+ if hasattr(pipe, 'vae'):
750
+ pipe.vae = pipe.vae.to(device).to(torch.float32)
751
+
752
+ with torch.inference_mode():
753
+ nag_output_frames_list = pipe(
754
+ prompt=prompt,
755
+ nag_negative_prompt=nag_negative_prompt,
756
+ nag_scale=nag_scale,
757
+ nag_tau=3.5,
758
+ nag_alpha=0.5,
759
+ height=target_h, width=target_w, num_frames=num_frames,
760
+ guidance_scale=0.,
761
+ num_inference_steps=int(steps),
762
+ generator=torch.Generator(device=device).manual_seed(current_seed)
763
+ ).frames[0]
764
+
765
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
766
+ nag_video_path = tmpfile.name
767
+ export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS)
768
+
769
+ # Generate audio if enabled
770
+ video_with_audio_path = None
771
+ if audio_mode == "Enable Audio":
772
+ try:
773
+ video_with_audio_path = add_audio_to_video(
774
+ nag_video_path, duration_seconds,
775
+ audio_prompt, audio_negative_prompt,
776
+ audio_seed, audio_steps, audio_cfg_strength
777
+ )
778
+ except Exception as e:
779
+ print(f"Warning: Could not generate audio: {e}")
780
+ video_with_audio_path = None
781
+
782
+ clear_cache()
783
+ cleanup_temp_files()
784
 
785
+ return nag_video_path, video_with_audio_path, current_seed
786
+
787
+ except Exception as e:
788
+ print(f"Error generating video: {e}")
789
+ # Return a simple error video
790
+ error_frames = np.zeros((1, 64, 64, 3), dtype=np.uint8)
791
+ error_frames[:, :, :] = [255, 0, 0] # Red frame
792
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
793
+ error_video_path = tmpfile.name
794
+ export_to_video([error_frames[0]], error_video_path, fps=1)
795
+ return error_video_path, None, current_seed
796
 
797
  def update_audio_visibility(audio_mode):
798
  return gr.update(visible=(audio_mode == "Enable Audio"))
 
801
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
802
  with gr.Column(elem_classes="container"):
803
  gr.HTML("""
804
+ <h1 class="main-title">🎬 NAG Video Demo with Audio</h1>
805
+ <p class="subtitle">Lightweight Text-to-Video with Normalized Attention Guidance + MMAudio</p>
806
  """)
807
 
808
  gr.HTML("""
809
  <div class="info-box">
810
+ <p>📌 <strong>Demo Version:</strong> This is a simplified demo that demonstrates NAG concepts without large model downloads</p>
811
  <p>🚀 <strong>NAG Technology:</strong> Normalized Attention Guidance for enhanced video quality</p>
812
  <p>🎵 <strong>Audio:</strong> Optional synchronized audio generation with MMAudio</p>
813
+ <p>⚡ <strong>Fast:</strong> Runs without downloading 28GB model files</p>
814
  </div>
815
  """)
816
 
 
886
  with gr.Row():
887
  duration_seconds_input = gr.Slider(
888
  minimum=1,
889
+ maximum=4,
890
  step=1,
891
  value=DEFAULT_DURATION_SECONDS,
892
  label="📱 Duration (seconds)",
 
894
  )
895
  steps_slider = gr.Slider(
896
  minimum=1,
897
+ maximum=4,
898
  step=1,
899
  value=DEFAULT_STEPS,
900
  label="🔄 Inference Steps",
 
957
 
958
  gr.HTML("""
959
  <div style="text-align: center; margin-top: 20px; color: #6b7280;">
960
+ <p>💡 Demo version with simplified model - Real NAG would produce higher quality results</p>
961
  <p>💡 Tip: Try different NAG scales for varied artistic effects!</p>
962
  </div>
963
  """)
 
966
  gr.Examples(
967
  examples=[
968
  ["A ginger cat passionately plays electric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights cast dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
969
+ 128, 128, 2,
970
+ 2, DEFAULT_SEED, False,
971
  "Enable Audio", "electric guitar riffs, cat meowing", default_audio_negative_prompt, -1, 25, 4.5],
972
  ["A red vintage Porsche convertible flying over a rugged coastal cliff. Monstrous waves violently crashing against the rocks below. A lighthouse stands tall atop the cliff.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
973
+ 128, 128, 2,
974
+ 2, DEFAULT_SEED, False,
975
  "Enable Audio", "car engine roaring, ocean waves crashing, wind", default_audio_negative_prompt, -1, 25, 4.5],
976
  ["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
977
+ 128, 128, 2,
978
+ 2, DEFAULT_SEED, False,
979
  "Video Only", "", default_audio_negative_prompt, -1, 25, 4.5],
980
  ],
981
  fn=generate_video,