RiverZ's picture
upd
3b609b9
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from typing import Any, Optional
import torch
# from torch import nn
from peft.import_utils import is_torchao_available
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from .config import LoraConfig
from .layer import Linear
class TorchaoLoraLinear(Linear):
"""LoRA layer implementation for Linear layers using torchao data"""
def __init__(self, *args, get_apply_tensor_subclass, **kwargs):
# this is not strictly necessary, as kwargs are stored either way, but we want to error early if
# get_apply_tensor_subclass is missing.
if kwargs.get("lora_bias", False):
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False")
super().__init__(*args, **kwargs)
self.get_apply_tensor_subclass = get_apply_tensor_subclass
self._check_dtype_supported()
def _check_dtype_supported(self):
# TODO: Not required once int4_weight_only is properly supported by torchao
base_layer = self.get_base_layer()
weight = base_layer.weight
if hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8):
raise ValueError(f"{type(self).__name__} only supports int8 weights for now.")
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
from torchao import quantize_
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
self._check_dtype_supported()
base_layer = self.get_base_layer()
weight = base_layer.weight
for active_adapter in adapter_names:
try:
weight = weight.dequantize()
except NotImplementedError as exc:
msg = (
f"Weights of type {type(weight).__name__} do not support dequantization (yet), which is needed to "
"support merging."
)
raise NotImplementedError(msg) from exc
if safe_merge and not torch.isfinite(weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
weight += self.get_delta_weight(active_adapter)
# TODO: once (if) torchao supports directly mutating the data, use that instead.
del base_layer.weight
base_layer.weight = weight
quantize_(base_layer, self.get_apply_tensor_subclass())
del weight
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
from torchao import quantize_
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.lora_A.keys():
continue
base_layer = self.get_base_layer()
weight = base_layer.weight
try:
weight = weight.dequantize()
except NotImplementedError as exc:
msg = (
f"Weights of type {type(weight).__name__} do not support dequantization (yet), which is needed to "
"support unmerging."
)
raise NotImplementedError(msg) from exc
weight -= self.get_delta_weight(active_adapter)
# We go through a dummy module because overriding the weight.data does not work, the tensor retains the old
# data. Therefore, we need to go through quantize_, which takes a module as input, and we need to delete and
# re-assign the weight.
# TODO: once (if) torchao supports directly mutating the data, use that instead.
del base_layer.weight
base_layer.weight = weight
quantize_(base_layer, self.get_apply_tensor_subclass())
del weight
def __repr__(self) -> str:
rep = super().__repr__()
return rep.replace("lora.Linear", f"lora.{self.__class__.__name__}")
def dispatch_torchao(
target: torch.nn.Module,
adapter_name: str,
lora_config: LoraConfig,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if not hasattr(target_base_layer, "weight"):
return new_module
if not is_torchao_available():
return new_module
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization import LinearActivationQuantizedTensor
if isinstance(target_base_layer.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)):
new_module = TorchaoLoraLinear(target, adapter_name, **kwargs)
return new_module