Spaces:
Build error
Build error
Commit
·
4e2ccb2
1
Parent(s):
977d4f7
Update app.py
Browse files
app.py
CHANGED
@@ -1,70 +1,161 @@
|
|
|
|
1 |
import os
|
2 |
import tempfile
|
3 |
import gradio as gr
|
|
|
4 |
from TTS.utils.synthesizer import Synthesizer
|
5 |
from huggingface_hub import hf_hub_download
|
|
|
6 |
|
7 |
# Define constants
|
8 |
MODEL_INFO = [
|
9 |
-
|
10 |
-
#["VITS Grapheme Multispeaker CV15(reduct)(best at 17864)", "best_model_17864.pth", "config.json",
|
11 |
-
|
12 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
"vits-multispeaker-495586",
|
19 |
-
# Add other model names similarly...
|
20 |
]
|
21 |
|
|
|
|
|
|
|
22 |
MAX_TXT_LEN = 400
|
23 |
TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
#
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
|
|
40 |
|
41 |
-
|
|
|
|
|
|
|
42 |
"""Synthesize speech using the selected model."""
|
43 |
if len(text) > MAX_TXT_LEN:
|
44 |
text = text[:MAX_TXT_LEN]
|
45 |
print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
|
46 |
-
|
47 |
-
synthesizer
|
|
|
|
|
|
|
48 |
if synthesizer is None:
|
49 |
raise NameError("Model not found")
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
54 |
synthesizer.save_wav(wavs, fp)
|
55 |
return fp.name
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
|
|
58 |
iface = gr.Interface(
|
59 |
fn=synthesize,
|
60 |
inputs=[
|
61 |
gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
|
62 |
-
gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0]),
|
|
|
63 |
],
|
64 |
outputs=gr.Audio(label="Output", type='filepath'),
|
65 |
-
examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],
|
66 |
title='Persian TTS Playground',
|
67 |
-
description="
|
|
|
|
|
|
|
|
|
68 |
article="",
|
69 |
live=False
|
70 |
)
|
|
|
1 |
+
|
2 |
import os
|
3 |
import tempfile
|
4 |
import gradio as gr
|
5 |
+
from TTS.api import TTS
|
6 |
from TTS.utils.synthesizer import Synthesizer
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
+
import json
|
9 |
|
10 |
# Define constants
|
11 |
MODEL_INFO = [
|
12 |
+
["vits checkpoint 57000", "checkpoint_57000.pth", "config.json", "mhrahmani/persian-tts-vits-0"],
|
13 |
+
# ["VITS Grapheme Multispeaker CV15(reduct)(best at 17864)", "best_model_17864.pth", "config.json",
|
14 |
+
# "saillab/persian-tts-cv15-reduct-grapheme-multispeaker"],
|
15 |
+
["VITS Grapheme Multispeaker CV15(reduct)(22000)", "checkpoint_22000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
|
16 |
+
["VITS Grapheme Multispeaker CV15(reduct)(26000)", "checkpoint_25000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
|
17 |
+
["vits-multispeaker-495586", "best_model_495586.pth", "config.json", "saillab/vits_multi_cv_15_validated_dataset","speakers.pth"]
|
18 |
+
|
19 |
+
# ["VITS Grapheme Azure (best at 15934)", "best_model_15934.pth", "config.json",
|
20 |
+
# "saillab/persian-tts-azure-grapheme-60K"],
|
21 |
+
["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
|
22 |
|
23 |
+
["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
|
24 |
+
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
|
25 |
+
["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
|
26 |
+
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
|
27 |
|
28 |
+
# ... Add other models similarly
|
|
|
|
|
29 |
]
|
30 |
|
31 |
+
# Extract model names from MODEL_INFO
|
32 |
+
MODEL_NAMES = [info[0] for info in MODEL_INFO]
|
33 |
+
|
34 |
MAX_TXT_LEN = 400
|
35 |
TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
|
36 |
|
37 |
+
model_files = {}
|
38 |
+
config_files = {}
|
39 |
+
speaker_files = {}
|
40 |
+
|
41 |
+
# Create a dictionary to store synthesizer objects for each model
|
42 |
+
synthesizers = {}
|
43 |
+
|
44 |
+
def update_config_speakers_file_recursive(config_dict, speakers_path):
|
45 |
+
"""Recursively update speakers_file keys in a dictionary."""
|
46 |
+
if "speakers_file" in config_dict:
|
47 |
+
config_dict["speakers_file"] = speakers_path
|
48 |
+
for key, value in config_dict.items():
|
49 |
+
if isinstance(value, dict):
|
50 |
+
update_config_speakers_file_recursive(value, speakers_path)
|
51 |
+
|
52 |
+
def update_config_speakers_file(config_path, speakers_path):
|
53 |
+
"""Update the config.json file to point to the correct speakers.pth file."""
|
54 |
|
55 |
+
# Load the existing config
|
56 |
+
with open(config_path, 'r') as f:
|
57 |
+
config = json.load(f)
|
58 |
+
|
59 |
+
# Modify the speakers_file entry
|
60 |
+
update_config_speakers_file_recursive(config, speakers_path)
|
61 |
+
|
62 |
+
# Save the modified config
|
63 |
+
with open(config_path, 'w') as f:
|
64 |
+
json.dump(config, f, indent=4)
|
65 |
+
|
66 |
+
# Download models and initialize synthesizers
|
67 |
+
for info in MODEL_INFO:
|
68 |
+
model_name, model_file, config_file, repo_name = info[:4]
|
69 |
+
speaker_file = info[4] if len(info) == 5 else None # Check if speakers.pth is defined for the model
|
70 |
+
|
71 |
+
print(f"|> Downloading: {model_name}")
|
72 |
+
|
73 |
+
# Download model and config files
|
74 |
+
model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
|
75 |
+
config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
|
76 |
|
77 |
+
# Download speakers.pth if it exists
|
78 |
+
if speaker_file:
|
79 |
+
speaker_files[model_name] = hf_hub_download(repo_id=repo_name, filename=speaker_file, use_auth_token=TOKEN)
|
80 |
+
update_config_speakers_file(config_files[model_name], speaker_files[model_name]) # Update the config file
|
81 |
+
print(speaker_files[model_name])
|
82 |
+
# Initialize synthesizer for the model
|
83 |
+
synthesizer = Synthesizer(
|
84 |
+
tts_checkpoint=model_files[model_name],
|
85 |
+
tts_config_path=config_files[model_name],
|
86 |
+
tts_speakers_file=speaker_files[model_name], # Pass the speakers.pth file if it exists
|
87 |
+
use_cuda=False # Assuming you don't want to use GPU, adjust if needed
|
88 |
+
)
|
89 |
|
90 |
+
elif speaker_file is None:
|
91 |
+
|
92 |
+
# Initialize synthesizer for the model
|
93 |
+
synthesizer = Synthesizer(
|
94 |
+
tts_checkpoint=model_files[model_name],
|
95 |
+
tts_config_path=config_files[model_name],
|
96 |
+
# tts_speakers_file=speaker_files.get(model_name, None), # Pass the speakers.pth file if it exists
|
97 |
+
use_cuda=False # Assuming you don't want to use GPU, adjust if needed
|
98 |
+
)
|
99 |
|
100 |
+
synthesizers[model_name] = synthesizer
|
101 |
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
def synthesize(text: str, model_name: str, speaker_name=None) -> str:
|
106 |
"""Synthesize speech using the selected model."""
|
107 |
if len(text) > MAX_TXT_LEN:
|
108 |
text = text[:MAX_TXT_LEN]
|
109 |
print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
|
110 |
+
|
111 |
+
# Use the synthesizer object for the selected model
|
112 |
+
synthesizer = synthesizers[model_name]
|
113 |
+
|
114 |
+
|
115 |
if synthesizer is None:
|
116 |
raise NameError("Model not found")
|
117 |
+
|
118 |
+
if synthesizer.tts_speakers_file is "":
|
119 |
+
wavs = synthesizer.tts(text)
|
120 |
+
|
121 |
+
elif synthesizer.tts_speakers_file is not "":
|
122 |
+
if speaker_name == "":
|
123 |
+
wavs = synthesizer.tts(text, speaker_name="speaker-0") ## should change, better if gradio conditions are figure out.
|
124 |
+
else:
|
125 |
+
wavs = synthesizer.tts(text, speaker_name=speaker_name)
|
126 |
+
|
127 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
128 |
synthesizer.save_wav(wavs, fp)
|
129 |
return fp.name
|
130 |
|
131 |
+
# Callback function to update UI based on the selected model
|
132 |
+
def update_options(model_name):
|
133 |
+
synthesizer = synthesizers[model_name]
|
134 |
+
# if synthesizer.tts.is_multi_speaker:
|
135 |
+
if model_name is MODEL_NAMES[1]:
|
136 |
+
speakers = synthesizer.tts_model.speaker_manager.speaker_names
|
137 |
+
# return options for the dropdown
|
138 |
+
return speakers
|
139 |
+
else:
|
140 |
+
# return empty options if not multi-speaker
|
141 |
+
return []
|
142 |
|
143 |
+
# Create Gradio interface
|
144 |
iface = gr.Interface(
|
145 |
fn=synthesize,
|
146 |
inputs=[
|
147 |
gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
|
148 |
+
gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
|
149 |
+
gr.Dropdown(label="Select Speaker", choices=update_options(MODEL_NAMES[1]), type="value", default="speaker-0")
|
150 |
],
|
151 |
outputs=gr.Audio(label="Output", type='filepath'),
|
152 |
+
examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0], ""]], # Example should include a speaker name for multispeaker models
|
153 |
title='Persian TTS Playground',
|
154 |
+
description="""
|
155 |
+
### Persian text to speech model demo.
|
156 |
+
|
157 |
+
#### Pick a speaker for MultiSpeaker models. (It won't affect the single speaker models)
|
158 |
+
""",
|
159 |
article="",
|
160 |
live=False
|
161 |
)
|