|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This module contains the implementation of the LoraPlus optimizer. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
from operator import attrgetter |
|
|
|
import torch.nn as nn |
|
from torch.optim import Optimizer |
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS |
|
from transformers.trainer_pt_utils import get_parameter_names |
|
|
|
from ..peft_model import PeftModel |
|
from ..tuners.lora.layer import Embedding |
|
|
|
|
|
def create_loraplus_optimizer( |
|
model: PeftModel, optimizer_cls: type[Optimizer], *, lr: float, loraplus_lr_ratio: float, **kwargs |
|
) -> Optimizer: |
|
""" |
|
Creates a LoraPlus optimizer. |
|
|
|
Efficient Low Rank Adaptation of Large Models: https://arxiv.org/abs/2402.12354 |
|
|
|
Reference: https://github.com/nikhil-ghosh-berkeley/loraplus/ |
|
|
|
Args: |
|
model (`torch.nn.Module`): The model to be optimized. |
|
optimizer_cls (`torch.optim.Optimizer`): The optimizer class to be used. |
|
lr (`float`): The learning rate to be used for the optimizer. |
|
loraplus_lr_ratio (`float`): |
|
The ratio of learning ηB/ηA where ηA (lr) is passed in as the optimizer learning rate. Should be ≥1. Should |
|
be set in tandem with the optimizer learning rate (lr); should be larger when the task is more difficult |
|
and the model needs to update its features to learn well. In this case, it helps to make the learning rate |
|
slightly smaller (e.g., by a factor of 2) than typical vanilla LoRA learning rates |
|
loraplus_lr_embedding (optional `float`): |
|
If LoRA modules are added to embedding layers your can specify a different learning rate for them. Default |
|
value 1e-6. |
|
kwargs (`dict`): Additional keyword arguments to be passed to the optimizer. |
|
|
|
Returns: |
|
`torch.optim.Optimizer`: An instance of the specified optimizer class configured with the model's parameters |
|
organized into groups with custom learning rates. |
|
""" |
|
|
|
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) |
|
decay_parameters = [name for name in decay_parameters if "bias" not in name] |
|
param_groups = { |
|
"groupA": {}, |
|
"groupB": {}, |
|
"groupB_no_decay": {}, |
|
"embedding": {}, |
|
} |
|
|
|
for name, param in model.named_parameters(): |
|
if not param.requires_grad: |
|
continue |
|
|
|
module = attrgetter(name)(model) |
|
if isinstance(module, Embedding): |
|
param_groups["embedding"][name] = param |
|
elif "lora_B" in name or param.ndim == 1: |
|
if name in decay_parameters: |
|
param_groups["groupB"][name] = param |
|
else: |
|
param_groups["groupB_no_decay"][name] = param |
|
else: |
|
param_groups["groupA"][name] = param |
|
|
|
kwargs["lr"] = lr |
|
loraplus_weight_decay = kwargs.pop("loraplus_weight_decay", 0.0) |
|
loraplus_lr_embedding = kwargs.pop("loraplus_lr_embedding", 1e-6) |
|
|
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": list(param_groups["groupA"].values()), |
|
"weight_decay": loraplus_weight_decay, |
|
"lr": lr, |
|
}, |
|
{ |
|
"params": list(param_groups["embedding"].values()), |
|
"weight_decay": loraplus_weight_decay, |
|
"lr": loraplus_lr_embedding, |
|
}, |
|
{ |
|
"params": list(param_groups["groupB"].values()), |
|
"weight_decay": loraplus_weight_decay, |
|
"lr": lr * loraplus_lr_ratio, |
|
}, |
|
{ |
|
"params": list(param_groups["groupB_no_decay"].values()), |
|
"weight_decay": 0.0, |
|
"lr": lr * loraplus_lr_ratio, |
|
}, |
|
] |
|
|
|
optimizer = optimizer_cls(optimizer_grouped_parameters, **kwargs) |
|
eight_bit_names = ["Adam8bit", "AdamW8bit", "PagedAdam8bit", "PagedAdamW8bit"] |
|
if optimizer_cls.__name__ in eight_bit_names: |
|
import bitsandbytes |
|
|
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance() |
|
for module in model.modules(): |
|
if isinstance(module, nn.Embedding): |
|
manager.register_module_override(module, "weight", {"optim_bits": 32}) |
|
return optimizer |
|
|