sakharamg's picture
Uploading all files
158b61b
"""Position feed-forward network from "Attention is All You Need"."""
import torch.nn as nn
import torch.nn.functional as F
class ActivationFunction(object):
relu = "relu"
gelu = "gelu"
ACTIVATION_FUNCTIONS = {
ActivationFunction.relu: F.relu,
ActivationFunction.gelu: F.gelu,
}
class PositionwiseFeedForward(nn.Module):
""" A two-layer Feed-Forward-Network with residual layer norm.
Args:
d_model (int): the size of input for the first-layer of the FFN.
d_ff (int): the hidden layer size of the second-layer
of the FNN.
dropout (float): dropout probability in :math:`[0, 1)`.
activation_fn (ActivationFunction): activation function used.
"""
def __init__(self, d_model, d_ff, dropout=0.1,
activation_fn=ActivationFunction.relu):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.dropout_1 = nn.Dropout(dropout)
self.activation = ACTIVATION_FUNCTIONS[activation_fn]
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x):
"""Layer definition.
Args:
x: ``(batch_size, input_len, model_dim)``
Returns:
(FloatTensor): Output ``(batch_size, input_len, model_dim)``.
"""
inter = self.dropout_1(self.activation(self.w_1(self.layer_norm(x))))
output = self.dropout_2(self.w_2(inter))
return output + x
def update_dropout(self, dropout):
self.dropout_1.p = dropout
self.dropout_2.p = dropout