dangthr commited on
Commit
2a7f487
·
verified ·
1 Parent(s): f55df05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -217
app.py CHANGED
@@ -1,189 +1,123 @@
1
  import os
2
  import sys
3
- sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
4
-
5
- # wan2.2-main/gradio_ti2v.py
6
- import gradio as gr
7
  import torch
8
  from huggingface_hub import snapshot_download
9
  from PIL import Image
10
- import random
11
- import numpy as np
12
- import spaces
13
-
14
- import wan
15
- from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
16
- from wan.utils.utils import cache_video
17
-
18
  import gc
19
 
20
- # --- 1. Global Setup and Model Loading ---
21
-
22
- print("Starting Gradio App for Wan 2.2 TI2V-5B...")
23
-
24
- # Download model snapshots from Hugging Face Hub
25
- repo_id = "Wan-AI/Wan2.2-TI2V-5B"
26
- print(f"Downloading/loading checkpoints for {repo_id}...")
27
- ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
28
- print(f"Using checkpoints from {ckpt_dir}")
29
-
30
- # Load the model configuration
31
- TASK_NAME = 'ti2v-5B'
32
- cfg = WAN_CONFIGS[TASK_NAME]
33
- FIXED_FPS = 24
34
- MIN_FRAMES_MODEL = 8
35
- MAX_FRAMES_MODEL = 121
36
-
37
- # Dimension calculation constants
38
- MOD_VALUE = 32
39
- DEFAULT_H_SLIDER_VALUE = 704
40
- DEFAULT_W_SLIDER_VALUE = 1280
41
- NEW_FORMULA_MAX_AREA = 1280.0 * 704.0
42
-
43
- SLIDER_MIN_H, SLIDER_MAX_H = 128, 1280
44
- SLIDER_MIN_W, SLIDER_MAX_W = 128, 1280
45
-
46
- # Instantiate the pipeline in the global scope
47
- print("Initializing WanTI2V pipeline...")
48
- device = "cuda" if torch.cuda.is_available() else "cpu"
49
- device_id = 0 if torch.cuda.is_available() else -1
50
- pipeline = wan.WanTI2V(
51
- config=cfg,
52
- checkpoint_dir=ckpt_dir,
53
- device_id=device_id,
54
- rank=0,
55
- t5_fsdp=False,
56
- dit_fsdp=False,
57
- use_sp=False,
58
- t5_cpu=False,
59
- init_on_cpu=False,
60
- convert_model_dtype=True,
61
- )
62
- print("Pipeline initialized and ready.")
63
-
64
- # --- Helper Functions (from Wan 2.1 Fast demo) ---
65
- def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
66
- min_slider_h, max_slider_h,
67
- min_slider_w, max_slider_w,
68
- default_h, default_w):
69
- orig_w, orig_h = pil_image.size
70
- if orig_w <= 0 or orig_h <= 0:
71
- return default_h, default_w
72
 
73
- aspect_ratio = orig_h / orig_w
74
-
75
- calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
76
- calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
77
 
78
- calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
79
- calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
80
-
81
- new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
82
- new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
83
-
84
- return new_h, new_w
85
 
86
- def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
87
  """
88
- Handle image upload and calculate appropriate dimensions for video generation.
89
-
90
- Args:
91
- uploaded_pil_image: The uploaded image (PIL Image or numpy array)
92
- current_h_val: Current height slider value
93
- current_w_val: Current width slider value
94
-
95
- Returns:
96
- Tuple of gr.update objects for height and width sliders
97
  """
98
- if uploaded_pil_image is None:
99
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
100
  try:
101
- # Convert numpy array to PIL Image if needed
102
- if hasattr(uploaded_pil_image, 'shape'): # numpy array
103
- pil_image = Image.fromarray(uploaded_pil_image).convert("RGB")
104
- else: # already PIL Image
105
- pil_image = uploaded_pil_image
106
-
107
- new_h, new_w = _calculate_new_dimensions_wan(
108
- pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
109
- SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
110
- DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
111
- )
112
- return gr.update(value=new_h), gr.update(value=new_w)
113
  except Exception as e:
114
- gr.Warning("Error attempting to calculate new dimensions")
115
- return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
116
 
