wasmdashai commited on
Commit
4de59fa
·
verified ·
1 Parent(s): 19af058

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -8
app.py CHANGED
@@ -1,9 +1,146 @@
1
- import gradio as gr
2
- from VitsModelSplit.vits_model import VitsModel
 
3
  from VitsModelSplit.vits_model_only_d import Vits_models_only_decoder
4
- def greet(name):
5
-
6
- model = Vits_models_only_decoder.from_pretrained("wasmdashai/vits-ar-sa-huba",token=name)
7
- return "Hello " + name + "!!"
8
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
9
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.onnx
3
+ import onnx
4
  from VitsModelSplit.vits_model_only_d import Vits_models_only_decoder
5
+ from VitsModelSplit.vits_model import VitsModel
6
+ import gradio as gr
7
+
8
+ class OnnxModelConverter:
9
+ def __init__(self):
10
+ self.model = None
11
+ def download_file(self,file_path):
12
+ ff= gr.File(value=file_path, visible=True)
13
+ file_url = ff.value['url']
14
+ return file_url
15
+
16
+ def convert(self, model_name, token, onnx_filename, conversion_type):
17
+ """
18
+ Main function to handle different types of model conversions.
19
+
20
+ Args:
21
+ model_name (str): Name of the model to convert.
22
+ token (str): Access token for loading the model.
23
+ onnx_filename (str): Desired filename for the ONNX output.
24
+ conversion_type (str): Type of conversion ('decoder', 'only_decoder', or 'full_model').
25
+
26
+ Returns:
27
+ str: The path to the generated ONNX file.
28
+ """
29
+ if conversion_type == "decoder":
30
+ return self.convert_decoder(model_name, token, onnx_filename)
31
+ elif conversion_type == "only_decoder":
32
+ return self.convert_only_decoder(model_name, token, onnx_filename)
33
+ elif conversion_type == "full_model":
34
+ return self.convert_full_model(model_name, token, onnx_filename)
35
+ else:
36
+ raise ValueError("Invalid conversion type. Choose from 'decoder', 'only_decoder', or 'full_model'.")
37
+
38
+ def convert_decoder(self, model_name, token, onnx_filename):
39
+ """
40
+ Converts only the decoder part of the Vits model to ONNX format.
41
+
42
+ Args:
43
+ model_name (str): Name of the model to convert.
44
+ token (str): Access token for loading the model.
45
+ onnx_filename (str): Desired filename for the ONNX output.
46
+
47
+ Returns:
48
+ str: The path to the generated ONNX file.
49
+ """
50
+ model = VitsModel.from_pretrained(model_name, token=token)
51
+ onnx_file = f"/tmp/{onnx_filename}.onnx"
52
+ vocab_size = model.text_encoder.embed_tokens.weight.size(0)
53
+ example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
54
+
55
+ torch.onnx.export(
56
+ model,
57
+ example_input,
58
+ onnx_file,
59
+ opset_version=11,
60
+ input_names=['input'],
61
+ output_names=['output'],
62
+ dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
63
+ )
64
+
65
+ return onnx_file
66
+
67
+
68
+ def convert_only_decoder(self, model_name, token, onnx_filename):
69
+ """
70
+ Converts only the decoder part of the Vits model to ONNX format.
71
+
72
+ Args:
73
+ model_name (str): Name of the model to convert.
74
+ token (str): Access token for loading the model.
75
+ onnx_filename (str): Desired filename for the ONNX output.
76
+
77
+ Returns:
78
+ str: The path to the generated ONNX file.
79
+ """
80
+ model = Vits_models_only_decoder.from_pretrained(model_name, token=token)
81
+ onnx_file = f"/tmp/{onnx_filename}.onnx"
82
+
83
+ vocab_size = model.text_encoder.embed_tokens.weight.size(0)
84
+ example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
85
+
86
+ torch.onnx.export(
87
+ model,
88
+ example_input,
89
+ onnx_file,
90
+ opset_version=11,
91
+ input_names=['input'],
92
+ output_names=['output'],
93
+ dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
94
+ )
95
+
96
+ return onnx_file
97
+
98
+ def convert_full_model(self, model_name, token, onnx_filename):
99
+ """
100
+ Converts the full Vits model (including encoder and decoder) to ONNX format.
101
+
102
+ Args:
103
+ model_name (str): Name of the model to convert.
104
+ token (str): Access token for loading the model.
105
+ onnx_filename (str): Desired filename for the ONNX output.
106
+
107
+ Returns:
108
+ str: The path to the generated ONNX file.
109
+ """
110
+ model = VitsModel.from_pretrained(model_name, token=token)
111
+ onnx_file = f"/tmp/{onnx_filename}.onnx"
112
+
113
+ vocab_size = model.text_encoder.embed_tokens.weight.size(0)
114
+ example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
115
+
116
+ torch.onnx.export(
117
+ model,
118
+ example_input,
119
+ onnx_file,
120
+ opset_version=11,
121
+ input_names=['input'],
122
+ output_names=['output'],
123
+ dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
124
+ )
125
+
126
+ return onnx_file
127
+ def starrt(self):
128
+ with gr.Blocks() as demo:
129
+ with gr.Row():
130
+ with gr.Column():
131
+ text_n_model=gr.Textbox(label="name model")
132
+ text_n_token=gr.Textbox(label="token")
133
+ text_n_onxx=gr.Textbox(label="name model onxx")
134
+ choice = gr.Dropdown(choices=["decoder", "only_decoder", "full_model"], label="My Dropdown")
135
+
136
+ with gr.Column():
137
+
138
+ btn=gr.Button("convert")
139
+ label=gr.Label("return name model onxx")
140
+ btn.click(self.convert,[text_n_model,text_n_token,text_n_onxx,choice],[outputs=gr.File(label="Download File")])
141
+ #choice.change(fn=function_change, inputs=choice, outputs=label)
142
+ return demo
143
+ c=OnnxModelConverter()
144
+ cc=c.starrt()
145
+ cc.launch(share=True)
146
+