barghavani commited on
Commit
b39f00f
·
1 Parent(s): 505e1ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -93
app.py CHANGED
@@ -8,23 +8,14 @@ import json
8
 
9
  # Define constants
10
  MODEL_INFO = [
11
- #["vits checkpoint 57000", "checkpoint_57000.pth", "config.json", "mhrahmani/persian-tts-vits-0"],
12
- # ["VITS Grapheme Multispeaker CV15(reduct)(best at 17864)", "best_model_17864.pth", "config.json",
13
- # "saillab/persian-tts-cv15-reduct-grapheme-multispeaker"],
14
- #["VITS Grapheme Multispeaker CV15(reduct)(22000)", "checkpoint_22000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
15
- #["VITS Grapheme Multispeaker CV15(reduct)(26000)", "checkpoint_25000.pth", "config.json", "saillab/persian-tts-cv15-reduct-grapheme-multispeaker", "speakers.pth"],
16
- ["VITS Grapheme Multispeaker CV15(90K)", "best_model_56960.pth", "config.json", "saillab/multi_speaker", "speakers.pth"],
17
-
18
- # ["VITS Grapheme Azure (best at 15934)", "best_model_15934.pth", "config.json",
19
- # "saillab/persian-tts-azure-grapheme-60K"],
20
  ["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
21
 
22
  ["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
23
  "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
24
  ["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
25
  "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
26
-
27
- # ... Add other models similarly
28
  ]
29
 
30
  # Extract model names from MODEL_INFO
@@ -35,37 +26,14 @@ TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
35
 
36
  model_files = {}
37
  config_files = {}
38
- speaker_files = {}
39
 
40
  # Create a dictionary to store synthesizer objects for each model
41
  synthesizers = {}
42
 
43
- def update_config_speakers_file_recursive(config_dict, speakers_path):
44
- """Recursively update speakers_file keys in a dictionary."""
45
- if "speakers_file" in config_dict:
46
- config_dict["speakers_file"] = speakers_path
47
- for key, value in config_dict.items():
48
- if isinstance(value, dict):
49
- update_config_speakers_file_recursive(value, speakers_path)
50
-
51
- def update_config_speakers_file(config_path, speakers_path):
52
- """Update the config.json file to point to the correct speakers.pth file."""
53
-
54
- # Load the existing config
55
- with open(config_path, 'r') as f:
56
- config = json.load(f)
57
-
58
- # Modify the speakers_file entry
59
- update_config_speakers_file_recursive(config, speakers_path)
60
-
61
- # Save the modified config
62
- with open(config_path, 'w') as f:
63
- json.dump(config, f, indent=4)
64
-
65
  # Download models and initialize synthesizers
66
  for info in MODEL_INFO:
67
  model_name, model_file, config_file, repo_name = info[:4]
68
- speaker_file = info[4] if len(info) == 5 else None # Check if speakers.pth is defined for the model
69
 
70
  print(f"|> Downloading: {model_name}")
71
 
@@ -73,91 +41,42 @@ for info in MODEL_INFO:
73
  model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
74
  config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
75
 
76
- # Download speakers.pth if it exists
77
- if speaker_file:
78
- speaker_files[model_name] = hf_hub_download(repo_id=repo_name, filename=speaker_file, use_auth_token=TOKEN)
79
- update_config_speakers_file(config_files[model_name], speaker_files[model_name]) # Update the config file
80
- print(speaker_files[model_name])
81
- # Initialize synthesizer for the model
82
- synthesizer = Synthesizer(
83
- tts_checkpoint=model_files[model_name],
84
- tts_config_path=config_files[model_name],
85
- tts_speakers_file=speaker_files[model_name], # Pass the speakers.pth file if it exists
86
- use_cuda=False # Assuming you don't want to use GPU, adjust if needed
87
- )
88
-
89
- elif speaker_file is None:
90
-
91
- # Initialize synthesizer for the model
92
- synthesizer = Synthesizer(
93
- tts_checkpoint=model_files[model_name],
94
- tts_config_path=config_files[model_name],
95
- # tts_speakers_file=speaker_files.get(model_name, None), # Pass the speakers.pth file if it exists
96
- use_cuda=False # Assuming you don't want to use GPU, adjust if needed
97
- )
98
-
99
  synthesizers[model_name] = synthesizer
100
 
101
-
102
-
103
-
104
- def synthesize(text: str, model_name: str, speaker_name=None) -> str:
105
- """Synthesize speech using the selected model."""
106
  if len(text) > MAX_TXT_LEN:
107
  text = text[:MAX_TXT_LEN]
108
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
109
 
110
- # Use the synthesizer object for the selected model
111
  synthesizer = synthesizers[model_name]
112
-
113
-
114
  if synthesizer is None:
115
  raise NameError("Model not found")
116
 
117
- if synthesizer.tts_speakers_file is "":
118
- wavs = synthesizer.tts(text)
119
-
120
- elif synthesizer.tts_speakers_file is not "":
121
- if speaker_name == "":
122
- wavs = synthesizer.tts(text, speaker_name="speaker-0") ## should change, better if gradio conditions are figure out.
123
- else:
124
- wavs = synthesizer.tts(text, speaker_name=speaker_name)
125
 
126
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
127
  synthesizer.save_wav(wavs, fp)
128
  return fp.name
129
 
130
- # Callback function to update UI based on the selected model
131
- def update_options(model_name):
132
- synthesizer = synthesizers[model_name]
133
- # if synthesizer.tts.is_multi_speaker:
134
- if model_name is MODEL_NAMES[1]:
135
- speakers = synthesizer.tts_model.speaker_manager.speaker_names
136
- # return options for the dropdown
137
- return speakers
138
- else:
139
- # return empty options if not multi-speaker
140
- return []
141
-
142
- # Create Gradio interface
143
  iface = gr.Interface(
144
  fn=synthesize,
145
  inputs=[
146
  gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
147
  gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
148
- gr.Dropdown(label="Select Speaker", choices=update_options(MODEL_NAMES[1]), type="value", default="speaker-0")
149
  ],
150
  outputs=gr.Audio(label="Output", type='filepath'),
151
- examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0], ""]], # Example should include a speaker name for multispeaker models
152
  title='Persian TTS Playground',
153
  description="""
154
  ### Persian text to speech model demo.
155
-
156
- #### Pick a speaker for MultiSpeaker models. (It won't affect the single speaker models)
157
  """,
158
  article="",
159
  live=False
160
  )
