vilarin's picture
Upload 16 files
021dc80 verified
raw
history blame
2.55 kB
import torch
from torch import nn
def replace_linear_with_lora(
module: nn.Module,
max_rank: int,
scale: float = 1.0,
) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
new_lora = LinearLora(
in_features=child.in_features,
out_features=child.out_features,
bias=child.bias,
rank=max_rank,
scale=scale,
dtype=child.weight.dtype,
device=child.weight.device,
)
new_lora.weight = child.weight
new_lora.bias = child.bias if child.bias is not None else None
setattr(module, name, new_lora)
else:
replace_linear_with_lora(
module=child,
max_rank=max_rank,
scale=scale,
)
class LinearLora(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
rank: int,
dtype: torch.dtype,
device: torch.device,
lora_bias: bool = True,
scale: float = 1.0,
*args,
**kwargs,
) -> None:
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias is not None,
device=device,
dtype=dtype,
*args,
**kwargs,
)
assert isinstance(scale, float), "scale must be a float"
self.scale = scale
self.rank = rank
self.lora_bias = lora_bias
self.dtype = dtype
self.device = device
if rank > (new_rank := min(self.out_features, self.in_features)):
self.rank = new_rank
self.lora_A = nn.Linear(
in_features=in_features,
out_features=self.rank,
bias=False,
dtype=dtype,
device=device,
)
self.lora_B = nn.Linear(
in_features=self.rank,
out_features=out_features,
bias=self.lora_bias,
dtype=dtype,
device=device,
)
def set_scale(self, scale: float) -> None:
assert isinstance(scale, float), "scalar value must be a float"
self.scale = scale
def forward(self, input: torch.Tensor) -> torch.Tensor:
base_out = super().forward(input)
_lora_out_B = self.lora_B(self.lora_A(input))
lora_update = _lora_out_B * self.scale
return base_out + lora_update