multimodalart HF Staff commited on
Commit
8df45ae
·
verified ·
1 Parent(s): 721f7aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -61
app.py CHANGED
@@ -192,7 +192,7 @@ def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict):
192
  return final_peft_state_dict
193
 
194
 
195
- def apply_manual_diff_patches(pipe_model, patches):
196
  """
197
  Manually applies diff_b/diff patches to the model.
198
  Assumes PEFT LoRA layers have already been loaded.
@@ -204,87 +204,95 @@ def apply_manual_diff_patches(pipe_model, patches):
204
  logger.info(f"Applying {len(patches)} manual diff patches...")
205
  patched_keys_count = 0
206
  unpatched_keys_count = 0
 
207
 
208
  for key, diff_tensor in patches.items():
209
  try:
210
- module_to_patch = pipe_model
211
- attrs = key.split(".")
 
 
212
 
213
- # Navigate to the parent module
214
- # e.g., key = "transformer.blocks.0.attn1.to_q.bias"
215
- # attrs[:-1] would be ["transformer", "blocks", "0", "attn1", "to_q"]
216
- for attr_name in attrs[:-1]:
217
- if hasattr(module_to_patch, attr_name):
218
- module_to_patch = getattr(module_to_patch, attr_name)
 
 
 
 
 
 
 
219
  else:
220
- # If it's a PEFT wrapped layer, try to access its base_layer
221
- if hasattr(module_to_patch, 'base_layer') and hasattr(module_to_patch.base_layer, attr_name):
222
- module_to_patch = getattr(module_to_patch.base_layer, attr_name)
223
- else:
224
- raise AttributeError(f"Submodule {attr_name} not found in {module_to_patch}")
 
 
 
 
 
 
 
225
 
226
- param_name = attrs[-1] # "bias" or "weight"
227
-
228
- # Access the target layer (it might be a PEFT LoraLayer or a regular nn.Module)
229
- target_layer = module_to_patch
230
-
231
- # If PEFT wrapped it, the actual nn.Linear or nn.LayerNorm is in `base_layer`
232
- if hasattr(target_layer, "base_layer") and isinstance(target_layer.base_layer, (torch.nn.Linear, torch.nn.LayerNorm)):
233
- layer_to_modify = target_layer.base_layer
234
- else:
235
- layer_to_modify = target_layer
236
 
237
- if not hasattr(layer_to_modify, param_name):
238
- logger.error(f"Parameter '{param_name}' not found in layer '{layer_to_modify}' for key '{key}'. Skipping.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  unpatched_keys_count +=1
240
  continue
241
 
242
- original_param = getattr(layer_to_modify, param_name)
243
-
244
- if original_param is None and param_name == "bias":
245
- # If bias is None (e.g., LayerNorm with elementwise_affine=False, or Linear(bias=False)),
246
- # we might need to initialize it if the diff expects to add to it.
247
- # For Linear layers, if bias was False, it should remain False unless LoRA intends to add one.
248
- # For LayerNorm, if elementwise_affine was False, adding a bias diff means it becomes affine.
249
- if isinstance(layer_to_modify, torch.nn.Linear):
250
- if layer_to_modify.bias is None: # Check if bias was intentionally None
251
- logger.warning(f"Original layer {layer_to_modify} for key '{key}' has no bias. Creating one to apply diff_b. This might be unintended if bias=False was set.")
252
- layer_to_modify.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
253
- original_param = layer_to_modify.bias
254
- else: # Should not happen if original_param was None but layer_to_modify.bias isn't
255
- pass
256
- elif isinstance(layer_to_modify, torch.nn.LayerNorm):
257
- if not layer_to_modify.elementwise_affine:
258
- logger.warning(f"LayerNorm {layer_to_modify} for key '{key}' was not elementwise_affine. Applying bias diff will make it effectively affine for bias.")
259
- # LayerNorm bias is initialized to zeros if elementwise_affine is True
260
- layer_to_modify.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
261
- original_param = layer_to_modify.bias
262
- # Also need to ensure weight exists if a weight diff is applied later
263
- if param_name == "bias" and not hasattr(layer_to_modify, "weight"):
264
- layer_to_modify.weight = torch.nn.Parameter(torch.ones_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype)) # Norm weights init to 1
265
 
266
  if original_param is not None:
267
  if original_param.shape != diff_tensor.shape:
268
- logger.error(f"Shape mismatch for key '{key}': model param '{original_param.shape}', LoRA diff '{diff_tensor.shape}'. Skipping.")
269
- unpatched_keys_count +=1
270
  continue
271
  with torch.no_grad():
272
  original_param.add_(diff_tensor.to(original_param.device, original_param.dtype))
273
- logger.info(f"Successfully applied diff to '{key}'")
274
- patched_keys_count +=1
275
  else:
276
- logger.warning(f"Original parameter '{param_name}' is None for key '{key}' and was not initialized. Cannot apply diff. Skipping.")
277
- unpatched_keys_count +=1
278
-
279
 
280
  except AttributeError as e:
281
- logger.error(f"AttributeError: Could not find module or parameter for key '{key}'. Error: {e}. Skipping.")
282
- unpatched_keys_count +=1
283
  except Exception as e:
284
- logger.error(f"General error applying patch for key '{key}': {e}. Skipping.")
285
- unpatched_keys_count +=1
286
- logger.info(f"Manual patching summary: {patched_keys_count} keys patched, {unpatched_keys_count} keys failed or skipped.")
287
 
 
 
 
 
 
288
 
289
  # --- Model Loading ---
290
  logger.info(f"Loading VAE for {MODEL_ID}...")
@@ -411,6 +419,7 @@ with gr.Blocks() as demo:
411
  width_input,
412
  num_frames_input,
413
  guidance_scale_input,
 
414
  fps_input
415
  ],
416
  outputs=video_output
 
192
  return final_peft_state_dict
193
 
194
 
195
+ def apply_manual_diff_patches(pipe_model: torch.nn.Module, patches: Dict[str, torch.Tensor]):
196
  """
