ssa-perin / model /module /transformer.py
larkkin's picture
Add code
991f07c
raw
history blame
3.25 kB
#!/usr/bin/env python3
# coding=utf-8
import torch
import torch.nn as nn
def checkpoint(module, *args, **kwargs):
dummy = torch.empty(1, requires_grad=True)
return torch.utils.checkpoint.checkpoint(lambda d, *a, **k: module(*a, **k), dummy, *args, **kwargs)
class Attention(nn.Module):
def __init__(self, args):
super().__init__()
self.attention = nn.MultiheadAttention(args.hidden_size, args.n_attention_heads, args.dropout_transformer_attention)
self.dropout = nn.Dropout(args.dropout_transformer)
def forward(self, q_input, kv_input, mask=None):
output, _ = self.attention(q_input, kv_input, kv_input, mask, need_weights=False)
output = self.dropout(output)
return output
class FeedForward(nn.Module):
def __init__(self, args):
super().__init__()
self.f = nn.Sequential(
nn.Linear(args.hidden_size, args.hidden_size_ff),
self._get_activation_f(args.activation),
nn.Dropout(args.dropout_transformer),
nn.Linear(args.hidden_size_ff, args.hidden_size),
nn.Dropout(args.dropout_transformer),
)
def forward(self, x):
return self.f(x)
def _get_activation_f(self, activation: str):
return {"relu": nn.ReLU, "gelu": nn.GELU}[activation]()
class DecoderLayer(nn.Module):
def __init__(self, args):
super().__init__()
self.self_f = Attention(args)
#self.cross_f = Attention(args)
self.feedforward_f = FeedForward(args)
self.pre_self_norm = nn.LayerNorm(args.hidden_size) if args.pre_norm else nn.Identity()
#self.pre_cross_norm = nn.LayerNorm(args.hidden_size) if args.pre_norm else nn.Identity()
self.pre_feedforward_norm = nn.LayerNorm(args.hidden_size) if args.pre_norm else nn.Identity()
self.post_self_norm = nn.Identity() if args.pre_norm else nn.LayerNorm(args.hidden_size)
#self.post_cross_norm = nn.Identity() if args.pre_norm else nn.LayerNorm(args.hidden_size)
self.post_feedforward_norm = nn.Identity() if args.pre_norm else nn.LayerNorm(args.hidden_size)
def forward(self, x, encoder_output, x_mask, encoder_mask):
x_ = self.pre_self_norm(x)
x = self.post_self_norm(x + self.self_f(x_, x_, x_mask))
#x_ = self.pre_cross_norm(x)
#x = self.post_cross_norm(x + self.cross_f(x_, encoder_output, encoder_mask))
x_ = self.pre_feedforward_norm(x)
x = self.post_feedforward_norm(x + self.feedforward_f(x_))
return x
class Decoder(nn.Module):
def __init__(self, args):
super(Decoder, self).__init__()
self.layers = nn.ModuleList([DecoderLayer(args) for _ in range(args.n_layers)])
def forward(self, target, encoder, target_mask, encoder_mask):
target = target.transpose(0, 1) # shape: (T, B, D)
encoder = encoder.transpose(0, 1) # shape: (T, B, D)
for layer in self.layers[:-1]:
target = checkpoint(layer, target, encoder, target_mask, encoder_mask)
target = self.layers[-1](target, encoder, target_mask, encoder_mask) # don't checkpoint due to grad_norm
target = target.transpose(0, 1) # shape: (B, T, D)
return target