lvwerra HF staff commited on
Commit
017e24f
1 Parent(s): b1f420c

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. mauve.py +54 -38
  2. requirements.txt +1 -1
mauve.py CHANGED
@@ -14,6 +14,9 @@
14
  # limitations under the License.
15
  """ MAUVE metric from https://github.com/krishnap25/mauve. """
16
 
 
 
 
17
  import datasets
18
  import faiss # Here to have a nice missing dependency error message early on
19
  import numpy # Here to have a nice missing dependency error message early on
@@ -85,14 +88,47 @@ Examples:
85
  """
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
89
  class Mauve(evaluate.Metric):
90
- def _info(self):
 
 
 
 
91
  return evaluate.MetricInfo(
92
  description=_DESCRIPTION,
93
  citation=_CITATION,
94
  homepage="https://github.com/krishnap25/mauve",
95
  inputs_description=_KWARGS_DESCRIPTION,
 
96
  features=datasets.Features(
97
  {
98
  "predictions": datasets.Value("string", id="sequence"),
@@ -106,45 +142,25 @@ class Mauve(evaluate.Metric):
106
  ],
107
  )
108
 
109
- def _compute(
110
- self,
111
- predictions,
112
- references,
113
- p_features=None,
114
- q_features=None,
115
- p_tokens=None,
116
- q_tokens=None,
117
- num_buckets="auto",
118
- pca_max_data=-1,
119
- kmeans_explained_var=0.9,
120
- kmeans_num_redo=5,
121
- kmeans_max_iter=500,
122
- featurize_model_name="gpt2-large",
123
- device_id=-1,
124
- max_text_length=1024,
125
- divergence_curve_discretization_size=25,
126
- mauve_scaling_factor=5,
127
- verbose=True,
128
- seed=25,
129
- ):
130
  out = compute_mauve(
131
  p_text=predictions,
132
  q_text=references,
133
- p_features=p_features,
134
- q_features=q_features,
135
- p_tokens=p_tokens,
136
- q_tokens=q_tokens,
137
- num_buckets=num_buckets,
138
- pca_max_data=pca_max_data,
139
- kmeans_explained_var=kmeans_explained_var,
140
- kmeans_num_redo=kmeans_num_redo,
141
- kmeans_max_iter=kmeans_max_iter,
142
- featurize_model_name=featurize_model_name,
143
- device_id=device_id,
144
- max_text_length=max_text_length,
145
- divergence_curve_discretization_size=divergence_curve_discretization_size,
146
- mauve_scaling_factor=mauve_scaling_factor,
147
- verbose=verbose,
148
- seed=seed,
149
  )
150
  return out
 
14
  # limitations under the License.
15
  """ MAUVE metric from https://github.com/krishnap25/mauve. """
16
 
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Union
19
+
20
  import datasets
21
  import faiss # Here to have a nice missing dependency error message early on
22
  import numpy # Here to have a nice missing dependency error message early on
 
88
  """
89
 
90
 
91
+ @dataclass
92
+ class MauveConfig(evaluate.info.Config):
93
+
94
+ name: str = "default"
95
+
96
+ pos_label: Union[str, int] = 1
97
+ average: str = "binary"
98
+ labels: Optional[List[str]] = None
99
+ sample_weight: Optional[List[float]] = None
100
+
101
+ p_features: Optional[List] = None
102
+ q_features: Optional[List] = None
103
+ p_tokens: Optional[List] = None
104
+ q_tokens: Optional[List] = None
105
+ num_buckets: str = "auto"
106
+ pca_max_data: int = -1
107
+ kmeans_explained_var: float = 0.9
108
+ kmeans_num_redo: int = 5
109
+ kmeans_max_iter: int = 500
110
+ featurize_model_name: str = "gpt2-large"
111
+ device_id: int = (-1,)
112
+ max_text_length: int = 1024
113
+ divergence_curve_discretization_size: int = 25
114
+ mauve_scaling_factor: int = 5
115
+ verbose: bool = True
116
+ seed: int = 25
117
+
118
+
119
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
120
  class Mauve(evaluate.Metric):
121
+
122
+ CONFIG_CLASS = MauveConfig
123
+ ALLOWED_CONFIG_NAMES = ["default"]
124
+
125
+ def _info(self, config):
126
  return evaluate.MetricInfo(
127
  description=_DESCRIPTION,
128
  citation=_CITATION,
129
  homepage="https://github.com/krishnap25/mauve",
130
  inputs_description=_KWARGS_DESCRIPTION,
131
+ config=config,
132
  features=datasets.Features(
133
  {
134
  "predictions": datasets.Value("string", id="sequence"),
 
142
  ],
143
  )
144
 
145
+ def _compute(self, predictions, references):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  out = compute_mauve(
147
  p_text=predictions,
148
  q_text=references,
149
+ p_features=self.config.p_features,
150
+ q_features=self.config.q_features,
151
+ p_tokens=self.config.p_tokens,
152
+ q_tokens=self.config.q_tokens,
153
+ num_buckets=self.config.num_buckets,
154
+ pca_max_data=self.config.pca_max_data,
155
+ kmeans_explained_var=self.config.kmeans_explained_var,
156
+ kmeans_num_redo=self.config.kmeans_num_redo,
157
+ kmeans_max_iter=self.config.kmeans_max_iter,
158
+ featurize_model_name=self.config.featurize_model_name,
159
+ device_id=self.config.device_id,
160
+ max_text_length=self.config.max_text_length,
161
+ divergence_curve_discretization_size=self.config.divergence_curve_discretization_size,
162
+ mauve_scaling_factor=self.config.mauve_scaling_factor,
163
+ verbose=self.config.verbose,
164
+ seed=self.config.seed,
165
  )
166
  return out
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  faiss-cpu
3
  sklearn
4
  mauve-text
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  faiss-cpu
3
  sklearn
4
  mauve-text