ZabanZad_PoC / app.py
barghavani's picture
Update app.py
b39f00f
raw
history blame
2.81 kB
import os
import tempfile
import gradio as gr
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer
from huggingface_hub import hf_hub_download
import json
# Define constants
MODEL_INFO = [
["VITS Grapheme Multispeaker CV15(90K)", "best_model_56960.pth", "config.json", "saillab/multi_speaker"],
["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
"saillab/persian-tts-grapheme-arm24-finetuned-on1"],
]
# Extract model names from MODEL_INFO
MODEL_NAMES = [info[0] for info in MODEL_INFO]
MAX_TXT_LEN = 400
TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
model_files = {}
config_files = {}
# Create a dictionary to store synthesizer objects for each model
synthesizers = {}
# Download models and initialize synthesizers
for info in MODEL_INFO:
model_name, model_file, config_file, repo_name = info[:4]
print(f"|> Downloading: {model_name}")
# Download model and config files
model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
# Initialize synthesizer for the model
synthesizer = Synthesizer(tts_checkpoint=model_files[model_name],
tts_config_path=config_files[model_name],
use_cuda=False)
synthesizers[model_name] = synthesizer
def synthesize(text: str, model_name: str) -> str:
if len(text) > MAX_TXT_LEN:
text = text[:MAX_TXT_LEN]
print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
synthesizer = synthesizers[model_name]
if synthesizer is None:
raise NameError("Model not found")
wavs = synthesizer.tts(text)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
synthesizer.save_wav(wavs, fp)
return fp.name
iface = gr.Interface(
fn=synthesize,
inputs=[
gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
],
outputs=gr.Audio(label="Output", type='filepath'),
examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],
title='Persian TTS Playground',
description="""
### Persian text to speech model demo.
""",
article="",
live=False
)
iface.launch()