wasmdashai commited on
Commit
e217e10
·
verified ·
1 Parent(s): ffc5468

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForTextToWaveform
4
+ def install_model(namemodel,tokenn,namemodelonxx):
5
+
6
+ model = AutoModelForTextToWaveform.from_pretrained(namemodel,token=tokenn)
7
+ namemodelonxxx=convert_to_onnx(model,namemodelonxx)
8
+ return namemodelonxxx
9
+ def convert_to_onnx(model,namemodelonxx):
10
+ vocab_size = model.text_encoder.embed_tokens.weight.size(0)
11
+ example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
12
+ torch.onnx.export(
13
+ model, # The model to be exported
14
+ example_input, # Example input for the model
15
+ namemodelonxx, # The filename for the exported ONNX model
16
+ opset_version=11, # Use an appropriate ONNX opset version
17
+ input_names=['input'], # Name of the input layer
18
+ output_names=['output'], # Name of the output layer
19
+ dynamic_axes={
20
+ 'input': {0: 'batch_size', 1: 'sequence_length'}, # Dynamic axes for variable-length inputs
21
+ 'output': {0: 'batch_size'}
22
+ }
23
+ )
24
+ return namemodelonxx
25
+ with gr.Blocks() as demo:
26
+ with gr.Row():
27
+ with gr.Column():
28
+ text_n_model=gr.Textbox(label="name model")
29
+ text_n_token=gr.Textbox(label="token")
30
+ text_n_onxx=gr.Textbox(label="name model onxx")
31
+ with gr.Column():
32
+
33
+ btn=gr.Button("convert")
34
+ label=gr.Label("return name model onxx")
35
+ btn.click(install_model,[text_n_model,text_n_token,text_n_onxx],[label])
36
+ demo.launch()