File size: 9,129 Bytes
78b0078
9147378
03d09ab
b5d25fc
 
03d09ab
 
6c2dbc0
 
b5d25fc
 
d4a2e16
6c2dbc0
 
 
 
 
 
9147378
6c2dbc0
9147378
 
 
6c2dbc0
 
9147378
03d09ab
ac16e60
29bfa47
03d09ab
 
6c2dbc0
 
03d09ab
 
d4a2e16
6c2dbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9147378
6c2dbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29bfa47
03d09ab
 
6c2dbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29bfa47
03d09ab
6c2dbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6fd3a8
03d09ab
 
 
b6fd3a8
b5d25fc
03d09ab
 
 
ac16e60
b5d25fc
6c2dbc0
 
03d09ab
 
 
 
 
 
 
 
 
 
 
 
b5d25fc
03d09ab
 
 
 
 
 
 
 
 
6c2dbc0
03d09ab
 
b5d25fc
6c2dbc0
03d09ab
b6fd3a8
9147378
 
 
 
 
 
 
b5d25fc
03d09ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9147378
 
 
 
 
 
 
 
03d09ab
 
 
 
 
 
 
 
 
 
 
 
 
 
b6fd3a8
b5d25fc
9147378
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import os
import sys
import gradio as gr
import torch
import torchaudio
import uroman
import numpy as np
import requests
import hashlib
from transformers import AutoModelForCausalLM, AutoTokenizer
from outetts.wav_tokenizer.decoder import WavTokenizer

# Set up logging
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Clone YarnGPT at startup
if not os.path.exists("yarngpt"):
    logger.info("Cloning YarnGPT repository...")
    os.system("git clone https://github.com/saheedniyi02/yarngpt.git")
    # Add the repository to Python path
    sys.path.append("yarngpt")
else:
    sys.path.append("yarngpt")

# Import the YarnGPT AudioTokenizer
from yarngpt.audiotokenizer import AudioTokenizerV2

# Constants and paths
MODEL_PATH = "saheedniyi/YarnGPT2b"
WAV_TOKENIZER_CONFIG_URL = "https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
WAV_TOKENIZER_MODEL_URL = "https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt"
WAV_TOKENIZER_CONFIG_PATH = "wavtokenizer_config.yaml"
WAV_TOKENIZER_MODEL_PATH = "wavtokenizer_model.ckpt"

# Function to download files with verification
def download_file(url, output_path):
    """Download a file with progress tracking and verification"""
    logger.info(f"Downloading {url} to {output_path}")
    
    # Stream the file download
    with requests.get(url, stream=True) as response:
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        
        with open(output_path, 'wb') as f:
            downloaded = 0
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
                    downloaded += len(chunk)
                    percent = int(100 * downloaded / total_size) if total_size > 0 else 0
                    if percent % 10 == 0:
                        logger.info(f"Download progress: {percent}%")
    
    # Verify the file exists and has content
    if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
        logger.info(f"Successfully downloaded {output_path}")
        return True
    else:
        logger.error(f"Failed to download {output_path}")
        return False

# Download the required files
def download_required_files():
    # Download config file
    if not os.path.exists(WAV_TOKENIZER_CONFIG_PATH) or os.path.getsize(WAV_TOKENIZER_CONFIG_PATH) == 0:
        logger.info("Downloading WavTokenizer config...")
        if not download_file(WAV_TOKENIZER_CONFIG_URL, WAV_TOKENIZER_CONFIG_PATH):
            raise RuntimeError("Failed to download WavTokenizer config")
    
    # Download model file
    if not os.path.exists(WAV_TOKENIZER_MODEL_PATH) or os.path.getsize(WAV_TOKENIZER_MODEL_PATH) == 0:
        logger.info("Downloading WavTokenizer model...")
        if not download_file(WAV_TOKENIZER_MODEL_URL, WAV_TOKENIZER_MODEL_PATH):
            raise RuntimeError("Failed to download WavTokenizer model")
    
    # Verify files exist
    if not os.path.exists(WAV_TOKENIZER_CONFIG_PATH) or not os.path.exists(WAV_TOKENIZER_MODEL_PATH):
        raise RuntimeError("Required files not found")
    
    # Verify files have content
    if os.path.getsize(WAV_TOKENIZER_CONFIG_PATH) == 0 or os.path.getsize(WAV_TOKENIZER_MODEL_PATH) == 0:
        raise RuntimeError("Downloaded files are empty")
    
    logger.info("All required files are downloaded and verified")

