Spaces:
Runtime error
Runtime error
Create model interface
Browse files- wav2vecasr/MispronounciationDetector.py +18 -13
- wav2vecasr/PhonemeASRModel.py +101 -0
wav2vecasr/MispronounciationDetector.py
CHANGED
@@ -3,30 +3,34 @@ import torch
|
|
3 |
import jiwer
|
4 |
|
5 |
class MispronounciationDetector:
|
6 |
-
def __init__(self, l2_phoneme_recogniser,
|
7 |
-
self.
|
8 |
-
self.l2_phoneme_recogniser_processor = l2_phoneme_recogniser_processor
|
9 |
self.g2p = g2p
|
10 |
self.device = device
|
11 |
|
12 |
def detect(self, audio, text):
|
13 |
-
l2_phones = self.get_l2_phoneme_sequence(audio)
|
14 |
native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
|
15 |
-
|
|
|
16 |
return raw_info
|
17 |
|
18 |
-
def get_l2_phoneme_sequence(self, audio):
|
19 |
-
input_dict = self.l2_phoneme_recogniser_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
|
20 |
-
logits = self.l2_phoneme_recogniser(input_dict.input_values.to(self.device)).logits
|
21 |
-
pred_ids = torch.argmax(logits, dim=-1)[0]
|
22 |
-
pred_phones = [phoneme for phoneme in self.l2_phoneme_recogniser_processor.batch_decode(pred_ids) if phoneme != ""]
|
23 |
-
return pred_phones
|
24 |
-
|
25 |
def get_native_speaker_phoneme_sequence(self, text):
|
26 |
phonemes = self.g2p(text)
|
27 |
return phonemes
|
28 |
|
29 |
def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# get per
|
31 |
label_phones = [phone for phone in org_label_phones if phone != " "]
|
32 |
reference = " ".join(label_phones) # dummy phones
|
@@ -80,6 +84,7 @@ class MispronounciationDetector:
|
|
80 |
space_padding = "-" * (len(label_phones[i]))
|
81 |
error_bool.append(space_padding)
|
82 |
|
|
|
83 |
delimiter_idx = 0
|
84 |
for phone in org_label_phones:
|
85 |
if phone == " ":
|
@@ -93,7 +98,7 @@ class MispronounciationDetector:
|
|
93 |
ref.append("|")
|
94 |
hyp.append("|")
|
95 |
|
96 |
-
# get mispronounced words
|
97 |
aligned_word_error_output = ""
|
98 |
words = text.split(" ")
|
99 |
word_error_bool = self.get_mispronounced_words(error_bool)
|
|
|
3 |
import jiwer
|
4 |
|
5 |
class MispronounciationDetector:
|
6 |
+
def __init__(self, l2_phoneme_recogniser, g2p, device):
|
7 |
+
self.phoneme_asr_model = l2_phoneme_recogniser # PhonemeASRModel class
|
|
|
8 |
self.g2p = g2p
|
9 |
self.device = device
|
10 |
|
11 |
def detect(self, audio, text):
|
12 |
+
l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
|
13 |
native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
|
14 |
+
standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
|
15 |
+
raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones)
|
16 |
return raw_info
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def get_native_speaker_phoneme_sequence(self, text):
|
19 |
phonemes = self.g2p(text)
|
20 |
return phonemes
|
21 |
|
22 |
def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
|
23 |
+
"""
|
24 |
+
Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
|
25 |
+
:param text: original words read by the user
|
26 |
+
:type text: string
|
27 |
+
:param pred_phones: predicted phonemes by L2 speaker from ASR Model
|
28 |
+
:type pred_phones: array
|
29 |
+
:param org_label_phones: correct, native speaker phonemes from G2P where phonemes of each word is segregated by " "
|
30 |
+
:type org_label_phones: array
|
31 |
+
:return: dictionary containing various mispronounciation information like PER, WER and error boolean arrays at phoneme/word level
|
32 |
+
:rtype: dictionary
|
33 |
+
"""
|
34 |
# get per
|
35 |
label_phones = [phone for phone in org_label_phones if phone != " "]
|
36 |
reference = " ".join(label_phones) # dummy phones
|
|
|
84 |
space_padding = "-" * (len(label_phones[i]))
|
85 |
error_bool.append(space_padding)
|
86 |
|
87 |
+
# insert word delimiters to show user phoneme sections by word
|
88 |
delimiter_idx = 0
|
89 |
for phone in org_label_phones:
|
90 |
if phone == " ":
|
|
|
98 |
ref.append("|")
|
99 |
hyp.append("|")
|
100 |
|
101 |
+
# get mispronounced words based on if there are phoneme errors present in the phonemes of that word
|
102 |
aligned_word_error_output = ""
|
103 |
words = text.split(" ")
|
104 |
word_error_bool = self.get_mispronounced_words(error_bool)
|
wav2vecasr/PhonemeASRModel.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, \
|
3 |
+
Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
4 |
+
import pyctcdecode
|
5 |
+
import json
|
6 |
+
import re
|
7 |
+
from sys import platform
|
8 |
+
|
9 |
+
class PhonemeASRModel:
|
10 |
+
def get_l2_phoneme_sequence(self, audio):
|
11 |
+
"""
|
12 |
+
:param audio: audio sampled at 16k sampling rate with torchaudio
|
13 |
+
:type audio: array
|
14 |
+
:return: predicted phonemes for L2 speaker
|
15 |
+
:rtype: array
|
16 |
+
"""
|
17 |
+
pass
|
18 |
+
|
19 |
+
def standardise_g2p_phoneme_sequence(self, phones):
|
20 |
+
"""
|
21 |
+
To facilitate mispronounciation detection
|
22 |
+
|
23 |
+
:param phones: native speaker phones predicted by G2P model
|
24 |
+
:type phones: array
|
25 |
+
:return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model
|
26 |
+
:rtype: array
|
27 |
+
"""
|
28 |
+
pass
|
29 |
+
|
30 |
+
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
|
31 |
+
"""
|
32 |
+
To facilitate testing
|
33 |
+
|
34 |
+
:param phones: native speaker phones as annotated in l2 artic
|
35 |
+
:type phones: array
|
36 |
+
:return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model
|
37 |
+
:rtype: array
|
38 |
+
"""
|
39 |
+
pass
|
40 |
+
|
41 |
+
class Wav2Vec2PhonemeASRModel(PhonemeASRModel):
|
42 |
+
"""
|
43 |
+
Uses greedy decoding
|
44 |
+
"""
|
45 |
+
def __init__(self, model_path, processor_path):
|
46 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
47 |
+
self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device)
|
48 |
+
self.processor = Wav2Vec2Processor.from_pretrained(processor_path)
|
49 |
+
|
50 |
+
def get_l2_phoneme_sequence(self, audio):
|
51 |
+
input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
|
52 |
+
logits = self.model(input_dict.input_values.to(self.device)).logits
|
53 |
+
pred_ids = torch.argmax(logits, dim=-1)[0]
|
54 |
+
pred_phones = [phoneme for phoneme in self.processor.batch_decode(pred_ids) if phoneme != ""]
|
55 |
+
return pred_phones
|
56 |
+
|
57 |
+
def standardise_g2p_phoneme_sequence(self, phones):
|
58 |
+
return phones
|
59 |
+
|
60 |
+
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
|
61 |
+
return [re.sub(r'\d', "", phone_str) for phone_str in phones]
|
62 |
+
|
63 |
+
# TODO debug on linux because KenLM is not supported on Windows
|
64 |
+
class Wav2Vec2OptimisedPhonemeASRModel(PhonemeASRModel):
|
65 |
+
"""
|
66 |
+
Uses beam search and a LM for decoding
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, model_path, vocab_json_path, kenlm_model_path):
|
70 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
71 |
+
|
72 |
+
f = open(vocab_json_path)
|
73 |
+
vocab_dict = json.load(f)
|
74 |
+
tokenizer = Wav2Vec2CTCTokenizer(vocab_json_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
|
75 |
+
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0,
|
76 |
+
do_normalize=True, return_attention_mask=False)
|
77 |
+
labels = list(vocab_dict.keys())
|
78 |
+
# beam search
|
79 |
+
decoder = pyctcdecode.decoder.build_ctcdecoder(labels)
|
80 |
+
if (platform == "linux" or platform == "linux2") and kenlm_model_path:
|
81 |
+
# beam search + LM
|
82 |
+
decoder = pyctcdecode.decoder.build_ctcdecoder(labels, kenlm_model_path=kenlm_model_path)
|
83 |
+
self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device)
|
84 |
+
self.processor = Wav2Vec2ProcessorWithLM(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
|
85 |
+
|
86 |
+
def get_l2_phoneme_sequence(self, audio):
|
87 |
+
input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
|
88 |
+
logits = self.model(input_dict.input_values.to(self.device)).logits.cpu().detach()
|
89 |
+
normalised_logits = torch.nn.Softmax(dim=2)(logits)
|
90 |
+
normalised_logits = normalised_logits.numpy()[0]
|
91 |
+
output = self.processor.decode(normalised_logits)
|
92 |
+
pred_phones = output.text.split(" ")
|
93 |
+
return pred_phones
|
94 |
+
|
95 |
+
def standardise_g2p_phoneme_sequence(self, phones):
|
96 |
+
return phones
|
97 |
+
|
98 |
+
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
|
99 |
+
return [re.sub(r'\d', "", phone_str) for phone_str in phones]
|
100 |
+
|
101 |
+
|