File size: 1,307 Bytes
be9d6f9
 
1b862fc
be9d6f9
3853a8e
8abf414
be9d6f9
8abf414
 
be9d6f9
8abf414
 
 
 
 
3853a8e
8abf414
1b862fc
 
8abf414
 
 
 
 
 
be9d6f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from comet import download_model, load_from_checkpoint
from sacrebleu.metrics import BLEU, CHRF, TER
from scores import LLM_eval

class multi_scores:
    def __init__(self, source_lang="English", target_lang="Chinese", domain="starcraft 2") -> None:
        self.comet_model = load_from_checkpoint(download_model("Unbabel/wmt22-comet-da"))
        self.bleu_model = BLEU(tokenize="zh")
        self.LLM_model = LLM_eval.init_evaluator(source_lang=source_lang, target_lang=target_lang, domain=domain)

    # The function to get the scores
    # src: orginal sentence
    # mt: machine translation
    # ref: reference translation
    def get_scores(self, src:str, mt:str, ref:str) -> dict:
        comet_score = self.comet_model.predict([{"src":src, "mt":mt, "ref":ref}], batch_size=8, gpus=0).scores[0]
        bleu_score = self.bleu_model.corpus_score([mt], [ref]).score
        llm_score, llm_explanation = LLM_eval.evaluate_prediction(src, ref, mt, self.LLM_model)
        return {'bleu_score':bleu_score, 'comet_score':comet_score, 'llm_score':llm_score, 'llm_explanation': llm_explanation}
    
if __name__ == "__main__":
    src = "this is an test sentences"
    mt = "这是一个测试句子。"
    ref = "这不是一个测试语句。"
    print(multi_scores().get_scores(src, mt, ref))