ppisljar commited on
Commit
f024d03
·
verified ·
1 Parent(s): d97805c

Upload 2 files

Browse files
Files changed (2) hide show
  1. g2p_t5.onnx +2 -2
  2. infer.py +30 -0
g2p_t5.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:77f18426d1321bd08c29a7f0fb743b2d8e83b850178a7bab16a3b64d3f04a415
3
- size 1208296672
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e15c8d2249cc940232f58442a2e93fe9e27fbaaaebdfb356f5f7a3a0fb7ec9c5
3
+ size 1208441138
infer.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import torch
3
+
4
+ from transformers import AutoTokenizer
5
+
6
+ # setup GPU
7
+ if torch.cuda.is_available():
8
+ device = [0] # use 0th CUDA device
9
+ accelerator = 'gpu'
10
+ else:
11
+ device = 1
12
+ accelerator = 'cpu'
13
+
14
+ map_location = torch.device('cuda:{}'.format(device[0]) if accelerator == 'gpu' else 'cpu')
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained('google/byt5-small')
17
+
18
+ sentence = "Kupil sem bicikel in mu zamenjal stol.".lower()
19
+
20
+ ort_session = onnxruntime.InferenceSession("g2p_t5.onnx", providers=["CPUExecutionProvider"])
21
+ input_ids = [sentence]
22
+ input_encoding = tokenizer(
23
+ input_ids, padding='longest', max_length=512, truncation=True, return_tensors='pt',
24
+ )
25
+ input_ids, attention_mask = input_encoding.input_ids, input_encoding.attention_mask
26
+ ort_inputs = {'input_ids': input_ids.numpy()}
27
+ ort_outs = ort_session.run(None, ort_inputs)
28
+ generated_ids = [ort_outs[0]]
29
+ generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
30
+ print(generated_texts)