File size: 9,653 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""Freezing.



This is not intended to be imported directly; please use the exposed

functionalities in `torch.jit`.

"""

from typing import List, Optional

import torch
from torch.jit._script import RecursiveScriptModule, ScriptModule


def freeze(

    mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True

):
    r"""Freeze ScriptModule, inline submodules, and attributes as constants.



    Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned

    module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.

    By default, `forward` will be preserved, as well as attributes & methods specified in

    `preserved_attrs`. Additionally, any attribute that is modified within a preserved

    method will be preserved.



    Freezing currently only accepts ScriptModules that are in eval mode.



    Freezing applies generic optimization that will speed up your model regardless of machine.

    To further optimize using server-specific settings, run `optimize_for_inference` after

    freezing.



    Args:

        mod (:class:`ScriptModule`): a module to be frozen

        preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.

            Attributes modified in preserved methods will also be preserved.

        optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly

            preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`.



    Returns:

        Frozen :class:`ScriptModule`.



    Example (Freezing a simple module with a Parameter):



    .. testcode::

        import torch

        class MyModule(torch.nn.Module):

            def __init__(self, N, M):

                super().__init__()

                self.weight = torch.nn.Parameter(torch.rand(N, M))

                self.linear = torch.nn.Linear(N, M)



            def forward(self, input):

                output = self.weight.mm(input)

                output = self.linear(output)

                return output



        scripted_module = torch.jit.script(MyModule(2, 3).eval())

        frozen_module = torch.jit.freeze(scripted_module)

        # parameters have been removed and inlined into the Graph as constants

        assert len(list(frozen_module.named_parameters())) == 0

        # See the compiled graph as Python code

        print(frozen_module.code)



    Example (Freezing a module with preserved attributes)



    .. testcode::

        import torch

        class MyModule2(torch.nn.Module):

            def __init__(self):

                super().__init__()

                self.modified_tensor = torch.tensor(10.)

                self.version = 1



            def forward(self, input):

                self.modified_tensor += 1

                return input + self.modified_tensor



        scripted_module = torch.jit.script(MyModule2().eval())

        frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])

        # we've manually preserved `version`, so it still exists on the frozen module and can be modified

        assert frozen_module.version == 1

        frozen_module.version = 2

        # `modified_tensor` is detected as being mutated in the forward, so freezing preserves

        # it to retain model semantics

        assert frozen_module(torch.tensor(1)) == torch.tensor(12)

        # now that we've run it once, the next result will be incremented by one

        assert frozen_module(torch.tensor(1)) == torch.tensor(13)



    Note:

        Freezing submodule attributes is also supported:

        frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["submodule.version"])



    Note:

        If you're not sure why an attribute is not being inlined as a constant, you can run

        `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the

        attribute is being modified.



    Note:

        Because freezing makes weights constants and removes module hierarchy, `to` and other

        nn.Module methods to manipulate device or dtype no longer work. As a workaround,

        You can remap devices by specifying `map_location` in `torch.jit.load`, however

        device-specific logic may have been baked into the model.

    """
    if not isinstance(mod, ScriptModule):
        raise RuntimeError(
            "Freezing expects a ScriptModule as input. "
            "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
        )

    if mod.training:
        raise RuntimeError(
            "Freezing is currently only implemented for modules in eval mode. "
            "Please call .eval() on your module before freezing."
        )

    preserved_attrs = preserved_attrs if preserved_attrs is not None else []

    out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
    RecursiveScriptModule._finalize_scriptmodule(out)

    preserved_methods = [x for x in preserved_attrs if mod._c._has_method(x)]
    run_frozen_optimizations(out, optimize_numerics, preserved_methods)

    return out


def run_frozen_optimizations(

    mod, optimize_numerics: bool = True, preserved_methods: Optional[List[str]] = None

):
    r"""

    Run a series of optimizations looking for patterns that occur in frozen graphs.



    The current set of optimizations includes:

        - Dropout Removal

        - Pretranspose Linear Layers

        - Concat Linear Layers with same input Tensor

        - Conv -> Batchnorm folding

        - Conv -> Add/Sub folding

        - Conv -> Mul/Div folding



    Args:

        mod (:class:`ScriptModule`): a frozen module to be optimized



        optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly

        preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_close`

        when applied on a single transformation, however in a module where many transformations are applied

        the rtol or atol may no longer fall within the default `assert_close` tolerance. Conv -> Batchnorm folding,

        Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics.



    Returns:

        None



    Note:

        In rare occassions, this can result in slower execution.



    Example (Freezing a module with Conv->Batchnorm)

    .. code-block:: python

        import torch

        in_channels, out_channels = 3, 32

        conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)

        bn = torch.nn.BatchNorm2d(out_channels, eps=.001)

        mod = torch.nn.Sequential(conv, bn)

        # set optimize to False here, by default freezing runs run_frozen_optimizations

        frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)

        # inspect frozen mod

        assert "batch_norm" in str(frozen_mod.graph)

        torch.jit.run_frozen_optimizations(frozen_mod)

        assert "batch_norm" not in str(frozen_mod.graph)



    """
    if mod._c._has_method("forward"):
        torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics)

    if preserved_methods is None:
        preserved_methods = []

    for method in preserved_methods:
        torch._C._jit_pass_optimize_frozen_graph(
            mod.__getattr__(method).graph, optimize_numerics
        )


def optimize_for_inference(

    mod: ScriptModule, other_methods: Optional[List[str]] = None

) -> ScriptModule:
    """

    Perform a set of optimization passes to optimize a model for the purposes of inference.



    If the model is not already frozen, optimize_for_inference

    will invoke `torch.jit.freeze` automatically.



    In addition to generic optimizations that should speed up your model regardless

    of environment, prepare for inference will also bake in build specific settings

    such as the presence of CUDNN or MKLDNN, and may in the future make transformations

    which speed things up on one machine but slow things down on another. Accordingly,

    serialization is not implemented following invoking `optimize_for_inference` and

    is not guaranteed.



    This is still in prototype, and may have the potential to slow down your model.

    Primary use cases that have been targeted so far have been vision models on cpu

    and gpu to a lesser extent.



    Example (optimizing a module with Conv->Batchnorm)::



        import torch

        in_channels, out_channels = 3, 32

        conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)

        bn = torch.nn.BatchNorm2d(out_channels, eps=.001)

        mod = torch.nn.Sequential(conv, bn)

        frozen_mod = torch.jit.optimize_for_inference(torch.jit.script(mod.eval()))

        assert "batch_norm" not in str(frozen_mod.graph)

        # if built with MKLDNN, convolution will be run with MKLDNN weights

        assert "MKLDNN" in frozen_mod.graph

    """
    if not isinstance(mod, ScriptModule):
        raise RuntimeError(
            "optimize_for_inference expects a ScriptModule as input. "
            "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
        )

    if other_methods is None:
        other_methods = []

    if hasattr(mod, "training"):
        mod = freeze(mod.eval(), preserved_attrs=other_methods)

    torch._C._jit_pass_optimize_for_inference(mod._c, other_methods)

    return mod