Sifal commited on
Commit
5813d56
1 Parent(s): 36d6903

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -209,6 +209,8 @@ def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,l
209
 
210
  def translate(src_sentence: str, strategy:str = 'greedy' , lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6):
211
  assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
 
 
212
  # Tokenize the source sentence
213
  src = source_tokenizer(src_sentence, **token_config)['input_ids']
214
  num_tokens = src.shape[1]
@@ -218,7 +220,9 @@ def translate(src_sentence: str, strategy:str = 'greedy' , lenght_extend :int =
218
  tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten()
219
  # Generate the target tokens using beam search decoding
220
  else:
221
- tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
 
 
222
  # Decode the target tokens and clean up the result
223
  return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
224
 
 
209
 
210
  def translate(src_sentence: str, strategy:str = 'greedy' , lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6):
211
  assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
212
+ assert lenght_extend >= 1, 'lenght_extend must be superior or equal to one'
213
+
214
  # Tokenize the source sentence
215
  src = source_tokenizer(src_sentence, **token_config)['input_ids']
216
  num_tokens = src.shape[1]
 
220
  tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten()
221
  # Generate the target tokens using beam search decoding
222
  else:
223
+ assert length_penalty >= 0 , 'lenght penelity must be superior or equal to zero'
224
+ assert beam_size >= 1, 'beam size must superior or equal to one'
225
+ tgt_tokens = beam_search_decode(model, src, src_mask, maxt_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
226
  # Decode the target tokens and clean up the result
227
  return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
228