TTS-Model / viXTTS /main.py
duyv's picture
Upload 381 files
813828b verified
import hashlib
import os
import string
import subprocess
import sys
from datetime import datetime
import torch
import torchaudio
from huggingface_hub import hf_hub_download, snapshot_download
from underthesea import sent_tokenize
from unidecode import unidecode
from vinorm import TTSnorm
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
XTTS_MODEL = None
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(SCRIPT_DIR, "model")
OUTPUT_DIR = os.path.join(SCRIPT_DIR, "output")
FILTER_SUFFIX = "_DeepFilterNet3.wav"
os.makedirs(OUTPUT_DIR, exist_ok=True)
def clear_gpu_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_model(checkpoint_dir="model/", repo_id="capleaf/viXTTS", use_deepspeed=False):
global XTTS_MODEL
clear_gpu_cache()
os.makedirs(checkpoint_dir, exist_ok=True)
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
files_in_dir = os.listdir(checkpoint_dir)
if not all(file in files_in_dir for file in required_files):
yield f"Missing model files! Downloading from {repo_id}..."
snapshot_download(
repo_id=repo_id,
repo_type="model",
local_dir=checkpoint_dir,
)
hf_hub_download(
repo_id="coqui/XTTS-v2",
filename="speakers_xtts.pth",
local_dir=checkpoint_dir,
)
yield f"Model download finished..."
xtts_config = os.path.join(checkpoint_dir, "config.json")
config = XttsConfig()
config.load_json(xtts_config)
XTTS_MODEL = Xtts.init_from_config(config)
yield "Loading model..."
XTTS_MODEL.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed)
if torch.cuda.is_available():
XTTS_MODEL.cuda()
print("Model Loaded!")
yield "Model Loaded!"
# Define dictionaries to store cached results
cache_queue = []
speaker_audio_cache = {}
filter_cache = {}
conditioning_latents_cache = {}
def invalidate_cache(cache_limit=50):
"""Invalidate the cache for the oldest key"""
if len(cache_queue) > cache_limit:
key_to_remove = cache_queue.pop(0)
print("Invalidating cache", key_to_remove)
if os.path.exists(key_to_remove):
os.remove(key_to_remove)
if os.path.exists(key_to_remove.replace(".wav", "_DeepFilterNet3.wav")):
os.remove(key_to_remove.replace(".wav", "_DeepFilterNet3.wav"))
if key_to_remove in filter_cache:
del filter_cache[key_to_remove]
if key_to_remove in conditioning_latents_cache:
del conditioning_latents_cache[key_to_remove]
def generate_hash(data):
hash_object = hashlib.md5()
hash_object.update(data)
return hash_object.hexdigest()
def get_file_name(text, max_char=50):
filename = text[:max_char]
filename = filename.lower()
filename = filename.replace(" ", "_")
filename = filename.translate(str.maketrans("", "", string.punctuation.replace("_", "")))
filename = unidecode(filename)
current_datetime = datetime.now().strftime("%m%d%H%M%S")
filename = f"{current_datetime}_{filename}"
return filename
from unicodedata import normalize
def normalize_vietnamese_text(text):
text = (
normalize("NFC", text)
.replace("..", ".")
.replace("!.", "!")
.replace("?.", "?")
.replace(" .", ".")
.replace(" ,", ",")
.replace('"', "")
.replace("'", "")
.replace("AI", "Ây Ai")
.replace("A.I", "Ây Ai")
)
return text
def calculate_keep_len(text, lang):
"""Simple hack for short sentences"""
if lang in ["ja", "zh-cn"]:
return -1
word_count = len(text.split())
num_punct = text.count(".") + text.count("!") + text.count("?") + text.count(",")
if word_count < 5:
return 15000 * word_count + 2000 * num_punct
elif word_count < 10:
return 13000 * word_count + 2000 * num_punct
return -1
def run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text):
global filter_cache, conditioning_latents_cache, cache_queue
if XTTS_MODEL is None:
return "You need to run the previous step to load the model !!", None, None
if not speaker_audio_file:
return "You need to provide reference audio!!!", None, None
# Use the file name as the key, since it's suppose to be unique 💀
speaker_audio_key = speaker_audio_file
if not speaker_audio_key in cache_queue:
cache_queue.append(speaker_audio_key)
invalidate_cache()
# Check if filtered reference is cached
if use_deepfilter and speaker_audio_key in filter_cache:
print("Using filter cache...")
speaker_audio_file = filter_cache[speaker_audio_key]
elif use_deepfilter:
print("Running filter...")
subprocess.run(
[
"deepFilter",
speaker_audio_file,
"-o",
os.path.dirname(speaker_audio_file),
]
)
filter_cache[speaker_audio_key] = speaker_audio_file.replace(".wav", FILTER_SUFFIX)
speaker_audio_file = filter_cache[speaker_audio_key]
# Check if conditioning latents are cached
cache_key = (
speaker_audio_key,
XTTS_MODEL.config.gpt_cond_len,
XTTS_MODEL.config.max_ref_len,
XTTS_MODEL.config.sound_norm_refs,
)
if cache_key in conditioning_latents_cache:
print("Using conditioning latents cache...")
gpt_cond_latent, speaker_embedding = conditioning_latents_cache[cache_key]
else:
print("Computing conditioning latents...")
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
audio_path=speaker_audio_file,
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
max_ref_length=XTTS_MODEL.config.max_ref_len,
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
)
conditioning_latents_cache[cache_key] = (gpt_cond_latent, speaker_embedding)
if normalize_text and lang == "vi":
tts_text = normalize_vietnamese_text(tts_text)
# Split text by sentence
if lang in ["ja", "zh-cn"]:
sentences = tts_text.split("。")
else:
sentences = sent_tokenize(tts_text)
wav_chunks = []
for sentence in sentences:
if sentence.strip() == "":
continue
wav_chunk = XTTS_MODEL.inference(
text=sentence,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
# The following values are carefully chosen for viXTTS
temperature=0.3,
length_penalty=1.0,
repetition_penalty=10.0,
top_k=30,
top_p=0.85,
enable_text_splitting=True,
)
keep_len = calculate_keep_len(sentence, lang)
wav_chunk["wav"] = wav_chunk["wav"][:keep_len]
wav_chunks.append(torch.tensor(wav_chunk["wav"]))
out_wav = torch.cat(wav_chunks, dim=0).unsqueeze(0)
out_path = os.path.join(OUTPUT_DIR, f"{get_file_name(tts_text)}.wav")
print("Saving output to ", out_path)
torchaudio.save(out_path, out_wav, 24000)
return "Speech generated !", out_path
def create_interface():
try:
# Gọi hàm load_model để tải mô hình
model_loading_gen = load_model(checkpoint_dir=MODEL_DIR, repo_id="capleaf/viXTTS", use_deepspeed=False)
# Chạy hàm này cho đến khi mô hình được tải xong
for message in model_loading_gen:
print(message) # In ra thông báo trạng thái tải mô hình
# Các tham số khác
speaker_audio_files = [
r"samples\nu-nhe-nhang.wav",
r"samples\nu-nhan-nha.wav",
r"samples\nu-luu-loat.wav",
r"samples\nu-cham.wav",
r"samples\nu-calm.wav",
r"samples\nam-truyen-cam.wav",
r"samples\nam-nhanh.wav",
r"samples\nam-cham.wav",
r"samples\nam-calm.wav",
]
speaker_audio_file = speaker_audio_files[0]
# Các tham số khác
lang = "vi"
normalize_text = True
use_deepfilter = False
tts_text = "Chào bạn, tôi là một trợ lý ảo."
# Gọi hàm run_tts sau khi mô hình đã được tải
return run_tts(lang, tts_text, speaker_audio_file, use_deepfilter, normalize_text)
except Exception as e:
return f"Error loading model: {str(e)}", None, None
# Gọi hàm create_interface để bắt đầu quá trình
print(create_interface())