Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,7 @@ except:
|
|
9 |
current_dir = os.getcwd()
|
10 |
|
11 |
sys.path.insert(0, current_dir)
|
|
|
12 |
|
13 |
os.makedirs("src", exist_ok=True)
|
14 |
|
@@ -30,14 +31,14 @@ from typing import Optional, Dict, Any
|
|
30 |
import torch.nn.functional as F
|
31 |
|
32 |
class NagWanTransformer3DModel(nn.Module):
|
33 |
-
"""NAG-enhanced Transformer for video generation"""
|
34 |
|
35 |
def __init__(
|
36 |
self,
|
37 |
in_channels: int = 4,
|
38 |
out_channels: int = 4,
|
39 |
-
hidden_size: int =
|
40 |
-
num_layers: int =
|
41 |
num_heads: int = 8,
|
42 |
):
|
43 |
super().__init__()
|
@@ -50,46 +51,22 @@ class NagWanTransformer3DModel(nn.Module):
|
|
50 |
self.config = type('Config', (), {
|
51 |
'in_channels': in_channels,
|
52 |
'out_channels': out_channels,
|
53 |
-
'hidden_size': hidden_size
|
|
|
|
|
54 |
})()
|
55 |
|
56 |
-
#
|
57 |
-
|
58 |
-
self.
|
|
|
|
|
|
|
59 |
self.time_embed = nn.Sequential(
|
60 |
-
nn.Linear(
|
61 |
nn.SiLU(),
|
62 |
-
nn.Linear(
|
63 |
-
)
|
64 |
-
self.down_blocks = nn.ModuleList([
|
65 |
-
nn.Conv3d(320, 320, kernel_size=3, stride=2, padding=1),
|
66 |
-
nn.Conv3d(320, 640, kernel_size=3, stride=2, padding=1),
|
67 |
-
nn.Conv3d(640, 1280, kernel_size=3, stride=2, padding=1),
|
68 |
-
])
|
69 |
-
self.mid_block = nn.Conv3d(1280, 1280, kernel_size=3, padding=1)
|
70 |
-
self.up_blocks = nn.ModuleList([
|
71 |
-
nn.ConvTranspose3d(1280, 640, kernel_size=3, stride=2, padding=1, output_padding=1),
|
72 |
-
nn.ConvTranspose3d(640, 320, kernel_size=3, stride=2, padding=1, output_padding=1),
|
73 |
-
nn.ConvTranspose3d(320, 320, kernel_size=3, stride=2, padding=1, output_padding=1),
|
74 |
-
])
|
75 |
-
self.conv_out = nn.Conv3d(320, out_channels, kernel_size=3, padding=1)
|
76 |
-
|
77 |
-
@classmethod
|
78 |
-
def from_single_file(cls, model_path, **kwargs):
|
79 |
-
"""Load model from single file"""
|
80 |
-
print(f"Note: Loading simplified NAG model instead of {model_path}")
|
81 |
-
print("This is a demo version that doesn't require 28GB of weights")
|
82 |
-
|
83 |
-
# Create a simplified model
|
84 |
-
model = cls(
|
85 |
-
in_channels=4,
|
86 |
-
out_channels=4,
|
87 |
-
hidden_size=768,
|
88 |
-
num_layers=4,
|
89 |
-
num_heads=8
|
90 |
)
|
91 |
-
|
92 |
-
return model.to(kwargs.get('torch_dtype', torch.float32))
|
93 |
|
94 |
@staticmethod
|
95 |
def attn_processors():
|
@@ -98,14 +75,6 @@ class NagWanTransformer3DModel(nn.Module):
|
|
98 |
@staticmethod
|
99 |
def set_attn_processor(processor):
|
100 |
pass
|
101 |
-
|
102 |
-
def time_proj(self, timesteps, dim=320):
|
103 |
-
half_dim = dim // 2
|
104 |
-
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
105 |
-
emb = torch.exp(-emb * torch.arange(half_dim, device=timesteps.device))
|
106 |
-
emb = timesteps[:, None] * emb[None, :]
|
107 |
-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
108 |
-
return emb
|
109 |
|
110 |
def forward(
|
111 |
self,
|
@@ -115,33 +84,40 @@ class NagWanTransformer3DModel(nn.Module):
|
|
115 |
attention_mask: Optional[torch.Tensor] = None,
|
116 |
**kwargs
|
117 |
):
|
118 |
-
#
|
|
|
|
|
|
|
119 |
if timestep is not None:
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
t_emb = self.time_embed(t_emb)
|
|
|
|
|
|
|
122 |
|
123 |
-
#
|
124 |
h = self.conv_in(hidden_states)
|
125 |
|
126 |
-
#
|
127 |
-
|
128 |
-
|
129 |
-
down_block_res_samples.append(h)
|
130 |
-
h = down_block(h)
|
131 |
-
|
132 |
-
# Mid block
|
133 |
-
h = self.mid_block(h)
|
134 |
-
|
135 |
-
# Up blocks
|
136 |
-
for i, up_block in enumerate(self.up_blocks):
|
137 |
-
h = up_block(h)
|
138 |
-
# Add skip connections
|
139 |
-
if i < len(down_block_res_samples):
|
140 |
-
h = h + down_block_res_samples[-(i+1)]
|
141 |
|
142 |
-
|
|
|
|
|
143 |
h = self.conv_out(h)
|
144 |
|
|
|
|
|
|
|
145 |
return h
|
146 |
''')
|
147 |
|
@@ -394,8 +370,9 @@ import types
|
|
394 |
import random
|
395 |
import spaces
|
396 |
import torch
|
|
|
397 |
import numpy as np
|
398 |
-
from diffusers import
|
399 |
from diffusers.utils import export_to_video
|
400 |
import gradio as gr
|
401 |
import tempfile
|
@@ -414,7 +391,17 @@ try:
|
|
414 |
print("Successfully imported NAG modules")
|
415 |
except Exception as e:
|
416 |
print(f"Error importing NAG modules: {e}")
|
417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
|
419 |
# MMAudio imports
|
420 |
try:
|
@@ -436,12 +423,12 @@ from mmaudio.model.utils.features_utils import FeaturesUtils
|
|
436 |
|
437 |
# Constants
|
438 |
MOD_VALUE = 32
|
439 |
-
DEFAULT_DURATION_SECONDS =
|
440 |
-
DEFAULT_STEPS =
|
441 |
DEFAULT_SEED = 2025
|
442 |
-
DEFAULT_H_SLIDER_VALUE =
|
443 |
-
DEFAULT_W_SLIDER_VALUE =
|
444 |
-
NEW_FORMULA_MAX_AREA =
|
445 |
|
446 |
SLIDER_MIN_H, SLIDER_MAX_H = 128, 512
|
447 |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 512
|
@@ -453,6 +440,7 @@ MAX_FRAMES_MODEL = 129
|
|
453 |
|
454 |
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
|
455 |
|
|
|
456 |
MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
457 |
SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
|
458 |
SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors"
|
@@ -460,44 +448,97 @@ LORA_REPO_ID = "Kijai/WanVideo_comfy"
|
|
460 |
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
461 |
|
462 |
# Initialize models
|
463 |
-
print("
|
464 |
-
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
|
465 |
|
466 |
-
#
|
467 |
-
|
468 |
-
|
469 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
|
471 |
-
print("Creating transformer model...")
|
472 |
-
transformer = NagWanTransformer3DModel
|
|
|
|
|
|
|
|
|
|
|
|
|
473 |
|
474 |
print("Creating pipeline...")
|
475 |
-
|
476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
)
|
478 |
-
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
|
479 |
|
480 |
# Move to appropriate device
|
481 |
-
if torch.cuda.is_available()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
482 |
pipe.to("cuda")
|
483 |
-
print("
|
484 |
else:
|
485 |
pipe.to("cpu")
|
486 |
print("Warning: CUDA not available, using CPU (will be slow)")
|
487 |
|
488 |
-
#
|
489 |
-
|
490 |
-
print("Loading LoRA weights...")
|
491 |
-
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
|
492 |
-
pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
|
493 |
-
pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
|
494 |
-
pipe.fuse_lora()
|
495 |
-
print("LoRA weights loaded successfully")
|
496 |
-
except Exception as e:
|
497 |
-
print(f"Warning: Could not load LoRA weights: {e}")
|
498 |
|
499 |
-
|
500 |
-
|
|
|
|
|
|
|
501 |
|
502 |
# Audio model setup
|
503 |
torch.backends.cuda.matmul.allow_tf32 = True
|
@@ -643,10 +684,11 @@ def get_duration(
|
|
643 |
audio_mode, audio_prompt, audio_negative_prompt,
|
644 |
audio_seed, audio_steps, audio_cfg_strength,
|
645 |
):
|
646 |
-
|
|
|
647 |
if audio_mode == "Enable Audio":
|
648 |
-
duration += 60
|
649 |
-
return duration
|
650 |
|
651 |
@torch.inference_mode()
|
652 |
def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
|
@@ -693,43 +735,64 @@ def generate_video(
|
|
693 |
audio_mode="Video Only", audio_prompt="", audio_negative_prompt="music",
|
694 |
audio_seed=-1, audio_steps=25, audio_cfg_strength=4.5,
|
695 |
):
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
nag_video_path
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
731 |
|
732 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
733 |
|
734 |
def update_audio_visibility(audio_mode):
|
735 |
return gr.update(visible=(audio_mode == "Enable Audio"))
|
@@ -738,15 +801,16 @@ def update_audio_visibility(audio_mode):
|
|
738 |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
739 |
with gr.Column(elem_classes="container"):
|
740 |
gr.HTML("""
|
741 |
-
<h1 class="main-title">🎬 NAG Video
|
742 |
-
<p class="subtitle">
|
743 |
""")
|
744 |
|
745 |
gr.HTML("""
|
746 |
<div class="info-box">
|
747 |
-
<p
|
748 |
<p>🚀 <strong>NAG Technology:</strong> Normalized Attention Guidance for enhanced video quality</p>
|
749 |
<p>🎵 <strong>Audio:</strong> Optional synchronized audio generation with MMAudio</p>
|
|
|
750 |
</div>
|
751 |
""")
|
752 |
|
@@ -822,7 +886,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
822 |
with gr.Row():
|
823 |
duration_seconds_input = gr.Slider(
|
824 |
minimum=1,
|
825 |
-
maximum=
|
826 |
step=1,
|
827 |
value=DEFAULT_DURATION_SECONDS,
|
828 |
label="📱 Duration (seconds)",
|
@@ -830,7 +894,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
830 |
)
|
831 |
steps_slider = gr.Slider(
|
832 |
minimum=1,
|
833 |
-
maximum=
|
834 |
step=1,
|
835 |
value=DEFAULT_STEPS,
|
836 |
label="🔄 Inference Steps",
|
@@ -893,6 +957,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
893 |
|
894 |
gr.HTML("""
|
895 |
<div style="text-align: center; margin-top: 20px; color: #6b7280;">
|
|
|
896 |
<p>💡 Tip: Try different NAG scales for varied artistic effects!</p>
|
897 |
</div>
|
898 |
""")
|
@@ -901,16 +966,16 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
|
901 |
gr.Examples(
|
902 |
examples=[
|
903 |
["A ginger cat passionately plays electric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights cast dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
|
904 |
-
|
905 |
-
|
906 |
"Enable Audio", "electric guitar riffs, cat meowing", default_audio_negative_prompt, -1, 25, 4.5],
|
907 |
["A red vintage Porsche convertible flying over a rugged coastal cliff. Monstrous waves violently crashing against the rocks below. A lighthouse stands tall atop the cliff.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
|
908 |
-
|
909 |
-
|
910 |
"Enable Audio", "car engine roaring, ocean waves crashing, wind", default_audio_negative_prompt, -1, 25, 4.5],
|
911 |
["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
|
912 |
-
|
913 |
-
|
914 |
"Video Only", "", default_audio_negative_prompt, -1, 25, 4.5],
|
915 |
],
|
916 |
fn=generate_video,
|
|
|
9 |
current_dir = os.getcwd()
|
10 |
|
11 |
sys.path.insert(0, current_dir)
|
12 |
+
print(f"Added {current_dir} to Python path")
|
13 |
|
14 |
os.makedirs("src", exist_ok=True)
|
15 |
|
|
|
31 |
import torch.nn.functional as F
|
32 |
|
33 |
class NagWanTransformer3DModel(nn.Module):
|
34 |
+
"""NAG-enhanced Transformer for video generation (simplified demo)"""
|
35 |
|
36 |
def __init__(
|
37 |
self,
|
38 |
in_channels: int = 4,
|
39 |
out_channels: int = 4,
|
40 |
+
hidden_size: int = 1280,
|
41 |
+
num_layers: int = 2,
|
42 |
num_heads: int = 8,
|
43 |
):
|
44 |
super().__init__()
|
|
|
51 |
self.config = type('Config', (), {
|
52 |
'in_channels': in_channels,
|
53 |
'out_channels': out_channels,
|
54 |
+
'hidden_size': hidden_size,
|
55 |
+
'num_attention_heads': num_heads,
|
56 |
+
'attention_head_dim': hidden_size // num_heads,
|
57 |
})()
|
58 |
|
59 |
+
# Simple conv layers for demo
|
60 |
+
self.conv_in = nn.Conv3d(in_channels, 64, kernel_size=3, padding=1)
|
61 |
+
self.conv_mid = nn.Conv3d(64, 64, kernel_size=3, padding=1)
|
62 |
+
self.conv_out = nn.Conv3d(64, out_channels, kernel_size=3, padding=1)
|
63 |
+
|
64 |
+
# Time embedding
|
65 |
self.time_embed = nn.Sequential(
|
66 |
+
nn.Linear(1, hidden_size),
|
67 |
nn.SiLU(),
|
68 |
+
nn.Linear(hidden_size, 64),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
)
|
|
|
|
|
70 |
|
71 |
@staticmethod
|
72 |
def attn_processors():
|
|
|
75 |
@staticmethod
|
76 |
def set_attn_processor(processor):
|
77 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def forward(
|
80 |
self,
|
|
|
84 |
attention_mask: Optional[torch.Tensor] = None,
|
85 |
**kwargs
|
86 |
):
|
87 |
+
# Simple forward pass for demo
|
88 |
+
batch_size = hidden_states.shape[0]
|
89 |
+
|
90 |
+
# Time embedding
|
91 |
if timestep is not None:
|
92 |
+
# Ensure timestep is the right shape
|
93 |
+
if timestep.ndim == 0:
|
94 |
+
timestep = timestep.unsqueeze(0)
|
95 |
+
if timestep.shape[0] != batch_size:
|
96 |
+
timestep = timestep.repeat(batch_size)
|
97 |
+
|
98 |
+
# Normalize timestep to [0, 1]
|
99 |
+
t_emb = timestep.float() / 1000.0
|
100 |
+
t_emb = t_emb.view(-1, 1)
|
101 |
t_emb = self.time_embed(t_emb)
|
102 |
+
|
103 |
+
# Reshape for broadcasting
|
104 |
+
t_emb = t_emb.view(batch_size, -1, 1, 1, 1)
|
105 |
|
106 |
+
# Simple convolutions
|
107 |
h = self.conv_in(hidden_states)
|
108 |
|
109 |
+
# Add time embedding if available
|
110 |
+
if timestep is not None:
|
111 |
+
h = h + t_emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
+
h = F.silu(h)
|
114 |
+
h = self.conv_mid(h)
|
115 |
+
h = F.silu(h)
|
116 |
h = self.conv_out(h)
|
117 |
|
118 |
+
# Add residual connection
|
119 |
+
h = h + hidden_states
|
120 |
+
|
121 |
return h
|
122 |
''')
|
123 |
|
|
|
370 |
import random
|
371 |
import spaces
|
372 |
import torch
|
373 |
+
import torch.nn as nn
|
374 |
import numpy as np
|
375 |
+
from diffusers import AutoencoderKL, UniPCMultistepScheduler, DDPMScheduler
|
376 |
from diffusers.utils import export_to_video
|
377 |
import gradio as gr
|
378 |
import tempfile
|
|
|
391 |
print("Successfully imported NAG modules")
|
392 |
except Exception as e:
|
393 |
print(f"Error importing NAG modules: {e}")
|
394 |
+
print("Attempting to recreate modules...")
|
395 |
+
# Wait a bit and try again
|
396 |
+
import time
|
397 |
+
time.sleep(3)
|
398 |
+
try:
|
399 |
+
from src.pipeline_wan_nag import NAGWanPipeline
|
400 |
+
from src.transformer_wan_nag import NagWanTransformer3DModel
|
401 |
+
print("Successfully imported NAG modules on second attempt")
|
402 |
+
except:
|
403 |
+
print("Failed to import modules. Please restart the application.")
|
404 |
+
sys.exit(1)
|
405 |
|
406 |
# MMAudio imports
|
407 |
try:
|
|
|
423 |
|
424 |
# Constants
|
425 |
MOD_VALUE = 32
|
426 |
+
DEFAULT_DURATION_SECONDS = 2
|
427 |
+
DEFAULT_STEPS = 2
|
428 |
DEFAULT_SEED = 2025
|
429 |
+
DEFAULT_H_SLIDER_VALUE = 128
|
430 |
+
DEFAULT_W_SLIDER_VALUE = 128
|
431 |
+
NEW_FORMULA_MAX_AREA = 256.0 * 256.0
|
432 |
|
433 |
SLIDER_MIN_H, SLIDER_MAX_H = 128, 512
|
434 |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 512
|
|
|
440 |
|
441 |
DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, worst quality, poorly drawn, low resolution, blurry, lack of details"
|
442 |
|
443 |
+
# Note: Model IDs are kept for reference but not used in demo
|
444 |
MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
445 |
SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
|
446 |
SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors"
|
|
|
448 |
LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
449 |
|
450 |
# Initialize models
|
451 |
+
print("Creating demo models...")
|
|
|
452 |
|
453 |
+
# Create a simple VAE-like model for demo
|
454 |
+
class DemoVAE(nn.Module):
|
455 |
+
def __init__(self):
|
456 |
+
super().__init__()
|
457 |
+
self.encoder = nn.Sequential(
|
458 |
+
nn.Conv2d(3, 64, 3, padding=1),
|
459 |
+
nn.ReLU(),
|
460 |
+
nn.Conv2d(64, 4, 3, padding=1)
|
461 |
+
)
|
462 |
+
self.decoder = nn.Sequential(
|
463 |
+
nn.Conv2d(4, 64, 3, padding=1),
|
464 |
+
nn.ReLU(),
|
465 |
+
nn.Conv2d(64, 3, 3, padding=1),
|
466 |
+
nn.Tanh() # Output in [-1, 1]
|
467 |
+
)
|
468 |
+
self.config = type('Config', (), {
|
469 |
+
'scaling_factor': 0.18215,
|
470 |
+
'latent_channels': 4,
|
471 |
+
})()
|
472 |
+
|
473 |
+
def encode(self, x):
|
474 |
+
# Simple encoding
|
475 |
+
encoded = self.encoder(x)
|
476 |
+
return type('EncoderOutput', (), {'latent_dist': type('LatentDist', (), {'sample': lambda: encoded})()})()
|
477 |
+
|
478 |
+
def decode(self, z):
|
479 |
+
# Simple decoding
|
480 |
+
# Handle different input shapes
|
481 |
+
if z.dim() == 5: # Video: (B, C, F, H, W)
|
482 |
+
b, c, f, h, w = z.shape
|
483 |
+
z = z.permute(0, 2, 1, 3, 4).reshape(b * f, c, h, w)
|
484 |
+
decoded = self.decoder(z)
|
485 |
+
decoded = decoded.reshape(b, f, 3, h * 8, w * 8).permute(0, 2, 1, 3, 4)
|
486 |
+
else: # Image: (B, C, H, W)
|
487 |
+
decoded = self.decoder(z)
|
488 |
+
return type('DecoderOutput', (), {'sample': decoded})()
|
489 |
+
|
490 |
+
vae = DemoVAE()
|
491 |
|
492 |
+
print("Creating simplified NAG transformer model...")
|
493 |
+
transformer = NagWanTransformer3DModel(
|
494 |
+
in_channels=4,
|
495 |
+
out_channels=4,
|
496 |
+
hidden_size=1280,
|
497 |
+
num_layers=2, # Reduced for demo
|
498 |
+
num_heads=8
|
499 |
+
)
|
500 |
|
501 |
print("Creating pipeline...")
|
502 |
+
# Create a minimal pipeline for demo
|
503 |
+
pipe = NAGWanPipeline(
|
504 |
+
vae=vae,
|
505 |
+
text_encoder=None,
|
506 |
+
tokenizer=None,
|
507 |
+
transformer=transformer,
|
508 |
+
scheduler=DDPMScheduler(
|
509 |
+
num_train_timesteps=1000,
|
510 |
+
beta_start=0.00085,
|
511 |
+
beta_end=0.012,
|
512 |
+
beta_schedule="scaled_linear",
|
513 |
+
clip_sample=False,
|
514 |
+
set_alpha_to_one=False,
|
515 |
+
steps_offset=1,
|
516 |
+
)
|
517 |
)
|
|
|
518 |
|
519 |
# Move to appropriate device
|
520 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
521 |
+
print(f"Using device: {device}")
|
522 |
+
|
523 |
+
# Move models to device
|
524 |
+
vae = vae.to(device)
|
525 |
+
transformer = transformer.to(device)
|
526 |
+
|
527 |
+
if device == 'cuda':
|
528 |
pipe.to("cuda")
|
529 |
+
print("Pipeline moved to CUDA")
|
530 |
else:
|
531 |
pipe.to("cpu")
|
532 |
print("Warning: CUDA not available, using CPU (will be slow)")
|
533 |
|
534 |
+
# Skip LoRA for demo version
|
535 |
+
print("Demo version ready!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
|
537 |
+
# Check if transformer has the required methods
|
538 |
+
if hasattr(transformer, 'attn_processors'):
|
539 |
+
pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
|
540 |
+
if hasattr(transformer, 'set_attn_processor'):
|
541 |
+
pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
|
542 |
|
543 |
# Audio model setup
|
544 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
684 |
audio_mode, audio_prompt, audio_negative_prompt,
|
685 |
audio_seed, audio_steps, audio_cfg_strength,
|
686 |
):
|
687 |
+
# Simplified duration calculation for demo
|
688 |
+
duration = int(duration_seconds) * int(steps) + 10
|
689 |
if audio_mode == "Enable Audio":
|
690 |
+
duration += 30 # Reduced from 60 for demo
|
691 |
+
return min(duration, 60) # Cap at 60 seconds for demo
|
692 |
|
693 |
@torch.inference_mode()
|
694 |
def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
|
|
|
735 |
audio_mode="Video Only", audio_prompt="", audio_negative_prompt="music",
|
736 |
audio_seed=-1, audio_steps=25, audio_cfg_strength=4.5,
|
737 |
):
|
738 |
+
try:
|
739 |
+
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
|
740 |
+
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
|
741 |
+
|
742 |
+
num_frames = np.clip(int(round(int(duration_seconds) * FIXED_FPS) + 1), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
|
743 |
+
|
744 |
+
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
745 |
+
|
746 |
+
# Ensure transformer is on the right device and dtype
|
747 |
+
if hasattr(pipe, 'transformer'):
|
748 |
+
pipe.transformer = pipe.transformer.to(device).to(torch.float32)
|
749 |
+
if hasattr(pipe, 'vae'):
|
750 |
+
pipe.vae = pipe.vae.to(device).to(torch.float32)
|
751 |
+
|
752 |
+
with torch.inference_mode():
|
753 |
+
nag_output_frames_list = pipe(
|
754 |
+
prompt=prompt,
|
755 |
+
nag_negative_prompt=nag_negative_prompt,
|
756 |
+
nag_scale=nag_scale,
|
757 |
+
nag_tau=3.5,
|
758 |
+
nag_alpha=0.5,
|
759 |
+
height=target_h, width=target_w, num_frames=num_frames,
|
760 |
+
guidance_scale=0.,
|
761 |
+
num_inference_steps=int(steps),
|
762 |
+
generator=torch.Generator(device=device).manual_seed(current_seed)
|
763 |
+
).frames[0]
|
764 |
+
|
765 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
766 |
+
nag_video_path = tmpfile.name
|
767 |
+
export_to_video(nag_output_frames_list, nag_video_path, fps=FIXED_FPS)
|
768 |
+
|
769 |
+
# Generate audio if enabled
|
770 |
+
video_with_audio_path = None
|
771 |
+
if audio_mode == "Enable Audio":
|
772 |
+
try:
|
773 |
+
video_with_audio_path = add_audio_to_video(
|
774 |
+
nag_video_path, duration_seconds,
|
775 |
+
audio_prompt, audio_negative_prompt,
|
776 |
+
audio_seed, audio_steps, audio_cfg_strength
|
777 |
+
)
|
778 |
+
except Exception as e:
|
779 |
+
print(f"Warning: Could not generate audio: {e}")
|
780 |
+
video_with_audio_path = None
|
781 |
+
|
782 |
+
clear_cache()
|
783 |
+
cleanup_temp_files()
|
784 |
|
785 |
+
return nag_video_path, video_with_audio_path, current_seed
|
786 |
+
|
787 |
+
except Exception as e:
|
788 |
+
print(f"Error generating video: {e}")
|
789 |
+
# Return a simple error video
|
790 |
+
error_frames = np.zeros((1, 64, 64, 3), dtype=np.uint8)
|
791 |
+
error_frames[:, :, :] = [255, 0, 0] # Red frame
|
792 |
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
|
793 |
+
error_video_path = tmpfile.name
|
794 |
+
export_to_video([error_frames[0]], error_video_path, fps=1)
|
795 |
+
return error_video_path, None, current_seed
|
796 |
|
797 |
def update_audio_visibility(audio_mode):
|
798 |
return gr.update(visible=(audio_mode == "Enable Audio"))
|
|
|
801 |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
802 |
with gr.Column(elem_classes="container"):
|
803 |
gr.HTML("""
|
804 |
+
<h1 class="main-title">🎬 NAG Video Demo with Audio</h1>
|
805 |
+
<p class="subtitle">Lightweight Text-to-Video with Normalized Attention Guidance + MMAudio</p>
|
806 |
""")
|
807 |
|
808 |
gr.HTML("""
|
809 |
<div class="info-box">
|
810 |
+
<p>📌 <strong>Demo Version:</strong> This is a simplified demo that demonstrates NAG concepts without large model downloads</p>
|
811 |
<p>🚀 <strong>NAG Technology:</strong> Normalized Attention Guidance for enhanced video quality</p>
|
812 |
<p>🎵 <strong>Audio:</strong> Optional synchronized audio generation with MMAudio</p>
|
813 |
+
<p>⚡ <strong>Fast:</strong> Runs without downloading 28GB model files</p>
|
814 |
</div>
|
815 |
""")
|
816 |
|
|
|
886 |
with gr.Row():
|
887 |
duration_seconds_input = gr.Slider(
|
888 |
minimum=1,
|
889 |
+
maximum=4,
|
890 |
step=1,
|
891 |
value=DEFAULT_DURATION_SECONDS,
|
892 |
label="📱 Duration (seconds)",
|
|
|
894 |
)
|
895 |
steps_slider = gr.Slider(
|
896 |
minimum=1,
|
897 |
+
maximum=4,
|
898 |
step=1,
|
899 |
value=DEFAULT_STEPS,
|
900 |
label="🔄 Inference Steps",
|
|
|
957 |
|
958 |
gr.HTML("""
|
959 |
<div style="text-align: center; margin-top: 20px; color: #6b7280;">
|
960 |
+
<p>💡 Demo version with simplified model - Real NAG would produce higher quality results</p>
|
961 |
<p>💡 Tip: Try different NAG scales for varied artistic effects!</p>
|
962 |
</div>
|
963 |
""")
|
|
|
966 |
gr.Examples(
|
967 |
examples=[
|
968 |
["A ginger cat passionately plays electric guitar with intensity and emotion on a stage. The background is shrouded in deep darkness. Spotlights cast dramatic shadows.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
|
969 |
+
128, 128, 2,
|
970 |
+
2, DEFAULT_SEED, False,
|
971 |
"Enable Audio", "electric guitar riffs, cat meowing", default_audio_negative_prompt, -1, 25, 4.5],
|
972 |
["A red vintage Porsche convertible flying over a rugged coastal cliff. Monstrous waves violently crashing against the rocks below. A lighthouse stands tall atop the cliff.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
|
973 |
+
128, 128, 2,
|
974 |
+
2, DEFAULT_SEED, False,
|
975 |
"Enable Audio", "car engine roaring, ocean waves crashing, wind", default_audio_negative_prompt, -1, 25, 4.5],
|
976 |
["Enormous glowing jellyfish float slowly across a sky filled with soft clouds. Their tentacles shimmer with iridescent light as they drift above a peaceful mountain landscape.", DEFAULT_NAG_NEGATIVE_PROMPT, 11,
|
977 |
+
128, 128, 2,
|
978 |
+
2, DEFAULT_SEED, False,
|
979 |
"Video Only", "", default_audio_negative_prompt, -1, 25, 4.5],
|
980 |
],
|
981 |
fn=generate_video,
|