File size: 1,706 Bytes
91fb4ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import functools

import torch
from accelerate.logging import get_logger
from peft.tuners.tuners_utils import BaseTunerLayer

from .constants import FINETRAINERS_LOG_LEVEL


logger = get_logger("finetrainers")  # pylint: disable=invalid-name
logger.setLevel(FINETRAINERS_LOG_LEVEL)


def perform_peft_patches() -> None:
    _perform_patch_move_adapter_to_device_of_base_layer()


def _perform_patch_move_adapter_to_device_of_base_layer() -> None:
    # We don't patch the method for torch.float32 and torch.bfloat16 because it is okay to train with them. If the model weights
    # are in torch.float16, torch.float8_e4m3fn or torch.float8_e5m2, we need to patch this method to avoid conversion of
    # LoRA weights from higher precision dtype.
    BaseTunerLayer._move_adapter_to_device_of_base_layer = _patched_move_adapter_to_device_of_base_layer(
        BaseTunerLayer._move_adapter_to_device_of_base_layer
    )


def _patched_move_adapter_to_device_of_base_layer(func) -> None:
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        with DisableTensorToDtype():
            return func(self, *args, **kwargs)

    return wrapper


class DisableTensorToDtype:
    def __enter__(self):
        self.original_to = torch.Tensor.to

        def modified_to(tensor, *args, **kwargs):
            # remove dtype from args if present
            args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
            if "dtype" in kwargs:
                kwargs.pop("dtype")
            return self.original_to(tensor, *args, **kwargs)

        torch.Tensor.to = modified_to

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.Tensor.to = self.original_to