nnmthuw commited on
Commit
60b414a
·
1 Parent(s): 9c2ca71

commit all

Browse files
app.py CHANGED
@@ -314,25 +314,27 @@ HID_DIM = 512
314
 
315
  # Load our Model Translation
316
  ENCODER = EncoderAtt(INPUT_DIM, HID_DIM)
317
- #ENCODER.load_state_dict(torch.load("hid512_encoder_att_epoch_20.pt"), map_location=torch.device('cpu'))
318
  DECODER = DecoderAtt(HID_DIM, OUTPUT_DIM)
319
- #DECODER.load_state_dict(torch.load("hid512_decoder_att_epoch_20.pt"), map_location=torch.device('cpu'))
320
 
321
 
322
- def evaluate_final_model(encoder, decoder, sentence, vocab_source, vocab_target, disable=False):
 
 
 
 
 
 
 
 
323
  encoder.eval()
324
  decoder.eval()
325
  with torch.no_grad():
326
- input_tensor = (
327
- vocab_source.corpus_to_tensor([sentence], disable=disable)[0]
328
- .view(1, -1)
329
- .to(device)
330
- )
331
 
332
  encoder_outputs, encoder_hidden = encoder(input_tensor)
333
- decoder_outputs, decoder_hidden, decoder_attn = decoder(
334
- encoder_outputs, encoder_hidden
335
- )
336
 
337
  _, topi = decoder_outputs.topk(1)
338
  decoded_ids = topi.squeeze()
@@ -340,20 +342,23 @@ def evaluate_final_model(encoder, decoder, sentence, vocab_source, vocab_target,
340
  decoded_words = []
341
  for idx in decoded_ids:
342
  if idx.item() == vocab_target.eos_id:
343
- decoded_words.append("<eos>")
344
  break
345
  decoded_words.append(vocab_target.id2word[idx.item()])
346
  return decoded_words, decoder_attn
347
 
348
-
349
- def my_translation(sentence):
350
  output_words, _ = evaluate_final_model(sentence, ENCODER, DECODER, VOCAB_SOURCE, VOCAB_TARGET, disable= True)
351
- output_words = output_words.remove("<pad>")
352
- output_words = output_words.remove("<unk>")
353
- output_words = output_words.remove("<sos>")
354
- output_words = output_words.remove("<eos>")
 
 
 
 
355
 
356
- return ' '.join(output_words[1:-1]).capitalize()
357
 
358
 
359
  def envit5_translation(text):
@@ -366,10 +371,10 @@ def envit5_translation(text):
366
 
367
 
368
  def translation(text):
 
 
369
  if not text.endswith(('.', '!', '?')):
370
  text = text + '.'
371
- #output1 = my_translation(text)
372
- output1 = "Something"
373
  output2 = envit5_translation(text)
374
 
375
  return (output1, output2)
@@ -401,4 +406,4 @@ if __name__ == "__main__":
401
  ]
402
  )
403
 
404
- demo.launch()
 
314
 
315
  # Load our Model Translation
316
  ENCODER = EncoderAtt(INPUT_DIM, HID_DIM)
317
+ ENCODER.load_state_dict(torch.load("encoderatt_epoch_35.pt", map_location=torch.device('cpu')))
318
  DECODER = DecoderAtt(HID_DIM, OUTPUT_DIM)
319
+ DECODER.load_state_dict(torch.load("decoderatt_epoch_35.pt", map_location=torch.device('cpu')))
320
 
321
 
322
+ def evaluate_final_model(sentence, encoder, decoder, vocab_source, vocab_target, disable = False):
323
+ """ Evaluation Model
324
+ @param encoder (EncoderAtt)
325
+ @param decoder (DecoderAtt)
326
+ @param sentence (str)
327
+ @param vocab_source (Vocabulary)
328
+ @param vocab_target (Vocabulary)
329
+ @param disable (bool)
330
+ """
331
  encoder.eval()
332
  decoder.eval()
333
  with torch.no_grad():
334
+ input_tensor = vocab_source.corpus_to_tensor([sentence], disable = disable)[0].view(1,-1).to(device)
 
 
 
 
335
 
336
  encoder_outputs, encoder_hidden = encoder(input_tensor)
337
+ decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)
 
 
338
 
339
  _, topi = decoder_outputs.topk(1)
