|
|
|
import contextlib |
|
import functools |
|
from typing import Dict, List, Optional, TYPE_CHECKING |
|
|
|
import torch |
|
from torch._dynamo.external_utils import call_backward, call_hook |
|
from torch._dynamo.source import GetItemSource, LocalSource |
|
from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal |
|
from torch._logging import getArtifactLogger, trace_structured |
|
from torch._prims_common import clone_preserve_strides |
|
from torch._subclasses import FakeTensorMode |
|
from torch.fx import GraphModule |
|
from torch.fx.experimental._backward_state import BackwardState |
|
from torch.fx.experimental.proxy_tensor import ( |
|
decompose, |
|
disable_autocast_cache, |
|
disable_proxy_modes_tracing, |
|
fetch_object_proxy, |
|
ProxyTorchDispatchMode, |
|
PythonKeyTracer, |
|
track_tensor_tree, |
|
) |
|
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv |
|
from torch.fx.traceback import preserve_node_meta, set_stack_trace |
|
from torch.utils._traceback import CapturedTraceback |
|
|
|
if TYPE_CHECKING: |
|
from torch.fx.proxy import Proxy |
|
|
|
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd") |
|
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") |
|
|
|
|
|
def snapshot_verbose_logging_enabled(): |
|
return torch._logging._internal.log_state.is_artifact_enabled( |
|
"compiled_autograd_verbose" |
|
) |
|
|
|
|
|
def cpp_verbose_log_fn(msg: str) -> None: |
|
verbose_log.debug(msg) |
|
|
|
|
|
def snapshot_cudagraph_enabled(): |
|
return torch._inductor.config.triton.cudagraphs |
|
|
|
|
|
def maybe_clone(x): |
|
if x is not None: |
|
return clone_preserve_strides(x) |
|
return x |
|
|
|
|
|
class AutogradCompilerInstance: |
|
def __init__(self, compiler_fn) -> None: |
|
self.compiler_fn = compiler_fn |
|
self.stack = contextlib.ExitStack() |
|
self.close = self.stack.close |
|
self.shape_env = ShapeEnv() |
|
self.fake_tensor_mode = FakeTensorMode( |
|
allow_fallback_kernels=True, |
|
allow_non_fake_inputs=True, |
|
shape_env=self.shape_env, |
|
) |
|
self.fx_tracer = PythonKeyTracer() |
|
self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") |
|
self.hooks_proxy: Optional[Proxy] = None |
|
|
|
def wrap_fake(self, x, source): |
|
assert isinstance(x, torch.Tensor) |
|
return self.fake_tensor_mode.from_tensor(x, source=source) |
|
|
|
@staticmethod |
|
def source(name, idx) -> GetItemSource: |
|
return GetItemSource(LocalSource(name), idx) |
|
|
|
def begin_capture(self, inputs: List[torch.Tensor], sizes: List[int]): |
|
counters["compiled_autograd"]["captures"] += 1 |
|
self.fx_tracer.root = torch.nn.Module() |
|
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) |
|
self.fx_tracer.tensor_attrs = {} |
|
args_proxy = self.fx_tracer.create_proxy("placeholder", "inputs", (), {}) |
|
sizes_proxy = self.fx_tracer.create_proxy("placeholder", "sizes", (), {}) |
|
self.hooks_proxy = self.fx_tracer.create_proxy("placeholder", "hooks", (), {}) |
|
|
|
|
|
inputs = [ |
|
self.wrap_fake(x, self.source("inputs", idx)) |
|
for idx, x in enumerate(inputs) |
|
] |
|
proxies = [args_proxy[i] for i in range(len(inputs))] |
|
self.bind_tensors_to_proxies(inputs, proxies) |
|
|
|
|
|
sizes = [ |
|
self.shape_env.create_unspecified_symint_and_symbol( |
|
val, |
|
self.source("sizes", idx), |
|
DimDynamic.DYNAMIC, |
|
) |
|
for idx, val in enumerate(sizes) |
|
] |
|
self.bind_tensors_to_proxies(sizes, sizes_proxy) |
|
|
|
|
|
self.stack.enter_context(decompose({})) |
|
self.stack.enter_context(self.fake_tensor_mode) |
|
self.stack.enter_context(self.proxy_mode.sym_mode) |
|
self.stack.enter_context(self.proxy_mode) |
|
self.stack.enter_context(disable_autocast_cache()) |
|
self.stack.enter_context(preserve_node_meta()) |
|
return inputs, sizes |
|
|
|
def proxy_call_backward( |
|
self, |
|
inputs, |
|
output_metadatas, |
|
saved_tensors, |
|
backward_idx: int, |
|
): |
|
assert self.hooks_proxy is not None |
|
backward_c_function = self.hooks_proxy[backward_idx] |
|
proxies = self.fx_tracer.create_proxy( |
|
kind="call_function", |
|
target=call_backward, |
|
args=( |
|
backward_c_function, |
|
self.to_proxy(saved_tensors), |
|
*self.to_proxy(inputs), |
|
), |
|
kwargs={}, |
|
) |
|
|
|
with disable_proxy_modes_tracing(): |
|
|
|
grad_ins: List[Optional[torch.Tensor]] = [] |
|
for output_metadata in output_metadatas: |
|
if output_metadata is None: |
|
grad_ins.append(None) |
|
continue |
|
|
|
layout, device, dtype, size = output_metadata |
|
grad_ins.append( |
|
torch.empty(size=size, dtype=dtype, layout=layout, device=device) |
|
) |
|
self.bind_tensors_to_proxies(grad_ins, proxies) |
|
return tuple(grad_ins) |
|
|
|
def proxy_call_hook(self, hook, *args): |
|
return self.fx_tracer.create_proxy( |
|
"call_function", |
|
call_hook, |
|
( |
|
hook, |
|
*[self.to_proxy(x) for x in args], |
|
), |
|
{}, |
|
) |
|
|
|
def tensor_pre_hook(self, inputs, hook_id, i: int): |
|
assert self.hooks_proxy is not None |
|
hook = self.hooks_proxy[hook_id] |
|
proxy = self.proxy_call_hook( |
|
hook, |
|
inputs[i], |
|
) |
|
with disable_proxy_modes_tracing(): |
|
inputs[i] = maybe_clone(inputs[i]) |
|
self.bind_tensors_to_proxies([inputs[i]], [proxy]) |
|
return inputs |
|
|
|
def pre_hook(self, inputs, hook_id): |
|
assert self.hooks_proxy is not None |
|
hook = self.hooks_proxy[hook_id] |
|
proxies = self.proxy_call_hook( |
|
hook, |
|
inputs, |
|
) |
|
with disable_proxy_modes_tracing(): |
|
inputs = [maybe_clone(x) for x in inputs] |
|
self.bind_tensors_to_proxies(inputs, proxies) |
|
return inputs |
|
|
|
def post_hook(self, outputs, inputs, hook_id): |
|
assert self.hooks_proxy is not None |
|
hook = self.hooks_proxy[hook_id] |
|
proxies = self.proxy_call_hook( |
|
hook, |
|
outputs, |
|
inputs, |
|
) |
|
with disable_proxy_modes_tracing(): |
|
outputs = [maybe_clone(x) for x in outputs] |
|
self.bind_tensors_to_proxies(outputs, proxies) |
|
return outputs |
|
|
|
def post_acc_grad_hook(self, input, hook_id): |
|
assert isinstance(input, torch.Tensor) |
|
assert self.hooks_proxy is not None |
|
hook = self.hooks_proxy[hook_id] |
|
proxies = self.proxy_call_hook( |
|
hook, |
|
input, |
|
) |
|
with disable_proxy_modes_tracing(): |
|
input = [maybe_clone(input)] |
|
self.bind_tensors_to_proxies(input, proxies) |
|
return input |
|
|
|
|
|
|
|
|
|
|
|
|
|
def move_graph_nodes_to_cuda(self, graph) -> List[int]: |
|
to_move: Dict[int, torch.fx.Node] = {} |
|
has_cuda_inputs = False |
|
nodes = list(graph.nodes) |
|
assert nodes[0].target == "inputs" |
|
inputs = nodes[0] |
|
inputs_users = list(inputs.users.keys()) |
|
|
|
|
|
first_getitem_idx = 3 |
|
assert nodes[first_getitem_idx] == inputs_users[0] |
|
last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 |
|
assert nodes[last_getitem_idx] == inputs_users[-1] |
|
for i, node in enumerate(inputs_users): |
|
if not has_cuda_inputs and node.meta["val"].device.type == "cuda": |
|
has_cuda_inputs = True |
|
continue |
|
|
|
is_cpu = node.meta["val"].device.type == "cpu" |
|
is_scalar = len(node.meta["val"].size()) == 0 |
|
if is_cpu and is_scalar: |
|
node_users = list(node.users.keys()) |
|
if all( |
|
isinstance(user.target, torch._ops.OpOverload) |
|
and user.target.namespace in ("prims", "aten") |
|
for user in node_users |
|
): |
|
|
|
to_move[i] = node |
|
|
|
|
|
|
|
if has_cuda_inputs: |
|
for node in to_move.values(): |
|
node.meta["val"] = node.meta["val"].cuda() |
|
|
|
|
|
return list(to_move.keys()) |
|
|
|
return [] |
|
|
|
def end_capture(self, outputs): |
|
self.stack.close() |
|
self.fx_tracer.create_node( |
|
"output", |
|
"output", |
|
(self.fx_tracer.create_arg(self.to_proxy(outputs)),), |
|
{}, |
|
) |
|
self.reorder_accumulate_grad_nodes() |
|
runtime_inputs_to_move: List[int] = [] |
|
if snapshot_cudagraph_enabled(): |
|
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) |
|
|
|
graph = GraphModule( |
|
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd" |
|
) |
|
set_locals_to_steal(graph, ["inputs"]) |
|
compiled_autograd_log.info( |
|
"%s", lazy_format_graph_code("Compiled autograd graph", graph) |
|
) |
|
verbose_log.debug( |
|
"%s", |
|
lazy_format_graph_code( |
|
"Compiled autograd graph", graph, include_device=True |
|
), |
|
) |
|
trace_structured( |
|
"compiled_autograd_graph", |
|
payload_fn=lambda: graph.print_readable(print_output=False), |
|
) |
|
|
|
def runtime_wrapper(compiled_fn, inputs, sizes, hooks): |
|
for i in runtime_inputs_to_move: |
|
inputs[i] = inputs[i].cuda() |
|
|
|
return compiled_fn(inputs, sizes, hooks) |
|
|
|
return runtime_wrapper, self.compiler_fn(graph) |
|
|
|
def reorder_accumulate_grad_nodes(self): |
|
""" |
|
Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of |
|
the graph. This differs from eager mode, which schedules them as soon as possible. This |
|
pass attempts to reorder the graph to mimic eager behavior. |
|
""" |
|
for node in self.fx_tracer.graph.find_nodes( |
|
op="call_function", target=torch.ops.inductor.accumulate_grad_.default |
|
): |
|
arg = max(node.args) |
|
if arg is not node.prev and arg.op != "placeholder": |
|
arg.append(node) |
|
|
|
def to_proxy(self, t): |
|
if t is None: |
|
return None |
|
if isinstance(t, list): |
|
return [self.to_proxy(x) for x in t] |
|
if isinstance(t, tuple): |
|
return tuple(self.to_proxy(x) for x in t) |
|
assert isinstance(t, (torch.Tensor, torch.SymInt)) |
|
return fetch_object_proxy(self.fx_tracer)(t).proxy |
|
|
|
def bind_tensors_to_proxies(self, tensors, proxies): |
|
if isinstance(proxies, torch.fx.Proxy): |
|
proxies = [proxies[i] for i in range(len(tensors))] |
|
assert len(tensors) == len(proxies) |
|
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer) |
|
|
|
def bind_backward_state(self, index: int): |
|
assert self.hooks_proxy is not None |
|
proxy = self.hooks_proxy[index] |
|
bw_state = BackwardState() |
|
track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer) |
|
return bw_state |
|
|
|
def set_node_origin(self, node_name, node_index): |
|
raw_stack_trace = CapturedTraceback.extract().format()[-1] |
|
new_code = f"{node_name} (NodeCall {node_index})" |
|
new_stack_trace = raw_stack_trace.replace( |
|
"raw_stack_trace = CapturedTraceback.extract().format()[-1]", new_code |
|
) |
|
set_stack_trace(new_stack_trace) |
|
|
|
|
|
compiled_autograd_enabled = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
compiled_autograd_enabled_count = 0 |
|
|
|
|
|
@contextlib.contextmanager |
|
def enable(compiler_fn): |
|
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( |
|
functools.partial(AutogradCompilerInstance, compiler_fn) |
|
) |
|
if snapshot_verbose_logging_enabled(): |
|
torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn) |
|
global compiled_autograd_enabled, compiled_autograd_enabled_count |
|
compiled_autograd_enabled = True |
|
compiled_autograd_enabled_count += 1 |
|
try: |
|
with torch.autograd.set_multithreading_enabled(False): |
|
yield |
|
finally: |
|
compiled_autograd_enabled_count -= 1 |
|
if not prior: |
|
compiled_autograd_enabled = False |
|
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) |
|
|
|
|
|
@contextlib.contextmanager |
|
def disable(): |
|
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) |
|
global compiled_autograd_enabled |
|
compiled_autograd_enabled = False |
|
try: |
|
yield |
|
finally: |
|
if prior: |
|
compiled_autograd_enabled = True |
|
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) |
|
|
|
|
|
|
|
def reset() -> None: |
|
compiled_autograd_enable = False |
|
assert compiled_autograd_enabled_count == 0 |
|
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) |
|
torch._C._dynamo.compiled_autograd.set_verbose_logger(None) |
|
|