Spaces:
Running
Running
import builtins | |
import copy | |
import functools | |
import inspect | |
import math | |
import os | |
import warnings | |
import collections | |
from itertools import chain | |
from types import CodeType, FunctionType, ModuleType | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
List, | |
NamedTuple, | |
Optional, | |
Set, | |
Tuple, | |
Type, | |
Union, | |
) | |
import torch | |
import torch.utils._pytree as pytree | |
from torch._C import ScriptObject # type: ignore[attr-defined] | |
from ._compatibility import compatibility | |
from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph | |
from .graph_module import GraphModule | |
from ._lazy_graph_module import _make_graph_module | |
from .node import Argument, base_types, map_aggregate | |
from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager | |
HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS | |
# These need to run in global scope to handle nested calls correctly | |
_orig_module_call: Callable = torch.nn.Module.__call__ | |
_orig_module_getattr: Callable = torch.nn.Module.__getattr__ | |
_proxyable_classes: Dict[Type, None] = {} | |
_is_fx_tracing_flag = False | |
def is_fx_tracing(): | |
return _is_fx_tracing_flag | |
class ProxyableClassMeta(type): | |
""" | |
ProxyableClassMeta allows you to make construction of a given Python class | |
symbolically traceable. For example:: | |
import torch | |
import torch.fx | |
class TensorPair(metaclass=torch.fx.ProxyableClassMeta): | |
def __init__(self, left, right): | |
self.left, self.right = left, right | |
def add(self, other): | |
l = self.left + other.left | |
r = self.right + other.right | |
return TensorPair(l, r) | |
def mul(self, other): | |
l = self.left * other.left | |
r = self.right * other.right | |
return TensorPair(l, r) | |
def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): | |
s = x.add(TensorPair(y, y)) | |
return s.mul(x) | |
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) | |
y = torch.randn(5, 3) | |
ref_out = use_tensor_pair_ctor(x, y) | |
traced = torch.fx.symbolic_trace(use_tensor_pair_ctor) | |
print(traced.code) | |
''' | |
def forward(self, x : __main___TensorPair, y : torch.Tensor): | |
tensor_pair = __main___TensorPair(y, y); y = None | |
add = x.add(tensor_pair); tensor_pair = None | |
mul = add.mul(x); add = x = None | |
return mul | |
''' | |
From this example, we can see that construction of a class (``TensorPair``) | |
defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic | |
tracing. | |
""" | |
def __init__(cls, name, bases, attrs): | |
_proxyable_classes.setdefault(cls) | |
super().__init__(name, bases, attrs) | |
def __call__(cls, *args, **kwargs): | |
instance = cls.__new__(cls) # type: ignore[call-overload] | |
if not is_fx_tracing(): | |
cls.__init__(instance, *args, **kwargs) # type: ignore[misc] | |
return instance | |
found_proxies = [] | |
def check_proxy(a): | |
if isinstance(a, Proxy): | |
found_proxies.append(a) | |
map_aggregate(args, check_proxy) | |
map_aggregate(kwargs, check_proxy) | |
if len(found_proxies) != 0: | |
tracer = found_proxies[0].tracer | |
return tracer.create_proxy("call_function", cls, args, kwargs) | |
else: | |
cls.__init__(instance, *args, **kwargs) # type: ignore[misc] | |
return instance | |
def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: | |
co = fn.__code__ | |
co_flags = co.co_flags & ~HAS_VARSTUFF | |
co_args: tuple | |
if hasattr(co, "co_qualname"): | |
# Python-3.11+ code signature | |
co_args = ( | |
nargs, | |
0, | |
0, | |
co.co_nlocals, | |
co.co_stacksize, | |
co_flags, | |
co.co_code, | |
co.co_consts, | |
co.co_names, | |
co.co_varnames, | |
co.co_filename, | |
co.co_name, | |
co.co_qualname, # type: ignore[attr-defined] | |
co.co_firstlineno, | |
co.co_lnotab, | |
co.co_exceptiontable, # type: ignore[attr-defined] | |
co.co_freevars, | |
co.co_cellvars, | |
) | |
elif hasattr(co, "co_posonlyargcount"): | |
co_args = ( | |
nargs, | |
0, | |
0, | |
co.co_nlocals, | |
co.co_stacksize, | |
co_flags, | |
co.co_code, | |
co.co_consts, | |
co.co_names, | |
co.co_varnames, | |
co.co_filename, | |
co.co_name, | |
co.co_firstlineno, | |
co.co_lnotab, | |
co.co_freevars, | |
co.co_cellvars, | |
) | |
else: | |
co_args = ( | |
nargs, | |
0, | |
co.co_nlocals, | |
co.co_stacksize, | |
co_flags, | |
co.co_code, | |
co.co_consts, | |
co.co_names, | |
co.co_varnames, | |
co.co_filename, | |
co.co_name, | |
co.co_firstlineno, | |
co.co_lnotab, | |
co.co_freevars, | |
co.co_cellvars, | |
) | |
new_code = CodeType(*co_args) # type: ignore[arg-type] | |
return FunctionType( | |
new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ | |
) | |
# we need to insert placeholder nodes for *args and **kwargs | |
# we can't call this function normally, otherwise it would try to unpack them | |
# instead, let's make python think that args and kwargs are normal variables | |
class PHBase: | |
""" | |
Object representing an input placeholder to `concrete_args` | |
""" | |
def __repr__(self): | |
return "PH" | |
PH = PHBase() | |
class PHWithMeta(PHBase): | |
""" | |
Object representing an input placeholder to `concrete_args` | |
""" | |
def __init__(self, ph_key: Optional[str] = None): | |
super().__init__() | |
# Provide a hey for user to identify placeholder node during analysis | |
self.ph_key = ph_key | |
def _transfer_attrs(fr, to): | |
for attr_name in dir(fr): | |
attr_val = getattr(fr, attr_name) | |
if ( | |
not callable(attr_val) | |
and not attr_name.startswith("__") | |
and not hasattr(to, attr_name) | |
): | |
setattr(to, attr_name, attr_val) | |
class Tracer(TracerBase): | |
# Reference: https://github.com/pytorch/pytorch/issues/54354 | |
# The first line of this docstring overrides the one Sphinx generates for the | |
# documentation. We need it so that Sphinx doesn't leak `math`s path from the | |
# build environment (e.g. `<module 'math' from '/leaked/path'). | |
"""Tracer(autowrap_modules=(math,), autowrap_functions=()) | |
``Tracer`` is the class that implements the symbolic tracing functionality | |
of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent | |
to ``Tracer().trace(m)``. | |
Tracer can be subclassed to override various behaviors of the tracing | |
process. The different behaviors that can be overridden are described | |
in the docstrings of the methods on this class. | |
""" | |
# Not checking BC on this API because the default value for `autowrap_modules` | |
# includes the local filepath to the `math` module, which would jitter | |
# across machines. | |
def __init__( | |
self, | |
autowrap_modules: Tuple[ModuleType] = (math,), | |
autowrap_functions: Tuple[Callable, ...] = (), | |
param_shapes_constant: bool = False, | |
) -> None: | |
# This method's signature is overridden by the first line of this class' | |
# docstring. If this method's signature is modified, the signature that | |
# overrides it also should be modified accordingly. | |
""" | |
Construct a Tracer object. | |
Args: | |
autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, | |
Python modules whose functions should be wrapped automatically | |
without needing to use fx.wrap(). Backward-compatibility for | |
this parameter is guaranteed. | |
autowrap_functions (Tuple[Callable, ...]): defaults to `()`, | |
Python functions that should be wrapped automatically without | |
needing to use fx.wrap(). Backward compatibility for this | |
parameter is guaranteed. | |
param_shapes_constant (bool): When this flag is set, calls to shape, | |
size and a few other shape like attributes of a module's parameter | |
will be evaluated directly, rather than returning a new Proxy value | |
for an attribute access. Backward compatibility for this parameter | |
is guaranteed. | |
""" | |
super().__init__() | |
# Functions we will eagerly wrap when we see them while tracing | |
# this captures both `math.sqrt()` and `from math import sqrt` automatically | |
self._autowrap_function_ids: Set[int] = { | |
id(value) | |
for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) | |
if not name.startswith("_") and callable(value) | |
} | |
self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) | |
# Python modules to apply autowrap to at the start, in addition to | |
# modules we see while tracing | |
self._autowrap_search: List[ModuleType] = list(autowrap_modules) | |
self.param_shapes_constant = param_shapes_constant | |
self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None | |
self.root_module_name: str = "" | |
# Maps the containing module's name to the operator name | |
self.scope = Scope("", None) | |
# Records the module call stack | |
self.module_stack = collections.OrderedDict() | |
# Mapping of node name to module scope | |
self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} | |
def create_arg(self, a: Any) -> "Argument": | |
""" | |
A method to specify the behavior of tracing when preparing values to | |
be used as arguments to nodes in the ``Graph``. | |
By default, the behavior includes: | |
#. Iterate through collection types (e.g. tuple, list, dict) and recursively | |
call ``create_args`` on the elements. | |
#. Given a Proxy object, return a reference to the underlying IR ``Node`` | |
#. Given a non-Proxy Tensor object, emit IR for various cases: | |
* For a Parameter, emit a ``get_attr`` node referring to that Parameter | |
* For a non-Parameter Tensor, store the Tensor away in a special | |
attribute referring to that attribute. | |
This method can be overridden to support more types. | |
Args: | |
a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. | |
Returns: | |
The value ``a`` converted into the appropriate ``Argument`` | |
""" | |
# The base tracer is used to construct Graphs when there is no associated | |
# module hierarchy, so it can never create parameter references. | |
# The default tracer adds the ability to refer to parameters when | |
# tracing modules. | |
if isinstance(a, torch.nn.Parameter): | |
for n, p in self.root.named_parameters(): | |
if a is p: | |
return self.create_node("get_attr", n, (), {}) | |
raise NameError("parameter is not a member of this module") | |
elif isinstance(a, torch.Tensor): | |
for n_, p_ in self.root.named_buffers(): | |
if a is p_: | |
return self.create_node("get_attr", n_, (), {}) | |
elif isinstance(a, torch.nn.Module): | |
for n_, p_ in self.root.named_modules(): | |
if a is p_: | |
return self.create_node("get_attr", n_, (), {}) | |
# For NamedTuple instances that appear literally as args, we emit | |
# a node to construct the NamedTuple and use that Node as the argument. | |
if isinstance(a, tuple) and hasattr(a, "_fields"): | |
args = tuple(self.create_arg(elem) for elem in a) | |
return self.create_node("call_function", a.__class__, args, {}) | |
# Tensors do not have a reliable string repr() from which they can be | |
# constructed (and we probably don't want to rely on that, either), so | |
# for any constant Tensor values we encounter, first search for if they | |
# are an attribute of some module in the module hierarchy. If so, emit | |
# a get_attr to retrieve that tensor. Otherwise, we'll store away the | |
# tensor value into a special attribute on the Module s.t. we can | |
# retrieve it with a get_attr. | |
if isinstance(a, (torch.Tensor, ScriptObject)): | |
qualname: Optional[str] = self.tensor_attrs.get(a) | |
# Tensor was not found in the Module hierarchy, stow it away in a | |
# special attribute and set the qualname to refer to that | |
if not qualname: | |
i = 0 | |
while True: | |
qualname = f"_tensor_constant{i}" | |
if not hasattr(self.root, qualname): | |
break | |
i += 1 | |
self.tensor_attrs[a] = qualname | |
setattr(self.root, qualname, a) | |
return self.create_node("get_attr", qualname, (), {}) | |
if type(a) in _proxyable_classes: | |
# This is an instance of a proxyable class for which we did not | |
# witness its construction. Intern this as a constant attribute | |
# TODO: binary search | |
i = 0 | |
while True: | |
qualname = f"_{a.__class__.__name__}_constant_{i}" | |
if not hasattr(self.root, qualname): | |
break | |
i += 1 | |
setattr(self.root, qualname, a) | |
return self.create_node("get_attr", qualname, (), {}) | |
return super().create_arg(a) | |
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: | |
""" | |
A method to specify whether a given ``nn.Module`` is a "leaf" module. | |
Leaf modules are the atomic units that appear in | |
the IR, referenced by ``call_module`` calls. By default, | |
Modules in the PyTorch standard library namespace (torch.nn) | |
are leaf modules. All other modules are traced through and | |
their constituent ops are recorded, unless specified otherwise | |
via this parameter. | |
Args: | |
m (Module): The module being queried about | |
module_qualified_name (str): The path to root of this module. For example, | |
if you have a module hierarchy where submodule ``foo`` contains | |
submodule ``bar``, which contains submodule ``baz``, that module will | |
appear with the qualified name ``foo.bar.baz`` here. | |
""" | |
return ( | |
(m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) | |
and not isinstance(m, torch.nn.Sequential) | |
) | |
def path_of_module(self, mod: torch.nn.Module) -> str: | |
""" | |
Helper method to find the qualified name of ``mod`` in the Module hierarchy | |
of ``root``. For example, if ``root`` has a submodule named ``foo``, which has | |
a submodule named ``bar``, passing ``bar`` into this function will return | |
the string "foo.bar". | |
Args: | |
mod (str): The ``Module`` to retrieve the qualified name for. | |
""" | |
# Prefer the O(1) algorithm | |
if self.submodule_paths: | |
path = self.submodule_paths.get(mod) | |
if path is None: | |
raise NameError("module is not installed as a submodule") | |
assert isinstance(path, str) | |
return path | |
# O(N^2) fallback in the case that we didn't store the submodule | |
# paths. | |
else: | |
for n, p in self.root.named_modules(): | |
if mod is p: | |
return n | |
raise NameError("module is not installed as a submodule") | |
def call_module( | |
self, | |
m: torch.nn.Module, | |
forward: Callable[..., Any], | |
args: Tuple[Any, ...], | |
kwargs: Dict[str, Any], | |
) -> Any: | |
""" | |
Method that specifies the behavior of this ``Tracer`` when it encounters | |
a call to an ``nn.Module`` instance. | |
By default, the behavior is to check if the called module is a leaf module | |
via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to | |
``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through | |
the operations in its ``forward`` function. | |
This method can be overridden to--for example--create nested traced | |
GraphModules, or any other behavior you would want while tracing across | |
``Module`` boundaries. | |
Args: | |
m (Module): The module for which a call is being emitted | |
forward (Callable): The forward() method of the ``Module`` to be invoked | |
args (Tuple): args of the module callsite | |
kwargs (Dict): kwargs of the module callsite | |
Return: | |
The return value from the Module call. In the case that a ``call_module`` | |
node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever | |
value was returned from the ``Module`` invocation. | |
""" | |
module_qualified_name = self.path_of_module(m) | |
with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: | |
# module_stack is an ordered dict so writing then deleting the | |
# entry is equivalent to push/pop on a list | |
self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type) | |
if not self.is_leaf_module(m, module_qualified_name): | |
ret_val = forward(*args, **kwargs) | |
else: | |
ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) | |
key, _ = self.module_stack.popitem(last=True) | |
assert key == _scope.module_path, f" Unexpected key {key}" | |
return ret_val | |
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): | |
""" | |
Method that specifies the behavior of this ``Tracer`` when we call getattr | |
on a call to an ``nn.Module`` instance. | |
By default, the behavior is to return a proxy value for the attribute. It | |
also stores the proxy value in the ``parameter_proxy_cache``, so that future | |
calls will reuse the proxy rather than creating a new one. | |
This method can be overridden to --for example-- not return proxies when | |
querying parameters. | |
Args: | |
attr (str): The name of the attribute being queried | |
attr_val (Any): The value of the attribute | |
parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies | |
Return: | |
The return value from the getattr call. | |
""" | |
def maybe_get_proxy_for_attr( | |
attr_val, collection_to_search, parameter_proxy_cache | |
): | |
for n, p in collection_to_search: | |
if attr_val is p: | |
if n not in parameter_proxy_cache: | |
kwargs = {} | |
if ( | |
"proxy_factory_fn" | |
in inspect.signature(self.create_proxy).parameters | |
): | |
kwargs["proxy_factory_fn"] = ( | |
None | |
if not self.param_shapes_constant | |
else lambda node: ParameterProxy( | |
self, node, n, attr_val | |
) | |
) | |
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] | |
parameter_proxy_cache[n] = val_proxy | |
return parameter_proxy_cache[n] | |
return None | |
if isinstance(attr_val, torch.nn.Parameter): | |
maybe_parameter_proxy = maybe_get_proxy_for_attr( | |
attr_val, self.root.named_parameters(), parameter_proxy_cache | |
) | |
if maybe_parameter_proxy is not None: | |
return maybe_parameter_proxy | |
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): | |
maybe_buffer_proxy = maybe_get_proxy_for_attr( | |
attr_val, self.root.named_buffers(), parameter_proxy_cache | |
) | |
if maybe_buffer_proxy is not None: | |
return maybe_buffer_proxy | |
return attr_val | |
# This method will be refactored | |
def create_args_for_root(self, root_fn, is_module, concrete_args=None): | |
""" | |
Create ``placeholder`` nodes corresponding to the signature of the ``root`` | |
Module. This method introspects root's signature and emits those | |
nodes accordingly, also supporting ``*args`` and ``**kwargs``. | |
""" | |
# In some cases, a function or method has been decorated with a wrapper | |
# defined via ``functools.wraps``. In this case, the outer code object | |
# will likely not contain the actual parameters we care about, so unwrap | |
# the function to get to the innermost callable. | |
fn_for_analysis = inspect.unwrap(root_fn) | |
co = fn_for_analysis.__code__ | |
total_args = co.co_argcount + co.co_kwonlyargcount | |
orig_args = list(co.co_varnames) | |
names_iter = iter(co.co_varnames) | |
args: List[Any] = [] | |
skip_arg_idx = 0 | |
if is_module: | |
if total_args == 0: | |
raise RuntimeError( | |
"``self`` argument cannot be part of *args expansion!" | |
) | |
skip_arg_idx = 1 | |
next(names_iter) # skip self | |
args.append(self.root) | |
sig = inspect.signature(fn_for_analysis) | |
# This covers the very specific case where we are passing in flat | |
# concrete_args as a tuple, but our traced fn takes (*args, **kwargs). | |
# In this case, just take the concrete_args and pass them through. | |
name_idx = 0 | |
if isinstance(concrete_args, tuple) and \ | |
len(concrete_args) > 0 and \ | |
(co.co_flags & HAS_VARSTUFF) and \ | |
total_args == 1: | |
for concrete_arg in concrete_args: | |
out = self.create_proxy("placeholder", f"input_{name_idx}", (), {}) | |
if isinstance(concrete_arg, PHBase): | |
if concrete_arg != PH: | |
# Transfer attrs in the case where you're using a placeholder other | |
# than the singleton PH (PH has no attributes to transfer). | |
# Proxies were created out of the placeholders. | |
# Transfer any metadata (put on the placeholders in the form of | |
# attributes set by the user) from the placeholder to the | |
# underlying nodes (the proxy is unwrapped by the user, but | |
# the metadata should hold). | |
_transfer_attrs(fr=concrete_arg, to=out.node) | |
args.append(out) | |
name_idx += 1 | |
return root_fn, args | |
arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] | |
if isinstance(concrete_args, tuple): | |
if len(arg_names) != len(concrete_args): | |
raise RuntimeError( | |
f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" | |
) | |
concrete_args = dict(zip(arg_names, concrete_args)) | |
def proxy_placeholder(name): | |
return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis) | |
args.extend(proxy_placeholder(names) for names in arg_names) | |
if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: | |
# TODO: type annotations for *args and **kwargs | |
if co.co_flags & inspect.CO_VARARGS: | |
args.append(proxy_placeholder("*" + next(names_iter))) | |
if co.co_flags & inspect.CO_VARKEYWORDS: | |
args.append(proxy_placeholder("**" + next(names_iter))) | |
root_fn = _patch_function(root_fn, len(args)) | |
flat_args, in_spec = pytree.tree_flatten(tuple(args)) | |
if not all(child.is_leaf() for child in in_spec.children_specs): | |
# In the case that we have pytree-flattened inputs in | |
# `concrete_args`, generate a flattening wrapper around the | |
# original root function and return that. | |
self.graph._codegen = _PyTreeCodeGen( | |
_PyTreeInfo(orig_args[:total_args], in_spec, None) | |
) | |
def flatten_fn(*args): | |
tree_args = pytree.tree_unflatten(list(args), in_spec) | |
tree_out = root_fn(*tree_args) | |
out_args, out_spec = pytree.tree_flatten(tree_out) | |
assert isinstance(self.graph._codegen, _PyTreeCodeGen) | |
self.graph._codegen.pytree_info = ( | |
self.graph._codegen.pytree_info._replace(out_spec=out_spec) | |
) | |
return out_args | |
return flatten_fn, flat_args | |
return root_fn, args | |
def trace( | |
self, | |
root: Union[torch.nn.Module, Callable[..., Any]], | |
concrete_args: Optional[Dict[str, Any]] = None, | |
) -> Graph: | |
""" | |
Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` | |
can either be an ``nn.Module`` instance or a Python callable. | |
Note that after this call, ``self.root`` may be different from the ``root`` passed | |
in here. For example, when a free function is passed to ``trace()``, we will | |
create an ``nn.Module`` instance to use as the root and add embedded constants | |
to. | |
Args: | |
root (Union[Module, Callable]): Either a ``Module`` or a function to be | |
traced through. Backwards-compatibility for this parameter is | |
guaranteed. | |
concrete_args (Optional[Dict[str, any]]): Concrete arguments that should | |
not be treated as Proxies. This parameter is experimental and | |
its backwards-compatibility is *NOT* guaranteed. | |
Returns: | |
A ``Graph`` representing the semantics of the passed-in ``root``. | |
""" | |
global _is_fx_tracing_flag | |
old_is_fx_tracing_flag = _is_fx_tracing_flag | |
_is_fx_tracing_flag = True | |
try: | |
if isinstance(root, torch.nn.Module): | |
# do real recompilation for _LazyGraphModule before retracing since the trace | |
# method can not trace the _lazy_forward method. Got error: | |
# https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 | |
# without this. | |
from torch.fx._lazy_graph_module import _LazyGraphModule | |
_LazyGraphModule.force_recompile(root) | |
self.root = root | |
assert hasattr( | |
type(root), self.traced_func_name | |
), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" | |
fn = getattr(type(root), self.traced_func_name) | |
self.root_module_name = root._get_name() | |
self.submodule_paths = {mod: name for name, mod in root.named_modules()} | |
else: | |
self.root = torch.nn.Module() | |
fn = root | |
tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None) | |
self.graph = Graph(tracer_cls=tracer_cls) | |
if hasattr(fn, '__code__'): | |
code = fn.__code__ | |
self.graph._co_fields = { | |
'co_name': code.co_name, | |
'co_filename': code.co_filename, | |
'co_firstlineno': code.co_firstlineno, | |
} | |
# When we encounter a Tensor value that's not a parameter, we look if it | |
# is some other attribute on the model. Construct a dict mapping Tensor | |
# values to the qualified name here for efficiency. This is used downstream | |
# in create_arg | |
self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} | |
def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): | |
for k, v in m.__dict__.items(): | |
if isinstance(v, (torch.Tensor, ScriptObject)): | |
self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) | |
for k, v in m.named_children(): | |
collect_tensor_attrs(v, prefix_atoms + [k]) | |
collect_tensor_attrs(self.root, []) | |
assert isinstance(fn, FunctionType) | |
fn_globals = fn.__globals__ # run before it gets patched | |
fn, args = self.create_args_for_root( | |
fn, isinstance(root, torch.nn.Module), concrete_args | |
) | |
parameter_proxy_cache: Dict[ | |
str, Proxy | |
] = {} # Reduce number of get_attr calls | |
# Method dispatch on parameters is not recorded unless it's directly used. | |
# Thus, we need to insert a proxy when __getattr__ requests a parameter. | |
def module_getattr_wrapper(mod, attr): | |
attr_val = _orig_module_getattr(mod, attr) | |
return self.getattr(attr, attr_val, parameter_proxy_cache) | |
def module_call_wrapper(mod, *args, **kwargs): | |
def forward(*args, **kwargs): | |
return _orig_module_call(mod, *args, **kwargs) | |
_autowrap_check( | |
patcher, | |
getattr(getattr(mod, "forward", mod), "__globals__", {}), | |
self._autowrap_function_ids, | |
) | |
return self.call_module(mod, forward, args, kwargs) | |
with _Patcher() as patcher: | |
# allow duplicate patches to support the case of nested calls | |
patcher.patch_method( | |
torch.nn.Module, | |
"__getattr__", | |
module_getattr_wrapper, | |
deduplicate=False, | |
) | |
patcher.patch_method( | |
torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False | |
) | |
_patch_wrapped_functions(patcher) | |
_autowrap_check(patcher, fn_globals, self._autowrap_function_ids) | |
for module in self._autowrap_search: | |
_autowrap_check( | |
patcher, module.__dict__, self._autowrap_function_ids | |
) | |
self.create_node( | |
"output", | |
"output", | |
(self.create_arg(fn(*args)),), | |
{}, | |
type_expr=fn.__annotations__.get("return", None), | |
) | |
self.submodule_paths = None | |
finally: | |
_is_fx_tracing_flag = old_is_fx_tracing_flag | |
return self.graph | |
def __deepcopy__(self, memo): | |
# _autowrap_search contains modules, which cannot be deepcopied. | |
new_tracer = Tracer.__new__(Tracer) | |
for k, v in self.__dict__.items(): | |
if k in {'_autowrap_search'}: | |
new_obj = copy.copy(v) | |
else: | |
new_obj = copy.deepcopy(v, memo) | |
new_tracer.__dict__[k] = new_obj | |
return new_tracer | |
def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis): | |
if concrete_args is not None and name in concrete_args: | |
cnt = 0 | |
def replace_ph(x): | |
nonlocal cnt | |
cnt += 1 | |
param = sig.parameters[name] | |
default = ( | |
() | |
if param.default is inspect.Parameter.empty | |
else (param.default,) | |
) | |
out = self.create_proxy( | |
"placeholder", f"{name}_{str(cnt)}", default, {} | |
) | |
if isinstance(x, PHBase): | |
if x != PH: | |
# Transfer attrs in the case where you're using a placeholder other | |
# than the singleton PH (PH has no attributes to transfer). | |
# Proxies were created out of the placeholders. | |
# Transfer any metadata (put on the placeholders in the form of | |
# attributes set by the user) from the placeholder to the | |
# underlying nodes (the proxy is unwrapped by the user, but | |
# the metadata should hold). | |
_transfer_attrs(fr=x, to=out.node) | |
return out | |
# Union[int, bool] == bool in Python <= 3.6 | |
if ( | |
type(x) == bool | |
or type(x) in base_types | |
and type(x) != torch.Tensor | |
): | |
torch._assert( | |
out == x, | |
f"{name} has been specialized to have value {x} but got another value", | |
) | |
elif x is None: | |
args = ( | |
out, | |
f"{name} has been specialized to have value None but got another value", | |
) | |
self.create_proxy("call_function", _assert_is_none, args, {}) | |
else: | |
warnings.warn( | |
f"Was not able to add assertion to guarantee correct input {name} to " | |
f"specialized function. It is up to the user to make sure that your inputs match the " | |
f"inputs you specialized the function with." | |
) | |
return x | |
return pytree.tree_map(replace_ph, concrete_args[name]) | |
if name[0] == "*": | |
default = () | |
else: | |
param = sig.parameters[name] | |
default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] | |
return self.create_proxy( | |
"placeholder", | |
name, | |
default, | |
{}, | |
type_expr=fn_for_analysis.__annotations__.get(name, None) | |
) | |
# Dictionary of (id(globals dict), function name) => globals_dict to patch for | |
# the purposes of the wrap() API. | |
# We key by the globals dict id and function name to ensure we're wrapping a given | |
# function only once. | |
_wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {} | |
# List of methods on classes to wrap (class type, function name) | |
# this currently only works for Tensor.* methods that aren't traced properly | |
_wrapped_methods_to_patch: List[Tuple[type, str]] = [] | |
if os.environ.get("FX_PATCH_GETITEM") == "1": | |
# This change is needed to trace models like PositionalEmbedding from BERT: | |
# https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py | |
# but causes issues in quantization documented here: | |
# https://github.com/pytorch/pytorch/issues/50710 | |
# once that is fixed we can make this the default behavior. | |
_wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) | |
def _find_proxy(*objects_to_search): | |
""" | |
Recursively search a data structure for a Proxy() and return it, | |
return None if not found. | |
""" | |
proxy = None | |
def find_proxy(x): | |
nonlocal proxy | |
if isinstance(x, Proxy): | |
proxy = x | |
map_aggregate(objects_to_search, find_proxy) | |
return proxy | |
def _create_wrapped_func(orig_fn): | |
def wrapped(*args, **kwargs): | |
""" | |
Given an closed-over ``orig_function`` to invoke, search the args and kwargs for | |
a Proxy object. If there is one, emit a ``call_function`` node to preserve the | |
call to this leaf function directly. Otherwise, just return the results of | |
this function call, as this function is not being traced. | |
""" | |
proxy = _find_proxy(args, kwargs) | |
if proxy is not None: | |
return_proxy = proxy.tracer.create_proxy( | |
"call_function", orig_fn, args, kwargs | |
) | |
return_proxy.node.meta["is_wrapped"] = True | |
return return_proxy | |
return orig_fn(*args, **kwargs) | |
return wrapped | |
def _create_wrapped_method(cls, name): | |
orig_fn = getattr(cls, name) | |
def wrapped(*args, **kwargs): | |
""" | |
Search the args and kwargs for a Proxy object. If there is one, | |
emit a ``call_method`` node to preserve the call to this method | |
directly. Otherwise, just return the results of this function | |
call, as this function is not being traced. | |
""" | |
proxy = _find_proxy(args, kwargs) | |
if proxy is not None: | |
return proxy.tracer.create_proxy("call_method", name, args, kwargs) | |
return orig_fn(*args, **kwargs) | |
return wrapped | |
class _PatchedFn(NamedTuple): | |
frame_dict: Any | |
fn_name: str | |
orig_fn: Any | |
def revert(self): | |
raise NotImplementedError() | |
class _PatchedFnSetItem(_PatchedFn): | |
def revert(self): | |
self.frame_dict[self.fn_name] = self.orig_fn | |
class _PatchedFnDel(_PatchedFn): | |
def revert(self): | |
del self.frame_dict[self.fn_name] | |
class _PatchedFnSetAttr(_PatchedFn): | |
def revert(self): | |
setattr(self.frame_dict, self.fn_name, self.orig_fn) | |
class _Patcher: | |
def __init__(self): | |
super().__init__() | |
self.patches_made: List[_PatchedFn] = [] | |
self.visited: Set[int] = set() | |
def patch( | |
self, | |
frame_dict: Dict[str, Any], | |
name: str, | |
new_fn: Callable, | |
deduplicate: bool = True, | |
): | |
""" | |
Replace frame_dict[name] with new_fn until we exit the context manager. | |
""" | |
new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] | |
if name not in frame_dict and hasattr(builtins, name): | |
self.patches_made.append(_PatchedFnDel(frame_dict, name, None)) | |
elif getattr(frame_dict[name], "__fx_already_patched", False): | |
return # already patched, no need to do it again | |
else: | |
self.patches_made.append( | |
_PatchedFnSetItem(frame_dict, name, frame_dict[name]) | |
) | |
frame_dict[name] = new_fn | |
def patch_method( | |
self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True | |
): | |
""" | |
Replace object_or_dict.name with new_fn until we exit the context manager. | |
""" | |
new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] | |
orig_fn = getattr(cls, name) | |
if getattr(orig_fn, "__fx_already_patched", False): | |
return # already patched, no need to do it again | |
self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn)) | |
setattr(cls, name, new_fn) | |
def visit_once(self, thing: Any): | |
"""Return True on the first call to with thing, otherwise false""" | |
idx = id(thing) | |
if idx in self.visited: | |
return False | |
self.visited.add(idx) | |
return True | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
""" | |
Undo all the changes made via self.patch() and self.patch_method() | |
""" | |
while self.patches_made: | |
# unpatch in reverse order to handle duplicates correctly | |
self.patches_made.pop().revert() | |
self.visited.clear() | |
def _patch_wrapped_functions(patcher: _Patcher): | |
""" | |
Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap | |
the listed global functions in the `_create_wrapped_func` wrapper. | |
""" | |
for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items(): | |
if name not in frame_dict and hasattr(builtins, name): | |
orig_fn = getattr(builtins, name) | |
else: | |
orig_fn = frame_dict[name] | |
patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) | |
for cls, name in _wrapped_methods_to_patch: | |
patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) | |
def _autowrap_check( | |
patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] | |
): | |
""" | |
Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. | |
This method searches a scope for them and patches them if found. | |
""" | |
if patcher.visit_once(frame_dict): | |
for name, value in frame_dict.items(): | |
if ( | |
not name.startswith("_") | |
and callable(value) | |
and id(value) in function_ids | |
): | |
patcher.patch(frame_dict, name, _create_wrapped_func(value)) | |
def wrap(fn_or_name: Union[str, Callable]): | |
""" | |
This function can be called at module-level scope to register fn_or_name as a "leaf function". | |
A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being | |
traced through:: | |
# foo/bar/baz.py | |
def my_custom_function(x, y): | |
return x * x + y * y | |
torch.fx.wrap('my_custom_function') | |
def fn_to_be_traced(x, y): | |
# When symbolic tracing, the below call to my_custom_function will be inserted into | |
# the graph rather than tracing it. | |
return my_custom_function(x, y) | |
This function can also equivalently be used as a decorator:: | |
# foo/bar/baz.py | |
@torch.fx.wrap | |
def my_custom_function(x, y): | |
return x * x + y * y | |
A wrapped function can be thought of a "leaf function", analogous to the concept of | |
"leaf modules", that is, they are functions that are left as calls in the FX trace | |
rather than traced through. | |
Args: | |
fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the | |
graph when it's called | |
""" | |
if not callable(fn_or_name) and not isinstance(fn_or_name, str): | |
raise RuntimeError( | |
"Unsupported type for global function! Must be either a callable or " | |
"string name" | |
) | |
if callable(fn_or_name): | |
assert not isinstance(fn_or_name, str) # to make mypy happy | |
fn_name = fn_or_name.__name__ | |
else: | |
assert isinstance( | |
fn_or_name, str | |
), "fn_or_name must be a global function or string name" | |
fn_name = fn_or_name | |
currentframe = inspect.currentframe() | |
assert currentframe is not None | |
f = currentframe.f_back | |
assert f is not None | |
if f.f_code.co_name != "<module>": | |
raise NotImplementedError("wrap must be called at the top level of a module") | |
# consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search | |
# semantics would be slightly different, but would add support `from x import wrapped_function` | |
_wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals | |
return fn_or_name | |
def symbolic_trace( | |
root: Union[torch.nn.Module, Callable[..., Any]], | |
concrete_args: Optional[Dict[str, Any]] = None, | |
) -> GraphModule: | |
""" | |
Symbolic tracing API | |
Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` | |
constructed by recording operations seen while tracing through ``root``. | |
``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. | |
For example:: | |
def f(a, b): | |
if b == True: | |
return a | |
else: | |
return a*2 | |
FX can typically not trace through this due to the presence of control | |
flow. However, we can use `concrete_args` to specialize on the value of | |
`b` to trace through this:: | |
f = fx.symbolic_trace(f, concrete_args={'b': False}) | |
assert f(3, False) == 6 | |
Note that although you can still pass in different values of `b`, they will be ignored. | |
We can also use `concrete_args` to eliminate data-structure handling from | |
our function. This will use pytrees to flatten your input. To avoid | |
overspecializing, pass in `fx.PH` for values that shouldn't be | |
specialized. For example:: | |
def f(x): | |
out = 0 | |
for v in x.values(): | |
out += v | |
return out | |
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) | |
assert f({'a': 1, 'b': 2, 'c': 4}) == 7 | |
Args: | |
root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted | |
into a Graph representation. | |
concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized | |
Returns: | |
GraphModule: a Module created from the recorded operations from ``root``. | |
""" | |
tracer = Tracer() | |
graph = tracer.trace(root, concrete_args) | |
name = ( | |
root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ | |
) | |
return _make_graph_module(tracer.root, graph, name) | |
def _assert_is_none(value, msg): | |
assert value is None, msg | |