import contextlib import warnings import weakref from typing import ContextManager, Dict, List, Optional, Tuple, TYPE_CHECKING import torch from torch._C._functorch import ( _add_batch_dim, _unwrap_functional_tensor, _wrap_functional_tensor, current_level, get_unwrapped, is_batchedtensor, is_functorch_wrapped_tensor, is_gradtrackingtensor, maybe_get_bdim, maybe_get_level, peek_interpreter_stack, TransformType, ) from torch._guards import Source from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._python_dispatch import ( is_traceable_wrapper_subclass, transform_subclass, ) from torch.utils.weak import WeakIdRef if TYPE_CHECKING: # Import the following modules during type checking to enable code intelligence features, # Do not import unconditionally, as they import sympy and importing sympy is very slow from torch.fx.experimental.symbolic_shapes import SymbolicContext DimList = List def safe_is_leaf(t): try: return t.is_leaf except RuntimeError: # inference mode can trigger this return False def safe_grad(t): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") return t.grad def assert_eq(a, b): assert a == b, f"{a} != {b}" def assert_metadata_eq(assert_eq, m1, m2, *, skip_symbolic=False): def go(m1, m2): assert_eq(m1.dtype, m2.dtype) if not skip_symbolic: assert_eq(m1.shape, m2.shape) assert_eq(m1.requires_grad, m2.requires_grad) assert_eq(m1.is_leaf, m2.is_leaf) assert_eq(m1.grad_fn is None, m2.grad_fn is None) assert_eq(m1.is_sparse, m2.is_sparse) assert_eq(m1.is_inference(), m2.is_inference()) assert_eq(m1.is_conj(), m2.is_conj()) assert_eq(m1.is_neg(), m2.is_neg()) assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None) if safe_grad(m1) is not None: go(safe_grad(m1), safe_grad(m2)) if m1.is_sparse: assert_eq(m1.dense_dim(), m2.dense_dim()) assert_eq(m1.sparse_dim(), m2.sparse_dim()) assert_eq(m1.is_coalesced(), m2.is_coalesced()) else: if not skip_symbolic: assert_eq(m1.stride(), m2.stride()) assert_eq(m1.storage_offset(), m2.storage_offset()) assert_eq(m1._is_view(), m2._is_view()) if m1._is_view(): go(m1._base, m2._base) # TODO: test if is resizable (no direct query for this atm) # TODO: audit AutogradMeta to see if it matches # TODO: test forward AD return go(m1, m2) def is_sparse_coo(t): return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo def is_sparse_compressed(t): return isinstance(t, torch.Tensor) and t.layout in { torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc, } def is_sparse_any(t): return is_sparse_coo(t) or is_sparse_compressed(t) # This is a class for converting multiple tensors into meta tensors which # share the same view/storage structure. The operation model is you allocate # one of these, and then call it repeatedly on all the tensors you want to # convert. It's important to use the same object for tensors you want to # share storage because this is how we correlate shared storages to the same # meta storages. This class will hold weak references to cached tenosrs # and tensor storages. class MetaConverter: def __init__(self): self.storage_memo = {} self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() self.maybe_storages_to_delete = [] self.check_expired_frequency = 128 self.check_expired_count = 0 self.hit = 0 self.miss = 0 self.del_hook = None self.arg_cnt = 0 def successful(self): return self.hit > 0 and self.miss == 0 def check_for_expired_weak_storages(self): new_li = [] stor_to_delete = [] for obj in self.maybe_storages_to_delete: if not obj.expired(): new_li.append(obj) else: stor_to_delete.append(obj) for obj in stor_to_delete: self.storage_memo.pop(obj, None) self.maybe_storages_to_delete = new_li # if for some reason we have aquired many storages which have not expired # even though a tensor with their storage has expired (aliasing or otherwise) # check for expired storages less often so as to bound the amount of work we # do checking for expired storages self.check_expired_frequency = max( self.check_expired_frequency, len(self.maybe_storages_to_delete) ) def get_tensor_memo(self, t): return self.tensor_memo.get(WeakIdRef(t), None) def set_tensor_memo(self, t, v): # hold a weak ref to self, otherwise it will be kept alive # by the del_ten closure self_weak_ref = weakref.ref(self) if is_sparse_any(t) or t.is_mkldnn or is_functorch_wrapped_tensor(t): weak_st = None else: weak_st = StorageWeakRef(t._typed_storage()) tensor_ref_key = WeakIdRef(t) def del_ten(): # tensor outlives the converter self_ref = self_weak_ref() if self_ref is None: return # on shutdown, tensor_ref_key may not be in memo self_ref.tensor_memo.pop(tensor_ref_key, None) if weak_st and weak_st.expired(): self_ref.storage_memo.pop(weak_st, None) elif weak_st is not None: # [expired-storages] # NB: even though the tensor has died, # the deallocation of its storage can take longer, # even when the storage has no other uses/views. # In this case, the StorageWeakRef object will be kept alive # longer than it needs to be, however the storage itself # will be deallocated. We retain the possibly dead storages # and periodically check if any of them are expired and # can be freed. self_ref.maybe_storages_to_delete.append(weak_st) weakref.finalize(t, del_ten) self.tensor_memo[tensor_ref_key] = v # NB: doesn't actually return a storage, because meta storage is # not supported def meta_storage(self, s, callback): # NB: TypedStorage is freshly allocated and cannot be used as hash # key index. # Use a Weak Ref to s in order to not leak memory swr = StorageWeakRef(s) if swr not in self.storage_memo: self.storage_memo[swr] = callback( lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta") ).untyped_storage() return self.storage_memo[swr] # This function assumes that it's possible to do the conversion # NB: name here is used in a conventional way by Dynamo; it corresponds # precisely to the Source.name() of the tensor we're fakeifying and # corresponds to a valid Python expression. When we construct sub-names # as part of this process, we will maintain this invariant! (Even though # other users of this may not need it this property to be upheld.) def meta_tensor( self, t, shape_env=None, callback=lambda t: t(), source: Optional[Source] = None, symbolic_context: Optional["SymbolicContext"] = None, ): if source is None: from torch._dynamo.source import ConstantSource # TODO: make a dedicated UnknownSource for this? source = ConstantSource( f"__meta_utils_unknown_tensor{len(self.tensor_memo)}" ) # This indicates you set no_dispatch() before calling into this # function. This is an error: we may be creating fake tensors and # will perform operations on them which need fake tensor mode to # be active. You will segfault if you are in a no_dispatch() block. assert not torch._C._dispatch_tls_local_exclude_set().has( torch._C.DispatchKey.Python ) arg_cnt = self.arg_cnt self.arg_cnt += 1 # When we make as_strided calls, we end up generating a guard # that the new as_strided tensor is in bounds for the old storage # for the base (since as_strided calls can "bust" out of their # bounding box.) This guard is unnecessary: if a user is able # to provide us a tensor with the view base setup this way, we # don't need to produce a guard, because the fact that they # were able to produce the view base means its in bounds. # # Now, ordinarily, this guard would be harmless. However, the # generated guard refers to variables bound on the base variable. # At the moment, Dynamo doesn't actually guard on x._base, because # according to Voz this results in a lot of spurious invalidations, # and also if the user doesn't directly make use of _base, its # pointless anyway (because programs should be parametric over # whether or not the input tensor is a view or not--unless you're # mutating the input, but that's a whole 'nother ballgame). So # for expediency, we suppress these guards so we don't have to # deal with this (yet, anyway.) # # NB: An old version of this code suppressed guards for ALL operations # happening during meta conversion, not just as_strided calls. # This is too aggressive: we do duck sizing and 0/1 simplification # as we allocate variables, and we do need to register guards for # these cases. maybe_suppress = contextlib.nullcontext if shape_env is not None: maybe_suppress = shape_env.suppress_guards def sym_sizes_strides_storage_offset( t, src, symbolic_context=symbolic_context ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: if shape_env is not None: fake_mode = torch._subclasses.fake_tensor.maybe_get_fake_mode(t) if fake_mode is not None and fake_mode.shape_env is shape_env: # Don't reallocate the sizes; the shape envs are the same, # so reuse the old sizes/strides/etc return (t.size(), t.stride(), t.storage_offset()) else: return shape_env.create_symbolic_sizes_strides_storage_offset( t, src, symbolic_context=symbolic_context, ) else: assert symbolic_context is None return (t.size(), t.stride(), t.storage_offset()) def empty_create(inner_t, inner_src, symbolic_context=symbolic_context): ( inner_sizes, inner_strides, inner_storage_offset, ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context) return torch.empty_strided( inner_sizes, inner_strides, dtype=inner_t.dtype, device="meta", ) # Creates a subclass instance with empty inner tensors according to the specified # symbolic context. def empty_create_subclass( t, outer_size, outer_stride, symbolic_context=symbolic_context, callback=callback, source=source, ): from torch._dynamo.source import AttrSource from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext assert symbolic_context is None or isinstance( symbolic_context, SubclassSymbolicContext ) # Note: transform_subclass will use __tensor_unflatten__ to generate # a fresh subclass wrapper with outer sizes / strides according to the # outer symbolic context (passed in to this function). Inner size / stride # / storage offset symbols are allocated according to the appropriate inner # symbolic contexts, after which the checks in transform_subclass() will # relate them to the outer metadata as possible. return transform_subclass( t, lambda attr, inner_t: callback( lambda: empty_create( inner_t, AttrSource(source, attr), symbolic_context=( None if symbolic_context is None else symbolic_context.inner_contexts[attr] ), ) ), outer_size=outer_size, outer_stride=outer_stride, ) # Returns an all-dynamic symbolic context used for metafying the given tensor with # fully dynamic dims. This is useful when fake-ifying intermediate tensors in # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we # don't want to over-specialize during view replay. def all_dynamic_symbolic_context(t, source, shape_env, callback): from torch._dynamo.source import AttrSource from torch.fx.experimental.symbolic_shapes import ( DimDynamic, StatelessSymbolicContext, SubclassSymbolicContext, SymbolicContext, ) view_base_context: Optional[SymbolicContext] = None if t._is_view(): view_base_context = all_dynamic_symbolic_context( t._base, AttrSource(source, "_base"), shape_env, callback ) t_symbolic_context: SymbolicContext t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.dim() if is_traceable_wrapper_subclass(t): inner_contexts: Dict[str, SymbolicContext] = {} attrs, _ = t.__tensor_flatten__() for attr in attrs: assert isinstance(attr, str) inner = getattr(t, attr) inner_contexts[attr] = all_dynamic_symbolic_context( inner, AttrSource(source, attr), shape_env, callback ) t_symbolic_context = SubclassSymbolicContext( dynamic_sizes=t_dynamic_sizes, constraint_sizes=[None] * t.dim(), inner_contexts=inner_contexts, tensor_source=source, view_base_context=view_base_context, ) else: t_symbolic_context = StatelessSymbolicContext( dynamic_sizes=t_dynamic_sizes, constraint_sizes=[None] * t.dim(), view_base_context=view_base_context, ) return t_symbolic_context # Returns a fake-ified version of an input view tensor t, given an already fake-ified # base. At a high level, we want two things: # 1. fake_t should have the same view relationship to the given fake base as the # input t has to its _base. # 2. fake_t should have symbolic sizes / strides / storage offset according to the # appropriate symbolic context (i.e. from the automatic dynamic algorithm). # # We currently take different strategies across view types: # * For dense -> dense views, accomplish both (1) and (2) simultaneously via an # as_strided() call on the fake-ified base, passing symbolic metadata. # * For views involving subclasses, perform view replay using view funcs to # achieve (1). It's necessary for (2) to swap out any closed-over state in # the view funcs with symbolicized SymInts and fake-ified tensors. Doing this # avoids specialization (and thus over-eager simplification of symbols) that # could occur during view replay on the fake-ified base. # # Examples: # * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled # with an as_strided() call on the fake base passing symbolic metadata. # * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg # is made symbolic to avoid invalid specialization and view replay is then # done to reconstruct the view. # * _nested_from_jagged(values, offsets) is a dense -> subclass view # that returns a subclass instance from a dense values tensor. The offsets # tensor is closed over in the view func, as it can be considered view metadata. # First, the offsets tensor is fake-ified according to the inner symbolic # context and with the correct relationship to the outer size / stride metadata. # Then view replay is done, swapping in the fake offsets so the view replay output # is fully fake with no invalid specialization. def view_from_base(base, t, source=source, shape_env=shape_env): # fake-ify t's metadata according to the outer symbolic context (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( t, source ) if not is_traceable_wrapper_subclass( t ) and not is_traceable_wrapper_subclass(base): # Dense -> Dense view case uses as_strided() to construct view relationship. # TODO: Change this logic to use view replay for consistency? # It's likely there is no view func available. return base.as_strided(sizes, strides, storage_offset) from torch._dynamo.source import EphemeralSource from torch.fx.experimental.symbolic_shapes import sym_eq def symint_visitor_fn(s): if shape_env is None: return s # NB: The symbol here is expected to be simplified out because we a priori # allocate inner and outer symbols according to the appropriate symbolic # contexts and prefer those over this symbol during symbol simplification # (via usage of EphemeralSource below). This -shouldn't- happen, but if # this symbol somehow leaks out beyond the view tensor's shape metadata, our # assumption of it being simplified out will fail and it may be guarded on, # which will hard error. sym_source = EphemeralSource("symint_visitor_fn") symbol = shape_env.create_symbol(s, sym_source) return shape_env.create_symintnode(symbol, hint=s, source=sym_source) real_to_fake_mapping = {} if is_traceable_wrapper_subclass(t): # Fake-ify t naively here; this is only done so we can get fake-ified inner # tensors with the correct relationships to the outer sizes / strides for use # in view replay. It's done beforehand here because it's not easy to do when # visiting tensors one-by-one during view replay. # # Example: # Consider a Dense -> NJT view. NJT has (values, offsets) components and we # want a view of values with the offsets closed over. As the offsets component # is needed to describe the output view, it's important that it's fakeified # correctly. fake_t = empty_create_subclass( t, outer_size=sizes, outer_stride=strides ) attrs, _ = fake_t.__tensor_flatten__() for attr in attrs: real_to_fake_mapping[getattr(t, attr)] = getattr(fake_t, attr) def tensor_visitor_fn( visited_t, shape_env=shape_env, callback=callback, source=source ): # It's possible to close over an undefined tensor (e.g. NJT's lengths). if visited_t is None: return None # Fake inner tensors of view subclasses will come from the mapping built above. fake_visited_t = real_to_fake_mapping.get(visited_t, None) if fake_visited_t is not None: return fake_visited_t # For other closed-over tensor state, fake-ify it as all dynamic with an # ephemeral source. This avoids invalid specialization during view replay. # If we find that in practice the usage of ephemeral sources isn't enough # to guarantee that we don't have guards on these symbols, we may need to # explicitly suppress guards (as is done for _base in the dense -> dense # view case). temp_source = EphemeralSource("tensor_visitor_fn") return self.meta_tensor( visited_t, shape_env, callback, source=temp_source, symbolic_context=all_dynamic_symbolic_context( visited_t, temp_source, shape_env, callback ), ) # Replay the view, swapping out any non-symbolic SymInts or real tensors # for symbolic SymInts or fake tensors. fake_t = t._view_func_unsafe(base, symint_visitor_fn, tensor_visitor_fn) # Ensure the output has symbolic shapes according to the outer symbolic context. # These checks should simplify out any symbols created for closed-over view func # SymInts. torch._check(sym_eq(fake_t.size(), sizes)) torch._check(sym_eq(fake_t.stride(), strides)) torch._check(sym_eq(fake_t.storage_offset(), storage_offset)) return fake_t # see expired-storages self.check_expired_count += 1 if self.check_expired_count >= self.check_expired_frequency: self.check_for_expired_weak_storages() self.check_expired_count = 0 if self.get_tensor_memo(t) is None: with torch.inference_mode(t.is_inference()): if t.is_sparse: is_leaf = safe_is_leaf(t) # The lambda function below is similar to # `t.to(device='meta')` except the latter # preserves nnz value r = callback( lambda: torch.ops.aten._sparse_coo_tensor_with_dims( t.sparse_dim(), t.dense_dim(), t.shape, dtype=t.dtype, layout=torch.sparse_coo, device="meta", ) ) assert safe_is_leaf(r), "the callback you passed in doesn't detach" # Note [is_coalesced is dispatched] # Strangely enough, is_coalesced() is a dispatched operator, # which means that it will get caught by fake tensor mode. # Ordinarily this would error, but there's some logic in # fake tensor ensure this doesn't happen. r._coalesced_(t.is_coalesced()) if t.requires_grad: r.requires_grad = True if t.requires_grad and not is_leaf: with torch.enable_grad(): r = r.clone() r._coalesced_(t.is_coalesced()) elif is_sparse_compressed(t): is_leaf = safe_is_leaf(t) def mk_meta(): nnz = 0 batch_dim = t.ndim - t.sparse_dim() - t.dense_dim() batch_size = t.shape[:batch_dim] if t.layout in {torch.sparse_csr, torch.sparse_bsr}: index_dtype = t.crow_indices().dtype compressed_indices = torch.empty( t.crow_indices().shape, device="meta", dtype=index_dtype ) plain_indices = torch.empty( (*t.col_indices().shape[:-1], nnz), device="meta", dtype=index_dtype, ) else: index_dtype = t.ccol_indices().dtype compressed_indices = torch.empty( t.ccol_indices().shape, device="meta", dtype=index_dtype ) plain_indices = torch.empty( (*t.row_indices().shape[:-1], nnz), device="meta", dtype=index_dtype, ) values_shape = t.values().shape values = torch.empty( ( *values_shape[:batch_dim], nnz, *values_shape[batch_dim + 1 :], ), dtype=t.dtype, device="meta", ) return torch.ops.aten.sparse_compressed_tensor( compressed_indices, plain_indices, values, t.shape, layout=t.layout, dtype=t.dtype, device="meta", ) # `mk_meta()` is similar to `t.to(device='meta'))` # except `to('meta')` preserves nnz value while # `mk_meta` result has nnz == 0. r = callback(mk_meta) assert safe_is_leaf(r), "the callback you passed in doesn't detach" if t.requires_grad: r.requires_grad = True if t.requires_grad and not is_leaf: with torch.enable_grad(): r = r.clone() elif t.is_nested and not is_traceable_wrapper_subclass(t): # TODO: Handle this better in Dynamo? # There are checks there now, but this can still be triggered by a dense # tensor graph input that is a view of a strided NT. from torch._dynamo.exc import unimplemented unimplemented( "strided nested tensors are not supported by meta conversion" ) elif t.is_mkldnn: is_leaf = safe_is_leaf(t) sizes, strides, _storage_offset = sym_sizes_strides_storage_offset( t, source ) r = callback( lambda: torch.empty_strided( sizes, strides, dtype=t.dtype, device="meta" ) ) assert safe_is_leaf(r), "the callback you passed in doesn't detach" if t.requires_grad: r.requires_grad = True if t.requires_grad and not is_leaf: with torch.enable_grad(): r = r.clone() elif is_functorch_wrapped_tensor(t): if t._is_view(): from torch._dynamo.exc import unimplemented unimplemented( "view functorch tensors are not supported by meta conversion" ) # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor) # in a FakeTensor def _to_fake_tensor(t): if is_batchedtensor(t): ft = _to_fake_tensor(get_unwrapped(t)) lvl = maybe_get_level(t) bdim = maybe_get_bdim(t) r = _add_batch_dim(ft, bdim, lvl) elif is_gradtrackingtensor(t): disable_functorch = torch._C._DisableFuncTorch with disable_functorch(): ft = _to_fake_tensor(get_unwrapped(t)) lvl = torch._C._functorch.maybe_get_level(t) r = torch._C._functorch._wrap_for_grad(ft, lvl) is_leaf = safe_is_leaf(t) if t.requires_grad and safe_is_leaf(r): r.requires_grad = True elif t.requires_grad and not is_leaf: with torch.enable_grad(): r = r.clone() else: sizes = t.size() strides = t.stride() r = callback( lambda: torch.empty_strided( sizes, strides, dtype=t.dtype, device="meta", ) ) return r r = _to_fake_tensor(t) elif t._is_view(): # Construct views in two steps: recursively meta-fy their # base, and then create view(s) off that. NB: doing it # directly from storage is WRONG because this won't cause # version counters to get shared. assert t._is_view() base_symbolic_context = None if shape_env and symbolic_context is not None: from torch.fx.experimental.symbolic_shapes import ( StatelessSymbolicContext, ) assert isinstance(symbolic_context, StatelessSymbolicContext) # NB: This should generally be set when the input is a view, # but the exception right now is for fake-ifying grads, which is # a work in progress. if symbolic_context.view_base_context is not None: base_symbolic_context = symbolic_context.view_base_context base = self.meta_tensor( t._base, shape_env, callback, source=torch._dynamo.source.AttrSource(source, "_base"), symbolic_context=base_symbolic_context, ) def is_c_of_r(complex_dtype, real_dtype): return ( utils.is_complex_dtype(complex_dtype) and utils.corresponding_real_dtype(complex_dtype) == real_dtype ) # In some situations, MetaConverter may be called in a # context where autograd is disabled. For the _is_view # assert to pass, we have to setup the autograd view # metadata anyway. Do this by reenabling the # ADInplaceOrView key. This is kind of a hack. old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded( torch._C.DispatchKey.ADInplaceOrView ) torch._C._dispatch_tls_set_dispatch_key_excluded( torch._C.DispatchKey.ADInplaceOrView, False ) try: if base.dtype == t.dtype: pass elif is_c_of_r(base.dtype, t.dtype): base = torch.view_as_real(base) elif is_c_of_r(t.dtype, base.dtype): base = torch.view_as_complex(base) else: # This is not guaranteed to succeed. If it fails, it # means there is another dtype-converting view function # that hasn't been handled here base = base.view(t.dtype) # This is very tricky. Naively, you might expect this # to hold: # # if t.requires_grad and not safe_is_leaf(t) # assert t._base.requires_grad # # But it's not true! As you can see in the following # program: # # x = torch.zeros(4) # y = x.view(1, 4) # y.requires_grad = True # z = y.view(1, 1, 4) # assert z._base is x # # So we may have to do *two* views out of the base to # recreate this situation. if safe_is_leaf(t): # Leaf views that track view metadata are created by # creating a view inside a no_grad block with torch.no_grad(), maybe_suppress(): r = view_from_base(base, t) # As it's a leaf, we can directly assign requires_grad r.requires_grad = t.requires_grad else: if t._base.requires_grad == t.requires_grad: # Easy case, just run the view op with torch.enable_grad(), maybe_suppress(): r = view_from_base(base, t) # NB: We don't actaully faithfully replicate # autograd connectivity, but that doesn't matter # today. See following for more info: # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913 else: # Obscure case. Create a leaf view and give it the # correct requires_grad, then do the final view. # NB: Can't have a non-leaf without requiring grad! assert t.requires_grad with torch.no_grad(): mid = base.view(base.shape) mid.requires_grad = t.requires_grad with torch.enable_grad(), maybe_suppress(): r = view_from_base(mid, t) # The CreationMeta influences whether or not inplace # mutation is an error or not. So we need to make # sure we properly propagate this as well. torch._C._autograd._set_creation_meta( r, torch._C._autograd._get_creation_meta(t) ) finally: torch._C._dispatch_tls_set_dispatch_key_excluded( torch._C.DispatchKey.ADInplaceOrView, old_exclude ) else: is_leaf = safe_is_leaf(t) ( sizes, strides, storage_offset, ) = sym_sizes_strides_storage_offset(t, source, symbolic_context) # If we have a subclass that desugars into dense tensors, # perform our callback on each inner tensor. if is_traceable_wrapper_subclass(t): r = empty_create_subclass( t, outer_size=sizes, outer_stride=strides ) else: r = callback( lambda: torch.empty_strided( sizes, strides, dtype=t.dtype, device="meta", ) ) assert safe_is_leaf(r), "the callback you passed in doesn't detach" if t.requires_grad: r.requires_grad = t.requires_grad if not is_leaf: # Fake up some autograd history. with torch.enable_grad(): # preserve_format is the default, but we want to # emphasize how important it is to preserve # format here r = r.clone(memory_format=torch.preserve_format) # Graph-Break for wrapped tensors if not ( is_batchedtensor(t) or is_gradtrackingtensor(t) ) and torch._C._functorch.is_functorch_wrapped_tensor(t): return NotImplemented s = t.untyped_storage() swr = StorageWeakRef(s) if swr not in self.storage_memo and ( r.is_nested or ( r.stride() == strides and r.storage_offset() == storage_offset ) ): # You're normal and happy, install the fresh storage into the memo self.storage_memo[swr] = r.untyped_storage() else: # You're in crazy town; somehow you gave us a tensor # that wasn't a view, but had nonzero storage offset, # nontrivial strides (such that clone() couldn't # preserve them), or already aliases with another # tensor's storage. The most typical way to end # up here is with set_. So use set_ to bludgeon this # in. r_s = self.meta_storage(s, callback=callback) # NB: In principle, this should always work, but there # is some subtle difference in the autograd metadata # that means we will backprop the set_ call, even if # r is declared as an input to grad. # See https://github.com/pytorch/pytorch/issues/87956 # for the reproducer. # NB: The in_kernel_invocation_manager here is necessary # for fake tensor. If we run the set_ call with fake # tensor on, r will improperly report that it is NOT a # meta tensor but a cpu tensor, and then the set_ call # will fail due to device mismatch. no_dispatch() is # not enough, because the fake tensor will still claim # to be a CPU tensor and you'll end up in the CPU # kernel. Arguably this is a hack; a cleaner way to # solve this is to have a FakeStorage concept which # would report it's CPU device--no problem now! But # this is difficult to do because we don't have storage # subclasses. Relevant test is # DynamicShapesFunctionTests::test_add_dynamic_shapes in # test/dynamo/test_dynamic_shapes.py maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext() from torch._subclasses.fake_tensor import ( in_kernel_invocation_manager, maybe_get_fake_mode, ) mb_fake_mode = maybe_get_fake_mode(r) if mb_fake_mode is not None: maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode) with maybe_fake_mgr, torch.no_grad(): r.set_(r_s, storage_offset, sizes, strides) if safe_grad(t) is not None: from torch._dynamo.source import AttrSource # TODO: Use a valid grad-specific symbolic context instead of recycling # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view(). r.grad = self.meta_tensor( safe_grad(t), shape_env, callback, source=AttrSource(source, "grad"), symbolic_context=symbolic_context, ) torch._C._set_conj(r, t.is_conj()) torch._C._set_neg(r, t.is_neg()) # This can be skipped if necessary for performance reasons assert_metadata_eq(assert_eq, t, r, skip_symbolic=True) self.set_tensor_memo(t, r) return self.get_tensor_memo(t) def __call__( self, t, shape_env=None, *, callback=lambda t: t(), source=None, symbolic_context=None, ): # TODO: zero tensors? We appear to have eliminated them by # excluding complex for now if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t): if t.device.type != "xla" and any( [ t.is_quantized, t._is_view() and t._base is not None and t._base.is_sparse, torch._is_functional_tensor(t), t.device.type in ("lazy"), # We need a way to test if a tensor is batched but there # is no official APi to do it # torch._C._is_batched(t), ] ): # TODO: sparse should support meta # NB technically to('meta') does work but our logging # instrumentation will see the meta conversions and the # tests all break so we just exclude this. In any case # the to conversion isn't really right anyhow. if torch._is_functional_tensor(t) and t.device.type != "lazy": if t._is_view(): raise RuntimeError( "Cannot safely fakify a view because this process drops the view information right now." ) st = peek_interpreter_stack() assert ( st is None or st.key() == TransformType.Functionalize ), "Expect st to be either None or have Functionalize transform key." if st is None: # the case of AOTAutograd torch._sync(t) unwrap_t = torch._from_functional_tensor(t) with torch._dispatch.python.suspend_functionalization(): fake_t = self.meta_tensor( unwrap_t, shape_env=shape_env, callback=callback, source=source, symbolic_context=symbolic_context, ) out = torch._to_functional_tensor(fake_t) torch._mirror_autograd_meta_to(fake_t, out) return out else: # torch.func.functionalize reapply_views = torch._C._functionalization_reapply_views_tls() unwrap_t = _unwrap_functional_tensor(t, reapply_views) pop_st_ctx = ( torch._functorch.pyfunctorch.temporarily_pop_interpreter_stack() ) with pop_st_ctx: fake_t = self.meta_tensor( unwrap_t, shape_env=shape_env, callback=callback, source=source, symbolic_context=symbolic_context, ) return _wrap_functional_tensor(fake_t, current_level()) self.miss += 1 return NotImplemented else: self.hit += 1 disable_functorch = torch._C._DisableFuncTorch with disable_functorch(): r = self.meta_tensor( t, shape_env=shape_env, callback=callback, source=source, symbolic_context=symbolic_context, ) if type(t) is torch.nn.Parameter: # NB: Cannot directly use Parameter constructor # because that would force a detach, not desirable r._is_param = True return r elif torch.overrides.is_tensor_like(t): self.miss += 1 return NotImplemented else: # non-Tensor types don't count as hit or miss return t import torch._prims_common as utils