Frenchizer commited on
Commit
c4b718f
·
verified ·
1 Parent(s): f80fc89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -4,7 +4,7 @@ from transformers import MarianTokenizer
4
  import gradio as gr
5
 
6
  # Load the tokenizer from the local folder
7
- tokenizer_path = "./onnx_model" # Path to the local tokenizer folder
8
  tokenizer = MarianTokenizer.from_pretrained(tokenizer_path)
9
 
10
  # Load the ONNX model
@@ -20,6 +20,7 @@ def translate(texts, max_length=512):
20
  # Initialize variables for decoding
21
  batch_size = input_ids.shape[0]
22
  decoder_input_ids = np.array([[tokenizer.pad_token_id]] * batch_size, dtype=np.int64) # Start with pad token
 
23
 
24
  # Generate output tokens iteratively
25
  for _ in range(max_length):
@@ -42,8 +43,11 @@ def translate(texts, max_length=512):
42
  # Append the next tokens to the decoder input for the next iteration
43
  decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens[:, None]], axis=-1)
44
 
 
 
 
45
  # Stop if all sequences have reached the EOS token
46
- if all(tokenizer.eos_token_id in sequence for sequence in decoder_input_ids):
47
  break
48
 
49
  # Decode the output tokens to text
@@ -64,7 +68,7 @@ interface = gr.Interface(
64
  inputs=gr.Textbox(lines=5, placeholder="Enter text to translate...", label="Input Text"),
65
  outputs=gr.Textbox(lines=5, label="Translated Text"),
66
  title="ONNX English to French Translation",
67
- description="Translate English text to French using an ONNX model.",
68
  )
69
 
70
  # Launch the Gradio app
 
4
  import gradio as gr
5
 
6
  # Load the tokenizer from the local folder
7
+ tokenizer_path = "./tokenizer" # Path to the local tokenizer folder
8
  tokenizer = MarianTokenizer.from_pretrained(tokenizer_path)
9
 
10
  # Load the ONNX model
 
20
  # Initialize variables for decoding
21
  batch_size = input_ids.shape[0]
22
  decoder_input_ids = np.array([[tokenizer.pad_token_id]] * batch_size, dtype=np.int64) # Start with pad token
23
+ eos_reached = np.zeros(batch_size, dtype=bool) # Track which sequences have finished
24
 
25
  # Generate output tokens iteratively
26
  for _ in range(max_length):
 
43
  # Append the next tokens to the decoder input for the next iteration
44
  decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens[:, None]], axis=-1)
45
 
46
+ # Check if the EOS token has been generated for each sequence
47
+ eos_reached = eos_reached | (next_tokens == tokenizer.eos_token_id)
48
+
49
  # Stop if all sequences have reached the EOS token
50
+ if all(eos_reached):
51
  break
52
 
53
  # Decode the output tokens to text
 
68
  inputs=gr.Textbox(lines=5, placeholder="Enter text to translate...", label="Input Text"),
69
  outputs=gr.Textbox(lines=5, label="Translated Text"),
70
  title="ONNX English to French Translation",
71
+ description="Translate English text to French using a MarianMT ONNX model.",
72
  )
73
 
74
  # Launch the Gradio app