Spaces:
Running
Running
File size: 28,371 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 |
# mypy: disable-error-code="method-assign"
import copy
import functools
import getpass
import inspect
import itertools
import logging
import os
import re
import subprocess
import tempfile
import textwrap
from collections import Counter
from importlib import import_module
from typing import Any, Callable, Dict, List, Optional, TypeVar
import torch
import torch._prims_common as utils
import torch._subclasses.meta_utils
from torch import Tensor
from torch._dynamo.testing import rand_strided
from torch._prims_common import is_float_dtype
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
from . import config
from .utils import clone_inputs, get_debug_dir
log = logging.getLogger(__name__)
T = TypeVar("T")
inductor_config = import_module("torch._inductor.config")
use_buck = inductor_config.is_fbcode()
if use_buck:
import libfb.py.build_info
extra_deps = []
extra_imports = ""
if use_buck:
extra_deps = [
"//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu",
"//caffe2/torch/fb/sparsenn:sparsenn_operators",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
]
cur_target = libfb.py.build_info.BuildInfo.get_build_rule().replace("fbcode:", "//") # type: ignore[possibly-undefined]
extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"]
class BuckTargetWriter:
def __init__(self, filename):
self.subdir, self.py_file = os.path.split(os.path.abspath(filename))
self.target = self.py_file.replace(".py", "")
# Get main_module path from fbcode
self.path = f'{self.subdir.replace("/", ".")}.{self.target}'
self.path = self.path[self.path.find("fbcode.") :]
self.path = self.path[7:]
# Get cmd line path
tmp = self.subdir
tmp = tmp[tmp.find("fbcode/") :][7:]
self.cmd_line_path = f"//{tmp}:{self.target}"
def build(self):
extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps])
return textwrap.dedent(
f"""
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
python_binary(
name="{self.target}",
srcs = ["{self.py_file}"],
compile = False,
deps = [
"//caffe2:torch",
"//caffe2/functorch:functorch",
"//triton:triton",
"{cur_target}",
],
cpp_deps = [
{extra_cpp_deps}
],
main_module = "{self.path}",
par_style = "xar",
)
"""
)
def write(self, print_msg=True):
target_file = os.path.join(self.subdir, "TARGETS")
with open(target_file, "w") as fd:
fd.write(self.build())
# log.warning("Wrote isolation TARGETS file at %s", target_file)
cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path]
if print_msg:
log.warning(
"Found an example that reproduces the error. Run this cmd to repro - %s",
" ".join(cmd_split),
)
return cmd_split
def minifier_dir():
path = os.path.join(get_debug_dir(), "minifier")
if path is None:
path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}"
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
return path
MAX_CONSTANT_NUMEL_INLINE = 4
class NNModuleToString:
safe_reprs = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.LayerNorm,
torch.nn.Dropout,
torch.nn.Softmax,
torch.nn.ReLU,
torch.nn.GELU,
torch.nn.Identity,
torch.nn.MaxPool2d,
torch.nn.Embedding,
torch.nn.Tanh,
torch.nn.ConvTranspose1d,
torch.nn.GLU,
torch.nn.LSTM,
torch.nn.Flatten,
torch.nn.AdaptiveAvgPool2d,
]
@staticmethod
def can_convert_to_string(gm):
cant_convert = set()
for _, module in gm.named_children():
if type(module) not in NNModuleToString.safe_reprs:
cant_convert.add(module)
if len(cant_convert) > 0:
log.warning("We have not tested reprs of some modules - %s", cant_convert)
# TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
return True
@staticmethod
def convert(gm):
from torch.nn.modules.module import _addindent
tab = " " * 4
model_str = textwrap.dedent(
"""
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
"""
)
for module_name, module in gm.named_children():
module_str = f"{module.__repr__()}"
# module should be a core torch.nn.Module, so all parameters
# should be on the same device.
example_param = next(module.parameters(), None)
if example_param is not None and example_param.is_cuda:
module_str = f"{module_str}.cuda()"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in gm._buffers.items():
if buffer is None:
continue
# Serialize full data for small buffers
if buffer.numel() <= MAX_CONSTANT_NUMEL_INLINE:
from torch._tensor_str import PRINT_OPTS
assert PRINT_OPTS.threshold >= MAX_CONSTANT_NUMEL_INLINE
tensor_str = repr(buffer)
elif torch.is_floating_point(buffer):
tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
else:
tensor_str = (
f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
)
if buffer.is_cuda:
tensor_str = f"{tensor_str}.cuda()"
model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n"
for param_name, param in gm._parameters.items():
if param is None:
continue
maybe_device = ""
if param.is_cuda:
maybe_device = ', device="cuda"'
tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}{maybe_device}))"
model_str += f"{tab*2}self.{param_name} = {tensor_str}\n"
# TODO - Keep this code for now. But, I don't think we will need this.
# attrs = dir(gm)
# for attr in attrs:
# if "_tensor_constant" in attr:
# val = getattr(gm, attr)
# model_str += f" {attr} = {val!r}\n"
model_str += f"{_addindent(gm.code, 4)}\n"
return model_str
@functools.lru_cache(None) # subprocess is expensive
def _cuda_system_info_comment():
if not torch.cuda.is_available():
return "# torch.cuda.is_available()==False, no GPU info collected\n"
model_str = "# CUDA Info: \n"
try:
cuda_version_out = subprocess.check_output(["nvcc", "--version"])
cuda_version_lines = cuda_version_out.decode().split("\n")
comment = "".join([f"# {s} \n" for s in cuda_version_lines if s not in [""]])
model_str += f"{comment}\n"
except (FileNotFoundError, subprocess.CalledProcessError):
model_str += "# nvcc not found\n"
gpu_names = Counter(
torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
)
model_str += "# GPU Hardware Info: \n"
for name, count in gpu_names.items():
model_str += f"# {name} : {count} \n"
model_str += "\n"
return model_str
def generate_config_string(*, stable_output=False):
import torch._functorch.config
import torch._inductor.config
if stable_output:
return "# config omitted due to stable_output=True"
experimental_config = torch.fx.experimental._config.codegen_config() # type: ignore[attr-defined]
return f"""\
import torch._dynamo.config
import torch._inductor.config
import torch._functorch.config
import torch.fx.experimental._config
{torch._dynamo.config.codegen_config()}
{torch._inductor.config.codegen_config()}
{torch._functorch.config.codegen_config()}
{experimental_config}
"""
def get_minifier_repro_path():
return os.path.join(minifier_dir(), "minifier_launcher.py")
def helper_for_dump_minify(contents):
minified_repro_path = get_minifier_repro_path()
log.warning("Writing minified repro to:\n%s", minified_repro_path)
if use_buck:
BuckTargetWriter(minified_repro_path).write()
try:
with open(minified_repro_path, "w") as fd:
fd.write(contents)
except OSError as e:
log.exception(e)
raise NotImplementedError("Could not write to {minified_repro_path}") from e
class AccuracyError(Exception):
pass
def clone_inputs_retaining_gradness(example_inputs):
"""
This clone inputs is different from utils clone_input. In case of minifier,
all the tensors are leaf tensors while creating a new graph. So, we set the
requires_grad field w/o checking the leafness of the tensor.
"""
cloned_inputs = clone_inputs(example_inputs)
for idx in range(len(example_inputs)):
if isinstance(cloned_inputs[idx], torch.Tensor):
cloned_inputs[idx].requires_grad_(example_inputs[idx].requires_grad)
return cloned_inputs
def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False):
"""
Runs a forward and possibly backward iteration for a given mod and args.
When disable_clone is True, we will use args as-is without cloning.
This is higher fidelity but we may destroy the args in the process.
"""
from torch._functorch.aot_autograd import make_boxed_func
from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
gm = copy.deepcopy(gm)
if not disable_clone:
args = clone_inputs_retaining_gradness(args)
if hasattr(gm, "zero_grad"):
gm.zero_grad(True)
# TorchInductor returned callable expects lists. So, boxing the call.
orig_named_parameters = getattr(gm, "named_parameters", None)
orig_named_buffers = getattr(gm, "named_buffers", None)
if not hasattr(gm, "_boxed_call") and (
orig_named_parameters is not None or orig_named_buffers is not None
):
gm = make_boxed_func(gm)
if orig_named_parameters is not None:
gm.named_parameters = orig_named_parameters
if orig_named_buffers is not None:
gm.named_buffers = orig_named_buffers
out = gm(args)
if only_fwd:
return out
if requires_bwd_pass(out):
loss = reduce_to_scalar_loss(out)
loss.backward()
return collect_results(gm, out, None, args)
def same_two_models(
gm,
opt_gm,
example_inputs,
only_fwd=False,
*,
require_fp64=False,
ignore_non_fp=False,
):
"""
Check two models have same accuracy.
require_fp64: if True, raise an error if we unable to calculate the fp64 reference
ignore_non_fp: if True, do not compare outputs which are not floating point. This
is mostly useful for the minifier (which wants to avoid quantizing floating point
error into integer/boolean error)
"""
from .eval_frame import OptimizedModule
from .testing import (
named_buffers_for_optimized_module,
named_parameters_for_optimized_module,
)
from .utils import same
if isinstance(gm, OptimizedModule):
gm.named_parameters = named_parameters_for_optimized_module(gm)
gm.named_buffers = named_buffers_for_optimized_module(gm)
if isinstance(opt_gm, OptimizedModule):
opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)
opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm)
ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
fp64_ref = None
if config.same_two_models_use_fp64:
try:
fp64_model, fp64_examples = cast_to_fp64(
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
)
fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
except Exception:
if require_fp64:
raise RuntimeError("Could not generate fp64 outputs") # noqa: TRY200
log.warning("Could not generate fp64 outputs")
try:
res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd)
except Exception as e:
# This means that the minified graph is bad/exposes a different problem.
# As we are checking accuracy here, lets log the exception and return True.
log.exception(
"While minifying the program in accuracy minification mode, "
"ran into a runtime exception which is likely an unrelated issue."
" Skipping this graph."
)
return True
passing = same(
ref,
res,
fp64_ref,
tol=config.repro_tolerance,
equal_nan=True,
ignore_non_fp=ignore_non_fp,
)
return passing
def cast_dtype_args_to_fp64(model):
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.prims.convert_element_type.default
):
assert len(node.args) == 2
if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:
node.args = (node.args[0], torch.float64)
if node.op == "call_function":
dtype = node.kwargs.get("dtype")
if dtype is not None and is_float_dtype(dtype):
new_kwargs = dict(node.kwargs)
new_kwargs["dtype"] = torch.float64
node.kwargs = new_kwargs
model.graph.lint()
model.recompile()
return model
def cast_to(dtype, model, inputs):
from torch.utils._pytree import tree_map
model = model.to(dtype)
if dtype == torch.float64:
# If casting to fp64 for accuracy comparison, we need to
# replace dtype arguments embedded in the graph with fp64
model = cast_dtype_args_to_fp64(model)
inputs = tree_map(
lambda x: x.to(dtype)
if isinstance(x, torch.Tensor) and x.is_floating_point()
else x,
inputs,
)
return model, inputs
def cast_to_fp64(model, inputs):
return cast_to(torch.float64, model, inputs)
def backend_accuracy_fails(
gm,
example_inputs,
compiler_fn,
only_fwd=False,
*,
require_fp64=False,
ignore_non_fp=False,
):
try:
compiled_gm = compiler_fn(
copy.deepcopy(gm), clone_inputs_retaining_gradness(example_inputs)
)
return not same_two_models(
gm,
compiled_gm,
example_inputs,
only_fwd,
require_fp64=require_fp64,
ignore_non_fp=ignore_non_fp,
)
except Exception as e:
# This means that the minified graph is bad/exposes a different problem.
# As we are checking accuracy here, lets log the exception and return False.
log.exception(
"While minifying the program in accuracy minification mode, "
"ran into a runtime exception which is likely an unrelated issue."
" Skipping this graph"
)
return False
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# REPRO SUPPORT CODE
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# Helper functions for computing what the default values of tensor
# values should be. These all coincide with factory functions, e.g., torch.empty
def _stride_or_default(
stride: Optional["torch._prims_common.StrideType"],
*,
shape: "torch._prims_common.ShapeType",
) -> "torch._prims_common.StrideType":
return stride if stride is not None else utils.make_contiguous_strides_for(shape)
def _mk_defaulter(d: T) -> Callable[[Optional[T]], T]:
return lambda x: x if x is not None else d
_dtype_or_default = _mk_defaulter(torch.float32)
_device_or_default = _mk_defaulter(torch.device("cpu"))
_storage_offset_or_default = _mk_defaulter(0)
_requires_grad_or_default = _mk_defaulter(False)
_is_leaf_or_default = _mk_defaulter(False)
class NopInputReader:
def __init__(self):
self.total = 0
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
self.total += 1
def tensor(self, *args, **kwargs):
pass
def symint(self, *args, **kwargs):
pass
# TODO: Support bundling the entire repro into a zip file for ease of
# transferring around
class InputReader:
def __init__(self, save_dir=None, *, pbar=None):
# If None, we will generate random data instead. It's important
# to natively support this use case as it will allow people to
# share repros without including the real data, if the problem
# reproduces even on random data.
if save_dir is None:
log.warning("no save_dir specified, will generate random data")
self.store = ContentStoreReader(save_dir) if save_dir is not None else None
self.args = []
self.pbar = pbar
def storage(self, storage_hash, nbytes, *, device=None, dtype_hint=None):
if self.pbar is not None:
self.pbar.update(1)
device = _device_or_default(device)
dtype_hint = _dtype_or_default(dtype_hint)
if self.store is not None and storage_hash is not None:
try:
storage = self.store.read_storage(storage_hash)
except FileNotFoundError:
pass
else:
if device != storage.device:
log.warning("device mismatch: %s != %s", device, storage.device)
# TODO: transfer it to the right device? But failing this
# way would be very mysterious! Would have been better
# not to store device in the serialized format...
return storage
log.warning("could not load %s, generating random data instead", storage_hash)
shape = (nbytes // dtype_hint.itemsize,)
stride = _stride_or_default(None, shape=shape)
return rand_strided(shape, stride, dtype_hint, device).untyped_storage()
def tensor(
self,
storage,
shape,
stride=None,
*,
storage_offset=None,
dtype=None,
requires_grad=None,
is_leaf=None,
**metadata,
):
stride = _stride_or_default(stride, shape=shape)
storage_offset = _storage_offset_or_default(storage_offset)
dtype = _dtype_or_default(dtype)
is_leaf = _is_leaf_or_default(is_leaf)
requires_grad = _requires_grad_or_default(requires_grad)
t = torch.tensor(
[], dtype=dtype, device=storage.device, requires_grad=requires_grad
)
with torch.no_grad():
t.set_(storage, storage_offset, shape, stride)
if not is_leaf:
# Fake up some autograd history in a very naughty way
with torch.enable_grad():
t = t.clone(memory_format=torch.preserve_format)
with torch.no_grad():
t.set_(storage, storage_offset, shape, stride)
assert torch._subclasses.meta_utils.safe_is_leaf(t) == is_leaf
torch._utils.set_tensor_metadata(t, metadata)
self.args.append(t)
return t # for BC
def symint(self, val):
self.args.append(val)
return val # for BC
# Here is our writer strategy:
# 1. We will stream all of the inputs to disk
# 2. You can now deterministically randomize the inputs, or reload
# the inputs from disk
# 3. You can YOLO run the script without the inputs, in which case
# we'll fill the inputs with random data and pray. This is the
# legacy behavior, but it's also useful if you want to find out
# if we're so broken even random inputs trigger it
# 4. We could offer an in process "check if the randomized thing
# works too" but this is delicate so we don't do it
class InputWriter:
def __init__(self, save_dir, *, stable_hash=False):
self._lines = []
# TODO: consider ensuring tensor and storage counters line up?
self.storage_counter = itertools.count()
self.save_dir = save_dir
self.store = (
ContentStoreWriter(save_dir, stable_hash=stable_hash)
if save_dir is not None
else None
)
self.seen_storages = {}
def lines(self):
r = [
"def load_args(reader):",
]
r.extend(f" {l}" for l in self._lines)
# In case we need to change the internal format of load_args
# in an FC-breaking way
r.append("load_args._version = 0")
return r
# Storages are untyped, but we need to initialize them with data if
# we don't have the real data, so we give a hint saying what kind
# of initialization may be appropriate
#
# If we had a FakeTensor, device_hint tells us what device should be
def storage(self, untyped_storage, *, dtype_hint=None, device_hint=None) -> str:
ws = StorageWeakRef(untyped_storage)
v = self.seen_storages.get(ws)
if v is not None:
return v
v = f"buf{next(self.storage_counter)}"
maybe_dtype_hint = ""
if _dtype_or_default(None) != _dtype_or_default(dtype_hint):
maybe_dtype_hint = f", dtype_hint={dtype_hint!r}"
# TODO: being optional on device is kind of pointless as the default
# is CPU but most repros we care about are CUDA
maybe_device = ""
device = untyped_storage.device
if device.type == "meta":
assert device_hint is not None
device = device_hint
if _device_or_default(None) != device:
maybe_device = f", device={device!r}"
nbytes = untyped_storage.nbytes()
storage_hash = None
if self.store is not None and untyped_storage.device.type != "meta":
storage_hash = self.store.write_storage(untyped_storage)
self._lines.append(
f"{v} = reader.storage({storage_hash!r}, {nbytes!r}{maybe_device}{maybe_dtype_hint})"
)
self.seen_storages[ws] = v
return v
def tensor(self, name, t) -> None:
storage = self.storage(
t.untyped_storage(), dtype_hint=t.dtype, device_hint=t.device
)
args = []
# NB: this is positional, must come first
if _stride_or_default(None, shape=t.shape) != t.stride():
args.append(str(tuple(t.stride())))
if _dtype_or_default(None) != t.dtype:
args.append(f"dtype={t.dtype!r}")
if _storage_offset_or_default(None) != t.storage_offset():
args.append(f"storage_offset={t.storage_offset()!r}")
tensor_metadata = torch._utils.get_tensor_metadata(t)
if tensor_metadata:
args.extend(f"{k}={v!r}" for k, v in tensor_metadata.items())
if _requires_grad_or_default(None) != t.requires_grad:
args.append(f"requires_grad={t.requires_grad!r}")
is_leaf = torch._subclasses.meta_utils.safe_is_leaf(t)
if _is_leaf_or_default(None) != is_leaf:
args.append(f"is_leaf={is_leaf!r}")
self._lines.append(
"reader.tensor("
+ ", ".join([storage, str(tuple(t.shape)), *args])
+ f") # {name}"
)
# TODO: this doesn't actually symint atm
def symint(self, name, val) -> None:
if isinstance(val, torch.SymInt):
val = val.node.hint
self._lines.append(f"reader.symint({val!r}) # {name}")
def aot_graph_input_parser(
func: Callable[[List[Tensor]], List[Tensor]],
device: str = "cuda",
sym_shapes: Optional[Dict[str, int]] = None,
default_sym_shape: Optional[int] = None,
) -> Dict[str, Any]:
"""
Takes in a function which has been printed with print_readable() and constructs kwargs to run it.
Handles Tensor inputs, Symints, and a graph module which might have tensor constants.
Consider a function `forward` defined as follows:
def forward(self, primals_1: "f32[1001, 6]", primals_2: "f32[s0]", primals_3: "Sym(s0)",):
_tensor_constant0: "i64[4190]" = self._tensor_constant0
# Further implementation
kwargs = aot_graph_input_parser(forward)
forward(**kwargs)
"""
from torch.fx.graph import dtype_abbrs
dtype_map = {value: key for key, value in dtype_abbrs.items()}
dtype_pattern = "|".join(dtype_abbrs.values())
# Extracting the source code from the function
source = inspect.getsource(func)
# Regular expressions
tensor_assignment_regex = rf"(_tensor_constant\d+): \"({dtype_pattern})\[\s*(.*?)\s*\]\" = self\.(_tensor_constant\d+)"
tensor_regex = rf"({dtype_pattern})\[\s*(.*?)\s*\]"
sym_shape_regex = r"Sym\((s\d+)\)"
class TensorContainer:
"Container for tensors as attributes"
pass
# Dictionary for tensors from annotations
kwargs: Dict[str, Any] = {}
sym_shapes = sym_shapes or {}
def get_sym_int(symint):
torch._check(
symint in sym_shapes or default_sym_shape is not None,
lambda: f"{symint} not in symbolic_shapes and default sym shape not passed in",
)
return sym_shapes.get(symint, default_sym_shape)
def gen_tensor(shape, dtype) -> Tensor:
# Resolve symbolic shapes to concrete values
resolved_shape = []
dynamic_dims = []
for i, dim in enumerate(shape):
dim = dim.strip()
if "s" in dim:
s = get_sym_int(dim)
resolved_shape.append(s)
dynamic_dims.append(i)
else:
resolved_shape.append(int(dim))
constructor = torch.randn if dtype.is_floating_point else torch.zeros
out = constructor(resolved_shape, dtype=dtype, device=device) # type: ignore[call-arg]
for d in dynamic_dims:
torch._dynamo.mark_dynamic(out, d)
return out
# Parse function annotations for tensor generation
annotations = func.__annotations__
for param, annotation in annotations.items():
# Skip 'return' annotation
if param == "return":
continue
match = re.search(tensor_regex, annotation)
if match:
data_type, shape_str = match.groups()
shape = tuple(shape_str.split(","))
dtype = dtype_map[data_type]
kwargs[param] = gen_tensor(shape, dtype)
match = re.search(sym_shape_regex, annotation)
if match:
kwargs[param] = get_sym_int(match.group(1))
if "self" in inspect.signature(func).parameters:
container = TensorContainer()
kwargs["self"] = container
for match in re.finditer(tensor_assignment_regex, source):
attr_name, data_type, shape_str, _ = match.groups()
shape = tuple(shape_str.split(","))
dtype = dtype_map[data_type]
setattr(container, attr_name, gen_tensor(shape, dtype))
return kwargs
|