jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
raw
history blame
2.53 kB
import json
from pathlib import Path
from typing import Optional
import safetensors.torch
from diffusers import DiffusionPipeline
from diffusers.loaders.lora_pipeline import _LOW_CPU_MEM_USAGE_DEFAULT_LORA
from huggingface_hub import repo_exists, snapshot_download
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from finetrainers.logging import get_logger
from finetrainers.utils import find_files
logger = get_logger()
def load_lora_weights(
pipeline: DiffusionPipeline, pretrained_model_name_or_path: str, adapter_name: Optional[str] = None, **kwargs
) -> None:
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
is_local_file_path = Path(pretrained_model_name_or_path).is_dir()
if not is_local_file_path:
does_repo_exist = repo_exists(pretrained_model_name_or_path, repo_type="model")
if not does_repo_exist:
raise ValueError(f"Model repo {pretrained_model_name_or_path} does not exist on the Hub or locally.")
else:
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path, repo_type="model")
prefix = "transformer"
state_dict = pipeline.lora_state_dict(pretrained_model_name_or_path)
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
file_list = find_files(pretrained_model_name_or_path, "*.safetensors", depth=1)
if len(file_list) == 0:
raise ValueError(f"No .safetensors files found in {pretrained_model_name_or_path}.")
if len(file_list) > 1:
logger.warning(
f"Multiple .safetensors files found in {pretrained_model_name_or_path}. Using the first one: {file_list[0]}."
)
with safetensors.torch.safe_open(file_list[0], framework="pt") as f:
metadata = f.metadata()
metadata = json.loads(metadata["lora_config"])
transformer = pipeline.transformer
if adapter_name is None:
adapter_name = "default"
lora_config = LoraConfig(**metadata)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
result = set_peft_model_state_dict(
transformer,
state_dict,
adapter_name=adapter_name,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=low_cpu_mem_usage,
)
logger.debug(
f"Loaded LoRA weights from {pretrained_model_name_or_path} into {pipeline.__class__.__name__}. Result: {result}"
)