Spaces:
Runtime error
Runtime error
File size: 6,218 Bytes
22efacc 83aa7fc 22efacc 83aa7fc 22efacc 83aa7fc 22efacc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
from wav2vecasr.models import MultiTaskWav2Vec2
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, \
Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
import pyctcdecode
import json
import re
from sys import platform
class PhonemeASRModel:
def get_l2_phoneme_sequence(self, audio):
"""
:param audio: audio sampled at 16k sampling rate with torchaudio
:type audio: array
:return: predicted phonemes for L2 speaker
:rtype: array
"""
pass
def standardise_g2p_phoneme_sequence(self, phones):
"""
To facilitate mispronounciation detection
:param phones: native speaker phones predicted by G2P model
:type phones: array
:return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model
:rtype: array
"""
pass
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
"""
To facilitate testing
:param phones: native speaker phones as annotated in l2 artic
:type phones: array
:return: standardised native speaker phoneme sequence that aligns with phoneme classes by the model
:rtype: array
"""
pass
class MultitaskPhonemeASRModel(PhonemeASRModel):
def __init__(self, model_path, best_model_vocab_path, device):
self.device = device
tokenizer = Wav2Vec2CTCTokenizer(best_model_vocab_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0.0,
do_normalize=True,
return_attention_mask=False,
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
wav2vec2_backbone = Wav2Vec2ForCTC.from_pretrained(
pretrained_model_name_or_path="facebook/wav2vec2-xls-r-300m",
ignore_mismatched_sizes=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
output_hidden_states=True,
)
wav2vec2_backbone = wav2vec2_backbone.to(device)
model = MultiTaskWav2Vec2(
wav2vec2_backbone=wav2vec2_backbone,
backbone_hidden_size=1024,
projection_hidden_size=256,
num_accent_class=3,
)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.to(device)
model.eval()
self.multitask_model = model
self.processor = processor
def get_l2_phoneme_sequence(self, audio):
audio = audio.unsqueeze(0)
audio = self.processor(audio, sampling_rate=16000).input_values[0]
audio = torch.tensor(audio, device=self.device)
with torch.no_grad():
_, lm_logits, _, _ = self.multitask_model(audio)
lm_preds = torch.argmax(lm_logits, dim=-1)
# Decode output results
pred_decoded = self.processor.batch_decode(lm_preds)
pred_phones = pred_decoded[0].split(" ")
# remove sil and sp
pred_phones = [phone for phone in pred_phones if phone != "sil" and phone != "sp"]
return pred_phones
def standardise_g2p_phoneme_sequence(self, phones):
return phones
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
return phones
class Wav2Vec2PhonemeASRModel(PhonemeASRModel):
"""
Uses greedy decoding
"""
def __init__(self, model_path, processor_path):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device)
self.processor = Wav2Vec2Processor.from_pretrained(processor_path)
def get_l2_phoneme_sequence(self, audio):
input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
logits = self.model(input_dict.input_values.to(self.device)).logits
pred_ids = torch.argmax(logits, dim=-1)[0]
pred_phones = [phoneme for phoneme in self.processor.batch_decode(pred_ids) if phoneme != ""]
return pred_phones
def standardise_g2p_phoneme_sequence(self, phones):
return phones
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
return [re.sub(r'\d', "", phone_str) for phone_str in phones]
# TODO debug on linux because KenLM is not supported on Windows
class Wav2Vec2OptimisedPhonemeASRModel(PhonemeASRModel):
"""
Uses beam search and a LM for decoding
"""
def __init__(self, model_path, vocab_json_path, kenlm_model_path):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
f = open(vocab_json_path)
vocab_dict = json.load(f)
tokenizer = Wav2Vec2CTCTokenizer(vocab_json_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0,
do_normalize=True, return_attention_mask=False)
labels = list(vocab_dict.keys())
# beam search
decoder = pyctcdecode.decoder.build_ctcdecoder(labels)
if (platform == "linux" or platform == "linux2") and kenlm_model_path:
# beam search + LM
decoder = pyctcdecode.decoder.build_ctcdecoder(labels, kenlm_model_path=kenlm_model_path)
self.model = Wav2Vec2ForCTC.from_pretrained(model_path).to(self.device)
self.processor = Wav2Vec2ProcessorWithLM(feature_extractor=feature_extractor, tokenizer=tokenizer, decoder=decoder)
def get_l2_phoneme_sequence(self, audio):
input_dict = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
logits = self.model(input_dict.input_values.to(self.device)).logits.cpu().detach()
normalised_logits = torch.nn.Softmax(dim=2)(logits)
normalised_logits = normalised_logits.numpy()[0]
output = self.processor.decode(normalised_logits)
pred_phones = output.text.split(" ")
return pred_phones
def standardise_g2p_phoneme_sequence(self, phones):
return phones
def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
return [re.sub(r'\d', "", phone_str) for phone_str in phones]
|