bel32123 commited on
Commit
0e6999d
β€’
1 Parent(s): 6d3dc99

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
4
+ from speechbrain.pretrained import GraphemeToPhoneme
5
+ import os
6
+ import torchaudio
7
+ from wav2vecasr.MispronounciationDetector import MispronounciationDetector
8
+
9
+ @st.cache_resource
10
+ def load_model():
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ path = os.path.join(os.getcwd(), "wav2vecasr", "model", "checkpoint-1200")
13
+ model = Wav2Vec2ForCTC.from_pretrained(path).to(device)
14
+ processor = Wav2Vec2Processor.from_pretrained(path)
15
+ g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
16
+ mispronounciation_detector = MispronounciationDetector(model, processor, g2p, "cpu")
17
+ return mispronounciation_detector
18
+
19
+
20
+ def save_file(sound_file):
21
+ # save your sound file in the right folder by following the path
22
+ with open(os.path.join(os.getcwd(), 'audio_files', sound_file.name), 'wb') as f:
23
+ f.write(sound_file.getbuffer())
24
+ return sound_file.name
25
+
26
+ @st.cache_data
27
+ def get_audio(saved_sound_filename):
28
+ audio_path = f'audio_files/{saved_sound_filename}'
29
+ audio, org_sr = torchaudio.load(audio_path)
30
+ audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000)
31
+ audio = audio.view(audio.shape[1])
32
+ return audio
33
+
34
+ def mispronounciation_detection_section():
35
+ st.write('# Prediction')
36
+ st.write('1. Upload a recording of you saying the text in .wav format')
37
+ uploaded_file = st.file_uploader(' ', type='wav')
38
+ st.write('2. Input the text you are saying in your recording')
39
+ text = st.text_input(
40
+ "Enter the text you want to read πŸ‘‡",
41
+ label_visibility='collapsed'
42
+ )
43
+ if st.button('Predict'):
44
+ if uploaded_file is not None and len(text) > 0:
45
+ # get audio from loaded file
46
+ save_file(uploaded_file)
47
+ audio = get_audio(uploaded_file.name)
48
+
49
+ # load model
50
+ mispronunciation_detector = load_model()
51
+
52
+ # start prediction
53
+ st.write('# Detection Results')
54
+ with st.spinner('Predicting...'):
55
+ raw_info = mispronunciation_detector.detect(audio, text)
56
+
57
+ st.write('#### Phoneme Level Analysis')
58
+ st.markdown(f"Phoneme Error Rate: ___{round(raw_info['per'],2)}___")
59
+ # enable horizontal scrolling for phoneme output
60
+ #st.text_area(label="Aligned phoneme outputs", value=raw_info['phoneme_output'],height=150)
61
+ st.markdown(
62
+ f"""
63
+ <style>
64
+ textarea {{
65
+ white-space: nowrap;
66
+ }}
67
+ </style>
68
+ ```
69
+ {" ".join(raw_info['ref'])}
70
+ {" ".join(raw_info['hyp'])}
71
+ {" ".join(raw_info['phoneme_errors'])}
72
+ ```
73
+ """,
74
+ unsafe_allow_html=True,
75
+ )
76
+
77
+ st.divider()
78
+ md = []
79
+ for word, has_error in zip(raw_info["words"], raw_info["word_errors"]):
80
+ if has_error:
81
+ md.append(f"**{word}**")
82
+ else:
83
+ md.append(word)
84
+
85
+ st.write('#### Word Level Analysis')
86
+ st.write(f"Word Error Rate: ___{round(raw_info['wer'], 2)}___ and the following words in bold have errors:")
87
+ st.markdown(" ".join(md))
88
+ else:
89
+ st.error('The audio or text has not been properly input', icon="🚨")
90
+ return
91
+
92
+ if __name__ == '__main__':
93
+ st.write('___')
94
+ # create a sidebar
95
+ st.sidebar.title('Pronounciation Evaluation')
96
+ select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection'], key='1', label_visibility='collapsed')
97
+ st.sidebar.write(select)
98
+ if select=='Mispronounciation Detection':
99
+ mispronounciation_detection_section()
100
+ # else: stay on the home page
101
+ else:
102
+ st.write('# Pronounciation Evaluation')
103
+ st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.')
104
+ st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text')