Spaces:
Configuration error
Configuration error
from typing import Literal, Union, Dict | |
import fire | |
from diffusers import StableDiffusionPipeline | |
import torch | |
from .lora import tune_lora_scale, weight_apply_lora | |
def add( | |
path_1: str, | |
path_2: str, | |
output_path: str = "./merged_lora.pt", | |
alpha: float = 0.5, | |
mode: Literal["lpl", "upl"] = "lpl", | |
): | |
if mode == "lpl": | |
out_list = [] | |
l1 = torch.load(path_1) | |
l2 = torch.load(path_2) | |
l1pairs = zip(l1[::2], l1[1::2]) | |
l2pairs = zip(l2[::2], l2[1::2]) | |
for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs): | |
x1.data = alpha * x1.data + (1 - alpha) * x2.data | |
y1.data = alpha * y1.data + (1 - alpha) * y2.data | |
out_list.append(x1) | |
out_list.append(y1) | |
torch.save(out_list, output_path) | |
elif mode == "upl": | |
loaded_pipeline = StableDiffusionPipeline.from_pretrained( | |
path_1, | |
).to("cpu") | |
weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha) | |
if output_path.endswith(".pt"): | |
output_path = output_path[:-3] | |
loaded_pipeline.save_pretrained(output_path) | |
def main(): | |
fire.Fire(add) | |