|
from pathlib import Path |
|
import os.path |
|
from .Install import Install |
|
import subprocess |
|
import wave |
|
import torch |
|
import torchaudio |
|
import hashlib |
|
import folder_paths |
|
import tempfile |
|
import soundfile as sf |
|
import shutil |
|
import sys |
|
import numpy as np |
|
import re |
|
import io |
|
from comfy.utils import ProgressBar |
|
from cached_path import cached_path |
|
sys.path.append(os.path.join(Install.f5TTSPath, "src")) |
|
from f5_tts.model import DiT,UNetT |
|
from f5_tts.infer.utils_infer import ( |
|
load_model, |
|
load_vocoder, |
|
preprocess_ref_audio_text, |
|
infer_process, |
|
) |
|
sys.path.pop() |
|
|
|
|
|
class F5TTSCreate: |
|
voice_reg = re.compile(r"\{([^\}]+)\}") |
|
model_types = ["F5", "F5-HI", "F5-JP", "F5-FR", "E2"] |
|
vocoder_types = ["vocos", "bigvgan"] |
|
tooltip_seed = "Seed. -1 = random" |
|
|
|
def get_model_types(): |
|
model_types = F5TTSCreate.model_types[:] |
|
models_path = folder_paths.get_folder_paths("checkpoints") |
|
for model_path in models_path: |
|
f5_model_path = os.path.join(model_path, 'F5-TTS') |
|
if os.path.isdir(f5_model_path): |
|
for file in os.listdir(f5_model_path): |
|
p = Path(file) |
|
if ( |
|
p.suffix in folder_paths.supported_pt_extensions |
|
and os.path.isfile(os.path.join(f5_model_path, file)) |
|
): |
|
txtFile = F5TTSCreate.get_txt_file_path( |
|
os.path.join(f5_model_path, file) |
|
) |
|
|
|
if ( |
|
os.path.isfile(txtFile) |
|
): |
|
model_types.append("model://"+file) |
|
return model_types |
|
|
|
@staticmethod |
|
def get_txt_file_path(file): |
|
p = Path(file) |
|
return os.path.join(os.path.dirname(file), p.stem + ".txt") |
|
|
|
def is_voice_name(self, word): |
|
return self.voice_reg.match(word.strip()) |
|
|
|
def get_voice_names(self, chunks): |
|
voice_names = {} |
|
for text in chunks: |
|
match = self.is_voice_name(text) |
|
if match: |
|
voice_names[match[1]] = True |
|
return voice_names |
|
|
|
def split_text(self, speech): |
|
reg1 = r"(?=\{[^\}]+\})" |
|
return re.split(reg1, speech) |
|
|
|
@staticmethod |
|
def load_voice(ref_audio, ref_text): |
|
main_voice = {"ref_audio": ref_audio, "ref_text": ref_text} |
|
|
|
main_voice["ref_audio"], main_voice["ref_text"] = preprocess_ref_audio_text( |
|
ref_audio, ref_text |
|
) |
|
return main_voice |
|
|
|
def get_model_funcs(self): |
|
return { |
|
"F5": self.load_f5_model, |
|
"F5-HI": self.load_f5_model_hi, |
|
"F5-JP": self.load_f5_model_jp, |
|
"F5-FR": self.load_f5_model_fr, |
|
"E2": self.load_e2_model, |
|
} |
|
|
|
def get_vocoder(self, vocoder_name): |
|
if vocoder_name == "vocos": |
|
os.path.join(Install.f5TTSPath, "checkpoints/vocos-mel-24khz") |
|
elif vocoder_name == "bigvgan": |
|
os.path.join(Install.f5TTSPath, "checkpoints/bigvgan_v2_24khz_100band_256x") |
|
|
|
def load_vocoder(self, vocoder_name): |
|
return load_vocoder(vocoder_name=vocoder_name) |
|
|
|
def load_model(self, model, vocoder_name): |
|
model_funcs = self.get_model_funcs() |
|
if model in model_funcs: |
|
return model_funcs[model](vocoder_name) |
|
else: |
|
return self.load_f5_model_url(model, vocoder_name) |
|
|
|
def get_vocab_file(self): |
|
return os.path.join( |
|
Install.f5TTSPath, "data/Emilia_ZH_EN_pinyin/vocab.txt" |
|
) |
|
|
|
def load_e2_model(self, vocoder): |
|
model_cls = UNetT |
|
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) |
|
repo_name = "E2-TTS" |
|
exp_name = "E2TTS_Base" |
|
ckpt_step = 1200000 |
|
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) |
|
vocab_file = self.get_vocab_file() |
|
vocoder_name = "vocos" |
|
ema_model = load_model( |
|
model_cls, model_cfg, |
|
ckpt_file, vocab_file=vocab_file, |
|
mel_spec_type=vocoder_name, |
|
) |
|
vocoder = self.load_vocoder(vocoder_name) |
|
return (ema_model, vocoder, vocoder_name) |
|
|
|
def load_f5_model(self, vocoder): |
|
repo_name = "F5-TTS" |
|
if vocoder == "bigvgan": |
|
exp_name = "F5TTS_Base_bigvgan" |
|
ckpt_step = 1250000 |
|
else: |
|
exp_name = "F5TTS_Base" |
|
ckpt_step = 1200000 |
|
return self.load_f5_model_url( |
|
f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors", |
|
vocoder, |
|
) |
|
|
|
def load_f5_model_jp(self, vocoder): |
|
return self.load_f5_model_url( |
|
"hf://Jmica/F5TTS/JA_8500000/model_8499660.pt", |
|
vocoder, |
|
"hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt" |
|
) |
|
|
|
def load_f5_model_fr(self, vocoder): |
|
return self.load_f5_model_url( |
|
"hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_1374000.pt", |
|
vocoder, |
|
"hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt" |
|
) |
|
|
|
def cached_path(self, url): |
|
if url.startswith("model:"): |
|
path = re.sub("^model:/*", "", url) |
|
models_path = folder_paths.get_folder_paths("checkpoints") |
|
for model_path in models_path: |
|
f5_model_path = os.path.join(model_path, 'F5-TTS') |
|
model_file = os.path.join(f5_model_path, path) |
|
if os.path.isfile(model_file): |
|
return model_file |
|
raise FileNotFoundError("No model found: " + url) |
|
return None |
|
return str(cached_path(url)) |
|
|
|
def load_f5_model_hi(self, vocoder): |
|
model_cfg = dict( |
|
dim=768, depth=18, heads=12, |
|
ff_mult=2, text_dim=512, conv_layers=4 |
|
) |
|
return self.load_f5_model_url( |
|
"hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors", |
|
"vocos", |
|
"hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt", |
|
model_cfg=model_cfg, |
|
) |
|
|
|
def load_f5_model_url( |
|
self, url, vocoder_name, vocab_url=None, model_cfg=None |
|
): |
|
vocoder = self.load_vocoder(vocoder_name) |
|
model_cls = DiT |
|
if model_cfg is None: |
|
model_cfg = dict( |
|
dim=1024, depth=22, heads=16, |
|
ff_mult=2, text_dim=512, conv_layers=4 |
|
) |
|
|
|
ckpt_file = str(self.cached_path(url)) |
|
|
|
if vocab_url is None: |
|
if url.startswith("model:"): |
|
vocab_file = F5TTSCreate.get_txt_file_path(ckpt_file) |
|
else: |
|
vocab_file = self.get_vocab_file() |
|
else: |
|
vocab_file = str(self.cached_path(vocab_url)) |
|
ema_model = load_model( |
|
model_cls, model_cfg, |
|
ckpt_file, vocab_file=vocab_file, |
|
mel_spec_type=vocoder_name, |
|
) |
|
return (ema_model, vocoder, vocoder_name) |
|
|
|
def generate_audio( |
|
self, voices, model_obj, chunks, seed, vocoder, mel_spec_type |
|
): |
|
if seed >= 0: |
|
torch.manual_seed(seed) |
|
else: |
|
torch.random.seed() |
|
|
|
frame_rate = 44100 |
|
generated_audio_segments = [] |
|
pbar = ProgressBar(len(chunks)) |
|
for text in chunks: |
|
match = self.is_voice_name(text) |
|
if match: |
|
voice = match[1] |
|
else: |
|
print("No voice tag found, using main.") |
|
voice = "main" |
|
if voice not in voices: |
|
print(f"Voice {voice} not found, using main.") |
|
voice = "main" |
|
text = F5TTSCreate.voice_reg.sub("", text) |
|
gen_text = text.strip() |
|
if gen_text == "": |
|
print(f"No text for {voice}, skip") |
|
continue |
|
ref_audio = voices[voice]["ref_audio"] |
|
ref_text = voices[voice]["ref_text"] |
|
print(f"Voice: {voice}") |
|
print("text:"+text) |
|
audio, final_sample_rate, spectragram = infer_process( |
|
ref_audio, ref_text, gen_text, model_obj, |
|
vocoder=vocoder, mel_spec_type=mel_spec_type |
|
) |
|
generated_audio_segments.append(audio) |
|
frame_rate = final_sample_rate |
|
pbar.update(1) |
|
|
|
if generated_audio_segments: |
|
final_wave = np.concatenate(generated_audio_segments) |
|
wave_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
|
sf.write(wave_file.name, final_wave, frame_rate) |
|
wave_file.close() |
|
|
|
waveform, sample_rate = torchaudio.load(wave_file.name) |
|
audio = { |
|
"waveform": waveform.unsqueeze(0), |
|
"sample_rate": sample_rate |
|
} |
|
os.unlink(wave_file.name) |
|
return audio |
|
|
|
def create( |
|
self, voices, chunks, seed=-1, model="F5", vocoder_name="vocos" |
|
): |
|
( |
|
model_obj, |
|
vocoder, |
|
mel_spec_type |
|
) = self.load_model(model, vocoder_name) |
|
return self.generate_audio( |
|
voices, |
|
model_obj, |
|
chunks, seed, |
|
vocoder, mel_spec_type=mel_spec_type, |
|
) |
|
|
|
|
|
class F5TTSAudioInputs: |
|
def __init__(self): |
|
self.wave_file_name = None |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
model_types = F5TTSCreate.get_model_types() |
|
return { |
|
"required": { |
|
"sample_audio": ("AUDIO",), |
|
"sample_text": ("STRING", {"default": "Text of sample_audio"}), |
|
"speech": ("STRING", { |
|
"multiline": True, |
|
"default": "This is what I want to say" |
|
}), |
|
"seed": ("INT", { |
|
"display": "number", "step": 1, |
|
"default": 1, "min": -1, |
|
"tooltip": F5TTSCreate.tooltip_seed, |
|
}), |
|
"model": (model_types,), |
|
|
|
}, |
|
} |
|
|
|
CATEGORY = "audio" |
|
|
|
RETURN_TYPES = ("AUDIO", ) |
|
FUNCTION = "create" |
|
|
|
def load_voice_from_input(self, sample_audio, sample_text): |
|
wave_file = tempfile.NamedTemporaryFile( |
|
suffix=".wav", delete=False |
|
) |
|
self.wave_file_name = wave_file.name |
|
wave_file.close() |
|
|
|
hasAudio = False |
|
for (batch_number, waveform) in enumerate( |
|
sample_audio["waveform"].cpu() |
|
): |
|
buff = io.BytesIO() |
|
torchaudio.save( |
|
buff, waveform, sample_audio["sample_rate"], format="WAV" |
|
) |
|
with open(self.wave_file_name, 'wb') as f: |
|
f.write(buff.getbuffer()) |
|
hasAudio = True |
|
break |
|
if not hasAudio: |
|
raise FileNotFoundError("No audio input") |
|
r = F5TTSCreate.load_voice(self.wave_file_name, sample_text) |
|
return r |
|
|
|
def remove_wave_file(self): |
|
if self.wave_file_name is not None: |
|
try: |
|
os.unlink(self.wave_file_name) |
|
self.wave_file_name = None |
|
except Exception as e: |
|
print("F5TTS: Cannot remove? "+self.wave_file_name) |
|
print(e) |
|
|
|
def create( |
|
self, sample_audio, sample_text, speech, seed=-1, model="F5" |
|
): |
|
vocoder = "vocos" |
|
try: |
|
main_voice = self.load_voice_from_input(sample_audio, sample_text) |
|
|
|
f5ttsCreate = F5TTSCreate() |
|
|
|
voices = {} |
|
chunks = f5ttsCreate.split_text(speech) |
|
voices['main'] = main_voice |
|
|
|
audio = f5ttsCreate.create( |
|
voices, chunks, seed, model, vocoder |
|
) |
|
finally: |
|
self.remove_wave_file() |
|
return (audio, ) |
|
|
|
@classmethod |
|
def IS_CHANGED(s, sample_audio, sample_text, speech, seed, model): |
|
m = hashlib.sha256() |
|
m.update(sample_text) |
|
m.update(sample_audio) |
|
m.update(speech) |
|
m.update(seed) |
|
m.update(model) |
|
return m.digest().hex() |
|
|
|
|
|
class F5TTSAudio: |
|
def __init__(self): |
|
self.use_cli = False |
|
|
|
@classmethod |
|
def INPUT_TYPES(s): |
|
input_dir = folder_paths.get_input_directory() |
|
files = folder_paths.filter_files_content_types( |
|
os.listdir(input_dir), ["audio", "video"] |
|
) |
|
filesWithTxt = [] |
|
for file in files: |
|
txtFile = F5TTSCreate.get_txt_file_path(file) |
|
if os.path.isfile(os.path.join(input_dir, txtFile)): |
|
filesWithTxt.append(file) |
|
filesWithTxt = sorted(filesWithTxt) |
|
|
|
model_types = F5TTSCreate.get_model_types() |
|
|
|
return { |
|
"required": { |
|
"sample": (filesWithTxt, {"audio_upload": True}), |
|
"speech": ("STRING", { |
|
"multiline": True, |
|
"default": "This is what I want to say" |
|
}), |
|
"seed": ("INT", { |
|
"display": "number", "step": 1, |
|
"default": 1, "min": -1, |
|
"tooltip": F5TTSCreate.tooltip_seed, |
|
}), |
|
"model": (model_types,), |
|
|
|
} |
|
} |
|
|
|
CATEGORY = "audio" |
|
|
|
RETURN_TYPES = ("AUDIO", ) |
|
FUNCTION = "create" |
|
|
|
def create_with_cli(self, audio_path, audio_text, speech, output_dir): |
|
subprocess.run( |
|
[ |
|
"python", "inference-cli.py", "--model", "F5-TTS", |
|
"--ref_audio", audio_path, "--ref_text", audio_text, |
|
"--gen_text", speech, |
|
"--output_dir", output_dir |
|
], |
|
cwd=Install.f5TTSPath |
|
) |
|
output_audio = os.path.join(output_dir, "out.wav") |
|
with wave.open(output_audio, "rb") as wave_file: |
|
frame_rate = wave_file.getframerate() |
|
|
|
waveform, sample_rate = torchaudio.load(output_audio) |
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": frame_rate} |
|
return audio |
|
|
|
def load_voice_from_file(self, sample): |
|
input_dir = folder_paths.get_input_directory() |
|
txt_file = os.path.join( |
|
input_dir, |
|
F5TTSCreate.get_txt_file_path(sample) |
|
) |
|
audio_text = '' |
|
with open(txt_file, 'r', encoding='utf-8') as file: |
|
audio_text = file.read() |
|
audio_path = folder_paths.get_annotated_filepath(sample) |
|
print("audio_text") |
|
print(audio_text) |
|
return F5TTSCreate.load_voice(audio_path, audio_text) |
|
|
|
def load_voices_from_files(self, sample, voice_names): |
|
voices = {} |
|
p = Path(sample) |
|
for voice_name in voice_names: |
|
if voice_name == "main": |
|
continue |
|
sample_file = os.path.join( |
|
os.path.dirname(sample), |
|
"{stem}.{voice_name}{suffix}".format( |
|
stem=p.stem, |
|
voice_name=voice_name, |
|
suffix=p.suffix |
|
) |
|
) |
|
print("voice:"+voice_name+","+sample_file+','+sample) |
|
voices[voice_name] = self.load_voice_from_file(sample_file) |
|
return voices |
|
|
|
def create(self, sample, speech, seed=-2, model="F5"): |
|
vocoder = "vocos" |
|
|
|
main_voice = self.load_voice_from_file(sample) |
|
|
|
f5ttsCreate = F5TTSCreate() |
|
if self.use_cli: |
|
|
|
output_dir = tempfile.mkdtemp() |
|
audio_path = folder_paths.get_annotated_filepath(sample) |
|
audio = self.create_with_cli( |
|
audio_path, main_voice["ref_text"], |
|
speech, output_dir |
|
) |
|
shutil.rmtree(output_dir) |
|
else: |
|
chunks = f5ttsCreate.split_text(speech) |
|
voice_names = f5ttsCreate.get_voice_names(chunks) |
|
voices = self.load_voices_from_files(sample, voice_names) |
|
voices['main'] = main_voice |
|
|
|
audio = f5ttsCreate.create(voices, chunks, seed, model, vocoder) |
|
return (audio, ) |
|
|
|
@classmethod |
|
def IS_CHANGED(s, sample, speech, seed, model): |
|
m = hashlib.sha256() |
|
audio_path = folder_paths.get_annotated_filepath(sample) |
|
audio_txt_path = F5TTSCreate.get_txt_file_path(audio_path) |
|
last_modified_timestamp = os.path.getmtime(audio_path) |
|
txt_last_modified_timestamp = os.path.getmtime(audio_txt_path) |
|
m.update(audio_path) |
|
m.update(str(last_modified_timestamp)) |
|
m.update(str(txt_last_modified_timestamp)) |
|
m.update(speech) |
|
m.update(seed) |
|
m.update(model) |
|
return m.digest().hex() |
|
|