Update utlis.py
Browse files
utlis.py
CHANGED
@@ -1,6 +1,17 @@
|
|
1 |
import yaml
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
with open(config_dir, 'r') as yaml_file:
|
6 |
loaded_model_params = yaml.safe_load(yaml_file)
|
@@ -116,7 +127,7 @@ def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,l
|
|
116 |
best_beam = beams[0][0]
|
117 |
return best_beam
|
118 |
|
119 |
-
def translate(model: torch.nn.Module, strategy:str, src_sentence: str, lenght_extend :int = 0, beam_size: int = 5,
|
120 |
assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
|
121 |
model.to(device)
|
122 |
model.eval()
|
@@ -131,4 +142,25 @@ def translate(model: torch.nn.Module, strategy:str, src_sentence: str, lenght_ex
|
|
131 |
else:
|
132 |
tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
|
133 |
# Decode the target tokens and clean up the result
|
134 |
-
return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import yaml
|
2 |
+
import torch
|
3 |
+
from .model import Seq2SeqTransformer
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from transformers import PreTrainedTokenizerFast
|
6 |
+
from tokenizers.processors import TemplateProcessing
|
7 |
|
8 |
+
|
9 |
+
def addPreprocessing(tokenizer):
|
10 |
+
tokenizer._tokenizer.post_processor = TemplateProcessing(
|
11 |
+
single=tokenizer.bos_token + " $A " + tokenizer.eos_token,
|
12 |
+
special_tokens=[(tokenizer.eos_token, tokenizer.eos_token_id), (tokenizer.bos_token, tokenizer.bos_token_id)])
|
13 |
+
|
14 |
+
def load_model(model_checkpoint_dir='model.pt',config_dir='config.yaml'):
|
15 |
|
16 |
with open(config_dir, 'r') as yaml_file:
|
17 |
loaded_model_params = yaml.safe_load(yaml_file)
|
|
|
127 |
best_beam = beams[0][0]
|
128 |
return best_beam
|
129 |
|
130 |
+
def translate(model: torch.nn.Module, strategy:str, src_sentence: str, lenght_extend :int = 0, beam_size: int = 5, length_penalty:float = 0.6):
|
131 |
assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
|
132 |
model.to(device)
|
133 |
model.eval()
|
|
|
142 |
else:
|
143 |
tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
|
144 |
# Decode the target tokens and clean up the result
|
145 |
+
return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
|
146 |
+
|
147 |
+
special_tokens = {'unk_token':"[UNK]",
|
148 |
+
'cls_token':"[CLS]",
|
149 |
+
'eos_token': '[EOS]',
|
150 |
+
'sep_token':"[SEP]",
|
151 |
+
'pad_token':"[PAD]",
|
152 |
+
'mask_token':"[MASK]",
|
153 |
+
'bos_token':"[BOS]"}
|
154 |
+
|
155 |
+
source_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", **special_tokens)
|
156 |
+
target_tokenizer = PreTrainedTokenizerFast.from_pretrained('Sifal/E2KT')
|
157 |
+
|
158 |
+
addPreprocessing(source_tokenizer)
|
159 |
+
addPreprocessing(target_tokenizer)
|
160 |
+
|
161 |
+
token_config = {
|
162 |
+
"add_special_tokens": True,
|
163 |
+
"return_tensors": True,
|
164 |
+
}
|
165 |
+
|
166 |
+
load_model()
|