seawolf2357 commited on
Commit
2d767e1
Β·
verified Β·
1 Parent(s): 6a6682f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +538 -232
app.py CHANGED
@@ -1,13 +1,19 @@
 
 
 
 
 
 
1
  import spaces
 
 
 
 
2
  import logging
3
- from datetime import datetime
4
- from pathlib import Path
5
-
6
- import gradio as gr
7
- import torch
8
  import torchaudio
9
  import os
10
 
 
11
  try:
12
  import mmaudio
13
  except ImportError:
@@ -20,262 +26,562 @@ from mmaudio.model.flow_matching import FlowMatching
20
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
21
  from mmaudio.model.sequence_config import SequenceConfig
22
  from mmaudio.model.utils.features_utils import FeaturesUtils
23
- import tempfile
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  torch.backends.cuda.matmul.allow_tf32 = True
26
  torch.backends.cudnn.allow_tf32 = True
27
 
28
  log = logging.getLogger()
29
-
30
  device = 'cuda'
31
  dtype = torch.bfloat16
32
 
33
- model: ModelConfig = all_model_cfg['large_44k_v2']
34
- model.download_if_needed()
35
- output_dir = Path('./output/gradio')
36
-
37
  setup_eval_logging()
38
 
 
 
 
 
 
39
 
40
- def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
41
- seq_cfg = model.seq_cfg
42
-
43
- net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
44
- net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
45
- log.info(f'Loaded weights from {model.model_path}')
46
-
47
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
48
- synchformer_ckpt=model.synchformer_ckpt,
49
  enable_conditions=True,
50
- mode=model.mode,
51
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
52
  need_vae_encoder=False)
53
  feature_utils = feature_utils.to(device, dtype).eval()
54
-
55
  return net, feature_utils, seq_cfg
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- net, feature_utils, seq_cfg = get_model()
59
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- @spaces.GPU(duration=120)
62
  @torch.inference_mode()
63
- def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
64
- cfg_strength: float, duration: float):
65
-
66
  rng = torch.Generator(device=device)
67
- if seed >= 0:
68
- rng.manual_seed(seed)
69
  else:
70
  rng.seed()
71
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
72
-
73
- video_info = load_video(video, duration)
74
- clip_frames = video_info.clip_frames
75
- sync_frames = video_info.sync_frames
 
76
  duration = video_info.duration_sec
