SebastianBodza commited on
Commit
67a186c
·
verified ·
1 Parent(s): a370bc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -56
app.py CHANGED
@@ -7,6 +7,7 @@ import torchaudio
7
  import gradio as gr
8
  import tempfile
9
  import os
 
10
 
11
  llasa_1b ='SebastianBodza/Kartoffel-1B-v0.2'
12
 
@@ -20,7 +21,6 @@ model = AutoModelForCausalLM.from_pretrained(
20
  )
21
 
22
  model_path = "srinivasbilla/xcodec2"
23
-
24
  Codec_model = XCodec2Model.from_pretrained(model_path)
25
  Codec_model.eval().cuda()
26
 
@@ -32,57 +32,37 @@ whisper_turbo_pipe = pipeline(
32
  )
33
 
34
 
35
- vad_model, utils = torch.hub.load(
36
- "snakers4/silero-vad",
37
- model="silero_vad",
38
- force_reload=False,
39
- source="github")
40
-
41
- get_speech_timestamps, *_ = utils
42
-
43
-
44
- def remove_silence_silero(waveform, sample_rate, vad_model):
45
  """
46
- Remove leading silence using Silero VAD.
47
 
48
  Args:
49
- waveform: torch.Tensor audio waveform (channels, samples)
50
- sample_rate: int sample rate
51
- vad_model: Silero VAD model
 
 
52
  """
53
- if waveform.size(0) > 1:
54
- waveform = torch.mean(waveform, dim=0, keepdim=True)
 
55
 
56
- original_waveform = waveform
 
57
 
58
- if sample_rate != 16000:
59
- waveform_16k = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
60
- else:
61
- waveform_16k = waveform
62
-
63
- # Get speech timestamps
64
- speech_timestamps = get_speech_timestamps(waveform_16k[0], vad_model, sampling_rate=16000)
65
 
66
- if speech_timestamps:
67
- # Get first speech segment start
68
- first_speech = speech_timestamps[0]['start']
69
-
70
- # Add small padding before speech (0.1 seconds)
71
- padding_samples = int(0.1 * sample_rate)
72
- start_idx = max(0, int(first_speech * sample_rate/16000) - padding_samples)
73
-
74
- # Same for the end
75
- last_speech = speech_timestamps[-1]['end']
76
- end_idx = min(original_waveform.size(1), int(last_speech * sample_rate/16000) + padding_samples)
77
-
78
- # Trim the original waveform (not the resampled one)
79
- trimmed_wav = original_waveform[:, start_idx:end_idx]
80
-
81
- # added padding of 16 at the start and end
82
- return torch.nn.functional.pad(trimmed_wav, (16, 16), "constant", 0)
83
 
84
- return original_waveform
85
-
 
 
 
 
86
 
87
  def ids_to_speech_tokens(speech_ids):
88
 
@@ -105,18 +85,19 @@ def extract_speech_ids(speech_tokens_str):
105
  return speech_ids
106
 
107
  @spaces.GPU(duration=30)
108
- def infer(sample_audio_path, target_text, progress=gr.Progress()):
109
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
110
  progress(0, 'Loading and trimming audio...')
111
  waveform, sample_rate = torchaudio.load(sample_audio_path)
112
- waveform = remove_silence_silero(waveform, sample_rate, vad_model)
113
 
114
- # For debugging save the trimmed audio
115
- torchaudio.save("dev.wav", waveform, sample_rate)
116
 
117
  if len(waveform[0])/sample_rate > 15:
118
  gr.Warning("Trimming audio to first 15secs.")
119
  waveform = waveform[:, :sample_rate*15]
 
 
120
 
121
  # Check if the audio is stereo (i.e., has more than one channel)
122
  if waveform.size(0) > 1:
@@ -132,11 +113,12 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
132
 
133
  if len(target_text) == 0:
134
  return None
135
- elif len(target_text) > 300:
136
  gr.Warning("Text is too long. Please keep it under 300 characters.")
137
- target_text = target_text[:300]
138
 
139
  input_text = prompt_text + ' ' + target_text
 
140
 
141
  #TTS start!
142
  with torch.no_grad():
@@ -159,7 +141,7 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
159
  chat,
160
  tokenize=True,
161
  return_tensors='pt',
162
- continue_final_message=True
163
  )
164
  input_ids = input_ids.to('cuda')
165
  speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
@@ -168,11 +150,13 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
168
  outputs = model.generate(
169
  input_ids,
170
  max_length=2048, # We trained our model with a max length of 2048
171
- eos_token_id= speech_end_id ,
172
- do_sample=True,
173
- top_p=1,
174
- temperature=0.8
 
175
  )
 
176
  # Extract the speech tokens
177
  generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
178
 
@@ -198,6 +182,34 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
198
 
199
  with gr.Blocks() as app_tts:
200
  gr.Markdown("# Zero Shot Voice Clone TTS")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
202
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
203
 
@@ -211,6 +223,10 @@ with gr.Blocks() as app_tts:
211
  inputs=[
212
  ref_audio_input,
213
  gen_text_input,
 
 
 
 
214
  ],
215
  outputs=[audio_output, raw_output_display] # Include both outputs
216
  )
@@ -230,7 +246,7 @@ with gr.Blocks() as app:
230
 
231
  This is a local web UI for my finetune of the llasa 1b SOTA(imo) Zero Shot Voice Cloning and TTS model.
232
 
