3DFuse / lora_util.py
jyseo's picture
first commit
d661b19
from lora_diffusion.cli_lora_add import *
from lora_diffusion.lora import *
from lora_diffusion.to_ckpt_v2 import *
def monkeypatch_or_replace_safeloras(models, safeloras):
loras = parse_safeloras(safeloras)
for name, (lora, ranks, target) in loras.items():
model = getattr(models, name, None)
if not model:
print(f"No model provided for {name}, contained in Lora")
continue
monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
def parse_safeloras(
safeloras,
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
"""
Converts a loaded safetensor file that contains a set of module Loras
into Parameters and other information
Output is a dictionary of {
"module name": (
[list of weights],
[list of ranks],
target_replacement_modules
)
}
"""
loras = {}
# metadata = safeloras.metadata()
metadata = safeloras['metadata']
safeloras_ = safeloras['weights']
get_name = lambda k: k.split(":")[0]
keys = list(safeloras_.keys())
keys.sort(key=get_name)
for name, module_keys in groupby(keys, get_name):
info = metadata.get(name)
if not info:
raise ValueError(
f"Tensor {name} has no metadata - is this a Lora safetensor?"
)
# Skip Textual Inversion embeds
if info == EMBED_FLAG:
continue
# Handle Loras
# Extract the targets
target = json.loads(info)
# Build the result lists - Python needs us to preallocate lists to insert into them
module_keys = list(module_keys)
ranks = [4] * (len(module_keys) // 2)
weights = [None] * len(module_keys)
for key in module_keys:
# Split the model name and index out of the key
_, idx, direction = key.split(":")
idx = int(idx)
# Add the rank
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
# Insert the weight into the list
idx = idx * 2 + (1 if direction == "down" else 0)
# weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
weights[idx] = nn.parameter.Parameter(safeloras_[key])
loras[name] = (weights, ranks, target)
return loras
def parse_safeloras_embeds(
safeloras,
) -> Dict[str, torch.Tensor]:
"""
Converts a loaded safetensor file that contains Textual Inversion embeds into
a dictionary of embed_token: Tensor
"""
embeds = {}
metadata = safeloras['metadata']
safeloras_ = safeloras['weights']
for key in safeloras_.keys():
# Only handle Textual Inversion embeds
meta=None
if key in metadata:
meta = metadata[key]
if not meta or meta != EMBED_FLAG:
continue
embeds[key] = safeloras_[key]
return embeds
def patch_pipe(
pipe,
maybe_unet_path,
token: Optional[str] = None,
r: int = 4,
patch_unet=True,
patch_text=True,
patch_ti=True,
idempotent_token=True,
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
):
safeloras=maybe_unet_path
monkeypatch_or_replace_safeloras(pipe, safeloras)
tok_dict = parse_safeloras_embeds(safeloras)
if patch_ti:
apply_learned_embed_in_clip(
tok_dict,
pipe.text_encoder,
pipe.tokenizer,
token=token,
idempotent=idempotent_token,
)
return tok_dict
def lora_convert(model_path, as_half):
"""
Modified version of lora_duffusion.to_ckpt_v2.convert_to_ckpt
"""
assert model_path is not None, "Must provide a model path!"
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
# Convert the UNet model
unet_state_dict = torch.load(unet_path, map_location="cpu")
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {
"model.diffusion_model." + k: v for k, v in unet_state_dict.items()
}
# Convert the VAE model
vae_state_dict = torch.load(vae_path, map_location="cpu")
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Convert the text encoder model
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
text_enc_dict = {
"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
}
# Put together new checkpoint
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
if as_half:
state_dict = {k: v.half() for k, v in state_dict.items()}
return state_dict
def merge(path_1: str,
path_2: str,
alpha_1: float = 0.5,
):
loaded_pipeline = StableDiffusionPipeline.from_pretrained(
path_1,
).to("cpu")
tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False)
collapse_lora(loaded_pipeline.unet, alpha_1)
collapse_lora(loaded_pipeline.text_encoder, alpha_1)
monkeypatch_remove_lora(loaded_pipeline.unet)
monkeypatch_remove_lora(loaded_pipeline.text_encoder)
_tmp_output = "./merge.tmp"
loaded_pipeline.save_pretrained(_tmp_output)
state_dict = lora_convert(_tmp_output, as_half=True)
# remove the tmp_output folder
shutil.rmtree(_tmp_output)
keys = sorted(tok_dict.keys())
tok_catted = torch.stack([tok_dict[k] for k in keys])
ret = {
"string_to_token": {"*": torch.tensor(265)},
"string_to_param": {"*": tok_catted},
"name": "",
}
return state_dict, ret