|
import os |
|
import sys |
|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
import uroman |
|
import numpy as np |
|
import requests |
|
import hashlib |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from outetts.wav_tokenizer.decoder import WavTokenizer |
|
|
|
|
|
import logging |
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
if not os.path.exists("yarngpt"): |
|
logger.info("Cloning YarnGPT repository...") |
|
os.system("git clone https://github.com/saheedniyi02/yarngpt.git") |
|
|
|
sys.path.append("yarngpt") |
|
else: |
|
sys.path.append("yarngpt") |
|
|
|
|
|
from yarngpt.audiotokenizer import AudioTokenizerV2 |
|
|
|
|
|
MODEL_PATH = "saheedniyi/YarnGPT2b" |
|
WAV_TOKENIZER_CONFIG_URL = "https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" |
|
WAV_TOKENIZER_MODEL_URL = "https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt" |
|
WAV_TOKENIZER_CONFIG_PATH = "wavtokenizer_config.yaml" |
|
WAV_TOKENIZER_MODEL_PATH = "wavtokenizer_model.ckpt" |
|
|
|
|
|
def download_file(url, output_path): |
|
"""Download a file with progress tracking and verification""" |
|
logger.info(f"Downloading {url} to {output_path}") |
|
|
|
|
|
with requests.get(url, stream=True) as response: |
|
response.raise_for_status() |
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
with open(output_path, 'wb') as f: |
|
downloaded = 0 |
|
for chunk in response.iter_content(chunk_size=8192): |
|
if chunk: |
|
f.write(chunk) |
|
downloaded += len(chunk) |
|
percent = int(100 * downloaded / total_size) if total_size > 0 else 0 |
|
if percent % 10 == 0: |
|
logger.info(f"Download progress: {percent}%") |
|
|
|
|
|
if os.path.exists(output_path) and os.path.getsize(output_path) > 0: |
|
logger.info(f"Successfully downloaded {output_path}") |
|
return True |
|
else: |
|
logger.error(f"Failed to download {output_path}") |
|
return False |
|
|
|
|
|
def download_required_files(): |
|
|
|
if not os.path.exists(WAV_TOKENIZER_CONFIG_PATH) or os.path.getsize(WAV_TOKENIZER_CONFIG_PATH) == 0: |
|
logger.info("Downloading WavTokenizer config...") |
|
if not download_file(WAV_TOKENIZER_CONFIG_URL, WAV_TOKENIZER_CONFIG_PATH): |
|
raise RuntimeError("Failed to download WavTokenizer config") |
|
|
|
|
|
if not os.path.exists(WAV_TOKENIZER_MODEL_PATH) or os.path.getsize(WAV_TOKENIZER_MODEL_PATH) == 0: |
|
logger.info("Downloading WavTokenizer model...") |
|
if not download_file(WAV_TOKENIZER_MODEL_URL, WAV_TOKENIZER_MODEL_PATH): |
|
raise RuntimeError("Failed to download WavTokenizer model") |
|
|
|
|
|
if not os.path.exists(WAV_TOKENIZER_CONFIG_PATH) or not os.path.exists(WAV_TOKENIZER_MODEL_PATH): |
|
raise RuntimeError("Required files not found") |
|
|
|
|
|
if os.path.getsize(WAV_TOKENIZER_CONFIG_PATH) == 0 or os.path.getsize(WAV_TOKENIZER_MODEL_PATH) == 0: |
|
raise RuntimeError("Downloaded files are empty") |
|
|
|
logger.info("All required files are downloaded and verified") |
|
|
|
|
|
def initialize_model(): |
|
try: |
|
|
|
download_required_files() |
|
|
|
logger.info("Initializing AudioTokenizer...") |
|
audio_tokenizer = AudioTokenizerV2( |
|
MODEL_PATH, |
|
WAV_TOKENIZER_MODEL_PATH, |
|
WAV_TOKENIZER_CONFIG_PATH |
|
) |
|
|
|
logger.info("Loading YarnGPT model...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_PATH, |
|
torch_dtype="auto" |
|
).to(audio_tokenizer.device) |
|
|
|
logger.info("Model initialization complete!") |
|
return model, audio_tokenizer |
|
except Exception as e: |
|
logger.error(f"Failed to initialize model: {str(e)}") |
|
raise |
|
|
|
|
|
logger.info("Starting model initialization...") |
|
try: |
|
model, audio_tokenizer = initialize_model() |
|
except Exception as e: |
|
logger.error(f"Error initializing model: {str(e)}") |
|
|
|
demo = gr.Interface( |
|
fn=lambda x: f"Model initialization failed: {str(e)}. Please check the space logs for more details.", |
|
inputs=gr.Textbox(label="Error occurred during initialization"), |
|
outputs=gr.Textbox(), |
|
title="YarnGPT - Initialization Error" |
|
) |
|
demo.launch() |
|
|
|
sys.exit(1) |
|
|
|
|
|
VOICES = ["idera", "jude", "kemi", "tunde", "funmi"] |
|
LANGUAGES = ["english", "yoruba", "igbo", "hausa", "pidgin"] |
|
|
|
|
|
def generate_speech(text, language, voice, temperature=0.1, rep_penalty=1.1): |
|
if not text: |
|
return None, "Please enter some text to convert to speech." |
|
|
|
try: |
|
logger.info(f"Generating speech for text: {text[:50]}...") |
|
|
|
|
|
prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice) |
|
|
|
|
|
input_ids = audio_tokenizer.tokenize_prompt(prompt) |
|
|
|
|
|
output = model.generate( |
|
input_ids=input_ids, |
|
temperature=temperature, |
|
repetition_penalty=rep_penalty, |
|
max_length=4000, |
|
) |
|
|
|
|
|
codes = audio_tokenizer.get_codes(output) |
|
audio = audio_tokenizer.get_audio(codes) |
|
|
|
|
|
temp_audio_path = "output.wav" |
|
torchaudio.save(temp_audio_path, audio, sample_rate=24000) |
|
|
|
logger.info("Speech generation complete") |
|
return temp_audio_path, f"Successfully generated speech for: {text[:50]}..." |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating speech: {str(e)}") |
|
return None, f"Error generating speech: {str(e)}" |
|
|
|
|
|
examples = [ |
|
["Hello, my name is Claude. I am an AI assistant created by Anthropic.", "english", "idera"], |
|
["Báwo ni o ṣe wà? Mo ń gbádùn ọjọ́ mi.", "yoruba", "kemi"], |
|
["I don dey come house now, make you prepare food.", "pidgin", "jude"] |
|
] |
|
|
|
|
|
with gr.Blocks(title="YarnGPT - Nigerian Accented Text-to-Speech") as demo: |
|
gr.Markdown("# YarnGPT - Nigerian Accented Text-to-Speech") |
|
gr.Markdown("Generate speech with Nigerian accents using YarnGPT model.") |
|
|
|
with gr.Tab("Basic TTS"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox( |
|
label="Text to convert to speech", |
|
placeholder="Enter text here...", |
|
lines=5 |
|
) |
|
language = gr.Dropdown( |
|
label="Language", |
|
choices=LANGUAGES, |
|
value="english" |
|
) |
|
voice = gr.Dropdown( |
|
label="Voice", |
|
choices=VOICES, |
|
value="idera" |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.1, |
|
step=0.1 |
|
) |
|
rep_penalty = gr.Slider( |
|
label="Repetition Penalty", |
|
minimum=1.0, |
|
maximum=2.0, |
|
value=1.1, |
|
step=0.1 |
|
) |
|
generate_btn = gr.Button("Generate Speech") |
|
|
|
with gr.Column(): |
|
audio_output = gr.Audio(label="Generated Speech") |
|
status_output = gr.Textbox(label="Status") |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[text_input, language, voice], |
|
outputs=[audio_output, status_output], |
|
fn=generate_speech, |
|
cache_examples=False |
|
) |
|
|
|
generate_btn.click( |
|
generate_speech, |
|
inputs=[text_input, language, voice, temperature, rep_penalty], |
|
outputs=[audio_output, status_output] |
|
) |
|
|
|
gr.Markdown(""" |
|
## About YarnGPT |
|
YarnGPT is a text-to-speech model with Nigerian accents. It supports multiple languages and voices. |
|
|
|
### Credits |
|
- Model by [saheedniyi](https://huggingface.co/saheedniyi/YarnGPT2b) |
|
- [Original Repository](https://github.com/saheedniyi02/yarngpt) |
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |