SebastianBodza commited on
Commit
99be1d8
·
verified ·
1 Parent(s): 67a186c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -62
app.py CHANGED
@@ -6,18 +6,15 @@ from xcodec2.modeling_xcodec2 import XCodec2Model
6
  import torchaudio
7
  import gradio as gr
8
  import tempfile
9
- import os
10
- import numpy as np
11
 
12
  llasa_1b ='SebastianBodza/Kartoffel-1B-v0.2'
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(llasa_1b, token=os.getenv("HF_TOKEN"))
15
 
16
  model = AutoModelForCausalLM.from_pretrained(
17
- llasa_1b,
18
- trust_remote_code=True,
19
- device_map='cuda',
20
- token=os.getenv("HF_TOKEN")
21
  )
22
 
23
  model_path = "srinivasbilla/xcodec2"
@@ -28,54 +25,107 @@ whisper_turbo_pipe = pipeline(
28
  "automatic-speech-recognition",
29
  model="openai/whisper-large-v3-turbo",
30
  torch_dtype=torch.float16,
31
- device='cuda',
32
  )
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def normalize_audio(waveform: torch.Tensor, target_db: float = -20) -> torch.Tensor:
36
  """
37
  Normalize audio volume to target dB and limit gain range.
38
-
39
  Args:
40
  waveform (torch.Tensor): Input audio waveform
41
  target_db (float): Target dB level (default: -20)
42
-
43
  Returns:
44
  torch.Tensor: Normalized audio waveform
45
  """
46
  # Calculate current dB
47
  eps = 1e-10
48
  current_db = 20 * torch.log10(torch.max(torch.abs(waveform)) + eps)
49
-
50
  # Calculate required gain
51
  gain_db = target_db - current_db
52
-
53
  # Limit gain to -3 to 3 dB range
54
  gain_db = torch.clamp(gain_db, min=-3, max=3)
55
-
56
  # Apply gain
57
  gain_factor = 10 ** (gain_db / 20)
58
  normalized = waveform * gain_factor
59
-
60
  # Final peak normalization
61
  max_amplitude = torch.max(torch.abs(normalized))
62
  if max_amplitude > 0:
63
  normalized = normalized / max_amplitude
64
-
65
  return normalized
66
 
 
67
  def ids_to_speech_tokens(speech_ids):
68
-
69
  speech_tokens_str = []
70
  for speech_id in speech_ids:
71
  speech_tokens_str.append(f"<|s_{speech_id}|>")
72
  return speech_tokens_str
73
 
 
74
  def extract_speech_ids(speech_tokens_str):
75
-
76
  speech_ids = []
77
  for token_str in speech_tokens_str:
78
- if token_str.startswith('<|s_') and token_str.endswith('|>'):
79
  num_str = token_str[4:-2]
80
 
81
  num = int(num_str)
@@ -84,20 +134,57 @@ def extract_speech_ids(speech_tokens_str):
84
  print(f"Unexpected token: {token_str}")