233
- The checkpoints support English and Chinese.
234
 
235
  If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
236
  """
 
7
  import gradio as gr
8
  import tempfile
9
  import os
10
+ import numpy as np
11
 
12
  llasa_1b ='SebastianBodza/Kartoffel-1B-v0.2'
13
 
 
21
  )
22
 
23
  model_path = "srinivasbilla/xcodec2"
 
24
  Codec_model = XCodec2Model.from_pretrained(model_path)
25
  Codec_model.eval().cuda()
26
 
 
32
  )
33
 
34
 
35
+ def normalize_audio(waveform: torch.Tensor, target_db: float = -20) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
36
  """
37
+ Normalize audio volume to target dB and limit gain range.
38
 
39
  Args:
40
+ waveform (torch.Tensor): Input audio waveform
41
+ target_db (float): Target dB level (default: -20)
42
+
43
+ Returns:
44
+ torch.Tensor: Normalized audio waveform
45
  """
46
+ # Calculate current dB
47
+ eps = 1e-10
48
+ current_db = 20 * torch.log10(torch.max(torch.abs(waveform)) + eps)
49
 
50
+ # Calculate required gain
51
+ gain_db = target_db - current_db
52
 
53
+ # Limit gain to -3 to 3 dB range
54
+ gain_db = torch.clamp(gain_db, min=-3, max=3)
 
 
 
 
 
55
 
56
+ # Apply gain
57
+ gain_factor = 10 ** (gain_db / 20)
58
+ normalized = waveform * gain_factor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # Final peak normalization
61
+ max_amplitude = torch.max(torch.abs(normalized))
62
+ if max_amplitude > 0:
63
+ normalized = normalized / max_amplitude
64
+
65
+ return normalized
66
 
67
  def ids_to_speech_tokens(speech_ids):
68
 
 
85
  return speech_ids
86
 
87
  @spaces.GPU(duration=30)
88
+ def infer(sample_audio_path, target_text, temp, top_p_val, min_new_tokens, do_sample, progress=gr.Progress()):
89
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
90
  progress(0, 'Loading and trimming audio...')
91
  waveform, sample_rate = torchaudio.load(sample_audio_path)
 
92
 
93
+ waveform = normalize_audio(waveform)
94
+
95
 
96
  if len(waveform[0])/sample_rate > 15:
97
  gr.Warning("Trimming audio to first 15secs.")
98
  waveform = waveform[:, :sample_rate*15]
99
+ waveform = torch.nn.functional.pad(waveform, (0, int(sample_rate*0.5)), "constant", 0)
100
+
101
 
102
  # Check if the audio is stereo (i.e., has more than one channel)
103
  if waveform.size(0) > 1:
 
113
 
114
  if len(target_text) == 0:
115
  return None
116
+ elif len(target_text) > 500:
117
  gr.Warning("Text is too long. Please keep it under 300 characters.")
118
+ target_text = target_text[:500]
119
 
120
  input_text = prompt_text + ' ' + target_text
121
+ print("Transcribed text:", input_text)
122
 
123
  #TTS start!
124
  with torch.no_grad():
 
141
  chat,
142
  tokenize=True,
143
  return_tensors='pt',
144
+ continue_final_message=True,
145
  )
146
  input_ids = input_ids.to('cuda')
147
  speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
 
150
  outputs = model.generate(
151
  input_ids,
152
  max_length=2048, # We trained our model with a max length of 2048
153
+ eos_token_id= speech_end_id,
154
+ do_sample=do_sample,
155
+ top_p=top_p_val,
156
+ temperature=temp,
157
+ min_new_tokens=min_new_tokens,
158
  )
159
+
160
  # Extract the speech tokens
161
  generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
162
 
 
182
 
183
  with gr.Blocks() as app_tts:
184
  gr.Markdown("# Zero Shot Voice Clone TTS")
185
+
186
+ with gr.Accordion("Model Settings", open=False):
187
+ temperature = gr.Slider(
188
+ minimum=0.1,
189
+ maximum=1.0,
190
+ value=0.8,
191
+ step=0.1,
192
+ label="Temperature",
193
+ info="Higher values = more random/creative output"
194
+ )
195
+ top_p = gr.Slider(
196
+ minimum=0.1,
197
+ maximum=1.0,
198
+ value=1.0,
199
+ step=0.1,
200
+ label="Top P",
201
+ info="Nucleus sampling threshold"
202
+ )
203
+ min_new_tokens = gr.Slider(
204
+ minimum=0,
205
+ maximum=128,
206
+ value=3,
207
+ step=1,
208
+ label="Min Length",
209
+ info="If the model just produces a click you can force it to create longer generations."
210
+ )
211
+ do_sample = gr.Checkbox(label="Sample", value=True, info="Sample from the distribution")
212
+
213
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
214
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
215
 
 
223
  inputs=[
224
  ref_audio_input,
225
  gen_text_input,
226
+ temperature,
227
+ top_p,
228
+ min_new_tokens,
229
+ do_sample
230
  ],
231
  outputs=[audio_output, raw_output_display] # Include both outputs
232
  )
 
246
 
247
  This is a local web UI for my finetune of the llasa 1b SOTA(imo) Zero Shot Voice Cloning and TTS model.
248
 
249
+ The checkpoints support German. If the audio is of low quality, the model may struggle to generate speech. Turn the **temperature** up to get more coherent results.
250
 
251
  If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
252
  """