Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import List, Set, Optional, Type | |
import torch | |
import torch.nn as nn | |
SELF_ATTENTION_MODULES = {'Attention', 'NormAttention'} | |
CROSS_ATTENTION_MODULES = {'CrossAttention', 'NormCrossAttention'} | |
ATTENTION_MODULES = SELF_ATTENTION_MODULES | CROSS_ATTENTION_MODULES | |
MLP_MODULES = {'Mlp', 'GatedMlp', 'SwiGLUFFNFused'} # SwiGLUFFNFused is from DINOv2 | |
TRANSFORMER_MODULES = ATTENTION_MODULES | MLP_MODULES | |
def get_LoRA_module_names(id: str) -> Set[str]: | |
""" Returns a list of module names that are LoRA-adapted for the given id. """ | |
id = id.lower() | |
if id in ['selfattn', 'selfattention', 'self_attn', 'self_attention']: | |
return SELF_ATTENTION_MODULES | |
elif id in ['crossattn', 'crossattention', 'cross_attn', 'cross_attention']: | |
return CROSS_ATTENTION_MODULES | |
elif id in ['attn', 'attention']: | |
return ATTENTION_MODULES | |
elif id in ['mlp']: | |
return MLP_MODULES | |
elif id in ['all', 'transformer']: | |
return TRANSFORMER_MODULES | |
else: | |
raise ValueError(f'Unknown LoRA module id {id}.') | |
class LoRAWrapper(nn.Module): | |
"""Low-Rank Adaptation Wrapper for linear layers. | |
See https://arxiv.org/abs/2106.09685 | |
Args: | |
linear: nn.Linear layer to wrap | |
rank: Rank of adaptation matrix B@A | |
scale: x = W_0@x + scale * B@A@x | |
num_packed_linear: Set to > 1 when wrapping e.g. packed kv, or qkv attention weights. | |
Weights will be initialized as if num_packed_linear = 1, but the LoRA bottleneck will | |
be num_packed_linear times larger. | |
""" | |
def __init__(self, linear: nn.Module, rank: int = 4, scale: float = 1.0, num_packed_linear: int = 1): | |
super().__init__() | |
self.rank = rank | |
self.scale = scale | |
self.in_features, self.out_features = linear.in_features, linear.out_features | |
assert num_packed_linear * rank <= min(self.in_features, self.out_features), \ | |
f'LoRA rank {num_packed_linear} * {rank} must be less or equal than {min(self.in_features, self.out_features)}' | |
self.linear = linear | |
self.lora_down = nn.Linear(self.in_features, num_packed_linear*rank, bias=False) | |
self.lora_up = nn.Linear(num_packed_linear*rank, self.out_features, bias=False) | |
nn.init.normal_(self.lora_down.weight, std=1/rank) | |
nn.init.zeros_(self.lora_up.weight) | |
def fuse_LoRA_into_linear(self) -> nn.Linear: | |
""" Returns a single nn.Linear layer with the LoRA matrix fused into the original one. """ | |
fused_linear = nn.Linear(self.in_features, self.out_features, bias=self.linear.bias is not None) | |
fused_linear.weight.data = self.linear.weight + self.scale * (self.lora_up.weight @ self.lora_down.weight) | |
if self.linear.bias is not None: | |
fused_linear.bias.data = self.linear.bias | |
return fused_linear | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" LoRA adapted linear layer forward pass. """ | |
return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale | |
def _find_modules( | |
model, | |
ancestor_class: Optional[Set[str]] = None, | |
search_class: List[Type[nn.Module]] = [nn.Linear], | |
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoRAWrapper], | |
): | |
""" | |
Find all modules of a certain class (or union of classes) that are direct or | |
indirect descendants of other modules of a certain class (or union of classes). | |
Returns all matching modules, along with the parent of those moduless and the | |
names they are referenced by. | |
Adapted from https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py | |
""" | |
# Get the targets we should replace all linears under | |
if ancestor_class is not None: | |
ancestors = ( | |
module | |
for module in model.modules() | |
if module.__class__.__name__ in ancestor_class | |
) | |
else: | |
# this, incase you want to naively iterate over all modules. | |
ancestors = [module for module in model.modules()] | |
# For each target find every linear_class module that isn't a child of a LoRA layer | |
for ancestor in ancestors: | |
for fullname, module in ancestor.named_modules(): | |
if any([isinstance(module, _class) for _class in search_class]): | |
# Find the direct parent if this is a descendant, not a child, of target | |
*path, name = fullname.split(".") | |
parent = ancestor | |
while path: | |
parent = parent.get_submodule(path.pop(0)) | |
# Skip this linear if it's a child of a LoRA layer | |
if exclude_children_of and any( | |
[isinstance(parent, _class) for _class in exclude_children_of] | |
): | |
continue | |
# Otherwise, yield it | |
yield parent, name, module | |
def inject_trainable_LoRA( | |
model: nn.Module, | |
rank: int = 4, | |
scale: float = 1.0, | |
target_replace_modules: Set[str] = ATTENTION_MODULES | |
) -> None: | |
"""Replaces all linear layers of the specified modules with LoRA-adapted linear layers. | |
Modifies the model in-place. | |
Args: | |
model: nn.Module to modify | |
rank: Rank of adaptation matrix B@A | |
scale: x = W_0@x + scale * B@A@x | |
target_replace_modules: Set of module names to replace linear layers in. | |
""" | |
for _module, name, _child_module in _find_modules( | |
model, target_replace_modules, search_class=[nn.Linear] | |
): | |
if sorted(name) == sorted('qkv'): | |
num_packed_linear = 3 | |
elif sorted(name) in [sorted('kv'), sorted('qk'), sorted('qv')]: | |
num_packed_linear = 2 | |
else: | |
num_packed_linear = 1 | |
_module._modules[name] = LoRAWrapper(_child_module, rank=rank, scale=scale, num_packed_linear=num_packed_linear) | |
def fuse_LoRA_into_linear( | |
model: nn.Module, | |
target_replace_modules: Set[str] = ATTENTION_MODULES | |
) -> None: | |
"""Fuses all LoRA-adapted linear layers back into single linear layers. | |
Modifies the model in-place. | |
Args: | |
model: nn.Module to modify | |
target_replace_modules: Set of module names to replace linear layers in. | |
""" | |
for _module, name, _child_module in _find_modules( | |
model, target_replace_modules, search_class=[LoRAWrapper] | |
): | |
_module._modules[name] = _module._modules[name].fuse_LoRA_into_linear() | |
def unfreeze_all_LoRA_layers(model: nn.Module) -> None: | |
""" Unfreezes all LoRA-adapted linear layers. """ | |
for name, param in model.named_parameters(): | |
if 'lora' in name: | |
param.requires_grad = True | |