Sifal commited on
Commit
4b177f2
1 Parent(s): fb8d450

Update utlis.py

Browse files
Files changed (1) hide show
  1. utlis.py +7 -5
utlis.py CHANGED
@@ -127,17 +127,15 @@ def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,l
127
  best_beam = beams[0][0]
128
  return best_beam
129
 
130
- def translate(model: torch.nn.Module, strategy:str, src_sentence: str, lenght_extend :int = 0, beam_size: int = 5, length_penalty:float = 0.6):
131
  assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
132
- model.to(device)
133
- model.eval()
134
  # Tokenize the source sentence
135
  src = source_tokenizer(src_sentence, **token_config)['input_ids']
136
  num_tokens = src.shape[1]
137
  # Create a source mask
138
  src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
139
  if strategy == 'greedy':
140
- tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=target_tokenizer.bos_token_id).flatten()
141
  # Generate the target tokens using beam search decoding
142
  else:
143
  tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
@@ -163,4 +161,8 @@ token_config = {
163
  "return_tensors": True,
164
  }
165
 
166
- load_model()
 
 
 
 
 
127
  best_beam = beams[0][0]
128
  return best_beam
129
 
130
+ def translate(model: torch.nn.Module, strategy:str = 'greedy' , src_sentence: str, lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6):
131
  assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
 
 
132
  # Tokenize the source sentence
133
  src = source_tokenizer(src_sentence, **token_config)['input_ids']
134
  num_tokens = src.shape[1]
135
  # Create a source mask
136
  src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
137
  if strategy == 'greedy':
138
+ tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten()
139
  # Generate the target tokens using beam search decoding
140
  else:
141
  tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
 
161
  "return_tensors": True,
162
  }
163
 
164
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
165
+
166
+ model = load_model()
167
+ model.to(device)
168
+ model.eval()