|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx |
|
from peft.utils.other import transpose |
|
|
|
|
|
class DoraLinearLayer(nn.Module): |
|
def __init__(self, fan_in_fan_out): |
|
super().__init__() |
|
self.fan_in_fan_out = fan_in_fan_out |
|
|
|
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: |
|
|
|
weight = transpose(weight, self.fan_in_fan_out) |
|
weight = weight + scaling * lora_weight |
|
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype) |
|
return weight_norm |
|
|
|
def update_layer(self, *, base_layer, lora_A, lora_B, scaling, place_on_cpu=False) -> None: |
|
|
|
dtype_is_fp16 = lora_A.dtype == torch.float16 |
|
if dtype_is_fp16: |
|
lora_A = lora_A.float() |
|
lora_B = lora_B.float() |
|
|
|
with gather_params_ctx(base_layer.parameters()): |
|
if base_layer.__class__.__name__ == "Linear4bit": |
|
|
|
|
|
base_layer = deepcopy(base_layer) |
|
|
|
weight = dequantize_module_weight(base_layer) |
|
if weight.data.ndim >= 4: |
|
lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1)) |
|
lora_weight = lora_weight.reshape(weight.shape) |
|
else: |
|
lora_weight = lora_B @ lora_A |
|
|
|
if dtype_is_fp16: |
|
lora_weight = lora_weight.half() |
|
weight_norm = self.get_weight_norm(weight.to(lora_A.device), lora_weight, scaling) |
|
|
|
if place_on_cpu: |
|
weight_norm = weight_norm.to("cpu") |
|
self.weight = nn.Parameter(weight_norm, requires_grad=True) |
|
|
|
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, base_result=None): |
|
""" |
|
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer |
|
output. |
|
""" |
|
|
|
|
|
x_eye = torch.eye(lora_A.weight.shape[1], device=lora_A.weight.device, dtype=x.dtype) |
|
lora_weight = lora_B(lora_A(x_eye)).T |
|
|
|
magnitude = self.weight |
|
weight = dequantize_module_weight(base_layer) |
|
weight = weight.to(x.dtype) |
|
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) |
|
|
|
|
|
|
|
|
|
|
|
|
|
weight_norm = weight_norm.detach() |
|
mag_norm_scale = (magnitude / weight_norm).view(1, -1) |
|
|
|
lora_result = lora_B(lora_A(x)) |
|
|
|
bias = None |
|
if base_result is not None: |
|
bias = base_layer.bias |
|
if bias is not None: |
|
base_result = base_result - bias |
|
else: |
|
base_result = F.linear(x, transpose(weight, self.fan_in_fan_out)) |
|
|
|
result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling |
|
|
|
return result_dora |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return "lora.dora." + rep |
|
|
|
|
|
class DoraEmbeddingLayer(DoraLinearLayer): |
|
def forward(self, x, *, lora_A, lora_B, scaling, base_layer, embed_fn): |
|
""" |
|
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer |
|
output. |
|
""" |
|
lora_weight = (lora_A @ lora_B).T |
|
magnitude = self.weight |
|
weight = base_layer.weight |
|
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) |
|
|
|
|
|
|
|
|
|
|
|
|
|
weight_norm = weight_norm.detach() |
|
mag_norm_scale = magnitude / weight_norm |
|
result_dora = mag_norm_scale * (embed_fn(x, lora_A) @ lora_B) * scaling |
|
return mag_norm_scale, result_dora |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return "lora.dora." + rep |
|
|
|
|
|
class _DoraConvNdLayer(DoraLinearLayer): |
|
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: |
|
|
|
weight = weight + scaling * lora_weight |
|
|
|
dim = tuple(range(1, weight.dim())) |
|
weight_norm = weight.norm(p=2, dim=dim, keepdim=True).transpose(1, 0) |
|
return weight_norm |
|
|
|
def forward(self, x, *, lora_A, lora_B, scaling, base_layer): |
|
""" |
|
For DoRA, calculate the extra output from LoRA with DoRA applied. This should be added on top of the base layer |
|
output. |
|
""" |
|
weight = base_layer.weight |
|
lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) |
|
lora_weight = lora_weight.reshape(weight.shape) |
|
magnitude = self.weight |
|
weight_norm = self.get_weight_norm(weight, lora_weight.detach(), scaling) |
|
|
|
|
|
|
|
|
|
|
|
|
|
weight_norm = weight_norm.detach() |
|
mag_norm_scale = magnitude / weight_norm |
|
result_dora = (mag_norm_scale - 1) * ( |
|
self.conv_fn( |
|
x, |
|
weight, |
|
bias=None, |
|
stride=base_layer.stride, |
|
padding=base_layer.padding, |
|
dilation=base_layer.dilation, |
|
groups=base_layer.groups, |
|
) |
|
) + mag_norm_scale * lora_B(lora_A(x)) * scaling |
|
|
|
return result_dora |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return "lora.dora." + rep |
|
|
|
|
|
class DoraConv2dLayer(_DoraConvNdLayer): |
|
def __init__(self, fan_in_fan_out): |
|
super().__init__(fan_in_fan_out) |
|
self.conv_fn = F.conv2d |
|
|
|
|
|
class DoraConv3dLayer(_DoraConvNdLayer): |
|
def __init__(self, fan_in_fan_out): |
|
super().__init__(fan_in_fan_out) |
|
self.conv_fn = F.conv3d |
|
|