Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# Copyright 2019 Shigeki Karita | |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
"""Layer normalization module.""" | |
import torch | |
import torch.nn as nn | |
class LayerNorm(torch.nn.LayerNorm): | |
"""Layer normalization module. | |
Args: | |
nout (int): Output dim size. | |
dim (int): Dimension to be normalized. | |
""" | |
def __init__(self, nout, dim=-1): | |
"""Construct an LayerNorm object.""" | |
super(LayerNorm, self).__init__(nout, eps=1e-12) | |
self.dim = dim | |
def forward(self, x): | |
"""Apply layer normalization. | |
Args: | |
x (torch.Tensor): Input tensor. | |
Returns: | |
torch.Tensor: Normalized tensor. | |
""" | |
if self.dim == -1: | |
return super(LayerNorm, self).forward(x) | |
return ( | |
super(LayerNorm, self) | |
.forward(x.transpose(self.dim, -1)) | |
.transpose(self.dim, -1) | |
) | |
class GlobalLayerNorm(nn.Module): | |
"""Calculate Global Layer Normalization. | |
Arguments | |
--------- | |
dim : (int or list or torch.Size) | |
Input shape from an expected input of size. | |
eps : float | |
A value added to the denominator for numerical stability. | |
elementwise_affine : bool | |
A boolean value that when set to True, | |
this module has learnable per-element affine parameters | |
initialized to ones (for weights) and zeros (for biases). | |
Example | |
------- | |
>>> x = torch.randn(5, 10, 20) | |
>>> GLN = GlobalLayerNorm(10, 3) | |
>>> x_norm = GLN(x) | |
""" | |
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True): | |
super(GlobalLayerNorm, self).__init__() | |
self.dim = dim | |
self.eps = eps | |
self.elementwise_affine = elementwise_affine | |
if self.elementwise_affine: | |
if shape == 3: | |
self.weight = nn.Parameter(torch.ones(self.dim, 1)) | |
self.bias = nn.Parameter(torch.zeros(self.dim, 1)) | |
if shape == 4: | |
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1)) | |
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1)) | |
else: | |
self.register_parameter("weight", None) | |
self.register_parameter("bias", None) | |
def forward(self, x): | |
"""Returns the normalized tensor. | |
Arguments | |
--------- | |
x : torch.Tensor | |
Tensor of size [N, C, K, S] or [N, C, L]. | |
""" | |
# x = N x C x K x S or N x C x L | |
# N x 1 x 1 | |
# cln: mean,var N x 1 x K x S | |
# gln: mean,var N x 1 x 1 | |
if x.dim() == 3: | |
mean = torch.mean(x, (1, 2), keepdim=True) | |
var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True) | |
if self.elementwise_affine: | |
x = self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias | |
else: | |
x = (x - mean) / torch.sqrt(var + self.eps) | |
if x.dim() == 4: | |
mean = torch.mean(x, (1, 2, 3), keepdim=True) | |
var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True) | |
if self.elementwise_affine: | |
x = self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias | |
else: | |
x = (x - mean) / torch.sqrt(var + self.eps) | |
return x | |
class CumulativeLayerNorm(nn.LayerNorm): | |
"""Calculate Cumulative Layer Normalization. | |
Arguments | |
--------- | |
dim : int | |
Dimension that you want to normalize. | |
elementwise_affine : True | |
Learnable per-element affine parameters. | |
Example | |
------- | |
>>> x = torch.randn(5, 10, 20) | |
>>> CLN = CumulativeLayerNorm(10) | |
>>> x_norm = CLN(x) | |
""" | |
def __init__(self, dim, elementwise_affine=True): | |
super(CumulativeLayerNorm, self).__init__( | |
dim, elementwise_affine=elementwise_affine, eps=1e-8 | |
) | |
def forward(self, x): | |
"""Returns the normalized tensor. | |
Arguments | |
--------- | |
x : torch.Tensor | |
Tensor size [N, C, K, S] or [N, C, L] | |
""" | |
# x: N x C x K x S or N x C x L | |
# N x K x S x C | |
if x.dim() == 4: | |
x = x.permute(0, 2, 3, 1).contiguous() | |
# N x K x S x C == only channel norm | |
x = super().forward(x) | |
# N x C x K x S | |
x = x.permute(0, 3, 1, 2).contiguous() | |
if x.dim() == 3: | |
x = torch.transpose(x, 1, 2) | |
# N x L x C == only channel norm | |
x = super().forward(x) | |
# N x C x L | |
x = torch.transpose(x, 1, 2) | |
return x | |
class ScaleNorm(nn.Module): | |
def __init__(self, dim, eps=1e-5): | |
super().__init__() | |
self.scale = dim**-0.5 | |
self.eps = eps | |
self.g = nn.Parameter(torch.ones(1)) | |
def forward(self, x): | |
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale | |
return x / norm.clamp(min=self.eps) * self.g | |