import streamlit as st from speechbrain.pretrained import GraphemeToPhoneme import os import torchaudio from wav2vecasr.MispronounciationDetector import MispronounciationDetector from wav2vecasr.PhonemeASRModel import Wav2Vec2PhonemeASRModel, Wav2Vec2OptimisedPhonemeASRModel @st.cache_resource def load_model(): path = os.path.join(os.getcwd(), "wav2vecasr", "model", "checkpoint-600") asr_model = Wav2Vec2OptimisedPhonemeASRModel(path, os.path.join(path, "wav2vec2_vocab_final.json"), os.path.join(os.getcwd(), "wav2vecasr", "pretrained_models", "ken-lm-ngram")) g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") mispronounciation_detector = MispronounciationDetector(asr_model, g2p, "cpu") return mispronounciation_detector def save_file(sound_file): # save your sound file in the right folder by following the path audio_folder_path = os.path.join(os.getcwd(), 'audio_files') if not os.path.exists(audio_folder_path): os.makedirs(audio_folder_path) with open(os.path.join(audio_folder_path, sound_file.name), 'wb') as f: f.write(sound_file.getbuffer()) return sound_file.name @st.cache_data def get_audio(saved_sound_filename): audio_path = f'audio_files/{saved_sound_filename}' audio, org_sr = torchaudio.load(audio_path) audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000) audio = audio.view(audio.shape[1]) return audio def mispronounciation_detection_section(): st.write('# Prediction') st.write('1. Upload a recording of you saying the text in .wav format') uploaded_file = st.file_uploader(' ', type='wav') st.write('2. Input the text you are saying in your recording') text = st.text_input( "Enter the text you want to read 👇", label_visibility='collapsed' ) if st.button('Predict'): if uploaded_file is not None and len(text) > 0: # get audio from loaded file save_file(uploaded_file) audio = get_audio(uploaded_file.name) # load model mispronunciation_detector = load_model() # start prediction st.write('# Detection Results') with st.spinner('Predicting...'): raw_info = mispronunciation_detector.detect(audio, text) st.write('#### Phoneme Level Analysis') st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}") # enable horizontal scrolling for phoneme output #st.text_area(label="Aligned phoneme outputs", value=raw_info['phoneme_output'],height=150) st.markdown( f""" ``` {" ".join(raw_info['ref'])} {" ".join(raw_info['hyp'])} {" ".join(raw_info['phoneme_errors'])} ``` """, unsafe_allow_html=True, ) st.divider() md = [] for word, has_error in zip(raw_info["words"], raw_info["word_errors"]): if has_error: md.append(f"**{word}**") else: md.append(word) st.write('#### Word Level Analysis') st.write(f"Word Error Rate: {round(raw_info['wer'], 2)} and the following words in bold have errors:") st.markdown(" ".join(md)) else: st.error('The audio or text has not been properly input', icon="🚨") return if __name__ == '__main__': st.write('___') # create a sidebar st.sidebar.title('Pronounciation Evaluation') select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection'], key='1', label_visibility='collapsed') st.sidebar.write(select) if select=='Mispronounciation Detection': mispronounciation_detection_section() # else: stay on the home page else: st.write('# Pronounciation Evaluation') st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.') st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text')