File size: 2,295 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from comfy.model_detection import count_blocks
from comfy.utils import flux_to_diffusers

def convert_diffusers_flux_lora(state_dict, output_prefix=""):
    out_sd = {}

    state_dict = {key.replace('transformer.', ''): value for key, value in state_dict.items()}

    new_sd = {}
    for lora_a in state_dict.keys():
        if "lora_A" not in lora_a:
            continue
        lora_b = lora_a.replace("lora_A", "lora_B")
        key = lora_a.replace("lora_A.", "")
        assert "lora_A" in lora_a and "lora_B" in lora_b, f"Invalid LoRA checkpoint. {lora_a} {lora_b}"

        new_sd[key] = state_dict[lora_b] @ state_dict[lora_a]

    state_dict = new_sd

    depth = count_blocks(state_dict, 'transformer_blocks.{}.')
    depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
    hidden_size = state_dict["x_embedder.weight"].shape[0]
    sd_map = flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)

    for k in sd_map:
        weight = state_dict.get(k, None)
        if weight is not None:
            t = sd_map[k]

            if not isinstance(t, str):
                if len(t) > 2:
                    fun = t[2]
                else:
                    fun = lambda a: a
                offset = t[1]
                if offset is not None:
                    old_weight = out_sd.get(t[0], None)
                    if old_weight is None:
                        old_weight = torch.empty_like(weight)
                    if old_weight.shape[offset[0]] < offset[1] + offset[2]:
                        exp = list(weight.shape)
                        exp[offset[0]] = offset[1] + offset[2]
                        new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
                        new[:old_weight.shape[0]] = old_weight
                        old_weight = new

                    w = old_weight.narrow(offset[0], offset[1], offset[2])
                else:
                    old_weight = weight
                    w = weight
                w[:] = fun(weight)
                t = t[0]
                out_sd[t] = old_weight
            else:
                out_sd[t] = weight
            state_dict.pop(k)

    return out_sd