|
from typing import Tuple |
|
import torch |
|
from pytorch_lightning import LightningModule |
|
from torchmetrics import MeanMetric |
|
from transformers import ( |
|
Wav2Vec2ForPreTraining, |
|
Wav2Vec2CTCTokenizer, |
|
Wav2Vec2FeatureExtractor, |
|
) |
|
|
|
from src.utils.metrics import character_error_rate, word_error_rate |
|
from src.utils.scheduler import TriStateScheduler |
|
|
|
|
|
class SpeechRecognizer(LightningModule): |
|
def __init__( |
|
self, |
|
wav2vec2: Wav2Vec2ForPreTraining, |
|
tokenizer: Wav2Vec2CTCTokenizer, |
|
feature_extractor: Wav2Vec2FeatureExtractor, |
|
adam_config: dict, |
|
tristate_scheduler_config: dict, |
|
): |
|
super().__init__() |
|
|
|
self.hidden_size = wav2vec2.config.proj_codevector_dim |
|
self.vocab_size = tokenizer.vocab_size |
|
|
|
self.wav2vec2 = wav2vec2 |
|
self.wav2vec2.freeze_feature_encoder() |
|
self.tokenizer = tokenizer |
|
self.feature_extractor = feature_extractor |
|
|
|
self.adam_config = adam_config |
|
self.tristate_scheduler_config = tristate_scheduler_config |
|
|
|
self.dropout = torch.nn.Dropout(0.1) |
|
self.fc = torch.nn.Sequential( |
|
torch.nn.Linear(self.hidden_size, self.hidden_size // 2), |
|
torch.nn.ReLU(inplace=True), |
|
torch.nn.Linear(self.hidden_size // 2, self.vocab_size), |
|
) |
|
|
|
self.criterion = torch.nn.CTCLoss(blank=tokenizer.pad_token_id, zero_infinity=True) |
|
|
|
self.train_loss = MeanMetric() |
|
|
|
self.save_hyperparameters(ignore=["wav2vec2", "tokenizer", "feature_extractor"]) |
|
|
|
def forward(self, waveforms: Tuple[torch.Tensor], transcripts: Tuple[str] = None): |
|
|
|
waveforms = tuple(waveform.cpu().numpy() for waveform in waveforms) |
|
|
|
input_values, attention_mask = self.feature_extractor( |
|
waveforms, |
|
sampling_rate=16000, |
|
padding=True, |
|
return_tensors="pt", |
|
return_attention_mask=True, |
|
).values() |
|
|
|
input_values = input_values.to(self.device) |
|
attention_mask = attention_mask.to(self.device) |
|
|
|
|
|
hidden_states = self.wav2vec2( |
|
input_values, |
|
attention_mask=attention_mask, |
|
)[0] |
|
|
|
hidden_states = self.dropout(hidden_states) |
|
|
|
|
|
logits = self.fc(hidden_states) |
|
|
|
|
|
input_lengths = self.wav2vec2._get_feat_extract_output_lengths( |
|
attention_mask.sum(-1) |
|
).long() |
|
|
|
if transcripts is not None: |
|
|
|
target_ids, target_lengths = self.tokenizer( |
|
transcripts, |
|
padding=True, |
|
return_length=True, |
|
return_attention_mask=False, |
|
return_tensors="pt", |
|
).values() |
|
|
|
target_ids = target_ids.to(self.device) |
|
assert ( |
|
target_ids < self.tokenizer.vocab_size |
|
).all(), "target_ids is out of range" |
|
|
|
target_lengths = target_lengths.to(self.device) |
|
assert ( |
|
target_lengths <= logits.size(1) |
|
).all(), "target_lengths is out of range" |
|
|
|
|
|
log_probs = torch.nn.functional.log_softmax(logits, dim=-1).transpose(0, 1) |
|
|
|
|
|
loss = self.criterion(log_probs, target_ids, input_lengths, target_lengths) |
|
|
|
return loss, logits, input_lengths |
|
else: |
|
return logits, input_lengths |
|
|
|
@staticmethod |
|
def _get_predicted_ids(logits: torch.Tensor, lengths: torch.Tensor): |
|
|
|
|
|
|
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
|
|
|
|
predicted_ids = [ |
|
predicted_id[:length] |
|
for predicted_id, length in zip(predicted_ids, lengths) |
|
] |
|
|
|
return predicted_ids |
|
|
|
def training_step(self, batch, batch_idx): |
|
transcripts, waveforms = batch |
|
|
|
loss = self(waveforms, transcripts)[0] |
|
|
|
self.train_loss(loss) |
|
|
|
if self.global_step % 500 == 0: |
|
self.log("train/loss", self.train_loss, on_step=True, on_epoch=True) |
|
|
|
return loss |
|
|
|
def on_train_epoch_end(self) -> None: |
|
self.train_loss.reset() |
|
|
|
def validation_step(self, batch, batch_idx): |
|
transcripts, waveforms = batch |
|
|
|
logits, seq_lengths = self(waveforms) |
|
|
|
predicted_ids = self._get_predicted_ids(logits, seq_lengths) |
|
predicted_texts = self.tokenizer.batch_decode( |
|
predicted_ids, skip_special_tokens=True |
|
) |
|
|
|
wer = word_error_rate(predicted_texts, transcripts) |
|
cer = character_error_rate(predicted_texts, transcripts) |
|
|
|
return wer, cer |
|
|
|
def validation_epoch_end(self, outputs): |
|
wer, cer = zip(*outputs) |
|
|
|
wer = sum(wer) / len(wer) |
|
cer = sum(cer) / len(cer) |
|
|
|
self.log("val/wer", wer, on_epoch=True) |
|
self.log("val/cer", cer, on_epoch=True) |
|
|
|
@torch.no_grad() |
|
def predict(self, waveforms: Tuple[torch.Tensor]): |
|
logits, seq_lengths = self(waveforms) |
|
|
|
predicted_ids = self._get_predicted_ids(logits, seq_lengths) |
|
predicted_texts = self.tokenizer.batch_decode( |
|
predicted_ids, skip_special_tokens=True |
|
) |
|
|
|
return predicted_texts |
|
|
|
def configure_optimizers(self): |
|
optimizer = torch.optim.AdamW( |
|
params=[ |
|
{ |
|
"params": self.wav2vec2.parameters(), |
|
"lr": self.adam_config["wav2vec2_lr"], |
|
}, |
|
{ |
|
"params": self.fc.parameters(), |
|
"lr": self.adam_config["classifier_lr"], |
|
}, |
|
], |
|
weight_decay=self.adam_config["weight_decay"], |
|
) |
|
|
|
scheduler = TriStateScheduler(optimizer, **self.tristate_scheduler_config) |
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": { |
|
"scheduler": scheduler, |
|
"interval": "step", |
|
"frequency": 1, |
|
}, |
|
} |
|
|