File size: 4,344 Bytes
0e6999d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from speechbrain.pretrained import GraphemeToPhoneme
import os
import torchaudio
from wav2vecasr.MispronounciationDetector import MispronounciationDetector

@st.cache_resource
def load_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    path = os.path.join(os.getcwd(), "wav2vecasr", "model", "checkpoint-1200")
    model = Wav2Vec2ForCTC.from_pretrained(path).to(device)
    processor = Wav2Vec2Processor.from_pretrained(path)
    g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
    mispronounciation_detector = MispronounciationDetector(model, processor, g2p, "cpu")
    return mispronounciation_detector


def save_file(sound_file):
    # save your sound file in the right folder by following the path
    with open(os.path.join(os.getcwd(), 'audio_files', sound_file.name), 'wb') as f:
         f.write(sound_file.getbuffer())
    return sound_file.name

@st.cache_data
def get_audio(saved_sound_filename):
    audio_path = f'audio_files/{saved_sound_filename}'
    audio, org_sr = torchaudio.load(audio_path)
    audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
    audio = audio.view(audio.shape[1])
    return audio

def mispronounciation_detection_section():
    st.write('# Prediction')
    st.write('1. Upload a recording of you saying the text in .wav format')
    uploaded_file = st.file_uploader(' ', type='wav')
    st.write('2. Input the text you are saying in your recording')
    text = st.text_input(
        "Enter the text you want to read πŸ‘‡",
        label_visibility='collapsed'
    )
    if st.button('Predict'):
        if uploaded_file is not None and len(text) > 0:
            # get audio from loaded file
            save_file(uploaded_file)
            audio = get_audio(uploaded_file.name)

            # load model
            mispronunciation_detector = load_model()

            # start prediction
            st.write('# Detection Results')
            with st.spinner('Predicting...'):
                raw_info = mispronunciation_detector.detect(audio, text)

                st.write('#### Phoneme Level Analysis')
                st.markdown(f"Phoneme Error Rate: ___{round(raw_info['per'],2)}___")
                # enable horizontal scrolling for phoneme output
                #st.text_area(label="Aligned phoneme outputs", value=raw_info['phoneme_output'],height=150)
                st.markdown(
                f"""
                <style>
                textarea {{
                    white-space: nowrap;
                }}
                </style>
                ```
                {" ".join(raw_info['ref'])}
                {" ".join(raw_info['hyp'])}
                {" ".join(raw_info['phoneme_errors'])}
                ```
                """,
                    unsafe_allow_html=True,
                )

                st.divider()
                md = []
                for word, has_error in zip(raw_info["words"], raw_info["word_errors"]):
                    if has_error:
                        md.append(f"**{word}**")
                    else:
                        md.append(word)

                st.write('#### Word Level Analysis')
                st.write(f"Word Error Rate: ___{round(raw_info['wer'], 2)}___ and the following words in bold have errors:")
                st.markdown(" ".join(md))
        else:
            st.error('The audio or text has not been properly input', icon="🚨")
    return

if __name__ == '__main__':
    st.write('___')
    # create a sidebar
    st.sidebar.title('Pronounciation Evaluation')
    select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection'], key='1', label_visibility='collapsed')
    st.sidebar.write(select)
    if select=='Mispronounciation Detection':
        mispronounciation_detection_section()
    # else: stay on the home page
    else:
        st.write('# Pronounciation Evaluation')
        st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.')
        st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text')