161
 
162
- iface.launch()
163
-
 
8
 
9
  # Define constants
10
  MODEL_INFO = [
11
+ ["VITS Grapheme Multispeaker CV15(90K)", "best_model_56960.pth", "config.json", "saillab/multi_speaker"],
12
+
 
 
 
 
 
 
 
13
  ["VITS Grapheme Azure (61000)", "checkpoint_61000.pth", "config.json", "saillab/persian-tts-azure-grapheme-60K"],
14
 
15
  ["VITS Grapheme ARM24 Fine-Tuned on 1 (66651)", "best_model_66651.pth", "config.json",
16
  "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
17
  ["VITS Grapheme ARM24 Fine-Tuned on 1 (120000)", "checkpoint_120000.pth", "config.json",
18
  "saillab/persian-tts-grapheme-arm24-finetuned-on1"],
 
 
19
  ]
20
 
21
  # Extract model names from MODEL_INFO
 
26
 
27
  model_files = {}
28
  config_files = {}
 
29
 
30
  # Create a dictionary to store synthesizer objects for each model
31
  synthesizers = {}
32
 
33
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Download models and initialize synthesizers
35
  for info in MODEL_INFO:
36
  model_name, model_file, config_file, repo_name = info[:4]
 
37
 
38
  print(f"|> Downloading: {model_name}")
39
 
 
41
  model_files[model_name] = hf_hub_download(repo_id=repo_name, filename=model_file, use_auth_token=TOKEN)
42
  config_files[model_name] = hf_hub_download(repo_id=repo_name, filename=config_file, use_auth_token=TOKEN)
43
 
44
+ # Initialize synthesizer for the model
45
+ synthesizer = Synthesizer(tts_checkpoint=model_files[model_name],
46
+ tts_config_path=config_files[model_name],
47
+ use_cuda=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  synthesizers[model_name] = synthesizer
49
 
50
+ def synthesize(text: str, model_name: str) -> str:
51
+
 
 
 
52
  if len(text) > MAX_TXT_LEN:
53
  text = text[:MAX_TXT_LEN]
54
  print(f"Input text was cut off as it exceeded the {MAX_TXT_LEN} character limit.")
55
 
 
56
  synthesizer = synthesizers[model_name]
 
 
57
  if synthesizer is None:
58
  raise NameError("Model not found")
59
 
60
+ wavs = synthesizer.tts(text)
 
 
 
 
 
 
 
61
 
62
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
63
  synthesizer.save_wav(wavs, fp)
64
  return fp.name
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  iface = gr.Interface(
67
  fn=synthesize,
68
  inputs=[
69
  gr.Textbox(label="Enter Text to Synthesize:", value="زین همرهان سست عناصر، دلم گرفت."),
70
  gr.Radio(label="Pick a Model", choices=MODEL_NAMES, value=MODEL_NAMES[0], type="value"),
 
71
  ],
72
  outputs=gr.Audio(label="Output", type='filepath'),
73
+ examples=[["زین همرهان سست عناصر، دلم گرفت.", MODEL_NAMES[0]]],
74
  title='Persian TTS Playground',
75
  description="""
76
  ### Persian text to speech model demo.
 
 
77
  """,
78
  article="",
79
  live=False
80
  )
81
 
82
+ iface.launch()