Spaces:
Runtime error
Runtime error
File size: 4,344 Bytes
0e6999d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import streamlit as st
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from speechbrain.pretrained import GraphemeToPhoneme
import os
import torchaudio
from wav2vecasr.MispronounciationDetector import MispronounciationDetector
@st.cache_resource
def load_model():
device = "cuda" if torch.cuda.is_available() else "cpu"
path = os.path.join(os.getcwd(), "wav2vecasr", "model", "checkpoint-1200")
model = Wav2Vec2ForCTC.from_pretrained(path).to(device)
processor = Wav2Vec2Processor.from_pretrained(path)
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
mispronounciation_detector = MispronounciationDetector(model, processor, g2p, "cpu")
return mispronounciation_detector
def save_file(sound_file):
# save your sound file in the right folder by following the path
with open(os.path.join(os.getcwd(), 'audio_files', sound_file.name), 'wb') as f:
f.write(sound_file.getbuffer())
return sound_file.name
@st.cache_data
def get_audio(saved_sound_filename):
audio_path = f'audio_files/{saved_sound_filename}'
audio, org_sr = torchaudio.load(audio_path)
audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
audio = audio.view(audio.shape[1])
return audio
def mispronounciation_detection_section():
st.write('# Prediction')
st.write('1. Upload a recording of you saying the text in .wav format')
uploaded_file = st.file_uploader(' ', type='wav')
st.write('2. Input the text you are saying in your recording')
text = st.text_input(
"Enter the text you want to read π",
label_visibility='collapsed'
)
if st.button('Predict'):
if uploaded_file is not None and len(text) > 0:
# get audio from loaded file
save_file(uploaded_file)
audio = get_audio(uploaded_file.name)
# load model
mispronunciation_detector = load_model()
# start prediction
st.write('# Detection Results')
with st.spinner('Predicting...'):
raw_info = mispronunciation_detector.detect(audio, text)
st.write('#### Phoneme Level Analysis')
st.markdown(f"Phoneme Error Rate: ___{round(raw_info['per'],2)}___")
# enable horizontal scrolling for phoneme output
#st.text_area(label="Aligned phoneme outputs", value=raw_info['phoneme_output'],height=150)
st.markdown(
f"""
<style>
textarea {{
white-space: nowrap;
}}
</style>
```
{" ".join(raw_info['ref'])}
{" ".join(raw_info['hyp'])}
{" ".join(raw_info['phoneme_errors'])}
```
""",
unsafe_allow_html=True,
)
st.divider()
md = []
for word, has_error in zip(raw_info["words"], raw_info["word_errors"]):
if has_error:
md.append(f"**{word}**")
else:
md.append(word)
st.write('#### Word Level Analysis')
st.write(f"Word Error Rate: ___{round(raw_info['wer'], 2)}___ and the following words in bold have errors:")
st.markdown(" ".join(md))
else:
st.error('The audio or text has not been properly input', icon="π¨")
return
if __name__ == '__main__':
st.write('___')
# create a sidebar
st.sidebar.title('Pronounciation Evaluation')
select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection'], key='1', label_visibility='collapsed')
st.sidebar.write(select)
if select=='Mispronounciation Detection':
mispronounciation_detection_section()
# else: stay on the home page
else:
st.write('# Pronounciation Evaluation')
st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.')
st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text')
|