yu-val-weiss
commited on
Commit
·
eb6c7b0
1
Parent(s):
17ddf40
Update blimp.py
Browse files
blimp.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14 |
"""Blimp Metric."""
|
15 |
|
16 |
from collections import defaultdict
|
|
|
17 |
|
18 |
import datasets
|
19 |
import evaluate
|
@@ -123,7 +124,7 @@ Args:
|
|
123 |
predictions (list[str]): names of metrics to run. pass empty list or ["*"] to run all of them
|
124 |
batch_size (int): the batch size to run texts through the model. Defaults to 16.
|
125 |
device (str): device to run on, defaults to 'cuda' when available.
|
126 |
-
samples_per_set (int): the number of samples per phenomenon, defaults to
|
127 |
|
128 |
Returns:
|
129 |
blimp: dictionary containing the blimp scores for each of the 67 sub-datasets, as well as the overall accuracy.
|
@@ -156,7 +157,7 @@ class Blimp(evaluate.Metric):
|
|
156 |
predictions=None,
|
157 |
batch_size: int = 16,
|
158 |
device=None,
|
159 |
-
samples_per_set: int =
|
160 |
):
|
161 |
if device is not None:
|
162 |
assert device in ["gpu", "cpu", "cuda", "mps"], (
|
@@ -171,6 +172,9 @@ class Blimp(evaluate.Metric):
|
|
171 |
else ("mps" if torch.mps.is_available() else "cpu")
|
172 |
)
|
173 |
|
|
|
|
|
|
|
174 |
model = AutoModelForCausalLM.from_pretrained(model_id)
|
175 |
model = model.to(device)
|
176 |
model.eval()
|
|
|
14 |
"""Blimp Metric."""
|
15 |
|
16 |
from collections import defaultdict
|
17 |
+
from typing import Optional
|
18 |
|
19 |
import datasets
|
20 |
import evaluate
|
|
|
124 |
predictions (list[str]): names of metrics to run. pass empty list or ["*"] to run all of them
|
125 |
batch_size (int): the batch size to run texts through the model. Defaults to 16.
|
126 |
device (str): device to run on, defaults to 'cuda' when available.
|
127 |
+
samples_per_set (Optional[int]): the number of samples per phenomenon. Max is 1,000 (but will not error if higher value given.) If None, defaults to 1000.
|
128 |
|
129 |
Returns:
|
130 |
blimp: dictionary containing the blimp scores for each of the 67 sub-datasets, as well as the overall accuracy.
|
|
|
157 |
predictions=None,
|
158 |
batch_size: int = 16,
|
159 |
device=None,
|
160 |
+
samples_per_set: Optional[int] = None,
|
161 |
):
|
162 |
if device is not None:
|
163 |
assert device in ["gpu", "cpu", "cuda", "mps"], (
|
|
|
172 |
else ("mps" if torch.mps.is_available() else "cpu")
|
173 |
)
|
174 |
|
175 |
+
if samples_per_set is None or samples_per_set <= 0:
|
176 |
+
samples_per_set = 1000
|
177 |
+
|
178 |
model = AutoModelForCausalLM.from_pretrained(model_id)
|
179 |
model = model.to(device)
|
180 |
model.eval()
|