Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import os | |
import logging | |
import soundfile as sf | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
# --- CONSTANTES --- | |
REPO_ID = "dofbi/galsenai-xtts-v2-wolof-inference" | |
LOCAL_DIR = "./models" | |
class WolofXTTSInference: | |
def __init__(self, repo_id=REPO_ID, local_dir=LOCAL_DIR): | |
# Configuration du logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
self.logger = logging.getLogger(__name__) | |
# Créer le dossier local s'il n'existe pas | |
os.makedirs(local_dir, exist_ok=True) | |
# Téléchargement des fichiers nécessaires | |
try: | |
# Créer les sous-dossiers nécessaires | |
os.makedirs(os.path.join(local_dir, "Anta_GPT_XTTS_Wo"), exist_ok=True) | |
os.makedirs(os.path.join(local_dir, "XTTS_v2.0_original_model_files"), exist_ok=True) | |
# Télécharger le checkpoint | |
self.model_path = hf_hub_download( | |
repo_id=repo_id, | |
filename="Anta_GPT_XTTS_Wo/best_model_89250.pth", | |
local_dir=local_dir | |
) | |
# Télécharger le fichier de configuration | |
self.config_path = hf_hub_download( | |
repo_id=repo_id, | |
filename="Anta_GPT_XTTS_Wo/config.json", | |
local_dir=local_dir | |
) | |
# Télécharger le vocabulaire | |
self.vocab_path = hf_hub_download( | |
repo_id=repo_id, | |
filename="XTTS_v2.0_original_model_files/vocab.json", | |
local_dir=local_dir | |
) | |
# Télécharger l'audio de référence | |
self.reference_audio = hf_hub_download( | |
repo_id=repo_id, | |
filename="anta_sample.wav", | |
local_dir=local_dir | |
) | |
except Exception as e: | |
self.logger.error(f"Erreur lors du téléchargement des fichiers : {e}") | |
raise | |
# Sélection du device | |
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# Initialisation du modèle | |
self.model = self._load_model() | |
def _load_model(self): | |
"""Charge le modèle XTTS""" | |
try: | |
self.logger.info("Chargement du modèle XTTS...") | |
# Initialisation du modèle | |
config = XttsConfig() | |
config.load_json(self.config_path) | |
model = Xtts.init_from_config(config) | |
# Chargement du checkpoint avec load_checkpoint | |
model.load_checkpoint(config, | |
checkpoint_path=self.model_path, | |
vocab_path=self.vocab_path, | |
use_deepspeed=False | |
) | |
model.to(self.device) | |
model.eval() # Mettre le modèle en mode évaluation | |
self.logger.info("Modèle chargé avec succès!") | |
return model | |
except Exception as e: | |
self.logger.error(f"Erreur lors du chargement du modèle : {e}") | |
raise | |
def generate_audio( | |
self, | |
text: str, | |
reference_audio: str = None, | |
speed: float = 1.06, | |
language: str = "wo", | |
output_path: str = None | |
) -> tuple[np.ndarray, int]: | |
""" | |
Génère de l'audio à partir du texte fourni | |
Args: | |
text (str): Texte à convertir en audio | |
reference_audio (str, optional): Chemin vers l'audio de référence. Defaults to None. | |
speed (float, optional): Vitesse de lecture. Defaults to 1.06. | |
language (str, optional): Langue du texte. Defaults to "wo". | |
output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None. | |
Returns: | |
tuple[np.ndarray, int]: audio_array, sample_rate | |
""" | |
if not text: | |
raise ValueError("Le texte ne peut pas être vide.") | |
try: | |
# Utiliser l'audio de référence fourni ou par défaut | |
ref_audio = reference_audio or self.reference_audio | |
# Obtenir les embeddings | |
gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents( | |
audio_path=[ref_audio], | |
gpt_cond_len=self.model.config.gpt_cond_len, | |
max_ref_length=self.model.config.max_ref_len, | |
sound_norm_refs=self.model.config.sound_norm_refs | |
) | |
# Génération de l'audio | |
result = self.model.inference( | |
text=text.lower(), | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
do_sample=False, | |
speed=speed, | |
language=language, | |
enable_text_splitting=True | |
) | |
# Récupérer le taux d'échantillonnage | |
sample_rate = self.model.config.audio.sample_rate | |
# Sauvegarde optionnelle | |
if output_path: | |
sf.write(output_path, result["wav"], sample_rate) | |
self.logger.info(f"Audio sauvegardé dans {output_path}") | |
return result["wav"], sample_rate | |
except Exception as e: | |
self.logger.error(f"Erreur lors de la génération de l'audio : {e}") | |
raise | |
def generate_audio_from_config(self, text: str, config: dict, output_path: str = None) -> tuple[np.ndarray, int]: | |
""" | |
Génère de l'audio à partir du texte et d'un dictionnaire de configuration. | |
Args: | |
text (str): Texte à convertir en audio | |
config (dict): Dictionnaire de configuration (speed, language, reference_audio) | |
output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None. | |
Returns: | |
tuple[np.ndarray, int]: audio_array, sample_rate | |
""" | |
speed = config.get('speed', 1.06) | |
language = config.get('language', "wo") | |
reference_audio = config.get('reference_audio', None) | |
return self.generate_audio(text=text, reference_audio=reference_audio, speed=speed, language=language, output_path=output_path) | |
# Exemple d'utilisation | |
if __name__ == "__main__": | |
tts = WolofXTTSInference() | |
# Exemple de génération d'audio | |
text = "Màngi tuddu Aadama, di baat bii waa Galsen A.I defar ngir wax ak yéen ci wolof!" | |
# Simple | |
audio, sr = tts.generate_audio( | |
text, | |
output_path="generated_audio.wav" | |
) | |
# Avec une config | |
config_gen_audio = { | |
"speed": 1.2, | |
"language": "wo", | |
} | |
audio, sr = tts.generate_audio_from_config( | |
text=text, | |
config=config_gen_audio, | |
output_path="generated_audio_config.wav" | |
) |