|
|
|
from __future__ import annotations |
|
|
|
import contextlib |
|
|
|
import dataclasses |
|
import warnings |
|
import weakref |
|
from dataclasses import dataclass |
|
from typing import ( |
|
Any, |
|
Callable, |
|
ClassVar, |
|
ContextManager, |
|
Dict, |
|
List, |
|
Optional, |
|
Tuple, |
|
Type, |
|
TYPE_CHECKING, |
|
Union, |
|
) |
|
from typing_extensions import TypeAlias |
|
|
|
import torch |
|
from torch._C._autograd import CreationMeta |
|
from torch._C._functorch import ( |
|
_add_batch_dim, |
|
_unwrap_functional_tensor, |
|
_wrap_functional_tensor, |
|
get_unwrapped, |
|
is_batchedtensor, |
|
is_functorch_wrapped_tensor, |
|
is_gradtrackingtensor, |
|
is_legacy_batchedtensor, |
|
maybe_get_bdim, |
|
maybe_get_level, |
|
peek_interpreter_stack, |
|
) |
|
from torch._logging import trace_structured |
|
from torch.utils._mode_utils import no_dispatch |
|
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
|
from torch.utils.weak import WeakIdKeyDictionary |
|
|
|
if TYPE_CHECKING: |
|
from torch._C._functorch import CInterpreter |
|
from torch._guards import Source |
|
|
|
|
|
from torch._subclasses.fake_tensor import FakeTensorMode |
|
|
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext |
|
|
|
DimList = List |
|
|
|
|
|
def safe_is_leaf(t): |
|
try: |
|
return t.is_leaf |
|
except RuntimeError: |
|
|
|
return False |
|
|
|
|
|
def safe_grad(t): |
|
with warnings.catch_warnings(): |
|
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") |
|
return t.grad |
|
|
|
|
|
def assert_eq(a, b): |
|
assert a == b, f"{a} != {b}" |
|
|
|
|
|
def assert_metadata_eq( |
|
assert_eq, |
|
m1: Union[MetaTensorDesc, torch.Tensor], |
|
m2: torch.Tensor, |
|
*, |
|
skip_symbolic=False, |
|
skip_leaf=False, |
|
): |
|
if isinstance(m1, torch.Tensor): |
|
m1 = MetaTensorDescriber().describe_tensor(m1) |
|
|
|
def go(m1, m2): |
|
assert_eq(m1.dtype, m2.dtype) |
|
if not skip_symbolic: |
|
assert_eq(m1.shape, m2.shape) |
|
assert_eq(m1.requires_grad, m2.requires_grad) |
|
if not skip_leaf: |
|
assert_eq(m1.is_leaf, m2.is_leaf) |
|
|
|
|
|
assert_eq(m1.is_sparse, m2.is_sparse) |
|
assert_eq(m1.is_inference, m2.is_inference()) |
|
assert_eq(m1.is_conj, m2.is_conj()) |
|
assert_eq(m1.is_neg, m2.is_neg()) |
|
assert_eq(m1.grad is not None, safe_grad(m2) is not None) |
|
if m1.grad is not None: |
|
go(m1.grad, safe_grad(m2)) |
|
if m1.is_sparse: |
|
assert_eq(m1.dense_dim, m2.dense_dim()) |
|
assert_eq(m1.sparse_dim, m2.sparse_dim()) |
|
assert_eq(m1.is_coalesced, m2.is_coalesced()) |
|
else: |
|
if not skip_symbolic: |
|
assert_eq(m1.stride, m2.stride()) |
|
assert_eq(m1.storage_offset, m2.storage_offset()) |
|
assert_eq(m1.is_view, m2._is_view()) |
|
if m1.is_view: |
|
go(m1.base, m2._base) |
|
|
|
|
|
|
|
|
|
return go(m1, m2) |
|
|
|
|
|
def is_sparse_coo(t): |
|
return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo |
|
|
|
|
|
def is_sparse_compressed_layout(layout): |
|
return layout in { |
|
torch.sparse_csr, |
|
torch.sparse_csc, |
|
torch.sparse_bsr, |
|
torch.sparse_bsc, |
|
} |
|
|
|
|
|
def is_sparse_compressed(t): |
|
return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout) |
|
|
|
|
|
def is_sparse_any(t): |
|
return is_sparse_coo(t) or is_sparse_compressed(t) |
|
|
|
|
|
|
|
MetaStorageId: TypeAlias = int |
|
MetaTensorId: TypeAlias = int |
|
|
|
|
|
DESCRIBER_NEXT_ID = 0 |
|
|
|
|
|
class MetaTensorDescriber: |
|
""" |
|
Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc |
|
for it, which is enough information to reconstruct a meta tensor/fake tensor |
|
corresponding to a Tensor as faithfully as possible. |
|
|
|
This is a stateful conversion object because we keep track of the IDs |
|
of the tensors/storages passed to us, so we can consistently give |
|
the same ID when we see the same tensor/storage. |
|
""" |
|
|
|
def __init__(self, *, copy_data=False): |
|
global DESCRIBER_NEXT_ID |
|
self.id = DESCRIBER_NEXT_ID |
|
DESCRIBER_NEXT_ID += 1 |
|
self.next_tensor_id: MetaTensorId = 0 |
|
self.next_storage_id: MetaStorageId = 0 |
|
|
|
self.lookup_tensor = WeakIdKeyDictionary() |
|
|
|
self.lookup_storage = WeakIdKeyDictionary() |
|
self.copy_data = copy_data |
|
self.traced_tensors = set() |
|
self.traced_storages = set() |
|
|
|
def get_tensor_id(self, t: torch.Tensor): |
|
if t not in self.lookup_tensor: |
|
self.lookup_tensor[t] = self.next_tensor_id |
|
self.next_tensor_id += 1 |
|
return self.lookup_tensor[t] |
|
|
|
def get_storage_id(self, s: torch.UntypedStorage): |
|
if s not in self.lookup_storage: |
|
self.lookup_storage[s] = self.next_storage_id |
|
self.next_storage_id += 1 |
|
return self.lookup_storage[s] |
|
|
|
def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False): |
|
r = MetaStorageDesc( |
|
id=self.get_storage_id(s), |
|
size=s.size(), |
|
|
|
|
|
data=s if self.copy_data else None, |
|
) |
|
if trace and r.id not in self.traced_storages: |
|
trace_structured( |
|
"describe_storage", |
|
metadata_fn=lambda: r.as_json(self.id), |
|
) |
|
self.traced_storages.add(r.id) |
|
return r |
|
|
|
def describe_tensor( |
|
self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False |
|
): |
|
is_leaf = safe_is_leaf(t) |
|
is_view = t._is_view() |
|
is_sparse = t.is_sparse |
|
layout = t.layout |
|
is_nested = t.is_nested |
|
is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t) |
|
is_functorch_wrapped = is_functorch_wrapped_tensor(t) |
|
is_mkldnn = t.is_mkldnn |
|
is_batchedtensor_v = is_batchedtensor(t) |
|
is_legacy_batchedtensor_v = is_legacy_batchedtensor(t) |
|
is_gradtrackingtensor_v = is_gradtrackingtensor(t) |
|
is_functorch_batched_or_grad = is_batchedtensor_v or is_gradtrackingtensor_v |
|
is_functional = torch._is_functional_tensor(t) |
|
|
|
storage = None |
|
|
|
|
|
|
|
storage_offset = 0 |
|
if not ( |
|
is_sparse |
|
or is_sparse_compressed_layout(layout) |
|
or (is_nested and not is_traceable_wrapper_subclass_v) |
|
or is_mkldnn |
|
|
|
|
|
or is_functorch_wrapped |
|
or is_legacy_batchedtensor_v |
|
): |
|
|
|
|
|
storage = self.describe_storage(t.untyped_storage(), trace=trace) |
|
storage_offset = t.storage_offset() |
|
|
|
stride = None |
|
if not ( |
|
is_sparse |
|
or is_sparse_compressed_layout(layout) |
|
or (is_nested and not is_traceable_wrapper_subclass_v) |
|
): |
|
|
|
|
|
|
|
stride = t.stride() |
|
|
|
|
|
|
|
|
|
unwrapped = None |
|
autograd_meta_from = None |
|
current_level = None |
|
if is_batchedtensor_v or is_gradtrackingtensor_v: |
|
unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace) |
|
|
|
|
|
elif is_functional and t.device.type not in ("xla", "lazy"): |
|
if t._is_view(): |
|
raise RuntimeError( |
|
"Cannot safely fakify a view because this process drops the view information right now." |
|
) |
|
if not is_functorch_wrapped: |
|
torch._sync(t) |
|
unwrapped = self.describe_tensor( |
|
torch._from_functional_tensor(t), trace=trace |
|
) |
|
autograd_meta_from = t |
|
else: |
|
reapply_views = torch._C._functionalization_reapply_views_tls() |
|
|
|
unwrapped = self.describe_tensor( |
|
_unwrap_functional_tensor(t, reapply_views), trace=trace |
|
) |
|
|
|
|
|
|
|
current_level = torch._C._functorch.current_level() |
|
|
|
maybe_functorch_stack = None |
|
if is_functorch_wrapped: |
|
with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack: |
|
pass |
|
|
|
attrs = None |
|
ctx = None |
|
type_v = None |
|
if is_traceable_wrapper_subclass_v: |
|
assert hasattr(t, "__tensor_flatten__") |
|
raw_attrs, ctx = t.__tensor_flatten__() |
|
attrs = { |
|
attr: self.describe_tensor(getattr(t, attr), trace=trace) |
|
for attr in raw_attrs |
|
} |
|
type_v = type(t) |
|
|
|
|
|
|
|
r = MetaTensorDesc( |
|
id=self.get_tensor_id(t), |
|
storage=storage, |
|
is_inference=t.is_inference(), |
|
is_leaf=is_leaf, |
|
requires_grad=t.requires_grad, |
|
|
|
|
|
|
|
|
|
|
|
|
|
ndim=t.dim(), |
|
dtype=t.dtype, |
|
is_sparse=is_sparse, |
|
is_mkldnn=is_mkldnn, |
|
is_functorch_wrapped=is_functorch_wrapped, |
|
is_batchedtensor=is_batchedtensor_v, |
|
is_legacy_batchedtensor=is_legacy_batchedtensor_v, |
|
is_gradtrackingtensor=is_gradtrackingtensor_v, |
|
is_view=is_view, |
|
is_conj=t.is_conj(), |
|
is_neg=t.is_neg(), |
|
is_parameter=isinstance(t, torch.nn.Parameter), |
|
is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v, |
|
is_nested=is_nested, |
|
is_functional=is_functional, |
|
layout=layout, |
|
device=t.device, |
|
size=t.size(), |
|
stride=stride, |
|
storage_offset=storage_offset, |
|
dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())), |
|
sparse_dim=t.sparse_dim() |
|
if t.is_sparse or is_sparse_compressed(t) |
|
else None, |
|
dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None, |
|
is_coalesced=t.is_coalesced() if t.is_sparse else None, |
|
|
|
|
|
|
|
crow_indices=self.describe_tensor( |
|
t.crow_indices(), recurse=False, trace=trace |
|
) |
|
if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} |
|
else None, |
|
col_indices=self.describe_tensor( |
|
t.col_indices(), recurse=False, trace=trace |
|
) |
|
if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} |
|
else None, |
|
ccol_indices=self.describe_tensor( |
|
t.ccol_indices(), recurse=False, trace=trace |
|
) |
|
if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} |
|
else None, |
|
row_indices=self.describe_tensor( |
|
t.row_indices(), recurse=False, trace=trace |
|
) |
|
if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} |
|
else None, |
|
values=self.describe_tensor(t.values(), recurse=False, trace=trace) |
|
if recurse and is_sparse_compressed(t) |
|
else None, |
|
grad=self.describe_tensor(safe_grad(t), trace=trace) |
|
if safe_grad(t) is not None |
|
else None, |
|
creation_meta=torch._C._autograd._get_creation_meta(t) |
|
if t._is_view() |
|
else None, |
|
unwrapped=unwrapped, |
|
level=maybe_get_level(t) |
|
if is_batchedtensor_v or is_gradtrackingtensor_v |
|
else None, |
|
bdim=maybe_get_bdim(t) if is_batchedtensor_v else None, |
|
base=self.describe_tensor(t._base, trace=trace) |
|
if recurse and t._is_view() and t._base is not None |
|
else None, |
|
fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t), |
|
view_func=t._view_func_unsafe, |
|
attrs=attrs, |
|
ctx=ctx, |
|
type=type_v, |
|
|
|
|
|
|
|
functorch_stack=maybe_functorch_stack, |
|
autograd_meta_from=autograd_meta_from, |
|
current_level=current_level, |
|
data=t if self.copy_data else None, |
|
) |
|
if trace and r.id not in self.traced_tensors: |
|
trace_structured( |
|
"describe_tensor", |
|
metadata_fn=lambda: r.as_json(self.id), |
|
) |
|
self.traced_tensors.add(r.id) |
|
return r |
|
|
|
|
|
@dataclass(frozen=True) |
|
class MetaStorageDesc: |
|
id: MetaStorageId |
|
size: int |
|
|
|
|
|
data: Optional[torch.UntypedStorage] |
|
|
|
def as_json(self, describer_id): |
|
return { |
|
"id": self.id, |
|
"describer_id": describer_id, |
|
"size": self.size if isinstance(self.size, int) else repr(self.size), |
|
} |
|
|
|
|
|
@dataclass(frozen=True) |
|
class MetaTensorDesc: |
|
id: MetaTensorId |
|
ndim: int |
|
dtype: torch.dtype |
|
device: torch.device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
size: Tuple[int, ...] |
|
dynamo_dynamic_indices: List[int] |
|
|
|
layout: torch.layout = torch.strided |
|
is_inference: bool = False |
|
is_leaf: bool = False |
|
requires_grad: bool = False |
|
is_sparse: bool = False |
|
is_mkldnn: bool = False |
|
is_functorch_wrapped: bool = False |
|
is_batchedtensor: bool = False |
|
is_legacy_batchedtensor: bool = False |
|
is_gradtrackingtensor: bool = False |
|
is_view: bool = False |
|
is_nested: bool = False |
|
is_traceable_wrapper_subclass: bool = False |
|
is_functional: bool = False |
|
is_conj: bool = False |
|
is_neg: bool = False |
|
is_parameter: bool = False |
|
stride: Optional[Tuple[int, ...]] = None |
|
storage_offset: int = 0 |
|
|
|
|
|
|
|
|
|
storage: Optional[MetaStorageDesc] = None |
|
sparse_dim: Optional[int] = None |
|
dense_dim: Optional[int] = None |
|
is_coalesced: Optional[bool] = None |
|
crow_indices: Optional[MetaTensorDesc] = None |
|
col_indices: Optional[MetaTensorDesc] = None |
|
ccol_indices: Optional[MetaTensorDesc] = None |
|
row_indices: Optional[MetaTensorDesc] = None |
|
values: Optional[MetaTensorDesc] = None |
|
unwrapped: Optional[MetaTensorDesc] = None |
|
bdim: Optional[int] = None |
|
base: Optional[MetaTensorDesc] = None |
|
attrs: Optional[Dict[str, MetaTensorDesc]] = None |
|
creation_meta: Optional[CreationMeta] = None |
|
grad: Optional[MetaTensorDesc] = None |
|
|
|
|
|
|
|
_UNSERIALIZABLE: ClassVar[List[str]] = [ |
|
"ctx", |
|
"type", |
|
"fake_mode", |
|
"view_func", |
|
"level", |
|
"current_level", |
|
"functorch_stack", |
|
"autograd_meta_from", |
|
"data", |
|
] |
|
|
|
ctx: Optional[object] = None |
|
type: Optional[Type] = None |
|
fake_mode: Optional[FakeTensorMode] = None |
|
view_func: Optional[ |
|
Callable[ |
|
[ |
|
torch.Tensor, |
|
Callable[[int], int], |
|
Callable[[torch.Tensor], torch.Tensor], |
|
], |
|
torch.Tensor, |
|
] |
|
] = None |
|
|
|
|
|
level: Optional[int] = None |
|
current_level: Optional[int] = None |
|
functorch_stack: Optional[List[CInterpreter]] = None |
|
autograd_meta_from: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
|
data: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def as_json(self, describer_id): |
|
def json(k, v): |
|
|
|
|
|
if k in ["data", "autograd_meta_from"]: |
|
return None |
|
if k in set(MetaTensorDesc._UNSERIALIZABLE): |
|
return repr(v) |
|
if isinstance(v, (torch.device, torch.dtype, torch.layout)): |
|
return repr(v) |
|
if isinstance(v, torch.SymInt): |
|
return repr(v) |
|
if isinstance(v, (tuple, list)): |
|
return [json(k, v1) for v1 in v] |
|
if isinstance(v, (MetaStorageDesc, MetaTensorDesc)): |
|
return v.id |
|
if isinstance(v, CreationMeta): |
|
return str(v) |
|
if k == "attrs" and isinstance(v, dict): |
|
return {k1: v1.id for k1, v1 in v.items()} |
|
return v |
|
|
|
r = { |
|
field.name: json(field.name, getattr(self, field.name)) |
|
for field in dataclasses.fields(self) |
|
if not ( |
|
getattr(self, field.name) is field.default |
|
or ( |
|
field.name == "dynamo_dynamic_indices" |
|
and not getattr(self, field.name) |
|
) |
|
) |
|
} |
|
r.update({"describer_id": describer_id}) |
|
return r |
|
|
|
@property |
|
def shape(self): |
|
return self.size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _safe_copy(dst, src): |
|
if type(src) is not torch.Tensor: |
|
return |
|
dst.copy_(src) |
|
|
|
|
|
def _safe_clone(src): |
|
if type(src) is not torch.Tensor: |
|
return None |
|
return src.clone() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MetaConverter: |
|
def __init__(self, *, copy_data: bool = False): |
|
|
|
self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() |
|
|
|
|
|
self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() |
|
self.hit = 0 |
|
self.miss = 0 |
|
self.del_hook = None |
|
self.arg_cnt = 0 |
|
|
|
|
|
|
|
|
|
self.copy_data = copy_data |
|
self.describer = MetaTensorDescriber(copy_data=copy_data) |
|
|
|
def successful(self): |
|
return self.hit > 0 and self.miss == 0 |
|
|
|
def get_tensor_memo(self, t: MetaTensorDesc): |
|
return self.tensor_memo.get(t.id, None) |
|
|
|
def set_tensor_memo(self, t: MetaTensorDesc, v): |
|
self.tensor_memo[t.id] = v |
|
|
|
def get_storage_memo(self, s: MetaStorageDesc): |
|
return self.storage_memo.get(s.id, None) |
|
|
|
def set_storage_memo(self, s: MetaStorageDesc, v): |
|
self.storage_memo[s.id] = v |
|
|
|
def meta_storage(self, s: MetaStorageDesc, callback): |
|
|
|
|
|
if self.get_storage_memo(s) is None: |
|
r_s = callback( |
|
lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"), |
|
).untyped_storage() |
|
if self.copy_data: |
|
|
|
|
|
with torch.no_grad(), no_dispatch(): |
|
assert s.data is not None |
|
r_s.real_storage = s.data.clone() |
|
self.set_storage_memo(s, r_s) |
|
return r_s |
|
else: |
|
return self.get_storage_memo(s) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def meta_tensor( |
|
self, |
|
t: MetaTensorDesc, |
|
shape_env: Optional[ShapeEnv] = None, |
|
callback=lambda t: t(), |
|
source: Optional[Source] = None, |
|
symbolic_context: Optional[SymbolicContext] = None, |
|
): |
|
if source is None: |
|
from torch._dynamo.source import ConstantSource |
|
|
|
|
|
source = ConstantSource( |
|
f"__meta_utils_unknown_tensor{len(self.tensor_memo)}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
assert not torch._C._dispatch_tls_local_exclude_set().has( |
|
torch._C.DispatchKey.Python |
|
) |
|
arg_cnt = self.arg_cnt |
|
self.arg_cnt += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
maybe_suppress: Callable[[], Any] = contextlib.nullcontext |
|
if shape_env is not None: |
|
maybe_suppress = shape_env.suppress_guards |
|
|
|
def sym_sizes_strides_storage_offset( |
|
t: MetaTensorDesc, src, symbolic_context=symbolic_context |
|
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: |
|
assert t.stride is not None |
|
if shape_env is not None: |
|
fake_mode = t.fake_mode |
|
if fake_mode is not None and fake_mode.shape_env is shape_env: |
|
|
|
|
|
return (t.size, t.stride, t.storage_offset) |
|
else: |
|
|
|
t_size = tuple( |
|
shape_env._maybe_specialize_sym_int_with_hint(sz) |
|
for sz in t.size |
|
) |
|
t_stride = tuple( |
|
shape_env._maybe_specialize_sym_int_with_hint(sd) |
|
for sd in t.stride |
|
) |
|
t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint( |
|
t.storage_offset |
|
) |
|
return shape_env._create_symbolic_sizes_strides_storage_offset( |
|
t_size, |
|
t_stride, |
|
t_storage_offset, |
|
[d in t.dynamo_dynamic_indices for d in range(t.ndim)], |
|
src, |
|
symbolic_context=symbolic_context, |
|
) |
|
else: |
|
return (t.size, t.stride, t.storage_offset) |
|
|
|
def empty_create( |
|
inner_t: MetaTensorDesc, inner_src, symbolic_context=symbolic_context |
|
): |
|
( |
|
inner_sizes, |
|
inner_strides, |
|
inner_storage_offset, |
|
) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context) |
|
return torch.empty_strided( |
|
inner_sizes, |
|
inner_strides, |
|
dtype=inner_t.dtype, |
|
device="meta", |
|
) |
|
|
|
|
|
|
|
def empty_create_subclass( |
|
t: MetaTensorDesc, |
|
outer_size, |
|
outer_stride, |
|
symbolic_context=symbolic_context, |
|
callback=callback, |
|
source=source, |
|
): |
|
from torch._dynamo.source import AttrSource |
|
from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext |
|
|
|
assert t.attrs is not None |
|
assert t.type is not None |
|
|
|
|
|
|
|
assert symbolic_context is None or isinstance( |
|
symbolic_context, SubclassSymbolicContext |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outer_size = outer_size if outer_size is not None else t.size |
|
outer_stride = outer_stride if outer_stride is not None else t.stride |
|
|
|
def transform(attr, inner_t): |
|
r = callback( |
|
lambda: empty_create( |
|
inner_t, |
|
AttrSource(source, attr), |
|
symbolic_context=( |
|
None |
|
if symbolic_context is None |
|
else symbolic_context.inner_contexts[attr] |
|
), |
|
) |
|
) |
|
if self.copy_data: |
|
with torch.no_grad(), no_dispatch(): |
|
r.real_tensor = torch.empty_strided( |
|
inner_t.size, |
|
inner_t.stride, |
|
dtype=inner_t.dtype, |
|
device=inner_t.device, |
|
) |
|
assert inner_t.data is not None |
|
_safe_copy(r.real_tensor, inner_t.data) |
|
return r |
|
|
|
transformed_tensors_dict = { |
|
attr: transform(attr, inner_t) for attr, inner_t in t.attrs.items() |
|
} |
|
|
|
sub = t.type.__tensor_unflatten__( |
|
transformed_tensors_dict, t.ctx, outer_size, outer_stride |
|
) |
|
|
|
|
|
|
|
|
|
assert sub.shape == outer_size, ( |
|
f"Expected return value from {t.type}__tensor_unflatten__() to have " |
|
f"shape equal to {outer_size}, but got: {sub.shape}" |
|
) |
|
assert sub.stride() == outer_stride, ( |
|
f"Expected return value from {t.type}__tensor_unflatten__() to have " |
|
f"stride equal to {outer_stride}, but got: {sub.stride()}" |
|
) |
|
|
|
return sub |
|
|
|
|
|
|
|
|
|
|
|
def all_dynamic_symbolic_context( |
|
t: MetaTensorDesc, source, shape_env, callback |
|
): |
|
from torch._dynamo.source import AttrSource |
|
from torch.fx.experimental.symbolic_shapes import ( |
|
DimDynamic, |
|
StatelessSymbolicContext, |
|
SubclassSymbolicContext, |
|
) |
|
|
|
view_base_context: Optional[SymbolicContext] = None |
|
if t.is_view: |
|
assert t.base is not None |
|
view_base_context = all_dynamic_symbolic_context( |
|
t.base, AttrSource(source, "_base"), shape_env, callback |
|
) |
|
|
|
t_symbolic_context: SymbolicContext |
|
t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim |
|
if t.is_traceable_wrapper_subclass: |
|
assert t.attrs is not None |
|
inner_contexts: Dict[str, SymbolicContext] = {} |
|
for attr, inner in t.attrs.items(): |
|
assert isinstance(attr, str) |
|
inner_contexts[attr] = all_dynamic_symbolic_context( |
|
inner, AttrSource(source, attr), shape_env, callback |
|
) |
|
t_symbolic_context = SubclassSymbolicContext( |
|
dynamic_sizes=t_dynamic_sizes, |
|
constraint_sizes=[None] * t.ndim, |
|
inner_contexts=inner_contexts, |
|
tensor_source=source, |
|
view_base_context=view_base_context, |
|
) |
|
else: |
|
t_symbolic_context = StatelessSymbolicContext( |
|
dynamic_sizes=t_dynamic_sizes, |
|
constraint_sizes=[None] * t.ndim, |
|
view_base_context=view_base_context, |
|
) |
|
|
|
return t_symbolic_context |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def view_from_base( |
|
base: torch.Tensor, t: MetaTensorDesc, source=source, shape_env=shape_env |
|
): |
|
|
|
(sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( |
|
t, source |
|
) |
|
if ( |
|
not t.is_traceable_wrapper_subclass |
|
and not is_traceable_wrapper_subclass(base) |
|
): |
|
|
|
|
|
|
|
with maybe_suppress(): |
|
return base.as_strided(sizes, strides, storage_offset) |
|
|
|
from torch._dynamo.source import EphemeralSource |
|
from torch.fx.experimental.symbolic_shapes import ( |
|
StatelessSymbolicContext, |
|
sym_eq, |
|
) |
|
|
|
def symint_visitor_fn(s): |
|
nonlocal symbolic_context |
|
from torch.fx.experimental.symbolic_shapes import DimDynamic |
|
|
|
all_static_sizes = ( |
|
symbolic_context is not None |
|
and isinstance(symbolic_context, StatelessSymbolicContext) |
|
and all( |
|
x is DimDynamic.STATIC for x in symbolic_context.dynamic_sizes |
|
) |
|
) |
|
|
|
if all_static_sizes or shape_env is None: |
|
return s |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sym_source = EphemeralSource("symint_visitor_fn") |
|
symbol = shape_env.create_symbol(s, sym_source) |
|
return shape_env.create_symintnode(symbol, hint=s, source=sym_source) |
|
|
|
real_to_fake_mapping = {} |
|
if t.is_traceable_wrapper_subclass: |
|
assert t.attrs is not None |
|
|
|
|
|
assert t.type is not None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fake_t = empty_create_subclass( |
|
t, outer_size=sizes, outer_stride=strides |
|
) |
|
attrs, _ = fake_t.__tensor_flatten__() |
|
for attr in attrs: |
|
real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr) |
|
|
|
def tensor_visitor_fn( |
|
visited_t: torch.Tensor, |
|
|
|
|
|
shape_env=shape_env, |
|
callback=callback, |
|
): |
|
|
|
if visited_t is None: |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
visited_id = self.describer.get_tensor_id(visited_t) |
|
fake_visited_t = real_to_fake_mapping.get(visited_id, None) |
|
if fake_visited_t is not None: |
|
return fake_visited_t |
|
|
|
visited_desc = self.describer.describe_tensor(visited_t) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temp_source = EphemeralSource("tensor_visitor_fn") |
|
return self.meta_tensor( |
|
visited_desc, |
|
shape_env, |
|
callback, |
|
source=temp_source, |
|
symbolic_context=all_dynamic_symbolic_context( |
|
visited_desc, temp_source, shape_env, callback |
|
), |
|
) |
|
|
|
|
|
|
|
assert t.view_func is not None |
|
|
|
|
|
fake_t = t.view_func(base, symint_visitor_fn, tensor_visitor_fn) |
|
|
|
|
|
|
|
|
|
torch._check(sym_eq(fake_t.size(), sizes)) |
|
torch._check(sym_eq(fake_t.stride(), strides)) |
|
torch._check(sym_eq(fake_t.storage_offset(), storage_offset)) |
|
return fake_t |
|
|
|
if self.get_tensor_memo(t) is None: |
|
GRAD_TENSOR_SENTINEL_VALUE = -2 |
|
|
|
with torch.inference_mode(t.is_inference): |
|
if t.is_sparse: |
|
is_leaf = t.is_leaf |
|
|
|
|
|
|
|
|
|
r = callback( |
|
lambda: torch.ops.aten._sparse_coo_tensor_with_dims( |
|
t.sparse_dim, |
|
t.dense_dim, |
|
t.size, |
|
dtype=t.dtype, |
|
layout=torch.sparse_coo, |
|
device="meta", |
|
) |
|
) |
|
if self.copy_data: |
|
|
|
assert t.data is not None |
|
with torch.no_grad(), no_dispatch(): |
|
r.real_tensor = _safe_clone(t.data) |
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
|
|
|
|
|
|
|
|
|
|
|
r._coalesced_(t.is_coalesced) |
|
if t.requires_grad: |
|
r.requires_grad = True |
|
if t.requires_grad and not is_leaf: |
|
|
|
|
|
|
|
|
|
r = r.clone() |
|
with torch.enable_grad(): |
|
r._coalesced_(t.is_coalesced) |
|
elif is_sparse_compressed_layout(t.layout): |
|
is_leaf = t.is_leaf |
|
|
|
if t.layout in {torch.sparse_bsr, torch.sparse_bsc}: |
|
assert t.sparse_dim is not None |
|
assert t.dense_dim is not None |
|
assert t.values is not None |
|
batch_dim = t.ndim - t.sparse_dim - t.dense_dim |
|
blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3] |
|
else: |
|
blocksize = () |
|
if t.layout in {torch.sparse_csr, torch.sparse_bsr}: |
|
assert t.crow_indices is not None |
|
index_dtype = t.crow_indices.dtype |
|
else: |
|
assert t.ccol_indices is not None |
|
index_dtype = t.ccol_indices.dtype |
|
|
|
r = callback( |
|
lambda: torch.ops.aten._sparse_compressed_tensor_with_dims( |
|
0, |
|
t.dense_dim, |
|
t.shape, |
|
blocksize, |
|
index_dtype, |
|
layout=t.layout, |
|
dtype=t.dtype, |
|
device="meta", |
|
) |
|
) |
|
if self.copy_data: |
|
|
|
assert t.data is not None |
|
with torch.no_grad(), no_dispatch(): |
|
r.real_tensor = _safe_clone(t.data) |
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
|
if t.requires_grad: |
|
r.requires_grad = True |
|
if t.requires_grad and not is_leaf: |
|
r = torch._C._functions.DelayedError( |
|
"Internal error: Tried to backward() through example input", |
|
1, |
|
)(r) |
|
elif t.is_nested and not t.is_traceable_wrapper_subclass: |
|
|
|
|
|
|
|
from torch._dynamo.exc import unimplemented |
|
|
|
unimplemented( |
|
"strided nested tensors are not supported by meta conversion" |
|
) |
|
elif t.is_mkldnn: |
|
is_leaf = t.is_leaf |
|
sizes, strides, _storage_offset = sym_sizes_strides_storage_offset( |
|
t, source |
|
) |
|
|
|
|
|
r = callback( |
|
lambda: torch.empty_strided( |
|
sizes, strides, dtype=t.dtype, device="meta" |
|
) |
|
) |
|
if self.copy_data: |
|
with torch.no_grad(), no_dispatch(): |
|
assert t.size is not None |
|
assert t.stride is not None |
|
r.real_tensor = torch.empty_strided( |
|
t.size, t.stride, dtype=t.dtype, device=t.device |
|
) |
|
assert t.data is not None |
|
_safe_copy(r.real_tensor, t.data) |
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
|
if t.requires_grad: |
|
r.requires_grad = True |
|
if t.requires_grad and not is_leaf: |
|
r = torch._C._functions.DelayedError( |
|
"Internal error: Tried to backward() through example input", |
|
1, |
|
)(r) |
|
elif t.is_functorch_wrapped: |
|
if t.is_view: |
|
from torch._dynamo.exc import unimplemented |
|
|
|
unimplemented( |
|
"view functorch tensors are not supported by meta conversion" |
|
) |
|
|
|
|
|
|
|
def _to_fake_tensor(t: MetaTensorDesc): |
|
|
|
|
|
if t.is_batchedtensor: |
|
assert t.unwrapped is not None |
|
assert t.level is not None |
|
assert t.bdim is not None |
|
ft = _to_fake_tensor(t.unwrapped) |
|
lvl = t.level |
|
bdim = t.bdim |
|
|
|
|
|
|
|
|
|
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( |
|
t.functorch_stack |
|
): |
|
r = _add_batch_dim(ft, bdim, lvl) |
|
elif t.is_gradtrackingtensor: |
|
assert t.unwrapped is not None |
|
assert t.level is not None |
|
disable_functorch = torch._C._DisableFuncTorch |
|
with disable_functorch(): |
|
ft = _to_fake_tensor(t.unwrapped) |
|
lvl = t.level |
|
if lvl == GRAD_TENSOR_SENTINEL_VALUE: |
|
r = ft |
|
else: |
|
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( |
|
t.functorch_stack |
|
): |
|
r = torch._C._functorch._wrap_for_grad(ft, lvl) |
|
|
|
is_leaf = t.is_leaf |
|
if t.requires_grad and safe_is_leaf(r): |
|
r.requires_grad = True |
|
elif t.requires_grad and not is_leaf: |
|
r = torch._C._functions.DelayedError( |
|
"Internal error: Tried to backward() through example input", |
|
1, |
|
)( |
|
r |
|
) |
|
elif t.is_functional: |
|
assert t.unwrapped is not None |
|
assert t.current_level is not None |
|
ft = self.meta_tensor( |
|
t.unwrapped, |
|
shape_env=shape_env, |
|
callback=callback, |
|
|
|
|
|
|
|
|
|
source=source, |
|
symbolic_context=symbolic_context, |
|
) |
|
r = _wrap_functional_tensor(ft, t.current_level) |
|
|
|
else: |
|
assert t.stride is not None |
|
|
|
sizes = t.size |
|
strides = t.stride |
|
r = callback( |
|
lambda: torch.empty_strided( |
|
sizes, |
|
strides, |
|
dtype=t.dtype, |
|
device="meta", |
|
) |
|
) |
|
if self.copy_data: |
|
with torch.no_grad(), no_dispatch(): |
|
r.real_tensor = torch.empty_strided( |
|
t.size, |
|
t.stride, |
|
dtype=t.dtype, |
|
device=t.device, |
|
) |
|
assert t.data is not None |
|
_safe_copy(r.real_tensor, t.data) |
|
return r |
|
|
|
r = _to_fake_tensor(t) |
|
|
|
elif t.is_functional and t.device.type not in ["xla", "lazy"]: |
|
assert t.unwrapped is not None |
|
assert not t.is_functorch_wrapped |
|
unwrapped = self.meta_tensor( |
|
t.unwrapped, |
|
shape_env=shape_env, |
|
callback=callback, |
|
source=source, |
|
symbolic_context=symbolic_context, |
|
) |
|
r = torch._to_functional_tensor(unwrapped) |
|
torch._mirror_autograd_meta_to(t.autograd_meta_from, r) |
|
|
|
elif t.is_view: |
|
|
|
|
|
|
|
|
|
|
|
assert t.base is not None |
|
|
|
base_symbolic_context = None |
|
if shape_env and symbolic_context is not None: |
|
from torch.fx.experimental.symbolic_shapes import ( |
|
StatelessSymbolicContext, |
|
) |
|
|
|
assert isinstance(symbolic_context, StatelessSymbolicContext) |
|
|
|
|
|
|
|
if symbolic_context.view_base_context is not None: |
|
base_symbolic_context = symbolic_context.view_base_context |
|
|
|
base = self.meta_tensor( |
|
t.base, |
|
shape_env, |
|
callback, |
|
source=torch._dynamo.source.AttrSource(source, "_base"), |
|
symbolic_context=base_symbolic_context, |
|
) |
|
|
|
def is_c_of_r(complex_dtype, real_dtype): |
|
return ( |
|
utils.is_complex_dtype(complex_dtype) |
|
and utils.corresponding_real_dtype(complex_dtype) |
|
== real_dtype |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded( |
|
torch._C.DispatchKey.ADInplaceOrView |
|
) |
|
torch._C._dispatch_tls_set_dispatch_key_excluded( |
|
torch._C.DispatchKey.ADInplaceOrView, False |
|
) |
|
try: |
|
if base.dtype == t.dtype: |
|
pass |
|
elif is_c_of_r(base.dtype, t.dtype): |
|
base = torch.view_as_real(base) |
|
elif is_c_of_r(t.dtype, base.dtype): |
|
base = torch.view_as_complex(base) |
|
else: |
|
|
|
|
|
|
|
base = base.view(t.dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if t.is_leaf: |
|
|
|
|
|
with torch.no_grad(): |
|
r = view_from_base(base, t) |
|
|
|
r.requires_grad = t.requires_grad |
|
else: |
|
if t.base.requires_grad == t.requires_grad: |
|
|
|
with torch.enable_grad(): |
|
r = view_from_base(base, t) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
assert t.requires_grad |
|
with torch.no_grad(): |
|
mid = base.view(base.shape) |
|
mid.requires_grad = t.requires_grad |
|
with torch.enable_grad(): |
|
r = view_from_base(mid, t) |
|
|
|
|
|
|
|
assert t.creation_meta is not None |
|
torch._C._autograd._set_creation_meta(r, t.creation_meta) |
|
finally: |
|
torch._C._dispatch_tls_set_dispatch_key_excluded( |
|
torch._C.DispatchKey.ADInplaceOrView, old_exclude |
|
) |
|
|
|
else: |
|
is_leaf = t.is_leaf |
|
|
|
|
|
if ( |
|
not (t.is_batchedtensor or t.is_gradtrackingtensor) |
|
and t.is_functorch_wrapped |
|
) or t.is_legacy_batchedtensor: |
|
return NotImplemented |
|
|
|
( |
|
sizes, |
|
strides, |
|
storage_offset, |
|
) = sym_sizes_strides_storage_offset(t, source, symbolic_context) |
|
|
|
|
|
|
|
if t.is_traceable_wrapper_subclass: |
|
r = empty_create_subclass( |
|
t, outer_size=sizes, outer_stride=strides |
|
) |
|
else: |
|
r = callback( |
|
lambda: torch.empty_strided( |
|
sizes, |
|
strides, |
|
dtype=t.dtype, |
|
device="meta", |
|
) |
|
) |
|
if self.copy_data: |
|
with torch.no_grad(), no_dispatch(): |
|
assert t.size is not None |
|
assert t.stride is not None |
|
r.real_tensor = torch.empty_strided( |
|
t.size, t.stride, dtype=t.dtype, device=t.device |
|
) |
|
_safe_copy(r.real_tensor, t.data) |
|
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
|
if t.requires_grad: |
|
r.requires_grad = t.requires_grad |
|
if not is_leaf: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r = torch._C._functions.DelayedError( |
|
"Internal error: Tried to backward() through example input", |
|
1, |
|
)(r) |
|
|
|
s = t.storage |
|
assert s is not None |
|
if s.id not in self.storage_memo and ( |
|
r.is_nested |
|
or ( |
|
r.stride() == strides |
|
and r.storage_offset() == storage_offset |
|
) |
|
): |
|
|
|
self.set_storage_memo(s, r.untyped_storage()) |
|
if self.copy_data: |
|
r.untyped_storage().real_storage = ( |
|
r.real_tensor.untyped_storage() |
|
) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r_s = self.meta_storage(s, callback=callback) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext() |
|
from torch._subclasses.fake_tensor import ( |
|
in_kernel_invocation_manager, |
|
maybe_get_fake_mode, |
|
) |
|
|
|
mb_fake_mode = maybe_get_fake_mode(r) |
|
if mb_fake_mode is not None: |
|
maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode) |
|
with torch.no_grad(), maybe_suppress(): |
|
with maybe_fake_mgr: |
|
r.set_(r_s, storage_offset, sizes, strides) |
|
if self.copy_data: |
|
with torch.no_grad(), no_dispatch(): |
|
r.real_tensor.set_( |
|
r_s.real_storage, |
|
t.storage_offset, |
|
t.size, |
|
t.stride, |
|
) |
|
|
|
if t.grad is not None: |
|
from torch._dynamo.source import AttrSource |
|
|
|
|
|
|
|
r.grad = self.meta_tensor( |
|
t.grad, |
|
shape_env, |
|
callback, |
|
source=AttrSource(source, "grad"), |
|
symbolic_context=symbolic_context, |
|
) |
|
torch._C._set_conj(r, t.is_conj) |
|
torch._C._set_neg(r, t.is_neg) |
|
|
|
skip_leaf = ( |
|
t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE |
|
) |
|
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf) |
|
|
|
|
|
|
|
if t.storage is not None and t.storage.size == 0: |
|
r.untyped_storage().resize_(0) |
|
|
|
if t.is_parameter: |
|
r._is_param = True |
|
|
|
self.set_tensor_memo(t, r) |
|
|
|
return self.get_tensor_memo(t) |
|
|
|
def __call__( |
|
self, |
|
t, |
|
shape_env=None, |
|
*, |
|
callback=lambda t: t(), |
|
source=None, |
|
symbolic_context=None, |
|
|
|
|
|
|
|
trace=True, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t): |
|
if ( |
|
|
|
|
|
|
|
|
|
t.device.type == "lazy" |
|
or |
|
|
|
t.is_quantized |
|
or |
|
|
|
|
|
(t._is_view() and t._base is not None and t._base.is_sparse) |
|
): |
|
self.miss += 1 |
|
return NotImplemented |
|
else: |
|
self.hit += 1 |
|
elif torch.overrides.is_tensor_like(t): |
|
self.miss += 1 |
|
return NotImplemented |
|
else: |
|
|
|
return t |
|
|
|
if source is None: |
|
trace = False |
|
|
|
|
|
|
|
t_desc = self.describer.describe_tensor(t, trace=trace) |
|
|
|
if trace: |
|
trace_structured( |
|
"describe_source", |
|
metadata_fn=lambda: { |
|
"describer_id": self.describer.id, |
|
"id": t_desc.id, |
|
"source": source.name(), |
|
}, |
|
) |
|
|
|
|
|
|
|
|
|
with contextlib.ExitStack() as exit_stack: |
|
exit_stack.enter_context(torch._dispatch.python.suspend_functionalization()) |
|
st = peek_interpreter_stack() |
|
if st is not None: |
|
exit_stack.enter_context( |
|
torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() |
|
) |
|
|
|
r = self.meta_tensor( |
|
t_desc, |
|
shape_env=shape_env, |
|
callback=callback, |
|
source=source, |
|
symbolic_context=symbolic_context, |
|
) |
|
|
|
if type(t) is torch.nn.Parameter: |
|
|
|
|
|
r._is_param = True |
|
|
|
|
|
return r |
|
|
|
|
|
import torch._prims_common as utils |
|
|