Spaces:
Runtime error
Runtime error
import streamlit as st | |
from speechbrain.pretrained import GraphemeToPhoneme | |
import os | |
import torchaudio | |
from wav2vecasr.MispronounciationDetector import MispronounciationDetector | |
from wav2vecasr.PhonemeASRModel import MultitaskPhonemeASRModel | |
import json | |
import os | |
import random | |
import openai | |
from gtts import gTTS | |
from io import BytesIO | |
openai.api_key = os.getenv("OPENAI_KEY") | |
# https://gtts.readthedocs.io/en/latest/ | |
# | |
def tts_gtts(text): | |
mp3_fp = BytesIO() | |
tts = gTTS(text, lang="en") | |
tts.write_to_fp(mp3_fp) | |
return mp3_fp | |
def pronounce(text): | |
if len(text) > 0: | |
data = tts_gtts(text) | |
return data | |
return [] | |
def load_model(): | |
path = os.path.join(os.getcwd(), "wav2vecasr", "model", "multitask_best_ctc.pt") | |
vocab_path = os.path.join(os.getcwd(), "wav2vecasr", "model", "vocab") | |
device = "cpu" | |
asr_model = MultitaskPhonemeASRModel(path, vocab_path, device) | |
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") | |
mispronounciation_detector = MispronounciationDetector(asr_model, g2p, device) | |
return mispronounciation_detector | |
def save_file(sound_file): | |
# save your sound file in the right folder by following the path | |
audio_folder_path = os.path.join(os.getcwd(), 'audio_files') | |
if not os.path.exists(audio_folder_path): | |
os.makedirs(audio_folder_path) | |
with open(os.path.join(audio_folder_path, sound_file.name), 'wb') as f: | |
f.write(sound_file.getbuffer()) | |
return sound_file.name | |
def get_audio(saved_sound_filename): | |
audio_path = f'audio_files/{saved_sound_filename}' | |
audio, org_sr = torchaudio.load(audio_path) | |
audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000) | |
audio = audio.view(audio.shape[1]) | |
return audio | |
def get_prompts(): | |
prompts_path = os.path.join(os.getcwd(), "wav2vecasr", "data", "prompts.json") | |
f = open(prompts_path) | |
data = json.load(f) | |
prompts = data["prompts"] | |
return prompts | |
def get_articulation_videos(): | |
# note -- not all arpabets could be mapped to a video with visualisation on articulation | |
path = os.path.join(os.getcwd(), "wav2vecasr", "data", "videos.json") | |
f = open(path) | |
data = json.load(f) | |
return data | |
def get_prompts_from_l2_arctic(prompts, current_prompt, num_to_get): | |
selected_prompts = [] | |
while len(selected_prompts) < num_to_get: | |
prompt = random.choice(prompts) | |
if prompt not in selected_prompts and prompt != current_prompt: | |
selected_prompts.append(prompt) | |
return selected_prompts | |
def get_prompt_from_openai(words_with_error_list): | |
try: | |
words_with_errors = ", ".join(words_with_error_list) | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are writing practice reading prompts for learners of English to practice pronunciation. These prompts should be short, easy to understand and useful."}, | |
{"role": "user", "content": f"Write a short sentence of less than 10 words and include the following words in the sentence: {words_with_errors} No numbers."} | |
] | |
) | |
return response['choices'][0]['message']['content'] | |
except: | |
return "" | |
def mispronounciation_detection_section(): | |
st.write('# Prediction') | |
st.write('1. Upload a recording of you saying the text in .wav format') | |
uploaded_file = st.file_uploader(' ', type='wav') | |
st.write('2. Input the text you are saying in your recording') | |
text = st.text_input( | |
"Enter the text you want to read π", | |
label_visibility='collapsed' | |
) | |
if st.button('Predict'): | |
if uploaded_file is not None and len(text) > 0: | |
# get audio from loaded file | |
save_file(uploaded_file) | |
audio = get_audio(uploaded_file.name) | |
# load model | |
mispronunciation_detector = load_model() | |
st.write('# Detection Results') | |
with st.spinner('Predicting...'): | |
# detect | |
raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25) | |
# display prediction results for phonemes | |
st.write('#### Phoneme Level Analysis') | |
st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}") | |
st.markdown( | |
f""" | |
<style> | |
textarea {{ | |
white-space: nowrap; | |
}} | |
</style> | |
``` | |
{raw_info['ref']} | |
{raw_info['hyp']} | |
{raw_info['phoneme_errors']} | |
``` | |
""", | |
unsafe_allow_html=True, | |
) | |
st.divider() | |
# display word errors | |
md = [] | |
words_with_errors = [] | |
for word, has_error in zip(raw_info["words"], raw_info["word_errors"]): | |
if has_error: | |
words_with_errors.append(word) | |
md.append(f"**{word}**") | |
else: | |
md.append(word) | |
st.write('#### Word Level Analysis') | |
st.write(f"Word Error Rate: {round(raw_info['wer'], 2)} and the following words in bold have errors:") | |
st.markdown(" ".join(md)) | |
st.divider() | |
st.write('#### What is next?') | |
# display pronounciation e.g. | |
st.write("Compare your pronunciation to pronounced sample") | |
st.audio(f'audio_files/{uploaded_file.name}', format="audio/wav", start_time=0) | |
pronounced_sample = pronounce(text) | |
st.audio(pronounced_sample, format="audio/wav", start_time=0) | |
# display more prompts to practice -- 1 from ChatGPT -- based on user's mistakes, 2 from L2 Arctic | |
st.write('Here are some more prompts for you to practice:') | |
selected_prompts = [] | |
unique_words_with_errors = list(set(words_with_errors)) | |
prompt_for_mistakes_made = get_prompt_from_openai(unique_words_with_errors) | |
if prompt_for_mistakes_made: | |
selected_prompts.append(prompt_for_mistakes_made) | |
prompts = get_prompts() | |
l2_arctic_prompts = get_prompts_from_l2_arctic(prompts, text, 3-len(selected_prompts)) | |
selected_prompts.extend(l2_arctic_prompts) | |
for prompt in selected_prompts: | |
st.code(f'''{prompt}''', language="python") | |
else: | |
st.error('The audio or text has not been properly input', icon="π¨") | |
return | |
def video_section(): | |
st.write('# Get helpful videos on phoneme articulation') | |
problem_phoneme = st.text_input( | |
"Enter the phoneme you had problems with π" | |
) | |
arpabet_to_video_map = get_articulation_videos() | |
if st.button('Look up'): | |
if not problem_phoneme: | |
st.error('The audio or text has not been properly input', icon="π¨") | |
elif problem_phoneme in arpabet_to_video_map: | |
video_link = arpabet_to_video_map[problem_phoneme]["link"] | |
if video_link: | |
st.video(video_link) | |
else: | |
st.write("Sorry, we couldn't find a good enough video yet :( we are working on it!") | |
if __name__ == '__main__': | |
st.write('___') | |
# create a sidebar | |
st.sidebar.title('Pronounciation Evaluation') | |
select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection', 'Helpful Videos for Problem Phonemes'], key='1', label_visibility='collapsed') | |
st.sidebar.write(select) | |
if select=='Mispronounciation Detection': | |
mispronounciation_detection_section() | |
elif select=="Helpful Videos for Problem Phonemes": | |
video_section() | |
else: | |
st.write('# Pronounciation Evaluation') | |
st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.') | |
st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text') | |