File size: 5,695 Bytes
9db8db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46c63a3
9db8db7
 
 
 
ba092fc
9db8db7
 
 
 
 
 
4e0f879
 
 
9db8db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1541189
 
9db8db7
 
1541189
 
9db8db7
 
 
 
7217d6a
 
495f38b
364be12
9db8db7
 
 
 
 
 
35bc035
9db8db7
 
 
 
 
 
 
 
 
1541189
 
9db8db7
 
 
 
 
 
 
 
18e88dc
9db8db7
 
 
4e0f879
7217d6a
364be12
9db8db7
7217d6a
 
 
9db8db7
 
46c63a3
7217d6a
008aa62
9db8db7
4e0f879
495f38b
4e0f879
 
9db8db7
ba092fc
 
 
 
 
 
 
 
 
 
 
495f38b
ba092fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9db8db7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright 2020 The HuggingFace Evaluate Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" MeaningBERT metric. """

from contextlib import contextmanager
from itertools import chain
from typing import List, Dict

import datasets
import evaluate
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


@contextmanager
def filter_logging_context():
    def filter_log(record):
        return (
            False if "This IS expected if you are initializing" in record.msg else True
        )

    logger = datasets.utils.logging.get_logger("transformers.modeling_utils")
    logger.addFilter(filter_log)
    try:
        yield
    finally:
        logger.removeFilter(filter_log)


_CITATION = """\
@ARTICLE{10.3389/frai.2023.1223924,
AUTHOR={Beauchemin, David and Saggion, Horacio and Khoury, Richard},    
TITLE={MeaningBERT: assessing meaning preservation between sentences},      
JOURNAL={Frontiers in Artificial Intelligence},      
VOLUME={6},           
YEAR={2023},      
URL={https://www.frontiersin.org/articles/10.3389/frai.2023.1223924},       
DOI={10.3389/frai.2023.1223924},      
ISSN={2624-8212},   
}
"""

_DESCRIPTION = """\
MeaningBERT is an automatic and trainable metric for assessing meaning preservation between sentences. MeaningBERT was
proposed in our
article [MeaningBERT: assessing meaning preservation between sentences](https://www.frontiersin.org/articles/10.3389/frai.2023.1223924/full).
Its goal is to assess meaning preservation between two sentences that correlate highly with human judgments and sanity 
checks. For more details, refer to our publicly available article.

See the project's README at https://github.com/GRAAL-Research/MeaningBERT for more information.
"""

_KWARGS_DESCRIPTION = """
MeaningBERT metric for assessing meaning preservation between sentences.

Args:
    predictions (list of str): Predictions sentences.
    references (list of str): References sentences (same number of element as predictions).

Returns:
    score: the meaning score between two sentences in alist format respecting the order of the predictions and 
    references pairs.
    hashcode: Hashcode of the library.

Examples:

    >>> references = ["hello there", "general kenobi"]
    >>> predictions = ["hello there", "general kenobi"]
    >>> meaning_bert = evaluate.load("davebulaval/meaningbert")
    >>> results = meaning_bert.compute(predictions=predictions, references=references)
"""

_HASH = "21845c0cc85a2e8e16c89bb0053f489095cf64c5b19e9c3865d3e10047aba51b"


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class MeaningBERT(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            homepage="https://github.com/GRAAL-Research/MeaningBERT",
            inputs_description=_KWARGS_DESCRIPTION,
            features=[
                datasets.Features(
                    {
                        "predictions": datasets.Value("string", id="sequence"),
                        "references": datasets.Value("string", id="sequence"),
                    }
                )
            ],
            codebase_urls=["https://github.com/GRAAL-Research/MeaningBERT"],
            reference_urls=[
                "https://github.com/GRAAL-Research/MeaningBERT",
                "https://www.frontiersin.org/articles/10.3389/frai.2023.1223924/full",
            ],
            module_type="metric",
        )

    def _compute(
        self,
        predictions: List,
        references: List,
    ) -> Dict:
        assert len(references) == len(
            predictions
        ), "The number of references is different of the number of predictions."
        hashcode = _HASH

        # Index of sentence with perfect match between two sentences
        matching_index = [i for i, item in enumerate(references) if item in predictions]

        # We load the MeaningBERT pretrained model
        scorer = AutoModelForSequenceClassification.from_pretrained(
            "davebulaval/MeaningBERT"
        )
        scorer.eval()

        with torch.no_grad():
            # We load MeaningBERT tokenizer
            tokenizer = AutoTokenizer.from_pretrained("davebulaval/MeaningBERT")

            # We tokenize the text as a pair and return Pytorch Tensors
            tokenize_text = tokenizer(
                references,
                predictions,
                truncation=True,
                padding=True,
                return_tensors="pt",
            )

            with filter_logging_context():
                # We process the text
                scores = scorer(**tokenize_text)

            scores = scores.logits.tolist()

            # Flatten the list of list of logits
            scores = list(chain(*scores))

            # Handle case of perfect match
            if len(matching_index) > 0:
                for matching_element_index in matching_index:
                    scores[matching_element_index] = 100

            output_dict = {
                "scores": scores,
                "hashcode": hashcode,
            }
        return output_dict