Frenchizer commited on
Commit
fe50c4c
·
verified ·
1 Parent(s): d9d6d78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -12
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import numpy as np
3
  import onnxruntime as ort
4
  from transformers import AutoTokenizer
@@ -15,22 +14,42 @@ def translate_text(input_text):
15
  input_text, return_tensors="np", padding=True, truncation=True, max_length=512
16
  )
17
 
 
18
  input_ids = tokenized_input["input_ids"].astype(np.int64)
19
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
20
 
21
- # Run inference with the ONNX model
22
- outputs = translation_session.run(
23
- None,
24
- {
25
- "input_ids": input_ids,
26
- "attention_mask": attention_mask,
27
- }
28
- )
29
 
30
- # Decode the output tokens
31
- translated_tokens = np.argmax(outputs[0], axis=-1)
32
- translated_text = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
34
  return translated_text
35
 
36
  # Create a Gradio interface
 
 
1
  import numpy as np
2
  import onnxruntime as ort
3
  from transformers import AutoTokenizer
 
14
  input_text, return_tensors="np", padding=True, truncation=True, max_length=512
15
  )
16
 
17
+ # Prepare encoder inputs
18
  input_ids = tokenized_input["input_ids"].astype(np.int64)
19
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
20
 
21
+ # Prepare decoder inputs (start with the start token)
22
+ decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id
23
+ decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
 
 
 
 
 
24
 
25
+ # Iteratively generate output tokens
26
+ translated_tokens = []
27
+ for _ in range(512): # Max length of output
28
+ # Run inference with the ONNX model
29
+ outputs = translation_session.run(
30
+ None,
31
+ {
32
+ "input_ids": input_ids,
33
+ "attention_mask": attention_mask,
34
+ "decoder_input_ids": decoder_input_ids,
35
+ }
36
+ )
37
+
38
+ # Get the next token ID
39
+ next_token_id = np.argmax(outputs[0][0, -1, :], axis=-1)
40
+ translated_tokens.append(next_token_id)
41
 
42
+ # Stop if the end-of-sequence token is generated
43
+ if next_token_id == translation_tokenizer.eos_token_id:
44
+ break
45
+
46
+ # Update decoder_input_ids for the next iteration
47
+ decoder_input_ids = np.concatenate(
48
+ [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
49
+ )
50
+
51
+ # Decode the output tokens
52
+ translated_text = translation_tokenizer.decode(translated_tokens, skip_special_tokens=True)
53
  return translated_text
54
 
55
  # Create a Gradio interface