197
  Manually applies diff_b/diff patches to the model.
198
  Assumes PEFT LoRA layers have already been loaded.
 
204
  logger.info(f"Applying {len(patches)} manual diff patches...")
205
  patched_keys_count = 0
206
  unpatched_keys_count = 0
207
+ skipped_keys_details = []
208
 
209
  for key, diff_tensor in patches.items():
210
  try:
211
+ # key is like "transformer.blocks.0.attn1.to_q.bias"
212
+ current_module = pipe_model # Starts from pipe.transformer
213
+ path_parts = key.split('.')[1:] # Remove "transformer." prefix for getattr navigation
214
+ # e.g., ["blocks", "0", "attn1", "to_q", "bias"]
215
 
216
+ # Navigate to the parent module of the parameter
217
+ # Example: for "blocks.0.attn1.to_q.bias", parent_module_path is "blocks.0.attn1.to_q"
218
+ parent_module_path = path_parts[:-1]
219
+ param_name_to_patch = path_parts[-1] # "bias" or "weight"
220
+
221
+ for part in parent_module_path:
222
+ if hasattr(current_module, part):
223
+ current_module = getattr(current_module, part)
224
+ elif hasattr(current_module, 'base_layer') and hasattr(current_module.base_layer, part):
225
+ # This case is unlikely here as we are navigating *to* the layer,
226
+ # not trying to access a sub-component of a base_layer.
227
+ # PEFT wrapping affects the layer itself, not its parent structure.
228
+ current_module = getattr(current_module.base_layer, part)
229
  else:
230
+ raise AttributeError(f"Submodule '{part}' not found in path '{'.'.join(parent_module_path)}' within {key}")
231
+
232
+ # Now, current_module is the layer whose parameter we want to patch
233
+ # e.g., if key was transformer.blocks.0.attn1.to_q.bias,
234
+ # current_module is the to_q Linear layer (or LoraLayer wrapping it)
235
+
236
+ layer_to_modify = current_module
237
+ # If PEFT wrapped the Linear layer (common for attention q,k,v,o and ffn projections)
238
+ if hasattr(layer_to_modify, "base_layer") and isinstance(layer_to_modify.base_layer, (torch.nn.Linear, torch.nn.LayerNorm)):
239
+ actual_param_owner = layer_to_modify.base_layer
240
+ else: # For non-wrapped layers like LayerNorm, or if it's already the base_layer
241
+ actual_param_owner = layer_to_modify
242
 
