Spaces:
Sleeping
Sleeping
refactor: allow custom Encoder instances
Browse files- README.md +2 -2
- encoder_models.py +48 -40
- semf1.py +157 -108
- tests.py +135 -76
README.md
CHANGED
@@ -59,8 +59,8 @@ Sem-F1 takes 2 mandatory arguments:
|
|
59 |
Sem-F1 also accepts multiple optional arguments:
|
60 |
|
61 |
|
62 |
-
- `model_type (str)`: Model to use for encoding sentences. Options: ['pv1' ([paraphrase-distilroberta-base-v1](https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1)), 'stsb' ([stsb-roberta-large](https://huggingface.co/sentence-transformers/stsb-roberta-large)), 'use' ([Universal Sentence Encoder](https://huggingface.co/sentence-transformers/use-cmlm-multilingual)) (Default)]. Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by SentenceTransformer
|
63 |
-
such as `all-mpnet-base-v2` or `roberta-base`.
|
64 |
- `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
65 |
- `multi_references (bool)`: Flag to indicate whether multiple references are provided. Default: False.
|
66 |
- `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU or multiple-processes for computation.
|
|
|
59 |
Sem-F1 also accepts multiple optional arguments:
|
60 |
|
61 |
|
62 |
+
- `model_type (Optional[Union[str, Encoder]])`: Model to use for encoding sentences. Options: ['pv1' ([paraphrase-distilroberta-base-v1](https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1)), 'stsb' ([stsb-roberta-large](https://huggingface.co/sentence-transformers/stsb-roberta-large)), 'use' ([Universal Sentence Encoder](https://huggingface.co/sentence-transformers/use-cmlm-multilingual)) (Default)]. Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by SentenceTransformer
|
63 |
+
such as `all-mpnet-base-v2` or `roberta-base`. Users can also pass a custom `Encoder` which must implement the `encode` method. Refer SemF1/encoder_models.py
|
64 |
- `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
65 |
- `multi_references (bool)`: Flag to indicate whether multiple references are provided. Default: False.
|
66 |
- `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU or multiple-processes for computation.
|
encoder_models.py
CHANGED
@@ -9,68 +9,83 @@ from .type_aliases import ENCODER_DEVICE_TYPE
|
|
9 |
|
10 |
class Encoder(abc.ABC):
|
11 |
@abc.abstractmethod
|
12 |
-
def encode(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
"""
|
14 |
-
|
15 |
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
|
22 |
-
|
23 |
-
|
24 |
"""
|
25 |
raise NotImplementedError("Method 'encode' must be implemented in subclass.")
|
26 |
|
27 |
|
28 |
class SBertEncoder(Encoder):
|
29 |
-
def __init__(self, model_name: str
|
30 |
"""
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
|
36 |
-
batch_size (int): Batch size for encoding.
|
37 |
-
verbose (bool): Whether to print verbose information during encoding.
|
38 |
"""
|
39 |
self.model = SentenceTransformer(model_name, trust_remote_code=True)
|
40 |
-
self.device = device
|
41 |
-
self.batch_size = batch_size
|
42 |
-
self.verbose = verbose
|
43 |
|
44 |
-
def encode(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
"""
|
46 |
-
|
47 |
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
"""
|
54 |
|
55 |
# SBert output is always Batch x Dim
|
56 |
-
if isinstance(
|
57 |
# Use multiprocess encoding for list of devices
|
58 |
-
pool = self.model.start_multi_process_pool(target_devices=
|
59 |
-
embeddings = self.model.encode_multi_process(
|
|
|
|
|
60 |
self.model.stop_multi_process_pool(pool)
|
61 |
else:
|
62 |
# Single device encoding
|
63 |
embeddings = self.model.encode(
|
64 |
prediction,
|
65 |
-
device=
|
66 |
-
batch_size=
|
67 |
-
show_progress_bar=
|
68 |
)
|
69 |
-
|
70 |
return embeddings
|
71 |
|
72 |
|
73 |
-
def get_encoder(model_name: str
|
74 |
"""
|
75 |
Get the encoder instance based on the specified model name.
|
76 |
|
@@ -83,11 +98,6 @@ def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, v
|
|
83 |
Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
|
84 |
SentenceTransformer.
|
85 |
|
86 |
-
device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
|
87 |
-
(e.g., "cuda", 0 for GPU, "cpu").
|
88 |
-
batch_size (int): Batch size for encoding.
|
89 |
-
verbose (bool): Whether to print verbose information during encoder initialization.
|
90 |
-
|
91 |
Returns:
|
92 |
Encoder: Instance of the selected encoder based on the model_name.
|
93 |
|
@@ -96,12 +106,10 @@ def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, v
|
|
96 |
"""
|
97 |
|
98 |
try:
|
99 |
-
encoder = SBertEncoder(model_name, device, batch_size, verbose)
|
100 |
except EnvironmentError as err:
|
101 |
raise EnvironmentError(str(err)) from None
|
102 |
except Exception as err:
|
103 |
raise RuntimeError(str(err)) from None
|
104 |
|
105 |
return encoder
|
106 |
-
|
107 |
-
|
|
|
9 |
|
10 |
class Encoder(abc.ABC):
|
11 |
@abc.abstractmethod
|
12 |
+
def encode(
|
13 |
+
self,
|
14 |
+
prediction: List[str],
|
15 |
+
*,
|
16 |
+
device: ENCODER_DEVICE_TYPE = "cpu",
|
17 |
+
batch_size: int = 32,
|
18 |
+
verbose: bool = False,
|
19 |
+
) -> NDArray:
|
20 |
"""
|
21 |
+
Abstract method to encode a list of sentences into sentence embeddings.
|
22 |
|
23 |
+
Args:
|
24 |
+
prediction (List[str]): List of sentences to encode.
|
25 |
+
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding.
|
26 |
+
batch_size (int): Batch size for encoding.
|
27 |
+
verbose (bool): Whether to print verbose information during encoding.
|
28 |
|
29 |
+
Returns:
|
30 |
+
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
|
31 |
|
32 |
+
Raises:
|
33 |
+
NotImplementedError: If the method is not implemented in the subclass.
|
34 |
"""
|
35 |
raise NotImplementedError("Method 'encode' must be implemented in subclass.")
|
36 |
|
37 |
|
38 |
class SBertEncoder(Encoder):
|
39 |
+
def __init__(self, model_name: str):
|
40 |
"""
|
41 |
+
Initialize SBertEncoder instance.
|
42 |
|
43 |
+
Args:
|
44 |
+
model_name (str): Name or path of the Sentence Transformer model.
|
|
|
|
|
|
|
45 |
"""
|
46 |
self.model = SentenceTransformer(model_name, trust_remote_code=True)
|
|
|
|
|
|
|
47 |
|
48 |
+
def encode(
|
49 |
+
self,
|
50 |
+
prediction: List[str],
|
51 |
+
*,
|
52 |
+
device: ENCODER_DEVICE_TYPE = "cpu",
|
53 |
+
batch_size: int = 32,
|
54 |
+
verbose: bool = False,
|
55 |
+
) -> NDArray:
|
56 |
"""
|
57 |
+
Encode a list of sentences into sentence embeddings.
|
58 |
|
59 |
+
Args:
|
60 |
+
prediction (List[str]): List of sentences to encode.
|
61 |
+
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
|
62 |
+
batch_size (int): Batch size for encoding.
|
63 |
+
verbose (bool): Whether to print verbose information during encoding.
|
64 |
|
65 |
+
Returns:
|
66 |
+
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
|
67 |
"""
|
68 |
|
69 |
# SBert output is always Batch x Dim
|
70 |
+
if isinstance(device, list):
|
71 |
# Use multiprocess encoding for list of devices
|
72 |
+
pool = self.model.start_multi_process_pool(target_devices=device)
|
73 |
+
embeddings = self.model.encode_multi_process(
|
74 |
+
prediction, pool=pool, batch_size=batch_size
|
75 |
+
)
|
76 |
self.model.stop_multi_process_pool(pool)
|
77 |
else:
|
78 |
# Single device encoding
|
79 |
embeddings = self.model.encode(
|
80 |
prediction,
|
81 |
+
device=device,
|
82 |
+
batch_size=batch_size,
|
83 |
+
show_progress_bar=verbose,
|
84 |
)
|
|
|
85 |
return embeddings
|
86 |
|
87 |
|
88 |
+
def get_encoder(model_name: str) -> Encoder:
|
89 |
"""
|
90 |
Get the encoder instance based on the specified model name.
|
91 |
|
|
|
98 |
Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
|
99 |
SentenceTransformer.
|
100 |
|
|
|
|
|
|
|
|
|
|
|
101 |
Returns:
|
102 |
Encoder: Instance of the selected encoder based on the model_name.
|
103 |
|
|
|
106 |
"""
|
107 |
|
108 |
try:
|
109 |
+
encoder = SBertEncoder(model_name) # , device, batch_size, verbose)
|
110 |
except EnvironmentError as err:
|
111 |
raise EnvironmentError(str(err)) from None
|
112 |
except Exception as err:
|
113 |
raise RuntimeError(str(err)) from None
|
114 |
|
115 |
return encoder
|
|
|
|
semf1.py
CHANGED
@@ -16,7 +16,7 @@ Sem-F1 metric
|
|
16 |
Author: Naman Bansal
|
17 |
"""
|
18 |
|
19 |
-
from typing import List, Optional, Tuple
|
20 |
|
21 |
import datasets
|
22 |
import evaluate
|
@@ -25,9 +25,16 @@ import numpy as np
|
|
25 |
from numpy.typing import NDArray
|
26 |
from sklearn.metrics.pairwise import cosine_similarity
|
27 |
|
28 |
-
from .encoder_models import get_encoder
|
29 |
from .type_aliases import DEVICE_TYPE, PREDICTION_TYPE, REFERENCE_TYPE
|
30 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
_CITATION = """\
|
33 |
@inproceedings{bansal-etal-2022-sem,
|
@@ -63,13 +70,15 @@ using precision, recall, and F1 score based on sentence embeddings.
|
|
63 |
Args:
|
64 |
predictions (list): List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
65 |
references (list): List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
66 |
-
model_type (str): Model to use for encoding sentences.
|
67 |
-
pv1
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
73 |
tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
74 |
multi_references (bool): Flag to indicate whether multiple references are provided. Default is False.
|
75 |
gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
|
@@ -151,19 +160,21 @@ Examples:
|
|
151 |
"""
|
152 |
|
153 |
|
154 |
-
def _compute_cosine_similarity(
|
|
|
|
|
155 |
"""
|
156 |
-
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
# Compute cosine similarity between predicted and reference embeddings
|
168 |
cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
|
169 |
|
@@ -181,60 +192,65 @@ def _compute_cosine_similarity(pred_embeds: NDArray, ref_embeds: NDArray) -> Tup
|
|
181 |
|
182 |
|
183 |
def _validate_input_format(
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
):
|
189 |
"""
|
190 |
-
|
191 |
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
|
198 |
-
|
199 |
-
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
|
219 |
-
|
220 |
-
|
221 |
|
222 |
-
|
223 |
-
|
224 |
"""
|
225 |
|
226 |
if len(predictions) != len(references):
|
227 |
-
raise ValueError(
|
228 |
-
|
|
|
|
|
229 |
|
230 |
if len(predictions) == 0:
|
231 |
raise ValueError("Can't have empty inputs")
|
232 |
|
233 |
def check_format(lst_obj, expected_depth: int, name: str):
|
234 |
-
is_valid, error_message = is_nested_list_of_type(
|
|
|
|
|
235 |
if not is_valid:
|
236 |
-
raise ValueError(
|
237 |
-
|
|
|
238 |
|
239 |
try:
|
240 |
if tokenize_sentences and multi_references:
|
@@ -274,9 +290,13 @@ class SemF1(evaluate.Metric):
|
|
274 |
datasets.Features(
|
275 |
{
|
276 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
277 |
-
"predictions": datasets.Sequence(
|
|
|
|
|
278 |
# references: List[List[str]] - List of references where each reference is a list of sentences
|
279 |
-
"references": datasets.Sequence(
|
|
|
|
|
280 |
}
|
281 |
),
|
282 |
# F1: Multi References: False, Tokenize_Sentences = True
|
@@ -292,12 +312,18 @@ class SemF1(evaluate.Metric):
|
|
292 |
datasets.Features(
|
293 |
{
|
294 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
295 |
-
"predictions": datasets.Sequence(
|
|
|
|
|
296 |
# references: List[List[List[str]]] - List of multi-references.
|
297 |
# So each "reference" is also a list (r1, r2, ...).
|
298 |
# Further, each ri's are also list of sentences.
|
299 |
"references": datasets.Sequence(
|
300 |
-
datasets.Sequence(
|
|
|
|
|
|
|
|
|
301 |
}
|
302 |
),
|
303 |
# F3: Multi References: True, Tokenize_Sentences = True
|
@@ -307,13 +333,15 @@ class SemF1(evaluate.Metric):
|
|
307 |
"predictions": datasets.Value("string", id="sequence"),
|
308 |
# references: List[List[List[str]]] - List of multi-references.
|
309 |
# So each "reference" is also a list (r1, r2, ...).
|
310 |
-
"references": datasets.Sequence(
|
|
|
|
|
311 |
}
|
312 |
),
|
313 |
],
|
314 |
# # Homepage of the module for documentation
|
315 |
# Additional links to the codebase or references
|
316 |
-
reference_urls=["https://aclanthology.org/2022.emnlp-main.49/"]
|
317 |
)
|
318 |
|
319 |
def _get_model_name(self, model_type: Optional[str] = None) -> str:
|
@@ -328,51 +356,62 @@ class SemF1(evaluate.Metric):
|
|
328 |
def _download_and_prepare(self, dl_manager):
|
329 |
"""Optional: download external resources useful to compute the scores"""
|
330 |
import nltk
|
|
|
331 |
nltk.download("punkt", quiet=True)
|
332 |
|
333 |
def _compute(
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
) -> List[Scores]:
|
345 |
"""
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
"""
|
367 |
|
368 |
# Note: I have to specifically handle this case because the library considers the feature corresponding to
|
369 |
# this case (F2) as the feature for the other case (F0) i.e. it can't make any distinction between
|
370 |
# List[str] and List[List[str]]
|
371 |
if not tokenize_sentences and multi_references:
|
372 |
-
references = [
|
|
|
|
|
373 |
|
374 |
# Validate inputs corresponding to flags
|
375 |
-
_validate_input_format(
|
|
|
|
|
376 |
|
377 |
# Get GPU
|
378 |
device = get_gpu(gpu)
|
@@ -380,8 +419,15 @@ class SemF1(evaluate.Metric):
|
|
380 |
print(f"Using devices: {device}")
|
381 |
|
382 |
# Get the encoder model
|
383 |
-
|
384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
# We'll handle the single reference and multi-reference case same way. So change the data format accordingly
|
387 |
if not multi_references:
|
@@ -401,11 +447,15 @@ class SemF1(evaluate.Metric):
|
|
401 |
|
402 |
# Note: This is the most optimal way of doing it
|
403 |
# Encode all sentences in one go
|
404 |
-
embeddings = encoder.encode(
|
|
|
|
|
405 |
|
406 |
# Get embeddings corresponding to predictions and references
|
407 |
pred_embeddings = slice_embeddings(embeddings, prediction_sentences_count)
|
408 |
-
ref_embeddings = slice_embeddings(
|
|
|
|
|
409 |
|
410 |
# Init output scores
|
411 |
results = []
|
@@ -418,23 +468,22 @@ class SemF1(evaluate.Metric):
|
|
418 |
precision = np.clip(precision, a_min=0.0, a_max=1.0).item()
|
419 |
|
420 |
# Recall: Compute individually for each reference
|
421 |
-
recall_scores = [
|
422 |
-
|
|
|
|
|
|
|
|
|
423 |
|
424 |
results.append(Scores(precision, recall_scores))
|
425 |
|
426 |
# run aggregation procedure
|
427 |
if aggregate:
|
428 |
-
mean_prec = np.mean(
|
429 |
-
|
430 |
-
|
431 |
-
mean_recall = np.mean(np.concatenate(
|
432 |
-
[np.array(score.recall) for score in results]
|
433 |
-
))
|
434 |
-
aggregated_score = Scores(
|
435 |
-
float(mean_prec),
|
436 |
-
[float(mean_recall)]
|
437 |
)
|
|
|
438 |
results = aggregated_score
|
439 |
|
440 |
return results
|
|
|
16 |
Author: Naman Bansal
|
17 |
"""
|
18 |
|
19 |
+
from typing import List, Optional, Tuple, Union
|
20 |
|
21 |
import datasets
|
22 |
import evaluate
|
|
|
25 |
from numpy.typing import NDArray
|
26 |
from sklearn.metrics.pairwise import cosine_similarity
|
27 |
|
28 |
+
from .encoder_models import get_encoder, Encoder
|
29 |
from .type_aliases import DEVICE_TYPE, PREDICTION_TYPE, REFERENCE_TYPE
|
30 |
+
from .utils import (
|
31 |
+
is_nested_list_of_type,
|
32 |
+
Scores,
|
33 |
+
slice_embeddings,
|
34 |
+
flatten_list,
|
35 |
+
get_gpu,
|
36 |
+
sent_tokenize,
|
37 |
+
)
|
38 |
|
39 |
_CITATION = """\
|
40 |
@inproceedings{bansal-etal-2022-sem,
|
|
|
70 |
Args:
|
71 |
predictions (list): List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
72 |
references (list): List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
|
73 |
+
model_type (Optional[Union[str, Encoder]]): Model to use for encoding sentences.
|
74 |
+
Options: ['pv1', 'stsb', 'use']
|
75 |
+
pv1 - paraphrase-distilroberta-base-v1
|
76 |
+
stsb - stsb-roberta-large
|
77 |
+
use - Universal Sentence Encoder (Default)
|
78 |
+
- A string path or name for any model on Huggingface/SentenceTransformer that is supported by
|
79 |
+
SentenceTransformer such as `all-mpnet-base-v2` or `roberta-base` .
|
80 |
+
- A custom instance of an Encoder (must implement the encode() method). Refer SemF1/encoder_models.py
|
81 |
+
|
82 |
tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
|
83 |
multi_references (bool): Flag to indicate whether multiple references are provided. Default is False.
|
84 |
gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
|
|
|
160 |
"""
|
161 |
|
162 |
|
163 |
+
def _compute_cosine_similarity(
|
164 |
+
pred_embeds: NDArray, ref_embeds: NDArray
|
165 |
+
) -> Tuple[float, float]:
|
166 |
"""
|
167 |
+
Compute precision and recall based on cosine similarity between predicted and reference embeddings.
|
168 |
|
169 |
+
Args:
|
170 |
+
pred_embeds (NDArray): Predicted embeddings (shape: [num_pred, embedding_dim]).
|
171 |
+
ref_embeds (NDArray): Reference embeddings (shape: [num_ref, embedding_dim]).
|
172 |
|
173 |
+
Returns:
|
174 |
+
Tuple[float, float]: Precision and recall based on cosine similarity scores.
|
175 |
+
Precision: Average maximum cosine similarity score per predicted embedding.
|
176 |
+
Recall: Average maximum cosine similarity score per reference embedding.
|
177 |
+
"""
|
178 |
# Compute cosine similarity between predicted and reference embeddings
|
179 |
cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
|
180 |
|
|
|
192 |
|
193 |
|
194 |
def _validate_input_format(
|
195 |
+
tokenize_sentences: bool,
|
196 |
+
multi_references: bool,
|
197 |
+
predictions: PREDICTION_TYPE,
|
198 |
+
references: REFERENCE_TYPE,
|
199 |
):
|
200 |
"""
|
201 |
+
Validate the format of predictions and references based on specified criteria.
|
202 |
|
203 |
+
Args:
|
204 |
+
- tokenize_sentences (bool): Flag indicating whether sentences should be tokenized.
|
205 |
+
- multi_references (bool): Flag indicating whether multiple references are provided.
|
206 |
+
- predictions (PREDICTION_TYPE): Predictions to validate.
|
207 |
+
- references (REFERENCE_TYPE): References to validate.
|
208 |
|
209 |
+
Raises:
|
210 |
+
- ValueError: If the format of predictions or references does not meet the specified criteria.
|
211 |
|
212 |
+
Validation Criteria:
|
213 |
+
The function validates predictions and references based on the following conditions:
|
214 |
+
1. If `tokenize_sentences` is True and `multi_references` is True:
|
215 |
+
- Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
|
216 |
+
- References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
|
217 |
|
218 |
+
2. If `tokenize_sentences` is False and `multi_references` is True:
|
219 |
+
- Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
|
220 |
+
- References must be a list of list of list of strings (`is_list_of_strings_at_depth(references, 3)`).
|
221 |
|
222 |
+
3. If `tokenize_sentences` is True and `multi_references` is False:
|
223 |
+
- Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
|
224 |
+
- References must be a list of strings (`is_list_of_strings_at_depth(references, 1)`).
|
225 |
|
226 |
+
4. If `tokenize_sentences` is False and `multi_references` is False:
|
227 |
+
- Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
|
228 |
+
- References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
|
229 |
|
230 |
+
The function checks these conditions and raises a ValueError if any condition is not met,
|
231 |
+
indicating that predictions or references are not in the valid input format.
|
232 |
|
233 |
+
Note:
|
234 |
+
- `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
|
235 |
"""
|
236 |
|
237 |
if len(predictions) != len(references):
|
238 |
+
raise ValueError(
|
239 |
+
f"Predictions and references must have the same length. "
|
240 |
+
f"Got {len(predictions)} predictions and {len(references)} references."
|
241 |
+
)
|
242 |
|
243 |
if len(predictions) == 0:
|
244 |
raise ValueError("Can't have empty inputs")
|
245 |
|
246 |
def check_format(lst_obj, expected_depth: int, name: str):
|
247 |
+
is_valid, error_message = is_nested_list_of_type(
|
248 |
+
lst_obj, element_type=str, depth=expected_depth
|
249 |
+
)
|
250 |
if not is_valid:
|
251 |
+
raise ValueError(
|
252 |
+
f"{name} are not in the expected format.\n" f"Error: {error_message}."
|
253 |
+
)
|
254 |
|
255 |
try:
|
256 |
if tokenize_sentences and multi_references:
|
|
|
290 |
datasets.Features(
|
291 |
{
|
292 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
293 |
+
"predictions": datasets.Sequence(
|
294 |
+
datasets.Value("string", id="sequence"), id="predictions"
|
295 |
+
),
|
296 |
# references: List[List[str]] - List of references where each reference is a list of sentences
|
297 |
+
"references": datasets.Sequence(
|
298 |
+
datasets.Value("string", id="sequence"), id="references"
|
299 |
+
),
|
300 |
}
|
301 |
),
|
302 |
# F1: Multi References: False, Tokenize_Sentences = True
|
|
|
312 |
datasets.Features(
|
313 |
{
|
314 |
# predictions: List[List[str]] - List of predictions where prediction is a list of sentences
|
315 |
+
"predictions": datasets.Sequence(
|
316 |
+
datasets.Value("string", id="sequence"), id="predictions"
|
317 |
+
),
|
318 |
# references: List[List[List[str]]] - List of multi-references.
|
319 |
# So each "reference" is also a list (r1, r2, ...).
|
320 |
# Further, each ri's are also list of sentences.
|
321 |
"references": datasets.Sequence(
|
322 |
+
datasets.Sequence(
|
323 |
+
datasets.Value("string", id="sequence"), id="ref"
|
324 |
+
),
|
325 |
+
id="references",
|
326 |
+
),
|
327 |
}
|
328 |
),
|
329 |
# F3: Multi References: True, Tokenize_Sentences = True
|
|
|
333 |
"predictions": datasets.Value("string", id="sequence"),
|
334 |
# references: List[List[List[str]]] - List of multi-references.
|
335 |
# So each "reference" is also a list (r1, r2, ...).
|
336 |
+
"references": datasets.Sequence(
|
337 |
+
datasets.Value("string", id="ref"), id="references"
|
338 |
+
),
|
339 |
}
|
340 |
),
|
341 |
],
|
342 |
# # Homepage of the module for documentation
|
343 |
# Additional links to the codebase or references
|
344 |
+
reference_urls=["https://aclanthology.org/2022.emnlp-main.49/"],
|
345 |
)
|
346 |
|
347 |
def _get_model_name(self, model_type: Optional[str] = None) -> str:
|
|
|
356 |
def _download_and_prepare(self, dl_manager):
|
357 |
"""Optional: download external resources useful to compute the scores"""
|
358 |
import nltk
|
359 |
+
|
360 |
nltk.download("punkt", quiet=True)
|
361 |
|
362 |
def _compute(
|
363 |
+
self,
|
364 |
+
predictions,
|
365 |
+
references,
|
366 |
+
model_type: Optional[Union[str, Encoder]] = None,
|
367 |
+
tokenize_sentences: bool = True,
|
368 |
+
multi_references: bool = False,
|
369 |
+
gpu: DEVICE_TYPE = False,
|
370 |
+
batch_size: int = 32,
|
371 |
+
verbose: bool = False,
|
372 |
+
aggregate: bool = False,
|
373 |
) -> List[Scores]:
|
374 |
"""
|
375 |
+
Compute precision, recall, and F1 scores for given predictions and references.
|
376 |
+
|
377 |
+
Args:
|
378 |
+
- predictions
|
379 |
+
- references
|
380 |
+
- model_type: Type of model to use for encoding.
|
381 |
+
Options: [pv1, stsb, use]
|
382 |
+
pv1 - paraphrase-distilroberta-base-v1
|
383 |
+
stsb - stsb-roberta-large
|
384 |
+
use - Universal Sentence Encoder (Default)
|
385 |
+
- A string path or name for any model on Huggingface/SentenceTransformer that is supported by
|
386 |
+
SentenceTransformer.
|
387 |
+
- A custom instance of an Encoder (must implement the encode() method). Refer SemF1/encoder_models.py
|
388 |
+
|
389 |
+
- tokenize_sentences: Flag to sentence tokenize the document.
|
390 |
+
- multi_references: Flag to indicate multiple references.
|
391 |
+
- gpu: GPU device to use.
|
392 |
+
- batch_size: Batch size for encoding.
|
393 |
+
- verbose: Flag to indicate verbose output.
|
394 |
+
- aggregate: Flag to determine if output should be averaged
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
Singleton/List of Scores dataclass with attributes as follows -
|
398 |
+
precision: float - precision score
|
399 |
+
recall: List[float] - List of recall scores corresponding to single/multiple references
|
400 |
+
f1: float - F1 score (between precision and average recall)
|
401 |
"""
|
402 |
|
403 |
# Note: I have to specifically handle this case because the library considers the feature corresponding to
|
404 |
# this case (F2) as the feature for the other case (F0) i.e. it can't make any distinction between
|
405 |
# List[str] and List[List[str]]
|
406 |
if not tokenize_sentences and multi_references:
|
407 |
+
references = [
|
408 |
+
[eval(ref) for ref in mul_ref_ex] for mul_ref_ex in references
|
409 |
+
]
|
410 |
|
411 |
# Validate inputs corresponding to flags
|
412 |
+
_validate_input_format(
|
413 |
+
tokenize_sentences, multi_references, predictions, references
|
414 |
+
)
|
415 |
|
416 |
# Get GPU
|
417 |
device = get_gpu(gpu)
|
|
|
419 |
print(f"Using devices: {device}")
|
420 |
|
421 |
# Get the encoder model
|
422 |
+
if model_type is None or isinstance(model_type, str):
|
423 |
+
model_name = self._get_model_name(model_type)
|
424 |
+
encoder = get_encoder(model_name)
|
425 |
+
elif isinstance(model_type, Encoder):
|
426 |
+
encoder = model_type
|
427 |
+
else:
|
428 |
+
raise TypeError(
|
429 |
+
f"Unsupported model_type: expected str or Encoder instance, got {type(model_type)}"
|
430 |
+
)
|
431 |
|
432 |
# We'll handle the single reference and multi-reference case same way. So change the data format accordingly
|
433 |
if not multi_references:
|
|
|
447 |
|
448 |
# Note: This is the most optimal way of doing it
|
449 |
# Encode all sentences in one go
|
450 |
+
embeddings = encoder.encode(
|
451 |
+
all_sentences, device=device, batch_size=batch_size, verbose=verbose
|
452 |
+
)
|
453 |
|
454 |
# Get embeddings corresponding to predictions and references
|
455 |
pred_embeddings = slice_embeddings(embeddings, prediction_sentences_count)
|
456 |
+
ref_embeddings = slice_embeddings(
|
457 |
+
embeddings[sum(prediction_sentences_count) :], reference_sentences_count
|
458 |
+
)
|
459 |
|
460 |
# Init output scores
|
461 |
results = []
|
|
|
468 |
precision = np.clip(precision, a_min=0.0, a_max=1.0).item()
|
469 |
|
470 |
# Recall: Compute individually for each reference
|
471 |
+
recall_scores = [
|
472 |
+
_compute_cosine_similarity(r_embeds, preds) for r_embeds in refs
|
473 |
+
]
|
474 |
+
recall_scores = [
|
475 |
+
np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores
|
476 |
+
]
|
477 |
|
478 |
results.append(Scores(precision, recall_scores))
|
479 |
|
480 |
# run aggregation procedure
|
481 |
if aggregate:
|
482 |
+
mean_prec = np.mean([score.precision for score in results])
|
483 |
+
mean_recall = np.mean(
|
484 |
+
np.concatenate([np.array(score.recall) for score in results])
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
)
|
486 |
+
aggregated_score = Scores(float(mean_prec), [float(mean_recall)])
|
487 |
results = aggregated_score
|
488 |
|
489 |
return results
|
tests.py
CHANGED
@@ -10,7 +10,14 @@ from unittest import TestLoader
|
|
10 |
|
11 |
from .encoder_models import SBertEncoder, get_encoder
|
12 |
from .semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
|
13 |
-
from .utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
class TestUtils(unittest.TestCase):
|
@@ -40,20 +47,29 @@ class TestUtils(unittest.TestCase):
|
|
40 |
self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
|
41 |
|
42 |
# Test list input with unique elements
|
43 |
-
self.assertEqual(
|
|
|
|
|
|
|
44 |
|
45 |
# Test list input with duplicate elements
|
46 |
-
self.assertEqual(
|
|
|
|
|
47 |
|
48 |
# Test list input with duplicate elements of different types
|
49 |
-
self.assertEqual(
|
|
|
|
|
50 |
|
51 |
# Test list input but only one element
|
52 |
self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
|
53 |
|
54 |
# Test list input with all integers
|
55 |
-
self.assertEqual(
|
56 |
-
|
|
|
|
|
57 |
|
58 |
with self.assertRaises(ValueError):
|
59 |
get_gpu("invalid")
|
@@ -66,12 +82,19 @@ class TestUtils(unittest.TestCase):
|
|
66 |
num_sentences = [3, 2, 5]
|
67 |
expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
|
68 |
self.assertTrue(
|
69 |
-
all(
|
70 |
-
|
|
|
|
|
|
|
|
|
71 |
)
|
72 |
|
73 |
num_sentences_nested = [[2, 1], [3, 4]]
|
74 |
-
expected_output_nested = [
|
|
|
|
|
|
|
75 |
self.assertTrue(
|
76 |
slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
|
77 |
)
|
@@ -88,7 +111,9 @@ class TestUtils(unittest.TestCase):
|
|
88 |
self.assertEqual(is_valid, False)
|
89 |
|
90 |
# Test case: Depth 1, list of elements matching element_type
|
91 |
-
self.assertEqual(
|
|
|
|
|
92 |
|
93 |
# Test case: Depth 1, list of elements not matching element_type
|
94 |
is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1)
|
@@ -100,15 +125,18 @@ class TestUtils(unittest.TestCase):
|
|
100 |
|
101 |
# Depth 2
|
102 |
self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, ""))
|
103 |
-
self.assertEqual(
|
|
|
|
|
104 |
is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)
|
105 |
self.assertEqual(is_valid, False)
|
106 |
|
107 |
-
|
108 |
# Depth 3
|
109 |
is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)
|
110 |
self.assertEqual(is_valid, False)
|
111 |
-
self.assertEqual(
|
|
|
|
|
112 |
|
113 |
# Test case: Depth is negative, expecting ValueError
|
114 |
with self.assertRaises(ValueError):
|
@@ -134,38 +162,55 @@ class TestUtils(unittest.TestCase):
|
|
134 |
class TestSBertEncoder(unittest.TestCase):
|
135 |
def setUp(self, device=None):
|
136 |
if device is None:
|
137 |
-
self.device =
|
138 |
else:
|
139 |
self.device = device
|
140 |
self.model_name = "stsb-roberta-large"
|
141 |
self.batch_size = 8
|
142 |
self.verbose = False
|
143 |
-
self.encoder = SBertEncoder(self.model_name
|
144 |
|
145 |
def test_initialization(self):
|
146 |
self.assertIsInstance(self.encoder.model, SentenceTransformer)
|
147 |
-
self.assertEqual(self.encoder.device, self.device)
|
148 |
-
self.assertEqual(self.encoder.batch_size, self.batch_size)
|
149 |
-
self.assertEqual(self.encoder.verbose, self.verbose)
|
150 |
|
151 |
def test_encode_single_device(self):
|
152 |
sentences = ["This is a test sentence.", "Here is another sentence."]
|
153 |
-
embeddings = self.encoder.encode(
|
|
|
|
|
|
|
|
|
|
|
154 |
self.assertIsInstance(embeddings, np.ndarray)
|
155 |
self.assertEqual(embeddings.shape[0], len(sentences))
|
156 |
-
self.assertEqual(
|
|
|
|
|
157 |
|
158 |
def test_encode_multi_device(self):
|
159 |
if torch.cuda.device_count() < 2:
|
160 |
self.skipTest("Multi-GPU test requires at least 2 GPUs.")
|
161 |
else:
|
162 |
-
devices = ["cuda:0", "cuda:1"]
|
|
|
163 |
self.setUp(devices)
|
164 |
-
sentences = [
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
self.assertIsInstance(embeddings, np.ndarray)
|
167 |
self.assertEqual(embeddings.shape[0], 3)
|
168 |
-
self.assertEqual(
|
|
|
|
|
|
|
169 |
|
170 |
|
171 |
class TestGetEncoder(unittest.TestCase):
|
@@ -175,13 +220,8 @@ class TestGetEncoder(unittest.TestCase):
|
|
175 |
self.verbose = False
|
176 |
|
177 |
def _base_test(self, model_name):
|
178 |
-
encoder = get_encoder(model_name
|
179 |
-
|
180 |
-
# Assert
|
181 |
self.assertIsInstance(encoder, SBertEncoder)
|
182 |
-
self.assertEqual(encoder.device, self.device)
|
183 |
-
self.assertEqual(encoder.batch_size, self.batch_size)
|
184 |
-
self.assertEqual(encoder.verbose, self.verbose)
|
185 |
|
186 |
def test_get_sbert_encoder(self):
|
187 |
model_name = "stsb-roberta-large"
|
@@ -196,15 +236,15 @@ class TestGetEncoder(unittest.TestCase):
|
|
196 |
model_name = "roberta-base"
|
197 |
self._base_test(model_name)
|
198 |
|
199 |
-
def test_get_encoder_environment_error(self):
|
200 |
model_name = "abc" # Wrong model_name
|
201 |
with self.assertRaises(EnvironmentError):
|
202 |
-
get_encoder(model_name
|
203 |
|
204 |
def test_get_encoder_other_exception(self):
|
205 |
model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib
|
206 |
with self.assertRaises(RuntimeError):
|
207 |
-
get_encoder(model_name
|
208 |
|
209 |
|
210 |
class TestSemF1(unittest.TestCase):
|
@@ -213,9 +253,11 @@ class TestSemF1(unittest.TestCase):
|
|
213 |
|
214 |
# Example cases, #Samples = 1
|
215 |
self.untokenized_single_reference_predictions = [
|
216 |
-
"This is a prediction sentence 1. This is a prediction sentence 2."
|
|
|
217 |
self.untokenized_single_reference_references = [
|
218 |
-
"This is a reference sentence 1. This is a reference sentence 2."
|
|
|
219 |
|
220 |
self.tokenized_single_reference_predictions = [
|
221 |
["This is a prediction sentence 1.", "This is a prediction sentence 2."],
|
@@ -228,7 +270,10 @@ class TestSemF1(unittest.TestCase):
|
|
228 |
"Prediction sentence 1. Prediction sentence 2."
|
229 |
]
|
230 |
self.untokenized_multi_reference_references = [
|
231 |
-
[
|
|
|
|
|
|
|
232 |
]
|
233 |
|
234 |
self.tokenized_multi_reference_predictions = [
|
@@ -237,21 +282,21 @@ class TestSemF1(unittest.TestCase):
|
|
237 |
self.tokenized_multi_reference_references = [
|
238 |
[
|
239 |
["Reference sentence 1.", "Reference sentence 2."],
|
240 |
-
["Alternative reference 1.", "Alternative reference 2."]
|
241 |
],
|
242 |
]
|
243 |
self.multi_sample_refs = [
|
244 |
-
|
245 |
-
|
246 |
]
|
247 |
self.multi_sample_preds = [
|
248 |
-
|
249 |
-
|
250 |
]
|
251 |
-
|
252 |
def test_aggregate_multi_sample(self):
|
253 |
"""
|
254 |
-
check if a `Scores` class is returned instead of a list of
|
255 |
`Scores`
|
256 |
"""
|
257 |
scores = self.semf1_metric.compute(
|
@@ -265,7 +310,7 @@ class TestSemF1(unittest.TestCase):
|
|
265 |
aggregate=True,
|
266 |
)
|
267 |
self.assertIsInstance(scores, Scores)
|
268 |
-
print(f
|
269 |
|
270 |
def test_aggregate_untokenized_single_ref(self):
|
271 |
scores = self.semf1_metric.compute(
|
@@ -279,7 +324,7 @@ class TestSemF1(unittest.TestCase):
|
|
279 |
aggregate=True,
|
280 |
)
|
281 |
self.assertIsInstance(scores, Scores)
|
282 |
-
print(f
|
283 |
|
284 |
def test_aggregate_tokenized_single_ref(self):
|
285 |
scores = self.semf1_metric.compute(
|
@@ -293,7 +338,7 @@ class TestSemF1(unittest.TestCase):
|
|
293 |
aggregate=True,
|
294 |
)
|
295 |
self.assertIsInstance(scores, Scores)
|
296 |
-
print(f
|
297 |
|
298 |
def test_aggregate_untokenized_multi_ref(self):
|
299 |
scores = self.semf1_metric.compute(
|
@@ -307,7 +352,7 @@ class TestSemF1(unittest.TestCase):
|
|
307 |
aggregate=True,
|
308 |
)
|
309 |
self.assertIsInstance(scores, Scores)
|
310 |
-
print(f
|
311 |
|
312 |
def test_aggregate_tokenized_multi_ref(self):
|
313 |
scores = self.semf1_metric.compute(
|
@@ -321,7 +366,7 @@ class TestSemF1(unittest.TestCase):
|
|
321 |
aggregate=True,
|
322 |
)
|
323 |
self.assertIsInstance(scores, Scores)
|
324 |
-
print(f
|
325 |
|
326 |
def test_aggregate_same_pred_and_ref(self):
|
327 |
scores = self.semf1_metric.compute(
|
@@ -335,7 +380,7 @@ class TestSemF1(unittest.TestCase):
|
|
335 |
aggregate=True,
|
336 |
)
|
337 |
self.assertIsInstance(scores, Scores)
|
338 |
-
print(f
|
339 |
|
340 |
def test_untokenized_single_reference(self):
|
341 |
scores = self.semf1_metric.compute(
|
@@ -345,10 +390,12 @@ class TestSemF1(unittest.TestCase):
|
|
345 |
multi_references=False,
|
346 |
gpu=False,
|
347 |
batch_size=32,
|
348 |
-
verbose=False
|
349 |
)
|
350 |
self.assertIsInstance(scores, list)
|
351 |
-
self.assertEqual(
|
|
|
|
|
352 |
|
353 |
def test_tokenized_single_reference(self):
|
354 |
scores = self.semf1_metric.compute(
|
@@ -358,7 +405,7 @@ class TestSemF1(unittest.TestCase):
|
|
358 |
multi_references=False,
|
359 |
gpu=False,
|
360 |
batch_size=32,
|
361 |
-
verbose=False
|
362 |
)
|
363 |
self.assertIsInstance(scores, list)
|
364 |
self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
|
@@ -376,7 +423,7 @@ class TestSemF1(unittest.TestCase):
|
|
376 |
multi_references=True,
|
377 |
gpu=False,
|
378 |
batch_size=32,
|
379 |
-
verbose=False
|
380 |
)
|
381 |
self.assertIsInstance(scores, list)
|
382 |
self.assertEqual(len(scores), len(self.untokenized_multi_reference_predictions))
|
@@ -389,7 +436,7 @@ class TestSemF1(unittest.TestCase):
|
|
389 |
multi_references=True,
|
390 |
gpu=False,
|
391 |
batch_size=32,
|
392 |
-
verbose=False
|
393 |
)
|
394 |
self.assertIsInstance(scores, list)
|
395 |
self.assertEqual(len(scores), len(self.tokenized_multi_reference_predictions))
|
@@ -407,7 +454,7 @@ class TestSemF1(unittest.TestCase):
|
|
407 |
multi_references=False,
|
408 |
gpu=False,
|
409 |
batch_size=32,
|
410 |
-
verbose=False
|
411 |
)
|
412 |
|
413 |
self.assertIsInstance(scores, list)
|
@@ -416,7 +463,12 @@ class TestSemF1(unittest.TestCase):
|
|
416 |
for score in scores:
|
417 |
self.assertIsInstance(score, Scores)
|
418 |
self.assertAlmostEqual(score.precision, 1.0, places=6)
|
419 |
-
assert_almost_equal(
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
def test_exact_output_scores(self):
|
422 |
predictions = [
|
@@ -473,7 +525,9 @@ class TestSemF1(unittest.TestCase):
|
|
473 |
["I am", "I am"],
|
474 |
[None, "I am"],
|
475 |
]
|
476 |
-
print(
|
|
|
|
|
477 |
|
478 |
# Case 2: tokenize_sentences = False, multi_references = True
|
479 |
tokenize_sentences = False
|
@@ -486,7 +540,9 @@ class TestSemF1(unittest.TestCase):
|
|
486 |
[["I am", "I am"], [None, "I am"]],
|
487 |
[[None, "I am"]],
|
488 |
]
|
489 |
-
print(
|
|
|
|
|
490 |
|
491 |
# Case 3: tokenize_sentences = True, multi_references = False
|
492 |
tokenize_sentences = True
|
@@ -499,7 +555,9 @@ class TestSemF1(unittest.TestCase):
|
|
499 |
"I am. I am.",
|
500 |
"I am. I am.",
|
501 |
]
|
502 |
-
print(
|
|
|
|
|
503 |
|
504 |
# Case 4: tokenize_sentences = False, multi_references = False
|
505 |
# This is taken care by the library itself
|
@@ -513,7 +571,9 @@ class TestSemF1(unittest.TestCase):
|
|
513 |
["I am.", "I am."],
|
514 |
["I am.", "I am."],
|
515 |
]
|
516 |
-
print(
|
|
|
|
|
517 |
|
518 |
def test_empty_input(self):
|
519 |
predictions = ["", ""]
|
@@ -538,22 +598,16 @@ class TestCosineSimilarity(unittest.TestCase):
|
|
538 |
|
539 |
def setUp(self):
|
540 |
# Sample embeddings for testing
|
541 |
-
self.pred_embeds = np.array([
|
542 |
-
|
543 |
-
[0, 1, 0],
|
544 |
-
[0, 0, 1]
|
545 |
-
])
|
546 |
-
self.ref_embeds = np.array([
|
547 |
-
[1, 0, 0],
|
548 |
-
[0, 1, 0],
|
549 |
-
[0, 0, 1]
|
550 |
-
])
|
551 |
|
552 |
self.pred_embeds_random = np.random.rand(3, 3)
|
553 |
self.ref_embeds_random = np.random.rand(3, 3)
|
554 |
|
555 |
def test_cosine_similarity_perfect_match(self):
|
556 |
-
precision, recall = _compute_cosine_similarity(
|
|
|
|
|
557 |
|
558 |
# Expected values are 1.0 for both precision and recall since embeddings are identical
|
559 |
self.assertAlmostEqual(precision, 1.0, places=5)
|
@@ -571,7 +625,9 @@ class TestCosineSimilarity(unittest.TestCase):
|
|
571 |
self.assertAlmostEqual(recall, expected_recall, places=5)
|
572 |
|
573 |
def test_cosine_similarity_random(self):
|
574 |
-
self._test_cosine_similarity_base(
|
|
|
|
|
575 |
|
576 |
def test_cosine_similarity_different_shapes(self):
|
577 |
pred_embeds_diff = np.random.rand(5, 3)
|
@@ -607,7 +663,7 @@ class TestValidateInputFormat(unittest.TestCase):
|
|
607 |
self.untokenized_multi_reference_references = [
|
608 |
[
|
609 |
"This is a reference sentence 1. This is a reference sentence 2.",
|
610 |
-
"Another reference sentence."
|
611 |
]
|
612 |
]
|
613 |
|
@@ -618,7 +674,7 @@ class TestValidateInputFormat(unittest.TestCase):
|
|
618 |
self.tokenized_multi_reference_references = [
|
619 |
[
|
620 |
["This is a reference sentence 1.", "This is a reference sentence 2."],
|
621 |
-
["Another reference sentence."]
|
622 |
]
|
623 |
]
|
624 |
|
@@ -701,7 +757,10 @@ class TestValidateInputFormat(unittest.TestCase):
|
|
701 |
True,
|
702 |
True,
|
703 |
self.untokenized_single_reference_predictions,
|
704 |
-
[
|
|
|
|
|
|
|
705 |
)
|
706 |
|
707 |
|
@@ -709,5 +768,5 @@ def run_tests():
|
|
709 |
unittest.main(verbosity=2)
|
710 |
|
711 |
|
712 |
-
if __name__ ==
|
713 |
run_tests()
|
|
|
10 |
|
11 |
from .encoder_models import SBertEncoder, get_encoder
|
12 |
from .semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
|
13 |
+
from .utils import (
|
14 |
+
get_gpu,
|
15 |
+
slice_embeddings,
|
16 |
+
is_nested_list_of_type,
|
17 |
+
flatten_list,
|
18 |
+
compute_f1,
|
19 |
+
Scores,
|
20 |
+
)
|
21 |
|
22 |
|
23 |
class TestUtils(unittest.TestCase):
|
|
|
47 |
self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
|
48 |
|
49 |
# Test list input with unique elements
|
50 |
+
self.assertEqual(
|
51 |
+
get_gpu([True, "cpu", 0]),
|
52 |
+
[0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"],
|
53 |
+
)
|
54 |
|
55 |
# Test list input with duplicate elements
|
56 |
+
self.assertEqual(
|
57 |
+
get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"]
|
58 |
+
)
|
59 |
|
60 |
# Test list input with duplicate elements of different types
|
61 |
+
self.assertEqual(
|
62 |
+
get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"]
|
63 |
+
)
|
64 |
|
65 |
# Test list input but only one element
|
66 |
self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
|
67 |
|
68 |
# Test list input with all integers
|
69 |
+
self.assertEqual(
|
70 |
+
get_gpu(list(range(gpu_count))),
|
71 |
+
list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"],
|
72 |
+
)
|
73 |
|
74 |
with self.assertRaises(ValueError):
|
75 |
get_gpu("invalid")
|
|
|
82 |
num_sentences = [3, 2, 5]
|
83 |
expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
|
84 |
self.assertTrue(
|
85 |
+
all(
|
86 |
+
np.array_equal(a, b)
|
87 |
+
for a, b in zip(
|
88 |
+
slice_embeddings(embeddings, num_sentences), expected_output
|
89 |
+
)
|
90 |
+
)
|
91 |
)
|
92 |
|
93 |
num_sentences_nested = [[2, 1], [3, 4]]
|
94 |
+
expected_output_nested = [
|
95 |
+
[embeddings[:2], embeddings[2:3]],
|
96 |
+
[embeddings[3:6], embeddings[6:]],
|
97 |
+
]
|
98 |
self.assertTrue(
|
99 |
slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
|
100 |
)
|
|
|
111 |
self.assertEqual(is_valid, False)
|
112 |
|
113 |
# Test case: Depth 1, list of elements matching element_type
|
114 |
+
self.assertEqual(
|
115 |
+
is_nested_list_of_type(["apple", "banana"], str, 1), (True, "")
|
116 |
+
)
|
117 |
|
118 |
# Test case: Depth 1, list of elements not matching element_type
|
119 |
is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1)
|
|
|
125 |
|
126 |
# Depth 2
|
127 |
self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, ""))
|
128 |
+
self.assertEqual(
|
129 |
+
is_nested_list_of_type([["1", "2"], ["3", "4"]], str, 2), (True, "")
|
130 |
+
)
|
131 |
is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)
|
132 |
self.assertEqual(is_valid, False)
|
133 |
|
|
|
134 |
# Depth 3
|
135 |
is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)
|
136 |
self.assertEqual(is_valid, False)
|
137 |
+
self.assertEqual(
|
138 |
+
is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3), (True, "")
|
139 |
+
)
|
140 |
|
141 |
# Test case: Depth is negative, expecting ValueError
|
142 |
with self.assertRaises(ValueError):
|
|
|
162 |
class TestSBertEncoder(unittest.TestCase):
|
163 |
def setUp(self, device=None):
|
164 |
if device is None:
|
165 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
166 |
else:
|
167 |
self.device = device
|
168 |
self.model_name = "stsb-roberta-large"
|
169 |
self.batch_size = 8
|
170 |
self.verbose = False
|
171 |
+
self.encoder = SBertEncoder(self.model_name)
|
172 |
|
173 |
def test_initialization(self):
|
174 |
self.assertIsInstance(self.encoder.model, SentenceTransformer)
|
|
|
|
|
|
|
175 |
|
176 |
def test_encode_single_device(self):
|
177 |
sentences = ["This is a test sentence.", "Here is another sentence."]
|
178 |
+
embeddings = self.encoder.encode(
|
179 |
+
sentences,
|
180 |
+
device=self.device,
|
181 |
+
batch_size=self.batch_size,
|
182 |
+
verbose=self.verbose,
|
183 |
+
)
|
184 |
self.assertIsInstance(embeddings, np.ndarray)
|
185 |
self.assertEqual(embeddings.shape[0], len(sentences))
|
186 |
+
self.assertEqual(
|
187 |
+
embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension()
|
188 |
+
)
|
189 |
|
190 |
def test_encode_multi_device(self):
|
191 |
if torch.cuda.device_count() < 2:
|
192 |
self.skipTest("Multi-GPU test requires at least 2 GPUs.")
|
193 |
else:
|
194 |
+
# devices = ["cuda:0", "cuda:1"]
|
195 |
+
devices = [0, 1]
|
196 |
self.setUp(devices)
|
197 |
+
sentences = [
|
198 |
+
"This is a test sentence.",
|
199 |
+
"Here is another sentence.",
|
200 |
+
"This is a test sentence.",
|
201 |
+
]
|
202 |
+
embeddings = self.encoder.encode(
|
203 |
+
sentences,
|
204 |
+
device=devices,
|
205 |
+
batch_size=self.batch_size,
|
206 |
+
verbose=self.verbose,
|
207 |
+
)
|
208 |
self.assertIsInstance(embeddings, np.ndarray)
|
209 |
self.assertEqual(embeddings.shape[0], 3)
|
210 |
+
self.assertEqual(
|
211 |
+
embeddings.shape[1],
|
212 |
+
self.encoder.model.get_sentence_embedding_dimension(),
|
213 |
+
)
|
214 |
|
215 |
|
216 |
class TestGetEncoder(unittest.TestCase):
|
|
|
220 |
self.verbose = False
|
221 |
|
222 |
def _base_test(self, model_name):
|
223 |
+
encoder = get_encoder(model_name)
|
|
|
|
|
224 |
self.assertIsInstance(encoder, SBertEncoder)
|
|
|
|
|
|
|
225 |
|
226 |
def test_get_sbert_encoder(self):
|
227 |
model_name = "stsb-roberta-large"
|
|
|
236 |
model_name = "roberta-base"
|
237 |
self._base_test(model_name)
|
238 |
|
239 |
+
def test_get_encoder_environment_error(self):
|
240 |
model_name = "abc" # Wrong model_name
|
241 |
with self.assertRaises(EnvironmentError):
|
242 |
+
get_encoder(model_name)
|
243 |
|
244 |
def test_get_encoder_other_exception(self):
|
245 |
model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib
|
246 |
with self.assertRaises(RuntimeError):
|
247 |
+
get_encoder(model_name)
|
248 |
|
249 |
|
250 |
class TestSemF1(unittest.TestCase):
|
|
|
253 |
|
254 |
# Example cases, #Samples = 1
|
255 |
self.untokenized_single_reference_predictions = [
|
256 |
+
"This is a prediction sentence 1. This is a prediction sentence 2."
|
257 |
+
]
|
258 |
self.untokenized_single_reference_references = [
|
259 |
+
"This is a reference sentence 1. This is a reference sentence 2."
|
260 |
+
]
|
261 |
|
262 |
self.tokenized_single_reference_predictions = [
|
263 |
["This is a prediction sentence 1.", "This is a prediction sentence 2."],
|
|
|
270 |
"Prediction sentence 1. Prediction sentence 2."
|
271 |
]
|
272 |
self.untokenized_multi_reference_references = [
|
273 |
+
[
|
274 |
+
"Reference sentence 1. Reference sentence 2.",
|
275 |
+
"Alternative reference 1. Alternative reference 2.",
|
276 |
+
],
|
277 |
]
|
278 |
|
279 |
self.tokenized_multi_reference_predictions = [
|
|
|
282 |
self.tokenized_multi_reference_references = [
|
283 |
[
|
284 |
["Reference sentence 1.", "Reference sentence 2."],
|
285 |
+
["Alternative reference 1.", "Alternative reference 2."],
|
286 |
],
|
287 |
]
|
288 |
self.multi_sample_refs = [
|
289 |
+
"this is the first reference sample",
|
290 |
+
"this is the second reference sample",
|
291 |
]
|
292 |
self.multi_sample_preds = [
|
293 |
+
"this is the first prediction sample",
|
294 |
+
"this is the second prediction sample",
|
295 |
]
|
296 |
+
|
297 |
def test_aggregate_multi_sample(self):
|
298 |
"""
|
299 |
+
check if a `Scores` class is returned instead of a list of
|
300 |
`Scores`
|
301 |
"""
|
302 |
scores = self.semf1_metric.compute(
|
|
|
310 |
aggregate=True,
|
311 |
)
|
312 |
self.assertIsInstance(scores, Scores)
|
313 |
+
print(f"Score: {scores}")
|
314 |
|
315 |
def test_aggregate_untokenized_single_ref(self):
|
316 |
scores = self.semf1_metric.compute(
|
|
|
324 |
aggregate=True,
|
325 |
)
|
326 |
self.assertIsInstance(scores, Scores)
|
327 |
+
print(f"Score: {scores}")
|
328 |
|
329 |
def test_aggregate_tokenized_single_ref(self):
|
330 |
scores = self.semf1_metric.compute(
|
|
|
338 |
aggregate=True,
|
339 |
)
|
340 |
self.assertIsInstance(scores, Scores)
|
341 |
+
print(f"Score: {scores}")
|
342 |
|
343 |
def test_aggregate_untokenized_multi_ref(self):
|
344 |
scores = self.semf1_metric.compute(
|
|
|
352 |
aggregate=True,
|
353 |
)
|
354 |
self.assertIsInstance(scores, Scores)
|
355 |
+
print(f"Score: {scores}")
|
356 |
|
357 |
def test_aggregate_tokenized_multi_ref(self):
|
358 |
scores = self.semf1_metric.compute(
|
|
|
366 |
aggregate=True,
|
367 |
)
|
368 |
self.assertIsInstance(scores, Scores)
|
369 |
+
print(f"Score: {scores}")
|
370 |
|
371 |
def test_aggregate_same_pred_and_ref(self):
|
372 |
scores = self.semf1_metric.compute(
|
|
|
380 |
aggregate=True,
|
381 |
)
|
382 |
self.assertIsInstance(scores, Scores)
|
383 |
+
print(f"Score: {scores}")
|
384 |
|
385 |
def test_untokenized_single_reference(self):
|
386 |
scores = self.semf1_metric.compute(
|
|
|
390 |
multi_references=False,
|
391 |
gpu=False,
|
392 |
batch_size=32,
|
393 |
+
verbose=False,
|
394 |
)
|
395 |
self.assertIsInstance(scores, list)
|
396 |
+
self.assertEqual(
|
397 |
+
len(scores), len(self.untokenized_single_reference_predictions)
|
398 |
+
)
|
399 |
|
400 |
def test_tokenized_single_reference(self):
|
401 |
scores = self.semf1_metric.compute(
|
|
|
405 |
multi_references=False,
|
406 |
gpu=False,
|
407 |
batch_size=32,
|
408 |
+
verbose=False,
|
409 |
)
|
410 |
self.assertIsInstance(scores, list)
|
411 |
self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
|
|
|
423 |
multi_references=True,
|
424 |
gpu=False,
|
425 |
batch_size=32,
|
426 |
+
verbose=False,
|
427 |
)
|
428 |
self.assertIsInstance(scores, list)
|
429 |
self.assertEqual(len(scores), len(self.untokenized_multi_reference_predictions))
|
|
|
436 |
multi_references=True,
|
437 |
gpu=False,
|
438 |
batch_size=32,
|
439 |
+
verbose=False,
|
440 |
)
|
441 |
self.assertIsInstance(scores, list)
|
442 |
self.assertEqual(len(scores), len(self.tokenized_multi_reference_predictions))
|
|
|
454 |
multi_references=False,
|
455 |
gpu=False,
|
456 |
batch_size=32,
|
457 |
+
verbose=False,
|
458 |
)
|
459 |
|
460 |
self.assertIsInstance(scores, list)
|
|
|
463 |
for score in scores:
|
464 |
self.assertIsInstance(score, Scores)
|
465 |
self.assertAlmostEqual(score.precision, 1.0, places=6)
|
466 |
+
assert_almost_equal(
|
467 |
+
score.recall,
|
468 |
+
1,
|
469 |
+
decimal=5,
|
470 |
+
err_msg="Not all values are almost equal to 1",
|
471 |
+
)
|
472 |
|
473 |
def test_exact_output_scores(self):
|
474 |
predictions = [
|
|
|
525 |
["I am", "I am"],
|
526 |
[None, "I am"],
|
527 |
]
|
528 |
+
print(
|
529 |
+
f"Case I\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
|
530 |
+
)
|
531 |
|
532 |
# Case 2: tokenize_sentences = False, multi_references = True
|
533 |
tokenize_sentences = False
|
|
|
540 |
[["I am", "I am"], [None, "I am"]],
|
541 |
[[None, "I am"]],
|
542 |
]
|
543 |
+
print(
|
544 |
+
f"Case II\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
|
545 |
+
)
|
546 |
|
547 |
# Case 3: tokenize_sentences = True, multi_references = False
|
548 |
tokenize_sentences = True
|
|
|
555 |
"I am. I am.",
|
556 |
"I am. I am.",
|
557 |
]
|
558 |
+
print(
|
559 |
+
f"Case III\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
|
560 |
+
)
|
561 |
|
562 |
# Case 4: tokenize_sentences = False, multi_references = False
|
563 |
# This is taken care by the library itself
|
|
|
571 |
["I am.", "I am."],
|
572 |
["I am.", "I am."],
|
573 |
]
|
574 |
+
print(
|
575 |
+
f"Case IV\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
|
576 |
+
)
|
577 |
|
578 |
def test_empty_input(self):
|
579 |
predictions = ["", ""]
|
|
|
598 |
|
599 |
def setUp(self):
|
600 |
# Sample embeddings for testing
|
601 |
+
self.pred_embeds = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
602 |
+
self.ref_embeds = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
603 |
|
604 |
self.pred_embeds_random = np.random.rand(3, 3)
|
605 |
self.ref_embeds_random = np.random.rand(3, 3)
|
606 |
|
607 |
def test_cosine_similarity_perfect_match(self):
|
608 |
+
precision, recall = _compute_cosine_similarity(
|
609 |
+
self.pred_embeds, self.ref_embeds
|
610 |
+
)
|
611 |
|
612 |
# Expected values are 1.0 for both precision and recall since embeddings are identical
|
613 |
self.assertAlmostEqual(precision, 1.0, places=5)
|
|
|
625 |
self.assertAlmostEqual(recall, expected_recall, places=5)
|
626 |
|
627 |
def test_cosine_similarity_random(self):
|
628 |
+
self._test_cosine_similarity_base(
|
629 |
+
self.pred_embeds_random, self.ref_embeds_random
|
630 |
+
)
|
631 |
|
632 |
def test_cosine_similarity_different_shapes(self):
|
633 |
pred_embeds_diff = np.random.rand(5, 3)
|
|
|
663 |
self.untokenized_multi_reference_references = [
|
664 |
[
|
665 |
"This is a reference sentence 1. This is a reference sentence 2.",
|
666 |
+
"Another reference sentence.",
|
667 |
]
|
668 |
]
|
669 |
|
|
|
674 |
self.tokenized_multi_reference_references = [
|
675 |
[
|
676 |
["This is a reference sentence 1.", "This is a reference sentence 2."],
|
677 |
+
["Another reference sentence."],
|
678 |
]
|
679 |
]
|
680 |
|
|
|
757 |
True,
|
758 |
True,
|
759 |
self.untokenized_single_reference_predictions,
|
760 |
+
[
|
761 |
+
self.untokenized_single_reference_predictions[0],
|
762 |
+
self.untokenized_single_reference_predictions[0],
|
763 |
+
],
|
764 |
)
|
765 |
|
766 |
|
|
|
768 |
unittest.main(verbosity=2)
|
769 |
|
770 |
|
771 |
+
if __name__ == "__main__":
|
772 |
run_tests()
|