|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import traceback |
|
import pickle |
|
import math |
|
import re |
|
import warnings |
|
from dataclasses import asdict, dataclass, field |
|
from enum import Enum |
|
from typing import List, Optional, Tuple, Union |
|
import itertools |
|
import copy |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.pytorch_utils import Conv1D |
|
from .gating import GATING_TO_MODEL_MAPPING |
|
|
|
from ..import_utils import is_bnb_4bit_available, is_bnb_available |
|
from ..utils import ( |
|
COMMON_LAYERS_PATTERN, |
|
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, |
|
ModulesToSaveWrapper, |
|
PeftConfig, |
|
PeftType, |
|
_freeze_adapter, |
|
_get_submodules, |
|
transpose, |
|
) |
|
|
|
if is_bnb_available(): |
|
import bitsandbytes as bnb |
|
|
|
|
|
@dataclass |
|
class MoeLoraConfig(PeftConfig): |
|
""" |
|
This is the configuration class to store the configuration of a [`MoeLoraModel`]. |
|
|
|
Args: |
|
r (`int`): Lora attention dimension. |
|
target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. |
|
lora_alpha (`int`): The alpha parameter for Lora scaling. |
|
lora_dropout (`float`): The dropout probability for Lora layers. |
|
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). |
|
For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.: |
|
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' |
|
modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable |
|
and saved in the final checkpoint. |
|
layers_to_transform (`Union[List[int],int]`): |
|
The layer indexes to transform, if this argument is specified, it will apply the LoRA transformations on |
|
the layer indexes that are specified in this list. If a single integer is passed, it will apply the LoRA |
|
transformations on the layer at this index. |
|
layers_pattern (`str`): |
|
The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer |
|
pattern is not in the common layers pattern. |
|
""" |
|
|
|
r: int = field(default=8, metadata={"help": "Lora attention dimension"}) |
|
num_moe: int = field(default=8, metadata={"help": "Num experts of MoeLora"}) |
|
gating: str = field(default="Standard", metadata={"help": "Select the gating network for each MoeLora"}) |
|
target_modules: Optional[Union[List[str], str]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "List of module names or regex expression of the module names to replace with Lora." |
|
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' " |
|
}, |
|
) |
|
lora_alpha: int = field(default=8, metadata={"help": "Lora alpha"}) |
|
lora_dropout: float = field(default=0.0, metadata={"help": "Lora dropout"}) |
|
fan_in_fan_out: bool = field( |
|
default=False, |
|
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"}, |
|
) |
|
bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"}) |
|
modules_to_save: Optional[List[str]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " |
|
"For example, in Sequence Classification or Token Classification tasks, " |
|
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved." |
|
}, |
|
) |
|
init_lora_weights: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to initialize the weights of the Lora layers."}, |
|
) |
|
layers_to_transform: Optional[Union[List, int]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index." |
|
}, |
|
) |
|
layers_pattern: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." |
|
}, |
|
) |
|
|
|
def __post_init__(self): |
|
self.peft_type = PeftType.MOELORA |
|
|
|
|
|
class MoeLoraModel(torch.nn.Module): |
|
""" |
|
Creates Low Rank Adapter (Lora) model from a pretrained transformers model. |
|
|
|
Args: |
|
model ([`~transformers.PreTrainedModel`]): The model to be adapted. |
|
config ([`LoraConfig`]): The configuration of the Lora model. |
|
|
|
Returns: |
|
`torch.nn.Module`: The Lora model. |
|
|
|
Example: |
|
|
|
```py |
|
>>> from transformers import AutoModelForSeq2SeqLM |
|
>>> from moelora import MoeLoraModel, MoeLoraConfig |
|
|
|
>>> config = LoraConfig( |
|
... peft_type="LORA", |
|
... task_type="SEQ_2_SEQ_LM", |
|
... r=8, |
|
... lora_alpha=32, |
|
... target_modules=["q", "v"], |
|
... lora_dropout=0.01, |
|
... ) |
|
|
|
>>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
|
>>> lora_model = MoeLoraModel(config, model) |
|
``` |
|
|
|
```py |
|
>>> import transformers |
|
>>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_int8_training |
|
|
|
>>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"] |
|
>>> config = LoraConfig( |
|
... r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" |
|
... ) |
|
|
|
>>> model = transformers.GPTJForCausalLM.from_pretrained( |
|
... "kakaobrain/kogpt", |
|
... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b |
|
... pad_token_id=tokenizer.eos_token_id, |
|
... use_cache=False, |
|
... device_map={"": rank}, |
|
... torch_dtype=torch.float16, |
|
... load_in_8bit=True, |
|
... ) |
|
>>> model = prepare_model_for_int8_training(model) |
|
>>> lora_model = get_peft_model(model, config) |
|
``` |
|
|
|
**Attributes**: |
|
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. |
|
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model. |
|
""" |
|
|
|
def __init__(self, model, config, adapter_name): |
|
super().__init__() |
|
self.model = model |
|
|
|
|
|
|
|
|
|
self.self_model_generate = self.model.generate |
|
self.model.generate = self.generate |
|
|
|
"""self.self_model_forward = self.model.forward |
|
self.model.forward = self.forward""" |
|
|
|
self.peft_config = config |
|
self.global_user_embeds = [] |
|
self.gate_weights = [] |
|
self.add_adapter(adapter_name, self.peft_config[adapter_name]) |
|
|
|
def forward(self, **kwargs): |
|
self.global_user_embeds.clear() |
|
user_embeds = kwargs['user_embeds'] |
|
self.global_user_embeds.extend([user_embeds]) |
|
|
|
self.gate_weights.clear() |
|
gate_weights = kwargs['gate_weights'] |
|
self.gate_weights.extend([gate_weights]) |
|
|
|
return self.model(**kwargs) |
|
|
|
|
|
def generate(self, **kwargs): |
|
self.global_user_embeds.clear() |
|
user_embeds = kwargs['user_embeds'] |
|
self.global_user_embeds.extend([user_embeds]) |
|
|
|
self.gate_weights.clear() |
|
gate_weights = kwargs['gate_weights'] |
|
|
|
self.gate_weights.extend([gate_weights]) |
|
|
|
return self.self_model_generate(**kwargs) |
|
|
|
def add_adapter(self, adapter_name, config=None): |
|
if config is not None: |
|
model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config |
|
config = self._prepare_moelora_config(config, model_config) |
|
self.peft_config[adapter_name] = config |
|
self._find_and_replace(adapter_name) |
|
if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none": |
|
raise ValueError( |
|
"MoeLoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters." |
|
) |
|
mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) |
|
if self.peft_config[adapter_name].inference_mode: |
|
_freeze_adapter(self.model, adapter_name) |
|
|
|
def _check_quantization_dependency(self): |
|
loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False) |
|
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) |
|
if (loaded_in_4bit or loaded_in_8bit) and not is_bnb_available(): |
|
raise ImportError( |
|
"To use Lora with 8-bit or 4-bit quantization, please install the `bitsandbytes` package. " |
|
"You can install it with `pip install bitsandbytes`." |
|
) |
|
|
|
def _check_target_module_exists(self, moelora_config, key): |
|
if isinstance(moelora_config.target_modules, str): |
|
target_module_found = re.fullmatch(moelora_config.target_modules, key) |
|
else: |
|
target_module_found = any(key.endswith(target_key) for target_key in moelora_config.target_modules) |
|
is_using_layer_indexes = getattr(moelora_config, "layers_to_transform", None) is not None |
|
layer_indexing_pattern = getattr(moelora_config, "layers_pattern", None) |
|
|
|
if is_using_layer_indexes and target_module_found: |
|
layers_pattern = COMMON_LAYERS_PATTERN if layer_indexing_pattern is None else layer_indexing_pattern |
|
layers_pattern = [layers_pattern] if isinstance(layers_pattern, str) else layers_pattern |
|
|
|
for pattern in layers_pattern: |
|
layer_index = re.match(f".*.{pattern}\.(\d+)\.*", key) |
|
if layer_index is not None: |
|
layer_index = int(layer_index.group(1)) |
|
if isinstance(moelora_config.layers_to_transform, int): |
|
target_module_found = layer_index == moelora_config.layers_to_transform |
|
else: |
|
target_module_found = layer_index in moelora_config.layers_to_transform |
|
|
|
break |
|
else: |
|
target_module_found = False |
|
return target_module_found |
|
|
|
def _create_new_module(self, moelora_config, adapter_name, target, **kwargs): |
|
bias = hasattr(target, "bias") and target.bias is not None |
|
kwargs = { |
|
"r": moelora_config.r, |
|
|
|
"num_moe": moelora_config.num_moe, |
|
"gating": moelora_config.gating, |
|
"global_user_embeds": self.global_user_embeds, |
|
|
|
"gate_weights": self.gate_weights, |
|
|
|
"lora_alpha": moelora_config.lora_alpha, |
|
"lora_dropout": moelora_config.lora_dropout, |
|
"fan_in_fan_out": moelora_config.fan_in_fan_out, |
|
"init_lora_weights": moelora_config.init_lora_weights, |
|
} |
|
loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False) |
|
loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) |
|
|
|
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): |
|
eightbit_kwargs = kwargs.copy() |
|
eightbit_kwargs.update( |
|
{ |
|
"has_fp16_weights": target.state.has_fp16_weights, |
|
"memory_efficient_backward": target.state.memory_efficient_backward, |
|
"threshold": target.state.threshold, |
|
"index": target.index, |
|
} |
|
) |
|
eightbit_kwargs.pop('fan_in_fan_out', None) |
|
eightbit_kwargs.pop('init_lora_weights', None) |
|
new_module = Linear8bitLt( |
|
adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs |
|
) |
|
elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): |
|
fourbit_kwargs = kwargs.copy() |
|
fourbit_kwargs.update( |
|
{ |
|
"compute_dtype": target.compute_dtype, |
|
"compress_statistics": target.weight.compress_statistics, |
|
"quant_type": target.weight.quant_type, |
|
} |
|
) |
|
new_module = Linear4bit(adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs) |
|
elif isinstance(target, torch.nn.Embedding): |
|
embedding_kwargs = kwargs.copy() |
|
embedding_kwargs.pop("fan_in_fan_out", None) |
|
in_features, out_features = target.num_embeddings, target.embedding_dim |
|
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs) |
|
elif isinstance(target, torch.nn.Conv2d): |
|
out_channels, in_channels = target.weight.size()[:2] |
|
kernel_size = target.weight.size()[2:] |
|
stride = target.stride |
|
padding = target.padding |
|
new_module = Conv2d(adapter_name, in_channels, out_channels, kernel_size, stride, padding, **kwargs) |
|
else: |
|
if isinstance(target, torch.nn.Linear): |
|
in_features, out_features = target.in_features, target.out_features |
|
if kwargs["fan_in_fan_out"]: |
|
warnings.warn( |
|
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " |
|
"Setting fan_in_fan_out to False." |
|
) |
|
kwargs["fan_in_fan_out"] = moelora_config.fan_in_fan_out = False |
|
elif isinstance(target, Conv1D): |
|
in_features, out_features = ( |
|
target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape |
|
) |
|
if not kwargs["fan_in_fan_out"]: |
|
warnings.warn( |
|
"fan_in_fan_out is set to False but the target module is `Conv1D`. " |
|
"Setting fan_in_fan_out to True." |
|
) |
|
kwargs["fan_in_fan_out"] = moelora_config.fan_in_fan_out = True |
|
else: |
|
raise ValueError( |
|
f"Target module {target} is not supported. " |
|
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." |
|
) |
|
new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs) |
|
|
|
return new_module |
|
|
|
def _find_and_replace(self, adapter_name): |
|
moelora_config = self.peft_config[adapter_name] |
|
self._check_quantization_dependency() |
|
is_target_modules_in_base_model = False |
|
key_list = [key for key, _ in self.model.named_modules()] |
|
|
|
for key in key_list: |
|
if not self._check_target_module_exists(moelora_config, key): |
|
continue |
|
|
|
is_target_modules_in_base_model = True |
|
parent, target, target_name = _get_submodules(self.model, key) |
|
|
|
if isinstance(target, MoeLoraLayer) and isinstance(target, torch.nn.Conv2d): |
|
target.update_layer_conv2d( |
|
adapter_name, |
|
moelora_config.r, |
|
moelora_config.lora_alpha, |
|
moelora_config.lora_dropout, |
|
moelora_config.init_lora_weights, |
|
) |
|
elif isinstance(target, MoeLoraLayer): |
|
target.update_layer( |
|
adapter_name, |
|
moelora_config.r, |
|
moelora_config.num_moe, |
|
moelora_config.gating, |
|
moelora_config.lora_alpha, |
|
moelora_config.lora_dropout, |
|
moelora_config.init_lora_weights, |
|
) |
|
else: |
|
new_module = self._create_new_module(moelora_config, adapter_name, target) |
|
self._replace_module(parent, target_name, new_module, target) |
|
|
|
if not is_target_modules_in_base_model: |
|
raise ValueError( |
|
f"Target modules {moelora_config.target_modules} not found in the base model. " |
|
f"Please check the target modules and try again." |
|
) |
|
|
|
def _replace_module(self, parent_module, child_name, new_module, old_module): |
|
setattr(parent_module, child_name, new_module) |
|
new_module.weight = old_module.weight |
|
if hasattr(old_module, "bias"): |
|
if old_module.bias is not None: |
|
new_module.bias = old_module.bias |
|
|
|
if getattr(old_module, "state", None) is not None: |
|
new_module.state = old_module.state |
|
new_module.to(old_module.weight.device) |
|
|
|
|
|
for name, module in new_module.named_modules(): |
|
if "lora_" in name: |
|
module.to(old_module.weight.device) |
|
|
|
if "gating" in name: |
|
module.to(old_module.weight.device) |
|
|
|
if "ranknum" in name: |
|
module.to(old_module.weight.device) |
|
|
|
def __getattr__(self, name: str): |
|
"""Forward missing attributes to the wrapped module.""" |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.model, name) |
|
|
|
def get_peft_config_as_dict(self, inference: bool = False): |
|
config_dict = {} |
|
for key, value in self.peft_config.items(): |
|
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} |
|
if inference: |
|
config["inference_mode"] = True |
|
config_dict[key] = config |
|
return config |
|
|
|
def _set_adapter_layers(self, enabled=True): |
|
for module in self.model.modules(): |
|
if isinstance(module, MoeLoraLayer): |
|
module.disable_adapters = False if enabled else True |
|
|
|
def enable_adapter_layers(self): |
|
self._set_adapter_layers(enabled=True) |
|
|
|
def disable_adapter_layers(self): |
|
self._set_adapter_layers(enabled=False) |
|
|
|
def set_adapter(self, adapter_name): |
|
for module in self.model.modules(): |
|
if isinstance(module, MoeLoraLayer): |
|
if module.merged: |
|
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") |
|
module.unmerge() |
|
module.active_adapter = adapter_name |
|
|
|
def merge_adapter(self): |
|
for module in self.model.modules(): |
|
if isinstance(module, MoeLoraLayer): |
|
module.merge() |
|
|
|
def unmerge_adapter(self): |
|
for module in self.model.modules(): |
|
if isinstance(module, MoeLoraLayer): |
|
module.unmerge() |
|
|
|
@staticmethod |
|
def _prepare_moelora_config(peft_config, model_config): |
|
if peft_config.target_modules is None: |
|
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: |
|
raise ValueError("Please specify `target_modules` in `peft_config`") |
|
peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] |
|
return peft_config |
|
|
|
def merge_and_unload(self): |
|
r""" |
|
This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model |
|
as a standalone model. |
|
""" |
|
if getattr(self.config, "model_type", None) == "gpt2": |
|
raise ValueError("GPT2 models are not supported for merging LORA layers") |
|
|
|
if getattr(self.model, "is_loaded_in_8bit", False) or getattr(self.model, "is_loaded_in_4bit", False): |
|
raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode") |
|
|
|
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] |
|
for key in key_list: |
|
try: |
|
parent, target, target_name = _get_submodules(self.model, key) |
|
except AttributeError: |
|
continue |
|
if isinstance(target, MoeLoraLayer): |
|
if isinstance(target, nn.Embedding): |
|
new_module = torch.nn.Embedding(target.in_features, target.out_features) |
|
elif isinstance(target, nn.Conv2d): |
|
new_module = torch.nn.Conv2d( |
|
target.in_channels, |
|
target.out_channels, |
|
kernel_size=target.kernel_size, |
|
stride=target.stride, |
|
padding=target.padding, |
|
dilation=target.dilation, |
|
) |
|
else: |
|
bias = target.bias is not None |
|
new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias) |
|
target.merge() |
|
self._replace_module(parent, target_name, new_module, target) |
|
|
|
|
|
if isinstance(target, ModulesToSaveWrapper): |
|
setattr(parent, target_name, target.modules_to_save[target.active_adapter]) |
|
|
|
return self.model |
|
|
|
def add_weighted_adapter(self, adapters, weights, adapter_name): |
|
if len({self.peft_config[adapter].r for adapter in adapters}) != 1: |
|
raise ValueError("All adapters must have the same r value") |
|
self.peft_config[adapter_name] = self.peft_config[adapters[0]] |
|
self.peft_config[adapter_name].lora_alpha = self.peft_config[adapters[0]].r |
|
self._find_and_replace(adapter_name) |
|
mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) |
|
_freeze_adapter(self.model, adapter_name) |
|
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] |
|
for key in key_list: |
|
_, target, _ = _get_submodules(self.model, key) |
|
if isinstance(target, MoeLoraLayer): |
|
if adapter_name in target.lora_A: |
|
target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0 |
|
target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0 |
|
for adapter, weight in zip(adapters, weights): |
|
if adapter not in target.lora_A: |
|
continue |
|
target.lora_A[adapter_name].weight.data += ( |
|
target.lora_A[adapter].weight.data * weight * target.scaling[adapter] |
|
) |
|
target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight |
|
|
|
elif adapter_name in target.lora_embedding_A: |
|
target.lora_embedding_A[adapter_name].data = target.lora_embedding_A[adapter_name].data * 0.0 |
|
target.lora_embedding_B[adapter_name].data = target.lora_embedding_B[adapter_name].data * 0.0 |
|
for adapter, weight in zip(adapters, weights): |
|
if adapter not in target.lora_embedding_A: |
|
continue |
|
target.lora_embedding_A[adapter_name].data += ( |
|
target.lora_embedding_A[adapter].data * weight * target.scaling[adapter] |
|
) |
|
target.lora_embedding_B[adapter_name].data += target.lora_embedding_B[adapter].data * weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: |
|
for n, p in model.named_parameters(): |
|
if "lora_" not in n and "gating" not in n: |
|
p.requires_grad = False |
|
if bias == "none": |
|
return |
|
elif bias == "all": |
|
for n, p in model.named_parameters(): |
|
if "bias" in n: |
|
p.requires_grad = True |
|
elif bias == "lora_only": |
|
for m in model.modules(): |
|
if isinstance(m, MoeLoraLayer) and hasattr(m, "bias") and m.bias is not None: |
|
m.bias.requires_grad = True |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
class MoeLoraLayer: |
|
def __init__(self, in_features: int, out_features: int, **kwargs): |
|
self.r = {} |
|
|
|
self.num_moe = {} |
|
|
|
self.lora_alpha = {} |
|
self.scaling = {} |
|
self.lora_dropout = nn.ModuleDict({}) |
|
self.lora_A = nn.ModuleDict({}) |
|
self.lora_B = nn.ModuleDict({}) |
|
|
|
self.gating = nn.ModuleDict({}) |
|
|
|
|
|
self.lora_embedding_A = nn.ParameterDict({}) |
|
self.lora_embedding_B = nn.ParameterDict({}) |
|
|
|
self.merged = False |
|
self.disable_adapters = False |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.kwargs = kwargs |
|
|
|
def update_layer(self, adapter_name, r, num_moe, gating, lora_alpha, lora_dropout, |
|
init_lora_weights): |
|
self.r[adapter_name] = r |
|
|
|
self.num_moe[adapter_name] = num_moe |
|
|
|
self.lora_alpha[adapter_name] = lora_alpha |
|
|
|
self.gating.update( |
|
nn.ModuleDict({adapter_name: GATING_TO_MODEL_MAPPING[gating](num_moe=num_moe, dim=self.in_features)})) |
|
self.gating[adapter_name].to(self.weight.device) |
|
|
|
if lora_dropout > 0.0: |
|
lora_dropout_layer = nn.Dropout(p=lora_dropout) |
|
else: |
|
lora_dropout_layer = nn.Identity() |
|
|
|
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) |
|
|
|
if r > 0: |
|
self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)})) |
|
self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)})) |
|
self.scaling[adapter_name] = lora_alpha / (r // num_moe) |
|
if init_lora_weights: |
|
self.reset_lora_parameters(adapter_name) |
|
self.to(self.weight.device) |
|
|
|
def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): |
|
self.r[adapter_name] = r |
|
self.lora_alpha[adapter_name] = lora_alpha |
|
if lora_dropout > 0.0: |
|
lora_dropout_layer = nn.Dropout(p=lora_dropout) |
|
else: |
|
lora_dropout_layer = nn.Identity() |
|
|
|
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) |
|
|
|
if r > 0: |
|
kernel_size = self.kwargs["kernel_size"] |
|
stride = self.kwargs["stride"] |
|
padding = self.kwargs["padding"] |
|
self.lora_A.update( |
|
nn.ModuleDict({adapter_name: nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)}) |
|
) |
|
self.lora_B.update( |
|
nn.ModuleDict({adapter_name: nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)}) |
|
) |
|
self.scaling[adapter_name] = lora_alpha / r |
|
if init_lora_weights: |
|
self.reset_lora_parameters(adapter_name) |
|
self.to(self.weight.device) |
|
|
|
def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): |
|
self.r[adapter_name] = r |
|
self.lora_alpha[adapter_name] = lora_alpha |
|
if lora_dropout > 0.0: |
|
lora_dropout_layer = nn.Dropout(p=lora_dropout) |
|
else: |
|
lora_dropout_layer = nn.Identity() |
|
|
|
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) |
|
|
|
if r > 0: |
|
self.lora_embedding_A.update( |
|
nn.ParameterDict({adapter_name: nn.Parameter(self.weight.new_zeros((r, self.in_features)))}) |
|
) |
|
self.lora_embedding_B.update( |
|
nn.ParameterDict({adapter_name: nn.Parameter(self.weight.new_zeros((self.out_features, r)))}) |
|
) |
|
self.scaling[adapter_name] = lora_alpha / r |
|
if init_lora_weights: |
|
self.reset_lora_parameters(adapter_name) |
|
self.to(self.weight.device) |
|
|
|
def reset_lora_parameters(self, adapter_name): |
|
if adapter_name in self.lora_A.keys(): |
|
|
|
nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) |
|
nn.init.zeros_(self.lora_B[adapter_name].weight) |
|
if adapter_name in self.lora_embedding_A.keys(): |
|
|
|
nn.init.zeros_(self.lora_embedding_A[adapter_name]) |
|
nn.init.normal_(self.lora_embedding_B[adapter_name]) |
|
|
|
|
|
class Linear(nn.Linear, MoeLoraLayer): |
|
|
|
def __init__( |
|
self, |
|
adapter_name: str, |
|
in_features: int, |
|
out_features: int, |
|
r: int = 0, |
|
|
|
|
|
num_moe: int = 0, |
|
gating: str = "", |
|
global_user_embeds: List = [], |
|
gate_weights: List = [], |
|
|
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
fan_in_fan_out: bool = False, |
|
|
|
**kwargs, |
|
): |
|
init_lora_weights = kwargs.pop("init_lora_weights", True) |
|
|
|
nn.Linear.__init__(self, in_features, out_features, **kwargs) |
|
MoeLoraLayer.__init__(self, in_features=in_features, out_features=out_features) |
|
|
|
self.weight.requires_grad = False |
|
self.gating.requires_grad_ = True |
|
self.fan_in_fan_out = fan_in_fan_out |
|
if fan_in_fan_out: |
|
self.weight.data = self.weight.data.T |
|
nn.Linear.reset_parameters(self) |
|
self.update_layer(adapter_name, r, num_moe, gating, lora_alpha, lora_dropout, init_lora_weights) |
|
self.active_adapter = adapter_name |
|
self.global_user_embeds = global_user_embeds |
|
self.gate_weights = gate_weights |
|
|
|
def merge(self): |
|
if self.active_adapter not in self.lora_A.keys(): |
|
return |
|
if self.merged: |
|
warnings.warn("Already merged. Nothing to do.") |
|
return |
|
if self.r[self.active_adapter] > 0: |
|
self.weight.data += ( |
|
transpose( |
|
self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight, |
|
self.fan_in_fan_out, |
|
) |
|
* self.scaling[self.active_adapter] |
|
) |
|
self.merged = True |
|
|
|
def unmerge(self): |
|
if self.active_adapter not in self.lora_A.keys(): |
|
return |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
if self.r[self.active_adapter] > 0: |
|
self.weight.data -= ( |
|
transpose( |
|
self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight, |
|
self.fan_in_fan_out, |
|
) |
|
* self.scaling[self.active_adapter] |
|
) |
|
self.merged = False |
|
|
|
def calculate_B(self, A_out): |
|
batch_size, seq_len, n, r = A_out.size() |
|
weight = self.lora_B[self.active_adapter].weight.t().reshape(n, r, -1) |
|
return torch.einsum('ijkl, klm->ijkm', A_out, weight) |
|
|
|
|
|
def forward(self, x: torch.Tensor, **kwargs): |
|
flag = False |
|
|
|
gate_weights = self.gate_weights[0] |
|
previous_dtype = x.dtype |
|
batch_size, seq_len, _ = x.size() |
|
|
|
try: |
|
if self.active_adapter not in self.lora_A.keys(): |
|
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
if self.disable_adapters: |
|
if self.r[self.active_adapter] > 0 and self.merged: |
|
self.unmerge() |
|
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
elif self.r[self.active_adapter] > 0 and not self.merged: |
|
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
x = x.to(self.lora_A[self.active_adapter].weight.dtype) |
|
|
|
A_out = self.lora_A[self.active_adapter]( |
|
self.lora_dropout[self.active_adapter](x) |
|
).reshape(batch_size, seq_len, self.num_moe[self.active_adapter], -1) |
|
B_out = self.calculate_B(A_out) |
|
|
|
|
|
Gate = gate_weights.unsqueeze(-1) |
|
|
|
|
|
result += ((B_out * Gate).sum(dim=-2) * self.scaling[self.active_adapter]) |
|
|
|
|
|
gate_weights_squeezed = Gate.squeeze(-1) |
|
gate_weights_split = gate_weights_squeezed.split(1, dim=0) |
|
gate_weight_vectors = [x.squeeze() for x in gate_weights_split] |
|
|
|
else: |
|
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) |
|
|
|
|
|
result = result.to(previous_dtype) |
|
|
|
return result |
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
return torch.zeros_like(x) |
|
|
|
|
|
class Embedding(nn.Embedding, MoeLoraLayer): |
|
|
|
def __init__( |
|
self, |
|
adapter_name: str, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
init_lora_weights = kwargs.pop("init_lora_weights", True) |
|
|
|
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) |
|
MoeLoraLayer.__init__(self, in_features=num_embeddings, out_features=embedding_dim) |
|
|
|
self.weight.requires_grad = False |
|
|
|
nn.Embedding.reset_parameters(self) |
|
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) |
|
self.active_adapter = adapter_name |
|
|
|
def unmerge(self, mode: bool = True): |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
if self.r[self.active_adapter] > 0: |
|
self.weight.data -= ( |
|
transpose( |
|
self.lora_embedding_B[self.active_adapter] @ self.lora_embedding_A[self.active_adapter], True |
|
) |
|
* self.scaling[self.active_adapter] |
|
) |
|
self.merged = False |
|
|
|
def merge(self): |
|
if self.merged: |
|
warnings.warn("Already merged. Nothing to do.") |
|
return |
|
if self.r[self.active_adapter] > 0: |
|
self.weight.data += ( |
|
transpose( |
|
self.lora_embedding_B[self.active_adapter] @ self.lora_embedding_A[self.active_adapter], True |
|
) |
|
* self.scaling[self.active_adapter] |
|
) |
|
self.merged = True |
|
|
|
def forward(self, x: torch.Tensor): |
|
if self.disable_adapters: |
|
if self.r[self.active.adapter] > 0 and self.merged: |
|
self.weight.data -= ( |
|
transpose( |
|
self.lora_embedding_B[self.active_adapter].weight |
|
@ self.lora_embedding_A[self.active_adapter].weight, |
|
True, |
|
) |
|
* self.scaling[self.active_adapter] |
|
) |
|
self.merged = False |
|
return nn.Embedding.forward(self, x) |
|
|
|
elif self.r[self.active_adapter] > 0 and not self.merged: |
|
result = nn.Embedding.forward(self, x) |
|
if self.r[self.active_adapter] > 0: |
|
after_A = F.embedding( |
|
x, |
|
self.lora_embedding_A[self.active_adapter].T, |
|
self.padding_idx, |
|
self.max_norm, |
|
self.norm_type, |
|
self.scale_grad_by_freq, |
|
self.sparse, |
|
) |
|
result += (after_A @ self.lora_embedding_B[self.active_adapter].T) * self.scaling[self.active_adapter] |
|
return result |
|
else: |
|
return nn.Embedding.forward(self, x) |
|
|
|
|
|
class Conv2d(nn.Conv2d, MoeLoraLayer): |
|
|
|
def __init__( |
|
self, |
|
adapter_name: str, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: Union[int, Tuple[int]], |
|
stride: Union[int, Tuple[int]] = 1, |
|
padding: Union[int, Tuple[int]] = 0, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
init_lora_weights = kwargs.pop("init_lora_weights", True) |
|
|
|
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding) |
|
MoeLoraLayer.__init__( |
|
self, |
|
in_features=in_channels, |
|
out_features=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
) |
|
|
|
self.weight.requires_grad = False |
|
|
|
nn.Conv2d.reset_parameters(self) |
|
self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) |
|
self.active_adapter = adapter_name |
|
|
|
def merge(self): |
|
if self.active_adapter not in self.lora_A.keys(): |
|
return |
|
if self.merged: |
|
warnings.warn("Already merged. Nothing to do.") |
|
return |
|
if self.r[self.active_adapter] > 0: |
|
|
|
if self.weight.size()[2:4] == (1, 1): |
|
|
|
self.weight.data += ( |
|
self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2) |
|
@ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2) |
|
).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter] |
|
else: |
|
|
|
self.weight.data += ( |
|
F.conv2d( |
|
self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3), |
|
self.lora_B[self.active_adapter].weight, |
|
).permute(1, 0, 2, 3) |
|
* self.scaling[self.active_adapter] |
|
) |
|
self.merged = True |
|
|
|
def unmerge(self): |
|
if self.active_adapter not in self.lora_A.keys(): |
|
return |
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
if self.r[self.active_adapter] > 0: |
|
if self.weight.size()[2:4] == (1, 1): |
|
|
|
self.weight.data -= ( |
|
self.lora_B[self.active_adapter].weight.squeeze(3).squeeze(2) |
|
@ self.lora_A[self.active_adapter].weight.squeeze(3).squeeze(2) |
|
).unsqueeze(2).unsqueeze(3) * self.scaling[self.active_adapter] |
|
else: |
|
|
|
self.weight.data += ( |
|
F.conv2d( |
|
self.lora_A[self.active_adapter].weight.permute(1, 0, 2, 3), |
|
self.lora_B[self.active_adapter].weight, |
|
).permute(1, 0, 2, 3) |
|
* self.scaling[self.active_adapter] |
|
) |
|
self.merged = False |
|
|
|
def forward(self, x: torch.Tensor): |
|
previous_dtype = x.dtype |
|
|
|
if self.active_adapter not in self.lora_A.keys(): |
|
return F.conv2d( |
|
x, |
|
self.weight, |
|
bias=self.bias, |
|
stride=self.stride, |
|
padding=self.padding, |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
) |
|
if self.disable_adapters: |
|
if self.r[self.active_adapter] > 0 and self.merged: |
|
self.unmerge() |
|
result = F.conv2d( |
|
x, |
|
self.weight, |
|
bias=self.bias, |
|
stride=self.stride, |
|
padding=self.padding, |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
) |
|
elif self.r[self.active_adapter] > 0 and not self.merged: |
|
result = F.conv2d( |
|
x, |
|
self.weight, |
|
bias=self.bias, |
|
stride=self.stride, |
|
padding=self.padding, |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
) |
|
|
|
x = x.to(self.lora_A[self.active_adapter].weight.dtype) |
|
|
|
result += ( |
|
self.lora_B[self.active_adapter]( |
|
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x)) |
|
) |
|
* self.scaling[self.active_adapter] |
|
) |
|
else: |
|
result = F.conv2d( |
|
x, |
|
self.weight, |
|
bias=self.bias, |
|
stride=self.stride, |
|
padding=self.padding, |
|
dilation=self.dilation, |
|
groups=self.groups, |
|
) |
|
|
|
result = result.to(previous_dtype) |
|
|
|
return result |
|
|
|
|
|
if is_bnb_available(): |
|
|
|
class Linear8bitLt(bnb.nn.Linear8bitLt, MoeLoraLayer): |
|
|
|
def __init__( |
|
self, |
|
adapter_name: str, |
|
in_features: int, |
|
out_features: int, |
|
r: int = 0, |
|
|
|
|
|
num_moe: int = 0, |
|
gating: str = "", |
|
global_user_embeds: List = [], |
|
loss_fn: str = "", |
|
|
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
kwargs.setdefault('bias', True) |
|
bnb.nn.Linear8bitLt.__init__( |
|
self, |
|
in_features, |
|
out_features, |
|
|
|
) |
|
MoeLoraLayer.__init__(self, in_features=in_features, out_features=out_features) |
|
|
|
|
|
self.weight.requires_grad = False |
|
init_lora_weights = kwargs.pop("init_lora_weights", True) |
|
|
|
|
|
|
|
self.update_layer(adapter_name, r, num_moe, gating, lora_alpha, lora_dropout, init_lora_weights) |
|
self.active_adapter = adapter_name |
|
self.global_user_embeds = global_user_embeds |
|
|
|
def calculate_B(self, A_out): |
|
|
|
batch_size, seq_len, n, r = A_out.size() |
|
weight = self.lora_B[self.active_adapter].weight.t().reshape(n, r, -1) |
|
return torch.einsum('ijkl, klm->ijkm', A_out, weight) |
|
|
|
def forward(self, x: torch.Tensor): |
|
user_embeds = self.global_user_embeds[0] |
|
result = super().forward(x) |
|
|
|
batch_size, seq_len, _ = x.size() |
|
|
|
if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): |
|
return result |
|
elif self.r[self.active_adapter] > 0: |
|
if not torch.is_autocast_enabled(): |
|
expected_dtype = result.dtype |
|
|
|
if x.dtype != torch.float32: |
|
x = x.float() |
|
A_out = self.lora_A[self.active_adapter]( |
|
self.lora_dropout[self.active_adapter](x) |
|
).reshape(batch_size, seq_len, self.num_moe[self.active_adapter], -1) |
|
B_out = self.calculate_B(A_out) |
|
Gate = self.gating[self.active_adapter](user_embeds).unsqueeze(-1) |
|
output = ((B_out * Gate).sum(dim=-2).to(expected_dtype) * self.scaling[self.active_adapter]) |
|
|
|
else: |
|
A_out = self.lora_A[self.active_adapter]( |
|
self.lora_dropout[self.active_adapter](x) |
|
).reshape(batch_size, seq_len, self.num_moe[self.active_adapter], -1) |
|
B_out = self.calculate_B(A_out) |
|
Gate = self.gating[self.active_adapter](user_embeds).unsqueeze(-1) |
|
output = ((B_out * Gate).sum(dim=-2) * self.scaling[self.active_adapter]) |
|
|
|
result += output |
|
return result |
|
|
|
|
|
if is_bnb_4bit_available(): |
|
|
|
class Linear4bit(bnb.nn.Linear4bit, MoeLoraLayer): |
|
|
|
def __init__( |
|
self, |
|
adapter_name, |
|
in_features, |
|
out_features, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
bnb.nn.Linear4bit.__init__( |
|
self, |
|
in_features, |
|
out_features, |
|
bias=kwargs.get("bias", True), |
|
compute_dtype=kwargs.get("compute_dtype", torch.float32), |
|
compress_statistics=kwargs.get("compress_statistics", True), |
|
quant_type=kwargs.get("quant_type", "nf4"), |
|
) |
|
MoeLoraLayer.__init__(self, in_features=in_features, out_features=out_features) |
|
|
|
|
|
self.weight.requires_grad = False |
|
|
|
init_lora_weights = kwargs.pop("init_lora_weights", True) |
|
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) |
|
self.active_adapter = adapter_name |
|
|
|
def forward(self, x: torch.Tensor): |
|
result = super().forward(x) |
|
user_embeds = self.global_user_embeds[0] |
|
|
|
if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): |
|
return result |
|
elif self.r[self.active_adapter] > 0: |
|
result = result.clone() |
|
if not torch.is_autocast_enabled(): |
|
expected_dtype = result.dtype |
|
x = x.to(self.lora_A[self.active_adapter].weight.dtype) |
|
batch_size, seq_len, _ = x.size() |
|
A_out = self.lora_A[self.active_adapter]( |
|
self.lora_dropout[self.active_adapter](x) |
|
).reshape(batch_size, seq_len, self.num_moe[self.active_adapter], -1) |
|
B_out = self.calculate_B(A_out) |
|
Gate = self.gating[self.active_adapter](user_embeds).unsqueeze(-1) |
|
output = ((B_out * Gate).sum(dim=-2).to(expected_dtype) * self.scaling[self.active_adapter]) |
|
else: |
|
x = x.to(self.lora_A[self.active_adapter].weight.dtype) |
|
batch_size, seq_len, _ = x.size() |
|
A_out = self.lora_A[self.active_adapter]( |
|
self.lora_dropout[self.active_adapter](x) |
|
).reshape(batch_size, seq_len, self.num_moe[self.active_adapter], -1) |
|
B_out = self.calculate_B(A_out) |
|
Gate = self.gating[self.active_adapter](user_embeds).unsqueeze(-1) |
|
output = ((B_out * Gate).sum(dim=-2) * self.scaling[self.active_adapter]) |
|
result += output |
|
return result |
|
|