okewunmi commited on
Commit
b5d25fc
·
verified ·
1 Parent(s): 388787c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -65
app.py CHANGED
@@ -1,50 +1,36 @@
1
- import gradio as gr
2
- import torch
3
- import torchaudio
4
  import os
5
  import re
6
- import subprocess
7
- from transformers import AutoModelForCausalLM
8
- from yarngpt_utils import AudioTokenizer
 
 
 
 
 
 
 
 
9
 
10
- # Download model files if they don't exist
11
- def download_if_not_exists(url, filename):
12
- if not os.path.exists(filename):
13
- print(f"Downloading {filename}...")
14
- subprocess.run(["wget", url, "-O", filename])
15
- print(f"Downloaded {filename}")
16
 
17
- # Download necessary files
18
- download_if_not_exists(
19
- "https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml",
20
- "wavtokenizer_config.yaml"
21
- )
22
- download_if_not_exists(
23
- "https://huggingface.co/novateur/WavTokenizer-large-speech-75token/blob/main/wavtokenizer_large_speech_320_v2.ckpt",
24
- "wavtokenizer_model.ckpt"
25
  )
26
 
27
- # Initialize the model (this runs when the app starts)
28
- def initialize_model():
29
- # Set paths
30
- hf_path = "saheedniyi/YarnGPT"
31
- wav_tokenizer_config_path = "wavtokenizer_config.yaml"
32
- wav_tokenizer_model_path = "wavtokenizer_model.ckpt"
33
-
34
- # Create AudioTokenizer
35
- audio_tokenizer = AudioTokenizer(
36
- hf_path, wav_tokenizer_model_path, wav_tokenizer_config_path
37
- )
38
-
39
- # Load model
40
- model = AutoModelForCausalLM.from_pretrained(hf_path, torch_dtype="auto").to(audio_tokenizer.device)
41
-
42
- return model, audio_tokenizer
43
 
44
- # Generate audio from text
45
- def generate_speech(text, speaker_name):
46
  # Create prompt
47
- prompt = audio_tokenizer.create_prompt(text, speaker_name)
48
 
49
  # Tokenize prompt
50
  input_ids = audio_tokenizer.tokenize_prompt(prompt)
@@ -52,45 +38,54 @@ def generate_speech(text, speaker_name):
52
  # Generate output
53
  output = model.generate(
54
  input_ids=input_ids,
55
- temperature=0.1,
56
- repetition_penalty=1.1,
57
  max_length=4000,
58
  )
59
 
60
- # Convert to audio codes
61
  codes = audio_tokenizer.get_codes(output)
62
-
63
- # Convert codes to audio
64
  audio = audio_tokenizer.get_audio(codes)
65
 
66
- # Save audio temporarily
67
- temp_path = "output.wav"
68
- torchaudio.save(temp_path, audio, sample_rate=24000)
69
 
70
- return temp_path
71
-
72
- # Load model globally
73
- print("Loading model...")
74
- model, audio_tokenizer = initialize_model()
75
- print("Model loaded!")
76
-
77
- # Add this before initializing the model
78
- from inspect import signature
79
- from outetts.wav_tokenizer.decoder import WavTokenizer
80
- print("WavTokenizer parameters:", signature(WavTokenizer.__init__))
81
 
82
  # Create Gradio interface
83
- speakers = ["idera", "emma", "jude", "osagie", "tayo", "zainab", "joke", "regina", "remi", "umar", "chinenye"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
 
85
  demo = gr.Interface(
86
- fn=generate_speech,
87
  inputs=[
88
- gr.Textbox(lines=5, placeholder="Enter text here..."),
89
- gr.Dropdown(choices=speakers, label="Speaker", value="idera")
 
 
 
90
  ],
91
  outputs=gr.Audio(type="filepath"),
92
- title="YarnGPT: Nigerian Accented Text-to-Speech",
93
- description="Generate natural-sounding Nigerian accented speech from text."
94
  )
95
 
96
- demo.launch()
 
 
 
 
 
 
1
  import os
2
  import re
3
+ import json
4
+ import torch
5
+ import inflect
6
+ import random
7
+ import uroman as ur
8
+ import numpy as np
9
+ import torchaudio
10
+ import gradio as gr
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from outetts.wav_tokenizer.decoder import WavTokenizer
13
+ from yarngpt.audiotokenizer import AudioTokenizerV2
14
 
15
+ # Initialize paths and models
16
+ tokenizer_path = "saheedniyi/YarnGPT2"
17
+ wav_tokenizer_config_path = "wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
18
+ wav_tokenizer_model_path = "wavtokenizer_large_speech_320_24k.ckpt"
 
 
19
 
20
+ # Initialize the audio tokenizer
21
+ audio_tokenizer = AudioTokenizerV2(
22
+ tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
 
 
 
 
 
23
  )
24
 
25
+ # Load the model
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ tokenizer_path, torch_dtype="auto"
28
+ ).to(audio_tokenizer.device)
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # Function to generate speech
31
+ def generate_speech(text, language, speaker_name, temperature=0.1, repetition_penalty=1.1):
32
  # Create prompt
33
+ prompt = audio_tokenizer.create_prompt(text, lang=language, speaker_name=speaker_name)
34
 
35
  # Tokenize prompt
36
  input_ids = audio_tokenizer.tokenize_prompt(prompt)
 
38
  # Generate output
39
  output = model.generate(
40
  input_ids=input_ids,
41
+ temperature=temperature,
42
+ repetition_penalty=repetition_penalty,
43
  max_length=4000,
44
  )
45
 
46
+ # Get audio codes and convert to audio
47
  codes = audio_tokenizer.get_codes(output)
 
 
48
  audio = audio_tokenizer.get_audio(codes)
49
 
50
+ # Save audio to file
51
+ output_path = "output.wav"
52
+ torchaudio.save(output_path, audio, sample_rate=24000)
53
 
54
+ return output_path
 
 
 
 
 
 
 
 
 
 
55
 
56
  # Create Gradio interface
57
+ def tts_interface(text, language, speaker_name, temperature, repetition_penalty):
58
+ try:
59
+ audio_path = generate_speech(
60
+ text,
61
+ language,
62
+ speaker_name,
63
+ temperature,
64
+ repetition_penalty
65
+ )
66
+ return audio_path
67
+ except Exception as e:
68
+ return f"Error: {str(e)}"
69
+
70
+ # Define available languages and speakers
71
+ languages = ["english", "igbo", "yoruba", "hausa", "pidgin"]
72
+ speakers = ["idera", "enitan", "abeo", "eniola", "kachi", "aisha", "amara", "bello", "chidi"]
73
 
74
+ # Create the Gradio interface
75
  demo = gr.Interface(
76
+ fn=tts_interface,
77
  inputs=[
78
+ gr.Textbox(label="Text to convert to speech", lines=5),
79
+ gr.Dropdown(languages, label="Language", value="english"),
80
+ gr.Dropdown(speakers, label="Speaker", value="idera"),
81
+ gr.Slider(0.1, 1.0, value=0.1, label="Temperature"),
82
+ gr.Slider(1.0, 2.0, value=1.1, label="Repetition Penalty"),
83
  ],
84
  outputs=gr.Audio(type="filepath"),
85
+ title="YarnGPT Text-to-Speech",
86
+ description="Convert text to speech using YarnGPT model for various African languages",
87
  )
88
 
89
+ # Launch the app
90
+ if __name__ == "__main__":
91
+ demo.launch()