Aston-xMAD's picture
init commit
b37c16f verified
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import collections
import functools
import inspect
import math
import operator
import os
import random
import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union
import torch
from torch import nn
from torch.fx import Graph, GraphModule, Proxy, Tracer
from torch.fx._compatibility import compatibility
from torch.fx.proxy import ParameterProxy
from .. import PretrainedConfig, PreTrainedModel, logging
from ..models.auto import get_values
from ..models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_BACKBONE_MAPPING_NAMES,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_CTC_MAPPING_NAMES,
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_IMAGE_MAPPING_NAMES,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
MODEL_FOR_MASKED_LM_MAPPING_NAMES,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..pytorch_utils import is_torch_greater_or_equal_than_2_0
from ..utils import (
ENV_VARS_TRUE_VALUES,
TORCH_FX_REQUIRED_VERSION,
get_torch_version,
is_peft_available,
is_torch_fx_available,
)
if is_peft_available():
from peft import PeftModel
logger = logging.get_logger(__name__)
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
def _generate_supported_model_class_names(
model_name: Type[PretrainedConfig],
supported_tasks: Optional[Union[str, List[str]]] = None,
) -> List[str]:
task_mapping = {
"default": MODEL_MAPPING_NAMES,
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
"next-sentence-prediction": MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
"masked-lm": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
"causal-lm": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
"seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
"speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
"multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
"document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
"question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
"sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
"token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
"masked-image-modeling": MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES,
"image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
"ctc": MODEL_FOR_CTC_MAPPING_NAMES,
"audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
"semantic-segmentation": MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
"backbone": MODEL_FOR_BACKBONE_MAPPING_NAMES,
"image-feature-extraction": MODEL_FOR_IMAGE_MAPPING_NAMES,
}
if supported_tasks is None:
supported_tasks = task_mapping.keys()
if isinstance(supported_tasks, str):
supported_tasks = [supported_tasks]
model_class_names = []
for task in supported_tasks:
class_name = task_mapping[task].get(model_name, None)
if class_name:
model_class_names.append(class_name)
return model_class_names
_REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"altclip",
"albert",
"bart",
"bert",
"blenderbot",
"blenderbot-small",
"bloom",
"clip",
"convnext",
"deberta",
"deberta-v2",
"dinov2",
"distilbert",
"donut-swin",
"electra",
"gpt2",
"gpt_neo",
"gptj",
"hubert",
"layoutlm",
"llama",
"cohere",
"lxmert",
"m2m_100",
"marian",
"mbart",
"megatron-bert",
"mistral",
"mixtral",
"mobilebert",
"mt5",
"nezha",
"opt",
"pegasus",
"plbart",
"qwen2",
"qwen2_moe",
"resnet",
"roberta",
"segformer",
"speech_to_text",
"speech_to_text_2",
"swin",
"t5",
"trocr",
"vit",
"xglm",
"wav2vec2",
# "xlnet",
]
_FX_SUPPORTED_MODELS_WITH_KV_CACHE = ["llama", "opt"]
_REGULAR_SUPPORTED_MODELS = []
for item in _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS:
if isinstance(item, dict):
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(**item))
else:
_REGULAR_SUPPORTED_MODELS.extend(_generate_supported_model_class_names(item))
_SPECIAL_SUPPORTED_MODELS = [
"CLIPTextModel",
"CLIPTextModelWithProjection",
"CLIPVisionModel",
"CLIPVisionModelWithProjection",
"AltCLIPTextModel",
"AltCLIPVisionModel",
"GitVisionModel",
"GPT2DoubleHeadsModel",
"Speech2Text2Decoder",
"TrOCRDecoder",
"PeftModelForCausalLM",
"PeftModelForSeq2SeqLM",
# TODO: add support for them as it should be quite easy to do so (small blocking issues).
# XLNetForQuestionAnswering,
]
_SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPORTED_MODELS)))
def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
def torch_nn_functional_embedding(
input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
):
return torch.empty(*input.shape, weight.shape[-1], device="meta", dtype=weight.dtype)
def torch_nn_layernorm(self, input):
return input
def torch_nn_groupnorm(self, input):
return input
def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
def torch_relu(x):
return x
def torch_nn_relu(self, x):
return x
def torch_nn_functional_relu(x, inplace=False):
if not inplace:
raise ValueError("Don't support in-place functional.relu for MetaTensor analysis")
return x
def torch_where(condition, x, y):
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta")
def torch_abs(input, *, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
return input
def torch_arange(*args, **kwargs):
n = len(args)
step = 1
if n == 1:
start = 0
end = args[0]
elif n == 2:
start, end = args
else:
start, end, step = args
if isinstance(start, float):
start = int(start)
if isinstance(end, float):
start = int(end)
if isinstance(step, float):
step = int(step)
step = kwargs.get("step", step)
dtype = kwargs.get("dtype")
return torch.empty((end - start) // step, dtype=dtype, device="meta")
def torch_full(*args, **kwargs):
args = list(args)
# We set the fill value to 1 as its value is not important as long as it's not a tensor on the `meta` device.
if len(args) > 1:
args[1] = 1
else:
kwargs["fill_value"] = 1
kwargs_without_device = dict(kwargs)
kwargs_without_device.pop("device", None)
return torch.full(*args, **kwargs_without_device, device="meta")
def torch_cat(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
dim = axis
if dim < 0:
dim = tensors[0].dim() + dim
shapes = [t.shape for t in tensors]
shape = list(shapes[0])
concatenated_dim = sum(shape[dim] for shape in shapes)
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
return torch.empty(final_shape, device="meta")
def torch_stack(tensors, dim=None, axis=None, *, out=None):
if dim is None and axis is None:
dim = 0
if dim is None and axis is not None:
dim = axis
if dim < 0:
dim = tensors[0].dim() + 1 + dim
shape = list(tensors[0].shape)
shape.insert(dim, len(tensors))
return torch.empty(shape, device="meta")
def torch_add(input, other, *, alpha=1, out=None):
if not isinstance(input, torch.Tensor):
return torch.empty_like(other, device="meta")
if not isinstance(other, torch.Tensor):
return torch.empty_like(input, device="meta")
max_length = max(input.dim(), other.dim())
input_shape = list(input.shape) + [1] * (max_length - input.dim())
other_shape = list(other.shape) + [1] * (max_length - other.dim())
shape = []
for i in range(max_length):
shape.append(max(input_shape[i], other_shape[i]))
return torch.empty(shape, device="meta")
def torch_mul(input, other, *, out=None):
return torch_add(input, other, out=out)
def torch_tensor_mul(self, other):
return torch_mul(self, other)
def torch_matmul(input, other, *, out=None):
d1 = input.dim()
d2 = other.dim()
shape = None
if d1 == 1 and d2 == 1:
shape = None
elif d1 == 2 and d2 == 2:
shape = (input.size(0), other.size(1))
elif d1 == 1 and d2 == 2:
shape = (other.size(1),)
elif d1 == 2 and d1 == 1:
shape = (input.size(0),)
else:
max_length = max(input.dim(), other.dim())
shape1 = list(input.shape)
shape2 = list(other.shape)
if d1 == 1:
shape1 = [1] + shape1
if d2 == 1:
shape2.append(1)
shape1 = [-1] * (max_length - d1) + list(input.shape)
shape2 = [-1] * (max_length - d2) + list(other.shape)
shape = []
for i in range(max_length):
shape.append(max(shape1[i], shape2[i]))
shape[-2] = shape1[-2]
shape[-1] = shape2[-1]
if d1 == 1:
shape.pop(-2)
if d2 == 1:
shape.pop(-1)
if shape is None:
return torch.tensor(0.0, device="meta")
return torch.empty(*shape, device="meta")
def torch_bmm(input, mat2, *, out=None):
if out is not None:
raise ValueError("Don't support in-place bmm for MetaTensor analysis")
batch_size, n, m = input.shape
_, _, p = mat2.shape
return torch.empty(batch_size, n, p, device="meta")
def torch_baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place baddbmm for MetaTensor analysis")
return torch_bmm(batch1, batch2)
def torch_tensor_baddbmm(self, batch1, batch2, *, beta=1, alpha=1, out=None):
return torch_baddbmm(self, batch1, batch2, beta=beta, alpha=alpha, out=out)
def torch_einsum(equation, *operands):
# TODO: infer shape without performing the computation, this might be quite hard.
concrete_operands = (torch.empty_like(operand, device="cpu") for operand in operands)
return torch.einsum(equation, *concrete_operands).to("meta")
def torch_tensor_repeat(self, *sizes):
shape = list(self.shape)
for i, x in enumerate(sizes):
shape[i] *= x
return torch.empty(shape, device="meta")
def torch_repeat_interleave(*args, dim=None, output_size=None):
num_args = len(args)
if num_args == 1:
shape = [output_size if output_size is not None else args[0].sum()]
else:
shape = list(args[0].shape)
if dim is None:
if num_args > 2:
dim = args[2]
else:
shape = [sum(shape)]
dim = 0
repeats = args[1]
if isinstance(repeats, int) or torch.numel(repeats) == 1:
shape[dim] *= int(repeats)
else:
shape[dim] = output_size if output_size is not None else repeats.sum()
return torch.empty(*shape, device="meta")
def torch_index_select(input, dim, index, *, out=None):
shape = list(input.shape)
shape[dim] = len(index)
return torch.empty(*shape, device="meta")
def torch_tensor_index_select(self, dim, index):
return torch_index_select(self, dim, index)
def torch_gather(input, dim, index, *, sparse_grad=False, out=None):
shape = list(input.shape)
shape[dim] = index.shape[dim]
return torch.empty(*shape, device="meta")
def torch_tensor_gather(self, dim, index):
return torch_gather(self, dim, index)
def torch_roll(input, shifts, dims=None):
return input
def torch_flip(input, dims):
return input
def torch_tensor_flip(self, dims):
return self
def torch_nn_conv1d(self, input):
l_in = input.shape[-1]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
l_out = math.floor(
(l_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
shape[-1] = l_out
shape[-2] = self.out_channels
return torch.empty(shape, device="meta")
def torch_nn_conv2d(self, input):
h_in, w_in = input.shape[-2:]
shape = None
padding = self.padding
if padding == "valid":
padding = (0, 0)
if padding == "same":
shape = list(input.shape)
if shape is None:
shape = list(input.shape)
h_out = math.floor(
(h_in + 2 * padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
w_out = math.floor(
(w_in + 2 * padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
)
shape[-2:] = [h_out, w_out]
shape[-3] = self.out_channels
return torch.empty(shape, device="meta")
def torch_squeeze(input, dim=None):
shape = list(input.shape)
if dim is not None:
if dim < 0:
dim = input.dim() + dim
if shape[dim] == 1:
shape.pop(dim)
else:
new_shape = []
for dim_value in shape:
if dim_value == 1:
continue
new_shape.append(dim_value)
shape = new_shape
return torch.empty(shape, device="meta")
def torch_tensor_squeeze(self, dim=None):
return torch_squeeze(self, dim)
def torch_unsqueeze(input, dim):
shape = list(input.shape)
if dim < 0:
dim = input.dim() + 1 + dim
shape.insert(dim, 1)
return torch.empty(shape, device="meta")
def torch_tensor_unsqueeze(self, dim):
return torch_unsqueeze(self, dim)
def torch_unique_consecutive(input, **kwargs):
output = torch.unique_consecutive(torch.zeros_like(input, device="cpu"), **kwargs)
if isinstance(output, torch.Tensor):
return output.to("meta")
else:
return tuple(map(output, lambda x: x.to("meta")))
def torch_nn_functional_one_hot(tensor, num_classes=-1):
if num_classes < 0:
raise ValueError("Don't support automatic num_classes inference for MetaTensor analysis")
shape = list(tensor.shape) + [num_classes]
return torch.empty(shape, device="meta")
def torch_nn_functional_scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
):
target_length = query.shape[-2]
head_dim = value.shape[-1]
return torch.empty((*query.shape[:-2], target_length, head_dim), device="meta")
def torch_nn_mseloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
def torch_nn_crossentropyloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
def torch_nn_bcewithlogitsloss(self, input, target):
if self.reduction == "none":
shape = target.shape
else:
shape = (1,)
return torch.empty(shape, device="meta")
def operator_getitem(a, b):
def to_concrete(t):
if isinstance(t, torch.Tensor):
concrete = torch.ones_like(t, device="cpu")
if concrete.dtype in [torch.float16, torch.float32, torch.float64, torch.int32]:
concrete = concrete.to(torch.int64)
return concrete
return t
if isinstance(a, torch.Tensor):
# TODO: infer shape without performing the computation.
if isinstance(b, tuple):
b = tuple(map(to_concrete, b))
else:
b = to_concrete(b)
return operator.getitem(torch.empty_like(a, device="cpu"), b).to("meta")
return operator.getitem(a, b)
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
torch.nn.Embedding: torch_nn_embedding,
torch.nn.functional.embedding: torch_nn_functional_embedding,
torch.nn.LayerNorm: torch_nn_layernorm,
torch.nn.GroupNorm: torch_nn_groupnorm,
torch.nn.Linear: torch_nn_linear,
torch.relu: torch_relu,
torch.nn.functional.relu: torch_nn_functional_relu,
torch.nn.ReLU: torch_nn_relu,
torch.where: torch_where,
torch.abs: torch_abs,
torch.arange: torch_arange,
torch.full: torch_full,
torch.cat: torch_cat,
torch.stack: torch_stack,
torch.add: torch_add,
torch.mul: torch_mul,
torch.Tensor.mul: torch_tensor_mul,
torch.matmul: torch_matmul,
torch.bmm: torch_bmm,
torch.baddbmm: torch_baddbmm,
torch.Tensor.baddbmm: torch_tensor_baddbmm,
torch.einsum: torch_einsum,
torch.Tensor.repeat: torch_tensor_repeat,
torch.repeat_interleave: torch_repeat_interleave,
torch.roll: torch_roll,
torch.flip: torch_flip,
torch.Tensor.flip: torch_tensor_flip,
torch.index_select: torch_index_select,
torch.Tensor.index_select: torch_tensor_index_select,
torch.gather: torch_gather,
torch.Tensor.gather: torch_tensor_gather,
torch.nn.Conv1d: torch_nn_conv1d,
torch.nn.Conv2d: torch_nn_conv2d,
torch.squeeze: torch_squeeze,
torch.Tensor.squeeze: torch_tensor_squeeze,
torch.unsqueeze: torch_unsqueeze,
torch.Tensor.unsqueeze: torch_tensor_unsqueeze,
torch.unique_consecutive: torch_unique_consecutive,
torch.nn.functional.one_hot: torch_nn_functional_one_hot,
torch.nn.MSELoss: torch_nn_mseloss,
torch.nn.CrossEntropyLoss: torch_nn_crossentropyloss,
torch.nn.BCEWithLogitsLoss: torch_nn_bcewithlogitsloss,
operator.getitem: operator_getitem,
}
if is_torch_greater_or_equal_than_2_0:
_MANUAL_META_OVERRIDES[
torch.nn.functional.scaled_dot_product_attention
] = torch_nn_functional_scaled_dot_product_attention
class HFProxy(Proxy):
"""
Proxy that uses metadata to handle data-dependent control-flow.
"""
def install_metadata(self, metadata):
self._metadata = metadata
@property
def shape(self):
return self.tracer.create_proxy("call_method", "size", (self,), {})
@property
def device(self):
# Hack so we can track when devices are used. During meta-tensor propagation,
# replace these values with a constant 'meta'
return MetaDeviceAttribute(self, "device")
def __len__(self):
if hasattr(self, "_metadata") and self._metadata is not None:
return len(self._metadata)
return super().__len__()
def __bool__(self):
if hasattr(self, "_metadata") and self._metadata is not None:
return self._metadata
return super().__bool__()
def __getattr__(self, k):
if k == "_metadata":
return self.__getattribute__(k)
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return HFAttribute(self, k)
def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})
def __contains__(self, key):
if hasattr(self, "_metadata") and self._metadata is not None:
return key in self._metadata
return super().__contains__(key)
class HFAttribute(HFProxy):
def __init__(self, root, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node = None
if hasattr(self.root, "_metadata"):
self.install_metadata(getattr(self.root._metadata, attr))
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy("call_function", builtins.getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
class MetaDeviceAttribute(HFAttribute):
pass
def _proxies_to_metas(v):
"""Returns the underlying metadata for HFProxies, and behaves like the identity for the others."""
if isinstance(v, MetaDeviceAttribute):
return "meta"
if isinstance(v, torch.fx.Proxy):
if not (isinstance(v, HFProxy) and hasattr(v, "_metadata")):
raise RuntimeError(f"No metadata was found for {v}")
return v._metadata
return v
def _gen_constructor_wrapper(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
proxy = None
def check_has_proxy(v):
if isinstance(v, Proxy):
nonlocal proxy
proxy = v
torch.fx.node.map_aggregate(args, check_has_proxy)
torch.fx.node.map_aggregate(kwargs, check_has_proxy)
if proxy is not None:
return proxy.tracer.create_proxy("call_function", target, args, kwargs)
else:
return target(*args, **kwargs)
return wrapper, target
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
if forbidden_values is None:
forbidden_values = []
value = random.randint(low, high)
while value in forbidden_values:
value = random.randint(low, high)
return value
class HFTracer(Tracer):
"""
Tracer that is able to symbolically trace models from the library. To do that, it uses the HFProxy instead of the
regular PyTorch torch.fx.Proxy.
"""
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes: bool = True
allow_insert_stateless_mods: bool = True
_TORCH_METHODS_TO_PATCH = [
"arange",
"zeros",
"ones",
"full",
"full_like",
"eye",
"empty",
"tensor",
"clamp",
"finfo",
"tril",
]
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
super().__init__(autowrap_modules=autowrap_modules, autowrap_functions=autowrap_functions)
if not is_torch_fx_available():
raise ImportError(
f"Found an incompatible version of torch. Found version {get_torch_version()}, but only version "
f"{TORCH_FX_REQUIRED_VERSION} is supported."
)
def _generate_dummy_input(
self, model: PreTrainedModel, input_name: str, shape: List[int], input_names: List[str]
) -> Dict[str, torch.Tensor]:
"""Generates dummy input for model inference recording."""
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
# from pickle, or from the "__class__" attribute in the general case.
model_class_name = getattr(model, "class_for_deserialization", model.__class__).__name__
device = model.device
inputs_dict = {}
# when tracing a model with KV cache, we simply need to unsure that the KV cache length is larger than one to
# rightfully pass certain controlflows (Example: https://github.com/huggingface/transformers/blob/5c8d941d66734811d2ef6f57f15b44f7fb7a98c4/src/transformers/modeling_attn_mask_utils.py#L162).
# After tracing, the model can then still be used with arbitrary lengths different than the one used during tracing.
kv_cache_length = 5
if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = shape[0]
if model_class_name in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class_name in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
*get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES),
"XLNetForQuestionAnswering",
]:
inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
inputs_dict["end_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class_name in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES):
if not hasattr(model.config, "problem_type") or model.config.problem_type is None:
raise ValueError(
"Could not retrieve the problem type for the sequence classification task, please set "
'model.config.problem_type to one of the following values: "regression", '
'"single_label_classification", or "multi_label_classification".'
)
if model.config.problem_type == "regression":
labels_shape = (batch_size, model.config.num_labels)
labels_dtype = torch.float32
elif model.config.problem_type == "single_label_classification":
labels_shape = (batch_size,)
labels_dtype = torch.long
elif model.config.problem_type == "multi_label_classification":
labels_shape = (batch_size, model.config.num_labels)
labels_dtype = torch.float32
else:
raise ValueError(
'Expected model.config.problem_type to be either: "regression", "single_label_classification"'
f', or "multi_label_classification", but "{model.config.problem_type}" was provided.'
)
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
elif model_class_name in [
*get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES),
"GPT2DoubleHeadsModel",
"PeftModelForCausalLM",
"PeftModelForSeq2SeqLM",
]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
elif model_class_name in [*get_values(MODEL_FOR_CTC_MAPPING_NAMES)]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.float32, device=device)
else:
raise NotImplementedError(
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
)
elif "pixel_values" in input_name:
batch_size = shape[0]
image_size = getattr(model.config, "image_size", None)
if image_size is None:
if hasattr(model.config, "vision_config"):
image_size = model.config.vision_config.image_size
elif hasattr(model.config, "encoder"):
image_size = model.config.encoder.image_size
else:
image_size = (_generate_random_int(), _generate_random_int())
# If no num_channels is in the config, use some arbitrary value.
num_channels = getattr(model.config, "num_channels", 3)
if not isinstance(image_size, collections.abc.Iterable):
image_size = (image_size, image_size)
height, width = image_size
inputs_dict[input_name] = torch.zeros(
batch_size, num_channels, height, width, dtype=torch.float32, device=device
)
elif "bbox" in input_name:
inputs_dict[input_name] = torch.zeros(*shape, 4, dtype=torch.float, device=device)
elif "input_features" in input_name:
inputs_dict[input_name] = torch.zeros(
*shape, model.config.input_feat_per_channel, dtype=torch.float, device=device
)
elif "visual_feats" in input_name:
inputs_dict[input_name] = torch.zeros(
shape
+ [
model.config.visual_feat_dim,
],
dtype=torch.float,
device=device,
)
elif "visual_pos" in input_name:
inputs_dict[input_name] = torch.zeros(
shape
+ [
model.config.visual_pos_dim,
],
dtype=torch.float,
device=device,
)
elif "inputs" in input_name:
inputs_dict[input_name] = torch.zeros(*shape, dtype=torch.float, device=device)
elif "input_values" in input_name:
batch_size, _ = shape
# Generating big sequence length for audio inputs.
seq_length = _generate_random_int(low=10000, high=20000)
inputs_dict[input_name] = torch.zeros(batch_size, seq_length, dtype=torch.float, device=device)
elif "mask" in input_name:
if "past_key_values" in input_names:
mask_shape = [shape[0], shape[1] + kv_cache_length]
else:
mask_shape = shape
inputs_dict[input_name] = torch.zeros(mask_shape, dtype=torch.long, device=device)
elif "ids" in input_name:
inputs_dict[input_name] = torch.zeros(shape, dtype=torch.long, device=device)
elif "past_key_values" in input_name:
if model.config.model_type not in _FX_SUPPORTED_MODELS_WITH_KV_CACHE:
raise NotImplementedError(
f"Symbolic trace with past_key_values input is not supported yet for the model {model.config.model_type}. Please open an issue or a PR in Transformers repository if you would like to see the support added."
)
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // model.config.num_attention_heads
cache_shape = (shape[0], num_heads, kv_cache_length, head_dim)
pkv = tuple(
(
torch.rand(cache_shape, dtype=torch.float, device=device),
torch.rand(cache_shape, dtype=torch.float, device=device),
)
for i in range(model.config.num_hidden_layers)
)
inputs_dict[input_name] = pkv
else:
shape_with_hidden_size = shape + [model.config.hidden_size]
inputs_dict[input_name] = torch.zeros(shape_with_hidden_size, dtype=torch.float, device=device)
return inputs_dict
def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
if kind == "placeholder" and target in self.meta_args:
rv.install_metadata(self.meta_args[target])
return rv
if target in self.orig_fns:
# NOTE: tensor constructors in PyTorch define the `device` argument as
# *kwargs-only*. That is why this works. If you add methods to
# _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
# this will break and you will likely see issues where we cannot infer
# the size of the output.
if "device" in kwargs:
kwargs["device"] = "meta"
try:
args_metas = torch.fx.node.map_aggregate(args, _proxies_to_metas)
kwargs_metas = torch.fx.node.map_aggregate(kwargs, _proxies_to_metas)
if kind == "call_function":
meta_target = _MANUAL_META_OVERRIDES.get(target, target)
meta_out = meta_target(*args_metas, **kwargs_metas)
if isinstance(meta_out, torch.Tensor):
meta_out = meta_out.to(device="meta")
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
meta_target = _MANUAL_META_OVERRIDES.get(method, method)
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_module":
if not hasattr(self, "orig_forward"):
raise AttributeError(f"{self} does not have an attribute called orig_forward")
self._disable_module_getattr = True
try:
mod = self.root.get_submodule(target)
mod_type = type(mod)
if mod_type in _MANUAL_META_OVERRIDES:
meta_out = _MANUAL_META_OVERRIDES[mod_type](mod, *args_metas, **kwargs_metas)
else:
meta_out = self.orig_forward(*args_metas, **kwargs_metas)
finally:
self._disable_module_getattr = False
elif kind == "get_attr":
self._disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
if isinstance(attr_itr, torch.Tensor):
meta_out = attr_itr.to(device="meta")
else:
meta_out = attr_itr
finally:
self._disable_module_getattr = False
else:
return rv
if not isinstance(rv, Proxy):
raise ValueError("Don't support composite output yet")
rv.install_metadata(meta_out)
except Exception as e:
if _IS_IN_DEBUG_MODE:
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
return rv
# Replaced by .getattr from PyTorch 1.13
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "_disable_module_getattr", False):
return attr_val
else:
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (
None
if not self.param_shapes_constant
else lambda node: ParameterProxy(self, node, n, attr_val)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(
attr_val, self.root.named_buffers(), parameter_proxy_cache
)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
return attr_val
# Needed for PyTorch 1.13+
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)
def proxy(self, node):
return HFProxy(node, self)
def trace(
self,
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
dummy_inputs: Optional[Dict[str, Any]] = None,
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
) -> Graph:
"""
Traces `root` and returns the corresponding FX `torch.fx.Graph` representation. `root` can either be a
`torch.nn.Module` instance or a Python callable. Note that after this call, `self.root` may be different from
the `root` passed in here. For example, when a free function is passed to `trace()`, we will create a
`torch.nn.Module` instance to use as the root and add embedded constants to.
Args:
root (`torch.nn.Module` or `Callable`):
Either a `torch.nn.Module`` or a function to be traced through. If root is not a
[`~transformers.PreTrainedModel`], then `dummy_inputs` must be passed, otherwise tracing will fail.
concrete_args (`Dict[str, Any], *optional*):
Concrete arguments that should not be treated as Proxies
dummy_inputs (`Dict[str, Any]`, *optional*):
The dummy inputs needed to handle data-dependent control-flow if `root` is not a
[`~transformers.PreTrainedModel`]. It can also be used when `root` is a
[`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
complete_concrete_args_with_inputs_not_in_dummy_inputs (`bool`, *optional*, defaults to `True`):
If `True`, and `dummy_inputs` is specified, every argument that `root` can take that is not in
`dummy_inputs` and not in `concrete_args` will be added to `concrete_args`, otherwise does nothing.
Returns:
`torch.fx.Graph`:
A FX `torch.fx.Graph` representing the semantics of the passed-in `root`.
"""
sig = inspect.signature(root.forward if isinstance(root, torch.nn.Module) else root)
if concrete_args is None:
concrete_args = {}
if dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs:
for param in sig.parameters.values():
if param.name in dummy_inputs:
continue
if param.default is inspect.Parameter.empty:
raise ValueError(f"You need to specify a default value for the parameter {param.name}.")
concrete_args.update(
{
p.name: p.default
for p in sig.parameters.values()
if (p.name not in dummy_inputs and p.name not in concrete_args)
}
)
input_names = sig.parameters.keys() - concrete_args.keys()
# Creating a random input shape to generate dummy inputs.
batch_size = _generate_random_int()
sequence_length = _generate_random_int()
shape = [batch_size, sequence_length]
if root.__class__.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
num_choices = _generate_random_int(low=2, high=5)
shape.insert(1, num_choices)
inputs = dict(dummy_inputs) if dummy_inputs is not None else {}
for input_name in input_names:
if input_name in inputs:
continue
# We enforce that root must either be a PreTrainedModel or deserialized from a serialized traced model to
# be able to use HFTracer._generate_dummy_input.
if isinstance(root, self.supported_archs) or type(root).__qualname__.startswith(
("_deserialize_graph_module", "_CodeOnlyModule")
):
inputs.update(self._generate_dummy_input(root, input_name, shape, input_names=input_names))
else:
raise RuntimeError(
f"Could not generate input named {input_name} for because root is not a"
" transformers.PreTrainedModel."
)
concrete_metas = {
input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_
for input_name, input_ in inputs.items()
}
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names:
concrete_metas[f"**{param.name}"] = {}
self.meta_args = concrete_metas
self.patched_torch_methods = {
target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
}
self.orig_fns = set()
for name, (wrapper, orig) in self.patched_torch_methods.items():
setattr(torch, name, wrapper)
self.orig_fns.add(orig)
try:
self.graph = super().trace(root, concrete_args=concrete_args)
finally:
for name, (_, orig) in self.patched_torch_methods.items():
setattr(torch, name, orig)
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in input_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
to_visit = [node]
to_delete = collections.OrderedDict()
while to_visit:
n = to_visit.pop(0)
to_delete[n] = None
to_visit += list(n.users.keys())
for user in reversed(to_delete.keys()):
self.graph.erase_node(user)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
return self.graph
def _stateless_mod_instanciation_depends_on_proxies(self, mod: nn.Module) -> bool:
"""
Whether the module was instantiated with Proxies. If that is the case, such module cannot be a leaf module
because its attributes are input-dependent.
"""
return any(isinstance(attr, Proxy) for attr in mod.__dict__.values())
def _insert_module_as_submodule(self, mod: nn.Module) -> str:
"""
Helper method which tries to insert a module that was not declared as submodule.
"""
# If one of the module attributes is a Proxy, it means that its instantiation is input-dependent.
# It is not possible to insert such modules, those should be traced through.
if self._stateless_mod_instanciation_depends_on_proxies(mod):
return ""
idx = 0
mod_name = mod.__class__.__name__.lower()
path = f"{mod_name}_{idx}"
already_inserted = False
while hasattr(self.root, path):
if getattr(self.root, path) is mod:
already_inserted = True
break
path = f"{mod_name}_{idx}"
idx += 1
# No need to add multiple instances of the same module.
if not already_inserted:
self.root.add_module(path, mod)
return path
def path_of_module(self, mod: nn.Module) -> str:
"""
Helper method to find the qualified name of `mod` in the Module hierarchy of `root`. For example, if `root` has
a submodule named `foo`, which has a submodule named `bar`, passing `bar` into this function will return the
string "foo.bar".
Args:
mod (str): The `Module` to retrieve the qualified name for.
"""
try:
return super().path_of_module(mod)
except NameError as e:
if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
path = self._insert_module_as_submodule(mod)
return path
raise e
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
return (not self._stateless_mod_instanciation_depends_on_proxies(m)) and super().is_leaf_module(
m, module_qualified_name
)
@compatibility(is_backward_compatible=True)
def keys(self, obj: "Proxy") -> Any:
"""Called when a proxy object is has the keys() method called.
This is what happens when ** is called on a proxy. This should return an iterator if ** is supposed to work in
your custom tracer.
"""
attribute = HFAttribute(obj, "keys")()
if obj.node.target == "**kwargs":
return attribute._metadata
return attribute
def get_concrete_args(model: nn.Module, input_names: List[str]):
sig = inspect.signature(model.forward)
if not (set(input_names) <= set(sig.parameters.keys())):
formatted_input_names = input_names[0] if len(input_names) == 1 else ", ".join(input_names)
formatted_allowed_input_names = ", ".join(sig.parameters.keys())
raise ValueError(
f"The model does not have input(s) named: {formatted_input_names}, expected a subset of the following:"
f" {formatted_allowed_input_names}"
)
return {p.name: p.default for p in sig.parameters.values() if p.name not in input_names}
def is_model_supported(model: PreTrainedModel):
return model.__class__.__name__ in _SUPPORTED_MODELS
def check_if_model_is_supported(model: PreTrainedModel):
if not is_model_supported(model):
supported_model_names = ", ".join(_SUPPORTED_MODELS)
raise NotImplementedError(
f"Model {model.__class__.__name__} is not supported yet, supported models: {supported_model_names}"
)
def symbolic_trace(
model: PreTrainedModel,
input_names: Optional[List[str]] = None,
disable_check: bool = False,
tracer_cls: Type[HFTracer] = HFTracer,
) -> GraphModule:
"""
Performs symbolic tracing on the model.
Args:
model ([`PretrainedModel`]):
The model to trace.
input_names (`List[str]`, *optional*):
The names of the inputs of the traced model. If unset, model.dummy_inputs.keys() are used instead.
disable_check (`bool`, *optional*, defaults to `False`):
If `True`, no check is done before trying to trace the model, this is mostly usesul for debugging purposes.
tracer_cls (`Type[HFTracer]`, *optional*, defaults to `HFTracer`):
The tracer class to use for instantiating the tracer. If unset, `HFTracer` is used instead.
Returns:
`torch.fx.GraphModule`: A GraphModule constructed by recording operations seen while tracing the model.
Example:
```python
from transformers.utils.fx import symbolic_trace
traced_model = symbolic_trace(model, input_names=["input_ids", "attention_mask", "token_type_ids"])
```
"""
if input_names is None:
input_names = model.dummy_inputs.keys()
input_names = list(input_names)
concrete_args = get_concrete_args(model, input_names)
if not disable_check:
check_if_model_is_supported(model)
# Tracing.
tracer = tracer_cls()
traced_graph = tracer.trace(model, concrete_args=concrete_args)
traced = torch.fx.GraphModule(model, traced_graph)
traced.config = model.config
# The model class must be stored as an attribute to allow model deserialization, which uses trace, and thus
# _generate_dummy_input, where the model class is needed.
traced.class_for_deserialization = model.__class__
traced.device = model.device
return traced