bel32123 commited on
Commit
22efacc
1 Parent(s): 4cb5b41

Create model interface

Browse files
wav2vecasr/MispronounciationDetector.py CHANGED
@@ -3,30 +3,34 @@ import torch
3
  import jiwer
4
 
5
  class MispronounciationDetector:
6
- def __init__(self, l2_phoneme_recogniser, l2_phoneme_recogniser_processor, g2p, device):
7
- self.l2_phoneme_recogniser = l2_phoneme_recogniser
8
- self.l2_phoneme_recogniser_processor = l2_phoneme_recogniser_processor
9
  self.g2p = g2p
10
  self.device = device
11
 
12
  def detect(self, audio, text):
13
- l2_phones = self.get_l2_phoneme_sequence(audio)
14
  native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
15
- raw_info = self.get_mispronounciation_output(text, l2_phones, native_speaker_phones)
 
16
  return raw_info
17
 
18
- def get_l2_phoneme_sequence(self, audio):
19
- input_dict = self.l2_phoneme_recogniser_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
20
- logits = self.l2_phoneme_recogniser(input_dict.input_values.to(self.device)).logits
21
- pred_ids = torch.argmax(logits, dim=-1)[0]
22
- pred_phones = [phoneme for phoneme in self.l2_phoneme_recogniser_processor.batch_decode(pred_ids) if phoneme != ""]
23
- return pred_phones
24
-
25
  def get_native_speaker_phoneme_sequence(self, text):
26
  phonemes = self.g2p(text)
27
  return phonemes
28
 
29
  def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
 
 
 
 
 
 
 
 
 
 
 
30
  # get per
31
  label_phones = [phone for phone in org_label_phones if phone != " "]
32
  reference = " ".join(label_phones) # dummy phones
@@ -80,6 +84,7 @@ class MispronounciationDetector:
80
  space_padding = "-" * (len(label_phones[i]))
81
  error_bool.append(space_padding)
82
 
 
83
  delimiter_idx = 0
84
  for phone in org_label_phones:
85
  if phone == " ":
@@ -93,7 +98,7 @@ class MispronounciationDetector:
93
  ref.append("|")
94
  hyp.append("|")
95
 
96
- # get mispronounced words
97
  aligned_word_error_output = ""
98
  words = text.split(" ")
99
  word_error_bool = self.get_mispronounced_words(error_bool)
 
3
  import jiwer
4
 
5
  class MispronounciationDetector:
6
+ def __init__(self, l2_phoneme_recogniser, g2p, device):
7
+ self.phoneme_asr_model = l2_phoneme_recogniser # PhonemeASRModel class
 
8
  self.g2p = g2p
9
  self.device = device
10
 
11
  def detect(self, audio, text):
12
+ l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
13
  native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
14
+ standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
15
+ raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones)
16
  return raw_info
17
 
 
 
 
 
 
 
 
18
  def get_native_speaker_phoneme_sequence(self, text):
19
  phonemes = self.g2p(text)
20
  return phonemes
21
 
22
  def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
23
+ """
24
+ Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
25
+ :param text: original words read by the user
26
+ :type text: string
27
+ :param pred_phones: predicted phonemes by L2 speaker from ASR Model
28
+ :type pred_phones: array
29
+ :param org_label_phones: correct, native speaker phonemes from G2P where phonemes of each word is segregated by " "
30
+ :type org_label_phones: array
31
+ :return: dictionary containing various mispronounciation information like PER, WER and error boolean arrays at phoneme/word level
32
+ :rtype: dictionary
33
+ """
34
  # get per
35
  label_phones = [phone for phone in org_label_phones if phone != " "]
36
  reference = " ".join(label_phones) # dummy phones
 
84
  space_padding = "-" * (len(label_phones[i]))
85
  error_bool.append(space_padding)
86
 
87
+ # insert word delimiters to show user phoneme sections by word
88
  delimiter_idx = 0
89
  for phone in org_label_phones:
90
  if phone == " ":
 
98
  ref.append("|")
99
  hyp.append("|")
100
 
