Spaces:
Sleeping
Sleeping
commit all
Browse files
app.py
CHANGED
@@ -314,25 +314,27 @@ HID_DIM = 512
|
|
314 |
|
315 |
# Load our Model Translation
|
316 |
ENCODER = EncoderAtt(INPUT_DIM, HID_DIM)
|
317 |
-
|
318 |
DECODER = DecoderAtt(HID_DIM, OUTPUT_DIM)
|
319 |
-
|
320 |
|
321 |
|
322 |
-
def evaluate_final_model(encoder, decoder,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
encoder.eval()
|
324 |
decoder.eval()
|
325 |
with torch.no_grad():
|
326 |
-
input_tensor = (
|
327 |
-
vocab_source.corpus_to_tensor([sentence], disable=disable)[0]
|
328 |
-
.view(1, -1)
|
329 |
-
.to(device)
|
330 |
-
)
|
331 |
|
332 |
encoder_outputs, encoder_hidden = encoder(input_tensor)
|
333 |
-
decoder_outputs, decoder_hidden, decoder_attn = decoder(
|
334 |
-
encoder_outputs, encoder_hidden
|
335 |
-
)
|
336 |
|
337 |
_, topi = decoder_outputs.topk(1)
|
338 |
decoded_ids = topi.squeeze()
|
@@ -340,20 +342,23 @@ def evaluate_final_model(encoder, decoder, sentence, vocab_source, vocab_target,
|
|
340 |
decoded_words = []
|
341 |
for idx in decoded_ids:
|
342 |
if idx.item() == vocab_target.eos_id:
|
343 |
-
decoded_words.append(
|
344 |
break
|
345 |
decoded_words.append(vocab_target.id2word[idx.item()])
|
346 |
return decoded_words, decoder_attn
|
347 |
|
348 |
-
|
349 |
-
def my_translation(sentence):
|
350 |
output_words, _ = evaluate_final_model(sentence, ENCODER, DECODER, VOCAB_SOURCE, VOCAB_TARGET, disable= True)
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
|
|
|
|
|
|
|
|
355 |
|
356 |
-
return ' '.join(output_words
|
357 |
|
358 |
|
359 |
def envit5_translation(text):
|
@@ -366,10 +371,10 @@ def envit5_translation(text):
|
|
366 |
|
367 |
|
368 |
def translation(text):
|
|
|
|
|
369 |
if not text.endswith(('.', '!', '?')):
|
370 |
text = text + '.'
|
371 |
-
#output1 = my_translation(text)
|
372 |
-
output1 = "Something"
|
373 |
output2 = envit5_translation(text)
|
374 |
|
375 |
return (output1, output2)
|
@@ -401,4 +406,4 @@ if __name__ == "__main__":
|
|
401 |
]
|
402 |
)
|
403 |
|
404 |
-
demo.launch()
|
|
|
314 |
|
315 |
# Load our Model Translation
|
316 |
ENCODER = EncoderAtt(INPUT_DIM, HID_DIM)
|
317 |
+
ENCODER.load_state_dict(torch.load("encoderatt_epoch_35.pt", map_location=torch.device('cpu')))
|
318 |
DECODER = DecoderAtt(HID_DIM, OUTPUT_DIM)
|
319 |
+
DECODER.load_state_dict(torch.load("decoderatt_epoch_35.pt", map_location=torch.device('cpu')))
|
320 |
|
321 |
|
322 |
+
def evaluate_final_model(sentence, encoder, decoder, vocab_source, vocab_target, disable = False):
|
323 |
+
""" Evaluation Model
|
324 |
+
@param encoder (EncoderAtt)
|
325 |
+
@param decoder (DecoderAtt)
|
326 |
+
@param sentence (str)
|
327 |
+
@param vocab_source (Vocabulary)
|
328 |
+
@param vocab_target (Vocabulary)
|
329 |
+
@param disable (bool)
|
330 |
+
"""
|
331 |
encoder.eval()
|
332 |
decoder.eval()
|
333 |
with torch.no_grad():
|
334 |
+
input_tensor = vocab_source.corpus_to_tensor([sentence], disable = disable)[0].view(1,-1).to(device)
|
|
|
|
|
|
|
|
|
335 |
|
336 |
encoder_outputs, encoder_hidden = encoder(input_tensor)
|
337 |
+
decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)
|
|
|
|
|
338 |
|
339 |
_, topi = decoder_outputs.topk(1)
|
340 |
decoded_ids = topi.squeeze()
|
|
|
342 |
decoded_words = []
|
343 |
for idx in decoded_ids:
|
344 |
if idx.item() == vocab_target.eos_id:
|
345 |
+
decoded_words.append('<eos>')
|
346 |
break
|
347 |
decoded_words.append(vocab_target.id2word[idx.item()])
|
348 |
return decoded_words, decoder_attn
|
349 |
|
350 |
+
def translate_sentence(sentence):
|
|
|
351 |
output_words, _ = evaluate_final_model(sentence, ENCODER, DECODER, VOCAB_SOURCE, VOCAB_TARGET, disable= True)
|
352 |
+
if "<pad>" in output_words:
|
353 |
+
output_words.remove("<pad>")
|
354 |
+
if "<unk>" in output_words:
|
355 |
+
output_words.remove("<unk>")
|
356 |
+
if "<sos>" in output_words:
|
357 |
+
output_words.remove("<sos>")
|
358 |
+
if "<eos>" in output_words:
|
359 |
+
output_words.remove("<eos>")
|
360 |
|
361 |
+
return ' '.join(output_words).capitalize()
|
362 |
|
363 |
|
364 |
def envit5_translation(text):
|
|
|
371 |
|
372 |
|
373 |
def translation(text):
|
374 |
+
output1 = translate_sentence(text)
|
375 |
+
|
376 |
if not text.endswith(('.', '!', '?')):
|
377 |
text = text + '.'
|
|
|
|
|
378 |
output2 = envit5_translation(text)
|
379 |
|
380 |
return (output1, output2)
|
|
|
406 |
]
|
407 |
)
|
408 |
|
409 |
+
demo.launch(share = True)
|
hid512_decoder_att_epoch_20.pt → decoderatt_epoch_35.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dee556d5e874646355b36d8d2fa94daf70d1da8684f0d95e4eb5edee4fe1b881
|
3 |
+
size 43042290
|
hid512_encoder_att_epoch_20.pt → encoderatt_epoch_35.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0879cb7c5a5359360b70195e61c9e91cd9f8caa0c6296f5924a7b9530f1350b0
|
3 |
+
size 16437536
|
temp.ipynb
DELETED
@@ -1,569 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 1,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [
|
8 |
-
{
|
9 |
-
"name": "stdout",
|
10 |
-
"output_type": "stream",
|
11 |
-
"text": [
|
12 |
-
"WARNING:tensorflow:From c:\\Users\\THU\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\keras\\src\\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.\n",
|
13 |
-
"\n"
|
14 |
-
]
|
15 |
-
}
|
16 |
-
],
|
17 |
-
"source": [
|
18 |
-
"import gradio as gr\n",
|
19 |
-
"from transformers import pipeline \n",
|
20 |
-
"import re\n",
|
21 |
-
"import pickle \n",
|
22 |
-
"import torch\n",
|
23 |
-
"import torch.nn as nn\n",
|
24 |
-
"from torchtext.transforms import PadTransform\n",
|
25 |
-
"from torch.utils.data import Dataset, DataLoader\n",
|
26 |
-
"from torch.nn import functional as F\n",
|
27 |
-
"from tqdm import tqdm\n",
|
28 |
-
"from underthesea import word_tokenize, text_normalize"
|
29 |
-
]
|
30 |
-
},
|
31 |
-
{
|
32 |
-
"cell_type": "code",
|
33 |
-
"execution_count": 7,
|
34 |
-
"metadata": {},
|
35 |
-
"outputs": [
|
36 |
-
{
|
37 |
-
"name": "stdout",
|
38 |
-
"output_type": "stream",
|
39 |
-
"text": [
|
40 |
-
"Running on local URL: http://127.0.0.1:7864\n",
|
41 |
-
"\n",
|
42 |
-
"To create a public link, set `share=True` in `launch()`.\n"
|
43 |
-
]
|
44 |
-
},
|
45 |
-
{
|
46 |
-
"data": {
|
47 |
-
"text/html": [
|
48 |
-
"<div><iframe src=\"http://127.0.0.1:7864/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
49 |
-
],
|
50 |
-
"text/plain": [
|
51 |
-
"<IPython.core.display.HTML object>"
|
52 |
-
]
|
53 |
-
},
|
54 |
-
"metadata": {},
|
55 |
-
"output_type": "display_data"
|
56 |
-
},
|
57 |
-
{
|
58 |
-
"data": {
|
59 |
-
"text/plain": []
|
60 |
-
},
|
61 |
-
"execution_count": 7,
|
62 |
-
"metadata": {},
|
63 |
-
"output_type": "execute_result"
|
64 |
-
}
|
65 |
-
],
|
66 |
-
"source": [
|
67 |
-
"import gradio as gr\n",
|
68 |
-
"\n",
|
69 |
-
"def translation(text):\n",
|
70 |
-
" output1 = 1\n",
|
71 |
-
" output2 = 2\n",
|
72 |
-
" #output3 = finetune_BERT(text)\n",
|
73 |
-
"\n",
|
74 |
-
" return (output1, output2)\n",
|
75 |
-
"\n",
|
76 |
-
"\n",
|
77 |
-
"\n",
|
78 |
-
"examples = [[\"Input: Hello guys\"], \n",
|
79 |
-
" [\"Output: Xin chào các bạn\"]]\n",
|
80 |
-
"\n",
|
81 |
-
"demo = gr.Interface(\n",
|
82 |
-
" theme = gr.themes.Base(),\n",
|
83 |
-
" fn=translation,\n",
|
84 |
-
" title=\"Co Gai Mo Duong\",\n",
|
85 |
-
" description=\"\"\"\n",
|
86 |
-
" ## Machine Translation: English to Vietnamese\n",
|
87 |
-
" \"\"\",\n",
|
88 |
-
" examples=examples,\n",
|
89 |
-
" inputs=[\n",
|
90 |
-
" gr.Textbox(\n",
|
91 |
-
" lines=5, placeholder=\"Enter text\", label=\"Input\"\n",
|
92 |
-
" )\n",
|
93 |
-
" ],\n",
|
94 |
-
" outputs=[\n",
|
95 |
-
" gr.Textbox(\n",
|
96 |
-
" \"text\", label=\"Our Machine Translation\"\n",
|
97 |
-
" ),\n",
|
98 |
-
" gr.Textbox(\n",
|
99 |
-
" \"text\", label=\"VietAI Machine Translation\"\n",
|
100 |
-
" )\n",
|
101 |
-
" ]\n",
|
102 |
-
")\n",
|
103 |
-
"\n",
|
104 |
-
"demo.launch(shared = True)"
|
105 |
-
]
|
106 |
-
},
|
107 |
-
{
|
108 |
-
"cell_type": "code",
|
109 |
-
"execution_count": 2,
|
110 |
-
"metadata": {},
|
111 |
-
"outputs": [],
|
112 |
-
"source": [
|
113 |
-
"# Build Vocabulary\n",
|
114 |
-
"MAX_LENGTH = 30\n",
|
115 |
-
"#device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
116 |
-
"device = 'cpu'"
|
117 |
-
]
|
118 |
-
},
|
119 |
-
{
|
120 |
-
"cell_type": "code",
|
121 |
-
"execution_count": 3,
|
122 |
-
"metadata": {},
|
123 |
-
"outputs": [],
|
124 |
-
"source": [
|
125 |
-
"class Vocabulary:\n",
|
126 |
-
" \"\"\"The Vocabulary class is used to record words, which are used to convert\n",
|
127 |
-
" text to numbers and vice versa.\n",
|
128 |
-
" \"\"\"\n",
|
129 |
-
"\n",
|
130 |
-
" def __init__(self, lang=\"vi\"):\n",
|
131 |
-
" self.lang = lang\n",
|
132 |
-
" self.word2id = dict()\n",
|
133 |
-
" self.word2id[\"<sos>\"] = 0 # Start of Sentece Token\n",
|
134 |
-
" self.word2id[\"<eos>\"] = 1 # End of Sentence Token\n",
|
135 |
-
" self.word2id[\"<unk>\"] = 2 # Unknown Token\n",
|
136 |
-
" self.word2id[\"<pad>\"] = 3 # Pad Token\n",
|
137 |
-
" self.sos_id = self.word2id[\"<sos>\"]\n",
|
138 |
-
" self.eos_id = self.word2id[\"<eos>\"]\n",
|
139 |
-
" self.unk_id = self.word2id[\"<unk>\"]\n",
|
140 |
-
" self.pad_id = self.word2id[\"<pad>\"]\n",
|
141 |
-
" self.id2word = {v: k for k, v in self.word2id.items()}\n",
|
142 |
-
" self.pad_transform = PadTransform(max_length = MAX_LENGTH, pad_value = self.pad_id)\n",
|
143 |
-
"\n",
|
144 |
-
" def __getitem__(self, word):\n",
|
145 |
-
" \"\"\"Return ID of word if existed else return ID unknown token\n",
|
146 |
-
" @param word (str)\n",
|
147 |
-
" \"\"\"\n",
|
148 |
-
" return self.word2id.get(word, self.unk_id)\n",
|
149 |
-
"\n",
|
150 |
-
" def __contains__(self, word):\n",
|
151 |
-
" \"\"\"Return True if word in Vocabulary else return False\n",
|
152 |
-
" @param word (str)\n",
|
153 |
-
" \"\"\"\n",
|
154 |
-
" return word in self.word2id\n",
|
155 |
-
"\n",
|
156 |
-
" def __len__(self):\n",
|
157 |
-
" \"\"\"\n",
|
158 |
-
" Return number of tokens(include sos, eos, unk and pad tokens) in Vocabulary\n",
|
159 |
-
" \"\"\"\n",
|
160 |
-
" return len(self.word2id)\n",
|
161 |
-
"\n",
|
162 |
-
" def lookup_tokens(self, word_indexes: list):\n",
|
163 |
-
" \"\"\"Return the list of words by lookup by ID\n",
|
164 |
-
" @param word_indexes (list(int))\n",
|
165 |
-
" @return words (list(str))\n",
|
166 |
-
" \"\"\"\n",
|
167 |
-
" return [self.id2word[word_index] for word_index in word_indexes]\n",
|
168 |
-
"\n",
|
169 |
-
" def add(self, word):\n",
|
170 |
-
" \"\"\"Add word to vocabulary\n",
|
171 |
-
" @param word (str)\n",
|
172 |
-
" @return index (str): index of the word just added\n",
|
173 |
-
" \"\"\"\n",
|
174 |
-
" if word not in self:\n",
|
175 |
-
" word_index = self.word2id[word] = len(self.word2id)\n",
|
176 |
-
" self.id2word[word_index] = word\n",
|
177 |
-
" return word_index\n",
|
178 |
-
" else:\n",
|
179 |
-
" return self[word]\n",
|
180 |
-
"\n",
|
181 |
-
" def preprocessing_sent(self, sent, lang=\"en\"):\n",
|
182 |
-
" \"\"\"Preprocessing a sentence (depend on language english or vietnamese)\"\"\"\n",
|
183 |
-
"\n",
|
184 |
-
" if (lang == \"en\") or (lang == \"eng\") or (lang == \"english\"):\n",
|
185 |
-
" # Remove unnecessary space\n",
|
186 |
-
" sent = re.sub(\" +\", \" \", sent)\n",
|
187 |
-
"\n",
|
188 |
-
" # Replace short form\n",
|
189 |
-
" sent = re.sub(\"'m \", \"am \", sent)\n",
|
190 |
-
" # Dont know to preprocess with possessive case\n",
|
191 |
-
" sent = re.sub(\"'s \", \"is \", sent)\n",
|
192 |
-
" sent = re.sub(\"'re \", \"are \", sent)\n",
|
193 |
-
" sent = re.sub(\"'ve \", \"have \", sent)\n",
|
194 |
-
" sent = re.sub(\"'ll \", \"will \", sent)\n",
|
195 |
-
" sent = re.sub(\"'d \", \"would \", sent)\n",
|
196 |
-
"\n",
|
197 |
-
" sent = re.sub(\"aren 't\", \"are not\", sent)\n",
|
198 |
-
" sent = re.sub(\"isn 't\", \"is not\", sent)\n",
|
199 |
-
" sent = re.sub(\"don 't\", \"do not\", sent)\n",
|
200 |
-
" sent = re.sub(\"doesn 't\", \"does not\", sent)\n",
|
201 |
-
" sent = re.sub(\"wasn 't\", \"was not\", sent)\n",
|
202 |
-
" sent = re.sub(\"weren 't\", \"were not\", sent)\n",
|
203 |
-
" sent = re.sub(\"won 't\", \"will not\", sent)\n",
|
204 |
-
" sent = re.sub(\"can 't\", \"can not\", sent)\n",
|
205 |
-
" sent = re.sub(\"let 's\", \"let us\", sent)\n",
|
206 |
-
"\n",
|
207 |
-
" else:\n",
|
208 |
-
" # Package underthesea.text_normalize support to normalize vietnamese\n",
|
209 |
-
" sent = text_normalize(sent)\n",
|
210 |
-
"\n",
|
211 |
-
" sent = re.sub(\"'\", \"'\", sent)\n",
|
212 |
-
" sent = re.sub(\""\", '\"', sent)\n",
|
213 |
-
" sent = re.sub(\"[\", \"[\", sent)\n",
|
214 |
-
" sent = re.sub(\"]\", \"]\", sent)\n",
|
215 |
-
" \n",
|
216 |
-
" # Lowercase sentence and remove space at beginning and ending\n",
|
217 |
-
" return sent.lower().strip()\n",
|
218 |
-
"\n",
|
219 |
-
" def tokenize_corpus(self, corpus, disable=False):\n",
|
220 |
-
" \"\"\"Split the documents of the corpus into words\n",
|
221 |
-
" @param corpus (list(str)): list of documents\n",
|
222 |
-
" @return tokenized_corpus (list(list(str))): list of words\n",
|
223 |
-
" \"\"\"\n",
|
224 |
-
" if not disable:\n",
|
225 |
-
" print(\"Tokenize the corpus...\")\n",
|
226 |
-
" tokenized_corpus = list()\n",
|
227 |
-
" for document in tqdm(corpus, disable=disable):\n",
|
228 |
-
" tokenized_document = [\"<sos>\"] + self.preprocessing_sent(document).split(\" \") + [\"<eos>\"]\n",
|
229 |
-
" tokenized_corpus.append(tokenized_document)\n",
|
230 |
-
" return tokenized_corpus\n",
|
231 |
-
"\n",
|
232 |
-
" def corpus_to_tensor(self, corpus, is_tokenized=False, disable=False):\n",
|
233 |
-
" \"\"\"Convert corpus to a list of indices tensor\n",
|
234 |
-
" @param corpus (list(str) if is_tokenized==False else list(list(str)))\n",
|
235 |
-
" @param is_tokenized (bool)\n",
|
236 |
-
" @return indicies_corpus (list(tensor))\n",
|
237 |
-
" \"\"\"\n",
|
238 |
-
" if is_tokenized:\n",
|
239 |
-
" tokenized_corpus = corpus\n",
|
240 |
-
" else:\n",
|
241 |
-
" tokenized_corpus = self.tokenize_corpus(corpus, disable=disable)\n",
|
242 |
-
" indicies_corpus = list()\n",
|
243 |
-
" for document in tqdm(tokenized_corpus, disable=disable):\n",
|
244 |
-
" indicies_document = torch.tensor(\n",
|
245 |
-
" list(map(lambda word: self[word], document)), dtype=torch.int64\n",
|
246 |
-
" )\n",
|
247 |
-
" \n",
|
248 |
-
" indicies_corpus.append(self.pad_transform(indicies_document))\n",
|
249 |
-
"\n",
|
250 |
-
" return indicies_corpus\n",
|
251 |
-
"\n",
|
252 |
-
" def tensor_to_corpus(self, tensor, disable=False):\n",
|
253 |
-
" \"\"\"Convert list of indices tensor to a list of tokenized documents\n",
|
254 |
-
" @param indicies_corpus (list(tensor))\n",
|
255 |
-
" @return corpus (list(list(str)))\n",
|
256 |
-
" \"\"\"\n",
|
257 |
-
" corpus = list()\n",
|
258 |
-
" for indicies in tqdm(tensor, disable=disable):\n",
|
259 |
-
" document = list(map(lambda index: self.id2word[index.item()], indicies))\n",
|
260 |
-
" corpus.append(document)\n",
|
261 |
-
"\n",
|
262 |
-
" return corpus"
|
263 |
-
]
|
264 |
-
},
|
265 |
-
{
|
266 |
-
"cell_type": "code",
|
267 |
-
"execution_count": 4,
|
268 |
-
"metadata": {},
|
269 |
-
"outputs": [],
|
270 |
-
"source": [
|
271 |
-
"def create_input_emb_layer():\n",
|
272 |
-
" num_embeddings, embedding_dim = 32998, 100\n",
|
273 |
-
" emb_layer = nn.Embedding(num_embeddings, embedding_dim)\n",
|
274 |
-
" emb_layer.weight.requires_grad = False\n",
|
275 |
-
"\n",
|
276 |
-
" return emb_layer, embedding_dim\n",
|
277 |
-
"\n",
|
278 |
-
"def create_output_emb_layer():\n",
|
279 |
-
" num_embeddings, embedding_dim = 15405, 100\n",
|
280 |
-
" emb_layer = nn.Embedding(num_embeddings, embedding_dim)\n",
|
281 |
-
" emb_layer.weight.requires_grad = False\n",
|
282 |
-
"\n",
|
283 |
-
" return emb_layer, embedding_dim\n",
|
284 |
-
" \n",
|
285 |
-
"class EncoderRNN(nn.Module):\n",
|
286 |
-
" def __init__(self, input_dim, hidden_dim, dropout = 0.2):\n",
|
287 |
-
" super(EncoderRNN, self).__init__()\n",
|
288 |
-
" \n",
|
289 |
-
" self.hidden_dim = hidden_dim\n",
|
290 |
-
" #self.embedding = nn.Embedding(input_dim, hidden_dim)\n",
|
291 |
-
" # Đổi thành input embedding\n",
|
292 |
-
" self.embedding, self.embedding_dim = create_input_emb_layer()\n",
|
293 |
-
" self.gru = nn.GRU(self.embedding_dim, hidden_dim, batch_first=True)\n",
|
294 |
-
" self.dropout = nn.Dropout(dropout)\n",
|
295 |
-
"\n",
|
296 |
-
" def forward(self, src):\n",
|
297 |
-
" embedded = self.dropout(self.embedding(src))\n",
|
298 |
-
" output, hidden = self.gru(embedded)\n",
|
299 |
-
" return output, hidden\n",
|
300 |
-
" \n",
|
301 |
-
"class BahdanauAttention(nn.Module):\n",
|
302 |
-
" def __init__(self, hidden_size):\n",
|
303 |
-
" super(BahdanauAttention, self).__init__()\n",
|
304 |
-
" self.Wa = nn.Linear(hidden_size, hidden_size)\n",
|
305 |
-
" self.Ua = nn.Linear(hidden_size, hidden_size)\n",
|
306 |
-
" self.Va = nn.Linear(hidden_size, 1)\n",
|
307 |
-
"\n",
|
308 |
-
" def forward(self, query, keys):\n",
|
309 |
-
" scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
|
310 |
-
" scores = scores.squeeze(2).unsqueeze(1)\n",
|
311 |
-
"\n",
|
312 |
-
" weights = F.softmax(scores, dim=-1)\n",
|
313 |
-
" context = torch.bmm(weights, keys)\n",
|
314 |
-
"\n",
|
315 |
-
" return context, weights\n",
|
316 |
-
"\n",
|
317 |
-
"class AttnDecoderRNN(nn.Module):\n",
|
318 |
-
" def __init__(self, hidden_size, output_size, dropout_p=0.1):\n",
|
319 |
-
" super(AttnDecoderRNN, self).__init__()\n",
|
320 |
-
" # self.embedding = nn.Embedding(output_size, hidden_size)\n",
|
321 |
-
" # Đổi thành output embedding\n",
|
322 |
-
" self.embedding, self.embedding_dim = create_output_emb_layer()\n",
|
323 |
-
" self.fc = nn.Linear(self.embedding_dim, hidden_size)\n",
|
324 |
-
" self.attention = BahdanauAttention(hidden_size)\n",
|
325 |
-
" self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
|
326 |
-
" self.out = nn.Linear(hidden_size, output_size)\n",
|
327 |
-
" self.dropout = nn.Dropout(dropout_p)\n",
|
328 |
-
"\n",
|
329 |
-
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
|
330 |
-
" batch_size = encoder_outputs.size(0)\n",
|
331 |
-
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(0)\n",
|
332 |
-
" decoder_hidden = encoder_hidden\n",
|
333 |
-
" decoder_outputs = []\n",
|
334 |
-
" attentions = []\n",
|
335 |
-
"\n",
|
336 |
-
" for i in range(MAX_LENGTH):\n",
|
337 |
-
" decoder_output, decoder_hidden, attn_weights = self.forward_step(\n",
|
338 |
-
" decoder_input, decoder_hidden, encoder_outputs\n",
|
339 |
-
" )\n",
|
340 |
-
" decoder_outputs.append(decoder_output)\n",
|
341 |
-
" attentions.append(attn_weights)\n",
|
342 |
-
"\n",
|
343 |
-
" if target_tensor is not None:\n",
|
344 |
-
" # Teacher forcing: Feed the target as the next input\n",
|
345 |
-
" decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
|
346 |
-
" else:\n",
|
347 |
-
" # Without teacher forcing: use its own predictions as the next input\n",
|
348 |
-
" _, topi = decoder_output.topk(1)\n",
|
349 |
-
" decoder_input = topi.squeeze(-1).detach() # detach from history as input\n",
|
350 |
-
"\n",
|
351 |
-
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
|
352 |
-
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
|
353 |
-
" attentions = torch.cat(attentions, dim=1)\n",
|
354 |
-
"\n",
|
355 |
-
" return decoder_outputs, decoder_hidden, attentions\n",
|
356 |
-
"\n",
|
357 |
-
"\n",
|
358 |
-
" def forward_step(self, input, hidden, encoder_outputs):\n",
|
359 |
-
" embedded = self.dropout(self.fc(self.embedding(input)))\n",
|
360 |
-
" \n",
|
361 |
-
" query = hidden.permute(1, 0, 2)\n",
|
362 |
-
" context, attn_weights = self.attention(query, encoder_outputs)\n",
|
363 |
-
" input_gru = torch.cat((embedded, context), dim=2)\n",
|
364 |
-
"\n",
|
365 |
-
" output, hidden = self.gru(input_gru, hidden)\n",
|
366 |
-
" output = self.out(output)\n",
|
367 |
-
"\n",
|
368 |
-
" return output, hidden, attn_weights"
|
369 |
-
]
|
370 |
-
},
|
371 |
-
{
|
372 |
-
"cell_type": "code",
|
373 |
-
"execution_count": null,
|
374 |
-
"metadata": {},
|
375 |
-
"outputs": [],
|
376 |
-
"source": []
|
377 |
-
},
|
378 |
-
{
|
379 |
-
"cell_type": "code",
|
380 |
-
"execution_count": 41,
|
381 |
-
"metadata": {},
|
382 |
-
"outputs": [
|
383 |
-
{
|
384 |
-
"data": {
|
385 |
-
"text/plain": [
|
386 |
-
"<All keys matched successfully>"
|
387 |
-
]
|
388 |
-
},
|
389 |
-
"execution_count": 41,
|
390 |
-
"metadata": {},
|
391 |
-
"output_type": "execute_result"
|
392 |
-
}
|
393 |
-
],
|
394 |
-
"source": [
|
395 |
-
"with open(\"vocab_source.pkl\", \"rb\") as file:\n",
|
396 |
-
" VOCAB_SOURCE = pickle.load(file)\n",
|
397 |
-
"with open(\"vocab_target.pkl\", \"rb\") as file:\n",
|
398 |
-
" VOCAB_TARGET = pickle.load(file)\n",
|
399 |
-
"\n",
|
400 |
-
"INPUT_DIM = len(VOCAB_SOURCE)\n",
|
401 |
-
"OUTPUT_DIM = len(VOCAB_TARGET)\n",
|
402 |
-
"HID_DIM = 512\n",
|
403 |
-
"\n",
|
404 |
-
"# Load our Model Translation\n",
|
405 |
-
"ENCODER = EncoderRNN(INPUT_DIM, HID_DIM)\n",
|
406 |
-
"ENCODER.load_state_dict(torch.load('encoder_att_epoch_16.pt'))\n",
|
407 |
-
"DECODER = AttnDecoderRNN(HID_DIM, OUTPUT_DIM)\n",
|
408 |
-
"DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
|
409 |
-
]
|
410 |
-
},
|
411 |
-
{
|
412 |
-
"cell_type": "code",
|
413 |
-
"execution_count": 42,
|
414 |
-
"metadata": {},
|
415 |
-
"outputs": [],
|
416 |
-
"source": [
|
417 |
-
"def evaluate(encoder, decoder, sentence, vocab_source, vocab_target, disable = False):\n",
|
418 |
-
" encoder.eval()\n",
|
419 |
-
" decoder.eval()\n",
|
420 |
-
" with torch.no_grad():\n",
|
421 |
-
" input_tensor = vocab_source.corpus_to_tensor([sentence], disable = disable)[0].view(1,-1).to(device)\n",
|
422 |
-
" \n",
|
423 |
-
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
|
424 |
-
" decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)\n",
|
425 |
-
"\n",
|
426 |
-
" _, topi = decoder_outputs.topk(1)\n",
|
427 |
-
" decoded_ids = topi.squeeze()\n",
|
428 |
-
"\n",
|
429 |
-
" decoded_words = []\n",
|
430 |
-
" for idx in decoded_ids:\n",
|
431 |
-
" if idx.item() == vocab_target.eos_id:\n",
|
432 |
-
" decoded_words.append('<eos>')\n",
|
433 |
-
" break\n",
|
434 |
-
" decoded_words.append(vocab_target.id2word[idx.item()])\n",
|
435 |
-
" return decoded_words, decoder_attn\n",
|
436 |
-
"\n",
|
437 |
-
"def my_translate_model(sentence):\n",
|
438 |
-
" output_words, _ = evaluate(ENCODER, DECODER, sentence, VOCAB_SOURCE, VOCAB_TARGET, disable= True)\n",
|
439 |
-
" \n",
|
440 |
-
" return ' '.join(output_words[1:-1]).capitalize()+ '.'"
|
441 |
-
]
|
442 |
-
},
|
443 |
-
{
|
444 |
-
"cell_type": "code",
|
445 |
-
"execution_count": 61,
|
446 |
-
"metadata": {},
|
447 |
-
"outputs": [
|
448 |
-
{
|
449 |
-
"data": {
|
450 |
-
"text/plain": [
|
451 |
-
"'Tôi hy vọng các bạn sẽ có thể làm được giải pháp.'"
|
452 |
-
]
|
453 |
-
},
|
454 |
-
"execution_count": 61,
|
455 |
-
"metadata": {},
|
456 |
-
"output_type": "execute_result"
|
457 |
-
}
|
458 |
-
],
|
459 |
-
"source": [
|
460 |
-
"my_translate_model(\"I hope you will be better\")"
|
461 |
-
]
|
462 |
-
},
|
463 |
-
{
|
464 |
-
"cell_type": "code",
|
465 |
-
"execution_count": 60,
|
466 |
-
"metadata": {},
|
467 |
-
"outputs": [
|
468 |
-
{
|
469 |
-
"data": {
|
470 |
-
"text/plain": [
|
471 |
-
"<All keys matched successfully>"
|
472 |
-
]
|
473 |
-
},
|
474 |
-
"execution_count": 60,
|
475 |
-
"metadata": {},
|
476 |
-
"output_type": "execute_result"
|
477 |
-
}
|
478 |
-
],
|
479 |
-
"source": [
|
480 |
-
"ENCODER = EncoderRNN(INPUT_DIM, HID_DIM)\n",
|
481 |
-
"ENCODER.load_state_dict(torch.load('encoder_att_epoch_16.pt'))\n",
|
482 |
-
"DECODER = AttnDecoderRNN(HID_DIM, OUTPUT_DIM)\n",
|
483 |
-
"DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
|
484 |
-
]
|
485 |
-
},
|
486 |
-
{
|
487 |
-
"cell_type": "code",
|
488 |
-
"execution_count": 48,
|
489 |
-
"metadata": {},
|
490 |
-
"outputs": [
|
491 |
-
{
|
492 |
-
"data": {
|
493 |
-
"text/plain": [
|
494 |
-
"odict_keys(['embedding.weight', 'fc.weight', 'fc.bias', 'attention.Wa.weight', 'attention.Wa.bias', 'attention.Ua.weight', 'attention.Ua.bias', 'attention.Va.weight', 'attention.Va.bias', 'gru.weight_ih_l0', 'gru.weight_hh_l0', 'gru.bias_ih_l0', 'gru.bias_hh_l0', 'out.weight', 'out.bias'])"
|
495 |
-
]
|
496 |
-
},
|
497 |
-
"execution_count": 48,
|
498 |
-
"metadata": {},
|
499 |
-
"output_type": "execute_result"
|
500 |
-
}
|
501 |
-
],
|
502 |
-
"source": [
|
503 |
-
"torch.load('decoder_att_epoch_16.pt').keys()"
|
504 |
-
]
|
505 |
-
},
|
506 |
-
{
|
507 |
-
"cell_type": "code",
|
508 |
-
"execution_count": 52,
|
509 |
-
"metadata": {},
|
510 |
-
"outputs": [
|
511 |
-
{
|
512 |
-
"data": {
|
513 |
-
"text/plain": [
|
514 |
-
"<All keys matched successfully>"
|
515 |
-
]
|
516 |
-
},
|
517 |
-
"execution_count": 52,
|
518 |
-
"metadata": {},
|
519 |
-
"output_type": "execute_result"
|
520 |
-
}
|
521 |
-
],
|
522 |
-
"source": [
|
523 |
-
"DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
|
524 |
-
]
|
525 |
-
},
|
526 |
-
{
|
527 |
-
"cell_type": "code",
|
528 |
-
"execution_count": 57,
|
529 |
-
"metadata": {},
|
530 |
-
"outputs": [
|
531 |
-
{
|
532 |
-
"data": {
|
533 |
-
"text/plain": [
|
534 |
-
"<All keys matched successfully>"
|
535 |
-
]
|
536 |
-
},
|
537 |
-
"execution_count": 57,
|
538 |
-
"metadata": {},
|
539 |
-
"output_type": "execute_result"
|
540 |
-
}
|
541 |
-
],
|
542 |
-
"source": [
|
543 |
-
"DECODER = AttnDecoderRNN(HID_DIM, OUTPUT_DIM)\n",
|
544 |
-
"DECODER.load_state_dict(torch.load('decoder_att_epoch_16.pt'))"
|
545 |
-
]
|
546 |
-
}
|
547 |
-
],
|
548 |
-
"metadata": {
|
549 |
-
"kernelspec": {
|
550 |
-
"display_name": "Python 3",
|
551 |
-
"language": "python",
|
552 |
-
"name": "python3"
|
553 |
-
},
|
554 |
-
"language_info": {
|
555 |
-
"codemirror_mode": {
|
556 |
-
"name": "ipython",
|
557 |
-
"version": 3
|
558 |
-
},
|
559 |
-
"file_extension": ".py",
|
560 |
-
"mimetype": "text/x-python",
|
561 |
-
"name": "python",
|
562 |
-
"nbconvert_exporter": "python",
|
563 |
-
"pygments_lexer": "ipython3",
|
564 |
-
"version": "3.11.5"
|
565 |
-
}
|
566 |
-
},
|
567 |
-
"nbformat": 4,
|
568 |
-
"nbformat_minor": 2
|
569 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|