ViDove / evaluation /scores /multi_scores.py
pinnnn's picture
fix: Evaluator
1b862fc
raw
history blame
1.31 kB
from comet import download_model, load_from_checkpoint
from sacrebleu.metrics import BLEU, CHRF, TER
from scores import LLM_eval
class multi_scores:
def __init__(self, source_lang="English", target_lang="Chinese", domain="starcraft 2") -> None:
self.comet_model = load_from_checkpoint(download_model("Unbabel/wmt22-comet-da"))
self.bleu_model = BLEU(tokenize="zh")
self.LLM_model = LLM_eval.init_evaluator(source_lang=source_lang, target_lang=target_lang, domain=domain)
# The function to get the scores
# src: orginal sentence
# mt: machine translation
# ref: reference translation
def get_scores(self, src:str, mt:str, ref:str) -> dict:
comet_score = self.comet_model.predict([{"src":src, "mt":mt, "ref":ref}], batch_size=8, gpus=0).scores[0]
bleu_score = self.bleu_model.corpus_score([mt], [ref]).score
llm_score, llm_explanation = LLM_eval.evaluate_prediction(src, ref, mt, self.LLM_model)
return {'bleu_score':bleu_score, 'comet_score':comet_score, 'llm_score':llm_score, 'llm_explanation': llm_explanation}
if __name__ == "__main__":
src = "this is an test sentences"
mt = "这是一个测试句子。"
ref = "这不是一个测试语句。"
print(multi_scores().get_scores(src, mt, ref))