okewunmi commited on
Commit
ac16e60
·
verified ·
1 Parent(s): fd3c8cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -49
app.py CHANGED
@@ -13,76 +13,73 @@ import subprocess
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
14
  from outetts.wav_tokenizer.decoder import WavTokenizer
15
 
16
- # Check if yarngpt is installed, if not install it manually
17
- try:
18
- from yarngpt.audiotokenizer import AudioTokenizerV2
19
- except ImportError:
20
- print("YarnGPT not found, attempting to install...")
21
- subprocess.run(["chmod", "+x", "install.sh"], check=True)
22
- subprocess.run(["./install.sh"], check=True)
23
-
24
- # Add the yarngpt directory to the Python path
25
- sys.path.append(os.path.join(os.getcwd(), "yarngpt"))
26
-
27
- # Try importing again
28
- from yarngpt.audiotokenizer import AudioTokenizerV2
29
 
30
- # Check if model files exist
31
  wav_tokenizer_config_path = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
32
  wav_tokenizer_model_path = "wavtokenizer_large_speech_320_24k.ckpt"
33
 
34
- if not os.path.exists(wav_tokenizer_config_path) or not os.path.exists(wav_tokenizer_model_path):
35
- print("Model files not found, downloading...")
36
- if not os.path.exists(wav_tokenizer_config_path):
37
- subprocess.run([
38
- "wget",
39
- "https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
40
- ], check=True)
41
 
42
- if not os.path.exists(wav_tokenizer_model_path):
43
- subprocess.run([
44
- "wget",
45
- "https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt"
46
- ], check=True)
 
47
 
48
  # Initialize paths and models
49
  tokenizer_path = "saheedniyi/YarnGPT2"
50
 
51
- # Add debug info
52
  print(f"Current directory: {os.getcwd()}")
53
  print(f"Files in directory: {os.listdir('.')}")
54
  print(f"Config exists: {os.path.exists(wav_tokenizer_config_path)}")
55
  print(f"Model exists: {os.path.exists(wav_tokenizer_model_path)}")
56
 
57
  # Initialize the audio tokenizer
58
- try:
59
- print("Initializing audio tokenizer...")
60
- audio_tokenizer = AudioTokenizerV2(
61
- tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
62
- )
63
- print("Audio tokenizer initialized")
64
- except Exception as e:
65
- print(f"Error initializing audio tokenizer: {str(e)}")
66
- raise
67
 
68
  # Load the model
69
- try:
70
- print("Loading model...")
71
- model = AutoModelForCausalLM.from_pretrained(
72
- tokenizer_path, torch_dtype="auto"
73
- ).to(audio_tokenizer.device)
74
- print("Model loaded")
75
- except Exception as e:
76
- print(f"Error loading model: {str(e)}")
77
- raise
78
 
79
  # Function to generate speech
80
  def generate_speech(text, language, speaker_name, temperature=0.1, repetition_penalty=1.1):
 
 
 
81
  # Create prompt
82
  prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=speaker_name)
 
83
 
84
  # Tokenize prompt
85
  input_ids = audio_tokenizer.tokenize_prompt(prompt)
 
86
 
87
  # Generate output
88
  output = model.generate(
@@ -91,21 +88,25 @@ def generate_speech(text, language, speaker_name, temperature=0.1, repetition_pe
91
  repetition_penalty=repetition_penalty,
92
  max_length=4000,
93
  )
 
94
 
95
  # Get audio codes and convert to audio
96
  codes = audio_tokenizer.get_codes(output)
 
 
97
  audio = audio_tokenizer.get_audio(codes)
 
98
 
99
  # Save audio to file
100
  output_path = "output.wav"
101
  torchaudio.save(output_path, audio, sample_rate=24000)
 
102
 
103
  return output_path
104
 
105
  # Create Gradio interface
106
  def tts_interface(text, language, speaker_name, temperature, repetition_penalty):
107
  try:
108
- print(f"Generating speech for: {text[:30]}...")
109
  audio_path = generate_speech(
110
  text,
111
  language,
@@ -113,10 +114,11 @@ def tts_interface(text, language, speaker_name, temperature, repetition_penalty)
113
  temperature,
114
  repetition_penalty
115
  )
116
- print("Speech generated successfully")
117
  return audio_path
118
  except Exception as e:
119
- print(f"Error in tts_interface: {str(e)}")
 
 
120
  return f"Error: {str(e)}"
121
 
122
  # Define available languages and speakers
@@ -135,7 +137,12 @@ demo = gr.Interface(
135
  ],
136
  outputs=gr.Audio(type="filepath"),
137
  title="YarnGPT Text-to-Speech",
138
- description="Convert text to speech using YarnGPT model for various African languages",
 
 
 
 
 
139
  )
140
 
141
  # Launch the app
 
13
  from transformers import AutoModelForCausalLM, AutoTokenizer
14
  from outetts.wav_tokenizer.decoder import WavTokenizer
15
 
