|
import hashlib
|
|
import os
|
|
import string
|
|
import subprocess
|
|
import sys
|
|
from datetime import datetime
|
|
import torch
|
|
import torchaudio
|
|
from huggingface_hub import hf_hub_download, snapshot_download
|
|
from underthesea import sent_tokenize
|
|
from unidecode import unidecode
|
|
from vinorm import TTSnorm
|
|
|
|
from TTS.tts.configs.xtts_config import XttsConfig
|
|
from TTS.tts.models.xtts import Xtts
|
|
|
|
XTTS_MODEL = None
|
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
MODEL_DIR = os.path.join(SCRIPT_DIR, "model")
|
|
OUTPUT_DIR = os.path.join(SCRIPT_DIR, "output")
|
|
FILTER_SUFFIX = "_DeepFilterNet3.wav"
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
def clear_gpu_cache():
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def load_model(checkpoint_dir="model/", repo_id="capleaf/viXTTS", use_deepspeed=False):
|
|
global XTTS_MODEL
|
|
clear_gpu_cache()
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
|
|
files_in_dir = os.listdir(checkpoint_dir)
|
|
if not all(file in files_in_dir for file in required_files):
|
|
yield f"Missing model files! Downloading from {repo_id}..."
|
|
snapshot_download(
|
|
repo_id=repo_id,
|
|
repo_type="model",
|
|
local_dir=checkpoint_dir,
|
|
)
|
|
hf_hub_download(
|
|
repo_id="coqui/XTTS-v2",
|
|
filename="speakers_xtts.pth",
|
|
local_dir=checkpoint_dir,
|
|
)
|
|
yield f"Model download finished..."
|
|
|
|
xtts_config = os.path.join(checkpoint_dir, "config.json")
|
|
config = XttsConfig()
|
|
config.load_json(xtts_config)
|
|
XTTS_MODEL = Xtts.init_from_config(config)
|
|
yield "Loading model..."
|
|
XTTS_MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
|
|
if torch.cuda.is_available():
|
|
XTTS_MODEL.cuda()
|
|
|
|
print("Model Loaded!")
|
|
yield "Model Loaded!"
|
|
|
|
|
|
|
|
cache_queue = []
|
|
speaker_audio_cache = {}
|
|
filter_cache = {}
|
|
conditioning_latents_cache = {}
|
|
|
|
|
|
def invalidate_cache(cache_limit=50):
|
|
"""Invalidate the cache for the oldest key"""
|
|
if len(cache_queue) > cache_limit:
|
|
key_to_remove = cache_queue.pop(0)
|
|
print("Invalidating cache", key_to_remove)
|
|
if os.path.exists(key_to_remove):
|
|
os.remove(key_to_remove)
|
|
if os.path.exists(key_to_remove.replace(".wav", "_DeepFilterNet3.wav")):
|
|
os.remove(key_to_remove.replace(".wav", "_DeepFilterNet3.wav"))
|
|
if key_to_remove in filter_cache:
|
|
del filter_cache[key_to_remove]
|
|
if key_to_remove in conditioning_latents_cache:
|
|
del conditioning_latents_cache[key_to_remove]
|
|
|
|
|
|
def generate_hash(data):
|
|
hash_object = hashlib.md5()
|
|
hash_object.update(data)
|
|
return hash_object.hexdigest()
|
|
|
|
|
|
def get_file_name(text, max_char=50):
|
|
filename = text[:max_char]
|
|
filename = filename.lower()
|
|
filename = filename.replace(" ", "_")
|
|
filename = filename.translate(str.maketrans("", "", string.punctuation.replace("_", "")))
|
|
filename = unidecode(filename)
|
|
current_datetime = datetime.now().strftime("%m%d%H%M%S")
|
|
filename = f"{current_datetime}_{filename}"
|
|
return filename
|
|
|
|
from unicodedata import normalize
|
|
def normalize_vietnamese_text(text):
|
|
text = (
|
|
normalize("NFC", text)
|
|
.replace("..", ".")
|
|
.replace("!.", "!")
|
|
.replace("?.", "?")
|
|
.replace(" .", ".")
|
|
.replace(" ,", ",")
|
|
.replace('"', "")
|
|
.replace("'", "")
|
|
.replace("AI", "Ây Ai")
|
|
.replace("A.I", "Ây Ai")
|
|
)
|
|
return text
|
|
|
|
|
|
def calculate_keep_len(text, lang):
|
|
"""Simple hack for short sentences"""
|
|
if lang in ["ja", "zh-cn"]:
|
|
return -1
|
|
|
|
word_count = len(text.split())
|
|
num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
|
|
|
|
if word_count < 5:
|
|
return 15000 * word_count + 2000 * num_punct
|
|
elif word_count < 10:
|
|
return 13000 * word_count + 2000 * num_punct
|
|
return -1
|
|
|
|
|
|
def run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text):
|
|
global filter_cache, conditioning_latents_cache, cache_queue
|
|
|
|
if XTTS_MODEL is None:
|
|
return "You need to run the previous step to load the model !!", None, None
|
|
|
|
if not speaker_audio_file:
|
|
return "You need to provide reference audio!!!", None, None
|
|
|
|
|
|
speaker_audio_key = speaker_audio_file
|
|
if not speaker_audio_key in cache_queue:
|
|
cache_queue.append(speaker_audio_key)
|
|
invalidate_cache()
|
|
|
|
|
|
if use_deepfilter and speaker_audio_key in filter_cache:
|
|
print("Using filter cache...")
|
|
speaker_audio_file = filter_cache[speaker_audio_key]
|
|
elif use_deepfilter:
|
|
print("Running filter...")
|
|
subprocess.run(
|
|
[
|
|
"deepFilter",
|
|
speaker_audio_file,
|
|
"-o",
|
|
os.path.dirname(speaker_audio_file),
|
|
]
|
|
)
|
|
filter_cache[speaker_audio_key] = speaker_audio_file.replace(".wav", FILTER_SUFFIX)
|
|
speaker_audio_file = filter_cache[speaker_audio_key]
|
|
|
|
|
|
cache_key = (
|
|
speaker_audio_key,
|
|
XTTS_MODEL.config.gpt_cond_len,
|
|
XTTS_MODEL.config.max_ref_len,
|
|
XTTS_MODEL.config.sound_norm_refs,
|
|
)
|
|
if cache_key in conditioning_latents_cache:
|
|
print("Using conditioning latents cache...")
|
|
gpt_cond_latent, speaker_embedding = conditioning_latents_cache[cache_key]
|
|
else:
|
|
print("Computing conditioning latents...")
|
|
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
|
|
audio_path=speaker_audio_file,
|
|
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
|
|
max_ref_length=XTTS_MODEL.config.max_ref_len,
|
|
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
|
|
)
|
|
conditioning_latents_cache[cache_key] = (gpt_cond_latent, speaker_embedding)
|
|
|
|
if normalize_text and lang == "vi":
|
|
tts_text = normalize_vietnamese_text(tts_text)
|
|
|
|
|
|
if lang in ["ja", "zh-cn"]:
|
|
sentences = tts_text.split("。")
|
|
else:
|
|
sentences = sent_tokenize(tts_text)
|
|
|
|
wav_chunks = []
|
|
for sentence in sentences:
|
|
if sentence.strip() == "":
|
|
continue
|
|
wav_chunk = XTTS_MODEL.inference(
|
|
text=sentence,
|
|
language=lang,
|
|
gpt_cond_latent=gpt_cond_latent,
|
|
speaker_embedding=speaker_embedding,
|
|
|
|
temperature=0.3,
|
|
length_penalty=1.0,
|
|
repetition_penalty=10.0,
|
|
top_k=30,
|
|
top_p=0.85,
|
|
enable_text_splitting=True,
|
|
)
|
|
|
|
keep_len = calculate_keep_len(sentence, lang)
|
|
wav_chunk["wav"] = wav_chunk["wav"][:keep_len]
|
|
|
|
wav_chunks.append(torch.tensor(wav_chunk["wav"]))
|
|
|
|
out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0)
|
|
out_path = os.path.join(OUTPUT_DIR, f"{get_file_name(tts_text)}.wav")
|
|
print("Saving output to ", out_path)
|
|
torchaudio.save(out_path, out_wav, 24000)
|
|
|
|
return "Speech generated !", out_path
|
|
|
|
|
|
|
|
|
|
def create_interface():
|
|
try:
|
|
|
|
model_loading_gen = load_model(checkpoint_dir=MODEL_DIR, repo_id="capleaf/viXTTS", use_deepspeed=False)
|
|
|
|
|
|
for message in model_loading_gen:
|
|
print(message)
|
|
|
|
|
|
speaker_audio_files = [
|
|
r"samples\nu-nhe-nhang.wav",
|
|
r"samples\nu-nhan-nha.wav",
|
|
r"samples\nu-luu-loat.wav",
|
|
r"samples\nu-cham.wav",
|
|
r"samples\nu-calm.wav",
|
|
r"samples\nam-truyen-cam.wav",
|
|
r"samples\nam-nhanh.wav",
|
|
r"samples\nam-cham.wav",
|
|
r"samples\nam-calm.wav",
|
|
]
|
|
|
|
speaker_audio_file = speaker_audio_files[0]
|
|
|
|
lang = "vi"
|
|
normalize_text = True
|
|
use_deepfilter = False
|
|
tts_text = "Chào bạn, tôi là một trợ lý ảo."
|
|
|
|
|
|
return run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text)
|
|
except Exception as e:
|
|
return f"Error loading model: {str(e)}", None, None
|
|
|
|
|
|
|
|
print(create_interface())
|
|
|