TheNetherWatcher's picture
Upload folder using huggingface_hub
d0ffe9c verified
raw
history blame
8.53 kB
import logging
from functools import wraps
from pathlib import Path
from typing import Optional, TypeVar
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
from torch import nn
from animatediff import HF_HUB_CACHE, HF_MODULE_REPO, get_dir
from animatediff.settings import CKPT_EXTENSIONS
from animatediff.utils.huggingface import get_hf_pipeline, get_hf_pipeline_sdxl
from animatediff.utils.util import path_from_cwd
logger = logging.getLogger(__name__)
data_dir = get_dir("data")
checkpoint_dir = data_dir.joinpath("models/sd")
pipeline_dir = data_dir.joinpath("models/huggingface")
# for the nop_train() monkeypatch
T = TypeVar("T", bound=nn.Module)
def nop_train(self: T, mode: bool = True) -> T:
"""No-op for monkeypatching train() call to prevent unfreezing module"""
return self
def get_base_model(model_name_or_path: str, local_dir: Path, force: bool = False, is_sdxl:bool=False) -> Path:
model_name_or_path = Path(model_name_or_path)
model_save_dir = local_dir.joinpath(str(model_name_or_path).split("/")[-1]).resolve()
model_is_repo_id = False if model_name_or_path.joinpath("model_index.json").exists() else True
# if we have a HF repo ID, download it
if model_is_repo_id:
logger.debug("Base model is a HuggingFace repo ID")
if model_save_dir.joinpath("model_index.json").exists():
logger.debug(f"Base model already downloaded to: {path_from_cwd(model_save_dir)}")
else:
logger.info(f"Downloading base model from {model_name_or_path}...")
if is_sdxl:
_ = get_hf_pipeline_sdxl(model_name_or_path, model_save_dir, save=True, force_download=force)
else:
_ = get_hf_pipeline(model_name_or_path, model_save_dir, save=True, force_download=force)
model_name_or_path = model_save_dir
return Path(model_name_or_path)
def fix_checkpoint_if_needed(checkpoint: Path, debug:bool):
def dump(loaded):
for a in loaded:
logger.info(f"{a} {loaded[a].shape}")
if debug:
from safetensors.torch import load_file, save_file
loaded = load_file(checkpoint, "cpu")
dump(loaded)
return
try:
pipeline = StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=str(checkpoint.absolute()),
local_files_only=False,
load_safety_checker=False,
)
logger.info("This file works fine.")
return
except:
from safetensors.torch import load_file, save_file
loaded = load_file(checkpoint, "cpu")
convert_table_bias={
"first_stage_model.decoder.mid.attn_1.to_k.bias":"first_stage_model.decoder.mid.attn_1.k.bias",
"first_stage_model.decoder.mid.attn_1.to_out.0.bias":"first_stage_model.decoder.mid.attn_1.proj_out.bias",
"first_stage_model.decoder.mid.attn_1.to_q.bias":"first_stage_model.decoder.mid.attn_1.q.bias",
"first_stage_model.decoder.mid.attn_1.to_v.bias":"first_stage_model.decoder.mid.attn_1.v.bias",
"first_stage_model.encoder.mid.attn_1.to_k.bias":"first_stage_model.encoder.mid.attn_1.k.bias",
"first_stage_model.encoder.mid.attn_1.to_out.0.bias":"first_stage_model.encoder.mid.attn_1.proj_out.bias",
"first_stage_model.encoder.mid.attn_1.to_q.bias":"first_stage_model.encoder.mid.attn_1.q.bias",
"first_stage_model.encoder.mid.attn_1.to_v.bias":"first_stage_model.encoder.mid.attn_1.v.bias",
}
convert_table_weight={
"first_stage_model.decoder.mid.attn_1.to_k.weight":"first_stage_model.decoder.mid.attn_1.k.weight",
"first_stage_model.decoder.mid.attn_1.to_out.0.weight":"first_stage_model.decoder.mid.attn_1.proj_out.weight",
"first_stage_model.decoder.mid.attn_1.to_q.weight":"first_stage_model.decoder.mid.attn_1.q.weight",
"first_stage_model.decoder.mid.attn_1.to_v.weight":"first_stage_model.decoder.mid.attn_1.v.weight",
"first_stage_model.encoder.mid.attn_1.to_k.weight":"first_stage_model.encoder.mid.attn_1.k.weight",
"first_stage_model.encoder.mid.attn_1.to_out.0.weight":"first_stage_model.encoder.mid.attn_1.proj_out.weight",
"first_stage_model.encoder.mid.attn_1.to_q.weight":"first_stage_model.encoder.mid.attn_1.q.weight",
"first_stage_model.encoder.mid.attn_1.to_v.weight":"first_stage_model.encoder.mid.attn_1.v.weight",
}
for a in list(loaded.keys()):
if a in convert_table_bias:
new_key = convert_table_bias[a]
loaded[new_key] = loaded.pop(a)
elif a in convert_table_weight:
new_key = convert_table_weight[a]
item = loaded.pop(a)
if len(item.shape) == 2:
item = item.unsqueeze(dim=-1).unsqueeze(dim=-1)
loaded[new_key] = item
new_path = str(checkpoint.parent / checkpoint.stem) + "_fixed"+checkpoint.suffix
logger.info(f"Saving file to {new_path}")
save_file(loaded, Path(new_path))
def checkpoint_to_pipeline(
checkpoint: Path,
target_dir: Optional[Path] = None,
save: bool = True,
) -> StableDiffusionPipeline:
logger.debug(f"Converting checkpoint {path_from_cwd(checkpoint)}")
if target_dir is None:
target_dir = pipeline_dir.joinpath(checkpoint.stem)
pipeline = StableDiffusionPipeline.from_single_file(
pretrained_model_link_or_path=str(checkpoint.absolute()),
local_files_only=False,
load_safety_checker=False,
)
if save:
target_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}")
pipeline.save_pretrained(target_dir, safe_serialization=True)
return pipeline, target_dir
def checkpoint_to_pipeline_sdxl(
checkpoint: Path,
target_dir: Optional[Path] = None,
save: bool = True,
) -> StableDiffusionXLPipeline:
logger.debug(f"Converting checkpoint {path_from_cwd(checkpoint)}")
if target_dir is None:
target_dir = pipeline_dir.joinpath(checkpoint.stem)
pipeline = StableDiffusionXLPipeline.from_single_file(
pretrained_model_link_or_path=str(checkpoint.absolute()),
local_files_only=False,
load_safety_checker=False,
)
if save:
target_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}")
pipeline.save_pretrained(target_dir, safe_serialization=True)
return pipeline, target_dir
def get_checkpoint_weights(checkpoint: Path):
temp_pipeline: StableDiffusionPipeline
temp_pipeline, _ = checkpoint_to_pipeline(checkpoint, save=False)
unet_state_dict = temp_pipeline.unet.state_dict()
tenc_state_dict = temp_pipeline.text_encoder.state_dict()
vae_state_dict = temp_pipeline.vae.state_dict()
return unet_state_dict, tenc_state_dict, vae_state_dict
def get_checkpoint_weights_sdxl(checkpoint: Path):
temp_pipeline: StableDiffusionXLPipeline
temp_pipeline, _ = checkpoint_to_pipeline_sdxl(checkpoint, save=False)
unet_state_dict = temp_pipeline.unet.state_dict()
tenc_state_dict = temp_pipeline.text_encoder.state_dict()
tenc2_state_dict = temp_pipeline.text_encoder_2.state_dict()
vae_state_dict = temp_pipeline.vae.state_dict()
return unet_state_dict, tenc_state_dict, tenc2_state_dict, vae_state_dict
def ensure_motion_modules(
repo_id: str = HF_MODULE_REPO,
fp16: bool = False,
force: bool = False,
):
"""Retrieve the motion modules from HuggingFace Hub."""
module_files = ["mm_sd_v14.safetensors", "mm_sd_v15.safetensors"]
module_dir = get_dir("data/models/motion-module")
for file in module_files:
target_path = module_dir.joinpath(file)
if fp16:
target_path = target_path.with_suffix(".fp16.safetensors")
if target_path.exists() and force is not True:
logger.debug(f"File {path_from_cwd(target_path)} already exists, skipping download")
else:
result = hf_hub_download(
repo_id=repo_id,
filename=target_path.name,
cache_dir=HF_HUB_CACHE,
local_dir=module_dir,
local_dir_use_symlinks=False,
resume_download=True,
)
logger.debug(f"Downloaded {path_from_cwd(result)}")