File size: 4,920 Bytes
d4c1bb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from safetensors.torch import load_file
import torch
from tqdm import tqdm

__all__ = [
    'flux_load_lora'
]


def is_int(d):
    try:
        d = int(d)
        return True
    except Exception as e:
        return False


def flux_load_lora(self, lora_file, lora_weight=1.0):
    device = self.transformer.device

    # DiT 部分
    state_dict, network_alphas = self.lora_state_dict(lora_file, return_alphas=True)
    state_dict = {k:v.to(device) for k,v in state_dict.items()}
    
    model = self.transformer
    keys = list(state_dict.keys())
    keys = [k for k in keys if k.startswith('transformer.')]

    for k_lora in tqdm(keys, total=len(keys), desc=f"loading lora in transformer ..."):
        v_lora = state_dict[k_lora]

        # 非 up 的都跳过
        if '.lora_A.weight' in k_lora:
            continue
        if '.alpha' in k_lora:
            continue

        k_lora_name = k_lora.replace("transformer.", "")
        k_lora_name = k_lora_name.replace(".lora_B.weight", "")
        attr_name_list = k_lora_name.split('.')

        cur_attr = model
        latest_attr_name = ''
        for idx in range(0, len(attr_name_list)):
            attr_name = attr_name_list[idx]
            if is_int(attr_name):
                cur_attr = cur_attr[int(attr_name)]
                latest_attr_name = ''
            else:
                try:
                    if latest_attr_name != '':
                        cur_attr = cur_attr.__getattr__(f"{latest_attr_name}.{attr_name}")
                    else:
                        cur_attr = cur_attr.__getattr__(attr_name)
                    latest_attr_name = ''
                except Exception as e:
                    if latest_attr_name != '':
                        latest_attr_name = f"{latest_attr_name}.{attr_name}"
                    else:
                        latest_attr_name = attr_name

        up_w = v_lora
        down_w = state_dict[k_lora.replace('.lora_B.weight', '.lora_A.weight')]

        # 赋值
        einsum_a = f"ijabcdefg"
        einsum_b = f"jkabcdefg"
        einsum_res = f"ikabcdefg"
        length_shape = len(up_w.shape)
        einsum_str = f"{einsum_a[:length_shape]},{einsum_b[:length_shape]}->{einsum_res[:length_shape]}"
        dtype = cur_attr.weight.data.dtype
        d_w = torch.einsum(einsum_str, up_w.to(torch.float32), down_w.to(torch.float32)).to(dtype)
        cur_attr.weight.data = cur_attr.weight.data + d_w * lora_weight



    # text encoder 部分
    raw_state_dict = load_file(lora_file)
    raw_state_dict = {k:v.to(device) for k,v in raw_state_dict.items()}

    # text encoder
    state_dict = {k:v for k,v in raw_state_dict.items() if 'lora_te1_' in k}
    model = self.text_encoder
    keys = list(state_dict.keys())
    keys = [k for k in keys if k.startswith('lora_te1_')]

    for k_lora in tqdm(keys, total=len(keys), desc=f"loading lora in text_encoder ..."):
        v_lora = state_dict[k_lora]

        # 非 up 的都跳过
        if '.lora_down.weight' in k_lora:
            continue
        if '.alpha' in k_lora:
            continue

        k_lora_name = k_lora.replace("lora_te1_", "")
        k_lora_name = k_lora_name.replace(".lora_up.weight", "")
        attr_name_list = k_lora_name.split('_')

        cur_attr = model
        latest_attr_name = ''
        for idx in range(0, len(attr_name_list)):
            attr_name = attr_name_list[idx]
            if is_int(attr_name):
                cur_attr = cur_attr[int(attr_name)]
                latest_attr_name = ''
            else:
                try:
                    if latest_attr_name != '':
                        cur_attr = cur_attr.__getattr__(f"{latest_attr_name}_{attr_name}")
                    else:
                        cur_attr = cur_attr.__getattr__(attr_name)
                    latest_attr_name = ''
                except Exception as e:
                    if latest_attr_name != '':
                        latest_attr_name = f"{latest_attr_name}_{attr_name}"
                    else:
                        latest_attr_name = attr_name

        up_w = v_lora
        down_w = state_dict[k_lora.replace('.lora_up.weight', '.lora_down.weight')]
        
        alpha = state_dict.get(k_lora.replace('.lora_up.weight', '.alpha'), None)
        if alpha is None:
            lora_scale = 1
        else:
            rank = up_w.shape[1]
            lora_scale = alpha / rank
        
        # 赋值
        einsum_a = f"ijabcdefg"
        einsum_b = f"jkabcdefg"
        einsum_res = f"ikabcdefg"
        length_shape = len(up_w.shape)
        einsum_str = f"{einsum_a[:length_shape]},{einsum_b[:length_shape]}->{einsum_res[:length_shape]}"
        dtype = cur_attr.weight.data.dtype
        d_w = torch.einsum(einsum_str, up_w.to(torch.float32), down_w.to(torch.float32)).to(dtype)
        cur_attr.weight.data = cur_attr.weight.data + d_w * lora_scale * lora_weight