Spaces:
Sleeping
Sleeping
from diffusers.models.attention_processor import Attention | |
from diffusers import ModelMixin, ConfigMixin | |
import functools | |
from .attention import GeneralizedLinearAttention | |
model_dict = { | |
"SG161222/Realistic_Vision_V4.0_noVAE": "Yuanshi/LinFusion-1-5", | |
"Lykon/dreamshaper-8": "Yuanshi/LinFusion-1-5", | |
"CompVis/stable-diffusion-v1-4": "Yuanshi/LinFusion-1-5" | |
} | |
def replace_submodule(model, module_name, new_submodule): | |
path, attr = module_name.rsplit(".", 1) | |
parent_module = functools.reduce(getattr, path.split("."), model) | |
setattr(parent_module, attr, new_submodule) | |
class LinFusion(ModelMixin, ConfigMixin): | |
def __init__(self, modules_list, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.modules_dict = {} | |
self.register_to_config(modules_list=modules_list) | |
for i, attention_config in enumerate(modules_list): | |
dim_n = attention_config["dim_n"] | |
heads = attention_config["heads"] | |
projection_mid_dim = attention_config["projection_mid_dim"] | |
linear_attention = GeneralizedLinearAttention( | |
query_dim=dim_n, | |
out_dim=dim_n, | |
dim_head=dim_n // heads, | |
projection_mid_dim=projection_mid_dim, | |
) | |
self.add_module(f"{i}", linear_attention) | |
self.modules_dict[attention_config["module_name"]] = linear_attention | |
def get_default_config( | |
cls, | |
pipeline=None, | |
unet=None, | |
): | |
""" | |
Get the default configuration for the LinFusion model. | |
(The `projection_mid_dim` is same as the `query_dim` by default.) | |
""" | |
assert unet is not None or pipeline.unet is not None | |
unet = unet or pipeline.unet | |
modules_list = [] | |
for module_name, module in unet.named_modules(): | |
if not isinstance(module, Attention): | |
continue | |
if "attn1" not in module_name: | |
continue | |
dim_n = module.to_q.weight.shape[0] | |
# modules_list.append((module_name, dim_n, module.heads)) | |
modules_list.append( | |
{ | |
"module_name": module_name, | |
"dim_n": dim_n, | |
"heads": module.heads, | |
"projection_mid_dim": None, | |
} | |
) | |
return {"modules_list": modules_list} | |
def construct_for( | |
cls, | |
pipeline=None, | |
unet=None, | |
load_pretrained=True, | |
pretrained_model_name_or_path=None, | |
) -> "LinFusion": | |
""" | |
Construct a LinFusion object for the given pipeline. | |
""" | |
assert unet is not None or pipeline.unet is not None | |
unet = unet or pipeline.unet | |
if load_pretrained: | |
# Load from pretrained | |
pipe_name_path = pipeline._internal_dict._name_or_path | |
if not pretrained_model_name_or_path: | |
pretrained_model_name_or_path = model_dict.get(pipe_name_path, None) | |
if pretrained_model_name_or_path: | |
print( | |
f"Matching LinFusion '{pretrained_model_name_or_path}' for pipeline '{pipe_name_path}'." | |
) | |
else: | |
raise Exception( | |
f"LinFusion not found for pipeline [{pipe_name_path}], please provide the path." | |
) | |
linfusion = ( | |
LinFusion.from_pretrained(pretrained_model_name_or_path) | |
.to(pipeline.device) | |
.to(pipeline.dtype) | |
) | |
else: | |
# Create from scratch without pretrained parameters | |
default_config = LinFusion.get_default_config(pipeline) | |
linfusion = ( | |
LinFusion(**default_config).to(pipeline.device).to(pipeline.dtype) | |
) | |
linfusion.mount_to(unet) | |
return linfusion | |
def mount_to(self, unet) -> None: | |
""" | |
Mounts the modules in the `modules_dict` to the given `pipeline`. | |
""" | |
for module_name, module in self.modules_dict.items(): | |
replace_submodule(unet, module_name, module) |