barghavani commited on
Commit
4e2ccb2
·
1 Parent(s): 977d4f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -29
app.py CHANGED
@@ -1,70 +1,161 @@
 
1
  import os
2
  import tempfile
3
  import gradio as gr
 
4
  from TTS.utils.synthesizer import Synthesizer
5
  from huggingface_hub import hf_hub_download
 
6
 
7
  # Define constants
8
  MODEL_INFO = [
9
- #["vits-multispeaker-495586", "best_model_495586.pth", "config.json", "saillab/vits_multi_cv_15_validated_dataset","speakers.pth"],
10
- #["VITS Grapheme Multispeaker CV15(reduct)(best at 17864)", "best_model_17864.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker"]
11
- ["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"]
12
- ]
 
 
 
 
 
 
13
 
14
- # # Extract model names from MODEL_INFO
15
- # MODEL_NAMES = [info[0] for info in MODEL_INFO]
 
 
16
 
17
- MODEL_NAMES = [
18
- "vits-multispeaker-495586",
19
- # Add other model names similarly...
20
  ]
21
 
 
 
 
22
  MAX_TXT_LEN = 400
23
  TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
24
 
25
- # # Download models
26
- # for model_name, model_file, config_file, repo_name in MODEL_INFO:
27
- # os.makedirs(model_name, exist_ok=True)
28
- # print(f"|> Downloading: {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # # Use hf_hub_download to download models from private Hugging Face repositories
31
- # hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
32
- # hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- repo_name = "saillab/vits_multi_cv_15_validated_dataset"
35
- filename = "best_model_495586.pth"
 
 
 
 
 
 
 
 
 
 
36
 
37
- model_file = hf_hub_download(repo_name, filename, use_auth_token=TOKEN)
38
- config_file = hf_hub_download(repo_name, "config.json", use_auth_token=TOKEN)
 
 
 
 
 
 
 
39
 
 
40
 
41
- def synthesize(text: str, model_name: str) -> str:
 
 
 
42
  """Synthesize speech using the selected model."""
43
  if len(text) > MAX_TXT_LEN:
44
  text = text[:MAX_TXT_LEN]
45
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
46
-
47
- synthesizer = Synthesizer(model_file, config_file)
 
 
 
48
  if synthesizer is None:
49
  raise NameError("Model not found")
50
-
51
- wavs = synthesizer.tts(text)
52
-
 
 
 
 
 
 
 
53
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
54
  synthesizer.save_wav(wavs, fp)
55
  return fp.name
56
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
58
  iface = gr.Interface(
59
  fn=synthesize,
60
  inputs=[
61
  gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
62
- gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0]),
 
63
  ],
64
  outputs=gr.Audio(label="Output", type='filepath'),
65
- examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],
66
  title='Persian TTS Playground',
67
- description="Persian text to speech model demo",
 
 
 
 
68
  article="",
69
  live=False
70
  )
 
1
+
2
  import os
3
  import tempfile
4
  import gradio as gr
5
+ from TTS.api import TTS
6
  from TTS.utils.synthesizer import Synthesizer
7
  from huggingface_hub import hf_hub_download
8
+ import json
9
 
10
  # Define constants
