Spaces:
Build error
Build error
barghavani
commited on
Commit
·
b39f00f
1
Parent(s):
505e1ed
Update app.py
Browse files
app.py
CHANGED
@@ -8,23 +8,14 @@ import json
|
|
8 |
|
9 |
# Define constants
|
10 |
MODEL_INFO = [
|
11 |
-
|
12 |
-
|
13 |
-
# "saillab/persian-tts-cv15-reduct-grapheme-multispeaker"],
|
14 |
-
#["VITS Grapheme Multispeaker CV15(reduct)(22000)", "checkpoint_22000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
|
15 |
-
#["VITS Grapheme Multispeaker CV15(reduct)(26000)", "checkpoint_25000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
|
16 |
-
["VITS Grapheme Multispeaker CV15(90K)", "best_model_56960.pth", "config.json", "saillab/multi_speaker", "speakers.pth"],
|
17 |
-
|
18 |
-
# ["VITS Grapheme Azure (best at 15934)", "best_model_15934.pth", "config.json",
|
19 |
-
# "saillab/persian-tts-azure-grapheme-60K"],
|
20 |
["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
|
21 |
|
22 |
["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
|
23 |
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
|
24 |
["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
|
25 |
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
|
26 |
-
|
27 |
-
# ... Add other models similarly
|
28 |
]
|
29 |
|
30 |
# Extract model names from MODEL_INFO
|
@@ -35,37 +26,14 @@ TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
|
|
35 |
|
36 |
model_files = {}
|
37 |
config_files = {}
|
38 |
-
speaker_files = {}
|
39 |
|
40 |
# Create a dictionary to store synthesizer objects for each model
|
41 |
synthesizers = {}
|
42 |
|
43 |
-
|
44 |
-
"""Recursively update speakers_file keys in a dictionary."""
|
45 |
-
if "speakers_file" in config_dict:
|
46 |
-
config_dict["speakers_file"] = speakers_path
|
47 |
-
for key, value in config_dict.items():
|
48 |
-
if isinstance(value, dict):
|
49 |
-
update_config_speakers_file_recursive(value, speakers_path)
|
50 |
-
|
51 |
-
def update_config_speakers_file(config_path, speakers_path):
|
52 |
-
"""Update the config.json file to point to the correct speakers.pth file."""
|
53 |
-
|
54 |
-
# Load the existing config
|
55 |
-
with open(config_path, 'r') as f:
|
56 |
-
config = json.load(f)
|
57 |
-
|
58 |
-
# Modify the speakers_file entry
|
59 |
-
update_config_speakers_file_recursive(config, speakers_path)
|
60 |
-
|
61 |
-
# Save the modified config
|
62 |
-
with open(config_path, 'w') as f:
|
63 |
-
json.dump(config, f, indent=4)
|
64 |
-
|
65 |
# Download models and initialize synthesizers
|
66 |
for info in MODEL_INFO:
|
67 |
model_name, model_file, config_file, repo_name = info[:4]
|
68 |
-
speaker_file = info[4] if len(info) == 5 else None # Check if speakers.pth is defined for the model
|
69 |
|
70 |
print(f"|> Downloading: {model_name}")
|
71 |
|
@@ -73,91 +41,42 @@ for info in MODEL_INFO:
|
|
73 |
model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
|
74 |
config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
|
75 |
|
76 |
-
#
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
print(speaker_files[model_name])
|
81 |
-
# Initialize synthesizer for the model
|
82 |
-
synthesizer = Synthesizer(
|
83 |
-
tts_checkpoint=model_files[model_name],
|
84 |
-
tts_config_path=config_files[model_name],
|
85 |
-
tts_speakers_file=speaker_files[model_name], # Pass the speakers.pth file if it exists
|
86 |
-
use_cuda=False # Assuming you don't want to use GPU, adjust if needed
|
87 |
-
)
|
88 |
-
|
89 |
-
elif speaker_file is None:
|
90 |
-
|
91 |
-
# Initialize synthesizer for the model
|
92 |
-
synthesizer = Synthesizer(
|
93 |
-
tts_checkpoint=model_files[model_name],
|
94 |
-
tts_config_path=config_files[model_name],
|
95 |
-
# tts_speakers_file=speaker_files.get(model_name, None), # Pass the speakers.pth file if it exists
|
96 |
-
use_cuda=False # Assuming you don't want to use GPU, adjust if needed
|
97 |
-
)
|
98 |
-
|
99 |
synthesizers[model_name] = synthesizer
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
def synthesize(text: str, model_name: str, speaker_name=None) -> str:
|
105 |
-
"""Synthesize speech using the selected model."""
|
106 |
if len(text) > MAX_TXT_LEN:
|
107 |
text = text[:MAX_TXT_LEN]
|
108 |
print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
|
109 |
|
110 |
-
# Use the synthesizer object for the selected model
|
111 |
synthesizer = synthesizers[model_name]
|
112 |
-
|
113 |
-
|
114 |
if synthesizer is None:
|
115 |
raise NameError("Model not found")
|
116 |
|
117 |
-
|
118 |
-
wavs = synthesizer.tts(text)
|
119 |
-
|
120 |
-
elif synthesizer.tts_speakers_file is not "":
|
121 |
-
if speaker_name == "":
|
122 |
-
wavs = synthesizer.tts(text, speaker_name="speaker-0") ## should change, better if gradio conditions are figure out.
|
123 |
-
else:
|
124 |
-
wavs = synthesizer.tts(text, speaker_name=speaker_name)
|
125 |
|
126 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
127 |
synthesizer.save_wav(wavs, fp)
|
128 |
return fp.name
|
129 |
|
130 |
-
# Callback function to update UI based on the selected model
|
131 |
-
def update_options(model_name):
|
132 |
-
synthesizer = synthesizers[model_name]
|
133 |
-
# if synthesizer.tts.is_multi_speaker:
|
134 |
-
if model_name is MODEL_NAMES[1]:
|
135 |
-
speakers = synthesizer.tts_model.speaker_manager.speaker_names
|
136 |
-
# return options for the dropdown
|
137 |
-
return speakers
|
138 |
-
else:
|
139 |
-
# return empty options if not multi-speaker
|
140 |
-
return []
|
141 |
-
|
142 |
-
# Create Gradio interface
|
143 |
iface = gr.Interface(
|
144 |
fn=synthesize,
|
145 |
inputs=[
|
146 |
gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
|
147 |
gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
|
148 |
-
gr.Dropdown(label="Select Speaker", choices=update_options(MODEL_NAMES[1]), type="value", default="speaker-0")
|
149 |
],
|
150 |
outputs=gr.Audio(label="Output", type='filepath'),
|
151 |
-
examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]
|
152 |
title='Persian TTS Playground',
|
153 |
description="""
|
154 |
### Persian text to speech model demo.
|
155 |
-
|
156 |
-
#### Pick a speaker for MultiSpeaker models. (It won't affect the single speaker models)
|
157 |
""",
|
158 |
article="",
|
159 |
live=False
|
160 |
)
|
161 |
|
162 |
-
iface.launch()
|
163 |
-
|
|
|
8 |
|
9 |
# Define constants
|
10 |
MODEL_INFO = [
|
11 |
+
["VITS Grapheme Multispeaker CV15(90K)", "best_model_56960.pth", "config.json", "saillab/multi_speaker"],
|
12 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
|
14 |
|
15 |
["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
|
16 |
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
|
17 |
["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
|
18 |
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
|
|
|
|
|
19 |
]
|
20 |
|
21 |
# Extract model names from MODEL_INFO
|
|
|
26 |
|
27 |
model_files = {}
|
28 |
config_files = {}
|
|
|
29 |
|
30 |
# Create a dictionary to store synthesizer objects for each model
|
31 |
synthesizers = {}
|
32 |
|
33 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Download models and initialize synthesizers
|
35 |
for info in MODEL_INFO:
|
36 |
model_name, model_file, config_file, repo_name = info[:4]
|
|
|
37 |
|
38 |
print(f"|> Downloading: {model_name}")
|
39 |
|
|
|
41 |
model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
|
42 |
config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
|
43 |
|
44 |
+
# Initialize synthesizer for the model
|
45 |
+
synthesizer = Synthesizer(tts_checkpoint=model_files[model_name],
|
46 |
+
tts_config_path=config_files[model_name],
|
47 |
+
use_cuda=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
synthesizers[model_name] = synthesizer
|
49 |
|
50 |
+
def synthesize(text: str, model_name: str) -> str:
|
51 |
+
|
|
|
|
|
|
|
52 |
if len(text) > MAX_TXT_LEN:
|
53 |
text = text[:MAX_TXT_LEN]
|
54 |
print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
|
55 |
|
|
|
56 |
synthesizer = synthesizers[model_name]
|
|
|
|
|
57 |
if synthesizer is None:
|
58 |
raise NameError("Model not found")
|
59 |
|
60 |
+
wavs = synthesizer.tts(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
63 |
synthesizer.save_wav(wavs, fp)
|
64 |
return fp.name
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
iface = gr.Interface(
|
67 |
fn=synthesize,
|
68 |
inputs=[
|
69 |
gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
|
70 |
gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
|
|
|
71 |
],
|
72 |
outputs=gr.Audio(label="Output", type='filepath'),
|
73 |
+
examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],
|
74 |
title='Persian TTS Playground',
|
75 |
description="""
|
76 |
### Persian text to speech model demo.
|
|
|
|
|
77 |
""",
|
78 |
article="",
|
79 |
live=False
|
80 |
)
|
81 |
|
82 |
+
iface.launch()
|
|