File size: 5,678 Bytes
3b609b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from typing import Any, Optional
import torch
# from torch import nn
from peft.import_utils import is_torchao_available
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from .config import LoraConfig
from .layer import Linear
class TorchaoLoraLinear(Linear):
"""LoRA layer implementation for Linear layers using torchao data"""
def __init__(self, *args, get_apply_tensor_subclass, **kwargs):
# this is not strictly necessary, as kwargs are stored either way, but we want to error early if
# get_apply_tensor_subclass is missing.
if kwargs.get("lora_bias", False):
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")
super().__init__(*args, **kwargs)
self.get_apply_tensor_subclass = get_apply_tensor_subclass
self._check_dtype_supported()
def _check_dtype_supported(self):
# TODO: Not required once int4_weight_only is properly supported by torchao
base_layer = self.get_base_layer()
weight = base_layer.weight
if hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8):
raise ValueError(f"{type(self).__name__} only supports int8 weights for now.")
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
from torchao import quantize_
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
self._check_dtype_supported()
base_layer = self.get_base_layer()
weight = base_layer.weight
for active_adapter in adapter_names:
try:
weight = weight.dequantize()
except NotImplementedError as exc:
msg = (
f"Weights of type {type(weight).__name__} do not support dequantization (yet), which is needed to "
"support merging."
)
raise NotImplementedError(msg) from exc
if safe_merge and not torch.isfinite(weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
weight += self.get_delta_weight(active_adapter)
# TODO: once (if) torchao supports directly mutating the data, use that instead.
del base_layer.weight
base_layer.weight = weight
quantize_(base_layer, self.get_apply_tensor_subclass())
del weight
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
from torchao import quantize_
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.lora_A.keys():
continue
base_layer = self.get_base_layer()
weight = base_layer.weight
try:
weight = weight.dequantize()
except NotImplementedError as exc:
msg = (
f"Weights of type {type(weight).__name__} do not support dequantization (yet), which is needed to "
"support unmerging."
)
raise NotImplementedError(msg) from exc
weight -= self.get_delta_weight(active_adapter)
# We go through a dummy module because overriding the weight.data does not work, the tensor retains the old
# data. Therefore, we need to go through quantize_, which takes a module as input, and we need to delete and
# re-assign the weight.
# TODO: once (if) torchao supports directly mutating the data, use that instead.
del base_layer.weight
base_layer.weight = weight
quantize_(base_layer, self.get_apply_tensor_subclass())
del weight
def __repr__(self) -> str:
rep = super().__repr__()
return rep.replace("lora.Linear", f"lora.{self.__class__.__name__}")
def dispatch_torchao(
target: torch.nn.Module,
adapter_name: str,
lora_config: LoraConfig,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if not hasattr(target_base_layer, "weight"):
return new_module
if not is_torchao_available():
return new_module
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization import LinearActivationQuantizedTensor
if isinstance(target_base_layer.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)):
new_module = TorchaoLoraLinear(target, adapter_name, **kwargs)
return new_module
|