Spaces:
Runtime error
Runtime error
File size: 8,426 Bytes
0e6999d 13375b8 da306a6 13375b8 da306a6 0e6999d 1e93f37 0e6999d 1e93f37 0e6999d cca7a66 0e6999d 13375b8 0e6999d 13375b8 1e93f37 0e6999d 13375b8 0e6999d cca7a66 0e6999d 1e93f37 0e6999d 13375b8 0e6999d 13375b8 0e6999d 13375b8 0e6999d cca7a66 0e6999d 13375b8 da306a6 3f0393b da306a6 13375b8 0e6999d 13375b8 0e6999d 13375b8 0e6999d 13375b8 0e6999d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import streamlit as st
from speechbrain.pretrained import GraphemeToPhoneme
import os
import torchaudio
from wav2vecasr.MispronounciationDetector import MispronounciationDetector
from wav2vecasr.PhonemeASRModel import MultitaskPhonemeASRModel
import json
import os
import random
import openai
from gtts import gTTS
from io import BytesIO
openai.api_key = os.getenv("OPENAI_KEY")
# https://gtts.readthedocs.io/en/latest/
#
def tts_gtts(text):
mp3_fp = BytesIO()
tts = gTTS(text, lang="en")
tts.write_to_fp(mp3_fp)
return mp3_fp
def pronounce(text):
if len(text) > 0:
data = tts_gtts(text)
return data
return []
@st.cache_resource
def load_model():
path = os.path.join(os.getcwd(), "wav2vecasr", "model", "multitask_best_ctc.pt")
vocab_path = os.path.join(os.getcwd(), "wav2vecasr", "model", "vocab")
device = "cpu"
asr_model = MultitaskPhonemeASRModel(path, vocab_path, device)
g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
mispronounciation_detector = MispronounciationDetector(asr_model, g2p, device)
return mispronounciation_detector
def save_file(sound_file):
# save your sound file in the right folder by following the path
audio_folder_path = os.path.join(os.getcwd(), 'audio_files')
if not os.path.exists(audio_folder_path):
os.makedirs(audio_folder_path)
with open(os.path.join(audio_folder_path, sound_file.name), 'wb') as f:
f.write(sound_file.getbuffer())
return sound_file.name
@st.cache_data
def get_audio(saved_sound_filename):
audio_path = f'audio_files/{saved_sound_filename}'
audio, org_sr = torchaudio.load(audio_path)
audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
audio = audio.view(audio.shape[1])
return audio
@st.cache_data
def get_prompts():
prompts_path = os.path.join(os.getcwd(), "wav2vecasr", "data", "prompts.json")
f = open(prompts_path)
data = json.load(f)
prompts = data["prompts"]
return prompts
@st.cache_data
def get_articulation_videos():
# note -- not all arpabets could be mapped to a video with visualisation on articulation
path = os.path.join(os.getcwd(), "wav2vecasr", "data", "videos.json")
f = open(path)
data = json.load(f)
return data
def get_prompts_from_l2_arctic(prompts, current_prompt, num_to_get):
selected_prompts = []
while len(selected_prompts) < num_to_get:
prompt = random.choice(prompts)
if prompt not in selected_prompts and prompt != current_prompt:
selected_prompts.append(prompt)
return selected_prompts
def get_prompt_from_openai(words_with_error_list):
try:
words_with_errors = ", ".join(words_with_error_list)
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are writing practice reading prompts for learners of English to practice pronunciation. These prompts should be short, easy to understand and useful."},
{"role": "user", "content": f"Write a short sentence of less than 10 words and include the following words in the sentence: {words_with_errors} No numbers."}
]
)
return response['choices'][0]['message']['content']
except:
return ""
def mispronounciation_detection_section():
st.write('# Prediction')
st.write('1. Upload a recording of you saying the text in .wav format')
uploaded_file = st.file_uploader(' ', type='wav')
st.write('2. Input the text you are saying in your recording')
text = st.text_input(
"Enter the text you want to read π",
label_visibility='collapsed'
)
if st.button('Predict'):
if uploaded_file is not None and len(text) > 0:
# get audio from loaded file
save_file(uploaded_file)
audio = get_audio(uploaded_file.name)
# load model
mispronunciation_detector = load_model()
st.write('# Detection Results')
with st.spinner('Predicting...'):
# detect
raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25)
# display prediction results for phonemes
st.write('#### Phoneme Level Analysis')
st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}")
st.markdown(
f"""
<style>
textarea {{
white-space: nowrap;
}}
</style>
```
{raw_info['ref']}
{raw_info['hyp']}
{raw_info['phoneme_errors']}
```
""",
unsafe_allow_html=True,
)
st.divider()
# display word errors
md = []
words_with_errors = []
for word, has_error in zip(raw_info["words"], raw_info["word_errors"]):
if has_error:
words_with_errors.append(word)
md.append(f"**{word}**")
else:
md.append(word)
st.write('#### Word Level Analysis')
st.write(f"Word Error Rate: {round(raw_info['wer'], 2)} and the following words in bold have errors:")
st.markdown(" ".join(md))
st.divider()
st.write('#### What is next?')
# display pronounciation e.g.
st.write("Compare your pronunciation to pronounced sample")
st.audio(f'audio_files/{uploaded_file.name}', format="audio/wav", start_time=0)
pronounced_sample = pronounce(text)
st.audio(pronounced_sample, format="audio/wav", start_time=0)
# display more prompts to practice -- 1 from ChatGPT -- based on user's mistakes, 2 from L2 Arctic
st.write('Here are some more prompts for you to practice:')
selected_prompts = []
unique_words_with_errors = list(set(words_with_errors))
prompt_for_mistakes_made = get_prompt_from_openai(unique_words_with_errors)
if prompt_for_mistakes_made:
selected_prompts.append(prompt_for_mistakes_made)
prompts = get_prompts()
l2_arctic_prompts = get_prompts_from_l2_arctic(prompts, text, 3-len(selected_prompts))
selected_prompts.extend(l2_arctic_prompts)
for prompt in selected_prompts:
st.code(f'''{prompt}''', language="python")
else:
st.error('The audio or text has not been properly input', icon="π¨")
return
def video_section():
st.write('# Get helpful videos on phoneme articulation')
problem_phoneme = st.text_input(
"Enter the phoneme you had problems with π"
)
arpabet_to_video_map = get_articulation_videos()
if st.button('Look up'):
if not problem_phoneme:
st.error('The audio or text has not been properly input', icon="π¨")
elif problem_phoneme in arpabet_to_video_map:
video_link = arpabet_to_video_map[problem_phoneme]["link"]
if video_link:
st.video(video_link)
else:
st.write("Sorry, we couldn't find a good enough video yet :( we are working on it!")
if __name__ == '__main__':
st.write('___')
# create a sidebar
st.sidebar.title('Pronounciation Evaluation')
select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection', 'Helpful Videos for Problem Phonemes'], key='1', label_visibility='collapsed')
st.sidebar.write(select)
if select=='Mispronounciation Detection':
mispronounciation_detection_section()
elif select=="Helpful Videos for Problem Phonemes":
video_section()
else:
st.write('# Pronounciation Evaluation')
st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.')
st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text')
|