77
- clip_frames = clip_frames.unsqueeze(0)
78
- sync_frames = sync_frames.unsqueeze(0)
79
- seq_cfg.duration = duration
80
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
81
-
82
- audios = generate(clip_frames,
83
- sync_frames, [prompt],
84
- negative_text=[negative_prompt],
85
- feature_utils=feature_utils,
86
- net=net,
87
- fm=fm,
88
- rng=rng,
89
- cfg_strength=cfg_strength)
90
- audio = audios.float().cpu()[0]
91
-
92
- # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
93
- video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
94
- # output_dir.mkdir(exist_ok=True, parents=True)
95
- # video_save_path = output_dir / f'{current_time_string}.mp4'
96
- make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
97
- log.info(f'Saved video to {video_save_path}')
98
- return video_save_path
99
-
100
-
101
- @spaces.GPU(duration=120)
102
- @torch.inference_mode()
103
- def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
104
- duration: float):
105
-
106
- rng = torch.Generator(device=device)
107
- if seed >= 0:
108
- rng.manual_seed(seed)
109
- else:
110
- rng.seed()
111
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
112
-
113
- clip_frames = sync_frames = None
114
- seq_cfg.duration = duration
115
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
116
-
117
  audios = generate(clip_frames,
118
- sync_frames, [prompt],
119
- negative_text=[negative_prompt],
120
- feature_utils=feature_utils,
121
- net=net,
122
  fm=fm,
123
  rng=rng,
124
- cfg_strength=cfg_strength)
125
  audio = audios.float().cpu()[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
128
- torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
129
- log.info(f'Saved audio to {audio_save_path}')
130
- return audio_save_path
131
-
132
-
133
- video_to_audio_tab = gr.Interface(
134
- fn=video_to_audio,
135
- description="""
136
- Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
137
- Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
138
-
139
- Ho Kei Cheng, Masato Ishii, Akio Hayakawa, Takashi Shibuya, Alexander Schwing, Yuki Mitsufuji
140
 
141
- University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation
142
 
143
- CVPR 2025
144
-
145
- NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
146
- Doing so does not improve results.
147
-
148
- The model has been trained on 8-second videos. Using much longer or shorter videos will degrade performance. Around 5s~12s should be fine.
149
- """,
150
- inputs=[
151
- gr.Video(),
152
- gr.Text(label='Prompt'),
153
- gr.Text(label='Negative prompt', value='music'),
154
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
155
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
156
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
157
- gr.Number(label='Duration (sec)', value=8, minimum=1),
158
- ],
159
- outputs='playable_video',
160
- cache_examples=False,
161
- title='MMAudio β€” Video-to-Audio Synthesis',
162
- examples=[
163
- [
164
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
165
- 'waves, seagulls',
166
- '',
167
- 0,
168
- 25,
169
- 4.5,
170
- 10,
171
- ],
172
- [
173
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
174
- '',
175
- 'music',
176
- 0,
177
- 25,
178
- 4.5,
179
- 10,
180
- ],
181
- [
182
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
183
- 'bubbles',
184
- '',
185
- 0,
186
- 25,
187
- 4.5,
188
- 10,
189
- ],
190
- [
191
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
192
- 'Indian holy music',
193
- '',
194
- 0,
195
- 25,
196
- 4.5,
197
- 10,
198
- ],
199
- [
200
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
201
- 'galloping',
202
- '',
203
- 0,
204
- 25,
205
- 4.5,
206
- 10,
207
- ],
208
- [
209
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
210
- 'waves, storm',
211
- '',
212
- 0,
213
- 25,
214
- 4.5,
215
- 10,
216
- ],
217
- [
218
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
219
- '',
220
- '',
221
- 0,
222
- 25,
223
- 4.5,
224
- 10,
225
- ],
226
- [
227
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
228
- 'storm',
229
- '',
230
- 0,
231
- 25,
232
- 4.5,
233
- 10,
234
- ],
235
- [
236
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
237
- '',
238
- '',
239
- 0,
240
- 25,
241
- 4.5,
242
- 10,
243
- ],
244
- [
245
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
246
- 'typing',
247
- '',
248
- 0,
249
- 25,
250
- 4.5,
251
- 10,
252
- ],
253
- [
254
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
255
- '',
256
- '',
257
- 0,
258
- 25,
259
- 4.5,
260
- 10,
261
- ],
262
- ])
263
-
264
- text_to_audio_tab = gr.Interface(
265
- fn=text_to_audio,
266
- inputs=[
267
- gr.Text(label='Prompt'),
268
- gr.Text(label='Negative prompt'),
269
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
270
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
271
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
272
- gr.Number(label='Duration (sec)', value=8, minimum=1),
273
- ],
274
- outputs='audio',
275
- cache_examples=False,
276
- title='MMAudio β€” Text-to-Audio Synthesis',
277
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  if __name__ == "__main__":
280
- gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab],
281
- ['Video-to-Audio', 'Text-to-Audio']).launch(allowed_paths=[output_dir])
 
1
+ import torch
2
+ from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
3
+ from diffusers.utils import export_to_video
4
+ from transformers import CLIPVisionModel
5
+ import gradio as gr
6
+ import tempfile
7
  import spaces
8
+ from huggingface_hub import hf_hub_download
9
+ import numpy as np
10
+ from PIL import Image
11
+ import random
12
  import logging
 
 
 
 
 
13
  import torchaudio
14
  import os
15
 
16
+ # MMAudio imports
17
  try:
18
  import mmaudio
19
  except ImportError:
 
26
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
27
  from mmaudio.model.sequence_config import SequenceConfig
28
  from mmaudio.model.utils.features_utils import FeaturesUtils
 
29
 
30
+ # Video generation model setup
31
+ MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
32
+ LORA_REPO_ID = "Kijai/WanVideo_comfy"
33
+ LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
34
+
35
+ image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32)
36
+ vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
37
+ pipe = WanImageToVideoPipeline.from_pretrained(
38
+ MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
39
+ )
40
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
41
+ pipe.to("cuda")
42
+
43
+ causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
44
+ pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
45
+ pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
46
+ pipe.fuse_lora()
47
+
48
+ # Audio generation model setup
49
  torch.backends.cuda.matmul.allow_tf32 = True
