Spaces:
Runtime error
Runtime error
File size: 5,691 Bytes
91fb4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Optional, Tuple, Type
import torch
from accelerate.logging import get_logger
from ..constants import FINETRAINERS_LOG_LEVEL
from .hooks import HookRegistry, ModelHook
logger = get_logger("finetrainers") # pylint: disable=invalid-name
logger.setLevel(FINETRAINERS_LOG_LEVEL)
# fmt: off
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
)
_DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm")
# fmt: on
class LayerwiseUpcastingHook(ModelHook):
r"""
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
footprint.
"""
_is_stateful = False
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
self.storage_dtype = storage_dtype
self.compute_dtype = compute_dtype
self.non_blocking = non_blocking
def initialize_hook(self, module: torch.nn.Module):
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
return module
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
return args, kwargs
def post_forward(self, module: torch.nn.Module, output):
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
return output
def apply_layerwise_upcasting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: Optional[Tuple[str]] = _DEFAULT_SKIP_MODULES_PATTERN,
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module]]] = None,
non_blocking: bool = False,
_prefix: str = "",
) -> None:
r"""
Applies layerwise upcasting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
nn.Module using diffusers layers or pytorch primitives.
Args:
module (`torch.nn.Module`):
The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
precision dtype for storage.
storage_dtype (`torch.dtype`):
The dtype to cast the module to before/after the forward pass for storage.
compute_dtype (`torch.dtype`):
The dtype to cast the module to during the forward pass for computation.
skip_modules_pattern (`Tuple[str]`, defaults to `["pos_embed", "patch_embed", "norm"]`):
A list of patterns to match the names of the modules to skip during the layerwise upcasting process.
skip_modules_classes (`Tuple[Type[torch.nn.Module]]`, defaults to `None`):
A list of module classes to skip during the layerwise upcasting process.
non_blocking (`bool`, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
"""
if skip_modules_classes is None and skip_modules_pattern is None:
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
)
if should_skip:
logger.debug(f'Skipping layerwise upcasting for layer "{_prefix}"')
return
if isinstance(module, _SUPPORTED_PYTORCH_LAYERS):
logger.debug(f'Applying layerwise upcasting to layer "{_prefix}"')
apply_layerwise_upcasting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
for name, submodule in module.named_children():
layer_name = f"{_prefix}.{name}" if _prefix else name
apply_layerwise_upcasting(
submodule,
storage_dtype,
compute_dtype,
skip_modules_pattern,
skip_modules_classes,
non_blocking,
_prefix=layer_name,
)
def apply_layerwise_upcasting_hook(
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
) -> None:
r"""
Applies a `LayerwiseUpcastingHook` to a given module.
Args:
module (`torch.nn.Module`):
The module to attach the hook to.
storage_dtype (`torch.dtype`):
The dtype to cast the module to before the forward pass.
compute_dtype (`torch.dtype`):
The dtype to cast the module to during the forward pass.
non_blocking (`bool`):
If `True`, the weight casting operations are non-blocking.
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerwiseUpcastingHook(storage_dtype, compute_dtype, non_blocking)
registry.register_hook(hook, "layerwise_upcasting")
|