|
import logging |
|
from functools import wraps |
|
from pathlib import Path |
|
from typing import Optional, TypeVar |
|
|
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline |
|
from huggingface_hub import hf_hub_download |
|
from torch import nn |
|
|
|
from animatediff import HF_HUB_CACHE, HF_MODULE_REPO, get_dir |
|
from animatediff.settings import CKPT_EXTENSIONS |
|
from animatediff.utils.huggingface import get_hf_pipeline, get_hf_pipeline_sdxl |
|
from animatediff.utils.util import path_from_cwd |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
data_dir = get_dir("data") |
|
checkpoint_dir = data_dir.joinpath("models/sd") |
|
pipeline_dir = data_dir.joinpath("models/huggingface") |
|
|
|
|
|
T = TypeVar("T", bound=nn.Module) |
|
|
|
|
|
def nop_train(self: T, mode: bool = True) -> T: |
|
"""No-op for monkeypatching train() call to prevent unfreezing module""" |
|
return self |
|
|
|
|
|
def get_base_model(model_name_or_path: str, local_dir: Path, force: bool = False, is_sdxl:bool=False) -> Path: |
|
model_name_or_path = Path(model_name_or_path) |
|
|
|
model_save_dir = local_dir.joinpath(str(model_name_or_path).split("/")[-1]).resolve() |
|
model_is_repo_id = False if model_name_or_path.joinpath("model_index.json").exists() else True |
|
|
|
|
|
if model_is_repo_id: |
|
logger.debug("Base model is a HuggingFace repo ID") |
|
if model_save_dir.joinpath("model_index.json").exists(): |
|
logger.debug(f"Base model already downloaded to: {path_from_cwd(model_save_dir)}") |
|
else: |
|
logger.info(f"Downloading base model from {model_name_or_path}...") |
|
if is_sdxl: |
|
_ = get_hf_pipeline_sdxl(model_name_or_path, model_save_dir, save=True, force_download=force) |
|
else: |
|
_ = get_hf_pipeline(model_name_or_path, model_save_dir, save=True, force_download=force) |
|
model_name_or_path = model_save_dir |
|
|
|
return Path(model_name_or_path) |
|
|
|
|
|
def fix_checkpoint_if_needed(checkpoint: Path, debug:bool): |
|
def dump(loaded): |
|
for a in loaded: |
|
logger.info(f"{a} {loaded[a].shape}") |
|
|
|
if debug: |
|
from safetensors.torch import load_file, save_file |
|
loaded = load_file(checkpoint, "cpu") |
|
|
|
dump(loaded) |
|
|
|
return |
|
|
|
try: |
|
pipeline = StableDiffusionPipeline.from_single_file( |
|
pretrained_model_link_or_path=str(checkpoint.absolute()), |
|
local_files_only=False, |
|
load_safety_checker=False, |
|
) |
|
logger.info("This file works fine.") |
|
return |
|
except: |
|
from safetensors.torch import load_file, save_file |
|
|
|
loaded = load_file(checkpoint, "cpu") |
|
|
|
convert_table_bias={ |
|
"first_stage_model.decoder.mid.attn_1.to_k.bias":"first_stage_model.decoder.mid.attn_1.k.bias", |
|
"first_stage_model.decoder.mid.attn_1.to_out.0.bias":"first_stage_model.decoder.mid.attn_1.proj_out.bias", |
|
"first_stage_model.decoder.mid.attn_1.to_q.bias":"first_stage_model.decoder.mid.attn_1.q.bias", |
|
"first_stage_model.decoder.mid.attn_1.to_v.bias":"first_stage_model.decoder.mid.attn_1.v.bias", |
|
"first_stage_model.encoder.mid.attn_1.to_k.bias":"first_stage_model.encoder.mid.attn_1.k.bias", |
|
"first_stage_model.encoder.mid.attn_1.to_out.0.bias":"first_stage_model.encoder.mid.attn_1.proj_out.bias", |
|
"first_stage_model.encoder.mid.attn_1.to_q.bias":"first_stage_model.encoder.mid.attn_1.q.bias", |
|
"first_stage_model.encoder.mid.attn_1.to_v.bias":"first_stage_model.encoder.mid.attn_1.v.bias", |
|
} |
|
|
|
convert_table_weight={ |
|
"first_stage_model.decoder.mid.attn_1.to_k.weight":"first_stage_model.decoder.mid.attn_1.k.weight", |
|
"first_stage_model.decoder.mid.attn_1.to_out.0.weight":"first_stage_model.decoder.mid.attn_1.proj_out.weight", |
|
"first_stage_model.decoder.mid.attn_1.to_q.weight":"first_stage_model.decoder.mid.attn_1.q.weight", |
|
"first_stage_model.decoder.mid.attn_1.to_v.weight":"first_stage_model.decoder.mid.attn_1.v.weight", |
|
"first_stage_model.encoder.mid.attn_1.to_k.weight":"first_stage_model.encoder.mid.attn_1.k.weight", |
|
"first_stage_model.encoder.mid.attn_1.to_out.0.weight":"first_stage_model.encoder.mid.attn_1.proj_out.weight", |
|
"first_stage_model.encoder.mid.attn_1.to_q.weight":"first_stage_model.encoder.mid.attn_1.q.weight", |
|
"first_stage_model.encoder.mid.attn_1.to_v.weight":"first_stage_model.encoder.mid.attn_1.v.weight", |
|
} |
|
|
|
for a in list(loaded.keys()): |
|
if a in convert_table_bias: |
|
new_key = convert_table_bias[a] |
|
loaded[new_key] = loaded.pop(a) |
|
elif a in convert_table_weight: |
|
new_key = convert_table_weight[a] |
|
item = loaded.pop(a) |
|
if len(item.shape) == 2: |
|
item = item.unsqueeze(dim=-1).unsqueeze(dim=-1) |
|
loaded[new_key] = item |
|
|
|
new_path = str(checkpoint.parent / checkpoint.stem) + "_fixed"+checkpoint.suffix |
|
|
|
logger.info(f"Saving file to {new_path}") |
|
save_file(loaded, Path(new_path)) |
|
|
|
|
|
|
|
def checkpoint_to_pipeline( |
|
checkpoint: Path, |
|
target_dir: Optional[Path] = None, |
|
save: bool = True, |
|
) -> StableDiffusionPipeline: |
|
logger.debug(f"Converting checkpoint {path_from_cwd(checkpoint)}") |
|
if target_dir is None: |
|
target_dir = pipeline_dir.joinpath(checkpoint.stem) |
|
|
|
pipeline = StableDiffusionPipeline.from_single_file( |
|
pretrained_model_link_or_path=str(checkpoint.absolute()), |
|
local_files_only=False, |
|
load_safety_checker=False, |
|
) |
|
|
|
if save: |
|
target_dir.mkdir(parents=True, exist_ok=True) |
|
logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}") |
|
pipeline.save_pretrained(target_dir, safe_serialization=True) |
|
return pipeline, target_dir |
|
|
|
def checkpoint_to_pipeline_sdxl( |
|
checkpoint: Path, |
|
target_dir: Optional[Path] = None, |
|
save: bool = True, |
|
) -> StableDiffusionXLPipeline: |
|
logger.debug(f"Converting checkpoint {path_from_cwd(checkpoint)}") |
|
if target_dir is None: |
|
target_dir = pipeline_dir.joinpath(checkpoint.stem) |
|
|
|
pipeline = StableDiffusionXLPipeline.from_single_file( |
|
pretrained_model_link_or_path=str(checkpoint.absolute()), |
|
local_files_only=False, |
|
load_safety_checker=False, |
|
) |
|
|
|
if save: |
|
target_dir.mkdir(parents=True, exist_ok=True) |
|
logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}") |
|
pipeline.save_pretrained(target_dir, safe_serialization=True) |
|
return pipeline, target_dir |
|
|
|
def get_checkpoint_weights(checkpoint: Path): |
|
temp_pipeline: StableDiffusionPipeline |
|
temp_pipeline, _ = checkpoint_to_pipeline(checkpoint, save=False) |
|
unet_state_dict = temp_pipeline.unet.state_dict() |
|
tenc_state_dict = temp_pipeline.text_encoder.state_dict() |
|
vae_state_dict = temp_pipeline.vae.state_dict() |
|
return unet_state_dict, tenc_state_dict, vae_state_dict |
|
|
|
def get_checkpoint_weights_sdxl(checkpoint: Path): |
|
temp_pipeline: StableDiffusionXLPipeline |
|
temp_pipeline, _ = checkpoint_to_pipeline_sdxl(checkpoint, save=False) |
|
unet_state_dict = temp_pipeline.unet.state_dict() |
|
tenc_state_dict = temp_pipeline.text_encoder.state_dict() |
|
tenc2_state_dict = temp_pipeline.text_encoder_2.state_dict() |
|
vae_state_dict = temp_pipeline.vae.state_dict() |
|
return unet_state_dict, tenc_state_dict, tenc2_state_dict, vae_state_dict |
|
|
|
|
|
def ensure_motion_modules( |
|
repo_id: str = HF_MODULE_REPO, |
|
fp16: bool = False, |
|
force: bool = False, |
|
): |
|
"""Retrieve the motion modules from HuggingFace Hub.""" |
|
module_files = ["mm_sd_v14.safetensors", "mm_sd_v15.safetensors"] |
|
module_dir = get_dir("data/models/motion-module") |
|
for file in module_files: |
|
target_path = module_dir.joinpath(file) |
|
if fp16: |
|
target_path = target_path.with_suffix(".fp16.safetensors") |
|
if target_path.exists() and force is not True: |
|
logger.debug(f"File {path_from_cwd(target_path)} already exists, skipping download") |
|
else: |
|
result = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=target_path.name, |
|
cache_dir=HF_HUB_CACHE, |
|
local_dir=module_dir, |
|
local_dir_use_symlinks=False, |
|
resume_download=True, |
|
) |
|
logger.debug(f"Downloaded {path_from_cwd(result)}") |
|
|