|
|
|
|
|
import contextlib |
|
import functools |
|
import logging |
|
from unittest.mock import patch |
|
|
|
import torch |
|
from torch._dynamo import disable |
|
from torch._dynamo.utils import counters, defake, flatten_graph_inputs |
|
from torch._functorch.aot_autograd import aot_module_simplified |
|
from torch.utils._python_dispatch import _disable_current_modes |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
class AotAutograd: |
|
def __init__(self, **kwargs): |
|
self.__name__ = "compiler_fn" |
|
self.kwargs = kwargs |
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs): |
|
if any(isinstance(x, (list, tuple, dict)) for x in example_inputs): |
|
return flatten_graph_inputs( |
|
gm, |
|
example_inputs, |
|
self, |
|
) |
|
|
|
|
|
if callable(self.kwargs.get("decompositions")): |
|
self.kwargs["decompositions"] = self.kwargs["decompositions"]() |
|
|
|
|
|
counters["aot_autograd"]["total"] += 1 |
|
use_fallback = False |
|
|
|
if use_fallback: |
|
log.debug("Unable to use AOT Autograd because graph has mutation") |
|
counters["aot_autograd"]["not_ok"] += 1 |
|
return gm |
|
|
|
|
|
|
|
def _wrapped_bw_compiler(*args, **kwargs): |
|
|
|
return disable(disable(bw_compiler)(*args, **kwargs)) |
|
|
|
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"] |
|
self.kwargs["bw_compiler"] = _wrapped_bw_compiler |
|
self.kwargs["inference_compiler"] = ( |
|
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"] |
|
) |
|
|
|
from functorch.compile import nop |
|
|
|
from torch._inductor.debug import enable_aot_logging |
|
|
|
|
|
|
|
if self.kwargs.get("fw_compiler", None) == nop: |
|
patch_config = patch("functorch.compile.config.debug_assert", True) |
|
else: |
|
patch_config = contextlib.nullcontext() |
|
|
|
try: |
|
|
|
with enable_aot_logging(), patch_config: |
|
cg = aot_module_simplified(gm, example_inputs, **self.kwargs) |
|
counters["aot_autograd"]["ok"] += 1 |
|
return disable(cg) |
|
except Exception: |
|
counters["aot_autograd"]["not_ok"] += 1 |
|
raise |
|
|
|
|
|
def aot_autograd(**kwargs): |
|
return AotAutograd(**kwargs) |
|
|
|
|
|
def mem_efficient_fusion_kwargs(use_decomps): |
|
from functorch.compile import ( |
|
default_decompositions, |
|
min_cut_rematerialization_partition, |
|
ts_compile, |
|
) |
|
|
|
kwargs = { |
|
|
|
"fw_compiler": ts_compile, |
|
"bw_compiler": ts_compile, |
|
"partition_fn": min_cut_rematerialization_partition, |
|
} |
|
|
|
if use_decomps: |
|
kwargs["decompositions"] = default_decompositions |
|
|
|
return kwargs |
|
|
|
|
|
def fake_tensor_unsupported(fn): |
|
""" |
|
Decorator for backends that need real inputs. We swap out fake |
|
tensors for zero tensors. |
|
""" |
|
|
|
@functools.wraps(fn) |
|
def wrapper(model, inputs, **kwargs): |
|
with _disable_current_modes(): |
|
inputs = list(map(defake, inputs)) |
|
return fn(model, inputs, **kwargs) |
|
|
|
return wrapper |
|
|
|
|
|
def device_from_inputs(example_inputs) -> torch.device: |
|
for x in example_inputs: |
|
if hasattr(x, "device"): |
|
return x.device |
|
|
|
|
|
def dtype_from_inputs(example_inputs) -> torch.dtype: |
|
for x in example_inputs: |
|
if hasattr(x, "dtype"): |
|
return x.dtype |
|
|