Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
from speechbrain.pretrained import GraphemeToPhoneme | |
import os | |
import torchaudio | |
from wav2vecasr.MispronounciationDetector import MispronounciationDetector | |
def load_model(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
path = os.path.join(os.getcwd(), "wav2vecasr", "model", "checkpoint-1200") | |
model = Wav2Vec2ForCTC.from_pretrained(path).to(device) | |
processor = Wav2Vec2Processor.from_pretrained(path) | |
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") | |
mispronounciation_detector = MispronounciationDetector(model, processor, g2p, "cpu") | |
return mispronounciation_detector | |
def save_file(sound_file): | |
# save your sound file in the right folder by following the path | |
with open(os.path.join(os.getcwd(), 'audio_files', sound_file.name), 'wb') as f: | |
f.write(sound_file.getbuffer()) | |
return sound_file.name | |
def get_audio(saved_sound_filename): | |
audio_path = f'audio_files/{saved_sound_filename}' | |
audio, org_sr = torchaudio.load(audio_path) | |
audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000) | |
audio = audio.view(audio.shape[1]) | |
return audio | |
def mispronounciation_detection_section(): | |
st.write('# Prediction') | |
st.write('1. Upload a recording of you saying the text in .wav format') | |
uploaded_file = st.file_uploader(' ', type='wav') | |
st.write('2. Input the text you are saying in your recording') | |
text = st.text_input( | |
"Enter the text you want to read π", | |
label_visibility='collapsed' | |
) | |
if st.button('Predict'): | |
if uploaded_file is not None and len(text) > 0: | |
# get audio from loaded file | |
save_file(uploaded_file) | |
audio = get_audio(uploaded_file.name) | |
# load model | |
mispronunciation_detector = load_model() | |
# start prediction | |
st.write('# Detection Results') | |
with st.spinner('Predicting...'): | |
raw_info = mispronunciation_detector.detect(audio, text) | |
st.write('#### Phoneme Level Analysis') | |
st.markdown(f"Phoneme Error Rate: ___{round(raw_info['per'],2)}___") | |
# enable horizontal scrolling for phoneme output | |
#st.text_area(label="Aligned phoneme outputs", value=raw_info['phoneme_output'],height=150) | |
st.markdown( | |
f""" | |
<style> | |
textarea {{ | |
white-space: nowrap; | |
}} | |
</style> | |
``` | |
{" ".join(raw_info['ref'])} | |
{" ".join(raw_info['hyp'])} | |
{" ".join(raw_info['phoneme_errors'])} | |
``` | |
""", | |
unsafe_allow_html=True, | |
) | |
st.divider() | |
md = [] | |
for word, has_error in zip(raw_info["words"], raw_info["word_errors"]): | |
if has_error: | |
md.append(f"**{word}**") | |
else: | |
md.append(word) | |
st.write('#### Word Level Analysis') | |
st.write(f"Word Error Rate: ___{round(raw_info['wer'], 2)}___ and the following words in bold have errors:") | |
st.markdown(" ".join(md)) | |
else: | |
st.error('The audio or text has not been properly input', icon="π¨") | |
return | |
if __name__ == '__main__': | |
st.write('___') | |
# create a sidebar | |
st.sidebar.title('Pronounciation Evaluation') | |
select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection'], key='1', label_visibility='collapsed') | |
st.sidebar.write(select) | |
if select=='Mispronounciation Detection': | |
mispronounciation_detection_section() | |
# else: stay on the home page | |
else: | |
st.write('# Pronounciation Evaluation') | |
st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.') | |
st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text') | |