50
  torch.backends.cudnn.allow_tf32 = True
51
 
52
  log = logging.getLogger()
 
53
  device = 'cuda'
54
  dtype = torch.bfloat16
55
 
56
+ audio_model: ModelConfig = all_model_cfg['large_44k_v2']
57
+ audio_model.download_if_needed()
 
 
58
  setup_eval_logging()
59
 
60
+ def get_audio_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
61
+ seq_cfg = audio_model.seq_cfg
62
+ net: MMAudio = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval()
63
+ net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True))
64
+ log.info(f'Loaded weights from {audio_model.model_path}')
65
 
66
+ feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model.vae_path,
67
+ synchformer_ckpt=audio_model.synchformer_ckpt,
 
 
 
 
 
 
 
68
  enable_conditions=True,
69
+ mode=audio_model.mode,
70
+ bigvgan_vocoder_ckpt=audio_model.bigvgan_16k_path,
71
  need_vae_encoder=False)
72
  feature_utils = feature_utils.to(device, dtype).eval()
 
73
  return net, feature_utils, seq_cfg
74
 
75
+ audio_net, audio_feature_utils, audio_seq_cfg = get_audio_model()
76
+
77
+ # Constants
78
+ MOD_VALUE = 32
79
+ DEFAULT_H_SLIDER_VALUE = 512
80
+ DEFAULT_W_SLIDER_VALUE = 896
81
+ NEW_FORMULA_MAX_AREA = 480.0 * 832.0
82
+
83
+ SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
84
+ SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
85
+ MAX_SEED = np.iinfo(np.int32).max
86
+
87
+ FIXED_FPS = 24
88
+ MIN_FRAMES_MODEL = 8
89
+ MAX_FRAMES_MODEL = 81
90
+
91
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
92
+ default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards, watermark, text, signature"
93
+ default_audio_prompt = ""
94
+ default_audio_negative_prompt = "music"
95
+
96
+ # CSS
97
+ custom_css = """
98
+ /* 전체 λ°°κ²½ κ·ΈλΌλ””μ–ΈνŠΈ */
99
+ .gradio-container {
100
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important;
101
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 25%, #f093fb 50%, #f5576c 75%, #fa709a 100%) !important;
102
+ background-size: 400% 400% !important;
103
+ animation: gradientShift 15s ease infinite !important;
104
+ }
105
+
106
+ @keyframes gradientShift {
107
+ 0% { background-position: 0% 50%; }
108
+ 50% { background-position: 100% 50%; }
109
+ 100% { background-position: 0% 50%; }
110
+ }
111
+
112
+ /* 메인 μ»¨ν…Œμ΄λ„ˆ μŠ€νƒ€μΌ */
113
+ .main-container {
114
+ backdrop-filter: blur(10px);
115
+ background: rgba(255, 255, 255, 0.1) !important;
116
+ border-radius: 20px !important;
117
+ padding: 30px !important;
118
+ box-shadow: 0 8px 32px 0 rgba(31, 38, 135, 0.37) !important;
119
+ border: 1px solid rgba(255, 255, 255, 0.18) !important;
120
+ }
121
+
122
+ /* 헀더 μŠ€νƒ€μΌ */
123
+ h1 {
124
+ background: linear-gradient(45deg, #ffffff, #f0f0f0) !important;
125
+ -webkit-background-clip: text !important;
126
+ -webkit-text-fill-color: transparent !important;
127
+ background-clip: text !important;
128
+ font-weight: 800 !important;
129
+ font-size: 2.5rem !important;
130
+ text-align: center !important;
131
+ margin-bottom: 2rem !important;
132
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.1) !important;
133
+ }
134
+
135
+ /* μ»΄ν¬λ„ŒνŠΈ μ»¨ν…Œμ΄λ„ˆ μŠ€νƒ€μΌ */
136
+ .input-container, .output-container {
137
+ background: rgba(255, 255, 255, 0.08) !important;
138
+ border-radius: 15px !important;
139
+ padding: 20px !important;
140
+ margin: 10px 0 !important;
141
+ backdrop-filter: blur(5px) !important;
142
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
143
+ }
144
+
145
+ /* μž…λ ₯ ν•„λ“œ μŠ€νƒ€μΌ */
146
+ input, textarea, .gr-box {
147
+ background: rgba(255, 255, 255, 0.9) !important;
148
+ border: 1px solid rgba(255, 255, 255, 0.3) !important;
149
+ border-radius: 10px !important;
150
+ color: #333 !important;
151
+ transition: all 0.3s ease !important;
152
+ }
153
+
154
+ input:focus, textarea:focus {
155
+ background: rgba(255, 255, 255, 1) !important;
156
+ border-color: #667eea !important;
157
+ box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important;
158
+ }
159
+
160
+ /* λ²„νŠΌ μŠ€νƒ€μΌ */
161
+ .generate-btn {
162
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
163
+ color: white !important;
164
+ font-weight: 600 !important;
165
+ font-size: 1.1rem !important;
166
+ padding: 12px 30px !important;
167
+ border-radius: 50px !important;
168
+ border: none !important;
169
+ cursor: pointer !important;
170
+ transition: all 0.3s ease !important;
171
+ box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
172
+ }
173
+
174
+ .generate-btn:hover {
175
+ transform: translateY(-2px) !important;
176
+ box-shadow: 0 6px 20px rgba(102, 126, 234, 0.6) !important;
177
+ }
178
+
179
+ /* μŠ¬λΌμ΄λ” μŠ€νƒ€μΌ */
180
+ input[type="range"] {
181
+ background: transparent !important;
182
+ }
183
+
184
+ input[type="range"]::-webkit-slider-track {
185
+ background: rgba(255, 255, 255, 0.3) !important;
186
+ border-radius: 5px !important;
187
+ height: 6px !important;
188
+ }
189
+
190
+ input[type="range"]::-webkit-slider-thumb {
191
+ background: linear-gradient(135deg, #667eea, #764ba2) !important;
192
+ border: 2px solid white !important;
193
+ border-radius: 50% !important;
194
+ cursor: pointer !important;
195
+ width: 18px !important;
196
+ height: 18px !important;
197
+ -webkit-appearance: none !important;
198
+ }
199
+
200
+ /* Accordion μŠ€νƒ€μΌ */
201
+ .gr-accordion {
202
+ background: rgba(255, 255, 255, 0.05) !important;
203
+ border-radius: 10px !important;
204
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
205
+ margin: 15px 0 !important;
206
+ }
207
+
208
+ /* 라벨 μŠ€νƒ€μΌ */
209
+ label {
210
+ color: #ffffff !important;
211
+ font-weight: 500 !important;
212
+ font-size: 0.95rem !important;
213
+ margin-bottom: 5px !important;
214
+ }
215
+
216
+ /* 이미지 μ—…λ‘œλ“œ μ˜μ—­ */
217
+ .image-upload {
218
+ border: 2px dashed rgba(255, 255, 255, 0.3) !important;
219
+ border-radius: 15px !important;
220
+ background: rgba(255, 255, 255, 0.05) !important;
221
+ transition: all 0.3s ease !important;
222
+ }
223
+
224
+ .image-upload:hover {
225
+ border-color: rgba(255, 255, 255, 0.5) !important;
226
+ background: rgba(255, 255, 255, 0.1) !important;
227
+ }
228
+
229
+ /* λΉ„λ””μ˜€ 좜λ ₯ μ˜μ—­ */
230
+ video {
231
+ border-radius: 15px !important;
232
+ box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3) !important;
233
+ }
234
+
235
+ /* Examples μ„Ήμ…˜ μŠ€νƒ€μΌ */
236
+ .gr-examples {
237
+ background: rgba(255, 255, 255, 0.05) !important;
238
+ border-radius: 15px !important;
239
+ padding: 20px !important;
240
+ margin-top: 20px !important;
241
+ }
242
+
243
+ /* Checkbox μŠ€νƒ€μΌ */
244
+ input[type="checkbox"] {
245
+ accent-color: #667eea !important;
246
+ }
247
+
248
+ /* Radio λ²„νŠΌ μŠ€νƒ€μΌ */
249
+ input[type="radio"] {
250
+ accent-color: #667eea !important;
251
+ }
252
+
253
+ /* λ°˜μ‘ν˜• μ• λ‹ˆλ©”μ΄μ…˜ */
254
+ @media (max-width: 768px) {
255
+ h1 { font-size: 2rem !important; }
256
+ .main-container { padding: 20px !important; }
257
+ }
258
+ """
259
+
260
+ def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
261
+ min_slider_h, max_slider_h,
262
+ min_slider_w, max_slider_w,
263
+ default_h, default_w):
264
+ orig_w, orig_h = pil_image.size
265
+ if orig_w <= 0 or orig_h <= 0:
266
+ return default_h, default_w
267
+
268
+ aspect_ratio = orig_h / orig_w
269
+
270
+ calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
271
+ calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
272
 
