okewunmi commited on
Commit
03d09ab
·
verified ·
1 Parent(s): 5155c90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -129
app.py CHANGED
@@ -1,151 +1,137 @@
1
  import os
2
- import sys
3
- import re
4
- import json
5
  import torch
6
- import inflect
7
- import random
8
- import uroman as ur
9
- import numpy as np
10
  import torchaudio
11
- import gradio as gr
12
- import subprocess
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
14
  from outetts.wav_tokenizer.decoder import WavTokenizer
15
 
16
- # Clone the YarnGPT repository if it doesn't exist
17
- if not os.path.exists("yarngpt"):
18
- print("Cloning YarnGPT repository...")
19
- subprocess.run(["git", "clone", "https://github.com/saheedniyi02/yarngpt.git"], check=True)
20
-
21
- # Add the yarngpt directory to the Python path
22
- yarngpt_path = os.path.abspath("yarngpt")
23
- if yarngpt_path not in sys.path:
24
- sys.path.append(yarngpt_path)
25
- print(f"Added {yarngpt_path} to Python path")
26
-
27
- # Now try importing from yarngpt
28
  from yarngpt.audiotokenizer import AudioTokenizerV2
29
 
30
- # Download model files if they don't exist
31
- wav_tokenizer_config_path = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
32
- wav_tokenizer_model_path = "wavtokenizer_large_speech_320_24k.ckpt"
 
33
 
34
- if not os.path.exists(wav_tokenizer_config_path):
35
- print(f"Downloading {wav_tokenizer_config_path}...")
36
- subprocess.run([
37
- "wget",
38
- "https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
39
- ], check=True)
40
-
41
- if not os.path.exists(wav_tokenizer_model_path):
42
- print(f"Downloading {wav_tokenizer_model_path}...")
43
- subprocess.run([
44
- "wget",
45
- "https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt"
46
- ], check=True)
47
 
48
- # Initialize paths and models
49
- tokenizer_path = "saheedniyi/YarnGPT2"
50
-
51
- # Print debug info
52
- print(f"Current directory: {os.getcwd()}")
53
- print(f"Files in directory: {os.listdir('.')}")
54
- print(f"Config exists: {os.path.exists(wav_tokenizer_config_path)}")
55
- print(f"Model exists: {os.path.exists(wav_tokenizer_model_path)}")
 
 
 
 
 
56
 
57
- # Initialize the audio tokenizer
58
- print("Initializing audio tokenizer...")
59
- audio_tokenizer = AudioTokenizerV2(
60
- tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
61
- )
62
- print("Audio tokenizer initialized")
63
 
64
- # Load the model
65
- print("Loading model...")
66
- model = AutoModelForCausalLM.from_pretrained(
67
- tokenizer_path, torch_dtype="auto"
68
- ).to(audio_tokenizer.device)
69
- print("Model loaded successfully")
70
 
71
  # Function to generate speech
72
- def generate_speech(text, language, speaker_name, temperature=0.1, repetition_penalty=1.1):
73
- print(f"Generating speech for: '{text[:50]}...'")
74
- print(f"Parameters: language={language}, speaker={speaker_name}, temp={temperature}, rep_penalty={repetition_penalty}")
75
 
76
- # Create prompt
77
- prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=speaker_name)
78
- print("Prompt created")
79
-
80
- # Tokenize prompt
81
- input_ids = audio_tokenizer.tokenize_prompt(prompt)
82
- print("Prompt tokenized")
83
-
84
- # Generate output
85
- output = model.generate(
86
- input_ids=input_ids,
87
- temperature=temperature,
88
- repetition_penalty=repetition_penalty,
89
- max_length=4000,
90
- )
91
- print("Model generation complete")
92
-
93
- # Get audio codes and convert to audio
94
- codes = audio_tokenizer.get_codes(output)
95
- print("Audio codes extracted")
96
-
97
- audio = audio_tokenizer.get_audio(codes)
98
- print("Audio generated")
99
-
100
- # Save audio to file
101
- output_path = "output.wav"
102
- torchaudio.save(output_path, audio, sample_rate=24000)
103
- print(f"Audio saved to {output_path}")
104
-
105
- return output_path
106
-
107
- # Create Gradio interface
108
- def tts_interface(text, language, speaker_name, temperature, repetition_penalty):
109
  try:
