import random import sys from importlib.resources import files import soundfile as sf import tqdm from cached_path import cached_path from hydra.utils import get_class from omegaconf import OmegaConf from f5_tts.infer.utils_infer import ( load_model, load_vocoder, transcribe, preprocess_ref_audio_text, infer_process, remove_silence_for_generated_wav, save_spectrogram, ) from f5_tts.model.utils import seed_everything class F5TTS: def __init__( self, model="F5TTS_v1_Base", ckpt_file="", vocab_file="", ode_method="euler", use_ema=True, vocoder_local_path=None, device=None, hf_cache_dir=None, ): model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml"))) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate self.ode_method = ode_method self.use_ema = use_ema if device is not None: self.device = device else: import torch self.device = ( "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) # Load models self.vocoder = load_vocoder( self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir ) repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors" # override for previous models if model == "F5TTS_Base": if self.mel_spec_type == "vocos": ckpt_step = 1200000 elif self.mel_spec_type == "bigvgan": model = "F5TTS_Base_bigvgan" ckpt_type = "pt" elif model == "E2TTS_Base": repo_name = "E2-TTS" ckpt_step = 1200000 if not ckpt_file: ckpt_file = str( cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir) ) self.ema_model = load_model( model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device ) def transcribe(self, ref_audio, language=None): return transcribe(ref_audio, language) def export_wav(self, wav, file_wave, remove_silence=False): sf.write(file_wave, wav, self.target_sample_rate) if remove_silence: remove_silence_for_generated_wav(file_wave) def export_spectrogram(self, spec, file_spec): save_spectrogram(spec, file_spec) def infer( self, ref_file, ref_text, gen_text, show_info=print, progress=tqdm, target_rms=0.1, cross_fade_duration=0.15, sway_sampling_coef=-1, cfg_strength=2, nfe_step=32, speed=1.0, fix_duration=None, remove_silence=False, file_wave=None, file_spec=None, seed=None, ): if seed is None: seed = random.randint(0, sys.maxsize) seed_everything(seed) self.seed = seed ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text) wav, sr, spec = infer_process( ref_file, ref_text, gen_text, self.ema_model, self.vocoder, self.mel_spec_type, show_info=show_info, progress=progress, target_rms=target_rms, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, speed=speed, fix_duration=fix_duration, device=self.device, ) if file_wave is not None: self.export_wav(wav, file_wave, remove_silence) if file_spec is not None: self.export_spectrogram(spec, file_spec) return wav, sr, spec if __name__ == "__main__": f5tts = F5TTS() wav, sr, spec = f5tts.infer( ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), ref_text="some call me nature, others call me mother nature.", gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")), file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")), seed=None, ) print("seed :", f5tts.seed)