bel32123's picture
Add utterance playback
3f0393b
raw
history blame
8.43 kB
import streamlit as st
from speechbrain.pretrained import GraphemeToPhoneme
import os
import torchaudio
from wav2vecasr.MispronounciationDetector import MispronounciationDetector
from wav2vecasr.PhonemeASRModel import MultitaskPhonemeASRModel
import json
import os
import random
import openai
from gtts import gTTS
from io import BytesIO
openai.api_key = os.getenv("OPENAI_KEY")
# https://gtts.readthedocs.io/en/latest/
#
def tts_gtts(text):
mp3_fp = BytesIO()
tts = gTTS(text, lang="en")
tts.write_to_fp(mp3_fp)
return mp3_fp
def pronounce(text):
if len(text) > 0:
data = tts_gtts(text)
return data
return []
@st.cache_resource
def load_model():
path = os.path.join(os.getcwd(), "wav2vecasr", "model", "multitask_best_ctc.pt")
vocab_path = os.path.join(os.getcwd(), "wav2vecasr", "model", "vocab")
device = "cpu"
asr_model = MultitaskPhonemeASRModel(path, vocab_path, device)
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
mispronounciation_detector = MispronounciationDetector(asr_model, g2p, device)
return mispronounciation_detector
def save_file(sound_file):
# save your sound file in the right folder by following the path
audio_folder_path = os.path.join(os.getcwd(), 'audio_files')
if not os.path.exists(audio_folder_path):
os.makedirs(audio_folder_path)
with open(os.path.join(audio_folder_path, sound_file.name), 'wb') as f:
f.write(sound_file.getbuffer())
return sound_file.name
@st.cache_data
def get_audio(saved_sound_filename):
audio_path = f'audio_files/{saved_sound_filename}'
audio, org_sr = torchaudio.load(audio_path)
audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
audio = audio.view(audio.shape[1])
return audio
@st.cache_data
def get_prompts():
prompts_path = os.path.join(os.getcwd(), "wav2vecasr", "data", "prompts.json")
f = open(prompts_path)
data = json.load(f)
prompts = data["prompts"]
return prompts
@st.cache_data
def get_articulation_videos():
# note -- not all arpabets could be mapped to a video with visualisation on articulation
path = os.path.join(os.getcwd(), "wav2vecasr", "data", "videos.json")
f = open(path)
data = json.load(f)
return data
def get_prompts_from_l2_arctic(prompts, current_prompt, num_to_get):
selected_prompts = []
while len(selected_prompts) < num_to_get:
prompt = random.choice(prompts)
if prompt not in selected_prompts and prompt != current_prompt:
selected_prompts.append(prompt)
return selected_prompts
def get_prompt_from_openai(words_with_error_list):
try:
words_with_errors = ", ".join(words_with_error_list)
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are writing practice reading prompts for learners of English to practice pronunciation. These prompts should be short, easy to understand and useful."},
{"role": "user", "content": f"Write a short sentence of less than 10 words and include the following words in the sentence: {words_with_errors} No numbers."}
]
)
return response['choices'][0]['message']['content']
except:
return ""
def mispronounciation_detection_section():
st.write('# Prediction')
st.write('1. Upload a recording of you saying the text in .wav format')
uploaded_file = st.file_uploader(' ', type='wav')
st.write('2. Input the text you are saying in your recording')
text = st.text_input(
"Enter the text you want to read πŸ‘‡",
label_visibility='collapsed'
)
if st.button('Predict'):
if uploaded_file is not None and len(text) > 0:
# get audio from loaded file
save_file(uploaded_file)
audio = get_audio(uploaded_file.name)
# load model
mispronunciation_detector = load_model()
st.write('# Detection Results')
with st.spinner('Predicting...'):
# detect
raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25)
# display prediction results for phonemes
st.write('#### Phoneme Level Analysis')
st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}")
st.markdown(
f"""
<style>
textarea {{
white-space: nowrap;
}}
</style>
```
{raw_info['ref']}
{raw_info['hyp']}
{raw_info['phoneme_errors']}
```
""",
unsafe_allow_html=True,
)
st.divider()
# display word errors
md = []
words_with_errors = []
for word, has_error in zip(raw_info["words"], raw_info["word_errors"]):
if has_error:
words_with_errors.append(word)
md.append(f"**{word}**")
else:
md.append(word)
st.write('#### Word Level Analysis')
st.write(f"Word Error Rate: {round(raw_info['wer'], 2)} and the following words in bold have errors:")
st.markdown(" ".join(md))
st.divider()
st.write('#### What is next?')
# display pronounciation e.g.
st.write("Compare your pronunciation to pronounced sample")
st.audio(f'audio_files/{uploaded_file.name}', format="audio/wav", start_time=0)
pronounced_sample = pronounce(text)
st.audio(pronounced_sample, format="audio/wav", start_time=0)
# display more prompts to practice -- 1 from ChatGPT -- based on user's mistakes, 2 from L2 Arctic
st.write('Here are some more prompts for you to practice:')
selected_prompts = []
unique_words_with_errors = list(set(words_with_errors))
prompt_for_mistakes_made = get_prompt_from_openai(unique_words_with_errors)
if prompt_for_mistakes_made:
selected_prompts.append(prompt_for_mistakes_made)
prompts = get_prompts()
l2_arctic_prompts = get_prompts_from_l2_arctic(prompts, text, 3-len(selected_prompts))
selected_prompts.extend(l2_arctic_prompts)
for prompt in selected_prompts:
st.code(f'''{prompt}''', language="python")
else:
st.error('The audio or text has not been properly input', icon="🚨")
return
def video_section():
st.write('# Get helpful videos on phoneme articulation')
problem_phoneme = st.text_input(
"Enter the phoneme you had problems with πŸ‘‡"
)
arpabet_to_video_map = get_articulation_videos()
if st.button('Look up'):
if not problem_phoneme:
st.error('The audio or text has not been properly input', icon="🚨")
elif problem_phoneme in arpabet_to_video_map:
video_link = arpabet_to_video_map[problem_phoneme]["link"]
if video_link:
st.video(video_link)
else:
st.write("Sorry, we couldn't find a good enough video yet :( we are working on it!")
if __name__ == '__main__':
st.write('___')
# create a sidebar
st.sidebar.title('Pronounciation Evaluation')
select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection', 'Helpful Videos for Problem Phonemes'], key='1', label_visibility='collapsed')
st.sidebar.write(select)
if select=='Mispronounciation Detection':
mispronounciation_detection_section()
elif select=="Helpful Videos for Problem Phonemes":
video_section()
else:
st.write('# Pronounciation Evaluation')
st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.')
st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text')