Spaces:
Build error
Build error
# Convert LoRA to different rank approximation (should only be used to go to lower rank) | |
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py | |
# Thanks to cloneofsimo | |
import argparse | |
import math | |
import os | |
import torch | |
from safetensors.torch import load_file, save_file, safe_open | |
from tqdm import tqdm | |
from library import train_util, model_util | |
import numpy as np | |
def load_state_dict(file_name): | |
if model_util.is_safetensors(file_name): | |
sd = load_file(file_name) | |
with safe_open(file_name, framework="pt") as f: | |
metadata = f.metadata() | |
else: | |
sd = torch.load(file_name, map_location="cpu") | |
metadata = None | |
return sd, metadata | |
def save_to_file(file_name, model, metadata): | |
if model_util.is_safetensors(file_name): | |
save_file(model, file_name, metadata) | |
else: | |
torch.save(model, file_name) | |
def split_lora_model(lora_sd, unit): | |
max_rank = 0 | |
# Extract loaded lora dim and alpha | |
for key, value in lora_sd.items(): | |
if "lora_down" in key: | |
rank = value.size()[0] | |
if rank > max_rank: | |
max_rank = rank | |
print(f"Max rank: {max_rank}") | |
rank = unit | |
split_models = [] | |
new_alpha = None | |
while rank < max_rank: | |
print(f"Splitting rank {rank}") | |
new_sd = {} | |
for key, value in lora_sd.items(): | |
if "lora_down" in key: | |
new_sd[key] = value[:rank].contiguous() | |
elif "lora_up" in key: | |
new_sd[key] = value[:, :rank].contiguous() | |
else: | |
# なぜかscaleするとおかしくなる…… | |
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0] | |
# scale = math.sqrt(this_rank / rank) # rank is > unit | |
# print(key, value.size(), this_rank, rank, value, scale) | |
# new_alpha = value * scale # always same | |
# new_sd[key] = new_alpha | |
new_sd[key] = value | |
split_models.append((new_sd, rank, new_alpha)) | |
rank += unit | |
return max_rank, split_models | |
def split(args): | |
print("loading Model...") | |
lora_sd, metadata = load_state_dict(args.model) | |
print("Splitting Model...") | |
original_rank, split_models = split_lora_model(lora_sd, args.unit) | |
comment = metadata.get("ss_training_comment", "") | |
for state_dict, new_rank, new_alpha in split_models: | |
# update metadata | |
if metadata is None: | |
new_metadata = {} | |
else: | |
new_metadata = metadata.copy() | |
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}" | |
new_metadata["ss_network_dim"] = str(new_rank) | |
# new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy()) | |
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) | |
metadata["sshs_model_hash"] = model_hash | |
metadata["sshs_legacy_hash"] = legacy_hash | |
filename, ext = os.path.splitext(args.save_to) | |
model_file_name = filename + f"-{new_rank:04d}{ext}" | |
print(f"saving model to: {model_file_name}") | |
save_to_file(model_file_name, state_dict, new_metadata) | |
def setup_parser() -> argparse.ArgumentParser: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ") | |
parser.add_argument( | |
"--save_to", | |
type=str, | |
default=None, | |
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors", | |
) | |
parser.add_argument( | |
"--model", | |
type=str, | |
default=None, | |
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors", | |
) | |
return parser | |
if __name__ == "__main__": | |
parser = setup_parser() | |
args = parser.parse_args() | |
split(args) | |