Spaces:
Sleeping
Sleeping
# ruff: noqa: E402 | |
# Above allows ruff to ignore E402: module level import not at top of file | |
import re | |
import tempfile | |
import gradio as gr | |
import soundfile as sf | |
import torchaudio | |
try: | |
import spaces | |
USING_SPACES = True | |
except ImportError: | |
USING_SPACES = False | |
def gpu_decorator(func): | |
if USING_SPACES: | |
return spaces.GPU(func) | |
else: | |
return func | |
from f5.src.f5_tts.model import DiT | |
from f5.src.f5_tts.infer.utils_infer import ( | |
load_vocoder, | |
load_model, | |
preprocess_ref_audio_text, | |
infer_process, | |
remove_silence_for_generated_wav, | |
save_spectrogram, | |
) | |
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) | |
F5TTS_ema_model = None | |
vocoder = None | |
def load(path = "./model_1200000.safetensors"): | |
global F5TTS_ema_model | |
F5TTS_ema_model = load_model( | |
DiT, F5TTS_model_cfg, path | |
) | |
def loadVoc(): | |
global vocoder | |
vocoder = load_vocoder() | |
def infer(ref_audio_orig, ref_text, gen_text, remove_silence, cross_fade_duration=0.10, speed=0.9): | |
if(F5TTS_ema_model == None): | |
load() | |
if(vocoder == None): | |
loadVoc() | |
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text) | |
ema_model = F5TTS_ema_model | |
final_wave, final_sample_rate, combined_spectrogram = infer_process( | |
ref_audio, | |
ref_text, | |
gen_text, | |
ema_model, | |
vocoder, | |
cross_fade_duration=cross_fade_duration, | |
speed=speed, | |
) | |
print("final_wave doe") | |
if remove_silence: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
sf.write(f.name, final_wave, final_sample_rate) | |
remove_silence_for_generated_wav(f.name) | |
final_wave, _ = torchaudio.load(f.name) | |
final_wave = final_wave.squeeze().cpu().numpy() | |
print('silence removed') | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram: | |
spectrogram_path = tmp_spectrogram.name | |
save_spectrogram(combined_spectrogram, spectrogram_path) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f_audio: | |
audio_path = f_audio.name | |
sf.write(audio_path, final_wave, final_sample_rate) | |
return audio_path, spectrogram_path | |
def parse_speechtypes_text(gen_text): | |
# Pattern to find {speechtype} | |
pattern = r"\{(.*?)\}" | |
# Split the text by the pattern | |
tokens = re.split(pattern, gen_text) | |
segments = [] | |
current_emotion = "Regular" | |
for i in range(len(tokens)): | |
if i % 2 == 0: | |
# This is text | |
text = tokens[i].strip() | |
if text: | |
segments.append({"emotion": current_emotion, "text": text}) | |
else: | |
# This is emotion | |
emotion = tokens[i].strip() | |
current_emotion = emotion | |
return segments | |