bel32123 commited on
Commit
83aa7fc
1 Parent(s): 2676061

Include MultitaskASRModel

Browse files
Files changed (1) hide show
  1. wav2vecasr/PhonemeASRModel.py +64 -1
wav2vecasr/PhonemeASRModel.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, \
3
  Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
4
  import pyctcdecode
@@ -6,6 +7,7 @@ import json
6
  import re
7
  from sys import platform
8
 
 
9
  class PhonemeASRModel:
10
  def get_l2_phoneme_sequence(self, audio):
11
  """
@@ -38,6 +40,68 @@ class PhonemeASRModel:
38
  """
39
  pass
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  class Wav2Vec2PhonemeASRModel(PhonemeASRModel):
42
  """
43
  Uses greedy decoding
@@ -98,4 +162,3 @@ class Wav2Vec2OptimisedPhonemeASRModel(PhonemeASRModel):
98
  def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
99
  return [re.sub(r'\d', "", phone_str) for phone_str in phones]
100
 
101
-
 
1
  import torch
2
+ from wav2vecasr.models import MultiTaskWav2Vec2
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM, \
4
  Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
5
  import pyctcdecode
 
7
  import re
8
  from sys import platform
9
 
10
+
11
  class PhonemeASRModel:
12
  def get_l2_phoneme_sequence(self, audio):
13
  """
 
40
  """
41
  pass
42
 
43
+ class MultitaskPhonemeASRModel(PhonemeASRModel):
44
+ def __init__(self, model_path, best_model_vocab_path, device):
45
+ self.device = device
46
+ tokenizer = Wav2Vec2CTCTokenizer(best_model_vocab_path, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
47
+ feature_extractor = Wav2Vec2FeatureExtractor(
48
+ feature_size=1,
49
+ sampling_rate=16000,
50
+ padding_value=0.0,
51
+ do_normalize=True,
52
+ return_attention_mask=False,
53
+ )
54
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
55
+
56
+ wav2vec2_backbone = Wav2Vec2ForCTC.from_pretrained(
57
+ pretrained_model_name_or_path="facebook/wav2vec2-xls-r-300m",
58
+ ignore_mismatched_sizes=True,
59
+ ctc_loss_reduction="mean",
60
+ pad_token_id=processor.tokenizer.pad_token_id,
61
+ vocab_size=len(processor.tokenizer),
62
+ output_hidden_states=True,
63
+ )
64
+ wav2vec2_backbone = wav2vec2_backbone.to(device)
65
+
66
+ model = MultiTaskWav2Vec2(
67
+ wav2vec2_backbone=wav2vec2_backbone,
68
+ backbone_hidden_size=1024,
69
+ projection_hidden_size=256,
70
+ num_accent_class=3,
71
+ )
72
+
73
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
74
+ model.to(device)
75
+ model.eval()
76
+
77
+ self.multitask_model = model
78
+ self.processor = processor
79
+
80
+ def get_l2_phoneme_sequence(self, audio):
81
+ audio = audio.unsqueeze(0)
82
+ audio = self.processor(audio, sampling_rate=16000).input_values[0]
83
+ audio = torch.tensor(audio, device=self.device)
84
+
85
+ with torch.no_grad():
86
+ _, lm_logits, _, _ = self.multitask_model(audio)
87
+ lm_preds = torch.argmax(lm_logits, dim=-1)
88
+
89
+ # Decode output results
90
+ pred_decoded = self.processor.batch_decode(lm_preds)
91
+
92
+ pred_phones = pred_decoded[0].split(" ")
93
+
94
+ # remove sil and sp
95
+ pred_phones = [phone for phone in pred_phones if phone != "sil" and phone != "sp"]
96
+
97
+ return pred_phones
98
+
99
+ def standardise_g2p_phoneme_sequence(self, phones):
100
+ return phones
101
+
102
+ def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
103
+ return phones
104
+
105
  class Wav2Vec2PhonemeASRModel(PhonemeASRModel):
106
  """
107
  Uses greedy decoding
 
162
  def standardise_l2_artic_groundtruth_phoneme_sequence(self, phones):
163
  return [re.sub(r'\d', "", phone_str) for phone_str in phones]
164