Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -192,7 +192,7 @@ def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
|
192 |
return final_peft_state_dict
|
193 |
|
194 |
|
195 |
-
def apply_manual_diff_patches(pipe_model, patches):
|
196 |
"""
|
197 |
Manually applies diff_b/diff patches to the model.
|
198 |
Assumes PEFT LoRA layers have already been loaded.
|
@@ -204,87 +204,95 @@ def apply_manual_diff_patches(pipe_model, patches):
|
|
204 |
logger.info(f"Applying {len(patches)} manual diff patches...")
|
205 |
patched_keys_count = 0
|
206 |
unpatched_keys_count = 0
|
|
|
207 |
|
208 |
for key, diff_tensor in patches.items():
|
209 |
try:
|
210 |
-
|
211 |
-
|
|
|
|
|
212 |
|
213 |
-
# Navigate to the parent module
|
214 |
-
#
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
else:
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
# If PEFT wrapped it, the actual nn.Linear or nn.LayerNorm is in `base_layer`
|
232 |
-
if hasattr(target_layer, "base_layer") and isinstance(target_layer.base_layer, (torch.nn.Linear, torch.nn.LayerNorm)):
|
233 |
-
layer_to_modify = target_layer.base_layer
|
234 |
-
else:
|
235 |
-
layer_to_modify = target_layer
|
236 |
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
unpatched_keys_count +=1
|
240 |
continue
|
241 |
|
242 |
-
original_param = getattr(layer_to_modify, param_name)
|
243 |
-
|
244 |
-
if original_param is None and param_name == "bias":
|
245 |
-
# If bias is None (e.g., LayerNorm with elementwise_affine=False, or Linear(bias=False)),
|
246 |
-
# we might need to initialize it if the diff expects to add to it.
|
247 |
-
# For Linear layers, if bias was False, it should remain False unless LoRA intends to add one.
|
248 |
-
# For LayerNorm, if elementwise_affine was False, adding a bias diff means it becomes affine.
|
249 |
-
if isinstance(layer_to_modify, torch.nn.Linear):
|
250 |
-
if layer_to_modify.bias is None: # Check if bias was intentionally None
|
251 |
-
logger.warning(f"Original layer {layer_to_modify} for key '{key}' has no bias. Creating one to apply diff_b. This might be unintended if bias=False was set.")
|
252 |
-
layer_to_modify.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
|
253 |
-
original_param = layer_to_modify.bias
|
254 |
-
else: # Should not happen if original_param was None but layer_to_modify.bias isn't
|
255 |
-
pass
|
256 |
-
elif isinstance(layer_to_modify, torch.nn.LayerNorm):
|
257 |
-
if not layer_to_modify.elementwise_affine:
|
258 |
-
logger.warning(f"LayerNorm {layer_to_modify} for key '{key}' was not elementwise_affine. Applying bias diff will make it effectively affine for bias.")
|
259 |
-
# LayerNorm bias is initialized to zeros if elementwise_affine is True
|
260 |
-
layer_to_modify.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
|
261 |
-
original_param = layer_to_modify.bias
|
262 |
-
# Also need to ensure weight exists if a weight diff is applied later
|
263 |
-
if param_name == "bias" and not hasattr(layer_to_modify, "weight"):
|
264 |
-
layer_to_modify.weight = torch.nn.Parameter(torch.ones_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype)) # Norm weights init to 1
|
265 |
|
266 |
if original_param is not None:
|
267 |
if original_param.shape != diff_tensor.shape:
|
268 |
-
|
269 |
-
unpatched_keys_count +=1
|
270 |
continue
|
271 |
with torch.no_grad():
|
272 |
original_param.add_(diff_tensor.to(original_param.device, original_param.dtype))
|
273 |
-
logger.info(f"Successfully applied diff to '{key}'")
|
274 |
-
patched_keys_count +=1
|
275 |
else:
|
276 |
-
|
277 |
-
unpatched_keys_count +=1
|
278 |
-
|
279 |
|
280 |
except AttributeError as e:
|
281 |
-
|
282 |
-
unpatched_keys_count +=1
|
283 |
except Exception as e:
|
284 |
-
|
285 |
-
unpatched_keys_count +=1
|
286 |
-
logger.info(f"Manual patching summary: {patched_keys_count} keys patched, {unpatched_keys_count} keys failed or skipped.")
|
287 |
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
# --- Model Loading ---
|
290 |
logger.info(f"Loading VAE for {MODEL_ID}...")
|
@@ -411,6 +419,7 @@ with gr.Blocks() as demo:
|
|
411 |
width_input,
|
412 |
num_frames_input,
|
413 |
guidance_scale_input,
|
|
|
414 |
fps_input
|
415 |
],
|
416 |
outputs=video_output
|
|
|
192 |
return final_peft_state_dict
|
193 |
|
194 |
|
195 |
+
def apply_manual_diff_patches(pipe_model: torch.nn.Module, patches: Dict[str, torch.Tensor]):
|
196 |
"""
|
197 |
Manually applies diff_b/diff patches to the model.
|
198 |
Assumes PEFT LoRA layers have already been loaded.
|
|
|
204 |
logger.info(f"Applying {len(patches)} manual diff patches...")
|
205 |
patched_keys_count = 0
|
206 |
unpatched_keys_count = 0
|
207 |
+
skipped_keys_details = []
|
208 |
|
209 |
for key, diff_tensor in patches.items():
|
210 |
try:
|
211 |
+
# key is like "transformer.blocks.0.attn1.to_q.bias"
|
212 |
+
current_module = pipe_model # Starts from pipe.transformer
|
213 |
+
path_parts = key.split('.')[1:] # Remove "transformer." prefix for getattr navigation
|
214 |
+
# e.g., ["blocks", "0", "attn1", "to_q", "bias"]
|
215 |
|
216 |
+
# Navigate to the parent module of the parameter
|
217 |
+
# Example: for "blocks.0.attn1.to_q.bias", parent_module_path is "blocks.0.attn1.to_q"
|
218 |
+
parent_module_path = path_parts[:-1]
|
219 |
+
param_name_to_patch = path_parts[-1] # "bias" or "weight"
|
220 |
+
|
221 |
+
for part in parent_module_path:
|
222 |
+
if hasattr(current_module, part):
|
223 |
+
current_module = getattr(current_module, part)
|
224 |
+
elif hasattr(current_module, 'base_layer') and hasattr(current_module.base_layer, part):
|
225 |
+
# This case is unlikely here as we are navigating *to* the layer,
|
226 |
+
# not trying to access a sub-component of a base_layer.
|
227 |
+
# PEFT wrapping affects the layer itself, not its parent structure.
|
228 |
+
current_module = getattr(current_module.base_layer, part)
|
229 |
else:
|
230 |
+
raise AttributeError(f"Submodule '{part}' not found in path '{'.'.join(parent_module_path)}' within {key}")
|
231 |
+
|
232 |
+
# Now, current_module is the layer whose parameter we want to patch
|
233 |
+
# e.g., if key was transformer.blocks.0.attn1.to_q.bias,
|
234 |
+
# current_module is the to_q Linear layer (or LoraLayer wrapping it)
|
235 |
+
|
236 |
+
layer_to_modify = current_module
|
237 |
+
# If PEFT wrapped the Linear layer (common for attention q,k,v,o and ffn projections)
|
238 |
+
if hasattr(layer_to_modify, "base_layer") and isinstance(layer_to_modify.base_layer, (torch.nn.Linear, torch.nn.LayerNorm)):
|
239 |
+
actual_param_owner = layer_to_modify.base_layer
|
240 |
+
else: # For non-wrapped layers like LayerNorm, or if it's already the base_layer
|
241 |
+
actual_param_owner = layer_to_modify
|
242 |
|
243 |
+
if not hasattr(actual_param_owner, param_name_to_patch):
|
244 |
+
skipped_keys_details.append(f"Key: {key}, Reason: Parameter '{param_name_to_patch}' not found in layer '{actual_param_owner}'. Layer type: {type(actual_param_owner)}")
|
245 |
+
unpatched_keys_count += 1
|
246 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
+
original_param = getattr(actual_param_owner, param_name_to_patch)
|
249 |
+
|
250 |
+
if original_param is None and param_name_to_patch == "bias":
|
251 |
+
logger.info(f"Key '{key}': Original bias is None. Attempting to initialize.")
|
252 |
+
if isinstance(actual_param_owner, torch.nn.Linear) or isinstance(actual_param_owner, torch.nn.LayerNorm):
|
253 |
+
# For LayerNorm, bias exists if elementwise_affine=True (default).
|
254 |
+
# If it was False, we are making it affine by adding a bias.
|
255 |
+
# For Linear, if bias was False, we are adding one.
|
256 |
+
actual_param_owner.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
|
257 |
+
original_param = actual_param_owner.bias
|
258 |
+
logger.info(f"Key '{key}': Initialized bias for {type(actual_param_owner)}.")
|
259 |
+
else:
|
260 |
+
skipped_keys_details.append(f"Key: {key}, Reason: Original bias is None and layer '{actual_param_owner}' is not Linear or LayerNorm. Cannot initialize.")
|
261 |
+
unpatched_keys_count +=1
|
262 |
+
continue
|
263 |
+
|
264 |
+
# Special handling for RMSNorm which typically has no bias
|
265 |
+
if isinstance(actual_param_owner, torch.nn.RMSNorm) and param_name_to_patch == "bias":
|
266 |
+
skipped_keys_details.append(f"Key: {key}, Reason: Layer '{actual_param_owner}' is RMSNorm which has no bias parameter. Skipping bias diff.")
|
267 |
unpatched_keys_count +=1
|
268 |
continue
|
269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
if original_param is not None:
|
272 |
if original_param.shape != diff_tensor.shape:
|
273 |
+
skipped_keys_details.append(f"Key: {key}, Reason: Shape mismatch. Model param: {original_param.shape}, LoRA diff: {diff_tensor.shape}. Layer: {actual_param_owner}")
|
274 |
+
unpatched_keys_count += 1
|
275 |
continue
|
276 |
with torch.no_grad():
|
277 |
original_param.add_(diff_tensor.to(original_param.device, original_param.dtype))
|
278 |
+
# logger.info(f"Successfully applied diff to '{key}'") # Too verbose, will log summary
|
279 |
+
patched_keys_count += 1
|
280 |
else:
|
281 |
+
skipped_keys_details.append(f"Key: {key}, Reason: Original parameter '{param_name_to_patch}' is None and was not initialized. Layer: {actual_param_owner}")
|
282 |
+
unpatched_keys_count += 1
|
|
|
283 |
|
284 |
except AttributeError as e:
|
285 |
+
skipped_keys_details.append(f"Key: {key}, Reason: AttributeError - {e}")
|
286 |
+
unpatched_keys_count += 1
|
287 |
except Exception as e:
|
288 |
+
skipped_keys_details.append(f"Key: {key}, Reason: General Exception - {e}")
|
289 |
+
unpatched_keys_count += 1
|
|
|
290 |
|
291 |
+
logger.info(f"Manual patching summary: {patched_keys_count} keys patched, {unpatched_keys_count} keys failed or skipped.")
|
292 |
+
if unpatched_keys_count > 0:
|
293 |
+
logger.warning("Details of unpatched/skipped keys:")
|
294 |
+
for detail in skipped_keys_details:
|
295 |
+
logger.warning(f" - {detail}")
|
296 |
|
297 |
# --- Model Loading ---
|
298 |
logger.info(f"Loading VAE for {MODEL_ID}...")
|
|
|
419 |
width_input,
|
420 |
num_frames_input,
|
421 |
guidance_scale_input,
|
422 |
+
steps,
|
423 |
fps_input
|
424 |
],
|
425 |
outputs=video_output
|