Spaces:
Running
Running
File size: 1,706 Bytes
91fb4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import functools
import torch
from accelerate.logging import get_logger
from peft.tuners.tuners_utils import BaseTunerLayer
from .constants import FINETRAINERS_LOG_LEVEL
logger = get_logger("finetrainers") # pylint: disable=invalid-name
logger.setLevel(FINETRAINERS_LOG_LEVEL)
def perform_peft_patches() -> None:
_perform_patch_move_adapter_to_device_of_base_layer()
def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
# We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights
# are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of
# LoRA weights from higher precision dtype.
BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
BaseTunerLayer._move_adapter_to_device_of_base_layer
)
def _patched_move_adapter_to_device_of_base_layer(func) -> None:
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with DisableTensorToDtype():
return func(self, *args, **kwargs)
return wrapper
class DisableTensorToDtype:
def __enter__(self):
self.original_to = torch.Tensor.to
def modified_to(tensor, *args, **kwargs):
# remove dtype from args if present
args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
if "dtype" in kwargs:
kwargs.pop("dtype")
return self.original_to(tensor, *args, **kwargs)
torch.Tensor.to = modified_to
def __exit__(self, exc_type, exc_val, exc_tb):
torch.Tensor.to = self.original_to
|