Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,049 Bytes
4dab15f fededd1 4dab15f 43bc5dc 1a19e0f 4dab15f fededd1 1a19e0f fededd1 1a19e0f 4dab15f fededd1 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1bcb8fe 4dab15f 1a19e0f 43bc5dc 1a19e0f b315dd9 1a19e0f 4dab15f 1a19e0f 1bcb8fe 4dab15f 1a19e0f 4dab15f 1a19e0f fededd1 1a19e0f fededd1 4dab15f c0fb8c8 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 6f84f9c b0bca14 5f7ec69 b0bca14 1a19e0f 4dab15f b6584c2 fededd1 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f 1a19e0f 4dab15f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import random
import sys
from importlib.resources import files
import soundfile as sf
import tqdm
from cached_path import cached_path
from hydra.utils import get_class
from omegaconf import OmegaConf
from f5_tts.infer.utils_infer import (
load_model,
load_vocoder,
transcribe,
preprocess_ref_audio_text,
infer_process,
remove_silence_for_generated_wav,
save_spectrogram,
)
from f5_tts.model.utils import seed_everything
class F5TTS:
def __init__(
self,
model="F5TTS_v1_Base",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
vocoder_local_path=None,
device=None,
hf_cache_dir=None,
):
model_cfg = OmegaConf.load(str(files("f5_tts").joinpath(f"configs/{model}.yaml")))
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
self.mel_spec_type = model_cfg.model.mel_spec.mel_spec_type
self.target_sample_rate = model_cfg.model.mel_spec.target_sample_rate
self.ode_method = ode_method
self.use_ema = use_ema
if device is not None:
self.device = device
else:
import torch
self.device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
# Load models
self.vocoder = load_vocoder(
self.mel_spec_type, vocoder_local_path is not None, vocoder_local_path, self.device, hf_cache_dir
)
repo_name, ckpt_step, ckpt_type = "F5-TTS", 1250000, "safetensors"
# override for previous models
if model == "F5TTS_Base":
if self.mel_spec_type == "vocos":
ckpt_step = 1200000
elif self.mel_spec_type == "bigvgan":
model = "F5TTS_Base_bigvgan"
ckpt_type = "pt"
elif model == "E2TTS_Base":
repo_name = "E2-TTS"
ckpt_step = 1200000
if not ckpt_file:
ckpt_file = str(
cached_path(f"hf://SWivid/{repo_name}/{model}/model_{ckpt_step}.{ckpt_type}", cache_dir=hf_cache_dir)
)
self.ema_model = load_model(
model_cls, model_arc, ckpt_file, self.mel_spec_type, vocab_file, self.ode_method, self.use_ema, self.device
)
def transcribe(self, ref_audio, language=None):
return transcribe(ref_audio, language)
def export_wav(self, wav, file_wave, remove_silence=False):
sf.write(file_wave, wav, self.target_sample_rate)
if remove_silence:
remove_silence_for_generated_wav(file_wave)
def export_spectrogram(self, spec, file_spec):
save_spectrogram(spec, file_spec)
def infer(
self,
ref_file,
ref_text,
gen_text,
show_info=print,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spec=None,
seed=None,
):
if seed is None:
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text)
wav, sr, spec = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
self.vocoder,
self.mel_spec_type,
show_info=show_info,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=self.device,
)
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)
if file_spec is not None:
self.export_spectrogram(spec, file_spec)
return wav, sr, spec
if __name__ == "__main__":
f5tts = F5TTS()
wav, sr, spec = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spec=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=None,
)
print("seed :", f5tts.seed)
|