tgritsaev's picture
Upload 198 files
affcd23 verified
from typing import List
import numpy as np
import torch
from torch import Tensor
from hw_asr.base.base_metric import BaseMetric
from hw_asr.base.base_text_encoder import BaseTextEncoder
from hw_asr.metric.utils import calc_cer
class ArgmaxCERMetric(BaseMetric):
def __init__(self, text_encoder: BaseTextEncoder, *args, **kwargs):
super().__init__(*args, **kwargs)
self.text_encoder = text_encoder
def __call__(self, log_probs: Tensor, log_probs_length: Tensor, text: List[str], **kwargs):
cers = []
predictions = torch.argmax(log_probs.cpu(), dim=-1).numpy()
lengths = log_probs_length.detach().numpy()
for log_prob_vec, length, target_text in zip(predictions, lengths, text):
target_text = BaseTextEncoder.normalize_text(target_text)
if hasattr(self.text_encoder, "ctc_decode"):
pred_text = self.text_encoder.ctc_decode(log_prob_vec[:length])
else:
pred_text = self.text_encoder.decode(log_prob_vec[:length])
cers.append(calc_cer(target_text, pred_text))
return sum(cers) / len(cers)
class BeamSearchCERMetric(BaseMetric):
def __init__(self, text_encoder: BaseTextEncoder, beam_size: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.text_encoder = text_encoder
self.beam_size = beam_size
def __call__(self, log_probs: Tensor, log_probs_length: Tensor, text: List[str], **kwargs):
cers = []
probs = np.exp(log_probs.detach().cpu().numpy())
lengths = log_probs_length.detach().numpy()
for prob, length, target_text in zip(probs, lengths, text):
target_text = BaseTextEncoder.normalize_text(target_text)
if hasattr(self.text_encoder, "ctc_beam_search"):
pred_text = self.text_encoder.ctc_beam_search(prob[:length], self.beam_size)
else:
assert False
cers.append(calc_cer(target_text, pred_text))
return sum(cers) / len(cers)
class LanguageModelCERMetric(BaseMetric):
def __init__(self, text_encoder: BaseTextEncoder, *args, **kwargs):
super().__init__(*args, **kwargs)
self.text_encoder = text_encoder
def __call__(self, logits: Tensor, log_probs_length: Tensor, text: List[str], **kwargs):
cers = []
logits = logits.detach().cpu().numpy()
lengths = log_probs_length.detach().numpy()
for logit, length, target_text in zip(logits, lengths, text):
target_text = BaseTextEncoder.normalize_text(target_text)
if hasattr(self.text_encoder, "ctc_lm_beam_search"):
pred_text = self.text_encoder.ctc_lm_beam_search(logit[:length])
else:
assert False
cers.append(calc_cer(target_text, pred_text))
return sum(cers) / len(cers)