Sifal commited on
Commit
fb8d450
1 Parent(s): e6827ec

Update utlis.py

Browse files
Files changed (1) hide show
  1. utlis.py +35 -3
utlis.py CHANGED
@@ -1,6 +1,17 @@
1
  import yaml
 
 
 
 
 
2
 
3
- def load_checkpoint(model_checkpoint_dir='model.pt',config_dir='config.yaml'):
 
 
 
 
 
 
4
 
5
  with open(config_dir, 'r') as yaml_file:
6
  loaded_model_params = yaml.safe_load(yaml_file)
@@ -116,7 +127,7 @@ def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,l
116
  best_beam = beams[0][0]
117
  return best_beam
118
 
119
- def translate(model: torch.nn.Module, strategy:str, src_sentence: str, lenght_extend :int = 0, beam_size: int = 5, raw: bool = False, length_penalty:float = 0.6):
120
  assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
121
  model.to(device)
122
  model.eval()
@@ -131,4 +142,25 @@ def translate(model: torch.nn.Module, strategy:str, src_sentence: str, lenght_ex
131
  else:
132
  tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
133
  # Decode the target tokens and clean up the result
134
- return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import yaml
2
+ import torch
3
+ from .model import Seq2SeqTransformer
4
+ from transformers import AutoTokenizer
5
+ from transformers import PreTrainedTokenizerFast
6
+ from tokenizers.processors import TemplateProcessing
7
 
8
+
9
+ def addPreprocessing(tokenizer):
10
+ tokenizer._tokenizer.post_processor = TemplateProcessing(
11
+ single=tokenizer.bos_token + " $A " + tokenizer.eos_token,
12
+ special_tokens=[(tokenizer.eos_token, tokenizer.eos_token_id), (tokenizer.bos_token, tokenizer.bos_token_id)])
13
+
14
+ def load_model(model_checkpoint_dir='model.pt',config_dir='config.yaml'):
15
 
16
  with open(config_dir, 'r') as yaml_file:
17
  loaded_model_params = yaml.safe_load(yaml_file)
 
127
  best_beam = beams[0][0]
128
  return best_beam
129
 
130
+ def translate(model: torch.nn.Module, strategy:str, src_sentence: str, lenght_extend :int = 0, beam_size: int = 5, length_penalty:float = 0.6):
131
  assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
132
  model.to(device)
133
  model.eval()
 
142
  else:
143
  tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
144
  # Decode the target tokens and clean up the result
145
+ return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
146
+
147
+ special_tokens = {'unk_token':"[UNK]",
148
+ 'cls_token':"[CLS]",
149
+ 'eos_token': '[EOS]',
150
+ 'sep_token':"[SEP]",
151
+ 'pad_token':"[PAD]",
152
+ 'mask_token':"[MASK]",
153
+ 'bos_token':"[BOS]"}
154
+
155
+ source_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", **special_tokens)
156
+ target_tokenizer = PreTrainedTokenizerFast.from_pretrained('Sifal/E2KT')
157
+
158
+ addPreprocessing(source_tokenizer)
159
+ addPreprocessing(target_tokenizer)
160
+
161
+ token_config = {
162
+ "add_special_tokens": True,
163
+ "return_tensors": True,
164
+ }
165
+
166
+ load_model()