117
- def get_duration(image,
118
- prompt,
119
- height,
120
- width,
121
- duration_seconds,
122
- sampling_steps,
123
- guide_scale,
124
- shift,
125
- seed,
126
- progress):
127
- """Calculate dynamic GPU duration based on parameters."""
128
- return sampling_steps * 15
129
 
130
- # --- 2. Gradio Inference Function ---
131
- @spaces.GPU(duration=get_duration)
132
- def generate_video(
133
- image,
134
- prompt,
135
- height,
136
- width,
137
- duration_seconds,
138
- sampling_steps=38,
139
- guide_scale=cfg.sample_guide_scale,
140
- shift=cfg.sample_shift,
141
- seed=42,
142
- progress=gr.Progress(track_tqdm=True)
143
- ):
144
  """
145
- Generate a video from text prompt and optional image using the Wan 2.2 TI2V model.
146
-
147
- Args:
148
- image: Optional input image (numpy array) for image-to-video generation
149
- prompt: Text prompt describing the desired video
150
- height: Target video height in pixels
151
- width: Target video width in pixels
152
- duration_seconds: Desired video duration in seconds
153
- sampling_steps: Number of denoising steps for video generation
154
- guide_scale: Guidance scale for classifier-free guidance
155
- shift: Sample shift parameter for the model
156
- seed: Random seed for reproducibility (-1 for random)
157
- progress: Gradio progress tracker
158
-
159
- Returns:
160
- Path to the generated video file
161
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  if seed == -1:
163
  seed = random.randint(0, sys.maxsize)
 
164
 
165
- # Ensure dimensions are multiples of MOD_VALUE
 
166
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
167
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
168
 
169
- input_image = None
170
- if image is not None:
171
- input_image = Image.fromarray(image).convert("RGB")
172
- # Resize image to match target dimensions
173
- input_image = input_image.resize((target_w, target_h))
174
-
175
- # Calculate number of frames based on duration
176
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
 
177
 
178
- # Create size string for the pipeline
179
- size_str = f"{target_h}*{target_w}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
 
 
 
 
181
  video_tensor = pipeline.generate(
182
  input_prompt=prompt,
183
- img=input_image, # Pass None for T2V, Image for I2V
184
  size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
185
  max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
186
- frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
187
  shift=shift,
188
  sample_solver='unipc',
189
  sampling_steps=int(sampling_steps),
@@ -192,83 +126,69 @@ def generate_video(
192
  offload_model=True
193
  )
194
 
195
- # Save the video to a temporary file
 
 
 
 
 
 
 
 
196
  video_path = cache_video(
197
- tensor=video_tensor[None], # Add a batch dimension
198
- save_file=None, # cache_video will create a temp file
199
  fps=cfg.sample_fps,
200
  normalize=True,
201
  value_range=(-1, 1)
202
  )
 
 
 
203
  del video_tensor
204
  gc.collect()
205
- return video_path
206
-
207
 
208
- # --- 3. Gradio Interface ---
209
- css = ".gradio-container {max-width: 1100px !important; margin: 0 auto} #output_video {height: 500px;} #input_image {height: 500px;}"
210
 
211
- with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
212
- gr.Markdown("# Wan 2.2 TI2V 5B")
213
- gr.Markdown("generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model**,[[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B),[[paper]](https://arxiv.org/abs/2503.20314)")
214
 
215
- with gr.Row():
216
- with gr.Column(scale=2):
217
- image_input = gr.Image(type="numpy", label="Optional (blank = text-to-image)", elem_id="input_image")
218
- prompt_input = gr.Textbox(label="Prompt", value="A beautiful waterfall in a lush jungle, cinematic.", lines=3)
219
- duration_input = gr.Slider(
220
- minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
221
- maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
222
- step=0.1,
223
- value=2.0,
224
- label="Duration (seconds)",
225
- info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
226
- )
227
-
228
- with gr.Accordion("Advanced Settings", open=False):
229
- with gr.Row():
230
- height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
231
- width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
232
- steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
233
- scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
234
- shift_input = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
235
- seed_input = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
236
 
237
- with gr.Column(scale=2):
238
- video_output = gr.Video(label="Generated Video", elem_id="output_video")
239
- run_button = gr.Button("Generate Video", variant="primary")
240
-
241
- # Add image upload handler
242
- image_input.upload(
243
- fn=handle_image_upload_for_dims_wan,
244
- inputs=[image_input, height_input, width_input],
245
- outputs=[height_input, width_input]
246
  )
247
 
248
- image_input.clear(
249
- fn=handle_image_upload_for_dims_wan,
250
- inputs=[image_input, height_input, width_input],
251
- outputs=[height_input, width_input]
 
252
  )
253
-
254
- example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
255
- gr.Examples(
256
- examples=[
257
- [example_image_path, "The cat removes the glasses from its eyes.", 1088, 800, 1.5],
258
- [None, "A cinematic shot of a boat sailing on a calm sea at sunset.", 704, 1280, 2.0],
259
- [None, "Drone footage flying over a futuristic city with flying cars.", 704, 1280, 2.0],
260
- ],
261
- inputs=[image_input, prompt_input, height_input, width_input, duration_input],
262
- outputs=video_output,
263
- fn=generate_video,
264
- cache_examples="lazy",
265
  )
266
 
267
- run_button.click(
268
- fn=generate_video,
269
- inputs=[image_input, prompt_input, height_input, width_input, duration_input, steps_input, scale_input, shift_input, seed_input],
270
- outputs=video_output
271
- )
 
 
 
 
 
 
 
272
 
273
  if __name__ == "__main__":
274
- demo.launch(mcp_server=True)
 
1
  import os
2
  import sys
3
+ import argparse
4
+ import random
5
+ import numpy as np
 
6
  import torch
7
  from huggingface_hub import snapshot_download
8
  from PIL import Image
 
 
 
 
 
 
 
 
9
  import gc
10
 
11
+ # 将当前文件所在目录添加到 Python 路径中
12
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # 'wan' 库中导入所需模块
15
+ import wan
16
+ from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS
17
+ from wan.utils.utils import cache_video
18
 
19
+ # --- 1. 模型下载器 ---
 
 
 
 
 
 
20
 
21
+ def download_models():
22
  """
