|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import math |
|
import warnings |
|
from typing import Any, Optional, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from accelerate.utils.imports import is_xpu_available |
|
from torch import svd_lowrank |
|
from transformers.pytorch_utils import Conv1D |
|
|
|
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge |
|
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx, get_bnb_param_type |
|
from peft.utils.other import transpose |
|
|
|
from .config import LoraConfig |
|
from .dora import DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer, _DoraConvNdLayer |
|
|
|
|
|
class LoraLayer(BaseTunerLayer): |
|
|
|
adapter_layer_names = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B","lora_route") |
|
|
|
other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout") |
|
|
|
def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None: |
|
self.base_layer = base_layer |
|
self.r = {} |
|
self.lora_alpha = {} |
|
self.scaling = {} |
|
self.lora_dropout = nn.ModuleDict({}) |
|
self.lora_A = nn.ModuleDict({}) |
|
self.lora_B = nn.ModuleDict({}) |
|
|
|
self.lora_embedding_A = nn.ParameterDict({}) |
|
self.lora_embedding_B = nn.ParameterDict({}) |
|
|
|
self.lora_route = nn.ModuleDict({}) |
|
|
|
self._disable_adapters = False |
|
self.merged_adapters = [] |
|
self.use_dora: dict[str, bool] = {} |
|
self.lora_bias: dict[str, bool] = {} |
|
self.lora_magnitude_vector = torch.nn.ModuleDict() |
|
self._caches: dict[str, Any] = {} |
|
self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload |
|
self.kwargs = kwargs |
|
|
|
base_layer = self.get_base_layer() |
|
if isinstance(base_layer, nn.Linear): |
|
in_features, out_features = base_layer.in_features, base_layer.out_features |
|
elif isinstance(base_layer, nn.Conv2d): |
|
in_features, out_features = base_layer.in_channels, base_layer.out_channels |
|
elif isinstance(base_layer, nn.Conv3d): |
|
in_features, out_features = base_layer.in_channels, base_layer.out_channels |
|
elif isinstance(base_layer, nn.Embedding): |
|
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim |
|
elif isinstance(base_layer, Conv1D): |
|
in_features, out_features = ( |
|
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape |
|
) |
|
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): |
|
|
|
in_features, out_features = base_layer.infeatures, base_layer.outfeatures |
|
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): |
|
|
|
in_features, out_features = base_layer.input_size, base_layer.output_size |
|
elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear": |
|
|
|
in_features, out_features = base_layer.in_features, base_layer.out_features |
|
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": |
|
|
|
in_features, out_features = base_layer.in_features, base_layer.out_features |
|
elif base_layer.__class__.__name__ == "EetqLinear": |
|
|
|
in_features, out_features = base_layer.in_features, base_layer.out_features |
|
elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear": |
|
|
|
in_features, out_features = base_layer.in_features, base_layer.out_features |
|
else: |
|
|
|
if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"): |
|
in_features, out_features = base_layer.in_features, base_layer.out_features |
|
else: |
|
in_features, out_features = None, None |
|
warnings.warn( |
|
f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning |
|
) |
|
|
|
self.in_features = in_features |
|
self.out_features = out_features |
|
|
|
def update_layer( |
|
self, |
|
adapter_name, |
|
r, |
|
lora_alpha, |
|
lora_dropout, |
|
init_lora_weights, |
|
use_rslora, |
|
use_dora: bool = False, |
|
lora_bias: bool = False, |
|
): |
|
|
|
if r <= 0: |
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") |
|
|
|
self.r[adapter_name] = r |
|
self.lora_alpha[adapter_name] = lora_alpha |
|
if lora_dropout > 0.0: |
|
lora_dropout_layer = nn.Dropout(p=lora_dropout) |
|
else: |
|
lora_dropout_layer = nn.Identity() |
|
|
|
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) |
|
|
|
self.lora_A[adapter_name] = nn.Linear(self.in_features, r, bias=False) |
|
self.lora_B[adapter_name] = nn.Linear(r, self.out_features, bias=lora_bias) |
|
self.lora_bias[adapter_name] = lora_bias |
|
|
|
if use_rslora: |
|
self.scaling[adapter_name] = lora_alpha / math.sqrt(r) |
|
else: |
|
self.scaling[adapter_name] = lora_alpha / r |
|
|
|
|
|
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"): |
|
with gather_params_ctx(self.get_base_layer().weight): |
|
self.pissa_init(adapter_name, init_lora_weights) |
|
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora": |
|
with gather_params_ctx(self.get_base_layer().weight): |
|
self.olora_init(adapter_name) |
|
elif init_lora_weights == "loftq": |
|
with gather_params_ctx(self.get_base_layer().weight): |
|
self.loftq_init(adapter_name) |
|
elif init_lora_weights == "eva": |
|
nn.init.zeros_(self.lora_B[adapter_name].weight) |
|
elif init_lora_weights: |
|
self.reset_lora_parameters(adapter_name, init_lora_weights) |
|
|
|
self._move_adapter_to_device_of_base_layer(adapter_name) |
|
|
|
if use_dora: |
|
self.dora_init(adapter_name) |
|
self.use_dora[adapter_name] = True |
|
else: |
|
self.use_dora[adapter_name] = False |
|
|
|
self.set_adapter(self.active_adapters) |
|
|
|
def reset_lora_parameters(self, adapter_name, init_lora_weights): |
|
if init_lora_weights is False: |
|
return |
|
|
|
if adapter_name in self.lora_A.keys(): |
|
if init_lora_weights is True: |
|
|
|
|
|
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) |
|
elif init_lora_weights.lower() == "gaussian": |
|
nn.init.normal_(self.lora_A[adapter_name].weight, std=1 / self.r[adapter_name]) |
|
else: |
|
raise ValueError(f"Unknown initialization {init_lora_weights=}") |
|
nn.init.zeros_(self.lora_B[adapter_name].weight) |
|
if self.lora_bias[adapter_name]: |
|
nn.init.zeros_(self.lora_B[adapter_name].bias) |
|
if adapter_name in self.lora_embedding_A.keys(): |
|
|
|
|
|
nn.init.zeros_(self.lora_embedding_A[adapter_name]) |
|
nn.init.normal_(self.lora_embedding_B[adapter_name]) |
|
if self.lora_bias[adapter_name]: |
|
|
|
nn.init.zeros_(self.lora_embedding_B[adapter_name].bias) |
|
|
|
def olora_init(self, adapter_name): |
|
base_layer = self.get_base_layer() |
|
orig_weight = base_layer.weight |
|
bnb_param_type = get_bnb_param_type(orig_weight) |
|
dtype = orig_weight.dtype |
|
|
|
if bnb_param_type: |
|
|
|
weight_tensor = dequantize_module_weight(base_layer) |
|
elif dtype in [torch.float32, torch.float16, torch.bfloat16]: |
|
weight_tensor = orig_weight |
|
else: |
|
raise TypeError(f"Unsupported data type for the base layer. Got {dtype}.") |
|
|
|
scale_factor = self.scaling[adapter_name] |
|
r = self.r[adapter_name] |
|
weight_tensor = weight_tensor.to(torch.float32) |
|
Q, R = torch.linalg.qr(weight_tensor.data) |
|
|
|
Qr, Rr = Q[:, :r], R[:r] |
|
|
|
self.lora_A[adapter_name].weight.data = Rr.contiguous() |
|
self.lora_B[adapter_name].weight.data = Qr.contiguous() |
|
|
|
weight_tensor.data -= scale_factor * self.lora_B[adapter_name].weight @ self.lora_A[adapter_name].weight |
|
if bnb_param_type == "4bit": |
|
weight_tensor = orig_weight.__class__( |
|
weight_tensor, |
|
quant_type=orig_weight.quant_type, |
|
quant_storage=orig_weight.quant_storage, |
|
compress_statistics=orig_weight.compress_statistics, |
|
module=orig_weight.module, |
|
).to(orig_weight.device) |
|
base_layer.weight = weight_tensor |
|
elif bnb_param_type == "8bit": |
|
weight_tensor = orig_weight.__class__( |
|
weight_tensor, |
|
requires_grad=orig_weight.requires_grad, |
|
has_fp16_weights=orig_weight.has_fp16_weights, |
|
).to(orig_weight.device) |
|
base_layer.weight = weight_tensor |
|
else: |
|
weight_tensor = weight_tensor.to(dtype) |
|
base_layer.weight.data = weight_tensor |
|
|
|
def pissa_init(self, adapter_name, init_lora_weights): |
|
weight = self.get_base_layer().weight |
|
dtype = weight.dtype |
|
if dtype not in [torch.float32, torch.float16, torch.bfloat16]: |
|
raise TypeError( |
|
"Please initialize PiSSA under float32, float16, or bfloat16. " |
|
"Subsequently, re-quantize the residual model to help minimize quantization errors." |
|
) |
|
weight = transpose(weight.to(torch.float32), self.fan_in_fan_out) |
|
if init_lora_weights == "pissa": |
|
|
|
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False) |
|
Vr = V[:, : self.r[adapter_name]] |
|
Sr = S[: self.r[adapter_name]] |
|
Sr /= self.scaling[adapter_name] |
|
Uhr = Uh[: self.r[adapter_name]] |
|
elif len(init_lora_weights.split("_niter_")) == 2: |
|
Vr, Sr, Ur = svd_lowrank( |
|
weight.data, self.r[adapter_name], niter=int(init_lora_weights.split("_niter_")[-1]) |
|
) |
|
Sr /= self.scaling[adapter_name] |
|
Uhr = Ur.t() |
|
else: |
|
raise ValueError( |
|
f"init_lora_weights should be 'pissa' or 'pissa_niter_[number of iters]', got {init_lora_weights} instead." |
|
) |
|
|
|
lora_A = torch.diag(torch.sqrt(Sr)) @ Uhr |
|
lora_B = Vr @ torch.diag(torch.sqrt(Sr)) |
|
self.lora_A[adapter_name].weight.data = lora_A |
|
self.lora_B[adapter_name].weight.data = lora_B |
|
weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A |
|
weight = transpose(weight.to(dtype), self.fan_in_fan_out) |
|
self.get_base_layer().weight.data = weight |
|
|
|
def loftq_init(self, adapter_name): |
|
from peft.utils.loftq_utils import loftq_init |
|
|
|
weight = self.get_base_layer().weight |
|
kwargs = { |
|
"num_bits": self.kwargs.get("loftq_bits", 4), |
|
"reduced_rank": self.r[adapter_name], |
|
"num_iter": self.kwargs.get("loftq_iter", 1), |
|
} |
|
|
|
qweight, lora_A, lora_B = loftq_init(weight, **kwargs) |
|
if adapter_name in self.lora_A.keys(): |
|
|
|
self.lora_A[adapter_name].weight.data = lora_A |
|
self.lora_B[adapter_name].weight.data = lora_B |
|
if adapter_name in self.lora_embedding_A.keys(): |
|
|
|
self.lora_embedding_A[adapter_name].weight.data = lora_A |
|
self.lora_embedding_B[adapter_name].weight.data = lora_B |
|
self.get_base_layer().weight.data = qweight |
|
|
|
def dora_init(self, adapter_name: str) -> None: |
|
if not self.lora_magnitude_vector: |
|
|
|
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) |
|
|
|
dora_layer = DoraLinearLayer(fan_in_fan_out=getattr(self, "fan_in_fan_out", False)) |
|
lora_A = self.lora_A[adapter_name].weight |
|
lora_B = self.lora_B[adapter_name].weight |
|
place_on_cpu = self.ephemeral_gpu_offload and (lora_A.device.type == "cpu" or lora_B.device.type == "cpu") |
|
if self.ephemeral_gpu_offload: |
|
if lora_A.device.type in ["cuda", "xpu"]: |
|
lora_B = lora_B.to(lora_A.device) |
|
else: |
|
if lora_B.device.type not in ["cuda", "xpu"]: |
|
if is_xpu_available(): |
|
lora_B = lora_B.to("xpu") |
|
else: |
|
lora_B = lora_B.to("cuda") |
|
lora_A = lora_A.to(lora_B.device) |
|
scaling = self.scaling[adapter_name] |
|
dora_layer.update_layer( |
|
base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling, place_on_cpu=place_on_cpu |
|
) |
|
self.lora_magnitude_vector[adapter_name] = dora_layer |
|
|
|
def _cache_store(self, key: str, value: Any) -> None: |
|
self._caches[key] = value |
|
|
|
def _cache_pop(self, key: str) -> Any: |
|
value = self._caches.pop(key) |
|
return value |
|
|
|
def set_scale(self, adapter, scale): |
|
if adapter not in self.scaling: |
|
|
|
return |
|
self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter] |
|
|
|
def scale_layer(self, scale: float) -> None: |
|
if scale == 1: |
|
return |
|
|
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.lora_A.keys(): |
|
continue |
|
|
|
self.scaling[active_adapter] *= scale |
|
|
|
def unscale_layer(self, scale=None) -> None: |
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.lora_A.keys(): |
|
continue |
|
|
|
if scale is None: |
|
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter] |
|
else: |
|
self.scaling[active_adapter] /= scale |
|
|
|
def _check_forward_args(self, x, *args, **kwargs): |
|
"""Check if the arguments are compatible with the configs and state of the model""" |
|
adapter_names = kwargs.get("adapter_names", None) |
|
if adapter_names is None: |
|
return |
|
|
|
if len(x) != len(adapter_names): |
|
msg = ( |
|
"Length of `adapter_names` should be the same as the number of inputs, but got " |
|
f"{len(adapter_names)} and {len(x)} respectively." |
|
) |
|
raise ValueError(msg) |
|
|
|
if self.merged: |
|
|
|
|
|
msg = "Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first." |
|
raise ValueError(msg) |
|
|
|
|
|
|
|
unique_adapters = {name for name in adapter_names if name != "__base__"} |
|
for adapter_name in unique_adapters: |
|
if self.use_dora.get(adapter_name, False): |
|
msg = "Cannot pass `adapter_names` when DoRA is enabled." |
|
raise ValueError(msg) |
|
|
|
def _mixed_batch_forward( |
|
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any |
|
) -> torch.Tensor: |
|
|
|
|
|
result = self.base_layer(x, *args, **kwargs) |
|
torch_result_dtype = result.dtype |
|
|
|
unique_adapters = set(adapter_names) |
|
sub_batch_indices_list = [] |
|
for adapter in unique_adapters: |
|
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) |
|
|
|
for i, active_adapter in enumerate(unique_adapters): |
|
if active_adapter == "__base__": |
|
continue |
|
if active_adapter not in self.lora_A.keys(): |
|
continue |
|
|
|
lora_A = self.lora_A[active_adapter] |
|
lora_B = self.lora_B[active_adapter] |
|
dropout = self.lora_dropout[active_adapter] |
|
scaling = self.scaling[active_adapter] |
|
|
|
|
|
|
|
sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype) |
|
lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling |
|
result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype) |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Linear(nn.Module, LoraLayer): |
|
|
|
def __init__( |
|
self, |
|
base_layer, |
|
adapter_name: str, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
fan_in_fan_out: bool = False, |
|
is_target_conv_1d_layer: bool = False, |
|
init_lora_weights: Union[bool, str] = True, |
|
use_rslora: bool = False, |
|
use_dora: bool = False, |
|
lora_bias: bool = False, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
LoraLayer.__init__(self, base_layer, **kwargs) |
|
self.fan_in_fan_out = fan_in_fan_out |
|
|
|
self._active_adapter = adapter_name |
|
self.num_experts = kwargs.get("num_experts", 1) |
|
self.expert_rank = kwargs.get("expert_rank", 4) |
|
self.expert_alpha = kwargs.get("expert_alpha", 4) |
|
self.top_k = kwargs.get("top_k", 4) |
|
self.blc_alpha = kwargs.get("blc_alpha", 0.0) |
|
self.blc_weight = kwargs.get("blc_weight", 0.0) |
|
|
|
if "ff.net" in kwargs["current_key"] or "proj_out" in kwargs["current_key"]: |
|
self.moe_lora = True |
|
self.update_moe_layer( |
|
adapter_name, |
|
r, |
|
lora_alpha=lora_alpha, |
|
lora_dropout=lora_dropout, |
|
init_lora_weights=init_lora_weights, |
|
use_rslora=use_rslora, |
|
use_dora=use_dora, |
|
lora_bias=lora_bias, |
|
num_experts=self.num_experts, |
|
expert_rank=self.expert_rank, |
|
expert_alpha=self.expert_alpha, |
|
) |
|
else: |
|
self.moe_lora = False |
|
|
|
self.update_layer( |
|
adapter_name, |
|
r, |
|
lora_alpha=lora_alpha, |
|
lora_dropout=lora_dropout, |
|
init_lora_weights=init_lora_weights, |
|
use_rslora=use_rslora, |
|
use_dora=use_dora, |
|
lora_bias=lora_bias, |
|
) |
|
self.is_target_conv_1d_layer = is_target_conv_1d_layer |
|
|
|
|
|
|
|
def update_moe_layer( |
|
self, |
|
adapter_name, |
|
r, |
|
lora_alpha, |
|
lora_dropout, |
|
init_lora_weights, |
|
use_rslora, |
|
use_dora: bool = False, |
|
lora_bias: bool = False, |
|
num_experts: int = 1, |
|
expert_rank: int = 4, |
|
expert_alpha: float = 4, |
|
): |
|
expert_list = [] |
|
for i in range(num_experts): |
|
expert_list.append(f"expert_{i}") |
|
|
|
if r <= 0 or num_experts <= 0 or expert_rank <= 0: |
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") |
|
if self.top_k > num_experts: |
|
raise ValueError(f"`top_k` should be a positive integer value but the value passed is {self.top_k}") |
|
|
|
self.r[adapter_name] = expert_rank |
|
self.lora_alpha[adapter_name] = expert_alpha |
|
if lora_dropout > 0.0: |
|
lora_dropout_layer = nn.Dropout(p=lora_dropout) |
|
else: |
|
lora_dropout_layer = nn.Identity() |
|
|
|
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) |
|
|
|
|
|
for i in range(num_experts): |
|
expert_name = expert_list[i] |
|
self.lora_A[expert_name] = nn.Linear(self.in_features, expert_rank, bias=False) |
|
self.lora_B[expert_name] = nn.Linear(expert_rank, self.out_features, bias=lora_bias) |
|
self.r[expert_name] = expert_rank |
|
self.lora_alpha[expert_name] = expert_alpha |
|
self.lora_bias[expert_name] = lora_bias |
|
self.lora_dropout.update(nn.ModuleDict({expert_name: lora_dropout_layer})) |
|
self.scaling[expert_name] = expert_alpha / expert_rank |
|
|
|
self.lora_route[adapter_name] = nn.Linear(self.in_features, num_experts, bias=False) |
|
self.lora_bias[adapter_name] = lora_bias |
|
|
|
if use_rslora: |
|
self.scaling[adapter_name] = expert_alpha / math.sqrt(expert_rank) |
|
else: |
|
self.scaling[adapter_name] = expert_alpha / expert_rank |
|
|
|
|
|
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"): |
|
with gather_params_ctx(self.get_base_layer().weight): |
|
self.pissa_init(adapter_name, init_lora_weights) |
|
elif isinstance(init_lora_weights, str) and init_lora_weights.lower() == "olora": |
|
with gather_params_ctx(self.get_base_layer().weight): |
|
self.olora_init(adapter_name) |
|
elif init_lora_weights == "loftq": |
|
with gather_params_ctx(self.get_base_layer().weight): |
|
self.loftq_init(adapter_name) |
|
elif init_lora_weights == "eva": |
|
nn.init.zeros_(self.lora_B[adapter_name].weight) |
|
elif init_lora_weights: |
|
self.reset_lora_parameters(adapter_name, init_lora_weights) |
|
for i in range(num_experts): |
|
expert_name = f"expert_{i}" |
|
self.reset_lora_parameters(expert_name, init_lora_weights) |
|
|
|
self._move_adapter_to_device_of_base_layer(adapter_name) |
|
for i in range(num_experts): |
|
expert_name = expert_list[i] |
|
self._move_adapter_to_device_of_base_layer(expert_name) |
|
|
|
if use_dora: |
|
self.dora_init(adapter_name) |
|
self.use_dora[adapter_name] = True |
|
else: |
|
self.use_dora[adapter_name] = False |
|
|
|
|
|
self.set_adapter(self.active_adapters+expert_list) |
|
|
|
|
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: |
|
""" |
|
Merge the active adapter weights into the base weights |
|
|
|
Args: |
|
safe_merge (`bool`, *optional*): |
|
If True, the merge operation will be performed in a copy of the original weights and check for NaNs |
|
before merging the weights. This is useful if you want to check if the merge operation will produce |
|
NaNs. Defaults to `False`. |
|
adapter_names (`list[str]`, *optional*): |
|
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults |
|
to `None`. |
|
""" |
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
if not adapter_names: |
|
|
|
return |
|
|
|
for active_adapter in adapter_names: |
|
if active_adapter in self.lora_A.keys(): |
|
base_layer = self.get_base_layer() |
|
if safe_merge: |
|
|
|
|
|
orig_weights = base_layer.weight.data.clone() |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
if not self.use_dora[active_adapter]: |
|
orig_weights += delta_weight |
|
else: |
|
|
|
|
|
weight_norm = ( |
|
self.lora_magnitude_vector[active_adapter] |
|
.get_weight_norm(orig_weights, transpose(delta_weight, self.fan_in_fan_out), scaling=1) |
|
.detach() |
|
) |
|
|
|
|
|
|
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm) |
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm |
|
dora_factor = transpose(dora_factor.view(-1, 1), self.fan_in_fan_out) |
|
orig_weights = dora_factor * (orig_weights + delta_weight) |
|
|
|
if not torch.isfinite(orig_weights).all(): |
|
raise ValueError( |
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
|
) |
|
|
|
base_layer.weight.data = orig_weights |
|
|
|
if self.lora_bias[active_adapter]: |
|
new_bias = base_layer.bias + self.lora_B[active_adapter].bias |
|
if not torch.isfinite(new_bias).all(): |
|
raise ValueError( |
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
|
) |
|
base_layer.bias.data = new_bias |
|
|
|
else: |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
if not self.use_dora[active_adapter]: |
|
base_layer.weight.data += delta_weight |
|
else: |
|
|
|
|
|
weight_norm = ( |
|
self.lora_magnitude_vector[active_adapter] |
|
.get_weight_norm( |
|
base_layer.weight, transpose(delta_weight, self.fan_in_fan_out), scaling=1 |
|
) |
|
.detach() |
|
) |
|
|
|
|
|
|
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm) |
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm |
|
new_weight = dora_factor.view(-1, 1) * (base_layer.weight.data + delta_weight) |
|
base_layer.weight.data = new_weight |
|
|
|
if self.lora_bias[active_adapter]: |
|
base_layer.bias.data += self.lora_B[active_adapter].bias |
|
|
|
self.merged_adapters.append(active_adapter) |
|
|
|
def unmerge(self) -> None: |
|
""" |
|
This method unmerges all merged adapter layers from the base weights. |
|
""" |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
while len(self.merged_adapters) > 0: |
|
active_adapter = self.merged_adapters.pop() |
|
if active_adapter in self.lora_A.keys(): |
|
weight = self.get_base_layer().weight |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
if not self.use_dora[active_adapter]: |
|
weight.data -= delta_weight |
|
else: |
|
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") |
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm |
|
weight_orig = weight.data / dora_factor.view(-1, 1) - delta_weight |
|
weight.data = weight_orig |
|
|
|
if self.lora_bias[active_adapter]: |
|
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias |
|
|
|
def get_delta_weight(self, adapter) -> torch.Tensor: |
|
""" |
|
Compute the delta weight for the given adapter. |
|
|
|
Args: |
|
adapter (str): |
|
The name of the adapter for which the delta weight should be computed. |
|
""" |
|
device = self.lora_B[adapter].weight.device |
|
dtype = self.lora_A[adapter].weight.dtype |
|
|
|
|
|
|
|
|
|
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) |
|
|
|
weight_A = self.lora_A[adapter].weight |
|
weight_B = self.lora_B[adapter].weight |
|
|
|
if cast_to_fp32: |
|
weight_A = weight_A.float() |
|
weight_B = weight_B.float() |
|
|
|
output_tensor = transpose(weight_B @ weight_A, self.fan_in_fan_out) * self.scaling[adapter] |
|
|
|
if cast_to_fp32: |
|
output_tensor = output_tensor.to(dtype=dtype) |
|
|
|
|
|
self.lora_A[adapter].weight.data = weight_A.to(dtype) |
|
self.lora_B[adapter].weight.data = weight_B.to(dtype) |
|
|
|
return output_tensor |
|
|
|
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
|
|
if self.moe_lora: |
|
return self.moe_forward(x, *args, **kwargs) |
|
|
|
self._check_forward_args(x, *args, **kwargs) |
|
adapter_names = kwargs.pop("adapter_names", None) |
|
|
|
if self.disable_adapters: |
|
if self.merged: |
|
self.unmerge() |
|
result = self.base_layer(x, *args, **kwargs) |
|
elif adapter_names is not None: |
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) |
|
elif self.merged: |
|
result = self.base_layer(x, *args, **kwargs) |
|
else: |
|
result = self.base_layer(x, *args, **kwargs) |
|
torch_result_dtype = result.dtype |
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.lora_A.keys(): |
|
continue |
|
lora_A = self.lora_A[active_adapter] |
|
lora_B = self.lora_B[active_adapter] |
|
dropout = self.lora_dropout[active_adapter] |
|
scaling = self.scaling[active_adapter] |
|
x = x.to(lora_A.weight.dtype) |
|
|
|
if not self.use_dora[active_adapter]: |
|
result = result + lora_B(lora_A(dropout(x))) * scaling |
|
else: |
|
if isinstance(dropout, nn.Identity) or not self.training: |
|
base_result = result |
|
else: |
|
x = dropout(x) |
|
base_result = None |
|
|
|
result = result + self.lora_magnitude_vector[active_adapter]( |
|
x, |
|
lora_A=lora_A, |
|
lora_B=lora_B, |
|
scaling=scaling, |
|
base_layer=self.get_base_layer(), |
|
base_result=base_result, |
|
) |
|
|
|
result = result.to(torch_result_dtype) |
|
|
|
return result |
|
|
|
def moe_forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
moe_type="token_wise" |
|
|
|
self._check_forward_args(x, *args, **kwargs) |
|
adapter_names = kwargs.pop("adapter_names", None) |
|
|
|
if self.disable_adapters: |
|
if self.merged: |
|
self.unmerge() |
|
result = self.base_layer(x, *args, **kwargs) |
|
elif adapter_names is not None: |
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) |
|
elif self.merged: |
|
result = self.base_layer(x, *args, **kwargs) |
|
else: |
|
if moe_type == "token_wise": |
|
result = self.base_layer(x, *args, **kwargs) |
|
torch_result_dtype = result.dtype |
|
activate_adapter_name = self.active_adapters[0] |
|
|
|
|
|
route_logits = self.lora_route[activate_adapter_name](x) |
|
|
|
|
|
top_k_probs, top_k_indices = torch.topk(route_logits, k=self.top_k, dim=-1) |
|
|
|
top_k_probs = F.softmax(top_k_probs, dim=-1, dtype=torch.float32).to(result.dtype) |
|
|
|
|
|
route_weight = torch.zeros_like(route_logits) |
|
route_weight=route_weight.scatter_(-1, top_k_indices, top_k_probs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(self.num_experts): |
|
expert_name = f"expert_{i}" |
|
lora_A = self.lora_A[expert_name] |
|
lora_B = self.lora_B[expert_name] |
|
scaling = self.scaling[expert_name] |
|
dropout = self.lora_dropout[expert_name] |
|
result += lora_B(lora_A(dropout(x))) * scaling * torch.unsqueeze(route_weight[:,:,i], -1) |
|
|
|
result = result.to(torch_result_dtype) |
|
|
|
elif moe_type == "sequence_wise": |
|
result = self.base_layer(x, *args, **kwargs) |
|
torch_result_dtype = result.dtype |
|
activate_adapter_name = self.active_adapters[0] |
|
route_logits = self.lora_route[activate_adapter_name](x[:,0]) |
|
|
|
route_logits=route_logits.unsqueeze(1).repeat(1,x.shape[1],1) |
|
|
|
|
|
top_k_probs, top_k_indices = torch.topk(route_logits, k=self.top_k, dim=-1) |
|
|
|
top_k_probs = F.softmax(top_k_probs, dim=-1, dtype=torch.float32).to(result.dtype) |
|
|
|
|
|
route_weight = torch.zeros_like(route_logits) |
|
route_weight=route_weight.scatter_(-1, top_k_indices, top_k_probs) |
|
|
|
|
|
|
|
|
|
for i in range(self.num_experts): |
|
expert_name = f"expert_{i}" |
|
lora_A = self.lora_A[expert_name] |
|
lora_B = self.lora_B[expert_name] |
|
scaling = self.scaling[expert_name] |
|
dropout = self.lora_dropout[expert_name] |
|
result += lora_B(lora_A(dropout(x))) * scaling * torch.unsqueeze(route_weight[:,:,i], -1) |
|
|
|
result = result.to(torch_result_dtype) |
|
|
|
return result |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return "lora." + rep |
|
|
|
|
|
class Embedding(nn.Module, LoraLayer): |
|
|
|
def __init__( |
|
self, |
|
base_layer: nn.Module, |
|
adapter_name: str, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
init_lora_weights: Union[bool, str] = True, |
|
use_rslora: bool = False, |
|
use_dora: bool = False, |
|
lora_bias: bool = False, |
|
**kwargs, |
|
) -> None: |
|
if lora_bias: |
|
|
|
raise ValueError(f"lora_bias={lora_bias} is not supported for {self.__class__.__name__}.") |
|
|
|
super().__init__() |
|
LoraLayer.__init__(self, base_layer) |
|
|
|
self._active_adapter = adapter_name |
|
self.update_layer( |
|
adapter_name, |
|
r, |
|
lora_alpha=lora_alpha, |
|
lora_dropout=lora_dropout, |
|
init_lora_weights=init_lora_weights, |
|
use_rslora=use_rslora, |
|
use_dora=use_dora, |
|
lora_bias=lora_bias, |
|
) |
|
|
|
def update_layer( |
|
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias |
|
): |
|
if r <= 0: |
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") |
|
|
|
self.r[adapter_name] = r |
|
self.lora_alpha[adapter_name] = lora_alpha |
|
if lora_dropout > 0.0: |
|
lora_dropout_layer = nn.Dropout(p=lora_dropout) |
|
else: |
|
lora_dropout_layer = nn.Identity() |
|
|
|
self.lora_dropout[adapter_name] = lora_dropout_layer |
|
|
|
weight_A = torch.randn((r, self.in_features)) |
|
weight_B = torch.randn((self.out_features, r)) |
|
self.lora_embedding_A[adapter_name] = nn.Parameter(weight_A) |
|
self.lora_embedding_B[adapter_name] = nn.Parameter(weight_B) |
|
self.lora_bias[adapter_name] = lora_bias |
|
|
|
if use_rslora: |
|
self.scaling[adapter_name] = lora_alpha / math.sqrt(r) |
|
else: |
|
self.scaling[adapter_name] = lora_alpha / r |
|
|
|
if init_lora_weights == "loftq": |
|
self.loftq_init(adapter_name) |
|
elif init_lora_weights: |
|
self.reset_lora_parameters(adapter_name, init_lora_weights) |
|
|
|
|
|
self._move_adapter_to_device_of_base_layer(adapter_name) |
|
|
|
if use_dora: |
|
self.dora_init(adapter_name) |
|
self.use_dora[adapter_name] = True |
|
else: |
|
self.use_dora[adapter_name] = False |
|
|
|
self.set_adapter(self.active_adapters) |
|
|
|
def dora_init(self, adapter_name: str) -> None: |
|
if self.lora_magnitude_vector is None: |
|
|
|
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) |
|
|
|
dora_layer = DoraEmbeddingLayer(fan_in_fan_out=True) |
|
lora_embedding_A = self.lora_embedding_A[adapter_name] |
|
lora_embedding_B = self.lora_embedding_B[adapter_name] |
|
scaling = self.scaling[adapter_name] |
|
dora_layer.update_layer( |
|
base_layer=self.get_base_layer(), lora_A=lora_embedding_A, lora_B=lora_embedding_B, scaling=scaling |
|
) |
|
self.lora_magnitude_vector[adapter_name] = dora_layer |
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: |
|
""" |
|
Merge the active adapter weights into the base weights |
|
|
|
Args: |
|
safe_merge (`bool`, *optional*): |
|
If True, the merge operation will be performed in a copy of the original weights and check for NaNs |
|
before merging the weights. This is useful if you want to check if the merge operation will produce |
|
NaNs. Defaults to `False`. |
|
adapter_names (`list[str]`, *optional*): |
|
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults |
|
to `None`. |
|
""" |
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
if not adapter_names: |
|
|
|
return |
|
|
|
for active_adapter in adapter_names: |
|
if active_adapter in self.lora_embedding_A.keys(): |
|
base_layer = self.get_base_layer() |
|
if safe_merge: |
|
|
|
|
|
orig_weights = base_layer.weight.data.clone() |
|
orig_weights += self.get_delta_weight(active_adapter) |
|
|
|
if not torch.isfinite(orig_weights).all(): |
|
raise ValueError( |
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
|
) |
|
|
|
base_layer.weight.data = orig_weights |
|
else: |
|
base_layer.weight.data += self.get_delta_weight(active_adapter) |
|
self.merged_adapters.append(active_adapter) |
|
|
|
def unmerge(self) -> None: |
|
""" |
|
This method unmerges all merged adapter layers from the base weights. |
|
""" |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
while len(self.merged_adapters) > 0: |
|
active_adapter = self.merged_adapters.pop() |
|
if active_adapter in self.lora_embedding_A.keys(): |
|
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter) |
|
|
|
def get_delta_weight(self, adapter) -> torch.Tensor: |
|
""" |
|
Compute the delta weight for the given adapter. |
|
|
|
Args: |
|
adapter (str): |
|
The name of the adapter for which the delta weight should be computed. |
|
""" |
|
device = self.lora_embedding_B[adapter].device |
|
dtype = self.lora_embedding_A[adapter].dtype |
|
|
|
|
|
|
|
|
|
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) |
|
|
|
weight_A = self.lora_embedding_A[adapter] |
|
weight_B = self.lora_embedding_B[adapter] |
|
|
|
if cast_to_fp32: |
|
weight_A = weight_A.float() |
|
weight_B = weight_B.float() |
|
|
|
output_tensor = transpose(weight_B @ weight_A, True) * self.scaling[adapter] |
|
|
|
if cast_to_fp32: |
|
output_tensor = output_tensor.to(dtype=dtype) |
|
|
|
|
|
self.lora_embedding_A[adapter] = weight_A.to(dtype) |
|
self.lora_embedding_B[adapter] = weight_B.to(dtype) |
|
|
|
return output_tensor |
|
|
|
def _mixed_batch_forward( |
|
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any |
|
) -> torch.Tensor: |
|
|
|
|
|
result = self.base_layer(x, *args, **kwargs) |
|
|
|
unique_adapters = set(adapter_names) |
|
sub_batch_indices_list = [] |
|
for adapter in unique_adapters: |
|
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) |
|
|
|
for i, active_adapter in enumerate(unique_adapters): |
|
if active_adapter == "__base__": |
|
continue |
|
if active_adapter not in self.lora_embedding_A.keys(): |
|
continue |
|
|
|
embedding_A = self.lora_embedding_A[active_adapter].T |
|
embedding_B = self.lora_embedding_B[active_adapter].T |
|
scaling = self.scaling[active_adapter] |
|
|
|
|
|
|
|
sub_batch = x[sub_batch_indices_list[i]] |
|
after_A = self._embed(sub_batch, embedding_A) |
|
result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling |
|
|
|
return result |
|
|
|
def _embed(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: |
|
base_layer = self.get_base_layer() |
|
return F.embedding( |
|
input, |
|
weight, |
|
padding_idx=base_layer.padding_idx, |
|
max_norm=base_layer.max_norm, |
|
norm_type=base_layer.norm_type, |
|
scale_grad_by_freq=base_layer.scale_grad_by_freq, |
|
sparse=base_layer.sparse, |
|
) |
|
|
|
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: |
|
|
|
self._check_forward_args(x, *args, **kwargs) |
|
adapter_names = kwargs.pop("adapter_names", None) |
|
|
|
if self.disable_adapters: |
|
if self.merged: |
|
self.unmerge() |
|
result = self.base_layer(x, *args, **kwargs) |
|
elif adapter_names is not None: |
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) |
|
elif self.merged: |
|
result = self.base_layer(x, *args, **kwargs) |
|
else: |
|
result = self.base_layer(x, *args, **kwargs) |
|
torch_result_dtype = result.dtype |
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.lora_embedding_A: |
|
continue |
|
embedding_A = self.lora_embedding_A[active_adapter].T |
|
embedding_B = self.lora_embedding_B[active_adapter].T |
|
scaling = self.scaling[active_adapter] |
|
|
|
if not self.use_dora[active_adapter]: |
|
after_A = self._embed(x, embedding_A) |
|
result = result + (after_A @ embedding_B) * scaling |
|
else: |
|
mag_norm_scale, dora_result = self.lora_magnitude_vector[active_adapter]( |
|
x, |
|
lora_A=embedding_A, |
|
lora_B=embedding_B, |
|
scaling=scaling, |
|
base_layer=self.get_base_layer(), |
|
embed_fn=self._embed, |
|
) |
|
result = mag_norm_scale * result + dora_result |
|
result = result.to(torch_result_dtype) |
|
|
|
return result |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return "lora." + rep |
|
|
|
|
|
class _ConvNd(nn.Module, LoraLayer): |
|
|
|
def __init__( |
|
self, |
|
base_layer: nn.Module, |
|
adapter_name: str, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
init_lora_weights: Union[bool, str] = True, |
|
use_rslora: bool = False, |
|
use_dora: bool = False, |
|
lora_bias: bool = False, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
LoraLayer.__init__(self, base_layer) |
|
|
|
self._active_adapter = adapter_name |
|
self._kernel_dim = base_layer.weight.dim() |
|
|
|
self.update_layer( |
|
adapter_name, |
|
r, |
|
lora_alpha=lora_alpha, |
|
lora_dropout=lora_dropout, |
|
init_lora_weights=init_lora_weights, |
|
use_rslora=use_rslora, |
|
use_dora=use_dora, |
|
lora_bias=lora_bias, |
|
) |
|
|
|
def update_layer( |
|
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias |
|
): |
|
if r <= 0: |
|
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") |
|
|
|
self.r[adapter_name] = r |
|
self.lora_alpha[adapter_name] = lora_alpha |
|
if lora_dropout > 0.0: |
|
lora_dropout_layer = nn.Dropout(p=lora_dropout) |
|
else: |
|
lora_dropout_layer = nn.Identity() |
|
|
|
self.lora_dropout[adapter_name] = lora_dropout_layer |
|
|
|
base_layer = self.get_base_layer() |
|
kernel_size = base_layer.kernel_size |
|
stride = base_layer.stride |
|
padding = base_layer.padding |
|
conv_layer = type(base_layer) |
|
out_kernel = out_stride = (1,) * (self._kernel_dim - 2) |
|
self.lora_A[adapter_name] = conv_layer(self.in_features, r, kernel_size, stride, padding, bias=False) |
|
self.lora_B[adapter_name] = conv_layer(r, self.out_features, out_kernel, out_stride, bias=lora_bias) |
|
self.lora_bias[adapter_name] = lora_bias |
|
|
|
if use_rslora: |
|
self.scaling[adapter_name] = lora_alpha / math.sqrt(r) |
|
else: |
|
self.scaling[adapter_name] = lora_alpha / r |
|
|
|
if init_lora_weights == "loftq": |
|
self.loftq_init(adapter_name) |
|
elif init_lora_weights: |
|
self.reset_lora_parameters(adapter_name, init_lora_weights) |
|
|
|
|
|
self._move_adapter_to_device_of_base_layer(adapter_name) |
|
|
|
if use_dora: |
|
self.dora_init(adapter_name) |
|
self.use_dora[adapter_name] = True |
|
else: |
|
self.use_dora[adapter_name] = False |
|
|
|
self.set_adapter(self.active_adapters) |
|
|
|
def _get_dora_factor_view(self): |
|
return (-1,) + (1,) * (self._kernel_dim - 1) |
|
|
|
def dora_init(self, adapter_name: str) -> None: |
|
if self.lora_magnitude_vector is None: |
|
|
|
self.adapter_layer_names = self.adapter_layer_names[:] + ("lora_magnitude_vector",) |
|
|
|
dora_layer_class = self._get_dora_layer_class() |
|
dora_layer = dora_layer_class(fan_in_fan_out=False) |
|
lora_A = self.lora_A[adapter_name].weight |
|
lora_B = self.lora_B[adapter_name].weight |
|
scaling = self.scaling[adapter_name] |
|
dora_layer.update_layer(base_layer=self.get_base_layer(), lora_A=lora_A, lora_B=lora_B, scaling=scaling) |
|
self.lora_magnitude_vector[adapter_name] = dora_layer |
|
|
|
def _get_dora_layer_class(self) -> type[_DoraConvNdLayer]: |
|
|
|
raise NotImplementedError |
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: |
|
""" |
|
Merge the active adapter weights inside the base weights |
|
|
|
Args: |
|
safe_merge (`bool`, *optional*): |
|
If True, the merge operation will be performed in a copy of the original weights and check for NaNs |
|
before merging the weights. This is useful if you want to check if the merge operation will produce |
|
NaNs. Defaults to `False`. |
|
adapter_names (`list[str]`, *optional*): |
|
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults |
|
to `None`. |
|
""" |
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
if not adapter_names: |
|
|
|
return |
|
|
|
for active_adapter in adapter_names: |
|
if active_adapter in self.lora_A.keys(): |
|
base_layer = self.get_base_layer() |
|
if safe_merge: |
|
|
|
|
|
orig_weights = base_layer.weight.data.clone() |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
|
|
if not self.use_dora[active_adapter]: |
|
orig_weights += delta_weight |
|
else: |
|
|
|
|
|
weight_norm = ( |
|
self.lora_magnitude_vector[active_adapter] |
|
.get_weight_norm(orig_weights, delta_weight, scaling=1) |
|
.detach() |
|
) |
|
|
|
|
|
|
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm) |
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm |
|
orig_weights = dora_factor.view(*self._get_dora_factor_view()) * (orig_weights + delta_weight) |
|
|
|
if not torch.isfinite(orig_weights).all(): |
|
raise ValueError( |
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
|
) |
|
base_layer.weight.data = orig_weights |
|
|
|
if self.lora_bias[active_adapter]: |
|
new_bias = base_layer.bias + self.lora_B[active_adapter].bias |
|
if not torch.isfinite(new_bias).all(): |
|
raise ValueError( |
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
|
) |
|
base_layer.bias.data = new_bias |
|
|
|
else: |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
if not self.use_dora[active_adapter]: |
|
base_layer.weight.data += delta_weight |
|
else: |
|
|
|
|
|
weight_norm = ( |
|
self.lora_magnitude_vector[active_adapter] |
|
.get_weight_norm(base_layer.weight, delta_weight, scaling=1) |
|
.detach() |
|
) |
|
|
|
|
|
|
|
self._cache_store(f"{active_adapter}-weight_norm", weight_norm) |
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm |
|
new_weight = dora_factor.view(*self._get_dora_factor_view()) * ( |
|
base_layer.weight.data + delta_weight |
|
) |
|
base_layer.weight.data = new_weight |
|
|
|
if self.lora_bias[active_adapter]: |
|
base_layer.bias.data += self.lora_B[active_adapter].bias |
|
|
|
self.merged_adapters.append(active_adapter) |
|
|
|
def unmerge(self) -> None: |
|
""" |
|
This method unmerges all merged adapter layers from the base weights. |
|
""" |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
while len(self.merged_adapters) > 0: |
|
active_adapter = self.merged_adapters.pop() |
|
if active_adapter in self.lora_A.keys(): |
|
weight = self.get_base_layer().weight |
|
delta_weight = self.get_delta_weight(active_adapter) |
|
if not self.use_dora[active_adapter]: |
|
weight.data -= delta_weight |
|
else: |
|
weight_norm = self._cache_pop(f"{active_adapter}-weight_norm") |
|
dora_factor = self.lora_magnitude_vector[active_adapter].weight / weight_norm |
|
weight_orig = weight.data / dora_factor.view(*self._get_dora_factor_view()) - delta_weight |
|
weight.data = weight_orig |
|
|
|
if self.lora_bias[active_adapter]: |
|
self.get_base_layer().bias.data -= self.lora_B[active_adapter].bias |
|
|
|
def get_delta_weight(self, adapter) -> torch.Tensor: |
|
""" |
|
Compute the delta weight for the given adapter. |
|
|
|
Args: |
|
adapter (str): |
|
The name of the adapter for which the delta weight should be computed. |
|
""" |
|
device = self.lora_B[adapter].weight.device |
|
dtype = self.lora_A[adapter].weight.dtype |
|
|
|
|
|
|
|
|
|
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) |
|
|
|
weight_A = self.lora_A[adapter].weight |
|
weight_B = self.lora_B[adapter].weight |
|
|
|
if cast_to_fp32: |
|
weight_A = weight_A.float() |
|
weight_B = weight_B.float() |
|
|
|
|
|
if self.get_base_layer().weight.size()[2:4] == (1, 1): |
|
|
|
output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( |
|
3 |
|
) * self.scaling[adapter] |
|
else: |
|
output_tensor = ( |
|
self.conv_fn( |
|
weight_A.transpose(0, 1), |
|
weight_B, |
|
).transpose(0, 1) |
|
* self.scaling[adapter] |
|
) |
|
|
|
if cast_to_fp32: |
|
output_tensor = output_tensor.to(dtype=dtype) |
|
|
|
|
|
self.lora_A[adapter].weight.data = weight_A.to(dtype) |
|
self.lora_B[adapter].weight.data = weight_B.to(dtype) |
|
|
|
return output_tensor |
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
|
self._check_forward_args(x, *args, **kwargs) |
|
adapter_names = kwargs.pop("adapter_names", None) |
|
|
|
if self.disable_adapters: |
|
if self.merged: |
|
self.unmerge() |
|
result = self.base_layer(x, *args, **kwargs) |
|
elif adapter_names is not None: |
|
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) |
|
elif self.merged: |
|
result = self.base_layer(x, *args, **kwargs) |
|
else: |
|
result = self.base_layer(x, *args, **kwargs) |
|
torch_result_dtype = result.dtype |
|
|
|
for active_adapter in self.active_adapters: |
|
if active_adapter not in self.lora_A.keys(): |
|
continue |
|
lora_A = self.lora_A[active_adapter] |
|
lora_B = self.lora_B[active_adapter] |
|
dropout = self.lora_dropout[active_adapter] |
|
scaling = self.scaling[active_adapter] |
|
x = x.to(lora_A.weight.dtype) |
|
|
|
if not self.use_dora[active_adapter]: |
|
result = result + lora_B(lora_A(dropout(x))) * scaling |
|
else: |
|
x = dropout(x) |
|
result = result + self.lora_magnitude_vector[active_adapter]( |
|
x, |
|
lora_A=lora_A, |
|
lora_B=lora_B, |
|
scaling=scaling, |
|
base_layer=self.get_base_layer(), |
|
) |
|
|
|
result = result.to(torch_result_dtype) |
|
return result |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return "lora." + rep |
|
|
|
|
|
class Conv2d(_ConvNd): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
if not self._kernel_dim == 4: |
|
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}") |
|
self.conv_fn = F.conv2d |
|
|
|
def _get_dora_layer_class(self): |
|
return DoraConv2dLayer |
|
|
|
|
|
class Conv3d(_ConvNd): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
if not self._kernel_dim == 5: |
|
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}") |
|
self.conv_fn = F.conv3d |
|
|
|
def _get_dora_layer_class(self): |
|
return DoraConv3dLayer |
|
|
|
|
|
def dispatch_default( |
|
target: torch.nn.Module, |
|
adapter_name: str, |
|
lora_config: LoraConfig, |
|
**kwargs, |
|
) -> Optional[torch.nn.Module]: |
|
new_module = None |
|
|
|
if isinstance(target, BaseTunerLayer): |
|
target_base_layer = target.get_base_layer() |
|
else: |
|
target_base_layer = target |
|
|
|
if isinstance(target_base_layer, torch.nn.Embedding): |
|
embedding_kwargs = kwargs.copy() |
|
embedding_kwargs.pop("fan_in_fan_out", None) |
|
embedding_kwargs.update(lora_config.loftq_config) |
|
new_module = Embedding(target, adapter_name, **embedding_kwargs) |
|
elif isinstance(target_base_layer, torch.nn.Conv2d): |
|
kwargs.update(lora_config.loftq_config) |
|
new_module = Conv2d(target, adapter_name, **kwargs) |
|
elif isinstance(target_base_layer, torch.nn.Conv3d): |
|
kwargs.update(lora_config.loftq_config) |
|
new_module = Conv3d(target, adapter_name, **kwargs) |
|
elif isinstance(target_base_layer, torch.nn.Linear): |
|
if kwargs["fan_in_fan_out"]: |
|
warnings.warn( |
|
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " |
|
"Setting fan_in_fan_out to False." |
|
) |
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False |
|
kwargs.update(lora_config.loftq_config) |
|
new_module = Linear(target, adapter_name, **kwargs) |
|
elif isinstance(target_base_layer, Conv1D): |
|
if not kwargs["fan_in_fan_out"]: |
|
warnings.warn( |
|
"fan_in_fan_out is set to False but the target module is `Conv1D`. " "Setting fan_in_fan_out to True." |
|
) |
|
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True |
|
kwargs.update(lora_config.loftq_config) |
|
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs) |
|
|
|
return new_module |
|
|