Delete model.py
Browse files
model.py
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
from torch import Tensor
|
4 |
-
from torch.nn import Transformer
|
5 |
-
|
6 |
-
# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
|
7 |
-
class PositionalEncoding(nn.Module):
|
8 |
-
def __init__(self,
|
9 |
-
emb_size: int,
|
10 |
-
dropout: float,
|
11 |
-
maxlen: int = 5000):
|
12 |
-
super(PositionalEncoding, self).__init__()
|
13 |
-
den = torch.exp(- torch.arange(0, emb_size, 2)* torch.log(10000) / emb_size)
|
14 |
-
pos = torch.arange(0, maxlen).reshape(maxlen, 1)
|
15 |
-
pos_embedding = torch.zeros((maxlen, emb_size))
|
16 |
-
pos_embedding[:, 0::2] = torch.sin(pos * den)
|
17 |
-
pos_embedding[:, 1::2] = torch.cos(pos * den)
|
18 |
-
pos_embedding = pos_embedding.unsqueeze(-2)
|
19 |
-
|
20 |
-
self.dropout = nn.Dropout(dropout)
|
21 |
-
self.register_buffer('pos_embedding', pos_embedding)
|
22 |
-
|
23 |
-
def forward(self, token_embedding: Tensor):
|
24 |
-
return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
|
25 |
-
|
26 |
-
# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
|
27 |
-
class TokenEmbedding(nn.Module):
|
28 |
-
def __init__(self, vocab_size: int, emb_size):
|
29 |
-
super(TokenEmbedding, self).__init__()
|
30 |
-
self.embedding = nn.Embedding(vocab_size, emb_size)
|
31 |
-
self.emb_size = emb_size
|
32 |
-
|
33 |
-
def forward(self, tokens: Tensor):
|
34 |
-
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
|
35 |
-
|
36 |
-
class Seq2SeqTransformer(nn.Module):
|
37 |
-
def __init__(self,
|
38 |
-
num_encoder_layers: int,
|
39 |
-
num_decoder_layers: int,
|
40 |
-
emb_size: int,
|
41 |
-
nhead: int,
|
42 |
-
src_vocab_size: int,
|
43 |
-
tgt_vocab_size: int,
|
44 |
-
dim_feedforward: int = 512,
|
45 |
-
dropout: float = 0.1):
|
46 |
-
super(Seq2SeqTransformer, self).__init__()
|
47 |
-
self.transformer = Transformer(d_model=emb_size,
|
48 |
-
nhead=nhead,
|
49 |
-
num_encoder_layers=num_encoder_layers,
|
50 |
-
num_decoder_layers=num_decoder_layers,
|
51 |
-
dim_feedforward=dim_feedforward,
|
52 |
-
dropout=dropout,
|
53 |
-
batch_first=True)
|
54 |
-
self.generator = nn.Linear(emb_size, tgt_vocab_size)
|
55 |
-
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
|
56 |
-
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
|
57 |
-
self.positional_encoding = PositionalEncoding(
|
58 |
-
emb_size, dropout=dropout)
|
59 |
-
|
60 |
-
def forward(self,
|
61 |
-
src: Tensor,
|
62 |
-
trg: Tensor,
|
63 |
-
src_mask: Tensor,
|
64 |
-
tgt_mask: Tensor,
|
65 |
-
src_padding_mask: Tensor,
|
66 |
-
tgt_padding_mask: Tensor,
|
67 |
-
memory_key_padding_mask: Tensor):
|
68 |
-
src_emb = self.positional_encoding(self.src_tok_emb(src))
|
69 |
-
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
|
70 |
-
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
|
71 |
-
src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
|
72 |
-
return self.generator(outs)
|
73 |
-
|
74 |
-
def encode(self, src: Tensor, src_mask: Tensor):
|
75 |
-
return self.transformer.encoder(self.positional_encoding(
|
76 |
-
self.src_tok_emb(src)), src_mask)
|
77 |
-
|
78 |
-
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
|
79 |
-
return self.transformer.decoder(self.positional_encoding(
|
80 |
-
self.tgt_tok_emb(tgt)), memory,
|
81 |
-
tgt_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|