Spaces:
Running
Running
File size: 10,020 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
import contextlib
import functools
from typing import List, Optional
import torch
from torch._dynamo.external_utils import call_backward, call_hook
from torch._dynamo.source import GetItemSource, LocalSource
from torch._dynamo.utils import counters, lazy_format_graph_code
from torch._logging import getArtifactLogger, trace_structured
from torch._prims_common import clone_preserve_strides
from torch._subclasses import FakeTensorMode
from torch.fx import GraphModule
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import (
decompose,
disable_autocast_cache,
disable_proxy_modes_tracing,
fetch_object_proxy,
ProxyTorchDispatchMode,
PythonKeyTracer,
track_tensor_tree,
)
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
from torch.fx.proxy import Proxy
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
def maybe_clone(x):
if x is not None:
return clone_preserve_strides(x)
return x
class AutogradCompilerInstance:
def __init__(self, compiler_fn) -> None:
self.compiler_fn = compiler_fn
self.stack = contextlib.ExitStack()
self.close = self.stack.close
self.shape_env = ShapeEnv()
self.fake_tensor_mode = FakeTensorMode(
allow_fallback_kernels=True,
allow_non_fake_inputs=True,
shape_env=self.shape_env,
)
self.fx_tracer = PythonKeyTracer()
self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic")
self.hooks_proxy: Optional[Proxy] = None
def wrap_fake(self, x, source):
assert isinstance(x, torch.Tensor)
return self.fake_tensor_mode.from_tensor(x, source=source)
@staticmethod
def source(name, idx) -> GetItemSource:
return GetItemSource(LocalSource(name), idx)
def begin_capture(self, inputs: List[torch.Tensor], sizes: List[int]):
counters["compiled_autograd"]["captures"] += 1
self.fx_tracer.root = torch.nn.Module()
self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer)
self.fx_tracer.tensor_attrs = {}
args_proxy = self.fx_tracer.create_proxy("placeholder", "inputs", (), {})
sizes_proxy = self.fx_tracer.create_proxy("placeholder", "sizes", (), {})
self.hooks_proxy = self.fx_tracer.create_proxy("placeholder", "hooks", (), {})
# tensor inputs to fake tensors
inputs = [
self.wrap_fake(x, self.source("inputs", idx))
for idx, x in enumerate(inputs)
]
proxies = [args_proxy[i] for i in range(len(inputs))]
self.bind_tensors_to_proxies(inputs, proxies)
# size inputs to symints
sizes = [
self.shape_env.create_unspecified_symint_and_symbol(
val,
self.source("sizes", idx),
DimDynamic.DYNAMIC,
)
for idx, val in enumerate(sizes)
]
self.bind_tensors_to_proxies(sizes, sizes_proxy)
# TODO(jansel): are all these modes needed?
self.stack.enter_context(decompose({}))
self.stack.enter_context(self.fake_tensor_mode)
self.stack.enter_context(self.proxy_mode.sym_mode)
self.stack.enter_context(self.proxy_mode)
self.stack.enter_context(disable_autocast_cache())
return inputs, sizes
def proxy_call_backward(
self,
inputs,
output_metadatas,
saved_tensors,
backward_idx: int,
):
assert self.hooks_proxy is not None
backward_fn = self.hooks_proxy[backward_idx] # type: ignore[index]
proxies = self.fx_tracer.create_proxy(
kind="call_function",
target=call_backward,
args=(
backward_fn,
self.to_proxy(saved_tensors),
*self.to_proxy(inputs),
),
kwargs={},
)
with disable_proxy_modes_tracing():
# create fake Tensors
grad_ins: List[Optional[torch.Tensor]] = []
for output_metadata in output_metadatas:
if output_metadata is None:
grad_ins.append(None)
continue
layout, device, dtype, size = output_metadata
grad_ins.append(
torch.empty(size=size, dtype=dtype, layout=layout, device=device)
)
self.bind_tensors_to_proxies(grad_ins, proxies)
return tuple(grad_ins)
def proxy_call_hook(self, hook, *args):
return self.fx_tracer.create_proxy(
"call_function",
call_hook,
(
hook,
*[self.to_proxy(x) for x in args],
),
{},
)
def tensor_pre_hook(self, inputs, hook_id, i: int):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxy = self.proxy_call_hook(
hook,
inputs[i],
)
with disable_proxy_modes_tracing():
inputs[i] = maybe_clone(inputs[i])
self.bind_tensors_to_proxies([inputs[i]], [proxy])
return inputs
def pre_hook(self, inputs, hook_id):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxies = self.proxy_call_hook(
hook,
inputs,
)
with disable_proxy_modes_tracing():
inputs = [maybe_clone(x) for x in inputs]
self.bind_tensors_to_proxies(inputs, proxies)
return inputs
def post_hook(self, outputs, inputs, hook_id):
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxies = self.proxy_call_hook(
hook,
outputs,
inputs,
)
with disable_proxy_modes_tracing():
outputs = [maybe_clone(x) for x in outputs]
self.bind_tensors_to_proxies(outputs, proxies)
return outputs
def post_acc_grad_hook(self, input, hook_id):
assert isinstance(input, torch.Tensor)
assert self.hooks_proxy is not None
hook = self.hooks_proxy[hook_id] # type: ignore[index]
proxies = self.proxy_call_hook(
hook,
input,
)
with disable_proxy_modes_tracing():
input = [maybe_clone(input)]
self.bind_tensors_to_proxies(input, proxies)
return input
def end_capture(self, outputs):
self.stack.close()
self.fx_tracer.create_node(
"output",
"output",
(self.fx_tracer.create_arg(self.to_proxy(outputs)),),
{},
)
graph = GraphModule(
self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd"
)
compiled_autograd_log.info(
"%s", lazy_format_graph_code("Compiled autograd graph", graph)
)
trace_structured(
"compiled_autograd_graph",
payload_fn=lambda: graph.print_readable(print_output=False),
)
return self.compiler_fn(graph)
def to_proxy(self, t):
if t is None:
return None
if isinstance(t, list):
return [self.to_proxy(x) for x in t]
if isinstance(t, tuple):
return tuple(self.to_proxy(x) for x in t)
assert isinstance(t, (torch.Tensor, torch.SymInt))
return fetch_object_proxy(self.fx_tracer)(t).proxy
def bind_tensors_to_proxies(self, tensors, proxies):
if isinstance(proxies, torch.fx.Proxy):
proxies = [proxies[i] for i in range(len(tensors))]
assert len(tensors) == len(proxies)
track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer)
def bind_backward_state(self, index: int):
assert self.hooks_proxy is not None
proxy = self.hooks_proxy[index] # type: ignore[index]
bw_state = BackwardState()
track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer)
return bw_state
compiled_autograd_enabled = False
# We may have code like:
# with enable(compiler_fn):
# ...
# with disable():
# ...
# ...
# The disable() call just want to disable compiled autograd temporarily.
# But overall the feature is enabled.
#
# The code covered by the disable context manager has no way to know if
# compiled autograd is overall eanbled. Use another variable
# compiled_autograd_enabled_count to indicate how many times compiled
# autograd has been enabled in the call stack for this purpose.
compiled_autograd_enabled_count = 0
@contextlib.contextmanager
def enable(compiler_fn):
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
functools.partial(AutogradCompilerInstance, compiler_fn)
)
global compiled_autograd_enabled, compiled_autograd_enabled_count
compiled_autograd_enabled = True
compiled_autograd_enabled_count += 1
try:
with torch.autograd.set_multithreading_enabled(False):
yield
finally:
compiled_autograd_enabled_count -= 1
if not prior:
compiled_autograd_enabled = False
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
@contextlib.contextmanager
def disable():
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
global compiled_autograd_enabled
compiled_autograd_enabled = False
try:
yield
finally:
if prior:
compiled_autograd_enabled = True
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
|