gorbiz commited on
Commit
793e8f1
·
verified ·
1 Parent(s): 3f8297c

Default wav file(?)

Browse files
Files changed (1) hide show
  1. app.py +23 -30
app.py CHANGED
@@ -6,8 +6,21 @@ from xcodec2.modeling_xcodec2 import XCodec2Model
6
  import torchaudio
7
  import gradio as gr
8
  import tempfile
9
-
10
- llasa_3b ='srinivasbilla/llasa-3b'
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
13
 
@@ -30,19 +43,16 @@ whisper_turbo_pipe = pipeline(
30
  )
31
 
32
  def ids_to_speech_tokens(speech_ids):
33
-
34
  speech_tokens_str = []
35
  for speech_id in speech_ids:
36
  speech_tokens_str.append(f"<|s_{speech_id}|>")
37
  return speech_tokens_str
38
 
39
  def extract_speech_ids(speech_tokens_str):
40
-
41
  speech_ids = []
42
  for token_str in speech_tokens_str:
43
  if token_str.startswith('<|s_') and token_str.endswith('|>'):
44
  num_str = token_str[4:-2]
45
-
46
  num = int(num_str)
47
  speech_ids.append(num)
48
  else:
@@ -58,12 +68,9 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
58
  gr.Warning("Trimming audio to first 15secs.")
59
  waveform = waveform[:, :sample_rate*15]
60
 
61
- # Check if the audio is stereo (i.e., has more than one channel)
62
  if waveform.size(0) > 1:
63
- # Convert stereo to mono by averaging the channels
64
  waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
65
  else:
66
- # If already mono, just use the original waveform
67
  waveform_mono = waveform
68
 
69
  prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
@@ -78,18 +85,13 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
78
 
79
  input_text = prompt_text + ' ' + target_text
80
 
81
- #TTS start!
82
  with torch.no_grad():
83
- # Encode the prompt wav
84
  vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
85
-
86
  vq_code_prompt = vq_code_prompt[0,0,:]
87
- # Convert int 12345 to token <|s_12345|>
88
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
89
 
90
  formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
91
 
92
- # Tokenize the text and the speech prefix
93
  chat = [
94
  {"role": "user", "content": "Convert the text to speech:" + formatted_text},
95
  {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
@@ -104,29 +106,20 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
104
  input_ids = input_ids.to('cuda')
105
  speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
106
 
107
- # Generate the speech autoregressively
108
  outputs = model.generate(
109
  input_ids,
110
- max_length=2048, # We trained our model with a max length of 2048
111
  eos_token_id= speech_end_id ,
112
  do_sample=True,
113
  top_p=1,
114
  temperature=0.8
115
  )
116
- # Extract the speech tokens
117
  generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
118
-
119
  speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
120
-
121
- # Convert token <|s_23456|> to int 23456
122
  speech_tokens = extract_speech_ids(speech_tokens)
123
 
124
  speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
125
-
126
- # Decode the speech tokens to speech waveform
127
  gen_wav = Codec_model.decode_code(speech_tokens)
128
-
129
- # if only need the generated part
130
  gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
131
 
132
  progress(1, 'Synthesized!')
@@ -135,19 +128,20 @@ def infer(sample_audio_path, target_text, progress=gr.Progress()):
135
 
136
  with gr.Blocks() as app_tts:
137
  gr.Markdown("# Zero Shot Voice Clone TTS")
138
- ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
 
 
 
 
 
139
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
140
 
141
  generate_btn = gr.Button("Synthesize", variant="primary")
142
-
143
  audio_output = gr.Audio(label="Synthesized Audio")
144
 
145
  generate_btn.click(
146
  infer,
147
- inputs=[
148
- ref_audio_input,
149
- gen_text_input,
150
- ],
151
  outputs=[audio_output],
152
  )
153
 
@@ -173,5 +167,4 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
173
  )
174
  gr.TabbedInterface([app_tts], ["TTS"])
175
 
176
-
177
  app.launch(ssr_mode=False)
 
6
  import torchaudio
7
  import gradio as gr
8
  import tempfile
9
+ import requests # Added import for downloading the default WAV
10
+
11
+ # Download the default WAV file
12
+ default_wav_url = "https://file.thatvoid.com/main/20250127T095211591Z-ee8c576d2304e5195ddfce77a45e0377.wav"
13
+ default_wav_path = "default_voice.wav"
14
+ try:
15
+ response = requests.get(default_wav_url)
16
+ response.raise_for_status()
17
+ with open(default_wav_path, "wb") as f:
18
+ f.write(response.content)
19
+ except Exception as e:
20
+ print(f"Failed to download default WAV: {e}")
21
+ default_wav_path = None # Fallback to requiring user input
22
+
23
+ llasa_3b = 'srinivasbilla/llasa-3b'
24
 
25
  tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
26
 
 
43
  )
44
 
45
  def ids_to_speech_tokens(speech_ids):
 
46
  speech_tokens_str = []
47
  for speech_id in speech_ids:
48
  speech_tokens_str.append(f"<|s_{speech_id}|>")
49
  return speech_tokens_str
50
 
51
  def extract_speech_ids(speech_tokens_str):
 
52
  speech_ids = []
53
  for token_str in speech_tokens_str:
54
  if token_str.startswith('<|s_') and token_str.endswith('|>'):
55
  num_str = token_str[4:-2]
 
56
  num = int(num_str)
57
  speech_ids.append(num)
58
  else:
 
68
  gr.Warning("Trimming audio to first 15secs.")
69
  waveform = waveform[:, :sample_rate*15]
70
 
 
71
  if waveform.size(0) > 1:
 
72
  waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
73
  else:
 
74
  waveform_mono = waveform
75
 
76
  prompt_wav = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform_mono)
 
85
 
86
  input_text = prompt_text + ' ' + target_text
87
 
 
88
  with torch.no_grad():
 
89
  vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
 
90
  vq_code_prompt = vq_code_prompt[0,0,:]
 
91
  speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)
92
 
93
  formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
94
 
 
95
  chat = [
96
  {"role": "user", "content": "Convert the text to speech:" + formatted_text},
97
  {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
 
106
  input_ids = input_ids.to('cuda')
107
  speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')
108
 
 
109
  outputs = model.generate(
110
  input_ids,
111
+ max_length=2048,
112
  eos_token_id= speech_end_id ,
113
  do_sample=True,
114
  top_p=1,
115
  temperature=0.8
116
  )
 
117
  generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]
 
118
  speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
 
 
119
  speech_tokens = extract_speech_ids(speech_tokens)
120
 
121
  speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
 
 
122
  gen_wav = Codec_model.decode_code(speech_tokens)
 
 
123
  gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]
124
 
125
  progress(1, 'Synthesized!')
 
128
 
129
  with gr.Blocks() as app_tts:
130
  gr.Markdown("# Zero Shot Voice Clone TTS")
131
+ # Set default value for the audio input
132
+ ref_audio_input = gr.Audio(
133
+ label="Reference Audio",
134
+ type="filepath",
135
+ value=default_wav_path if default_wav_path else None # Use downloaded file or fallback
136
+ )
137
  gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
138
 
139
  generate_btn = gr.Button("Synthesize", variant="primary")
 
140
  audio_output = gr.Audio(label="Synthesized Audio")
141
 
142
  generate_btn.click(
143
  infer,
144
+ inputs=[ref_audio_input, gen_text_input],
 
 
 
145
  outputs=[audio_output],
146
  )
147
 
 
167
  )
168
  gr.TabbedInterface([app_tts], ["TTS"])
169
 
 
170
  app.launch(ssr_mode=False)