# Copyright (c) 2021 microsoft # 2023 Alan (alanfangemail@gmail.com) # ----------------------------------------------------------------------------- # Licensed under the MIT License (MIT). See LICENSE in the repo root for # license information. # ----------------------------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F import math from typing import List class LoRALayer(): def __init__( self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool, ): self.r = r self.lora_alpha = lora_alpha # Optional dropout if lora_dropout > 0.: self.lora_dropout = nn.Dropout(p=lora_dropout) else: self.lora_dropout = self.identity # Mark the weight as unmerged self.merged = False self.merge_weights = merge_weights def identity(self, x): return x class Embedding(nn.Embedding, LoRALayer): # LoRA implemented in a dense layer def __init__(self, num_embeddings: int, embedding_dim: int, r: int = 0, lora_alpha: int = 1, merge_weights: bool = True, **kwargs): nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, merge_weights=merge_weights) # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter( self.weight.new_zeros((r, num_embeddings))) self.lora_B = nn.Parameter( self.weight.new_zeros((embedding_dim, r))) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() def reset_parameters(self): nn.Embedding.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.zeros_(self.lora_A) nn.init.normal_(self.lora_B) def train(self, mode: bool = True): nn.Embedding.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: temp = (self.lora_B @ self.lora_A).transpose(0, 1) self.weight.data -= temp * self.scaling self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0: temp = (self.lora_B @ self.lora_A).transpose(0, 1) self.weight.data += temp * self.scaling self.merged = True def forward(self, x: torch.Tensor): if self.r > 0 and not self.merged: result = nn.Embedding.forward(self, x) after_A = F.embedding(x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling return result else: return nn.Embedding.forward(self, x) class Linear(nn.Linear, LoRALayer): # LoRA implemented in a dense layer def __init__( self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0., fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, # fan_out) merge_weights: bool = True, **kwargs): nn.Linear.__init__(self, in_features, out_features, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) self.lora_B = nn.Parameter(self.weight.new_zeros( (out_features, r))) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1) def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def T(self, w): return w.transpose(0, 1) if self.fan_in_fan_out else w def train(self, mode: bool = True): nn.Linear.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0: temp = self.T(self.lora_B @ self.lora_A) self.weight.data -= temp * self.scaling self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0: temp = self.T(self.lora_B @ self.lora_A) self.weight.data += temp * self.scaling self.merged = True def forward(self, x: torch.Tensor): if self.r > 0 and not self.merged: result = F.linear(x, self.T(self.weight), bias=self.bias) result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling return result else: return F.linear(x, self.T(self.weight), bias=self.bias) class MergedLinear(nn.Linear, LoRALayer): # LoRA implemented in a dense layer def __init__(self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0., enable_lora: List[bool] = None, fan_in_fan_out: bool = False, merge_weights: bool = True, **kwargs): if enable_lora is None: enable_lora = [False] nn.Linear.__init__(self, in_features, out_features, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) assert out_features % len(enable_lora) == 0, \ 'The length of enable_lora must divide out_features' self.enable_lora = enable_lora self.fan_in_fan_out = fan_in_fan_out # Actual trainable parameters if r > 0 and any(enable_lora): self.lora_A = nn.Parameter( self.weight.new_zeros((r * sum(enable_lora), in_features))) self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features // len(enable_lora) * sum(enable_lora), r))) # weights for Conv1D with groups=sum(enable_lora) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.weight.requires_grad = False # Compute the indices self.lora_ind = self.weight.new_zeros( (out_features, ), dtype=torch.bool).view(len(enable_lora), -1) self.lora_ind[enable_lora, :] = True self.lora_ind = self.lora_ind.view(-1) self.reset_parameters() if fan_in_fan_out: self.weight.data = self.weight.data.transpose(0, 1) def reset_parameters(self): nn.Linear.reset_parameters(self) if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def zero_pad(self, x): result = x.new_zeros((len(self.lora_ind), *x.size()[1:])) result[self.lora_ind] = x return result def T(self, w): return w.transpose(0, 1) if self.fan_in_fan_out else w def merge_AB(self): delta_w = F.conv1d(self.lora_A.unsqueeze(0), self.lora_B.unsqueeze(-1), groups=sum(self.enable_lora)).squeeze(0) return self.T(delta_w) def train(self, mode: bool = True): nn.Linear.train(self, mode) if mode: if self.merge_weights and self.merged: # Make sure that the weights are not merged if self.r > 0 and any(self.enable_lora): self.weight.data -= self.merge_AB() * self.scaling self.merged = False else: if self.merge_weights and not self.merged: # Merge the weights and mark it if self.r > 0 and any(self.enable_lora): self.weight.data += self.merge_AB() * self.scaling self.merged = True def forward(self, x: torch.Tensor): if self.merged: return F.linear(x, self.T(self.weight), bias=self.bias) else: result = F.linear(x, self.T(self.weight), bias=self.bias) if self.r > 0: temp = self.T(self.merge_AB().T) result += self.lora_dropout(x) @ temp * self.scaling return result class ConvLoRA(nn.Module, LoRALayer): def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): super(ConvLoRA, self).__init__() self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) assert isinstance(kernel_size, int) # Actual trainable parameters if r > 0: self.lora_A = nn.Parameter( self.conv.weight.new_zeros( (r * kernel_size, in_channels * kernel_size))) self.lora_B = nn.Parameter( self.conv.weight.new_zeros( (out_channels // self.conv.groups * kernel_size, r * kernel_size))) self.scaling = self.lora_alpha / self.r # Freezing the pre-trained weight matrix self.conv.weight.requires_grad = False self.reset_parameters() self.merged = False def reset_parameters(self): self.conv.reset_parameters() if hasattr(self, 'lora_A'): # initialize A the same way as the default for nn.Linear and B to zero nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.zeros_(self.lora_B) def train(self, mode=True): super(ConvLoRA, self).train(mode) if mode: if self.merge_weights and self.merged: if self.r > 0: # Make sure that the weights are not merged self.conv.weight.data -= (self.lora_B @ self.lora_A).view( self.conv.weight.shape) * self.scaling self.merged = False else: if self.merge_weights and not self.merged: if self.r > 0: # Merge the weights and mark it self.conv.weight.data += (self.lora_B @ self.lora_A).view( self.conv.weight.shape) * self.scaling self.merged = True def forward(self, x): if self.r > 0 and not self.merged: return self.conv._conv_forward( x, self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling, self.conv.bias) return self.conv(x) class Conv2d(ConvLoRA): def __init__(self, *args, **kwargs): super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs) class Conv1d(ConvLoRA): def __init__(self, *args, **kwargs): super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs) # Can Extend to other ones like this class Conv3d(ConvLoRA): def __init__(self, *args, **kwargs): super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)