Spaces:
Running
Running
File size: 23,217 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 |
# mypy: ignore-errors
import enum
import dis
import copy
import sys
import torch
import inspect
import operator
import traceback
import collections
from dataclasses import is_dataclass, fields
from .graph import magic_methods, reflectable_magic_methods, Graph
from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable
from .node import Target, Node, Argument, base_types, map_aggregate
from ._compatibility import compatibility
from .operator_schemas import check_for_mutable_operation
import torch.fx.traceback as fx_traceback
__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError',
'Proxy', 'Attribute', 'ParameterProxy', 'Scope',
'ScopeContextManager']
@compatibility(is_backward_compatible=False)
class Scope:
""" Scope object that records the module path and the module type
of a module. Scope is used to track the information of the module
that contains a Node in a Graph of GraphModule. For example::
class Sub(torch.nn.Module):
def forward(self, x):
# This will be a call_method Node in GraphModule,
# scope for this would be (module_path="sub", module_type=Sub)
return x.transpose(1, 2)
class M(torch.nn.Module):
def __init__(self):
self.sub = Sub()
def forward(self, x):
# This will be a call_method Node as well,
# scope for this would be (module_path="", None)
x = x.transpose(1, 2)
x = self.sub(x)
return x
"""
def __init__(self, module_path: str, module_type: Any):
super().__init__()
self.module_path = module_path
self.module_type = module_type
@compatibility(is_backward_compatible=False)
class ScopeContextManager:
""" A context manager to track the Scope of Node during symbolic tracing.
When entering a forward function of a Module, we'll update the scope information of
the current module, and when we exit, we'll restore the previous scope information.
"""
def __init__(
self,
scope: Scope,
current_scope: Scope,
):
super().__init__()
# Keep a copy of prev scope to restore on exit
self._prev_scope = copy.copy(scope)
# Update scope to current scope
scope.module_path = current_scope.module_path
scope.module_type = current_scope.module_type
# Save a reference so we can restore it
self._scope = scope
def __enter__(self):
return self._scope
def __exit__(self, *args):
self._scope.module_path = self._prev_scope.module_path
self._scope.module_type = self._prev_scope.module_type
return
_COPY_META_FIELDS = ["nn_module_stack", "source_fn_stack", "original_aten", "recompute", "from_node", "quantization_tag"]
@compatibility(is_backward_compatible=True)
class TracerBase:
graph: Graph
record_stack_traces : bool = False
# Feature flag for mutable schema checking
# Enableby default in 1.12
check_mutable_operations : bool = False
# Feature flag for assert tracing
trace_asserts : bool = False
# Feature flag for proxying accesses to buffer values
proxy_buffer_attributes : bool = False
# Name of the function to be traced. It will only be used when
# ``root`` is an instance of ``nn.Module``
traced_func_name: str = "forward"
# Maps the containing module's name to the operator name
scope : Scope
# Records the module call stack
module_stack: OrderedDict[str, Tuple[str, Any]]
# Mapping of node name to module scope
node_name_to_scope: Dict[str, Tuple[str, type]]
@compatibility(is_backward_compatible=True)
def create_node(self, kind : str, target : Target,
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
type_expr : Optional[Any] = None) -> Node:
"""
Inserts a graph node given target, args, kwargs, and name.
This method can be overridden to do extra checking, validation, or
modification of values used in node creation. For example, one might
want to disallow in-place operations from being recorded.
"""
if kind == 'call_function' and self.check_mutable_operations:
check_for_mutable_operation(target, args, kwargs)
node = self.graph.create_node(kind, target, args, kwargs, name, type_expr)
# TODO node_name_to_scope will be depreciated in favor of
# node.meta['nn_module_stack']
self.node_name_to_scope[node.name] = (
self.scope.module_path,
self.scope.module_type,
)
# Optionally set stack trace on the created Node for debugging purposes
if fx_traceback.has_preserved_node_meta():
current_meta: Dict[str, Any] = fx_traceback.get_current_meta()
stack_trace = current_meta.get("stack_trace")
if stack_trace:
node.stack_trace = stack_trace
# Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta
# If other meta fields are needed, they can be added here
for field in _COPY_META_FIELDS:
if field in current_meta:
node.meta[field] = copy.copy(current_meta[field])
# Here we decrement to account for the sequence_nr having
# just been incremented while tracing this lowered aten op.
new_seq_nr = torch.autograd._get_sequence_nr() - 1
# The sequence_nr increments every time a new autograd Node
# is created. During the FWD pass we store the sequence_nr
# corresponding to the last autograd Node created on this fx
# node's meta. A single aten op can create multiple autograd
# nodes as is the case with in-place foreach ops. During the
# BWD pass we retrieve the sequence_nr stored on the current
# executing autograd Node. See NOTE [ Sequence Number ].
if current_meta.get("in_grad_fn", 0) > 0:
new_seq_nr = current_meta["grad_fn_seq_nr"][-1]
node.meta["seq_nr"] = new_seq_nr
elif self.module_stack:
node.meta['nn_module_stack'] = copy.copy(self.module_stack)
return node
@compatibility(is_backward_compatible=True)
def proxy(self, node: Node) -> 'Proxy':
return Proxy(node, self)
@compatibility(is_backward_compatible=True)
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
name: Optional[str] = None, type_expr : Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
'''
Create a Node from the given arguments, then return the Node
wrapped in a Proxy object.
If kind = 'placeholder', then we're creating a Node that
represents the parameter of a function. If we need to encode
a default parameter, we use the ``args`` tuple. ``args`` is
otherwise empty for ``placeholder`` Nodes.
'''
args_ = self.create_arg(args)
kwargs_ = self.create_arg(kwargs)
assert isinstance(args_, tuple)
assert isinstance(kwargs_, dict)
node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
if not proxy_factory_fn:
proxy = self.proxy(node)
else:
proxy = proxy_factory_fn(node)
if self.record_stack_traces and not proxy.node.stack_trace:
user_frame = self._find_user_frame()
if user_frame:
summary = traceback.extract_stack(user_frame)
tb_lines = summary.format()
# stack_trace would have innermost frame at the bottom
proxy.node.stack_trace = ''.join(tb_lines)
return proxy
def _find_user_frame(self):
"""
Find the Python stack frame executing the user code during
symbolic tracing.
"""
# We have to do a little dance here. Basically, walk up the callstack and
# record the first frame not in the pytorch source. This is the frame executing
# the user code during tracing.
frame = inspect.currentframe()
pt_files = ['torch/fx/proxy.py',
'torch/fx/_symbolic_trace.py',
'torch/fx/experimental/proxy_tensor.py',
'torch/_ops.py',
'torch/_tensor.py',
'torch/utils/_python_dispatch.py',
'torch/_prims_common/wrappers.py',
'torch/_refs/__init__.py',
'torch/_refs/nn/functional/__init__.py',
'torch/utils/_stats.py',
]
while frame:
frame = frame.f_back
if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files):
break
if not frame:
return None
return frame
@compatibility(is_backward_compatible=True)
def create_arg(self, a: Any) -> Argument:
"""
A method that lowers the objects seen as arguments during symbolic evaluation
into Argument types that can be stored in IR.
Can be override to support more trace-specific types.
"""
if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
return a.__fx_create_arg__(self)
# aggregates
elif isinstance(a, tuple) and hasattr(a, '_fields'):
# NamedTuple constructors don't seem to like getting a generator
# expression as an argument to their constructor, so build this
# intermediate tuple and unpack it into the NamedTuple constructor
args = tuple(self.create_arg(elem) for elem in a)
return type(a)(*args) # type: ignore[arg-type]
elif isinstance(a, (tuple, list)):
return type(a)(self.create_arg(elem) for elem in a)
elif isinstance(a, dict):
r = {}
for k, v in a.items():
# Check for invalid dict keys. We do not want a Proxy to appear
# anywhere within the key. Since keys can be collection types,
# we iterate through the key with map_aggregate
k = self.create_arg(k)
def no_node(arg):
if isinstance(arg, Node):
raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
f"Node. Got key: {k}")
map_aggregate(k, no_node)
r[k] = self.create_arg(v)
return r
elif isinstance(a, slice):
return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
elif isinstance(a, range):
return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
elif isinstance(a, torch._ops.OpOverload):
return a
if isinstance(a, Proxy):
# base case: we unwrap the Proxy object
return a.node
if is_dataclass(a):
kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)}
return self.create_node("call_function", a.__class__, (), kwargs)
elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...:
return a
raise NotImplementedError(f"argument of type: {type(a)}")
@compatibility(is_backward_compatible=True)
def to_bool(self, obj: 'Proxy') -> bool:
"""Called when a proxy object is being converted to a boolean, such as
when used in control flow. Normally we don't know what to do because
we don't know the value of the proxy, but a custom tracer can attach more
information to the graph node using create_node and can choose to return a value.
"""
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
@compatibility(is_backward_compatible=True)
def iter(self, obj: 'Proxy') -> Iterator:
"""Called when a proxy object is being iterated over, such as
when used in control flow. Normally we don't know what to do because
we don't know the value of the proxy, but a custom tracer can attach more
information to the graph node using create_node and can choose to return an iterator.
"""
raise TraceError('Proxy object cannot be iterated. This can be '
'attempted when the Proxy is used in a loop or'
' as a *args or **kwargs function argument. '
'See the torch.fx docs on pytorch.org for a '
'more detailed explanation of what types of '
'control flow can be traced, and check out the'
' Proxy docstring for help troubleshooting '
'Proxy iteration errors')
@compatibility(is_backward_compatible=True)
def keys(self, obj: 'Proxy') -> Any:
"""Called when a proxy object is has the keys() method called.
This is what happens when ** is called on a proxy. This should return an
iterator it ** is suppose to work in your custom tracer.
"""
return Attribute(obj, 'keys')()
# used in Proxy object when just appending to the graph while not tracing.
@compatibility(is_backward_compatible=True)
class GraphAppendingTracer(TracerBase):
def __init__(self, graph: Graph):
super().__init__()
self.graph = graph
self.scope = Scope("", None)
self.module_stack = collections.OrderedDict()
self.node_name_to_scope = {}
@compatibility(is_backward_compatible=False)
def assert_fn(x):
assert x
@compatibility(is_backward_compatible=True)
class TraceError(ValueError):
pass
@compatibility(is_backward_compatible=True)
class Proxy:
"""
``Proxy`` objects are ``Node`` wrappers that flow through the
program during symbolic tracing and record all the operations
(``torch`` function calls, method calls, operators) that they touch
into the growing FX Graph.
If you're doing graph transforms, you can wrap your own ``Proxy``
method around a raw ``Node`` so that you can use the overloaded
operators to add additional things to a ``Graph``.
``Proxy`` objects cannot be iterated. In other words, the symbolic
tracer will throw an error if a ``Proxy`` is used in a loop or as
an ``*args``/``**kwargs`` function argument.
There are two main ways around this:
1. Factor out the untraceable logic into a top-level function and
use ``fx.wrap`` on it.
2. If the control flow is static (i.e. the loop trip count is
based on some hyperparameter), the code can be kept in its original
position and refactored into something like::
for i in range(self.some_hyperparameter):
indexed_item = proxied_value[i]
For a more detailed description into the Proxy internals, check out
the "Proxy" section in `torch/fx/OVERVIEW.md`
"""
@compatibility(is_backward_compatible=True)
def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
if tracer is None:
# This allows you to create a Proxy object around a raw Node
tracer = GraphAppendingTracer(node.graph)
self.tracer = tracer
self.node = node
def __repr__(self) -> str:
return f'Proxy({self.node.name})'
def __getattr__(self, k) -> 'Attribute':
# note: not added to the graph yet, if this is a method call
# we peephole optimize to the method invocation
return Attribute(self, k)
def __call__(self, *args, **kwargs) -> 'Proxy':
return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
def __iter__(self) -> Iterator['Proxy']:
frame = inspect.currentframe()
assert frame is not None
calling_frame = frame.f_back
assert calling_frame is not None
inst_list = list(dis.get_instructions(calling_frame.f_code))
if sys.version_info >= (3, 11):
from bisect import bisect_left
inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset)
else:
inst_idx = calling_frame.f_lasti // 2
inst = inst_list[inst_idx]
if inst.opname == 'UNPACK_SEQUENCE':
return (self[i] for i in range(inst.argval)) # type: ignore[index]
return self.tracer.iter(self)
def __abs__(self):
return self.tracer.create_proxy('call_function', operator.abs, (self,), {})
def __bool__(self) -> bool:
if self.tracer.trace_asserts:
# check if this boolean is used in an assertion, bytecode pattern for assertions
# is pretty stable for Python 3.7--3.9
frame = inspect.currentframe()
assert frame is not None
calling_frame = frame.f_back
assert calling_frame is not None
insts = list(dis.get_instructions(calling_frame.f_code))
if sys.version_info >= (3, 11):
from bisect import bisect_left
cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset)
else:
cur = calling_frame.f_lasti // 2
inst = insts[cur]
if inst.opname == 'POP_JUMP_IF_TRUE':
first = insts[cur + 1]
assert inst.arg is not None
last = insts[inst.arg // 2 - 1]
starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError'
or first.opname == 'LOAD_ASSERTION_ERROR')
if starts_with_assert and last.opname == 'RAISE_VARARGS':
self.tracer.create_proxy('call_function', assert_fn, (self,), {})
return True
return self.tracer.to_bool(self)
@compatibility(is_backward_compatible=True)
def keys(self):
return self.tracer.keys(self)
def __len__(self):
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
"this call to be recorded, please call torch.fx.wrap('len') at "
"module scope")
@classmethod
def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
args = args if args else ()
kwargs = kwargs if kwargs else {}
tracers : Dict[Any, None] = {}
def find_tracer(a):
if isinstance(a, cls):
tracers[a.tracer] = None
torch.fx.node.map_aggregate(args, find_tracer)
torch.fx.node.map_aggregate(kwargs, find_tracer)
if len(tracers) > 1:
raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while '
f'trying to trace operations {orig_method}')
tracer = next(iter(tracers.keys()))
if isinstance(orig_method, torch._C.ScriptMethod):
args = (orig_method.owner,) + args
return tracer.create_proxy('call_method', orig_method.name, args, kwargs)
if torch.overrides.is_tensor_method_or_property(orig_method):
return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
else:
if isinstance(orig_method, torch._ops.HigherOrderOperator):
# TODO: Define how to symbolically trace HigherOrderOperators
raise RuntimeError("Unable to symbolically trace HigherOrderOperators")
return tracer.create_proxy('call_function', orig_method, args, kwargs,
name=tracer.graph._target_to_str(orig_method.__name__))
@compatibility(is_backward_compatible=True)
class Attribute(Proxy):
@compatibility(is_backward_compatible=True)
def __init__(self, root: Proxy, attr: str):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node: Optional[Node] = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
@compatibility(is_backward_compatible=False)
class ParameterProxy(Proxy):
"""
A special proxy which lets "shape", "size", "dim", and a few other
attribute accesses pass through to the underlying module parameter object,
so that conditional tests on these attributes will not throw exception during tracing
"""
def __init__(self, tracer: TracerBase, node: Node, name, param):
super().__init__(node, tracer)
assert isinstance(param, torch.nn.Parameter)
self.param = param
self.name = name
def __repr__(self) -> str:
return f'ParameterProxy({self.name})'
@property
def shape(self):
return self.param.shape
def size(self):
return self.param.size()
def dim(self):
return self.param.dim()
@property
def ndim(self):
return self.param.ndim
def numel(self):
return self.param.numel()
def nelement(self):
return self.param.nelement()
for method in magic_methods:
def _scope(method):
def impl(*args, **kwargs):
tracer = args[0].tracer
target = getattr(operator, method)
return tracer.create_proxy('call_function', target, args, kwargs)
impl.__name__ = method
as_magic = f'__{method.strip("_")}__'
setattr(Proxy, as_magic, impl)
_scope(method)
def _define_reflectable(orig_method_name):
method_name = f'__r{orig_method_name.strip("_")}__'
def impl(self, rhs):
target = getattr(operator, orig_method_name)
return self.tracer.create_proxy('call_function', target, (rhs, self), {})
impl.__name__ = method_name
impl.__qualname__ = method_name
setattr(Proxy, method_name, impl)
for orig_method_name in reflectable_magic_methods:
_define_reflectable(orig_method_name)
|