File size: 7,887 Bytes
0f53151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# importing required libraries
import math
import copy
import time
import random
import spacy
import numpy as np
import os 

# torch packages
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch.optim as optim

class MultiHeadAttention(nn.Module):
    """
    We can refer to the following blog to understand in depth about the transformer and MHA
    https://medium.com/@hunter-j-phillips/multi-head-attention-7924371d477a
    
    Here we are clubbing all the linear layers together and duplicating the inputs and 
    then performing matrix multiplications
    """
    def __init__(self, dk, dv, h, pdropout=0.1):
        """
        Input Args:
        
        dk(int): Key dimensions used for generating Key weight matrix
        dv(int): Val dimensions used for generating val weight matrix
        h(int) : Number of heads in MHA
        """
        super().__init__()
        assert dk == dv
        self.dk = dk
        self.dv = dv
        self.h = h
        self.dmodel = self.dk * self.h  # model dimension
        
        # Add the params in modulelist as the params in the conv list needs to be tracked
        # wq, wk, wv -> multiple linear weights for the number of heads
        self.WQ = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        self.WK = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        self.WV = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        # Output Weights
        self.WO = nn.Linear(self.h*self.dv, self.dmodel)  # shape -> (dmodel, dmodel)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(p = pdropout)
        
    def forward(self, query, key, val, mask=None):
        """
        Forward pass for MHA
        
        X has a size of (batch_size, seq_length, d_model)
        Wq, Wk, and Wv have a size of (d_model, d_model)
        
        Perform Scaled Dot Product Attention on multi head attention. 
        
        Notation: B - batch size, S/T - max src/trg token-sequence length
        query shape = (B, S, dmodel)
        key shape = (B, S, dmodel)
        val shape = (B, S, dmodel)
        """
        # Weight the queries
        Q = self.WQ(query)     # shape -> (B, S, dmodel)
        K = self.WK(key)       # shape -> (B, S, dmodel)
        V = self.WV(val)       # shape -> (B, S, dmodel)
        
        # Separate last dimension to number of head and dk
        batch_size = Q.size(0)   
        Q = Q.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        K = K.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        V = V.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        
        # each sequence is split across n_heads, with each head receiving seq_length tokens 
        # with d_key elements in each token instead of d_model.
        Q = Q.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        K = K.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        V = V.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        
        # dot product of Q and K
        scaled_dot_product = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(self.dk)
        
        # fill those positions of product as (-1e10) where mask positions are 0
        if mask is not None:
            scaled_dot_product = scaled_dot_product.masked_fill(mask == 0, -1e10)
            
        attn_probs = self.softmax(scaled_dot_product)
        
        # Create head 
        head = torch.matmul(self.dropout(attn_probs), V)  # shape -> (B, h, S, S) * (B, h, S, dk) = (B, h, S, dk)
        # Prepare the head to pass it through output linear layer
        head = head.permute(0, 2, 1, 3).contiguous()  # shape -> (B, S, h, dk)
        # Concatenate the head together
        head = head.view(batch_size, -1, self.h* self.dk)  # shape -> (B, S, (h*dk = dmodel))
        # Pass through output layer
        token_representation = self.WO(head)
        return token_representation, attn_probs


class Embedding(nn.Module):
    """
    Embedding lookup table which is used by the positional 
    embedding block.
    Embedding lookup table is shared across input and output
    """
    def __init__(self, vocab_size, dmodel):
        """
        Embedding lookup needs a vocab size and model 
        dimension size matrix for creating lookups
        """
        super().__init__()
        self.embedding_lookup = nn.Embedding(vocab_size, dmodel)
        self.vocab_size = vocab_size
        self.dmodel = dmodel

    def forward(self, token_ids):
        """
        For a given token lookup the embedding vector
        
        As per the paper, we also multiply the embedding vector with sqrt of dmodel 
        """
        assert token_ids.ndim == 2, \
        f'Expected: (batch size, max token sequence length), got {token_ids.shape}'
        
        embedding_vector = self.embedding_lookup(token_ids)
        
        return embedding_vector * math.sqrt(self.dmodel)
    

class PositionalEncoding(nn.Module):
    def __init__(self, dmodel, max_seq_length = 5000, pdropout = 0.1,):
        """
        dmodel(int): model dimensions
        max_seq_length(int): Maximum input sequence length
        pdropout(float): Dropout probability
        """
        super().__init__()
        self.dropout = nn.Dropout(p = pdropout)
        
        # Calculate frequencies
        position_ids = torch.arange(0, max_seq_length).unsqueeze(1)
        # -ve sign is added because the exponents are inverted when you multiply position and frequencies
        frequencies = torch.pow(10000, -torch.arange(0, dmodel, 2, dtype = torch.float)/ dmodel) 
        
        # Create positional encoding table
        positional_encoding_table = torch.zeros(max_seq_length, dmodel)
        # Fill the table with even entries with sin and odd entries with cosine
        positional_encoding_table[:, 0::2] = torch.sin(position_ids * frequencies)
        positional_encoding_table[:, 1::2] = torch.cos(position_ids * frequencies)
    
        # Registering the position enconding in state_dict but the its not included 
        # in named parameter as it is not trainable
        self.register_buffer("positional_encoding_table", positional_encoding_table)

    def forward(self, embeddings_batch):
        """
        embeddings_batch shape = (batch size, seq_length, dmodel)
        positional_encoding_table shape = (max_seq_length, dmodel)
        """
        assert embeddings_batch.ndim == 3, \
        f"Embeddings batch should have dimension of 3 but got {embeddings_batch.ndim}"
        assert embeddings_batch.size()[-1] == self.positional_encoding_table.size()[-1], \
        f"Embedding batch shape and positional_encoding_table shape should match, expected Embedding batch shape : {embeddings_batch.shape[-1]} while positional_encoding_table shape : {self.positional_encoding_table[-1]}"
        
        # Get encodings for the given input sequence length
        pos_encodings = self.positional_encoding_table[:embeddings_batch.shape[1]] # Choose only seq_length out of max_seq_length
        
        # Final output 
        out = embeddings_batch + pos_encodings
        out = self.dropout(out)
        return out

 
class PositionwiseFeedForward(nn.Module):
    def __init__(self, dmodel, dff, pdropout = 0.1):
        super().__init__()
        
        self.dropout = nn.Dropout(p = pdropout)
        
        self.W1 = nn.Linear(dmodel, dff)      # Intermediate layer
        self.W2 = nn.Linear(dff, dmodel)    # Output layer
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        """
        Perform Feedforward calculation
        
        x shape = (B - batch size, S/T - max token sequence length, D- model dimension).
        """
        out = self.W2(self.relu(self.dropout(self.W1(x))))
        return out