mrfakename commited on
Commit
68d3791
1 Parent(s): ca0edb1

Update app_local.py

Browse files
Files changed (1) hide show
  1. app_local.py +79 -55
app_local.py CHANGED
@@ -14,17 +14,17 @@ from pydub import AudioSegment
14
  from model import CFM, UNetT, DiT, MMDiT
15
  from cached_path import cached_path
16
  from model.utils import (
17
- get_tokenizer,
18
- convert_char_to_pinyin,
19
  save_spectrogram,
20
  )
21
  from transformers import pipeline
 
22
  import librosa
 
23
 
24
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
25
 
26
- print(f"Using {device} device")
27
-
28
  pipe = pipeline(
29
  "automatic-speech-recognition",
30
  model="openai/whisper-large-v3-turbo",
@@ -79,13 +79,13 @@ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
79
  F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
80
  E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
81
 
82
- def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
83
  print(gen_text)
84
- if len(gen_text) > 200:
85
- raise gr.Error("Please keep your text under 200 chars.")
86
  gr.Info("Converting audio...")
87
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
88
  aseg = AudioSegment.from_file(ref_audio_orig)
 
 
89
  audio_duration = len(aseg)
90
  if audio_duration > 15000:
91
  gr.Warning("Audio is over 15s, clipping to only first 15s.")
@@ -98,7 +98,7 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
98
  elif exp_name == "E2-TTS":
99
  ema_model = E2TTS_ema_model
100
  base_model = E2TTS_base_model
101
-
102
  if not ref_text.strip():
103
  gr.Info("No reference text provided, transcribing reference audio...")
104
  ref_text = outputs = pipe(
@@ -112,7 +112,9 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
112
  else:
113
  gr.Info("Using custom reference text...")
114
  audio, sr = torchaudio.load(ref_audio)
115
-
 
 
116
  rms = torch.sqrt(torch.mean(torch.square(audio)))
117
  if rms < target_rms:
118
  audio = audio * target_rms / rms
@@ -120,44 +122,49 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
120
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
121
  audio = resampler(audio)
122
  audio = audio.to(device)
123
-
124
- # Prepare the text
125
- text_list = [ref_text + gen_text]
126
- final_text_list = convert_char_to_pinyin(text_list)
127
-
128
- # Calculate duration
129
- ref_audio_len = audio.shape[-1] // hop_length
130
- # if fix_duration is not None:
131
- # duration = int(fix_duration * target_sample_rate / hop_length)
132
- # else:
133
- zh_pause_punc = r"。,、;:?!"
134
- ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
135
- gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
136
- duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
137
-
138
- # inference
139
- gr.Info(f"Generating audio using {exp_name}")
140
- with torch.inference_mode():
141
- generated, _ = base_model.sample(
142
- cond=audio,
143
- text=final_text_list,
144
- duration=duration,
145
- steps=nfe_step,
146
- cfg_strength=cfg_strength,
147
- sway_sampling_coef=sway_sampling_coef,
148
- )
149
-
150
- generated = generated[:, ref_audio_len:, :]
151
- generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
152
- gr.Info("Running vocoder")
153
- vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
154
- generated_wave = vocos.decode(generated_mel_spec.cpu())
155
- if rms < target_rms:
156
- generated_wave = generated_wave * rms / target_rms
157
-
158
- # wav -> numpy
159
- generated_wave = generated_wave.squeeze().cpu().numpy()
160
-
 
 
 
 
 
161
  if remove_silence:
162
  gr.Info("Removing audio silences... This may take a moment")
163
  non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
@@ -169,11 +176,11 @@ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence):
169
 
170
 
171
  # spectogram
172
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
173
- spectrogram_path = tmp_spectrogram.name
174
- save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
175
 
176
- return (target_sample_rate, generated_wave), spectrogram_path
177
 
178
  with gr.Blocks() as app:
179
  gr.Markdown("""
@@ -190,21 +197,38 @@ The checkpoints support English and Chinese.
190
 
191
  If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. If you're still running into issues, please open a [community Discussion](https://huggingface.co/spaces/mrfakename/E2-F5-TTS/discussions).
192
 
 
 
193
  **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
194
  """)
195
 
196
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
197
- gen_text_input = gr.Textbox(label="Text to Generate (max 200 chars.)", lines=4)
198
  model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
199
  generate_btn = gr.Button("Synthesize", variant="primary")
200
  with gr.Accordion("Advanced Settings", open=False):
201
  ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
202
  remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
203
-
204
  audio_output = gr.Audio(label="Synthesized Audio")
205
- spectrogram_output = gr.Image(label="Spectrogram")
 
 
 
 
 
 
206
 
207
- generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output, spectrogram_output])
 
 
 
 
 
 
 
 
 
208
  gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
209
 
210
 
 
14
  from model import CFM, UNetT, DiT, MMDiT
15
  from cached_path import cached_path
16
  from model.utils import (
17
+ get_tokenizer,
18
+ convert_char_to_pinyin,
19
  save_spectrogram,
20
  )
21
  from transformers import pipeline
22
+ import spaces
23
  import librosa
24
+ from txtsplit import txtsplit
25
 
26
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
 
 
 
