Sifal commited on
Commit
9065f39
1 Parent(s): de48642

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -1
app.py CHANGED
@@ -1,5 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from .utils import translate
3
 
4
  x = lambda text : translate(x)
5
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ from torch.nn import Transformer
5
+
6
+ # helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
7
+ class PositionalEncoding(nn.Module):
8
+ def __init__(self,
9
+ emb_size: int,
10
+ dropout: float,
11
+ maxlen: int = 5000):
12
+ super(PositionalEncoding, self).__init__()
13
+ den = torch.exp(- torch.arange(0, emb_size, 2)* torch.log(10000) / emb_size)
14
+ pos = torch.arange(0, maxlen).reshape(maxlen, 1)
15
+ pos_embedding = torch.zeros((maxlen, emb_size))
16
+ pos_embedding[:, 0::2] = torch.sin(pos * den)
17
+ pos_embedding[:, 1::2] = torch.cos(pos * den)
18
+ pos_embedding = pos_embedding.unsqueeze(-2)
19
+
20
+ self.dropout = nn.Dropout(dropout)
21
+ self.register_buffer('pos_embedding', pos_embedding)
22
+
23
+ def forward(self, token_embedding: Tensor):
24
+ return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])
25
+
26
+ # helper Module to convert tensor of input indices into corresponding tensor of token embeddings
27
+ class TokenEmbedding(nn.Module):
28
+ def __init__(self, vocab_size: int, emb_size):
29
+ super(TokenEmbedding, self).__init__()
30
+ self.embedding = nn.Embedding(vocab_size, emb_size)
31
+ self.emb_size = emb_size
32
+
33
+ def forward(self, tokens: Tensor):
34
+ return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
35
+
36
+ class Seq2SeqTransformer(nn.Module):
37
+ def __init__(self,
38
+ num_encoder_layers: int,
39
+ num_decoder_layers: int,
40
+ emb_size: int,
41
+ nhead: int,
42
+ src_vocab_size: int,
43
+ tgt_vocab_size: int,
44
+ dim_feedforward: int = 512,
45
+ dropout: float = 0.1):
46
+ super(Seq2SeqTransformer, self).__init__()
47
+ self.transformer = Transformer(d_model=emb_size,
48
+ nhead=nhead,
49
+ num_encoder_layers=num_encoder_layers,
50
+ num_decoder_layers=num_decoder_layers,
51
+ dim_feedforward=dim_feedforward,
52
+ dropout=dropout,
53
+ batch_first=True)
54
+ self.generator = nn.Linear(emb_size, tgt_vocab_size)
55
+ self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
56
+ self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
57
+ self.positional_encoding = PositionalEncoding(
58
+ emb_size, dropout=dropout)
59
+
60
+ def forward(self,
61
+ src: Tensor,
62
+ trg: Tensor,
63
+ src_mask: Tensor,
64
+ tgt_mask: Tensor,
65
+ src_padding_mask: Tensor,
66
+ tgt_padding_mask: Tensor,
67
+ memory_key_padding_mask: Tensor):
68
+ src_emb = self.positional_encoding(self.src_tok_emb(src))
69
+ tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
70
+ outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
71
+ src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
72
+ return self.generator(outs)
73
+
74
+ def encode(self, src: Tensor, src_mask: Tensor):
75
+ return self.transformer.encoder(self.positional_encoding(
76
+ self.src_tok_emb(src)), src_mask)
77
+
78
+ def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
79
+ return self.transformer.decoder(self.positional_encoding(
80
+ self.tgt_tok_emb(tgt)), memory,
81
+ tgt_mask)
82
+
83
+ import yaml
84
+ from transformers import AutoTokenizer
85
+ from transformers import PreTrainedTokenizerFast
86
+ from tokenizers.processors import TemplateProcessing
87
+
88
+
89
+ def addPreprocessing(tokenizer):
90
+ tokenizer._tokenizer.post_processor = TemplateProcessing(
91
+ single=tokenizer.bos_token + " $A " + tokenizer.eos_token,
92
+ special_tokens=[(tokenizer.eos_token, tokenizer.eos_token_id), (tokenizer.bos_token, tokenizer.bos_token_id)])
93
+
94
+ def load_model(model_checkpoint_dir='model.pt',config_dir='config.yaml'):
95
+
96
+ with open(config_dir, 'r') as yaml_file:
97
+ loaded_model_params = yaml.safe_load(yaml_file)
98
+
99
+ # Create a new instance of the model with the loaded configuration
100
+ model = Seq2SeqTransformer(
101
+ loaded_model_params["num_encoder_layers"],
102
+ loaded_model_params["num_decoder_layers"],
103
+ loaded_model_params["emb_size"],
104
+ loaded_model_params["nhead"],
105
+ loaded_model_params["source_vocab_size"],
106
+ loaded_model_params["target_vocab_size"],
107
+ loaded_model_params["ffn_hid_dim"]
108
+ )
109
+
110
+ checkpoint = torch.load(model_checkpoint_dir) if torch.cuda.is_available() else torch.load(model_checkpoint_dir,map_location=torch.device('cpu'))
111
+ model.load_state_dict(checkpoint)
112
+
113
+ return model
114
+
115
+
116
+ def greedy_decode(model, src, src_mask, max_len, start_symbol):
117
+ # Move inputs to the device
118
+ src = src.to(device)
119
+ src_mask = src_mask.to(device)
120
+
121
+ # Encode the source sequence
122
+ memory = model.encode(src, src_mask)
123
+
124
+ # Initialize the target sequence with the start symbol
125
+ ys = torch.tensor([[start_symbol]]).type(torch.long).to(device)
126
+
127
+ for i in range(max_len - 1):
128
+ memory = memory.to(device)
129
+ # Create a target mask for autoregressive decoding
130
+ tgt_mask = torch.tril(torch.full((ys.size(1), ys.size(1)), float('-inf'), device=device), diagonal=-1).transpose(0, 1).to(device)
131
+ # Decode the target sequence
132
+ out = model.decode(ys, memory, tgt_mask)
133
+ # Generate the probability distribution over the vocabulary
134
+ prob = model.generator(out[:, -1])
135
+
136
+ # Select the next word with the highest probability
137
+ _, next_word = torch.max(prob, dim=1)
138
+ next_word = next_word.item()
139
+
140
+ # Append the next word to the target sequence
141
+ ys = torch.cat([ys,
142
+ torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
143
+
144
+ # Check if the generated word is the end-of-sequence token
145
+ if next_word == target_tokenizer.eos_token_id:
146
+ break
147
+
148
+ return ys
149
+
150
+
151
+ def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,length_penalty):
152
+ # Move inputs to the device
153
+ src = src.to(device)
154
+ src_mask = src_mask.to(device)
155
+
156
+ # Encode the source sequence
157
+ memory = model.encode(src, src_mask) # b * seqlen_src * hdim
158
+
159
+ # Initialize the beams (sequences, score)
160
+ beams = [(torch.tensor([[start_symbol]]).type(torch.long).to(device), 0)]
161
+
162
+ for i in range(max_len - 1):
163
+ new_beams = []
164
+ complete_beams = []
165
+ cbl = []
166
+
167
+ for ys, score in beams:
168
+
169
+ # Create a target mask for autoregressive decoding
170
+ tgt_mask = torch.tril(torch.full((ys.size(1), ys.size(1)), float('-inf'), device=device), diagonal=-1).transpose(0, 1).to(device)
171
+ # Decode the target sequence
172
+ out = model.decode(ys, memory, tgt_mask) # b * seqlen_tgt * hdim
173
+ #print(f'shape out {out.shape}')
174
+ # Generate the probability distribution over the vocabulary
175
+ prob = model.generator(out[:, -1]) # b * tgt_vocab_size
176
+ #print(f'shape prob {prob.shape}')
177
+
178
+ # Get the top beam_size candidates for the next word
179
+ _, top_indices = torch.topk(prob, beam_size, dim=1) # b * beam_size
180
+
181
+ for j,next_word in enumerate(top_indices[0]):
182
+
183
+ next_word = next_word.item()
184
+
185
+ # Append the next word to the target sequence
186
+ new_ys = torch.cat([ys, torch.full((1, 1), fill_value=next_word, dtype=src.dtype).to(device)], dim=1)
187
+
188
+ length_factor = (5 + j / 6) ** length_penalty
189
+ new_score = (score + prob[0][next_word].item()) / length_factor
190
+
191
+ if next_word == target_tokenizer.eos_token_id:
192
+ complete_beams.append((new_ys, new_score))
193
+ else:
194
+ new_beams.append((new_ys, new_score))
195
+
196
+
197
+ # Sort the beams by score and select the top beam_size beams
198
+ new_beams.sort(key=lambda x: x[1], reverse=True)
199
+ try:
200
+ beams = new_beams[:beam_size]
201
+ except:
202
+ beams = new_beams
203
+
204
+ beams = new_beams + complete_beams
205
+ beams.sort(key=lambda x: x[1], reverse=True)
206
+
207
+ best_beam = beams[0][0]
208
+ return best_beam
209
+
210
+ def translate(model: torch.nn.Module, strategy:str = 'greedy' , src_sentence: str, lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6):
211
+ assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
212
+ # Tokenize the source sentence
213
+ src = source_tokenizer(src_sentence, **token_config)['input_ids']
214
+ num_tokens = src.shape[1]
215
+ # Create a source mask
216
+ src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
217
+ if strategy == 'greedy':
218
+ tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten()
219
+ # Generate the target tokens using beam search decoding
220
+ else:
221
+ tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
222
+ # Decode the target tokens and clean up the result
223
+ return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
224
+
225
+ special_tokens = {'unk_token':"[UNK]",
226
+ 'cls_token':"[CLS]",
227
+ 'eos_token': '[EOS]',
228
+ 'sep_token':"[SEP]",
229
+ 'pad_token':"[PAD]",
230
+ 'mask_token':"[MASK]",
231
+ 'bos_token':"[BOS]"}
232
+
233
+ source_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", **special_tokens)
234
+ target_tokenizer = PreTrainedTokenizerFast.from_pretrained('Sifal/E2KT')
235
+
236
+ addPreprocessing(source_tokenizer)
237
+ addPreprocessing(target_tokenizer)
238
+
239
+ token_config = {
240
+ "add_special_tokens": True,
241
+ "return_tensors": True,
242
+ }
243
+
244
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
245
+
246
+ model = load_model()
247
+ model.to(device)
248
+ model.eval()
249
+
250
  import gradio as gr
 
251
 
252
  x = lambda text : translate(x)
253