File size: 7,309 Bytes
3424266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Set, Optional, Type

import torch
import torch.nn as nn


SELF_ATTENTION_MODULES = {'Attention', 'NormAttention'}
CROSS_ATTENTION_MODULES = {'CrossAttention', 'NormCrossAttention'}
ATTENTION_MODULES = SELF_ATTENTION_MODULES | CROSS_ATTENTION_MODULES
MLP_MODULES = {'Mlp', 'GatedMlp', 'SwiGLUFFNFused'} # SwiGLUFFNFused is from DINOv2
TRANSFORMER_MODULES = ATTENTION_MODULES | MLP_MODULES


def get_LoRA_module_names(id: str) -> Set[str]:
    """ Returns a list of module names that are LoRA-adapted for the given id. """
    id = id.lower()
    if id in ['selfattn', 'selfattention', 'self_attn', 'self_attention']:
        return SELF_ATTENTION_MODULES
    elif id in ['crossattn', 'crossattention', 'cross_attn', 'cross_attention']:
        return CROSS_ATTENTION_MODULES
    elif id in ['attn', 'attention']:
        return ATTENTION_MODULES
    elif id in ['mlp']:
        return MLP_MODULES
    elif id in ['all', 'transformer']:
        return TRANSFORMER_MODULES
    else:
        raise ValueError(f'Unknown LoRA module id {id}.')


class LoRAWrapper(nn.Module):
    """Low-Rank Adaptation Wrapper for linear layers.
    See https://arxiv.org/abs/2106.09685
    
    Args:
        linear: nn.Linear layer to wrap
        rank: Rank of adaptation matrix B@A
        scale: x = W_0@x + scale * B@A@x
        num_packed_linear: Set to > 1 when wrapping e.g. packed kv, or qkv attention weights.
            Weights will be initialized as if num_packed_linear = 1, but the LoRA bottleneck will
            be num_packed_linear times larger.
    """
    def __init__(self, linear: nn.Module, rank: int = 4, scale: float = 1.0, num_packed_linear: int = 1):
        super().__init__()
        self.rank = rank
        self.scale = scale
        self.in_features, self.out_features = linear.in_features, linear.out_features
        assert num_packed_linear * rank <= min(self.in_features, self.out_features), \
            f'LoRA rank {num_packed_linear} * {rank} must be less or equal than {min(self.in_features, self.out_features)}'
        
        self.linear = linear
        self.lora_down = nn.Linear(self.in_features, num_packed_linear*rank, bias=False)
        self.lora_up = nn.Linear(num_packed_linear*rank, self.out_features, bias=False)

        nn.init.normal_(self.lora_down.weight, std=1/rank)
        nn.init.zeros_(self.lora_up.weight)
        
    def fuse_LoRA_into_linear(self) -> nn.Linear:
        """ Returns a single nn.Linear layer with the LoRA matrix fused into the original one. """
        fused_linear = nn.Linear(self.in_features, self.out_features, bias=self.linear.bias is not None)
        fused_linear.weight.data = self.linear.weight + self.scale * (self.lora_up.weight @ self.lora_down.weight)
        if self.linear.bias is not None:
            fused_linear.bias.data = self.linear.bias
        return fused_linear

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """ LoRA adapted linear layer forward pass. """
        return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale
    

def _find_modules(
    model,
    ancestor_class: Optional[Set[str]] = None,
    search_class: List[Type[nn.Module]] = [nn.Linear],
    exclude_children_of: Optional[List[Type[nn.Module]]] = [LoRAWrapper],
):
    """
    Find all modules of a certain class (or union of classes) that are direct or
    indirect descendants of other modules of a certain class (or union of classes).

    Returns all matching modules, along with the parent of those moduless and the
    names they are referenced by.
    
    Adapted from https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
    """
    # Get the targets we should replace all linears under
    if ancestor_class is not None:
        ancestors = (
            module
            for module in model.modules()
            if module.__class__.__name__ in ancestor_class
        )
    else:
        # this, incase you want to naively iterate over all modules.
        ancestors = [module for module in model.modules()]

    # For each target find every linear_class module that isn't a child of a LoRA layer
    for ancestor in ancestors:
        for fullname, module in ancestor.named_modules():
            if any([isinstance(module, _class) for _class in search_class]):
                # Find the direct parent if this is a descendant, not a child, of target
                *path, name = fullname.split(".")
                parent = ancestor
                while path:
                    parent = parent.get_submodule(path.pop(0))
                # Skip this linear if it's a child of a LoRA layer
                if exclude_children_of and any(
                    [isinstance(parent, _class) for _class in exclude_children_of]
                ):
                    continue
                # Otherwise, yield it
                yield parent, name, module
                

def inject_trainable_LoRA(
    model: nn.Module, 
    rank: int = 4, 
    scale: float = 1.0,
    target_replace_modules: Set[str] = ATTENTION_MODULES
) -> None:
    """Replaces all linear layers of the specified modules with LoRA-adapted linear layers.
    Modifies the model in-place.
    
    Args:
        model: nn.Module to modify
        rank: Rank of adaptation matrix B@A
        scale: x = W_0@x + scale * B@A@x
        target_replace_modules: Set of module names to replace linear layers in.
    """
    for _module, name, _child_module in _find_modules(
        model, target_replace_modules, search_class=[nn.Linear]
    ):
        if sorted(name) == sorted('qkv'):
            num_packed_linear = 3
        elif sorted(name) in [sorted('kv'), sorted('qk'), sorted('qv')]:
            num_packed_linear = 2
        else:
            num_packed_linear = 1
        
        _module._modules[name] = LoRAWrapper(_child_module, rank=rank, scale=scale, num_packed_linear=num_packed_linear)
        

def fuse_LoRA_into_linear(
    model: nn.Module,
    target_replace_modules: Set[str] = ATTENTION_MODULES
) -> None:
    """Fuses all LoRA-adapted linear layers back into single linear layers.
    Modifies the model in-place.

    Args:
        model: nn.Module to modify
        target_replace_modules: Set of module names to replace linear layers in.
    """
    for _module, name, _child_module in _find_modules(
        model, target_replace_modules, search_class=[LoRAWrapper]
    ):
        _module._modules[name] = _module._modules[name].fuse_LoRA_into_linear()


def unfreeze_all_LoRA_layers(model: nn.Module) -> None:
    """ Unfreezes all LoRA-adapted linear layers. """
    for name, param in model.named_parameters():
        if 'lora' in name:
            param.requires_grad = True