File size: 5,678 Bytes
3b609b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import warnings
from typing import Any, Optional

import torch

# from torch import nn
from peft.import_utils import is_torchao_available
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge

from .config import LoraConfig
from .layer import Linear


class TorchaoLoraLinear(Linear):
    """LoRA layer implementation for Linear layers using torchao data"""

    def __init__(self, *args, get_apply_tensor_subclass, **kwargs):
        # this is not strictly necessary, as kwargs are stored either way, but we want to error early if
        # get_apply_tensor_subclass is missing.
        if kwargs.get("lora_bias", False):
            raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")

        super().__init__(*args, **kwargs)
        self.get_apply_tensor_subclass = get_apply_tensor_subclass
        self._check_dtype_supported()

    def _check_dtype_supported(self):
        # TODO: Not required once int4_weight_only is properly supported by torchao
        base_layer = self.get_base_layer()
        weight = base_layer.weight
        if hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8):
            raise ValueError(f"{type(self).__name__} only supports int8 weights for now.")

    def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
        from torchao import quantize_

        adapter_names = check_adapters_to_merge(self, adapter_names)
        if not adapter_names:
            # no adapter to merge
            return

        self._check_dtype_supported()

        base_layer = self.get_base_layer()
        weight = base_layer.weight

        for active_adapter in adapter_names:
            try:
                weight = weight.dequantize()
            except NotImplementedError as exc:
                msg = (
                    f"Weights of type {type(weight).__name__} do not support dequantization (yet), which is needed to "
                    "support merging."
                )
                raise NotImplementedError(msg) from exc

            if safe_merge and not torch.isfinite(weight).all():
                raise ValueError(
                    f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
                )

            weight += self.get_delta_weight(active_adapter)
            # TODO: once (if) torchao supports directly mutating the data, use that instead.
            del base_layer.weight
            base_layer.weight = weight
            quantize_(base_layer, self.get_apply_tensor_subclass())
            del weight

            self.merged_adapters.append(active_adapter)

    def unmerge(self) -> None:
        from torchao import quantize_

        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return

        while len(self.merged_adapters) > 0:
            active_adapter = self.merged_adapters.pop()
            if active_adapter not in self.lora_A.keys():
                continue

            base_layer = self.get_base_layer()
            weight = base_layer.weight
            try:
                weight = weight.dequantize()
            except NotImplementedError as exc:
                msg = (
                    f"Weights of type {type(weight).__name__} do not support dequantization (yet), which is needed to "
                    "support unmerging."
                )
                raise NotImplementedError(msg) from exc

            weight -= self.get_delta_weight(active_adapter)
            # We go through a dummy module because overriding the weight.data does not work, the tensor retains the old
            # data. Therefore, we need to go through quantize_, which takes a module as input, and we need to delete and
            # re-assign the weight.
            # TODO: once (if) torchao supports directly mutating the data, use that instead.
            del base_layer.weight
            base_layer.weight = weight
            quantize_(base_layer, self.get_apply_tensor_subclass())
            del weight

    def __repr__(self) -> str:
        rep = super().__repr__()
        return rep.replace("lora.Linear", f"lora.{self.__class__.__name__}")


def dispatch_torchao(
    target: torch.nn.Module,
    adapter_name: str,
    lora_config: LoraConfig,
    **kwargs: Any,
) -> Optional[torch.nn.Module]:
    new_module = None

    if isinstance(target, BaseTunerLayer):
        target_base_layer = target.get_base_layer()
    else:
        target_base_layer = target

    if not hasattr(target_base_layer, "weight"):
        return new_module

    if not is_torchao_available():
        return new_module

    from torchao.dtypes import AffineQuantizedTensor
    from torchao.quantization import LinearActivationQuantizedTensor

    if isinstance(target_base_layer.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)):
        new_module = TorchaoLoraLinear(target, adapter_name, **kwargs)

    return new_module