xu1998hz commited on
Commit
5ffcd76
1 Parent(s): fc60851

Create sescore_english_mt.py

Browse files
Files changed (1) hide show
  1. sescore_english_mt.py +139 -0
sescore_english_mt.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """SEScore: a text generation evaluation metric """
15
+
16
+ import evaluate
17
+ import datasets
18
+
19
+ import comet
20
+ from typing import Dict
21
+ import torch
22
+ from comet.encoders.base import Encoder
23
+ from comet.encoders.bert import BERTEncoder
24
+ from transformers import AutoModel, AutoTokenizer
25
+
26
+ class robertaEncoder(BERTEncoder):
27
+ def __init__(self, pretrained_model: str) -> None:
28
+ super(Encoder, self).__init__()
29
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
30
+ self.model = AutoModel.from_pretrained(
31
+ pretrained_model, add_pooling_layer=False
32
+ )
33
+ self.model.encoder.output_hidden_states = True
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, pretrained_model: str) -> Encoder:
37
+ return robertaEncoder(pretrained_model)
38
+
39
+ def forward(
40
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
41
+ ) -> Dict[str, torch.Tensor]:
42
+ last_hidden_states, _, all_layers = self.model(
43
+ input_ids=input_ids,
44
+ attention_mask=attention_mask,
45
+ output_hidden_states=True,
46
+ return_dict=False,
47
+ )
48
+ return {
49
+ "sentemb": last_hidden_states[:, 0, :],
50
+ "wordemb": last_hidden_states,
51
+ "all_layers": all_layers,
52
+ "attention_mask": attention_mask,
53
+ }
54
+
55
+
56
+ # TODO: Add BibTeX citation
57
+ _CITATION = """\
58
+ @inproceedings{xu-etal-2022-not,
59
+ title={Not All Errors are Equal: Learning Text Generation Metrics using Stratified Error Synthesis},
60
+ author={Xu, Wenda and Tuan, Yi-lin and Lu, Yujie and Saxon, Michael and Li, Lei and Wang, William Yang},
61
+ booktitle ={Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing},
62
+ month={dec},
63
+ year={2022},
64
+ url={https://arxiv.org/abs/2210.05035}
65
+ }
66
+ """
67
+
68
+ _DESCRIPTION = """\
69
+ SEScore is an evaluation metric that trys to compute an overall score to measure text generation quality.
70
+ """
71
+
72
+ _KWARGS_DESCRIPTION = """
73
+ Calculates how good are predictions given some references
74
+ Args:
75
+ predictions: list of candidate outputs
76
+ references: list of references
77
+ Returns:
78
+ {"mean_score": mean_score, "scores": scores}
79
+
80
+ Examples:
81
+ >>> import evaluate
82
+ >>> sescore = evaluate.load("xu1998hz/sescore")
83
+ >>> score = sescore.compute(
84
+ references=['sescore is a simple but effective next-generation text evaluation metric'],
85
+ predictions=['sescore is simple effective text evaluation metric for next generation']
86
+ )
87
+ """
88
+
89
+ # TODO: Define external resources urls if needed
90
+ BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
91
+
92
+
93
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
94
+ class SEScore(evaluate.Metric):
95
+ """SEScore"""
96
+
97
+ def _info(self):
98
+ # TODO: Specifies the evaluate.EvaluationModuleInfo object
99
+ return evaluate.MetricInfo(
100
+ # This is the description that will appear on the modules page.
101
+ module_type="metric",
102
+ description=_DESCRIPTION,
103
+ citation=_CITATION,
104
+ inputs_description=_KWARGS_DESCRIPTION,
105
+ # This defines the format of each prediction and reference
106
+ features=datasets.Features({
107
+ 'predictions': datasets.Value("string", id="sequence"),
108
+ 'references': datasets.Value("string", id="sequence"),
109
+ }),
110
+ # Homepage of the module for documentation
111
+ homepage="http://module.homepage",
112
+ # Additional links to the codebase or references
113
+ codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
114
+ reference_urls=["http://path.to.reference.url/new_module"]
115
+ )
116
+
117
+ def _download_and_prepare(self, dl_manager):
118
+ """download SEScore checkpoints to compute the scores"""
119
+ # Download SEScore checkpoint
120
+ from comet import load_from_checkpoint
121
+ import os
122
+ from huggingface_hub import snapshot_download
123
+ # initialize roberta into str2encoder
124
+ comet.encoders.str2encoder['RoBERTa'] = robertaEncoder
125
+ print("config name: ", self.config_name)
126
+ if self.config_name == "default":
127
+ destination = snapshot_download(repo_id="xu1998hz/sescore_english_mt", revision="main")
128
+ self.scorer = load_from_checkpoint(f'{destination}/checkpoint/sescore_english_mt.ckpt')
129
+ else:
130
+ print("Config name is not supported!")
131
+
132
+ def _compute(self, predictions, references, gpus=None, progress_bar=False):
133
+ if gpus is None:
134
+ gpus = 1 if torch.cuda.is_available() else 0
135
+
136
+ data = {"src": references, "mt": predictions}
137
+ data = [dict(zip(data, t)) for t in zip(*data.values())]
138
+ scores, mean_score = self.scorer.predict(data, gpus=gpus, progress_bar=progress_bar)
139
+ return {"mean_score": mean_score, "scores": scores}