Spaces:
Sleeping
Sleeping
from functools import partial | |
import torch | |
from torch import nn | |
class Swish(nn.Module): | |
def __init__(self): | |
super(Swish, self).__init__() | |
def forward(self, x): | |
return x * torch.sigmoid(x) | |
def linear(): | |
return nn.Identity() | |
def relu(): | |
return nn.ReLU() | |
def prelu(): | |
return nn.PReLU() | |
def leaky_relu(): | |
return nn.LeakyReLU() | |
def sigmoid(): | |
return nn.Sigmoid() | |
def softmax(dim=None): | |
return nn.Softmax(dim=dim) | |
def tanh(): | |
return nn.Tanh() | |
def gelu(): | |
return nn.GELU() | |
def swish(): | |
return Swish() | |
def register_activation(custom_act): | |
"""Register a custom activation, gettable with `activation.get`. | |
Args: | |
custom_act: Custom activation function to register. | |
""" | |
if custom_act.__name__ in globals().keys() or custom_act.__name__.lower() in globals().keys(): | |
raise ValueError(f"Activation {custom_act.__name__} already exists. Choose another name.") | |
globals().update({custom_act.__name__: custom_act}) | |
def get(identifier): | |
"""Returns an activation function from a string. Returns its input if it | |
is callable (already an activation for example). | |
Args: | |
identifier (str or Callable or None): the activation identifier. | |
Returns: | |
:class:`nn.Module` or None | |
""" | |
if identifier is None: | |
return None | |
elif callable(identifier): | |
return identifier | |
elif isinstance(identifier, str): | |
cls = globals().get(identifier) | |
if cls is None: | |
raise ValueError("Could not interpret activation identifier: " + str(identifier)) | |
return cls | |
else: | |
raise ValueError("Could not interpret activation identifier: " + str(identifier)) | |