#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2019 Shigeki Karita # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Layer normalization module.""" import torch import torch.nn as nn class LayerNorm(torch.nn.LayerNorm): """Layer normalization module. Args: nout (int): Output dim size. dim (int): Dimension to be normalized. """ def __init__(self, nout, dim=-1): """Construct an LayerNorm object.""" super(LayerNorm, self).__init__(nout, eps=1e-12) self.dim = dim def forward(self, x): """Apply layer normalization. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Normalized tensor. """ if self.dim == -1: return super(LayerNorm, self).forward(x) return ( super(LayerNorm, self) .forward(x.transpose(self.dim, -1)) .transpose(self.dim, -1) ) class GlobalLayerNorm(nn.Module): """Calculate Global Layer Normalization. Arguments --------- dim : (int or list or torch.Size) Input shape from an expected input of size. eps : float A value added to the denominator for numerical stability. elementwise_affine : bool A boolean value that when set to True, this module has learnable per-element affine parameters initialized to ones (for weights) and zeros (for biases). Example ------- >>> x = torch.randn(5, 10, 20) >>> GLN = GlobalLayerNorm(10, 3) >>> x_norm = GLN(x) """ def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True): super(GlobalLayerNorm, self).__init__() self.dim = dim self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: if shape == 3: self.weight = nn.Parameter(torch.ones(self.dim, 1)) self.bias = nn.Parameter(torch.zeros(self.dim, 1)) if shape == 4: self.weight = nn.Parameter(torch.ones(self.dim, 1, 1)) self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) def forward(self, x): """Returns the normalized tensor. Arguments --------- x : torch.Tensor Tensor of size [N, C, K, S] or [N, C, L]. """ # x = N x C x K x S or N x C x L # N x 1 x 1 # cln: mean,var N x 1 x K x S # gln: mean,var N x 1 x 1 if x.dim() == 3: mean = torch.mean(x, (1, 2), keepdim=True) var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True) if self.elementwise_affine: x = self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias else: x = (x - mean) / torch.sqrt(var + self.eps) if x.dim() == 4: mean = torch.mean(x, (1, 2, 3), keepdim=True) var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True) if self.elementwise_affine: x = self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias else: x = (x - mean) / torch.sqrt(var + self.eps) return x class CumulativeLayerNorm(nn.LayerNorm): """Calculate Cumulative Layer Normalization. Arguments --------- dim : int Dimension that you want to normalize. elementwise_affine : True Learnable per-element affine parameters. Example ------- >>> x = torch.randn(5, 10, 20) >>> CLN = CumulativeLayerNorm(10) >>> x_norm = CLN(x) """ def __init__(self, dim, elementwise_affine=True): super(CumulativeLayerNorm, self).__init__( dim, elementwise_affine=elementwise_affine, eps=1e-8 ) def forward(self, x): """Returns the normalized tensor. Arguments --------- x : torch.Tensor Tensor size [N, C, K, S] or [N, C, L] """ # x: N x C x K x S or N x C x L # N x K x S x C if x.dim() == 4: x = x.permute(0, 2, 3, 1).contiguous() # N x K x S x C == only channel norm x = super().forward(x) # N x C x K x S x = x.permute(0, 3, 1, 2).contiguous() if x.dim() == 3: x = torch.transpose(x, 1, 2) # N x L x C == only channel norm x = super().forward(x) # N x C x L x = torch.transpose(x, 1, 2) return x class ScaleNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) def forward(self, x): norm = torch.norm(x, dim=-1, keepdim=True) * self.scale return x / norm.clamp(min=self.eps) * self.g