multimodalart HF Staff commited on
Commit
afdfe21
·
verified ·
1 Parent(s): 5d9e5c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -30
app.py CHANGED
@@ -1,53 +1,364 @@
1
  import torch
2
  from diffusers import AutoencoderKLWan, WanPipeline
3
  from diffusers.utils import export_to_video
 
4
  import gradio as gr
5
  import tempfile
6
  import os
7
  import spaces
8
  from huggingface_hub import hf_hub_download
 
9
 
10
- # --- Global Model Loading (runs once when the script starts) ---
11
  MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
 
 
12
 
 
 
 
13
 
14
- print(f"Loading VAE for {MODEL_ID}...")
15
- # Use float32 for VAE for stability, as recommended in some cases
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  vae = AutoencoderKLWan.from_pretrained(
17
  MODEL_ID,
18
  subfolder="vae",
19
- torch_dtype=torch.float32
20
  )
21
- print(f"Loading Pipeline {MODEL_ID}...")
22
- # Use bfloat16 for the main pipeline for memory efficiency and speed
23
  pipe = WanPipeline.from_pretrained(
24
  MODEL_ID,
25
  vae=vae,
26
- torch_dtype=torch.bfloat16
27
  )
28
- print("Moving pipeline to CUDA...")
29
  pipe.to("cuda")
