import csv import datetime import os import re import time import uuid from io import StringIO import gradio as gr import nltk import numpy as np import pyrubberband import spaces import torch import torchaudio from huggingface_hub import HfApi, hf_hub_download, snapshot_download from nltk.sentiment import SentimentIntensityAnalyzer from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from vinorm import TTSnorm nltk.download('vader_lexicon') os.system("python -m unidic download") os.system('nvidia-smi') HF_TOKEN = None api = HfApi(token=HF_TOKEN) checkpoint_dir = "model/" repo_id = "capleaf/viXTTS" use_deepspeed = False os.makedirs(checkpoint_dir, exist_ok=True) required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] files_in_dir = os.listdir(checkpoint_dir) if not all(file in files_in_dir for file in required_files): snapshot_download( repo_id=repo_id, repo_type="model", local_dir=checkpoint_dir, ) hf_hub_download( repo_id="coqui/XTTS-v2", filename="speakers_xtts.pth", local_dir=checkpoint_dir, ) xtts_config = os.path.join(checkpoint_dir, "config.json") config = XttsConfig() config.load_json(xtts_config) MODEL = Xtts.init_from_config(config) MODEL.load_checkpoint( config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed ) if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") MODEL.to(device) supported_languages = config.languages if not "vi" in supported_languages: supported_languages.append("vi") if not "es-AR" in supported_languages: supported_languages.append("es-AR") def normalize_vietnamese_text(text): text = ( TTSnorm(text, unknown=False, lower=False, rule=True) .replace("..", ".") .replace("!.", "!") .replace("?.", "?") .replace(" .", ".") .replace(" ,", ",") .replace('"', "") .replace("'", "") .replace("AI", "Ây Ai") .replace("A.I", "Ây Ai") ) return text def analyze_sentiment(text): sia = SentimentIntensityAnalyzer() scores = sia.polarity_scores(text) return scores['compound'] def change_pitch(audio_data, sampling_rate, sentiment): semitones = sentiment * 2 return pyrubberband.pitch_shift(audio_data, sampling_rate, semitones) def apply_distortion(audio_data, sentiment): distortion_factor = abs(sentiment) * 0.5 return audio_data * (1 + distortion_factor * np.random.randn(len(audio_data))) @spaces.GPU(duration=0) def predict( prompt, language, audio_file_pth, normalize_text=True, ): if language not in supported_languages: metrics_text = gr.Warning( f"El idioma seleccionado ({language}) no está disponible. Por favor, elige uno de la lista." ) return (None, metrics_text) speaker_wav = audio_file_pth if len(prompt) < 2: metrics_text = gr.Warning("Por favor, introduce un texto más largo.") return (None, metrics_text) try: metrics_text = "" t_latent = time.time() try: ( gpt_cond_latent, speaker_embedding, ) = MODEL.get_conditioning_latents( audio_path=speaker_wav, gpt_cond_len=30, gpt_cond_chunk_len=4, max_ref_length=60, ) except Exception as e: print("Speaker encoding error", str(e)) metrics_text = gr.Warning( "¿Has activado el micrófono? Parece que hay un problema con la referencia de audio." ) return (None, metrics_text) prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt) if normalize_text and language == "vi": prompt = normalize_vietnamese_text(prompt) sentiment = analyze_sentiment(prompt) temperature = 0.75 + sentiment * 0.2 temperature = max(0.5, min(temperature, 1.0)) t0 = time.time() out = MODEL.inference( prompt, language, gpt_cond_latent, speaker_embedding, repetition_penalty=5.0, temperature=temperature, enable_text_splitting=True, ) inference_time = time.time() - t0 metrics_text += ( f"Tiempo de generación de audio: {round(inference_time*1000)} milisegundos\n" ) real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000 metrics_text += f"Factor de tiempo real (RTF): {real_time_factor:.2f}\n" audio_data = np.array(out["wav"]) modified_audio = change_pitch(audio_data, 24000, sentiment) modified_audio = apply_distortion(modified_audio, sentiment) torchaudio.save("output.wav", torch.tensor(modified_audio).unsqueeze(0), 24000) except RuntimeError as e: if "device-side assert" in str(e): error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S") error_data = [ error_time, prompt, language, audio_file_pth, ] error_data = [str(e) if type(e) != str else e for e in error_data] write_io = StringIO() csv.writer(write_io).writerows([error_data]) csv_upload = write_io.getvalue().encode() filename = error_time + "_" + str(uuid.uuid4()) + ".csv" error_api = HfApi() error_api.upload_file( path_or_fileobj=csv_upload, path_in_repo=filename, repo_id="coqui/xtts-flagged-dataset", repo_type="dataset", ) speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav" error_api = HfApi() error_api.upload_file( path_or_fileobj=speaker_wav, path_in_repo=speaker_filename, repo_id="coqui/xtts-flagged-dataset", repo_type="dataset", ) space = api.get_space_runtime(repo_id=repo_id) if space.stage != "BUILDING": api.restart_space(repo_id=repo_id) else: if "Failed to decode" in str(e): metrics_text = gr.Warning( metrics_text="Parece que hay un problema con la referencia de audio. ¿Has activado el micrófono?" ) else: metrics_text = gr.Warning( "Se ha producido un error inesperado. Por favor, inténtalo de nuevo." ) return (None, metrics_text) return ("output.wav", metrics_text) with gr.Blocks(analytics_enabled=False) as demo: with gr.Row(): with gr.Column(): gr.Markdown( """ # viXTTS Demo ✨ """ ) with gr.Column(): pass with gr.Row(): with gr.Column(): input_text_gr = gr.Textbox( label="Texto a convertir a voz", value="Hola, soy un modelo de texto a voz.", ) language_gr = gr.Dropdown( label="Idioma", choices=[ "es-AR", "vi", "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "ko", "hu", "hi", ], max_choices=1, value="es-AR", ) normalize_text = gr.Checkbox( label="Normalizar texto en vietnamita", info="Solo aplicable al idioma vietnamita", value=True, ) ref_gr = gr.Audio( label="Audio de referencia (opcional)", type="filepath", value="model/samples/nu-luu-loat.wav", ) tts_button = gr.Button( "Generar voz 🗣️🔥", elem_id="send-btn", visible=True, variant="primary", ) with gr.Column(): audio_gr = gr.Audio(label="Audio generado", autoplay=True) out_text_gr = gr.Text(label="Métricas") tts_button.click( predict, [ input_text_gr, language_gr, ref_gr, normalize_text, ], outputs=[audio_gr, out_text_gr], api_name="predict", ) demo.queue() demo.launch(debug=True, show_api=True, share=True)