# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. # LICENSE is in incl_licenses directory. # Modified by Yiwei Guo, 2024 # including conditioned snakebeta activation import torch from torch import nn, sin, pow from torch.nn import Parameter class Snake(nn.Module): ''' Implementation of a sine-based periodic activation function Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter References: - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snake(256) >>> x = torch.randn(256) >>> x = a1(x) ''' def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. INPUT: - in_features: shape of the input - alpha: trainable parameter alpha is initialized to 1 by default, higher values = higher-frequency. alpha will be trained along with the rest of your model. ''' super(Snake, self).__init__() self.in_features = in_features # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x): ''' Forward pass of the function. Applies the function to the input elementwise. Snake := x + 1/a * sin^2 (xa) ''' alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] if self.alpha_logscale: alpha = torch.exp(alpha) x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x class SnakeBeta(nn.Module): ''' A modified Snake function which uses separate parameters for the magnitude of the periodic components Shape: - Input: (B, C, T) - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude References: - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snakebeta(256) >>> x = torch.randn(256) >>> x = a1(x) ''' def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. INPUT: - in_features: shape of the input - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude alpha is initialized to 1 by default, higher values = higher-frequency. beta is initialized to 1 by default, higher values = higher-magnitude. alpha will be trained along with the rest of your model. ''' super(SnakeBeta, self).__init__() self.in_features = in_features # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.beta = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.beta.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x, cond=None): ''' Forward pass of the function. Applies the function to the input elementwise. SnakeBeta ∶= x + 1/b * sin^2 (xa) ''' alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x class SnakeBetaWithCondition(nn.Module): ''' A modified Snake function which uses separate parameters for the magnitude of the periodic components Shape: - Input: (B, C, T) - Condition: (B, D), where D-dimension will be mapped to C dimensions - Output: (B, C, T), same shape as the input Parameters: - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude - condition_alpha_prenet - trainable parameter that controls alpha and beta using condition References: - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195 Examples: >>> a1 = snakebeta(256, 128) >>> x = torch.randn(256) >>> cond = torch.randn(128) >>> x = a1(x, cond) ''' def __init__(self, in_features, condition_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): ''' Initialization. INPUT: - in_features: dimension of the input - condition_features: dimension of the condition vectors - alpha - trainable parameter that controls frequency - beta - trainable parameter that controls magnitude alpha is initialized to 1 by default, higher values = higher-frequency. beta is initialized to 1 by default, higher values = higher-magnitude. alpha, beta will be trained along with the rest of your model. ''' super(SnakeBetaWithCondition, self).__init__() self.in_features = in_features self.condition_alpha_prenet = torch.nn.Linear(condition_features, in_features) # self.condition_beta_prenet = torch.nn.Linear(condition_features, in_features) # initialize alpha self.alpha_logscale = alpha_logscale if self.alpha_logscale: # log scale alphas initialized to zeros self.alpha = Parameter(torch.zeros(in_features) * alpha) self.beta = Parameter(torch.zeros(in_features) * alpha) else: # linear scale alphas initialized to ones self.alpha = Parameter(torch.ones(in_features) * alpha) self.beta = Parameter(torch.ones(in_features) * alpha) self.alpha.requires_grad = alpha_trainable self.beta.requires_grad = alpha_trainable self.no_div_by_zero = 0.000000001 def forward(self, x, condition): ''' condition: [B, D] Forward pass of the function. Applies the function to the input elementwise. SnakeBeta := x + 1/b * sin^2 (xa) ''' alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] beta = self.beta.unsqueeze(0).unsqueeze(-1) if self.alpha_logscale: alpha = torch.exp(alpha) beta = torch.exp(beta) condition = torch.tanh(self.condition_alpha_prenet(condition).unsqueeze(-1)) # Same prenet for both alpha and beta, to save parameters alpha = alpha + condition beta = beta + 0.5 * condition # multiply 0.5 for avoiding beta being too small x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) return x