Spaces:
Running
Running
File size: 12,793 Bytes
74a35d9 28d0c5f 74a35d9 85b7206 28d0c5f 0700cb3 28d0c5f 0700cb3 28d0c5f 0700cb3 acfca85 28d0c5f 0700cb3 28d0c5f 85b7206 28d0c5f 0700cb3 28d0c5f 0700cb3 28d0c5f 0700cb3 28d0c5f 85b7206 28d0c5f 0700cb3 28d0c5f 0700cb3 28d0c5f 85b7206 28d0c5f 85b7206 28d0c5f 0700cb3 28d0c5f 85b7206 28d0c5f 85b7206 28d0c5f 85b7206 28d0c5f 85b7206 0700cb3 85b7206 28d0c5f 0700cb3 28d0c5f 85b7206 28d0c5f 85b7206 28d0c5f 85b7206 0700cb3 28d0c5f 0700cb3 28d0c5f 0700cb3 28d0c5f 0700cb3 28d0c5f 0700cb3 85b7206 0700cb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 |
import time
from string import punctuation
import epitran
import numpy as np
import torch
import ModelInterfaces as mi
import RuleBasedModels
import WordMatching as wm
import WordMetrics
import models as mo
from constants import app_logger, MODEL_NAME_DEFAULT, sample_rate_resample
def preprocessAudioStandalone(audio: torch.tensor) -> torch.tensor:
"""
Preprocess the audio by normalizing it.
Args:
audio (torch.tensor): The input audio tensor.
Returns:
torch.tensor: The normalized audio tensor.
"""
audio = audio-torch.mean(audio)
audio = audio/torch.max(torch.abs(audio))
return audio
class PronunciationTrainer:
"""
A class to train and evaluate pronunciation accuracy using ASR and phoneme conversion models.
"""
current_transcript: str
current_ipa: str
current_recorded_audio: torch.Tensor
current_recorded_transcript: str
current_recorded_word_locations: list
current_recorded_intonations: torch.tensor
current_words_pronunciation_accuracy = []
categories_thresholds = np.array([80, 60, 59])
sampling_rate = sample_rate_resample
def __init__(self, asr_model: mi.IASRModel, word_to_ipa_coverter: mi.ITextToPhonemModel) -> None:
"""
Initialize the PronunciationTrainer with ASR and phoneme conversion models.
Args:
asr_model (mi.IASRModel): The ASR model to use.
word_to_ipa_coverter (mi.ITextToPhonemModel): The phoneme conversion model to use.
"""
self.asr_model = asr_model
self.ipa_converter = word_to_ipa_coverter
def getTranscriptAndWordsLocations(self, audio_length_in_samples: int) -> tuple[str, list]:
"""
Get the transcript and word locations from the ASR model.
Args:
audio_length_in_samples (int): The length of the audio in samples.
Returns:
tuple: A tuple containing the audio transcript and word locations in samples.
"""
audio_transcript = self.asr_model.getTranscript()
word_locations_in_samples = self.asr_model.getWordLocations()
fade_duration_in_samples = 0.05*self.sampling_rate
word_locations_in_samples = [(int(np.maximum(0, word['start_ts']-fade_duration_in_samples)), int(np.minimum(
audio_length_in_samples-1, word['end_ts']+fade_duration_in_samples))) for word in word_locations_in_samples]
return audio_transcript, word_locations_in_samples
# def getWordsRelativeIntonation(self, Audio: torch.tensor, word_locations: list):
# intonations = torch.zeros((len(word_locations), 1))
# intonation_fade_samples = 0.3*self.sampling_rate
# app_logger.info(f"intonations.shape: {intonations.shape}.")
# for word in range(len(word_locations)):
# intonation_start = int(np.maximum(
# 0, word_locations[word][0]-intonation_fade_samples))
# intonation_end = int(np.minimum(
# Audio.shape[1]-1, word_locations[word][1]+intonation_fade_samples))
# intonations[word] = torch.sqrt(torch.mean(
# Audio[0][intonation_start:intonation_end]**2))
#
# intonations = intonations/torch.mean(intonations)
# return intonations
##################### ASR Functions ###########################
def processAudioForGivenText(self, recordedAudio: torch.Tensor = None, real_text=None) -> dict:
"""
Process the recorded audio and evaluate pronunciation accuracy.
Args:
recordedAudio (torch.Tensor, optional): The recorded audio tensor. Defaults to None.
real_text (str, optional): The real text to compare against. Defaults to None.
Returns:
dict: A dictionary containing the evaluation results.
"""
start = time.time()
recording_transcript, recording_ipa, word_locations = self.getAudioTranscript(
recordedAudio)
time_transcript_audio = time.time() - start
app_logger.info(f'Time for NN to transcript audio: {time_transcript_audio:.2f}.')
start = time.time()
real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = self.matchSampleAndRecordedWords(
real_text, recording_transcript)
time_matching_transcripts = time.time() - start
app_logger.info(f'Time for matching transcripts: {time_matching_transcripts:.3f}.')
start_time, end_time = self.getWordLocationsFromRecordInSeconds(
word_locations, mapped_words_indices)
pronunciation_accuracy, current_words_pronunciation_accuracy = self.getPronunciationAccuracy(
real_and_transcribed_words) # _ipa
pronunciation_categories = self.getWordsPronunciationCategory(
current_words_pronunciation_accuracy)
result = {'recording_transcript': recording_transcript,
'real_and_transcribed_words': real_and_transcribed_words,
'recording_ipa': recording_ipa, 'start_time': start_time, 'end_time': end_time,
'real_and_transcribed_words_ipa': real_and_transcribed_words_ipa, 'pronunciation_accuracy': pronunciation_accuracy,
'pronunciation_categories': pronunciation_categories}
return result
def getAudioTranscript(self, recordedAudio: torch.Tensor = None) -> tuple[str | list]:
"""
Get the transcript and IPA representation of the recorded audio.
Args:
recordedAudio (torch.Tensor, optional): The recorded audio tensor. Defaults to None.
Returns:
tuple: A tuple containing the transcript, IPA representation, and word locations.
"""
current_recorded_audio = recordedAudio
current_recorded_audio = self.preprocessAudio(
current_recorded_audio)
self.asr_model.processAudio(current_recorded_audio)
current_recorded_transcript, current_recorded_word_locations = self.getTranscriptAndWordsLocations(
current_recorded_audio.shape[1])
current_recorded_ipa = self.ipa_converter.convertToPhonem(
current_recorded_transcript)
# time.sleep(10000)
return current_recorded_transcript, current_recorded_ipa, current_recorded_word_locations
def getWordLocationsFromRecordInSeconds(self, word_locations, mapped_words_indices) -> list:
"""
Get the start and end times of words in the recorded audio in seconds.
Args:
word_locations (list): The word locations in samples.
mapped_words_indices (list): The indices of the mapped words.
Returns:
list: A list containing the start and end times of words in seconds.
"""
app_logger.info(f"len_list: word_locations:{len(word_locations)}, mapped_words_indices:{len(mapped_words_indices)}, {len(word_locations) == len(mapped_words_indices)}...")
start_time = []
end_time = []
for word_idx in range(len(mapped_words_indices)):
start_time.append(float(word_locations[mapped_words_indices[word_idx]]
[0])/self.sampling_rate)
end_time.append(float(word_locations[mapped_words_indices[word_idx]]
[1])/self.sampling_rate)
return ' '.join([str(time) for time in start_time]), ' '.join([str(time) for time in end_time])
##################### END ASR Functions ###########################
##################### Evaluation Functions ###########################
def matchSampleAndRecordedWords(self, real_text, recorded_transcript):
"""
Match the real text with the recorded transcript and get the IPA representations.
Args:
real_text (str): The real text to compare against.
recorded_transcript (str): The recorded transcript.
Returns:
tuple: A tuple containing the matched words, IPA representations, and mapped word indices.
"""
words_estimated = recorded_transcript.split()
try:
words_real = real_text.split()
except AttributeError:
raise ValueError("Real text is None, but should be a string.")
mapped_words, mapped_words_indices = wm.get_best_mapped_words(
words_estimated, words_real)
real_and_transcribed_words = []
real_and_transcribed_words_ipa = []
for word_idx in range(len(words_real)):
if word_idx >= len(mapped_words)-1:
mapped_words.append('-')
real_and_transcribed_words.append(
(words_real[word_idx], mapped_words[word_idx]))
real_and_transcribed_words_ipa.append((self.ipa_converter.convertToPhonem(words_real[word_idx]),
self.ipa_converter.convertToPhonem(mapped_words[word_idx])))
return real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices
def getPronunciationAccuracy(self, real_and_transcribed_words_ipa) -> float:
"""
Calculate the pronunciation accuracy based on the IPA representations.
Args:
real_and_transcribed_words_ipa (list): A list of tuples containing the real and transcribed IPA representations.
Returns:
float: The percentage of correct pronunciations.
"""
total_mismatches = 0.
number_of_phonemes = 0.
current_words_pronunciation_accuracy = []
for pair in real_and_transcribed_words_ipa:
real_without_punctuation = self.removePunctuation(pair[0]).lower()
number_of_word_mismatches = WordMetrics.edit_distance_python(
real_without_punctuation, self.removePunctuation(pair[1]).lower())
total_mismatches += number_of_word_mismatches
number_of_phonemes_in_word = len(real_without_punctuation)
number_of_phonemes += number_of_phonemes_in_word
current_words_pronunciation_accuracy.append(float(
number_of_phonemes_in_word-number_of_word_mismatches)/number_of_phonemes_in_word*100)
percentage_of_correct_pronunciations = (
number_of_phonemes-total_mismatches)/number_of_phonemes*100
return np.round(percentage_of_correct_pronunciations), current_words_pronunciation_accuracy
def removePunctuation(self, word: str) -> str:
"""
Remove punctuation from a word.
Args:
word (str): The input word.
Returns:
str: The word without punctuation.
"""
return ''.join([char for char in word if char not in punctuation])
def getWordsPronunciationCategory(self, accuracies) -> list:
"""
Get the pronunciation category for each word based on accuracy.
Args:
accuracies (list): A list of pronunciation accuracies.
Returns:
list: A list of pronunciation categories.
"""
categories = []
for accuracy in accuracies:
categories.append(
self.getPronunciationCategoryFromAccuracy(accuracy))
return categories
def getPronunciationCategoryFromAccuracy(self, accuracy) -> int:
"""
Get the pronunciation category based on accuracy.
Args:
accuracy (float): The pronunciation accuracy.
Returns:
int: The pronunciation category.
"""
return np.argmin(abs(self.categories_thresholds-accuracy))
def preprocessAudio(self, audio: torch.tensor) -> torch.tensor:
"""
Preprocess the audio by normalizing it.
Args:
audio (torch.tensor): The input audio tensor.
Returns:
torch.tensor: The normalized audio tensor.
"""
return preprocessAudioStandalone(audio)
def getTrainer(language: str, model_name: str = MODEL_NAME_DEFAULT) -> PronunciationTrainer:
"""
Get a PronunciationTrainer instance for the specified language and model.
Args:
language (str): The language of the model.
model_name (str, optional): The name of the model. Defaults to MODEL_NAME_DEFAULT.
Returns:
PronunciationTrainer: An instance of PronunciationTrainer.
"""
asr_model = mo.getASRModel(language, model_name=model_name)
if language == 'de':
phonem_converter = RuleBasedModels.EpitranPhonemConverter(epitran.Epitran('deu-Latn'))
elif language == 'en':
phonem_converter = RuleBasedModels.EngPhonemConverter()
else:
raise ValueError(f"Language '{language}' not implemented")
trainer = PronunciationTrainer(asr_model, phonem_converter)
return trainer
|