Spaces:
Sleeping
Sleeping
File size: 4,196 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from enum import Enum
from typing import Optional
import triton
import triton.language as tl
_sqrt2pi = math.sqrt(2.0 / math.pi)
_sqrt1_2 = math.sqrt(1.0 / 2)
_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
class Activation(str, Enum):
SquaredReLU = "squared_relu"
GeLU = "gelu"
GeLUApprox = "gelu_approx"
LeakyReLU = "leaky_relu"
ReLU = "relu"
def get_triton_activation_kernel(activation: Optional[Activation]):
return (
{
Activation.ReLU: relu,
Activation.LeakyReLU: leaky_relu,
Activation.GeLU: gelu,
Activation.GeLUApprox: gelu_approx,
Activation.SquaredReLU: squared_relu,
}[activation]
if activation
else None
)
def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
return (
{
Activation.ReLU: relu_grad,
Activation.LeakyReLU: leaky_relu_grad,
Activation.GeLU: gelu_grad,
Activation.GeLUApprox: gelu_approx_grad,
Activation.SquaredReLU: squared_relu_grad,
}[activation]
if activation
else None
)
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def cosh(x):
exp_x = tl.exp(x)
return (exp_x + 1.0 / exp_x) * 0.5
# a Triton implementation of the most used activations
# See for instance http://arxiv.org/abs/1606.08415 for an overview
# ReLU
@triton.jit
def relu(x):
"""
ReLU_ activation function
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
zero = 0.0
return tl.where(x >= 0, x, zero.to(x.dtype))
@triton.jit
def relu_grad(x):
# ReLU is different from other activations
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
zero = 0.0
one = 1.0
return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))
@triton.jit
def squared_relu(x):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_ = relu(x)
return (x_ * x_).to(x.dtype)
@triton.jit
def squared_relu_grad(x):
return tl.where(x >= 0, 2.0 * x, 0.0)
# Leaky ReLU
@triton.jit
def leaky_relu(x):
"""
LeakyReLU_ activation
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
scale = 0.01 + 0.0
scale = scale.to(x.dtype)
return tl.where(x >= 0, x, scale * x)
@triton.jit
def leaky_relu_grad(x):
min_grad = 0.01
max_grad = 1
min_grad = min_grad.to(x.dtype)
max_grad = max_grad.to(x.dtype)
return tl.where(x >= 0, max_grad, min_grad)
@triton.jit
def gelu(x):
"""Gaussian Error Linear Unit (GELU)"""
return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
@triton.jit
def gelu_grad(x):
cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
return cdf + x * pdf
@triton.jit
def gelu_approx(x):
"""
GeLU_ activation - Gaussian error linear unit, with tanh approximation
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
"""
return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))
@triton.jit
def gelu_approx_grad(x):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
|