|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import warnings |
|
from typing import Any, Optional |
|
|
|
import torch |
|
|
|
|
|
from peft.import_utils import is_torchao_available |
|
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge |
|
|
|
from .config import LoraConfig |
|
from .layer import Linear |
|
|
|
|
|
class TorchaoLoraLinear(Linear): |
|
"""LoRA layer implementation for Linear layers using torchao data""" |
|
|
|
def __init__(self, *args, get_apply_tensor_subclass, **kwargs): |
|
|
|
|
|
if kwargs.get("lora_bias", False): |
|
raise ValueError(f"{self.__class__.__name__} does not support lora_bias yet, set it to False") |
|
|
|
super().__init__(*args, **kwargs) |
|
self.get_apply_tensor_subclass = get_apply_tensor_subclass |
|
self._check_dtype_supported() |
|
|
|
def _check_dtype_supported(self): |
|
|
|
base_layer = self.get_base_layer() |
|
weight = base_layer.weight |
|
if hasattr(weight, "layout_tensor") and (weight.layout_tensor.data.dtype != torch.int8): |
|
raise ValueError(f"{type(self).__name__} only supports int8 weights for now.") |
|
|
|
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: |
|
from torchao import quantize_ |
|
|
|
adapter_names = check_adapters_to_merge(self, adapter_names) |
|
if not adapter_names: |
|
|
|
return |
|
|
|
self._check_dtype_supported() |
|
|
|
base_layer = self.get_base_layer() |
|
weight = base_layer.weight |
|
|
|
for active_adapter in adapter_names: |
|
try: |
|
weight = weight.dequantize() |
|
except NotImplementedError as exc: |
|
msg = ( |
|
f"Weights of type {type(weight).__name__} do not support dequantization (yet), which is needed to " |
|
"support merging." |
|
) |
|
raise NotImplementedError(msg) from exc |
|
|
|
if safe_merge and not torch.isfinite(weight).all(): |
|
raise ValueError( |
|
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" |
|
) |
|
|
|
weight += self.get_delta_weight(active_adapter) |
|
|
|
del base_layer.weight |
|
base_layer.weight = weight |
|
quantize_(base_layer, self.get_apply_tensor_subclass()) |
|
del weight |
|
|
|
self.merged_adapters.append(active_adapter) |
|
|
|
def unmerge(self) -> None: |
|
from torchao import quantize_ |
|
|
|
if not self.merged: |
|
warnings.warn("Already unmerged. Nothing to do.") |
|
return |
|
|
|
while len(self.merged_adapters) > 0: |
|
active_adapter = self.merged_adapters.pop() |
|
if active_adapter not in self.lora_A.keys(): |
|
continue |
|
|
|
base_layer = self.get_base_layer() |
|
weight = base_layer.weight |
|
try: |
|
weight = weight.dequantize() |
|
except NotImplementedError as exc: |
|
msg = ( |
|
f"Weights of type {type(weight).__name__} do not support dequantization (yet), which is needed to " |
|
"support unmerging." |
|
) |
|
raise NotImplementedError(msg) from exc |
|
|
|
weight -= self.get_delta_weight(active_adapter) |
|
|
|
|
|
|
|
|
|
del base_layer.weight |
|
base_layer.weight = weight |
|
quantize_(base_layer, self.get_apply_tensor_subclass()) |
|
del weight |
|
|
|
def __repr__(self) -> str: |
|
rep = super().__repr__() |
|
return rep.replace("lora.Linear", f"lora.{self.__class__.__name__}") |
|
|
|
|
|
def dispatch_torchao( |
|
target: torch.nn.Module, |
|
adapter_name: str, |
|
lora_config: LoraConfig, |
|
**kwargs: Any, |
|
) -> Optional[torch.nn.Module]: |
|
new_module = None |
|
|
|
if isinstance(target, BaseTunerLayer): |
|
target_base_layer = target.get_base_layer() |
|
else: |
|
target_base_layer = target |
|
|
|
if not hasattr(target_base_layer, "weight"): |
|
return new_module |
|
|
|
if not is_torchao_available(): |
|
return new_module |
|
|
|
from torchao.dtypes import AffineQuantizedTensor |
|
from torchao.quantization import LinearActivationQuantizedTensor |
|
|
|
if isinstance(target_base_layer.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)): |
|
new_module = TorchaoLoraLinear(target, adapter_name, **kwargs) |
|
|
|
return new_module |
|
|