Spaces:
Runtime error
Runtime error
File size: 7,309 Bytes
3424266 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Set, Optional, Type
import torch
import torch.nn as nn
SELF_ATTENTION_MODULES = {'Attention', 'NormAttention'}
CROSS_ATTENTION_MODULES = {'CrossAttention', 'NormCrossAttention'}
ATTENTION_MODULES = SELF_ATTENTION_MODULES | CROSS_ATTENTION_MODULES
MLP_MODULES = {'Mlp', 'GatedMlp', 'SwiGLUFFNFused'} # SwiGLUFFNFused is from DINOv2
TRANSFORMER_MODULES = ATTENTION_MODULES | MLP_MODULES
def get_LoRA_module_names(id: str) -> Set[str]:
""" Returns a list of module names that are LoRA-adapted for the given id. """
id = id.lower()
if id in ['selfattn', 'selfattention', 'self_attn', 'self_attention']:
return SELF_ATTENTION_MODULES
elif id in ['crossattn', 'crossattention', 'cross_attn', 'cross_attention']:
return CROSS_ATTENTION_MODULES
elif id in ['attn', 'attention']:
return ATTENTION_MODULES
elif id in ['mlp']:
return MLP_MODULES
elif id in ['all', 'transformer']:
return TRANSFORMER_MODULES
else:
raise ValueError(f'Unknown LoRA module id {id}.')
class LoRAWrapper(nn.Module):
"""Low-Rank Adaptation Wrapper for linear layers.
See https://arxiv.org/abs/2106.09685
Args:
linear: nn.Linear layer to wrap
rank: Rank of adaptation matrix B@A
scale: x = W_0@x + scale * B@A@x
num_packed_linear: Set to > 1 when wrapping e.g. packed kv, or qkv attention weights.
Weights will be initialized as if num_packed_linear = 1, but the LoRA bottleneck will
be num_packed_linear times larger.
"""
def __init__(self, linear: nn.Module, rank: int = 4, scale: float = 1.0, num_packed_linear: int = 1):
super().__init__()
self.rank = rank
self.scale = scale
self.in_features, self.out_features = linear.in_features, linear.out_features
assert num_packed_linear * rank <= min(self.in_features, self.out_features), \
f'LoRA rank {num_packed_linear} * {rank} must be less or equal than {min(self.in_features, self.out_features)}'
self.linear = linear
self.lora_down = nn.Linear(self.in_features, num_packed_linear*rank, bias=False)
self.lora_up = nn.Linear(num_packed_linear*rank, self.out_features, bias=False)
nn.init.normal_(self.lora_down.weight, std=1/rank)
nn.init.zeros_(self.lora_up.weight)
def fuse_LoRA_into_linear(self) -> nn.Linear:
""" Returns a single nn.Linear layer with the LoRA matrix fused into the original one. """
fused_linear = nn.Linear(self.in_features, self.out_features, bias=self.linear.bias is not None)
fused_linear.weight.data = self.linear.weight + self.scale * (self.lora_up.weight @ self.lora_down.weight)
if self.linear.bias is not None:
fused_linear.bias.data = self.linear.bias
return fused_linear
def forward(self, x: torch.Tensor) -> torch.Tensor:
""" LoRA adapted linear layer forward pass. """
return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale
def _find_modules(
model,
ancestor_class: Optional[Set[str]] = None,
search_class: List[Type[nn.Module]] = [nn.Linear],
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoRAWrapper],
):
"""
Find all modules of a certain class (or union of classes) that are direct or
indirect descendants of other modules of a certain class (or union of classes).
Returns all matching modules, along with the parent of those moduless and the
names they are referenced by.
Adapted from https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
"""
# Get the targets we should replace all linears under
if ancestor_class is not None:
ancestors = (
module
for module in model.modules()
if module.__class__.__name__ in ancestor_class
)
else:
# this, incase you want to naively iterate over all modules.
ancestors = [module for module in model.modules()]
# For each target find every linear_class module that isn't a child of a LoRA layer
for ancestor in ancestors:
for fullname, module in ancestor.named_modules():
if any([isinstance(module, _class) for _class in search_class]):
# Find the direct parent if this is a descendant, not a child, of target
*path, name = fullname.split(".")
parent = ancestor
while path:
parent = parent.get_submodule(path.pop(0))
# Skip this linear if it's a child of a LoRA layer
if exclude_children_of and any(
[isinstance(parent, _class) for _class in exclude_children_of]
):
continue
# Otherwise, yield it
yield parent, name, module
def inject_trainable_LoRA(
model: nn.Module,
rank: int = 4,
scale: float = 1.0,
target_replace_modules: Set[str] = ATTENTION_MODULES
) -> None:
"""Replaces all linear layers of the specified modules with LoRA-adapted linear layers.
Modifies the model in-place.
Args:
model: nn.Module to modify
rank: Rank of adaptation matrix B@A
scale: x = W_0@x + scale * B@A@x
target_replace_modules: Set of module names to replace linear layers in.
"""
for _module, name, _child_module in _find_modules(
model, target_replace_modules, search_class=[nn.Linear]
):
if sorted(name) == sorted('qkv'):
num_packed_linear = 3
elif sorted(name) in [sorted('kv'), sorted('qk'), sorted('qv')]:
num_packed_linear = 2
else:
num_packed_linear = 1
_module._modules[name] = LoRAWrapper(_child_module, rank=rank, scale=scale, num_packed_linear=num_packed_linear)
def fuse_LoRA_into_linear(
model: nn.Module,
target_replace_modules: Set[str] = ATTENTION_MODULES
) -> None:
"""Fuses all LoRA-adapted linear layers back into single linear layers.
Modifies the model in-place.
Args:
model: nn.Module to modify
target_replace_modules: Set of module names to replace linear layers in.
"""
for _module, name, _child_module in _find_modules(
model, target_replace_modules, search_class=[LoRAWrapper]
):
_module._modules[name] = _module._modules[name].fuse_LoRA_into_linear()
def unfreeze_all_LoRA_layers(model: nn.Module) -> None:
""" Unfreezes all LoRA-adapted linear layers. """
for name, param in model.named_parameters():
if 'lora' in name:
param.requires_grad = True
|