kleervoyans commited on
Commit
c4d24a3
·
verified ·
1 Parent(s): 6bdbfeb

Update evaluators/evaluator.py

Browse files
Files changed (1) hide show
  1. evaluators/evaluator.py +55 -0
evaluators/evaluator.py CHANGED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import evaluate
3
+
4
+ class TranslationEvaluator:
5
+ def __init__(self):
6
+ self.bleu = evaluate.load("bleu")
7
+ self.bertscore = evaluate.load("bertscore")
8
+ # COMET MQM model
9
+ self.comet = evaluate.load("comet", model_id="unbabel/comet-mqm-qe-da")
10
+ logging.info("Loaded BLEU, BERTScore, COMET metrics")
11
+
12
+ def evaluate(self, sources, references, predictions):
13
+ """
14
+ - sources: List[str]
15
+ - references: List[str]
16
+ - predictions: List[str]
17
+ Returns a dict: { "BLEU": float, "BERTScore": float, "BERTurk": float, "COMET": float }
18
+ """
19
+ results = {}
20
+
21
+ # BLEU
22
+ results["BLEU"] = self.bleu.compute(
23
+ predictions=predictions,
24
+ references=[[r] for r in references]
25
+ )["bleu"]
26
+
27
+ # BERTScore (general, lang="xx")
28
+ bs = self.bertscore.compute(
29
+ predictions=predictions,
30
+ references=references,
31
+ lang="xx"
32
+ )
33
+ results["BERTScore"] = float(sum(bs["f1"]) / len(bs["f1"])) if bs["f1"] else 0.0
34
+
35
+ # BERTurk (lang="tr")
36
+ bs_tr = self.bertscore.compute(
37
+ predictions=predictions,
38
+ references=references,
39
+ lang="tr"
40
+ )
41
+ results["BERTurk"] = float(sum(bs_tr["f1"]) / len(bs_tr["f1"])) if bs_tr["f1"] else 0.0
42
+
43
+ # COMET (expects srcs, hyps, refs)
44
+ comet_out = self.comet.compute(
45
+ srcs=sources,
46
+ hyps=predictions,
47
+ refs=references
48
+ )
49
+ scores = comet_out.get("scores", None)
50
+ if isinstance(scores, list):
51
+ results["COMET"] = float(scores[0]) if scores else 0.0
52
+ else:
53
+ results["COMET"] = float(scores) if scores is not None else 0.0
54
+
55
+ return results