weights2weights / lora_w2w.py
amildravid4292's picture
Update lora_w2w.py
e789a6b verified
# ref:
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
import os
import math
from typing import Optional, List, Type, Set, Literal
import torch
import torch.nn as nn
from diffusers import UNet2DConditionModel
from safetensors.torch import save_file
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
# "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
"Attention"
]
UNET_TARGET_REPLACE_MODULE_CONV = [
"ResnetBlock2D",
"Downsample2D",
"Upsample2D",
"DownBlock2D",
"UpBlock2D",
] # locon, 3clier
LORA_PREFIX_UNET = "lora_unet"
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
TRAINING_METHODS = Literal[
"noxattn", # train all layers except x-attns and time_embed layers
"innoxattn", # train all layers except self attention layers
"selfattn", # ESD-u, train only self attention layers
"xattn", # ESD-x, train only x attention layers
"full", # train all layers
"xattn-strict", # q and k values
"noxattn-hspace",
"noxattn-hspace-last",
# "xlayer",
# "outxattn",
# "outsattn",
# "inxattn",
# "inmidsattn",
# "selflayer",
]
class LoRAModule(nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
proj,
v,
mean,
std,
org_module: nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
self.in_dim = org_module.in_features
self.out_dim = org_module.out_features
self.proj = proj.bfloat16().cuda()
self.mean1 = mean[0:self.in_dim].bfloat16().cuda()
self.mean2 = mean[self.in_dim:].bfloat16().cuda()
self.std1 = std[0:self.in_dim].bfloat16().cuda()
self.std2 = std[self.in_dim:].bfloat16().cuda()
self.v1 = v[0:self.in_dim].bfloat16().cuda()
self.v2 = v[self.in_dim: ].bfloat16().cuda()
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.multiplier = multiplier
self.org_module = org_module
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
return self.org_forward(x) +\
(x@(([email protected])*self.std1+self.mean1).T)@((([email protected])*self.std2+self.mean2))*self.multiplier*self.scale
class LoRAw2w(nn.Module):
def __init__(
self,
proj,
mean,
std,
v,
unet: UNet2DConditionModel,
rank: int = 4,
multiplier: torch.bfloat16= 1.0,
alpha: torch.bfloat16 = 1.0,
train_method: TRAINING_METHODS = "full"
) -> None:
super().__init__()
self.lora_scale = 1
self.multiplier = multiplier
self.lora_dim = rank
self.alpha = alpha
self.proj = torch.nn.Parameter(proj)
self.register_buffer("mean", torch.tensor(mean))
self.register_buffer("std", torch.tensor(std))
self.register_buffer("v", torch.tensor(v))
self.module = LoRAModule
self.unet_loras = self.create_modules(
LORA_PREFIX_UNET,
unet,
DEFAULT_TARGET_REPLACE,
self.lora_dim,
self.multiplier,
train_method=train_method,
)
self.lora_names = set()
for lora in self.unet_loras:
assert (
lora.lora_name not in self.lora_names
), f"duplicated lora name: {lora.lora_name}. {self.lora_names}"
self.lora_names.add(lora.lora_name)
for lora in self.unet_loras:
lora.apply_to()
self.add_module(
lora.lora_name,
lora,
)
del unet
torch.cuda.empty_cache()
def reset(self):
for lora in self.unet_loras:
lora.proj = torch.nn.Parameter(self.proj.bfloat16())
def create_modules(
self,
prefix: str,
root_module: nn.Module,
target_replace_modules: List[str],
rank: int,
multiplier: float,
train_method: TRAINING_METHODS,
) -> list:
counter = 0
mm = []
nn = []
for name, module in root_module.named_modules():
nn.append(name)
mm.append(module)
midstart = 0
upstart = 0
for i in range(len(nn)):
if "mid_block" in nn[i]:
midstart = i
break
for i in range(len(nn)):
if "up_block" in nn[i]:
upstart = i
break
mm = mm[:upstart]+mm[midstart:]+mm[upstart:midstart]
nn = nn[:upstart]+nn[midstart:]+nn[upstart:midstart]
loras = []
names = []
for i in range(len(mm)):
name = nn[i]
module = mm[i]
if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
if "attn2" in name or "time_embed" in name:
continue
elif train_method == "innoxattn": # Cross Attention
if "attn2" in name:
continue
elif train_method == "selfattn": # Self Attention
if "attn1" not in name:
continue
elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention
if "to_k" in name:
continue
elif train_method == "full": # 全部学習
pass
else:
raise NotImplementedError(
f"train_method: {train_method} is not implemented."
)
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
if train_method == 'xattn-strict':
if 'out' in child_name:
continue
if "to_k" in child_name:
continue
if train_method == 'noxattn-hspace':
if 'mid_block' not in name:
continue
if train_method == 'noxattn-hspace-last':
if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
continue
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
in_dim = child_module.in_features
out_dim = child_module.out_features
combined_dim = in_dim+out_dim
lora = self.module(
lora_name, self.proj, self.v[counter:counter+combined_dim], self.mean[counter:counter+combined_dim],\
self.std[counter:counter+combined_dim], child_module, multiplier, rank, self.alpha)
counter+=combined_dim
if lora_name not in names:
loras.append(lora)
names.append(lora_name)
return loras
def prepare_optimizer_params(self):
all_params = []
if self.unet_loras: # 実質これしかない
params = []
[params.extend(lora.parameters()) for lora in self.unet_loras]
param_data = {"params": params}
all_params.append(param_data)
return all_params
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def set_lora_slider(self, scale):
self.lora_scale = scale
def __enter__(self):
for lora in self.unet_loras:
lora.multiplier = 1.0 * self.lora_scale
def __exit__(self, exc_type, exc_value, tb):
for lora in self.unet_loras:
lora.multiplier = 0