File size: 8,426 Bytes
0e6999d
 
 
 
 
13375b8
 
 
 
 
da306a6
 
13375b8
 
da306a6
 
 
 
 
 
 
 
 
 
 
 
 
0e6999d
 
 
1e93f37
 
 
 
0e6999d
1e93f37
0e6999d
 
 
 
 
cca7a66
 
 
 
0e6999d
 
 
 
 
 
 
 
 
 
 
13375b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e6999d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13375b8
 
1e93f37
0e6999d
13375b8
0e6999d
cca7a66
0e6999d
 
 
 
 
 
 
 
1e93f37
 
 
0e6999d
 
 
 
 
 
13375b8
 
0e6999d
13375b8
0e6999d
 
13375b8
0e6999d
 
 
 
 
cca7a66
0e6999d
13375b8
 
 
 
 
da306a6
3f0393b
 
da306a6
 
 
 
13375b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e6999d
 
 
 
13375b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e6999d
 
 
 
13375b8
0e6999d
 
 
13375b8
 
0e6999d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import streamlit as st
from speechbrain.pretrained import GraphemeToPhoneme
import os
import torchaudio
from wav2vecasr.MispronounciationDetector import MispronounciationDetector
from wav2vecasr.PhonemeASRModel import MultitaskPhonemeASRModel
import json
import os
import random
import openai
from gtts import gTTS
from io import BytesIO

openai.api_key = os.getenv("OPENAI_KEY")
# https://gtts.readthedocs.io/en/latest/
#
def tts_gtts(text):
    mp3_fp = BytesIO()
    tts = gTTS(text, lang="en")
    tts.write_to_fp(mp3_fp)
    return mp3_fp

def pronounce(text):
    if len(text) > 0:
        data = tts_gtts(text)
        return data
    return []

@st.cache_resource
def load_model():
    path = os.path.join(os.getcwd(), "wav2vecasr", "model", "multitask_best_ctc.pt")
    vocab_path = os.path.join(os.getcwd(), "wav2vecasr", "model", "vocab")
    device = "cpu"
    asr_model = MultitaskPhonemeASRModel(path, vocab_path, device)
    g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
    mispronounciation_detector = MispronounciationDetector(asr_model, g2p, device)
    return mispronounciation_detector


def save_file(sound_file):
    # save your sound file in the right folder by following the path
    audio_folder_path = os.path.join(os.getcwd(), 'audio_files')
    if not os.path.exists(audio_folder_path):
        os.makedirs(audio_folder_path)
    with open(os.path.join(audio_folder_path, sound_file.name), 'wb') as f:
         f.write(sound_file.getbuffer())
    return sound_file.name

@st.cache_data
def get_audio(saved_sound_filename):
    audio_path = f'audio_files/{saved_sound_filename}'
    audio, org_sr = torchaudio.load(audio_path)
    audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
    audio = audio.view(audio.shape[1])
    return audio

@st.cache_data
def get_prompts():
    prompts_path = os.path.join(os.getcwd(), "wav2vecasr", "data", "prompts.json")
    f = open(prompts_path)
    data = json.load(f)
    prompts = data["prompts"]
    return prompts

@st.cache_data
def get_articulation_videos():
    # note -- not all arpabets could be mapped to a video with visualisation on articulation
    path = os.path.join(os.getcwd(), "wav2vecasr", "data", "videos.json")
    f = open(path)
    data = json.load(f)
    return data

def get_prompts_from_l2_arctic(prompts, current_prompt, num_to_get):
    selected_prompts = []
    while len(selected_prompts) < num_to_get:
        prompt = random.choice(prompts)
        if prompt not in selected_prompts and prompt != current_prompt:
            selected_prompts.append(prompt)

    return selected_prompts

def get_prompt_from_openai(words_with_error_list):
    try:
        words_with_errors = ", ".join(words_with_error_list)
        response = openai.ChatCompletion.create(
          model="gpt-3.5-turbo",
          messages=[
            {"role": "system", "content": "You are writing practice reading prompts for learners of English to practice pronunciation. These prompts should be short, easy to understand and useful."},
            {"role": "user", "content": f"Write a short sentence of less than 10 words and include the following words in the sentence: {words_with_errors} No numbers."}
          ]
        )

        return response['choices'][0]['message']['content']
    except:
        return ""

