File size: 3,666 Bytes
c6e7238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
import random

BASE_FNS = {'gelu': mtf.gelu,
            'relu': mtf.relu,
            'sigmoid': mtf.sigmoid,
            'tanh': mtf.tanh,
            'selu': mtf.selu,
            'elu': mtf.elu,
            'abs': mtf.abs,
            'sin': mtf.sin,
            'cos': mtf.cos,
            'sign': mtf.sign,
            'silu': mtf.swish,
            'softplus': mtf.softplus
            }


def _arcsinh(x):
    return mtf.log(x + mtf.sqrt(1 + x ** 2))


def _var(x, init):
    return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [],
                            initializer=tf.constant_initializer(init), dtype=x.dtype)


def _pos_var(x, val):
    return mtf.softplus(_var(x, 0)) + val


def _rrelu(x):
    negative_scale = random.random()
    return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale)


def _elish(x):
    cond = mtf.cast(mtf.greater(x, 0), x.dtype)
    exp = mtf.exp(x)
    return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1)


CUSTOM_FNS = {'lrelu001': lambda x: mtf.leaky_relu(x, alpha=0.01),
              'lrelu020': lambda x: mtf.leaky_relu(x, alpha=0.20),
              'id': lambda x: x,
              'triangle_relax': lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(5 * x) / 25 - mtf.sin(7 * x) / 49,
              'square_relax': lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos(7 * x) / 7,
              'spike': lambda x: 1 / (1 + x ** 2),
              'spike2': lambda x: mtf.exp(-x ** 2),
              'tanhshrink': lambda x: x - tanh(x),
              'softsign': lambda x: x / (mtf.abs(x) + 1),
              'softmax': lambda x: mtf.softmax(x, x.shape[-1]),
              'logsoftmax': lambda x: mtf.log_softmax(x, x.shape[-1]),
              'bipolarsigmoid': lambda x: mtf.sigmoid(x) * 2 - 1,
              'rrelu': _rrelu,
              'elish': _elish,
              'arcsinh': _arcsinh,
              'aria': lambda x: x * (_var(x, 0) + _var(x, 1) / (
                          _pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x) ** (1 / _pos_var(x, 1)))),
              'prelu': lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)),
              'parcsinh': lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)),
              'psoftplus': lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0),
              'proottanh': lambda x: (x ** _pos_var(x, 2) + _pos_var(x, 1)) ** (1 / _pos_var(x, 3)) * mtf.tanh(x),
              'maxsig': lambda x: mtf.maximum(x, mtf.sigmoid(x)),
              'cosid': lambda x: mtf.cos(x) - x,
              'minsin': lambda x: mtf.minimum(x, mtf.sin(x)),
              'maxtanh': lambda x: mtf.maximum(x, mtf.tanh(x)),
              'mish': lambda x: x * mtf.tanh(mtf.softplus(x)),
              'tanhexp': lambda x: x * mtf.tanh(mtf.exp(x)),
              'lisht': lambda x: x * mtf.tanh(x),
              'seagull': lambda x: mtf.log(1 + x ** 2),
              'snake': lambda x: x + mtf.sin(x) ** 2,
              'roottanh': lambda x: (x ** 2 + 1) ** (1 / 3) * mtf.tanh(x),
              'softplusmone': lambda x: mtf.softplus(x) - 1
              }


def get_activation_fn(params):
    if "activation_fn" in params:
        activation_fn = params["activation_fn"]
    else:
        print("Defaulting to GELU activation (see here: https://arxiv.org/abs/1606.08415)")
        activation_fn = "gelu"

    if activation_fn in BASE_FNS:
        return BASE_FNS[activation_fn]

    if activation_fn in CUSTOM_FNS:
        return CUSTOM_FNS[activation_fn]

    raise ValueError('unknown activation function "activation_fn" in config')