File size: 4,251 Bytes
644b57d
 
 
 
 
 
 
 
08f09e2
644b57d
2a8f434
644b57d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from diffusers.models.attention_processor import Attention
from diffusers import ModelMixin, ConfigMixin
import functools

from .attention import GeneralizedLinearAttention


model_dict = {
    "SG161222/Realistic_Vision_V4.0_noVAE": "Yuanshi/LinFusion-1-5",
    "Lykon/dreamshaper-8": "Yuanshi/LinFusion-1-5",
    "CompVis/stable-diffusion-v1-4": "Yuanshi/LinFusion-1-5"
}


def replace_submodule(model, module_name, new_submodule):
    path, attr = module_name.rsplit(".", 1)
    parent_module = functools.reduce(getattr, path.split("."), model)
    setattr(parent_module, attr, new_submodule)


class LinFusion(ModelMixin, ConfigMixin):
    def __init__(self, modules_list, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.modules_dict = {}
        self.register_to_config(modules_list=modules_list)

        for i, attention_config in enumerate(modules_list):
            dim_n = attention_config["dim_n"]
            heads = attention_config["heads"]
            projection_mid_dim = attention_config["projection_mid_dim"]
            linear_attention = GeneralizedLinearAttention(
                query_dim=dim_n,
                out_dim=dim_n,
                dim_head=dim_n // heads,
                projection_mid_dim=projection_mid_dim,
            )
            self.add_module(f"{i}", linear_attention)
            self.modules_dict[attention_config["module_name"]] = linear_attention

    @classmethod
    def get_default_config(
        cls,
        pipeline=None,
        unet=None,
    ):
        """
        Get the default configuration for the LinFusion model.
        (The `projection_mid_dim` is same as the `query_dim` by default.)
        """
        assert unet is not None or pipeline.unet is not None
        unet = unet or pipeline.unet
        modules_list = []
        for module_name, module in unet.named_modules():
            if not isinstance(module, Attention):
                continue
            if "attn1" not in module_name:
                continue
            dim_n = module.to_q.weight.shape[0]
            # modules_list.append((module_name, dim_n, module.heads))
            modules_list.append(
                {
                    "module_name": module_name,
                    "dim_n": dim_n,
                    "heads": module.heads,
                    "projection_mid_dim": None,
                }
            )
        return {"modules_list": modules_list}

    @classmethod
    def construct_for(
        cls,
        pipeline=None,
        unet=None,
        load_pretrained=True,
        pretrained_model_name_or_path=None,
    ) -> "LinFusion":
        """
        Construct a LinFusion object for the given pipeline.
        """
        assert unet is not None or pipeline.unet is not None
        unet = unet or pipeline.unet
        if load_pretrained:
            # Load from pretrained
            pipe_name_path = pipeline._internal_dict._name_or_path
            if not pretrained_model_name_or_path:
                pretrained_model_name_or_path = model_dict.get(pipe_name_path, None)
                if pretrained_model_name_or_path:
                    print(
                        f"Matching LinFusion '{pretrained_model_name_or_path}' for pipeline '{pipe_name_path}'."
                    )
                else:
                    raise Exception(
                        f"LinFusion not found for pipeline [{pipe_name_path}], please provide the path."
                    )
            linfusion = (
                LinFusion.from_pretrained(pretrained_model_name_or_path)
                .to(pipeline.device)
                .to(pipeline.dtype)
            )
        else:
            # Create from scratch without pretrained parameters
            default_config = LinFusion.get_default_config(pipeline)
            linfusion = (
                LinFusion(**default_config).to(pipeline.device).to(pipeline.dtype)
            )
        linfusion.mount_to(unet)
        return linfusion

    def mount_to(self, unet) -> None:
        """
        Mounts the modules in the `modules_dict` to the given `pipeline`.
        """
        for module_name, module in self.modules_dict.items():
            replace_submodule(unet, module_name, module)