Update app.py
Browse files
app.py
CHANGED
@@ -209,6 +209,8 @@ def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,l
|
|
209 |
|
210 |
def translate(src_sentence: str, strategy:str = 'greedy' , lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6):
|
211 |
assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
|
|
|
|
|
212 |
# Tokenize the source sentence
|
213 |
src = source_tokenizer(src_sentence, **token_config)['input_ids']
|
214 |
num_tokens = src.shape[1]
|
@@ -218,7 +220,9 @@ def translate(src_sentence: str, strategy:str = 'greedy' , lenght_extend :int =
|
|
218 |
tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten()
|
219 |
# Generate the target tokens using beam search decoding
|
220 |
else:
|
221 |
-
|
|
|
|
|
222 |
# Decode the target tokens and clean up the result
|
223 |
return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
|
224 |
|
|
|
209 |
|
210 |
def translate(src_sentence: str, strategy:str = 'greedy' , lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6):
|
211 |
assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search'
|
212 |
+
assert lenght_extend >= 1, 'lenght_extend must be superior or equal to one'
|
213 |
+
|
214 |
# Tokenize the source sentence
|
215 |
src = source_tokenizer(src_sentence, **token_config)['input_ids']
|
216 |
num_tokens = src.shape[1]
|
|
|
220 |
tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten()
|
221 |
# Generate the target tokens using beam search decoding
|
222 |
else:
|
223 |
+
assert length_penalty >= 0 , 'lenght penelity must be superior or equal to zero'
|
224 |
+
assert beam_size >= 1, 'beam size must superior or equal to one'
|
225 |
+
tgt_tokens = beam_search_decode(model, src, src_mask, maxt_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten()
|
226 |
# Decode the target tokens and clean up the result
|
227 |
return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
|
228 |
|