273
+ calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
274
+ calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
275
+
276
+ new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
277
+ new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
278
+
279
+ return new_h, new_w
280
+
281
+ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
282
+ if uploaded_pil_image is None:
283
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
284
+ try:
285
+ new_h, new_w = _calculate_new_dimensions_wan(
286
+ uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
287
+ SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
288
+ DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
289
+ )
290
+ return gr.update(value=new_h), gr.update(value=new_w)
291
+ except Exception as e:
292
+ gr.Warning("Error attempting to calculate new dimensions")
293
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
294
+
295
+ def get_duration(input_image, prompt, height, width,
296
+ negative_prompt, duration_seconds,
297
+ guidance_scale, steps,
298
+ seed, randomize_seed,
299
+ audio_mode, audio_prompt, audio_negative_prompt,
300
+ audio_seed, audio_steps, audio_cfg_strength,
301
+ progress):
302
+ base_duration = 60
303
+ if steps > 4 and duration_seconds > 2:
304
+ base_duration = 90
305
+ elif steps > 4 or duration_seconds > 2:
306
+ base_duration = 75
307
+
308
+ # Add extra time for audio generation
309
+ if audio_mode == "Enable Audio":
310
+ base_duration += 60
311
+
312
+ return base_duration
313
 
 
314
  @torch.inference_mode()
