import torch from wav2vecasr.models import MultiTaskWav2Vec2 from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, \ Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor import pyctcdecode import json import re from sys import platform class PhonemeASRModel: def get_l2_phoneme_sequence(self, audio): """ :param audio: audio sampled at 16k sampling rate with torchaudio :type audio: array :return: predicted phonemes for L2 speaker :rtype: array """ pass def standardise_g2p_phoneme_sequence(self, phones): """ To facilitate mispronounciation detection :param phones: native speaker phones predicted by G2P model :type phones: array :return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model :rtype: array """ pass def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): """ To facilitate testing :param phones: native speaker phones as annotated in l2 artic :type phones: array :return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model :rtype: array """ pass class MultitaskPhonemeASRModel(PhonemeASRModel): def __init__(self, model_path, best_model_vocab_path, device): self.device = device tokenizer = Wav2Vec2CTCTokenizer(best_model_vocab_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") feature_extractor = Wav2Vec2FeatureExtractor( feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False, ) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) wav2vec2_backbone = Wav2Vec2ForCTC.from_pretrained( pretrained_model_name_or_path="facebook/wav2vec2-xls-r-300m", ignore_mismatched_sizes=True, ctc_loss_reduction="mean", pad_token_id=processor.tokenizer.pad_token_id, vocab_size=len(processor.tokenizer), output_hidden_states=True, ) wav2vec2_backbone = wav2vec2_backbone.to(device) model = MultiTaskWav2Vec2( wav2vec2_backbone=wav2vec2_backbone, backbone_hidden_size=1024, projection_hidden_size=256, num_accent_class=3, ) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.to(device) model.eval() self.multitask_model = model self.processor = processor def get_l2_phoneme_sequence(self, audio): audio = audio.unsqueeze(0) audio = self.processor(audio, sampling_rate=16000).input_values[0] audio = torch.tensor(audio, device=self.device) with torch.no_grad(): _, lm_logits, _, _ = self.multitask_model(audio) lm_preds = torch.argmax(lm_logits, dim=-1) # Decode output results pred_decoded = self.processor.batch_decode(lm_preds) pred_phones = pred_decoded[0].split(" ") # remove sil and sp pred_phones = [phone for phone in pred_phones if phone != "sil" and phone != "sp"] return pred_phones def standardise_g2p_phoneme_sequence(self, phones): return phones def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): return phones class Wav2Vec2PhonemeASRModel(PhonemeASRModel): """ Uses greedy decoding """ def __init__(self, model_path, processor_path): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device) self.processor = Wav2Vec2Processor.from_pretrained(processor_path) def get_l2_phoneme_sequence(self, audio): input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) logits = self.model(input_dict.input_values.to(self.device)).logits pred_ids = torch.argmax(logits, dim=-1)[0] pred_phones = [phoneme for phoneme in self.processor.batch_decode(pred_ids) if phoneme != ""] return pred_phones def standardise_g2p_phoneme_sequence(self, phones): return phones def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): return [re.sub(r'\d', "", phone_str) for phone_str in phones] # TODO debug on linux because KenLM is not supported on Windows class Wav2Vec2OptimisedPhonemeASRModel(PhonemeASRModel): """ Uses beam search and a LM for decoding """ def __init__(self, model_path, vocab_json_path, kenlm_model_path): self.device = "cuda" if torch.cuda.is_available() else "cpu" f = open(vocab_json_path) vocab_dict = json.load(f) tokenizer = Wav2Vec2CTCTokenizer(vocab_json_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) labels = list(vocab_dict.keys()) # beam search decoder = pyctcdecode.decoder.build_ctcdecoder(labels) if (platform == "linux" or platform == "linux2") and kenlm_model_path: # beam search + LM decoder = pyctcdecode.decoder.build_ctcdecoder(labels, kenlm_model_path=kenlm_model_path) self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device) self.processor = Wav2Vec2ProcessorWithLM(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder) def get_l2_phoneme_sequence(self, audio): input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) logits = self.model(input_dict.input_values.to(self.device)).logits.cpu().detach() normalised_logits = torch.nn.Softmax(dim=2)(logits) normalised_logits = normalised_logits.numpy()[0] output = self.processor.decode(normalised_logits) pred_phones = output.text.split(" ") return pred_phones def standardise_g2p_phoneme_sequence(self, phones): return phones def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): return [re.sub(r'\d', "", phone_str) for phone_str in phones]