|
import torch |
|
from torch import nn |
|
|
|
|
|
class LayerNorm(torch.nn.LayerNorm): |
|
"""Layer normalization module. |
|
:param int nout: output dim size |
|
:param int dim: dimension to be normalized |
|
""" |
|
|
|
def __init__(self, nout, dim=-1, eps=1e-5): |
|
"""Construct an LayerNorm object.""" |
|
super(LayerNorm, self).__init__(nout, eps=eps) |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
"""Apply layer normalization. |
|
:param torch.Tensor x: input tensor |
|
:return: layer normalized tensor |
|
:rtype torch.Tensor |
|
""" |
|
if self.dim == -1: |
|
return super(LayerNorm, self).forward(x) |
|
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) |
|
|
|
|
|
class Reshape(nn.Module): |
|
def __init__(self, *args): |
|
super(Reshape, self).__init__() |
|
self.shape = args |
|
|
|
def forward(self, x): |
|
return x.view(self.shape) |
|
|
|
|
|
class Permute(nn.Module): |
|
def __init__(self, *args): |
|
super(Permute, self).__init__() |
|
self.args = args |
|
|
|
def forward(self, x): |
|
return x.permute(self.args) |
|
|
|
|
|
def Embedding(num_embeddings, embedding_dim, padding_idx=None): |
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) |
|
if padding_idx is not None: |
|
nn.init.constant_(m.weight[padding_idx], 0) |
|
return m |
|
|