def mispronounciation_detection_section():
    st.write('# Prediction')
    st.write('1. Upload a recording of you saying the text in .wav format')
    uploaded_file = st.file_uploader(' ', type='wav')
    st.write('2. Input the text you are saying in your recording')
    text = st.text_input(
        "Enter the text you want to read πŸ‘‡",
        label_visibility='collapsed'
    )
    if st.button('Predict'):
        if uploaded_file is not None and len(text) > 0:
            # get audio from loaded file
            save_file(uploaded_file)
            audio = get_audio(uploaded_file.name)

            # load model
            mispronunciation_detector = load_model()

            st.write('# Detection Results')
            with st.spinner('Predicting...'):

                # detect
                raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25)

                # display prediction results for phonemes
                st.write('#### Phoneme Level Analysis')
                st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}")
                st.markdown(
                f"""
                <style>
                textarea {{
                    white-space: nowrap;
                }}
                </style>
                ```
                {raw_info['ref']}
                {raw_info['hyp']}
                {raw_info['phoneme_errors']}
                ```
                """,
                    unsafe_allow_html=True,
                )

                st.divider()

                # display word errors
                md = []
                words_with_errors = []
                for word, has_error in zip(raw_info["words"], raw_info["word_errors"]):
                    if has_error:
                        words_with_errors.append(word)
                        md.append(f"**{word}**")
                    else:
                        md.append(word)

                st.write('#### Word Level Analysis')
                st.write(f"Word Error Rate: {round(raw_info['wer'], 2)} and the following words in bold have errors:")
                st.markdown(" ".join(md))

                st.divider()

                st.write('#### What is next?')

                # display pronounciation e.g.
                st.write("Compare your pronunciation to pronounced sample")
                st.audio(f'audio_files/{uploaded_file.name}', format="audio/wav", start_time=0)
                pronounced_sample = pronounce(text)
                st.audio(pronounced_sample, format="audio/wav", start_time=0)

                # display more prompts to practice -- 1 from ChatGPT -- based on user's mistakes, 2 from L2 Arctic
                st.write('Here are some more prompts for you to practice:')

                selected_prompts = []

                unique_words_with_errors = list(set(words_with_errors))
                prompt_for_mistakes_made = get_prompt_from_openai(unique_words_with_errors)
                if prompt_for_mistakes_made:
                    selected_prompts.append(prompt_for_mistakes_made)

                prompts = get_prompts()
                l2_arctic_prompts = get_prompts_from_l2_arctic(prompts, text, 3-len(selected_prompts))
                selected_prompts.extend(l2_arctic_prompts)

                for prompt in selected_prompts:
                    st.code(f'''{prompt}''', language="python")


        else:
            st.error('The audio or text has not been properly input', icon="🚨")
    return

def video_section():
    st.write('# Get helpful videos on phoneme articulation')

    problem_phoneme = st.text_input(
        "Enter the phoneme you had problems with πŸ‘‡"
    )

    arpabet_to_video_map = get_articulation_videos()

    if st.button('Look up'):
        if not problem_phoneme:
            st.error('The audio or text has not been properly input', icon="🚨")
        elif problem_phoneme in arpabet_to_video_map:
            video_link = arpabet_to_video_map[problem_phoneme]["link"]
            if video_link:
                st.video(video_link)
            else:
                st.write("Sorry, we couldn't find a good enough video yet :(  we are working on it!")

if __name__ == '__main__':
    st.write('___')
    # create a sidebar
    st.sidebar.title('Pronounciation Evaluation')
    select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection', 'Helpful Videos for Problem Phonemes'], key='1', label_visibility='collapsed')
    st.sidebar.write(select)
    if select=='Mispronounciation Detection':
        mispronounciation_detection_section()
    elif select=="Helpful Videos for Problem Phonemes":
        video_section()
    else:
        st.write('# Pronounciation Evaluation')
        st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.')
        st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text')