multimodalart HF Staff commited on
Commit
8575388
·
verified ·
1 Parent(s): c103ac7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -60
app.py CHANGED
@@ -60,14 +60,14 @@ pipe.set_adapters(["causvid_lora"], adapter_weights=[1.0])
60
  MOD_VALUE = 32
61
  MOD_VALUE_H = MOD_VALUE_W = MOD_VALUE
62
 
63
- DEFAULT_H_SLIDER_VALUE = 512
64
- DEFAULT_W_SLIDER_VALUE = 896
65
 
66
  # New fixed max_area for the calculation formula
67
- NEW_FORMULA_MAX_AREA = float(480 * 832)
68
 
69
  SLIDER_MIN_H = 128
70
- SLIDER_MAX_H = 896
71
  SLIDER_MIN_W = 128
72
  SLIDER_MAX_W = 896
73
 
@@ -82,36 +82,22 @@ def _calculate_new_dimensions_wan(pil_image: Image.Image, mod_val: int, calculat
82
  return default_h, default_w
83
 
84
  aspect_ratio = orig_h / orig_w
85
-
86
- # New calculation logic as per user request:
87
- # height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
88
- # width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
89
-
90
- # Calculate sqrt terms
91
  sqrt_h_term = np.sqrt(calculation_max_area * aspect_ratio)
92
  sqrt_w_term = np.sqrt(calculation_max_area / aspect_ratio)
93
 
94
- # Apply the formula: round(sqrt_term) then floor_division by mod_val, then multiply by mod_val
95
  calc_h = round(sqrt_h_term) // mod_val * mod_val
96
  calc_w = round(sqrt_w_term) // mod_val * mod_val
97
 
98
- # Ensure calculated dimensions are at least mod_val (since round(...) // mod_val * mod_val can yield 0 if round(sqrt_term) < mod_val)
99
  calc_h = mod_val if calc_h < mod_val else calc_h
100
  calc_w = mod_val if calc_w < mod_val else calc_w
101
 
102
- # Determine effective min/max dimensions from slider limits, ensuring they are multiples of mod_val.
103
- # Slider min values (min_slider_h, min_slider_w) are assumed to be multiples of mod_val.
104
  effective_min_h = min_slider_h
105
  effective_min_w = min_slider_w
106
 
107
- # Slider max values (max_slider_h, max_slider_w) might not be multiples of mod_val.
108
- # The actual maximum value a slider can output is (its_max_limit // mod_val) * mod_val.
109
  effective_max_h_from_slider = (max_slider_h // mod_val) * mod_val
110
  effective_max_w_from_slider = (max_slider_w // mod_val) * mod_val
111
-
112
- # Clip calc_h and calc_w (which are already multiples of mod_val)
113
- # to the effective slider range (which are also multiples of mod_val).
114
- # The results (new_h, new_w) will therefore also be multiples of mod_val.
115
  new_h = int(np.clip(calc_h, effective_min_h, effective_max_h_from_slider))
116
  new_w = int(np.clip(calc_w, effective_min_w, effective_max_w_from_slider))
117
 
@@ -144,32 +130,43 @@ def handle_image_upload_for_dims_wan(uploaded_pil_image: Image.Image | None, cur
144
  # --- Gradio Interface Function ---
145
  @spaces.GPU
146
  def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
147
- height: int, width: int, num_frames: int,
148
- guidance_scale: float, steps: int, fps_for_conditioning_and_export: int,
149
- progress=gr.Progress(track_tqdm=True)):
150
  if input_image is None:
151
  raise gr.Error("Please upload an input image.")
152
 
 
 
 
 
 
153
  logger.info("Starting video generation...")
154
  logger.info(f" Input Image: Uploaded (Original size: {input_image.size if input_image else 'N/A'})")
155
  logger.info(f" Prompt: {prompt}")
156
  logger.info(f" Negative Prompt: {negative_prompt if negative_prompt else 'None'}")
157
  logger.info(f" Target Output Height: {height}, Target Output Width: {width}")
158
- logger.info(f" Num Frames: {num_frames}, FPS for conditioning & export: {fps_for_conditioning_and_export}")
159
- logger.info(f" Guidance Scale: {guidance_scale}, Steps: {steps}")
160
 
161
  target_height = int(height)
162
  target_width = int(width)
163
- num_frames = int(num_frames)
164
- fps_val = int(fps_for_conditioning_and_export)
165
  guidance_scale_val = float(guidance_scale)
166
  steps_val = int(steps)
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  # Ensure dimensions are compatible.
169
- # With the updated _calculate_new_dimensions_wan, height and width from sliders
170
- # (after image upload auto-adjustment) should already be multiples of MOD_VALUE.
171
- # This block acts as a safeguard if values come from direct slider interaction
172
- # before an image upload, or if something unexpected happens.
173
  if target_height % MOD_VALUE_H != 0:
174
  logger.warning(f"Height {target_height} is not a multiple of {MOD_VALUE_H}. Adjusting...")
175
  target_height = (target_height // MOD_VALUE_H) * MOD_VALUE_H
@@ -177,7 +174,6 @@ def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
177
  logger.warning(f"Width {target_width} is not a multiple of {MOD_VALUE_W}. Adjusting...")
178
  target_width = (target_width // MOD_VALUE_W) * MOD_VALUE_W
179
 
180
- # Ensure minimum size (should already be handled by _calculate_new_dimensions_wan and slider mins)
181
  target_height = max(MOD_VALUE_H, target_height if target_height > 0 else MOD_VALUE_H)
182
  target_width = max(MOD_VALUE_W, target_width if target_width > 0 else MOD_VALUE_W)
183
 
@@ -192,17 +188,16 @@ def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
192
  negative_prompt=negative_prompt,
193
  height=target_height,
194
  width=target_width,
195
- num_frames=num_frames,
196
  guidance_scale=guidance_scale_val,
197
  num_inference_steps=steps_val,
198
- generator=torch.Generator(device="cuda").manual_seed(0) # Consider making seed configurable
199
  ).frames[0]
200
 
201
- # Using a temporary file for video export
202
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
203
  video_path = tmpfile.name
204
 
205
- export_to_video(output_frames_list, video_path, fps=fps_val)
206
  logger.info(f"Video successfully generated and saved to {video_path}")
207
  return video_path
208
 
@@ -213,19 +208,14 @@ penguin_image_url = "https://huggingface.co/datasets/huggingface/documentation-i
213
 
214
  with gr.Blocks() as demo:
215
  gr.Markdown(f"""
216
- # Image-to-Video with Wan 2.1 I2V (14B) + CausVid LoRA
217
- Powered by `diffusers` and `{MODEL_ID}`.
218
- Model is loaded into memory when the app starts. This might take a few minutes.
219
- Ensure you have a GPU with sufficient VRAM (e.g., ~24GB+ for these default settings).
220
- Output Height and Width will be multiples of **{MOD_VALUE}**.
221
- Uploading an image will suggest dimensions based on its aspect ratio and a pre-defined target pixel area ({NEW_FORMULA_MAX_AREA:.0f} pixels),
222
- clamped to slider limits.
223
  """)
224
  with gr.Row():
225
  with gr.Column():
226
  input_image_component = gr.Image(type="pil", label="Input Image (will be resized to target H/W)")
227
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v, lines=3)
228
-
 
229
  with gr.Accordion("Advanced Settings", open=False):
230
  negative_prompt_input = gr.Textbox(
231
  label="Negative Prompt (Optional)",
@@ -235,41 +225,35 @@ with gr.Blocks() as demo:
235
  with gr.Row():
236
  height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
237
  width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
238
- with gr.Row():
239
- num_frames_input = gr.Slider(minimum=8, maximum=81, step=1, value=41, label="Number of Frames") # Max 81 for this model
240
- fps_input = gr.Slider(minimum=5, maximum=30, step=1, value=24, label="FPS (for conditioning & export)")
241
- steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps") # WanI2V is good with few steps
242
- guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale") # Low CFG usually better for I2V
243
 
244
  generate_button = gr.Button("Generate Video", variant="primary")
245
 
246
  with gr.Column():
247
  video_output = gr.Video(label="Generated Video", interactive=False)
248
 
249
- # Connect image upload to dimension auto-adjustment
250
  input_image_component.upload(
251
  fn=handle_image_upload_for_dims_wan,
252
- inputs=[input_image_component, height_input, width_input], # Pass current slider values for fallback on error
253
  outputs=[height_input, width_input]
254
  )
255
- # Also trigger on clear, though handle_image_upload_for_dims_wan handles None input
256
  input_image_component.clear(
257
  fn=handle_image_upload_for_dims_wan,
258
  inputs=[input_image_component, height_input, width_input],
259
  outputs=[height_input, width_input]
260
  )
261
 
262
-
263
  inputs_for_click_and_examples = [
264
  input_image_component,
265
  prompt_input,
266
  negative_prompt_input,
267
  height_input,
268
  width_input,
269
- num_frames_input,
270
  guidance_scale_input,
271
- steps_slider,
272
- fps_input
273
  ]
274
 
275
  generate_button.click(
@@ -280,13 +264,13 @@ with gr.Blocks() as demo:
280
 
281
  gr.Examples(
282
  examples=[
283
- [penguin_image_url, "a penguin playfully dancing in the snow, Antarctica", default_negative_prompt, DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, 41, 1.0, 4, 24],
284
- ["https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0001.jpg", "the frog jumps around", default_negative_prompt, 384, 640, 60, 1.0, 4, 24],
285
  ],
286
- inputs=inputs_for_click_and_examples,
287
  outputs=video_output,
288
  fn=generate_video,
289
- cache_examples="lazy"
290
  )
291
 
292
  if __name__ == "__main__":
 
60
  MOD_VALUE = 32
61
  MOD_VALUE_H = MOD_VALUE_W = MOD_VALUE
62
 
63
+ DEFAULT_H_SLIDER_VALUE = 512
64
+ DEFAULT_W_SLIDER_VALUE = 896
65
 
66
  # New fixed max_area for the calculation formula
67
+ NEW_FORMULA_MAX_AREA = float(480 * 832)
68
 
69
  SLIDER_MIN_H = 128
70
+ SLIDER_MAX_H = 896
71
  SLIDER_MIN_W = 128
72
  SLIDER_MAX_W = 896
73
 
 
82
  return default_h, default_w
83
 
84
  aspect_ratio = orig_h / orig_w
85
+
 
 
 
 
 
86
  sqrt_h_term = np.sqrt(calculation_max_area * aspect_ratio)
87
  sqrt_w_term = np.sqrt(calculation_max_area / aspect_ratio)
88
 
 
89
  calc_h = round(sqrt_h_term) // mod_val * mod_val
90
  calc_w = round(sqrt_w_term) // mod_val * mod_val
91
 
 
92
  calc_h = mod_val if calc_h < mod_val else calc_h
93
  calc_w = mod_val if calc_w < mod_val else calc_w
94
 
 
 
95
  effective_min_h = min_slider_h
96
  effective_min_w = min_slider_w
97
 
 
 
98
  effective_max_h_from_slider = (max_slider_h // mod_val) * mod_val
99
  effective_max_w_from_slider = (max_slider_w // mod_val) * mod_val
100
+
 
 
 
101
  new_h = int(np.clip(calc_h, effective_min_h, effective_max_h_from_slider))
102
  new_w = int(np.clip(calc_w, effective_min_w, effective_max_w_from_slider))
103
 
 
130
  # --- Gradio Interface Function ---
131
  @spaces.GPU
132
  def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
133
+ height: int, width: int, duration_seconds: float, # Changed from num_frames
134
+ guidance_scale: float, steps: int,
135
+ progress=gr.Progress(track_tqdm=True)): # Removed fps_for_conditioning_and_export
136
  if input_image is None:
137
  raise gr.Error("Please upload an input image.")
138
 
139
+ # Constants for frame calculation
140
+ FIXED_FPS = 24
141
+ MIN_FRAMES_MODEL = 8 # Based on original num_frames_input slider min
142
+ MAX_FRAMES_MODEL = 81 # Based on original num_frames_input slider max
143
+
144
  logger.info("Starting video generation...")
145
  logger.info(f" Input Image: Uploaded (Original size: {input_image.size if input_image else 'N/A'})")
146
  logger.info(f" Prompt: {prompt}")
147
  logger.info(f" Negative Prompt: {negative_prompt if negative_prompt else 'None'}")
148
  logger.info(f" Target Output Height: {height}, Target Output Width: {width}")
 
 
149
 
150
  target_height = int(height)
151
  target_width = int(width)
152
+ # duration_seconds is already float
 
153
  guidance_scale_val = float(guidance_scale)
154
  steps_val = int(steps)
155
 
156
+ # Calculate number of frames based on duration and fixed FPS
157
+ num_frames_for_pipeline = int(round(duration_seconds * FIXED_FPS))
158
+ # Clamp num_frames to be within model's supported range
159
+ num_frames_for_pipeline = max(MIN_FRAMES_MODEL, min(MAX_FRAMES_MODEL, num_frames_for_pipeline))
160
+ # Ensure at least MIN_FRAMES_MODEL if rounding leads to a very small number (or zero)
161
+ if num_frames_for_pipeline < MIN_FRAMES_MODEL:
162
+ num_frames_for_pipeline = MIN_FRAMES_MODEL
163
+
164
+ logger.info(f" Duration: {duration_seconds:.1f}s, Fixed FPS (conditioning & export): {FIXED_FPS}")
165
+ logger.info(f" Calculated Num Frames: {num_frames_for_pipeline} (clamped to [{MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL}])")
166
+ logger.info(f" Guidance Scale: {guidance_scale_val}, Steps: {steps_val}")
167
+
168
+
169
  # Ensure dimensions are compatible.
 
 
 
 
170
  if target_height % MOD_VALUE_H != 0:
171
  logger.warning(f"Height {target_height} is not a multiple of {MOD_VALUE_H}. Adjusting...")
172
  target_height = (target_height // MOD_VALUE_H) * MOD_VALUE_H
 
174
  logger.warning(f"Width {target_width} is not a multiple of {MOD_VALUE_W}. Adjusting...")
175
  target_width = (target_width // MOD_VALUE_W) * MOD_VALUE_W
176
 
 
177
  target_height = max(MOD_VALUE_H, target_height if target_height > 0 else MOD_VALUE_H)
178
  target_width = max(MOD_VALUE_W, target_width if target_width > 0 else MOD_VALUE_W)
179
 
 
188
  negative_prompt=negative_prompt,
189
  height=target_height,
190
  width=target_width,
191
+ num_frames=num_frames_for_pipeline, # Use calculated and clamped num_frames
192
  guidance_scale=guidance_scale_val,
193
  num_inference_steps=steps_val,
194
+ generator=torch.Generator(device="cuda").manual_seed(0)
195
  ).frames[0]
196
 
 
197
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
198
  video_path = tmpfile.name
199
 
200
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS) # Use fixed FPS for export
201
  logger.info(f"Video successfully generated and saved to {video_path}")
202
  return video_path
203
 
 
208
 
209
  with gr.Blocks() as demo:
210
  gr.Markdown(f"""
211
+ # Fast 4 steps Wan 2.1 I2V (14B) with CausVid LoRA
 
 
 
 
 
 
212
  """)
213
  with gr.Row():
214
  with gr.Column():
215
  input_image_component = gr.Image(type="pil", label="Input Image (will be resized to target H/W)")
216
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v, lines=3)
217
+ duration_seconds_input = gr.Slider(minimum=0.4, maximum=3.3, step=0.1, value=1.7, label="Duration (seconds)", info="The CausVid LoRA was trained on 24fps, Wan has 81 maximum frames limit, limiting the maximum to 3.3s")
218
+
219
  with gr.Accordion("Advanced Settings", open=False):
220
  negative_prompt_input = gr.Textbox(
221
  label="Negative Prompt (Optional)",
 
225
  with gr.Row():
226
  height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
227
  width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
228
+
229
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
230
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale", visible=False)
 
 
231
 
232
  generate_button = gr.Button("Generate Video", variant="primary")
233
 
234
  with gr.Column():
235
  video_output = gr.Video(label="Generated Video", interactive=False)
236
 
 
237
  input_image_component.upload(
238
  fn=handle_image_upload_for_dims_wan,
239
+ inputs=[input_image_component, height_input, width_input],
240
  outputs=[height_input, width_input]
241
  )
 
242
  input_image_component.clear(
243
  fn=handle_image_upload_for_dims_wan,
244
  inputs=[input_image_component, height_input, width_input],
245
  outputs=[height_input, width_input]
246
  )
247
 
 
248
  inputs_for_click_and_examples = [
249
  input_image_component,
250
  prompt_input,
251
  negative_prompt_input,
252
  height_input,
253
  width_input,
254
+ duration_seconds_input,
255
  guidance_scale_input,
256
+ steps_slider
 
257
  ]
258
 
259
  generate_button.click(
 
264
 
265
  gr.Examples(
266
  examples=[
267
+ [penguin_image_url, "a penguin playfully dancing in the snow, Antarctica", default_negative_prompt, 896, 512, 2, 1.0, 4],
268
+ ["https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0001.jpg", "the frog jumps around", default_negative_prompt, 448, 832, 2, 1.0, 4],
269
  ],
270
+ inputs=inputs_for_click_and_examples,
271
  outputs=video_output,
272
  fn=generate_video,
273
+ cache_examples="lazy"
274
  )
275
 
276
  if __name__ == "__main__":