Spaces:
Running
Running
# mypy: ignore-errors | |
import enum | |
import dis | |
import copy | |
import sys | |
import torch | |
import inspect | |
import operator | |
import traceback | |
import collections | |
from dataclasses import is_dataclass, fields | |
from .graph import magic_methods, reflectable_magic_methods, Graph | |
from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable | |
from .node import Target, Node, Argument, base_types, map_aggregate | |
from ._compatibility import compatibility | |
from .operator_schemas import check_for_mutable_operation | |
import torch.fx.traceback as fx_traceback | |
__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', | |
'Proxy', 'Attribute', 'ParameterProxy', 'Scope', | |
'ScopeContextManager'] | |
class Scope: | |
""" Scope object that records the module path and the module type | |
of a module. Scope is used to track the information of the module | |
that contains a Node in a Graph of GraphModule. For example:: | |
class Sub(torch.nn.Module): | |
def forward(self, x): | |
# This will be a call_method Node in GraphModule, | |
# scope for this would be (module_path="sub", module_type=Sub) | |
return x.transpose(1, 2) | |
class M(torch.nn.Module): | |
def __init__(self): | |
self.sub = Sub() | |
def forward(self, x): | |
# This will be a call_method Node as well, | |
# scope for this would be (module_path="", None) | |
x = x.transpose(1, 2) | |
x = self.sub(x) | |
return x | |
""" | |
def __init__(self, module_path: str, module_type: Any): | |
super().__init__() | |
self.module_path = module_path | |
self.module_type = module_type | |
class ScopeContextManager: | |
""" A context manager to track the Scope of Node during symbolic tracing. | |
When entering a forward function of a Module, we'll update the scope information of | |
the current module, and when we exit, we'll restore the previous scope information. | |
""" | |
def __init__( | |
self, | |
scope: Scope, | |
current_scope: Scope, | |
): | |
super().__init__() | |
# Keep a copy of prev scope to restore on exit | |
self._prev_scope = copy.copy(scope) | |
# Update scope to current scope | |
scope.module_path = current_scope.module_path | |
scope.module_type = current_scope.module_type | |
# Save a reference so we can restore it | |
self._scope = scope | |
def __enter__(self): | |
return self._scope | |
def __exit__(self, *args): | |
self._scope.module_path = self._prev_scope.module_path | |
self._scope.module_type = self._prev_scope.module_type | |
return | |
_COPY_META_FIELDS = ["nn_module_stack", "source_fn_stack", "original_aten", "recompute", "from_node", "quantization_tag"] | |
class TracerBase: | |
graph: Graph | |
record_stack_traces : bool = False | |
# Feature flag for mutable schema checking | |
# Enableby default in 1.12 | |
check_mutable_operations : bool = False | |
# Feature flag for assert tracing | |
trace_asserts : bool = False | |
# Feature flag for proxying accesses to buffer values | |
proxy_buffer_attributes : bool = False | |
# Name of the function to be traced. It will only be used when | |
# ``root`` is an instance of ``nn.Module`` | |
traced_func_name: str = "forward" | |
# Maps the containing module's name to the operator name | |
scope : Scope | |
# Records the module call stack | |
module_stack: OrderedDict[str, Tuple[str, Any]] | |
# Mapping of node name to module scope | |
node_name_to_scope: Dict[str, Tuple[str, type]] | |
def create_node(self, kind : str, target : Target, | |
args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, | |
type_expr : Optional[Any] = None) -> Node: | |
""" | |
Inserts a graph node given target, args, kwargs, and name. | |
This method can be overridden to do extra checking, validation, or | |
modification of values used in node creation. For example, one might | |
want to disallow in-place operations from being recorded. | |
""" | |
if kind == 'call_function' and self.check_mutable_operations: | |
check_for_mutable_operation(target, args, kwargs) | |
node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) | |
# TODO node_name_to_scope will be depreciated in favor of | |
# node.meta['nn_module_stack'] | |
self.node_name_to_scope[node.name] = ( | |
self.scope.module_path, | |
self.scope.module_type, | |
) | |
# Optionally set stack trace on the created Node for debugging purposes | |
if fx_traceback.has_preserved_node_meta(): | |
current_meta: Dict[str, Any] = fx_traceback.get_current_meta() | |
stack_trace = current_meta.get("stack_trace") | |
if stack_trace: | |
node.stack_trace = stack_trace | |
# Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta | |
# If other meta fields are needed, they can be added here | |
for field in _COPY_META_FIELDS: | |
if field in current_meta: | |
node.meta[field] = copy.copy(current_meta[field]) | |
# Here we decrement to account for the sequence_nr having | |
# just been incremented while tracing this lowered aten op. | |
new_seq_nr = torch.autograd._get_sequence_nr() - 1 | |
# The sequence_nr increments every time a new autograd Node | |
# is created. During the FWD pass we store the sequence_nr | |
# corresponding to the last autograd Node created on this fx | |
# node's meta. A single aten op can create multiple autograd | |
# nodes as is the case with in-place foreach ops. During the | |
# BWD pass we retrieve the sequence_nr stored on the current | |
# executing autograd Node. See NOTE [ Sequence Number ]. | |
if current_meta.get("in_grad_fn", 0) > 0: | |
new_seq_nr = current_meta["grad_fn_seq_nr"][-1] | |
node.meta["seq_nr"] = new_seq_nr | |
elif self.module_stack: | |
node.meta['nn_module_stack'] = copy.copy(self.module_stack) | |
return node | |
def proxy(self, node: Node) -> 'Proxy': | |
return Proxy(node, self) | |
def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], | |
name: Optional[str] = None, type_expr : Optional[Any] = None, | |
proxy_factory_fn: Callable[[Node], 'Proxy'] = None): | |
''' | |
Create a Node from the given arguments, then return the Node | |
wrapped in a Proxy object. | |
If kind = 'placeholder', then we're creating a Node that | |
represents the parameter of a function. If we need to encode | |
a default parameter, we use the ``args`` tuple. ``args`` is | |
otherwise empty for ``placeholder`` Nodes. | |
''' | |
args_ = self.create_arg(args) | |
kwargs_ = self.create_arg(kwargs) | |
assert isinstance(args_, tuple) | |
assert isinstance(kwargs_, dict) | |
node = self.create_node(kind, target, args_, kwargs_, name, type_expr) | |
if not proxy_factory_fn: | |
proxy = self.proxy(node) | |
else: | |
proxy = proxy_factory_fn(node) | |
if self.record_stack_traces and not proxy.node.stack_trace: | |
user_frame = self._find_user_frame() | |
if user_frame: | |
summary = traceback.extract_stack(user_frame) | |
tb_lines = summary.format() | |
# stack_trace would have innermost frame at the bottom | |
proxy.node.stack_trace = ''.join(tb_lines) | |
return proxy | |
def _find_user_frame(self): | |
""" | |
Find the Python stack frame executing the user code during | |
symbolic tracing. | |
""" | |
# We have to do a little dance here. Basically, walk up the callstack and | |
# record the first frame not in the pytorch source. This is the frame executing | |
# the user code during tracing. | |
frame = inspect.currentframe() | |
pt_files = ['torch/fx/proxy.py', | |
'torch/fx/_symbolic_trace.py', | |
'torch/fx/experimental/proxy_tensor.py', | |
'torch/_ops.py', | |
'torch/_tensor.py', | |
'torch/utils/_python_dispatch.py', | |
'torch/_prims_common/wrappers.py', | |
'torch/_refs/__init__.py', | |
'torch/_refs/nn/functional/__init__.py', | |
'torch/utils/_stats.py', | |
] | |
while frame: | |
frame = frame.f_back | |
if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files): | |
break | |
if not frame: | |
return None | |
return frame | |
def create_arg(self, a: Any) -> Argument: | |
""" | |
A method that lowers the objects seen as arguments during symbolic evaluation | |
into Argument types that can be stored in IR. | |
Can be override to support more trace-specific types. | |
""" | |
if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'): | |
return a.__fx_create_arg__(self) | |
# aggregates | |
elif isinstance(a, tuple) and hasattr(a, '_fields'): | |
# NamedTuple constructors don't seem to like getting a generator | |
# expression as an argument to their constructor, so build this | |
# intermediate tuple and unpack it into the NamedTuple constructor | |
args = tuple(self.create_arg(elem) for elem in a) | |
return type(a)(*args) # type: ignore[arg-type] | |
elif isinstance(a, (tuple, list)): | |
return type(a)(self.create_arg(elem) for elem in a) | |
elif isinstance(a, dict): | |
r = {} | |
for k, v in a.items(): | |
# Check for invalid dict keys. We do not want a Proxy to appear | |
# anywhere within the key. Since keys can be collection types, | |
# we iterate through the key with map_aggregate | |
k = self.create_arg(k) | |
def no_node(arg): | |
if isinstance(arg, Node): | |
raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " | |
f"Node. Got key: {k}") | |
map_aggregate(k, no_node) | |
r[k] = self.create_arg(v) | |
return r | |
elif isinstance(a, slice): | |
return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) | |
elif isinstance(a, range): | |
return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) | |
elif isinstance(a, torch._ops.OpOverload): | |
return a | |
if isinstance(a, Proxy): | |
# base case: we unwrap the Proxy object | |
return a.node | |
if is_dataclass(a): | |
kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)} | |
return self.create_node("call_function", a.__class__, (), kwargs) | |
elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: | |
return a | |
raise NotImplementedError(f"argument of type: {type(a)}") | |
def to_bool(self, obj: 'Proxy') -> bool: | |
"""Called when a proxy object is being converted to a boolean, such as | |
when used in control flow. Normally we don't know what to do because | |
we don't know the value of the proxy, but a custom tracer can attach more | |
information to the graph node using create_node and can choose to return a value. | |
""" | |
raise TraceError('symbolically traced variables cannot be used as inputs to control flow') | |
def iter(self, obj: 'Proxy') -> Iterator: | |
"""Called when a proxy object is being iterated over, such as | |
when used in control flow. Normally we don't know what to do because | |
we don't know the value of the proxy, but a custom tracer can attach more | |
information to the graph node using create_node and can choose to return an iterator. | |
""" | |
raise TraceError('Proxy object cannot be iterated. This can be ' | |
'attempted when the Proxy is used in a loop or' | |
' as a *args or **kwargs function argument. ' | |
'See the torch.fx docs on pytorch.org for a ' | |
'more detailed explanation of what types of ' | |
'control flow can be traced, and check out the' | |
' Proxy docstring for help troubleshooting ' | |
'Proxy iteration errors') | |
def keys(self, obj: 'Proxy') -> Any: | |
"""Called when a proxy object is has the keys() method called. | |
This is what happens when ** is called on a proxy. This should return an | |
iterator it ** is suppose to work in your custom tracer. | |
""" | |
return Attribute(obj, 'keys')() | |
# used in Proxy object when just appending to the graph while not tracing. | |
class GraphAppendingTracer(TracerBase): | |
def __init__(self, graph: Graph): | |
super().__init__() | |
self.graph = graph | |
self.scope = Scope("", None) | |
self.module_stack = collections.OrderedDict() | |
self.node_name_to_scope = {} | |
def assert_fn(x): | |
assert x | |
class TraceError(ValueError): | |
pass | |
class Proxy: | |
""" | |
``Proxy`` objects are ``Node`` wrappers that flow through the | |
program during symbolic tracing and record all the operations | |
(``torch`` function calls, method calls, operators) that they touch | |
into the growing FX Graph. | |
If you're doing graph transforms, you can wrap your own ``Proxy`` | |
method around a raw ``Node`` so that you can use the overloaded | |
operators to add additional things to a ``Graph``. | |
``Proxy`` objects cannot be iterated. In other words, the symbolic | |
tracer will throw an error if a ``Proxy`` is used in a loop or as | |
an ``*args``/``**kwargs`` function argument. | |
There are two main ways around this: | |
1. Factor out the untraceable logic into a top-level function and | |
use ``fx.wrap`` on it. | |
2. If the control flow is static (i.e. the loop trip count is | |
based on some hyperparameter), the code can be kept in its original | |
position and refactored into something like:: | |
for i in range(self.some_hyperparameter): | |
indexed_item = proxied_value[i] | |
For a more detailed description into the Proxy internals, check out | |
the "Proxy" section in `torch/fx/OVERVIEW.md` | |
""" | |
def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): | |
if tracer is None: | |
# This allows you to create a Proxy object around a raw Node | |
tracer = GraphAppendingTracer(node.graph) | |
self.tracer = tracer | |
self.node = node | |
def __repr__(self) -> str: | |
return f'Proxy({self.node.name})' | |
def __getattr__(self, k) -> 'Attribute': | |
# note: not added to the graph yet, if this is a method call | |
# we peephole optimize to the method invocation | |
return Attribute(self, k) | |
def __call__(self, *args, **kwargs) -> 'Proxy': | |
return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) | |
def __iter__(self) -> Iterator['Proxy']: | |
frame = inspect.currentframe() | |
assert frame is not None | |
calling_frame = frame.f_back | |
assert calling_frame is not None | |
inst_list = list(dis.get_instructions(calling_frame.f_code)) | |
if sys.version_info >= (3, 11): | |
from bisect import bisect_left | |
inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset) | |
else: | |
inst_idx = calling_frame.f_lasti // 2 | |
inst = inst_list[inst_idx] | |
if inst.opname == 'UNPACK_SEQUENCE': | |
return (self[i] for i in range(inst.argval)) # type: ignore[index] | |
return self.tracer.iter(self) | |
def __abs__(self): | |
return self.tracer.create_proxy('call_function', operator.abs, (self,), {}) | |
def __bool__(self) -> bool: | |
if self.tracer.trace_asserts: | |
# check if this boolean is used in an assertion, bytecode pattern for assertions | |
# is pretty stable for Python 3.7--3.9 | |
frame = inspect.currentframe() | |
assert frame is not None | |
calling_frame = frame.f_back | |
assert calling_frame is not None | |
insts = list(dis.get_instructions(calling_frame.f_code)) | |
if sys.version_info >= (3, 11): | |
from bisect import bisect_left | |
cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) | |
else: | |
cur = calling_frame.f_lasti // 2 | |
inst = insts[cur] | |
if inst.opname == 'POP_JUMP_IF_TRUE': | |
first = insts[cur + 1] | |
assert inst.arg is not None | |
last = insts[inst.arg // 2 - 1] | |
starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError' | |
or first.opname == 'LOAD_ASSERTION_ERROR') | |
if starts_with_assert and last.opname == 'RAISE_VARARGS': | |
self.tracer.create_proxy('call_function', assert_fn, (self,), {}) | |
return True | |
return self.tracer.to_bool(self) | |
def keys(self): | |
return self.tracer.keys(self) | |
def __len__(self): | |
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " | |
"this call to be recorded, please call torch.fx.wrap('len') at " | |
"module scope") | |
def __torch_function__(cls, orig_method, types, args=None, kwargs=None): | |
args = args if args else () | |
kwargs = kwargs if kwargs else {} | |
tracers : Dict[Any, None] = {} | |
def find_tracer(a): | |
if isinstance(a, cls): | |
tracers[a.tracer] = None | |
torch.fx.node.map_aggregate(args, find_tracer) | |
torch.fx.node.map_aggregate(kwargs, find_tracer) | |
if len(tracers) > 1: | |
raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while ' | |
f'trying to trace operations {orig_method}') | |
tracer = next(iter(tracers.keys())) | |
if isinstance(orig_method, torch._C.ScriptMethod): | |
args = (orig_method.owner,) + args | |
return tracer.create_proxy('call_method', orig_method.name, args, kwargs) | |
if torch.overrides.is_tensor_method_or_property(orig_method): | |
return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) | |
else: | |
if isinstance(orig_method, torch._ops.HigherOrderOperator): | |
# TODO: Define how to symbolically trace HigherOrderOperators | |
raise RuntimeError("Unable to symbolically trace HigherOrderOperators") | |
return tracer.create_proxy('call_function', orig_method, args, kwargs, | |
name=tracer.graph._target_to_str(orig_method.__name__)) | |
class Attribute(Proxy): | |
def __init__(self, root: Proxy, attr: str): | |
self.root = root | |
self.attr = attr | |
self.tracer = root.tracer | |
self._node: Optional[Node] = None | |
def node(self): | |
# the node for attributes is added lazily, since most will just be method calls | |
# which do not rely on the getitem call | |
if self._node is None: | |
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node | |
return self._node | |
def __call__(self, *args, **kwargs): | |
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) | |
class ParameterProxy(Proxy): | |
""" | |
A special proxy which lets "shape", "size", "dim", and a few other | |
attribute accesses pass through to the underlying module parameter object, | |
so that conditional tests on these attributes will not throw exception during tracing | |
""" | |
def __init__(self, tracer: TracerBase, node: Node, name, param): | |
super().__init__(node, tracer) | |
assert isinstance(param, torch.nn.Parameter) | |
self.param = param | |
self.name = name | |
def __repr__(self) -> str: | |
return f'ParameterProxy({self.name})' | |
def shape(self): | |
return self.param.shape | |
def size(self): | |
return self.param.size() | |
def dim(self): | |
return self.param.dim() | |
def ndim(self): | |
return self.param.ndim | |
def numel(self): | |
return self.param.numel() | |
def nelement(self): | |
return self.param.nelement() | |
for method in magic_methods: | |
def _scope(method): | |
def impl(*args, **kwargs): | |
tracer = args[0].tracer | |
target = getattr(operator, method) | |
return tracer.create_proxy('call_function', target, args, kwargs) | |
impl.__name__ = method | |
as_magic = f'__{method.strip("_")}__' | |
setattr(Proxy, as_magic, impl) | |
_scope(method) | |
def _define_reflectable(orig_method_name): | |
method_name = f'__r{orig_method_name.strip("_")}__' | |
def impl(self, rhs): | |
target = getattr(operator, orig_method_name) | |
return self.tracer.create_proxy('call_function', target, (rhs, self), {}) | |
impl.__name__ = method_name | |
impl.__qualname__ = method_name | |
setattr(Proxy, method_name, impl) | |
for orig_method_name in reflectable_magic_methods: | |
_define_reflectable(orig_method_name) | |