iLoRA / model /peft /tuners /adalora.py
MingLi
fork and bug fix from https://github.com/AkaliKong/iLoRA
9f13819
import re
import warnings
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D
from ..import_utils import is_bnb_4bit_available, is_bnb_available
from ..utils import (
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
PeftType,
_freeze_adapter,
_get_submodules,
transpose,
)
from .lora import (
LoraConfig,
LoraLayer,
LoraModel,
mark_only_lora_as_trainable,
)
if is_bnb_available():
import bitsandbytes as bnb
@dataclass
class AdaLoraConfig(LoraConfig):
"""
This is the configuration class to store the configuration of a [`~peft.AdaLora`].
Args:
target_r (`int`): The target average rank of incremental matrix.
init_r (`int`): The initial rank for each incremental matrix.
tinit (`int`): The steps of initial fine-tuning warmup.
tfinal (`int`): The step of final fine-tuning.
deltaT (`int`): The time internval between two budget allocations.
beta1 (`float`): The hyperparameter of EMA for sensitivity smoothing.
beta2 (`float`): The hyperparameter of EMA for undertainty quantification.
orth_reg_weight (`float`): The coefficient of orthogonal regularization.
total_step (`int`): The total training steps that should be specified before training.
rank_pattern (`list`): The allocated rank for each weight matrix by RankAllocator.
"""
target_r: int = field(default=8, metadata={"help": "Target Lora matrix dimension."})
init_r: int = field(default=12, metadata={"help": "Intial Lora matrix dimension."})
tinit: int = field(default=0, metadata={"help": "The steps of initial warmup."})
tfinal: int = field(default=0, metadata={"help": "The steps of final warmup."})
deltaT: int = field(default=1, metadata={"help": "Step interval of rank allocation."})
beta1: float = field(default=0.85, metadata={"help": "Hyperparameter of EMA."})
beta2: float = field(default=0.85, metadata={"help": "Hyperparameter of EMA."})
orth_reg_weight: float = field(default=0.5, metadata={"help": "The orthogonal regularization coefficient."})
total_step: Optional[int] = field(default=None, metadata={"help": "The total training steps."})
rank_pattern: Optional[dict] = field(default=None, metadata={"help": "The saved rank pattern."})
def __post_init__(self):
self.peft_type = PeftType.ADALORA
class AdaLoraModel(LoraModel):
"""
Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper:
https://openreview.net/pdf?id=lq62uWRJjiY
Args:
model ([`transformers.PreTrainedModel`]): The model to be adapted.
config ([`AdaLoraConfig`]): The configuration of the AdaLora model.
Returns:
`torch.nn.Module`: The AdaLora model.
Example::
>>> from transformers import AutoModelForSeq2SeqLM, LoraConfig >>> from peft import AdaLoraModel, AdaLoraConfig
>>> config = AdaLoraConfig(
peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", r=8, lora_alpha=32, target_modules=["q", "v"],
lora_dropout=0.01,
)
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AdaLoraModel(config, model)
**Attributes**:
- **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model.
"""
def __init__(self, model, config, adapter_name):
nn.Module.__init__(self)
self.model = model
self.peft_config = config
self.add_adapter(adapter_name, self.peft_config[adapter_name])
def add_adapter(self, adapter_name, config=None):
if config is not None:
model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config
config = self._prepare_adalora_config(config, model_config)
self.peft_config[adapter_name] = config
self._find_and_replace(adapter_name)
if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none":
raise ValueError(
"AdaLoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters."
)
traininable_mode_counter = 0
for config in self.peft_config.values():
if not config.inference_mode:
traininable_mode_counter += 1
if traininable_mode_counter > 1:
raise ValueError(
"AdaLoraModel supports only 1 trainable adapter. "
"When using multiple adapters, set inference_mode to True for all adapters except the one you want to train."
)
mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
if self.peft_config[adapter_name].inference_mode:
_freeze_adapter(self.model, adapter_name)
else:
self.trainable_adapter_name = adapter_name
self.rankallocator = RankAllocator(self.model, self.peft_config[adapter_name], self.trainable_adapter_name)
def _find_and_replace(self, adapter_name):
lora_config = self.peft_config[adapter_name]
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
if (loaded_in_8bit or loaded_in_4bit) and not is_bnb_available():
raise ImportError(
"To use Lora with 8-bit quantization, please install the `bitsandbytes` package. "
"You can install it with `pip install bitsandbytes`."
)
is_target_modules_in_base_model = False
kwargs = {
"r": lora_config.init_r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
if isinstance(lora_config.target_modules, str):
target_module_found = re.fullmatch(lora_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)
bias = target.bias is not None
if isinstance(target, LoraLayer):
target.update_layer(
adapter_name,
lora_config.init_r,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
else:
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"memory_efficient_backward": target.state.memory_efficient_backward,
"threshold": target.state.threshold,
"index": target.index,
}
)
new_module = SVDLinear8bitLt(
adapter_name, target.in_features, target.out_features, bias=bias, **kwargs
)
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target.compute_dtype,
"compress_statistics": target.weight.compress_statistics,
"quant_type": target.weight.quant_type,
}
)
new_module = SVDLinear4bit(
adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs
)
else:
if isinstance(target, torch.nn.Linear):
in_features, out_features = target.in_features, target.out_features
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
elif isinstance(target, Conv1D):
in_features, out_features = (
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
)
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. "
"Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
else:
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
)
new_module = SVDLinear(adapter_name, in_features, out_features, bias=bias, **kwargs)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {lora_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.model, name)
def forward(self, *args, **kwargs):
outputs = self.model.forward(*args, **kwargs)
# Calculate the orthogonal regularization
orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight
assert orth_reg_weight > 0
if hasattr(outputs, "loss"):
regu_loss = 0
num_param = 0
for n, p in self.model.named_parameters():
if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n:
para_cov = p @ p.T if "lora_A" in n else p.T @ p
I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov))
I.requires_grad = False
num_param += 1
regu_loss += torch.norm(para_cov - I, p="fro")
if num_param > 0:
regu_loss = regu_loss / num_param
else:
regu_loss = 0
outputs.loss += orth_reg_weight * regu_loss
return outputs
def resize_modules_by_rank_pattern(self, rank_pattern, adapter_name):
lora_config = self.peft_config[adapter_name]
for name, rank_idx in rank_pattern.items():
if isinstance(rank_idx, list):
rank = sum(rank_idx)
elif isinstance(rank_idx, torch.Tensor):
rank_idx = rank_idx.view(-1)
rank = rank_idx.sum().item()
else:
raise ValueError("Unexcepted type of rank_idx")
key = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
_, target, _ = _get_submodules(self.model, key)
lora_E_weights = target.lora_E[adapter_name][rank_idx]
lora_A_weights = target.lora_A[adapter_name][rank_idx]
lora_B_weights = target.lora_B[adapter_name][:, rank_idx]
ranknum = target.ranknum[adapter_name]
target.update_layer(
adapter_name,
rank,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
with torch.no_grad():
if rank > 0:
target.lora_E[adapter_name].copy_(lora_E_weights)
target.lora_A[adapter_name].copy_(lora_A_weights)
target.lora_B[adapter_name].copy_(lora_B_weights)
# The scaling is exactly as the previous
target.ranknum[adapter_name].copy_(ranknum)
def resize_state_dict_by_rank_pattern(self, rank_pattern, state_dict, adapter_name):
for name, rank_idx in rank_pattern.items():
rank = sum(rank_idx)
prefix = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
for layer in ["lora_E", "lora_A", "lora_B"]:
key = f"base_model.model.{prefix}.{layer}.{adapter_name}"
if layer != "lora_B":
state_dict[key] = (
state_dict[key][rank_idx] if rank != state_dict[key].shape[0] else state_dict[key]
)
else:
state_dict[key] = (
state_dict[key][:, rank_idx] if rank != state_dict[key].shape[1] else state_dict[key]
)
return state_dict
def update_and_allocate(self, global_step):
lora_config = self.peft_config[self.trainable_adapter_name]
# Update the importance score and allocate the budget
if global_step < lora_config.total_step - lora_config.tfinal:
_, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step)
if rank_pattern:
lora_config.rank_pattern = rank_pattern
# Finalize the budget allocation
elif global_step == lora_config.total_step - lora_config.tfinal:
_, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, force_mask=True)
# for some reason, this freezes the trainable parameters and nothing gets updates
# self.resize_modules_by_rank_pattern(rank_pattern, self.trainable_adapter_name)
lora_config.rank_pattern = rank_pattern
self.rankallocator.reset_ipt()
# Currently using inefficient way to mask the unimportant weights using the rank pattern
# due to problem mentioned above
elif global_step > lora_config.total_step - lora_config.tfinal:
self.rankallocator.mask_using_rank_pattern(self.model, lora_config.rank_pattern)
# Pass the function and do forward propagation
else:
return None
@staticmethod
def _prepare_adalora_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[
model_config["model_type"]
]
return peft_config
class AdaLoraLayer(LoraLayer):
def __init__(
self,
in_features: int,
out_features: int,
):
super().__init__(in_features, out_features)
self.lora_E = nn.ParameterDict({})
self.lora_A = nn.ParameterDict({})
self.lora_B = nn.ParameterDict({})
self.ranknum = nn.ParameterDict({})
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
def lora_dropout_layer(x):
return x
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
# Right singular vectors
self.lora_A.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(r, self.in_features))}))
# Singular values
self.lora_E.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(r, 1))}))
# Left singular vectors
self.lora_B.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(self.out_features, r))}))
# The current rank
self.ranknum.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(1), requires_grad=False)}))
self.ranknum[adapter_name].data.fill_(float(r))
self.ranknum[adapter_name].requires_grad = False
self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r)
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)
def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
nn.init.zeros_(self.lora_E[adapter_name])
nn.init.normal_(self.lora_A[adapter_name], mean=0.0, std=0.02)
nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02)
class SVDLinear(nn.Linear, AdaLoraLayer):
# SVD-based adaptation by a dense layer
def __init__(
self,
adapter_name: str,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False,
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)
nn.Linear.__init__(self, in_features, out_features, **kwargs)
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
self.weight.data = self.weight.data.T
nn.Linear.reset_parameters(self)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
def merge(self):
if self.active_adapter not in self.lora_A.keys():
return
if self.merged:
warnings.warn("Already merged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
self.weight.data += (
transpose(
self.lora_B[self.active_adapter]
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]),
self.fan_in_fan_out,
)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
self.merged = True
def unmerge(self):
if self.active_adapter not in self.lora_A.keys():
return
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
if self.r[self.active_adapter] > 0:
self.weight.data -= (
transpose(
self.lora_B[self.active_adapter]
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter])
)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
self.merged = False
def forward(self, x: torch.Tensor):
if self.active_adapter not in self.lora_A.keys():
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.disable_adapters:
if self.r[self.active_adapter] > 0 and self.merged:
self.unmerge()
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
result += (
(
self.lora_dropout[self.active_adapter](x)
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
@ self.lora_B[self.active_adapter].T
)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
else:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
return result
if is_bnb_available():
class SVDLinear8bitLt(bnb.nn.Linear8bitLt, AdaLoraLayer):
# Low-rank matrix for SVD-based adaptation
def __init__(
self,
adapter_name,
in_features,
out_features,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
bnb.nn.Linear8bitLt.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
has_fp16_weights=kwargs.get("has_fp16_weights", True),
memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
threshold=kwargs.get("threshold", 0.0),
index=kwargs.get("index", None),
)
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
return result
elif self.r[self.active_adapter] > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
output = (
(
self.lora_dropout[self.active_adapter](x)
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
@ self.lora_B[self.active_adapter].T
).to(expected_dtype)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
else:
output = (
(
self.lora_dropout[self.active_adapter](x)
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
@ self.lora_B[self.active_adapter].T
)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
result = result + output
return result
if is_bnb_4bit_available():
class SVDLinear4bit(bnb.nn.Linear4bit, AdaLoraLayer):
# Low-rank matrix for SVD-based adaptation
def __init__(
self,
adapter_name,
in_features,
out_features,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
**kwargs,
):
bnb.nn.Linear4bit.__init__(
self,
in_features,
out_features,
bias=kwargs.get("bias", True),
compute_dtype=kwargs.get("compute_dtype", torch.float32),
compress_statistics=kwargs.get("compress_statistics", True),
quant_type=kwargs.get("quant_type", "nf4"),
)
AdaLoraLayer.__init__(self, in_features=in_features, out_features=out_features)
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
def forward(self, x: torch.Tensor):
result = super().forward(x)
if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
return result
elif self.r[self.active_adapter] > 0:
if not torch.is_autocast_enabled():
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
output = (
(
self.lora_dropout[self.active_adapter](x)
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
@ self.lora_B[self.active_adapter].T
).to(expected_dtype)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
else:
output = (
(
self.lora_dropout[self.active_adapter](x)
@ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]).T
@ self.lora_B[self.active_adapter].T
)
* self.scaling[self.active_adapter]
/ (self.ranknum[self.active_adapter] + 1e-5)
)
result = result + output
return result
class RankAllocator(object):
"""
The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY
Args:
config ([`AdaLoraConfig`]): The configuration of the AdaLora model.
model: the model that we apply AdaLoRA to.
"""
def __init__(self, model, peft_config, adapter_name):
self.peft_config = peft_config
self.adapter_name = adapter_name
self.beta1 = peft_config.beta1
self.beta2 = peft_config.beta2
assert self.beta1 > 0 and self.beta1 < 1
assert self.beta2 > 0 and self.beta2 < 1
self.reset_ipt()
self._set_budget_scheduler(model)
def set_total_step(self, total_step):
self.peft_config.total_step = total_step
def reset_ipt(self):
self.ipt = {}
self.exp_avg_ipt = {}
self.exp_avg_unc = {}
def _set_budget_scheduler(self, model):
self.init_bgt = 0
self.name_set = set()
for n, p in model.named_parameters():
if f"lora_A.{self.adapter_name}" in n:
self.init_bgt += p.size(0)
self.name_set.add(n.replace("lora_A", "%s"))
self.name_set = sorted(self.name_set)
# The total final rank budget
self.target_bgt = self.peft_config.target_r * len(self.name_set)
def budget_schedule(self, step: int):
tinit = self.peft_config.tinit
tfinal = self.peft_config.tfinal
total_step = self.peft_config.total_step
# Initial warmup
if step <= tinit:
budget = self.init_bgt
mask_ind = False
# Final fine-tuning
elif step > total_step - tfinal:
budget = self.target_bgt
mask_ind = True
else:
# Budget decreasing with a cubic scheduler
mul_coeff = 1 - (step - tinit) / (total_step - tfinal - tinit)
budget = int((self.init_bgt - self.target_bgt) * (mul_coeff**3) + self.target_bgt)
mask_ind = True if step % self.peft_config.deltaT == 0 else False
return budget, mask_ind
def update_ipt(self, model):
# Update the sensitivity and uncertainty for every weight
for n, p in model.named_parameters():
if "lora_" in n and self.adapter_name in n:
if n not in self.ipt:
self.ipt[n] = torch.zeros_like(p)
self.exp_avg_ipt[n] = torch.zeros_like(p)
self.exp_avg_unc[n] = torch.zeros_like(p)
with torch.no_grad():
self.ipt[n] = (p * p.grad).abs().detach()
# Sensitivity smoothing
self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + (1 - self.beta1) * self.ipt[n]
# Uncertainty quantification
self.exp_avg_unc[n] = (
self.beta2 * self.exp_avg_unc[n] + (1 - self.beta2) * (self.ipt[n] - self.exp_avg_ipt[n]).abs()
)
def _element_score(self, n):
return self.exp_avg_ipt[n] * self.exp_avg_unc[n]
def _combine_ipt(self, ipt_E, ipt_AB):
ipt_AB = ipt_AB.sum(dim=1, keepdim=False)
sum_ipt = ipt_E.view(-1) + ipt_AB.view(-1)
return sum_ipt
def mask_to_budget(self, model, budget):
value_ipt = {}
vector_ipt = {}
triplet_ipt = {}
# Get the importance score for A, E, B
for n, p in model.named_parameters():
if f"lora_A.{self.adapter_name}" in n:
entry_ipt = self._element_score(n)
comb_ipt = torch.mean(entry_ipt, dim=1, keepdim=True)
name_m = n.replace("lora_A", "%s")
if name_m not in vector_ipt:
vector_ipt[name_m] = [comb_ipt]
else:
vector_ipt[name_m].append(comb_ipt)
if f"lora_B.{self.adapter_name}" in n:
entry_ipt = self._element_score(n)
comb_ipt = torch.mean(entry_ipt, dim=0, keepdim=False).view(-1, 1)
name_m = n.replace("lora_B", "%s")
if name_m not in vector_ipt:
vector_ipt[name_m] = [comb_ipt]
else:
vector_ipt[name_m].append(comb_ipt)
if f"lora_E.{self.adapter_name}" in n:
entry_ipt = self._element_score(n)
name_m = n.replace("lora_E", "%s")
value_ipt[name_m] = entry_ipt
all_score = []
# Calculate the score for each triplet
for name_m in vector_ipt:
ipt_E = value_ipt[name_m]
ipt_AB = torch.cat(vector_ipt[name_m], dim=1)
sum_ipt = self._combine_ipt(ipt_E, ipt_AB)
name_E = name_m % "lora_E"
triplet_ipt[name_E] = sum_ipt.view(-1, 1)
all_score.append(sum_ipt.view(-1))
# Get the threshold by ranking ipt
mask_threshold = torch.kthvalue(
torch.cat(all_score),
k=self.init_bgt - budget,
)[0].item()
rank_pattern = {}
# Mask the unimportant triplets
with torch.no_grad():
for n, p in model.named_parameters():
if f"lora_E.{self.adapter_name}" in n:
p.masked_fill_(triplet_ipt[n] <= mask_threshold, 0.0)
rank_pattern[n] = (~(triplet_ipt[n] <= mask_threshold)).view(-1).tolist()
return rank_pattern
def update_and_allocate(self, model, global_step, force_mask=False):
# # Update the importance score and allocate the budget
if global_step < self.peft_config.total_step - self.peft_config.tfinal:
self.update_ipt(model)
budget, mask_ind = self.budget_schedule(global_step)
# Allocate the budget according to importance scores
if mask_ind or force_mask:
rank_pattern = self.mask_to_budget(model, budget)
else:
rank_pattern = None
return budget, rank_pattern
def mask_using_rank_pattern(self, model, rank_pattern):
# Mask the unimportant triplets
is_adapter_name_truncated = False
if self.adapter_name not in next(iter(rank_pattern.keys())):
is_adapter_name_truncated = True
with torch.no_grad():
for n, p in model.named_parameters():
if f"lora_E.{self.adapter_name}" in n:
key = n if not is_adapter_name_truncated else n.replace(f".{self.adapter_name}", "")
mask = torch.Tensor(rank_pattern[key]).unsqueeze(-1).to(p.device)
p.masked_fill_(~mask.bool(), 0.0)