243
+ if not hasattr(actual_param_owner, param_name_to_patch):
244
+ skipped_keys_details.append(f"Key: {key}, Reason: Parameter '{param_name_to_patch}' not found in layer '{actual_param_owner}'. Layer type: {type(actual_param_owner)}")
245
+ unpatched_keys_count += 1
246
+ continue
 
 
 
 
 
 
247
 
248
+ original_param = getattr(actual_param_owner, param_name_to_patch)
249
+
250
+ if original_param is None and param_name_to_patch == "bias":
251
+ logger.info(f"Key '{key}': Original bias is None. Attempting to initialize.")
252
+ if isinstance(actual_param_owner, torch.nn.Linear) or isinstance(actual_param_owner, torch.nn.LayerNorm):
253
+ # For LayerNorm, bias exists if elementwise_affine=True (default).
254
+ # If it was False, we are making it affine by adding a bias.
255
+ # For Linear, if bias was False, we are adding one.
256
+ actual_param_owner.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
257
+ original_param = actual_param_owner.bias
258
+ logger.info(f"Key '{key}': Initialized bias for {type(actual_param_owner)}.")
259
+ else:
260
+ skipped_keys_details.append(f"Key: {key}, Reason: Original bias is None and layer '{actual_param_owner}' is not Linear or LayerNorm. Cannot initialize.")
261
+ unpatched_keys_count +=1
262
+ continue
263
+
264
+ # Special handling for RMSNorm which typically has no bias
265
+ if isinstance(actual_param_owner, torch.nn.RMSNorm) and param_name_to_patch == "bias":
266
+ skipped_keys_details.append(f"Key: {key}, Reason: Layer '{actual_param_owner}' is RMSNorm which has no bias parameter. Skipping bias diff.")
267
  unpatched_keys_count +=1
268
  continue
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  if original_param is not None:
272
  if original_param.shape != diff_tensor.shape:
273
+ skipped_keys_details.append(f"Key: {key}, Reason: Shape mismatch. Model param: {original_param.shape}, LoRA diff: {diff_tensor.shape}. Layer: {actual_param_owner}")
274
+ unpatched_keys_count += 1
275
  continue
276
  with torch.no_grad():
277
  original_param.add_(diff_tensor.to(original_param.device, original_param.dtype))
278
+ # logger.info(f"Successfully applied diff to '{key}'") # Too verbose, will log summary
279
+ patched_keys_count += 1
280
  else:
281
+ skipped_keys_details.append(f"Key: {key}, Reason: Original parameter '{param_name_to_patch}' is None and was not initialized. Layer: {actual_param_owner}")
282
+ unpatched_keys_count += 1
 
283
 
284
  except AttributeError as e:
285
+ skipped_keys_details.append(f"Key: {key}, Reason: AttributeError - {e}")
286
+ unpatched_keys_count += 1
287
  except Exception as e:
288
+ skipped_keys_details.append(f"Key: {key}, Reason: General Exception - {e}")
289
+ unpatched_keys_count += 1
 
290
 
291
+ logger.info(f"Manual patching summary: {patched_keys_count} keys patched, {unpatched_keys_count} keys failed or skipped.")
292
+ if unpatched_keys_count > 0:
293
+ logger.warning("Details of unpatched/skipped keys:")
294
+ for detail in skipped_keys_details:
295
+ logger.warning(f" - {detail}")
296
 
297
  # --- Model Loading ---
298
  logger.info(f"Loading VAE for {MODEL_ID}...")
 
419
  width_input,
420
  num_frames_input,
421
  guidance_scale_input,
422
+ steps,
423
  fps_input
424
  ],
425
  outputs=video_output