mhrahmani commited on
Commit
4712ee3
·
1 Parent(s): aba350c

Update app.py

Browse files

refactor to gradio

Files changed (1) hide show
  1. app.py +45 -43
app.py CHANGED
@@ -1,60 +1,62 @@
1
- import streamlit as st
2
- import tempfile
3
  import os
4
- from TTS.config import load_config
5
- from TTS.utils.manage import ModelManager
6
  from TTS.utils.synthesizer import Synthesizer
7
- from TTS.utils.download import download_url
8
 
9
  # Define constants
10
- MAX_TXT_LEN = 800
11
  MODEL_INFO = [
12
- # ["Model Name", "Model File", "Config File", "URL"]
13
- # Add other models in the same format
14
- ["vits-espeak-57000", "checkpoint_57000.pth", "config.json", "https://huggingface.co/mhrahmani/persian-tts-vits-0/tree/main"],
15
- # ...
16
  ]
17
 
 
 
 
 
 
 
18
  # Download models
19
- def download_models():
20
- for model_name, model_file, config_file, url in MODEL_INFO:
21
- directory = model_name
22
- os.makedirs(directory, exist_ok=True)
23
- download_url(f"{url}{model_file}", directory, str(model_file))
24
- download_url(f"{url}{config_file}", directory, "config.json")
25
-
26
- # Load a model and perform TTS
27
- def synthesize_speech(text, model_name):
 
28
  if len(text) > MAX_TXT_LEN:
29
  text = text[:MAX_TXT_LEN]
30
- st.warning(f"Input text was truncated to {MAX_TXT_LEN} characters.")
31
-
32
- synthesizer = Synthesizer(f"{model_name}/best_model.pth", f"{model_name}/config.json")
33
  if synthesizer is None:
34
- st.error("Model not found!")
35
- return None
36
 
37
  wavs = synthesizer.tts(text)
 
38
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
39
  synthesizer.save_wav(wavs, fp)
40
  return fp.name
41
 
42
- # Streamlit app
43
- def main():
44
- st.title('persian tts playground')
45
- st.markdown("""
46
- Persian TTS Demo)
47
- """)
48
-
49
- text_input = st.text_area("Enter Text to Synthesize:", "زین همرهان سست عناصر، دلم گرفت.")
50
- model_name = st.selectbox("Pick a TTS Model", [info[0] for info in MODEL_INFO], index=1)
51
-
52
- if st.button('Synthesize'):
53
- audio_file = synthesize_speech(text_input, model_name)
54
- if audio_file:
55
- st.audio(audio_file, format='audio/wav')
56
-
57
- # Download models and run the Streamlit app
58
- if __name__ == "__main__":
59
- download_models()
60
- main()
 
 
 
1
  import os
2
+ import tempfile
3
+ import gradio as gr
4
  from TTS.utils.synthesizer import Synthesizer
5
+ from huggingface_hub import hf_hub_download
6
 
7
  # Define constants
 
8
  MODEL_INFO = [
9
+ # ["Model Name", "Model File", "Config File", "Hub URL"]
10
+ ["vits-espeak-57000", "checkpoint_57000.pth", "config.json", "mhrahmani/persian-tts-vits-0"],
11
+ # Add other models similarly...
 
12
  ]
13
 
14
+ # Extract model names from MODEL_INFO
15
+ MODEL_NAMES = [info[0] for info in MODEL_INFO]
16
+
17
+ MAX_TXT_LEN = 400
18
+ TOKEN = os.environ.get('HUGGING_FACE_HUB_TOKEN') # Replace with the environment variable containing your token, if different
19
+
20
  # Download models
21
+ for model_name, model_file, config_file, repo_name in MODEL_INFO:
22
+ os.makedirs(model_name, exist_ok=True)
23
+ print(f"|> Downloading: {model_name}")
24
+
25
+ # Use hf_hub_download to download models from Hugging Face repositories
26
+ hf_hub_download(repo_id=repo_name, filename=model_file, cache_dir=model_name, use_auth_token=TOKEN)
27
+ hf_hub_download(repo_id=repo_name, filename=config_file, cache_dir=model_name, use_auth_token=TOKEN)
28
+
29
+ def synthesize(text: str, model_name: str) -> str:
30
+ """Synthesize speech using the selected model."""
31
  if len(text) > MAX_TXT_LEN:
32
  text = text[:MAX_TXT_LEN]
33
+ print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
34
+
35
+ synthesizer = Synthesizer(f"{model_name}/{model_file}", f"{model_name}/{config_file}")
36
  if synthesizer is None:
37
+ raise NameError("Model not found")
 
38
 
39
  wavs = synthesizer.tts(text)
40
+
41
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
42
  synthesizer.save_wav(wavs, fp)
43
  return fp.name
44
 
45
+ # Define Gradio interface
46
+ iface = gr.Interface(
47
+ fn=synthesize,
48
+ inputs=[
49
+ gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
50
+ gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0]),
51
+ ],
52
+ outputs=gr.Audio(label="Output", type='filepath'),
53
+ examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],
54
+ title='persian tts playground',
55
+ description="Persian text to speech model demo", # Add the required description here.
56
+ article="",
57
+ live=False
58
+ )
59
+
60
+ # Launch the interface
61
+ iface.launch(share=False)
62
+