multimodalart HF Staff commited on
Commit
c103ac7
·
verified ·
1 Parent(s): cb85dbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -43
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import torch
2
  from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
3
- from diffusers.utils import export_to_video, load_image
4
  from transformers import CLIPVisionModel
5
  import gradio as gr
6
  import tempfile
7
- import os
8
- import spaces # Assuming this is for Hugging Face Spaces GPU decorator
9
  from huggingface_hub import hf_hub_download
10
  import logging
11
  import numpy as np
@@ -25,21 +24,21 @@ logger.info(f"Loading Image Encoder for {MODEL_ID}...")
25
  image_encoder = CLIPVisionModel.from_pretrained(
26
  MODEL_ID,
27
  subfolder="image_encoder",
28
- torch_dtype=torch.float32
29
  )
30
 
31
  logger.info(f"Loading VAE for {MODEL_ID}...")
32
  vae = AutoencoderKLWan.from_pretrained(
33
  MODEL_ID,
34
  subfolder="vae",
35
- torch_dtype=torch.float32
36
  )
37
  logger.info(f"Loading Pipeline {MODEL_ID}...")
38
  pipe = WanImageToVideoPipeline.from_pretrained(
39
  MODEL_ID,
40
  vae=vae,
41
  image_encoder=image_encoder,
42
- torch_dtype=torch.bfloat16
43
  )
44
  flow_shift = 8.0
