Update app.py
Browse files
app.py
CHANGED
@@ -1,151 +1,137 @@
|
|
1 |
import os
|
2 |
-
import
|
3 |
-
import re
|
4 |
-
import json
|
5 |
import torch
|
6 |
-
import inflect
|
7 |
-
import random
|
8 |
-
import uroman as ur
|
9 |
-
import numpy as np
|
10 |
import torchaudio
|
11 |
-
import
|
12 |
-
import
|
13 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
14 |
from outetts.wav_tokenizer.decoder import WavTokenizer
|
15 |
|
16 |
-
#
|
17 |
-
|
18 |
-
print("Cloning YarnGPT repository...")
|
19 |
-
subprocess.run(["git", "clone", "https://github.com/saheedniyi02/yarngpt.git"], check=True)
|
20 |
-
|
21 |
-
# Add the yarngpt directory to the Python path
|
22 |
-
yarngpt_path = os.path.abspath("yarngpt")
|
23 |
-
if yarngpt_path not in sys.path:
|
24 |
-
sys.path.append(yarngpt_path)
|
25 |
-
print(f"Added {yarngpt_path} to Python path")
|
26 |
-
|
27 |
-
# Now try importing from yarngpt
|
28 |
from yarngpt.audiotokenizer import AudioTokenizerV2
|
29 |
|
30 |
-
#
|
31 |
-
|
32 |
-
|
|
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
"https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
|
39 |
-
], check=True)
|
40 |
-
|
41 |
-
if not os.path.exists(wav_tokenizer_model_path):
|
42 |
-
print(f"Downloading {wav_tokenizer_model_path}...")
|
43 |
-
subprocess.run([
|
44 |
-
"wget",
|
45 |
-
"https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt"
|
46 |
-
], check=True)
|
47 |
|
48 |
-
# Initialize
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
# Initialize the
|
58 |
-
|
59 |
-
audio_tokenizer = AudioTokenizerV2(
|
60 |
-
tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
|
61 |
-
)
|
62 |
-
print("Audio tokenizer initialized")
|
63 |
|
64 |
-
#
|
65 |
-
|
66 |
-
|
67 |
-
tokenizer_path, torch_dtype="auto"
|
68 |
-
).to(audio_tokenizer.device)
|
69 |
-
print("Model loaded successfully")
|
70 |
|
71 |
# Function to generate speech
|
72 |
-
def generate_speech(text, language,
|
73 |
-
|
74 |
-
|
75 |
|
76 |
-
# Create prompt
|
77 |
-
prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=speaker_name)
|
78 |
-
print("Prompt created")
|
79 |
-
|
80 |
-
# Tokenize prompt
|
81 |
-
input_ids = audio_tokenizer.tokenize_prompt(prompt)
|
82 |
-
print("Prompt tokenized")
|
83 |
-
|
84 |
-
# Generate output
|
85 |
-
output = model.generate(
|
86 |
-
input_ids=input_ids,
|
87 |
-
temperature=temperature,
|
88 |
-
repetition_penalty=repetition_penalty,
|
89 |
-
max_length=4000,
|
90 |
-
)
|
91 |
-
print("Model generation complete")
|
92 |
-
|
93 |
-
# Get audio codes and convert to audio
|
94 |
-
codes = audio_tokenizer.get_codes(output)
|
95 |
-
print("Audio codes extracted")
|
96 |
-
|
97 |
-
audio = audio_tokenizer.get_audio(codes)
|
98 |
-
print("Audio generated")
|
99 |
-
|
100 |
-
# Save audio to file
|
101 |
-
output_path = "output.wav"
|
102 |
-
torchaudio.save(output_path, audio, sample_rate=24000)
|
103 |
-
print(f"Audio saved to {output_path}")
|
104 |
-
|
105 |
-
return output_path
|
106 |
-
|
107 |
-
# Create Gradio interface
|
108 |
-
def tts_interface(text, language, speaker_name, temperature, repetition_penalty):
|
109 |
try:
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
)
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
except Exception as e:
|
119 |
-
|
120 |
-
error_details = traceback.format_exc()
|
121 |
-
print(f"Error in tts_interface: {str(e)}\n{error_details}")
|
122 |
-
return f"Error: {str(e)}"
|
123 |
-
|
124 |
-
# Define available languages and speakers
|
125 |
-
languages = ["english", "igbo", "yoruba", "hausa", "pidgin"]
|
126 |
-
speakers = ["idera", "enitan", "abeo", "eniola", "kachi", "aisha", "amara", "bello", "chidi"]
|
127 |
|
128 |
# Create the Gradio interface
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
gr.
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
# Launch the app
|
149 |
-
|
150 |
-
print("Starting Gradio interface...")
|
151 |
-
demo.launch()
|
|
|
1 |
import os
|
2 |
+
import gradio as gr
|
|
|
|
|
3 |
import torch
|
|
|
|
|
|
|
|
|
4 |
import torchaudio
|
5 |
+
import uroman
|
6 |
+
import numpy as np
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
from outetts.wav_tokenizer.decoder import WavTokenizer
|
9 |
|
10 |
+
# Import the YarnGPT AudioTokenizer
|
11 |
+
# Assuming the git repository is cloned in the same directory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from yarngpt.audiotokenizer import AudioTokenizerV2
|
13 |
|
14 |
+
# Constants and paths
|
15 |
+
MODEL_PATH = "saheedniyi/YarnGPT2b"
|
16 |
+
WAV_TOKENIZER_CONFIG_PATH = "wavtokenizer_config.yaml"
|
17 |
+
WAV_TOKENIZER_MODEL_PATH = "wavtokenizer_model.ckpt"
|
18 |
|
19 |
+
# Download the model files at startup
|
20 |
+
os.system(f"wget -O {WAV_TOKENIZER_CONFIG_PATH} https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml")
|
21 |
+
os.system(f"wget -O {WAV_TOKENIZER_MODEL_PATH} https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt")
|
22 |
+
os.system("git clone https://github.com/saheedniyi02/yarngpt.git")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
# Initialize the model and tokenizer
|
25 |
+
def initialize_model():
|
26 |
+
audio_tokenizer = AudioTokenizerV2(
|
27 |
+
MODEL_PATH,
|
28 |
+
WAV_TOKENIZER_MODEL_PATH,
|
29 |
+
WAV_TOKENIZER_CONFIG_PATH
|
30 |
+
)
|
31 |
+
model = AutoModelForCausalLM.from_pretrained(
|
32 |
+
MODEL_PATH,
|
33 |
+
torch_dtype="auto"
|
34 |
+
).to(audio_tokenizer.device)
|
35 |
+
|
36 |
+
return model, audio_tokenizer
|
37 |
|
38 |
+
# Initialize the model and tokenizer
|
39 |
+
model, audio_tokenizer = initialize_model()
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
# Available voices and languages
|
42 |
+
VOICES = ["idera", "jude", "kemi", "tunde", "funmi"]
|
43 |
+
LANGUAGES = ["english", "yoruba", "igbo", "hausa", "pidgin"]
|
|
|
|
|
|
|
44 |
|
45 |
# Function to generate speech
|
46 |
+
def generate_speech(text, language, voice, temperature=0.1, rep_penalty=1.1):
|
47 |
+
if not text:
|
48 |
+
return None, "Please enter some text to convert to speech."
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
try:
|
51 |
+
# Create prompt
|
52 |
+
prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice)
|
53 |
+
|
54 |
+
# Tokenize prompt
|
55 |
+
input_ids = audio_tokenizer.tokenize_prompt(prompt)
|
56 |
+
|
57 |
+
# Generate output
|
58 |
+
output = model.generate(
|
59 |
+
input_ids=input_ids,
|
60 |
+
temperature=temperature,
|
61 |
+
repetition_penalty=rep_penalty,
|
62 |
+
max_length=4000,
|
63 |
)
|
64 |
+
|
65 |
+
# Convert to audio
|
66 |
+
codes = audio_tokenizer.get_codes(output)
|
67 |
+
audio = audio_tokenizer.get_audio(codes)
|
68 |
+
|
69 |
+
# Save audio to file
|
70 |
+
temp_audio_path = "output.wav"
|
71 |
+
torchaudio.save(temp_audio_path, audio, sample_rate=24000)
|
72 |
+
|
73 |
+
return temp_audio_path, f"Successfully generated speech for: {text[:50]}..."
|
74 |
+
|
75 |
except Exception as e:
|
76 |
+
return None, f"Error generating speech: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
# Create the Gradio interface
|
79 |
+
with gr.Blocks(title="YarnGPT - Nigerian Accented Text-to-Speech") as demo:
|
80 |
+
gr.Markdown("# YarnGPT - Nigerian Accented Text-to-Speech")
|
81 |
+
gr.Markdown("Generate speech with Nigerian accents using YarnGPT model.")
|
82 |
+
|
83 |
+
with gr.Tab("Basic TTS"):
|
84 |
+
with gr.Row():
|
85 |
+
with gr.Column():
|
86 |
+
text_input = gr.Textbox(
|
87 |
+
label="Text to convert to speech",
|
88 |
+
placeholder="Enter text here...",
|
89 |
+
lines=5
|
90 |
+
)
|
91 |
+
language = gr.Dropdown(
|
92 |
+
label="Language",
|
93 |
+
choices=LANGUAGES,
|
94 |
+
value="english"
|
95 |
+
)
|
96 |
+
voice = gr.Dropdown(
|
97 |
+
label="Voice",
|
98 |
+
choices=VOICES,
|
99 |
+
value="idera"
|
100 |
+
)
|
101 |
+
temperature = gr.Slider(
|
102 |
+
label="Temperature",
|
103 |
+
minimum=0.1,
|
104 |
+
maximum=1.0,
|
105 |
+
value=0.1,
|
106 |
+
step=0.1
|
107 |
+
)
|
108 |
+
rep_penalty = gr.Slider(
|
109 |
+
label="Repetition Penalty",
|
110 |
+
minimum=1.0,
|
111 |
+
maximum=2.0,
|
112 |
+
value=1.1,
|
113 |
+
step=0.1
|
114 |
+
)
|
115 |
+
generate_btn = gr.Button("Generate Speech")
|
116 |
+
|
117 |
+
with gr.Column():
|
118 |
+
audio_output = gr.Audio(label="Generated Speech")
|
119 |
+
status_output = gr.Textbox(label="Status")
|
120 |
+
|
121 |
+
generate_btn.click(
|
122 |
+
generate_speech,
|
123 |
+
inputs=[text_input, language, voice, temperature, rep_penalty],
|
124 |
+
outputs=[audio_output, status_output]
|
125 |
+
)
|
126 |
+
|
127 |
+
gr.Markdown("""
|
128 |
+
## About YarnGPT
|
129 |
+
YarnGPT is a text-to-speech model with Nigerian accents. It supports multiple languages and voices.
|
130 |
+
|
131 |
+
### Credits
|
132 |
+
- Model by [saheedniyi](https://huggingface.co/saheedniyi/YarnGPT2b)
|
133 |
+
- [Original Repository](https://github.com/saheedniyi02/yarngpt)
|
134 |
+
""")
|
135 |
|
136 |
# Launch the app
|
137 |
+
demo.launch()
|
|
|
|