Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
# pointwise operators can go through a faster pathway | |
tensor_magic_methods = ["add", ""] | |
pointwise_magic_methods_with_reverse = ( | |
"add", | |
"sub", | |
"mul", | |
"floordiv", | |
"div", | |
"truediv", | |
"mod", | |
"pow", | |
"lshift", | |
"rshift", | |
"and", | |
"or", | |
"xor", | |
) | |
pointwise_magic_methods = ( | |
*(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)), | |
"eq", | |
"gt", | |
"le", | |
"lt", | |
"ge", | |
"gt", | |
"ne", | |
"neg", | |
"pos", | |
"abs", | |
"invert", | |
"iadd", | |
"isub", | |
"imul", | |
"ifloordiv", | |
"idiv", | |
"itruediv", | |
"imod", | |
"ipow", | |
"ilshift", | |
"irshift", | |
"iand", | |
"ior", | |
"ixor", | |
"int", | |
"long", | |
"float", | |
"complex", | |
) | |
pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),) | |
pointwise = ( | |
*(getattr(torch.Tensor, m) for m in pointwise_methods), | |
torch.nn.functional.dropout, | |
torch.where, | |
torch.Tensor.abs, | |
torch.abs, | |
torch.Tensor.acos, | |
torch.acos, | |
torch.Tensor.acosh, | |
torch.acosh, | |
torch.Tensor.add, | |
torch.add, | |
torch.Tensor.addcdiv, | |
torch.addcdiv, | |
torch.Tensor.addcmul, | |
torch.addcmul, | |
torch.Tensor.addr, | |
torch.addr, | |
torch.Tensor.angle, | |
torch.angle, | |
torch.Tensor.asin, | |
torch.asin, | |
torch.Tensor.asinh, | |
torch.asinh, | |
torch.Tensor.atan, | |
torch.atan, | |
torch.Tensor.atan2, | |
torch.atan2, | |
torch.Tensor.atanh, | |
torch.atanh, | |
torch.Tensor.bitwise_and, | |
torch.bitwise_and, | |
torch.Tensor.bitwise_left_shift, | |
torch.bitwise_left_shift, | |
torch.Tensor.bitwise_not, | |
torch.bitwise_not, | |
torch.Tensor.bitwise_or, | |
torch.bitwise_or, | |
torch.Tensor.bitwise_right_shift, | |
torch.bitwise_right_shift, | |
torch.Tensor.bitwise_xor, | |
torch.bitwise_xor, | |
torch.Tensor.ceil, | |
torch.ceil, | |
torch.celu, | |
torch.nn.functional.celu, | |
torch.Tensor.clamp, | |
torch.clamp, | |
torch.Tensor.clamp_max, | |
torch.clamp_max, | |
torch.Tensor.clamp_min, | |
torch.clamp_min, | |
torch.Tensor.copysign, | |
torch.copysign, | |
torch.Tensor.cos, | |
torch.cos, | |
torch.Tensor.cosh, | |
torch.cosh, | |
torch.Tensor.deg2rad, | |
torch.deg2rad, | |
torch.Tensor.digamma, | |
torch.digamma, | |
torch.Tensor.div, | |
torch.div, | |
torch.dropout, | |
torch.nn.functional.dropout, | |
torch.nn.functional.elu, | |
torch.Tensor.eq, | |
torch.eq, | |
torch.Tensor.erf, | |
torch.erf, | |
torch.Tensor.erfc, | |
torch.erfc, | |
torch.Tensor.erfinv, | |
torch.erfinv, | |
torch.Tensor.exp, | |
torch.exp, | |
torch.Tensor.exp2, | |
torch.exp2, | |
torch.Tensor.expm1, | |
torch.expm1, | |
torch.feature_dropout, | |
torch.Tensor.float_power, | |
torch.float_power, | |
torch.Tensor.floor, | |
torch.floor, | |
torch.Tensor.floor_divide, | |
torch.floor_divide, | |
torch.Tensor.fmod, | |
torch.fmod, | |
torch.Tensor.frac, | |
torch.frac, | |
torch.Tensor.frexp, | |
torch.frexp, | |
torch.Tensor.gcd, | |
torch.gcd, | |
torch.Tensor.ge, | |
torch.ge, | |
torch.nn.functional.gelu, | |
torch.nn.functional.glu, | |
torch.Tensor.gt, | |
torch.gt, | |
torch.Tensor.hardshrink, | |
torch.hardshrink, | |
torch.nn.functional.hardshrink, | |
torch.nn.functional.hardsigmoid, | |
torch.nn.functional.hardswish, | |
torch.nn.functional.hardtanh, | |
torch.Tensor.heaviside, | |
torch.heaviside, | |
torch.Tensor.hypot, | |
torch.hypot, | |
torch.Tensor.i0, | |
torch.i0, | |
torch.Tensor.igamma, | |
torch.igamma, | |
torch.Tensor.igammac, | |
torch.igammac, | |
torch.Tensor.isclose, | |
torch.isclose, | |
torch.Tensor.isfinite, | |
torch.isfinite, | |
torch.Tensor.isinf, | |
torch.isinf, | |
torch.Tensor.isnan, | |
torch.isnan, | |
torch.Tensor.isneginf, | |
torch.isneginf, | |
torch.Tensor.isposinf, | |
torch.isposinf, | |
torch.Tensor.isreal, | |
torch.isreal, | |
torch.Tensor.kron, | |
torch.kron, | |
torch.Tensor.lcm, | |
torch.lcm, | |
torch.Tensor.ldexp, | |
torch.ldexp, | |
torch.Tensor.le, | |
torch.le, | |
torch.nn.functional.leaky_relu, | |
torch.Tensor.lerp, | |
torch.lerp, | |
torch.Tensor.lgamma, | |
torch.lgamma, | |
torch.Tensor.log, | |
torch.log, | |
torch.Tensor.log10, | |
torch.log10, | |
torch.Tensor.log1p, | |
torch.log1p, | |
torch.Tensor.log2, | |
torch.log2, | |
torch.nn.functional.logsigmoid, | |
torch.Tensor.logical_and, | |
torch.logical_and, | |
torch.Tensor.logical_not, | |
torch.logical_not, | |
torch.Tensor.logical_or, | |
torch.logical_or, | |
torch.Tensor.logical_xor, | |
torch.logical_xor, | |
torch.Tensor.logit, | |
torch.logit, | |
torch.Tensor.lt, | |
torch.lt, | |
torch.Tensor.maximum, | |
torch.maximum, | |
torch.Tensor.minimum, | |
torch.minimum, | |
torch.nn.functional.mish, | |
torch.Tensor.mvlgamma, | |
torch.mvlgamma, | |
torch.Tensor.nan_to_num, | |
torch.nan_to_num, | |
torch.Tensor.ne, | |
torch.ne, | |
torch.Tensor.neg, | |
torch.neg, | |
torch.Tensor.nextafter, | |
torch.nextafter, | |
torch.Tensor.outer, | |
torch.outer, | |
torch.polar, | |
torch.Tensor.polygamma, | |
torch.polygamma, | |
torch.Tensor.positive, | |
torch.positive, | |
torch.Tensor.pow, | |
torch.pow, | |
torch.Tensor.prelu, | |
torch.prelu, | |
torch.nn.functional.prelu, | |
torch.Tensor.rad2deg, | |
torch.rad2deg, | |
torch.Tensor.reciprocal, | |
torch.reciprocal, | |
torch.Tensor.relu, | |
torch.relu, | |
torch.nn.functional.relu, | |
torch.nn.functional.relu6, | |
torch.Tensor.remainder, | |
torch.remainder, | |
torch.Tensor.round, | |
torch.round, | |
torch.rrelu, | |
torch.nn.functional.rrelu, | |
torch.Tensor.rsqrt, | |
torch.rsqrt, | |
torch.rsub, | |
torch.selu, | |
torch.nn.functional.selu, | |
torch.Tensor.sgn, | |
torch.sgn, | |
torch.Tensor.sigmoid, | |
torch.sigmoid, | |
torch.nn.functional.sigmoid, | |
torch.Tensor.sign, | |
torch.sign, | |
torch.Tensor.signbit, | |
torch.signbit, | |
torch.nn.functional.silu, | |
torch.Tensor.sin, | |
torch.sin, | |
torch.Tensor.sinc, | |
torch.sinc, | |
torch.Tensor.sinh, | |
torch.sinh, | |
torch.nn.functional.softplus, | |
torch.nn.functional.softshrink, | |
torch.Tensor.sqrt, | |
torch.sqrt, | |
torch.Tensor.square, | |
torch.square, | |
torch.Tensor.sub, | |
torch.sub, | |
torch.Tensor.tan, | |
torch.tan, | |
torch.Tensor.tanh, | |
torch.tanh, | |
torch.nn.functional.tanh, | |
torch.threshold, | |
torch.nn.functional.threshold, | |
torch.trapz, | |
torch.Tensor.true_divide, | |
torch.true_divide, | |
torch.Tensor.trunc, | |
torch.trunc, | |
torch.Tensor.xlogy, | |
torch.xlogy, | |
torch.rand_like, | |
) | |