# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch # pointwise operators can go through a faster pathway tensor_magic_methods = ["add", ""] pointwise_magic_methods_with_reverse = ( "add", "sub", "mul", "floordiv", "div", "truediv", "mod", "pow", "lshift", "rshift", "and", "or", "xor", ) pointwise_magic_methods = ( *(x for m in pointwise_magic_methods_with_reverse for x in (m, "r" + m)), "eq", "gt", "le", "lt", "ge", "gt", "ne", "neg", "pos", "abs", "invert", "iadd", "isub", "imul", "ifloordiv", "idiv", "itruediv", "imod", "ipow", "ilshift", "irshift", "iand", "ior", "ixor", "int", "long", "float", "complex", ) pointwise_methods = (*(f"__{m}__" for m in pointwise_magic_methods),) pointwise = ( *(getattr(torch.Tensor, m) for m in pointwise_methods), torch.nn.functional.dropout, torch.where, torch.Tensor.abs, torch.abs, torch.Tensor.acos, torch.acos, torch.Tensor.acosh, torch.acosh, torch.Tensor.add, torch.add, torch.Tensor.addcdiv, torch.addcdiv, torch.Tensor.addcmul, torch.addcmul, torch.Tensor.addr, torch.addr, torch.Tensor.angle, torch.angle, torch.Tensor.asin, torch.asin, torch.Tensor.asinh, torch.asinh, torch.Tensor.atan, torch.atan, torch.Tensor.atan2, torch.atan2, torch.Tensor.atanh, torch.atanh, torch.Tensor.bitwise_and, torch.bitwise_and, torch.Tensor.bitwise_left_shift, torch.bitwise_left_shift, torch.Tensor.bitwise_not, torch.bitwise_not, torch.Tensor.bitwise_or, torch.bitwise_or, torch.Tensor.bitwise_right_shift, torch.bitwise_right_shift, torch.Tensor.bitwise_xor, torch.bitwise_xor, torch.Tensor.ceil, torch.ceil, torch.celu, torch.nn.functional.celu, torch.Tensor.clamp, torch.clamp, torch.Tensor.clamp_max, torch.clamp_max, torch.Tensor.clamp_min, torch.clamp_min, torch.Tensor.copysign, torch.copysign, torch.Tensor.cos, torch.cos, torch.Tensor.cosh, torch.cosh, torch.Tensor.deg2rad, torch.deg2rad, torch.Tensor.digamma, torch.digamma, torch.Tensor.div, torch.div, torch.dropout, torch.nn.functional.dropout, torch.nn.functional.elu, torch.Tensor.eq, torch.eq, torch.Tensor.erf, torch.erf, torch.Tensor.erfc, torch.erfc, torch.Tensor.erfinv, torch.erfinv, torch.Tensor.exp, torch.exp, torch.Tensor.exp2, torch.exp2, torch.Tensor.expm1, torch.expm1, torch.feature_dropout, torch.Tensor.float_power, torch.float_power, torch.Tensor.floor, torch.floor, torch.Tensor.floor_divide, torch.floor_divide, torch.Tensor.fmod, torch.fmod, torch.Tensor.frac, torch.frac, torch.Tensor.frexp, torch.frexp, torch.Tensor.gcd, torch.gcd, torch.Tensor.ge, torch.ge, torch.nn.functional.gelu, torch.nn.functional.glu, torch.Tensor.gt, torch.gt, torch.Tensor.hardshrink, torch.hardshrink, torch.nn.functional.hardshrink, torch.nn.functional.hardsigmoid, torch.nn.functional.hardswish, torch.nn.functional.hardtanh, torch.Tensor.heaviside, torch.heaviside, torch.Tensor.hypot, torch.hypot, torch.Tensor.i0, torch.i0, torch.Tensor.igamma, torch.igamma, torch.Tensor.igammac, torch.igammac, torch.Tensor.isclose, torch.isclose, torch.Tensor.isfinite, torch.isfinite, torch.Tensor.isinf, torch.isinf, torch.Tensor.isnan, torch.isnan, torch.Tensor.isneginf, torch.isneginf, torch.Tensor.isposinf, torch.isposinf, torch.Tensor.isreal, torch.isreal, torch.Tensor.kron, torch.kron, torch.Tensor.lcm, torch.lcm, torch.Tensor.ldexp, torch.ldexp, torch.Tensor.le, torch.le, torch.nn.functional.leaky_relu, torch.Tensor.lerp, torch.lerp, torch.Tensor.lgamma, torch.lgamma, torch.Tensor.log, torch.log, torch.Tensor.log10, torch.log10, torch.Tensor.log1p, torch.log1p, torch.Tensor.log2, torch.log2, torch.nn.functional.logsigmoid, torch.Tensor.logical_and, torch.logical_and, torch.Tensor.logical_not, torch.logical_not, torch.Tensor.logical_or, torch.logical_or, torch.Tensor.logical_xor, torch.logical_xor, torch.Tensor.logit, torch.logit, torch.Tensor.lt, torch.lt, torch.Tensor.maximum, torch.maximum, torch.Tensor.minimum, torch.minimum, torch.nn.functional.mish, torch.Tensor.mvlgamma, torch.mvlgamma, torch.Tensor.nan_to_num, torch.nan_to_num, torch.Tensor.ne, torch.ne, torch.Tensor.neg, torch.neg, torch.Tensor.nextafter, torch.nextafter, torch.Tensor.outer, torch.outer, torch.polar, torch.Tensor.polygamma, torch.polygamma, torch.Tensor.positive, torch.positive, torch.Tensor.pow, torch.pow, torch.Tensor.prelu, torch.prelu, torch.nn.functional.prelu, torch.Tensor.rad2deg, torch.rad2deg, torch.Tensor.reciprocal, torch.reciprocal, torch.Tensor.relu, torch.relu, torch.nn.functional.relu, torch.nn.functional.relu6, torch.Tensor.remainder, torch.remainder, torch.Tensor.round, torch.round, torch.rrelu, torch.nn.functional.rrelu, torch.Tensor.rsqrt, torch.rsqrt, torch.rsub, torch.selu, torch.nn.functional.selu, torch.Tensor.sgn, torch.sgn, torch.Tensor.sigmoid, torch.sigmoid, torch.nn.functional.sigmoid, torch.Tensor.sign, torch.sign, torch.Tensor.signbit, torch.signbit, torch.nn.functional.silu, torch.Tensor.sin, torch.sin, torch.Tensor.sinc, torch.sinc, torch.Tensor.sinh, torch.sinh, torch.nn.functional.softplus, torch.nn.functional.softshrink, torch.Tensor.sqrt, torch.sqrt, torch.Tensor.square, torch.square, torch.Tensor.sub, torch.sub, torch.Tensor.tan, torch.tan, torch.Tensor.tanh, torch.tanh, torch.nn.functional.tanh, torch.threshold, torch.nn.functional.threshold, torch.trapz, torch.Tensor.true_divide, torch.true_divide, torch.Tensor.trunc, torch.trunc, torch.Tensor.xlogy, torch.xlogy, torch.rand_like, )