xu1998hz commited on
Commit
fc60851
·
1 Parent(s): 5104b51

Delete sescore.py

Browse files
Files changed (1) hide show
  1. sescore.py +0 -139
sescore.py DELETED
@@ -1,139 +0,0 @@
1
- # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- """SEScore: a text generation evaluation metric """
15
-
16
- import evaluate
17
- import datasets
18
-
19
- import comet
20
- from typing import Dict
21
- import torch
22
- from comet.encoders.base import Encoder
23
- from comet.encoders.bert import BERTEncoder
24
- from transformers import AutoModel, AutoTokenizer
25
-
26
- class robertaEncoder(BERTEncoder):
27
- def __init__(self, pretrained_model: str) -> None:
28
- super(Encoder, self).__init__()
29
- self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
30
- self.model = AutoModel.from_pretrained(
31
- pretrained_model, add_pooling_layer=False
32
- )
33
- self.model.encoder.output_hidden_states = True
34
-
35
- @classmethod
36
- def from_pretrained(cls, pretrained_model: str) -> Encoder:
37
- return robertaEncoder(pretrained_model)
38
-
39
- def forward(
40
- self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
41
- ) -> Dict[str, torch.Tensor]:
42
- last_hidden_states, _, all_layers = self.model(
43
- input_ids=input_ids,
44
- attention_mask=attention_mask,
45
- output_hidden_states=True,
46
- return_dict=False,
47
- )
48
- return {
49
- "sentemb": last_hidden_states[:, 0, :],
50
- "wordemb": last_hidden_states,
51
- "all_layers": all_layers,
52
- "attention_mask": attention_mask,
53
- }
54
-
55
-
56
- # TODO: Add BibTeX citation
57
- _CITATION = """\
58
- @inproceedings{xu-etal-2022-not,
59
- title={Not All Errors are Equal: Learning Text Generation Metrics using Stratified Error Synthesis},
60
- author={Xu, Wenda and Tuan, Yi-lin and Lu, Yujie and Saxon, Michael and Li, Lei and Wang, William Yang},
61
- booktitle ={Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing},
62
- month={dec},
63
- year={2022},
64
- url={https://arxiv.org/abs/2210.05035}
65
- }
66
- """
67
-
68
- _DESCRIPTION = """\
69
- SEScore is an evaluation metric that trys to compute an overall score to measure text generation quality.
70
- """
71
-
72
- _KWARGS_DESCRIPTION = """
73
- Calculates how good are predictions given some references
74
- Args:
75
- predictions: list of candidate outputs
76
- references: list of references
77
- Returns:
78
- {"mean_score": mean_score, "scores": scores}
79
-
80
- Examples:
81
- >>> import evaluate
82
- >>> sescore = evaluate.load("xu1998hz/sescore")
83
- >>> score = sescore.compute(
84
- references=['sescore is a simple but effective next-generation text evaluation metric'],
85
- predictions=['sescore is simple effective text evaluation metric for next generation']
86
- )
87
- """
88
-
89
- # TODO: Define external resources urls if needed
90
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
91
-
92
-
93
- @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
94
- class SEScore(evaluate.Metric):
95
- """SEScore"""
96
-
97
- def _info(self):
98
- # TODO: Specifies the evaluate.EvaluationModuleInfo object
99
- return evaluate.MetricInfo(
100
- # This is the description that will appear on the modules page.
101
- module_type="metric",
102
- description=_DESCRIPTION,
103
- citation=_CITATION,
104
- inputs_description=_KWARGS_DESCRIPTION,
105
- # This defines the format of each prediction and reference
106
- features=datasets.Features({
107
- 'predictions': datasets.Value("string", id="sequence"),
108
- 'references': datasets.Value("string", id="sequence"),
109
- }),
110
- # Homepage of the module for documentation
111
- homepage="http://module.homepage",
112
- # Additional links to the codebase or references
113
- codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
114
- reference_urls=["http://path.to.reference.url/new_module"]
115
- )
116
-
117
- def _download_and_prepare(self, dl_manager):
118
- """download SEScore checkpoints to compute the scores"""
119
- # Download SEScore checkpoint
120
- from comet import load_from_checkpoint
121
- import os
122
- from huggingface_hub import snapshot_download
123
- # initialize roberta into str2encoder
124
- comet.encoders.str2encoder['RoBERTa'] = robertaEncoder
125
- print("config name: ", self.config_name)
126
- if self.config_name == "default":
127
- destination = snapshot_download(repo_id="xu1998hz/sescore_english_mt", revision="main")
128
- self.scorer = load_from_checkpoint(f'{destination}/checkpoint/sescore_english_mt.ckpt')
129
- else:
130
- print("Config name is not supported!")
131
-
132
- def _compute(self, predictions, references, gpus=None, progress_bar=False):
133
- if gpus is None:
134
- gpus = 1 if torch.cuda.is_available() else 0
135
-
136
- data = {"src": references, "mt": predictions}
137
- data = [dict(zip(data, t)) for t in zip(*data.values())]
138
- scores, mean_score = self.scorer.predict(data, gpus=gpus, progress_bar=progress_bar)
139
- return {"mean_score": mean_score, "scores": scores}