110
- audio_path = generate_speech(
111
- text,
112
- language,
113
- speaker_name,
114
- temperature,
115
- repetition_penalty
 
 
 
 
 
 
116
  )
117
- return audio_path
 
 
 
 
 
 
 
 
 
 
118
  except Exception as e:
119
- import traceback
120
- error_details = traceback.format_exc()
121
- print(f"Error in tts_interface: {str(e)}\n{error_details}")
122
- return f"Error: {str(e)}"
123
-
124
- # Define available languages and speakers
125
- languages = ["english", "igbo", "yoruba", "hausa", "pidgin"]
126
- speakers = ["idera", "enitan", "abeo", "eniola", "kachi", "aisha", "amara", "bello", "chidi"]
127
 
128
  # Create the Gradio interface
129
- demo = gr.Interface(
130
- fn=tts_interface,
131
- inputs=[
132
- gr.Textbox(label="Text to convert to speech", lines=5, value="Welcome to YarnGPT text-to-speech model for African languages."),
133
- gr.Dropdown(languages, label="Language", value="english"),
134
- gr.Dropdown(speakers, label="Speaker", value="idera"),
135
- gr.Slider(0.1, 1.0, value=0.1, label="Temperature"),
136
- gr.Slider(1.0, 2.0, value=1.1, label="Repetition Penalty"),
137
- ],
138
- outputs=gr.Audio(type="filepath"),
139
- title="YarnGPT Text-to-Speech",
140
- description="Convert text to speech using YarnGPT model for various African languages.",
141
- examples=[
142
- ["The election was won by businessman and politician, Moshood Abiola, but Babangida annulled the results, citing concerns over national security.", "english", "idera", 0.1, 1.1],
143
- ["Hello, how are you today?", "english", "enitan", 0.1, 1.1],
144
- ["Bawo ni?", "yoruba", "eniola", 0.2, 1.2],
145
- ]
146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  # Launch the app
149
- if __name__ == "__main__":
150
- print("Starting Gradio interface...")
151
- demo.launch()
 
1
  import os
2
+ import gradio as gr
 
 
3
  import torch
 
 
 
 
4
  import torchaudio
5
+ import uroman
6
+ import numpy as np
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from outetts.wav_tokenizer.decoder import WavTokenizer
9
 
10
+ # Import the YarnGPT AudioTokenizer
11
+ # Assuming the git repository is cloned in the same directory
 
 
 
 
 
 
 
 
 
 
12
  from yarngpt.audiotokenizer import AudioTokenizerV2
13
 
14
+ # Constants and paths
15
+ MODEL_PATH = "saheedniyi/YarnGPT2b"
16
+ WAV_TOKENIZER_CONFIG_PATH = "wavtokenizer_config.yaml"
17
+ WAV_TOKENIZER_MODEL_PATH = "wavtokenizer_model.ckpt"
18
 
19
+ # Download the model files at startup
20
+ os.system(f"wget -O {WAV_TOKENIZER_CONFIG_PATH} https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml")
21
+ os.system(f"wget -O {WAV_TOKENIZER_MODEL_PATH} https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt")
22
+ os.system("git clone https://github.com/saheedniyi02/yarngpt.git")
 
 
 
 
 
 
 
 
 
23
 
24
+ # Initialize the model and tokenizer
25
+ def initialize_model():
26
+ audio_tokenizer = AudioTokenizerV2(
27
+ MODEL_PATH,
28
+ WAV_TOKENIZER_MODEL_PATH,
29
+ WAV_TOKENIZER_CONFIG_PATH
30
+ )
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ MODEL_PATH,
33
+ torch_dtype="auto"
34
+ ).to(audio_tokenizer.device)
35
+
36
+ return model, audio_tokenizer
37
 
