Sifal commited on
Commit
d8b103e
1 Parent(s): f3abe06

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -81
model.py DELETED
@@ -1,81 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch import Tensor
4
- from torch.nn import Transformer
5
-
6
- # helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
7
- class PositionalEncoding(nn.Module):
8
- def __init__(self,
9
- emb_size: int,
10
- dropout: float,
11
- maxlen: int = 5000):
12
- super(PositionalEncoding, self).__init__()
13
- den = torch.exp(- torch.arange(0, emb_size, 2)* torch.log(10000) / emb_size)
14
- pos = torch.arange(0, maxlen).reshape(maxlen, 1)
15
- pos_embedding = torch.zeros((maxlen, emb_size))
16
- pos_embedding[:, 0::2] = torch.sin(pos * den)
17
- pos_embedding[:, 1::2] = torch.cos(pos * den)
18
- pos_embedding = pos_embedding.unsqueeze(-2)
19
-
20
- self.dropout = nn.Dropout(dropout)
21
- self.register_buffer('pos_embedding', pos_embedding)
22
-
23
- def forward(self, token_embedding: Tensor):
24
- return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
25
-
26
- # helper Module to convert tensor of input indices into corresponding tensor of token embeddings
27
- class TokenEmbedding(nn.Module):
28
- def __init__(self, vocab_size: int, emb_size):
29
- super(TokenEmbedding, self).__init__()
30
- self.embedding = nn.Embedding(vocab_size, emb_size)
31
- self.emb_size = emb_size
32
-
33
- def forward(self, tokens: Tensor):
34
- return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
35
-
36
- class Seq2SeqTransformer(nn.Module):
37
- def __init__(self,
38
- num_encoder_layers: int,
39
- num_decoder_layers: int,
40
- emb_size: int,
41
- nhead: int,
42
- src_vocab_size: int,
43
- tgt_vocab_size: int,
44
- dim_feedforward: int = 512,
45
- dropout: float = 0.1):
46
- super(Seq2SeqTransformer, self).__init__()
47
- self.transformer = Transformer(d_model=emb_size,
48
- nhead=nhead,
49
- num_encoder_layers=num_encoder_layers,
50
- num_decoder_layers=num_decoder_layers,
51
- dim_feedforward=dim_feedforward,
52
- dropout=dropout,
53
- batch_first=True)
54
- self.generator = nn.Linear(emb_size, tgt_vocab_size)
55
- self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
56
- self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
57
- self.positional_encoding = PositionalEncoding(
58
- emb_size, dropout=dropout)
59
-
60
- def forward(self,
61
- src: Tensor,
62
- trg: Tensor,
63
- src_mask: Tensor,
64
- tgt_mask: Tensor,
65
- src_padding_mask: Tensor,
66
- tgt_padding_mask: Tensor,
67
- memory_key_padding_mask: Tensor):
68
- src_emb = self.positional_encoding(self.src_tok_emb(src))
69
- tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
70
- outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
71
- src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
72
- return self.generator(outs)
73
-
74
- def encode(self, src: Tensor, src_mask: Tensor):
75
- return self.transformer.encoder(self.positional_encoding(
76
- self.src_tok_emb(src)), src_mask)
77
-
78
- def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
79
- return self.transformer.decoder(self.positional_encoding(
80
- self.tgt_tok_emb(tgt)), memory,
81
- tgt_mask)