Spaces:
Running
Running
File size: 1,542 Bytes
03f6091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# -*- coding: utf-8 -*-
r"""
Activation Functions
==============
Provides an easy to use interface to initialize different activation functions.
"""
import torch
from torch import nn
def build_activation(activation: str) -> nn.Module:
"""Builder function that returns a nn.module activation function.
:param activation: string defining the name of the activation function.
Activations available:
Swish + every native pytorch activation function.
"""
if hasattr(nn, activation):
return getattr(nn, activation)()
elif activation == "Swish":
return Swish()
else:
raise Exception("{} invalid activation function.".format(activation))
def swish(input: torch.Tensor) -> torch.Tensor:
"""
Applies Swish element-wise: A self-gated activation function
swish(x) = x * sigmoid(x)
"""
return input * torch.sigmoid(input)
class Swish(nn.Module):
"""
Applies the Swish function element-wise:
Swish(x) = x * sigmoid(x)
Shape:
- Input: (N, *) where * means, any number of additional
dimensions
- Output: (N, *), same shape as the input
References:
- Related paper:
https://arxiv.org/pdf/1710.05941v1.pdf
"""
def __init__(self) -> None:
"""
Init method.
"""
super().__init__() # init the base class
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the function.
"""
return swish(input)
|