wasmdashai commited on
Commit
bd2d3be
·
verified ·
1 Parent(s): 5d372fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -44,7 +44,14 @@ class model_onxx:
44
  def __init__(self):
45
  self.model=None
46
  self.n_onxx=""
 
47
  pass
 
 
 
 
 
 
48
  def function_change(self,n_model,token,n_onxx,choice):
49
  if choice=="decoder":
50
 
@@ -66,12 +73,15 @@ class model_onxx:
66
  def convert_to_onnx_only_decoder(self,n_model,token,namemodelonxx):
67
  model=VitsModel.from_pretrained(n_model,token=token)
68
  x=f"{namemodelonxx}.onnx"
 
 
 
69
  vocab_size = model.text_encoder.embed_tokens.weight.size(0)
70
  example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
71
  torch.onnx.export(
72
  model, # The model to be exported
73
  example_input, # Example input for the model
74
- x, # The filename for the exported ONNX model
75
  opset_version=11, # Use an appropriate ONNX opset version
76
  input_names=['input'], # Name of the input layer
77
  output_names=['output'], # Name of the output layer
@@ -80,7 +90,7 @@ class model_onxx:
80
  'output': {0: 'batch_size'}
81
  }
82
  )
83
- return x
84
  def convert_to_onnx_all(self,n_model,token ,namemodelonxx):
85
 
86
  model=VitsModel.from_pretrained(n_model,token=token)
 
44
  def __init__(self):
45
  self.model=None
46
  self.n_onxx=""
47
+ self.storage_dir = "uploads"
48
  pass
49
+
50
+
51
+
52
+
53
+
54
+
55
  def function_change(self,n_model,token,n_onxx,choice):
56
  if choice=="decoder":
57
 
 
73
  def convert_to_onnx_only_decoder(self,n_model,token,namemodelonxx):
74
  model=VitsModel.from_pretrained(n_model,token=token)
75
  x=f"{namemodelonxx}.onnx"
76
+ if not os.path.exists("uploads"):
77
+ os.makedirs(storage_dir)
78
+ file_path = os.path.join("uploads",x)
79
  vocab_size = model.text_encoder.embed_tokens.weight.size(0)
80
  example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long)
81
  torch.onnx.export(
82
  model, # The model to be exported
83
  example_input, # Example input for the model
84
+ file_path, # The filename for the exported ONNX model
85
  opset_version=11, # Use an appropriate ONNX opset version
86
  input_names=['input'], # Name of the input layer
87
  output_names=['output'], # Name of the output layer
 
90
  'output': {0: 'batch_size'}
91
  }
92
  )
93
+ return file_path
94
  def convert_to_onnx_all(self,n_model,token ,namemodelonxx):
95
 
96
  model=VitsModel.from_pretrained(n_model,token=token)