dubai / models /voice.py
fountai's picture
mimic
1fe2f2f
raw
history blame
2.94 kB
# ruff: noqa: E402
# Above allows ruff to ignore E402: module level import not at top of file
import re
import tempfile
import gradio as gr
import soundfile as sf
import torchaudio
try:
import spaces
USING_SPACES = True
except ImportError:
USING_SPACES = False
def gpu_decorator(func):
if USING_SPACES:
return spaces.GPU(func)
else:
return func
from f5.src.f5_tts.model import DiT
from f5.src.f5_tts.infer.utils_infer import (
load_vocoder,
load_model,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
)
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
F5TTS_ema_model = None
vocoder = None
def load(path = "./model_1200000.safetensors"):
global F5TTS_ema_model
F5TTS_ema_model = load_model(
DiT, F5TTS_model_cfg, path
)
def loadVoc():
global vocoder
vocoder = load_vocoder()
@gpu_decorator
def infer(ref_audio_orig, ref_text, gen_text, remove_silence, cross_fade_duration=0.10, speed=0.9):
if(F5TTS_ema_model == None):
load()
if(vocoder == None):
loadVoc()
ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text)
ema_model = F5TTS_ema_model
final_wave, final_sample_rate, combined_spectrogram = infer_process(
ref_audio,
ref_text,
gen_text,
ema_model,
vocoder,
cross_fade_duration=cross_fade_duration,
speed=speed,
)
print("final_wave doe")
if remove_silence:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, final_wave, final_sample_rate)
remove_silence_for_generated_wav(f.name)
final_wave, _ = torchaudio.load(f.name)
final_wave = final_wave.squeeze().cpu().numpy()
print('silence removed')
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f_audio:
audio_path = f_audio.name
sf.write(audio_path, final_wave, final_sample_rate)
return audio_path, spectrogram_path
def parse_speechtypes_text(gen_text):
# Pattern to find {speechtype}
pattern = r"\{(.*?)\}"
# Split the text by the pattern
tokens = re.split(pattern, gen_text)
segments = []
current_emotion = "Regular"
for i in range(len(tokens)):
if i % 2 == 0:
# This is text
text = tokens[i].strip()
if text:
segments.append({"emotion": current_emotion, "text": text})
else:
# This is emotion
emotion = tokens[i].strip()
current_emotion = emotion
return segments