# Imports import gradio as gr import spaces import re import torch import torchaudio import numpy as np import tempfile import click import soundfile as sf from einops import rearrange from vocos import Vocos from pydub import AudioSegment, silence from model import CFM, UNetT, DiT, MMDiT from cached_path import cached_path from model.utils import (load_checkpoint, get_tokenizer, convert_char_to_pinyin, save_spectrogram) # Pre-Initialize DEVICE = "auto" if DEVICE == "auto": DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[SYSTEM] | Using {DEVICE} type compute device.") target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 ode_method = "euler" def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step): ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin") model = CFM( transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels), mel_spec_kwargs=dict(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length), odeint_kwargs=dict(method=ode_method), vocab_char_map=vocab_char_map, ).to(DEVICE) model = load_checkpoint(model, ckpt_path, DEVICE, use_ema = True) return model # Variables DEFAULT_MODEL = "F5" DEFAULT_REMOVE_SILENCES = True DEFAULT_STEPS = 32 DEFAULT_SPEED = 1 DEFAULT_CROSS_FADE = 0.15 target_rms = 0.1 cfg_strength = 2.0 sway_sampling_coef = -1.0 silence_offset = 25 silence_min_len = 500 vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) F5TTS_ema_model = load_model("F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000) E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) E2TTS_ema_model = load_model("E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000) css = ''' .gradio-container{max-width: 560px !important} h1{text-align:center} footer { visibility: hidden } ''' # Functions @spaces.GPU(duration=30) def infer_batch(input_batches, reference_audio, reference_input, model_choice=DEFAULT_MODEL, remove_silences=DEFAULT_REMOVE_SILENCES, steps=DEFAULT_STEPS, speed=DEFAULT_SPEED, cross_fade=DEFAULT_CROSS_FADE): if model_choice == "F5": ema_model = F5TTS_ema_model elif model_choice == "E2": ema_model = E2TTS_ema_model print("Waiting for inference...") audio, sr = reference_audio if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate) audio = resampler(audio) audio = audio.to(DEVICE) generated_waves = [] if len(reference_input[-1].encode('utf-8')) == 1: reference_input = reference_input + " " print("Inferencing each batch...") for i, input in enumerate(input_batches): text_list = [reference_input + input] final_text_list = convert_char_to_pinyin(text_list) reference_audio_len = audio.shape[-1] // hop_length zh_pause_punc = r"。,、;:?!" reference_input_len = len(reference_input.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, reference_input)) input_len = len(input.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, input)) duration = reference_audio_len + int(reference_audio_len / reference_input_len * input_len / speed) # Inference with torch.inference_mode(): generated, _ = ema_model.sample(cond=audio, text=final_text_list, duration=duration, steps=steps, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef) generated = generated[:, reference_audio_len:, :] generated_mel_spec = rearrange(generated, "1 n d -> 1 d n") generated_wave = vocos.decode(generated_mel_spec.cpu()) if rms < target_rms: generated_wave = generated_wave * rms / target_rms generated_wave = generated_wave.squeeze().cpu().numpy() generated_waves.append(generated_wave) # Handle combining generated waves with cross-fading print("Handling combining and cross-fading...") if cross_fade <= 0: final_wave = np.concatenate(generated_waves) else: final_wave = generated_waves[0] for i in range(1, len(generated_waves)): prev_wave = final_wave next_wave = generated_waves[i] cross_fade_samples = int(cross_fade * target_sample_rate) cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) if cross_fade_samples <= 0: final_wave = np.concatenate([prev_wave, next_wave]) continue prev_overlap = prev_wave[-cross_fade_samples:] next_overlap = next_wave[:cross_fade_samples] fade_out = np.linspace(1, 0, cross_fade_samples) fade_in = np.linspace(0, 1, cross_fade_samples) cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in new_wave = np.concatenate([prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]) final_wave = new_wave # Handle removing silences print("Handling removing silences...") if remove_silences: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: sf.write(f.name, final_wave, target_sample_rate) aseg = AudioSegment.from_file(f.name) non_silent_segs = silence.split_on_silence(aseg, min_silence_len=silence_min_len, silence_thresh=aseg.dBFS - silence_offset, keep_silence=250) non_silent_wave = AudioSegment.empty() for seg in non_silent_segs: non_silent_wave += seg aseg = non_silent_wave aseg.export(f.name, format="wav") final_wave, _ = torchaudio.load(f.name) final_wave = final_wave.squeeze().cpu().numpy() print("Done!") return (target_sample_rate, final_wave) @spaces.GPU(duration=30) def infer(input, reference_audio, reference_input, model_choice=DEFAULT_MODEL, remove_silences=DEFAULT_REMOVE_SILENCES, steps=DEFAULT_STEPS, speed=DEFAULT_SPEED, cross_fade=DEFAULT_CROSS_FADE): print("Modifying reference audio...") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(reference_audio) audio_duration = len(aseg) if audio_duration > 15000: gr.Warning("Audio is over 15s, clipping to only first 15s.") aseg = aseg[:15000] aseg.export(f.name, format="wav") ref_audio = f.name # Ensure it ends with period. print("Modifying reference input...") if not reference_input.endswith(". "): if reference_input.endswith("."): reference_input += " " else: reference_input += ". " print("Loading reference audio...") audio, sr = torchaudio.load(ref_audio) # Split input into chunks print("--------------------------------------------- INPUT") print(f"Input: {input}") print(f"Reference Input: {reference_input}") print(f"Parameters: {model_choice}, {remove_silences}, {steps}, {speed}, {cross_fade}") print("---------------------------------------------------") max_chars = int(len(reference_input.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr)) input_batches = chunk_text(input, max_chars=max_chars) print(f"------------------------------------------ BATCHES") for i, batch_text in enumerate(input_batches): print(f" {i}: ", batch_text) print("---------------------------------------------------") return infer_batch(input_batches, (audio, sr), reference_input, model_choice, remove_silences, steps, speed, cross_fade) def chunk_text(text, max_chars=135): chunks = [] current_chunk = "" # Split input into sentences with punctuations sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text) for sentence in sentences: if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars: current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence if current_chunk: chunks.append(current_chunk.strip()) print("-------------------------------------------- CHUNKS") print(chunks) print("---------------------------------------------------") return chunks def cloud(): print("[CLOUD] | Space maintained.") # Initialize with gr.Blocks(css=css) as main: with gr.Column(): gr.Markdown("🪄 Speak text to audio.") with gr.Column(): input = gr.Textbox(lines=1, value="", label="Input") reference_audio = gr.Audio(sources="upload", type="filepath", label="Reference Audio") reference_input = gr.Textbox(lines=1, value="", label="Reference Text") model_choice = gr.Radio(["F5", "E2"], label="TTS Model", value=DEFAULT_MODEL) remove_silences = gr.Checkbox(value=DEFAULT_REMOVE_SILENCES, label="Remove Silences") steps = gr.Slider(minimum=1, maximum=64, value=DEFAULT_STEPS, step=1, label="Steps") speed = gr.Slider(minimum=0.3, maximum=2.0, value=DEFAULT_SPEED, step=0.1, label="Speed") cross_fade = gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_CROSS_FADE, step=0.01, label="Audio Cross-Fade Duration Between Sentences") submit = gr.Button("▶") maintain = gr.Button("☁️") with gr.Column(): output = gr.Audio(label="Output") submit.click(infer, inputs=[input, reference_audio, reference_input, model_choice, remove_silences, steps, speed, cross_fade], outputs=output, queue=False) maintain.click(cloud, inputs=[], outputs=[], queue=False) main.launch(show_api=True)