File size: 5,539 Bytes
6d3dc99
 
 
 
 
22efacc
 
6d3dc99
 
 
 
22efacc
6d3dc99
22efacc
 
6d3dc99
 
 
 
 
 
 
22efacc
 
 
 
 
 
 
 
 
 
 
6d3dc99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22efacc
6d3dc99
 
 
 
 
 
 
 
 
 
 
 
 
22efacc
6d3dc99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from pandas.core.construction import T
import torch
import jiwer

class MispronounciationDetector:
  def __init__(self, l2_phoneme_recogniser, g2p, device):
    self.phoneme_asr_model = l2_phoneme_recogniser # PhonemeASRModel class
    self.g2p = g2p
    self.device = device

  def detect(self, audio, text):
    l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
    native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
    standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
    raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones)
    return raw_info

  def get_native_speaker_phoneme_sequence(self, text):
    phonemes = self.g2p(text)
    return phonemes

  def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
    """
    Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
    :param text: original words read by the user
    :type text: string
    :param pred_phones: predicted phonemes by L2 speaker from ASR Model
    :type pred_phones: array
    :param org_label_phones: correct, native speaker phonemes from G2P where phonemes of each word is segregated by " "
    :type org_label_phones: array
    :return: dictionary containing various mispronounciation information like PER, WER and error boolean arrays at phoneme/word level
    :rtype: dictionary
    """
    # get per
    label_phones = [phone for phone in org_label_phones if phone != " "]
    reference = " ".join(label_phones) # dummy phones
    hypothesis = " ".join(pred_phones) # dummy l2 speaker phones
    res = jiwer.process_words(reference, hypothesis)
    per = res.wer
    # print(jiwer.visualize_alignment(res))

    # get phoneme alignments
    alignments = res.alignments
    error_bool = []
    ref, hyp = [],[]
    for alignment_chunk in alignments[0]:
      alignment_type = alignment_chunk.type
      ref_start_idx = alignment_chunk.ref_start_idx
      ref_end_idx = alignment_chunk.ref_end_idx
      hyp_start_idx = alignment_chunk.hyp_start_idx
      hyp_end_idx = alignment_chunk.hyp_end_idx
      if alignment_type != "equal":
        if alignment_type == "insert":
          for i in range(hyp_start_idx, hyp_end_idx):
            ref.append("*" * len(pred_phones[i]))
            space_padding = " " * (len(pred_phones[i])-1)
            error_bool.append(space_padding + "a")
          hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
        elif alignment_type == "delete":
          ref.extend(label_phones[ref_start_idx:ref_end_idx])
          for i in range(ref_start_idx, ref_end_idx):
            hyp.append("*" * len(label_phones[i]))
            space_padding = " " * (len(label_phones[i])-1)
            error_bool.append(space_padding + alignment_type[0])
        else:
          for i in range(ref_end_idx - ref_start_idx):
            correct_phone = label_phones[ref_start_idx+i]
            pred_phone = pred_phones[hyp_start_idx+i]
            if len(correct_phone) > len(pred_phone):
              space_padding = " " * (len(correct_phone) - len(pred_phone))
              ref.append(correct_phone)
              hyp.append(space_padding + pred_phone)
              error_bool.append(" " * (len(correct_phone)-1) + alignment_type[0])
            else:
              space_padding = " " * (len(pred_phone) - len(correct_phone))
              ref.append(space_padding + correct_phone)
              hyp.append(pred_phone)
              error_bool.append(" " * (len(pred_phone)-1) + alignment_type[0])
      else:
        ref.extend(label_phones[ref_start_idx:ref_end_idx])
        hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
        # ref or hyp does not matter
        for i in range(ref_start_idx, ref_end_idx):
          space_padding = "-" * (len(label_phones[i]))
          error_bool.append(space_padding)

    # insert word delimiters to show user phoneme sections by word
    delimiter_idx = 0
    for phone in org_label_phones:
      if phone == " ":
        hyp.insert(delimiter_idx+1, "|")
        ref.insert(delimiter_idx+1, "|")
        error_bool.insert(delimiter_idx+1, "|")
        continue
      while delimiter_idx < len(ref) and ref[delimiter_idx].strip() != phone:
        delimiter_idx += 1
    # word ends
    ref.append("|")
    hyp.append("|")

    # get mispronounced words based on if there are phoneme errors present in the phonemes of that word
    aligned_word_error_output = ""
    words = text.split(" ")
    word_error_bool = self.get_mispronounced_words(error_bool)
    wer = sum(word_error_bool) / len(words)

    raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}

    return raw_info


  def get_mispronounced_words(self, phoneme_error_bool):
    # map mispronounced phones back to words that were mispronounce to get WER
    word_error_bool = []
    phoneme_error_bool.append("|")
    word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
    for phones in word_phones:
      if "s" in phones or "d" in phones or "a" in phones:
        word_error_bool.append(True)
      else:
        word_error_bool.append(False)
    return word_error_bool


  def split_lst_by_delim(self, lst, delimiter):
    temp = []
    res = []
    for item in lst:
      if item != delimiter:
        temp.append(item.strip())
      else:
        res.append(temp);
        temp = []
    return res