Rzoro's picture
Add trained transformer model
0f53151
import math
import copy
import time
import random
import spacy
import numpy as np
import os
# torch packages
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch.optim as optim
from model.sublayers import (
MultiHeadAttention,
PositionalEncoding,
PositionwiseFeedForward,
Embedding)
class EncoderLayer(nn.Module):
"""
This building block in the encoder layer consists of the following
1. MultiHead Attention
2. Sublayer Logic
3. Positional FeedForward Network
"""
def __init__(self, dk, dv, h, dim_multiplier = 4, pdropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(dk, dv, h, pdropout)
# Reference page 5 chapter 3.2.2 Multi-head attention
dmodel = dk*h
# Reference page 5 chapter 3.3 positionwise FeedForward
dff = dmodel * dim_multiplier
self.attn_norm = nn.LayerNorm(dmodel)
self.ff = PositionwiseFeedForward(dmodel, dff, pdropout=pdropout)
self.ff_norm = nn.LayerNorm(dmodel)
self.dropout = nn.Dropout(p = pdropout)
def forward(self, src_inputs, src_mask=None):
"""
Forward pass as per page 3 chapter 3.1
"""
mha_out, attention_wts = self.attention(
query = src_inputs,
key = src_inputs,
val = src_inputs,
mask = src_mask)
# Residual connection between input and sublayer output, details: Page 7, Chapter 5.4 "Regularization",
# Actual paper design is the following
intermediate_out = self.attn_norm(src_inputs + self.dropout(mha_out))
pff_out = self.ff(intermediate_out)
# Perform Add Norm again
out = self.ff_norm(intermediate_out + self.dropout(pff_out))
return out, attention_wts
class Encoder(nn.Module):
def __init__(self, dk, dv, h, num_encoders, dim_multiplier = 4, pdropout=0.1):
super().__init__()
self.encoder_layers = nn.ModuleList([
EncoderLayer(dk,
dv,
h,
dim_multiplier,
pdropout) for _ in range(num_encoders)
])
def forward(self, src_inputs, src_mask = None):
"""
Input from the Embedding layer
src_inputs = (B - batch size, S/T - max token sequence length, D- model dimension)
"""
src_representation = src_inputs
# Forward pass through encoder stack
for enc in self.encoder_layers:
src_representation, attn_probs = enc(src_representation, src_mask)
self.attn_probs = attn_probs
return src_representation