alvations commited on
Commit
33f2ebb
·
1 Parent(s): 554f046

added comet da

Browse files
Files changed (1) hide show
  1. cometda.py +81 -0
cometda.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+
4
+ import datasets
5
+ import evaluate
6
+ from huggingface_hub import snapshot_download, login
7
+
8
+ from comet.models.multitask.unified_metric import UnifiedMetric
9
+
10
+
11
+ login(token="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")
12
+
13
+ _CITATION = """\
14
+ @inproceedings{rei-etal-2022-comet,
15
+ title = "{COMET}-22: Unbabel-{IST} 2022 Submission for the Metrics Shared Task",
16
+ author = "Rei, Ricardo and
17
+ C. de Souza, Jos{\'e} G. and
18
+ Alves, Duarte and
19
+ Zerva, Chrysoula and
20
+ Farinha, Ana C and
21
+ Glushkova, Taisiya and
22
+ Lavie, Alon and
23
+ Coheur, Luisa and
24
+ Martins, Andr{\'e} F. T.",
25
+ booktitle = "Proceedings of the Seventh Conference on Machine Translation (WMT)",
26
+ month = dec,
27
+ year = "2022",
28
+ address = "Abu Dhabi, United Arab Emirates (Hybrid)",
29
+ publisher = "Association for Computational Linguistics",
30
+ url = "https://aclanthology.org/2022.wmt-1.52",
31
+ pages = "578--585",
32
+ }
33
+ """
34
+
35
+
36
+ _DESCRIPTION = """\
37
+ From https://huggingface.co/Unbabel/unite-mup
38
+ """
39
+
40
+ class COMETDA(evaluate.Metric):
41
+ def _info(self):
42
+ return evaluate.MetricInfo(
43
+ description=_DESCRIPTION,
44
+ citation=_CITATION,
45
+ features=datasets.Features(
46
+ {
47
+ "predictions": datasets.Value("string"),
48
+ "references": datasets.Value("string"),
49
+ }
50
+ ),
51
+ )
52
+
53
+ def _download_and_prepare(self, dl_manager):
54
+ try:
55
+ model_checkpoint_path = next(pathlib.Path('./models--Unbabel--wmt22-cometkiwi-da/').rglob('*.ckpt'))
56
+ self.model = UnifiedMetric.load_from_checkpoint(model_checkpoint_path)
57
+ except:
58
+ model_path = snapshot_download(repo_id="Unbabel/wmt22-cometkiwi-da", cache_dir=os.path.abspath(os.path.dirname('.')))
59
+ model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt"
60
+ self.model = UnifiedMetric.load_from_checkpoint(model_checkpoint_path)
61
+
62
+
63
+ def _compute(
64
+ self,
65
+ predictions,
66
+ references,
67
+ data_keys=None,
68
+ ): # Allows user to use either source inputs or reference translations as ground truth.
69
+ data = [{data_keys[0]: p, data_keys[1]: r} for p, r in zip(predictions, references)]
70
+ return {"scores": self.model.predict(data, batch_size=8).scores}
71
+
72
+
73
+ def compute_triplet(
74
+ self,
75
+ predictions,
76
+ references,
77
+ sources,
78
+ ): # Unified scores, uses sources, hypotheses and references.
79
+ data = [{"src": s, "mt": p, "ref": r} for s, p, r in zip(sources, predictions, references)]
80
+ return {"scores": self.model.predict(data, batch_size=8).metadata.unified_scores}
81
+