|
import math |
|
import copy |
|
import time |
|
import random |
|
import spacy |
|
import numpy as np |
|
import os |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
import torch.optim as optim |
|
|
|
from model.sublayers import ( |
|
MultiHeadAttention, |
|
PositionalEncoding, |
|
PositionwiseFeedForward, |
|
Embedding) |
|
|
|
|
|
class EncoderLayer(nn.Module): |
|
""" |
|
This building block in the encoder layer consists of the following |
|
1. MultiHead Attention |
|
2. Sublayer Logic |
|
3. Positional FeedForward Network |
|
""" |
|
def __init__(self, dk, dv, h, dim_multiplier = 4, pdropout=0.1): |
|
super().__init__() |
|
self.attention = MultiHeadAttention(dk, dv, h, pdropout) |
|
|
|
dmodel = dk*h |
|
|
|
dff = dmodel * dim_multiplier |
|
self.attn_norm = nn.LayerNorm(dmodel) |
|
self.ff = PositionwiseFeedForward(dmodel, dff, pdropout=pdropout) |
|
self.ff_norm = nn.LayerNorm(dmodel) |
|
|
|
self.dropout = nn.Dropout(p = pdropout) |
|
|
|
def forward(self, src_inputs, src_mask=None): |
|
""" |
|
Forward pass as per page 3 chapter 3.1 |
|
""" |
|
mha_out, attention_wts = self.attention( |
|
query = src_inputs, |
|
key = src_inputs, |
|
val = src_inputs, |
|
mask = src_mask) |
|
|
|
|
|
|
|
intermediate_out = self.attn_norm(src_inputs + self.dropout(mha_out)) |
|
|
|
pff_out = self.ff(intermediate_out) |
|
|
|
|
|
out = self.ff_norm(intermediate_out + self.dropout(pff_out)) |
|
return out, attention_wts |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, dk, dv, h, num_encoders, dim_multiplier = 4, pdropout=0.1): |
|
super().__init__() |
|
self.encoder_layers = nn.ModuleList([ |
|
EncoderLayer(dk, |
|
dv, |
|
h, |
|
dim_multiplier, |
|
pdropout) for _ in range(num_encoders) |
|
]) |
|
|
|
def forward(self, src_inputs, src_mask = None): |
|
""" |
|
Input from the Embedding layer |
|
src_inputs = (B - batch size, S/T - max token sequence length, D- model dimension) |
|
""" |
|
src_representation = src_inputs |
|
|
|
|
|
for enc in self.encoder_layers: |
|
src_representation, attn_probs = enc(src_representation, src_mask) |
|
|
|
self.attn_probs = attn_probs |
|
return src_representation |