File size: 6,768 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import cmath
import math
import warnings

from collections import OrderedDict
from typing import Dict, Optional

import torch
import torch.backends.cudnn as cudnn

from ..nn.modules.utils import _list_with_default, _pair, _quadruple, _single, _triple

_builtin_table: Optional[Dict[int, str]] = None

_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special)  # type: ignore[attr-defined] # noqa: B950

_builtin_ops = [
    # Pairs of (function, op_name)
    (_pair, "aten::_pair"),
    (_quadruple, "aten::_quadruple"),
    (_single, "aten::_single"),
    (_triple, "aten::_triple"),
    (_list_with_default, "aten::list_with_default"),
    (OrderedDict, "aten::dict"),
    (dict, "aten::dict"),
    (cudnn.is_acceptable, "aten::cudnn_is_acceptable"),
    (math.ceil, "aten::ceil"),
    (math.copysign, "aten::copysign"),
    (math.erf, "aten::erf"),
    (math.erfc, "aten::erfc"),
    (math.exp, "aten::exp"),
    (math.expm1, "aten::expm1"),
    (math.fabs, "aten::fabs"),
    (math.floor, "aten::floor"),
    (math.gamma, "aten::gamma"),
    (math.lgamma, "aten::lgamma"),
    (math.log, "aten::log"),
    (math.log10, "aten::log10"),
    (math.log1p, "aten::log1p"),
    (math.pow, "aten::pow"),
    (math.sqrt, "aten::sqrt"),
    (math.isnan, "aten::isnan"),
    (math.asinh, "aten::asinh"),
    (math.atanh, "aten::atanh"),
    (math.cosh, "aten::cosh"),
    (math.sinh, "aten::sinh"),
    (math.tanh, "aten::tanh"),
    (math.acos, "aten::acos"),
    (math.asin, "aten::asin"),
    (math.atan, "aten::atan"),
    (math.atan2, "aten::atan2"),
    (math.cos, "aten::cos"),
    (math.sin, "aten::sin"),
    (math.tan, "aten::tan"),
    (math.asinh, "aten::asinh"),
    (math.atanh, "aten::atanh"),
    (math.acosh, "aten::acosh"),
    (math.fmod, "aten::fmod"),
    (math.modf, "aten::modf"),
    (math.factorial, "aten::factorial"),
    (math.frexp, "aten::frexp"),
    (math.isinf, "aten::isinf"),
    (math.degrees, "aten::degrees"),
    (math.radians, "aten::radians"),
    (cmath.isnan, "aten::isnan"),
    (cmath.isfinite, "aten::isfinite"),
    (cmath.isinf, "aten::isinf"),
    (cmath.phase, "aten::angle"),
    (cmath.rect, "aten::polar"),
    (cmath.log, "aten::log"),
    (cmath.log10, "aten::log10"),
    (cmath.sqrt, "aten::sqrt"),
    (cmath.exp, "aten::exp"),
    (cmath.sin, "aten::sin"),
    (cmath.tan, "aten::tan"),
    (cmath.cos, "aten::cos"),
    (cmath.asin, "aten::asin"),
    (cmath.acos, "aten::acos"),
    (cmath.atan, "aten::atan"),
    (cmath.sinh, "aten::sinh"),
    (cmath.cosh, "aten::cosh"),
    (cmath.tanh, "aten::tanh"),
    (cmath.asinh, "aten::asinh"),
    (cmath.acosh, "aten::acosh"),
    (cmath.atanh, "aten::atanh"),
    (math.ldexp, "aten::ldexp"),
    (torch._assert, "aten::_assert"),
    (torch.autograd.grad, "aten::grad"),
    (torch.autograd.backward, "aten::backward"),
    (torch._C._infer_size, "aten::_infer_size"),
    (torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"),  # type: ignore[attr-defined]
    (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"),
    (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"),
    (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),
    (torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
    (torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
    (torch._C._get_tracing_state, "aten::_get_tracing_state"),
    (torch._C._get_cpu_capability, "aten::_get_cpu_capability"),
    (warnings.warn, "aten::warn"),
    (torch._VF.stft, "aten::stft"),  # type: ignore[attr-defined]
    (torch._VF.istft, "aten::istft"),  # type: ignore[attr-defined]
    (torch._VF.cdist, "aten::cdist"),  # type: ignore[attr-defined]
    (torch._VF.norm, "aten::norm"),  # type: ignore[attr-defined]
    (torch._VF.unique_dim, "aten::unique_dim"),
    (torch._VF.unique_consecutive, "aten::unique_consecutive"),  # type: ignore[attr-defined]
    (torch._VF.nuclear_norm, "aten::nuclear_norm"),
    (torch._VF.frobenius_norm, "aten::frobenius_norm"),
    (torch._VF.tensordot, "aten::tensordot"),  # type: ignore[attr-defined]
]

# ops in torch.functional are bound to torch
# in these cases, we want to resolve the function to their python implementation
# instead looking up a builtin "aten::" schema


def _gen_torch_functional_registered_ops():
    # eventually ops should encompass all of torch/functional.py, (torch.functional.__all__)
    # but we are currently only able to compile some of the functions. additionally,
    # some functions directly map to their aten:: implementations.
    # TODO: add support for more ops
    ops = [
        "stft",
        "istft",
        "lu",
        "cdist",
        "norm",
        "unique",
        "unique_consecutive",
        "tensordot",
    ]
    return {getattr(torch.functional, name) for name in ops}


_functional_registered_ops = _gen_torch_functional_registered_ops()


def _is_special_functional_bound_op(fn):
    return fn in _functional_registered_ops


# lazily built to ensure the correct initialization order
def _get_builtin_table():
    global _builtin_table
    if _builtin_table is not None:
        return _builtin_table
    _builtin_table = {}

    def register_all(mod):
        for name in dir(mod):
            v = getattr(mod, name)
            if (
                callable(v)
                and not _is_special_functional_bound_op(v)
                and v is not torch.no_grad
                and v is not torch.autocast
            ):
                # Fixup inconsistency in segment_reduce
                if name == "_segment_reduce":
                    name = name[1:]
                _builtin_ops.append((v, "aten::" + name))

    for mod in _modules_containing_builtins:
        register_all(mod)

    _builtin_ops.append((math.gcd, "aten::gcd"))
    _builtin_ops.append((math.isfinite, "aten::isfinite"))
    _builtin_ops.append((math.remainder, "aten::mathremainder"))  # type: ignore[attr-defined]

    import torch.distributed.autograd as dist_autograd

    if dist_autograd.is_available():
        _builtin_ops.append((dist_autograd.get_gradients, "aten::get_gradients"))
        _builtin_ops.append((dist_autograd.backward, "aten::dist_backward"))

    # populate the _builtin_table from _builtin_ops
    for builtin, aten_op in _builtin_ops:
        _builtin_table[id(builtin)] = aten_op

    return _builtin_table


def _register_builtin(fn, op):
    _get_builtin_table()[id(fn)] = op


def _find_builtin(fn):
    return _get_builtin_table().get(id(fn))