315
+ def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
316
+ audio_seed, audio_steps, audio_cfg_strength):
317
+ """Add audio to video using MMAudio"""
318
  rng = torch.Generator(device=device)
319
+ if audio_seed >= 0:
320
+ rng.manual_seed(audio_seed)
321
  else:
322
  rng.seed()
323
+
324
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=audio_steps)
325
+
326
+ video_info = load_video(video_path, duration_sec)
327
+ clip_frames = video_info.clip_frames.unsqueeze(0)
328
+ sync_frames = video_info.sync_frames.unsqueeze(0)
329
  duration = video_info.duration_sec
330
+ audio_seq_cfg.duration = duration
331
+ audio_net.update_seq_lengths(audio_seq_cfg.latent_seq_len, audio_seq_cfg.clip_seq_len, audio_seq_cfg.sync_seq_len)
332
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  audios = generate(clip_frames,
334
+ sync_frames, [audio_prompt],
335
+ negative_text=[audio_negative_prompt],
336
+ feature_utils=audio_feature_utils,
337
+ net=audio_net,
338
  fm=fm,
339
  rng=rng,
340
+ cfg_strength=audio_cfg_strength)
341
  audio = audios.float().cpu()[0]
342
+
343
+ # Save video with audio
344
+ video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
345
+ make_video(video_info, video_with_audio_path, audio, sampling_rate=audio_seq_cfg.sampling_rate)
346
+
347
+ return video_with_audio_path
348
+
349
+ @spaces.GPU(duration=get_duration)
350
+ def generate_video(input_image, prompt, height, width,
351
+ negative_prompt, duration_seconds,
352
+ guidance_scale, steps,
353
+ seed, randomize_seed,
354
+ audio_mode, audio_prompt, audio_negative_prompt,
355
+ audio_seed, audio_steps, audio_cfg_strength,
356
+ progress=gr.Progress(track_tqdm=True)):
357
+
358
+ if input_image is None:
359
+ raise gr.Error("Please upload an input image.")
360
 
