HaloMaster's picture
Duplicate from ysharma/Low-rank-Adaptation
fd74f7e
import math
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import PIL
import torch
import torch.nn.functional as F
import torch.nn as nn
class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4):
super().__init__()
if r > min(in_features, out_features):
raise ValueError(
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
)
self.linear = nn.Linear(in_features, out_features, bias)
self.lora_down = nn.Linear(in_features, r, bias=False)
self.lora_up = nn.Linear(r, out_features, bias=False)
self.scale = 1.0
nn.init.normal_(self.lora_down.weight, std=1 / r**2)
nn.init.zeros_(self.lora_up.weight)
def forward(self, input):
return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale
def inject_trainable_lora(
model: nn.Module,
target_replace_module: List[str] = ["CrossAttention", "Attention"],
r: int = 4,
loras=None, # path to lora .pt
):
"""
inject lora into model, and returns lora parameter groups.
"""
require_grad_params = []
names = []
if loras != None:
loras = torch.load(loras)
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for name, _child_module in _module.named_modules():
if _child_module.__class__.__name__ == "Linear":
weight = _child_module.weight
bias = _child_module.bias
_tmp = LoraInjectedLinear(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r,
)
_tmp.linear.weight = weight
if bias is not None:
_tmp.linear.bias = bias
# switch the module
_module._modules[name] = _tmp
require_grad_params.append(
_module._modules[name].lora_up.parameters()
)
require_grad_params.append(
_module._modules[name].lora_down.parameters()
)
if loras != None:
_module._modules[name].lora_up.weight = loras.pop(0)
_module._modules[name].lora_down.weight = loras.pop(0)
_module._modules[name].lora_up.weight.requires_grad = True
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)
return require_grad_params, names
def extract_lora_ups_down(model, target_replace_module=["CrossAttention", "Attention"]):
loras = []
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for _child_module in _module.modules():
if _child_module.__class__.__name__ == "LoraInjectedLinear":
loras.append((_child_module.lora_up, _child_module.lora_down))
if len(loras) == 0:
raise ValueError("No lora injected.")
return loras
def save_lora_weight(
model, path="./lora.pt", target_replace_module=["CrossAttention", "Attention"]
):
weights = []
for _up, _down in extract_lora_ups_down(
model, target_replace_module=target_replace_module
):
weights.append(_up.weight)
weights.append(_down.weight)
torch.save(weights, path)
def save_lora_as_json(model, path="./lora.json"):
weights = []
for _up, _down in extract_lora_ups_down(model):
weights.append(_up.weight.detach().cpu().numpy().tolist())
weights.append(_down.weight.detach().cpu().numpy().tolist())
import json
with open(path, "w") as f:
json.dump(weights, f)
def weight_apply_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"], alpha=1.0
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for _child_module in _module.modules():
if _child_module.__class__.__name__ == "Linear":
weight = _child_module.weight
up_weight = loras.pop(0).detach().to(weight.device)
down_weight = loras.pop(0).detach().to(weight.device)
# W <- W + U * D
weight = weight + alpha * (up_weight @ down_weight).type(
weight.dtype
)
_child_module.weight = nn.Parameter(weight)
def monkeypatch_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for name, _child_module in _module.named_modules():
if _child_module.__class__.__name__ == "Linear":
weight = _child_module.weight
bias = _child_module.bias
_tmp = LoraInjectedLinear(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r=r,
)
_tmp.linear.weight = weight
if bias is not None:
_tmp.linear.bias = bias
# switch the module
_module._modules[name] = _tmp
up_weight = loras.pop(0)
down_weight = loras.pop(0)
_module._modules[name].lora_up.weight = nn.Parameter(
up_weight.type(weight.dtype)
)
_module._modules[name].lora_down.weight = nn.Parameter(
down_weight.type(weight.dtype)
)
_module._modules[name].to(weight.device)
def monkeypatch_replace_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for name, _child_module in _module.named_modules():
if _child_module.__class__.__name__ == "LoraInjectedLinear":
weight = _child_module.linear.weight
bias = _child_module.linear.bias
_tmp = LoraInjectedLinear(
_child_module.linear.in_features,
_child_module.linear.out_features,
_child_module.linear.bias is not None,
r=r,
)
_tmp.linear.weight = weight
if bias is not None:
_tmp.linear.bias = bias
# switch the module
_module._modules[name] = _tmp
up_weight = loras.pop(0)
down_weight = loras.pop(0)
_module._modules[name].lora_up.weight = nn.Parameter(
up_weight.type(weight.dtype)
)
_module._modules[name].lora_down.weight = nn.Parameter(
down_weight.type(weight.dtype)
)
_module._modules[name].to(weight.device)
def monkeypatch_add_lora(
model,
loras,
target_replace_module=["CrossAttention", "Attention"],
alpha: float = 1.0,
beta: float = 1.0,
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for name, _child_module in _module.named_modules():
if _child_module.__class__.__name__ == "LoraInjectedLinear":
weight = _child_module.linear.weight
up_weight = loras.pop(0)
down_weight = loras.pop(0)
_module._modules[name].lora_up.weight = nn.Parameter(
up_weight.type(weight.dtype).to(weight.device) * alpha
+ _module._modules[name].lora_up.weight.to(weight.device) * beta
)
_module._modules[name].lora_down.weight = nn.Parameter(
down_weight.type(weight.dtype).to(weight.device) * alpha
+ _module._modules[name].lora_down.weight.to(weight.device)
* beta
)
_module._modules[name].to(weight.device)
def tune_lora_scale(model, alpha: float = 1.0):
for _module in model.modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
_module.scale = alpha
def _text_lora_path(path: str) -> str:
assert path.endswith(".pt"), "Only .pt files are supported"
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
def _ti_lora_path(path: str) -> str:
assert path.endswith(".pt"), "Only .pt files are supported"
return ".".join(path.split(".")[:-1] + ["ti", "pt"])
def load_learned_embed_in_clip(
learned_embeds_path, text_encoder, tokenizer, token=None, idempotent=False
):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# separate token and the embeds
trained_token = list(loaded_learned_embeds.keys())[0]
embeds = loaded_learned_embeds[trained_token]
# cast to dtype of text_encoder
dtype = text_encoder.get_input_embeddings().weight.dtype
# add the token in tokenizer
token = token if token is not None else trained_token
num_added_tokens = tokenizer.add_tokens(token)
i = 1
if num_added_tokens == 0 and idempotent:
return token
while num_added_tokens == 0:
print(f"The tokenizer already contains the token {token}.")
token = f"{token[:-1]}-{i}>"
print(f"Attempting to add the token {token}.")
num_added_tokens = tokenizer.add_tokens(token)
i += 1
# resize the token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# get the id for the token and assign the embeds
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
return token
def patch_pipe(
pipe,
unet_path,
token,
alpha: float = 1.0,
r: int = 4,
patch_text=False,
patch_ti=False,
idempotent_token=True,
):
ti_path = _ti_lora_path(unet_path)
text_path = _text_lora_path(unet_path)
unet_has_lora = False
text_encoder_has_lora = False
for _module in pipe.unet.modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
unet_has_lora = True
for _module in pipe.text_encoder.modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
text_encoder_has_lora = True
if not unet_has_lora:
monkeypatch_lora(pipe.unet, torch.load(unet_path), r=r)
else:
monkeypatch_replace_lora(pipe.unet, torch.load(unet_path), r=r)
if patch_text:
if not text_encoder_has_lora:
monkeypatch_lora(
pipe.text_encoder,
torch.load(text_path),
target_replace_module=["CLIPAttention"],
r=r,
)
else:
monkeypatch_replace_lora(
pipe.text_encoder,
torch.load(text_path),
target_replace_module=["CLIPAttention"],
r=r,
)
if patch_ti:
token = load_learned_embed_in_clip(
ti_path,
pipe.text_encoder,
pipe.tokenizer,
token,
idempotent=idempotent_token,
)