85
  return speech_ids
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  @spaces.GPU(duration=30)
88
- def infer(sample_audio_path, target_text, temp, top_p_val, min_new_tokens, do_sample, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
89
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
90
- progress(0, 'Loading and trimming audio...')
91
  waveform, sample_rate = torchaudio.load(sample_audio_path)
92
 
93
  waveform = normalize_audio(waveform)
94
 
95
-
96
- if len(waveform[0])/sample_rate > 15:
97
  gr.Warning("Trimming audio to first 15secs.")
98
- waveform = waveform[:, :sample_rate*15]
99
- waveform = torch.nn.functional.pad(waveform, (0, int(sample_rate*0.5)), "constant", 0)
100
-
 
101
 
102
  # Check if the audio is stereo (i.e., has more than one channel)
103
  if waveform.size(0) > 1:
@@ -107,78 +194,104 @@ def infer(sample_audio_path, target_text, temp, top_p_val, min_new_tokens, do_sa
107
  # If already mono, just use the original waveform
108
  waveform_mono = waveform
109
 
110
- prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
111
- prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())['text'].strip()
112
- progress(0.5, 'Transcribed! Generating speech...')
 
 
 
 
 
 
 
 
 
113
 
114
  if len(target_text) == 0:
115
  return None
116
  elif len(target_text) > 500:
117
  gr.Warning("Text is too long. Please keep it under 300 characters.")
118
  target_text = target_text[:500]
119
-
120
- input_text = prompt_text + ' ' + target_text
121
  print("Transcribed text:", input_text)
122
 
123
- #TTS start!
124
  with torch.no_grad():
125
  # Encode the prompt wav
126
  vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
127
 
128
- vq_code_prompt = vq_code_prompt[0,0,:]
129
  # Convert int 12345 to token <|s_12345|>
130
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
131
 
132
- formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
 
 
133
 
134
  # Tokenize the text and the speech prefix
135
  chat = [
136
- {"role": "user", "content": "Convert the text to speech:" + formatted_text},
137
- {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
 
 
 
 
 
 
 
138
  ]
139
 
140
  input_ids = tokenizer.apply_chat_template(
141
- chat,
142
- tokenize=True,
143
- return_tensors='pt',
144
  continue_final_message=True,
145
  )
146
- input_ids = input_ids.to('cuda')
147
- speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
148
 
149
  # Generate the speech autoregressively
150
  outputs = model.generate(
151
  input_ids,
152
  max_length=2048, # We trained our model with a max length of 2048
153
- eos_token_id= speech_end_id,
154
  do_sample=do_sample,
155
- top_p=top_p_val,
156
  temperature=temp,
157
  min_new_tokens=min_new_tokens,
158
  )
159
 
160
  # Extract the speech tokens
161
- generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
162
-
163
- speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
164
- raw_output = ' '.join(speech_tokens) # Capture raw tokens
165
 
166
- speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
 
 
 
167
 
168
- # Convert token <|s_23456|> to int 23456
 
 
 
 
169
  speech_tokens = extract_speech_ids(speech_tokens)
170
 
171
  speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
172
 
173
  # Decode the speech tokens to speech waveform
174
- gen_wav = Codec_model.decode_code(speech_tokens)
175
 
176
  # if only need the generated part
177
- gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
178
 
179
- progress(1, 'Synthesized!')
 
 
 
 
 
180
 
181
- return (16000, gen_wav[0, 0, :].cpu().numpy()), raw_output # Return both audio and raw tokens
182
 
183
  with gr.Blocks() as app_tts:
184
  gr.Markdown("# Zero Shot Voice Clone TTS")
@@ -187,10 +300,10 @@ with gr.Blocks() as app_tts:
187
  temperature = gr.Slider(
188
  minimum=0.1,
189
  maximum=1.0,
190
- value=0.8,
191
  step=0.1,
192
  label="Temperature",
193
- info="Higher values = more random/creative output"
194
  )
195
  top_p = gr.Slider(
196
  minimum=0.1,
@@ -198,7 +311,7 @@ with gr.Blocks() as app_tts:
198
  value=1.0,
199
  step=0.1,
200
  label="Top P",
201
- info="Nucleus sampling threshold"
202
  )
203
  min_new_tokens = gr.Slider(
204
  minimum=0,
@@ -206,9 +319,11 @@ with gr.Blocks() as app_tts:
206
  value=3,
207
  step=1,
208
  label="Min Length",
209
- info="If the model just produces a click you can force it to create longer generations."
 
 
 
210
  )
211
- do_sample = gr.Checkbox(label="Sample", value=True, info="Sample from the distribution")
212
 
213
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
214
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
@@ -216,21 +331,101 @@ with gr.Blocks() as app_tts:
216
  generate_btn = gr.Button("Synthesize", variant="primary")
217
 
218
  audio_output = gr.Audio(label="Synthesized Audio")
219
- raw_output_display = gr.Textbox(label="Raw Model Output", interactive=False) # Add textbox
 
 
220
 
221
  generate_btn.click(
222
- infer,
223
  inputs=[
224
  ref_audio_input,
225
  gen_text_input,
226
  temperature,
227
  top_p,
228
  min_new_tokens,
229
- do_sample
230
  ],
231
- outputs=[audio_output, raw_output_display] # Include both outputs
232
  )
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  with gr.Blocks() as app_credits:
235
  gr.Markdown("""
236
  # Credits
@@ -251,7 +446,7 @@ The checkpoints support German. If the audio is of low quality, the model may st
251
  If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
252
  """
253
  )
254
- gr.TabbedInterface([app_tts], ["TTS"])
255
 
256
 
257
- app.launch(ssr_mode=False)
 
6
  import torchaudio
7
  import gradio as gr
8
  import tempfile
9
+ import os
10
+ import numpy as np
11
 
12
  llasa_1b ='SebastianBodza/Kartoffel-1B-v0.2'
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(llasa_1b, token=os.getenv("HF_TOKEN"))
15
 
16
  model = AutoModelForCausalLM.from_pretrained(
17
+ llasa_1b, trust_remote_code=True, device_map="cuda", token=os.getenv("HF_TOKEN")
 
 
 
18
  )
19
 
20
  model_path = "srinivasbilla/xcodec2"
 
25
  "automatic-speech-recognition",
26
  model="openai/whisper-large-v3-turbo",
27
  torch_dtype=torch.float16,
28
+ device="cuda",
29
  )
30
 
31
 
32
+ SPEAKERS = {
33
+ "Male 1": {
34
+ "path": "speakers/deep_speaker.mp3",
35
+ "transcript": "Das große Tor von Minas Tirith brach erst, nachdem er die Ramme eingesetzt hatte.",
36
+ "description": "Eine tiefe epische Männerstimme.",
37
+ },
38
+ "Male 2": {
39
+ "path": "speakers/male_austrian_accent.mp3",
40
+ "transcript": "Man kann sich auch leichter vorstellen, wie schwierig es ist, dass man Entscheidungen trifft, die allen passen.",
41
+ "description": "Eine männliche Stimme mit österreicherischem Akzent.",
42
+ },
43
+ "Male 3": {
44
+ "path": "speakers/male_energic.mp3",
45
+ "transcript": "Wo keine Infrastruktur, da auch keine Ansiedlung von IT-Unternehmen und deren Beschäftigten bzw. dem geeigneten Fachkräftenachwuchs. Kann man diese Rechnung so einfach aufmachen, wie es es tatsächlich um deren regionale Verteilung beschäftigt?",
46
+ "description": "Eine männliche energische Stimme",
47
+ },
48
+ "Male 4": {
49
+ "path": "speakers/schneller_speaker.mp3",
50
+ "transcript": "Genau, wenn wir alle Dächer voll machen, also alle Dächer von Einfamilienhäusern, alleine mit den Einfamilienhäusern können wir 20 Prozent des heutigen Strombedarfs decken.",
51
+ "description": "Eine männliche Spreche mit schnellerem Tempo.",
52
+ },
53
+ "Female 1": {
54
+ "path": "speakers/female_standard.mp3",
55
+ "transcript": "Es wird ein Beispiel für ein barrierearmes Layout gegeben, sowie Tipps und ein Verweis auf eine Checkliste, die hilft, Barrierearmut in den eigenen Materialien zu prüfen bzw. umzusetzen.",
56
+ "description": "Eine weibliche Stimme.",
57
+ },
58
+ "Female 2": {
59
+ "path": "speakers/female_energic.mp3",
60
+ "transcript": "Dunkel flog weiter durch das Wald. Er sah die Sterne am Phaneten an sich vorbeiziehen und fühlte sich frei und glücklich.",
61
+ "description": "Eine weibliche Erzähler-Stimme.",
62
+ },
63
+ "Female 3": {
64
+ "path": "speakers/austrian_accent.mp3",
65
+ "transcript": "Die politische Europäische Union war geboren, verbrieft im Vertrag von Maastricht. Ab diesem Zeitpunkt bestehen zwei Vertragswerke.",
66
+ "description": "Eine weibliche Stimme mit österreicherischem Akzent.",
67
+ },
68
+ "Special 1": {
69
+ "path": "speakers/low_audio.mp3",
70
+ "transcript": "Druckplatten und Lasersensoren, um sicherzugehen, dass er auch da drin ist und",
71
+ "description": "Eine männliche Stimme mit schlechter Audioqualität als Effekt.",
72
+ },
73
+ }
74
+
75
+
76
+ def preview_speaker(display_name):
77
+ """Returns the audio and transcript for preview"""
78
+ speaker_name = speaker_display_dict[display_name]
79
+ if speaker_name in SPEAKERS:
80
+ waveform, sample_rate = torchaudio.load(SPEAKERS[speaker_name]["path"])
81
+ return (sample_rate, waveform[0].numpy()), SPEAKERS[speaker_name]["transcript"]
82
+ return None, ""
83
+
84
+
85
  def normalize_audio(waveform: torch.Tensor, target_db: float = -20) -> torch.Tensor:
86
  """
87
  Normalize audio volume to target dB and limit gain range.
88
+
89
  Args:
90
  waveform (torch.Tensor): Input audio waveform
91
  target_db (float): Target dB level (default: -20)
92
+
93
  Returns:
94
  torch.Tensor: Normalized audio waveform
95
  """
96
  # Calculate current dB
97
  eps = 1e-10
98
  current_db = 20 * torch.log10(torch.max(torch.abs(waveform)) + eps)
99
+
100
  # Calculate required gain
101
  gain_db = target_db - current_db
102
+
103
  # Limit gain to -3 to 3 dB range
104
  gain_db = torch.clamp(gain_db, min=-3, max=3)
105
+
106
  # Apply gain
107
  gain_factor = 10 ** (gain_db / 20)
108
  normalized = waveform * gain_factor
109
+
110
  # Final peak normalization
111
  max_amplitude = torch.max(torch.abs(normalized))
112
  if max_amplitude > 0:
113
  normalized = normalized / max_amplitude
114
+
115
  return normalized
116
 
117
+
118
  def ids_to_speech_tokens(speech_ids):
 
119
  speech_tokens_str = []
120
  for speech_id in speech_ids:
121
  speech_tokens_str.append(f"<|s_{speech_id}|>")
122
  return speech_tokens_str
123
 
124
+
125
  def extract_speech_ids(speech_tokens_str):
 
126
  speech_ids = []
127
  for token_str in speech_tokens_str:
128
+ if token_str.startswith("<|s_") and token_str.endswith("|>"):
129
  num_str = token_str[4:-2]
130
 
131
  num = int(num_str)
 
134
  print(f"Unexpected token: {token_str}")
135
  return speech_ids
136
 
137
+
138
+ def infer_with_speaker(
139
+ display_name,
140
+ target_text,
141
+ temp,
142
+ top_p_val,
143
+ min_new_tokens,
144
+ do_sample,
145
+ progress=gr.Progress(),
146
+ ):
147
+ """Modified infer function that uses predefined speaker"""
148
+ speaker_name = speaker_display_dict[display_name] # Get actual speaker name
149
+ if speaker_name not in SPEAKERS:
150
+ return None, "Invalid speaker selected"
151
+
152
+
153
+ return infer(
154
+ SPEAKERS[speaker_name]["path"],
155
+ target_text,
156
+ temp,
157
+ top_p_val,
158
+ min_new_tokens,
159
+ do_sample,
160
+ SPEAKERS[speaker_name]["transcript"], # Pass the predefined transcript
161
+ progress,
162
+ )
163
+
164
+
165
  @spaces.GPU(duration=30)
166
+ def infer(
167
+ sample_audio_path,
168
+ target_text,
169
+ temp,
170
+ top_p_val,
171
+ min_new_tokens,
172
+ do_sample,
173
+ transcribed_text=None,
174
+ progress=gr.Progress(),
175
+ ):
176
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
177
+ progress(0, "Loading and trimming audio...")
178
  waveform, sample_rate = torchaudio.load(sample_audio_path)
179
 
180
  waveform = normalize_audio(waveform)
181
 
182
+ if len(waveform[0]) / sample_rate > 15:
 
183
  gr.Warning("Trimming audio to first 15secs.")
184
+ waveform = waveform[:, : sample_rate * 15]
185
+ waveform = torch.nn.functional.pad(
186
+ waveform, (0, int(sample_rate * 0.5)), "constant", 0
187
+ )
188
 
189
  # Check if the audio is stereo (i.e., has more than one channel)
190
  if waveform.size(0) > 1:
 
194
  # If already mono, just use the original waveform
195
  waveform_mono = waveform
196
 
197
+ prompt_wav = torchaudio.transforms.Resample(
198
+ orig_freq=sample_rate, new_freq=16000
199
+ )(waveform_mono)
200
+
201
+ if transcribed_text is None:
202
+ progress(0.3, "Transcribing audio...")
203
+ prompt_text = whisper_turbo_pipe(prompt_wav[0].numpy())["text"].strip()
204
+ print("Transcribed text:", prompt_text)
205
+ else:
206
+ prompt_text = transcribed_text
207
+
208
+ progress(0.5, "Transcribed! Generating speech...")
209
 
210
  if len(target_text) == 0:
211
  return None
212
  elif len(target_text) > 500:
213
  gr.Warning("Text is too long. Please keep it under 300 characters.")
214
  target_text = target_text[:500]
215
+
216
+ input_text = prompt_text + " " + target_text
217
  print("Transcribed text:", input_text)
218
 
219
+ # TTS start!
220
  with torch.no_grad():
221
  # Encode the prompt wav
222
  vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
223
 
224
+ vq_code_prompt = vq_code_prompt[0, 0, :]
225
  # Convert int 12345 to token <|s_12345|>
226
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
227
 
228
+ formatted_text = (
229
+ f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
230
+ )
231
 
232
  # Tokenize the text and the speech prefix
233
  chat = [
234
+ {
235
+ "role": "user",
236
+ "content": "Convert the text to speech:" + formatted_text,
237
+ },
238
+ {
239
+ "role": "assistant",
240
+ "content": "<|SPEECH_GENERATION_START|>"
241
+ + "".join(speech_ids_prefix),
242
+ },
243
  ]
244
 
245
  input_ids = tokenizer.apply_chat_template(
246
+ chat,
247
+ tokenize=True,
248
+ return_tensors="pt",
249
  continue_final_message=True,
250
  )
251
+ input_ids = input_ids.to("cuda")
252
+ speech_end_id = tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
253
 
254
  # Generate the speech autoregressively
255
  outputs = model.generate(
256
  input_ids,
257
  max_length=2048, # We trained our model with a max length of 2048
258
+ eos_token_id=speech_end_id,
259
  do_sample=do_sample,
260
+ top_p=top_p_val,
261
  temperature=temp,
262
  min_new_tokens=min_new_tokens,
263
  )
264
 
265
  # Extract the speech tokens
266
+ generated_ids = outputs[0][input_ids.shape[1] - len(speech_ids_prefix) : -1]
 
 
 
267
 
268
+ speech_tokens = tokenizer.batch_decode(
269
+ generated_ids, skip_special_tokens=False
270
+ )
271
+ raw_output = " ".join(speech_tokens) # Capture raw tokens
272
 
273
+ speech_tokens = tokenizer.batch_decode(
274
+ generated_ids, skip_special_tokens=True
275
+ )
276
+
277
+ # Convert token <|s_23456|> to int 23456
278
  speech_tokens = extract_speech_ids(speech_tokens)
279
 
280
  speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
281
 
282
  # Decode the speech tokens to speech waveform
283
+ gen_wav = Codec_model.decode_code(speech_tokens)
284
 
285
  # if only need the generated part
286
+ gen_wav = gen_wav[:, :, prompt_wav.shape[1] :]
287
 
288
+ progress(1, "Synthesized!")
289
+
290
+ return (
291
+ 16000,
292
+ gen_wav[0, 0, :].cpu().numpy(),
293
+ ), raw_output # Return both audio and raw tokens
294
 
 
295
 
296
  with gr.Blocks() as app_tts:
297
  gr.Markdown("# Zero Shot Voice Clone TTS")
 
300
  temperature = gr.Slider(
301
  minimum=0.1,
302
  maximum=1.0,
303
+ value=0.4,
304
  step=0.1,
305
  label="Temperature",
306
+ info="Higher values = more random/creative output",
307
  )
308
  top_p = gr.Slider(
309
  minimum=0.1,
 
311
  value=1.0,
312
  step=0.1,
313
  label="Top P",
314
+ info="Nucleus sampling threshold",
315
  )
316
  min_new_tokens = gr.Slider(
317
  minimum=0,
 
319
  value=3,
320
  step=1,
321
  label="Min Length",
322
+ info="If the model just produces a click you can force it to create longer generations.",
323
+ )
324
+ do_sample = gr.Checkbox(
325
+ label="Sample", value=True, info="Sample from the distribution"
326
  )
 
327
 
328
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
329
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
 
331
  generate_btn = gr.Button("Synthesize", variant="primary")
332
 
333
  audio_output = gr.Audio(label="Synthesized Audio")
334
+ raw_output_display = gr.Textbox(
335
+ label="Raw Model Output", interactive=False
336
+ ) # Add textbox
337
 
338
  generate_btn.click(
339
+ lambda *args: infer(*args, transcribed_text=None),
340
  inputs=[
341
  ref_audio_input,
342
  gen_text_input,
343
  temperature,
344
  top_p,
345
  min_new_tokens,
346
+ do_sample,
347
  ],
348
+ outputs=[audio_output, raw_output_display], # Include both outputs
349
  )
350
 
351
+
352
+ with gr.Blocks() as app_speaker:
353
+ gr.Markdown("# Predefined Speaker TTS")
354
+
355
+ with gr.Accordion("Model Settings", open=False):
356
+ temperature = gr.Slider(
357
+ minimum=0.0,
358
+ maximum=1.0,
359
+ value=0.7,
360
+ step=0.1,
361
+ label="Temperature",
362
+ info="Higher values = more random/creative output",
363
+ )
364
+ top_p = gr.Slider(
365
+ minimum=0.1,
366
+ maximum=1.0,
367
+ value=1.0,
368
+ step=0.1,
369
+ label="Top P",
370
+ info="Nucleus sampling threshold",
371
+ )
372
+ min_new_tokens = gr.Slider(
373
+ minimum=0,
374
+ maximum=128,
375
+ value=3,
376
+ step=1,
377
+ label="Min Length",
378
+ info="If the model just produces a click you can force it to create longer generations.",
379
+ )
380
+ do_sample = gr.Checkbox(
381
+ label="Sample", value=True, info="Sample from the distribution"
382
+ )
383
+
384
+ with gr.Row():
385
+ speaker_display_dict = {
386
+ f"{name} - {SPEAKERS[name]['description']}": name
387
+ for name in SPEAKERS.keys()
388
+ }
389
+ speaker_dropdown = gr.Dropdown(
390
+ choices=list(speaker_display_dict.keys()),
391
+ label="Select Speaker",
392
+ value=list(speaker_display_dict.keys())[0],
393
+ )
394
+ preview_btn = gr.Button("Preview Voice")
395
+
396
+
397
+ with gr.Row():
398
+ preview_audio = gr.Audio(label="Preview")
399
+ preview_text = gr.Textbox(label="Original Transcript", interactive=False)
400
+
401
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
402
+ generate_btn = gr.Button("Synthesize", variant="primary")
403
+
404
+ audio_output = gr.Audio(label="Synthesized Audio")
405
+ raw_output_display = gr.Textbox(label="Raw Model Output", interactive=False)
406
+
407
+ # Connect the preview button
408
+ preview_btn.click(
409
+ preview_speaker,
410
+ inputs=[speaker_dropdown],
411
+ outputs=[preview_audio, preview_text],
412
+ )
413
+
414
+ # Connect the generate button
415
+ generate_btn.click(
416
+ infer_with_speaker,
417
+ inputs=[
418
+ speaker_dropdown,
419
+ gen_text_input,
420
+ temperature,
421
+ top_p,
422
+ min_new_tokens,
423
+ do_sample,
424
+ ],
425
+ outputs=[audio_output, raw_output_display],
426
+ )
427
+
428
+
429
  with gr.Blocks() as app_credits:
430
  gr.Markdown("""
431
  # Credits
 
446
  If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
447
  """
448
  )
449
+ gr.TabbedInterface([app_speaker, app_tts], ["Speaker", "Clone"])
450
 
451
 
452
+ app.launch(ssr_mode=False)