30
- causvid_path = hf_hub_download(repo_id="Kijai/WanVideo_comfy", filename="Wan21_CausVid_14B_T2V_lora_rank32.safetensors")
31
- pipe.load_lora_weights(causvid_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # --- Gradio Interface Function ---
34
  @spaces.GPU
35
  def generate_video(prompt, negative_prompt, height, width, num_frames, guidance_scale, fps):
 
 
 
 
 
 
36
 
37
- print("Starting video generation...")
38
- print(f" Prompt: {prompt}")
39
- print(f" Negative Prompt: {negative_prompt if negative_prompt else 'None'}")
40
- print(f" Height: {height}, Width: {width}")
41
- print(f" Num Frames: {num_frames}, FPS: {fps}")
42
- print(f" Guidance Scale: {guidance_scale}")
43
-
44
- # Ensure height and width are multiples of 8 (common requirement for VAEs)
45
  height = (int(height) // 8) * 8
46
  width = (int(width) // 8) * 8
47
  num_frames = int(num_frames)
48
  fps = int(fps)
49
 
50
- with torch.inference_mode(): # Conserve memory
51
  output_frames_list = pipe(
52
  prompt=prompt,
53
  negative_prompt=negative_prompt,
@@ -55,35 +366,30 @@ def generate_video(prompt, negative_prompt, height, width, num_frames, guidance_
55
  width=width,
56
  num_frames=num_frames,
57
  guidance_scale=float(guidance_scale),
58
- # num_inference_steps=25 # Default is 25, can be exposed if needed
59
  ).frames
60
 
61
  if not output_frames_list or not output_frames_list[0]:
62
  raise gr.Error("Model returned empty frames. Check parameters or try a different prompt.")
63
-
64
- output_frames = output_frames_list[0] # The actual list of PIL Image frames
65
 
66
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
67
  video_path = tmpfile.name
68
-
69
  export_to_video(output_frames, video_path, fps=fps)
70
- print(f"Video successfully generated and saved to {video_path}")
71
-
72
  return video_path
73
 
74
-
75
  # --- Gradio UI Definition ---
76
  default_prompt = "A cat walks on the grass, realistic"
77
  default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
78
 
79
  with gr.Blocks() as demo:
80
  gr.Markdown(f"""
81
- # Text-to-Video with Wan 2.1 (14B)
82
  Powered by `diffusers` and `Wan-AI/{MODEL_ID}`.
83
  Model is loaded into memory when the app starts. This might take a few minutes.
84
  Ensure you have a GPU with sufficient VRAM (e.g., ~24GB+ for these default settings).
85
  """)
86
-
87
  with gr.Row():
88
  with gr.Column(scale=2):
89
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt, lines=3)
@@ -129,9 +435,8 @@ with gr.Blocks() as demo:
129
  inputs=[prompt_input, negative_prompt_input, height_input, width_input, num_frames_input, guidance_scale_input, fps_input],
130
  outputs=video_output,
131
  fn=generate_video,
132
- cache_examples=False # Caching videos can take a lot of space quickly
133
  )
134
 
135
  if __name__ == "__main__":
136
- # The share=True option will create a public temporary link if you run this on Colab or similar
137
  demo.queue().launch(share=True, debug=True)
 
1
  import torch
2
  from diffusers import AutoencoderKLWan, WanPipeline
3
  from diffusers.utils import export_to_video
4
+ from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers # Keep this if it's the base
5
  import gradio as gr
6
  import tempfile
7
  import os
8
  import spaces
9
  from huggingface_hub import hf_hub_download
10
+ import logging # For better logging
11
 
12
+ # --- Global Model Loading & LoRA Handling ---
13
  MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
14
+ LORA_REPO_ID = "Kijai/WanVideo_comfy"
15
+ LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
16
 
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
 
21
+ # This dictionary will store the manual patches extracted by the converter
22
+ MANUAL_PATCHES_STORE = {}
23
+
24
+ def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
25
+ """
26
+ Custom converter for Wan 2.1 T2V LoRA.
27
+ Separates LoRA A/B weights for PEFT and diff_b/diff for manual patching.
28
+ Stores diff_b/diff in the global MANUAL_PATCHES_STORE.
29
+ """
30
+ global MANUAL_PATCHES_STORE
31
+ MANUAL_PATCHES_STORE.clear() # Clear previous patches if any
32
+
33
+ converted_state_dict_for_peft = {}
34
+ manual_diff_patches = {}
35
+
36
+ # Strip "diffusion_model." prefix
37
+ original_state_dict = {
38
+ k[len("diffusion_model.") :]: v
39
+ for k, v in state_dict.items()
40
+ if k.startswith("diffusion_model.")
41
+ }
42
+
43
+ # --- Determine number of blocks ---
44
+ block_indices = set()
45
+ for k_orig in original_state_dict:
46
+ if "blocks." in k_orig:
47
+ try:
48
+ block_idx_str = k_orig.split("blocks.")[1].split(".")[0]
49
+ if block_idx_str.isdigit():
50
+ block_indices.add(int(block_idx_str))
51
+ except (IndexError, ValueError) as e:
52
+ logger.warning(f"Could not parse block index from key: {k_orig} due to {e}")
53
+
54
+ num_transformer_blocks = max(block_indices) + 1 if block_indices else 0
55
+ if not block_indices and any("blocks." in k for k in original_state_dict):
56
+ logger.warning("Found 'blocks.' in keys but could not determine num_transformer_blocks reliably.")
57
+
58
+
59
+ # --- Convert Transformer Blocks (blocks.0 to blocks.N-1) ---
60
+ for i in range(num_transformer_blocks):
61
+ # Self-attention (attn1 in Diffusers DiT)
62
+ for lora_key_part, diffusers_layer_name in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
63
+ orig_lora_down_key = f"blocks.{i}.self_attn.{lora_key_part}.lora_down.weight"
64
+ orig_lora_up_key = f"blocks.{i}.self_attn.{lora_key_part}.lora_up.weight"
65
+ target_base_key_peft = f"blocks.{i}.attn1.{diffusers_layer_name}"
66
+ target_base_key_manual = f"transformer.blocks.{i}.attn1.{diffusers_layer_name}"
67
+
68
+ if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
69
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
70
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
71
+
72
+ orig_diff_b_key = f"blocks.{i}.self_attn.{lora_key_part}.diff_b"
73
+ if orig_diff_b_key in original_state_dict:
74
+ manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
75
+
76
+ # Cross-attention (attn2 in Diffusers DiT)
77
+ for lora_key_part, diffusers_layer_name in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
78
+ orig_lora_down_key = f"blocks.{i}.cross_attn.{lora_key_part}.lora_down.weight"
79
+ orig_lora_up_key = f"blocks.{i}.cross_attn.{lora_key_part}.lora_up.weight"
80
+ target_base_key_peft = f"blocks.{i}.attn2.{diffusers_layer_name}"
81
+ target_base_key_manual = f"transformer.blocks.{i}.attn2.{diffusers_layer_name}"
82
+
83
+ if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
84
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
85
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
86
+
87
+ orig_diff_b_key = f"blocks.{i}.cross_attn.{lora_key_part}.diff_b"
88
+ if orig_diff_b_key in original_state_dict:
89
+ manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
90
+
91
+ # FFN
92
+ for original_ffn_idx, diffusers_ffn_path_part in zip(["0", "2"], ["net.0.proj", "net.2"]):
93
+ orig_lora_down_key = f"blocks.{i}.ffn.{original_ffn_idx}.lora_down.weight"
94
+ orig_lora_up_key = f"blocks.{i}.ffn.{original_ffn_idx}.lora_up.weight"
95
+ target_base_key_peft = f"blocks.{i}.ffn.{diffusers_ffn_path_part}"
96
+ target_base_key_manual = f"transformer.blocks.{i}.ffn.{diffusers_ffn_path_part}"
97
+
98
+ if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
99
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
100
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
101
+
102
+ orig_diff_b_key = f"blocks.{i}.ffn.{original_ffn_idx}.diff_b"
103
+ if orig_diff_b_key in original_state_dict:
104
+ manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
105
+
106
+ # Norm layers within blocks
107
+ # LoRA has `norm3.diff` and `norm3.diff_b`. Wan2.1 base DiTBlock has `norm2`.
108
+ norm3_diff_key = f"blocks.{i}.norm3.diff"
109
+ norm3_diff_b_key = f"blocks.{i}.norm3.diff_b"
110
+ target_norm_key_base_manual = f"transformer.blocks.{i}.norm2" # Diffusers DiTBlock's second norm
111
+ if norm3_diff_key in original_state_dict:
112
+ manual_diff_patches[f"{target_norm_key_base_manual}.weight"] = original_state_dict.pop(norm3_diff_key)
113
+ if norm3_diff_b_key in original_state_dict:
114
+ manual_diff_patches[f"{target_norm_key_base_manual}.bias"] = original_state_dict.pop(norm3_diff_b_key)
115
+
116
+ # Attention QK norms
117
+ for attn_type, diffusers_attn_block in zip(["self_attn", "cross_attn"], ["attn1", "attn2"]):
118
+ for norm_target_suffix in ["norm_q", "norm_k"]:
119
+ orig_norm_diff_key = f"blocks.{i}.{attn_type}.{norm_target_suffix}.diff"
120
+ target_norm_key_manual = f"transformer.blocks.{i}.{diffusers_attn_block}.{norm_target_suffix}.weight"
121
+ if orig_norm_diff_key in original_state_dict:
122
+ manual_diff_patches[target_norm_key_manual] = original_state_dict.pop(orig_norm_diff_key)
123
+
124
+ # --- Convert Non-Block Components ---
125
+ # Patch Embedding
126
+ patch_emb_diff_b_key = "patch_embedding.diff_b"
127
+ if patch_emb_diff_b_key in original_state_dict:
128
+ manual_diff_patches["transformer.patch_embedding.bias"] = original_state_dict.pop(patch_emb_diff_b_key)
129
+
130
+ # Text Embedding
131
+ for orig_idx, diffusers_linear_idx in zip(["0", "2"], ["linear_1", "linear_2"]):
132
+ orig_lora_down_key = f"text_embedding.{orig_idx}.lora_down.weight"
133
+ orig_lora_up_key = f"text_embedding.{orig_idx}.lora_up.weight"
134
+ target_base_key_peft = f"condition_embedder.text_embedder.{diffusers_linear_idx}"
135
+ target_base_key_manual = f"transformer.condition_embedder.text_embedder.{diffusers_linear_idx}"
136
+ if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
137
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
138
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
139
+ orig_diff_b_key = f"text_embedding.{orig_idx}.diff_b"
140
+ if orig_diff_b_key in original_state_dict:
141
+ manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
142
+
143
+ # Time Embedding
144
+ for orig_idx, diffusers_linear_idx in zip(["0", "2"], ["linear_1", "linear_2"]):
145
+ orig_lora_down_key = f"time_embedding.{orig_idx}.lora_down.weight"
146
+ orig_lora_up_key = f"time_embedding.{orig_idx}.lora_up.weight"
147
+ target_base_key_peft = f"condition_embedder.time_embedder.{diffusers_linear_idx}"
148
+ target_base_key_manual = f"transformer.condition_embedder.time_embedder.{diffusers_linear_idx}"
149
+ if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
150
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
151
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
152
+ orig_diff_b_key = f"time_embedding.{orig_idx}.diff_b"
153
+ if orig_diff_b_key in original_state_dict:
154
+ manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
155
+
156
+ # Time Projection
157
+ orig_lora_down_key = "time_projection.1.lora_down.weight"
158
+ orig_lora_up_key = "time_projection.1.lora_up.weight"
159
+ target_base_key_peft = "condition_embedder.time_proj"
160
+ target_base_key_manual = "transformer.condition_embedder.time_proj"
161
+ if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
162
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
163
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
164
+ orig_diff_b_key = "time_projection.1.diff_b"
165
+ if orig_diff_b_key in original_state_dict:
166
+ manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
167
+
168
+ # Head
169
+ orig_lora_down_key = "head.head.lora_down.weight"
170
+ orig_lora_up_key = "head.head.lora_up.weight"
171
+ target_base_key_peft = "proj_out" # Directly under transformer in Diffusers DiT
172
+ target_base_key_manual = "transformer.proj_out"
173
+ if orig_lora_down_key in original_state_dict and orig_lora_up_key in original_state_dict:
174
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_A.weight"] = original_state_dict.pop(orig_lora_down_key)
175
+ converted_state_dict_for_peft[f"{target_base_key_peft}.lora_B.weight"] = original_state_dict.pop(orig_lora_up_key)
176
+ orig_diff_b_key = "head.head.diff_b"
177
+ if orig_diff_b_key in original_state_dict:
178
+ manual_diff_patches[f"{target_base_key_manual}.bias"] = original_state_dict.pop(orig_diff_b_key)
179
+
180
+ # Log any remaining keys from the original LoRA after stripping "diffusion_model."
181
+ if len(original_state_dict) > 0:
182
+ logger.warning(
183
+ f"Following keys from LoRA (after stripping 'diffusion_model.') were not converted or explicitly handled for PEFT/manual patching: {original_state_dict.keys()}"
184
+ )
185
+
186
+ # Add "transformer." prefix for Diffusers LoraLoaderMixin to the PEFT keys
187
+ final_peft_state_dict = {}
188
+ for k_peft, v_peft in converted_state_dict_for_peft.items():
189
+ final_peft_state_dict[f"transformer.{k_peft}"] = v_peft
190
+
191
+ MANUAL_PATCHES_STORE = manual_diff_patches # Store for later use
192
+ return final_peft_state_dict
193
+
194
+
195
+ def apply_manual_diff_patches(pipe_model: torch.nn.Module, patches: Dict[str, torch.Tensor]):
196
+ """
197
+ Manually applies diff_b/diff patches to the model.
198
+ Assumes PEFT LoRA layers have already been loaded.
199
+ """
200
+ if not patches:
201
+ logger.info("No manual diff patches to apply.")
202
+ return
203
+
204
+ logger.info(f"Applying {len(patches)} manual diff patches...")
205
+ patched_keys_count = 0
206
+ unpatched_keys_count = 0
207
+
208
+ for key, diff_tensor in patches.items():
209
+ try:
210
+ module_to_patch = pipe_model
211
+ attrs = key.split(".")
212
+
213
+ # Navigate to the parent module
214
+ # e.g., key = "transformer.blocks.0.attn1.to_q.bias"
215
+ # attrs[:-1] would be ["transformer", "blocks", "0", "attn1", "to_q"]
216
+ for attr_name in attrs[:-1]:
217
+ if hasattr(module_to_patch, attr_name):
218
+ module_to_patch = getattr(module_to_patch, attr_name)
219
+ else:
220
+ # If it's a PEFT wrapped layer, try to access its base_layer
221
+ if hasattr(module_to_patch, 'base_layer') and hasattr(module_to_patch.base_layer, attr_name):
222
+ module_to_patch = getattr(module_to_patch.base_layer, attr_name)
223
+ else:
224
+ raise AttributeError(f"Submodule {attr_name} not found in {module_to_patch}")
225
+
226
+ param_name = attrs[-1] # "bias" or "weight"
227
+
228
+ # Access the target layer (it might be a PEFT LoraLayer or a regular nn.Module)
229
+ target_layer = module_to_patch
230
+
231
+ # If PEFT wrapped it, the actual nn.Linear or nn.LayerNorm is in `base_layer`
232
+ if hasattr(target_layer, "base_layer") and isinstance(target_layer.base_layer, (torch.nn.Linear, torch.nn.LayerNorm)):
233
+ layer_to_modify = target_layer.base_layer
234
+ else:
235
+ layer_to_modify = target_layer
236
+
237
+ if not hasattr(layer_to_modify, param_name):
238
+ logger.error(f"Parameter '{param_name}' not found in layer '{layer_to_modify}' for key '{key}'. Skipping.")
239
+ unpatched_keys_count +=1
240
+ continue
241
+
242
+ original_param = getattr(layer_to_modify, param_name)
243
+
244
+ if original_param is None and param_name == "bias":
245
+ # If bias is None (e.g., LayerNorm with elementwise_affine=False, or Linear(bias=False)),
246
+ # we might need to initialize it if the diff expects to add to it.
247
+ # For Linear layers, if bias was False, it should remain False unless LoRA intends to add one.
248
+ # For LayerNorm, if elementwise_affine was False, adding a bias diff means it becomes affine.
249
+ if isinstance(layer_to_modify, torch.nn.Linear):
250
+ if layer_to_modify.bias is None: # Check if bias was intentionally None
251
+ logger.warning(f"Original layer {layer_to_modify} for key '{key}' has no bias. Creating one to apply diff_b. This might be unintended if bias=False was set.")
252
+ layer_to_modify.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
253
+ original_param = layer_to_modify.bias
254
+ else: # Should not happen if original_param was None but layer_to_modify.bias isn't
255
+ pass
256
+ elif isinstance(layer_to_modify, torch.nn.LayerNorm):
257
+ if not layer_to_modify.elementwise_affine:
258
+ logger.warning(f"LayerNorm {layer_to_modify} for key '{key}' was not elementwise_affine. Applying bias diff will make it effectively affine for bias.")
259
+ # LayerNorm bias is initialized to zeros if elementwise_affine is True
260
+ layer_to_modify.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
261
+ original_param = layer_to_modify.bias
262
+ # Also need to ensure weight exists if a weight diff is applied later
263
+ if param_name == "bias" and not hasattr(layer_to_modify, "weight"):
264
+ layer_to_modify.weight = torch.nn.Parameter(torch.ones_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype)) # Norm weights init to 1
265
+
266
+ if original_param is not None:
267
+ if original_param.shape != diff_tensor.shape:
268
+ logger.error(f"Shape mismatch for key '{key}': model param '{original_param.shape}', LoRA diff '{diff_tensor.shape}'. Skipping.")
269
+ unpatched_keys_count +=1
270
+ continue
271
+ with torch.no_grad():
272
+ original_param.add_(diff_tensor.to(original_param.device, original_param.dtype))
273
+ logger.info(f"Successfully applied diff to '{key}'")
274
+ patched_keys_count +=1
275
+ else:
276
+ logger.warning(f"Original parameter '{param_name}' is None for key '{key}' and was not initialized. Cannot apply diff. Skipping.")
277
+ unpatched_keys_count +=1
278
+
279
+
280
+ except AttributeError as e:
281
+ logger.error(f"AttributeError: Could not find module or parameter for key '{key}'. Error: {e}. Skipping.")
282
+ unpatched_keys_count +=1
283
+ except Exception as e:
284
+ logger.error(f"General error applying patch for key '{key}': {e}. Skipping.")
285
+ unpatched_keys_count +=1
286
+ logger.info(f"Manual patching summary: {patched_keys_count} keys patched, {unpatched_keys_count} keys failed or skipped.")
287
+
288
+
289
+ # --- Model Loading ---
290
+ logger.info(f"Loading VAE for {MODEL_ID}...")
291
  vae = AutoencoderKLWan.from_pretrained(
292
  MODEL_ID,
293
  subfolder="vae",
294
+ torch_dtype=torch.float32 # float32 for VAE stability
295
  )
296
+ logger.info(f"Loading Pipeline {MODEL_ID}...")
 
297
  pipe = WanPipeline.from_pretrained(
298
  MODEL_ID,
299
  vae=vae,
300
+ torch_dtype=torch.bfloat16 # bfloat16 for pipeline
301
  )
302
+ logger.info("Moving pipeline to CUDA...")
303
  pipe.to("cuda")
304
+
305
+ # --- LoRA Loading ---
306
+ logger.info(f"Downloading LoRA {LORA_FILENAME} from {LORA_REPO_ID}...")
307
+ causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
308
+
309
+ logger.info("Loading LoRA weights with custom converter...")
310
+ # The load_lora_weights will use the lora_converters mechanism if available.
311
+ # We need to ensure our custom converter is registered or passed correctly.
312
+ # Since WanPipeline inherits from a LoraLoaderMixin that might have its own
313
+ # lora_state_dict, we need to be careful.
314
+ # A robust way is to load the state_dict, convert it, then load the converted dict.
315
+
316
+ lora_state_dict_raw = WanPipeline.lora_state_dict(causvid_path) # This might already do some conversion
317
+
318
+ # If WanPipeline.lora_state_dict doesn't directly call our specific wan converter,
319
+ # we might need to load the raw safetensors file and then call our converter.
320
+ # Let's assume for now that lora_state_dict loads it and we then pass it to our converter.
321
+ # If WanPipeline's lora_state_dict already calls a wan-specific converter,
322
+ # then we need to inject our custom one there, which is not possible without modifying the library.
323
+
324
+ # Alternative: Load raw state_dict and then convert
325
+ from safetensors.torch import load_file as load_safetensors
326
+ raw_lora_state_dict = load_safetensors(causvid_path)
327
+
328
+ # Now call our custom converter which will populate MANUAL_PATCHES_STORE
329
+ peft_state_dict = _custom_convert_non_diffusers_wan_lora_to_diffusers(raw_lora_state_dict)
330
+
331
+ # Load the LoRA A/B matrices using PEFT
332
+ if peft_state_dict:
333
+ pipe.load_lora_weights(
334
+ peft_state_dict, # Pass the dictionary directly
335
+ adapter_name="causvid_lora"
336
+ )
337
+ logger.info("PEFT LoRA A/B weights loaded.")
338
+ else:
339
+ logger.warning("No PEFT-compatible LoRA weights found after conversion.")
340
+
341
+ # Apply manual diff_b and diff patches
342
+ apply_manual_diff_patches(pipe.transformer, MANUAL_PATCHES_STORE) # Apply to the transformer component
343
+ logger.info("Manual diff_b/diff patches applied.")
344
+
345
 
346
  # --- Gradio Interface Function ---
347
  @spaces.GPU
348
  def generate_video(prompt, negative_prompt, height, width, num_frames, guidance_scale, fps):
349
+ logger.info("Starting video generation...")
350
+ logger.info(f" Prompt: {prompt}")
351
+ logger.info(f" Negative Prompt: {negative_prompt if negative_prompt else 'None'}")
352
+ logger.info(f" Height: {height}, Width: {width}")
353
+ logger.info(f" Num Frames: {num_frames}, FPS: {fps}")
354
+ logger.info(f" Guidance Scale: {guidance_scale}")
355
 
 
 
 
 
 
 
 
 
356
  height = (int(height) // 8) * 8
357
  width = (int(width) // 8) * 8
358
  num_frames = int(num_frames)
359
  fps = int(fps)
360
 
361
+ with torch.inference_mode():
362
  output_frames_list = pipe(
363
  prompt=prompt,
364
  negative_prompt=negative_prompt,
 
366
  width=width,
367
  num_frames=num_frames,
368
  guidance_scale=float(guidance_scale),
 
369
  ).frames
370
 
371
  if not output_frames_list or not output_frames_list[0]:
372
  raise gr.Error("Model returned empty frames. Check parameters or try a different prompt.")
373
+ output_frames = output_frames_list[0]
 
374
 
375
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
376
  video_path = tmpfile.name
 
377
  export_to_video(output_frames, video_path, fps=fps)
378
+ logger.info(f"Video successfully generated and saved to {video_path}")
 
379
  return video_path
380
 
 
381
  # --- Gradio UI Definition ---
382
  default_prompt = "A cat walks on the grass, realistic"
383
  default_negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
384
 
385
  with gr.Blocks() as demo:
386
  gr.Markdown(f"""
387
+ # Text-to-Video with Wan 2.1 (14B) + CausVid LoRA
388
  Powered by `diffusers` and `Wan-AI/{MODEL_ID}`.
389
  Model is loaded into memory when the app starts. This might take a few minutes.
390
  Ensure you have a GPU with sufficient VRAM (e.g., ~24GB+ for these default settings).
391
  """)
392
+ # ... (rest of your Gradio UI definition remains the same) ...
393
  with gr.Row():
394
  with gr.Column(scale=2):
395
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt, lines=3)
 
435
  inputs=[prompt_input, negative_prompt_input, height_input, width_input, num_frames_input, guidance_scale_input, fps_input],
436
  outputs=video_output,
437
  fn=generate_video,
438
+ cache_examples=False
439
  )
440
 
441
  if __name__ == "__main__":
 
442
  demo.queue().launch(share=True, debug=True)