Spaces:
Running
Running
# mypy: ignore-errors | |
import abc | |
import collections | |
import contextlib | |
import dataclasses | |
import enum | |
import functools | |
import inspect | |
import itertools | |
import logging | |
import operator | |
import re | |
import sys | |
import types | |
from typing import List, NamedTuple, Optional, Union | |
from torch.utils._sympy.value_ranges import ValueRanges | |
try: | |
import numpy as np | |
except ModuleNotFoundError: | |
np = None | |
import torch | |
from torch import SymInt | |
from torch._guards import GuardSource, TracingContext | |
from torch._ops import HigherOrderOperator | |
from torch._streambase import _EventBase, _StreamBase | |
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode | |
from torch._subclasses.meta_utils import is_sparse_any | |
from torch.fx.experimental._backward_state import BackwardState | |
from torch.fx.experimental.symbolic_shapes import ( | |
_constrain_range_for_size, | |
DimDynamic, | |
RelaxedUnspecConstraint, | |
StatefulSymbolicContext, | |
SubclassSymbolicContext, | |
SymbolicContext, | |
) | |
from torch.fx.immutable_collections import immutable_list | |
from torch.utils._python_dispatch import is_traceable_wrapper_subclass | |
from torch.utils.weak import TensorWeakRef | |
from .. import config, mutation_guard, replay_record, trace_rules | |
from ..device_interface import get_registered_device_interfaces | |
from ..exc import InternalTorchDynamoError, unimplemented | |
from ..guards import GuardBuilder, install_guard, make_dupe_guard | |
from ..side_effects import SideEffects | |
from ..source import ( | |
AttrSource, | |
ConstantSource, | |
ConstDictKeySource, | |
ConvertIntSource, | |
GetItemSource, | |
is_constant_source, | |
is_from_defaults, | |
LocalSource, | |
NumpyTensorSource, | |
RandomValueSource, | |
Source, | |
TupleIteratorGetItemSource, | |
) | |
from ..trace_rules import is_callable_allowed, is_numpy | |
from ..utils import ( | |
build_checkpoint_variable, | |
clone_input, | |
common_constant_types, | |
get_fake_value, | |
get_static_address_type, | |
is_function_or_wrapper, | |
is_namedtuple, | |
is_typing, | |
is_utils_checkpoint, | |
istype, | |
odict_values, | |
preserve_rng_state, | |
tensor_always_has_static_shape, | |
tuple_iterator, | |
tuple_iterator_getitem, | |
tuple_iterator_len, | |
unwrap_with_attr_name_if_wrapper, | |
wrap_fake_exception, | |
) | |
from .base import MutableLocal, typestr, VariableTracker | |
from .constant import ConstantVariable, EnumVariable | |
from .ctx_manager import ( | |
AutocastModeVariable, | |
EventVariable, | |
NullContextVariable, | |
PreserveVersionContextVariable, | |
StreamContextVariable, | |
StreamVariable, | |
) | |
from .dicts import ( | |
ConstDictVariable, | |
DataClassVariable, | |
DefaultDictVariable, | |
HFPretrainedConfigVariable, | |
PythonSysModulesVariable, | |
SetVariable, | |
) | |
from .distributed import ( | |
DeviceMeshVariable, | |
PlacementClassVariable, | |
PlacementVariable, | |
ProcessGroupVariable, | |
) | |
from .functions import ( | |
CollectiveFunctionRewriteVariable, | |
FunctoolsPartialVariable, | |
TritonKernelVariable, | |
UserMethodVariable, | |
) | |
from .higher_order_ops import TorchHigherOrderOperatorVariable | |
from .iter import ItertoolsVariable | |
from .lazy import LazyVariableTracker | |
from .lists import ( | |
BaseListVariable, | |
ListVariable, | |
NamedTupleVariable, | |
RangeVariable, | |
RestrictedListSubclassVariable, | |
SizeVariable, | |
SliceVariable, | |
TupleIteratorVariable, | |
TupleVariable, | |
) | |
from .misc import ( | |
AutogradFunctionContextVariable, | |
AutogradFunctionVariable, | |
ComptimeVariable, | |
DebuggingVariable, | |
GetAttrVariable, | |
GetSetDescriptorVariable, | |
InspectSignatureVariable, | |
LambdaVariable, | |
MethodWrapperVariable, | |
NumpyVariable, | |
PythonModuleVariable, | |
SavedTensorBox, | |
TypingVariable, | |
) | |
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable | |
from .optimizer import OptimizerVariable | |
from .sdpa import SDPAParamsVariable | |
from .tensor import ( | |
NumpyNdarrayVariable, | |
SymNodeVariable, | |
TensorSubclassVariable, | |
TensorVariable, | |
UnspecializedPythonVariable, | |
) | |
from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable | |
from .torch_function import build_torch_function_fn, TensorWithTFOverrideVariable | |
from .user_defined import ( | |
KeyedJaggedTensorVariable, | |
UserDefinedClassVariable, | |
UserDefinedObjectVariable, | |
) | |
log = logging.getLogger(__name__) | |
DimList = List | |
class _missing: | |
pass | |
class GraphArg: | |
source: Source | |
# TODO: storing a SymInt here but not a FakeTensor is a pretty strange | |
# thing to do. Probably should have example (which stores an int) and | |
# fake_example | |
_example: Union[TensorWeakRef, torch.SymInt] | |
is_unspecialized: bool | |
fake_tensor: Optional[torch._subclasses.fake_tensor.FakeTensor] | |
# UnspecializedPythonVariable often masquerades as a tensor. | |
# We MUST NOT generate shape guard code | |
# that actually tries to access tensor properties on these values. | |
# is_tensor lets us tell if this graph arg actually is a tensor | |
# or not. | |
is_tensor: bool = True | |
# Sometimes, the Tensor we pass to example is freshly allocated (smh). | |
# Then we cannot only keep a weak reference to it. This lets you | |
# stash a strong reference too. | |
example_strong_ref: Optional[torch.Tensor] = None | |
def example(self): | |
if isinstance(self._example, TensorWeakRef): | |
r = self._example() | |
assert r is not None | |
return r | |
else: | |
return self._example | |
def __post_init__(self): | |
if isinstance(self._example, torch.Tensor): | |
self._example = TensorWeakRef(self._example) | |
assert is_fake(self.fake_tensor) | |
def reconstruct(self, codegen): | |
self.source.reconstruct(codegen) | |
def erase(self): | |
self._example = None | |
self.example_strong_ref = None | |
def __eq__(self, other): | |
return self.source.name() == other.source.name() | |
class BackwardStateGraphArg(GraphArg): | |
def __init__(self): | |
super().__init__( | |
source=None, | |
_example=BackwardState(), | |
is_unspecialized=False, | |
fake_tensor=None, | |
is_tensor=False, | |
) | |
def reconstruct(self, codegen): | |
assert codegen.tx.output.backward_state_var | |
codegen.load_import_from(BackwardState.__module__, "BackwardState") | |
codegen.call_function(0, True) | |
codegen.dup_top() | |
codegen.store(codegen.tx.output.backward_state_var) | |
class FrameStateSizeEntry: | |
scalar: Optional[int] | |
size: Optional[List[int]] | |
class VariableBuilder: | |
"""Wrap a python value in a VariableTracker() instance""" | |
def __init__( | |
self, | |
tx, | |
source: Source, | |
): | |
assert ( | |
source is not None | |
), "Consider SourcelessBuilder for ephemeral objects, usually objects created locally." | |
assert TracingContext.try_get() is not None, "Expected active TracingContext" | |
super().__init__() | |
self.tx = tx | |
self.source = source | |
self.name = source.name() | |
def __call__(self, value): | |
if value in self.tx.output.side_effects: | |
side_effect_result = self.tx.output.side_effects[value] | |
dup_guard = make_dupe_guard(self.source, side_effect_result.source) | |
if dup_guard: | |
self.install_guards(dup_guard) | |
return side_effect_result | |
vt = self._wrap(value) | |
vt.source = self.source | |
if self._can_lift_attrs_to_inputs(vt): | |
vt = self.tx.output.side_effects.track_object_existing(value, vt) | |
return vt | |
def _can_lift_attrs_to_inputs(self, vt): | |
if type(vt) in [ | |
TensorVariable, | |
TensorWithTFOverrideVariable, | |
UserDefinedObjectVariable, | |
NumpyNdarrayVariable, | |
]: | |
return True | |
return False | |
def _common_constants(): | |
return { | |
# We zero-one specialize shapes, so specialize these constants | |
# too | |
0, | |
1, | |
# NB: There used to be more constants here, but honestly it was | |
# pretty confusing. Note we specialize floats by default, and | |
# DON'T specialize ints by default. This all only matters with | |
# dynamic_shapes | |
} | |
def get_source(self): | |
return self.source | |
def install_guards(self, *guards): | |
source = self.get_source() | |
if ( | |
isinstance(source, ConstantSource) | |
or source.guard_source() == GuardSource.CONSTANT | |
): | |
return None | |
install_guard(*[source.make_guard(guard) for guard in guards], skip=1) | |
return {} | |
def set_source_and_track_mutable(self, value, var): | |
assert isinstance(var, VariableTracker) | |
var.source = self.source | |
return self.tx.output.side_effects.track_mutable(value, var) | |
def _type_dispatch(cls): | |
# NB: Careful not to close over self to avoid ref cycle from lru_cache | |
entries = [ | |
( | |
( | |
torch.Tensor, | |
torch.nn.Parameter, | |
torch._subclasses.FakeTensor, | |
torch._subclasses.functional_tensor.FunctionalTensor, | |
), | |
cls.wrap_tensor, | |
), | |
( | |
(tuple, list, odict_values, collections.deque, torch.Size), | |
cls.wrap_listlike, | |
), | |
(tuple_iterator, cls.wrap_tuple_iterator), | |
((slice, range), cls.wrap_slice_range), | |
(tuple(common_constant_types), cls.wrap_literal), | |
] | |
if config.trace_numpy and np: | |
entries.append((np.ndarray, cls.wrap_numpy_ndarray)) | |
result = {} | |
for ts, fn in entries: | |
for t in ts if isinstance(ts, tuple) else (ts,): | |
assert t not in result | |
result[t] = fn | |
return result | |
def _id_dispatch(cls): | |
from ..comptime import comptime | |
entries = [ | |
( | |
inspect.signature, | |
lambda self, value: LambdaVariable( | |
InspectSignatureVariable.create, | |
source=self.source, | |
**self.install_guards(GuardBuilder.CLOSURE_MATCH), | |
), | |
), | |
(comptime, lambda self, value: ComptimeVariable()), | |
( | |
dataclasses.fields, | |
lambda self, value: LambdaVariable( | |
_dataclasses_fields_lambda, | |
source=self.source, | |
**self.install_guards(GuardBuilder.FUNCTION_MATCH), | |
), | |
), | |
] | |
result = {} | |
for ts, fn in entries: | |
for t in ts if isinstance(ts, (tuple, list)) else (ts,): | |
assert t not in result | |
result[id(t)] = fn | |
return result | |
def _wrap(self, value): | |
# import here to avoid circular dependencies | |
from torch.utils._triton import has_triton | |
if has_triton(): | |
from triton.runtime.autotuner import Autotuner | |
from triton.runtime.jit import JITFunction | |
else: | |
class JITFunction: | |
pass | |
class Autotuner: | |
pass | |
# Handle exact type() match | |
type_dispatch = self._type_dispatch().get(type(value)) | |
if type_dispatch is not None: | |
return type_dispatch(self, value) | |
# Handle exact id() match | |
id_dispatch = self._id_dispatch().get(id(value)) | |
if id_dispatch is not None: | |
return id_dispatch(self, value) | |
# Note - There are some nested values where types mismatch! | |
# We want to get those out and wrap those. | |
value = inspect.getattr_static(value, "_torchdynamo_inline", value) | |
# Everything else (NB: order matters!) | |
if is_traceable_wrapper_subclass(value) or istype( | |
value, config.traceable_tensor_subclasses | |
): | |
return self.wrap_tensor(value) | |
elif is_namedtuple(value): | |
return self.wrap_listlike(value) | |
elif value is torch.utils._pytree.SUPPORTED_NODES: | |
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509) | |
# under the assumption that the values themselves don't change. | |
self.install_guards(GuardBuilder.DICT_VERSION) | |
result = { | |
ConstantVariable.create(k): UserDefinedObjectVariable( | |
v, | |
source=GetItemSource( | |
self.get_source(), ConstDictKeySource(self.get_source(), i) | |
), | |
) | |
for i, (k, v) in enumerate(value.items()) | |
} | |
return ConstDictVariable(result, type(value)) | |
elif value is sys.modules: | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return PythonSysModulesVariable(source=self.source) | |
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)): | |
if not value and self.get_source().is_nn_module(): | |
# It is faster to guard on 'false' property than to guard | |
# on actual dict keys, but we can't do this fast guard in general because | |
# it omits a crucial type check that ensures the value is actually still a dict at runtime. | |
# Why is this OK for (specialized) nnmodules? We set up a setattr hook | |
# to check for module property mutations, which does a reasonable, | |
# but not completely secure job ensuring a property wasn't changed. | |
self.install_guards(GuardBuilder.BOOL_FALSE) | |
else: | |
self.install_guards(GuardBuilder.DICT_LENGTH) | |
# Optimisation for the common case strings, ints, etc | |
all_const = all(ConstantVariable.is_literal(k) for k in value.keys()) | |
if all_const: | |
self.install_guards(GuardBuilder.DICT_CONST_KEYS) | |
# We need all the keys to be hashable. We do this within the | |
# _HashableTracker class in dicts.py | |
def build_key_value(i, k, v): | |
if all_const: | |
key = ConstantVariable.create(k) | |
source_key = k | |
else: | |
source_key = ConstDictKeySource(self.get_source(), i) | |
key = LazyVariableTracker.create(k, source_key) | |
source_value = GetItemSource(self.get_source(), source_key) | |
value = LazyVariableTracker.create(v, source_value) | |
return key, value | |
result = dict( | |
build_key_value(i, k, v) for i, (k, v) in enumerate(value.items()) | |
) | |
if istype(value, collections.defaultdict): | |
factory_source = AttrSource(self.source, "default_factory") | |
result = DefaultDictVariable( | |
result, | |
type(value), | |
default_factory=VariableBuilder(self.tx, factory_source)( | |
value.default_factory | |
), | |
source=self.source, | |
) | |
else: | |
result = ConstDictVariable(result, type(value), source=self.source) | |
return self.set_source_and_track_mutable(value, result) | |
elif isinstance(value, torch.nn.Module): | |
return self.wrap_module(value) | |
elif ConstantVariable.is_literal(value): # non-atomic literals | |
return self.wrap_literal(value) | |
elif istype(value, frozenset) and ( | |
ConstantVariable.is_literal(x) for x in value | |
): | |
# For frozenset, we can guard by object ID instead of value | |
# equality, this allows us to handle non-literal values | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return ConstantVariable.create(value=value, source=self.source) | |
elif isinstance(value, enum.Enum): | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return EnumVariable(value=value, source=self.source) | |
elif DebuggingVariable.is_reorderable_logging_function(value): | |
# Put this above builtin_callable so that print() can be handled | |
# along with other builtin debugging functions | |
self.install_guards(GuardBuilder.BUILTIN_MATCH) | |
return DebuggingVariable(value, source=self.source) | |
elif is_utils_checkpoint(value): | |
return build_checkpoint_variable(source=self.source) | |
elif isinstance(value, functools.partial): | |
func_src = AttrSource(self.get_source(), "func") | |
func_obj = VariableBuilder(self.tx, func_src)(value.func) | |
args = [] | |
args_source = AttrSource(self.get_source(), "args") | |
for i, arg in enumerate(value.args): | |
args.append( | |
VariableBuilder(self.tx, GetItemSource(args_source, i))(arg) | |
) | |
keywords = {} | |
keywords_source = AttrSource(self.get_source(), "keywords") | |
for k, v in value.keywords.items(): | |
if not ConstantVariable.is_literal(k): | |
unimplemented("functools.partial with non-literal keyword") | |
keywords[k] = VariableBuilder( | |
self.tx, GetItemSource(keywords_source, k) | |
)(v) | |
install_guard( | |
self.get_source().make_guard(GuardBuilder.TYPE_MATCH), | |
keywords_source.make_guard(GuardBuilder.DICT_KEYS), | |
args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), | |
) | |
return FunctoolsPartialVariable(func_obj, args, keywords) | |
elif is_typing(value): | |
# typing.List, typing.Mapping, etc. | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return TypingVariable( | |
value, | |
source=self.source, | |
) | |
elif np is not None and isinstance(value, np.generic): | |
# numpy array scalars: convert to 0D arrays | |
return self.wrap_numpy_ndarray(np.asarray(value)) | |
elif is_numpy(value): | |
assert np | |
self.install_guards( | |
GuardBuilder.FUNCTION_MATCH | |
if callable(value) | |
else GuardBuilder.TYPE_MATCH | |
) | |
return NumpyVariable(value, source=self.source) | |
# NB: These can't be put in type_dispatch, they have to run later | |
elif CollectiveFunctionRewriteVariable.can_rewrite(value): | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return CollectiveFunctionRewriteVariable.create( | |
self.tx, | |
value, | |
source=self.source, | |
) | |
elif istype(value, torch.autograd.function.FunctionMeta): | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return AutogradFunctionVariable( | |
value, | |
source=self.source, | |
) | |
elif isinstance(value, torch.autograd.function.FunctionCtx): | |
saved_tensors_source = AttrSource(self.source, "saved_tensors") | |
install_guard( | |
self.source.make_guard(GuardBuilder.TYPE_MATCH), | |
saved_tensors_source.make_guard(GuardBuilder.SEQUENCE_LENGTH), | |
) | |
saved_tensors = [ | |
VariableBuilder(self.tx, GetItemSource(saved_tensors_source, n))(v) | |
for n, v in enumerate(value.saved_tensors) | |
] | |
return self.tx.output.side_effects.track_object_existing( | |
value, | |
AutogradFunctionContextVariable( | |
value, | |
source=self.source, | |
saved_tensors=SavedTensorBox(saved_tensors), | |
), | |
) | |
elif ( | |
isinstance(value, types.MethodType) | |
and istype( | |
getattr(value, "__self__", None), torch.autograd.function.FunctionMeta | |
) | |
and getattr(value, "__name__", "") == "apply" | |
and value == getattr(value.__self__, "apply", None) | |
): | |
# handle aliased autograd function `apply` calls | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return GetAttrVariable( | |
AutogradFunctionVariable( | |
value.__self__, source=AttrSource(self.source, member="__self__") | |
), | |
"apply", | |
) | |
elif callable(value) and trace_rules.lookup_callable(value) is not None: | |
if is_callable_allowed(value): | |
self.tx.output.has_user_defined_allowed_in_graph = True | |
return trace_rules.lookup_callable(value).create_with_source( | |
value, source=self.source | |
) | |
elif np and isinstance(value, np.number): | |
return self.wrap_unspecialized_primitive(value) | |
elif DataClassVariable.is_matching_object(value): | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
return DataClassVariable.wrap(self, value) | |
elif HFPretrainedConfigVariable.is_matching_object(value): | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
return HFPretrainedConfigVariable(value) | |
elif isinstance(value, HigherOrderOperator): | |
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.NAME_MATCH) | |
return TorchHigherOrderOperatorVariable.make(value, source=self.source) | |
elif isinstance(value, torch.cuda.StreamContext): | |
self.install_guards(GuardBuilder.ID_MATCH) | |
stream_source = AttrSource(self.source, "stream") | |
stream_var = VariableBuilder(self.tx, stream_source)(value.stream) | |
return StreamContextVariable.create(self.tx, stream_var) | |
elif isinstance(value, _StreamBase): | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return StreamVariable( | |
None, | |
value, | |
value.device, | |
source=self.source, | |
) | |
elif isinstance(value, (torch._C._SDPAParams)): | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
return SDPAParamsVariable.create(self.tx, value, self.source) | |
elif isinstance(value, _EventBase): | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return EventVariable( | |
None, | |
value, | |
source=self.source, | |
) | |
elif ( | |
isinstance(value, torch._C._TensorMeta) | |
and value in config.traceable_tensor_subclasses | |
): | |
return TensorSubclassVariable(value, source=self.source) | |
elif ( | |
istype(value, contextlib.nullcontext) | |
and inspect.getattr_static(value, "enter_result", None) is None | |
): | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
return NullContextVariable(source=self.source) | |
elif KeyedJaggedTensorVariable.is_matching_object(value): | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
result = KeyedJaggedTensorVariable(value, source=self.source) | |
# TODO: this doing it manually is bad | |
return self.tx.output.side_effects.track_object_existing(value, result) | |
elif isinstance(value, torch.optim.Optimizer): | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
return OptimizerVariable(value, source=self.source) | |
elif ProcessGroupVariable.is_process_group(value): | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return ProcessGroupVariable(value, source=self.source) | |
elif DeviceMeshVariable.is_device_mesh(value): | |
# TODO: see if we need to add custom guard instead of a simple ID_MATCH | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return DeviceMeshVariable(value, source=self.source) | |
elif PlacementClassVariable.is_placement_type(value): | |
# TODO: see if we need to add custom guard instead of a simple ID_MATCH | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return PlacementClassVariable(value, source=self.source) | |
elif PlacementVariable.is_placement(value): | |
# TODO: see if we need to add custom guard instead of a simple ID_MATCH | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return PlacementVariable( | |
value, | |
source=self.source, | |
) | |
elif istype(value, type) and value in itertools.__dict__.values(): | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return ItertoolsVariable(value, source=self.source) | |
elif isinstance(value, torch.SymBool): | |
# Note: the idea here is to re-use the infra we've built for SymInt by simulating the | |
# user provided SymBool with a SymInt in dynamo. | |
# Concretely, | |
# 1. We create a SymInt in dynamo's shape_env, whose source is constructed as ConvertIntSource(self.source). | |
# so that guards on the SymInts can be effectively applied on the original SymBool in user program. | |
# 2. We create a SymBool based on the SymInt in dynamo's ShapeEnv. Because the original user program | |
# depends on the value being a SymBool. This allows dynamo to interpret the user's program correctly. | |
value_hint = value.node.require_hint() | |
new_source = ConvertIntSource(self.source) | |
new_symint = self.tx.output.shape_env.create_unspecified_symint_and_symbol( | |
int(value_hint), | |
new_source, | |
dynamic_dim=DimDynamic.DYNAMIC, | |
) | |
sym_node_proxy = self.tx.output.root_tracer.create_graph_input( | |
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), | |
type(new_symint), | |
source=new_source, | |
) | |
sym_node_proxy.node.meta["grapharg"] = GraphArg( | |
new_source, | |
new_symint, | |
False, | |
None, | |
is_tensor=False, | |
example_strong_ref=new_symint, | |
) | |
self.tx.output.bound_symbols.add(new_symint.node.expr) | |
self.tx.output.tracked_fakes.append( | |
TrackedFake(new_symint, new_source, None) | |
) | |
return SymNodeVariable( | |
sym_node_proxy, | |
new_symint == 1, | |
) | |
elif isinstance(value, (JITFunction, Autotuner)): | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return TritonKernelVariable( | |
value, | |
None, # No kernel idx provided | |
None, # No grid provided | |
source=self.source, | |
) | |
elif isinstance(value, torch.amp.autocast_mode.autocast): | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return AutocastModeVariable( | |
target_values=[ | |
value.device, | |
value.fast_dtype, | |
value._enabled, | |
value._cache_enabled, | |
], | |
source=self.source, | |
) | |
elif TorchCtxManagerClassVariable.is_matching_cls(value): | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return TorchCtxManagerClassVariable(value, source=self.source) | |
elif is_function_or_wrapper(value): | |
value, attr_name = unwrap_with_attr_name_if_wrapper(value) | |
# For these wrappers, Dynamo points to the wrapped function, | |
# so source needs to be updated as well. | |
if attr_name is not None: | |
self.source = AttrSource(self.source, attr_name) | |
return trace_rules.lookup(value).create_with_source( | |
value, source=self.source | |
) | |
# Don't use istype, since some python modules are not subclasses of types.ModuleType directly. | |
# E.g, type(torch.ops) -> <class 'torch._ops._Ops'>, | |
# type(torch.backends.cudnn) -> <class 'torch.backends.cudnn.CudnnModule'> | |
elif isinstance(value, (types.ModuleType, replay_record.DummyModule)): | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return PythonModuleVariable( | |
value, | |
source=self.source, | |
) | |
elif isinstance(value, types.MethodType) and isinstance( | |
value.__self__, (torch.nn.Module, torch.utils._pytree.TreeSpec) | |
): | |
# don't let MethodTypes fall through to UserDefinedObject, | |
# which doesn't support 'CALL_FUNCTION' | |
# TODO(whc): Why do we limit this to methods on NNModules? | |
# I don't have a good reason for this, but it preserves the existing behavior | |
# for MBartForConditionalGeneration, which generates many graph breaks and OOMs otherwise. | |
# I suspect we probably want to relax this check and dig deeper there. | |
# In order to construct a MethodVariable in Dynamo, we start with an actual method obj from python, | |
# but need to separately wrap its underlying `__func__` and its `self` argument. We wrap `self` here | |
# and then `__func__` gets wrapped inside UserMethodVariable. | |
self_obj = VariableBuilder( | |
self.tx, source=AttrSource(self.source, "__self__") | |
)(value.__self__) | |
assert self_obj and isinstance( | |
self_obj, VariableTracker | |
), "Failed to produce a valid self obj" | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return UserMethodVariable( | |
value.__func__, | |
self_obj, | |
source=self.source, | |
) | |
elif isinstance(value, types.GetSetDescriptorType): | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return GetSetDescriptorVariable(value) | |
elif isinstance(value, types.MethodWrapperType): | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return MethodWrapperVariable(value) | |
elif issubclass(type(value), type): | |
if value in (torch.utils.hooks.BackwardHook, torch.nn.Parameter): | |
# TODO(jansel): combine this case with the one above | |
return trace_rules.lookup(value).create_with_source( | |
value, source=self.source | |
) | |
if value is torch.autograd._unsafe_preserve_version_counter: | |
self.install_guards(GuardBuilder.FUNCTION_MATCH) | |
return PreserveVersionContextVariable.constructor(self.tx) | |
# This is a userdefined class, so install an ID_MATCH even if its a | |
# global variable. | |
self.install_guards(GuardBuilder.ID_MATCH) | |
return UserDefinedClassVariable( | |
value, | |
source=self.source, | |
) | |
elif RestrictedListSubclassVariable.is_matching_cls(type(value)): | |
self.install_guards(GuardBuilder.SEQUENCE_LENGTH) | |
return self.set_source_and_track_mutable( | |
value, | |
RestrictedListSubclassVariable( | |
[ | |
LazyVariableTracker.create( | |
value=value[i], source=GetItemSource(self.source, i) | |
) | |
for i in range(len(value)) | |
], | |
user_cls=type(value), | |
user_cls_source=AttrSource(self.source, "__class__"), | |
), | |
) | |
else: | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
result = UserDefinedObjectVariable(value, source=self.source) | |
if not SideEffects.cls_supports_mutation_side_effects(type(value)): | |
# don't allow STORE_ATTR mutation with custom __setattr__ | |
return result | |
return self.tx.output.side_effects.track_object_existing(value, result) | |
def wrap_listlike(self, value: Union[tuple, list, odict_values, NamedTuple]): | |
if config.specialize_int and type(value) is torch.Size: | |
self.install_guards(GuardBuilder.CONSTANT_MATCH) | |
return ConstantVariable.create(value=value) | |
# One can index a tensor with a list/tuple. Therefore, we need to | |
# have a stricter match. | |
self.install_guards(GuardBuilder.SEQUENCE_LENGTH) | |
for item in value: | |
if item is value: | |
unimplemented("list elements are pointing to the list itself") | |
output = [ | |
LazyVariableTracker.create(item, source=GetItemSource(self.get_source(), i)) | |
for i, item in enumerate(value) | |
] | |
result = BaseListVariable.cls_for_instance(value)( | |
output, mutable_local=MutableLocal() | |
) | |
if istype(value, list): | |
return self.set_source_and_track_mutable(value, result) | |
return result | |
def wrap_tuple_iterator(self, value: tuple_iterator): | |
self.install_guards(GuardBuilder.TUPLE_ITERATOR_LEN) | |
output = [ | |
VariableBuilder(self.tx, TupleIteratorGetItemSource(self.get_source(), i))( | |
tuple_iterator_getitem(value, i) | |
) | |
for i in range(tuple_iterator_len(value)) | |
] | |
result = TupleIteratorVariable( | |
output, mutable_local=MutableLocal(), source=self.source | |
) | |
return self.set_source_and_track_mutable(value, result) | |
def wrap_slice_range(self, value: Union[slice, range]): | |
items = [ | |
VariableBuilder(self.tx, AttrSource(self.get_source(), k))( | |
getattr(value, k) | |
) | |
for k in ("start", "stop", "step") | |
] | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
if isinstance(value, slice): | |
return SliceVariable(items, source=self.source) | |
else: | |
return RangeVariable(items, source=self.source) | |
def wrap_module(self, value: torch.nn.Module): | |
from ..eval_frame import OptimizedModule | |
if istype(value, OptimizedModule): | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
self.source = AttrSource(self.source, "_orig_mod") | |
return self.wrap_module(value._orig_mod) | |
if ( | |
isinstance(value, (torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM)) | |
and not config.allow_rnn | |
): | |
unimplemented("TorchDynamo purposely graph breaks on RNN, GRU, LSTMs") | |
if mutation_guard.is_dynamic_nn_module(value): | |
# created dynamically, don't specialize on it | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
result = UnspecializedNNModuleVariable(value, source=self.source) | |
if not SideEffects.cls_supports_mutation_side_effects(type(value)): | |
# don't allow STORE_ATTR mutation with custom __setattr__ | |
return result | |
return self.tx.output.side_effects.track_object_existing(value, result) | |
elif issubclass( | |
value.__class__, torch.nn.parallel.distributed.DistributedDataParallel | |
): | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
return UnspecializedNNModuleVariable(value) | |
elif getattr(value, "_is_fsdp_managed_module", False): | |
# See note [Dynamo treats FSDP wrapped modules as UnspecializedNNModule] | |
# in fully_sharded_data_parallel.py for more information | |
# we can't do this assert inside FSDP constructor, | |
# since we don't know yet whether dynamo will be used | |
assert getattr( | |
value, "_fsdp_use_orig_params", False | |
), "Dynamo only supports FSDP with use_orig_params=True" | |
# Note on FSDP guarding | |
# 1. We expect FSDP wrapping mutates an nn module irreversably (no way to de-wrap). | |
# 2. Eager FSDP already assumes (requires, but without enforcement) that users don't mutate their | |
# model parameters/structure after FSDP wrapping, because FSDP wouldn't notice or update its FlatParams. | |
# | |
# Due to (1), once we enter this path we expect not to go back nor have to guard on type | |
# or _is_fsdp_managed_module. | |
# | |
# TODO(whc) We could add a guard on the opposite case, where a user compiled/ran | |
# pre-FSDP-wrapped model, then wrapped, to ensure that we recompile with the FSDP handling. | |
# | |
# Due to (2), we skip guards on inner contents of fsdp_managed modules, by using FSDPNNModuleSource as the | |
# guard source. This behavior is gated on config.skip_fsdp_guards. | |
# | |
# ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps | |
# them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager) | |
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH) | |
return FSDPManagedNNModuleVariable(value, source=self.get_source()) | |
else: | |
return self.tx.output.register_attr_or_module( | |
value, | |
self.name, | |
source=self.get_source(), | |
# Guards are added inside register_attr_or_module | |
) | |
def wrap_literal(self, value): | |
unspec = not config.specialize_int | |
if unspec and type(value) is int: | |
# unspecializing int by default, but still | |
# specialize for the following conditions | |
if not TracingContext.get().force_unspec_int_unbacked_size_like and ( | |
value in self._common_constants() | |
# Assume integers from global variables want to be specialized | |
or not self.source.guard_source().is_local() | |
# Assume that integers that came from NN modules want to be | |
# specialized (as we don't expect users to be changing the | |
# NN modules on the fly) | |
or self.source.guard_source().is_nn_module() | |
or is_from_defaults(self.source) | |
): | |
self.install_guards(GuardBuilder.CONSTANT_MATCH) | |
return ConstantVariable.create(value=value, source=self.source) | |
else: | |
return self.wrap_unspecialized_primitive(value) | |
else: | |
self.install_guards(GuardBuilder.CONSTANT_MATCH) | |
return ConstantVariable.create(value=value) | |
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor): | |
if is_fake(value) and maybe_get_fake_mode(value) is self.tx.fake_mode: | |
raise InternalTorchDynamoError( | |
"Cannot wrap a Tensor that has already been", | |
"wrapped by this instance of Dynamo", | |
) | |
def wrap_tensor(self, value: torch.Tensor): | |
source = self.get_source() | |
# We cannot already be tracking the tensor, which implies | |
# it would have already been wrapped | |
assert value not in self.tx.output.side_effects | |
if ( | |
source.guard_source().is_nn_module() | |
or get_static_address_type(value) is not None | |
) and not source.guard_source().is_fsdp_module(): | |
self.assert_not_wrapped_by_this_graph(value) | |
return self.tx.output.register_attr_or_module( | |
value, self.name, source=source | |
) | |
if is_constant_source(source): | |
self.assert_not_wrapped_by_this_graph(value) | |
return self.tx.output.register_attr_or_module( | |
value, | |
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), | |
source=source, | |
# Guards are added inside register_attr_or_module | |
) | |
if type(value) in config.traceable_tensor_subclasses: | |
# Ordinarily, we would fakeify a tensor so that it can get dynamic | |
# shapes and be computed on without triggering actual operations. | |
# However, how can we fakeify a tensor subclass? Ordinary | |
# inheritance (nor multiple inheritance) won't work work. | |
# | |
# Instead, our plan is to *manually simulate* the tensor subclass | |
# inheriting from a fake tensor with dynamo. This means our | |
# data representation for a tensor subclass will be a fake tensor | |
# + tensor subclass type + any extra data the subclass may have | |
# been storing on the tensor. Because all Python accesses are | |
# mediated through TensorWithTFOverrideVariable, we can ensure | |
# that we dispatch differently, e.g., according to | |
# __torch_function__ | |
# | |
# To simplify things for now, the __dict__ tracking bits haven't | |
# been implemented yet, but they can be added into this design at | |
# a later point in time. | |
subclass_type = type(value) | |
else: | |
assert type(value) in ( | |
torch.Tensor, | |
torch.nn.Parameter, | |
torch._subclasses.fake_tensor.FakeTensor, | |
torch._subclasses.functional_tensor.FunctionalTensor, | |
) or is_traceable_wrapper_subclass(value), type(value) | |
subclass_type = None | |
# NB: this just says we accessed a tensor from the same source again | |
# (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). | |
# This is distinct from two distinct sources mapping to the same | |
# Tensor (per id())! No guard is necessary here. See below for the | |
# other case. | |
is_duplicate_tensor = source in self.tx.output.input_source_to_var | |
if is_duplicate_tensor: | |
return self.tx.output.input_source_to_var[source] | |
# By this point, we should have deduplicated all tensors | |
self.assert_not_wrapped_by_this_graph(value) | |
# tx.output has multiple tracers if we're introspecting HigherOrderOperator. | |
# When we've discovered an untracked tensor, then we actually need | |
# to get Dynamo to track the tensor (which is what this function does) | |
# and put it as a graph input on the root tracer. Later on, | |
# if the input is actually used in the body of the HigherOrderOperator, | |
# then the relevant SubgraphTracer will lift it to being an input of | |
# the subgraph. | |
# See NOTE [HigherOrderOperator tracing design] for more details. | |
tensor_proxy = self.tx.output.root_tracer.create_graph_input( | |
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source | |
) | |
options = {} | |
if type(value) in config.traceable_tensor_subclasses: | |
options["torch_function_fn"] = build_torch_function_fn( | |
self.tx, value, self.source | |
) | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
if ( | |
isinstance(value, torch.Tensor) | |
and value.is_nested | |
and not isinstance(value, torch.nested._internal.nested_tensor.NestedTensor) | |
): | |
unimplemented("torch.compile does not support strided NestedTensor") | |
if is_sparse_any(value): | |
unimplemented( | |
f"torch.compile does not support sparse Tensor with {value.layout} layout" | |
) | |
tensor_variable = wrap_fx_proxy( | |
tx=self.tx, | |
proxy=tensor_proxy, | |
example_value=value, | |
subclass_type=subclass_type, | |
source=source, | |
**options, | |
) | |
self.install_guards( | |
functools.partial( | |
GuardBuilder.TENSOR_MATCH, | |
value=value | |
if isinstance(source, NumpyTensorSource) | |
else TensorWeakRef(value), | |
) | |
) | |
# We install TYPE_MATCH guards for traceable wrapper subclass object, | |
# and recursively install corresponding guard for each inner attribute. | |
if is_traceable_wrapper_subclass(value): | |
self.install_guards(GuardBuilder.TYPE_MATCH) | |
attrs, _ = value.__tensor_flatten__() | |
for attr in attrs: | |
inner_value = getattr(value, attr) | |
inner_source = AttrSource(self.source, attr) | |
VariableBuilder(self.tx, inner_source)(inner_value).recursive_realize() | |
self.tx.output.input_source_to_var[source] = tensor_variable | |
assert "tensor_dict" not in tensor_proxy.node.meta | |
tensor_proxy.node.meta["tensor_dict"] = value.__dict__.copy() | |
# Note: this information is conveyed via subclass_type now | |
fake_tensor_value = tensor_variable.proxy.node.meta["example_value"] | |
if maybe_get_fake_mode(fake_tensor_value) is not self.tx.fake_mode: | |
raise InternalTorchDynamoError("Wrapped Tensor must be this graph's fake") | |
grapharg = GraphArg(source, value, False, fake_tensor_value) | |
tensor_proxy.node.meta["grapharg"] = grapharg | |
self.tx.output.add_symbol_bindings(grapharg) | |
return tensor_variable | |
def wrap_numpy_ndarray(self, value): | |
assert np is not None | |
assert isinstance(value, np.ndarray) | |
source = NumpyTensorSource(self.get_source()) | |
from torch._numpy import _util | |
readonly = not value.flags.writeable | |
if readonly: | |
try: | |
value.flags.writeable = True | |
except ValueError: | |
# One can not easily make nditer elements writable, | |
# but warning is not the end of the world | |
assert isinstance(value.base, np.nditer) | |
pass | |
try: | |
tensor_value = _util._try_convert_to_tensor(value) | |
if readonly: | |
from torch._prims_common import clone_preserve_strides | |
tensor_value = clone_preserve_strides(tensor_value) | |
except NotImplementedError as e: | |
# failed to convert to tensor, graph break | |
unimplemented(str(e)) | |
# We do this because we want the full behavior of guarding the numpy ndarray as if it were | |
# a tensor. It's a little annoying to make a VT to throw out, but there's so many side effects here | |
# that there's not another great way to do this atm. | |
# This creates the right graphargs, as well as registration for guards in tensor names and shape env. | |
VariableBuilder(self.tx, source)(tensor_value).recursive_realize() | |
proxy = self.tx.output.root_tracer.create_graph_input( | |
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(tensor_value), source=source | |
) | |
options = {"source": source} | |
numpy_ndarray_variable = wrap_fx_proxy_cls( | |
target_cls=NumpyNdarrayVariable, | |
tx=self.tx, | |
proxy=proxy, | |
example_value=tensor_value, | |
**options, | |
) | |
self.tx.output.input_source_to_var[source] = numpy_ndarray_variable | |
example_value = numpy_ndarray_variable.proxy.node.meta["example_value"] | |
# is_unspecialized should be true because we are wrapping a np.ndarray as argument input, and it needs to be | |
# converted to a tensor. | |
grapharg = GraphArg( | |
source, | |
tensor_value, | |
is_unspecialized=True, | |
fake_tensor=example_value, | |
is_tensor=True, | |
example_strong_ref=tensor_value, | |
) | |
proxy.node.meta["grapharg"] = grapharg | |
return numpy_ndarray_variable | |
def wrap_unspecialized_primitive(self, value): | |
if self.name in self.tx.output.unspec_variable_map: | |
return self.tx.output.unspec_variable_map[self.name] | |
else: | |
shape_env = self.tx.output.shape_env | |
if TracingContext.get().force_unspec_int_unbacked_size_like and isinstance( | |
value, int | |
): | |
wrapped_value = shape_env.create_unbacked_symint() | |
_constrain_range_for_size(wrapped_value) | |
self.tx.output.bound_symbols.add(wrapped_value.node.expr) | |
self.tx.output.tracked_fakes.append( | |
TrackedFake(wrapped_value, self.source, None) | |
) | |
# NB: We do not do float. For motivation, see | |
# https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit | |
# but the general idea is that we generate kernels that can | |
# take unspecialized floats and use them in sizevar computation | |
elif ( | |
isinstance(value, int) | |
and not is_constant_source(self.get_source()) | |
and not isinstance(self.get_source(), RandomValueSource) | |
): | |
if torch._dynamo.config.specialize_int: | |
# If specialize_int is False, also return | |
# a constant (but this should have been handled | |
# in the caller, TBH) | |
self.install_guards(GuardBuilder.CONSTANT_MATCH) | |
return ConstantVariable.create(value=value, source=self.source) | |
name = self.source.name() | |
if name not in self.tx.output.frame_state: | |
# Note - this essentially means that if this name gets reused as a tensor, | |
# it will start fully dynamic. That should always be a safe option, and not awfully inefficient. | |
# Alternatively, if we want to improve pef here, we can add a third state of unset, but I am not | |
# sure that is necessary for now. | |
frame_state_entry = FrameStateSizeEntry(scalar=value, size=None) | |
else: | |
frame_state_entry = self.tx.output.frame_state[name] | |
if frame_state_entry.scalar != value: | |
log.debug( | |
"automatic dynamic int %s val %s != %s", | |
name, | |
value, | |
frame_state_entry.scalar, | |
) | |
frame_state_entry.scalar = None | |
self.tx.output.frame_state[name] = frame_state_entry | |
# TODO: This should be dynamic, as we in general do not | |
# know if bare integers are actually going to be sizevars | |
# and it is inappropriate to eagerly duck size them with | |
# real sizevars | |
if ( | |
config.automatic_dynamic_shapes and frame_state_entry.scalar is None | |
) or not config.assume_static_by_default: | |
dynamic_dim = DimDynamic.DYNAMIC | |
else: # assume_static_by_default | |
# TODO: dynamic_dim = DimDynamic.STATIC should work but | |
# for some reason it doesn't | |
self.install_guards(GuardBuilder.CONSTANT_MATCH) | |
return ConstantVariable.create(value=value) | |
wrapped_value = shape_env.create_unspecified_symint_and_symbol( | |
value, | |
source=self.source, | |
dynamic_dim=dynamic_dim, | |
) | |
self.tx.output.bound_symbols.add(wrapped_value.node.expr) | |
self.tx.output.tracked_fakes.append( | |
TrackedFake(wrapped_value, self.source, None) | |
) | |
else: | |
wrapped_value = torch.tensor(value) | |
if not isinstance(self.get_source(), RandomValueSource): | |
install_guard(self.get_source().make_guard(GuardBuilder.TYPE_MATCH)) | |
options = {"source": self.get_source()} | |
if isinstance(wrapped_value, torch.Tensor): | |
options.update({"raw_value": value}) | |
proxy = self.tx.output.root_tracer.create_graph_input( | |
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), | |
type(wrapped_value), | |
source=self.get_source(), | |
) | |
unspec_var = wrap_fx_proxy_cls( | |
UnspecializedPythonVariable, | |
tx=self.tx, | |
proxy=proxy, | |
example_value=wrapped_value, | |
**options, | |
) | |
self.tx.output.unspec_variable_map[self.name] = unspec_var | |
if not is_constant_source(self.get_source()): | |
if self.tx.export and not isinstance(self.get_source(), LocalSource): | |
raise AssertionError( | |
"Dynamo attempts to add additional input during export: value={}, source={}".format( | |
wrapped_value, self.get_source() | |
) | |
) | |
fake_tensor_value = None | |
if isinstance(unspec_var, ConstantVariable): | |
example_value = unspec_var.value | |
else: | |
example_value = unspec_var.proxy.node.meta["example_value"] | |
if is_fake(example_value): | |
fake_tensor_value = example_value | |
assert fake_tensor_value.fake_mode is self.tx.fake_mode, ( | |
f"fake mode ({fake_tensor_value.fake_mode}) from fake tensor metadata doesn't match mode" | |
"({self.tx.fake_mode}) from InstructionTranslator" | |
) | |
proxy.node.meta["grapharg"] = GraphArg( | |
self.get_source(), | |
wrapped_value, | |
isinstance(wrapped_value, torch.Tensor), | |
fake_tensor_value, | |
is_tensor=False, | |
example_strong_ref=wrapped_value, | |
) | |
return unspec_var | |
def _dataclasses_fields_lambda(obj): | |
if isinstance(obj, UserDefinedObjectVariable): | |
value = obj.value | |
elif isinstance(obj, DataClassVariable): | |
value = obj.user_cls | |
else: | |
unimplemented(f"Dataclass fields handling fails for type {obj}") | |
items = [] | |
for field in dataclasses.fields(value): | |
source = None | |
if obj.source: | |
source = GetItemSource( | |
AttrSource(obj.source, "__dataclass_fields__"), field.name | |
) | |
items.append(UserDefinedObjectVariable(field, source=source)) | |
return TupleVariable(items) | |
def wrap_fx_proxy(tx, proxy, example_value=None, subclass_type=None, **options): | |
kwargs = { | |
"tx": tx, | |
"proxy": proxy, | |
"example_value": example_value, | |
"subclass_type": subclass_type, | |
**options, | |
} | |
if subclass_type is None: | |
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) | |
else: | |
result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs) | |
result.install_global(tx) | |
return result | |
# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable | |
# Should be compositional instead | |
# | |
# This is a horribly complicated function that does too many things, to | |
# explain what it does, let's first talk about the classic usage wrap_fx_proxy | |
# for a TensorVariable. There are two primary modes of use: | |
# | |
# 1. Wrapping a pre-existing Tensor. In this case, example_value is set | |
# to the pre-existing Tensor. (Note that this example_value will NOT | |
# be the final example_value we put into node.meta['example_value'], | |
# instead it is converted into a fake tensor using | |
# wrap_to_fake_tensor_and_record and registered as a graph input.) | |
# | |
# 2. "Wrapping" the result of some Tensor operation Dynamo traced over. In | |
# this case, example_value is None (and we are going to figure it out | |
# ourselves using FakeTensors, via get_fake_value, which will run | |
# the operation represented by the (singular!) FX node referenced by | |
# the passed in proxy.) | |
# | |
# The expectation is you end up with a Tensor output, and everything is | |
# straightforwardly traced into the graph. | |
# | |
# In all cases, the returned `TensorVariable` subclass will have an `example_value` | |
# and that `example_value` must be a `FakeTensor` produced by the currently running | |
# instance of Dynamo. | |
# | |
# Upon closer inspection, you may notice that there are a slurry of non-Tensor | |
# output cases. What gives? Well, we sometimes trace operations into the | |
# graph that don't involve tensors. | |
# | |
# * Some operators return tuples; we need to recursively handle their | |
# contents | |
# | |
# * Some operators have side effects that will affect subsequent AOTAutograd | |
# tracing but don't otherwise return anything. | |
# | |
# * Some operators return symbolic ints/floats/bools which can go in the | |
# graph and be traced (but only if they're actually symbolic! If they're | |
# static you don't want to put them in the graph, which means you | |
# shouldn't call this function.) | |
# | |
# The common theme is that you only use this function WHEN YOU ARE TRACING | |
# SOMETHING INTO THE GRAPH. This is sort of obvious, because you can't call | |
# this function without a proxy. | |
def wrap_fx_proxy_cls( | |
target_cls, tx, proxy, example_value=None, subclass_type=None, **options | |
): | |
from ..symbolic_convert import InstructionTranslatorBase | |
assert isinstance(tx, InstructionTranslatorBase) | |
if "guards" in options and options["guards"] is not None: | |
tx.output.guards.update(options["guards"]) | |
assert "example_value" not in proxy.node.meta, f"{proxy.node.meta['example_value']}" | |
initial_example_value = example_value | |
def _clone_input(value): | |
if isinstance(value, torch.Tensor): | |
# tensor subclasses will not be converted to FakeTensors and need to be cloned | |
if not ( | |
isinstance(value, FakeTensor) | |
or ( | |
# Is functional tensor fakeified by this instance of Dynamo | |
torch._is_functional_tensor(value) | |
and maybe_get_fake_mode(value) is tx.fake_mode | |
) | |
or value.is_nested | |
): | |
# NB: ensure strides are preserved | |
value = clone_input(value) | |
return value | |
with preserve_rng_state(): | |
if example_value is None: | |
# only allow_non_graph_fake in this instance because we handle the non-fake | |
# cases properly below. | |
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) | |
# Handle recursive calls here | |
elif maybe_get_fake_mode(example_value) is tx.fake_mode: | |
pass | |
elif isinstance(example_value, torch.Tensor): | |
if tx.export: | |
# The legacy behavior for real value cache with subclasses was | |
# to perform a clone WITHOUT preserving the subclass. It's | |
# not entirely clear this is what you actually want though. | |
with torch._C.DisableTorchFunctionSubclass(): | |
proxy.tracer.real_value_cache[proxy.node] = _clone_input( | |
example_value | |
) | |
# NB: If we're ignoring subclass, then the expectation is you will | |
# take the returned TensorVariable and wrap it into a more | |
# accurate TensorVariable that is able to track subclass-ness; | |
# otherwise this is wrong! | |
kwargs = { | |
"is_tensor": target_cls | |
in (TensorVariable, TensorWithTFOverrideVariable), | |
} | |
assert "source" in options and options["source"] is not None | |
kwargs["source"] = options["source"] | |
example_value = wrap_to_fake_tensor_and_record( | |
example_value, tx=tx, **kwargs | |
) | |
if isinstance(example_value, torch.Tensor) and ( | |
maybe_get_fake_mode(example_value) is not tx.fake_mode | |
): | |
raise InternalTorchDynamoError( | |
"`example_value` needs to be a `FakeTensor`" | |
f"wrapped by this instance of Dynamo. Found: {example_value}" | |
) | |
if isinstance(example_value, torch.Tensor): | |
is_parameter = isinstance(example_value, torch.nn.Parameter) | |
# NB: In most (all?) cases, this does not actually do a clone. | |
# (WARNING: this means that if we mutate metadata on the fake | |
# tensor, the stored example value will update too!) | |
example_value = _clone_input(example_value) | |
proxy.node.meta["example_value"] = example_value | |
specialized_props = target_cls.specialize(example_value) | |
# TODO: not sure about this fake mode test | |
if ( | |
isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) | |
and example_value.fake_mode is tx.fake_mode | |
): | |
tensor_type = subclass_type if subclass_type else torch.Tensor | |
specialized_props["class_type"] = ( | |
torch.nn.Parameter if is_parameter else tensor_type | |
) | |
options.update(specialized_props) | |
return target_cls(proxy, **options) | |
elif ( | |
hasattr(proxy.node.target, "__name__") | |
and proxy.node.target.__name__ == "set_state" | |
and isinstance(proxy.node.target.__self__, torch._C.Generator) | |
or proxy.node.target == torch.random.set_rng_state | |
): | |
return TorchInGraphFunctionVariable(proxy.node.target) | |
elif ( | |
proxy.node.target == torch._C._DisableFuncTorch | |
or proxy.node.target == torch.cuda._is_in_bad_fork | |
): | |
return UserDefinedObjectVariable(example_value) | |
elif istype(example_value, torch.Size) and all( | |
isinstance(x, int) for x in example_value | |
): | |
sizes = [ConstantVariable.create(x) for x in example_value] | |
return SizeVariable(sizes, **options) | |
elif isinstance(example_value, (tuple, list)): | |
proxy.node.meta["example_value"] = example_value | |
unpacked = [] | |
for i, val in enumerate(example_value): | |
if val is None: | |
# nn.MultiheadAttention() can return None, see issue #175 | |
unpacked.append( | |
ConstantVariable.create(None, **options), | |
) | |
else: | |
unpacked.append( | |
wrap_fx_proxy_cls( | |
target_cls, | |
tx, | |
proxy.tracer.create_proxy( | |
"call_function", operator.getitem, (proxy, i), {} | |
), | |
example_value=val, | |
**options, | |
) | |
) | |
if isinstance(example_value, torch.Size): | |
# NB: Keep the old proxy around. See SizeVariable for an | |
# explanation why | |
return SizeVariable(unpacked, proxy, **options) | |
elif istype(example_value, tuple): | |
return TupleVariable(unpacked, **options) | |
elif istype(example_value, (list, immutable_list)): | |
return ListVariable(unpacked, mutable_local=MutableLocal(), **options) | |
else: | |
assert example_value.__class__.__module__ == "torch.return_types" or hasattr( | |
example_value, "_fields" | |
), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}" | |
return NamedTupleVariable(unpacked, example_value.__class__, **options) | |
elif example_value is None or proxy.node.target is torch.manual_seed: | |
return ConstantVariable.create(None, **options) | |
elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): | |
proxy.node.meta["example_value"] = example_value | |
return SymNodeVariable(proxy, example_value, **options) | |
elif ( | |
inspect.isclass(proxy.node.target) | |
and issubclass(proxy.node.target, _StreamBase) | |
) or proxy.node.target in [ | |
device_interface.current_stream | |
for _, device_interface in get_registered_device_interfaces() | |
]: | |
proxy.node.meta["example_value"] = example_value | |
return StreamVariable(proxy, example_value, example_value.device, **options) | |
elif ( | |
inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, _EventBase) | |
) or proxy.node.target in [ | |
device_interface.Event | |
for _, device_interface in get_registered_device_interfaces() | |
]: | |
proxy.node.meta["example_value"] = example_value | |
return EventVariable(proxy, example_value, **options) | |
elif proxy.node.target == "query" and proxy.node.op == "call_method": | |
proxy.node.meta["example_value"] = example_value | |
return ConstantVariable(example_value, **options) | |
elif ( | |
example_value is not None | |
and isinstance(example_value, _EventBase) | |
and proxy.node.target == "record_event" | |
and proxy.node.op == "call_method" | |
): | |
proxy.node.meta["example_value"] = example_value | |
return EventVariable(proxy, example_value, **options) | |
elif isinstance(example_value, int) and proxy.node.target in [ | |
torch.sym_int, | |
getattr, | |
operator.getitem, | |
torch._utils._element_size, | |
torch.seed, | |
operator.mod, | |
torch._C._functorch._vmap_increment_nesting, | |
torch._C._functorch._vmap_decrement_nesting, | |
torch._functorch.vmap._validate_and_get_batch_size, | |
torch._C._functorch._grad_increment_nesting, | |
torch._C._functorch._grad_decrement_nesting, | |
# some mac builds are missing torch.distributed.get_rank() | |
getattr(torch.distributed, "get_rank", _missing), | |
getattr(torch.distributed, "get_world_size", _missing), | |
# This always wants to be in the graph, even if the constraint | |
# results in a constant int | |
torch._constrain_as_value, | |
torch._constrain_as_size, | |
]: | |
proxy.node.meta["example_value"] = example_value | |
return ConstantVariable.create(example_value, **options) | |
elif isinstance(example_value, torch.backends.cuda.SDPAParams): | |
from .sdpa import SDPAParamsVariable | |
proxy.node.meta["example_value"] = example_value | |
return SDPAParamsVariable(proxy, **options) | |
elif isinstance(example_value, bool) and proxy.node.target in [ | |
torch.backends.cuda.can_use_flash_attention, | |
torch.backends.cuda.can_use_efficient_attention, | |
]: | |
proxy.node.meta["example_value"] = example_value | |
return ConstantVariable.create(example_value, **options) | |
else: | |
unimplemented( | |
"torch.* op returned non-Tensor " | |
+ f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" | |
) | |
# Tracks the sources of all fake tensors we wrap in Dynamo. | |
# Used by shape guard computation. | |
class TrackedFake: | |
fake: Union[FakeTensor, SymInt] | |
source: Source | |
# Is None when fake is SymInt | |
symbolic_context: Optional[SymbolicContext] | |
def __hash__(self) -> int: | |
return hash((self.fake, self.source.name())) | |
def __eq__(self, other: object) -> bool: | |
if isinstance(other, TrackedFake): | |
return self.fake is other.fake and self.source.name() == other.source.name() | |
return False | |
# Performs automatic dynamic dim determination. | |
# Returns a SymbolicContext | |
def _automatic_dynamic( | |
e, tx, source, static_shapes, outer_only=False | |
) -> SymbolicContext: | |
# strided NT not supported | |
if e.is_nested and not isinstance( | |
e, torch.nested._internal.nested_tensor.NestedTensor | |
): | |
unimplemented("torch.compile does not support strided NestedTensor") | |
name = source.name() | |
prior_policy = tx.output.tracing_context.tensor_to_context.get(e, None) | |
shape_env_to_source_to_symbol_cache = ( | |
prior_policy.shape_env_to_source_to_symbol_cache if prior_policy else None | |
) | |
# Get base context if the tensor is a view | |
view_base_context: Optional[SymbolicContext] = None | |
if e._is_view(): | |
base_source = AttrSource(source, "_base") | |
view_base_context = _automatic_dynamic(e._base, tx, base_source, static_shapes) | |
if is_traceable_wrapper_subclass(e) and not outer_only: | |
# Get symbolic context for outer tensor | |
outer_context = _automatic_dynamic( | |
e, tx, source, static_shapes, outer_only=True | |
) | |
# Get symbolic contexts for inner tensors | |
attrs, _ = type(e).__tensor_flatten__(e) | |
inner_contexts = {} # mapping from attr -> symbolic context | |
for attr in attrs: | |
inner_tensor = getattr(e, attr) | |
inner_source = AttrSource(source, attr) | |
inner_context = _automatic_dynamic( | |
inner_tensor, tx, inner_source, static_shapes | |
) | |
inner_contexts[attr] = inner_context | |
return SubclassSymbolicContext( | |
dynamic_sizes=outer_context.dynamic_sizes, | |
constraint_sizes=outer_context.constraint_sizes, | |
view_base_context=view_base_context, | |
tensor_source=outer_context.tensor_source, | |
shape_env_to_source_to_symbol_cache=outer_context.shape_env_to_source_to_symbol_cache, | |
inner_contexts=inner_contexts, | |
) | |
if static_shapes: | |
return StatefulSymbolicContext( | |
dynamic_sizes=[DimDynamic.STATIC] * e.dim(), | |
constraint_sizes=[None] * e.dim(), | |
view_base_context=view_base_context, | |
tensor_source=source, | |
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, | |
) | |
# We preserve the dynamism of inputs. For example, when users call | |
# make_fx(torch.cond, tracing_mode="symbolic")(*args), inputs have SymInt sizes. | |
from torch.fx.experimental.symbolic_shapes import is_nested_int | |
if any(isinstance(s, SymInt) and not is_nested_int(s) for s in e.size()): | |
return StatefulSymbolicContext( | |
dynamic_sizes=[ | |
DimDynamic.DYNAMIC if isinstance(s, SymInt) else DimDynamic.STATIC | |
for s in e.size() | |
], | |
constraint_sizes=[None] * e.dim(), | |
view_base_context=view_base_context, | |
tensor_source=source, | |
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, | |
) | |
# Prep for automatic dynamic | |
frame_state_entry = None | |
if name not in tx.output.frame_state: | |
# If there is no entry for this source, add the tensor to frame state with its current static size. | |
# E.g., {} -> {"x": [2, 4]} | |
frame_state_entry = FrameStateSizeEntry(None, None) | |
frame_state_entry.size = list(e.size()) | |
else: | |
frame_state_entry = tx.output.frame_state[name] | |
if frame_state_entry.size is not None: | |
if e.ndim != len(frame_state_entry.size): | |
# If there is already an entry, and the dim mismatches, replace the frame state entry with None. | |
# E.g. {"x": [2, 3, 4]} -> {"x": None} | |
log.debug( | |
"automatic dynamic %s dim %s != %s", | |
name, | |
e.ndim, | |
frame_state_entry.size, | |
) | |
frame_state_entry.size = None | |
else: | |
# If there is already an entry, and the dim matches, for every size in the frame state which | |
# disagrees with the current static size, replace it with None. E.g., {"x": [2, 3]} -> {"x": [2, None]} | |
for i, dim in enumerate(frame_state_entry.size): | |
if dim is not None and e.size()[i] != dim: | |
log.debug( | |
"automatic dynamic %s size(%s) %s != %s", | |
name, | |
i, | |
e.size(i), | |
dim, | |
) | |
frame_state_entry.size[i] = None | |
# TODO: index export_constraints ahead of time so we don't have to | |
# do a linear scan every time here | |
t_id = id(e) | |
dim2constraint = {} | |
def update_dim2constraint(dim, constraint_range, debug_name): | |
if dim in dim2constraint: | |
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint | |
old_constraint_range, old_debug_name = dim2constraint[dim] | |
new_constraint_range = StrictMinMaxConstraint( | |
vr=constraint_range.vr & old_constraint_range.vr, | |
warn_only=False, | |
) | |
# It is possible for (non-None) old_debug_name and debug_name to be different | |
# but this will only happen the corresponding Dims can be derived equal. | |
new_debug_name = old_debug_name or debug_name | |
dim2constraint[dim] = new_constraint_range, new_debug_name | |
else: | |
dim2constraint[dim] = constraint_range, debug_name | |
if tx.output.export_constraints: | |
for constraint in tx.output.export_constraints: | |
if constraint.t_id == t_id: | |
update_dim2constraint( | |
constraint.dim, constraint.constraint_range, constraint.debug_name | |
) | |
if constraint.shared is not None and constraint.shared.t_id == t_id: | |
# We process constraint ranges for each shared dimension separately | |
# so that we can directly check range constraint violations on them | |
# without looking up which other shared dimensions have this info. | |
# In other words, for this t_id, we will have processed all of its | |
# constraint ranges, no matter where / how they were specified, by | |
# by the end of this loop. | |
update_dim2constraint( | |
constraint.shared.dim, | |
constraint.constraint_range, | |
constraint.debug_name, | |
) | |
dynamic_dims = [] | |
constraint_dims = [] | |
for i in range(e.dim()): | |
# NB: mark dynamic has precedence over static | |
marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set()) | |
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set()) | |
marked_static = i in getattr(e, "_dynamo_static_indices", set()) | |
# NB: both static and dynamic have precedence over | |
automatic_dynamic = config.automatic_dynamic_shapes and ( | |
frame_state_entry.size is None or frame_state_entry.size[i] is None | |
) | |
# Reflect the user directive in the frame_state | |
# For dynamic, apply None always | |
if frame_state_entry.size and marked_dynamic: | |
log.debug("automatic dynamic %s marked dynamic", name) | |
frame_state_entry.size[i] = None | |
# We will process constraints first, as they will imply that we | |
# have a dynamic dimension | |
# Precedence: export constraints > eager constraints | |
constraint = dim2constraint.get(i) | |
if constraint is None: | |
if marked_dynamic and not config.allow_ignore_mark_dynamic: | |
if hasattr(e, "_dynamo_dynamic_range"): | |
dim_range = [ | |
dr for dr in e._dynamo_dynamic_range if dr.dim == i | |
].pop() | |
if dim_range.min is None and dim_range.max is None: | |
constraint_dim = RelaxedUnspecConstraint(warn_only=False) | |
else: | |
from torch.fx.experimental.symbolic_shapes import ( | |
StrictMinMaxConstraint, | |
) | |
constraint_dim = StrictMinMaxConstraint( | |
vr=ValueRanges(lower=dim_range.min, upper=dim_range.max), | |
warn_only=False, | |
) | |
else: | |
constraint_dim = RelaxedUnspecConstraint(warn_only=False) | |
elif not marked_static and automatic_dynamic: | |
constraint_dim = RelaxedUnspecConstraint(warn_only=True) | |
else: | |
constraint_dim = None | |
else: | |
constraint_dim, debug_name = constraint | |
if debug_name is not None: | |
dim_name = f"{name}.size()[{i}]" | |
tx.output.shape_env.source_name_to_debug_name[dim_name] = debug_name | |
constraint_dims.append(constraint_dim) | |
# Now, figure out if the dim is dynamic/duck/static | |
if ( | |
constraint_dim is not None | |
or marked_dynamic | |
or marked_weak_dynamic | |
or is_nested_int(e.shape[i]) | |
): | |
# NB: We could assert static_shapes is False here, but it | |
# seems better to allow the user to override symbolic_context in this | |
# case | |
dynamic = DimDynamic.DYNAMIC | |
elif static_shapes or config.assume_static_by_default or marked_static: | |
dynamic = DimDynamic.STATIC | |
else: | |
dynamic = DimDynamic.DUCK | |
dynamic_dims.append(dynamic) | |
tx.output.frame_state[name] = frame_state_entry | |
return StatefulSymbolicContext( | |
dynamic_sizes=dynamic_dims, | |
constraint_sizes=constraint_dims, | |
view_base_context=view_base_context, | |
tensor_source=source, | |
shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, | |
) | |
# See note [Tensor Fakification and Symbol Caching] | |
def wrap_to_fake_tensor_and_record( | |
e, tx, *, source: Optional[Source], is_tensor: bool, parent_context=None | |
): | |
if ( | |
type(e) in (torch.Tensor, torch.nn.Parameter, FakeTensor) | |
or isinstance(e, torch.Tensor) | |
or is_traceable_wrapper_subclass(e) | |
): | |
assert source is not None | |
static_shapes, reason = tensor_always_has_static_shape( | |
e, is_tensor, guard_source=source.guard_source() | |
) | |
if not parent_context: | |
symbolic_context = _automatic_dynamic(e, tx, source, static_shapes) | |
else: | |
# Parent contexts are passed in when we are recursively creating | |
# fake tensors for subclasses. A better design would be not to create a | |
# parent/child relationship, but to recursively call _automatic_dynamic | |
# as we recursively call wrap_to_fake_tensor_and_record. This runs | |
# into bugs around how meta_utils knows and works to create fake tensors | |
# with tensor subclasses. Ideally, dynamo would drive both the recursive | |
# wrap_to_fake_tensor_and_record and _automatic_dynamic policy creation. | |
assert isinstance(source, AttrSource) | |
inner_context_name = source.member | |
symbolic_context = parent_context.inner_contexts[inner_context_name] | |
log.debug( | |
"wrap_to_fake %s %s %s %s", | |
source.name(), | |
tuple(e.shape), | |
symbolic_context, | |
type(e), | |
) | |
fake_e = wrap_fake_exception( | |
lambda: tx.fake_mode.from_tensor( | |
e, | |
source=source, | |
symbolic_context=symbolic_context, | |
) | |
) | |
if is_traceable_wrapper_subclass(fake_e): | |
attrs, _ = fake_e.__tensor_flatten__() | |
for attr in attrs: | |
fake_inner = getattr(fake_e, attr) | |
inner = getattr(e, attr) | |
inner_source = AttrSource(source, attr) | |
wrap_to_fake_tensor_and_record( | |
inner, | |
tx, | |
source=inner_source, | |
is_tensor=isinstance(fake_inner, torch.Tensor), | |
parent_context=symbolic_context, | |
) | |
tx.output.tracing_context.tensor_to_context[e] = symbolic_context | |
tx.output.tensor_weakref_to_sizes_strides[e] = { | |
"size": fake_e.size(), | |
"stride": fake_e.stride(), | |
} | |
if ( | |
is_tensor | |
and not (static_shapes and source.is_nn_module()) | |
and not is_constant_source(source) | |
): | |
tx.output.tracked_fakes.append( | |
TrackedFake(fake_e, source, symbolic_context) | |
) | |
tx.output.tracked_fakes_id_to_source[id(e)].append(source) | |
return fake_e | |
else: | |
return e | |
class SourcelessBuilder: | |
""" | |
Like builder, but stateless and does not require a source. Useful for simple type->VT objects, or objects | |
that are being created/evaporated during inlining (ex: consider a locally made list of tensors we then iterate over | |
.), such a list should not show up as an artifact from inputs, nor in reconstruction, nor in the graph. However, | |
there may be reasons to represent it as a ListVariable internally. | |
NOTE - Objects produced here are born UNGUARDED due to the nature of sources! | |
NOTE - This class is very new! It will have some rough edges, but it was created to stem the bleeding of giant | |
if/else type->VariableTracker trees that were cropping up all over dynamo. | |
""" | |
def __call__(self, tx, value) -> VariableTracker: | |
if isinstance(value, VariableTracker): | |
# This is always valid to call, and useful for recursive calls. | |
return value | |
if isinstance(value, dataclasses._HAS_DEFAULT_FACTORY_CLASS): | |
return UserDefinedObjectVariable(value) | |
if ConstantVariable.is_literal(value): | |
return SourcelessBuilder.wrap_constant_literal(value) | |
elif callable(value) and trace_rules.lookup_callable(value) is not None: | |
if is_callable_allowed(value): | |
self.tx.output.has_user_defined_allowed_in_graph = True | |
return trace_rules.lookup_callable(value)(value) | |
elif is_function_or_wrapper(value): | |
return trace_rules.lookup(value)(value) | |
elif isinstance(value, enum.Enum): | |
return EnumVariable(value) | |
elif isinstance(value, (type, abc.ABCMeta)): | |
return UserDefinedClassVariable(value) | |
elif isinstance(value, dict): | |
items = {self(tx, k): self(tx, v) for k, v in value.items()} | |
return ConstDictVariable(items, mutable_local=MutableLocal()) | |
elif isinstance(value, set): | |
# Nb. value is a set here so the iteration below is non-deterministic! | |
return SetVariable( | |
[self(tx, x) for x in value], mutable_local=MutableLocal() | |
) | |
elif isinstance(value, (tuple, list)): | |
cls = BaseListVariable.cls_for(type(value)) | |
return cls([self(tx, x) for x in value], mutable_local=MutableLocal()) | |
elif isinstance(value, types.MethodWrapperType): | |
return MethodWrapperVariable(value) | |
elif PlacementVariable.is_placement(value): | |
return PlacementVariable(value) | |
elif DeviceMeshVariable.is_device_mesh(value): | |
return DeviceMeshVariable(value) | |
unimplemented(f"Unexpected type in sourceless builder {type(value)}") | |
def wrap_constant_literal(value): | |
assert ConstantVariable.is_literal(value) | |
return ConstantVariable.create(value=value) | |