keisanmono commited on
Commit
af68d8f
·
1 Parent(s): 1986a1b

优化生成速度

Browse files
Files changed (3) hide show
  1. simple_app.py +20 -9
  2. wan/image2video.py +4 -0
  3. wan/text2video.py +4 -0
simple_app.py CHANGED
@@ -13,9 +13,10 @@ snapshot_download(
13
 
14
  SIZE_OPTIONS = ["480*832", "832*480"] # 分辨率选项
15
  FRAME_NUM_OPTIONS = [10, 20, 30, 40, 50, 60, 81] # 帧数选项
 
16
 
17
 
18
- def infer(prompt, video_size, frame_num, progress=gr.Progress(track_tqdm=True)): # 添加 video_size, frame_num 参数
19
  # Configuration:
20
  total_process_steps = 11 # Total INFO messages expected
21
  irrelevant_steps = 4 # First 4 INFO messages are ignored
@@ -51,10 +52,13 @@ def infer(prompt, video_size, frame_num, progress=gr.Progress(track_tqdm=True)):
51
  "generate", # using -u for unbuffered output
52
  "--task",
53
  "t2v-1.3B",
 
54
  "--size",
55
- video_size, # 使用 WebUI 传递的分辨率参数
56
  "--frame_num",
57
- str(frame_num), # 使用 WebUI 传递的帧数参数 (注意转为字符串)
 
 
58
  "--ckpt_dir",
59
  "./Wan2.1-T2V-1.3B",
60
  "--t5_cpu",
@@ -206,23 +210,30 @@ with gr.Blocks() as demo:
206
 
207
  video_size_dropdown = gr.Dropdown(
208
  choices=SIZE_OPTIONS,
209
- value=SIZE_OPTIONS[0], # 默认选择第一个分辨率
210
  label="Video Size (Resolution)",
211
  )
212
  frame_num_slider = gr.Slider(
213
- minimum=FRAME_NUM_OPTIONS[0], # 最小帧数
214
- maximum=FRAME_NUM_OPTIONS[-1], # 最大帧数
215
- value=FRAME_NUM_OPTIONS[2], # 默认帧数 (例如,选择列表的第三个)
216
- step=1, # 步长为 1
217
  label="Frame Number (Video Length)",
218
  )
 
 
 
 
 
 
 
219
 
220
  submit_btn = gr.Button("Submit")
221
  video_res = gr.Video(label="Generated Video")
222
 
223
  submit_btn.click(
224
  fn=infer,
225
- inputs=[prompt, video_size_dropdown, frame_num_slider], # inputs 添加 video_size_dropdown, frame_num_slider
226
  outputs=[video_res],
227
  )
228
 
 
13
 
14
  SIZE_OPTIONS = ["480*832", "832*480"] # 分辨率选项
15
  FRAME_NUM_OPTIONS = [10, 20, 30, 40, 50, 60, 81] # 帧数选项
16
+ SAMPLING_STEPS_OPTIONS = [5, 10, 15, 20, 25, 30, 40, 50] # 采样步数选项
17
 
18
 
19
+ def infer(prompt, video_size, frame_num, sampling_steps, progress=gr.Progress(track_tqdm=True)):
20
  # Configuration:
21
  total_process_steps = 11 # Total INFO messages expected
22
  irrelevant_steps = 4 # First 4 INFO messages are ignored
 
52
  "generate", # using -u for unbuffered output
53
  "--task",
54
  "t2v-1.3B",
55
+ "--fp16", # Enable FP16 for acceleration
56
  "--size",
57
+ video_size,
58
  "--frame_num",
59
+ str(frame_num),
60
+ "--sample_steps",
61
+ str(sampling_steps), # Add sampling steps
62
  "--ckpt_dir",
63
  "./Wan2.1-T2V-1.3B",
64
  "--t5_cpu",
 
210
 
211
  video_size_dropdown = gr.Dropdown(
212
  choices=SIZE_OPTIONS,
213
+ value=SIZE_OPTIONS[0],
214
  label="Video Size (Resolution)",
215
  )
216
  frame_num_slider = gr.Slider(
217
+ minimum=FRAME_NUM_OPTIONS[0],
218
+ maximum=FRAME_NUM_OPTIONS[-1],
219
+ value=FRAME_NUM_OPTIONS[2],
220
+ step=1,
221
  label="Frame Number (Video Length)",
222
  )
223
+ sampling_steps_slider = gr.Slider(
224
+ minimum=SAMPLING_STEPS_OPTIONS[0],
225
+ maximum=SAMPLING_STEPS_OPTIONS[-1],
226
+ value=SAMPLING_STEPS_OPTIONS[1], # Default to 10 steps
227
+ step=1,
228
+ label="Sampling Steps (Fewer steps = Faster, Lower quality)",
229
+ )
230
 
231
  submit_btn = gr.Button("Submit")
232
  video_res = gr.Video(label="Generated Video")
233
 
234
  submit_btn.click(
235
  fn=infer,
236
+ inputs=[prompt, video_size_dropdown, frame_num_slider, sampling_steps_slider],
237
  outputs=[video_res],
238
  )
239
 
wan/image2video.py CHANGED
@@ -123,6 +123,10 @@ class WanI2V:
123
  else:
124
  if not init_on_cpu:
125
  self.model.to(self.device)
 
 
 
 
126
 
127
  self.sample_neg_prompt = config.sample_neg_prompt
128
 
 
123
  else:
124
  if not init_on_cpu:
125
  self.model.to(self.device)
126
+ try:
127
+ self.model.enable_xformers_memory_efficient_attention()
128
+ except Exception as e:
129
+ logging.warning(f"Could not enable xformers memory efficient attention: {e}")
130
 
131
  self.sample_neg_prompt = config.sample_neg_prompt
132
 
wan/text2video.py CHANGED
@@ -104,6 +104,10 @@ class WanT2V:
104
  self.model = shard_fn(self.model)
105
  else:
106
  self.model.to(self.device)
 
 
 
 
107
 
108
  self.sample_neg_prompt = config.sample_neg_prompt
109
 
 
104
  self.model = shard_fn(self.model)
105
  else:
106
  self.model.to(self.device)
107
+ try:
108
+ self.model.enable_xformers_memory_efficient_attention()
109
+ except Exception as e:
110
+ logging.warning(f"Could not enable xformers memory efficient attention: {e}")
111
 
112
  self.sample_neg_prompt = config.sample_neg_prompt
113