# Initialize the model and tokenizer
def initialize_model():
    try:
        # Download required files
        download_required_files()
        
        logger.info("Initializing AudioTokenizer...")
        audio_tokenizer = AudioTokenizerV2(
            MODEL_PATH, 
            WAV_TOKENIZER_MODEL_PATH, 
            WAV_TOKENIZER_CONFIG_PATH
        )
        
        logger.info("Loading YarnGPT model...")
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH, 
            torch_dtype="auto"
        ).to(audio_tokenizer.device)
        
        logger.info("Model initialization complete!")
        return model, audio_tokenizer
    except Exception as e:
        logger.error(f"Failed to initialize model: {str(e)}")
        raise

# Initialize the model and tokenizer
logger.info("Starting model initialization...")
try:
    model, audio_tokenizer = initialize_model()
except Exception as e:
    logger.error(f"Error initializing model: {str(e)}")
    # Provide a basic interface to show the error
    demo = gr.Interface(
        fn=lambda x: f"Model initialization failed: {str(e)}. Please check the space logs for more details.",
        inputs=gr.Textbox(label="Error occurred during initialization"),
        outputs=gr.Textbox(),
        title="YarnGPT - Initialization Error"
    )
    demo.launch()
    # Exit the script
    sys.exit(1)

# Available voices and languages
VOICES = ["idera", "jude", "kemi", "tunde", "funmi"]
LANGUAGES = ["english", "yoruba", "igbo", "hausa", "pidgin"]

# Function to generate speech
def generate_speech(text, language, voice, temperature=0.1, rep_penalty=1.1):
    if not text:
        return None, "Please enter some text to convert to speech."
    
    try:
        logger.info(f"Generating speech for text: {text[:50]}...")
        
        # Create prompt
        prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice)
        
        # Tokenize prompt
        input_ids = audio_tokenizer.tokenize_prompt(prompt)
        
        # Generate output
        output = model.generate(
            input_ids=input_ids,
            temperature=temperature,
            repetition_penalty=rep_penalty,
            max_length=4000,
        )
        
        # Convert to audio
        codes = audio_tokenizer.get_codes(output)
        audio = audio_tokenizer.get_audio(codes)
        
        # Save audio to file
        temp_audio_path = "output.wav"
        torchaudio.save(temp_audio_path, audio, sample_rate=24000)
        
        logger.info("Speech generation complete")
        return temp_audio_path, f"Successfully generated speech for: {text[:50]}..."
    
    except Exception as e:
        logger.error(f"Error generating speech: {str(e)}")
        return None, f"Error generating speech: {str(e)}"

# Example text for demonstration
examples = [
    ["Hello, my name is Claude. I am an AI assistant created by Anthropic.", "english", "idera"],
    ["Báwo ni o ṣe wà? Mo ń gbádùn ọjọ́ mi.", "yoruba", "kemi"],
    ["I don dey come house now, make you prepare food.", "pidgin", "jude"]
]

# Create the Gradio interface
with gr.Blocks(title="YarnGPT - Nigerian Accented Text-to-Speech") as demo:
    gr.Markdown("# YarnGPT - Nigerian Accented Text-to-Speech")
    gr.Markdown("Generate speech with Nigerian accents using YarnGPT model.")
    
    with gr.Tab("Basic TTS"):
        with gr.Row():
            with gr.Column():
                text_input = gr.Textbox(
                    label="Text to convert to speech", 
                    placeholder="Enter text here...",
                    lines=5
                )
                language = gr.Dropdown(
                    label="Language", 
                    choices=LANGUAGES, 
                    value="english"
                )
                voice = gr.Dropdown(
                    label="Voice", 
                    choices=VOICES, 
                    value="idera"
                )
                temperature = gr.Slider(
                    label="Temperature", 
                    minimum=0.1, 
                    maximum=1.0, 
                    value=0.1, 
                    step=0.1
                )
                rep_penalty = gr.Slider(
                    label="Repetition Penalty", 
                    minimum=1.0, 
                    maximum=2.0, 
                    value=1.1, 
                    step=0.1
                )
                generate_btn = gr.Button("Generate Speech")
            
            with gr.Column():
                audio_output = gr.Audio(label="Generated Speech")
                status_output = gr.Textbox(label="Status")
    
        gr.Examples(
            examples=examples,
            inputs=[text_input, language, voice],
            outputs=[audio_output, status_output],
            fn=generate_speech,
            cache_examples=False
        )
    
    generate_btn.click(
        generate_speech, 
        inputs=[text_input, language, voice, temperature, rep_penalty], 
        outputs=[audio_output, status_output]
    )
    
    gr.Markdown("""
    ## About YarnGPT
    YarnGPT is a text-to-speech model with Nigerian accents. It supports multiple languages and voices.
    
    ### Credits
    - Model by [saheedniyi](https://huggingface.co/saheedniyi/YarnGPT2b)
    - [Original Repository](https://github.com/saheedniyi02/yarngpt)
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch()