# ruff: noqa: E402 # Above allows ruff to ignore E402: module level import not at top of file import re import tempfile import gradio as gr import soundfile as sf import torchaudio try: import spaces USING_SPACES = True except ImportError: USING_SPACES = False def gpu_decorator(func): if USING_SPACES: return spaces.GPU(func) else: return func from f5.src.f5_tts.model import DiT from f5.src.f5_tts.infer.utils_infer import ( load_vocoder, load_model, preprocess_ref_audio_text, infer_process, remove_silence_for_generated_wav, save_spectrogram, ) F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) F5TTS_ema_model = None vocoder = None def load(path = "./model_1200000.safetensors"): global F5TTS_ema_model F5TTS_ema_model = load_model( DiT, F5TTS_model_cfg, path ) def loadVoc(): global vocoder vocoder = load_vocoder() @gpu_decorator def infer(ref_audio_orig, ref_text, gen_text, remove_silence, cross_fade_duration=0.10, speed=0.9): if(F5TTS_ema_model == None): load() if(vocoder == None): loadVoc() ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text) ema_model = F5TTS_ema_model final_wave, final_sample_rate, combined_spectrogram = infer_process( ref_audio, ref_text, gen_text, ema_model, vocoder, cross_fade_duration=cross_fade_duration, speed=speed, ) print("final_wave doe") if remove_silence: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: sf.write(f.name, final_wave, final_sample_rate) remove_silence_for_generated_wav(f.name) final_wave, _ = torchaudio.load(f.name) final_wave = final_wave.squeeze().cpu().numpy() print('silence removed') with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram: spectrogram_path = tmp_spectrogram.name save_spectrogram(combined_spectrogram, spectrogram_path) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f_audio: audio_path = f_audio.name sf.write(audio_path, final_wave, final_sample_rate) return audio_path, spectrogram_path def parse_speechtypes_text(gen_text): # Pattern to find {speechtype} pattern = r"\{(.*?)\}" # Split the text by the pattern tokens = re.split(pattern, gen_text) segments = [] current_emotion = "Regular" for i in range(len(tokens)): if i % 2 == 0: # This is text text = tokens[i].strip() if text: segments.append({"emotion": current_emotion, "text": text}) else: # This is emotion emotion = tokens[i].strip() current_emotion = emotion return segments