ginipick commited on
Commit
b90de71
·
verified ·
1 Parent(s): 77ed819

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -41
app.py CHANGED
@@ -2,6 +2,8 @@
2
  import os
3
  import sys
4
 
 
 
5
  # Add current directory to Python path
6
  try:
7
  current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -46,6 +48,7 @@ class NagWanTransformer3DModel(nn.Module):
46
  self.out_channels = out_channels
47
  self.hidden_size = hidden_size
48
  self.training = False
 
49
 
50
  # Dummy config for compatibility
51
  self.config = type('Config', (), {
@@ -67,6 +70,27 @@ class NagWanTransformer3DModel(nn.Module):
67
  nn.SiLU(),
68
  nn.Linear(hidden_size, hidden_size),
69
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  @staticmethod
72
  def attn_processors():
@@ -423,8 +447,8 @@ from mmaudio.model.utils.features_utils import FeaturesUtils
423
 
424
  # Constants
425
  MOD_VALUE = 32
426
- DEFAULT_DURATION_SECONDS = 2
427
- DEFAULT_STEPS = 2
428
  DEFAULT_SEED = 2025
429
  DEFAULT_H_SLIDER_VALUE = 128
430
  DEFAULT_W_SLIDER_VALUE = 128
@@ -434,9 +458,9 @@ SLIDER_MIN_H, SLIDER_MAX_H = 128, 256
434
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 256
435
  MAX_SEED = np.iinfo(np.int32).max
436
 
437
- FIXED_FPS = 16
438
  MIN_FRAMES_MODEL = 8
439
- MAX_FRAMES_MODEL = 129
440
 
441
  DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
442
 
@@ -454,6 +478,7 @@ print("Creating demo models...")
454
  class DemoVAE(nn.Module):
455
  def __init__(self):
456
  super().__init__()
 
457
  self.encoder = nn.Sequential(
458
  nn.Conv2d(3, 64, 3, padding=1),
459
  nn.ReLU(),
@@ -470,6 +495,27 @@ class DemoVAE(nn.Module):
470
  'latent_channels': 4,
471
  })()
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  def encode(self, x):
474
  # Simple encoding
475
  encoded = self.encoder(x)
@@ -519,18 +565,19 @@ pipe = NAGWanPipeline(
519
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
520
  print(f"Using device: {device}")
521
 
522
- # Move models to device
523
- vae = vae.to(device)
524
- transformer = transformer.to(device)
525
 
526
- if device == 'cuda':
527
- pipe.to("cuda")
528
- print("Pipeline moved to CUDA")
529
- else:
530
- pipe.to("cpu")
531
- print("Warning: CUDA not available, using CPU (will be slow)")
 
 
532
 
533
- # Skip LoRA for demo version
534
  print("Demo version ready!")
535
 
536
  # Check if transformer has the required methods
@@ -748,6 +795,8 @@ def generate_video(
748
  if hasattr(pipe, 'vae'):
749
  pipe.vae = pipe.vae.to(device).to(torch.float32)
750
 
 
 
751
  with torch.inference_mode():
752
  nag_output_frames_list = pipe(
753
  prompt=prompt,
@@ -785,13 +834,21 @@ def generate_video(
785
 
786
  except Exception as e:
787
  print(f"Error generating video: {e}")
 
 
 
788
  # Return a simple error video
789
- error_frames = np.zeros((1, 64, 64, 3), dtype=np.uint8)
790
- error_frames[:, :, :] = [255, 0, 0] # Red frame
 
 
 
 
 
791
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
792
  error_video_path = tmpfile.name
793
- export_to_video([error_frames[0]], error_video_path, fps=1)
794
- return error_video_path, None, current_seed
795
 
796
  def update_audio_visibility(audio_mode):
797
  return gr.update(visible=(audio_mode == "Enable Audio"))
@@ -800,8 +857,8 @@ def update_audio_visibility(audio_mode):
800
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
801
  with gr.Column(elem_classes="container"):
802
  gr.HTML("""
803
- <h1 class="main-title">🎬 NAG Video Demo with Audio</h1>
804
- <p class="subtitle">Lightweight Text-to-Video with Normalized Attention Guidance + MMAudio</p>
805
  """)
806
 
807
  gr.HTML("""
@@ -818,8 +875,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
818
  with gr.Group(elem_classes="prompt-container"):
819
  prompt = gr.Textbox(
820
  label="✨ Video Prompt",
821
- placeholder="Describe your video scene in detail...",
822
- lines=3,
 
823
  elem_classes="prompt-input"
824
  )
825
 
@@ -831,11 +889,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
831
  )
832
  nag_scale = gr.Slider(
833
  label="NAG Scale",
834
- minimum=1.0,
835
  maximum=20.0,
836
  step=0.25,
837
- value=11.0,
838
- info="Higher values = stronger guidance"
839
  )
840
 
841
  audio_mode = gr.Radio(
@@ -866,9 +924,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
866
  )
867
  audio_steps = gr.Slider(
868
  minimum=1,
869
- maximum=50,
870
  step=1,
871
- value=25,
872
  label="🚀 Audio Steps"
873
  )
874
  audio_cfg_strength = gr.Slider(
@@ -885,7 +943,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
885
  with gr.Row():
886
  duration_seconds_input = gr.Slider(
887
  minimum=1,
888
- maximum=4,
889
  step=1,
890
  value=DEFAULT_DURATION_SECONDS,
891
  label="📱 Duration (seconds)",
@@ -893,7 +951,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
893
  )
894
  steps_slider = gr.Slider(
895
  minimum=1,
896
- maximum=4,
897
  step=1,
898
  value=DEFAULT_STEPS,
899
  label="🔄 Inference Steps",
@@ -964,18 +1022,18 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
964
  gr.Markdown("### 🎯 Example Prompts")
965
  gr.Examples(
966
  examples=[
967
- ["A ginger cat passionately plays electric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights cast dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
968
- 128, 128, 2,
969
- 2, DEFAULT_SEED, False,
970
- "Enable Audio", "electric guitar riffs, cat meowing", default_audio_negative_prompt, -1, 25, 4.5],
971
- ["A red vintage Porsche convertible flying over a rugged coastal cliff. Monstrous waves violently crashing against the rocks below. A lighthouse stands tall atop the cliff.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
972
- 128, 128, 2,
973
- 2, DEFAULT_SEED, False,
974
- "Enable Audio", "car engine roaring, ocean waves crashing, wind", default_audio_negative_prompt, -1, 25, 4.5],
975
- ["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
976
- 128, 128, 2,
977
- 2, DEFAULT_SEED, False,
978
- "Video Only", "", default_audio_negative_prompt, -1, 25, 4.5],
979
  ],
980
  fn=generate_video,
981
  inputs=[prompt, nag_negative_prompt, nag_scale,
 
2
  import os
3
  import sys
4
 
5
+ print("Starting NAG Video Demo application...")
6
+
7
  # Add current directory to Python path
8
  try:
9
  current_dir = os.path.dirname(os.path.abspath(__file__))
 
48
  self.out_channels = out_channels
49
  self.hidden_size = hidden_size
50
  self.training = False
51
+ self._dtype = torch.float32 # Add dtype attribute
52
 
53
  # Dummy config for compatibility
54
  self.config = type('Config', (), {
 
70
  nn.SiLU(),
71
  nn.Linear(hidden_size, hidden_size),
72
  )
73
+
74
+ @property
75
+ def dtype(self):
76
+ """Return the dtype of the model"""
77
+ return self._dtype
78
+
79
+ @dtype.setter
80
+ def dtype(self, value):
81
+ """Set the dtype of the model"""
82
+ self._dtype = value
83
+
84
+ def to(self, *args, **kwargs):
85
+ """Override to method to handle dtype"""
86
+ result = super().to(*args, **kwargs)
87
+ # Update dtype if moving to a specific dtype
88
+ for arg in args:
89
+ if isinstance(arg, torch.dtype):
90
+ self._dtype = arg
91
+ if 'dtype' in kwargs:
92
+ self._dtype = kwargs['dtype']
93
+ return result
94
 
95
  @staticmethod
96
  def attn_processors():
 
447
 
448
  # Constants
449
  MOD_VALUE = 32
450
+ DEFAULT_DURATION_SECONDS = 1
451
+ DEFAULT_STEPS = 1
452
  DEFAULT_SEED = 2025
453
  DEFAULT_H_SLIDER_VALUE = 128
454
  DEFAULT_W_SLIDER_VALUE = 128
 
458
  SLIDER_MIN_W, SLIDER_MAX_W = 128, 256
459
  MAX_SEED = np.iinfo(np.int32).max
460
 
461
+ FIXED_FPS = 8 # Reduced FPS for demo
462
  MIN_FRAMES_MODEL = 8
463
+ MAX_FRAMES_MODEL = 32 # Reduced max frames for demo
464
 
465
  DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
466
 
 
478
  class DemoVAE(nn.Module):
479
  def __init__(self):
480
  super().__init__()
481
+ self._dtype = torch.float32 # Add dtype attribute
482
  self.encoder = nn.Sequential(
483
  nn.Conv2d(3, 64, 3, padding=1),
484
  nn.ReLU(),
 
495
  'latent_channels': 4,
496
  })()
497
 
498
+ @property
499
+ def dtype(self):
500
+ """Return the dtype of the model"""
501
+ return self._dtype
502
+
503
+ @dtype.setter
504
+ def dtype(self, value):
505
+ """Set the dtype of the model"""
506
+ self._dtype = value
507
+
508
+ def to(self, *args, **kwargs):
509
+ """Override to method to handle dtype"""
510
+ result = super().to(*args, **kwargs)
511
+ # Update dtype if moving to a specific dtype
512
+ for arg in args:
513
+ if isinstance(arg, torch.dtype):
514
+ self._dtype = arg
515
+ if 'dtype' in kwargs:
516
+ self._dtype = kwargs['dtype']
517
+ return result
518
+
519
  def encode(self, x):
520
  # Simple encoding
521
  encoded = self.encoder(x)
 
565
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
566
  print(f"Using device: {device}")
567
 
568
+ # Move models to device with explicit dtype
569
+ vae = vae.to(device).to(torch.float32)
570
+ transformer = transformer.to(device).to(torch.float32)
571
 
572
+ # Now move pipeline to device (it will handle the components)
573
+ try:
574
+ pipe = pipe.to(device)
575
+ print(f"Pipeline moved to {device}")
576
+ except Exception as e:
577
+ print(f"Warning: Could not move pipeline to {device}: {e}")
578
+ # Manually set device
579
+ pipe._execution_device = device
580
 
 
581
  print("Demo version ready!")
582
 
583
  # Check if transformer has the required methods
 
795
  if hasattr(pipe, 'vae'):
796
  pipe.vae = pipe.vae.to(device).to(torch.float32)
797
 
798
+ print(f"Generating video: {target_w}x{target_h}, {num_frames} frames, seed {current_seed}")
799
+
800
  with torch.inference_mode():
801
  nag_output_frames_list = pipe(
802
  prompt=prompt,
 
834
 
835
  except Exception as e:
836
  print(f"Error generating video: {e}")
837
+ import traceback
838
+ traceback.print_exc()
839
+
840
  # Return a simple error video
841
+ error_frames = []
842
+ for i in range(8): # Create 8 frames
843
+ frame = np.zeros((128, 128, 3), dtype=np.uint8)
844
+ frame[:, :] = [255, 0, 0] # Red frame
845
+ # Add error text
846
+ error_frames.append(frame)
847
+
848
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
849
  error_video_path = tmpfile.name
850
+ export_to_video(error_frames, error_video_path, fps=FIXED_FPS)
851
+ return error_video_path, None, 0
852
 
853
  def update_audio_visibility(audio_mode):
854
  return gr.update(visible=(audio_mode == "Enable Audio"))
 
857
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
858
  with gr.Column(elem_classes="container"):
859
  gr.HTML("""
860
+ <h1 class="main-title">🎬 NAG Video Demo</h1>
861
+ <p class="subtitle">Simple Text-to-Video with NAG + Audio Generation</p>
862
  """)
863
 
864
  gr.HTML("""
 
875
  with gr.Group(elem_classes="prompt-container"):
876
  prompt = gr.Textbox(
877
  label="✨ Video Prompt",
878
+ value=default_prompt,
879
+ placeholder="Describe your video scene...",
880
+ lines=2,
881
  elem_classes="prompt-input"
882
  )
883
 
 
889
  )
890
  nag_scale = gr.Slider(
891
  label="NAG Scale",
892
+ minimum=0.0,
893
  maximum=20.0,
894
  step=0.25,
895
+ value=5.0,
896
+ info="Higher values = stronger guidance (0 = no NAG)"
897
  )
898
 
899
  audio_mode = gr.Radio(
 
924
  )
925
  audio_steps = gr.Slider(
926
  minimum=1,
927
+ maximum=25,
928
  step=1,
929
+ value=10,
930
  label="🚀 Audio Steps"
931
  )
932
  audio_cfg_strength = gr.Slider(
 
943
  with gr.Row():
944
  duration_seconds_input = gr.Slider(
945
  minimum=1,
946
+ maximum=2,
947
  step=1,
948
  value=DEFAULT_DURATION_SECONDS,
949
  label="📱 Duration (seconds)",
 
951
  )
952
  steps_slider = gr.Slider(
953
  minimum=1,
954
+ maximum=2,
955
  step=1,
956
  value=DEFAULT_STEPS,
957
  label="🔄 Inference Steps",
 
1022
  gr.Markdown("### 🎯 Example Prompts")
1023
  gr.Examples(
1024
  examples=[
1025
+ ["A cat playing guitar on stage", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
1026
+ 128, 128, 1,
1027
+ 1, DEFAULT_SEED, False,
1028
+ "Enable Audio", "guitar music", default_audio_negative_prompt, -1, 10, 4.5],
1029
+ ["A red car driving on a cliff road", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
1030
+ 128, 128, 1,
1031
+ 1, DEFAULT_SEED, False,
1032
+ "Enable Audio", "car engine, wind", default_audio_negative_prompt, -1, 10, 4.5],
1033
+ ["Glowing jellyfish floating in the sky", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
1034
+ 128, 128, 1,
1035
+ 1, DEFAULT_SEED, False,
1036
+ "Video Only", "", default_audio_negative_prompt, -1, 10, 4.5],
1037
  ],
1038
  fn=generate_video,
1039
  inputs=[prompt, nag_negative_prompt, nag_scale,