45
  pipe.scheduler = UniPCMultistepScheduler.from_config(
@@ -57,44 +56,68 @@ pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
57
  logger.info("Setting LoRA adapter...")
58
  pipe.set_adapters(["causvid_lora"], adapter_weights=[1.0])
59
 
60
- MOD_VALUE = 128
 
61
  MOD_VALUE_H = MOD_VALUE_W = MOD_VALUE
62
 
63
- DEFAULT_H_SLIDER_VALUE = 384 # (3 * 128)
64
- DEFAULT_W_SLIDER_VALUE = 640 # (5 * 128)
65
- DEFAULT_TARGET_AREA = float(DEFAULT_H_SLIDER_VALUE * DEFAULT_W_SLIDER_VALUE)
 
 
66
 
67
  SLIDER_MIN_H = 128
68
- SLIDER_MAX_H = 512
69
  SLIDER_MIN_W = 128
70
- SLIDER_MAX_W = 854
71
 
72
- def _calculate_new_dimensions_wan(pil_image: Image.Image, mod_val: int, target_area: float,
73
- min_h: int, max_h: int, min_w: int, max_w: int,
 
74
  default_h: int, default_w: int) -> tuple[int, int]:
75
  orig_w, orig_h = pil_image.size
76
 
77
- if orig_w == 0 or orig_h == 0:
78
- logger.warning("Uploaded image has zero width or height. Using default slider dimensions.")
79
  return default_h, default_w
80
 
81
  aspect_ratio = orig_h / orig_w
82
 
83
- ideal_h = np.sqrt(target_area * aspect_ratio)
84
- ideal_w = np.sqrt(target_area / aspect_ratio)
 
 
 
 
 
 
 
 
 
85
 
86
- calc_h = round(ideal_h / mod_val) * mod_val
87
- calc_w = round(ideal_w / mod_val) * mod_val
 
88
 
89
- calc_h = mod_val if calc_h < mod_val else calc_h # Ensure at least one mod_val unit
90
- calc_w = mod_val if calc_w < mod_val else calc_w # Ensure at least one mod_val unit
 
 
91
 
92
- new_h = int(np.clip(calc_h, min_h, max_h))
93
- new_w = int(np.clip(calc_w, min_w, max_w))
 
 
94
 
95
- logger.info(f"Auto-dim: Original {orig_w}x{orig_h} (AR: {aspect_ratio:.2f}). Target Area: {target_area}.")
96
- logger.info(f"Auto-dim: Ideal HxW: {ideal_h:.0f}x{ideal_w:.0f}. Rounded (step {mod_val}): {calc_h}x{calc_w}.")
97
- logger.info(f"Auto-dim: Clamped HxW: {new_h}x{new_w} (H_range:[{min_h}-{max_h}], W_range:[{min_w}-{max_w}]).")
 
 
 
 
 
 
98
 
99
  return new_h, new_w
100
 
@@ -105,8 +128,8 @@ def handle_image_upload_for_dims_wan(uploaded_pil_image: Image.Image | None, cur
105
  try:
106
  new_h, new_w = _calculate_new_dimensions_wan(
107
  uploaded_pil_image,
108
- MOD_VALUE, # Use the globally determined MOD_VALUE
109
- DEFAULT_TARGET_AREA,
110
  SLIDER_MIN_H, SLIDER_MAX_H,
111
  SLIDER_MIN_W, SLIDER_MAX_W,
112
  DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
@@ -114,11 +137,12 @@ def handle_image_upload_for_dims_wan(uploaded_pil_image: Image.Image | None, cur
114
  return gr.update(value=new_h), gr.update(value=new_w)
115
  except Exception as e:
116
  logger.error(f"Error auto-adjusting H/W from image: {e}", exc_info=True)
 
117
  return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
118
 
119
 
120
  # --- Gradio Interface Function ---
121
- @spaces.GPU # type: ignore
122
  def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
123
  height: int, width: int, num_frames: int,
124
  guidance_scale: float, steps: int, fps_for_conditioning_and_export: int,
@@ -141,16 +165,21 @@ def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
141
  guidance_scale_val = float(guidance_scale)
142
  steps_val = int(steps)
143
 
144
- # Ensure dimensions are compatible (already handled by slider steps and auto-adjustment)
 
 
 
 
145
  if target_height % MOD_VALUE_H != 0:
146
  logger.warning(f"Height {target_height} is not a multiple of {MOD_VALUE_H}. Adjusting...")
147
  target_height = (target_height // MOD_VALUE_H) * MOD_VALUE_H
148
  if target_width % MOD_VALUE_W != 0:
149
  logger.warning(f"Width {target_width} is not a multiple of {MOD_VALUE_W}. Adjusting...")
150
  target_width = (target_width // MOD_VALUE_W) * MOD_VALUE_W
151
-
152
- target_height = max(MOD_VALUE_H, target_height) # Ensure minimum size
153
- target_width = max(MOD_VALUE_W, target_width) # Ensure minimum size
 
154
 
155
 
156
  resized_image = input_image.resize((target_width, target_height))
@@ -166,9 +195,10 @@ def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
166
  num_frames=num_frames,
167
  guidance_scale=guidance_scale_val,
168
  num_inference_steps=steps_val,
169
- generator=torch.Generator(device="cuda").manual_seed(0)
170
  ).frames[0]
171
 
 
172
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
173
  video_path = tmpfile.name
174
 
@@ -187,10 +217,12 @@ with gr.Blocks() as demo:
187
  Powered by `diffusers` and `{MODEL_ID}`.
188
  Model is loaded into memory when the app starts. This might take a few minutes.
189
  Ensure you have a GPU with sufficient VRAM (e.g., ~24GB+ for these default settings).
190
- Output Height and Width must be multiples of **{MOD_VALUE}**. Uploading an image will suggest dimensions based on its aspect ratio and a target area.
 
 
191
  """)
192
  with gr.Row():
193
- with gr.Column(scale=2):
194
  input_image_component = gr.Image(type="pil", label="Input Image (will be resized to target H/W)")
195
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v, lines=3)
196
 
@@ -204,22 +236,30 @@ with gr.Blocks() as demo:
204
  height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
205
  width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
206
  with gr.Row():
207
- num_frames_input = gr.Slider(minimum=8, maximum=81, step=1, value=41, label="Number of Frames")
208
  fps_input = gr.Slider(minimum=5, maximum=30, step=1, value=24, label="FPS (for conditioning & export)")
209
- steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps")
210
- guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale")
211
 
212
  generate_button = gr.Button("Generate Video", variant="primary")
213
 
214
- with gr.Column(scale=3):
215
  video_output = gr.Video(label="Generated Video", interactive=False)
216
 
 
217
  input_image_component.upload(
 
 
 
 
 
 
218
  fn=handle_image_upload_for_dims_wan,
219
  inputs=[input_image_component, height_input, width_input],
220
  outputs=[height_input, width_input]
221
  )
222
 
 
223
  inputs_for_click_and_examples = [
224
  input_image_component,
225
  prompt_input,
@@ -240,12 +280,13 @@ with gr.Blocks() as demo:
240
 
241
  gr.Examples(
242
  examples=[
243
- [penguin_image_url, "a penguin playfully dancing in the snow, Antarctica", default_negative_prompt, DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, 25, 1.0, 4, 16],
 
244
  ],
245
  inputs=inputs_for_click_and_examples,
246
  outputs=video_output,
247
  fn=generate_video,
248
- cache_examples=False
249
  )
250
 
251
  if __name__ == "__main__":
 
1
  import torch
2
  from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
3
+ from diffusers.utils import export_to_video
4
  from transformers import CLIPVisionModel
5
  import gradio as gr
6
  import tempfile
7
+ import spaces
 
8
  from huggingface_hub import hf_hub_download
9
  import logging
10
  import numpy as np
 
24
  image_encoder = CLIPVisionModel.from_pretrained(
25
  MODEL_ID,
26
  subfolder="image_encoder",
27
+ torch_dtype=torch.float32 # Using float32 for image encoder as sometimes bfloat16/float16 can be problematic
28
  )
29
 
30
  logger.info(f"Loading VAE for {MODEL_ID}...")
31
  vae = AutoencoderKLWan.from_pretrained(
32
  MODEL_ID,
33
  subfolder="vae",
34
+ torch_dtype=torch.float32 # Using float32 for VAE for precision
35
  )
36
  logger.info(f"Loading Pipeline {MODEL_ID}...")
37
  pipe = WanImageToVideoPipeline.from_pretrained(
38
  MODEL_ID,
39
  vae=vae,
40
  image_encoder=image_encoder,
41
+ torch_dtype=torch.bfloat16 # Main pipeline can use bfloat16 for speed/memory
42
  )
43
  flow_shift = 8.0
44
  pipe.scheduler = UniPCMultistepScheduler.from_config(
 
56
  logger.info("Setting LoRA adapter...")
57
  pipe.set_adapters(["causvid_lora"], adapter_weights=[1.0])
58
 
59
+ # --- Constants for Dimension Calculation ---
60
+ MOD_VALUE = 32
61
  MOD_VALUE_H = MOD_VALUE_W = MOD_VALUE
62
 
63
+ DEFAULT_H_SLIDER_VALUE = 512
64
+ DEFAULT_W_SLIDER_VALUE = 896
65
+
66
+ # New fixed max_area for the calculation formula
67
+ NEW_FORMULA_MAX_AREA = float(480 * 832)
68
 
69
  SLIDER_MIN_H = 128
70
+ SLIDER_MAX_H = 896
71
  SLIDER_MIN_W = 128
72
+ SLIDER_MAX_W = 896
73
 
74
+ def _calculate_new_dimensions_wan(pil_image: Image.Image, mod_val: int, calculation_max_area: float,
75
+ min_slider_h: int, max_slider_h: int,
76
+ min_slider_w: int, max_slider_w: int,
77
  default_h: int, default_w: int) -> tuple[int, int]:
78
  orig_w, orig_h = pil_image.size
79
 
80
+ if orig_w <= 0 or orig_h <= 0: # Changed to <= 0 for robustness
81
+ logger.warning(f"Uploaded image has non-positive width or height ({orig_w}x{orig_h}). Using default slider dimensions.")
82
  return default_h, default_w
83
 
84
  aspect_ratio = orig_h / orig_w
85
 
86
+ # New calculation logic as per user request:
87
+ # height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
88
+ # width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
89
+
90
+ # Calculate sqrt terms
91
+ sqrt_h_term = np.sqrt(calculation_max_area * aspect_ratio)
92
+ sqrt_w_term = np.sqrt(calculation_max_area / aspect_ratio)
93
+
94
+ # Apply the formula: round(sqrt_term) then floor_division by mod_val, then multiply by mod_val
95
+ calc_h = round(sqrt_h_term) // mod_val * mod_val
96
+ calc_w = round(sqrt_w_term) // mod_val * mod_val
97
 
98
+ # Ensure calculated dimensions are at least mod_val (since round(...) // mod_val * mod_val can yield 0 if round(sqrt_term) < mod_val)
99
+ calc_h = mod_val if calc_h < mod_val else calc_h
100
+ calc_w = mod_val if calc_w < mod_val else calc_w
101
 
102
+ # Determine effective min/max dimensions from slider limits, ensuring they are multiples of mod_val.
103
+ # Slider min values (min_slider_h, min_slider_w) are assumed to be multiples of mod_val.
104
+ effective_min_h = min_slider_h
105
+ effective_min_w = min_slider_w
106
 
107
+ # Slider max values (max_slider_h, max_slider_w) might not be multiples of mod_val.
108
+ # The actual maximum value a slider can output is (its_max_limit // mod_val) * mod_val.
109
+ effective_max_h_from_slider = (max_slider_h // mod_val) * mod_val
110
+ effective_max_w_from_slider = (max_slider_w // mod_val) * mod_val
111
 
112
+ # Clip calc_h and calc_w (which are already multiples of mod_val)
113
+ # to the effective slider range (which are also multiples of mod_val).
114
+ # The results (new_h, new_w) will therefore also be multiples of mod_val.
115
+ new_h = int(np.clip(calc_h, effective_min_h, effective_max_h_from_slider))
116
+ new_w = int(np.clip(calc_w, effective_min_w, effective_max_w_from_slider))
117
+
118
+ logger.info(f"Auto-dim: Original {orig_w}x{orig_h} (AR: {aspect_ratio:.2f}). Max Area for calc: {calculation_max_area}.")
119
+ logger.info(f"Auto-dim: Sqrt terms HxW: {sqrt_h_term:.0f}x{sqrt_w_term:.0f}. Calculated (round(sqrt_term)//{mod_val}*{mod_val}): {calc_h}x{calc_w}.")
120
+ logger.info(f"Auto-dim: Clamped HxW: {new_h}x{new_w} (Effective H_range:[{effective_min_h}-{effective_max_h_from_slider}], Effective W_range:[{effective_min_w}-{effective_max_w_from_slider}]).")
121
 
122
  return new_h, new_w
123
 
 
128
  try:
129
  new_h, new_w = _calculate_new_dimensions_wan(
130
  uploaded_pil_image,
131
+ MOD_VALUE,
132
+ NEW_FORMULA_MAX_AREA, # Use the globally defined max_area for the new formula
133
  SLIDER_MIN_H, SLIDER_MAX_H,
134
  SLIDER_MIN_W, SLIDER_MAX_W,
135
  DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
 
137
  return gr.update(value=new_h), gr.update(value=new_w)
138
  except Exception as e:
139
  logger.error(f"Error auto-adjusting H/W from image: {e}", exc_info=True)
140
+ # Fallback to default slider values on error, as in the original code
141
  return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
142
 
143
 
144
  # --- Gradio Interface Function ---
145
+ @spaces.GPU
146
  def generate_video(input_image: Image.Image, prompt: str, negative_prompt: str,
147
  height: int, width: int, num_frames: int,
148
  guidance_scale: float, steps: int, fps_for_conditioning_and_export: int,
 
165
  guidance_scale_val = float(guidance_scale)
166
  steps_val = int(steps)
167
 
168
+ # Ensure dimensions are compatible.
169
+ # With the updated _calculate_new_dimensions_wan, height and width from sliders
170
+ # (after image upload auto-adjustment) should already be multiples of MOD_VALUE.
171
+ # This block acts as a safeguard if values come from direct slider interaction
172
+ # before an image upload, or if something unexpected happens.
173
  if target_height % MOD_VALUE_H != 0:
174
  logger.warning(f"Height {target_height} is not a multiple of {MOD_VALUE_H}. Adjusting...")
175
  target_height = (target_height // MOD_VALUE_H) * MOD_VALUE_H
176
  if target_width % MOD_VALUE_W != 0:
177
  logger.warning(f"Width {target_width} is not a multiple of {MOD_VALUE_W}. Adjusting...")
178
  target_width = (target_width // MOD_VALUE_W) * MOD_VALUE_W
179
+
180
+ # Ensure minimum size (should already be handled by _calculate_new_dimensions_wan and slider mins)
181
+ target_height = max(MOD_VALUE_H, target_height if target_height > 0 else MOD_VALUE_H)
182
+ target_width = max(MOD_VALUE_W, target_width if target_width > 0 else MOD_VALUE_W)
183
 
184
 
185
  resized_image = input_image.resize((target_width, target_height))
 
195
  num_frames=num_frames,
196
  guidance_scale=guidance_scale_val,
197
  num_inference_steps=steps_val,
198
+ generator=torch.Generator(device="cuda").manual_seed(0) # Consider making seed configurable
199
  ).frames[0]
200
 
201
+ # Using a temporary file for video export
202
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
203
  video_path = tmpfile.name
204
 
 
217
  Powered by `diffusers` and `{MODEL_ID}`.
218
  Model is loaded into memory when the app starts. This might take a few minutes.
219
  Ensure you have a GPU with sufficient VRAM (e.g., ~24GB+ for these default settings).
220
+ Output Height and Width will be multiples of **{MOD_VALUE}**.
221
+ Uploading an image will suggest dimensions based on its aspect ratio and a pre-defined target pixel area ({NEW_FORMULA_MAX_AREA:.0f} pixels),
222
+ clamped to slider limits.
223
  """)
224
  with gr.Row():
225
+ with gr.Column():
226
  input_image_component = gr.Image(type="pil", label="Input Image (will be resized to target H/W)")
227
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v, lines=3)
228
 
 
236
  height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
237
  width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
238
  with gr.Row():
239
+ num_frames_input = gr.Slider(minimum=8, maximum=81, step=1, value=41, label="Number of Frames") # Max 81 for this model
240
  fps_input = gr.Slider(minimum=5, maximum=30, step=1, value=24, label="FPS (for conditioning & export)")
241
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps") # WanI2V is good with few steps
242
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="Guidance Scale") # Low CFG usually better for I2V
243
 
244
  generate_button = gr.Button("Generate Video", variant="primary")
245
 
246
+ with gr.Column():
247
  video_output = gr.Video(label="Generated Video", interactive=False)
248
 
249
+ # Connect image upload to dimension auto-adjustment
250
  input_image_component.upload(
251
+ fn=handle_image_upload_for_dims_wan,
252
+ inputs=[input_image_component, height_input, width_input], # Pass current slider values for fallback on error
253
+ outputs=[height_input, width_input]
254
+ )
255
+ # Also trigger on clear, though handle_image_upload_for_dims_wan handles None input
256
+ input_image_component.clear(
257
  fn=handle_image_upload_for_dims_wan,
258
  inputs=[input_image_component, height_input, width_input],
259
  outputs=[height_input, width_input]
260
  )
261
 
262
+
263
  inputs_for_click_and_examples = [
264
  input_image_component,
265
  prompt_input,
 
280
 
281
  gr.Examples(
282
  examples=[
283
+ [penguin_image_url, "a penguin playfully dancing in the snow, Antarctica", default_negative_prompt, DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, 41, 1.0, 4, 24],
284
+ ["https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0001.jpg", "the frog jumps around", default_negative_prompt, 384, 640, 60, 1.0, 4, 24],
285
  ],
286
  inputs=inputs_for_click_and_examples,
287
  outputs=video_output,
288
  fn=generate_video,
289
+ cache_examples="lazy"
290
  )
291
 
292
  if __name__ == "__main__":