Spaces:
Running
Running
"""Tracing. | |
This module contains functionality to support the JIT's tracing frontend, notably: | |
* torch.jit.trace | |
* torch.jit.trace_module | |
This is not intended to be imported directly; please use the exposed | |
functionalities in `torch.jit`. | |
""" | |
import contextlib | |
import copy | |
import functools | |
import inspect | |
import os | |
import re | |
import warnings | |
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar | |
from typing_extensions import ParamSpec | |
import torch | |
from torch._jit_internal import ( | |
_qualified_name, | |
get_callable_argument_names, | |
is_scripting, | |
) | |
from torch.autograd import function | |
from torch.jit._script import _CachedForward, script, ScriptModule | |
from torch.jit._state import _enabled, _python_cu | |
from torch.nn import Module | |
from torch.testing._comparison import default_tolerances | |
_flatten = torch._C._jit_flatten | |
_unflatten = torch._C._jit_unflatten | |
R = TypeVar("R", covariant=True) # return type (always covariant) | |
P = ParamSpec("P") | |
def _create_interpreter_name_lookup_fn(frames_up=1): | |
def _get_interpreter_name_for_var(var): | |
frame = inspect.currentframe() | |
if not frame: | |
raise RuntimeError("failed to inspect frame") | |
i = 0 | |
while i < frames_up + 1: | |
frame = frame.f_back | |
if not frame: | |
raise RuntimeError("failed to get frame") | |
i += 1 | |
f_locals = frame.f_locals | |
f_globals = frame.f_globals | |
for k, v in f_locals.items(): | |
if isinstance(v, torch.Tensor) and var is v: | |
return k if k != "self" else "" | |
return "" | |
return _get_interpreter_name_for_var | |
def _unique_state_dict(module, keep_vars=False): | |
# since Parameter.detach() always creates a new torch.Tensor instance, | |
# id(v) doesn't work with it. So we always get the Parameter or Buffer | |
# as values, and deduplicate the params using Parameters and Buffers | |
state_dict = module.state_dict(keep_vars=True) | |
filtered_dict = type(state_dict)() | |
seen_ids: Set[int] = set() | |
for k, v in state_dict.items(): | |
if id(v) in seen_ids: | |
continue | |
seen_ids.add(id(v)) | |
if keep_vars: | |
filtered_dict[k] = v | |
else: | |
filtered_dict[k] = v.detach() | |
return filtered_dict | |
class ONNXTracedModule(torch.nn.Module): | |
def __init__( | |
self, | |
inner, | |
strict=True, | |
force_outplace=False, | |
return_inputs=False, | |
return_inputs_states=False, | |
): | |
super().__init__() | |
# inner may be a Module, or it may be an arbitrary callable | |
# If it's a Module, we get its parameters automatically, which lets | |
# us avoid a special casing functions versus modules. | |
self.inner = inner | |
self.strict = strict | |
self._force_outplace = force_outplace | |
self._return_inputs = return_inputs | |
self._return_inputs_states = return_inputs_states | |
def forward(self, *args: torch.Tensor): | |
in_vars, in_desc = _flatten(args) | |
# NOTE: use full state, because we need it for BatchNorm export | |
# This differs from the compiler path, which doesn't support it at the moment. | |
module_state = list(_unique_state_dict(self, keep_vars=True).values()) | |
ret_inputs = [] | |
inputs_states = [] | |
outs = [] | |
def wrapper(*args): | |
in_args: List[torch.Tensor] = [] | |
for i in range(len(in_vars)): | |
if not isinstance(args[i], torch.Tensor): | |
raise RuntimeError("Expected Tensor argument") | |
in_args.append(args[i]) | |
trace_inputs = _unflatten(in_args, in_desc) | |
if self._return_inputs: | |
ret_inputs.append( | |
tuple(x.clone(memory_format=torch.preserve_format) for x in args) | |
) | |
if self._return_inputs_states: | |
inputs_states.append(_unflatten(in_args, in_desc)) | |
outs.append(self.inner(*trace_inputs)) | |
if self._return_inputs_states: | |
inputs_states[0] = (inputs_states[0], trace_inputs) | |
out_vars, _ = _flatten(outs) | |
if len(out_vars) == 1: | |
return out_vars[0] | |
else: | |
return tuple(out_vars) | |
graph, out = torch._C._create_graph_by_tracing( | |
wrapper, | |
in_vars + module_state, | |
_create_interpreter_name_lookup_fn(), | |
self.strict, | |
self._force_outplace, | |
) | |
if self._return_inputs: | |
return graph, outs[0], ret_inputs[0] | |
if self._return_inputs_states: | |
return graph, outs[0], inputs_states[0] | |
else: | |
return graph, outs[0] | |
def _clone_inputs(args): | |
def clone_input(a): | |
if a is None: | |
return None | |
elif isinstance(a, torch.Tensor): | |
# TODO: figure out one liner to .clone() and set requires_grad | |
v = ( | |
a.detach() | |
.clone(memory_format=None if a.is_mkldnn else torch.preserve_format) | |
.requires_grad_(a.requires_grad) | |
) | |
if a.grad is not None: | |
v.grad = clone_input(v.grad) | |
return v | |
else: | |
return a.clone(memory_format=torch.preserve_format) | |
return function._nested_map( | |
lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors" | |
)(args) | |
# This is purely for developer debugging. We are not going to advertise it. | |
_JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False) # CUDA-only timing | |
_JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False) | |
_JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False) | |
def _time(trace_name, name, time=True): | |
if (not _JIT_TIME and not time) or not torch.cuda.is_available(): | |
yield | |
return | |
stream = torch.cuda.current_stream() | |
start = torch.cuda.Event(enable_timing=True) | |
end = torch.cuda.Event(enable_timing=True) | |
stream.record_event(start) | |
try: | |
yield | |
finally: | |
stream.record_event(end) | |
end.synchronize() | |
print(f"{trace_name} {name} time: {start.elapsed_time(end)} ms") | |
def verify(model, args, loss_fn=torch.sum, devices=None): | |
""" | |
Verify that a JIT compiled model has the same behavior as its uncompiled version along with its backwards pass. | |
If your model returns multiple outputs, | |
you must also specify a `loss_fn` to produce a loss for which | |
the backwards will be computed. | |
This function has side-effects (e.g., it executes your model / saves and loads | |
parameters), so don't expect the model to come out exactly the same as what | |
you passed in. | |
Args: | |
model (compiled torch.nn.Module or function): the module/function to be | |
verified. The module/function definition MUST have been decorated with | |
`@torch.jit.compile`. | |
args (tuple or Tensor): the positional arguments to pass to the | |
compiled function/module to be verified. A non-tuple is assumed to | |
be a single positional argument to be passed to the model. | |
loss_fn (function, optional): the loss function to be applied to | |
the output of the model, before backwards is invoked. By default, | |
we assume that a model returns a single result, and we :func:`torch.sum` | |
before calling backwards; if this is inappropriate, you can pass your | |
own loss function. Note that if a model returns a tuple of results, | |
these are passed as separate positional arguments to `loss_fn`. | |
devices (iterable of device IDs, optional): the GPU devices which the | |
compiled module will be run on. This determines the RNG state we | |
must save when running both compiled and uncompiled versions of the model. | |
""" | |
# TODO: In principle, we track device information in our trace, so it | |
# should be possible to check if our execution actually obeyed the 'devices' | |
# the user provided. | |
# TODO: Consider adding a utility function to torch.jit to test | |
# for this case | |
if not isinstance(model, torch._C.CompiledFunction): # type: ignore[attr-defined] | |
raise TypeError( | |
"Cannot verify an uncompiled module. Add @torch.jit.compile to compile it" | |
) | |
is_module = isinstance(model, Module) | |
if not isinstance(args, tuple): | |
args = (args,) | |
saved_args = _clone_inputs(args) | |
if is_module: | |
saved_state = copy.deepcopy(model.state_dict()) | |
def run_fwd_bwd(args, force_trace=False, assert_compiled=False): | |
params = list(model.parameters()) if is_module else [] | |
in_vars, _ = _flatten((args, params)) | |
# We use a special API to reset the trace and compile it from scratch. | |
compiled_fn = model | |
if force_trace: | |
compiled_fn.clear_cache() | |
if assert_compiled: | |
hits = compiled_fn.hits | |
out = model(*args) | |
if assert_compiled and compiled_fn.hits == hits: # type: ignore[possibly-undefined] | |
raise RuntimeError("failed to use the compiled function") | |
if not isinstance(out, tuple): | |
out = (out,) | |
if loss_fn == torch.sum and len(out) != 1: | |
raise ValueError( | |
f"Model returns {len(out)} outputs, but default loss function " | |
"(torch.sum) can only handle a single output" | |
) | |
out_vars, _ = _flatten(out) | |
saved_outs = [ | |
v.detach().clone(memory_format=torch.preserve_format) for v in out_vars | |
] | |
loss = loss_fn(*out) | |
grads = torch.autograd.grad([loss], in_vars) | |
# TODO: I'm not sure if the clone here is necessary but it is safer | |
saved_grads = [ | |
v.detach().clone(memory_format=torch.preserve_format) for v in grads | |
] | |
return (saved_outs, saved_grads) | |
with torch.random.fork_rng(devices, _caller="torch.jit.verify"): | |
uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True) | |
assert model.has_trace_for(*args) | |
if is_module: | |
model.load_state_dict(saved_state) # type: ignore[possibly-undefined] | |
compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True) | |
_verify_equal(uncompiled_outs, compiled_outs) | |
_verify_equal(uncompiled_grads, compiled_grads) | |
def _verify_equal(xs, ys): | |
for x, y in zip(xs, ys): | |
if x.sub(y).abs().max() > 1e-6: | |
raise RuntimeError("JIT and real computation mismatch") | |
def indent(s): | |
return "\n".join(["\t" + line for line in s.splitlines()]) | |
class TracingCheckError(Exception): | |
def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None): | |
self.message = "Tracing failed sanity checks!\n" | |
if extra_msg is not None: | |
self.message += extra_msg + "\n" | |
if graph_diff_error is not None: | |
self.message += "ERROR: Graphs differed across invocations!\n" | |
self.message += indent(graph_diff_error) + "\n" | |
if tensor_compare_error is not None: | |
self.message += ( | |
"ERROR: Tensor-valued Constant nodes differed in value " | |
"across invocations. This often indicates that the tracer has" | |
" encountered untraceable code.\n" | |
) | |
self.message += indent(tensor_compare_error) + "\n" | |
super().__init__(self.message) | |
# Check the traced module against a set of user-provided validation inputs | |
def _check_trace( | |
check_inputs, | |
func, | |
traced_func, | |
check_tolerance, | |
strict, | |
force_outplace, | |
is_trace_module, | |
_module_class, | |
example_inputs_is_kwarg=False, | |
): | |
# Note: tracing is independent of optimizations, which consume the trace | |
for inputs in check_inputs: | |
if isinstance(inputs, torch.Tensor): | |
inputs = (inputs,) | |
if is_trace_module: | |
copied_dict = {} | |
for name, data in inputs.items(): | |
copied_dict[name] = _clone_inputs(data) | |
check_mod = torch.jit.trace_module( | |
getattr(func, "__self__", func), | |
copied_dict, | |
check_trace=False, | |
strict=strict, | |
_force_outplace=force_outplace, | |
_module_class=_module_class, | |
_compilation_unit=torch._C.CompilationUnit(), | |
example_inputs_is_kwarg=example_inputs_is_kwarg, | |
_store_inputs=False, | |
) | |
check_mod_func = check_mod._c._get_method(traced_func.name) | |
inputs = inputs[traced_func.name] | |
if ( | |
isinstance(inputs, (torch.Tensor)) | |
or isinstance(inputs, dict) | |
and not example_inputs_is_kwarg | |
): | |
inputs = (inputs,) | |
else: | |
if example_inputs_is_kwarg: | |
check_mod = torch.jit.trace( | |
func, | |
check_trace=False, | |
strict=strict, | |
_force_outplace=force_outplace, | |
_module_class=_module_class, | |
example_kwarg_inputs=_clone_inputs(inputs), | |
_store_inputs=False, | |
) | |
else: | |
check_mod = torch.jit.trace( | |
func, | |
_clone_inputs(inputs), | |
check_trace=False, | |
strict=strict, | |
_force_outplace=force_outplace, | |
_module_class=_module_class, | |
_store_inputs=False, | |
) | |
check_mod_func = check_mod | |
def graph_diagnostic_info(): | |
mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph) | |
torch._C._jit_pass_inline(mod_canonicalized) | |
torch._C._jit_pass_erase_shape_information(mod_canonicalized) | |
mod_str = str(mod_canonicalized) | |
mod_str = re.sub(r"___torch_mangle_[0-9]+\.", "", mod_str) | |
check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph) | |
torch._C._jit_pass_inline(check_canonicalized) | |
torch._C._jit_pass_erase_shape_information(check_canonicalized) | |
check_str = str(check_canonicalized) | |
check_str = re.sub(r"___torch_mangle_[0-9]+\.", "", check_str) | |
graph_diff_errors = None | |
if mod_str != check_str: | |
import difflib | |
graph_diff = difflib.ndiff( | |
mod_str.splitlines(True), check_str.splitlines(True) | |
) | |
graph_diff_errors = "Graph diff:\n" + indent("".join(graph_diff)) + "\n" | |
for n_mod, n_check in zip( | |
mod_canonicalized.nodes(), check_canonicalized.nodes() | |
): | |
if str(n_mod) != str(n_check): | |
graph_diff_errors += "First diverging operator:\n" | |
node_diff = difflib.ndiff( | |
str(n_mod).splitlines(True), str(n_check).splitlines(True) | |
) | |
source_printout = ( | |
"Node diff:\n" + indent("".join(node_diff)) + "\n" | |
) | |
mod_stack = n_mod.sourceRange() | |
if mod_stack: | |
source_printout += ( | |
"Trace source location:\n" + indent(mod_stack) + "\n" | |
) | |
check_stack = n_check.sourceRange() | |
if check_stack: | |
source_printout += ( | |
"Check source location:\n" + indent(check_stack) + "\n" | |
) | |
graph_diff_errors += source_printout | |
break # For now, only print out the first pair of nodes that diverges | |
tensor_compare_errors = None | |
# Check Tensor-valued constant nodes | |
for n_mod, n_check in zip( | |
mod_canonicalized.nodes(), check_canonicalized.nodes() | |
): | |
if n_mod.kind() != n_check.kind(): | |
break # Graphs have already diverged | |
if n_mod.kind() == "prim::Constant" and not ( | |
n_mod.mustBeNone() or n_check.mustBeNone() | |
): | |
if not n_mod.hasAttribute("value"): | |
continue | |
if n_mod.kindOf("value") != "t" or n_check.kindOf("value") != "t": | |
continue | |
mod_tensor_val = n_mod.t("value") | |
check_tensor_val = n_check.t("value") | |
try: | |
torch.testing.assert_close( | |
mod_tensor_val, check_tensor_val, equal_nan=True | |
) | |
except (RuntimeError, AssertionError) as e: | |
if tensor_compare_errors is None: | |
tensor_compare_errors = "" | |
tensor_compare_errors += "Node:\n" + indent(str(n_mod)) + "\n" | |
compare_stack = n_mod.sourceRange() | |
if compare_stack: | |
tensor_compare_errors += ( | |
"Source Location:\n" + indent(compare_stack) + "\n" | |
) | |
tensor_compare_errors += "Comparison exception: " + indent( | |
str(e) | |
) | |
break # For now, only print the first diverging pair | |
return graph_diff_errors, tensor_compare_errors | |
def wrap_retval(x): | |
return x if isinstance(x, tuple) else (x,) | |
def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): | |
try: | |
if isinstance(inputs, dict) and example_inputs_is_kwarg: | |
outs = wrap_retval(mod(**inputs)) | |
else: | |
outs = wrap_retval(mod(*_clone_inputs(inputs))) | |
outs = [out for out in outs if isinstance(out, torch.Tensor)] | |
return outs | |
except Exception as e: | |
graph_diff_errors, tensor_compare_errors = graph_diagnostic_info() | |
msg = f"encountered an exception while running the {running_what} with test inputs.\nException:\n{indent(str(e))}" | |
raise TracingCheckError( | |
graph_diff_errors, | |
tensor_compare_errors, | |
extra_msg=msg, | |
) from e | |
has_warned = [False] | |
def maybe_warn_nondeterministic(): | |
if has_warned[0]: | |
return | |
has_warned[0] = True | |
nondeterm_ops = [ | |
op for op in traced_func.graph.nodes() if op.isNondeterministic() | |
] | |
if len(nondeterm_ops) > 0: | |
nondeterministic_ops_warning = "Trace had nondeterministic nodes. " | |
nondeterministic_ops_warning += ( | |
"Did you forget call .eval() on your model? Nodes:\n" | |
) | |
nondeterministic_ops_warning += "\n".join( | |
[indent(str(op)) for op in nondeterm_ops][:20] | |
) | |
nondeterministic_ops_warning += ( | |
"\nThis may cause errors in trace checking. To disable trace checking," | |
" pass check_trace=False to torch.jit.trace()" | |
) | |
warnings.warn( | |
nondeterministic_ops_warning, category=TracerWarning, stacklevel=5 | |
) | |
def compare_outputs(original, reference, match_what): | |
all_ok = True | |
for i, (orig, ref) in enumerate(zip(original, reference)): | |
try: | |
if orig.is_quantized: | |
orig = orig.dequantize() | |
if ref.is_quantized: | |
ref = ref.dequantize() | |
if orig.is_mkldnn: | |
orig = orig.to_dense() | |
if ref.is_mkldnn: | |
ref = ref.to_dense() | |
if ref.is_complex() or orig.is_complex(): | |
torch.testing.assert_close( | |
orig.to(torch.cdouble), | |
ref.to(torch.cdouble), | |
rtol=check_tolerance, | |
atol=default_tolerances(orig, ref)[1], | |
equal_nan=True, | |
) | |
else: | |
if orig.is_mps or ref.is_mps: | |
torch.testing.assert_close( | |
orig.float(), | |
ref.float(), | |
rtol=check_tolerance, | |
atol=default_tolerances(orig, ref)[1], | |
equal_nan=True, | |
) | |
elif getattr(orig, "is_nested", None) or getattr( | |
ref, "is_nested", None | |
): | |
assert getattr(orig, "is_nested", None) == getattr( | |
ref, "is_nested", None | |
) | |
for t_orig, t_ref in zip(orig.unbind(), ref.unbind()): | |
torch.testing.assert_close( | |
t_orig.double(), | |
t_ref.double(), | |
rtol=check_tolerance, | |
atol=default_tolerances(t_orig, t_ref)[1], | |
equal_nan=True, | |
) | |
else: | |
torch.testing.assert_close( | |
orig.double(), | |
ref.double(), | |
rtol=check_tolerance, | |
atol=default_tolerances(orig, ref)[1], | |
equal_nan=True, | |
) | |
except AssertionError as e: | |
maybe_warn_nondeterministic() | |
warnings.warn( | |
"Output nr " | |
+ str(i + 1) | |
+ ". of the traced function does not match " | |
"the corresponding output of the " | |
+ match_what | |
+ ". Detailed error:\n" | |
+ str(e), | |
category=TracerWarning, | |
stacklevel=4, | |
) | |
all_ok = False | |
return all_ok | |
traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace") | |
fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function") | |
if compare_outputs(traced_outs, fn_outs, "Python function"): | |
check_outs = run_mod_and_filter_tensor_outputs( | |
check_mod_func, inputs, "repeated trace" | |
) | |
compare_outputs(traced_outs, check_outs, "repeated trace") | |
diag_info = graph_diagnostic_info() | |
if any(info is not None for info in diag_info): | |
raise TracingCheckError(*diag_info) | |
class TracerWarning(Warning): | |
def ignore_lib_warnings(): | |
# We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace | |
warnings.filterwarnings( | |
"ignore", category=TracerWarning, module="torch.(?!jit)" | |
) | |
warnings.filterwarnings("ignore", "torch::jit::fuser::cuda") | |
# We ignore the tracer warnings coming form inside the library, because all our shape | |
# checks in nn will trigger them. | |
TracerWarning.ignore_lib_warnings() | |
torch._C._tracer_warn_use_python() | |
def make_tuple(example_inputs): | |
if isinstance(example_inputs, (torch.Tensor, dict)): | |
return (example_inputs,) | |
# done primarily so that weird iterables fail here and not pybind11 code | |
if not isinstance(example_inputs, tuple): | |
return tuple(example_inputs) | |
return example_inputs | |
def make_module(mod, _module_class, _compilation_unit): | |
if isinstance(mod, ScriptModule): | |
return mod | |
elif torch._jit_internal.module_has_exports(mod): | |
infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods | |
return torch.jit._recursive.create_script_module( | |
mod, infer_methods_stubs_fn, share_types=False, is_tracing=True | |
) | |
else: | |
if _module_class is None: | |
_module_class = TopLevelTracedModule | |
return _module_class(mod, _compilation_unit=_compilation_unit) | |
def wrap_check_inputs(check_inputs): | |
if check_inputs is None: | |
return None | |
return [{"forward": c} for c in check_inputs] | |
def trace( | |
func, | |
example_inputs=None, | |
optimize=None, | |
check_trace=True, | |
check_inputs=None, | |
check_tolerance=1e-5, | |
strict=True, | |
_force_outplace=False, | |
_module_class=None, | |
_compilation_unit=_python_cu, | |
example_kwarg_inputs=None, | |
_store_inputs=True, | |
): | |
r""" | |
Trace a function and return an executable or :class:`ScriptFunction` that will be optimized using just-in-time compilation. | |
Tracing is ideal for code that operates only on | |
``Tensor``\\s and lists, dictionaries, and | |
tuples of ``Tensor``\\s. | |
Using `torch.jit.trace` and `torch.jit.trace_module`, you can turn an | |
existing module or Python function into a TorchScript | |
:class:`ScriptFunction` or :class:`ScriptModule`. You must provide example | |
inputs, and we run the function, recording the operations performed on all | |
the tensors. | |
* The resulting recording of a standalone function produces `ScriptFunction`. | |
* The resulting recording of `nn.Module.forward` or `nn.Module` produces | |
`ScriptModule`. | |
This module also contains any parameters that the original | |
module had as well. | |
Warning: | |
Tracing only correctly records functions and modules which are not data | |
dependent (e.g., do not have conditionals on data in tensors) and do not have | |
any untracked external dependencies (e.g., perform input/output or | |
access global variables). Tracing only records operations done when the given | |
function is run on the given tensors. Therefore, the returned | |
`ScriptModule` will always run the same traced graph on any input. This | |
has some important implications when your module is expected to run | |
different sets of operations, depending on the input and/or the module | |
state. For example, | |
* Tracing will not record any control-flow like if-statements or loops. | |
When this control-flow is constant across your module, this is fine | |
and it often inlines the control-flow decisions. But sometimes the | |
control-flow is actually part of the model itself. For instance, a | |
recurrent network is a loop over the (possibly dynamic) length of an | |
input sequence. | |
* In the returned :class:`ScriptModule`, operations that have different | |
behaviors in ``training`` and ``eval`` modes will always behave as if | |
it is in the mode it was in during tracing, no matter which mode the | |
`ScriptModule` is in. | |
In cases like these, tracing would not be appropriate and | |
:func:`scripting <torch.jit.script>` is a better choice. If you trace | |
such models, you may silently get incorrect results on subsequent | |
invocations of the model. The tracer will try to emit warnings when | |
doing something that may cause an incorrect trace to be produced. | |
Args: | |
func (callable or torch.nn.Module): A Python function or `torch.nn.Module` | |
that will be run with `example_inputs`. `func` arguments and return | |
values must be tensors or (possibly nested) tuples that contain | |
tensors. When a module is passed `torch.jit.trace`, only the | |
``forward`` method is run and traced (see :func:`torch.jit.trace | |
<torch.jit.trace_module>` for details). | |
Keyword arguments: | |
example_inputs (tuple or torch.Tensor or None, optional): A tuple of example | |
inputs that will be passed to the function while tracing. | |
Default: ``None``. Either this argument or ``example_kwarg_inputs`` | |
should be specified. The resulting trace can be run with inputs of | |
different types and shapes assuming the traced operations support those | |
types and shapes. `example_inputs` may also be a single Tensor in which | |
case it is automatically wrapped in a tuple. When the value is None, | |
``example_kwarg_inputs`` should be specified. | |
check_trace (``bool``, optional): Check if the same inputs run through | |
traced code produce the same outputs. Default: ``True``. You might want | |
to disable this if, for example, your network contains non- | |
deterministic ops or if you are sure that the network is correct despite | |
a checker failure. | |
check_inputs (list of tuples, optional): A list of tuples of input | |
arguments that should be used to check the trace against what is | |
expected. Each tuple is equivalent to a set of input arguments that | |
would be specified in ``example_inputs``. For best results, pass in | |
a set of checking inputs representative of the space of shapes and | |
types of inputs you expect the network to see. If not specified, | |
the original ``example_inputs`` are used for checking | |
check_tolerance (float, optional): Floating-point comparison tolerance | |
to use in the checker procedure. This can be used to relax the | |
checker strictness in the event that results diverge numerically | |
for a known reason, such as operator fusion. | |
strict (``bool``, optional): run the tracer in a strict mode or not | |
(default: ``True``). Only turn this off when you want the tracer to | |
record your mutable container types (currently ``list``/``dict``) | |
and you are sure that the container you are using in your | |
problem is a ``constant`` structure and does not get used as | |
control flow (if, for) conditions. | |
example_kwarg_inputs (dict, optional): This parameter is a pack of keyword | |
arguments of example inputs that will be passed to the function while | |
tracing. Default: ``None``. Either this argument or ``example_inputs`` | |
should be specified. The dict will be unpacking by the arguments name | |
of the traced function. If the keys of the dict don't not match with | |
the traced function's arguments name, a runtime exception will be raised. | |
Returns: | |
If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns | |
a :class:`ScriptModule` object with a single ``forward`` method | |
containing the traced code. The returned `ScriptModule` will | |
have the same set of sub-modules and parameters as the original | |
``nn.Module``. If ``func`` is a standalone function, ``trace`` | |
returns `ScriptFunction`. | |
Example (tracing a function): | |
.. testcode:: | |
import torch | |
def foo(x, y): | |
return 2 * x + y | |
# Run `foo` with the provided inputs and record the tensor operations | |
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) | |
# `traced_foo` can now be run with the TorchScript interpreter or saved | |
# and loaded in a Python-free environment | |
Example (tracing an existing module):: | |
import torch | |
import torch.nn as nn | |
class Net(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv = nn.Conv2d(1, 1, 3) | |
def forward(self, x): | |
return self.conv(x) | |
n = Net() | |
example_weight = torch.rand(1, 1, 3, 3) | |
example_forward_input = torch.rand(1, 1, 3, 3) | |
# Trace a specific method and construct `ScriptModule` with | |
# a single `forward` method | |
module = torch.jit.trace(n.forward, example_forward_input) | |
# Trace a module (implicitly traces `forward`) and construct a | |
# `ScriptModule` with a single `forward` method | |
module = torch.jit.trace(n, example_forward_input) | |
""" | |
if not _enabled: | |
return func | |
if optimize is not None: | |
warnings.warn( | |
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" | |
) | |
if isinstance(func, torch.jit.ScriptModule): | |
# it is hard to trace it because the forward method on ScriptModule is already defined, so it | |
# would result in an error. | |
warnings.warn( | |
"The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is." | |
) | |
return func | |
if isinstance(func, torch.nn.Module): | |
if example_inputs is None: | |
if isinstance(example_kwarg_inputs, dict): | |
example_inputs = example_kwarg_inputs | |
else: | |
raise RuntimeError("example_kwarg_inputs should be a dict") | |
return trace_module( | |
func, | |
{"forward": example_inputs}, | |
None, | |
check_trace, | |
wrap_check_inputs(check_inputs), | |
check_tolerance, | |
strict, | |
_force_outplace, | |
_module_class, | |
example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), | |
_store_inputs=_store_inputs, | |
) | |
if ( | |
hasattr(func, "__self__") | |
and isinstance(func.__self__, torch.nn.Module) | |
and func.__name__ == "forward" | |
): | |
if example_inputs is None: | |
if isinstance(example_kwarg_inputs, dict): | |
example_inputs = example_kwarg_inputs | |
else: | |
raise RuntimeError("example_kwarg_inputs should be a dict") | |
return trace_module( | |
func.__self__, | |
{"forward": example_inputs}, | |
None, | |
check_trace, | |
wrap_check_inputs(check_inputs), | |
check_tolerance, | |
strict, | |
_force_outplace, | |
_module_class, | |
example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), | |
_store_inputs=_store_inputs, | |
) | |
# Special case for common case of passing a single Tensor | |
if ( | |
isinstance(example_inputs, (torch.Tensor, dict)) | |
and example_kwarg_inputs is None | |
): | |
example_inputs = (example_inputs,) | |
# done primarily so that weird iterables fail here and not pybind11 code | |
elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple): | |
example_inputs = tuple(example_inputs) | |
var_lookup_fn = _create_interpreter_name_lookup_fn(0) | |
if hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module): | |
raise AttributeError( | |
"trace doesn't support compiling individual module's functions.\n" | |
"Please use trace_module" | |
) | |
name = _qualified_name(func) | |
if isinstance(example_kwarg_inputs, dict): | |
example_inputs = example_kwarg_inputs | |
traced = torch._C._create_function_from_trace_with_dict( | |
name, | |
func, | |
example_kwarg_inputs, | |
var_lookup_fn, | |
strict, | |
_force_outplace, | |
get_callable_argument_names(func), | |
) | |
else: | |
traced = torch._C._create_function_from_trace( | |
name, | |
func, | |
example_inputs, | |
var_lookup_fn, | |
strict, | |
_force_outplace, | |
get_callable_argument_names(func), | |
) | |
# Check the trace against new traces created from user-specified inputs | |
if check_trace: | |
if check_inputs is not None: | |
_check_trace( | |
check_inputs, | |
func, | |
traced, | |
check_tolerance, | |
strict, | |
_force_outplace, | |
False, | |
_module_class, | |
example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), | |
) | |
else: | |
_check_trace( | |
[example_inputs], | |
func, | |
traced, | |
check_tolerance, | |
strict, | |
_force_outplace, | |
False, | |
_module_class, | |
example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), | |
) | |
# Allow torch.compile() to inline | |
traced._torchdynamo_inline = func # type: ignore[attr-defined] | |
return traced | |
_trace_module_map: Optional[Dict[Any, Any]] = None | |
def trace_module( | |
mod, | |
inputs, | |
optimize=None, | |
check_trace=True, | |
check_inputs=None, | |
check_tolerance=1e-5, | |
strict=True, | |
_force_outplace=False, | |
_module_class=None, | |
_compilation_unit=_python_cu, | |
example_inputs_is_kwarg=False, | |
_store_inputs=True, | |
): | |
""" | |
Trace a module and return an executable :class:`ScriptModule` that will be optimized using just-in-time compilation. | |
When a module is passed to :func:`torch.jit.trace <torch.jit.trace>`, only | |
the ``forward`` method is run and traced. With ``trace_module``, you can specify a dictionary of | |
method names to example inputs to trace (see the ``inputs``) argument below. | |
See :func:`torch.jit.trace <torch.jit.trace>` for more information on tracing. | |
Args: | |
mod (torch.nn.Module): A ``torch.nn.Module`` containing methods whose names are | |
specified in ``inputs``. The given methods will be compiled | |
as a part of a single `ScriptModule`. | |
inputs (dict): A dict containing sample inputs indexed by method names in ``mod``. | |
The inputs will be passed to methods whose names correspond to inputs' | |
keys while tracing. | |
``{ 'forward' : example_forward_input, 'method2': example_method2_input}`` | |
Keyword arguments: | |
check_trace (``bool``, optional): Check if the same inputs run through | |
traced code produce the same outputs. Default: ``True``. You might want | |
to disable this if, for example, your network contains non- | |
deterministic ops or if you are sure that the network is correct despite | |
a checker failure. | |
check_inputs (list of dicts, optional): A list of dicts of input arguments that should be used | |
to check the trace against what is expected. Each tuple | |
is equivalent to a set of input arguments that would | |
be specified in ``inputs``. For best results, pass in a | |
set of checking inputs representative of the space of | |
shapes and types of inputs you expect the network to see. | |
If not specified, the original ``inputs`` are used for checking | |
check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. | |
This can be used to relax the checker strictness in the event that | |
results diverge numerically for a known reason, such as operator fusion. | |
example_inputs_is_kwarg (``bool``, optional): This parameter indicate whether the example inputs is a pack | |
pack of keyword arguments. Default: ``False``. | |
Returns: | |
A :class:`ScriptModule` object with a single ``forward`` method containing the traced code. | |
When ``func`` is a ``torch.nn.Module``, the returned :class:`ScriptModule` will have the same set of | |
sub-modules and parameters as ``func``. | |
Example (tracing a module with multiple methods):: | |
import torch | |
import torch.nn as nn | |
class Net(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv = nn.Conv2d(1, 1, 3) | |
def forward(self, x): | |
return self.conv(x) | |
def weighted_kernel_sum(self, weight): | |
return weight * self.conv.weight | |
n = Net() | |
example_weight = torch.rand(1, 1, 3, 3) | |
example_forward_input = torch.rand(1, 1, 3, 3) | |
# Trace a specific method and construct `ScriptModule` with | |
# a single `forward` method | |
module = torch.jit.trace(n.forward, example_forward_input) | |
# Trace a module (implicitly traces `forward`) and construct a | |
# `ScriptModule` with a single `forward` method | |
module = torch.jit.trace(n, example_forward_input) | |
# Trace specific methods on a module (specified in `inputs`), constructs | |
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods | |
inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} | |
module = torch.jit.trace_module(n, inputs) | |
""" | |
if not _enabled: | |
return mod | |
if optimize is not None: | |
warnings.warn( | |
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead" | |
) | |
var_lookup_fn = _create_interpreter_name_lookup_fn(0) | |
if not isinstance(mod, torch.nn.Module): | |
raise AttributeError("expected torch.nn.Module as the first argument") | |
if not isinstance(inputs, dict): | |
raise AttributeError("expected a dictionary of (method_name, input) pairs") | |
old_module_map = torch.jit._trace._trace_module_map | |
try: | |
trace_module_map: Dict[Any, Any] = {} | |
def register_submods(mod, prefix): | |
for name, child in mod.named_children(): | |
submod_qualname = prefix + "." + name | |
trace_module_map[child] = submod_qualname | |
register_submods(child, submod_qualname) | |
trace_module_map["__module"] = mod | |
torch.jit._trace._trace_module_map = trace_module_map | |
register_submods(mod, "__module") | |
module = make_module(mod, _module_class, _compilation_unit) | |
for method_name, example_inputs in inputs.items(): | |
if method_name == "forward": | |
# "forward" is a special case because we need to trace | |
# `Module.__call__`, which sets up some extra tracing, but uses | |
# argument names of the real `Module.forward` method. | |
func = mod | |
forward_method = getattr(mod, method_name) | |
argument_names = get_callable_argument_names(forward_method) | |
else: | |
func = getattr(mod, method_name) | |
argument_names = get_callable_argument_names(func) | |
if isinstance(example_inputs, dict) and example_inputs_is_kwarg: | |
# Raise exception when the user provided key names are not aligned with forward() method's arguments' name/ | |
for key in example_inputs: | |
if key not in argument_names: | |
valid_arguments = "[" + ",".join(argument_names) + "]" | |
raise NameError( | |
f"""'{key}' is not in forward() method's arguments, | |
valid arguments name are {valid_arguments}""" | |
) | |
module._c._create_method_from_trace_with_dict( | |
method_name, | |
func, | |
example_inputs, | |
var_lookup_fn, | |
strict, | |
_force_outplace, | |
argument_names, | |
_store_inputs, | |
) | |
else: | |
example_inputs = make_tuple(example_inputs) | |
module._c._create_method_from_trace( | |
method_name, | |
func, | |
example_inputs, | |
var_lookup_fn, | |
strict, | |
_force_outplace, | |
argument_names, | |
_store_inputs, | |
) | |
check_trace_method = module._c._get_method(method_name) | |
# Check the trace against new traces created from user-specified inputs | |
if check_trace: | |
if check_inputs is not None: | |
_check_trace( | |
check_inputs, | |
func, | |
check_trace_method, | |
check_tolerance, | |
strict, | |
_force_outplace, | |
True, | |
_module_class, | |
example_inputs_is_kwarg=example_inputs_is_kwarg, | |
) | |
else: | |
_check_trace( | |
[inputs], | |
func, | |
check_trace_method, | |
check_tolerance, | |
strict, | |
_force_outplace, | |
True, | |
_module_class, | |
example_inputs_is_kwarg=example_inputs_is_kwarg, | |
) | |
finally: | |
torch.jit._trace._trace_module_map = old_module_map | |
return module | |
def is_tracing(): | |
"""Return a boolean value. | |
Returns ``True`` in tracing (if a function is called during the | |
tracing of code with ``torch.jit.trace``) and ``False`` otherwise. | |
""" | |
if is_scripting(): | |
return False | |
return torch._C._is_tracing() | |
class TracedModule(ScriptModule): | |
_disable_script_meta = True | |
def __init__(self, orig, id_set=None, _compilation_unit=None): | |
# XXX: orig can be a nn.Module or a function! | |
super().__init__() | |
assert isinstance(orig, torch.nn.Module) | |
# Copy a subset of `orig` to a temporary nn.Module. | |
# This is a way to customize what will actually get compiled by create_script_module | |
id_set = set() | |
# This allows us to preserve the original module's qualified name by defining a new | |
# type with the attribute _jit_override_qualname. In torch._jit_internal._qualified_name | |
# we have a special case that will look up this attribute to override whatever qualname | |
# we would get from the python type system | |
class QualnameWrapper(torch.nn.Module): | |
pass | |
QualnameWrapper._jit_override_qualname = torch._jit_internal._qualified_name( # type: ignore[attr-defined] | |
type(orig) | |
) | |
tmp_module = QualnameWrapper() | |
def check_unique(param): | |
if param in id_set: | |
raise ValueError( | |
"TracedModules don't support parameter sharing between modules" | |
) | |
id_set.add(param) | |
tmp_module.training = orig.training | |
for name, param in orig._parameters.items(): | |
if param is not None: | |
tmp_module._parameters[name] = param | |
check_unique(param) | |
for name, buf in orig._buffers.items(): | |
if buf is not None: | |
tmp_module._buffers[name] = buf | |
check_unique(buf) | |
for name, val in orig.__dict__.items(): | |
if ( | |
torch._C._jit_is_script_object(val) | |
and name not in orig._parameters | |
and name not in orig._buffers | |
): | |
setattr(tmp_module, name, val) | |
if orig._backward_hooks: | |
raise ValueError( | |
"Modules that have backward hooks assigned can't be compiled: " | |
+ str(orig) | |
) | |
for name, submodule in orig._modules.items(): | |
if submodule is None: | |
continue | |
tmp_module._modules[name] = make_module( | |
submodule, TracedModule, _compilation_unit=None | |
) | |
script_module = torch.jit._recursive.create_script_module( | |
tmp_module, lambda module: (), share_types=False, is_tracing=True | |
) | |
self.__dict__["_name"] = type(orig).__name__ | |
self.__dict__["_actual_script_module"] = script_module | |
for name in ("_parameters", "_buffers", "_modules", "training"): | |
delattr(self, name) | |
def forward(self, *args, **kwargs): | |
raise RuntimeError("Trace submodules cannot be called.") | |
def __getattr__(self, attr): | |
if "_actual_script_module" not in self.__dict__: | |
return super().__getattr__(attr) | |
return getattr(self._actual_script_module, attr) | |
def __setattr__(self, attr, value): | |
if "_actual_script_module" not in self.__dict__: | |
return super().__setattr__(attr, value) | |
setattr(self._actual_script_module, attr, value) | |
def _get_name(self): | |
return self._name | |
def extra_repr(self): | |
return f"original_name={self._name}" | |
class TopLevelTracedModule(TracedModule): | |
forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment] | |
def _reconstruct(self, cpp_module): | |
""" | |
Re-construct an instance of TopLevelTracedModule using an instance of a C++ module. | |
Args: | |
cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around. | |
""" | |
self.__dict__["_actual_script_module"]._reconstruct(cpp_module) | |
def _script_if_tracing(fn: Callable[P, R]) -> Callable[P, R]: | |
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: | |
if not is_tracing(): | |
# Not tracing, don't do anything | |
return fn(*args, **kwargs) | |
compiled_fn: Callable[P, R] = script(wrapper.__original_fn) # type: ignore[attr-defined] | |
return compiled_fn(*args, **kwargs) | |
wrapper.__original_fn = fn # type: ignore[attr-defined] | |
wrapper.__script_if_tracing_wrapper = True # type: ignore[attr-defined] | |
return wrapper | |
def _get_trace_graph( | |
f, | |
args=(), | |
kwargs=None, | |
strict=True, | |
_force_outplace=False, | |
return_inputs=False, | |
_return_inputs_states=False, | |
): | |
"""Return a tuple on tracing a function or model. | |
.. warning:: | |
This function is internal-only and should only be used by the ONNX | |
exporter. If you are trying to get a graph through tracing, please go | |
through the public API instead:: | |
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) | |
trace_graph = trace.graph | |
Trace a function or model, returning a tuple consisting of the both the | |
*trace* of an execution, as well as the original return value. If return_inputs, | |
also returns the trace inputs as part of the tuple | |
Tracing is guaranteed not to change the semantics of the function/module | |
that is traced. | |
Args: | |
f (torch.nn.Module or function): the function or module | |
to be traced. | |
args (tuple or Tensor): the positional arguments to pass to the | |
function/module to be traced. A non-tuple is assumed to | |
be a single positional argument to be passed to the model. | |
kwargs (dict): the keyword arguments to pass to the function/module | |
to be traced. | |
Example (trace a cell): | |
.. testcode:: | |
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) | |
""" | |
if kwargs is None: | |
kwargs = {} | |
if not isinstance(args, tuple): | |
args = (args,) | |
outs = ONNXTracedModule( | |
f, strict, _force_outplace, return_inputs, _return_inputs_states | |
)(*args, **kwargs) | |
return outs | |