16
+ # Clone the YarnGPT repository if it doesn't exist
17
+ if not os.path.exists("yarngpt"):
18
+ print("Cloning YarnGPT repository...")
19
+ subprocess.run(["git", "clone", "https://github.com/saheedniyi02/yarngpt.git"], check=True)
20
+
21
+ # Add the yarngpt directory to the Python path
22
+ yarngpt_path = os.path.abspath("yarngpt")
23
+ if yarngpt_path not in sys.path:
24
+ sys.path.append(yarngpt_path)
25
+ print(f"Added {yarngpt_path} to Python path")
26
+
27
+ # Now try importing from yarngpt
28
+ from yarngpt.audiotokenizer import AudioTokenizerV2
29
 
30
+ # Download model files if they don't exist
31
  wav_tokenizer_config_path = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
32
  wav_tokenizer_model_path = "wavtokenizer_large_speech_320_24k.ckpt"
33
 
34
+ if not os.path.exists(wav_tokenizer_config_path):
35
+ print(f"Downloading {wav_tokenizer_config_path}...")
36
+ subprocess.run([
37
+ "wget",
38
+ "https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
39
+ ], check=True)
 
40
 
41
+ if not os.path.exists(wav_tokenizer_model_path):
42
+ print(f"Downloading {wav_tokenizer_model_path}...")
43
+ subprocess.run([
44
+ "wget",
45
+ "https://huggingface.co/novateur/WavTokenizer-large-speech-75token/resolve/main/wavtokenizer_large_speech_320_24k.ckpt"
46
+ ], check=True)
47
 
48
  # Initialize paths and models
49
  tokenizer_path = "saheedniyi/YarnGPT2"
50
 
51
+ # Print debug info
52
  print(f"Current directory: {os.getcwd()}")
53
  print(f"Files in directory: {os.listdir('.')}")
54
  print(f"Config exists: {os.path.exists(wav_tokenizer_config_path)}")
55
  print(f"Model exists: {os.path.exists(wav_tokenizer_model_path)}")
56
 
57
  # Initialize the audio tokenizer
58
+ print("Initializing audio tokenizer...")
59
+ audio_tokenizer = AudioTokenizerV2(
60
+ tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
61
+ )
62
+ print("Audio tokenizer initialized")
 
 
 
 
63
 
64
  # Load the model
65
+ print("Loading model...")
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ tokenizer_path, torch_dtype="auto"
68
+ ).to(audio_tokenizer.device)
69
+ print("Model loaded successfully")
 
 
 
 
70
 
71
  # Function to generate speech
72
  def generate_speech(text, language, speaker_name, temperature=0.1, repetition_penalty=1.1):
73
+ print(f"Generating speech for: '{text[:50]}...'")
74
+ print(f"Parameters: language={language}, speaker={speaker_name}, temp={temperature}, rep_penalty={repetition_penalty}")
75
+
76
  # Create prompt
77
  prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=speaker_name)
78
+ print("Prompt created")
79
 
80
  # Tokenize prompt
81
  input_ids = audio_tokenizer.tokenize_prompt(prompt)
82
+ print("Prompt tokenized")
83
 
84
  # Generate output
85
  output = model.generate(
 
88
  repetition_penalty=repetition_penalty,
89
  max_length=4000,
90
  )
91
+ print("Model generation complete")
92
 
93
  # Get audio codes and convert to audio
94
  codes = audio_tokenizer.get_codes(output)
95
+ print("Audio codes extracted")
96
+
97
  audio = audio_tokenizer.get_audio(codes)
98
+ print("Audio generated")
99
 
100
  # Save audio to file
101
  output_path = "output.wav"
102
  torchaudio.save(output_path, audio, sample_rate=24000)
103
+ print(f"Audio saved to {output_path}")
104
 
105
  return output_path
106
 
107
  # Create Gradio interface
108
  def tts_interface(text, language, speaker_name, temperature, repetition_penalty):
109
  try:
 
110
  audio_path = generate_speech(
111
  text,
112
  language,
 
114
  temperature,
115
  repetition_penalty
116
  )
 
117
  return audio_path
118
  except Exception as e:
119
+ import traceback
120
+ error_details = traceback.format_exc()
121
+ print(f"Error in tts_interface: {str(e)}\n{error_details}")
122
  return f"Error: {str(e)}"
123
 
124
  # Define available languages and speakers
 
137
  ],
138
  outputs=gr.Audio(type="filepath"),
139
  title="YarnGPT Text-to-Speech",
140
+ description="Convert text to speech using YarnGPT model for various African languages.",
141
+ examples=[
142
+ ["The election was won by businessman and politician, Moshood Abiola, but Babangida annulled the results, citing concerns over national security.", "english", "idera", 0.1, 1.1],
143
+ ["Hello, how are you today?", "english", "enitan", 0.1, 1.1],
144
+ ["Bawo ni?", "yoruba", "eniola", 0.2, 1.2],
145
+ ]
146
  )
147
 
148
  # Launch the app