Spaces:
Runtime error
Runtime error
Create linfusion.py
Browse files- src/linfusion/linfusion.py +116 -0
src/linfusion/linfusion.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers.models.attention_processor import Attention
|
2 |
+
from diffusers import ModelMixin, ConfigMixin
|
3 |
+
import functools
|
4 |
+
|
5 |
+
from .attention import GeneralizedLinearAttention
|
6 |
+
|
7 |
+
|
8 |
+
model_dict = {
|
9 |
+
"runwayml/stable-diffusion-v1-5": "Yuanshi/LinFusion-1-5",
|
10 |
+
"stablediffusionapi/realistic-vision-v51": "Yuanshi/LinFusion-1-5",
|
11 |
+
"Lykon/dreamshaper-8": "Yuanshi/LinFusion-1-5",
|
12 |
+
}
|
13 |
+
|
14 |
+
|
15 |
+
def replace_submodule(model, module_name, new_submodule):
|
16 |
+
path, attr = module_name.rsplit(".", 1)
|
17 |
+
parent_module = functools.reduce(getattr, path.split("."), model)
|
18 |
+
setattr(parent_module, attr, new_submodule)
|
19 |
+
|
20 |
+
|
21 |
+
class LinFusion(ModelMixin, ConfigMixin):
|
22 |
+
def __init__(self, modules_list, *args, **kwargs) -> None:
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
|
25 |
+
self.modules_dict = {}
|
26 |
+
self.register_to_config(modules_list=modules_list)
|
27 |
+
|
28 |
+
for i, attention_config in enumerate(modules_list):
|
29 |
+
dim_n = attention_config["dim_n"]
|
30 |
+
heads = attention_config["heads"]
|
31 |
+
projection_mid_dim = attention_config["projection_mid_dim"]
|
32 |
+
linear_attention = GeneralizedLinearAttention(
|
33 |
+
query_dim=dim_n,
|
34 |
+
out_dim=dim_n,
|
35 |
+
dim_head=dim_n // heads,
|
36 |
+
projection_mid_dim=projection_mid_dim,
|
37 |
+
)
|
38 |
+
self.add_module(f"{i}", linear_attention)
|
39 |
+
self.modules_dict[attention_config["module_name"]] = linear_attention
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def get_default_config(
|
43 |
+
cls,
|
44 |
+
pipeline=None,
|
45 |
+
unet=None,
|
46 |
+
):
|
47 |
+
"""
|
48 |
+
Get the default configuration for the LinFusion model.
|
49 |
+
(The `projection_mid_dim` is same as the `query_dim` by default.)
|
50 |
+
"""
|
51 |
+
assert unet is not None or pipeline.unet is not None
|
52 |
+
unet = unet or pipeline.unet
|
53 |
+
modules_list = []
|
54 |
+
for module_name, module in unet.named_modules():
|
55 |
+
if not isinstance(module, Attention):
|
56 |
+
continue
|
57 |
+
if "attn1" not in module_name:
|
58 |
+
continue
|
59 |
+
dim_n = module.to_q.weight.shape[0]
|
60 |
+
# modules_list.append((module_name, dim_n, module.heads))
|
61 |
+
modules_list.append(
|
62 |
+
{
|
63 |
+
"module_name": module_name,
|
64 |
+
"dim_n": dim_n,
|
65 |
+
"heads": module.heads,
|
66 |
+
"projection_mid_dim": None,
|
67 |
+
}
|
68 |
+
)
|
69 |
+
return {"modules_list": modules_list}
|
70 |
+
|
71 |
+
@classmethod
|
72 |
+
def construct_for(
|
73 |
+
cls,
|
74 |
+
pipeline=None,
|
75 |
+
unet=None,
|
76 |
+
load_pretrained=True,
|
77 |
+
pretrained_model_name_or_path=None,
|
78 |
+
) -> "LinFusion":
|
79 |
+
"""
|
80 |
+
Construct a LinFusion object for the given pipeline.
|
81 |
+
"""
|
82 |
+
assert unet is not None or pipeline.unet is not None
|
83 |
+
unet = unet or pipeline.unet
|
84 |
+
if load_pretrained:
|
85 |
+
# Load from pretrained
|
86 |
+
pipe_name_path = pipeline._internal_dict._name_or_path
|
87 |
+
if not pretrained_model_name_or_path:
|
88 |
+
pretrained_model_name_or_path = model_dict.get(pipe_name_path, None)
|
89 |
+
if pretrained_model_name_or_path:
|
90 |
+
print(
|
91 |
+
f"Matching LinFusion '{pretrained_model_name_or_path}' for pipeline '{pipe_name_path}'."
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
raise Exception(
|
95 |
+
f"LinFusion not found for pipeline [{pipe_name_path}], please provide the path."
|
96 |
+
)
|
97 |
+
linfusion = (
|
98 |
+
LinFusion.from_pretrained(pretrained_model_name_or_path)
|
99 |
+
.to(pipeline.device)
|
100 |
+
.to(pipeline.dtype)
|
101 |
+
)
|
102 |
+
else:
|
103 |
+
# Create from scratch without pretrained parameters
|
104 |
+
default_config = LinFusion.get_default_config(pipeline)
|
105 |
+
linfusion = (
|
106 |
+
LinFusion(**default_config).to(pipeline.device).to(pipeline.dtype)
|
107 |
+
)
|
108 |
+
linfusion.mount_to(unet)
|
109 |
+
return linfusion
|
110 |
+
|
111 |
+
def mount_to(self, unet) -> None:
|
112 |
+
"""
|
113 |
+
Mounts the modules in the `modules_dict` to the given `pipeline`.
|
114 |
+
"""
|
115 |
+
for module_name, module in self.modules_dict.items():
|
116 |
+
replace_submodule(unet, module_name, module)
|