361
+ target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
362
+ target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
365
 
366
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
367
+
368
+ resized_image = input_image.resize((target_w, target_h))
369
+
370
+ # Generate video
371
+ with torch.inference_mode():
372
+ output_frames_list = pipe(
373
+ image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
374
+ height=target_h, width=target_w, num_frames=num_frames,
375
+ guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
376
+ generator=torch.Generator(device="cuda").manual_seed(current_seed)
377
+ ).frames[0]
378
+
379
+ # Save video without audio
380
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
381
+ video_path = tmpfile.name
382
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
383
+
384
+ # Generate audio if enabled
385
+ video_with_audio_path = None
386
+ if audio_mode == "Enable Audio":
387
+ progress(0.5, desc="Generating audio...")
388
+ video_with_audio_path = add_audio_to_video(
389
+ video_path, duration_seconds,
390
+ audio_prompt, audio_negative_prompt,
391
+ audio_seed, audio_steps, audio_cfg_strength
392
+ )
393
+
394
+ return video_path, video_with_audio_path, current_seed
395
+
396
+ def update_audio_visibility(audio_mode):
397
+ """Update visibility of audio-related components"""
398
+ return gr.update(visible=(audio_mode == "Enable Audio"))
399
+
400
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
401
+ with gr.Column(elem_classes=["main-container"]):
402
+ gr.Markdown("# ✨ Fast 4 steps Wan 2.1 I2V (14B) with CausVid LoRA + Audio")
403
+
404
+ with gr.Row():
405
+ with gr.Column(elem_classes=["input-container"]):
406
+ input_image_component = gr.Image(
407
+ type="pil",
408
+ label="πŸ–ΌοΈ Input Image (auto-resized to target H/W)",
409
+ elem_classes=["image-upload"]
410
+ )
411
+ prompt_input = gr.Textbox(
412
+ label="✏️ Prompt",
413
+ value=default_prompt_i2v,
414
+ lines=2
415
+ )
416
+ duration_seconds_input = gr.Slider(
417
+ minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1),
418
+ maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1),
419
+ step=0.1,
420
+ value=2,
421
+ label="⏱️ Duration (seconds)",
422
+ info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
423
+ )
424
+
425
+ # Audio mode radio button
426
+ audio_mode = gr.Radio(
427
+ choices=["Video Only", "Enable Audio"],
428
+ value="Video Only",
429
+ label="🎡 Audio Mode",
430
+ info="Enable to add audio to your generated video"
431
+ )
432
+
433
+ # Audio settings (initially hidden)
434
+ with gr.Column(visible=False) as audio_settings:
435
+ audio_prompt = gr.Textbox(
436
+ label="🎡 Audio Prompt",
437
+ value=default_audio_prompt,
438
+ placeholder="Describe the audio you want (e.g., 'waves, seagulls', 'footsteps on gravel')",
439
+ lines=2
440
+ )
441
+ audio_negative_prompt = gr.Textbox(
442
+ label="❌ Audio Negative Prompt",
443
+ value=default_audio_negative_prompt,
444
+ lines=2
445
+ )
446
+ with gr.Row():
447
+ audio_seed = gr.Number(
448
+ label="🎲 Audio Seed",
449
+ value=-1,
450
+ precision=0,
451
+ minimum=-1
452
+ )
453
+ audio_steps = gr.Slider(
454
+ minimum=1,
455
+ maximum=50,
456
+ step=1,
457
+ value=25,
458
+ label="πŸš€ Audio Steps"
459
+ )
460
+ audio_cfg_strength = gr.Slider(
461
+ minimum=1.0,
462
+ maximum=10.0,
463
+ step=0.5,
464
+ value=4.5,
465
+ label="🎯 Audio Guidance"
466
+ )
467
+
468
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
469
+ negative_prompt_input = gr.Textbox(
470
+ label="❌ Negative Prompt",
471
+ value=default_negative_prompt,
472
+ lines=3
473
+ )
474
+ seed_input = gr.Slider(
475
+ label="🎲 Seed",
476
+ minimum=0,
477
+ maximum=MAX_SEED,
478
+ step=1,
479
+ value=42,
480
+ interactive=True
481
+ )
482
+ randomize_seed_checkbox = gr.Checkbox(
483
+ label="πŸ”€ Randomize seed",
484
+ value=True,
485
+ interactive=True
486
+ )
487
+ with gr.Row():
488
+ height_input = gr.Slider(
489
+ minimum=SLIDER_MIN_H,
490
+ maximum=SLIDER_MAX_H,
491
+ step=MOD_VALUE,
492
+ value=DEFAULT_H_SLIDER_VALUE,
493
+ label=f"πŸ“ Output Height (multiple of {MOD_VALUE})"
494
+ )
495
+ width_input = gr.Slider(
496
+ minimum=SLIDER_MIN_W,
497
+ maximum=SLIDER_MAX_W,
498
+ step=MOD_VALUE,
499
+ value=DEFAULT_W_SLIDER_VALUE,
500
+ label=f"πŸ“ Output Width (multiple of {MOD_VALUE})"
501
+ )
502
+ steps_slider = gr.Slider(
503
+ minimum=1,
504
+ maximum=30,
505
+ step=1,
506
+ value=4,
507
+ label="πŸš€ Inference Steps"
508
+ )
509
+ guidance_scale_input = gr.Slider(
510
+ minimum=0.0,
511
+ maximum=20.0,
512
+ step=0.5,
513
+ value=1.0,
514
+ label="🎯 Guidance Scale",
515
+ visible=False
516
+ )
517
+
518
+ generate_button = gr.Button(
519
+ "🎬 Generate Video",
520
+ variant="primary",
521
+ elem_classes=["generate-btn"]
522
+ )
523
+
524
+ with gr.Column(elem_classes=["output-container"]):
525
+ video_output = gr.Video(
526
+ label="πŸŽ₯ Generated Video",
527
+ autoplay=True,
528
+ interactive=False
529
+ )
530
+ video_with_audio_output = gr.Video(
531
+ label="πŸŽ₯ Generated Video with Audio",
532
+ autoplay=True,
533
+ interactive=False,
534
+ visible=False
535
+ )
536
+
537
+ # Event handlers
538
+ audio_mode.change(
539
+ fn=update_audio_visibility,
540
+ inputs=[audio_mode],
541
+ outputs=[audio_settings, video_with_audio_output]
542
+ )
543
+
544
+ input_image_component.upload(
545
+ fn=handle_image_upload_for_dims_wan,
546
+ inputs=[input_image_component, height_input, width_input],
547
+ outputs=[height_input, width_input]
548
+ )
549
+
550
+ input_image_component.clear(
551
+ fn=handle_image_upload_for_dims_wan,
552
+ inputs=[input_image_component, height_input, width_input],
553
+ outputs=[height_input, width_input]
554
+ )
555
+
556
+ ui_inputs = [
557
+ input_image_component, prompt_input, height_input, width_input,
558
+ negative_prompt_input, duration_seconds_input,
559
+ guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox,
560
+ audio_mode, audio_prompt, audio_negative_prompt,
561
+ audio_seed, audio_steps, audio_cfg_strength
562
+ ]
563
+ generate_button.click(
564
+ fn=generate_video,
565
+ inputs=ui_inputs,
566
+ outputs=[video_output, video_with_audio_output, seed_input]
567
+ )
568
+
569
+ with gr.Column():
570
+ gr.Examples(
571
+ examples=[
572
+ ["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512,
573
+ default_negative_prompt, 2, 1.0, 4, 42, False,
574
+ "Video Only", "", default_audio_negative_prompt, -1, 25, 4.5],
575
+ ["forg.jpg", "the frog jumps around", 448, 832,
576
+ default_negative_prompt, 2, 1.0, 4, 42, False,
577
+ "Enable Audio", "frog croaking, water splashing", default_audio_negative_prompt, -1, 25, 4.5],
578
+ ],
579
+ inputs=ui_inputs,
580
+ outputs=[video_output, video_with_audio_output, seed_input],
581
+ fn=generate_video,
582
+ cache_examples="lazy",
583
+ label="🌟 Example Gallery"
584
+ )
585
 
586
  if __name__ == "__main__":
587
+ demo.queue().launch()