nbansal commited on
Commit
47cf512
·
1 Parent(s): 9db3d74

refactor: allow custom Encoder instances

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. encoder_models.py +48 -40
  3. semf1.py +157 -108
  4. tests.py +135 -76
README.md CHANGED
@@ -59,8 +59,8 @@ Sem-F1 takes 2 mandatory arguments:
59
  Sem-F1 also accepts multiple optional arguments:
60
 
61
 
62
- - `model_type (str)`: Model to use for encoding sentences. Options: ['pv1' ([paraphrase-distilroberta-base-v1](https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1)), 'stsb' ([stsb-roberta-large](https://huggingface.co/sentence-transformers/stsb-roberta-large)), 'use' ([Universal Sentence Encoder](https://huggingface.co/sentence-transformers/use-cmlm-multilingual)) (Default)]. Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by SentenceTransformer
63
- such as `all-mpnet-base-v2` or `roberta-base`.
64
  - `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
65
  - `multi_references (bool)`: Flag to indicate whether multiple references are provided. Default: False.
66
  - `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU or multiple-processes for computation.
 
59
  Sem-F1 also accepts multiple optional arguments:
60
 
61
 
62
+ - `model_type (Optional[Union[str, Encoder]])`: Model to use for encoding sentences. Options: ['pv1' ([paraphrase-distilroberta-base-v1](https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1)), 'stsb' ([stsb-roberta-large](https://huggingface.co/sentence-transformers/stsb-roberta-large)), 'use' ([Universal Sentence Encoder](https://huggingface.co/sentence-transformers/use-cmlm-multilingual)) (Default)]. Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by SentenceTransformer
63
+ such as `all-mpnet-base-v2` or `roberta-base`. Users can also pass a custom `Encoder` which must implement the `encode` method. Refer SemF1/encoder_models.py
64
  - `tokenize_sentences (bool)`: Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
65
  - `multi_references (bool)`: Flag to indicate whether multiple references are provided. Default: False.
66
  - `gpu (Union[bool, str, int, List[Union[str, int]]])`: Whether to use GPU, CPU or multiple-processes for computation.
encoder_models.py CHANGED
@@ -9,68 +9,83 @@ from .type_aliases import ENCODER_DEVICE_TYPE
9
 
10
  class Encoder(abc.ABC):
11
  @abc.abstractmethod
12
- def encode(self, prediction: List[str]) -> NDArray:
 
 
 
 
 
 
 
13
  """
14
- Abstract method to encode a list of sentences into sentence embeddings.
15
 
16
- Args:
17
- prediction (List[str]): List of sentences to encode.
 
 
 
18
 
19
- Returns:
20
- NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
21
 
22
- Raises:
23
- NotImplementedError: If the method is not implemented in the subclass.
24
  """
25
  raise NotImplementedError("Method 'encode' must be implemented in subclass.")
26
 
27
 
28
  class SBertEncoder(Encoder):
29
- def __init__(self, model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool):
30
  """
31
- Initialize SBertEncoder instance.
32
 
33
- Args:
34
- model_name (str): Name or path of the Sentence Transformer model.
35
- device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
36
- batch_size (int): Batch size for encoding.
37
- verbose (bool): Whether to print verbose information during encoding.
38
  """
39
  self.model = SentenceTransformer(model_name, trust_remote_code=True)
40
- self.device = device
41
- self.batch_size = batch_size
42
- self.verbose = verbose
43
 
44
- def encode(self, prediction: List[str]) -> NDArray:
 
 
 
 
 
 
 
45
  """
46
- Encode a list of sentences into sentence embeddings.
47
 
48
- Args:
49
- prediction (List[str]): List of sentences to encode.
 
 
 
50
 
51
- Returns:
52
- NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
53
  """
54
 
55
  # SBert output is always Batch x Dim
56
- if isinstance(self.device, list):
57
  # Use multiprocess encoding for list of devices
58
- pool = self.model.start_multi_process_pool(target_devices=self.device)
59
- embeddings = self.model.encode_multi_process(prediction, pool=pool, batch_size=self.batch_size)
 
 
60
  self.model.stop_multi_process_pool(pool)
61
  else:
62
  # Single device encoding
63
  embeddings = self.model.encode(
64
  prediction,
65
- device=self.device,
66
- batch_size=self.batch_size,
67
- show_progress_bar=self.verbose,
68
  )
69
-
70
  return embeddings
71
 
72
 
73
- def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool) -> Encoder:
74
  """
75
  Get the encoder instance based on the specified model name.
76
 
@@ -83,11 +98,6 @@ def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, v
83
  Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
84
  SentenceTransformer.
85
 
86
- device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
87
- (e.g., "cuda", 0 for GPU, "cpu").
88
- batch_size (int): Batch size for encoding.
89
- verbose (bool): Whether to print verbose information during encoder initialization.
90
-
91
  Returns:
92
  Encoder: Instance of the selected encoder based on the model_name.
93
 
@@ -96,12 +106,10 @@ def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, v
96
  """
97
 
98
  try:
99
- encoder = SBertEncoder(model_name, device, batch_size, verbose)
100
  except EnvironmentError as err:
101
  raise EnvironmentError(str(err)) from None
102
  except Exception as err:
103
  raise RuntimeError(str(err)) from None
104
 
105
  return encoder
106
-
107
-
 
9
 
10
  class Encoder(abc.ABC):
11
  @abc.abstractmethod
12
+ def encode(
13
+ self,
14
+ prediction: List[str],
15
+ *,
16
+ device: ENCODER_DEVICE_TYPE = "cpu",
17
+ batch_size: int = 32,
18
+ verbose: bool = False,
19
+ ) -> NDArray:
20
  """
21
+ Abstract method to encode a list of sentences into sentence embeddings.
22
 
23
+ Args:
24
+ prediction (List[str]): List of sentences to encode.
25
+ device (Union[str, int, List[Union[str, int]]]): Device specification for encoding.
26
+ batch_size (int): Batch size for encoding.
27
+ verbose (bool): Whether to print verbose information during encoding.
28
 
29
+ Returns:
30
+ NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
31
 
32
+ Raises:
33
+ NotImplementedError: If the method is not implemented in the subclass.
34
  """
35
  raise NotImplementedError("Method 'encode' must be implemented in subclass.")
36
 
37
 
38
  class SBertEncoder(Encoder):
39
+ def __init__(self, model_name: str):
40
  """
41
+ Initialize SBertEncoder instance.
42
 
43
+ Args:
44
+ model_name (str): Name or path of the Sentence Transformer model.
 
 
 
45
  """
46
  self.model = SentenceTransformer(model_name, trust_remote_code=True)
 
 
 
47
 
48
+ def encode(
49
+ self,
50
+ prediction: List[str],
51
+ *,
52
+ device: ENCODER_DEVICE_TYPE = "cpu",
53
+ batch_size: int = 32,
54
+ verbose: bool = False,
55
+ ) -> NDArray:
56
  """
57
+ Encode a list of sentences into sentence embeddings.
58
 
59
+ Args:
60
+ prediction (List[str]): List of sentences to encode.
61
+ device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
62
+ batch_size (int): Batch size for encoding.
63
+ verbose (bool): Whether to print verbose information during encoding.
64
 
65
+ Returns:
66
+ NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
67
  """
68
 
69
  # SBert output is always Batch x Dim
70
+ if isinstance(device, list):
71
  # Use multiprocess encoding for list of devices
72
+ pool = self.model.start_multi_process_pool(target_devices=device)
73
+ embeddings = self.model.encode_multi_process(
74
+ prediction, pool=pool, batch_size=batch_size
75
+ )
76
  self.model.stop_multi_process_pool(pool)
77
  else:
78
  # Single device encoding
79
  embeddings = self.model.encode(
80
  prediction,
81
+ device=device,
82
+ batch_size=batch_size,
83
+ show_progress_bar=verbose,
84
  )
 
85
  return embeddings
86
 
87
 
88
+ def get_encoder(model_name: str) -> Encoder:
89
  """
90
  Get the encoder instance based on the specified model name.
91
 
 
98
  Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
99
  SentenceTransformer.
100
 
 
 
 
 
 
101
  Returns:
102
  Encoder: Instance of the selected encoder based on the model_name.
103
 
 
106
  """
107
 
108
  try:
109
+ encoder = SBertEncoder(model_name) # , device, batch_size, verbose)
110
  except EnvironmentError as err:
111
  raise EnvironmentError(str(err)) from None
112
  except Exception as err:
113
  raise RuntimeError(str(err)) from None
114
 
115
  return encoder
 
 
semf1.py CHANGED
@@ -16,7 +16,7 @@ Sem-F1 metric
16
  Author: Naman Bansal
17
  """
18
 
19
- from typing import List, Optional, Tuple
20
 
21
  import datasets
22
  import evaluate
@@ -25,9 +25,16 @@ import numpy as np
25
  from numpy.typing import NDArray
26
  from sklearn.metrics.pairwise import cosine_similarity
27
 
28
- from .encoder_models import get_encoder
29
  from .type_aliases import DEVICE_TYPE, PREDICTION_TYPE, REFERENCE_TYPE
30
- from .utils import is_nested_list_of_type, Scores, slice_embeddings, flatten_list, get_gpu, sent_tokenize
 
 
 
 
 
 
 
31
 
32
  _CITATION = """\
33
  @inproceedings{bansal-etal-2022-sem,
@@ -63,13 +70,15 @@ using precision, recall, and F1 score based on sentence embeddings.
63
  Args:
64
  predictions (list): List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
65
  references (list): List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
66
- model_type (str): Model to use for encoding sentences. Options: ['pv1', 'stsb', 'use']
67
- pv1 - paraphrase-distilroberta-base-v1
68
- stsb - stsb-roberta-large
69
- use - Universal Sentence Encoder (Default)
70
- Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by SentenceTransformer such
71
- as `all-mpnet-base-v2` or `roberta-base`
72
-
 
 
73
  tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
74
  multi_references (bool): Flag to indicate whether multiple references are provided. Default is False.
75
  gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
@@ -151,19 +160,21 @@ Examples:
151
  """
152
 
153
 
154
- def _compute_cosine_similarity(pred_embeds: NDArray, ref_embeds: NDArray) -> Tuple[float, float]:
 
 
155
  """
156
- Compute precision and recall based on cosine similarity between predicted and reference embeddings.
157
 
158
- Args:
159
- pred_embeds (NDArray): Predicted embeddings (shape: [num_pred, embedding_dim]).
160
- ref_embeds (NDArray): Reference embeddings (shape: [num_ref, embedding_dim]).
161
 
162
- Returns:
163
- Tuple[float, float]: Precision and recall based on cosine similarity scores.
164
- Precision: Average maximum cosine similarity score per predicted embedding.
165
- Recall: Average maximum cosine similarity score per reference embedding.
166
- """
167
  # Compute cosine similarity between predicted and reference embeddings
168
  cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
169
 
@@ -181,60 +192,65 @@ def _compute_cosine_similarity(pred_embeds: NDArray, ref_embeds: NDArray) -> Tup
181
 
182
 
183
  def _validate_input_format(
184
- tokenize_sentences: bool,
185
- multi_references: bool,
186
- predictions: PREDICTION_TYPE,
187
- references: REFERENCE_TYPE,
188
  ):
189
  """
190
- Validate the format of predictions and references based on specified criteria.
191
 
192
- Args:
193
- - tokenize_sentences (bool): Flag indicating whether sentences should be tokenized.
194
- - multi_references (bool): Flag indicating whether multiple references are provided.
195
- - predictions (PREDICTION_TYPE): Predictions to validate.
196
- - references (REFERENCE_TYPE): References to validate.
197
 
198
- Raises:
199
- - ValueError: If the format of predictions or references does not meet the specified criteria.
200
 
201
- Validation Criteria:
202
- The function validates predictions and references based on the following conditions:
203
- 1. If `tokenize_sentences` is True and `multi_references` is True:
204
- - Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
205
- - References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
206
 
207
- 2. If `tokenize_sentences` is False and `multi_references` is True:
208
- - Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
209
- - References must be a list of list of list of strings (`is_list_of_strings_at_depth(references, 3)`).
210
 
211
- 3. If `tokenize_sentences` is True and `multi_references` is False:
212
- - Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
213
- - References must be a list of strings (`is_list_of_strings_at_depth(references, 1)`).
214
 
215
- 4. If `tokenize_sentences` is False and `multi_references` is False:
216
- - Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
217
- - References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
218
 
219
- The function checks these conditions and raises a ValueError if any condition is not met,
220
- indicating that predictions or references are not in the valid input format.
221
 
222
- Note:
223
- - `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
224
  """
225
 
226
  if len(predictions) != len(references):
227
- raise ValueError(f"Predictions and references must have the same length. "
228
- f"Got {len(predictions)} predictions and {len(references)} references.")
 
 
229
 
230
  if len(predictions) == 0:
231
  raise ValueError("Can't have empty inputs")
232
 
233
  def check_format(lst_obj, expected_depth: int, name: str):
234
- is_valid, error_message = is_nested_list_of_type(lst_obj, element_type=str, depth=expected_depth)
 
 
235
  if not is_valid:
236
- raise ValueError(f"{name} are not in the expected format.\n"
237
- f"Error: {error_message}.")
 
238
 
239
  try:
240
  if tokenize_sentences and multi_references:
@@ -274,9 +290,13 @@ class SemF1(evaluate.Metric):
274
  datasets.Features(
275
  {
276
  # predictions: List[List[str]] - List of predictions where prediction is a list of sentences
277
- "predictions": datasets.Sequence(datasets.Value("string", id="sequence"), id="predictions"),
 
 
278
  # references: List[List[str]] - List of references where each reference is a list of sentences
279
- "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
 
 
280
  }
281
  ),
282
  # F1: Multi References: False, Tokenize_Sentences = True
@@ -292,12 +312,18 @@ class SemF1(evaluate.Metric):
292
  datasets.Features(
293
  {
294
  # predictions: List[List[str]] - List of predictions where prediction is a list of sentences
295
- "predictions": datasets.Sequence(datasets.Value("string", id="sequence"), id="predictions"),
 
 
296
  # references: List[List[List[str]]] - List of multi-references.
297
  # So each "reference" is also a list (r1, r2, ...).
298
  # Further, each ri's are also list of sentences.
299
  "references": datasets.Sequence(
300
- datasets.Sequence(datasets.Value("string", id="sequence"), id="ref"), id="references"),
 
 
 
 
301
  }
302
  ),
303
  # F3: Multi References: True, Tokenize_Sentences = True
@@ -307,13 +333,15 @@ class SemF1(evaluate.Metric):
307
  "predictions": datasets.Value("string", id="sequence"),
308
  # references: List[List[List[str]]] - List of multi-references.
309
  # So each "reference" is also a list (r1, r2, ...).
310
- "references": datasets.Sequence(datasets.Value("string", id="ref"), id="references"),
 
 
311
  }
312
  ),
313
  ],
314
  # # Homepage of the module for documentation
315
  # Additional links to the codebase or references
316
- reference_urls=["https://aclanthology.org/2022.emnlp-main.49/"]
317
  )
318
 
319
  def _get_model_name(self, model_type: Optional[str] = None) -> str:
@@ -328,51 +356,62 @@ class SemF1(evaluate.Metric):
328
  def _download_and_prepare(self, dl_manager):
329
  """Optional: download external resources useful to compute the scores"""
330
  import nltk
 
331
  nltk.download("punkt", quiet=True)
332
 
333
  def _compute(
334
- self,
335
- predictions,
336
- references,
337
- model_type: Optional[str] = None,
338
- tokenize_sentences: bool = True,
339
- multi_references: bool = False,
340
- gpu: DEVICE_TYPE = False,
341
- batch_size: int = 32,
342
- verbose: bool = False,
343
- aggregate: bool = False,
344
  ) -> List[Scores]:
345
  """
346
- Compute precision, recall, and F1 scores for given predictions and references.
347
-
348
- :param predictions
349
- :param references
350
- :param model_type: Type of model to use for encoding.
351
- Options: [pv1, stsb, use]
352
- pv1 - paraphrase-distilroberta-base-v1
353
- stsb - stsb-roberta-large
354
- use - Universal Sentence Encoder (Default)
355
- Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
356
- SentenceTransformer.
357
-
358
- :param tokenize_sentences: Flag to sentence tokenize the document.
359
- :param multi_references: Flag to indicate multiple references.
360
- :param gpu: GPU device to use.
361
- :param batch_size: Batch size for encoding.
362
- :param verbose: Flag to indicate verbose output.
363
- :param aggregate: Flag to determine if output should be averaged
364
-
365
- :return: List of Scores dataclass with precision, recall, and F1 scores.
 
 
 
 
 
 
366
  """
367
 
368
  # Note: I have to specifically handle this case because the library considers the feature corresponding to
369
  # this case (F2) as the feature for the other case (F0) i.e. it can't make any distinction between
370
  # List[str] and List[List[str]]
371
  if not tokenize_sentences and multi_references:
372
- references = [[eval(ref) for ref in mul_ref_ex] for mul_ref_ex in references]
 
 
373
 
374
  # Validate inputs corresponding to flags
375
- _validate_input_format(tokenize_sentences, multi_references, predictions, references)
 
 
376
 
377
  # Get GPU
378
  device = get_gpu(gpu)
@@ -380,8 +419,15 @@ class SemF1(evaluate.Metric):
380
  print(f"Using devices: {device}")
381
 
382
  # Get the encoder model
383
- model_name = self._get_model_name(model_type)
384
- encoder = get_encoder(model_name, device=device, batch_size=batch_size, verbose=verbose)
 
 
 
 
 
 
 
385
 
386
  # We'll handle the single reference and multi-reference case same way. So change the data format accordingly
387
  if not multi_references:
@@ -401,11 +447,15 @@ class SemF1(evaluate.Metric):
401
 
402
  # Note: This is the most optimal way of doing it
403
  # Encode all sentences in one go
404
- embeddings = encoder.encode(all_sentences)
 
 
405
 
406
  # Get embeddings corresponding to predictions and references
407
  pred_embeddings = slice_embeddings(embeddings, prediction_sentences_count)
408
- ref_embeddings = slice_embeddings(embeddings[sum(prediction_sentences_count):], reference_sentences_count)
 
 
409
 
410
  # Init output scores
411
  results = []
@@ -418,23 +468,22 @@ class SemF1(evaluate.Metric):
418
  precision = np.clip(precision, a_min=0.0, a_max=1.0).item()
419
 
420
  # Recall: Compute individually for each reference
421
- recall_scores = [_compute_cosine_similarity(r_embeds, preds) for r_embeds in refs]
422
- recall_scores = [np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores]
 
 
 
 
423
 
424
  results.append(Scores(precision, recall_scores))
425
 
426
  # run aggregation procedure
427
  if aggregate:
428
- mean_prec = np.mean(
429
- [score.precision for score in results]
430
- )
431
- mean_recall = np.mean(np.concatenate(
432
- [np.array(score.recall) for score in results]
433
- ))
434
- aggregated_score = Scores(
435
- float(mean_prec),
436
- [float(mean_recall)]
437
  )
 
438
  results = aggregated_score
439
 
440
  return results
 
16
  Author: Naman Bansal
17
  """
18
 
19
+ from typing import List, Optional, Tuple, Union
20
 
21
  import datasets
22
  import evaluate
 
25
  from numpy.typing import NDArray
26
  from sklearn.metrics.pairwise import cosine_similarity
27
 
28
+ from .encoder_models import get_encoder, Encoder
29
  from .type_aliases import DEVICE_TYPE, PREDICTION_TYPE, REFERENCE_TYPE
30
+ from .utils import (
31
+ is_nested_list_of_type,
32
+ Scores,
33
+ slice_embeddings,
34
+ flatten_list,
35
+ get_gpu,
36
+ sent_tokenize,
37
+ )
38
 
39
  _CITATION = """\
40
  @inproceedings{bansal-etal-2022-sem,
 
70
  Args:
71
  predictions (list): List of predictions. Format varies based on `tokenize_sentences` and `multi_references` flags.
72
  references (list): List of references. Format varies based on `tokenize_sentences` and `multi_references` flags.
73
+ model_type (Optional[Union[str, Encoder]]): Model to use for encoding sentences.
74
+ Options: ['pv1', 'stsb', 'use']
75
+ pv1 - paraphrase-distilroberta-base-v1
76
+ stsb - stsb-roberta-large
77
+ use - Universal Sentence Encoder (Default)
78
+ - A string path or name for any model on Huggingface/SentenceTransformer that is supported by
79
+ SentenceTransformer such as `all-mpnet-base-v2` or `roberta-base` .
80
+ - A custom instance of an Encoder (must implement the encode() method). Refer SemF1/encoder_models.py
81
+
82
  tokenize_sentences (bool): Flag to indicate whether to tokenize the sentences in the input documents. Default: True.
83
  multi_references (bool): Flag to indicate whether multiple references are provided. Default is False.
84
  gpu (Union[bool, str, int, List[Union[str, int]]]): Whether to use GPU or CPU for computation.
 
160
  """
161
 
162
 
163
+ def _compute_cosine_similarity(
164
+ pred_embeds: NDArray, ref_embeds: NDArray
165
+ ) -> Tuple[float, float]:
166
  """
167
+ Compute precision and recall based on cosine similarity between predicted and reference embeddings.
168
 
169
+ Args:
170
+ pred_embeds (NDArray): Predicted embeddings (shape: [num_pred, embedding_dim]).
171
+ ref_embeds (NDArray): Reference embeddings (shape: [num_ref, embedding_dim]).
172
 
173
+ Returns:
174
+ Tuple[float, float]: Precision and recall based on cosine similarity scores.
175
+ Precision: Average maximum cosine similarity score per predicted embedding.
176
+ Recall: Average maximum cosine similarity score per reference embedding.
177
+ """
178
  # Compute cosine similarity between predicted and reference embeddings
179
  cosine_scores = cosine_similarity(pred_embeds, ref_embeds)
180
 
 
192
 
193
 
194
  def _validate_input_format(
195
+ tokenize_sentences: bool,
196
+ multi_references: bool,
197
+ predictions: PREDICTION_TYPE,
198
+ references: REFERENCE_TYPE,
199
  ):
200
  """
201
+ Validate the format of predictions and references based on specified criteria.
202
 
203
+ Args:
204
+ - tokenize_sentences (bool): Flag indicating whether sentences should be tokenized.
205
+ - multi_references (bool): Flag indicating whether multiple references are provided.
206
+ - predictions (PREDICTION_TYPE): Predictions to validate.
207
+ - references (REFERENCE_TYPE): References to validate.
208
 
209
+ Raises:
210
+ - ValueError: If the format of predictions or references does not meet the specified criteria.
211
 
212
+ Validation Criteria:
213
+ The function validates predictions and references based on the following conditions:
214
+ 1. If `tokenize_sentences` is True and `multi_references` is True:
215
+ - Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
216
+ - References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
217
 
218
+ 2. If `tokenize_sentences` is False and `multi_references` is True:
219
+ - Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
220
+ - References must be a list of list of list of strings (`is_list_of_strings_at_depth(references, 3)`).
221
 
222
+ 3. If `tokenize_sentences` is True and `multi_references` is False:
223
+ - Predictions must be a list of strings (`is_list_of_strings_at_depth(predictions, 1)`).
224
+ - References must be a list of strings (`is_list_of_strings_at_depth(references, 1)`).
225
 
226
+ 4. If `tokenize_sentences` is False and `multi_references` is False:
227
+ - Predictions must be a list of list of strings (`is_list_of_strings_at_depth(predictions, 2)`).
228
+ - References must be a list of list of strings (`is_list_of_strings_at_depth(references, 2)`).
229
 
230
+ The function checks these conditions and raises a ValueError if any condition is not met,
231
+ indicating that predictions or references are not in the valid input format.
232
 
233
+ Note:
234
+ - `PREDICTION_TYPE` and `REFERENCE_TYPE` are defined at the top of the file
235
  """
236
 
237
  if len(predictions) != len(references):
238
+ raise ValueError(
239
+ f"Predictions and references must have the same length. "
240
+ f"Got {len(predictions)} predictions and {len(references)} references."
241
+ )
242
 
243
  if len(predictions) == 0:
244
  raise ValueError("Can't have empty inputs")
245
 
246
  def check_format(lst_obj, expected_depth: int, name: str):
247
+ is_valid, error_message = is_nested_list_of_type(
248
+ lst_obj, element_type=str, depth=expected_depth
249
+ )
250
  if not is_valid:
251
+ raise ValueError(
252
+ f"{name} are not in the expected format.\n" f"Error: {error_message}."
253
+ )
254
 
255
  try:
256
  if tokenize_sentences and multi_references:
 
290
  datasets.Features(
291
  {
292
  # predictions: List[List[str]] - List of predictions where prediction is a list of sentences
293
+ "predictions": datasets.Sequence(
294
+ datasets.Value("string", id="sequence"), id="predictions"
295
+ ),
296
  # references: List[List[str]] - List of references where each reference is a list of sentences
297
+ "references": datasets.Sequence(
298
+ datasets.Value("string", id="sequence"), id="references"
299
+ ),
300
  }
301
  ),
302
  # F1: Multi References: False, Tokenize_Sentences = True
 
312
  datasets.Features(
313
  {
314
  # predictions: List[List[str]] - List of predictions where prediction is a list of sentences
315
+ "predictions": datasets.Sequence(
316
+ datasets.Value("string", id="sequence"), id="predictions"
317
+ ),
318
  # references: List[List[List[str]]] - List of multi-references.
319
  # So each "reference" is also a list (r1, r2, ...).
320
  # Further, each ri's are also list of sentences.
321
  "references": datasets.Sequence(
322
+ datasets.Sequence(
323
+ datasets.Value("string", id="sequence"), id="ref"
324
+ ),
325
+ id="references",
326
+ ),
327
  }
328
  ),
329
  # F3: Multi References: True, Tokenize_Sentences = True
 
333
  "predictions": datasets.Value("string", id="sequence"),
334
  # references: List[List[List[str]]] - List of multi-references.
335
  # So each "reference" is also a list (r1, r2, ...).
336
+ "references": datasets.Sequence(
337
+ datasets.Value("string", id="ref"), id="references"
338
+ ),
339
  }
340
  ),
341
  ],
342
  # # Homepage of the module for documentation
343
  # Additional links to the codebase or references
344
+ reference_urls=["https://aclanthology.org/2022.emnlp-main.49/"],
345
  )
346
 
347
  def _get_model_name(self, model_type: Optional[str] = None) -> str:
 
356
  def _download_and_prepare(self, dl_manager):
357
  """Optional: download external resources useful to compute the scores"""
358
  import nltk
359
+
360
  nltk.download("punkt", quiet=True)
361
 
362
  def _compute(
363
+ self,
364
+ predictions,
365
+ references,
366
+ model_type: Optional[Union[str, Encoder]] = None,
367
+ tokenize_sentences: bool = True,
368
+ multi_references: bool = False,
369
+ gpu: DEVICE_TYPE = False,
370
+ batch_size: int = 32,
371
+ verbose: bool = False,
372
+ aggregate: bool = False,
373
  ) -> List[Scores]:
374
  """
375
+ Compute precision, recall, and F1 scores for given predictions and references.
376
+
377
+ Args:
378
+ - predictions
379
+ - references
380
+ - model_type: Type of model to use for encoding.
381
+ Options: [pv1, stsb, use]
382
+ pv1 - paraphrase-distilroberta-base-v1
383
+ stsb - stsb-roberta-large
384
+ use - Universal Sentence Encoder (Default)
385
+ - A string path or name for any model on Huggingface/SentenceTransformer that is supported by
386
+ SentenceTransformer.
387
+ - A custom instance of an Encoder (must implement the encode() method). Refer SemF1/encoder_models.py
388
+
389
+ - tokenize_sentences: Flag to sentence tokenize the document.
390
+ - multi_references: Flag to indicate multiple references.
391
+ - gpu: GPU device to use.
392
+ - batch_size: Batch size for encoding.
393
+ - verbose: Flag to indicate verbose output.
394
+ - aggregate: Flag to determine if output should be averaged
395
+
396
+ Returns:
397
+ Singleton/List of Scores dataclass with attributes as follows -
398
+ precision: float - precision score
399
+ recall: List[float] - List of recall scores corresponding to single/multiple references
400
+ f1: float - F1 score (between precision and average recall)
401
  """
402
 
403
  # Note: I have to specifically handle this case because the library considers the feature corresponding to
404
  # this case (F2) as the feature for the other case (F0) i.e. it can't make any distinction between
405
  # List[str] and List[List[str]]
406
  if not tokenize_sentences and multi_references:
407
+ references = [
408
+ [eval(ref) for ref in mul_ref_ex] for mul_ref_ex in references
409
+ ]
410
 
411
  # Validate inputs corresponding to flags
412
+ _validate_input_format(
413
+ tokenize_sentences, multi_references, predictions, references
414
+ )
415
 
416
  # Get GPU
417
  device = get_gpu(gpu)
 
419
  print(f"Using devices: {device}")
420
 
421
  # Get the encoder model
422
+ if model_type is None or isinstance(model_type, str):
423
+ model_name = self._get_model_name(model_type)
424
+ encoder = get_encoder(model_name)
425
+ elif isinstance(model_type, Encoder):
426
+ encoder = model_type
427
+ else:
428
+ raise TypeError(
429
+ f"Unsupported model_type: expected str or Encoder instance, got {type(model_type)}"
430
+ )
431
 
432
  # We'll handle the single reference and multi-reference case same way. So change the data format accordingly
433
  if not multi_references:
 
447
 
448
  # Note: This is the most optimal way of doing it
449
  # Encode all sentences in one go
450
+ embeddings = encoder.encode(
451
+ all_sentences, device=device, batch_size=batch_size, verbose=verbose
452
+ )
453
 
454
  # Get embeddings corresponding to predictions and references
455
  pred_embeddings = slice_embeddings(embeddings, prediction_sentences_count)
456
+ ref_embeddings = slice_embeddings(
457
+ embeddings[sum(prediction_sentences_count) :], reference_sentences_count
458
+ )
459
 
460
  # Init output scores
461
  results = []
 
468
  precision = np.clip(precision, a_min=0.0, a_max=1.0).item()
469
 
470
  # Recall: Compute individually for each reference
471
+ recall_scores = [
472
+ _compute_cosine_similarity(r_embeds, preds) for r_embeds in refs
473
+ ]
474
+ recall_scores = [
475
+ np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores
476
+ ]
477
 
478
  results.append(Scores(precision, recall_scores))
479
 
480
  # run aggregation procedure
481
  if aggregate:
482
+ mean_prec = np.mean([score.precision for score in results])
483
+ mean_recall = np.mean(
484
+ np.concatenate([np.array(score.recall) for score in results])
 
 
 
 
 
 
485
  )
486
+ aggregated_score = Scores(float(mean_prec), [float(mean_recall)])
487
  results = aggregated_score
488
 
489
  return results
tests.py CHANGED
@@ -10,7 +10,14 @@ from unittest import TestLoader
10
 
11
  from .encoder_models import SBertEncoder, get_encoder
12
  from .semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
13
- from .utils import get_gpu, slice_embeddings, is_nested_list_of_type, flatten_list, compute_f1, Scores
 
 
 
 
 
 
 
14
 
15
 
16
  class TestUtils(unittest.TestCase):
@@ -40,20 +47,29 @@ class TestUtils(unittest.TestCase):
40
  self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
41
 
42
  # Test list input with unique elements
43
- self.assertEqual(get_gpu([True, "cpu", 0]), [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"])
 
 
 
44
 
45
  # Test list input with duplicate elements
46
- self.assertEqual(get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
 
 
47
 
48
  # Test list input with duplicate elements of different types
49
- self.assertEqual(get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"])
 
 
50
 
51
  # Test list input but only one element
52
  self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
53
 
54
  # Test list input with all integers
55
- self.assertEqual(get_gpu(list(range(gpu_count))),
56
- list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"])
 
 
57
 
58
  with self.assertRaises(ValueError):
59
  get_gpu("invalid")
@@ -66,12 +82,19 @@ class TestUtils(unittest.TestCase):
66
  num_sentences = [3, 2, 5]
67
  expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
68
  self.assertTrue(
69
- all(np.array_equal(a, b) for a, b in zip(slice_embeddings(embeddings, num_sentences),
70
- expected_output))
 
 
 
 
71
  )
72
 
73
  num_sentences_nested = [[2, 1], [3, 4]]
74
- expected_output_nested = [[embeddings[:2], embeddings[2:3]], [embeddings[3:6], embeddings[6:]]]
 
 
 
75
  self.assertTrue(
76
  slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
77
  )
@@ -88,7 +111,9 @@ class TestUtils(unittest.TestCase):
88
  self.assertEqual(is_valid, False)
89
 
90
  # Test case: Depth 1, list of elements matching element_type
91
- self.assertEqual(is_nested_list_of_type(["apple", "banana"], str, 1), (True, ""))
 
 
92
 
93
  # Test case: Depth 1, list of elements not matching element_type
94
  is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1)
@@ -100,15 +125,18 @@ class TestUtils(unittest.TestCase):
100
 
101
  # Depth 2
102
  self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, ""))
103
- self.assertEqual(is_nested_list_of_type([['1', '2'], ['3', '4']], str, 2), (True, ""))
 
 
104
  is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)
105
  self.assertEqual(is_valid, False)
106
 
107
-
108
  # Depth 3
109
  is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)
110
  self.assertEqual(is_valid, False)
111
- self.assertEqual(is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3), (True, ""))
 
 
112
 
113
  # Test case: Depth is negative, expecting ValueError
114
  with self.assertRaises(ValueError):
@@ -134,38 +162,55 @@ class TestUtils(unittest.TestCase):
134
  class TestSBertEncoder(unittest.TestCase):
135
  def setUp(self, device=None):
136
  if device is None:
137
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
138
  else:
139
  self.device = device
140
  self.model_name = "stsb-roberta-large"
141
  self.batch_size = 8
142
  self.verbose = False
143
- self.encoder = SBertEncoder(self.model_name, self.device, self.batch_size, self.verbose)
144
 
145
  def test_initialization(self):
146
  self.assertIsInstance(self.encoder.model, SentenceTransformer)
147
- self.assertEqual(self.encoder.device, self.device)
148
- self.assertEqual(self.encoder.batch_size, self.batch_size)
149
- self.assertEqual(self.encoder.verbose, self.verbose)
150
 
151
  def test_encode_single_device(self):
152
  sentences = ["This is a test sentence.", "Here is another sentence."]
153
- embeddings = self.encoder.encode(sentences)
 
 
 
 
 
154
  self.assertIsInstance(embeddings, np.ndarray)
155
  self.assertEqual(embeddings.shape[0], len(sentences))
156
- self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
 
 
157
 
158
  def test_encode_multi_device(self):
159
  if torch.cuda.device_count() < 2:
160
  self.skipTest("Multi-GPU test requires at least 2 GPUs.")
161
  else:
162
- devices = ["cuda:0", "cuda:1"]
 
163
  self.setUp(devices)
164
- sentences = ["This is a test sentence.", "Here is another sentence.", "This is a test sentence."]
165
- embeddings = self.encoder.encode(sentences)
 
 
 
 
 
 
 
 
 
166
  self.assertIsInstance(embeddings, np.ndarray)
167
  self.assertEqual(embeddings.shape[0], 3)
168
- self.assertEqual(embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension())
 
 
 
169
 
170
 
171
  class TestGetEncoder(unittest.TestCase):
@@ -175,13 +220,8 @@ class TestGetEncoder(unittest.TestCase):
175
  self.verbose = False
176
 
177
  def _base_test(self, model_name):
178
- encoder = get_encoder(model_name, self.device, self.batch_size, self.verbose)
179
-
180
- # Assert
181
  self.assertIsInstance(encoder, SBertEncoder)
182
- self.assertEqual(encoder.device, self.device)
183
- self.assertEqual(encoder.batch_size, self.batch_size)
184
- self.assertEqual(encoder.verbose, self.verbose)
185
 
186
  def test_get_sbert_encoder(self):
187
  model_name = "stsb-roberta-large"
@@ -196,15 +236,15 @@ class TestGetEncoder(unittest.TestCase):
196
  model_name = "roberta-base"
197
  self._base_test(model_name)
198
 
199
- def test_get_encoder_environment_error(self): # This parameter is used when using patch decorator
200
  model_name = "abc" # Wrong model_name
201
  with self.assertRaises(EnvironmentError):
202
- get_encoder(model_name, self.device, self.batch_size, self.verbose)
203
 
204
  def test_get_encoder_other_exception(self):
205
  model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib
206
  with self.assertRaises(RuntimeError):
207
- get_encoder(model_name, self.device, self.batch_size, self.verbose)
208
 
209
 
210
  class TestSemF1(unittest.TestCase):
@@ -213,9 +253,11 @@ class TestSemF1(unittest.TestCase):
213
 
214
  # Example cases, #Samples = 1
215
  self.untokenized_single_reference_predictions = [
216
- "This is a prediction sentence 1. This is a prediction sentence 2."]
 
217
  self.untokenized_single_reference_references = [
218
- "This is a reference sentence 1. This is a reference sentence 2."]
 
219
 
220
  self.tokenized_single_reference_predictions = [
221
  ["This is a prediction sentence 1.", "This is a prediction sentence 2."],
@@ -228,7 +270,10 @@ class TestSemF1(unittest.TestCase):
228
  "Prediction sentence 1. Prediction sentence 2."
229
  ]
230
  self.untokenized_multi_reference_references = [
231
- ["Reference sentence 1. Reference sentence 2.", "Alternative reference 1. Alternative reference 2."],
 
 
 
232
  ]
233
 
234
  self.tokenized_multi_reference_predictions = [
@@ -237,21 +282,21 @@ class TestSemF1(unittest.TestCase):
237
  self.tokenized_multi_reference_references = [
238
  [
239
  ["Reference sentence 1.", "Reference sentence 2."],
240
- ["Alternative reference 1.", "Alternative reference 2."]
241
  ],
242
  ]
243
  self.multi_sample_refs = [
244
- 'this is the first reference sample',
245
- 'this is the second reference sample',
246
  ]
247
  self.multi_sample_preds = [
248
- 'this is the first prediction sample',
249
- 'this is the second prediction sample',
250
  ]
251
-
252
  def test_aggregate_multi_sample(self):
253
  """
254
- check if a `Scores` class is returned instead of a list of
255
  `Scores`
256
  """
257
  scores = self.semf1_metric.compute(
@@ -265,7 +310,7 @@ class TestSemF1(unittest.TestCase):
265
  aggregate=True,
266
  )
267
  self.assertIsInstance(scores, Scores)
268
- print(f'Score: {scores}')
269
 
270
  def test_aggregate_untokenized_single_ref(self):
271
  scores = self.semf1_metric.compute(
@@ -279,7 +324,7 @@ class TestSemF1(unittest.TestCase):
279
  aggregate=True,
280
  )
281
  self.assertIsInstance(scores, Scores)
282
- print(f'Score: {scores}')
283
 
284
  def test_aggregate_tokenized_single_ref(self):
285
  scores = self.semf1_metric.compute(
@@ -293,7 +338,7 @@ class TestSemF1(unittest.TestCase):
293
  aggregate=True,
294
  )
295
  self.assertIsInstance(scores, Scores)
296
- print(f'Score: {scores}')
297
 
298
  def test_aggregate_untokenized_multi_ref(self):
299
  scores = self.semf1_metric.compute(
@@ -307,7 +352,7 @@ class TestSemF1(unittest.TestCase):
307
  aggregate=True,
308
  )
309
  self.assertIsInstance(scores, Scores)
310
- print(f'Score: {scores}')
311
 
312
  def test_aggregate_tokenized_multi_ref(self):
313
  scores = self.semf1_metric.compute(
@@ -321,7 +366,7 @@ class TestSemF1(unittest.TestCase):
321
  aggregate=True,
322
  )
323
  self.assertIsInstance(scores, Scores)
324
- print(f'Score: {scores}')
325
 
326
  def test_aggregate_same_pred_and_ref(self):
327
  scores = self.semf1_metric.compute(
@@ -335,7 +380,7 @@ class TestSemF1(unittest.TestCase):
335
  aggregate=True,
336
  )
337
  self.assertIsInstance(scores, Scores)
338
- print(f'Score: {scores}')
339
 
340
  def test_untokenized_single_reference(self):
341
  scores = self.semf1_metric.compute(
@@ -345,10 +390,12 @@ class TestSemF1(unittest.TestCase):
345
  multi_references=False,
346
  gpu=False,
347
  batch_size=32,
348
- verbose=False
349
  )
350
  self.assertIsInstance(scores, list)
351
- self.assertEqual(len(scores), len(self.untokenized_single_reference_predictions))
 
 
352
 
353
  def test_tokenized_single_reference(self):
354
  scores = self.semf1_metric.compute(
@@ -358,7 +405,7 @@ class TestSemF1(unittest.TestCase):
358
  multi_references=False,
359
  gpu=False,
360
  batch_size=32,
361
- verbose=False
362
  )
363
  self.assertIsInstance(scores, list)
364
  self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
@@ -376,7 +423,7 @@ class TestSemF1(unittest.TestCase):
376
  multi_references=True,
377
  gpu=False,
378
  batch_size=32,
379
- verbose=False
380
  )
381
  self.assertIsInstance(scores, list)
382
  self.assertEqual(len(scores), len(self.untokenized_multi_reference_predictions))
@@ -389,7 +436,7 @@ class TestSemF1(unittest.TestCase):
389
  multi_references=True,
390
  gpu=False,
391
  batch_size=32,
392
- verbose=False
393
  )
394
  self.assertIsInstance(scores, list)
395
  self.assertEqual(len(scores), len(self.tokenized_multi_reference_predictions))
@@ -407,7 +454,7 @@ class TestSemF1(unittest.TestCase):
407
  multi_references=False,
408
  gpu=False,
409
  batch_size=32,
410
- verbose=False
411
  )
412
 
413
  self.assertIsInstance(scores, list)
@@ -416,7 +463,12 @@ class TestSemF1(unittest.TestCase):
416
  for score in scores:
417
  self.assertIsInstance(score, Scores)
418
  self.assertAlmostEqual(score.precision, 1.0, places=6)
419
- assert_almost_equal(score.recall, 1, decimal=5, err_msg="Not all values are almost equal to 1")
 
 
 
 
 
420
 
421
  def test_exact_output_scores(self):
422
  predictions = [
@@ -473,7 +525,9 @@ class TestSemF1(unittest.TestCase):
473
  ["I am", "I am"],
474
  [None, "I am"],
475
  ]
476
- print(f"Case I\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
 
 
477
 
478
  # Case 2: tokenize_sentences = False, multi_references = True
479
  tokenize_sentences = False
@@ -486,7 +540,9 @@ class TestSemF1(unittest.TestCase):
486
  [["I am", "I am"], [None, "I am"]],
487
  [[None, "I am"]],
488
  ]
489
- print(f"Case II\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
 
 
490
 
491
  # Case 3: tokenize_sentences = True, multi_references = False
492
  tokenize_sentences = True
@@ -499,7 +555,9 @@ class TestSemF1(unittest.TestCase):
499
  "I am. I am.",
500
  "I am. I am.",
501
  ]
502
- print(f"Case III\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
 
 
503
 
504
  # Case 4: tokenize_sentences = False, multi_references = False
505
  # This is taken care by the library itself
@@ -513,7 +571,9 @@ class TestSemF1(unittest.TestCase):
513
  ["I am.", "I am."],
514
  ["I am.", "I am."],
515
  ]
516
- print(f"Case IV\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n")
 
 
517
 
518
  def test_empty_input(self):
519
  predictions = ["", ""]
@@ -538,22 +598,16 @@ class TestCosineSimilarity(unittest.TestCase):
538
 
539
  def setUp(self):
540
  # Sample embeddings for testing
541
- self.pred_embeds = np.array([
542
- [1, 0, 0],
543
- [0, 1, 0],
544
- [0, 0, 1]
545
- ])
546
- self.ref_embeds = np.array([
547
- [1, 0, 0],
548
- [0, 1, 0],
549
- [0, 0, 1]
550
- ])
551
 
552
  self.pred_embeds_random = np.random.rand(3, 3)
553
  self.ref_embeds_random = np.random.rand(3, 3)
554
 
555
  def test_cosine_similarity_perfect_match(self):
556
- precision, recall = _compute_cosine_similarity(self.pred_embeds, self.ref_embeds)
 
 
557
 
558
  # Expected values are 1.0 for both precision and recall since embeddings are identical
559
  self.assertAlmostEqual(precision, 1.0, places=5)
@@ -571,7 +625,9 @@ class TestCosineSimilarity(unittest.TestCase):
571
  self.assertAlmostEqual(recall, expected_recall, places=5)
572
 
573
  def test_cosine_similarity_random(self):
574
- self._test_cosine_similarity_base(self.pred_embeds_random, self.ref_embeds_random)
 
 
575
 
576
  def test_cosine_similarity_different_shapes(self):
577
  pred_embeds_diff = np.random.rand(5, 3)
@@ -607,7 +663,7 @@ class TestValidateInputFormat(unittest.TestCase):
607
  self.untokenized_multi_reference_references = [
608
  [
609
  "This is a reference sentence 1. This is a reference sentence 2.",
610
- "Another reference sentence."
611
  ]
612
  ]
613
 
@@ -618,7 +674,7 @@ class TestValidateInputFormat(unittest.TestCase):
618
  self.tokenized_multi_reference_references = [
619
  [
620
  ["This is a reference sentence 1.", "This is a reference sentence 2."],
621
- ["Another reference sentence."]
622
  ]
623
  ]
624
 
@@ -701,7 +757,10 @@ class TestValidateInputFormat(unittest.TestCase):
701
  True,
702
  True,
703
  self.untokenized_single_reference_predictions,
704
- [self.untokenized_single_reference_predictions[0], self.untokenized_single_reference_predictions[0]],
 
 
 
705
  )
706
 
707
 
@@ -709,5 +768,5 @@ def run_tests():
709
  unittest.main(verbosity=2)
710
 
711
 
712
- if __name__ == '__main__':
713
  run_tests()
 
10
 
11
  from .encoder_models import SBertEncoder, get_encoder
12
  from .semf1 import SemF1, _compute_cosine_similarity, _validate_input_format
13
+ from .utils import (
14
+ get_gpu,
15
+ slice_embeddings,
16
+ is_nested_list_of_type,
17
+ flatten_list,
18
+ compute_f1,
19
+ Scores,
20
+ )
21
 
22
 
23
  class TestUtils(unittest.TestCase):
 
47
  self.assertEqual(get_gpu(1), 1 if gpu_available else "cpu")
48
 
49
  # Test list input with unique elements
50
+ self.assertEqual(
51
+ get_gpu([True, "cpu", 0]),
52
+ [0, "cpu"] if gpu_available else ["cpu", "cpu", "cpu"],
53
+ )
54
 
55
  # Test list input with duplicate elements
56
+ self.assertEqual(
57
+ get_gpu([0, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"]
58
+ )
59
 
60
  # Test list input with duplicate elements of different types
61
+ self.assertEqual(
62
+ get_gpu([True, 0, "gpu"]), 0 if gpu_available else ["cpu", "cpu", "cpu"]
63
+ )
64
 
65
  # Test list input but only one element
66
  self.assertEqual(get_gpu([True]), 0 if gpu_available else "cpu")
67
 
68
  # Test list input with all integers
69
+ self.assertEqual(
70
+ get_gpu(list(range(gpu_count))),
71
+ list(range(gpu_count)) if gpu_available else gpu_count * ["cpu"],
72
+ )
73
 
74
  with self.assertRaises(ValueError):
75
  get_gpu("invalid")
 
82
  num_sentences = [3, 2, 5]
83
  expected_output = [embeddings[:3], embeddings[3:5], embeddings[5:]]
84
  self.assertTrue(
85
+ all(
86
+ np.array_equal(a, b)
87
+ for a, b in zip(
88
+ slice_embeddings(embeddings, num_sentences), expected_output
89
+ )
90
+ )
91
  )
92
 
93
  num_sentences_nested = [[2, 1], [3, 4]]
94
+ expected_output_nested = [
95
+ [embeddings[:2], embeddings[2:3]],
96
+ [embeddings[3:6], embeddings[6:]],
97
+ ]
98
  self.assertTrue(
99
  slice_embeddings(embeddings, num_sentences_nested), expected_output_nested
100
  )
 
111
  self.assertEqual(is_valid, False)
112
 
113
  # Test case: Depth 1, list of elements matching element_type
114
+ self.assertEqual(
115
+ is_nested_list_of_type(["apple", "banana"], str, 1), (True, "")
116
+ )
117
 
118
  # Test case: Depth 1, list of elements not matching element_type
119
  is_valid, err_msg = is_nested_list_of_type([1, 2, 3], str, 1)
 
125
 
126
  # Depth 2
127
  self.assertEqual(is_nested_list_of_type([[1, 2], [3, 4]], int, 2), (True, ""))
128
+ self.assertEqual(
129
+ is_nested_list_of_type([["1", "2"], ["3", "4"]], str, 2), (True, "")
130
+ )
131
  is_valid, err_msg = is_nested_list_of_type([[1, 2], ["a", "b"]], int, 2)
132
  self.assertEqual(is_valid, False)
133
 
 
134
  # Depth 3
135
  is_valid, err_msg = is_nested_list_of_type([[[1], [2]], [[3], [4]]], list, 3)
136
  self.assertEqual(is_valid, False)
137
+ self.assertEqual(
138
+ is_nested_list_of_type([[[1], [2]], [[3], [4]]], int, 3), (True, "")
139
+ )
140
 
141
  # Test case: Depth is negative, expecting ValueError
142
  with self.assertRaises(ValueError):
 
162
  class TestSBertEncoder(unittest.TestCase):
163
  def setUp(self, device=None):
164
  if device is None:
165
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
166
  else:
167
  self.device = device
168
  self.model_name = "stsb-roberta-large"
169
  self.batch_size = 8
170
  self.verbose = False
171
+ self.encoder = SBertEncoder(self.model_name)
172
 
173
  def test_initialization(self):
174
  self.assertIsInstance(self.encoder.model, SentenceTransformer)
 
 
 
175
 
176
  def test_encode_single_device(self):
177
  sentences = ["This is a test sentence.", "Here is another sentence."]
178
+ embeddings = self.encoder.encode(
179
+ sentences,
180
+ device=self.device,
181
+ batch_size=self.batch_size,
182
+ verbose=self.verbose,
183
+ )
184
  self.assertIsInstance(embeddings, np.ndarray)
185
  self.assertEqual(embeddings.shape[0], len(sentences))
186
+ self.assertEqual(
187
+ embeddings.shape[1], self.encoder.model.get_sentence_embedding_dimension()
188
+ )
189
 
190
  def test_encode_multi_device(self):
191
  if torch.cuda.device_count() < 2:
192
  self.skipTest("Multi-GPU test requires at least 2 GPUs.")
193
  else:
194
+ # devices = ["cuda:0", "cuda:1"]
195
+ devices = [0, 1]
196
  self.setUp(devices)
197
+ sentences = [
198
+ "This is a test sentence.",
199
+ "Here is another sentence.",
200
+ "This is a test sentence.",
201
+ ]
202
+ embeddings = self.encoder.encode(
203
+ sentences,
204
+ device=devices,
205
+ batch_size=self.batch_size,
206
+ verbose=self.verbose,
207
+ )
208
  self.assertIsInstance(embeddings, np.ndarray)
209
  self.assertEqual(embeddings.shape[0], 3)
210
+ self.assertEqual(
211
+ embeddings.shape[1],
212
+ self.encoder.model.get_sentence_embedding_dimension(),
213
+ )
214
 
215
 
216
  class TestGetEncoder(unittest.TestCase):
 
220
  self.verbose = False
221
 
222
  def _base_test(self, model_name):
223
+ encoder = get_encoder(model_name)
 
 
224
  self.assertIsInstance(encoder, SBertEncoder)
 
 
 
225
 
226
  def test_get_sbert_encoder(self):
227
  model_name = "stsb-roberta-large"
 
236
  model_name = "roberta-base"
237
  self._base_test(model_name)
238
 
239
+ def test_get_encoder_environment_error(self):
240
  model_name = "abc" # Wrong model_name
241
  with self.assertRaises(EnvironmentError):
242
+ get_encoder(model_name)
243
 
244
  def test_get_encoder_other_exception(self):
245
  model_name = "apple/OpenELM-270M" # This model is not supported by SentenceTransformer lib
246
  with self.assertRaises(RuntimeError):
247
+ get_encoder(model_name)
248
 
249
 
250
  class TestSemF1(unittest.TestCase):
 
253
 
254
  # Example cases, #Samples = 1
255
  self.untokenized_single_reference_predictions = [
256
+ "This is a prediction sentence 1. This is a prediction sentence 2."
257
+ ]
258
  self.untokenized_single_reference_references = [
259
+ "This is a reference sentence 1. This is a reference sentence 2."
260
+ ]
261
 
262
  self.tokenized_single_reference_predictions = [
263
  ["This is a prediction sentence 1.", "This is a prediction sentence 2."],
 
270
  "Prediction sentence 1. Prediction sentence 2."
271
  ]
272
  self.untokenized_multi_reference_references = [
273
+ [
274
+ "Reference sentence 1. Reference sentence 2.",
275
+ "Alternative reference 1. Alternative reference 2.",
276
+ ],
277
  ]
278
 
279
  self.tokenized_multi_reference_predictions = [
 
282
  self.tokenized_multi_reference_references = [
283
  [
284
  ["Reference sentence 1.", "Reference sentence 2."],
285
+ ["Alternative reference 1.", "Alternative reference 2."],
286
  ],
287
  ]
288
  self.multi_sample_refs = [
289
+ "this is the first reference sample",
290
+ "this is the second reference sample",
291
  ]
292
  self.multi_sample_preds = [
293
+ "this is the first prediction sample",
294
+ "this is the second prediction sample",
295
  ]
296
+
297
  def test_aggregate_multi_sample(self):
298
  """
299
+ check if a `Scores` class is returned instead of a list of
300
  `Scores`
301
  """
302
  scores = self.semf1_metric.compute(
 
310
  aggregate=True,
311
  )
312
  self.assertIsInstance(scores, Scores)
313
+ print(f"Score: {scores}")
314
 
315
  def test_aggregate_untokenized_single_ref(self):
316
  scores = self.semf1_metric.compute(
 
324
  aggregate=True,
325
  )
326
  self.assertIsInstance(scores, Scores)
327
+ print(f"Score: {scores}")
328
 
329
  def test_aggregate_tokenized_single_ref(self):
330
  scores = self.semf1_metric.compute(
 
338
  aggregate=True,
339
  )
340
  self.assertIsInstance(scores, Scores)
341
+ print(f"Score: {scores}")
342
 
343
  def test_aggregate_untokenized_multi_ref(self):
344
  scores = self.semf1_metric.compute(
 
352
  aggregate=True,
353
  )
354
  self.assertIsInstance(scores, Scores)
355
+ print(f"Score: {scores}")
356
 
357
  def test_aggregate_tokenized_multi_ref(self):
358
  scores = self.semf1_metric.compute(
 
366
  aggregate=True,
367
  )
368
  self.assertIsInstance(scores, Scores)
369
+ print(f"Score: {scores}")
370
 
371
  def test_aggregate_same_pred_and_ref(self):
372
  scores = self.semf1_metric.compute(
 
380
  aggregate=True,
381
  )
382
  self.assertIsInstance(scores, Scores)
383
+ print(f"Score: {scores}")
384
 
385
  def test_untokenized_single_reference(self):
386
  scores = self.semf1_metric.compute(
 
390
  multi_references=False,
391
  gpu=False,
392
  batch_size=32,
393
+ verbose=False,
394
  )
395
  self.assertIsInstance(scores, list)
396
+ self.assertEqual(
397
+ len(scores), len(self.untokenized_single_reference_predictions)
398
+ )
399
 
400
  def test_tokenized_single_reference(self):
401
  scores = self.semf1_metric.compute(
 
405
  multi_references=False,
406
  gpu=False,
407
  batch_size=32,
408
+ verbose=False,
409
  )
410
  self.assertIsInstance(scores, list)
411
  self.assertEqual(len(scores), len(self.tokenized_single_reference_predictions))
 
423
  multi_references=True,
424
  gpu=False,
425
  batch_size=32,
426
+ verbose=False,
427
  )
428
  self.assertIsInstance(scores, list)
429
  self.assertEqual(len(scores), len(self.untokenized_multi_reference_predictions))
 
436
  multi_references=True,
437
  gpu=False,
438
  batch_size=32,
439
+ verbose=False,
440
  )
441
  self.assertIsInstance(scores, list)
442
  self.assertEqual(len(scores), len(self.tokenized_multi_reference_predictions))
 
454
  multi_references=False,
455
  gpu=False,
456
  batch_size=32,
457
+ verbose=False,
458
  )
459
 
460
  self.assertIsInstance(scores, list)
 
463
  for score in scores:
464
  self.assertIsInstance(score, Scores)
465
  self.assertAlmostEqual(score.precision, 1.0, places=6)
466
+ assert_almost_equal(
467
+ score.recall,
468
+ 1,
469
+ decimal=5,
470
+ err_msg="Not all values are almost equal to 1",
471
+ )
472
 
473
  def test_exact_output_scores(self):
474
  predictions = [
 
525
  ["I am", "I am"],
526
  [None, "I am"],
527
  ]
528
+ print(
529
+ f"Case I\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
530
+ )
531
 
532
  # Case 2: tokenize_sentences = False, multi_references = True
533
  tokenize_sentences = False
 
540
  [["I am", "I am"], [None, "I am"]],
541
  [[None, "I am"]],
542
  ]
543
+ print(
544
+ f"Case II\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
545
+ )
546
 
547
  # Case 3: tokenize_sentences = True, multi_references = False
548
  tokenize_sentences = True
 
555
  "I am. I am.",
556
  "I am. I am.",
557
  ]
558
+ print(
559
+ f"Case III\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
560
+ )
561
 
562
  # Case 4: tokenize_sentences = False, multi_references = False
563
  # This is taken care by the library itself
 
571
  ["I am.", "I am."],
572
  ["I am.", "I am."],
573
  ]
574
+ print(
575
+ f"Case IV\n{_call_metric(predictions, references, tokenize_sentences, multi_references)}\n"
576
+ )
577
 
578
  def test_empty_input(self):
579
  predictions = ["", ""]
 
598
 
599
  def setUp(self):
600
  # Sample embeddings for testing
601
+ self.pred_embeds = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
602
+ self.ref_embeds = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
 
 
 
 
 
 
 
 
603
 
604
  self.pred_embeds_random = np.random.rand(3, 3)
605
  self.ref_embeds_random = np.random.rand(3, 3)
606
 
607
  def test_cosine_similarity_perfect_match(self):
608
+ precision, recall = _compute_cosine_similarity(
609
+ self.pred_embeds, self.ref_embeds
610
+ )
611
 
612
  # Expected values are 1.0 for both precision and recall since embeddings are identical
613
  self.assertAlmostEqual(precision, 1.0, places=5)
 
625
  self.assertAlmostEqual(recall, expected_recall, places=5)
626
 
627
  def test_cosine_similarity_random(self):
628
+ self._test_cosine_similarity_base(
629
+ self.pred_embeds_random, self.ref_embeds_random
630
+ )
631
 
632
  def test_cosine_similarity_different_shapes(self):
633
  pred_embeds_diff = np.random.rand(5, 3)
 
663
  self.untokenized_multi_reference_references = [
664
  [
665
  "This is a reference sentence 1. This is a reference sentence 2.",
666
+ "Another reference sentence.",
667
  ]
668
  ]
669
 
 
674
  self.tokenized_multi_reference_references = [
675
  [
676
  ["This is a reference sentence 1.", "This is a reference sentence 2."],
677
+ ["Another reference sentence."],
678
  ]
679
  ]
680
 
 
757
  True,
758
  True,
759
  self.untokenized_single_reference_predictions,
760
+ [
761
+ self.untokenized_single_reference_predictions[0],
762
+ self.untokenized_single_reference_predictions[0],
763
+ ],
764
  )
765
 
766
 
 
768
  unittest.main(verbosity=2)
769
 
770
 
771
+ if __name__ == "__main__":
772
  run_tests()