Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,589 Bytes
22a452a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Optional, Tuple, Type, Union
import torch
from ..utils import get_logger, is_peft_available, is_peft_version
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
# fmt: off
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
torch.nn.Linear,
)
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
# fmt: on
_SHOULD_DISABLE_PEFT_INPUT_AUTOCAST = is_peft_available() and is_peft_version(">", "0.14.0")
if _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
from peft.helpers import disable_input_dtype_casting
from peft.tuners.tuners_utils import BaseTunerLayer
class LayerwiseCastingHook(ModelHook):
r"""
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
footprint.
"""
_is_stateful = False
def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
self.storage_dtype = storage_dtype
self.compute_dtype = compute_dtype
self.non_blocking = non_blocking
def initialize_hook(self, module: torch.nn.Module):
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
return module
def deinitalize_hook(self, module: torch.nn.Module):
raise NotImplementedError(
"LayerwiseCastingHook does not support deinitialization. A model once enabled with layerwise casting will "
"have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype "
"will lead to precision loss, which might have an impact on the model's generation quality. The model should "
"be re-initialized and loaded in the original dtype."
)
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
return args, kwargs
def post_forward(self, module: torch.nn.Module, output):
module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
return output
class PeftInputAutocastDisableHook(ModelHook):
r"""
A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT
casts the inputs to the weight dtype of the module, which can lead to precision loss.
The reasons for needing this are:
- If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the
inputs will be casted to the, possibly lower precision, storage dtype. Reference:
https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706
- We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure
that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
hoping to achieve:
1. Making forward implementations independent of device/dtype casting operations as much as possible.
2. Performing inference without losing information from casting to different precisions. With the current
PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to
torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
forward pass in PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality.
"""
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
with disable_input_dtype_casting(module):
return self.fn_ref.original_forward(*args, **kwargs)
def apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
) -> None:
r"""
Applies layerwise casting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
nn.Module using diffusers layers or pytorch primitives.
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXTransformer3DModel
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> apply_layerwise_casting(
... transformer,
... storage_dtype=torch.float8_e4m3fn,
... compute_dtype=torch.bfloat16,
... skip_modules_pattern=["patch_embed", "norm", "proj_out"],
... non_blocking=True,
... )
```
Args:
module (`torch.nn.Module`):
The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
precision dtype for storage.
storage_dtype (`torch.dtype`):
The dtype to cast the module to before/after the forward pass for storage.
compute_dtype (`torch.dtype`):
The dtype to cast the module to during the forward pass for computation.
skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
instead of its internal submodules.
skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
A list of module classes to skip during the layerwise casting process.
non_blocking (`bool`, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
"""
if skip_modules_pattern == "auto":
skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
if skip_modules_classes is None and skip_modules_pattern is None:
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
_apply_layerwise_casting(
module,
storage_dtype,
compute_dtype,
skip_modules_pattern,
skip_modules_classes,
non_blocking,
)
_disable_peft_input_autocast(module)
def _apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
compute_dtype: torch.dtype,
skip_modules_pattern: Optional[Tuple[str, ...]] = None,
skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
non_blocking: bool = False,
_prefix: str = "",
) -> None:
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
)
if should_skip:
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
return
if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
for name, submodule in module.named_children():
layer_name = f"{_prefix}.{name}" if _prefix else name
_apply_layerwise_casting(
submodule,
storage_dtype,
compute_dtype,
skip_modules_pattern,
skip_modules_classes,
non_blocking,
_prefix=layer_name,
)
def apply_layerwise_casting_hook(
module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
) -> None:
r"""
Applies a `LayerwiseCastingHook` to a given module.
Args:
module (`torch.nn.Module`):
The module to attach the hook to.
storage_dtype (`torch.dtype`):
The dtype to cast the module to before the forward pass.
compute_dtype (`torch.dtype`):
The dtype to cast the module to during the forward pass.
non_blocking (`bool`):
If `True`, the weight casting operations are non-blocking.
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
registry.register_hook(hook, _LAYERWISE_CASTING_HOOK)
def _is_layerwise_casting_active(module: torch.nn.Module) -> bool:
for submodule in module.modules():
if (
hasattr(submodule, "_diffusers_hook")
and submodule._diffusers_hook.get_hook(_LAYERWISE_CASTING_HOOK) is not None
):
return True
return False
def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
if not _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
return
for submodule in module.modules():
if isinstance(submodule, BaseTunerLayer) and _is_layerwise_casting_active(submodule):
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = PeftInputAutocastDisableHook()
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
|