Spaces:
Running
Running
# mypy: disable-error-code="method-assign" | |
import copy | |
import functools | |
import getpass | |
import inspect | |
import itertools | |
import logging | |
import os | |
import re | |
import subprocess | |
import tempfile | |
import textwrap | |
from collections import Counter | |
from importlib import import_module | |
from typing import Any, Callable, Dict, List, Optional, TypeVar | |
import torch | |
import torch._prims_common as utils | |
import torch._subclasses.meta_utils | |
from torch import Tensor | |
from torch._dynamo.testing import rand_strided | |
from torch._prims_common import is_float_dtype | |
from torch.multiprocessing.reductions import StorageWeakRef | |
from torch.utils._content_store import ContentStoreReader, ContentStoreWriter | |
from . import config | |
from .utils import clone_inputs, get_debug_dir | |
log = logging.getLogger(__name__) | |
T = TypeVar("T") | |
inductor_config = import_module("torch._inductor.config") | |
use_buck = inductor_config.is_fbcode() | |
if use_buck: | |
import libfb.py.build_info | |
extra_deps = [] | |
extra_imports = "" | |
if use_buck: | |
extra_deps = [ | |
"//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", | |
"//caffe2/torch/fb/sparsenn:sparsenn_operators", | |
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu", | |
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops", | |
] | |
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined] | |
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) | |
BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"] | |
class BuckTargetWriter: | |
def __init__(self, filename): | |
self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) | |
self.target = self.py_file.replace(".py", "") | |
# Get main_module path from fbcode | |
self.path = f'{self.subdir.replace("/", ".")}.{self.target}' | |
self.path = self.path[self.path.find("fbcode.") :] | |
self.path = self.path[7:] | |
# Get cmd line path | |
tmp = self.subdir | |
tmp = tmp[tmp.find("fbcode/") :][7:] | |
self.cmd_line_path = f"//{tmp}:{self.target}" | |
def build(self): | |
extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) | |
return textwrap.dedent( | |
f""" | |
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") | |
python_binary( | |
name="{self.target}", | |
srcs = ["{self.py_file}"], | |
compile = False, | |
deps = [ | |
"//caffe2:torch", | |
"//caffe2/functorch:functorch", | |
"//triton:triton", | |
"{cur_target}", | |
], | |
cpp_deps = [ | |
{extra_cpp_deps} | |
], | |
main_module = "{self.path}", | |
par_style = "xar", | |
) | |
""" | |
) | |
def write(self, print_msg=True): | |
target_file = os.path.join(self.subdir, "TARGETS") | |
with open(target_file, "w") as fd: | |
fd.write(self.build()) | |
# log.warning("Wrote isolation TARGETS file at %s", target_file) | |
cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path] | |
if print_msg: | |
log.warning( | |
"Found an example that reproduces the error. Run this cmd to repro - %s", | |
" ".join(cmd_split), | |
) | |
return cmd_split | |
def minifier_dir(): | |
path = os.path.join(get_debug_dir(), "minifier") | |
if path is None: | |
path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" | |
if not os.path.exists(path): | |
os.makedirs(path, exist_ok=True) | |
return path | |
MAX_CONSTANT_NUMEL_INLINE = 4 | |
class NNModuleToString: | |
safe_reprs = [ | |
torch.nn.Linear, | |
torch.nn.Conv1d, | |
torch.nn.Conv2d, | |
torch.nn.Conv3d, | |
torch.nn.BatchNorm1d, | |
torch.nn.BatchNorm2d, | |
torch.nn.BatchNorm3d, | |
torch.nn.LayerNorm, | |
torch.nn.Dropout, | |
torch.nn.Softmax, | |
torch.nn.ReLU, | |
torch.nn.GELU, | |
torch.nn.Identity, | |
torch.nn.MaxPool2d, | |
torch.nn.Embedding, | |
torch.nn.Tanh, | |
torch.nn.ConvTranspose1d, | |
torch.nn.GLU, | |
torch.nn.LSTM, | |
torch.nn.Flatten, | |
torch.nn.AdaptiveAvgPool2d, | |
] | |
def can_convert_to_string(gm): | |
cant_convert = set() | |
for _, module in gm.named_children(): | |
if type(module) not in NNModuleToString.safe_reprs: | |
cant_convert.add(module) | |
if len(cant_convert) > 0: | |
log.warning("We have not tested reprs of some modules - %s", cant_convert) | |
# TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct. | |
return True | |
def convert(gm): | |
from torch.nn.modules.module import _addindent | |
tab = " " * 4 | |
model_str = textwrap.dedent( | |
""" | |
from torch.nn import * | |
class Repro(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
""" | |
) | |
for module_name, module in gm.named_children(): | |
module_str = f"{module.__repr__()}" | |
# module should be a core torch.nn.Module, so all parameters | |
# should be on the same device. | |
example_param = next(module.parameters(), None) | |
if example_param is not None and example_param.is_cuda: | |
module_str = f"{module_str}.cuda()" | |
model_str += f"{tab*2}self.{module_name} = {module_str}\n" | |
for buffer_name, buffer in gm._buffers.items(): | |
if buffer is None: | |
continue | |
# Serialize full data for small buffers | |
if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE: | |
from torch._tensor_str import PRINT_OPTS | |
assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE | |
tensor_str = repr(buffer) | |
elif torch.is_floating_point(buffer): | |
tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})" | |
else: | |
tensor_str = ( | |
f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" | |
) | |
if buffer.is_cuda: | |
tensor_str = f"{tensor_str}.cuda()" | |
model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" | |
for param_name, param in gm._parameters.items(): | |
if param is None: | |
continue | |
maybe_device = "" | |
if param.is_cuda: | |
maybe_device = ', device="cuda"' | |
tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))" | |
model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" | |
# TODO - Keep this code for now. But, I don't think we will need this. | |
# attrs = dir(gm) | |
# for attr in attrs: | |
# if "_tensor_constant" in attr: | |
# val = getattr(gm, attr) | |
# model_str += f" {attr} = {val!r}\n" | |
model_str += f"{_addindent(gm.code, 4)}\n" | |
return model_str | |
# subprocess is expensive | |
def _cuda_system_info_comment(): | |
if not torch.cuda.is_available(): | |
return "# torch.cuda.is_available()==False, no GPU info collected\n" | |
model_str = "# CUDA Info: \n" | |
try: | |
cuda_version_out = subprocess.check_output(["nvcc", "--version"]) | |
cuda_version_lines = cuda_version_out.decode().split("\n") | |
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]]) | |
model_str += f"{comment}\n" | |
except (FileNotFoundError, subprocess.CalledProcessError): | |
model_str += "# nvcc not found\n" | |
gpu_names = Counter( | |
torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) | |
) | |
model_str += "# GPU Hardware Info: \n" | |
for name, count in gpu_names.items(): | |
model_str += f"# {name} : {count} \n" | |
model_str += "\n" | |
return model_str | |
def generate_config_string(*, stable_output=False): | |
import torch._functorch.config | |
import torch._inductor.config | |
if stable_output: | |
return "# config omitted due to stable_output=True" | |
experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined] | |
return f"""\ | |
import torch._dynamo.config | |
import torch._inductor.config | |
import torch._functorch.config | |
import torch.fx.experimental._config | |
{torch._dynamo.config.codegen_config()} | |
{torch._inductor.config.codegen_config()} | |
{torch._functorch.config.codegen_config()} | |
{experimental_config} | |
""" | |
def get_minifier_repro_path(): | |
return os.path.join(minifier_dir(), "minifier_launcher.py") | |
def helper_for_dump_minify(contents): | |
minified_repro_path = get_minifier_repro_path() | |
log.warning("Writing minified repro to:\n%s", minified_repro_path) | |
if use_buck: | |
BuckTargetWriter(minified_repro_path).write() | |
try: | |
with open(minified_repro_path, "w") as fd: | |
fd.write(contents) | |
except OSError as e: | |
log.exception(e) | |
raise NotImplementedError("Could not write to {minified_repro_path}") from e | |
class AccuracyError(Exception): | |
pass | |
def clone_inputs_retaining_gradness(example_inputs): | |
""" | |
This clone inputs is different from utils clone_input. In case of minifier, | |
all the tensors are leaf tensors while creating a new graph. So, we set the | |
requires_grad field w/o checking the leafness of the tensor. | |
""" | |
cloned_inputs = clone_inputs(example_inputs) | |
for idx in range(len(example_inputs)): | |
if isinstance(cloned_inputs[idx], torch.Tensor): | |
cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad) | |
return cloned_inputs | |
def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False): | |
""" | |
Runs a forward and possibly backward iteration for a given mod and args. | |
When disable_clone is True, we will use args as-is without cloning. | |
This is higher fidelity but we may destroy the args in the process. | |
""" | |
from torch._functorch.aot_autograd import make_boxed_func | |
from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass | |
gm = copy.deepcopy(gm) | |
if not disable_clone: | |
args = clone_inputs_retaining_gradness(args) | |
if hasattr(gm, "zero_grad"): | |
gm.zero_grad(True) | |
# TorchInductor returned callable expects lists. So, boxing the call. | |
orig_named_parameters = getattr(gm, "named_parameters", None) | |
orig_named_buffers = getattr(gm, "named_buffers", None) | |
if not hasattr(gm, "_boxed_call") and ( | |
orig_named_parameters is not None or orig_named_buffers is not None | |
): | |
gm = make_boxed_func(gm) | |
if orig_named_parameters is not None: | |
gm.named_parameters = orig_named_parameters | |
if orig_named_buffers is not None: | |
gm.named_buffers = orig_named_buffers | |
out = gm(args) | |
if only_fwd: | |
return out | |
if requires_bwd_pass(out): | |
loss = reduce_to_scalar_loss(out) | |
loss.backward() | |
return collect_results(gm, out, None, args) | |
def same_two_models( | |
gm, | |
opt_gm, | |
example_inputs, | |
only_fwd=False, | |
*, | |
require_fp64=False, | |
ignore_non_fp=False, | |
): | |
""" | |
Check two models have same accuracy. | |
require_fp64: if True, raise an error if we unable to calculate the fp64 reference | |
ignore_non_fp: if True, do not compare outputs which are not floating point. This | |
is mostly useful for the minifier (which wants to avoid quantizing floating point | |
error into integer/boolean error) | |
""" | |
from .eval_frame import OptimizedModule | |
from .testing import ( | |
named_buffers_for_optimized_module, | |
named_parameters_for_optimized_module, | |
) | |
from .utils import same | |
if isinstance(gm, OptimizedModule): | |
gm.named_parameters = named_parameters_for_optimized_module(gm) | |
gm.named_buffers = named_buffers_for_optimized_module(gm) | |
if isinstance(opt_gm, OptimizedModule): | |
opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) | |
opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm) | |
ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) | |
fp64_ref = None | |
if config.same_two_models_use_fp64: | |
try: | |
fp64_model, fp64_examples = cast_to_fp64( | |
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) | |
) | |
fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) | |
except Exception: | |
if require_fp64: | |
raise RuntimeError("Could not generate fp64 outputs") # noqa: TRY200 | |
log.warning("Could not generate fp64 outputs") | |
try: | |
res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) | |
except Exception as e: | |
# This means that the minified graph is bad/exposes a different problem. | |
# As we are checking accuracy here, lets log the exception and return True. | |
log.exception( | |
"While minifying the program in accuracy minification mode, " | |
"ran into a runtime exception which is likely an unrelated issue." | |
" Skipping this graph." | |
) | |
return True | |
passing = same( | |
ref, | |
res, | |
fp64_ref, | |
tol=config.repro_tolerance, | |
equal_nan=True, | |
ignore_non_fp=ignore_non_fp, | |
) | |
return passing | |
def cast_dtype_args_to_fp64(model): | |
for node in model.graph.nodes: | |
if ( | |
node.op == "call_function" | |
and node.target == torch.ops.prims.convert_element_type.default | |
): | |
assert len(node.args) == 2 | |
if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: | |
node.args = (node.args[0], torch.float64) | |
if node.op == "call_function": | |
dtype = node.kwargs.get("dtype") | |
if dtype is not None and is_float_dtype(dtype): | |
new_kwargs = dict(node.kwargs) | |
new_kwargs["dtype"] = torch.float64 | |
node.kwargs = new_kwargs | |
model.graph.lint() | |
model.recompile() | |
return model | |
def cast_to(dtype, model, inputs): | |
from torch.utils._pytree import tree_map | |
model = model.to(dtype) | |
if dtype == torch.float64: | |
# If casting to fp64 for accuracy comparison, we need to | |
# replace dtype arguments embedded in the graph with fp64 | |
model = cast_dtype_args_to_fp64(model) | |
inputs = tree_map( | |
lambda x: x.to(dtype) | |
if isinstance(x, torch.Tensor) and x.is_floating_point() | |
else x, | |
inputs, | |
) | |
return model, inputs | |
def cast_to_fp64(model, inputs): | |
return cast_to(torch.float64, model, inputs) | |
def backend_accuracy_fails( | |
gm, | |
example_inputs, | |
compiler_fn, | |
only_fwd=False, | |
*, | |
require_fp64=False, | |
ignore_non_fp=False, | |
): | |
try: | |
compiled_gm = compiler_fn( | |
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs) | |
) | |
return not same_two_models( | |
gm, | |
compiled_gm, | |
example_inputs, | |
only_fwd, | |
require_fp64=require_fp64, | |
ignore_non_fp=ignore_non_fp, | |
) | |
except Exception as e: | |
# This means that the minified graph is bad/exposes a different problem. | |
# As we are checking accuracy here, lets log the exception and return False. | |
log.exception( | |
"While minifying the program in accuracy minification mode, " | |
"ran into a runtime exception which is likely an unrelated issue." | |
" Skipping this graph" | |
) | |
return False | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # | |
# REPRO SUPPORT CODE | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # | |
# Helper functions for computing what the default values of tensor | |
# values should be. These all coincide with factory functions, e.g., torch.empty | |
def _stride_or_default( | |
stride: Optional["torch._prims_common.StrideType"], | |
*, | |
shape: "torch._prims_common.ShapeType", | |
) -> "torch._prims_common.StrideType": | |
return stride if stride is not None else utils.make_contiguous_strides_for(shape) | |
def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]: | |
return lambda x: x if x is not None else d | |
_dtype_or_default = _mk_defaulter(torch.float32) | |
_device_or_default = _mk_defaulter(torch.device("cpu")) | |
_storage_offset_or_default = _mk_defaulter(0) | |
_requires_grad_or_default = _mk_defaulter(False) | |
_is_leaf_or_default = _mk_defaulter(False) | |
class NopInputReader: | |
def __init__(self): | |
self.total = 0 | |
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): | |
self.total += 1 | |
def tensor(self, *args, **kwargs): | |
pass | |
def symint(self, *args, **kwargs): | |
pass | |
# TODO: Support bundling the entire repro into a zip file for ease of | |
# transferring around | |
class InputReader: | |
def __init__(self, save_dir=None, *, pbar=None): | |
# If None, we will generate random data instead. It's important | |
# to natively support this use case as it will allow people to | |
# share repros without including the real data, if the problem | |
# reproduces even on random data. | |
if save_dir is None: | |
log.warning("no save_dir specified, will generate random data") | |
self.store = ContentStoreReader(save_dir) if save_dir is not None else None | |
self.args = [] | |
self.pbar = pbar | |
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None): | |
if self.pbar is not None: | |
self.pbar.update(1) | |
device = _device_or_default(device) | |
dtype_hint = _dtype_or_default(dtype_hint) | |
if self.store is not None and storage_hash is not None: | |
try: | |
storage = self.store.read_storage(storage_hash) | |
except FileNotFoundError: | |
pass | |
else: | |
if device != storage.device: | |
log.warning("device mismatch: %s != %s", device, storage.device) | |
# TODO: transfer it to the right device? But failing this | |
# way would be very mysterious! Would have been better | |
# not to store device in the serialized format... | |
return storage | |
log.warning("could not load %s, generating random data instead", storage_hash) | |
shape = (nbytes // dtype_hint.itemsize,) | |
stride = _stride_or_default(None, shape=shape) | |
return rand_strided(shape, stride, dtype_hint, device).untyped_storage() | |
def tensor( | |
self, | |
storage, | |
shape, | |
stride=None, | |
*, | |
storage_offset=None, | |
dtype=None, | |
requires_grad=None, | |
is_leaf=None, | |
**metadata, | |
): | |
stride = _stride_or_default(stride, shape=shape) | |
storage_offset = _storage_offset_or_default(storage_offset) | |
dtype = _dtype_or_default(dtype) | |
is_leaf = _is_leaf_or_default(is_leaf) | |
requires_grad = _requires_grad_or_default(requires_grad) | |
t = torch.tensor( | |
[], dtype=dtype, device=storage.device, requires_grad=requires_grad | |
) | |
with torch.no_grad(): | |
t.set_(storage, storage_offset, shape, stride) | |
if not is_leaf: | |
# Fake up some autograd history in a very naughty way | |
with torch.enable_grad(): | |
t = t.clone(memory_format=torch.preserve_format) | |
with torch.no_grad(): | |
t.set_(storage, storage_offset, shape, stride) | |
assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf | |
torch._utils.set_tensor_metadata(t, metadata) | |
self.args.append(t) | |
return t # for BC | |
def symint(self, val): | |
self.args.append(val) | |
return val # for BC | |
# Here is our writer strategy: | |
# 1. We will stream all of the inputs to disk | |
# 2. You can now deterministically randomize the inputs, or reload | |
# the inputs from disk | |
# 3. You can YOLO run the script without the inputs, in which case | |
# we'll fill the inputs with random data and pray. This is the | |
# legacy behavior, but it's also useful if you want to find out | |
# if we're so broken even random inputs trigger it | |
# 4. We could offer an in process "check if the randomized thing | |
# works too" but this is delicate so we don't do it | |
class InputWriter: | |
def __init__(self, save_dir, *, stable_hash=False): | |
self._lines = [] | |
# TODO: consider ensuring tensor and storage counters line up? | |
self.storage_counter = itertools.count() | |
self.save_dir = save_dir | |
self.store = ( | |
ContentStoreWriter(save_dir, stable_hash=stable_hash) | |
if save_dir is not None | |
else None | |
) | |
self.seen_storages = {} | |
def lines(self): | |
r = [ | |
"def load_args(reader):", | |
] | |
r.extend(f" {l}" for l in self._lines) | |
# In case we need to change the internal format of load_args | |
# in an FC-breaking way | |
r.append("load_args._version = 0") | |
return r | |
# Storages are untyped, but we need to initialize them with data if | |
# we don't have the real data, so we give a hint saying what kind | |
# of initialization may be appropriate | |
# | |
# If we had a FakeTensor, device_hint tells us what device should be | |
def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str: | |
ws = StorageWeakRef(untyped_storage) | |
v = self.seen_storages.get(ws) | |
if v is not None: | |
return v | |
v = f"buf{next(self.storage_counter)}" | |
maybe_dtype_hint = "" | |
if _dtype_or_default(None) != _dtype_or_default(dtype_hint): | |
maybe_dtype_hint = f", dtype_hint={dtype_hint!r}" | |
# TODO: being optional on device is kind of pointless as the default | |
# is CPU but most repros we care about are CUDA | |
maybe_device = "" | |
device = untyped_storage.device | |
if device.type == "meta": | |
assert device_hint is not None | |
device = device_hint | |
if _device_or_default(None) != device: | |
maybe_device = f", device={device!r}" | |
nbytes = untyped_storage.nbytes() | |
storage_hash = None | |
if self.store is not None and untyped_storage.device.type != "meta": | |
storage_hash = self.store.write_storage(untyped_storage) | |
self._lines.append( | |
f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})" | |
) | |
self.seen_storages[ws] = v | |
return v | |
def tensor(self, name, t) -> None: | |
storage = self.storage( | |
t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device | |
) | |
args = [] | |
# NB: this is positional, must come first | |
if _stride_or_default(None, shape=t.shape) != t.stride(): | |
args.append(str(tuple(t.stride()))) | |
if _dtype_or_default(None) != t.dtype: | |
args.append(f"dtype={t.dtype!r}") | |
if _storage_offset_or_default(None) != t.storage_offset(): | |
args.append(f"storage_offset={t.storage_offset()!r}") | |
tensor_metadata = torch._utils.get_tensor_metadata(t) | |
if tensor_metadata: | |
args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items()) | |
if _requires_grad_or_default(None) != t.requires_grad: | |
args.append(f"requires_grad={t.requires_grad!r}") | |
is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t) | |
if _is_leaf_or_default(None) != is_leaf: | |
args.append(f"is_leaf={is_leaf!r}") | |
self._lines.append( | |
"reader.tensor(" | |
+ ", ".join([storage, str(tuple(t.shape)), *args]) | |
+ f") # {name}" | |
) | |
# TODO: this doesn't actually symint atm | |
def symint(self, name, val) -> None: | |
if isinstance(val, torch.SymInt): | |
val = val.node.hint | |
self._lines.append(f"reader.symint({val!r}) # {name}") | |
def aot_graph_input_parser( | |
func: Callable[[List[Tensor]], List[Tensor]], | |
device: str = "cuda", | |
sym_shapes: Optional[Dict[str, int]] = None, | |
default_sym_shape: Optional[int] = None, | |
) -> Dict[str, Any]: | |
""" | |
Takes in a function which has been printed with print_readable() and constructs kwargs to run it. | |
Handles Tensor inputs, Symints, and a graph module which might have tensor constants. | |
Consider a function `forward` defined as follows: | |
def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",): | |
_tensor_constant0: "i64[4190]" = self._tensor_constant0 | |
# Further implementation | |
kwargs = aot_graph_input_parser(forward) | |
forward(**kwargs) | |
""" | |
from torch.fx.graph import dtype_abbrs | |
dtype_map = {value: key for key, value in dtype_abbrs.items()} | |
dtype_pattern = "|".join(dtype_abbrs.values()) | |
# Extracting the source code from the function | |
source = inspect.getsource(func) | |
# Regular expressions | |
tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)" | |
tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]" | |
sym_shape_regex = r"Sym\((s\d+)\)" | |
class TensorContainer: | |
"Container for tensors as attributes" | |
pass | |
# Dictionary for tensors from annotations | |
kwargs: Dict[str, Any] = {} | |
sym_shapes = sym_shapes or {} | |
def get_sym_int(symint): | |
torch._check( | |
symint in sym_shapes or default_sym_shape is not None, | |
lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in", | |
) | |
return sym_shapes.get(symint, default_sym_shape) | |
def gen_tensor(shape, dtype) -> Tensor: | |
# Resolve symbolic shapes to concrete values | |
resolved_shape = [] | |
dynamic_dims = [] | |
for i, dim in enumerate(shape): | |
dim = dim.strip() | |
if "s" in dim: | |
s = get_sym_int(dim) | |
resolved_shape.append(s) | |
dynamic_dims.append(i) | |
else: | |
resolved_shape.append(int(dim)) | |
constructor = torch.randn if dtype.is_floating_point else torch.zeros | |
out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg] | |
for d in dynamic_dims: | |
torch._dynamo.mark_dynamic(out, d) | |
return out | |
# Parse function annotations for tensor generation | |
annotations = func.__annotations__ | |
for param, annotation in annotations.items(): | |
# Skip 'return' annotation | |
if param == "return": | |
continue | |
match = re.search(tensor_regex, annotation) | |
if match: | |
data_type, shape_str = match.groups() | |
shape = tuple(shape_str.split(",")) | |
dtype = dtype_map[data_type] | |
kwargs[param] = gen_tensor(shape, dtype) | |
match = re.search(sym_shape_regex, annotation) | |
if match: | |
kwargs[param] = get_sym_int(match.group(1)) | |
if "self" in inspect.signature(func).parameters: | |
container = TensorContainer() | |
kwargs["self"] = container | |
for match in re.finditer(tensor_assignment_regex, source): | |
attr_name, data_type, shape_str, _ = match.groups() | |
shape = tuple(shape_str.split(",")) | |
dtype = dtype_map[data_type] | |
setattr(container, attr_name, gen_tensor(shape, dtype)) | |
return kwargs | |