11
  MODEL_INFO = [
12
+ ["vits checkpoint 57000", "checkpoint_57000.pth", "config.json", "mhrahmani/persian-tts-vits-0"],
13
+ # ["VITS Grapheme Multispeaker CV15(reduct)(best at 17864)", "best_model_17864.pth", "config.json",
14
+ # "saillab/persian-tts-cv15-reduct-grapheme-multispeaker"],
15
+ ["VITS Grapheme Multispeaker CV15(reduct)(22000)", "checkpoint_22000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
16
+ ["VITS Grapheme Multispeaker CV15(reduct)(26000)", "checkpoint_25000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
17
+ ["vits-multispeaker-495586", "best_model_495586.pth", "config.json", "saillab/vits_multi_cv_15_validated_dataset","speakers.pth"]
18
+
19
+ # ["VITS Grapheme Azure (best at 15934)", "best_model_15934.pth", "config.json",
20
+ # "saillab/persian-tts-azure-grapheme-60K"],
21
+ ["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
22
 
23
+ ["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
24
+ "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
25
+ ["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
26
+ "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
27
 
28
+ # ... Add other models similarly
 
 
29
  ]
30
 
31
+ # Extract model names from MODEL_INFO
32
+ MODEL_NAMES = [info[0] for info in MODEL_INFO]
33
+
34
  MAX_TXT_LEN = 400
35
  TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
36
 
37
+ model_files = {}
38
+ config_files = {}
39
+ speaker_files = {}
40
+
41
+ # Create a dictionary to store synthesizer objects for each model
42
+ synthesizers = {}
43
+
44
+ def update_config_speakers_file_recursive(config_dict, speakers_path):
45
+ """Recursively update speakers_file keys in a dictionary."""
46
+ if "speakers_file" in config_dict:
47
+ config_dict["speakers_file"] = speakers_path
48
+ for key, value in config_dict.items():
49
+ if isinstance(value, dict):
50
+ update_config_speakers_file_recursive(value, speakers_path)
51
+
52
+ def update_config_speakers_file(config_path, speakers_path):
53
+ """Update the config.json file to point to the correct speakers.pth file."""
54
 
55
+ # Load the existing config
56
+ with open(config_path, 'r') as f:
57
+ config = json.load(f)
58
+
59
+ # Modify the speakers_file entry
60
+ update_config_speakers_file_recursive(config, speakers_path)
61
+
62
+ # Save the modified config
63
+ with open(config_path, 'w') as f:
64
+ json.dump(config, f, indent=4)
65
+
66
+ # Download models and initialize synthesizers
67
+ for info in MODEL_INFO:
68
+ model_name, model_file, config_file, repo_name = info[:4]
69
+ speaker_file = info[4] if len(info) == 5 else None # Check if speakers.pth is defined for the model
70
+
71
+ print(f"|> Downloading: {model_name}")
72
+
73
+ # Download model and config files
74
+ model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
75
+ config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
76
 
77
+ # Download speakers.pth if it exists
78
+ if speaker_file:
79
+ speaker_files[model_name] = hf_hub_download(repo_id=repo_name, filename=speaker_file, use_auth_token=TOKEN)
80
+ update_config_speakers_file(config_files[model_name], speaker_files[model_name]) # Update the config file
81
+ print(speaker_files[model_name])
82
+ # Initialize synthesizer for the model
83
+ synthesizer = Synthesizer(
84
+ tts_checkpoint=model_files[model_name],
85
+ tts_config_path=config_files[model_name],
86
+ tts_speakers_file=speaker_files[model_name], # Pass the speakers.pth file if it exists
87
+ use_cuda=False # Assuming you don't want to use GPU, adjust if needed
88
+ )
89
 
90
+ elif speaker_file is None:
91
+
92
+ # Initialize synthesizer for the model
93
+ synthesizer = Synthesizer(
94
+ tts_checkpoint=model_files[model_name],
95
+ tts_config_path=config_files[model_name],
96
+ # tts_speakers_file=speaker_files.get(model_name, None), # Pass the speakers.pth file if it exists
97
+ use_cuda=False # Assuming you don't want to use GPU, adjust if needed
98
+ )
99
 
100
+ synthesizers[model_name] = synthesizer
101
 
102
+
103
+
104
+
105
+ def synthesize(text: str, model_name: str, speaker_name=None) -> str:
106
  """Synthesize speech using the selected model."""
107
  if len(text) > MAX_TXT_LEN:
108
  text = text[:MAX_TXT_LEN]
109
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
110
+
111
+ # Use the synthesizer object for the selected model
112
+ synthesizer = synthesizers[model_name]
113
+
114
+
115
  if synthesizer is None:
116
  raise NameError("Model not found")
117
+
118
+ if synthesizer.tts_speakers_file is "":
119
+ wavs = synthesizer.tts(text)
120
+
121
+ elif synthesizer.tts_speakers_file is not "":
122
+ if speaker_name == "":
123
+ wavs = synthesizer.tts(text, speaker_name="speaker-0") ## should change, better if gradio conditions are figure out.
124
+ else:
125
+ wavs = synthesizer.tts(text, speaker_name=speaker_name)
126
+
127
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
128
  synthesizer.save_wav(wavs, fp)
129
  return fp.name
130
 
131
+ # Callback function to update UI based on the selected model
132
+ def update_options(model_name):
133
+ synthesizer = synthesizers[model_name]
134
+ # if synthesizer.tts.is_multi_speaker:
135
+ if model_name is MODEL_NAMES[1]:
136
+ speakers = synthesizer.tts_model.speaker_manager.speaker_names
137
+ # return options for the dropdown
138
+ return speakers
139
+ else:
140
+ # return empty options if not multi-speaker
141
+ return []
142
 
143
+ # Create Gradio interface
144
  iface = gr.Interface(
145
  fn=synthesize,
146
  inputs=[
147
  gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
148
+ gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
149
+ gr.Dropdown(label="Select Speaker", choices=update_options(MODEL_NAMES[1]), type="value", default="speaker-0")
150
  ],
151
  outputs=gr.Audio(label="Output", type='filepath'),
152
+ examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0], ""]], # Example should include a speaker name for multispeaker models
153
  title='Persian TTS Playground',
154
+ description="""
155
+ ### Persian text to speech model demo.
156
+
157
+ #### Pick a speaker for MultiSpeaker models. (It won't affect the single speaker models)
158
+ """,
159
  article="",
160
  live=False
161
  )