340
  decoded_ids = topi.squeeze()
 
342
  decoded_words = []
343
  for idx in decoded_ids:
344
  if idx.item() == vocab_target.eos_id:
345
+ decoded_words.append('<eos>')
346
  break
347
  decoded_words.append(vocab_target.id2word[idx.item()])
348
  return decoded_words, decoder_attn
349
 
350
+ def translate_sentence(sentence):
 
351
  output_words, _ = evaluate_final_model(sentence, ENCODER, DECODER, VOCAB_SOURCE, VOCAB_TARGET, disable= True)
352
+ if "<pad>" in output_words:
353
+ output_words.remove("<pad>")
354
+ if "<unk>" in output_words:
355
+ output_words.remove("<unk>")
356
+ if "<sos>" in output_words:
357
+ output_words.remove("<sos>")
358
+ if "<eos>" in output_words:
359
+ output_words.remove("<eos>")
360
 
361
+ return ' '.join(output_words).capitalize()
362
 
363
 
364
  def envit5_translation(text):
 
371
 
372
 
373
  def translation(text):
374
+ output1 = translate_sentence(text)
375
+
376
  if not text.endswith(('.', '!', '?')):
377
  text = text + '.'
 
 
378
  output2 = envit5_translation(text)
379
 
380
  return (output1, output2)
 
406
  ]
407
  )
408
 
