File size: 3,579 Bytes
fd74f7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Literal, Union, Dict
import os
import shutil
import fire
from diffusers import StableDiffusionPipeline

import torch
from .lora import tune_lora_scale, weight_apply_lora
from .to_ckpt_v2 import convert_to_ckpt


def _text_lora_path(path: str) -> str:
    assert path.endswith(".pt"), "Only .pt files are supported"
    return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])


def add(
    path_1: str,
    path_2: str,
    output_path: str,
    alpha: float = 0.5,
    mode: Literal[
        "lpl",
        "upl",
        "upl-ckpt-v2",
    ] = "lpl",
    with_text_lora: bool = False,
):
    print("Lora Add, mode " + mode)
    if mode == "lpl":
        for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
            [(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
            if with_text_lora
            else []
        ):
            print("Loading", _path_1, _path_2)
            out_list = []
            if opt == "text_encoder":
                if not os.path.exists(_path_1):
                    print(f"No text encoder found in {_path_1}, skipping...")
                    continue
                if not os.path.exists(_path_2):
                    print(f"No text encoder found in {_path_1}, skipping...")
                    continue

            l1 = torch.load(_path_1)
            l2 = torch.load(_path_2)

            l1pairs = zip(l1[::2], l1[1::2])
            l2pairs = zip(l2[::2], l2[1::2])

            for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
                # print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
                x1.data = alpha * x1.data + (1 - alpha) * x2.data
                y1.data = alpha * y1.data + (1 - alpha) * y2.data

                out_list.append(x1)
                out_list.append(y1)

            if opt == "unet":

                print("Saving merged UNET to", output_path)
                torch.save(out_list, output_path)

            elif opt == "text_encoder":
                print("Saving merged text encoder to", _text_lora_path(output_path))
                torch.save(
                    out_list,
                    _text_lora_path(output_path),
                )

    elif mode == "upl":

        loaded_pipeline = StableDiffusionPipeline.from_pretrained(
            path_1,
        ).to("cpu")

        weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
        if with_text_lora:

            weight_apply_lora(
                loaded_pipeline.text_encoder,
                torch.load(_text_lora_path(path_2)),
                alpha=alpha,
                target_replace_module=["CLIPAttention"],
            )

        loaded_pipeline.save_pretrained(output_path)

    elif mode == "upl-ckpt-v2":

        loaded_pipeline = StableDiffusionPipeline.from_pretrained(
            path_1,
        ).to("cpu")

        weight_apply_lora(loaded_pipeline.unet, torch.load(path_2), alpha=alpha)
        if with_text_lora:
            weight_apply_lora(
                loaded_pipeline.text_encoder,
                torch.load(_text_lora_path(path_2)),
                alpha=alpha,
                target_replace_module=["CLIPAttention"],
            )

        _tmp_output = output_path + ".tmp"

        loaded_pipeline.save_pretrained(_tmp_output)
        convert_to_ckpt(_tmp_output, output_path, as_half=True)
        # remove the tmp_output folder
        shutil.rmtree(_tmp_output)

    else:
        print("Unknown mode", mode)
        raise ValueError(f"Unknown mode {mode}")


def main():
    fire.Fire(add)