okewunmi commited on
Commit
29bfa47
·
verified ·
1 Parent(s): 531b21a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -10
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import re
3
  import json
4
  import torch
@@ -8,24 +9,72 @@ import uroman as ur
8
  import numpy as np
9
  import torchaudio
10
  import gradio as gr
 
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
  from outetts.wav_tokenizer.decoder import WavTokenizer
13
- from yarngpt.audiotokenizer import AudioTokenizerV2
14
 
15
- # Initialize paths and models
16
- tokenizer_path = "saheedniyi/YarnGPT2"
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  wav_tokenizer_config_path = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
18
  wav_tokenizer_model_path = "wavtokenizer_large_speech_320_24k.ckpt"
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Initialize the audio tokenizer
21
- audio_tokenizer = AudioTokenizerV2(
22
- tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
23
- )
 
 
 
 
 
 
24
 
25
  # Load the model
26
- model = AutoModelForCausalLM.from_pretrained(
27
- tokenizer_path, torch_dtype="auto"
28
- ).to(audio_tokenizer.device)
 
 
 
 
 
 
29
 
30
  # Function to generate speech
31
  def generate_speech(text, language, speaker_name, temperature=0.1, repetition_penalty=1.1):
@@ -56,6 +105,7 @@ def generate_speech(text, language, speaker_name, temperature=0.1, repetition_pe
56
  # Create Gradio interface
57
  def tts_interface(text, language, speaker_name, temperature, repetition_penalty):
58
  try:
 
59
  audio_path = generate_speech(
60
  text,
61
  language,
@@ -63,8 +113,10 @@ def tts_interface(text, language, speaker_name, temperature, repetition_penalty)
63
  temperature,
64
  repetition_penalty
65
  )
 
66
  return audio_path
67
  except Exception as e:
 
68
  return f"Error: {str(e)}"
69
 
70
  # Define available languages and speakers
@@ -75,7 +127,7 @@ speakers = ["idera", "enitan", "abeo", "eniola", "kachi", "aisha", "amara", "bel
75
  demo = gr.Interface(
76
  fn=tts_interface,
77
  inputs=[
78
- gr.Textbox(label="Text to convert to speech", lines=5),
79
  gr.Dropdown(languages, label="Language", value="english"),
80
  gr.Dropdown(speakers, label="Speaker", value="idera"),
81
  gr.Slider(0.1, 1.0, value=0.1, label="Temperature"),
@@ -88,4 +140,5 @@ demo = gr.Interface(
88
 
89
  # Launch the app
90
  if __name__ == "__main__":
 
91
  demo.launch()
 
1
  import os
2
+ import sys
3
  import re
4
  import json
5
  import torch
 
9
  import numpy as np
10
  import torchaudio
11
  import gradio as gr
12
+ import subprocess
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
14
  from outetts.wav_tokenizer.decoder import WavTokenizer
 
15
 
16
+ # Check if yarngpt is installed, if not install it manually
17
+ try:
18
+ from yarngpt.audiotokenizer import AudioTokenizerV2
19
+ except ImportError:
20
+ print("YarnGPT not found, attempting to install...")
21
+ subprocess.run(["chmod", "+x", "install.sh"], check=True)
22
+ subprocess.run(["./install.sh"], check=True)
23
+
24
+ # Add the yarngpt directory to the Python path
25
+ sys.path.append(os.path.join(os.getcwd(), "yarngpt"))
26
+
27
+ # Try importing again
28
+ from yarngpt.audiotokenizer import AudioTokenizerV2
29
+
30
+ # Check if model files exist
31
  wav_tokenizer_config_path = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
32
  wav_tokenizer_model_path = "wavtokenizer_large_speech_320_24k.ckpt"
33
 
34
+ if not os.path.exists(wav_tokenizer_config_path) or not os.path.exists(wav_tokenizer_model_path):
35
+ print("Model files not found, downloading...")
36
+ if not os.path.exists(wav_tokenizer_config_path):
37
+ subprocess.run([
38
+ "wget",
39
+ "https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
40
+ ], check=True)
41
+
42
+ if not os.path.exists(wav_tokenizer_model_path):
43
+ subprocess.run([
44
+ "wget",
45
+ "https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt"
46
+ ], check=True)
47
+
48
+ # Initialize paths and models
49
+ tokenizer_path = "saheedniyi/YarnGPT2"
50
+
51
+ # Add debug info
52
+ print(f"Current directory: {os.getcwd()}")
53
+ print(f"Files in directory: {os.listdir('.')}")
54
+ print(f"Config exists: {os.path.exists(wav_tokenizer_config_path)}")
55
+ print(f"Model exists: {os.path.exists(wav_tokenizer_model_path)}")
56
+
57
  # Initialize the audio tokenizer
58
+ try:
59
+ print("Initializing audio tokenizer...")
60
+ audio_tokenizer = AudioTokenizerV2(
61
+ tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
62
+ )
63
+ print("Audio tokenizer initialized")
64
+ except Exception as e:
65
+ print(f"Error initializing audio tokenizer: {str(e)}")
66
+ raise
67
 
68
  # Load the model
69
+ try:
70
+ print("Loading model...")
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ tokenizer_path, torch_dtype="auto"
73
+ ).to(audio_tokenizer.device)
74
+ print("Model loaded")
75
+ except Exception as e:
76
+ print(f"Error loading model: {str(e)}")
77
+ raise
78
 
79
  # Function to generate speech
80
  def generate_speech(text, language, speaker_name, temperature=0.1, repetition_penalty=1.1):
 
105
  # Create Gradio interface
106
  def tts_interface(text, language, speaker_name, temperature, repetition_penalty):
107
  try:
108
+ print(f"Generating speech for: {text[:30]}...")
109
  audio_path = generate_speech(
110
  text,
111
  language,
 
113
  temperature,
114
  repetition_penalty
115
  )
116
+ print("Speech generated successfully")
117
  return audio_path
118
  except Exception as e:
119
+ print(f"Error in tts_interface: {str(e)}")
120
  return f"Error: {str(e)}"
121
 
122
  # Define available languages and speakers
 
127
  demo = gr.Interface(
128
  fn=tts_interface,
129
  inputs=[
130
+ gr.Textbox(label="Text to convert to speech", lines=5, value="Welcome to YarnGPT text-to-speech model for African languages."),
131
  gr.Dropdown(languages, label="Language", value="english"),
132
  gr.Dropdown(speakers, label="Speaker", value="idera"),
133
  gr.Slider(0.1, 1.0, value=0.1, label="Temperature"),
 
140
 
141
  # Launch the app
142
  if __name__ == "__main__":
143
+ print("Starting Gradio interface...")
144
  demo.launch()