from typing import Dict, List

import evaluate  # type: ignore[import]

from ..base_task_metrics import BaseTaskMetrics


class CMGMetrics(BaseTaskMetrics):
    def __init__(self):
        self.bleu = evaluate.load("sacrebleu")
        self.chrf = evaluate.load("chrf")
        self.rouge = evaluate.load("rouge")
        self.bertscore = evaluate.load("bertscore")
        self.bertscore_normalized = evaluate.load("bertscore")

    def add_batch(self, predictions: List[str], references: List[str], *args, **kwargs) -> None:
        self.bleu.add_batch(predictions=predictions, references=[[ref] for ref in references])
        self.chrf.add_batch(predictions=predictions, references=[[ref] for ref in references])
        self.rouge.add_batch(predictions=predictions, references=references)
        self.bertscore.add_batch(predictions=predictions, references=references)
        self.bertscore_normalized.add_batch(predictions=predictions, references=references)

    def compute(self, *args, **kwargs) -> Dict[str, float]:
        rouge = self.rouge.compute()
        bertscore = self.bertscore.compute(lang="en")
        bertscore_normalized = self.bertscore_normalized.compute(lang="en", rescale_with_baseline=True)
        return {
            "bleu": self.bleu.compute(tokenize="13a")["score"],
            "chrf": self.chrf.compute()["score"],
            "rouge1": rouge["rouge1"] * 100,
            "rouge2": rouge["rouge2"] * 100,
            "rougeL": rouge["rougeL"] * 100,
            "bertscore": sum(bertscore["f1"]) / len(bertscore["f1"]),
            "bertscore_normalized": sum(bertscore_normalized["f1"]) / len(bertscore_normalized["f1"]),
        }