from typing import List, Optional, Set, Type, Union import torch from torch import nn class LoraInjectedLinear(nn.Module): """ Linear layer with LoRA injection. Taken from https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py """ def __init__( self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0 ): super().__init__() if r > min(in_features, out_features): raise ValueError( f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" ) self.r = r self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) self.dropout = nn.Dropout(dropout_p) self.lora_up = nn.Linear(r, out_features, bias=False) self.scale = scale self.selector = nn.Identity() nn.init.normal_(self.lora_down.weight, std=1 / r) nn.init.zeros_(self.lora_up.weight) def forward(self, input): return ( self.linear(input.float()) + self.dropout(self.lora_up(self.selector(self.lora_down(input.float())))) * self.scale ).half() def realize_as_lora(self): return self.lora_up.weight.data * self.scale, self.lora_down.weight.data def set_selector_from_diag(self, diag: torch.Tensor): # diag is a 1D tensor of size (r,) assert diag.shape == (self.r,) self.selector = nn.Linear(self.r, self.r, bias=False) self.selector.weight.data = torch.diag(diag) self.selector.weight.data = self.selector.weight.data.to( self.lora_up.weight.device ).to(self.lora_up.weight.dtype) class LoraInjectedConv2d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size, stride=1, padding=0, dilation=1, groups: int = 1, bias: bool = True, r: int = 4, dropout_p: float = 0.1, scale: float = 1.0, ): super().__init__() if r > min(in_channels, out_channels): raise ValueError( f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}" ) self.r = r self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, ) self.lora_down = nn.Conv2d( in_channels=in_channels, out_channels=r, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=False, ) self.dropout = nn.Dropout(dropout_p) self.lora_up = nn.Conv2d( in_channels=r, out_channels=out_channels, kernel_size=1, stride=1, padding=0, bias=False, ) self.selector = nn.Identity() self.scale = scale nn.init.normal_(self.lora_down.weight, std=1 / r) nn.init.zeros_(self.lora_up.weight) def forward(self, input): return ( self.conv(input) + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) * self.scale ) def realize_as_lora(self): return self.lora_up.weight.data * self.scale, self.lora_down.weight.data def set_selector_from_diag(self, diag: torch.Tensor): # diag is a 1D tensor of size (r,) assert diag.shape == (self.r,) self.selector = nn.Conv2d( in_channels=self.r, out_channels=self.r, kernel_size=1, stride=1, padding=0, bias=False, ) self.selector.weight.data = torch.diag(diag) # same device + dtype as lora_up self.selector.weight.data = self.selector.weight.data.to( self.lora_up.weight.device ).to(self.lora_up.weight.dtype)