Spaces:
Running
Running
"""Freezing. | |
This is not intended to be imported directly; please use the exposed | |
functionalities in `torch.jit`. | |
""" | |
from typing import List, Optional | |
import torch | |
from torch.jit._script import RecursiveScriptModule, ScriptModule | |
def freeze( | |
mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True | |
): | |
r"""Freeze ScriptModule, inline submodules, and attributes as constants. | |
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned | |
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph. | |
By default, `forward` will be preserved, as well as attributes & methods specified in | |
`preserved_attrs`. Additionally, any attribute that is modified within a preserved | |
method will be preserved. | |
Freezing currently only accepts ScriptModules that are in eval mode. | |
Freezing applies generic optimization that will speed up your model regardless of machine. | |
To further optimize using server-specific settings, run `optimize_for_inference` after | |
freezing. | |
Args: | |
mod (:class:`ScriptModule`): a module to be frozen | |
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method. | |
Attributes modified in preserved methods will also be preserved. | |
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly | |
preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`. | |
Returns: | |
Frozen :class:`ScriptModule`. | |
Example (Freezing a simple module with a Parameter): | |
.. testcode:: | |
import torch | |
class MyModule(torch.nn.Module): | |
def __init__(self, N, M): | |
super().__init__() | |
self.weight = torch.nn.Parameter(torch.rand(N, M)) | |
self.linear = torch.nn.Linear(N, M) | |
def forward(self, input): | |
output = self.weight.mm(input) | |
output = self.linear(output) | |
return output | |
scripted_module = torch.jit.script(MyModule(2, 3).eval()) | |
frozen_module = torch.jit.freeze(scripted_module) | |
# parameters have been removed and inlined into the Graph as constants | |
assert len(list(frozen_module.named_parameters())) == 0 | |
# See the compiled graph as Python code | |
print(frozen_module.code) | |
Example (Freezing a module with preserved attributes) | |
.. testcode:: | |
import torch | |
class MyModule2(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.modified_tensor = torch.tensor(10.) | |
self.version = 1 | |
def forward(self, input): | |
self.modified_tensor += 1 | |
return input + self.modified_tensor | |
scripted_module = torch.jit.script(MyModule2().eval()) | |
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) | |
# we've manually preserved `version`, so it still exists on the frozen module and can be modified | |
assert frozen_module.version == 1 | |
frozen_module.version = 2 | |
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves | |
# it to retain model semantics | |
assert frozen_module(torch.tensor(1)) == torch.tensor(12) | |
# now that we've run it once, the next result will be incremented by one | |
assert frozen_module(torch.tensor(1)) == torch.tensor(13) | |
Note: | |
Freezing submodule attributes is also supported: | |
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["submodule.version"]) | |
Note: | |
If you're not sure why an attribute is not being inlined as a constant, you can run | |
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the | |
attribute is being modified. | |
Note: | |
Because freezing makes weights constants and removes module hierarchy, `to` and other | |
nn.Module methods to manipulate device or dtype no longer work. As a workaround, | |
You can remap devices by specifying `map_location` in `torch.jit.load`, however | |
device-specific logic may have been baked into the model. | |
""" | |
if not isinstance(mod, ScriptModule): | |
raise RuntimeError( | |
"Freezing expects a ScriptModule as input. " | |
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." | |
) | |
if mod.training: | |
raise RuntimeError( | |
"Freezing is currently only implemented for modules in eval mode. " | |
"Please call .eval() on your module before freezing." | |
) | |
preserved_attrs = preserved_attrs if preserved_attrs is not None else [] | |
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs)) | |
RecursiveScriptModule._finalize_scriptmodule(out) | |
preserved_methods = [x for x in preserved_attrs if mod._c._has_method(x)] | |
run_frozen_optimizations(out, optimize_numerics, preserved_methods) | |
return out | |
def run_frozen_optimizations( | |
mod, optimize_numerics: bool = True, preserved_methods: Optional[List[str]] = None | |
): | |
r""" | |
Run a series of optimizations looking for patterns that occur in frozen graphs. | |
The current set of optimizations includes: | |
- Dropout Removal | |
- Pretranspose Linear Layers | |
- Concat Linear Layers with same input Tensor | |
- Conv -> Batchnorm folding | |
- Conv -> Add/Sub folding | |
- Conv -> Mul/Div folding | |
Args: | |
mod (:class:`ScriptModule`): a frozen module to be optimized | |
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly | |
preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_close` | |
when applied on a single transformation, however in a module where many transformations are applied | |
the rtol or atol may no longer fall within the default `assert_close` tolerance. Conv -> Batchnorm folding, | |
Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics. | |
Returns: | |
None | |
Note: | |
In rare occassions, this can result in slower execution. | |
Example (Freezing a module with Conv->Batchnorm) | |
.. code-block:: python | |
import torch | |
in_channels, out_channels = 3, 32 | |
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True) | |
bn = torch.nn.BatchNorm2d(out_channels, eps=.001) | |
mod = torch.nn.Sequential(conv, bn) | |
# set optimize to False here, by default freezing runs run_frozen_optimizations | |
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False) | |
# inspect frozen mod | |
assert "batch_norm" in str(frozen_mod.graph) | |
torch.jit.run_frozen_optimizations(frozen_mod) | |
assert "batch_norm" not in str(frozen_mod.graph) | |
""" | |
if mod._c._has_method("forward"): | |
torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics) | |
if preserved_methods is None: | |
preserved_methods = [] | |
for method in preserved_methods: | |
torch._C._jit_pass_optimize_frozen_graph( | |
mod.__getattr__(method).graph, optimize_numerics | |
) | |
def optimize_for_inference( | |
mod: ScriptModule, other_methods: Optional[List[str]] = None | |
) -> ScriptModule: | |
""" | |
Perform a set of optimization passes to optimize a model for the purposes of inference. | |
If the model is not already frozen, optimize_for_inference | |
will invoke `torch.jit.freeze` automatically. | |
In addition to generic optimizations that should speed up your model regardless | |
of environment, prepare for inference will also bake in build specific settings | |
such as the presence of CUDNN or MKLDNN, and may in the future make transformations | |
which speed things up on one machine but slow things down on another. Accordingly, | |
serialization is not implemented following invoking `optimize_for_inference` and | |
is not guaranteed. | |
This is still in prototype, and may have the potential to slow down your model. | |
Primary use cases that have been targeted so far have been vision models on cpu | |
and gpu to a lesser extent. | |
Example (optimizing a module with Conv->Batchnorm):: | |
import torch | |
in_channels, out_channels = 3, 32 | |
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True) | |
bn = torch.nn.BatchNorm2d(out_channels, eps=.001) | |
mod = torch.nn.Sequential(conv, bn) | |
frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval())) | |
assert "batch_norm" not in str(frozen_mod.graph) | |
# if built with MKLDNN, convolution will be run with MKLDNN weights | |
assert "MKLDNN" in frozen_mod.graph | |
""" | |
if not isinstance(mod, ScriptModule): | |
raise RuntimeError( | |
"optimize_for_inference expects a ScriptModule as input. " | |
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." | |
) | |
if other_methods is None: | |
other_methods = [] | |
if hasattr(mod, "training"): | |
mod = freeze(mod.eval(), preserved_attrs=other_methods) | |
torch._C._jit_pass_optimize_for_inference(mod._c, other_methods) | |
return mod | |