101
+ # get mispronounced words based on if there are phoneme errors present in the phonemes of that word
102
  aligned_word_error_output = ""
103
  words = text.split(" ")
104
  word_error_bool = self.get_mispronounced_words(error_bool)
wav2vecasr/PhonemeASRModel.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, \
3
+ Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
4
+ import pyctcdecode
5
+ import json
6
+ import re
7
+ from sys import platform
8
+
9
+ class PhonemeASRModel:
10
+ def get_l2_phoneme_sequence(self, audio):
11
+ """
12
+ :param audio: audio sampled at 16k sampling rate with torchaudio
13
+ :type audio: array
14
+ :return: predicted phonemes for L2 speaker
15
+ :rtype: array
16
+ """
17
+ pass
18
+
19
+ def standardise_g2p_phoneme_sequence(self, phones):
20
+ """
21
+ To facilitate mispronounciation detection
22
+
23
+ :param phones: native speaker phones predicted by G2P model
24
+ :type phones: array
25
+ :return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model
26
+ :rtype: array
27
+ """
28
+ pass
29
+
30
+ def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
31
+ """
32
+ To facilitate testing
33
+
34
+ :param phones: native speaker phones as annotated in l2 artic
35
+ :type phones: array
36
+ :return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model
37
+ :rtype: array
38
+ """
39
+ pass
40
+
41
+ class Wav2Vec2PhonemeASRModel(PhonemeASRModel):
42
+ """
43
+ Uses greedy decoding
44
+ """
45
+ def __init__(self, model_path, processor_path):
46
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device)
48
+ self.processor = Wav2Vec2Processor.from_pretrained(processor_path)
49
+
50
+ def get_l2_phoneme_sequence(self, audio):
51
+ input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
52
+ logits = self.model(input_dict.input_values.to(self.device)).logits
53
+ pred_ids = torch.argmax(logits, dim=-1)[0]
54
+ pred_phones = [phoneme for phoneme in self.processor.batch_decode(pred_ids) if phoneme != ""]
55
+ return pred_phones
56
+
57
+ def standardise_g2p_phoneme_sequence(self, phones):
58
+ return phones
59
+
60
+ def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
61
+ return [re.sub(r'\d', "", phone_str) for phone_str in phones]
62
+
63
+ # TODO debug on linux because KenLM is not supported on Windows
64
+ class Wav2Vec2OptimisedPhonemeASRModel(PhonemeASRModel):
65
+ """
66
+ Uses beam search and a LM for decoding
67
+ """
68
+
69
+ def __init__(self, model_path, vocab_json_path, kenlm_model_path):
70
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
71
+
72
+ f = open(vocab_json_path)
73
+ vocab_dict = json.load(f)
74
+ tokenizer = Wav2Vec2CTCTokenizer(vocab_json_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
75
+ feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0,
76
+ do_normalize=True, return_attention_mask=False)
77
+ labels = list(vocab_dict.keys())
78
+ # beam search
79
+ decoder = pyctcdecode.decoder.build_ctcdecoder(labels)
80
+ if (platform == "linux" or platform == "linux2") and kenlm_model_path:
81
+ # beam search + LM
82
+ decoder = pyctcdecode.decoder.build_ctcdecoder(labels, kenlm_model_path=kenlm_model_path)
83
+ self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device)
84
+ self.processor = Wav2Vec2ProcessorWithLM(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
85
+
86
+ def get_l2_phoneme_sequence(self, audio):
87
+ input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
88
+ logits = self.model(input_dict.input_values.to(self.device)).logits.cpu().detach()
89
+ normalised_logits = torch.nn.Softmax(dim=2)(logits)
90
+ normalised_logits = normalised_logits.numpy()[0]
91
+ output = self.processor.decode(normalised_logits)
92
+ pred_phones = output.text.split(" ")
93
+ return pred_phones
94
+
95
+ def standardise_g2p_phoneme_sequence(self, phones):
96
+ return phones
97
+
98
+ def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
99
+ return [re.sub(r'\d', "", phone_str) for phone_str in phones]
100
+
101
+