|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
from torch import Tensor |
|
|
|
from hw_asr.base.base_metric import BaseMetric |
|
from hw_asr.base.base_text_encoder import BaseTextEncoder |
|
from hw_asr.metric.utils import calc_cer |
|
|
|
|
|
class ArgmaxCERMetric(BaseMetric): |
|
def __init__(self, text_encoder: BaseTextEncoder, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.text_encoder = text_encoder |
|
|
|
def __call__(self, log_probs: Tensor, log_probs_length: Tensor, text: List[str], **kwargs): |
|
cers = [] |
|
predictions = torch.argmax(log_probs.cpu(), dim=-1).numpy() |
|
lengths = log_probs_length.detach().numpy() |
|
for log_prob_vec, length, target_text in zip(predictions, lengths, text): |
|
target_text = BaseTextEncoder.normalize_text(target_text) |
|
if hasattr(self.text_encoder, "ctc_decode"): |
|
pred_text = self.text_encoder.ctc_decode(log_prob_vec[:length]) |
|
else: |
|
pred_text = self.text_encoder.decode(log_prob_vec[:length]) |
|
cers.append(calc_cer(target_text, pred_text)) |
|
return sum(cers) / len(cers) |
|
|
|
|
|
class BeamSearchCERMetric(BaseMetric): |
|
def __init__(self, text_encoder: BaseTextEncoder, beam_size: int, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.text_encoder = text_encoder |
|
self.beam_size = beam_size |
|
|
|
def __call__(self, log_probs: Tensor, log_probs_length: Tensor, text: List[str], **kwargs): |
|
cers = [] |
|
probs = np.exp(log_probs.detach().cpu().numpy()) |
|
lengths = log_probs_length.detach().numpy() |
|
for prob, length, target_text in zip(probs, lengths, text): |
|
target_text = BaseTextEncoder.normalize_text(target_text) |
|
if hasattr(self.text_encoder, "ctc_beam_search"): |
|
pred_text = self.text_encoder.ctc_beam_search(prob[:length], self.beam_size) |
|
else: |
|
assert False |
|
cers.append(calc_cer(target_text, pred_text)) |
|
return sum(cers) / len(cers) |
|
|
|
|
|
class LanguageModelCERMetric(BaseMetric): |
|
def __init__(self, text_encoder: BaseTextEncoder, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.text_encoder = text_encoder |
|
|
|
def __call__(self, logits: Tensor, log_probs_length: Tensor, text: List[str], **kwargs): |
|
cers = [] |
|
logits = logits.detach().cpu().numpy() |
|
lengths = log_probs_length.detach().numpy() |
|
for logit, length, target_text in zip(logits, lengths, text): |
|
target_text = BaseTextEncoder.normalize_text(target_text) |
|
if hasattr(self.text_encoder, "ctc_lm_beam_search"): |
|
pred_text = self.text_encoder.ctc_lm_beam_search(logit[:length]) |
|
else: |
|
assert False |
|
cers.append(calc_cer(target_text, pred_text)) |
|
return sum(cers) / len(cers) |
|
|
|
|