Frenchizer commited on
Commit
fc36581
·
1 Parent(s): 89b5af7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -51
app.py CHANGED
@@ -1,65 +1,37 @@
1
  import numpy as np
2
  import onnxruntime as ort
3
- from transformers import AutoTokenizer
 
4
  import gradio as gr
5
 
6
- # Load the ONNX model and tokenizer
7
- model_path = "model.onnx"
8
- translation_session = ort.InferenceSession(model_path)
9
- translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
 
 
 
10
 
11
  def translate_text(input_text):
12
  # Tokenize input text
13
- tokenized_input = translation_tokenizer(
14
- input_text, return_tensors="np", padding=True, truncation=True, max_length=512
15
  )
16
-
17
- # Prepare encoder inputs
18
- input_ids = tokenized_input["input_ids"].astype(np.int64)
19
- attention_mask = tokenized_input["attention_mask"].astype(np.int64)
20
-
21
- # Prepare decoder inputs (start with the start token)
22
- decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id
23
- decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
24
-
25
- # Iteratively generate output tokens
26
- translated_tokens = []
27
- for _ in range(512): # Max length of output
28
- # Run inference with the ONNX model
29
- outputs = translation_session.run(
30
- None,
31
- {
32
- "input_ids": input_ids,
33
- "attention_mask": attention_mask,
34
- "decoder_input_ids": decoder_input_ids,
35
- }
36
- )
37
-
38
- # Get the next token ID
39
- next_token_id = np.argmax(outputs[0][0, -1, :], axis=-1)
40
- translated_tokens.append(next_token_id)
41
-
42
- # Stop if the end-of-sequence token is generated
43
- if next_token_id == translation_tokenizer.eos_token_id:
44
- break
45
-
46
- # Update decoder_input_ids for the next iteration
47
- decoder_input_ids = np.concatenate(
48
- [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
49
  )
50
 
51
  # Decode the output tokens
52
- translated_text = translation_tokenizer.decode(translated_tokens, skip_special_tokens=True)
53
  return translated_text
54
 
55
- # Create a Gradio interface
56
- interface = gr.Interface(
57
- fn=translate_text,
58
- inputs="text",
59
- outputs="text",
60
- title="Frenchizer Translation Model",
61
- description="Translate text from English to French using an ONNX model."
62
- )
63
-
64
- # Launch the Gradio app
65
  interface.launch()
 
1
  import numpy as np
2
  import onnxruntime as ort
3
+ import torch
4
+ from transformers import MarianMTModel, MarianTokenizer
5
  import gradio as gr
6
 
7
+ # Load the MarianMT model and tokenizer from the local folder
8
+ model_path = "./model.onnx" # Path to the folder containing the model files
9
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
10
+ decoder_model = MarianMTModel.from_pretrained(model_name).get_decoder()
11
+
12
+ # Load the ONNX encoder
13
+ encoder_session = ort.InferenceSession("./onnx_model/encoder.onnx")
14
 
15
  def translate_text(input_text):
16
  # Tokenize input text
17
+ tokenized_input = tokenizer(
18
+ input_text, return_tensors="pt", padding=True, truncation=True, max_length=512
19
  )
20
+ input_ids = tokenized_input["input_ids"]
21
+ attention_mask = tokenized_input["attention_mask"]
22
+
23
+ # Generate translation using the model
24
+ with torch.no_grad():
25
+ outputs = model.generate(
26
+ input_ids=input_ids,
27
+ attention_mask=attention_mask,
28
+ max_length=512, # Maximum length of the output
29
+ num_beams=5, # Use beam search for better translations
30
+ early_stopping=True, # Stop generation when the model predicts the end-of-sequence token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
 
33
  # Decode the output tokens
34
+ translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
35
  return translated_text
36
 
 
 
 
 
 
 
 
 
 
 
37
  interface.launch()