wasmdashai commited on
Commit
ff3f6c6
·
verified ·
1 Parent(s): b3add08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -4
app.py CHANGED
@@ -7,21 +7,46 @@ import os
7
 
8
  token=os.environ.get("key_")
9
  tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vtk",token=token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- model=VitsModel.from_pretrained("wasmdashai/vtk",token=token).cuda()
12
  zero = torch.Tensor([0]).cuda()
13
  print(zero.device) # <-- 'cpu' 🤔
14
  import torch
15
  @spaces.GPU
16
- def modelspeech(text):
17
 
18
 
19
  inputs = tokenizer(text, return_tensors="pt")
 
20
  with torch.no_grad():
21
  wav = model(input_ids=inputs["input_ids"].cuda()).waveform.cpu().numpy().reshape(-1)#.detach()
22
 
23
  return model.config.sampling_rate,wav#remove_noise_nr(wav)
24
 
25
-
26
- demo = gr.Interface(fn=modelspeech, inputs=["text"], outputs=["audio"])
 
 
 
 
 
 
 
 
 
 
27
  demo.launch()
 
7
 
8
  token=os.environ.get("key_")
9
  tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vtk",token=token)
10
+ models= {}
11
+
12
+ def get_model(name_model):
13
+ global models
14
+ if name_model in models:
15
+ return models[name_model]
16
+ models[name_model]=VitsModel.from_pretrained(name_model,token=token).cuda()
17
+ models[name_model].decoder.apply_weight_norm()
18
+ # torch.nn.utils.weight_norm(self.decoder.conv_pre)
19
+ # torch.nn.utils.weight_norm(self.decoder.conv_post)
20
+ for flow in models[name_model].flow.flows:
21
+ torch.nn.utils.weight_norm(flow.conv_pre)
22
+ torch.nn.utils.weight_norm(flow.conv_post)
23
+ return models[name_model]
24
+
25
 
 
26
  zero = torch.Tensor([0]).cuda()
27
  print(zero.device) # <-- 'cpu' 🤔
28
  import torch
29
  @spaces.GPU
30
+ def modelspeech(text,name_model):
31
 
32
 
33
  inputs = tokenizer(text, return_tensors="pt")
34
+ model=get_model(name_model)
35
  with torch.no_grad():
36
  wav = model(input_ids=inputs["input_ids"].cuda()).waveform.cpu().numpy().reshape(-1)#.detach()
37
 
38
  return model.config.sampling_rate,wav#remove_noise_nr(wav)
39
 
40
+ model_choices = gr.Dropdown(
41
+ choices=[
42
+ "wasmdashai/vits-ar-sa",
43
+ "wasmdashai/vits-ar-sa-huba",
44
+ "wasmdashai/vits-ar-sa-ms",
45
+ "wasmdashai/vits-ar-sa-magd",
46
+ "wasmdashai/vtk",
47
+ ],
48
+ label="اختر النموذج",
49
+ value="wasmdashai/vtk",
50
+ )
51
+ demo = gr.Interface(fn=modelspeech, inputs=["text",model_choices], outputs=["audio"])
52
  demo.launch()