Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
import os
|
3 |
import sys
|
4 |
|
|
|
|
|
5 |
# Add current directory to Python path
|
6 |
try:
|
7 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
@@ -46,6 +48,7 @@ class NagWanTransformer3DModel(nn.Module):
|
|
46 |
self.out_channels = out_channels
|
47 |
self.hidden_size = hidden_size
|
48 |
self.training = False
|
|
|
49 |
|
50 |
# Dummy config for compatibility
|
51 |
self.config = type('Config', (), {
|
@@ -67,6 +70,27 @@ class NagWanTransformer3DModel(nn.Module):
|
|
67 |
nn.SiLU(),
|
68 |
nn.Linear(hidden_size, hidden_size),
|
69 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
@staticmethod
|
72 |
def attn_processors():
|
@@ -423,8 +447,8 @@ from mmaudio.model.utils.features_utils import FeaturesUtils
|
|
423 |
|
424 |
# Constants
|
425 |
MOD_VALUE = 32
|
426 |
-
DEFAULT_DURATION_SECONDS =
|
427 |
-
DEFAULT_STEPS =
|
428 |
DEFAULT_SEED = 2025
|
429 |
DEFAULT_H_SLIDER_VALUE = 128
|
430 |
DEFAULT_W_SLIDER_VALUE = 128
|
@@ -434,9 +458,9 @@ SLIDER_MIN_H, SLIDER_MAX_H = 128, 256
|
|
434 |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 256
|
435 |
MAX_SEED = np.iinfo(np.int32).max
|
436 |
|
437 |
-
FIXED_FPS =
|
438 |
MIN_FRAMES_MODEL = 8
|
439 |
-
MAX_FRAMES_MODEL =
|
440 |
|
441 |
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
|
442 |
|
@@ -454,6 +478,7 @@ print("Creating demo models...")
|
|
454 |
class DemoVAE(nn.Module):
|
455 |
def __init__(self):
|
456 |
super().__init__()
|
|
|
457 |
self.encoder = nn.Sequential(
|
458 |
nn.Conv2d(3, 64, 3, padding=1),
|
459 |
nn.ReLU(),
|
@@ -470,6 +495,27 @@ class DemoVAE(nn.Module):
|
|
470 |
'latent_channels': 4,
|
471 |
})()
|
472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
def encode(self, x):
|
474 |
# Simple encoding
|
475 |
encoded = self.encoder(x)
|
@@ -519,18 +565,19 @@ pipe = NAGWanPipeline(
|
|
519 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
520 |
print(f"Using device: {device}")
|
521 |
|
522 |
-
# Move models to device
|
523 |
-
vae = vae.to(device)
|
524 |
-
transformer = transformer.to(device)
|
525 |
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
print("Warning:
|
|
|
|
|
532 |
|
533 |
-
# Skip LoRA for demo version
|
534 |
print("Demo version ready!")
|
535 |
|
536 |
# Check if transformer has the required methods
|
@@ -748,6 +795,8 @@ def generate_video(
|
|
748 |
if hasattr(pipe, 'vae'):
|
749 |
pipe.vae = pipe.vae.to(device).to(torch.float32)
|
750 |
|
|
|
|
|
751 |
with torch.inference_mode():
|
752 |
nag_output_frames_list = pipe(
|
753 |
prompt=prompt,
|
@@ -785,13 +834,21 @@ def generate_video(
|
|
785 |
|
786 |
except Exception as e:
|
787 |
print(f"Error generating video: {e}")
|
|
|
|
|
|
|
788 |
# Return a simple error video
|
789 |
-
error_frames =
|
790 |
-
|
|
|
|
|
|
|
|
|
|
|
791 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
792 |
error_video_path = tmpfile.name
|
793 |
-
export_to_video(
|
794 |
-
return error_video_path, None,
|
795 |
|
796 |
def update_audio_visibility(audio_mode):
|
797 |
return gr.update(visible=(audio_mode == "Enable Audio"))
|
@@ -800,8 +857,8 @@ def update_audio_visibility(audio_mode):
|
|
800 |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
801 |
with gr.Column(elem_classes="container"):
|
802 |
gr.HTML("""
|
803 |
-
<h1 class="main-title">🎬 NAG Video Demo
|
804 |
-
<p class="subtitle">
|
805 |
""")
|
806 |
|
807 |
gr.HTML("""
|
@@ -818,8 +875,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
818 |
with gr.Group(elem_classes="prompt-container"):
|
819 |
prompt = gr.Textbox(
|
820 |
label="✨ Video Prompt",
|
821 |
-
|
822 |
-
|
|
|
823 |
elem_classes="prompt-input"
|
824 |
)
|
825 |
|
@@ -831,11 +889,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
831 |
)
|
832 |
nag_scale = gr.Slider(
|
833 |
label="NAG Scale",
|
834 |
-
minimum=
|
835 |
maximum=20.0,
|
836 |
step=0.25,
|
837 |
-
value=
|
838 |
-
info="Higher values = stronger guidance"
|
839 |
)
|
840 |
|
841 |
audio_mode = gr.Radio(
|
@@ -866,9 +924,9 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
866 |
)
|
867 |
audio_steps = gr.Slider(
|
868 |
minimum=1,
|
869 |
-
maximum=
|
870 |
step=1,
|
871 |
-
value=
|
872 |
label="🚀 Audio Steps"
|
873 |
)
|
874 |
audio_cfg_strength = gr.Slider(
|
@@ -885,7 +943,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
885 |
with gr.Row():
|
886 |
duration_seconds_input = gr.Slider(
|
887 |
minimum=1,
|
888 |
-
maximum=
|
889 |
step=1,
|
890 |
value=DEFAULT_DURATION_SECONDS,
|
891 |
label="📱 Duration (seconds)",
|
@@ -893,7 +951,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
893 |
)
|
894 |
steps_slider = gr.Slider(
|
895 |
minimum=1,
|
896 |
-
maximum=
|
897 |
step=1,
|
898 |
value=DEFAULT_STEPS,
|
899 |
label="🔄 Inference Steps",
|
@@ -964,18 +1022,18 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
964 |
gr.Markdown("### 🎯 Example Prompts")
|
965 |
gr.Examples(
|
966 |
examples=[
|
967 |
-
["A
|
968 |
-
128, 128,
|
969 |
-
|
970 |
-
"Enable Audio", "
|
971 |
-
["A red
|
972 |
-
128, 128,
|
973 |
-
|
974 |
-
"Enable Audio", "car engine
|
975 |
-
["
|
976 |
-
128, 128,
|
977 |
-
|
978 |
-
"Video Only", "", default_audio_negative_prompt, -1,
|
979 |
],
|
980 |
fn=generate_video,
|
981 |
inputs=[prompt, nag_negative_prompt, nag_scale,
|
|
|
2 |
import os
|
3 |
import sys
|
4 |
|
5 |
+
print("Starting NAG Video Demo application...")
|
6 |
+
|
7 |
# Add current directory to Python path
|
8 |
try:
|
9 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
48 |
self.out_channels = out_channels
|
49 |
self.hidden_size = hidden_size
|
50 |
self.training = False
|
51 |
+
self._dtype = torch.float32 # Add dtype attribute
|
52 |
|
53 |
# Dummy config for compatibility
|
54 |
self.config = type('Config', (), {
|
|
|
70 |
nn.SiLU(),
|
71 |
nn.Linear(hidden_size, hidden_size),
|
72 |
)
|
73 |
+
|
74 |
+
@property
|
75 |
+
def dtype(self):
|
76 |
+
"""Return the dtype of the model"""
|
77 |
+
return self._dtype
|
78 |
+
|
79 |
+
@dtype.setter
|
80 |
+
def dtype(self, value):
|
81 |
+
"""Set the dtype of the model"""
|
82 |
+
self._dtype = value
|
83 |
+
|
84 |
+
def to(self, *args, **kwargs):
|
85 |
+
"""Override to method to handle dtype"""
|
86 |
+
result = super().to(*args, **kwargs)
|
87 |
+
# Update dtype if moving to a specific dtype
|
88 |
+
for arg in args:
|
89 |
+
if isinstance(arg, torch.dtype):
|
90 |
+
self._dtype = arg
|
91 |
+
if 'dtype' in kwargs:
|
92 |
+
self._dtype = kwargs['dtype']
|
93 |
+
return result
|
94 |
|
95 |
@staticmethod
|
96 |
def attn_processors():
|
|
|
447 |
|
448 |
# Constants
|
449 |
MOD_VALUE = 32
|
450 |
+
DEFAULT_DURATION_SECONDS = 1
|
451 |
+
DEFAULT_STEPS = 1
|
452 |
DEFAULT_SEED = 2025
|
453 |
DEFAULT_H_SLIDER_VALUE = 128
|
454 |
DEFAULT_W_SLIDER_VALUE = 128
|
|
|
458 |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 256
|
459 |
MAX_SEED = np.iinfo(np.int32).max
|
460 |
|
461 |
+
FIXED_FPS = 8 # Reduced FPS for demo
|
462 |
MIN_FRAMES_MODEL = 8
|
463 |
+
MAX_FRAMES_MODEL = 32 # Reduced max frames for demo
|
464 |
|
465 |
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
|
466 |
|
|
|
478 |
class DemoVAE(nn.Module):
|
479 |
def __init__(self):
|
480 |
super().__init__()
|
481 |
+
self._dtype = torch.float32 # Add dtype attribute
|
482 |
self.encoder = nn.Sequential(
|
483 |
nn.Conv2d(3, 64, 3, padding=1),
|
484 |
nn.ReLU(),
|
|
|
495 |
'latent_channels': 4,
|
496 |
})()
|
497 |
|
498 |
+
@property
|
499 |
+
def dtype(self):
|
500 |
+
"""Return the dtype of the model"""
|
501 |
+
return self._dtype
|
502 |
+
|
503 |
+
@dtype.setter
|
504 |
+
def dtype(self, value):
|
505 |
+
"""Set the dtype of the model"""
|
506 |
+
self._dtype = value
|
507 |
+
|
508 |
+
def to(self, *args, **kwargs):
|
509 |
+
"""Override to method to handle dtype"""
|
510 |
+
result = super().to(*args, **kwargs)
|
511 |
+
# Update dtype if moving to a specific dtype
|
512 |
+
for arg in args:
|
513 |
+
if isinstance(arg, torch.dtype):
|
514 |
+
self._dtype = arg
|
515 |
+
if 'dtype' in kwargs:
|
516 |
+
self._dtype = kwargs['dtype']
|
517 |
+
return result
|
518 |
+
|
519 |
def encode(self, x):
|
520 |
# Simple encoding
|
521 |
encoded = self.encoder(x)
|
|
|
565 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
566 |
print(f"Using device: {device}")
|
567 |
|
568 |
+
# Move models to device with explicit dtype
|
569 |
+
vae = vae.to(device).to(torch.float32)
|
570 |
+
transformer = transformer.to(device).to(torch.float32)
|
571 |
|
572 |
+
# Now move pipeline to device (it will handle the components)
|
573 |
+
try:
|
574 |
+
pipe = pipe.to(device)
|
575 |
+
print(f"Pipeline moved to {device}")
|
576 |
+
except Exception as e:
|
577 |
+
print(f"Warning: Could not move pipeline to {device}: {e}")
|
578 |
+
# Manually set device
|
579 |
+
pipe._execution_device = device
|
580 |
|
|
|
581 |
print("Demo version ready!")
|
582 |
|
583 |
# Check if transformer has the required methods
|
|
|
795 |
if hasattr(pipe, 'vae'):
|
796 |
pipe.vae = pipe.vae.to(device).to(torch.float32)
|
797 |
|
798 |
+
print(f"Generating video: {target_w}x{target_h}, {num_frames} frames, seed {current_seed}")
|
799 |
+
|
800 |
with torch.inference_mode():
|
801 |
nag_output_frames_list = pipe(
|
802 |
prompt=prompt,
|
|
|
834 |
|
835 |
except Exception as e:
|
836 |
print(f"Error generating video: {e}")
|
837 |
+
import traceback
|
838 |
+
traceback.print_exc()
|
839 |
+
|
840 |
# Return a simple error video
|
841 |
+
error_frames = []
|
842 |
+
for i in range(8): # Create 8 frames
|
843 |
+
frame = np.zeros((128, 128, 3), dtype=np.uint8)
|
844 |
+
frame[:, :] = [255, 0, 0] # Red frame
|
845 |
+
# Add error text
|
846 |
+
error_frames.append(frame)
|
847 |
+
|
848 |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
849 |
error_video_path = tmpfile.name
|
850 |
+
export_to_video(error_frames, error_video_path, fps=FIXED_FPS)
|
851 |
+
return error_video_path, None, 0
|
852 |
|
853 |
def update_audio_visibility(audio_mode):
|
854 |
return gr.update(visible=(audio_mode == "Enable Audio"))
|
|
|
857 |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
858 |
with gr.Column(elem_classes="container"):
|
859 |
gr.HTML("""
|
860 |
+
<h1 class="main-title">🎬 NAG Video Demo</h1>
|
861 |
+
<p class="subtitle">Simple Text-to-Video with NAG + Audio Generation</p>
|
862 |
""")
|
863 |
|
864 |
gr.HTML("""
|
|
|
875 |
with gr.Group(elem_classes="prompt-container"):
|
876 |
prompt = gr.Textbox(
|
877 |
label="✨ Video Prompt",
|
878 |
+
value=default_prompt,
|
879 |
+
placeholder="Describe your video scene...",
|
880 |
+
lines=2,
|
881 |
elem_classes="prompt-input"
|
882 |
)
|
883 |
|
|
|
889 |
)
|
890 |
nag_scale = gr.Slider(
|
891 |
label="NAG Scale",
|
892 |
+
minimum=0.0,
|
893 |
maximum=20.0,
|
894 |
step=0.25,
|
895 |
+
value=5.0,
|
896 |
+
info="Higher values = stronger guidance (0 = no NAG)"
|
897 |
)
|
898 |
|
899 |
audio_mode = gr.Radio(
|
|
|
924 |
)
|
925 |
audio_steps = gr.Slider(
|
926 |
minimum=1,
|
927 |
+
maximum=25,
|
928 |
step=1,
|
929 |
+
value=10,
|
930 |
label="🚀 Audio Steps"
|
931 |
)
|
932 |
audio_cfg_strength = gr.Slider(
|
|
|
943 |
with gr.Row():
|
944 |
duration_seconds_input = gr.Slider(
|
945 |
minimum=1,
|
946 |
+
maximum=2,
|
947 |
step=1,
|
948 |
value=DEFAULT_DURATION_SECONDS,
|
949 |
label="📱 Duration (seconds)",
|
|
|
951 |
)
|
952 |
steps_slider = gr.Slider(
|
953 |
minimum=1,
|
954 |
+
maximum=2,
|
955 |
step=1,
|
956 |
value=DEFAULT_STEPS,
|
957 |
label="🔄 Inference Steps",
|
|
|
1022 |
gr.Markdown("### 🎯 Example Prompts")
|
1023 |
gr.Examples(
|
1024 |
examples=[
|
1025 |
+
["A cat playing guitar on stage", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
|
1026 |
+
128, 128, 1,
|
1027 |
+
1, DEFAULT_SEED, False,
|
1028 |
+
"Enable Audio", "guitar music", default_audio_negative_prompt, -1, 10, 4.5],
|
1029 |
+
["A red car driving on a cliff road", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
|
1030 |
+
128, 128, 1,
|
1031 |
+
1, DEFAULT_SEED, False,
|
1032 |
+
"Enable Audio", "car engine, wind", default_audio_negative_prompt, -1, 10, 4.5],
|
1033 |
+
["Glowing jellyfish floating in the sky", DEFAULT_NAG_NEGATIVE_PROMPT, 5,
|
1034 |
+
128, 128, 1,
|
1035 |
+
1, DEFAULT_SEED, False,
|
1036 |
+
"Video Only", "", default_audio_negative_prompt, -1, 10, 4.5],
|
1037 |
],
|
1038 |
fn=generate_video,
|
1039 |
inputs=[prompt, nag_negative_prompt, nag_scale,
|