Spaces:
Running
Running
"""TorchScript. | |
This module contains functionality to support the JIT's scripting frontend, notably: | |
- torch.jit.script | |
This is not intended to be imported directly; please use the exposed | |
functionalities in `torch.jit`. | |
""" | |
import collections | |
import copy | |
import enum | |
import functools | |
import inspect | |
import pickle | |
import warnings | |
from typing import Any, Callable, Dict, List, Set, Tuple, Union | |
import torch | |
import torch._jit_internal as _jit_internal | |
from torch._classes import classes | |
from torch._jit_internal import _qualified_name | |
from torch.jit._builtins import _register_builtin | |
from torch.jit._fuser import _graph_for, _script_method_graph_for | |
from torch.jit._monkeytype_config import ( | |
JitTypeTraceConfig, | |
JitTypeTraceStore, | |
monkeytype_trace, | |
) | |
from torch.jit._recursive import ( | |
_compile_and_register_class, | |
infer_methods_to_compile, | |
ScriptMethodStub, | |
wrap_cpp_module, | |
) | |
from torch.jit._state import ( | |
_enabled, | |
_set_jit_function_cache, | |
_set_jit_overload_cache, | |
_try_get_jit_cached_function, | |
_try_get_jit_cached_overloads, | |
) | |
from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def | |
from torch.nn import Module | |
from torch.overrides import ( | |
has_torch_function, | |
has_torch_function_unary, | |
has_torch_function_variadic, | |
) | |
from torch.package import PackageExporter, PackageImporter | |
from torch.utils import set_module | |
from ._serialization import validate_map_location | |
type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType | |
torch._C.ScriptMethod.graph_for = _script_method_graph_for # type: ignore[attr-defined] | |
torch._C.ScriptFunction.graph_for = _graph_for # type: ignore[attr-defined] | |
ScriptFunction = torch._C.ScriptFunction | |
ScriptFunction.__doc__ = """ | |
Functionally equivalent to a :class:`ScriptModule`, but represents a single | |
function and does not have any attributes or Parameters. | |
""" | |
set_module(ScriptFunction, "torch.jit") | |
# Throws an error if a jit function is pickled. | |
# Helps to avoid Python crashes for Python versions 3.9.5 + when protocol 0 or 1 is given as an argument. | |
def _reduce(cls): | |
raise pickle.PickleError("ScriptFunction cannot be pickled") | |
ScriptFunction.__reduce__ = _reduce # type: ignore[assignment] | |
if _enabled: | |
Attribute = collections.namedtuple("Attribute", ["value", "type"]) | |
else: | |
def Attribute(value, type): # type: ignore[no-redef] | |
return value | |
Attribute.__doc__ = """ | |
This method is a pass-through function that returns `value`, mostly | |
used to indicate to the TorchScript compiler that the left-hand side | |
expression is a class instance attribute with type of `type`. Note that | |
`torch.jit.Attribute` should only be used in `__init__` method of `jit.ScriptModule` | |
subclasses. | |
Though TorchScript can infer correct type for most Python expressions, there are some cases where | |
type inference can be wrong, including: | |
- Empty containers like `[]` and `{}`, which TorchScript assumes to be container of `Tensor` | |
- Optional types like `Optional[T]` but assigned a valid value of type `T`, TorchScript would assume | |
it is type `T` rather than `Optional[T]` | |
In eager mode, it is simply a pass-through function that returns `value` | |
without other implications. | |
Example: | |
.. testcode:: | |
import torch | |
from typing import Dict | |
class AttributeModule(torch.jit.ScriptModule): | |
def __init__(self): | |
super().__init__() | |
self.foo = torch.jit.Attribute(0.1, float) | |
# we should be able to use self.foo as a float here | |
assert 0.0 < self.foo | |
self.names_ages = torch.jit.Attribute({}, Dict[str, int]) | |
self.names_ages["someone"] = 20 | |
assert isinstance(self.names_ages["someone"], int) | |
m = AttributeModule() | |
# m will contain two attributes | |
# 1. foo of type float | |
# 2. names_ages of type Dict[str, int] | |
.. testcleanup:: | |
del AttributeModule | |
del m | |
Note: it's now preferred to instead use type annotations instead of `torch.jit.Attribute`: | |
.. testcode:: | |
import torch | |
from typing import Dict | |
class AttributeModule(torch.nn.Module): | |
names: Dict[str, int] | |
def __init__(self): | |
super().__init__() | |
self.names = {} | |
m = AttributeModule() | |
.. testcleanup:: | |
del AttributeModule | |
del m | |
Args: | |
value: An initial value to be assigned to attribute. | |
type: A Python type | |
Returns: | |
Returns `value` | |
""" | |
def _get_type_trace_db(): | |
# This is a private API. Use of this for external purposes is discouraged. | |
return type_trace_db | |
# Gets a function from the name of a method on a type | |
def _get_function_from_type(cls, name): | |
return getattr(cls, name, None) | |
# ScriptClasses must be new-style classes because we construct them using their | |
# __new__ method. | |
def _is_new_style_class(cls): | |
if hasattr(cls, "__class__"): | |
return "__dict__" in dir(cls) or hasattr(cls, "__slots__") | |
# These OrderedDictWrapper classes replace the actual OrderedDicts in | |
# module with versions that get/set properties inside of Module. | |
# This allows us to reuse most of nn.Module while still storing the | |
# data in C++. | |
# Each OrderedDict needs to support: | |
# x not in view | |
# x in view | |
# view[name] = ... | |
# view.values() | |
# del view[name] | |
# view.items() | |
# view.keys() | |
# len(view) | |
class OrderedDictWrapper: | |
def __init__(self, _c): | |
self._c = _c | |
def keys(self): | |
return [k for k, v in self.items()] | |
def values(self): | |
return [v for k, v in self.items()] | |
def __len__(self): | |
return len(self.values()) | |
def __delitem__(self, k): | |
raise RuntimeError("cannot delete methods or parameters of a script module") | |
def items(self): | |
return self._c.items() | |
def __setitem__(self, k, v): | |
if k not in self: | |
raise RuntimeError( | |
f"Can't add a new parameter after ScriptModule construction. Tried to add '{k}" | |
) | |
self._c.setattr(k, v) | |
def __contains__(self, k): | |
return self._c.contains(k) | |
def __getitem__(self, k): | |
if k not in self: | |
raise KeyError(k) | |
return self._c.getattr(k) | |
class OrderedModuleDict(OrderedDictWrapper): | |
def __init__(self, module, python_dict): | |
super().__init__(torch._C.ModuleDict(module)) | |
# contains _both_ script modules and non-script python-only modules | |
# because script modules are subclassed in python and the | |
# C++ Module class will not hold references to them, | |
# to ensure that you always get the same python value here | |
# we store it in the python dict as well | |
self._python_modules = python_dict | |
def items(self): | |
r = self._python_modules.items() | |
return r | |
def __contains__(self, k): | |
return k in self._python_modules | |
def __setitem__(self, k, v): | |
# Cases where sub-module can be re-assigned after ScriptModule construction | |
# 1. If the attr is an module interface type, it's guaranteed that the module is | |
# not inlined in the graph, so it's safe to swap a new ScriptModule in. | |
# 2. if the new value if a ScriptModule with the same JIT type, IR won't change | |
# and it's legit to swap a new module in. | |
# In these two cases we allow swapping a new scripted module and update the | |
# corresponding python module dict to keep sync. | |
# Note: the value to be swapped in has to be ScriptModule instead of nn.Module, | |
# otherwise it's illegal and we throw error. | |
if isinstance(v, ScriptModule): | |
self._c.setattr(k, v) | |
self._python_modules[k] = v | |
else: | |
raise RuntimeError( | |
"Cannot re-assign modules in a ScriptModule with non-scripted " | |
f"module, tried to replace existing module '{k}': {v}" | |
) | |
def __getitem__(self, k): | |
return self._python_modules[k] | |
# For each user-defined class that subclasses ScriptModule, this meta-class: | |
# (1) finds all the methods annotated with @script_method in a ScriptModule and | |
# removes them from the class attributes | |
# (2) puts a wrapper around the class's __init__ method to recursively compile | |
# all of the script_methods with the module after the original __init__ has | |
# run. This has to occur after the user-defined __init__ so that submodules and | |
# parameters are initialized _before_ the script compiler resolve references to | |
# `self.param` or `self.module`. | |
class ScriptMeta(type): | |
def __init__(cls, name, bases, attrs): # noqa: B902 | |
# Aggregate all the ScriptMethods and constants from superclasses | |
cls._methods: Dict[str, Any] = {} | |
cls._constants_set = set(getattr(cls, "__constants__", ())) | |
for base in reversed(bases): | |
for k, v in getattr(base, "_methods", {}).items(): | |
cls._methods[k] = v | |
base_constants: Set = getattr(base, "_constants_set", set()) | |
cls._constants_set = cls._constants_set.union(base_constants) | |
# find all the script methods of the current class | |
for k, v in sorted(attrs.items()): | |
if isinstance(v, ScriptMethodStub): | |
delattr(cls, k) | |
cls._methods[v.original_method.__name__] = v | |
if getattr(cls, "_disable_script_meta", False): | |
# We leave built-in ScriptModule types alone, since this metaclass | |
# is only for compiling user classes that inherit from | |
# ScriptModule. | |
return super().__init__(name, bases, attrs) | |
original_init = getattr(cls, "__init__", lambda self: None) | |
def init_then_script(self, *args, **kwargs): | |
num_methods = len(cls._methods) | |
original_init(self, *args, **kwargs) | |
added_methods_in_init = len(cls._methods) > num_methods | |
if type(self) == cls: | |
def make_stubs(module): | |
cls = type(module) | |
if hasattr(cls, "_methods"): | |
return [v for k, v in sorted(cls._methods.items())] | |
else: | |
return infer_methods_to_compile(module) | |
self.__dict__[ | |
"_actual_script_module" | |
] = torch.jit._recursive.create_script_module( | |
self, make_stubs, share_types=not added_methods_in_init | |
) | |
# Delete the Python attributes that now shadow the ScriptModule | |
# ones, so that __getattr__ and __setattr__ will properly find | |
# the scripted versions. | |
concrete_type = self._actual_script_module._concrete_type | |
for name in concrete_type.get_attributes(): | |
delattr(self, name) | |
for name, _ in concrete_type.get_modules(): | |
delattr(self, name) | |
for name in ("_parameters", "_buffers", "_modules"): | |
delattr(self, name) | |
cls.__init__ = init_then_script # type: ignore[misc] | |
super().__init__(name, bases, attrs) | |
class _CachedForward: | |
def __get__(self, obj, cls): | |
return self.__getattr__("forward") # type: ignore[attr-defined] | |
class ScriptWarning(Warning): | |
pass | |
def script_method(fn): | |
if not _enabled: | |
return fn | |
# NOTE: we need to traverse two frames here because the meta-class frame | |
# for ScriptModule will be present, as opposed to invoking @script on a | |
# a function or invoking define() on a CompilationUnit. | |
# The stack will look like: | |
# | |
# 0. createResolutionCallback() | |
# 1. script_method() | |
# 2. ScriptModule metaclass frame | |
# 3. Surrounding scope | |
# | |
# createResolutionCallback internally adds 1 to get us to the scope of this | |
# function (the calling function). Adding 2 gets us to the proper surrounding scope. | |
_rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2) | |
ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule") | |
return ScriptMethodStub(_rcb, ast, fn) | |
class ConstMap: | |
def __init__(self, const_mapping): | |
self.const_mapping = const_mapping | |
def __getattr__(self, attr): | |
return self.const_mapping[attr] | |
def unpackage_script_module( | |
importer: PackageImporter, script_module_id: str | |
) -> torch.nn.Module: | |
""" | |
Call by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function. | |
Performs work of loading and returning a ScriptModule from a ``torch.package`` archive. | |
""" | |
if not isinstance(importer.zip_reader, torch._C.PyTorchFileReader): | |
raise RuntimeError( | |
"Loading ScriptObjects from a PackageImporter created from a " | |
"directory is not supported. Use a package archive file instead." | |
) | |
cu = torch._C.CompilationUnit() | |
cpp_module = torch._C._import_ir_module_from_package( | |
cu, | |
importer.zip_reader, | |
importer.storage_context, | |
validate_map_location(importer.last_map_location), | |
script_module_id, | |
) | |
return wrap_cpp_module(cpp_module) | |
if _enabled: | |
_magic_methods = [ | |
"__iter__", | |
"__len__", | |
"__neg__", | |
"__mul__", | |
"__contains__", | |
"__add__", | |
"__sub__", | |
"__pow__", | |
"__truediv__", | |
"__mod__", | |
"__ne__", | |
"__eq__", | |
"__lt__", | |
"__gt__", | |
"__le__", | |
"__ge__", | |
"__and__", | |
"__or__", | |
"__xor__", | |
"__getitem__", | |
"__setitem__", | |
"__call__", | |
"__int__", | |
"__float__", | |
"__bool__", | |
"__str__", | |
"__enter__", | |
"__exit__", | |
] | |
class RecursiveScriptClass: | |
"""Wrapper for a TorchScript class instance for use in Python. | |
An analogue of RecursiveScriptModule for regular objects that are not modules. | |
This class is a wrapper around a torch._C.ScriptObject that represents an instance | |
of a TorchScript class and allows it to be used in Python. | |
Attributes: | |
_c [torch._C.ScriptObject]: The C++ object to which attribute lookups and method | |
calls are forwarded. | |
_props [Dict[str, property]]: A dictionary of properties fetched from self._c and | |
exposed on this wrppaer. | |
""" | |
def __init__(self, cpp_class): | |
super().__init__() | |
self.__dict__["_initializing"] = True | |
self._c = cpp_class | |
# Add wrapped object's properties to this class instance. | |
self._props = { | |
prop.name: property(prop.getter, prop.setter) | |
for prop in self._c._properties() | |
} | |
self.__dict__["_initializing"] = False | |
def __getattr__(self, attr): | |
if self.__dict__.get("_initializing"): | |
return super().__getattr__(attr) # type: ignore[misc] | |
if attr in self._props: | |
return self._props[attr].fget() # type: ignore[call-arg, misc] | |
return getattr(self._c, attr) | |
def __setattr__(self, attr, value): | |
if self.__dict__.get("_initializing"): | |
return super().__setattr__(attr, value) | |
if attr in self._props: | |
return self._props[attr].fset(value) # type: ignore[call-arg, misc] | |
setattr(self._c, attr, value) | |
# Delegate calls to magic methods like __len__ to the C++ module backing the | |
# RecursiveScriptClass. | |
def forward_magic_method(self, method_name, *args, **kwargs): | |
if not self._c._has_method(method_name): | |
raise TypeError() | |
self_method = self.__getattr__(method_name) | |
return self_method(*args, **kwargs) | |
def __getstate__(self): | |
raise pickle.PickleError("ScriptClasses cannot be pickled") | |
def __iadd__(self, other): | |
if self._c._has_method("__iadd__"): | |
return self.forward_magic_method("__iadd__", other) | |
else: | |
return self.forward_magic_method("__add__", other) | |
for method_name in _magic_methods: | |
def method_template(self, *args, **kwargs): | |
return self.forward_magic_method(method_name, *args, **kwargs) | |
setattr(RecursiveScriptClass, method_name, method_template) | |
# this is a Python 'non-data descriptor' that causes the first access | |
# to ScriptModule's forward to look up the forward method and stash | |
# it in the objects dict. Due to the standard rules for attribute lookup, | |
# subsequent lookups will just directly return the previously looked up method. | |
# This is necessary because nn.Module defines forward as a method. If we | |
# did nothing, __getattr__ would not be called. Instead we'd get nn.Module.forward | |
# which always throws an exception. | |
class ScriptModule(Module, metaclass=ScriptMeta): | |
r"""Wrapper for C++ torch::jit::Module with methods, attributes, and parameters. | |
A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s | |
contain methods, attributes, parameters, and | |
constants. These can be accessed the same way as on a normal ``nn.Module``. | |
""" | |
__jit_unused_properties__ = [ | |
"code", | |
"code_with_constants", | |
"graph", | |
"inlined_graph", | |
"original_name", | |
] | |
def __init__(self): | |
super().__init__() | |
forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment] | |
def __getattr__(self, attr): | |
if "_actual_script_module" not in self.__dict__: | |
return super().__getattr__(attr) | |
return getattr(self._actual_script_module, attr) | |
def __setattr__(self, attr, value): | |
if "_actual_script_module" not in self.__dict__: | |
# Unwrap torch.jit.Attribute into a regular setattr + record | |
# the provided type in __annotations__. | |
# | |
# This ensures that if we use the attr again in `__init__`, it | |
# will look like the actual value, not an instance of Attribute. | |
if isinstance(value, Attribute): | |
# NB: Ensure that we set __annotations__ on the specific | |
# class in question, and not on a superclass (which would | |
# be wrong wrong wrong!). | |
# See also https://github.com/pytorch/pytorch/issues/39463 | |
if "__annotations__" not in self.__class__.__dict__: | |
self.__class__.__annotations__ = {} | |
self.__annotations__[attr] = value.type | |
value = value.value | |
return super().__setattr__(attr, value) | |
setattr(self._actual_script_module, attr, value) | |
def define(self, src): | |
if "_actual_script_module" in self.__dict__: | |
# If we have completed initialization, just defer to the | |
# backing RecursiveScriptModule to eagerly compile the provided | |
# source. | |
return self._actual_script_module.define(src) | |
# Otherwise, we are still in the object's __init__. | |
# In that case, add `src` as a stub to be compiled. | |
# | |
# We use frames_up=1 to get to the proper surrounding scope. The stack | |
# will look like: | |
# 0. createResolutionCallback | |
# 1. define() | |
# 2. surrounding scope. | |
# | |
# createResolutionCallback internally adds 1 to get us to our frame, then | |
# we add 1 to get to the proper surrounding scope. | |
rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) | |
ast = torch._C._parse_source_def(src) | |
self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None) | |
def _replicate_for_data_parallel(self): | |
return self._actual_script_module._replicate_for_data_parallel() | |
def __reduce_package__(self, exporter: PackageExporter): | |
"""Save a ScriptModule inside of a ``torch.package`` archive. | |
Called by ``torch.package.PackageExporter``'s Pickler's ``persistent_id`` when | |
saving TorchScript objects. Performs act of saving a ScriptModule inside of | |
a ``torch.package`` archive. | |
Returns method to load the ScriptModule from a ``torch.package.PackageImporter``'s | |
Pickler's ``persistent_load`` function. | |
""" | |
script_module_id = exporter.get_unique_id() | |
exporter.script_module_serializer.serialize(self._c, int(script_module_id)) | |
return (unpackage_script_module, (script_module_id,)) | |
class RecursiveScriptModule(ScriptModule): | |
# XXX: RecursiveScriptModule inherits from ScriptModule for the sole | |
# reason that it retains the existing isinstance(ScriptModule) | |
# behavior. | |
r"""Retain the existing isinstance(ScriptModule) behavior. | |
The core data structure in TorchScript is the ``ScriptModule``. It is an | |
analogue of torch's ``nn.Module`` and represents an entire model as a tree of | |
submodules. Like normal modules, each individual module in a ``ScriptModule`` can | |
have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented | |
as Python functions, but in ``ScriptModule``\s methods are implemented as | |
TorchScript functions, a statically-typed subset of Python that contains all | |
of PyTorch's built-in Tensor operations. This difference allows your | |
``ScriptModule``\s code to run without the need for a Python interpreter. | |
``ScriptModule``\s should not be created manually, instead use | |
either :func:`tracing <torch.jit.trace>` or :func:`scripting <torch.jit.script>`. | |
Tracing and scripting can be applied incrementally and :ref:`composed as necessary <Types>`. | |
* Tracing records the tensor operations as executed with a set of example inputs and uses these | |
operations to construct a computation graph. You can use the full dynamic behavior of Python with tracing, | |
but values other than Tensors and control flow aren't captured in the graph. | |
* Scripting inspects the Python code of the model | |
and compiles it to TorchScript. Scripting allows the use of many `types`_ of values and supports dynamic control flow. | |
Many, but not all features of Python are supported by the compiler, so changes to the source code may be necessary. | |
""" | |
_disable_script_meta = True | |
def __init__(self, cpp_module): | |
self.__dict__["_initializing"] = True | |
self._c = cpp_module | |
super().__init__() | |
# Delete the 'training' attribute set up by `Module.__init__`. It | |
# will get set on the underlying cpp module, so we delete it here | |
# to avoid this version shadowing the cpp module version. | |
delattr(self, "training") | |
def _construct(cpp_module, init_fn): | |
""" | |
Construct a RecursiveScriptModule that's ready for use. | |
PyTorch code should use this to construct a RecursiveScriptModule instead | |
of instead of calling `__init__` directly, as it makes sure the | |
object is properly finalized (and in the future, we may take | |
control of how the RecursiveScriptModule instance is created). | |
Args: | |
cpp_module: The C++ Module that will hold the actual state of | |
this RecursiveScriptModule instance. | |
init_fn: Lambda that initializes the RecursiveScriptModule passed to it. | |
""" | |
script_module = RecursiveScriptModule(cpp_module) | |
init_fn(script_module) | |
# Finalize the ScriptModule: replace the nn.Module state with our | |
# custom implementations and flip the _initializing bit. | |
RecursiveScriptModule._finalize_scriptmodule(script_module) | |
return script_module | |
def _finalize_scriptmodule(script_module): | |
script_module._parameters = OrderedDictWrapper( | |
torch._C.ParameterDict(script_module._c) | |
) | |
script_module._buffers = OrderedDictWrapper( | |
torch._C.BufferDict(script_module._c) | |
) | |
script_module._modules = OrderedModuleDict( | |
script_module._c, script_module._modules | |
) | |
script_module._initializing = False | |
def _reconstruct(self, cpp_module): | |
""" | |
Re-construct an instance of RecursiveScriptModule using an instance of a C++ module. | |
Args: | |
cpp_module: The C++ module that this RecursiveScriptModule will be rebuilt around. | |
""" | |
self.__init__(cpp_module) # type: ignore[misc] | |
# Copy the concrete type from the C++ module to this ScriptModule. | |
self._concrete_type = torch._C.ConcreteModuleType.from_jit_type( | |
self._c._type() | |
) | |
# Copy submodules from the C++ module to this ScriptModule. | |
modules = {} | |
for name, cpp_module in torch._C.ModuleDict(self._c).items(): | |
modules[name] = wrap_cpp_module(cpp_module) | |
self._modules = OrderedModuleDict(self._c, modules) # type: ignore[assignment] | |
# Copy parameters and buffers. | |
self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c)) # type: ignore[assignment] | |
self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c)) # type: ignore[assignment] | |
# Get rid of the functions from the old C++ module. | |
self.__dict__ = { | |
k: v | |
for k, v in self.__dict__.items() | |
if not isinstance(v, torch._C.ScriptMethod) | |
} | |
self.__dict__["_initializing"] = False | |
def graph(self): | |
r"""Return a string representation of the internal graph for the ``forward`` method. | |
See :ref:`interpreting-graphs` for details. | |
""" | |
return self._c._get_method("forward").graph | |
def inlined_graph(self): | |
r""" | |
Return a string representation of the internal graph for the ``forward`` method. | |
This graph will be preprocessed to inline all function and method calls. | |
See :ref:`interpreting-graphs` for details. | |
""" | |
return self.forward.inlined_graph # type: ignore[attr-defined] | |
def code(self): | |
r""" | |
Return a pretty-printed representation (as valid Python syntax) of the internal graph for the ``forward`` method. | |
See :ref:`inspecting-code` for details. | |
""" | |
return self.forward.code # type: ignore[attr-defined] | |
def code_with_constants(self): | |
r"""Return a tuple. | |
Returns a tuple of: | |
[0] a pretty-printed representation (as valid Python syntax) of | |
the internal graph for the ``forward`` method. See `code`. | |
[1] a ConstMap following the CONSTANT.cN format of the output in [0]. | |
The indices in the [0] output are keys to the underlying constant's values. | |
See :ref:`inspecting-code` for details. | |
""" | |
r = self.forward.code_with_constants # type: ignore[attr-defined] | |
return (r[0], ConstMap(r[1])) | |
def save(self, f, **kwargs): | |
r"""Save with a file-like object. | |
save(f, _extra_files={}) | |
See :func:`torch.jit.save <torch.jit.save>` which accepts a file-like object. | |
This function, torch.save(), converts the object to a string, treating it as a path. | |
DO NOT confuse these two functions when it comes to the 'f' parameter functionality. | |
""" | |
return self._c.save(str(f), **kwargs) | |
def _save_for_lite_interpreter(self, *args, **kwargs): | |
r"""Add (or update) the bytecode session to the script model. | |
_save_for_lite_interpreter(f) | |
The updated model is used | |
in lite interpreter for mobile applications. | |
Args: | |
f: a string containing a file name. | |
_extra_files: Map from filename to contents which will be stored as part of 'f'. | |
""" | |
return self._c._save_for_mobile(*args, **kwargs) | |
def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs): | |
return self._c._save_to_buffer_for_mobile(*args, **kwargs) | |
def save_to_buffer(self, *args, **kwargs): | |
return self._c.save_to_buffer(*args, **kwargs) | |
def get_debug_state(self, *args, **kwargs): | |
return self._c.get_debug_state() | |
def extra_repr(self): | |
return f"original_name={self.original_name}" | |
def graph_for(self, *args, **kwargs): | |
return self.forward.graph_for(self, *args, **kwargs) # type: ignore[attr-defined] | |
def original_name(self): | |
if type(self) == str(self._c._type().name()): | |
return "" | |
return str(self._c._type().name()) | |
def define(self, src): | |
# We use frames_up=1 to get to the proper surrounding scope. The stack | |
# will look like: | |
# 0. createResolutionCallback | |
# 1. define() | |
# 2. surrounding scope. | |
# | |
# createResolutionCallback internally adds 1 to get us to our frame, then | |
# we add 1 to get to the proper surrounding scope. | |
rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1) | |
self._c._define(self._concrete_type, src, rcb) | |
def __getattr__(self, attr): | |
if "_initializing" not in self.__dict__: | |
raise RuntimeError( | |
"ScriptModule has not been initialized, did you forget to call super's init?" | |
) | |
if self._initializing: | |
return super().__getattr__(attr) | |
# _modules check is before hasattr since modules are included as attributes in _c, | |
# but we want to get the python wrapper from _modules instead of the raw _c object. | |
if attr in self._modules: | |
return self._modules[attr] | |
elif self._c.hasattr(attr): | |
return self._c.getattr(attr) | |
elif self._c._has_method(attr): | |
script_method = self._c._get_method(attr) | |
# cache method so future calls do not go through __getattr__ | |
# to improve invocation performance | |
self.__dict__[attr] = script_method | |
return script_method | |
return super().__getattr__(attr) | |
def __setattr__(self, attr, value): | |
if self._initializing: | |
return super().__setattr__(attr, value) | |
if attr in self._modules: | |
self._modules[attr] = value | |
elif self._c.hasattr(attr): | |
self._c.setattr(attr, value) | |
elif ( | |
hasattr(self, "_concrete_type") | |
and attr in self._concrete_type.get_constants().keys() | |
): | |
# TODO: we don't have _concrete_type set after load(), and in general we lose constant information. | |
# We should encode constants as class type attributes (or something) so it persists across save/load. | |
raise AttributeError( | |
f"Cannot mutate TorchScript constant value: '{attr}'. Value: '{value}'" | |
) | |
else: | |
# We allow setting Python attributes on the ScriptModule, for | |
# when people want to stash some convenience info on it. | |
# TODO: it's possible that the following is confusing: | |
# s = torch.jit.script(...) | |
# s.python_attr = ... | |
# s.save() <--- this doesn't have `python_attr` | |
# It's fairly trivial to save enough info to warn in this case. | |
return super().__setattr__(attr, value) | |
def __copy__(self): | |
return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c)) | |
def __deepcopy__(self, memo): | |
return torch.jit._recursive.wrap_cpp_module(copy.deepcopy(self._c, memo)) | |
# Python magic methods do method lookups on an object's class type, instead of looking up | |
# the method defines on the class instance. In order to continue to expose the magic methods | |
# of builtin-containers (ModuleList, Sequential, ModuleDict) to Python, we | |
# define magic methods here as a shim to the correct attribute. | |
def forward_magic_method(self, method_name, *args, **kwargs): | |
self_method = getattr(self, method_name) | |
if getattr(self_method, "__func__", None) == getattr( | |
RecursiveScriptModule, method_name | |
): | |
raise NotImplementedError() | |
return self_method(*args, **kwargs) | |
def __iter__(self): | |
return self.forward_magic_method("__iter__") | |
def __getitem__(self, idx): | |
return self.forward_magic_method("__getitem__", idx) | |
def __len__(self): | |
return self.forward_magic_method("__len__") | |
def __contains__(self, key): | |
return self.forward_magic_method("__contains__", key) | |
# dir is defined by the base nn.Module, so instead of throwing if | |
# it is not overridden, we call into the nn.Module __dir__ method | |
def __dir__(self): | |
self_method = self.__dir__ | |
if ( | |
self_method.__func__ # type: ignore[attr-defined] | |
== _get_function_from_type(RecursiveScriptModule, "__dir__") | |
): | |
return super().__dir__() | |
return self_method() | |
# to resolve bool(value), Python looks if __bool__ is defined then __iter__ | |
# is defined then returns true for classes. Since __iter__() on this | |
# class throws if it isn't overridden, we define __bool__ to preserve default behavior | |
def __bool__(self): | |
self_method = self.__bool__ | |
if ( | |
self_method.__func__ # type: ignore[attr-defined] | |
== _get_function_from_type(RecursiveScriptModule, "__bool__") | |
): | |
return True | |
return self_method() | |
def _replicate_for_data_parallel(self): | |
# we have to initialize ScriptModule properly so that | |
# it works with pybind11 | |
def init_fn(script_module): | |
# Don't do anything here, we'll initialize the ScriptModule below | |
return | |
return RecursiveScriptModule._construct( | |
self._c._replicate_for_data_parallel(), init_fn | |
) | |
# Need to copy all RecursiveScriptModule methods to ScriptModule. | |
# | |
# This is because `super().foo()` does not use | |
# `__getattr__` to look up `foo`. So we need to make each method available on | |
# the ScriptModule manually. | |
for name, item in RecursiveScriptModule.__dict__.items(): | |
if not callable(item) and not isinstance(item, property): | |
continue | |
if name.startswith("__") or hasattr(ScriptModule, name): | |
continue | |
# We can copy over the implementation wholesale because besides the | |
# `super()` thing above, ScriptModule behaves exactly like | |
# RecursiveScriptModule | |
setattr(ScriptModule, name, item) | |
def _get_methods(cls): | |
import inspect | |
# In Python 3 unbound methods are functions, but in Python 2 they are methods | |
return inspect.getmembers( | |
cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x) | |
) | |
_compiled_methods_allowlist = { | |
"forward", | |
"register_buffer", | |
"register_parameter", | |
"register_module", | |
"add_module", | |
"_apply", | |
"apply", | |
"cuda", | |
"cpu", | |
"to", | |
"type", | |
"float", | |
"double", | |
"half", | |
"state_dict", | |
"_save_to_state_dict", | |
"load_state_dict", | |
"_load_from_state_dict", | |
"_named_members", | |
"parameters", | |
"named_parameters", | |
"buffers", | |
"named_buffers", | |
"children", | |
"named_children", | |
"modules", | |
"named_modules", | |
"zero_grad", | |
"share_memory", | |
"_get_name", | |
"extra_repr", | |
"_slow_forward", | |
"_tracing_name", | |
"eval", | |
"train", | |
"get_extra_state", | |
"set_extra_state", | |
} | |
def _make_fail(name): | |
def fail(self, *args, **kwargs): | |
raise RuntimeError(name + " is not supported on ScriptModules") | |
return fail | |
for name, method in _get_methods(torch.nn.Module): | |
if name.startswith("__") or name.endswith("_call_impl"): | |
continue | |
if ( | |
name not in RecursiveScriptModule.__dict__ | |
and name not in _compiled_methods_allowlist | |
): | |
setattr(RecursiveScriptModule, method.__name__, _make_fail(name)) | |
else: | |
# TODO MAKE SURE THAT DISABLING WORKS | |
class RecursiveScriptClass: # type: ignore[no-redef] | |
pass | |
class ScriptModule(torch.nn.Module): # type: ignore[no-redef] | |
def __init__(self, arg=None): | |
super().__init__() | |
class RecursiveScriptModule(ScriptModule): # type: ignore[no-redef] | |
def __init__(self, arg=None): | |
super().__init__() | |
def call_prepare_scriptable_func_impl(obj, memo): | |
if not isinstance(obj, torch.nn.Module): | |
return obj | |
obj_id = id(obj) | |
# If obj_id is in memo, obj has already been prepared or is being | |
# prepared in another call up the stack. | |
if obj_id in memo: | |
return memo[id(obj)] | |
obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj # type: ignore[operator] | |
# Record obj in memo to avoid infinite recursion in the case of cycles in the module | |
# hierarchy when recursing below. | |
memo[obj_id] = obj | |
new_obj_dict = {} | |
for name, sub_module in obj.__dict__.items(): | |
if name == "_modules": | |
for k, v in sub_module.items(): | |
sub_module[k] = call_prepare_scriptable_func_impl(v, memo) | |
new_obj_dict[name] = sub_module | |
elif isinstance(sub_module, torch.nn.Module) and not isinstance( | |
sub_module, ScriptModule | |
): | |
new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo) | |
else: | |
new_obj_dict[name] = sub_module | |
for k, v in new_obj_dict.items(): | |
obj.__dict__[name] = v | |
return obj | |
def call_prepare_scriptable_func(obj): | |
memo: Dict[int, torch.nn.Module] = {} | |
return call_prepare_scriptable_func_impl(obj, memo) | |
def create_script_dict(obj): | |
""" | |
Create a ``torch._C.ScriptDict`` instance with the data from ``obj``. | |
Args: | |
obj (dict): The Python dictionary that is used to initialize the ``ScriptDict`` | |
returned by this function. | |
Returns: | |
An instance of ``torch._C.ScriptDict`` that has the same data as ``obj`` | |
and can be passed between Python and TorchScript with reference semantics and | |
zero copy overhead. | |
""" | |
return torch._C.ScriptDict(obj) # type: ignore[attr-defined] | |
def create_script_list(obj, type_hint=None): | |
""" | |
Create a ``torch._C.ScriptList`` instance with the data from ``obj``. | |
Args: | |
obj (dict): The Python list that is used to initialize the ``ScriptList`` | |
returned by this function. | |
Returns: | |
An instance of ``torch._C.ScriptList`` that has the same data as ``obj`` | |
and can be passed between Python and TorchScript with reference semantics and | |
zero copy overhead. | |
""" | |
return torch._C.ScriptList(obj) # type: ignore[attr-defined] | |
def script( | |
obj, | |
optimize=None, | |
_frames_up=0, | |
_rcb=None, | |
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None, | |
): | |
r"""Script the function. | |
Scripting a function or ``nn.Module`` will inspect the source code, compile | |
it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or | |
:class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all | |
features in Python work, but we provide enough functionality to compute on | |
tensors and do control-dependent operations. For a complete guide, see the | |
:ref:`language-reference`. | |
Scripting a dictionary or list copies the data inside it into a TorchScript instance than can be | |
subsequently passed by reference between Python and TorchScript with zero copy overhead. | |
``torch.jit.script`` can be used as a function for modules, functions, dictionaries and lists | |
and as a decorator ``@torch.jit.script`` for :ref:`torchscript-classes` and functions. | |
Args: | |
obj (Callable, class, or nn.Module): The ``nn.Module``, function, class type, | |
dictionary, or list to compile. | |
example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]): Provide example inputs | |
to annotate the arguments for a function or ``nn.Module``. | |
Returns: | |
If ``obj`` is ``nn.Module``, ``script`` returns | |
a :class:`ScriptModule` object. The returned :class:`ScriptModule` will | |
have the same set of sub-modules and parameters as the | |
original ``nn.Module``. If ``obj`` is a standalone function, | |
a :class:`ScriptFunction` will be returned. If ``obj`` is a ``dict``, then | |
``script`` returns an instance of `torch._C.ScriptDict`. If ``obj`` is a ``list``, | |
then ``script`` returns an instance of `torch._C.ScriptList`. | |
**Scripting a function** | |
The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction` | |
by compiling the body of the function. | |
Example (scripting a function): | |
.. testcode:: | |
import torch | |
@torch.jit.script | |
def foo(x, y): | |
if x.max() > y.max(): | |
r = x | |
else: | |
r = y | |
return r | |
print(type(foo)) # torch.jit.ScriptFunction | |
# See the compiled graph as Python code | |
print(foo.code) | |
# Call the function using the TorchScript interpreter | |
foo(torch.ones(2, 2), torch.ones(2, 2)) | |
.. testoutput:: | |
:hide: | |
... | |
****Scripting a function using example_inputs** | |
Example inputs can be used to annotate a function arguments. | |
Example (annotating a function before scripting): | |
.. testcode:: | |
import torch | |
def test_sum(a, b): | |
return a + b | |
# Annotate the arguments to be int | |
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)]) | |
print(type(scripted_fn)) # torch.jit.ScriptFunction | |
# See the compiled graph as Python code | |
print(scripted_fn.code) | |
# Call the function using the TorchScript interpreter | |
scripted_fn(20, 100) | |
.. testoutput:: | |
:hide: | |
... | |
**Scripting an nn.Module** | |
Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively | |
compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses | |
features supported in TorchScript, no changes to the original module code should be necessary. ``script`` | |
will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of | |
the original module. | |
Example (scripting a simple module with a Parameter): | |
.. testcode:: | |
import torch | |
class MyModule(torch.nn.Module): | |
def __init__(self, N, M): | |
super().__init__() | |
# This parameter will be copied to the new ScriptModule | |
self.weight = torch.nn.Parameter(torch.rand(N, M)) | |
# When this submodule is used, it will be compiled | |
self.linear = torch.nn.Linear(N, M) | |
def forward(self, input): | |
output = self.weight.mv(input) | |
# This calls the `forward` method of the `nn.Linear` module, which will | |
# cause the `self.linear` submodule to be compiled to a `ScriptModule` here | |
output = self.linear(output) | |
return output | |
scripted_module = torch.jit.script(MyModule(2, 3)) | |
Example (scripting a module with traced submodules): | |
.. testcode:: | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class MyModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# torch.jit.trace produces a ScriptModule's conv1 and conv2 | |
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) | |
self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) | |
def forward(self, input): | |
input = F.relu(self.conv1(input)) | |
input = F.relu(self.conv2(input)) | |
return input | |
scripted_module = torch.jit.script(MyModule()) | |
To compile a method other than ``forward`` (and recursively compile anything it calls), add | |
the :func:`@torch.jit.export <torch.jit.export>` decorator to the method. To opt out of compilation | |
use :func:`@torch.jit.ignore <torch.jit.ignore>` or :func:`@torch.jit.unused <torch.jit.unused>`. | |
Example (an exported and ignored method in a module):: | |
import torch | |
import torch.nn as nn | |
class MyModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
@torch.jit.export | |
def some_entry_point(self, input): | |
return input + 10 | |
@torch.jit.ignore | |
def python_only_fn(self, input): | |
# This function won't be compiled, so any | |
# Python APIs can be used | |
import pdb | |
pdb.set_trace() | |
def forward(self, input): | |
if self.training: | |
self.python_only_fn(input) | |
return input * 99 | |
scripted_module = torch.jit.script(MyModule()) | |
print(scripted_module.some_entry_point(torch.randn(2, 2))) | |
print(scripted_module(torch.randn(2, 2))) | |
Example ( Annotating forward of nn.Module using example_inputs):: | |
import torch | |
import torch.nn as nn | |
from typing import NamedTuple | |
class MyModule(NamedTuple): | |
result: List[int] | |
class TestNNModule(torch.nn.Module): | |
def forward(self, a) -> MyModule: | |
result = MyModule(result=a) | |
return result | |
pdt_model = TestNNModule() | |
# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward | |
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) | |
# Run the scripted_model with actual inputs | |
print(scripted_model([20])) | |
""" | |
global type_trace_db | |
if not _enabled: | |
return obj | |
if optimize is not None: | |
warnings.warn( | |
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" | |
) | |
# No-op for modules, functions, class instances that are already scripted | |
if isinstance(obj, RecursiveScriptClass): | |
return obj | |
if isinstance(obj, ScriptModule): | |
return obj | |
if isinstance(obj, ScriptFunction): | |
return obj | |
if example_inputs: | |
# If MonkeyType is installed, enable profile directed type annotation | |
# Check if example_inputs are defined and generate call traces | |
# for the method by running eager mode version of the method with | |
# the provide example inputs. This logs all the traces in type_trace_db | |
type_trace_db = JitTypeTraceStore() | |
if monkeytype_trace: | |
monkeytype_config = JitTypeTraceConfig(type_trace_db) | |
with monkeytype_trace(monkeytype_config): | |
if isinstance(example_inputs, Dict): | |
# If the obj is an nn.Module or a class, then each method is | |
# executed with the arguments provided in the example inputs. | |
# example inputs here will be of type Dict(class.method, (arguments)) | |
# This is used to infer type annotations for those methods | |
# which are not called directly under the hood of monkeytype. | |
for module, example_input in example_inputs.items(): | |
for example in example_input: | |
module(*example) | |
elif isinstance(example_inputs, List): | |
for examples in example_inputs: | |
obj(*examples) | |
else: | |
raise ValueError( | |
"Error: Unable to infer types. Please format the inputs to type `List[Tuple]`" | |
" or `Dict[Callable, List[Tuple]]` to be run with MonkeyType." | |
) | |
else: | |
warnings.warn( | |
"Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType " | |
"to enable Profile-Directed Typing in TorchScript. Refer to " | |
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. " | |
) | |
if isinstance(obj, torch.nn.Module): | |
obj = call_prepare_scriptable_func(obj) | |
return torch.jit._recursive.create_script_module( | |
obj, torch.jit._recursive.infer_methods_to_compile | |
) | |
else: | |
obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj # type: ignore[operator] | |
if isinstance(obj, dict): | |
return create_script_dict(obj) | |
if isinstance(obj, list): | |
return create_script_list(obj) | |
if inspect.isclass(obj): | |
qualified_name = _qualified_name(obj) | |
# If this type is a `nn.Module` subclass, they probably meant to pass | |
# an instance instead of a Module | |
if issubclass(obj, torch.nn.Module): | |
raise RuntimeError( | |
f"Type '{obj}' cannot be compiled since it inherits from nn.Module, pass an instance instead" | |
) | |
# Enums are automatically usable in TorchScript, explicitly scripting | |
# is not necessary, but not harmful either. | |
if issubclass(obj, enum.Enum): | |
return obj | |
if not _is_new_style_class(obj): | |
raise RuntimeError( | |
"TorchScript classes must be new-style classes. " | |
"Please inherit from 'object'." | |
) | |
if len(obj.mro()) > 2: | |
raise RuntimeError( | |
"TorchScript classes does not support inheritance yet. " | |
"Please directly inherit from 'object'." | |
) | |
if _rcb is None: | |
_rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) | |
_compile_and_register_class(obj, _rcb, qualified_name) | |
return obj | |
elif inspect.isfunction(obj) or inspect.ismethod(obj): | |
qualified_name = _qualified_name(obj) | |
# this is a decorated fn, and we need to the underlying fn and its rcb | |
if hasattr(obj, "__script_if_tracing_wrapper"): | |
obj = obj.__original_fn # type: ignore[union-attr] | |
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj) | |
# some functions are explicitly marked as not supported in script mode | |
if hasattr(obj, "__script_unsupported"): | |
raise RuntimeError("TorchScript error: " + obj.__script_unsupported) | |
_check_directly_compile_overloaded(obj) | |
maybe_already_compiled_fn = _try_get_jit_cached_function(obj) | |
if maybe_already_compiled_fn: | |
return maybe_already_compiled_fn | |
ast = get_jit_def(obj, obj.__name__) | |
if _rcb is None: | |
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj) | |
fn = torch._C._jit_script_compile( | |
qualified_name, ast, _rcb, get_default_args(obj) | |
) | |
# Forward docstrings | |
fn.__doc__ = obj.__doc__ | |
# Allow torch.compile() to inline | |
fn._torchdynamo_inline = obj # type: ignore[attr-defined] | |
_set_jit_function_cache(obj, fn) | |
return fn | |
else: | |
return torch.jit._recursive.create_script_class(obj) | |
# overloads are registered in _jit_internal and compiled here so that _overload | |
# can be used in nn/functional.py without an import cycle | |
def _check_overload_defaults(impl_defaults, overload_defaults, loc): | |
for name, overload_value in overload_defaults.items(): | |
if name not in impl_defaults or impl_defaults[name] != overload_value: | |
raise torch.jit.frontend.FrontendError( | |
loc, | |
"Default parameters on overloads do not affect the runtime so they " | |
"must equal to the default parameter on the implementation function. Found on " | |
f"parameter {name}", | |
) | |
def _compile_function_with_overload(overload_fn, qual_name, impl_fn): | |
overload_decl = get_jit_def(overload_fn, overload_fn.__name__).decl() | |
overload_signature = torch.jit.annotations.get_signature( | |
overload_fn, None, None, inspect.ismethod(overload_fn) | |
) | |
impl_ast = get_jit_def(impl_fn, impl_fn.__name__) | |
overload_defaults = get_default_args(overload_fn) | |
implementation_defaults = get_default_args(impl_fn) | |
_rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn) | |
_check_overload_defaults( | |
implementation_defaults, overload_defaults, overload_decl.range() | |
) | |
fn = torch._C._jit_script_compile_overload( | |
qual_name, | |
overload_decl, | |
impl_ast, | |
_rcb, | |
implementation_defaults, | |
overload_signature, | |
) | |
return fn | |
def _get_overloads(obj): | |
# check for cached compiled fns | |
existing_compiled_fns = _try_get_jit_cached_overloads(obj) | |
qual_name = _qualified_name(obj) | |
uncompiled_overloads = _jit_internal._get_fn_overloads(qual_name) | |
if uncompiled_overloads is None: | |
return existing_compiled_fns | |
if obj in uncompiled_overloads: | |
raise RuntimeError( | |
_jit_internal.get_overload_no_implementation_error_message("function", obj) | |
) | |
compiled_fns = [] | |
for overload_fn in uncompiled_overloads: | |
compiled_fns.append( | |
_compile_function_with_overload(overload_fn, qual_name, obj) | |
) | |
if existing_compiled_fns: | |
compiled_fns = existing_compiled_fns + compiled_fns | |
# cache compilation, remove information stored to do compilation | |
_set_jit_overload_cache(obj, compiled_fns) | |
_jit_internal._clear_fn_overloads(qual_name) | |
return compiled_fns | |
def _check_directly_compile_overloaded(obj): | |
qual_name = _qualified_name(obj) | |
if _jit_internal._get_fn_overloads(qual_name) or _try_get_jit_cached_overloads(obj): | |
raise RuntimeError( | |
f"Function {qual_name} cannot be directly compiled because it" | |
" is overloaded. It must be used in a context of a function" | |
" where its inputs can determine which overload to call." | |
) | |
def interface(obj): | |
r"""Decorate to annotate classes or modules of different types. | |
This decorator can be used to define an interface that can be used to annotate | |
classes or modules of different types. This can be used for to annotate a submodule | |
or attribute class that could have different types that implement the same | |
interface, or which could be swapped at runtime; or to store a list of modules or | |
classes of varying types. | |
It is sometimes used to implement "Callables" - functions or modules that implement | |
an interface but whose implementations differ and which can be swapped out. | |
Example: | |
.. testcode:: | |
import torch | |
from typing import List | |
@torch.jit.interface | |
class InterfaceType: | |
def run(self, x: torch.Tensor) -> torch.Tensor: | |
pass | |
# implements InterfaceType | |
@torch.jit.script | |
class Impl1: | |
def run(self, x: torch.Tensor) -> torch.Tensor: | |
return x.relu() | |
class Impl2(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.val = torch.rand(()) | |
@torch.jit.export | |
def run(self, x: torch.Tensor) -> torch.Tensor: | |
return x + self.val | |
def user_fn(impls: List[InterfaceType], idx: int, val: torch.Tensor) -> torch.Tensor: | |
return impls[idx].run(val) | |
user_fn_jit = torch.jit.script(user_fn) | |
impls = [Impl1(), torch.jit.script(Impl2())] | |
val = torch.rand(4, 4) | |
user_fn_jit(impls, 0, val) | |
user_fn_jit(impls, 1, val) | |
""" | |
if not inspect.isclass(obj): | |
raise RuntimeError("interface must be applied to a class") | |
if not _is_new_style_class(obj): | |
raise RuntimeError("TorchScript interfaces must inherit from 'object'") | |
# Expected MRO is: | |
# User module | |
# torch.nn.modules.module.Module | |
# object | |
is_module_interface = issubclass(obj, torch.nn.Module) and len(obj.mro()) == 3 | |
if not is_module_interface and len(obj.mro()) > 2: | |
raise RuntimeError( | |
"TorchScript interface does not support inheritance yet. " | |
"Please directly inherit from 'object' or 'nn.Module'." | |
) | |
qualified_name = _qualified_name(obj) | |
rcb = _jit_internal.createResolutionCallbackFromFrame(1) | |
# if this type is a `nn.Module` subclass, generate a module interface type | |
# instead of a class interface type; a module interface type only compiles | |
# the user provided methods as part of the interface | |
ast = get_jit_class_def(obj, obj.__name__) | |
mangled_classname = torch._C._jit_script_interface_compile( | |
qualified_name, ast, rcb, is_module_interface | |
) | |
obj.__torch_script_interface__ = mangled_classname | |
return obj | |
def _recursive_compile_class(obj, loc): | |
_qual_name = _qualified_name(obj) | |
# We're starting a new compilation, so update the error call stack in | |
# case it fails | |
error_stack = torch._C.CallStack(_qual_name, loc) | |
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) | |
return _compile_and_register_class(obj, rcb, _qual_name) | |
CompilationUnit = torch._C.CompilationUnit | |
set_module(CompilationUnit, "torch.jit") | |
def pad(s: str, padding: int, offset: int = 0, char: str = " "): | |
if padding >= len(s): | |
padding -= len(s) | |
return "".join([char for _ in range(padding + offset)]) + s | |
class _ScriptProfileColumn: | |
def __init__(self, header: str, alignment: int = 4, offset: int = 0): | |
self.header = header | |
self.alignment = alignment | |
self.offset = offset | |
self.rows: Dict[int, Any] = {} | |
def add_row(self, lineno: int, value: Any): | |
self.rows[lineno] = value | |
def materialize(self): | |
max_length = len(self.header) | |
rows: List[Tuple[int, str]] = [] | |
for key, value in self.rows.items(): | |
cell = str(value) | |
rows.append((key, cell)) | |
max_length = max(len(cell), max_length) | |
if self.alignment > 0: | |
padding = max_length + self.alignment | |
padding -= padding % self.alignment | |
else: | |
padding = 0 | |
rows = [(key, pad(cell, padding, self.offset)) for key, cell in rows] | |
return pad(self.header, padding, self.offset), rows | |
class _ScriptProfileTable: | |
def __init__(self, cols: List[_ScriptProfileColumn], source_range: List[int]): | |
self.cols = cols | |
self.source_range = source_range | |
def dump_string(self): | |
outputs: List[str] = [] | |
cells: List[Tuple[str, Dict[int, str]]] = [] | |
header_buffer = "" | |
for col in self.cols: | |
header, rows = col.materialize() | |
header_buffer += header | |
cells.append((header, dict(rows))) | |
outputs.append(header_buffer) | |
outputs.append(pad("", len(header_buffer), 0, "=")) | |
for line in self.source_range: | |
row_buffer = "" | |
for header, rows in cells: | |
cell = rows.get(line) | |
if cell is None: | |
row_buffer += pad("", len(header)) | |
else: | |
row_buffer += cell | |
outputs.append(row_buffer) | |
return "\n".join(outputs) | |
class _ScriptProfile: | |
def __init__(self): | |
self.profile = classes.profiling._ScriptProfile() | |
def enable(self): | |
self.profile.enable() | |
def disable(self): | |
self.profile.disable() | |
def dump_string(self) -> str: | |
outputs: List[str] = [] | |
for source_stats in self.profile._dump_stats(): | |
source_ref = source_stats.source() | |
source_lines = source_ref.text().splitlines() | |
dedent = min([len(line) - len(line.lstrip(" ")) for line in source_lines]) | |
source_lines = [line[dedent:] for line in source_lines] | |
start_line = source_ref.starting_lineno() | |
end_line = start_line + len(source_lines) | |
source_range = range(start_line, end_line) | |
lineno = _ScriptProfileColumn("Line #") | |
hits = _ScriptProfileColumn("Hits") | |
time_ns = _ScriptProfileColumn("Time (ns)") | |
line_contents = _ScriptProfileColumn("Line Contents", 0, 1) | |
stats = source_stats.line_map() | |
for line in source_range: | |
lineno.add_row(line, line) | |
line_contents.add_row(line, source_lines[line - start_line]) | |
stat = stats.get(line) | |
if stat is not None: | |
hits.add_row(line, stat.count()) | |
time_ns.add_row(line, stat.duration_ns()) | |
table = _ScriptProfileTable( | |
[lineno, hits, time_ns, line_contents], list(source_range) | |
) | |
outputs.append(table.dump_string()) | |
return "\n\n".join(outputs) | |
def dump(self): | |
print(self.dump_string()) | |
def _unwrap_optional(x): | |
assert x is not None, "Unwrapping null optional" | |
return x | |
_register_builtin(_unwrap_optional, "aten::_unwrap_optional") | |
_register_builtin(_jit_internal.is_scripting, "aten::is_scripting") | |
_register_builtin(has_torch_function, "aten::has_torch_function") | |
_register_builtin(has_torch_function_unary, "aten::has_torch_function") | |
_register_builtin(has_torch_function_variadic, "aten::has_torch_function") | |