File size: 5,388 Bytes
78b0078
29bfa47
d4a2e16
b5d25fc
 
 
 
 
 
 
 
29bfa47
b5d25fc
 
d4a2e16
ac16e60
 
 
 
 
 
 
 
 
 
 
 
 
29bfa47
ac16e60
b5d25fc
 
d4a2e16
ac16e60
 
 
 
 
 
29bfa47
ac16e60
 
 
 
 
 
29bfa47
 
 
 
ac16e60
29bfa47
 
 
 
 
b5d25fc
ac16e60
 
 
 
 
b6fd3a8
b5d25fc
ac16e60
 
 
 
 
b6fd3a8
b5d25fc
 
ac16e60
 
 
b6fd3a8
b5d25fc
ac16e60
b6fd3a8
 
 
ac16e60
b6fd3a8
 
 
 
b5d25fc
 
b6fd3a8
 
ac16e60
b6fd3a8
b5d25fc
b6fd3a8
ac16e60
 
b6fd3a8
ac16e60
b6fd3a8
b5d25fc
 
 
ac16e60
b6fd3a8
b5d25fc
fa5149d
b6fd3a8
b5d25fc
 
 
 
 
 
 
 
 
 
 
ac16e60
 
 
b5d25fc
 
 
 
 
b6fd3a8
b5d25fc
b6fd3a8
b5d25fc
b6fd3a8
29bfa47
b5d25fc
 
 
 
b6fd3a8
 
b5d25fc
ac16e60
 
 
 
 
 
b6fd3a8
 
b5d25fc
 
29bfa47
b5d25fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import sys
import re
import json
import torch
import inflect
import random
import uroman as ur
import numpy as np
import torchaudio
import gradio as gr
import subprocess
from transformers import AutoModelForCausalLM, AutoTokenizer
from outetts.wav_tokenizer.decoder import WavTokenizer

# Clone the YarnGPT repository if it doesn't exist
if not os.path.exists("yarngpt"):
    print("Cloning YarnGPT repository...")
    subprocess.run(["git", "clone", "https://github.com/saheedniyi02/yarngpt.git"], check=True)

# Add the yarngpt directory to the Python path
yarngpt_path = os.path.abspath("yarngpt")
if yarngpt_path not in sys.path:
    sys.path.append(yarngpt_path)
    print(f"Added {yarngpt_path} to Python path")

# Now try importing from yarngpt
from yarngpt.audiotokenizer import AudioTokenizerV2

# Download model files if they don't exist
wav_tokenizer_config_path = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
wav_tokenizer_model_path = "wavtokenizer_large_speech_320_24k.ckpt"

if not os.path.exists(wav_tokenizer_config_path):
    print(f"Downloading {wav_tokenizer_config_path}...")
    subprocess.run([
        "wget", 
        "https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
    ], check=True)
    
if not os.path.exists(wav_tokenizer_model_path):
    print(f"Downloading {wav_tokenizer_model_path}...")
    subprocess.run([
        "wget", 
        "https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt"
    ], check=True)

# Initialize paths and models
tokenizer_path = "saheedniyi/YarnGPT2"

# Print debug info
print(f"Current directory: {os.getcwd()}")
print(f"Files in directory: {os.listdir('.')}")
print(f"Config exists: {os.path.exists(wav_tokenizer_config_path)}")
print(f"Model exists: {os.path.exists(wav_tokenizer_model_path)}")

# Initialize the audio tokenizer
print("Initializing audio tokenizer...")
audio_tokenizer = AudioTokenizerV2(
    tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
)
print("Audio tokenizer initialized")

# Load the model
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    tokenizer_path, torch_dtype="auto"
).to(audio_tokenizer.device)
print("Model loaded successfully")

# Function to generate speech
def generate_speech(text, language, speaker_name, temperature=0.1, repetition_penalty=1.1):
    print(f"Generating speech for: '{text[:50]}...'")
    print(f"Parameters: language={language}, speaker={speaker_name}, temp={temperature}, rep_penalty={repetition_penalty}")
    
    # Create prompt
    prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=speaker_name)
    print("Prompt created")
    
    # Tokenize prompt
    input_ids = audio_tokenizer.tokenize_prompt(prompt)
    print("Prompt tokenized")
    
    # Generate output
    output = model.generate(
        input_ids=input_ids,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        max_length=4000,
    )
    print("Model generation complete")
    
    # Get audio codes and convert to audio
    codes = audio_tokenizer.get_codes(output)
    print("Audio codes extracted")
    
    audio = audio_tokenizer.get_audio(codes)
    print("Audio generated")
    
    # Save audio to file
    output_path = "output.wav"
    torchaudio.save(output_path, audio, sample_rate=24000)
    print(f"Audio saved to {output_path}")
    
    return output_path

# Create Gradio interface
def tts_interface(text, language, speaker_name, temperature, repetition_penalty):
    try:
        audio_path = generate_speech(
            text, 
            language, 
            speaker_name,
            temperature,
            repetition_penalty
        )
        return audio_path
    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        print(f"Error in tts_interface: {str(e)}\n{error_details}")
        return f"Error: {str(e)}"

# Define available languages and speakers
languages = ["english", "igbo", "yoruba", "hausa", "pidgin"]
speakers = ["idera", "enitan", "abeo", "eniola", "kachi", "aisha", "amara", "bello", "chidi"]

# Create the Gradio interface
demo = gr.Interface(
    fn=tts_interface,
    inputs=[
        gr.Textbox(label="Text to convert to speech", lines=5, value="Welcome to YarnGPT text-to-speech model for African languages."),
        gr.Dropdown(languages, label="Language", value="english"),
        gr.Dropdown(speakers, label="Speaker", value="idera"),
        gr.Slider(0.1, 1.0, value=0.1, label="Temperature"),
        gr.Slider(1.0, 2.0, value=1.1, label="Repetition Penalty"),
    ],
    outputs=gr.Audio(type="filepath"),
    title="YarnGPT Text-to-Speech",
    description="Convert text to speech using YarnGPT model for various African languages.",
    examples=[
        ["The election was won by businessman and politician, Moshood Abiola, but Babangida annulled the results, citing concerns over national security.", "english", "idera", 0.1, 1.1],
        ["Hello, how are you today?", "english", "enitan", 0.1, 1.1],
        ["Bawo ni?", "yoruba", "eniola", 0.2, 1.2],
    ]
)

# Launch the app
if __name__ == "__main__":
    print("Starting Gradio interface...")
    demo.launch()