23
+ Hugging Face Hub 下载并缓存所需的模型。
 
 
 
 
 
 
 
 
24
  """
25
+ repo_id = "Wan-AI/Wan2.2-TI2V-5B"
26
+ print(f"正在为 {repo_id} 下载模型检查点...")
27
  try:
28
+ ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
29
+ print(f"✅ 模型成功下载到: {ckpt_dir}")
 
 
 
 
 
 
 
 
 
 
30
  except Exception as e:
31
+ print(f" 下载模型时出错: {e}")
32
+ sys.exit(1)
33
 
34
+ # --- 2. 视频生成函数 ---
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def generate_video_cli(prompt: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  """
38
+ 使用命令行设置,根据文本提示生成视频。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  """
40
+ print("🎬 开始视频生成流程...")
41
+
42
+ # --- 设置 ---
43
+ print("正在加载模型配置...")
44
+ repo_id = "Wan-AI/Wan2.2-TI2V-5B"
45
+ # 确保模型已下载,否则立即下载。
46
+ try:
47
+ # snapshot_download 会检查本地缓存,如果已存在则不会重复下载
48
+ ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
49
+ except Exception as e:
50
+ print(f"❌ 无法找到或下载模型。请先运行 `python app.py --downloader`。")
51
+ print(f"错误详情: {e}")
52
+ sys.exit(1)
53
+
54
+ print(f"使用来自 {ckpt_dir} 的检查点")
55
+
56
+ TASK_NAME = 'ti2v-5B'
57
+ cfg = WAN_CONFIGS[TASK_NAME]
58
+
59
+ # --- 生成参数 (使用原脚本中的默认值) ---
60
+ height = 704
61
+ width = 1280
62
+ duration_seconds = 2.0
63
+ sampling_steps = 38
64
+ guide_scale = cfg.sample_guide_scale
65
+ shift = cfg.sample_shift
66
+ seed = -1 # -1 代表随机种子
67
+ image = None # 当前命令行版本不处理图像输入
68
+
69
+ # --- 处理 ---
70
  if seed == -1:
71
  seed = random.randint(0, sys.maxsize)
72
+ print(f"使用随机种子: {seed}")
73
 
74
+ # 确保尺寸有效
75
+ MOD_VALUE = 32
76
  target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
77
  target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
78
 
79
+ # 计算帧数
80
+ FIXED_FPS = 24
81
+ MIN_FRAMES_MODEL = 8
82
+ MAX_FRAMES_MODEL = 121
 
 
 
83
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
84
+ print(f"正在生成 {num_frames} 帧 ({duration_seconds}秒 @ {FIXED_FPS}fps),分辨率为 {target_w}x{target_h}。")
85
 
86
+ # --- 初始化 Pipeline ---
87
+ print("正在初始化 WanTI2V pipeline... (可能需要一些时间)")
88
+ device = "cuda" if torch.cuda.is_available() else "cpu"
89
+ device_id = 0 if torch.cuda.is_available() else -1
90
+ if device == "cpu":
91
+ print("⚠️ 警告: 未检测到 GPU。在 CPU 上运行会非常慢。")
92
+
93
+ try:
94
+ pipeline = wan.WanTI2V(
95
+ config=cfg,
96
+ checkpoint_dir=ckpt_dir,
97
+ device_id=device_id,
98
+ rank=0,
99
+ t5_fsdp=False,
100
+ dit_fsdp=False,
101
+ use_sp=False,
102
+ t5_cpu=False,
103
+ init_on_cpu=False,
104
+ convert_model_dtype=True,
105
+ )
106
+ print("Pipeline 初始化完成。")
107
+ except Exception as e:
108
+ print(f"❌ 初始化 pipeline 失败: {e}")
109
+ sys.exit(1)
110
 
111
+ # --- 生成视频 ---
112
+ print(f"正在为提示词生成视频: '{prompt}'")
113
+ size_str = f"{target_h}*{target_w}"
114
+
115
  video_tensor = pipeline.generate(
116
  input_prompt=prompt,
117
+ img=image,
118
  size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
119
  max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
120
+ frame_num=num_frames,
121
  shift=shift,
122
  sample_solver='unipc',
123
  sampling_steps=int(sampling_steps),
 
126
  offload_model=True
127
  )
128
 
129
+ # --- 保存视频 ---
130
+ print("正在保存视频...")
131
+
132
+ # 根据提示词生成一个安全的文件名
133
+ safe_prompt = "".join([c for c in prompt if c.isalnum() or c==' ']).rstrip()
134
+ safe_prompt = safe_prompt.replace(" ", "_")
135
+ output_filename = f"{safe_prompt[:50]}_{seed}.mp4"
136
+ output_path = os.path.join(os.getcwd(), output_filename) #保存在当前工作目录
137
+
138
  video_path = cache_video(
139
+ tensor=video_tensor[None],
140
+ save_file=output_path, # 指定保存路径
141
  fps=cfg.sample_fps,
142
  normalize=True,
143
  value_range=(-1, 1)
144
  )
145
+
146
+ # --- 清理 ---
147
+ del pipeline
148
  del video_tensor
149
  gc.collect()
150
+ if torch.cuda.is_available():
151
+ torch.cuda.empty_cache()
152
 
153
+ print(f"✅ 视频生成完成!已保存至: {video_path}")
 
154
 
 
 
 
155
 
156
+ # --- 3. 主执行模块 ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ def main():
159
+ """
160
+ 解析命令行参数并运行相应的功能。
161
+ """
162
+ parser = argparse.ArgumentParser(
163
+ description="Wan 2.2 TI2V-5B 命令行工具。用于从文本生成视频或下载模型。",
164
+ formatter_class=argparse.RawTextHelpFormatter
 
 
165
  )
166
 
167
+ parser.add_argument(
168
+ '--prompt',
169
+ nargs='+',
170
+ type=str,
171
+ help="用于视频生成的文本提示词。\n示例: --prompt A beautiful waterfall"
172
  )
173
+
174
+ parser.add_argument(
175
+ '--downloader',
176
+ action='store_true',
177
+ help="如果指定此参数,将只下载所需的模型然后退出。"
 
 
 
 
 
 
 
178
  )
179
 
180
+ args = parser.parse_args()
181
+
182
+ if args.downloader:
183
+ download_models()
184
+ elif args.prompt:
185
+ # 将单词列表合并成一个完整的提示词字符串
186
+ # 这能正确处理 'prompt text' 和 "prompt text" 以及 prompt text
187
+ prompt_text = " ".join(args.prompt)
188
+ generate_video_cli(prompt_text)
189
+ else:
190
+ print("未指定操作。请输入 --prompt 或使用 --downloader 标志。")
191
+ parser.print_help()
192
 
193
  if __name__ == "__main__":
194
+ main()