409
+ demo.launch(share = True)
hid512_decoder_att_epoch_20.pt → decoderatt_epoch_35.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2b13f49e00d60a51226db3a66e343ef3b73eccf06e0efe771cac417e1994a706
3
- size 40323250
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dee556d5e874646355b36d8d2fa94daf70d1da8684f0d95e4eb5edee4fe1b881
3
+ size 43042290
hid512_encoder_att_epoch_20.pt → encoderatt_epoch_35.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ec38b650930515f30086a04a16285c88430ceed352cfbd52cc27e34b4283221a
3
- size 16096464
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0879cb7c5a5359360b70195e61c9e91cd9f8caa0c6296f5924a7b9530f1350b0
3
+ size 16437536
temp.ipynb DELETED
@@ -1,569 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "metadata": {},
7
- "outputs": [
8
- {
9
- "name": "stdout",
10
- "output_type": "stream",
11
- "text": [
12
- "WARNING:tensorflow:From c:\\Users\\THU\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n",
13
- "\n"
14
- ]
15
- }
16
- ],
17
- "source": [
18
- "import gradio as gr\n",
19
- "from transformers import pipeline \n",
20
- "import re\n",
21
- "import pickle \n",
22
- "import torch\n",
23
- "import torch.nn as nn\n",
24
- "from torchtext.transforms import PadTransform\n",
25
- "from torch.utils.data import Dataset, DataLoader\n",
26
- "from torch.nn import functional as F\n",
27
- "from tqdm import tqdm\n",
28
- "from underthesea import word_tokenize, text_normalize"
29
- ]
30
- },
31
- {
32
- "cell_type": "code",
33
- "execution_count": 7,
34
- "metadata": {},
35
- "outputs": [
36
- {
37
- "name": "stdout",
38
- "output_type": "stream",
39
- "text": [
40
- "Running on local URL: http://127.0.0.1:7864\n",
41
- "\n",
42
- "To create a public link, set `share=True` in `launch()`.\n"
43
- ]
44
- },
45
- {
46
- "data": {
47
- "text/html": [
48
- "<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
49
- ],
50
- "text/plain": [
51
- "<IPython.core.display.HTML object>"
52
- ]
53
- },
54
- "metadata": {},
55
- "output_type": "display_data"
56
- },
57
- {
58
- "data": {
59
- "text/plain": []
60
- },
61
- "execution_count": 7,
62
- "metadata": {},
63
- "output_type": "execute_result"
64
- }
65
- ],
66
- "source": [
67
- "import gradio as gr\n",
68
- "\n",
69
- "def translation(text):\n",
70
- " output1 = 1\n",
71
- " output2 = 2\n",
72
- " #output3 = finetune_BERT(text)\n",
73
- "\n",
74
- " return (output1, output2)\n",
75
- "\n",
76
- "\n",
77
- "\n",
78
- "examples = [[\"Input: Hello guys\"], \n",
79
- " [\"Output: Xin chào các bạn\"]]\n",
80
- "\n",
81
- "demo = gr.Interface(\n",
82
- " theme = gr.themes.Base(),\n",
83
- " fn=translation,\n",
84
- " title=\"Co Gai Mo Duong\",\n",
85
- " description=\"\"\"\n",
86
- " ## Machine Translation: English to Vietnamese\n",
87
- " \"\"\",\n",
88
- " examples=examples,\n",
89
- " inputs=[\n",
90
- " gr.Textbox(\n",
91
- " lines=5, placeholder=\"Enter text\", label=\"Input\"\n",
92
- " )\n",
93
- " ],\n",
94
- " outputs=[\n",
95
- " gr.Textbox(\n",
96
- " \"text\", label=\"Our Machine Translation\"\n",
97
- " ),\n",
98
- " gr.Textbox(\n",
99
- " \"text\", label=\"VietAI Machine Translation\"\n",
100
- " )\n",
101
- " ]\n",
102
- ")\n",
103
- "\n",
104
- "demo.launch(shared = True)"
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": 2,
110
- "metadata": {},
111
- "outputs": [],
112
- "source": [
113
- "# Build Vocabulary\n",
114
- "MAX_LENGTH = 30\n",
115
- "#device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
116
- "device = 'cpu'"
117
- ]
118
- },
119
- {
120
- "cell_type": "code",
121
- "execution_count": 3,
122
- "metadata": {},
123
- "outputs": [],
124
- "source": [
125
- "class Vocabulary:\n",
126
- " \"\"\"The Vocabulary class is used to record words, which are used to convert\n",
127
- " text to numbers and vice versa.\n",
128
- " \"\"\"\n",
129
- "\n",
130
- " def __init__(self, lang=\"vi\"):\n",
131
- " self.lang = lang\n",
132
- " self.word2id = dict()\n",
133
- " self.word2id[\"<sos>\"] = 0 # Start of Sentece Token\n",
134
- " self.word2id[\"<eos>\"] = 1 # End of Sentence Token\n",
135
- " self.word2id[\"<unk>\"] = 2 # Unknown Token\n",
136
- " self.word2id[\"<pad>\"] = 3 # Pad Token\n",
137
- " self.sos_id = self.word2id[\"<sos>\"]\n",
138
- " self.eos_id = self.word2id[\"<eos>\"]\n",
139
- " self.unk_id = self.word2id[\"<unk>\"]\n",
140
- " self.pad_id = self.word2id[\"<pad>\"]\n",
141
- " self.id2word = {v: k for k, v in self.word2id.items()}\n",
142
- " self.pad_transform = PadTransform(max_length = MAX_LENGTH, pad_value = self.pad_id)\n",
143
- "\n",
144
- " def __getitem__(self, word):\n",
145
- " \"\"\"Return ID of word if existed else return ID unknown token\n",
146
- " @param word (str)\n",
147
- " \"\"\"\n",
148
- " return self.word2id.get(word, self.unk_id)\n",
149
- "\n",
150
- " def __contains__(self, word):\n",
151
- " \"\"\"Return True if word in Vocabulary else return False\n",
152
- " @param word (str)\n",
153
- " \"\"\"\n",
154
- " return word in self.word2id\n",
155
- "\n",
156
- " def __len__(self):\n",
157
- " \"\"\"\n",
158
- " Return number of tokens(include sos, eos, unk and pad tokens) in Vocabulary\n",
159
- " \"\"\"\n",
160
- " return len(self.word2id)\n",
161
- "\n",
162
- " def lookup_tokens(self, word_indexes: list):\n",
163
- " \"\"\"Return the list of words by lookup by ID\n",
164
- " @param word_indexes (list(int))\n",
165
- " @return words (list(str))\n",
166
- " \"\"\"\n",
167
- " return [self.id2word[word_index] for word_index in word_indexes]\n",
168
- "\n",
169
- " def add(self, word):\n",
170
- " \"\"\"Add word to vocabulary\n",
171
- " @param word (str)\n",
172
- " @return index (str): index of the word just added\n",
173
- " \"\"\"\n",
174
- " if word not in self:\n",
175
- " word_index = self.word2id[word] = len(self.word2id)\n",
176
- " self.id2word[word_index] = word\n",
177
- " return word_index\n",
178
- " else:\n",
179
- " return self[word]\n",
180
- "\n",
181
- " def preprocessing_sent(self, sent, lang=\"en\"):\n",
182
- " \"\"\"Preprocessing a sentence (depend on language english or vietnamese)\"\"\"\n",
183
- "\n",
184
- " if (lang == \"en\") or (lang == \"eng\") or (lang == \"english\"):\n",
185
- " # Remove unnecessary space\n",
186
- " sent = re.sub(\" +\", \" \", sent)\n",
187
- "\n",
188
- " # Replace short form\n",
189
- " sent = re.sub(\"&apos;m \", \"am \", sent)\n",
190
- " # Dont know to preprocess with possessive case\n",
191
- " sent = re.sub(\"&apos;s \", \"is \", sent)\n",
192
- " sent = re.sub(\"&apos;re \", \"are \", sent)\n",
193
- " sent = re.sub(\"&apos;ve \", \"have \", sent)\n",
194
- " sent = re.sub(\"&apos;ll \", \"will \", sent)\n",
195
- " sent = re.sub(\"&apos;d \", \"would \", sent)\n",
196
- "\n",
197
- " sent = re.sub(\"aren &apos;t\", \"are not\", sent)\n",
198
- " sent = re.sub(\"isn &apos;t\", \"is not\", sent)\n",
199
- " sent = re.sub(\"don &apos;t\", \"do not\", sent)\n",
200
- " sent = re.sub(\"doesn &apos;t\", \"does not\", sent)\n",
201
- " sent = re.sub(\"wasn &apos;t\", \"was not\", sent)\n",
202
- " sent = re.sub(\"weren &apos;t\", \"were not\", sent)\n",
203
- " sent = re.sub(\"won &apos;t\", \"will not\", sent)\n",
204
- " sent = re.sub(\"can &apos;t\", \"can not\", sent)\n",
205
- " sent = re.sub(\"let &apos;s\", \"let us\", sent)\n",
206
- "\n",
207
- " else:\n",
208
- " # Package underthesea.text_normalize support to normalize vietnamese\n",
209
- " sent = text_normalize(sent)\n",
210
- "\n",
211
- " sent = re.sub(\"&apos;\", \"'\", sent)\n",
212
- " sent = re.sub(\"&quot;\", '\"', sent)\n",
213
- " sent = re.sub(\"&#91;\", \"[\", sent)\n",
214
- " sent = re.sub(\"&#93;\", \"]\", sent)\n",
215
- " \n",
216
- " # Lowercase sentence and remove space at beginning and ending\n",
217
- " return sent.lower().strip()\n",
218
- "\n",
219
- " def tokenize_corpus(self, corpus, disable=False):\n",
220
- " \"\"\"Split the documents of the corpus into words\n",
221
- " @param corpus (list(str)): list of documents\n",
222
- " @return tokenized_corpus (list(list(str))): list of words\n",
223
- " \"\"\"\n",
224
- " if not disable:\n",
225
- " print(\"Tokenize the corpus...\")\n",
226
- " tokenized_corpus = list()\n",
227
- " for document in tqdm(corpus, disable=disable):\n",
228
- " tokenized_document = [\"<sos>\"] + self.preprocessing_sent(document).split(\" \") + [\"<eos>\"]\n",
229
- " tokenized_corpus.append(tokenized_document)\n",
230
- " return tokenized_corpus\n",
231
- "\n",
232
- " def corpus_to_tensor(self, corpus, is_tokenized=False, disable=False):\n",
233
- " \"\"\"Convert corpus to a list of indices tensor\n",
234
- " @param corpus (list(str) if is_tokenized==False else list(list(str)))\n",
235
- " @param is_tokenized (bool)\n",
236
- " @return indicies_corpus (list(tensor))\n",
237
- " \"\"\"\n",
238
- " if is_tokenized:\n",
239
- " tokenized_corpus = corpus\n",
240
- " else:\n",
241
- " tokenized_corpus = self.tokenize_corpus(corpus, disable=disable)\n",
242
- " indicies_corpus = list()\n",
243
- " for document in tqdm(tokenized_corpus, disable=disable):\n",
244
- " indicies_document = torch.tensor(\n",
245
- " list(map(lambda word: self[word], document)), dtype=torch.int64\n",
246
- " )\n",
247
- " \n",
248
- " indicies_corpus.append(self.pad_transform(indicies_document))\n",
249
- "\n",
250
- " return indicies_corpus\n",
251
- "\n",
252
- " def tensor_to_corpus(self, tensor, disable=False):\n",
253
- " \"\"\"Convert list of indices tensor to a list of tokenized documents\n",
254
- " @param indicies_corpus (list(tensor))\n",
255
- " @return corpus (list(list(str)))\n",
256
- " \"\"\"\n",
257
- " corpus = list()\n",
258
- " for indicies in tqdm(tensor, disable=disable):\n",
259
- " document = list(map(lambda index: self.id2word[index.item()], indicies))\n",
260
- " corpus.append(document)\n",
261
- "\n",
262
- " return corpus"
263
- ]
264
- },
265
- {
266
- "cell_type": "code",
267
- "execution_count": 4,
268
- "metadata": {},
269
- "outputs": [],
270
- "source": [
271
- "def create_input_emb_layer():\n",
272
- " num_embeddings, embedding_dim = 32998, 100\n",
273
- " emb_layer = nn.Embedding(num_embeddings, embedding_dim)\n",
274
- " emb_layer.weight.requires_grad = False\n",
275
- "\n",
276
- " return emb_layer, embedding_dim\n",
277
- "\n",
278
- "def create_output_emb_layer():\n",
279
- " num_embeddings, embedding_dim = 15405, 100\n",
280
- " emb_layer = nn.Embedding(num_embeddings, embedding_dim)\n",
281
- " emb_layer.weight.requires_grad = False\n",
282
- "\n",
283
- " return emb_layer, embedding_dim\n",
284
- " \n",
285
- "class EncoderRNN(nn.Module):\n",
286
- " def __init__(self, input_dim, hidden_dim, dropout = 0.2):\n",
287
- " super(EncoderRNN, self).__init__()\n",
288
- " \n",
289
- " self.hidden_dim = hidden_dim\n",
290
- " #self.embedding = nn.Embedding(input_dim, hidden_dim)\n",
291
- " # Đổi thành input embedding\n",
292
- " self.embedding, self.embedding_dim = create_input_emb_layer()\n",
293
- " self.gru = nn.GRU(self.embedding_dim, hidden_dim, batch_first=True)\n",
294
- " self.dropout = nn.Dropout(dropout)\n",
295
- "\n",
296
- " def forward(self, src):\n",
297
- " embedded = self.dropout(self.embedding(src))\n",
298
- " output, hidden = self.gru(embedded)\n",
299
- " return output, hidden\n",
300
- " \n",
301
- "class BahdanauAttention(nn.Module):\n",
302
- " def __init__(self, hidden_size):\n",
303
- " super(BahdanauAttention, self).__init__()\n",
304
- " self.Wa = nn.Linear(hidden_size, hidden_size)\n",
305
- " self.Ua = nn.Linear(hidden_size, hidden_size)\n",
306
- " self.Va = nn.Linear(hidden_size, 1)\n",
307
- "\n",
308
- " def forward(self, query, keys):\n",
309
- " scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
310
- " scores = scores.squeeze(2).unsqueeze(1)\n",
311
- "\n",
312
- " weights = F.softmax(scores, dim=-1)\n",
313
- " context = torch.bmm(weights, keys)\n",
314
- "\n",
315
- " return context, weights\n",
316
- "\n",
317
- "class AttnDecoderRNN(nn.Module):\n",
318
- " def __init__(self, hidden_size, output_size, dropout_p=0.1):\n",
319
- " super(AttnDecoderRNN, self).__init__()\n",
320
- " # self.embedding = nn.Embedding(output_size, hidden_size)\n",
321
- " # Đổi thành output embedding\n",
322
- " self.embedding, self.embedding_dim = create_output_emb_layer()\n",
323
- " self.fc = nn.Linear(self.embedding_dim, hidden_size)\n",
324
- " self.attention = BahdanauAttention(hidden_size)\n",
325
- " self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
326
- " self.out = nn.Linear(hidden_size, output_size)\n",
327
- " self.dropout = nn.Dropout(dropout_p)\n",
328
- "\n",
329
- " def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
330
- " batch_size = encoder_outputs.size(0)\n",
331
- " decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(0)\n",
332
- " decoder_hidden = encoder_hidden\n",
333
- " decoder_outputs = []\n",
334
- " attentions = []\n",
335
- "\n",
336
- " for i in range(MAX_LENGTH):\n",
337
- " decoder_output, decoder_hidden, attn_weights = self.forward_step(\n",
338
- " decoder_input, decoder_hidden, encoder_outputs\n",
339
- " )\n",
340
- " decoder_outputs.append(decoder_output)\n",
341
- " attentions.append(attn_weights)\n",
342
- "\n",
343
- " if target_tensor is not None:\n",
344
- " # Teacher forcing: Feed the target as the next input\n",
345
- " decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
346
- " else:\n",
347
- " # Without teacher forcing: use its own predictions as the next input\n",
348
- " _, topi = decoder_output.topk(1)\n",
349
- " decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
350
- "\n",
351
- " decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
352
- " decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
353
- " attentions = torch.cat(attentions, dim=1)\n",
354
- "\n",
355
- " return decoder_outputs, decoder_hidden, attentions\n",
356
- "\n",
357
- "\n",
358
- " def forward_step(self, input, hidden, encoder_outputs):\n",
359
- " embedded = self.dropout(self.fc(self.embedding(input)))\n",
360
- " \n",
361
- " query = hidden.permute(1, 0, 2)\n",
362
- " context, attn_weights = self.attention(query, encoder_outputs)\n",
363
- " input_gru = torch.cat((embedded, context), dim=2)\n",
364
- "\n",
365
- " output, hidden = self.gru(input_gru, hidden)\n",
366
- " output = self.out(output)\n",
367
- "\n",
368
- " return output, hidden, attn_weights"
369
- ]
370
- },
371
- {
372
- "cell_type": "code",
373
- "execution_count": null,
374
- "metadata": {},
375
- "outputs": [],
376
- "source": []
377
- },
378
- {
379
- "cell_type": "code",
380
- "execution_count": 41,
381
- "metadata": {},
382
- "outputs": [
383
- {
384
- "data": {
385
- "text/plain": [
386
- "<All keys matched successfully>"
387
- ]
388
- },
389
- "execution_count": 41,
390
- "metadata": {},
391
- "output_type": "execute_result"
392
- }
393
- ],
394
- "source": [
395
- "with open(\"vocab_source.pkl\", \"rb\") as file:\n",
396
- " VOCAB_SOURCE = pickle.load(file)\n",
397
- "with open(\"vocab_target.pkl\", \"rb\") as file:\n",
398
- " VOCAB_TARGET = pickle.load(file)\n",
399
- "\n",
400
- "INPUT_DIM = len(VOCAB_SOURCE)\n",
401
- "OUTPUT_DIM = len(VOCAB_TARGET)\n",
402
- "HID_DIM = 512\n",
403
- "\n",
404
- "# Load our Model Translation\n",
405
- "ENCODER = EncoderRNN(INPUT_DIM, HID_DIM)\n",
406
- "ENCODER.load_state_dict(torch.load('encoder_att_epoch_16.pt'))\n",
407
- "DECODER = AttnDecoderRNN(HID_DIM, OUTPUT_DIM)\n",
408
- "DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
409
- ]
410
- },
411
- {
412
- "cell_type": "code",
413
- "execution_count": 42,
414
- "metadata": {},
415
- "outputs": [],
416
- "source": [
417
- "def evaluate(encoder, decoder, sentence, vocab_source, vocab_target, disable = False):\n",
418
- " encoder.eval()\n",
419
- " decoder.eval()\n",
420
- " with torch.no_grad():\n",
421
- " input_tensor = vocab_source.corpus_to_tensor([sentence], disable = disable)[0].view(1,-1).to(device)\n",
422
- " \n",
423
- " encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
424
- " decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n",
425
- "\n",
426
- " _, topi = decoder_outputs.topk(1)\n",
427
- " decoded_ids = topi.squeeze()\n",
428
- "\n",
429
- " decoded_words = []\n",
430
- " for idx in decoded_ids:\n",
431
- " if idx.item() == vocab_target.eos_id:\n",
432
- " decoded_words.append('<eos>')\n",
433
- " break\n",
434
- " decoded_words.append(vocab_target.id2word[idx.item()])\n",
435
- " return decoded_words, decoder_attn\n",
436
- "\n",
437
- "def my_translate_model(sentence):\n",
438
- " output_words, _ = evaluate(ENCODER, DECODER, sentence, VOCAB_SOURCE, VOCAB_TARGET, disable= True)\n",
439
- " \n",
440
- " return ' '.join(output_words[1:-1]).capitalize()+ '.'"
441
- ]
442
- },
443
- {
444
- "cell_type": "code",
445
- "execution_count": 61,
446
- "metadata": {},
447
- "outputs": [
448
- {
449
- "data": {
450
- "text/plain": [
451
- "'Tôi hy vọng các bạn sẽ có thể làm được giải pháp.'"
452
- ]
453
- },
454
- "execution_count": 61,
455
- "metadata": {},
456
- "output_type": "execute_result"
457
- }
458
- ],
459
- "source": [
460
- "my_translate_model(\"I hope you will be better\")"
461
- ]
462
- },
463
- {
464
- "cell_type": "code",
465
- "execution_count": 60,
466
- "metadata": {},
467
- "outputs": [
468
- {
469
- "data": {
470
- "text/plain": [
471
- "<All keys matched successfully>"
472
- ]
473
- },
474
- "execution_count": 60,
475
- "metadata": {},
476
- "output_type": "execute_result"
477
- }
478
- ],
479
- "source": [
480
- "ENCODER = EncoderRNN(INPUT_DIM, HID_DIM)\n",
481
- "ENCODER.load_state_dict(torch.load('encoder_att_epoch_16.pt'))\n",
482
- "DECODER = AttnDecoderRNN(HID_DIM, OUTPUT_DIM)\n",
483
- "DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
484
- ]
485
- },
486
- {
487
- "cell_type": "code",
488
- "execution_count": 48,
489
- "metadata": {},
490
- "outputs": [
491
- {
492
- "data": {
493
- "text/plain": [
494
- "odict_keys(['embedding.weight', 'fc.weight', 'fc.bias', 'attention.Wa.weight', 'attention.Wa.bias', 'attention.Ua.weight', 'attention.Ua.bias', 'attention.Va.weight', 'attention.Va.bias', 'gru.weight_ih_l0', 'gru.weight_hh_l0', 'gru.bias_ih_l0', 'gru.bias_hh_l0', 'out.weight', 'out.bias'])"
495
- ]
496
- },
497
- "execution_count": 48,
498
- "metadata": {},
499
- "output_type": "execute_result"
500
- }
501
- ],
502
- "source": [
503
- "torch.load('decoder_att_epoch_16.pt').keys()"
504
- ]
505
- },
506
- {
507
- "cell_type": "code",
508
- "execution_count": 52,
509
- "metadata": {},
510
- "outputs": [
511
- {
512
- "data": {
513
- "text/plain": [
514
- "<All keys matched successfully>"
515
- ]
516
- },
517
- "execution_count": 52,
518
- "metadata": {},
519
- "output_type": "execute_result"
520
- }
521
- ],
522
- "source": [
523
- "DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
524
- ]
525
- },
526
- {
527
- "cell_type": "code",
528
- "execution_count": 57,
529
- "metadata": {},
530
- "outputs": [
531
- {
532
- "data": {
533
- "text/plain": [
534
- "<All keys matched successfully>"
535
- ]
536
- },
537
- "execution_count": 57,
538
- "metadata": {},
539
- "output_type": "execute_result"
540
- }
541
- ],
542
- "source": [
543
- "DECODER = AttnDecoderRNN(HID_DIM, OUTPUT_DIM)\n",
544
- "DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
545
- ]
546
- }
547
- ],
548
- "metadata": {
549
- "kernelspec": {
550
- "display_name": "Python 3",
551
- "language": "python",
552
- "name": "python3"
553
- },
554
- "language_info": {
555
- "codemirror_mode": {
556
- "name": "ipython",
557
- "version": 3
558
- },
559
- "file_extension": ".py",
560
- "mimetype": "text/x-python",
561
- "name": "python",
562
- "nbconvert_exporter": "python",
563
- "pygments_lexer": "ipython3",
564
- "version": "3.11.5"
565
- }
566
- },
567
- "nbformat": 4,
568
- "nbformat_minor": 2
569
- }