Spaces:
Runtime error
Runtime error
import torch | |
from wav2vecasr.models import MultiTaskWav2Vec2 | |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, \ | |
Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor | |
import pyctcdecode | |
import json | |
import re | |
from sys import platform | |
class PhonemeASRModel: | |
def get_l2_phoneme_sequence(self, audio): | |
""" | |
:param audio: audio sampled at 16k sampling rate with torchaudio | |
:type audio: array | |
:return: predicted phonemes for L2 speaker | |
:rtype: array | |
""" | |
pass | |
def standardise_g2p_phoneme_sequence(self, phones): | |
""" | |
To facilitate mispronounciation detection | |
:param phones: native speaker phones predicted by G2P model | |
:type phones: array | |
:return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model | |
:rtype: array | |
""" | |
pass | |
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
""" | |
To facilitate testing | |
:param phones: native speaker phones as annotated in l2 artic | |
:type phones: array | |
:return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model | |
:rtype: array | |
""" | |
pass | |
class MultitaskPhonemeASRModel(PhonemeASRModel): | |
def __init__(self, model_path, best_model_vocab_path, device): | |
self.device = device | |
tokenizer = Wav2Vec2CTCTokenizer(best_model_vocab_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") | |
feature_extractor = Wav2Vec2FeatureExtractor( | |
feature_size=1, | |
sampling_rate=16000, | |
padding_value=0.0, | |
do_normalize=True, | |
return_attention_mask=False, | |
) | |
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) | |
wav2vec2_backbone = Wav2Vec2ForCTC.from_pretrained( | |
pretrained_model_name_or_path="facebook/wav2vec2-xls-r-300m", | |
ignore_mismatched_sizes=True, | |
ctc_loss_reduction="mean", | |
pad_token_id=processor.tokenizer.pad_token_id, | |
vocab_size=len(processor.tokenizer), | |
output_hidden_states=True, | |
) | |
wav2vec2_backbone = wav2vec2_backbone.to(device) | |
model = MultiTaskWav2Vec2( | |
wav2vec2_backbone=wav2vec2_backbone, | |
backbone_hidden_size=1024, | |
projection_hidden_size=256, | |
num_accent_class=3, | |
) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
model.to(device) | |
model.eval() | |
self.multitask_model = model | |
self.processor = processor | |
def get_l2_phoneme_sequence(self, audio): | |
audio = audio.unsqueeze(0) | |
audio = self.processor(audio, sampling_rate=16000).input_values[0] | |
audio = torch.tensor(audio, device=self.device) | |
with torch.no_grad(): | |
_, lm_logits, _, _ = self.multitask_model(audio) | |
lm_preds = torch.argmax(lm_logits, dim=-1) | |
# Decode output results | |
pred_decoded = self.processor.batch_decode(lm_preds) | |
pred_phones = pred_decoded[0].split(" ") | |
# remove sil and sp | |
pred_phones = [phone for phone in pred_phones if phone != "sil" and phone != "sp"] | |
return pred_phones | |
def standardise_g2p_phoneme_sequence(self, phones): | |
return phones | |
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
return phones | |
class Wav2Vec2PhonemeASRModel(PhonemeASRModel): | |
""" | |
Uses greedy decoding | |
""" | |
def __init__(self, model_path, processor_path): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device) | |
self.processor = Wav2Vec2Processor.from_pretrained(processor_path) | |
def get_l2_phoneme_sequence(self, audio): | |
input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) | |
logits = self.model(input_dict.input_values.to(self.device)).logits | |
pred_ids = torch.argmax(logits, dim=-1)[0] | |
pred_phones = [phoneme for phoneme in self.processor.batch_decode(pred_ids) if phoneme != ""] | |
return pred_phones | |
def standardise_g2p_phoneme_sequence(self, phones): | |
return phones | |
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
return [re.sub(r'\d', "", phone_str) for phone_str in phones] | |
# TODO debug on linux because KenLM is not supported on Windows | |
class Wav2Vec2OptimisedPhonemeASRModel(PhonemeASRModel): | |
""" | |
Uses beam search and a LM for decoding | |
""" | |
def __init__(self, model_path, vocab_json_path, kenlm_model_path): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
f = open(vocab_json_path) | |
vocab_dict = json.load(f) | |
tokenizer = Wav2Vec2CTCTokenizer(vocab_json_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|") | |
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, | |
do_normalize=True, return_attention_mask=False) | |
labels = list(vocab_dict.keys()) | |
# beam search | |
decoder = pyctcdecode.decoder.build_ctcdecoder(labels) | |
if (platform == "linux" or platform == "linux2") and kenlm_model_path: | |
# beam search + LM | |
decoder = pyctcdecode.decoder.build_ctcdecoder(labels, kenlm_model_path=kenlm_model_path) | |
self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device) | |
self.processor = Wav2Vec2ProcessorWithLM(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder) | |
def get_l2_phoneme_sequence(self, audio): | |
input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) | |
logits = self.model(input_dict.input_values.to(self.device)).logits.cpu().detach() | |
normalised_logits = torch.nn.Softmax(dim=2)(logits) | |
normalised_logits = normalised_logits.numpy()[0] | |
output = self.processor.decode(normalised_logits) | |
pred_phones = output.text.split(" ") | |
return pred_phones | |
def standardise_g2p_phoneme_sequence(self, phones): | |
return phones | |
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones): | |
return [re.sub(r'\d', "", phone_str) for phone_str in phones] | |