Chain-of-Zoom / lora /lora_layers.py
alexnasa's picture
Upload 54 files
0301e15 verified
from typing import List, Optional, Set, Type, Union
import torch
from torch import nn
class LoraInjectedLinear(nn.Module):
"""
Linear layer with LoRA injection.
Taken from https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
"""
def __init__(
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
):
super().__init__()
if r > min(in_features, out_features):
raise ValueError(
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
)
self.r = r
self.linear = nn.Linear(in_features, out_features, bias)
self.lora_down = nn.Linear(in_features, r, bias=False)
self.dropout = nn.Dropout(dropout_p)
self.lora_up = nn.Linear(r, out_features, bias=False)
self.scale = scale
self.selector = nn.Identity()
nn.init.normal_(self.lora_down.weight, std=1 / r)
nn.init.zeros_(self.lora_up.weight)
def forward(self, input):
return (
self.linear(input.float())
+ self.dropout(self.lora_up(self.selector(self.lora_down(input.float()))))
* self.scale
).half()
def realize_as_lora(self):
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
def set_selector_from_diag(self, diag: torch.Tensor):
# diag is a 1D tensor of size (r,)
assert diag.shape == (self.r,)
self.selector = nn.Linear(self.r, self.r, bias=False)
self.selector.weight.data = torch.diag(diag)
self.selector.weight.data = self.selector.weight.data.to(
self.lora_up.weight.device
).to(self.lora_up.weight.dtype)
class LoraInjectedConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups: int = 1,
bias: bool = True,
r: int = 4,
dropout_p: float = 0.1,
scale: float = 1.0,
):
super().__init__()
if r > min(in_channels, out_channels):
raise ValueError(
f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
)
self.r = r
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
self.lora_down = nn.Conv2d(
in_channels=in_channels,
out_channels=r,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=False,
)
self.dropout = nn.Dropout(dropout_p)
self.lora_up = nn.Conv2d(
in_channels=r,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.selector = nn.Identity()
self.scale = scale
nn.init.normal_(self.lora_down.weight, std=1 / r)
nn.init.zeros_(self.lora_up.weight)
def forward(self, input):
return (
self.conv(input)
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
* self.scale
)
def realize_as_lora(self):
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
def set_selector_from_diag(self, diag: torch.Tensor):
# diag is a 1D tensor of size (r,)
assert diag.shape == (self.r,)
self.selector = nn.Conv2d(
in_channels=self.r,
out_channels=self.r,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.selector.weight.data = torch.diag(diag)
# same device + dtype as lora_up
self.selector.weight.data = self.selector.weight.data.to(
self.lora_up.weight.device
).to(self.lora_up.weight.dtype)