File size: 2,808 Bytes
92d8e4f
 
 
4e2ccb2
92d8e4f
 
4e2ccb2
92d8e4f
 
 
b39f00f
 
254f03b
92d8e4f
ab9a8c2
 
 
 
92d8e4f
ab9a8c2
4e2ccb2
 
 
92d8e4f
7127b4a
92d8e4f
4e2ccb2
 
 
 
 
 
b39f00f
4e2ccb2
 
 
 
 
 
 
7127b4a
 
92d8e4f
b39f00f
 
 
 
4e2ccb2
92d8e4f
b39f00f
 
92d8e4f
 
 
4e2ccb2
 
92d8e4f
 
4e2ccb2
b39f00f
4e2ccb2
92d8e4f
 
 
 
 
 
 
 
4e2ccb2
92d8e4f
 
b39f00f
92d8e4f
4e2ccb2
 
 
92d8e4f
 
 
 
b39f00f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import tempfile
import gradio as gr
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer
from huggingface_hub import hf_hub_download
import json

# Define constants
MODEL_INFO = [
    ["VITS Grapheme Multispeaker CV15(90K)", "best_model_56960.pth", "config.json", "saillab/multi_speaker"],

    ["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],

    ["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
     "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
    ["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
     "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
]

# Extract model names from MODEL_INFO
MODEL_NAMES = [info[0] for info in MODEL_INFO]

MAX_TXT_LEN = 400
TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')

model_files = {}
config_files = {}

# Create a dictionary to store synthesizer objects for each model
synthesizers = {}

   
# Download models and initialize synthesizers
for info in MODEL_INFO:
    model_name, model_file, config_file, repo_name = info[:4]

    print(f"|> Downloading: {model_name}")

    # Download model and config files
    model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
    config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)

    # Initialize synthesizer for the model
    synthesizer = Synthesizer(tts_checkpoint=model_files[model_name],
                              tts_config_path=config_files[model_name],
                              use_cuda=False)
    synthesizers[model_name] = synthesizer

def synthesize(text: str, model_name: str) -> str:
  
    if len(text) > MAX_TXT_LEN:
        text = text[:MAX_TXT_LEN]
        print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")

    synthesizer = synthesizers[model_name]
    if synthesizer is None:
        raise NameError("Model not found")

    wavs = synthesizer.tts(text)
  
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
        synthesizer.save_wav(wavs, fp)
        return fp.name

iface = gr.Interface(
    fn=synthesize,
    inputs=[
        gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
        gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
    ],
    outputs=gr.Audio(label="Output", type='filepath'),
    examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],  
    title='Persian TTS Playground',
    description="""
    ### Persian text to speech model demo. 
    """,
    article="",
    live=False
)

iface.launch()