wasmdashai commited on
Commit
8838afa
·
verified ·
1 Parent(s): 9dd92d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import numpy as np
7
 
8
  token=os.environ.get("key_")
9
- tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vtk",token=token)
10
  models= {}
11
 
12
  import noisereduce as nr
@@ -107,15 +107,32 @@ def _inference_forward_stream(
107
  def get_model(name_model):
108
  global models
109
  if name_model in models:
110
- return models[name_model]
 
 
 
 
 
 
 
 
111
  models[name_model]=VitsModel.from_pretrained(name_model,token=token)
 
 
 
112
  models[name_model].decoder.apply_weight_norm()
113
  # torch.nn.utils.weight_norm(self.decoder.conv_pre)
114
  # torch.nn.utils.weight_norm(self.decoder.conv_post)
115
  for flow in models[name_model].flow.flows:
116
  torch.nn.utils.weight_norm(flow.conv_pre)
117
  torch.nn.utils.weight_norm(flow.conv_post)
118
- return models[name_model]
 
 
 
 
 
 
119
 
120
 
121
 
@@ -144,7 +161,8 @@ model_choices = gr.Dropdown(
144
 
145
  "wasmdashai/vits-ar-sa-A",
146
  "wasmdashai/vits-ar-ye-sa",
147
- "wasmdashai/vits-ar-sa-M-v1"
 
148
 
149
 
150
  ],
 
6
  import numpy as np
7
 
8
  token=os.environ.get("key_")
9
+ tokenizers={}
10
  models= {}
11
 
12
  import noisereduce as nr
 
107
  def get_model(name_model):
108
  global models
109
  if name_model in models:
110
+ if name_model=='wasmdashai/vits-en-v1':
111
+ tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vits-en-v1",token=token)
112
+ else:
113
+ tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vtk",token=token)
114
+
115
+
116
+
117
+
118
+ return models[name_model],tokenizer
119
  models[name_model]=VitsModel.from_pretrained(name_model,token=token)
120
+
121
+
122
+
123
  models[name_model].decoder.apply_weight_norm()
124
  # torch.nn.utils.weight_norm(self.decoder.conv_pre)
125
  # torch.nn.utils.weight_norm(self.decoder.conv_post)
126
  for flow in models[name_model].flow.flows:
127
  torch.nn.utils.weight_norm(flow.conv_pre)
128
  torch.nn.utils.weight_norm(flow.conv_post)
129
+
130
+ if name_model=='wasmdashai/vits-en-v1':
131
+ tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vits-en-v1",token=token)
132
+ else:
133
+ tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vtk",token=token)
134
+
135
+ return models[name_model],tokenizer
136
 
137
 
138
 
 
161
 
162
  "wasmdashai/vits-ar-sa-A",
163
  "wasmdashai/vits-ar-ye-sa",
164
+ "wasmdashai/vits-ar-sa-M-v1",
165
+ "wasmdashai/vits-en-v1"
166
 
167
 
168
  ],