cantabile-kwok
prepare demo page
05005db
raw
history blame contribute delete
958 Bytes
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Layer normalization module."""
import torch
class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module.
Args:
nout (int): Output dim size.
dim (int): Dimension to be normalized.
"""
def __init__(self, nout, dim=-1):
"""Construct an LayerNorm object."""
super(LayerNorm, self).__init__(nout, eps=1e-12)
self.dim = dim
def forward(self, x):
"""Apply layer normalization.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Normalized tensor.
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return (
super(LayerNorm, self)
.forward(x.transpose(self.dim, -1))
.transpose(self.dim, -1)
)