Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# coding=utf-8 | |
import torch | |
import torch.nn as nn | |
def checkpoint(module, *args, **kwargs): | |
dummy = torch.empty(1, requires_grad=True) | |
return torch.utils.checkpoint.checkpoint(lambda d, *a, **k: module(*a, **k), dummy, *args, **kwargs) | |
class Attention(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.attention = nn.MultiheadAttention(args.hidden_size, args.n_attention_heads, args.dropout_transformer_attention) | |
self.dropout = nn.Dropout(args.dropout_transformer) | |
def forward(self, q_input, kv_input, mask=None): | |
output, _ = self.attention(q_input, kv_input, kv_input, mask, need_weights=False) | |
output = self.dropout(output) | |
return output | |
class FeedForward(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.f = nn.Sequential( | |
nn.Linear(args.hidden_size, args.hidden_size_ff), | |
self._get_activation_f(args.activation), | |
nn.Dropout(args.dropout_transformer), | |
nn.Linear(args.hidden_size_ff, args.hidden_size), | |
nn.Dropout(args.dropout_transformer), | |
) | |
def forward(self, x): | |
return self.f(x) | |
def _get_activation_f(self, activation: str): | |
return {"relu": nn.ReLU, "gelu": nn.GELU}[activation]() | |
class DecoderLayer(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.self_f = Attention(args) | |
#self.cross_f = Attention(args) | |
self.feedforward_f = FeedForward(args) | |
self.pre_self_norm = nn.LayerNorm(args.hidden_size) if args.pre_norm else nn.Identity() | |
#self.pre_cross_norm = nn.LayerNorm(args.hidden_size) if args.pre_norm else nn.Identity() | |
self.pre_feedforward_norm = nn.LayerNorm(args.hidden_size) if args.pre_norm else nn.Identity() | |
self.post_self_norm = nn.Identity() if args.pre_norm else nn.LayerNorm(args.hidden_size) | |
#self.post_cross_norm = nn.Identity() if args.pre_norm else nn.LayerNorm(args.hidden_size) | |
self.post_feedforward_norm = nn.Identity() if args.pre_norm else nn.LayerNorm(args.hidden_size) | |
def forward(self, x, encoder_output, x_mask, encoder_mask): | |
x_ = self.pre_self_norm(x) | |
x = self.post_self_norm(x + self.self_f(x_, x_, x_mask)) | |
#x_ = self.pre_cross_norm(x) | |
#x = self.post_cross_norm(x + self.cross_f(x_, encoder_output, encoder_mask)) | |
x_ = self.pre_feedforward_norm(x) | |
x = self.post_feedforward_norm(x + self.feedforward_f(x_)) | |
return x | |
class Decoder(nn.Module): | |
def __init__(self, args): | |
super(Decoder, self).__init__() | |
self.layers = nn.ModuleList([DecoderLayer(args) for _ in range(args.n_layers)]) | |
def forward(self, target, encoder, target_mask, encoder_mask): | |
target = target.transpose(0, 1) # shape: (T, B, D) | |
encoder = encoder.transpose(0, 1) # shape: (T, B, D) | |
for layer in self.layers[:-1]: | |
target = checkpoint(layer, target, encoder, target_mask, encoder_mask) | |
target = self.layers[-1](target, encoder, target_mask, encoder_mask) # don't checkpoint due to grad_norm | |
target = target.transpose(0, 1) # shape: (B, T, D) | |
return target | |