38
+ # Initialize the model and tokenizer
39
+ model, audio_tokenizer = initialize_model()
 
 
 
 
40
 
41
+ # Available voices and languages
42
+ VOICES = ["idera", "jude", "kemi", "tunde", "funmi"]
43
+ LANGUAGES = ["english", "yoruba", "igbo", "hausa", "pidgin"]
 
 
 
44
 
45
  # Function to generate speech
46
+ def generate_speech(text, language, voice, temperature=0.1, rep_penalty=1.1):
47
+ if not text:
48
+ return None, "Please enter some text to convert to speech."
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
+ # Create prompt
52
+ prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=voice)
53
+
54
+ # Tokenize prompt
55
+ input_ids = audio_tokenizer.tokenize_prompt(prompt)
56
+
57
+ # Generate output
58
+ output = model.generate(
59
+ input_ids=input_ids,
60
+ temperature=temperature,
61
+ repetition_penalty=rep_penalty,
62
+ max_length=4000,
63
  )
64
+
65
+ # Convert to audio
66
+ codes = audio_tokenizer.get_codes(output)
67
+ audio = audio_tokenizer.get_audio(codes)
68
+
69
+ # Save audio to file
70
+ temp_audio_path = "output.wav"
71
+ torchaudio.save(temp_audio_path, audio, sample_rate=24000)
72
+
73
+ return temp_audio_path, f"Successfully generated speech for: {text[:50]}..."
74
+
75
  except Exception as e:
76
+ return None, f"Error generating speech: {str(e)}"
 
 
 
 
 
 
 
77
 
78
  # Create the Gradio interface
79
+ with gr.Blocks(title="YarnGPT - Nigerian Accented Text-to-Speech") as demo:
80
+ gr.Markdown("# YarnGPT - Nigerian Accented Text-to-Speech")
81
+ gr.Markdown("Generate speech with Nigerian accents using YarnGPT model.")
82
+
83
+ with gr.Tab("Basic TTS"):
84
+ with gr.Row():
85
+ with gr.Column():
86
+ text_input = gr.Textbox(
87
+ label="Text to convert to speech",
88
+ placeholder="Enter text here...",
89
+ lines=5
90
+ )
91
+ language = gr.Dropdown(
92
+ label="Language",
93
+ choices=LANGUAGES,
94
+ value="english"
95
+ )
96
+ voice = gr.Dropdown(
97
+ label="Voice",
98
+ choices=VOICES,
99
+ value="idera"
100
+ )
101
+ temperature = gr.Slider(
102
+ label="Temperature",
103
+ minimum=0.1,
104
+ maximum=1.0,
105
+ value=0.1,
106
+ step=0.1
107
+ )
108
+ rep_penalty = gr.Slider(
109
+ label="Repetition Penalty",
110
+ minimum=1.0,
111
+ maximum=2.0,
112
+ value=1.1,
113
+ step=0.1
114
+ )
115
+ generate_btn = gr.Button("Generate Speech")
116
+
117
+ with gr.Column():
118
+ audio_output = gr.Audio(label="Generated Speech")
119
+ status_output = gr.Textbox(label="Status")
120
+
121
+ generate_btn.click(
122
+ generate_speech,
123
+ inputs=[text_input, language, voice, temperature, rep_penalty],
124
+ outputs=[audio_output, status_output]
125
+ )
126
+
127
+ gr.Markdown("""
128
+ ## About YarnGPT
129
+ YarnGPT is a text-to-speech model with Nigerian accents. It supports multiple languages and voices.
130
+
131
+ ### Credits
132
+ - Model by [saheedniyi](https://huggingface.co/saheedniyi/YarnGPT2b)
133
+ - [Original Repository](https://github.com/saheedniyi02/yarngpt)
134
+ """)
135
 
136
  # Launch the app
137
+ demo.launch()