Spaces:
Sleeping
Sleeping
File size: 5,492 Bytes
4de59fa 636ff5e 4de59fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
import torch.onnx
import onnx
from VitsModelSplit.vits_model_only_d import Vits_models_only_decoder
from VitsModelSplit.vits_model import VitsModel
import gradio as gr
class OnnxModelConverter:
def __init__(self):
self.model = None
def download_file(self,file_path):
ff= gr.File(value=file_path, visible=True)
file_url = ff.value['url']
return file_url
def convert(self, model_name, token, onnx_filename, conversion_type):
"""
Main function to handle different types of model conversions.
Args:
model_name (str): Name of the model to convert.
token (str): Access token for loading the model.
onnx_filename (str): Desired filename for the ONNX output.
conversion_type (str): Type of conversion ('decoder', 'only_decoder', or 'full_model').
Returns:
str: The path to the generated ONNX file.
"""
if conversion_type == "decoder":
return self.convert_decoder(model_name, token, onnx_filename)
elif conversion_type == "only_decoder":
return self.convert_only_decoder(model_name, token, onnx_filename)
elif conversion_type == "full_model":
return self.convert_full_model(model_name, token, onnx_filename)
else:
raise ValueError("Invalid conversion type. Choose from 'decoder', 'only_decoder', or 'full_model'.")
def convert_decoder(self, model_name, token, onnx_filename):
"""
Converts only the decoder part of the Vits model to ONNX format.
Args:
model_name (str): Name of the model to convert.
token (str): Access token for loading the model.
onnx_filename (str): Desired filename for the ONNX output.
Returns:
str: The path to the generated ONNX file.
"""
model = VitsModel.from_pretrained(model_name, token=token)
onnx_file = f"/tmp/{onnx_filename}.onnx"
vocab_size = model.text_encoder.embed_tokens.weight.size(0)
example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
torch.onnx.export(
model,
example_input,
onnx_file,
opset_version=11,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
)
return onnx_file
def convert_only_decoder(self, model_name, token, onnx_filename):
"""
Converts only the decoder part of the Vits model to ONNX format.
Args:
model_name (str): Name of the model to convert.
token (str): Access token for loading the model.
onnx_filename (str): Desired filename for the ONNX output.
Returns:
str: The path to the generated ONNX file.
"""
model = Vits_models_only_decoder.from_pretrained(model_name, token=token)
onnx_file = f"/tmp/{onnx_filename}.onnx"
vocab_size = model.text_encoder.embed_tokens.weight.size(0)
example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
torch.onnx.export(
model,
example_input,
onnx_file,
opset_version=11,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
)
return onnx_file
def convert_full_model(self, model_name, token, onnx_filename):
"""
Converts the full Vits model (including encoder and decoder) to ONNX format.
Args:
model_name (str): Name of the model to convert.
token (str): Access token for loading the model.
onnx_filename (str): Desired filename for the ONNX output.
Returns:
str: The path to the generated ONNX file.
"""
model = VitsModel.from_pretrained(model_name, token=token)
onnx_file = f"/tmp/{onnx_filename}.onnx"
vocab_size = model.text_encoder.embed_tokens.weight.size(0)
example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
torch.onnx.export(
model,
example_input,
onnx_file,
opset_version=11,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}}
)
return onnx_file
def starrt(self):
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
text_n_model=gr.Textbox(label="name model")
text_n_token=gr.Textbox(label="token")
text_n_onxx=gr.Textbox(label="name model onxx")
choice = gr.Dropdown(choices=["decoder", "only_decoder", "full_model"], label="My Dropdown")
with gr.Column():
btn=gr.Button("convert")
label=gr.Label("return name model onxx")
btn.click(self.convert,[text_n_model,text_n_token,text_n_onxx,choice],[outputs=gr.File(label="Download File")])
#choice.change(fn=function_change, inputs=choice, outputs=label)
return demo
c=OnnxModelConverter()
cc=c.starrt()
cc.launch(share=True)
|