Spaces:
Build error
Build error
File size: 2,209 Bytes
11c2c17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os, sys
sys.path.insert(0, os.getcwd())
import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"base_model", help="The model you want to merge with loha",
default='', type=str
)
parser.add_argument(
"lycoris_model", help="the lyco model you want to merge into sd model",
default='', type=str
)
parser.add_argument(
"output_name", help="the output model",
default='./out.pt', type=str
)
parser.add_argument(
"--is_v2", help="Your base model is sd v2 or not",
default=False, action="store_true"
)
parser.add_argument(
"--device", help="Which device you want to use to merge the weight",
default='cpu', type=str
)
parser.add_argument(
"--dtype", help='dtype to save',
default='float', type=str
)
parser.add_argument(
"--weight", help='weight for the lyco model to merge',
default='1.0', type=float
)
return parser.parse_args()
ARGS = get_args()
from lycoris_utils import merge
from lycoris.kohya_model_utils import (
load_models_from_stable_diffusion_checkpoint,
save_stable_diffusion_checkpoint,
load_file
)
import torch
def main():
base = load_models_from_stable_diffusion_checkpoint(ARGS.is_v2, ARGS.base_model)
if ARGS.lycoris_model.rsplit('.', 1)[-1] == 'safetensors':
lyco = load_file(ARGS.lycoris_model)
else:
lyco = torch.load(ARGS.lycoris_model)
dtype_str = ARGS.dtype.replace('fp', 'float').replace('bf', 'bfloat')
dtype = {
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
'float64': torch.float64,
'bfloat': torch.bfloat16,
'bfloat16': torch.bfloat16,
}.get(dtype_str, None)
if dtype is None:
raise ValueError(f'Cannot Find the dtype "{dtype}"')
merge(
base,
lyco,
ARGS.weight,
ARGS.device
)
save_stable_diffusion_checkpoint(
ARGS.is_v2, ARGS.output_name,
base[0], base[2],
None, 0, 0, dtype,
base[1]
)
if __name__ == '__main__':
main() |