28
  pipe = pipeline(
29
  "automatic-speech-recognition",
30
  model="openai/whisper-large-v3-turbo",
 
79
  F5TTS_ema_model, F5TTS_base_model = load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
80
  E2TTS_ema_model, E2TTS_base_model = load_model("E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
81
 
82
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, progress = gr.Progress()):
83
  print(gen_text)
 
 
84
  gr.Info("Converting audio...")
85
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
86
  aseg = AudioSegment.from_file(ref_audio_orig)
87
+ # Convert to mono
88
+ aseg = aseg.set_channels(1)
89
  audio_duration = len(aseg)
90
  if audio_duration > 15000:
91
  gr.Warning("Audio is over 15s, clipping to only first 15s.")
 
98
  elif exp_name == "E2-TTS":
99
  ema_model = E2TTS_ema_model
100
  base_model = E2TTS_base_model
101
+
102
  if not ref_text.strip():
103
  gr.Info("No reference text provided, transcribing reference audio...")
104
  ref_text = outputs = pipe(
 
112
  else:
113
  gr.Info("Using custom reference text...")
114
  audio, sr = torchaudio.load(ref_audio)
115
+ # Audio
116
+ if audio.shape[0] > 1:
117
+ audio = torch.mean(audio, dim=0, keepdim=True)
118
  rms = torch.sqrt(torch.mean(torch.square(audio)))
119
  if rms < target_rms:
120
  audio = audio * target_rms / rms
 
122
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
123
  audio = resampler(audio)
124
  audio = audio.to(device)
125
+ # Chunk
126
+ chunks = txtsplit(gen_text, 100, 150) # 100 chars preferred, 150 max
127
+ results = []
128
+ generated_mel_specs = []
129
+ for chunk in progress.tqdm(chunks):
130
+ # Prepare the text
131
+ text_list = [ref_text + chunk]
132
+ final_text_list = convert_char_to_pinyin(text_list)
133
+
134
+ # Calculate duration
135
+ ref_audio_len = audio.shape[-1] // hop_length
136
+ # if fix_duration is not None:
137
+ # duration = int(fix_duration * target_sample_rate / hop_length)
138
+ # else:
139
+ zh_pause_punc = r"。,、;:?!"
140
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
141
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
142
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
143
+
144
+ # inference
145
+ gr.Info(f"Generating audio using {exp_name}")
146
+ with torch.inference_mode():
147
+ generated, _ = base_model.sample(
148
+ cond=audio,
149
+ text=final_text_list,
150
+ duration=duration,
151
+ steps=nfe_step,
152
+ cfg_strength=cfg_strength,
153
+ sway_sampling_coef=sway_sampling_coef,
154
+ )
155
+
156
+ generated = generated[:, ref_audio_len:, :]
157
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
158
+ gr.Info("Running vocoder")
159
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
160
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
161
+ if rms < target_rms:
162
+ generated_wave = generated_wave * rms / target_rms
163
+
164
+ # wav -> numpy
165
+ generated_wave = generated_wave.squeeze().cpu().numpy()
166
+ results.append(generated_wave)
167
+ generated_wave = np.concatenate(results)
168
  if remove_silence:
169
  gr.Info("Removing audio silences... This may take a moment")
170
  non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
 
176
 
177
 
178
  # spectogram
179
+ # with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
180
+ # spectrogram_path = tmp_spectrogram.name
181
+ # save_spectrogram(generated_mel_spec[0].cpu().numpy(), spectrogram_path)
182
 
183
+ return (target_sample_rate, generated_wave)
184
 
185
  with gr.Blocks() as app:
186
  gr.Markdown("""
 
197
 
198
  If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt. If you're still running into issues, please open a [community Discussion](https://huggingface.co/spaces/mrfakename/E2-F5-TTS/discussions).
199
 
200
+ Long-form/batched inference + speech editing is coming soon!
201
+
202
  **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
203
  """)
204
 
205
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
206
+ gen_text_input = gr.Textbox(label="Text to Generate (longer text will use chunking)", lines=4)
207
  model_choice = gr.Radio(choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS")
208
  generate_btn = gr.Button("Synthesize", variant="primary")
209
  with gr.Accordion("Advanced Settings", open=False):
210
  ref_text_input = gr.Textbox(label="Reference Text", info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.", lines=2)
211
  remove_silence = gr.Checkbox(label="Remove Silences", info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.", value=True)
212
+
213
  audio_output = gr.Audio(label="Synthesized Audio")
214
+ # spectrogram_output = gr.Image(label="Spectrogram")
215
+
216
+ generate_btn.click(infer, inputs=[ref_audio_input, ref_text_input, gen_text_input, model_choice, remove_silence], outputs=[audio_output])
217
+ gr.Markdown("""
218
+ ## Run Locally
219
+
220
+ Run this demo locally on CPU, CUDA, or MPS/Apple Silicon (requires macOS >= 14):
221
 
222
+ First, ensure `ffmpeg` is installed.
223
+
224
+ ```bash
225
+ git clone https://huggingface.co/spaces/mrfakename/E2-F5-TTS
226
+ cd E2-F5-TTS
227
+ python -m pip install -r requirements.txt
228
+ python app_local.py
229
+ ```
230
+
231
+ """)
232
  gr.Markdown("Unofficial demo by [mrfakename](https://x.com/realmrfakename)")
233
 
234