Upload folder using huggingface_hub
Browse files- base_metric.py +229 -0
- benchmark.py +15 -0
- dataset.py +1 -0
- evaluate_cli.py +6 -8
- fusion.py +14 -2
- image_operators.py +5 -0
- inference.py +83 -6
- llm_as_judge.py +1 -1
- llm_as_judge_constants.py +93 -14
- loaders.py +127 -54
- metric.py +1 -0
- metrics.py +67 -215
- operators.py +76 -70
- version.py +1 -1
base_metric.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from typing import (
|
3 |
+
Any,
|
4 |
+
Dict,
|
5 |
+
List,
|
6 |
+
Union,
|
7 |
+
)
|
8 |
+
|
9 |
+
from .artifact import Artifact
|
10 |
+
from .dataclass import (
|
11 |
+
AbstractField,
|
12 |
+
)
|
13 |
+
from .deprecation_utils import deprecation
|
14 |
+
from .error_utils import Documentation, UnitxtWarning
|
15 |
+
from .stream import Stream
|
16 |
+
from .type_utils import Type, isoftype, parse_type_string, to_type_string
|
17 |
+
|
18 |
+
|
19 |
+
@deprecation(
|
20 |
+
version="2.0.0",
|
21 |
+
msg="use regular type instead of strings (e.g Dict[str] instead of 'Dict[str]')",
|
22 |
+
)
|
23 |
+
def parse_string_types_instead_of_actual_objects(obj):
|
24 |
+
return parse_type_string(obj)
|
25 |
+
|
26 |
+
class Metric(Artifact):
|
27 |
+
main_score: str = AbstractField()
|
28 |
+
# Override 'prediction_type' with the expected type of predictions
|
29 |
+
# and references. Example: "List[str]", "List[Dict]"", "string".
|
30 |
+
# If left with default None, a warning will be displayed.
|
31 |
+
# In future versions of unitxt, this will be an error.
|
32 |
+
prediction_type: Union[Type, str] = Any
|
33 |
+
|
34 |
+
# Standard metrics can receive multiple references per predictions (in a list)
|
35 |
+
# Some metrics support only a single reference per prediction (one element in the list)
|
36 |
+
single_reference_per_prediction: bool = False
|
37 |
+
|
38 |
+
#
|
39 |
+
# Used to add a prefix to all score, except the "score_name" and "score" fields.
|
40 |
+
# This is used to distinguish two scores of the same metrics, operating on different fields of the task
|
41 |
+
#
|
42 |
+
score_prefix: str = ""
|
43 |
+
|
44 |
+
def prepare_args(self):
|
45 |
+
super().prepare_args()
|
46 |
+
if isinstance(self.prediction_type, str):
|
47 |
+
self.prediction_type = parse_string_types_instead_of_actual_objects(
|
48 |
+
self.prediction_type
|
49 |
+
)
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def process_data_after_load(cls, data):
|
53 |
+
if "prediction_type" in data:
|
54 |
+
data["prediction_type"] = parse_type_string(data["prediction_type"])
|
55 |
+
return data
|
56 |
+
|
57 |
+
def process_data_before_dump(self, data):
|
58 |
+
if "prediction_type" in data:
|
59 |
+
if not isinstance(data["prediction_type"], str):
|
60 |
+
data["prediction_type"] = to_type_string(data["prediction_type"])
|
61 |
+
return data
|
62 |
+
|
63 |
+
def _add_score_prefix(self, score_name):
|
64 |
+
return (
|
65 |
+
self.score_prefix + score_name
|
66 |
+
if score_name not in ["score", "score_name", "num_of_instances"]
|
67 |
+
else score_name
|
68 |
+
)
|
69 |
+
|
70 |
+
def _add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
71 |
+
self, scores: Dict[str, Any], existing_scores: Dict[str, Any]
|
72 |
+
) -> Dict[str, Any]:
|
73 |
+
new_scores = {}
|
74 |
+
for score_name, score in scores.items():
|
75 |
+
score_with_prefix = self._add_score_prefix(score_name)
|
76 |
+
new_scores[score_with_prefix] = (
|
77 |
+
score if score_name not in ["score_name"] else self.score_prefix + score
|
78 |
+
)
|
79 |
+
for new_score_name in new_scores:
|
80 |
+
if new_score_name in ["score", "score_name", "num_of_instances"]:
|
81 |
+
continue
|
82 |
+
if new_score_name in existing_scores:
|
83 |
+
UnitxtWarning(
|
84 |
+
message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded "
|
85 |
+
f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. "
|
86 |
+
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
|
87 |
+
f"which will yield, in this case, a score named: 'my_second_{new_score_name}')",
|
88 |
+
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
89 |
+
)
|
90 |
+
return new_scores
|
91 |
+
|
92 |
+
def _validate_references_and_prediction(self, references, predictions):
|
93 |
+
if not isoftype(predictions, List[Any]):
|
94 |
+
raise ValueError(
|
95 |
+
f"Metric {self.get_metric_name()} should receive a list of predictions {self.get_metric_name()}. Received predictions of type {type(predictions)}: {predictions}"
|
96 |
+
)
|
97 |
+
|
98 |
+
if not isoftype(references, List[Any]):
|
99 |
+
raise ValueError(
|
100 |
+
f"Metric {self.get_metric_name()} should receive a list of predictions. Received references of type {type(references)}: {references}"
|
101 |
+
)
|
102 |
+
|
103 |
+
if len(references) != len(predictions):
|
104 |
+
raise ValueError(
|
105 |
+
f"references size ({len(references)})"
|
106 |
+
f" doesn't mach predictions size ({len(references)})."
|
107 |
+
)
|
108 |
+
|
109 |
+
for reference in references:
|
110 |
+
self._validate_reference(reference)
|
111 |
+
|
112 |
+
for prediction in predictions:
|
113 |
+
self._validate_prediction(prediction)
|
114 |
+
|
115 |
+
def _validate_prediction(self, prediction):
|
116 |
+
if not isoftype(prediction, self.prediction_type):
|
117 |
+
raise ValueError(
|
118 |
+
f"Each prediction is expected to be of type '{to_type_string(self.prediction_type)}' in {self.get_metric_name()} metric. Received prediction of type {type(prediction)}: {prediction}"
|
119 |
+
)
|
120 |
+
|
121 |
+
def _validate_reference(self, reference):
|
122 |
+
if not isoftype(reference, List[Any]):
|
123 |
+
raise ValueError(
|
124 |
+
f"Expecting a list of references for each prediction in {self.get_metric_name()} metric. Received reference of type {type(reference)}: {reference}"
|
125 |
+
)
|
126 |
+
if self.single_reference_per_prediction and not len(reference) == 1:
|
127 |
+
raise ValueError(
|
128 |
+
f"Expecting a list with a single reference per prediction in {self.get_metric_name()} metric. Received a list with multiple references: {reference}"
|
129 |
+
)
|
130 |
+
for ref in reference:
|
131 |
+
if not isoftype(ref, self.prediction_type):
|
132 |
+
raise ValueError(
|
133 |
+
f"Each reference is expected to be of type '{to_type_string(self.prediction_type)}' in {self.get_metric_name()} metric. Received reference of type {type(ref)}: {ref}"
|
134 |
+
)
|
135 |
+
|
136 |
+
def get_metric_name(self):
|
137 |
+
if self.__id__ is not None:
|
138 |
+
return self.__id__
|
139 |
+
return self.__class__.__name__
|
140 |
+
|
141 |
+
def consume_stream(self, stream: Stream):
|
142 |
+
references = []
|
143 |
+
predictions = []
|
144 |
+
additional_inputs = []
|
145 |
+
instances = []
|
146 |
+
for instance in stream:
|
147 |
+
instance = self.verify_instance(instance)
|
148 |
+
references.append(instance["references"])
|
149 |
+
predictions.append(instance["prediction"])
|
150 |
+
additional_inputs.append(
|
151 |
+
instance["additional_inputs"] if "additional_inputs" in instance else {}
|
152 |
+
)
|
153 |
+
instances.append(instance)
|
154 |
+
return predictions, references, additional_inputs, instances
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def update_instance_scores(instances, instances_scores: List[Dict[str, Any]]):
|
158 |
+
for instance, new_scores in zip(instances, instances_scores):
|
159 |
+
if "score" not in instance:
|
160 |
+
instance["score"] = {}
|
161 |
+
scores = instance["score"]
|
162 |
+
if "instance" not in scores:
|
163 |
+
scores["instance"] = {}
|
164 |
+
scores["instance"].update(new_scores)
|
165 |
+
|
166 |
+
@staticmethod
|
167 |
+
def set_global_score(instances, global_score: Dict[str, Any]):
|
168 |
+
for instance in instances:
|
169 |
+
if "score" not in instance:
|
170 |
+
instance["score"] = {}
|
171 |
+
scores = instance["score"]
|
172 |
+
if "global" not in scores:
|
173 |
+
scores["global"] = {}
|
174 |
+
scores["global"] = global_score
|
175 |
+
|
176 |
+
@abstractmethod
|
177 |
+
def disable_confidence_interval_calculation(self):
|
178 |
+
pass
|
179 |
+
|
180 |
+
# update instance["score"]["global"] with the global_score just computed for the
|
181 |
+
# current metric. global_score contains "score" and "score_name" fields that reflect
|
182 |
+
# (the main_score of) the current metric. If CI was computed for global_score, then global_score
|
183 |
+
# also contains "score_ci_low" and "score_ci_high" that reflect (the main_score of) the current metric.
|
184 |
+
# A simple python-dictionary-update adds new fields to instance["score"]["global"], and also replaces the values
|
185 |
+
# of its fields "score" and "score_name" (and "score_ci_low", "score_ci_high" if applicable),
|
186 |
+
# to reflect the current metric, overwriting previous metrics' settings of these fields
|
187 |
+
# (if any previous metric exists).
|
188 |
+
# When global_score does NOT contain ci score (because CI was not computed for the current metric), but
|
189 |
+
# one of the previous metrics computed did have, the last of such previous metrics set the values in
|
190 |
+
# fields "score_ci_low" and "score_ci_high" in instance["score"]["global"] to reflect its
|
191 |
+
# (the previous metric's) CI scores.
|
192 |
+
# Because CI is not computed for the current metric, global_score does not contain fields "score_ci_low" and
|
193 |
+
# "score_ci_high" to overwrite the ones existing in instance["score"]["global"], and these might remain in
|
194 |
+
# instance["score"]["global"], but their values, that are not associated with the current metric, are,
|
195 |
+
# therefore, not consistent with "score_name".
|
196 |
+
# In such a case, following the python-dictionary-update, we pop out fields "score_ci_low" and
|
197 |
+
# "score_ci_high" from instance["score"]["global"], so that now all the fields "score.." in
|
198 |
+
# instance["score"]["global"] are consistent with the current metric: The metric that is named
|
199 |
+
# instance["score"]["global"]["score_name"], its score shows in
|
200 |
+
# field instance["score"]["global"]["score"], and it does not have ci_scores,
|
201 |
+
# which is also reflected in the absence of fields "score_ci_low" and "score_ci_high" from instance["score"]["global"].
|
202 |
+
# If ci IS computed for the current metric, global_score contains "score_ci_low" and "score_ci_high", and these overwrite
|
203 |
+
# the ones existing in instance["score"]["global"] by the simple python-dictionary-update, and no need for any further fixeup.
|
204 |
+
def update_and_adjust_global_score(
|
205 |
+
self, instance: Dict[str, Any], global_score: dict
|
206 |
+
):
|
207 |
+
for score_name in global_score:
|
208 |
+
if score_name in [
|
209 |
+
"score",
|
210 |
+
"score_name",
|
211 |
+
"score_ci_low",
|
212 |
+
"score_ci_high",
|
213 |
+
"num_of_instances",
|
214 |
+
]:
|
215 |
+
continue
|
216 |
+
if score_name in instance["score"]["global"]:
|
217 |
+
UnitxtWarning(
|
218 |
+
message=f"Global metric '{score_name}' that has just been evaluated to {global_score[score_name]}, is already recorded "
|
219 |
+
f"to have value {instance['score']['global'][score_name]} by a previous metric evaluation on this stream. "
|
220 |
+
f"To avoid overwriting the value, add a score_prefix to the metric (e.g. score_prefix='my_{score_name}'.",
|
221 |
+
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
222 |
+
)
|
223 |
+
instance["score"]["global"].update(global_score)
|
224 |
+
for score_ci in ["score_ci_low", "score_ci_high"]:
|
225 |
+
if score_ci in global_score:
|
226 |
+
continue
|
227 |
+
if score_ci in instance["score"]["global"]:
|
228 |
+
instance["score"]["global"].pop(score_ci)
|
229 |
+
|
benchmark.py
CHANGED
@@ -30,6 +30,9 @@ class Benchmark(BaseBenchmark):
|
|
30 |
|
31 |
max_total_samples: int = None
|
32 |
max_samples_per_subset: int = None
|
|
|
|
|
|
|
33 |
|
34 |
def verify(self):
|
35 |
super().verify()
|
@@ -73,10 +76,22 @@ class Benchmark(BaseBenchmark):
|
|
73 |
subsets = {self.subset: self.subsets[self.subset]}
|
74 |
else:
|
75 |
subsets = self.subsets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
if self.max_total_samples is None:
|
77 |
operator = FixedFusion(
|
78 |
subsets=subsets,
|
79 |
max_instances_per_subset=self.max_samples_per_subset,
|
|
|
80 |
include_splits=self.splits,
|
81 |
)
|
82 |
else:
|
|
|
30 |
|
31 |
max_total_samples: int = None
|
32 |
max_samples_per_subset: int = None
|
33 |
+
max_train_instances: int = None
|
34 |
+
max_validation_instances: int = None
|
35 |
+
max_test_instances: int = None
|
36 |
|
37 |
def verify(self):
|
38 |
super().verify()
|
|
|
76 |
subsets = {self.subset: self.subsets[self.subset]}
|
77 |
else:
|
78 |
subsets = self.subsets
|
79 |
+
|
80 |
+
max_instances_per_split = {}
|
81 |
+
if self.max_train_instances is not None:
|
82 |
+
max_instances_per_split["train"] = self.max_train_instances
|
83 |
+
if self.max_validation_instances is not None:
|
84 |
+
max_instances_per_split["validation"] = self.max_validation_instances
|
85 |
+
if self.max_test_instances is not None:
|
86 |
+
max_instances_per_split["test"] = self.max_test_instances
|
87 |
+
if len(max_instances_per_split) == 0:
|
88 |
+
max_instances_per_split = None
|
89 |
+
|
90 |
if self.max_total_samples is None:
|
91 |
operator = FixedFusion(
|
92 |
subsets=subsets,
|
93 |
max_instances_per_subset=self.max_samples_per_subset,
|
94 |
+
max_instances_per_split=max_instances_per_split,
|
95 |
include_splits=self.splits,
|
96 |
)
|
97 |
else:
|
dataset.py
CHANGED
@@ -6,6 +6,7 @@ import datasets
|
|
6 |
from .api import __file__ as _
|
7 |
from .artifact import __file__ as _
|
8 |
from .augmentors import __file__ as _
|
|
|
9 |
from .benchmark import __file__ as _
|
10 |
from .blocks import __file__ as _
|
11 |
from .card import __file__ as _
|
|
|
6 |
from .api import __file__ as _
|
7 |
from .artifact import __file__ as _
|
8 |
from .augmentors import __file__ as _
|
9 |
+
from .base_metric import __file__ as _
|
10 |
from .benchmark import __file__ as _
|
11 |
from .blocks import __file__ as _
|
12 |
from .card import __file__ as _
|
evaluate_cli.py
CHANGED
@@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
13 |
|
14 |
from datasets import Dataset as HFDataset
|
15 |
|
16 |
-
from .api import evaluate,
|
17 |
from .artifact import UnitxtArtifactNotFoundError
|
18 |
from .benchmark import Benchmark
|
19 |
|
@@ -27,7 +27,6 @@ from .logging_utils import get_logger
|
|
27 |
from .metric_utils import EvaluationResults
|
28 |
from .parsing_utils import parse_key_equals_value_string_to_dict
|
29 |
from .settings_utils import settings
|
30 |
-
from .standard import DatasetRecipe
|
31 |
|
32 |
# Define logger early so it can be used in initial error handling
|
33 |
# Basic config for initial messages, will be reconfigured in main()
|
@@ -294,21 +293,20 @@ def cli_load_dataset(args: argparse.Namespace) -> HFDataset:
|
|
294 |
|
295 |
benchmark_subsets = {}
|
296 |
for task_str in args.tasks:
|
297 |
-
|
298 |
-
|
299 |
-
benchmark_subsets[task_str] = DatasetRecipe(**dataset_args)
|
300 |
|
301 |
benchmark = Benchmark(subsets=benchmark_subsets)
|
302 |
|
303 |
-
test_dataset =
|
304 |
logger.info(
|
305 |
f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
|
306 |
)
|
307 |
return test_dataset
|
308 |
|
309 |
|
310 |
-
def
|
311 |
-
dataset_args =
|
312 |
|
313 |
if args.limit is not None:
|
314 |
assert f"max_{args.split}_instances" not in dataset_args, (
|
|
|
13 |
|
14 |
from datasets import Dataset as HFDataset
|
15 |
|
16 |
+
from .api import _source_to_dataset, evaluate, load_recipe
|
17 |
from .artifact import UnitxtArtifactNotFoundError
|
18 |
from .benchmark import Benchmark
|
19 |
|
|
|
27 |
from .metric_utils import EvaluationResults
|
28 |
from .parsing_utils import parse_key_equals_value_string_to_dict
|
29 |
from .settings_utils import settings
|
|
|
30 |
|
31 |
# Define logger early so it can be used in initial error handling
|
32 |
# Basic config for initial messages, will be reconfigured in main()
|
|
|
293 |
|
294 |
benchmark_subsets = {}
|
295 |
for task_str in args.tasks:
|
296 |
+
overwrite_args = extract_overwrite_args(args)
|
297 |
+
benchmark_subsets[task_str] = load_recipe(dataset_query=task_str, **overwrite_args)
|
|
|
298 |
|
299 |
benchmark = Benchmark(subsets=benchmark_subsets)
|
300 |
|
301 |
+
test_dataset = _source_to_dataset(benchmark, split=args.split)
|
302 |
logger.info(
|
303 |
f"Dataset loaded successfully. Number of instances: {len(test_dataset)}"
|
304 |
)
|
305 |
return test_dataset
|
306 |
|
307 |
|
308 |
+
def extract_overwrite_args(args):
|
309 |
+
dataset_args = {}
|
310 |
|
311 |
if args.limit is not None:
|
312 |
assert f"max_{args.split}_instances" not in dataset_args, (
|
fusion.py
CHANGED
@@ -67,6 +67,7 @@ class FixedFusion(BaseFusion):
|
|
67 |
"""
|
68 |
|
69 |
max_instances_per_subset: Optional[int] = None
|
|
|
70 |
|
71 |
def prepare(self):
|
72 |
super().prepare()
|
@@ -78,12 +79,23 @@ class FixedFusion(BaseFusion):
|
|
78 |
if split not in multi_stream:
|
79 |
continue
|
80 |
emitted_from_this_split = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
logger.info(f"Processing {split} from {origin_name}...")
|
82 |
try:
|
83 |
for instance in multi_stream[split]:
|
84 |
if (
|
85 |
-
|
86 |
-
and emitted_from_this_split >=
|
87 |
):
|
88 |
break
|
89 |
if isinstance(origin_name, str):
|
|
|
67 |
"""
|
68 |
|
69 |
max_instances_per_subset: Optional[int] = None
|
70 |
+
max_instances_per_split: Optional[Dict[str, int]]= None
|
71 |
|
72 |
def prepare(self):
|
73 |
super().prepare()
|
|
|
79 |
if split not in multi_stream:
|
80 |
continue
|
81 |
emitted_from_this_split = 0
|
82 |
+
max_from_this_split = None
|
83 |
+
if self.max_instances_per_subset is not None:
|
84 |
+
max_from_this_split = self.max_instances_per_subset
|
85 |
+
if self.max_instances_per_split is not None:
|
86 |
+
max_per_this_split = self.max_instances_per_split.get(split)
|
87 |
+
if max_per_this_split is not None:
|
88 |
+
if max_from_this_split is None:
|
89 |
+
max_from_this_split = max_per_this_split
|
90 |
+
elif max_per_this_split < max_from_this_split:
|
91 |
+
max_from_this_split = max_per_this_split
|
92 |
+
|
93 |
logger.info(f"Processing {split} from {origin_name}...")
|
94 |
try:
|
95 |
for instance in multi_stream[split]:
|
96 |
if (
|
97 |
+
max_from_this_split is not None
|
98 |
+
and emitted_from_this_split >= max_from_this_split
|
99 |
):
|
100 |
break
|
101 |
if isinstance(origin_name, str):
|
image_operators.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import base64
|
|
|
2 |
import io
|
3 |
import re
|
4 |
from abc import abstractmethod
|
@@ -113,6 +114,10 @@ class EncodeImageToString(FieldOperator):
|
|
113 |
def process_value(self, value: Any) -> Any:
|
114 |
return {"image": self.encode_image_to_base64(value)}
|
115 |
|
|
|
|
|
|
|
|
|
116 |
|
117 |
class DecodeImage(FieldOperator, PillowMixin):
|
118 |
def process_value(self, value: str) -> Any:
|
|
|
1 |
import base64
|
2 |
+
import hashlib
|
3 |
import io
|
4 |
import re
|
5 |
from abc import abstractmethod
|
|
|
114 |
def process_value(self, value: Any) -> Any:
|
115 |
return {"image": self.encode_image_to_base64(value)}
|
116 |
|
117 |
+
class HashImage(FieldOperator, PillowMixin):
|
118 |
+
|
119 |
+
def process_value(self, value: Any) -> Any:
|
120 |
+
return hashlib.md5(value.tobytes()).hexdigest()
|
121 |
|
122 |
class DecodeImage(FieldOperator, PillowMixin):
|
123 |
def process_value(self, value: str) -> Any:
|
inference.py
CHANGED
@@ -6,6 +6,7 @@ import hashlib
|
|
6 |
import io
|
7 |
import json
|
8 |
import logging
|
|
|
9 |
import os
|
10 |
import re
|
11 |
import sys
|
@@ -35,6 +36,7 @@ from tqdm import tqdm, trange
|
|
35 |
from tqdm.asyncio import tqdm_asyncio
|
36 |
|
37 |
from .artifact import Artifact
|
|
|
38 |
from .dataclass import InternalField, NonPositionalField
|
39 |
from .deprecation_utils import deprecation
|
40 |
from .error_utils import UnitxtError, UnitxtWarning
|
@@ -238,7 +240,7 @@ class InferenceEngine(Artifact):
|
|
238 |
result = self._mock_infer(dataset)
|
239 |
else:
|
240 |
if self.use_cache:
|
241 |
-
number_of_batches = len(dataset)
|
242 |
result = []
|
243 |
for batch_index, batch in enumerate(
|
244 |
batched(dataset, self.cache_batch_size)
|
@@ -3342,10 +3344,12 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3342 |
provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
|
3343 |
"watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
|
3344 |
"granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
|
3345 |
-
"granite-3-2-8b-instruct": "ibm/granite-3-2-8b-instruct",
|
3346 |
-
"granite-3-3-8b-instruct": "ibm/granite-3-3-8b-instruct",
|
3347 |
"granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
|
3348 |
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
|
|
|
|
|
|
|
|
|
3349 |
"granite-34b-code-instruct": "ibm/granite-34b-code-instruct",
|
3350 |
"granite-guardian-3-8b": "ibm/granite-guardian-3-8b",
|
3351 |
"granite-vision-3-2-2b": "ibm/granite-vision-3-2-2b",
|
@@ -3361,7 +3365,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3361 |
"mistral-large-instruct": "mistralai/mistral-large",
|
3362 |
"mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7b-instruct-v01",
|
3363 |
},
|
3364 |
-
"together-ai": {
|
3365 |
"llama-3-8b-instruct": "together_ai/meta-llama/Llama-3-8b-chat-hf",
|
3366 |
"llama-3-70b-instruct": "together_ai/meta-llama/Llama-3-70b-chat-hf",
|
3367 |
"llama-3-1-8b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
@@ -3369,10 +3373,23 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3369 |
"llama-3-1-405b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
3370 |
"llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
|
3371 |
"llama-3-3-70b-instruct": "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
|
|
|
|
|
|
|
|
|
|
3372 |
},
|
3373 |
-
"aws": {
|
3374 |
"llama-3-8b-instruct": "bedrock/meta.llama3-8b-instruct-v1:0",
|
3375 |
"llama-3-70b-instruct": "bedrock/meta.llama3-70b-instruct-v1:0",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3376 |
},
|
3377 |
"ollama": {
|
3378 |
"llama-3-8b-instruct": "llama3:8b",
|
@@ -3383,6 +3400,8 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3383 |
"llama-3-2-1b-instruct": "llama3.2:1b",
|
3384 |
"llama-3-2-3b-instruct": "llama3.2:3b",
|
3385 |
"llama-3-3-70b-instruct": "llama3.3",
|
|
|
|
|
3386 |
},
|
3387 |
"bam": {
|
3388 |
"granite-3-8b-instruct": "ibm/granite-8b-instruct-preview-4k",
|
@@ -3401,9 +3420,12 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3401 |
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
3402 |
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
3403 |
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
|
|
|
|
|
3404 |
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
|
3405 |
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3406 |
-
"
|
|
|
3407 |
"granite-guardian-3-2-3b-a800m": "ibm-granite/granite-guardian-3.2-3b-a800m",
|
3408 |
"granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
|
3409 |
},
|
@@ -3432,6 +3454,12 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3432 |
"gpt-4-32k-0314": "gpt-4-32k-0314",
|
3433 |
"gpt-4-32k-0613": "gpt-4-32k-0613",
|
3434 |
"gpt-4-vision-preview": "gpt-4-vision-preview",
|
|
|
|
|
|
|
|
|
|
|
|
|
3435 |
},
|
3436 |
"azure": {
|
3437 |
"o1-mini": "azure/o1-mini",
|
@@ -3454,11 +3482,23 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3454 |
"gpt-3.5-turbo-16k": "azure/gpt-3.5-turbo-16k",
|
3455 |
"gpt-3.5-turbo-16k-0613": "azure/gpt-3.5-turbo-16k-0613",
|
3456 |
"gpt-4-vision": "azure/gpt-4-vision",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3457 |
},
|
3458 |
"vertex-ai": {
|
3459 |
"llama-3-1-8b-instruct": "vertex_ai/meta/llama-3.1-8b-instruct-maas",
|
3460 |
"llama-3-1-70b-instruct": "vertex_ai/meta/llama-3.1-70b-instruct-maas",
|
3461 |
"llama-3-1-405b-instruct": "vertex_ai/meta/llama-3.1-405b-instruct-maas",
|
|
|
|
|
|
|
|
|
3462 |
},
|
3463 |
"replicate": {
|
3464 |
"granite-3-2-8b-instruct": "replicate/ibm-granite/granite-3.2-8b-instruct",
|
@@ -3480,9 +3520,13 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3480 |
"llama-3-70b-instruct": "replicate/meta/meta-llama-3-70b-instruct",
|
3481 |
"llama-3-8b": "replicate/meta/meta-llama-3-8b",
|
3482 |
"llama-3-8b-instruct": "replicate/meta/meta-llama-3-8b-instruct",
|
|
|
|
|
|
|
3483 |
"mistral-7b-instruct-v0.2": "replicate/mistralai/mistral-7b-instruct-v0.2",
|
3484 |
"mistral-7b-v0.1": "replicate/mistralai/mistral-7b-v0.1",
|
3485 |
"mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
|
|
|
3486 |
},
|
3487 |
}
|
3488 |
provider_model_map["watsonx"] = {
|
@@ -3516,6 +3560,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3516 |
return self.provider if self.provider is not None else settings.default_provider
|
3517 |
|
3518 |
def prepare_engine(self):
|
|
|
3519 |
provider = self.get_provider_name()
|
3520 |
if provider not in self._provider_to_base_class:
|
3521 |
raise UnitxtError(
|
@@ -3675,3 +3720,35 @@ class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
|
|
3675 |
predictions.append(options_scores.most_common(1)[0][0])
|
3676 |
|
3677 |
return predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import io
|
7 |
import json
|
8 |
import logging
|
9 |
+
import math
|
10 |
import os
|
11 |
import re
|
12 |
import sys
|
|
|
36 |
from tqdm.asyncio import tqdm_asyncio
|
37 |
|
38 |
from .artifact import Artifact
|
39 |
+
from .base_metric import Metric
|
40 |
from .dataclass import InternalField, NonPositionalField
|
41 |
from .deprecation_utils import deprecation
|
42 |
from .error_utils import UnitxtError, UnitxtWarning
|
|
|
240 |
result = self._mock_infer(dataset)
|
241 |
else:
|
242 |
if self.use_cache:
|
243 |
+
number_of_batches = math.ceil(len(dataset) / self.cache_batch_size)
|
244 |
result = []
|
245 |
for batch_index, batch in enumerate(
|
246 |
batched(dataset, self.cache_batch_size)
|
|
|
3344 |
provider_model_map: Dict[_supported_apis, Dict[str, str]] = {
|
3345 |
"watsonx-sdk": { # checked from ibm_watsonx_ai.APIClient().foundation_models.ChatModels
|
3346 |
"granite-20b-code-instruct": "ibm/granite-20b-code-instruct",
|
|
|
|
|
3347 |
"granite-3-2b-instruct": "ibm/granite-3-2b-instruct",
|
3348 |
"granite-3-8b-instruct": "ibm/granite-3-8b-instruct",
|
3349 |
+
"granite-3-2-2b-instruct": "ibm/granite-3-2-2b-instruct",
|
3350 |
+
"granite-3-2-8b-instruct": "ibm/granite-3-2-8b-instruct",
|
3351 |
+
"granite-3-3-2b-instruct": "ibm/granite-3-3-2b-instruct",
|
3352 |
+
"granite-3-3-8b-instruct": "ibm/granite-3-3-8b-instruct",
|
3353 |
"granite-34b-code-instruct": "ibm/granite-34b-code-instruct",
|
3354 |
"granite-guardian-3-8b": "ibm/granite-guardian-3-8b",
|
3355 |
"granite-vision-3-2-2b": "ibm/granite-vision-3-2-2b",
|
|
|
3365 |
"mistral-large-instruct": "mistralai/mistral-large",
|
3366 |
"mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7b-instruct-v01",
|
3367 |
},
|
3368 |
+
"together-ai": { # checked from https://www.together.ai/models
|
3369 |
"llama-3-8b-instruct": "together_ai/meta-llama/Llama-3-8b-chat-hf",
|
3370 |
"llama-3-70b-instruct": "together_ai/meta-llama/Llama-3-70b-chat-hf",
|
3371 |
"llama-3-1-8b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
|
|
3373 |
"llama-3-1-405b-instruct": "together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
|
3374 |
"llama-3-2-1b-instruct": "together_ai/togethercomputer/llama-3-2-1b-instruct",
|
3375 |
"llama-3-3-70b-instruct": "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
3376 |
+
"llama-4-maverick": "together_ai/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", #pragma: allowlist secret
|
3377 |
+
"llama-4-scout": "together_ai/meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
3378 |
+
"deepseek-v3": "together_ai/deepseek-ai/DeepSeek-V3",
|
3379 |
+
"llama-3-3-70b-instruct-free": "together_ai/meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
|
3380 |
+
"deepseek-r1-distilled-llama-70b-free": "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free",
|
3381 |
},
|
3382 |
+
"aws": { # checked from https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
|
3383 |
"llama-3-8b-instruct": "bedrock/meta.llama3-8b-instruct-v1:0",
|
3384 |
"llama-3-70b-instruct": "bedrock/meta.llama3-70b-instruct-v1:0",
|
3385 |
+
"llama-3-1-70b-instruct": "bedrock/meta.llama3-1-70b-instruct-v1:0",
|
3386 |
+
"llama-3-1-405b-instruct": "bedrock/meta.llama3-1-405b-instruct-v1:0",
|
3387 |
+
"llama-3-3-70b-instruct": "bedrock/meta.llama3-3-70b-instruct-v1:0",
|
3388 |
+
"llama-4-maverick": "bedrock/meta.llama4-maverick-17b-instruct-v1:0", #pragma: allowlist secret
|
3389 |
+
"llama-4-scout": "bedrock/meta.llama4-scout-17b-instruct-v1:0",
|
3390 |
+
"mistral-large-instruct": "bedrock/mistral.mistral-large-2407-v1:0",
|
3391 |
+
"deepseek-r1": "bedrock/deepseek.r1-v1:0",
|
3392 |
+
"claude-3-7-sonnet": "bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0",
|
3393 |
},
|
3394 |
"ollama": {
|
3395 |
"llama-3-8b-instruct": "llama3:8b",
|
|
|
3400 |
"llama-3-2-1b-instruct": "llama3.2:1b",
|
3401 |
"llama-3-2-3b-instruct": "llama3.2:3b",
|
3402 |
"llama-3-3-70b-instruct": "llama3.3",
|
3403 |
+
"granite-3-3-2b-instruct": "granite3.3:2b",
|
3404 |
+
"granite-3-3-8b-instruct": "granite3.3:8b",
|
3405 |
},
|
3406 |
"bam": {
|
3407 |
"granite-3-8b-instruct": "ibm/granite-8b-instruct-preview-4k",
|
|
|
3420 |
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
3421 |
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
3422 |
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
|
3423 |
+
"llama-4-scout": "llama-4-scout-17b-16e",
|
3424 |
+
"llama-4-maverick": "llama-4-mvk-17b-128e-fp8",
|
3425 |
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
|
3426 |
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3427 |
+
"mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3428 |
+
"deepseek-v3": "deepseek-ai/deepseek-v3-h200",
|
3429 |
"granite-guardian-3-2-3b-a800m": "ibm-granite/granite-guardian-3.2-3b-a800m",
|
3430 |
"granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
|
3431 |
},
|
|
|
3454 |
"gpt-4-32k-0314": "gpt-4-32k-0314",
|
3455 |
"gpt-4-32k-0613": "gpt-4-32k-0613",
|
3456 |
"gpt-4-vision-preview": "gpt-4-vision-preview",
|
3457 |
+
"gpt-4-1": "gpt-4.1",
|
3458 |
+
"gpt-4-1-2025-04-14": "gpt-4.1-2025-04-14",
|
3459 |
+
"gpt-4-1-nano": "gpt-4.1-nano",
|
3460 |
+
"gpt-4-1-nano-2025-04-14": "gpt-4.1-nano-2025-04-14",
|
3461 |
+
"gpt-4-1-mini": "gpt-4.1-mini",
|
3462 |
+
"gpt-4-1-mini-2025-04-14": "gpt-4.1-mini-2025-04-14",
|
3463 |
},
|
3464 |
"azure": {
|
3465 |
"o1-mini": "azure/o1-mini",
|
|
|
3482 |
"gpt-3.5-turbo-16k": "azure/gpt-3.5-turbo-16k",
|
3483 |
"gpt-3.5-turbo-16k-0613": "azure/gpt-3.5-turbo-16k-0613",
|
3484 |
"gpt-4-vision": "azure/gpt-4-vision",
|
3485 |
+
"gpt-4-1": "azure/gpt-4.1",
|
3486 |
+
"gpt-4-1-nano": "azure/gpt-4.1-nano",
|
3487 |
+
"gpt-4-1-mini": "azure/gpt-4.1-mini",
|
3488 |
+
"gpt-4-1-mini-2025-04-14": "azure/gpt-4.1-mini-2025-04-14",
|
3489 |
+
"llama-3-1-405b-instruct": "azure/Meta-Llama-3.1-405B-Instruct",
|
3490 |
+
"llama-3-3-70b-instruct": "azure/Llama-3.3-70B-Instruct",
|
3491 |
+
"llama-4-maverick": "azure/Llama-4-Maverick-17B-128E-Instruct-FP8", #pragma: allowlist secret
|
3492 |
+
"llama-4-scout": "azure/Llama-4-Scout-17B-16E-Instruct",
|
3493 |
},
|
3494 |
"vertex-ai": {
|
3495 |
"llama-3-1-8b-instruct": "vertex_ai/meta/llama-3.1-8b-instruct-maas",
|
3496 |
"llama-3-1-70b-instruct": "vertex_ai/meta/llama-3.1-70b-instruct-maas",
|
3497 |
"llama-3-1-405b-instruct": "vertex_ai/meta/llama-3.1-405b-instruct-maas",
|
3498 |
+
"gemini-2-5-pro": "vertex_ai/gemini-2.5-pro-preview-05-06",
|
3499 |
+
"gemini-2-5-pro-preview-05-06": "vertex_ai/gemini-2.5-pro-preview-05-06",
|
3500 |
+
"gemini-2.5-flash": "gemini-2.5-flash-preview-05-20",
|
3501 |
+
"gemini-2.5-flash-preview-05-20": "gemini-2.5-flash-preview-05-20",
|
3502 |
},
|
3503 |
"replicate": {
|
3504 |
"granite-3-2-8b-instruct": "replicate/ibm-granite/granite-3.2-8b-instruct",
|
|
|
3520 |
"llama-3-70b-instruct": "replicate/meta/meta-llama-3-70b-instruct",
|
3521 |
"llama-3-8b": "replicate/meta/meta-llama-3-8b",
|
3522 |
"llama-3-8b-instruct": "replicate/meta/meta-llama-3-8b-instruct",
|
3523 |
+
"llama-3-3-70b-instruct": "replicate/meta/meta-llama-3.3-70b-instruct",
|
3524 |
+
"llama-4-maverick": "replicate/meta/llama-4-maverick-instruct",
|
3525 |
+
"llama-4-scout": "replicate/meta/llama-4-scout-instruct",
|
3526 |
"mistral-7b-instruct-v0.2": "replicate/mistralai/mistral-7b-instruct-v0.2",
|
3527 |
"mistral-7b-v0.1": "replicate/mistralai/mistral-7b-v0.1",
|
3528 |
"mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1",
|
3529 |
+
"gpt-4-1": "replicate/openai/gpt-4.1",
|
3530 |
},
|
3531 |
}
|
3532 |
provider_model_map["watsonx"] = {
|
|
|
3560 |
return self.provider if self.provider is not None else settings.default_provider
|
3561 |
|
3562 |
def prepare_engine(self):
|
3563 |
+
# print("provider", self.provider)
|
3564 |
provider = self.get_provider_name()
|
3565 |
if provider not in self._provider_to_base_class:
|
3566 |
raise UnitxtError(
|
|
|
3720 |
predictions.append(options_scores.most_common(1)[0][0])
|
3721 |
|
3722 |
return predictions
|
3723 |
+
|
3724 |
+
class MetricInferenceEngine(InferenceEngine):
|
3725 |
+
"""An inference engine that uses the output of a metric as its prediction. Used to evaluate metrics like LLM as Judge or Granite Guardian.
|
3726 |
+
|
3727 |
+
Args:
|
3728 |
+
InferenceEngine (_type_): _description_
|
3729 |
+
"""
|
3730 |
+
metric: Metric
|
3731 |
+
prediction_field: str
|
3732 |
+
|
3733 |
+
def _infer(
|
3734 |
+
self,
|
3735 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
3736 |
+
return_meta_data: bool = False,
|
3737 |
+
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
3738 |
+
task_data = [
|
3739 |
+
json.loads(instance["task_data"]) if "task_data" in instance else {}
|
3740 |
+
for instance in dataset
|
3741 |
+
]
|
3742 |
+
predictions=[td[self.prediction_field] for td in task_data]
|
3743 |
+
references = [instance["references"] for instance in dataset]
|
3744 |
+
return self.metric.compute(
|
3745 |
+
task_data=task_data,
|
3746 |
+
predictions=predictions,
|
3747 |
+
references=references,
|
3748 |
+
)
|
3749 |
+
|
3750 |
+
def prepare_engine(self):
|
3751 |
+
pass
|
3752 |
+
|
3753 |
+
def get_engine_id(self):
|
3754 |
+
return "metric_inference_engine"
|
llm_as_judge.py
CHANGED
@@ -251,7 +251,7 @@ class LLMJudgeDirect(LLMJudge):
|
|
251 |
self.assessment_task = Task(
|
252 |
input_fields={
|
253 |
"context_variables": str,
|
254 |
-
"response":
|
255 |
"criteria_description": str,
|
256 |
"display_options_instruction": str,
|
257 |
},
|
|
|
251 |
self.assessment_task = Task(
|
252 |
input_fields={
|
253 |
"context_variables": str,
|
254 |
+
"response": Any,
|
255 |
"criteria_description": str,
|
256 |
"display_options_instruction": str,
|
257 |
},
|
llm_as_judge_constants.py
CHANGED
@@ -71,8 +71,13 @@ class EvaluatorNameEnum(str, Enum):
|
|
71 |
LLAMA3_1_70B = "Llama3.1-70b"
|
72 |
LLAMA3_2_3B = "Llama3.2-3b"
|
73 |
LLAMA3_3_70B = "Llama3.3-70b"
|
|
|
|
|
74 |
PROMETHEUS = "Prometheus"
|
75 |
-
|
|
|
|
|
|
|
76 |
O1_PREVIEW = "o1-Preview"
|
77 |
O1_MINI = "o1-Mini"
|
78 |
GRANITE_13B = "Granite-13b"
|
@@ -81,23 +86,36 @@ class EvaluatorNameEnum(str, Enum):
|
|
81 |
GRANITE3_1_2B = "Granite3.1-2b"
|
82 |
GRANITE3_1_8B = "Granite3.1-8b"
|
83 |
GRANITE3_2_8B = "Granite3.2-8b"
|
84 |
-
|
|
|
|
|
|
|
85 |
|
86 |
class ModelProviderEnum(str, Enum):
|
87 |
WATSONX = "watsonx"
|
88 |
OPENAI = "open-ai"
|
89 |
RITS = "rits"
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
|
93 |
EVALUATOR_TO_MODEL_ID = {
|
94 |
EvaluatorNameEnum.MIXTRAL8_7b: "mixtral-8x7b-instruct-v01",
|
95 |
EvaluatorNameEnum.MIXTRAL_LARGE: "mistral-large-instruct",
|
96 |
EvaluatorNameEnum.LLAMA3_1_405B: "llama-3-1-405b-instruct",
|
97 |
-
EvaluatorNameEnum.LLAMA3_1_8B: "llama-3-1-
|
98 |
EvaluatorNameEnum.LLAMA3_1_70B: "llama-3-1-70b-instruct",
|
99 |
EvaluatorNameEnum.LLAMA3_3_70B: "llama-3-3-70b-instruct",
|
100 |
-
EvaluatorNameEnum.
|
|
|
|
|
|
|
|
|
|
|
101 |
EvaluatorNameEnum.O1_PREVIEW: "o1-preview",
|
102 |
EvaluatorNameEnum.O1_MINI: "o1-mini",
|
103 |
EvaluatorNameEnum.GRANITE3_2B: "granite-3-2b-instruct",
|
@@ -105,8 +123,15 @@ EVALUATOR_TO_MODEL_ID = {
|
|
105 |
EvaluatorNameEnum.GRANITE3_1_2B: "granite-3-1-2b-instruct",
|
106 |
EvaluatorNameEnum.GRANITE3_1_8B: "granite-3-1-8b-instruct",
|
107 |
EvaluatorNameEnum.GRANITE3_2_8B: "granite-3-2-8b-instruct",
|
|
|
|
|
|
|
|
|
108 |
}
|
109 |
|
|
|
|
|
|
|
110 |
class EvaluatorMetadata:
|
111 |
name: EvaluatorNameEnum
|
112 |
providers: List[ModelProviderEnum]
|
@@ -123,7 +148,7 @@ EVALUATORS_METADATA = [
|
|
123 |
),
|
124 |
EvaluatorMetadata(
|
125 |
EvaluatorNameEnum.MIXTRAL_LARGE,
|
126 |
-
[ModelProviderEnum.RITS, ModelProviderEnum.WATSONX],
|
127 |
),
|
128 |
EvaluatorMetadata(
|
129 |
EvaluatorNameEnum.GRANITE3_8B,
|
@@ -138,33 +163,69 @@ EVALUATORS_METADATA = [
|
|
138 |
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
|
139 |
),
|
140 |
EvaluatorMetadata(
|
141 |
-
EvaluatorNameEnum.
|
142 |
-
[ModelProviderEnum.
|
|
|
|
|
|
|
|
|
143 |
),
|
144 |
EvaluatorMetadata(
|
145 |
EvaluatorNameEnum.O1_MINI,
|
146 |
-
[ModelProviderEnum.OPENAI, ModelProviderEnum.
|
147 |
),
|
148 |
EvaluatorMetadata(
|
149 |
EvaluatorNameEnum.O1_PREVIEW,
|
150 |
-
[ModelProviderEnum.OPENAI, ModelProviderEnum.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
),
|
152 |
EvaluatorMetadata(
|
153 |
EvaluatorNameEnum.LLAMA3_1_70B,
|
154 |
-
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
|
155 |
),
|
156 |
EvaluatorMetadata(
|
157 |
EvaluatorNameEnum.LLAMA3_1_8B,
|
158 |
-
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
|
159 |
),
|
160 |
EvaluatorMetadata(
|
161 |
EvaluatorNameEnum.LLAMA3_1_405B,
|
162 |
-
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
|
163 |
),
|
164 |
EvaluatorMetadata(
|
165 |
EvaluatorNameEnum.LLAMA3_3_70B,
|
166 |
-
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
|
|
|
|
|
|
|
|
|
167 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
]
|
169 |
|
170 |
################################ Direct Assessment Criterias ################################
|
@@ -952,6 +1013,24 @@ class DirectCriteriaCatalogEnum(Enum):
|
|
952 |
"incorrect": 0.0,
|
953 |
},
|
954 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
955 |
|
956 |
|
957 |
DIRECT_CRITERIA = [c.value for c in DirectCriteriaCatalogEnum]
|
|
|
71 |
LLAMA3_1_70B = "Llama3.1-70b"
|
72 |
LLAMA3_2_3B = "Llama3.2-3b"
|
73 |
LLAMA3_3_70B = "Llama3.3-70b"
|
74 |
+
LLAMA3_4_MAVERICK = "Llama4-Maverick"
|
75 |
+
LLAMA3_4_SCOUT = "Llama4-Scout"
|
76 |
PROMETHEUS = "Prometheus"
|
77 |
+
GPT4o = "GPT-4o"
|
78 |
+
GPT4_1 = "GPT-4.1"
|
79 |
+
GPT4_1_NANO = "GPT-4.1-nano"
|
80 |
+
GPT4_1_MINI = "GPT-4.1-mini"
|
81 |
O1_PREVIEW = "o1-Preview"
|
82 |
O1_MINI = "o1-Mini"
|
83 |
GRANITE_13B = "Granite-13b"
|
|
|
86 |
GRANITE3_1_2B = "Granite3.1-2b"
|
87 |
GRANITE3_1_8B = "Granite3.1-8b"
|
88 |
GRANITE3_2_8B = "Granite3.2-8b"
|
89 |
+
GRANITE3_3_8B = "Granite3.3-8b"
|
90 |
+
DEEPSEEK_V3 = "DeepSeek V3"
|
91 |
+
GEMMA_2_5_PRO = "Gemmini 2.5 Pro"
|
92 |
+
GEMINI_2_5_FLASH = "Gemini 2.5 Flash"
|
93 |
|
94 |
class ModelProviderEnum(str, Enum):
|
95 |
WATSONX = "watsonx"
|
96 |
OPENAI = "open-ai"
|
97 |
RITS = "rits"
|
98 |
+
AZURE = "azure"
|
99 |
+
TOGETHER_AI = "together-ai"
|
100 |
+
AWS = "aws"
|
101 |
+
VERTEX_AI = "vertex-ai"
|
102 |
+
OLLAMA = "ollama"
|
103 |
+
REPLICATE = "replicate"
|
104 |
|
105 |
|
106 |
EVALUATOR_TO_MODEL_ID = {
|
107 |
EvaluatorNameEnum.MIXTRAL8_7b: "mixtral-8x7b-instruct-v01",
|
108 |
EvaluatorNameEnum.MIXTRAL_LARGE: "mistral-large-instruct",
|
109 |
EvaluatorNameEnum.LLAMA3_1_405B: "llama-3-1-405b-instruct",
|
110 |
+
EvaluatorNameEnum.LLAMA3_1_8B: "llama-3-1-8b-instruct",
|
111 |
EvaluatorNameEnum.LLAMA3_1_70B: "llama-3-1-70b-instruct",
|
112 |
EvaluatorNameEnum.LLAMA3_3_70B: "llama-3-3-70b-instruct",
|
113 |
+
EvaluatorNameEnum.LLAMA3_4_MAVERICK: "llama-4-maverick",
|
114 |
+
EvaluatorNameEnum.LLAMA3_4_SCOUT: "llama-4-scout",
|
115 |
+
EvaluatorNameEnum.GPT4o: "gpt-4o-2024-08-06",
|
116 |
+
EvaluatorNameEnum.GPT4_1: "gpt-4-1",
|
117 |
+
EvaluatorNameEnum.GPT4_1_NANO: "gpt-4-1-nano",
|
118 |
+
EvaluatorNameEnum.GPT4_1_MINI: "gpt-4-1-mini",
|
119 |
EvaluatorNameEnum.O1_PREVIEW: "o1-preview",
|
120 |
EvaluatorNameEnum.O1_MINI: "o1-mini",
|
121 |
EvaluatorNameEnum.GRANITE3_2B: "granite-3-2b-instruct",
|
|
|
123 |
EvaluatorNameEnum.GRANITE3_1_2B: "granite-3-1-2b-instruct",
|
124 |
EvaluatorNameEnum.GRANITE3_1_8B: "granite-3-1-8b-instruct",
|
125 |
EvaluatorNameEnum.GRANITE3_2_8B: "granite-3-2-8b-instruct",
|
126 |
+
EvaluatorNameEnum.GRANITE3_3_8B: "granite-3-3-8b-instruct",
|
127 |
+
EvaluatorNameEnum.DEEPSEEK_V3: "deepseek-ai/DeepSeek-V3",
|
128 |
+
EvaluatorNameEnum.GEMMA_2_5_PRO: "gemma-2-5-pro",
|
129 |
+
EvaluatorNameEnum.GEMINI_2_5_FLASH: "gemini-2-5-flash",
|
130 |
}
|
131 |
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
class EvaluatorMetadata:
|
136 |
name: EvaluatorNameEnum
|
137 |
providers: List[ModelProviderEnum]
|
|
|
148 |
),
|
149 |
EvaluatorMetadata(
|
150 |
EvaluatorNameEnum.MIXTRAL_LARGE,
|
151 |
+
[ModelProviderEnum.RITS, ModelProviderEnum.WATSONX, ModelProviderEnum.AWS],
|
152 |
),
|
153 |
EvaluatorMetadata(
|
154 |
EvaluatorNameEnum.GRANITE3_8B,
|
|
|
163 |
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS],
|
164 |
),
|
165 |
EvaluatorMetadata(
|
166 |
+
EvaluatorNameEnum.GRANITE3_3_8B,
|
167 |
+
[ModelProviderEnum.WATSONX, ModelProviderEnum.RITS, ModelProviderEnum.OLLAMA],
|
168 |
+
),
|
169 |
+
EvaluatorMetadata(
|
170 |
+
EvaluatorNameEnum.GPT4o,
|
171 |
+
[ModelProviderEnum.OPENAI, ModelProviderEnum.AZURE],
|
172 |
),
|
173 |
EvaluatorMetadata(
|
174 |
EvaluatorNameEnum.O1_MINI,
|
175 |
+
[ModelProviderEnum.OPENAI, ModelProviderEnum.AZURE],
|
176 |
),
|
177 |
EvaluatorMetadata(
|
178 |
EvaluatorNameEnum.O1_PREVIEW,
|
179 |
+
[ModelProviderEnum.OPENAI, ModelProviderEnum.AZURE],
|
180 |
+
),
|
181 |
+
EvaluatorMetadata(
|
182 |
+
EvaluatorNameEnum.GPT4_1,
|
183 |
+
[ModelProviderEnum.OPENAI, ModelProviderEnum.AZURE, ModelProviderEnum.REPLICATE],
|
184 |
+
),
|
185 |
+
EvaluatorMetadata(
|
186 |
+
EvaluatorNameEnum.GPT4_1_NANO,
|
187 |
+
[ModelProviderEnum.OPENAI, ModelProviderEnum.AZURE],
|
188 |
+
),
|
189 |
+
EvaluatorMetadata(
|
190 |
+
EvaluatorNameEnum.GPT4_1_MINI,
|
191 |
+
[ModelProviderEnum.OPENAI, ModelProviderEnum.AZURE],
|
192 |
),
|
193 |
EvaluatorMetadata(
|
194 |
EvaluatorNameEnum.LLAMA3_1_70B,
|
195 |
+
[ModelProviderEnum.WATSONX, ModelProviderEnum.TOGETHER_AI, ModelProviderEnum.RITS, ModelProviderEnum.OLLAMA],
|
196 |
),
|
197 |
EvaluatorMetadata(
|
198 |
EvaluatorNameEnum.LLAMA3_1_8B,
|
199 |
+
[ModelProviderEnum.WATSONX, ModelProviderEnum.TOGETHER_AI, ModelProviderEnum.RITS, ModelProviderEnum.OLLAMA],
|
200 |
),
|
201 |
EvaluatorMetadata(
|
202 |
EvaluatorNameEnum.LLAMA3_1_405B,
|
203 |
+
[ModelProviderEnum.WATSONX, ModelProviderEnum.TOGETHER_AI, ModelProviderEnum.RITS, ModelProviderEnum.AWS, ModelProviderEnum.OLLAMA],
|
204 |
),
|
205 |
EvaluatorMetadata(
|
206 |
EvaluatorNameEnum.LLAMA3_3_70B,
|
207 |
+
[ModelProviderEnum.WATSONX, ModelProviderEnum.TOGETHER_AI, ModelProviderEnum.RITS, ModelProviderEnum.AWS, ModelProviderEnum.OLLAMA, ModelProviderEnum.AZURE],
|
208 |
+
),
|
209 |
+
EvaluatorMetadata(
|
210 |
+
EvaluatorNameEnum.LLAMA3_4_SCOUT,
|
211 |
+
[ModelProviderEnum.AZURE, ModelProviderEnum.TOGETHER_AI, ModelProviderEnum.AWS, ModelProviderEnum.REPLICATE, ModelProviderEnum.RITS],
|
212 |
),
|
213 |
+
EvaluatorMetadata(
|
214 |
+
EvaluatorNameEnum.LLAMA3_4_MAVERICK,
|
215 |
+
[ModelProviderEnum.AZURE, ModelProviderEnum.TOGETHER_AI, ModelProviderEnum.AWS, ModelProviderEnum.REPLICATE, ModelProviderEnum.RITS],
|
216 |
+
),
|
217 |
+
EvaluatorMetadata(
|
218 |
+
EvaluatorNameEnum.DEEPSEEK_V3,
|
219 |
+
[ModelProviderEnum.RITS, ModelProviderEnum.TOGETHER_AI, ModelProviderEnum.AWS],
|
220 |
+
),
|
221 |
+
EvaluatorMetadata(
|
222 |
+
EvaluatorNameEnum.GEMMA_2_5_PRO,
|
223 |
+
[ModelProviderEnum.VERTEX_AI]
|
224 |
+
),
|
225 |
+
EvaluatorMetadata(
|
226 |
+
EvaluatorNameEnum.GEMINI_2_5_FLASH,
|
227 |
+
[ModelProviderEnum.VERTEX_AI]
|
228 |
+
)
|
229 |
]
|
230 |
|
231 |
################################ Direct Assessment Criterias ################################
|
|
|
1013 |
"incorrect": 0.0,
|
1014 |
},
|
1015 |
)
|
1016 |
+
TOXICITY = CriteriaWithOptions(
|
1017 |
+
name="toxicity",
|
1018 |
+
description="Is the text toxic?",
|
1019 |
+
options=[
|
1020 |
+
CriteriaOption(
|
1021 |
+
name="Yes",
|
1022 |
+
description="The text may be interpreted as offensive, profane or controversial by any culture. This option also includes text with veiled or direct targeted offenses, regardless of the context or the presence of any positive elements or includes text that asks toxic text to be generated.",
|
1023 |
+
),
|
1024 |
+
CriteriaOption(
|
1025 |
+
name="No",
|
1026 |
+
description="The text does not contain offensive or profane language. This includes text that is polite, respectful, and does not use profanity. It excludes any content that could be considered a targeted offense, insult, threat, or that contains profane language or swear words.",
|
1027 |
+
),
|
1028 |
+
],
|
1029 |
+
option_map={
|
1030 |
+
"Yes": 1.0,
|
1031 |
+
"No": 0.0,
|
1032 |
+
}
|
1033 |
+
)
|
1034 |
|
1035 |
|
1036 |
DIRECT_CRITERIA = [c.value for c in DirectCriteriaCatalogEnum]
|
loaders.py
CHANGED
@@ -46,7 +46,6 @@ from typing import (
|
|
46 |
Generator,
|
47 |
Iterable,
|
48 |
List,
|
49 |
-
Literal,
|
50 |
Mapping,
|
51 |
Optional,
|
52 |
Sequence,
|
@@ -66,6 +65,7 @@ from huggingface_hub import HfApi
|
|
66 |
from tqdm import tqdm
|
67 |
|
68 |
from .dataclass import NonPositionalField
|
|
|
69 |
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
70 |
from .fusion import FixedFusion
|
71 |
from .logging_utils import get_logger
|
@@ -403,64 +403,20 @@ class LoadHF(LazyLoader):
|
|
403 |
if i + 1 >= limit:
|
404 |
break
|
405 |
|
406 |
-
|
407 |
-
class
|
408 |
-
"""Loads data from CSV files.
|
409 |
-
|
410 |
-
Supports streaming and can handle large files by loading them in chunks.
|
411 |
-
|
412 |
-
Args:
|
413 |
-
files (Dict[str, str]): A dictionary mapping names to file paths.
|
414 |
-
chunksize : Size of the chunks to load at a time.
|
415 |
-
loader_limit: Optional integer to specify a limit on the number of records to load.
|
416 |
-
streaming: Bool indicating if streaming should be used.
|
417 |
-
sep: String specifying the separator used in the CSV files.
|
418 |
-
|
419 |
-
Example:
|
420 |
-
Loading csv
|
421 |
-
|
422 |
-
.. code-block:: python
|
423 |
-
|
424 |
-
load_csv = LoadCSV(files={'train': 'path/to/train.csv'}, chunksize=100)
|
425 |
-
"""
|
426 |
|
427 |
files: Dict[str, str]
|
428 |
chunksize: int = 1000
|
429 |
loader_limit: Optional[int] = None
|
430 |
streaming: bool = True
|
431 |
-
sep: str = ","
|
432 |
compression: Optional[str] = None
|
433 |
-
lines: Optional[bool] = None
|
434 |
-
file_type: Literal["csv", "json"] = "csv"
|
435 |
|
436 |
def _maybe_set_classification_policy(self):
|
437 |
self.set_default_data_classification(
|
438 |
["proprietary"], "when loading from local files"
|
439 |
)
|
440 |
|
441 |
-
def get_reader(self):
|
442 |
-
if self.file_type == "csv":
|
443 |
-
return pd.read_csv
|
444 |
-
if self.file_type == "json":
|
445 |
-
return pd.read_json
|
446 |
-
raise ValueError()
|
447 |
-
|
448 |
-
def get_args(self):
|
449 |
-
args = {}
|
450 |
-
if self.file_type == "csv":
|
451 |
-
args["sep"] = self.sep
|
452 |
-
args["low_memory"] = self.streaming
|
453 |
-
if self.compression is not None:
|
454 |
-
args["compression"] = self.compression
|
455 |
-
if self.lines is not None:
|
456 |
-
args["lines"] = self.lines
|
457 |
-
if self.get_limit() is not None:
|
458 |
-
args["nrows"] = self.get_limit()
|
459 |
-
return args
|
460 |
-
|
461 |
-
def get_splits(self) -> List[str]:
|
462 |
-
return list(self.files.keys())
|
463 |
-
|
464 |
def split_generator(self, split: str) -> Generator:
|
465 |
dataset_id = str(self) + "_" + split
|
466 |
dataset = self.__class__._loader_cache.get(dataset_id, None)
|
@@ -469,33 +425,150 @@ class LoadCSV(LazyLoader):
|
|
469 |
self.log_limited_loading()
|
470 |
for attempt in range(settings.loaders_max_retries):
|
471 |
try:
|
472 |
-
|
473 |
if self.get_limit() is not None:
|
474 |
self.log_limited_loading()
|
475 |
|
476 |
try:
|
477 |
-
|
478 |
-
"records"
|
479 |
-
)
|
480 |
break
|
481 |
except ValueError:
|
482 |
import fsspec
|
483 |
|
484 |
-
with fsspec.open(
|
485 |
-
|
486 |
break
|
487 |
except Exception as e:
|
488 |
-
logger.
|
489 |
if attempt < settings.loaders_max_retries - 1:
|
490 |
time.sleep(2)
|
491 |
else:
|
492 |
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
self.__class__._loader_cache.max_size = settings.loader_cache_size
|
494 |
self.__class__._loader_cache[dataset_id] = dataset
|
495 |
|
496 |
for instance in self.__class__._loader_cache[dataset_id]:
|
497 |
yield recursive_copy(instance)
|
498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
class LoadFromSklearn(LazyLoader):
|
501 |
"""Loads datasets from the sklearn library.
|
|
|
46 |
Generator,
|
47 |
Iterable,
|
48 |
List,
|
|
|
49 |
Mapping,
|
50 |
Optional,
|
51 |
Sequence,
|
|
|
65 |
from tqdm import tqdm
|
66 |
|
67 |
from .dataclass import NonPositionalField
|
68 |
+
from .dict_utils import dict_get
|
69 |
from .error_utils import Documentation, UnitxtError, UnitxtWarning
|
70 |
from .fusion import FixedFusion
|
71 |
from .logging_utils import get_logger
|
|
|
403 |
if i + 1 >= limit:
|
404 |
break
|
405 |
|
406 |
+
class LoadWithPandas(LazyLoader):
|
407 |
+
"""Utility base class for classes loading with pandas."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
|
409 |
files: Dict[str, str]
|
410 |
chunksize: int = 1000
|
411 |
loader_limit: Optional[int] = None
|
412 |
streaming: bool = True
|
|
|
413 |
compression: Optional[str] = None
|
|
|
|
|
414 |
|
415 |
def _maybe_set_classification_policy(self):
|
416 |
self.set_default_data_classification(
|
417 |
["proprietary"], "when loading from local files"
|
418 |
)
|
419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
def split_generator(self, split: str) -> Generator:
|
421 |
dataset_id = str(self) + "_" + split
|
422 |
dataset = self.__class__._loader_cache.get(dataset_id, None)
|
|
|
425 |
self.log_limited_loading()
|
426 |
for attempt in range(settings.loaders_max_retries):
|
427 |
try:
|
428 |
+
file = self.files[split]
|
429 |
if self.get_limit() is not None:
|
430 |
self.log_limited_loading()
|
431 |
|
432 |
try:
|
433 |
+
dataframe = self.read_dataframe(file)
|
|
|
|
|
434 |
break
|
435 |
except ValueError:
|
436 |
import fsspec
|
437 |
|
438 |
+
with fsspec.open(file, mode="rt") as file:
|
439 |
+
dataframe = self.read_dataframe(file)
|
440 |
break
|
441 |
except Exception as e:
|
442 |
+
logger.warning(f"Attempt load {attempt + 1} failed: {e}")
|
443 |
if attempt < settings.loaders_max_retries - 1:
|
444 |
time.sleep(2)
|
445 |
else:
|
446 |
raise e
|
447 |
+
|
448 |
+
limit = self.get_limit()
|
449 |
+
if limit is not None and len(dataframe) > limit:
|
450 |
+
dataframe = dataframe.head(limit)
|
451 |
+
|
452 |
+
dataset = dataframe.to_dict("records")
|
453 |
+
|
454 |
self.__class__._loader_cache.max_size = settings.loader_cache_size
|
455 |
self.__class__._loader_cache[dataset_id] = dataset
|
456 |
|
457 |
for instance in self.__class__._loader_cache[dataset_id]:
|
458 |
yield recursive_copy(instance)
|
459 |
|
460 |
+
def get_splits(self) -> List[str]:
|
461 |
+
return list(self.files.keys())
|
462 |
+
|
463 |
+
|
464 |
+
def get_args(self) -> Dict[str, Any]:
|
465 |
+
args = {}
|
466 |
+
if self.compression is not None:
|
467 |
+
args["compression"] = self.compression
|
468 |
+
if self.get_limit() is not None:
|
469 |
+
args["nrows"] = self.get_limit()
|
470 |
+
return args
|
471 |
+
|
472 |
+
@abstractmethod
|
473 |
+
def read_dataframe(self, file) -> pd.DataFrame:
|
474 |
+
...
|
475 |
+
|
476 |
+
class LoadCSV(LoadWithPandas):
|
477 |
+
"""Loads data from CSV files.
|
478 |
+
|
479 |
+
Supports streaming and can handle large files by loading them in chunks.
|
480 |
+
|
481 |
+
Args:
|
482 |
+
files (Dict[str, str]): A dictionary mapping names to file paths.
|
483 |
+
chunksize : Size of the chunks to load at a time.
|
484 |
+
loader_limit: Optional integer to specify a limit on the number of records to load.
|
485 |
+
streaming: Bool indicating if streaming should be used.
|
486 |
+
sep: String specifying the separator used in the CSV files.
|
487 |
+
|
488 |
+
Example:
|
489 |
+
Loading csv
|
490 |
+
|
491 |
+
.. code-block:: python
|
492 |
+
|
493 |
+
load_csv = LoadCSV(files={'train': 'path/to/train.csv'}, chunksize=100)
|
494 |
+
"""
|
495 |
+
|
496 |
+
sep: str = ","
|
497 |
+
|
498 |
+
def read_dataframe(self, file) -> pd.DataFrame:
|
499 |
+
return pd.read_csv(
|
500 |
+
file,
|
501 |
+
sep=self.sep,
|
502 |
+
low_memory=self.streaming,
|
503 |
+
**self.get_args()
|
504 |
+
)
|
505 |
+
|
506 |
+
|
507 |
+
def read_file(source) -> bytes:
|
508 |
+
|
509 |
+
if hasattr(source, "read"):
|
510 |
+
return source.read()
|
511 |
+
|
512 |
+
if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")):
|
513 |
+
from urllib import request
|
514 |
+
with request.urlopen(source) as response:
|
515 |
+
return response.read()
|
516 |
+
|
517 |
+
with open(source, "rb") as f:
|
518 |
+
return f.read()
|
519 |
+
|
520 |
+
class LoadJsonFile(LoadWithPandas):
|
521 |
+
"""Loads data from JSON files.
|
522 |
+
|
523 |
+
Supports streaming and can handle large files by loading them in chunks.
|
524 |
+
|
525 |
+
Args:
|
526 |
+
files (Dict[str, str]): A dictionary mapping names to file paths.
|
527 |
+
chunksize : Size of the chunks to load at a time.
|
528 |
+
loader_limit: Optional integer to specify a limit on the number of records to load.
|
529 |
+
streaming: Bool indicating if streaming should be used.
|
530 |
+
lines: Bool indicate if it is json lines file structure. Otherwise, assumes a single json object in the file.
|
531 |
+
data_field: optional field within the json object, that contains the list of instances.
|
532 |
+
|
533 |
+
Example:
|
534 |
+
Loading json lines
|
535 |
+
|
536 |
+
.. code-block:: python
|
537 |
+
|
538 |
+
load_csv = LoadJsonFile(files={'train': 'path/to/train.jsonl'}, line=True, chunksize=100)
|
539 |
+
"""
|
540 |
+
|
541 |
+
lines: bool = False
|
542 |
+
data_field: Optional[str] = None
|
543 |
+
|
544 |
+
def read_dataframe(self, file) -> pd.DataFrame:
|
545 |
+
|
546 |
+
args = self.get_args()
|
547 |
+
if not self.lines:
|
548 |
+
data = json.loads(read_file(file))
|
549 |
+
if (self.data_field):
|
550 |
+
instances = dict_get(data, self.data_field)
|
551 |
+
if not isoftype(instances,List[Dict[str,Any]]):
|
552 |
+
raise UnitxtError(f"{self.data_field} of file {file} is not a list of dictionariess in LoadJsonFile loader")
|
553 |
+
else:
|
554 |
+
if isoftype(data,Dict[str,Any]):
|
555 |
+
instances = [data]
|
556 |
+
elif isoftype(data,List[Dict[str,Any]]):
|
557 |
+
instances=data
|
558 |
+
else:
|
559 |
+
raise UnitxtError(f"data of file {file} is not dictionary or a list of dictionaries in LoadJsonFile loader")
|
560 |
+
dataframe = pd.DataFrame(instances)
|
561 |
+
else:
|
562 |
+
if self.data_field is not None:
|
563 |
+
raise UnitxtError("Can not load from a specific 'data_field' when loading multiple lines (lines=True)")
|
564 |
+
dataframe = pd.read_json(
|
565 |
+
file,
|
566 |
+
lines=self.lines,
|
567 |
+
**args
|
568 |
+
)
|
569 |
+
return dataframe
|
570 |
+
|
571 |
+
|
572 |
|
573 |
class LoadFromSklearn(LazyLoader):
|
574 |
"""Loads datasets from the sklearn library.
|
metric.py
CHANGED
@@ -5,6 +5,7 @@ import evaluate
|
|
5 |
from .api import __file__ as _
|
6 |
from .artifact import __file__ as _
|
7 |
from .augmentors import __file__ as _
|
|
|
8 |
from .benchmark import __file__ as _
|
9 |
from .blocks import __file__ as _
|
10 |
from .card import __file__ as _
|
|
|
5 |
from .api import __file__ as _
|
6 |
from .artifact import __file__ as _
|
7 |
from .augmentors import __file__ as _
|
8 |
+
from .base_metric import __file__ as _
|
9 |
from .benchmark import __file__ as _
|
10 |
from .blocks import __file__ as _
|
11 |
from .card import __file__ as _
|
metrics.py
CHANGED
@@ -33,6 +33,7 @@ from scipy.stats import bootstrap
|
|
33 |
from scipy.stats._warnings_errors import DegenerateDataWarning
|
34 |
|
35 |
from .artifact import Artifact
|
|
|
36 |
from .collections import ListCollection
|
37 |
from .dataclass import (
|
38 |
AbstractField,
|
@@ -63,7 +64,7 @@ from .operators import ArtifactFetcherMixin, Copy, Set
|
|
63 |
from .random_utils import get_seed
|
64 |
from .settings_utils import get_settings
|
65 |
from .stream import MultiStream, Stream
|
66 |
-
from .type_utils import
|
67 |
from .types import ToolCall
|
68 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
69 |
|
@@ -154,211 +155,6 @@ def parse_string_types_instead_of_actual_objects(obj):
|
|
154 |
return parse_type_string(obj)
|
155 |
|
156 |
|
157 |
-
class Metric(Artifact):
|
158 |
-
main_score: str = AbstractField()
|
159 |
-
# Override 'prediction_type' with the expected type of predictions
|
160 |
-
# and references. Example: "List[str]", "List[Dict]"", "string".
|
161 |
-
# If left with default None, a warning will be displayed.
|
162 |
-
# In future versions of unitxt, this will be an error.
|
163 |
-
prediction_type: Union[Type, str] = Any
|
164 |
-
|
165 |
-
# Standard metrics can receive multiple references per predictions (in a list)
|
166 |
-
# Some metrics support only a single reference per prediction (one element in the list)
|
167 |
-
single_reference_per_prediction: bool = False
|
168 |
-
|
169 |
-
#
|
170 |
-
# Used to add a prefix to all score, except the "score_name" and "score" fields.
|
171 |
-
# This is used to distinguish two scores of the same metrics, operating on different fields of the task
|
172 |
-
#
|
173 |
-
score_prefix: str = ""
|
174 |
-
|
175 |
-
def prepare_args(self):
|
176 |
-
super().prepare_args()
|
177 |
-
if isinstance(self.prediction_type, str):
|
178 |
-
self.prediction_type = parse_string_types_instead_of_actual_objects(
|
179 |
-
self.prediction_type
|
180 |
-
)
|
181 |
-
|
182 |
-
@classmethod
|
183 |
-
def process_data_after_load(cls, data):
|
184 |
-
if "prediction_type" in data:
|
185 |
-
data["prediction_type"] = parse_type_string(data["prediction_type"])
|
186 |
-
return data
|
187 |
-
|
188 |
-
def process_data_before_dump(self, data):
|
189 |
-
if "prediction_type" in data:
|
190 |
-
if not isinstance(data["prediction_type"], str):
|
191 |
-
data["prediction_type"] = to_type_string(data["prediction_type"])
|
192 |
-
return data
|
193 |
-
|
194 |
-
def _add_score_prefix(self, score_name):
|
195 |
-
return (
|
196 |
-
self.score_prefix + score_name
|
197 |
-
if score_name not in ["score", "score_name", "num_of_instances"]
|
198 |
-
else score_name
|
199 |
-
)
|
200 |
-
|
201 |
-
def _add_score_prefixes_to_score_dict_and_check_against_existing_scores(
|
202 |
-
self, scores: Dict[str, Any], existing_scores: Dict[str, Any]
|
203 |
-
) -> Dict[str, Any]:
|
204 |
-
new_scores = {}
|
205 |
-
for score_name, score in scores.items():
|
206 |
-
score_with_prefix = self._add_score_prefix(score_name)
|
207 |
-
new_scores[score_with_prefix] = (
|
208 |
-
score if score_name not in ["score_name"] else self.score_prefix + score
|
209 |
-
)
|
210 |
-
for new_score_name in new_scores:
|
211 |
-
if new_score_name in ["score", "score_name", "num_of_instances"]:
|
212 |
-
continue
|
213 |
-
if new_score_name in existing_scores:
|
214 |
-
UnitxtWarning(
|
215 |
-
message=f"Metric '{new_score_name}' that has just been evaluated to {new_scores[new_score_name]}, is already recorded "
|
216 |
-
f"to have value {existing_scores[new_score_name]} by a previous metric evaluation on this instance or stream. "
|
217 |
-
f"To avoid overwriting the existing value, add a score_prefix to the metric name (e.g. score_prefix='my_second_' , "
|
218 |
-
f"which will yield, in this case, a score named: 'my_second_{new_score_name}')",
|
219 |
-
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
220 |
-
)
|
221 |
-
return new_scores
|
222 |
-
|
223 |
-
def _validate_references_and_prediction(self, references, predictions):
|
224 |
-
if not isoftype(predictions, List[Any]):
|
225 |
-
raise ValueError(
|
226 |
-
f"Metric {self.get_metric_name()} should receive a list of predictions {self.get_metric_name()}. Received predictions of type {type(predictions)}: {predictions}"
|
227 |
-
)
|
228 |
-
|
229 |
-
if not isoftype(references, List[Any]):
|
230 |
-
raise ValueError(
|
231 |
-
f"Metric {self.get_metric_name()} should receive a list of predictions. Received references of type {type(references)}: {references}"
|
232 |
-
)
|
233 |
-
|
234 |
-
if len(references) != len(predictions):
|
235 |
-
raise ValueError(
|
236 |
-
f"references size ({len(references)})"
|
237 |
-
f" doesn't mach predictions size ({len(references)})."
|
238 |
-
)
|
239 |
-
|
240 |
-
for reference in references:
|
241 |
-
self._validate_reference(reference)
|
242 |
-
|
243 |
-
for prediction in predictions:
|
244 |
-
self._validate_prediction(prediction)
|
245 |
-
|
246 |
-
def _validate_prediction(self, prediction):
|
247 |
-
if not isoftype(prediction, self.prediction_type):
|
248 |
-
raise ValueError(
|
249 |
-
f"Each prediction is expected to be of type '{to_type_string(self.prediction_type)}' in {self.get_metric_name()} metric. Received prediction of type {type(prediction)}: {prediction}"
|
250 |
-
)
|
251 |
-
|
252 |
-
def _validate_reference(self, reference):
|
253 |
-
if not isoftype(reference, List[Any]):
|
254 |
-
raise ValueError(
|
255 |
-
f"Expecting a list of references for each prediction in {self.get_metric_name()} metric. Received reference of type {type(reference)}: {reference}"
|
256 |
-
)
|
257 |
-
if self.single_reference_per_prediction and not len(reference) == 1:
|
258 |
-
raise ValueError(
|
259 |
-
f"Expecting a list with a single reference per prediction in {self.get_metric_name()} metric. Received a list with multiple references: {reference}"
|
260 |
-
)
|
261 |
-
for ref in reference:
|
262 |
-
if not isoftype(ref, self.prediction_type):
|
263 |
-
raise ValueError(
|
264 |
-
f"Each reference is expected to be of type '{to_type_string(self.prediction_type)}' in {self.get_metric_name()} metric. Received reference of type {type(ref)}: {ref}"
|
265 |
-
)
|
266 |
-
|
267 |
-
def get_metric_name(self):
|
268 |
-
if self.__id__ is not None:
|
269 |
-
return self.__id__
|
270 |
-
return self.__class__.__name__
|
271 |
-
|
272 |
-
def consume_stream(self, stream: Stream):
|
273 |
-
references = []
|
274 |
-
predictions = []
|
275 |
-
additional_inputs = []
|
276 |
-
instances = []
|
277 |
-
for instance in stream:
|
278 |
-
instance = self.verify_instance(instance)
|
279 |
-
references.append(instance["references"])
|
280 |
-
predictions.append(instance["prediction"])
|
281 |
-
additional_inputs.append(
|
282 |
-
instance["additional_inputs"] if "additional_inputs" in instance else {}
|
283 |
-
)
|
284 |
-
instances.append(instance)
|
285 |
-
return predictions, references, additional_inputs, instances
|
286 |
-
|
287 |
-
@staticmethod
|
288 |
-
def update_instance_scores(instances, instances_scores: List[Dict[str, Any]]):
|
289 |
-
for instance, new_scores in zip(instances, instances_scores):
|
290 |
-
if "score" not in instance:
|
291 |
-
instance["score"] = {}
|
292 |
-
scores = instance["score"]
|
293 |
-
if "instance" not in scores:
|
294 |
-
scores["instance"] = {}
|
295 |
-
scores["instance"].update(new_scores)
|
296 |
-
|
297 |
-
@staticmethod
|
298 |
-
def set_global_score(instances, global_score: Dict[str, Any]):
|
299 |
-
for instance in instances:
|
300 |
-
if "score" not in instance:
|
301 |
-
instance["score"] = {}
|
302 |
-
scores = instance["score"]
|
303 |
-
if "global" not in scores:
|
304 |
-
scores["global"] = {}
|
305 |
-
scores["global"] = global_score
|
306 |
-
|
307 |
-
@abstractmethod
|
308 |
-
def disable_confidence_interval_calculation(self):
|
309 |
-
pass
|
310 |
-
|
311 |
-
# update instance["score"]["global"] with the global_score just computed for the
|
312 |
-
# current metric. global_score contains "score" and "score_name" fields that reflect
|
313 |
-
# (the main_score of) the current metric. If CI was computed for global_score, then global_score
|
314 |
-
# also contains "score_ci_low" and "score_ci_high" that reflect (the main_score of) the current metric.
|
315 |
-
# A simple python-dictionary-update adds new fields to instance["score"]["global"], and also replaces the values
|
316 |
-
# of its fields "score" and "score_name" (and "score_ci_low", "score_ci_high" if applicable),
|
317 |
-
# to reflect the current metric, overwriting previous metrics' settings of these fields
|
318 |
-
# (if any previous metric exists).
|
319 |
-
# When global_score does NOT contain ci score (because CI was not computed for the current metric), but
|
320 |
-
# one of the previous metrics computed did have, the last of such previous metrics set the values in
|
321 |
-
# fields "score_ci_low" and "score_ci_high" in instance["score"]["global"] to reflect its
|
322 |
-
# (the previous metric's) CI scores.
|
323 |
-
# Because CI is not computed for the current metric, global_score does not contain fields "score_ci_low" and
|
324 |
-
# "score_ci_high" to overwrite the ones existing in instance["score"]["global"], and these might remain in
|
325 |
-
# instance["score"]["global"], but their values, that are not associated with the current metric, are,
|
326 |
-
# therefore, not consistent with "score_name".
|
327 |
-
# In such a case, following the python-dictionary-update, we pop out fields "score_ci_low" and
|
328 |
-
# "score_ci_high" from instance["score"]["global"], so that now all the fields "score.." in
|
329 |
-
# instance["score"]["global"] are consistent with the current metric: The metric that is named
|
330 |
-
# instance["score"]["global"]["score_name"], its score shows in
|
331 |
-
# field instance["score"]["global"]["score"], and it does not have ci_scores,
|
332 |
-
# which is also reflected in the absence of fields "score_ci_low" and "score_ci_high" from instance["score"]["global"].
|
333 |
-
# If ci IS computed for the current metric, global_score contains "score_ci_low" and "score_ci_high", and these overwrite
|
334 |
-
# the ones existing in instance["score"]["global"] by the simple python-dictionary-update, and no need for any further fixeup.
|
335 |
-
def update_and_adjust_global_score(
|
336 |
-
self, instance: Dict[str, Any], global_score: dict
|
337 |
-
):
|
338 |
-
for score_name in global_score:
|
339 |
-
if score_name in [
|
340 |
-
"score",
|
341 |
-
"score_name",
|
342 |
-
"score_ci_low",
|
343 |
-
"score_ci_high",
|
344 |
-
"num_of_instances",
|
345 |
-
]:
|
346 |
-
continue
|
347 |
-
if score_name in instance["score"]["global"]:
|
348 |
-
UnitxtWarning(
|
349 |
-
message=f"Global metric '{score_name}' that has just been evaluated to {global_score[score_name]}, is already recorded "
|
350 |
-
f"to have value {instance['score']['global'][score_name]} by a previous metric evaluation on this stream. "
|
351 |
-
f"To avoid overwriting the value, add a score_prefix to the metric (e.g. score_prefix='my_{score_name}'.",
|
352 |
-
additional_info_id=Documentation.MULTIPLE_METRICS_OUTPUTS,
|
353 |
-
)
|
354 |
-
instance["score"]["global"].update(global_score)
|
355 |
-
for score_ci in ["score_ci_low", "score_ci_high"]:
|
356 |
-
if score_ci in global_score:
|
357 |
-
continue
|
358 |
-
if score_ci in instance["score"]["global"]:
|
359 |
-
instance["score"]["global"].pop(score_ci)
|
360 |
-
|
361 |
-
|
362 |
def new_random_generator():
|
363 |
# The np.random.default_rng expects a 32-bit int, while hash(..) can return a 64-bit integer.
|
364 |
# So use '& MAX_32BIT' to get a 32-bit seed.
|
@@ -848,8 +644,10 @@ class ToolCallingMetric(ReductionInstanceMetric[str, Dict[str, float]]):
|
|
848 |
|
849 |
if len(prediction["arguments"]) > 0:
|
850 |
score = value_matches / len(prediction["arguments"])
|
851 |
-
|
852 |
score = 1.0
|
|
|
|
|
853 |
if score > argument_value_precision:
|
854 |
argument_value_precision = score
|
855 |
|
@@ -3593,17 +3391,61 @@ class KeyValueExtraction(GlobalMetric):
|
|
3593 |
return result
|
3594 |
|
3595 |
class ToolCallKeyValueExtraction(KeyValueExtraction):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3596 |
prediction_type = ToolCall
|
3597 |
|
|
|
|
|
3598 |
def flatten_dict(self,nested_dict, parent_key="", sep="."):
|
3599 |
flat_dict = {}
|
3600 |
for k, v in nested_dict.items():
|
3601 |
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
3602 |
-
|
3603 |
-
|
3604 |
-
|
3605 |
-
|
3606 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3607 |
flat_dict.update(self.flatten_dict(v, new_key, sep=sep))
|
3608 |
else:
|
3609 |
flat_dict[new_key] = v
|
@@ -6290,7 +6132,7 @@ class GraniteGuardianBase(InstanceMetric):
|
|
6290 |
return result
|
6291 |
|
6292 |
def create_message(self, role: str, content: str) -> List[Dict[str, str]]:
|
6293 |
-
return [{"role": role, "content": content}]
|
6294 |
|
6295 |
def parse_output(self, generated_tokens_list):
|
6296 |
top_tokens_list = [
|
@@ -6421,12 +6263,22 @@ class GraniteGuardianAgenticRisk(GraniteGuardianBase):
|
|
6421 |
|
6422 |
def process_input_fields(self, task_data):
|
6423 |
messages = []
|
|
|
|
|
|
|
|
|
|
|
6424 |
messages += self.create_message(
|
6425 |
-
"tools",
|
6426 |
)
|
6427 |
messages += self.create_message("user", task_data[self.user_message_field])
|
|
|
|
|
|
|
|
|
|
|
6428 |
messages += self.create_message(
|
6429 |
-
"assistant",
|
6430 |
)
|
6431 |
return messages
|
6432 |
|
|
|
33 |
from scipy.stats._warnings_errors import DegenerateDataWarning
|
34 |
|
35 |
from .artifact import Artifact
|
36 |
+
from .base_metric import Metric
|
37 |
from .collections import ListCollection
|
38 |
from .dataclass import (
|
39 |
AbstractField,
|
|
|
64 |
from .random_utils import get_seed
|
65 |
from .settings_utils import get_settings
|
66 |
from .stream import MultiStream, Stream
|
67 |
+
from .type_utils import isoftype, parse_type_string, to_type_string
|
68 |
from .types import ToolCall
|
69 |
from .utils import deep_copy, recursive_copy, retry_connection_with_exponential_backoff
|
70 |
|
|
|
155 |
return parse_type_string(obj)
|
156 |
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
def new_random_generator():
|
159 |
# The np.random.default_rng expects a 32-bit int, while hash(..) can return a 64-bit integer.
|
160 |
# So use '& MAX_32BIT' to get a 32-bit seed.
|
|
|
644 |
|
645 |
if len(prediction["arguments"]) > 0:
|
646 |
score = value_matches / len(prediction["arguments"])
|
647 |
+
elif len(reference["arguments"]) == 0:
|
648 |
score = 1.0
|
649 |
+
else:
|
650 |
+
score = 0.0
|
651 |
if score > argument_value_precision:
|
652 |
argument_value_precision = score
|
653 |
|
|
|
3391 |
return result
|
3392 |
|
3393 |
class ToolCallKeyValueExtraction(KeyValueExtraction):
|
3394 |
+
"""Metrics that formulate ToolCall evaluation as a Key Value Extraction task.
|
3395 |
+
|
3396 |
+
Each argument and each nested value are first flatten to a key value.
|
3397 |
+
|
3398 |
+
{ arguments : {"name" : "John", "address" : { "street" : "Main St", "City" : "Smallville" } } }
|
3399 |
+
|
3400 |
+
becomes
|
3401 |
+
|
3402 |
+
argument.names = "John"
|
3403 |
+
argument.address.street = "Main St"
|
3404 |
+
argument.address.city = "Smallvile"
|
3405 |
+
|
3406 |
+
Note that by default, if a parameter is a list of dictionaries, they are flattened with indexes
|
3407 |
+
|
3408 |
+
{ arguments : {"addresses" : [{ "street" : "Main St", "City" : "Smallville" } ,
|
3409 |
+
{ "street" : "Log St", "City" : "BigCity" } ] } }
|
3410 |
+
|
3411 |
+
argument.address.0.street = "Main St"
|
3412 |
+
argument.address.0.city = "Smallvile"
|
3413 |
+
argument.address.1.street = "Log St"
|
3414 |
+
argument.address.1.city = "BigCity"
|
3415 |
+
|
3416 |
+
But if each dictionary in the list has a single unique key, it is used instead.
|
3417 |
+
|
3418 |
+
{ arguments : {"addresses" : [ { "home" : { "street" : "Main St", "City" : "Smallville" }} ,
|
3419 |
+
{ "work" : {"street" : "Log St", "City" : "BigCity" } ] } }
|
3420 |
+
|
3421 |
+
argument.address.home.street = "Main St"
|
3422 |
+
argument.address.home.city = "Smallvile"
|
3423 |
+
argument.address.work.street = "Log St"
|
3424 |
+
argument.address.work.city = "BigCity"
|
3425 |
+
|
3426 |
+
"""
|
3427 |
prediction_type = ToolCall
|
3428 |
|
3429 |
+
flatten_list_of_dictionaries = False
|
3430 |
+
|
3431 |
def flatten_dict(self,nested_dict, parent_key="", sep="."):
|
3432 |
flat_dict = {}
|
3433 |
for k, v in nested_dict.items():
|
3434 |
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
3435 |
+
|
3436 |
+
|
3437 |
+
|
3438 |
+
|
3439 |
+
if isoftype(v, List[Dict[Any,Any]]):
|
3440 |
+
if (all(len(d) == 1 for d in v)):
|
3441 |
+
keys = [next(iter(d.keys())) for d in v]
|
3442 |
+
if len(keys) == len(set(keys)):
|
3443 |
+
for e in v:
|
3444 |
+
flat_dict.update(self.flatten_dict(e, f"{new_key}",sep=sep))
|
3445 |
+
continue
|
3446 |
+
for i,e in enumerate(v):
|
3447 |
+
flat_dict.update(self.flatten_dict(e, f"{new_key}{sep}{i}",sep=sep))
|
3448 |
+
elif isoftype(v, Dict[Any,Any]):
|
3449 |
flat_dict.update(self.flatten_dict(v, new_key, sep=sep))
|
3450 |
else:
|
3451 |
flat_dict[new_key] = v
|
|
|
6132 |
return result
|
6133 |
|
6134 |
def create_message(self, role: str, content: str) -> List[Dict[str, str]]:
|
6135 |
+
return [{"role": role, "content": str(content)}]
|
6136 |
|
6137 |
def parse_output(self, generated_tokens_list):
|
6138 |
top_tokens_list = [
|
|
|
6263 |
|
6264 |
def process_input_fields(self, task_data):
|
6265 |
messages = []
|
6266 |
+
|
6267 |
+
tools = task_data[self.tools_field]
|
6268 |
+
if isinstance(tools, str):
|
6269 |
+
tools = json.loads(tools)
|
6270 |
+
|
6271 |
messages += self.create_message(
|
6272 |
+
"tools", tools
|
6273 |
)
|
6274 |
messages += self.create_message("user", task_data[self.user_message_field])
|
6275 |
+
|
6276 |
+
calls = task_data[self.assistant_message_field]
|
6277 |
+
if isinstance(calls, str):
|
6278 |
+
calls = json.loads(calls)
|
6279 |
+
|
6280 |
messages += self.create_message(
|
6281 |
+
"assistant", calls
|
6282 |
)
|
6283 |
return messages
|
6284 |
|
operators.py
CHANGED
@@ -76,7 +76,6 @@ from .operator import (
|
|
76 |
PagedStreamOperator,
|
77 |
SequentialOperator,
|
78 |
SideEffectOperator,
|
79 |
-
SingleStreamReducer,
|
80 |
SourceOperator,
|
81 |
StreamingOperator,
|
82 |
StreamInitializerOperator,
|
@@ -85,7 +84,7 @@ from .operator import (
|
|
85 |
from .random_utils import new_random_generator
|
86 |
from .settings_utils import get_settings
|
87 |
from .stream import DynamicStream, Stream
|
88 |
-
from .text_utils import
|
89 |
from .type_utils import isoftype
|
90 |
from .utils import (
|
91 |
LRUCache,
|
@@ -283,6 +282,7 @@ class Set(InstanceOperator):
|
|
283 |
dict_set(instance, key, value)
|
284 |
return instance
|
285 |
|
|
|
286 |
def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
|
287 |
"""Recursively traverses a data structure (dicts and lists), replaces values of target_key using value_map, and removes values listed in value_remove.
|
288 |
|
@@ -323,13 +323,34 @@ def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
|
|
323 |
recursive_key_value_replace(item, target_key, value_map, value_remove)
|
324 |
return data
|
325 |
|
|
|
326 |
class RecursiveReplace(InstanceOperator):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
key: str
|
328 |
map_values: dict
|
329 |
remove_values: Optional[list] = None
|
330 |
|
331 |
-
def process(
|
332 |
-
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
@deprecation(version="2.0.0", alternative=Set)
|
335 |
class AddFields(Set):
|
@@ -427,8 +448,8 @@ class InstanceFieldOperator(InstanceOperator):
|
|
427 |
def verify_field_definition(self):
|
428 |
if hasattr(self, "_field_to_field") and self._field_to_field is not None:
|
429 |
return
|
430 |
-
assert (
|
431 |
-
|
432 |
), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
|
433 |
assert (
|
434 |
self.to_field is None or self.field_to_field is None
|
@@ -605,10 +626,27 @@ class AddConstant(FieldOperator):
|
|
605 |
def process_value(self, value: Any) -> Any:
|
606 |
return self.add + value
|
607 |
|
608 |
-
|
609 |
class ShuffleFieldValues(FieldOperator):
|
610 |
-
|
|
|
|
|
|
|
|
|
|
|
611 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
612 |
def process_value(self, value: Any) -> Any:
|
613 |
res = list(value)
|
614 |
random_generator = new_random_generator(sub_seed=res)
|
@@ -784,9 +822,8 @@ class InterleaveListsToDialogOperator(InstanceOperator):
|
|
784 |
user_turns = instance[self.user_turns_field]
|
785 |
assistant_turns = instance[self.assistant_turns_field]
|
786 |
|
787 |
-
assert (
|
788 |
-
len(user_turns)
|
789 |
-
or (len(user_turns) - len(assistant_turns) == 1)
|
790 |
), "user_turns must have either the same length as assistant_turns or one more turn."
|
791 |
|
792 |
interleaved_dialog = []
|
@@ -945,7 +982,14 @@ class CopyFields(Copy):
|
|
945 |
|
946 |
|
947 |
class GetItemByIndex(FieldOperator):
|
948 |
-
"""Get from the
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
949 |
|
950 |
items_list: List[Any]
|
951 |
|
@@ -977,7 +1021,13 @@ class Cast(FieldOperator):
|
|
977 |
failure_default: Optional[Any] = "__UNDEFINED__"
|
978 |
|
979 |
def prepare(self):
|
980 |
-
self.types = {
|
|
|
|
|
|
|
|
|
|
|
|
|
981 |
|
982 |
def process_value(self, value):
|
983 |
try:
|
@@ -1658,63 +1708,6 @@ class RemoveValues(FieldOperator):
|
|
1658 |
return [e for e in value if e not in self.unallowed_values]
|
1659 |
|
1660 |
|
1661 |
-
class Unique(SingleStreamReducer):
|
1662 |
-
"""Reduces a stream to unique instances based on specified fields.
|
1663 |
-
|
1664 |
-
Args:
|
1665 |
-
fields (List[str]): The fields that should be unique in each instance.
|
1666 |
-
"""
|
1667 |
-
|
1668 |
-
fields: List[str] = field(default_factory=list)
|
1669 |
-
|
1670 |
-
@staticmethod
|
1671 |
-
def to_tuple(instance: dict, fields: List[str]) -> tuple:
|
1672 |
-
result = []
|
1673 |
-
for field_name in fields:
|
1674 |
-
value = instance[field_name]
|
1675 |
-
if isinstance(value, list):
|
1676 |
-
value = tuple(value)
|
1677 |
-
result.append(value)
|
1678 |
-
return tuple(result)
|
1679 |
-
|
1680 |
-
def process(self, stream: Stream) -> Stream:
|
1681 |
-
seen = set()
|
1682 |
-
for instance in stream:
|
1683 |
-
values = self.to_tuple(instance, self.fields)
|
1684 |
-
if values not in seen:
|
1685 |
-
seen.add(values)
|
1686 |
-
return list(seen)
|
1687 |
-
|
1688 |
-
|
1689 |
-
class SplitByValue(MultiStreamOperator):
|
1690 |
-
"""Splits a MultiStream into multiple streams based on unique values in specified fields.
|
1691 |
-
|
1692 |
-
Args:
|
1693 |
-
fields (List[str]): The fields to use when splitting the MultiStream.
|
1694 |
-
"""
|
1695 |
-
|
1696 |
-
fields: List[str] = field(default_factory=list)
|
1697 |
-
|
1698 |
-
def process(self, multi_stream: MultiStream) -> MultiStream:
|
1699 |
-
uniques = Unique(fields=self.fields)(multi_stream)
|
1700 |
-
|
1701 |
-
result = {}
|
1702 |
-
|
1703 |
-
for stream_name, stream in multi_stream.items():
|
1704 |
-
stream_unique_values = uniques[stream_name]
|
1705 |
-
for unique_values in stream_unique_values:
|
1706 |
-
filtering_values = dict(zip(self.fields, unique_values))
|
1707 |
-
filtered_streams = FilterByCondition(
|
1708 |
-
values=filtering_values, condition="eq"
|
1709 |
-
)._process_single_stream(stream)
|
1710 |
-
filtered_stream_name = (
|
1711 |
-
stream_name + "_" + nested_tuple_to_string(unique_values)
|
1712 |
-
)
|
1713 |
-
result[filtered_stream_name] = filtered_streams
|
1714 |
-
|
1715 |
-
return MultiStream(result)
|
1716 |
-
|
1717 |
-
|
1718 |
class SplitByNestedGroup(MultiStreamOperator):
|
1719 |
"""Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
|
1720 |
|
@@ -1761,6 +1754,16 @@ class SplitByNestedGroup(MultiStreamOperator):
|
|
1761 |
return MultiStream.from_iterables(result)
|
1762 |
|
1763 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1764 |
class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
|
1765 |
"""Applies stream operators to a stream based on specified fields in each instance.
|
1766 |
|
@@ -2516,10 +2519,13 @@ class WikipediaFetcher(FieldOperator):
|
|
2516 |
|
2517 |
return {"title": page.title, "body": getattr(page, self.mode)}
|
2518 |
|
|
|
2519 |
class Fillna(FieldOperator):
|
2520 |
value: Any
|
|
|
2521 |
def process_value(self, value: Any) -> Any:
|
2522 |
import numpy as np
|
|
|
2523 |
try:
|
2524 |
if np.isnan(value):
|
2525 |
return self.value
|
|
|
76 |
PagedStreamOperator,
|
77 |
SequentialOperator,
|
78 |
SideEffectOperator,
|
|
|
79 |
SourceOperator,
|
80 |
StreamingOperator,
|
81 |
StreamInitializerOperator,
|
|
|
84 |
from .random_utils import new_random_generator
|
85 |
from .settings_utils import get_settings
|
86 |
from .stream import DynamicStream, Stream
|
87 |
+
from .text_utils import to_pretty_string
|
88 |
from .type_utils import isoftype
|
89 |
from .utils import (
|
90 |
LRUCache,
|
|
|
282 |
dict_set(instance, key, value)
|
283 |
return instance
|
284 |
|
285 |
+
|
286 |
def recursive_key_value_replace(data, target_key, value_map, value_remove=None):
|
287 |
"""Recursively traverses a data structure (dicts and lists), replaces values of target_key using value_map, and removes values listed in value_remove.
|
288 |
|
|
|
323 |
recursive_key_value_replace(item, target_key, value_map, value_remove)
|
324 |
return data
|
325 |
|
326 |
+
|
327 |
class RecursiveReplace(InstanceOperator):
|
328 |
+
# Assisted by watsonx Code Assistant
|
329 |
+
"""An operator to recursively replace values in dictionary fields of instances based on a key and a mapping of values.
|
330 |
+
|
331 |
+
Attributes:
|
332 |
+
key (str): The key in the dictionary to start the replacement process.
|
333 |
+
map_values (dict): A dictionary containing the key-value pairs to replace the original values.
|
334 |
+
remove_values (Optional[list]): An optional list of values to remove from the dictionary. Defaults to None.
|
335 |
+
|
336 |
+
Example:
|
337 |
+
RecursiveReplace(key="a", map_values={"1": "hi", "2": "bye" }, remove_values=["3"])
|
338 |
+
replaces the value of key "a" in all instances of all streams:
|
339 |
+
instance ``{"field" : [{"a": "1", "b" : "2"}, {"a" : "3", "b:" "4"}}` becomes ``{"field" : [{"a": "hi", "b" : "2"}, {"b": "4"}}``
|
340 |
+
|
341 |
+
Notice how the value of field ``"a"`` in the first instance is replaced with ``"hi"`` and the value of field ``"a"`` in the second instance is removed.
|
342 |
+
"""
|
343 |
key: str
|
344 |
map_values: dict
|
345 |
remove_values: Optional[list] = None
|
346 |
|
347 |
+
def process(
|
348 |
+
self, instance: Dict[str, Any], stream_name: Optional[str] = None
|
349 |
+
) -> Dict[str, Any]:
|
350 |
+
return recursive_key_value_replace(
|
351 |
+
instance, self.key, self.map_values, self.remove_values
|
352 |
+
)
|
353 |
+
|
354 |
|
355 |
@deprecation(version="2.0.0", alternative=Set)
|
356 |
class AddFields(Set):
|
|
|
448 |
def verify_field_definition(self):
|
449 |
if hasattr(self, "_field_to_field") and self._field_to_field is not None:
|
450 |
return
|
451 |
+
assert (self.field is None) != (
|
452 |
+
self.field_to_field is None
|
453 |
), "Must uniquely define the field to work on, through exactly one of either 'field' or 'field_to_field'"
|
454 |
assert (
|
455 |
self.to_field is None or self.field_to_field is None
|
|
|
626 |
def process_value(self, value: Any) -> Any:
|
627 |
return self.add + value
|
628 |
|
|
|
629 |
class ShuffleFieldValues(FieldOperator):
|
630 |
+
# Assisted by watsonx Code Assistant
|
631 |
+
"""An operator that shuffles the values of a list field.
|
632 |
+
|
633 |
+
the seed for shuffling in the is determined by the elements of the input field,
|
634 |
+
ensuring that the shuffling operation produces different results for different input lists,
|
635 |
+
but also that it is deterministic and reproducible.
|
636 |
|
637 |
+
Attributes:
|
638 |
+
None
|
639 |
+
|
640 |
+
Methods:
|
641 |
+
process_value(value: Any) -> Any:
|
642 |
+
Shuffles the elements of the input list and returns the shuffled list.
|
643 |
+
|
644 |
+
Parameters:
|
645 |
+
value (Any): The input list to be shuffled.
|
646 |
+
|
647 |
+
Returns:
|
648 |
+
Any: The shuffled list.
|
649 |
+
"""
|
650 |
def process_value(self, value: Any) -> Any:
|
651 |
res = list(value)
|
652 |
random_generator = new_random_generator(sub_seed=res)
|
|
|
822 |
user_turns = instance[self.user_turns_field]
|
823 |
assistant_turns = instance[self.assistant_turns_field]
|
824 |
|
825 |
+
assert len(user_turns) == len(assistant_turns) or (
|
826 |
+
len(user_turns) - len(assistant_turns) == 1
|
|
|
827 |
), "user_turns must have either the same length as assistant_turns or one more turn."
|
828 |
|
829 |
interleaved_dialog = []
|
|
|
982 |
|
983 |
|
984 |
class GetItemByIndex(FieldOperator):
|
985 |
+
"""Get the element from the fixed list by the index in the given field and store in another field.
|
986 |
+
|
987 |
+
Example:
|
988 |
+
GetItemByIndex(items_list=["dog",cat"],field="animal_index",to_field="animal")
|
989 |
+
|
990 |
+
on instance {"animal_index" : 1} will change the instance to {"animal_index" : 1, "animal" : "cat"}
|
991 |
+
|
992 |
+
"""
|
993 |
|
994 |
items_list: List[Any]
|
995 |
|
|
|
1021 |
failure_default: Optional[Any] = "__UNDEFINED__"
|
1022 |
|
1023 |
def prepare(self):
|
1024 |
+
self.types = {
|
1025 |
+
"int": int,
|
1026 |
+
"float": float,
|
1027 |
+
"str": str,
|
1028 |
+
"bool": bool,
|
1029 |
+
"tuple": tuple,
|
1030 |
+
}
|
1031 |
|
1032 |
def process_value(self, value):
|
1033 |
try:
|
|
|
1708 |
return [e for e in value if e not in self.unallowed_values]
|
1709 |
|
1710 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1711 |
class SplitByNestedGroup(MultiStreamOperator):
|
1712 |
"""Splits a MultiStream that is small - for metrics, hence: whole stream can sit in memory, split by the value of field 'group'.
|
1713 |
|
|
|
1754 |
return MultiStream.from_iterables(result)
|
1755 |
|
1756 |
|
1757 |
+
class AddIncrementalId(StreamOperator):
|
1758 |
+
|
1759 |
+
to_field: str
|
1760 |
+
|
1761 |
+
def process(self, stream: Stream, stream_name: Optional[str] = None) -> Generator:
|
1762 |
+
for i, instance in enumerate(stream):
|
1763 |
+
instance[self.to_field] = i
|
1764 |
+
yield instance
|
1765 |
+
|
1766 |
+
|
1767 |
class ApplyStreamOperatorsField(StreamOperator, ArtifactFetcherMixin):
|
1768 |
"""Applies stream operators to a stream based on specified fields in each instance.
|
1769 |
|
|
|
2519 |
|
2520 |
return {"title": page.title, "body": getattr(page, self.mode)}
|
2521 |
|
2522 |
+
|
2523 |
class Fillna(FieldOperator):
|
2524 |
value: Any
|
2525 |
+
|
2526 |
def process_value(self, value: Any) -> Any:
|
2527 |
import numpy as np
|
2528 |
+
|
2529 |
try:
|
2530 |
if np.isnan(value):
|
2531 |
return self.value
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.23.
|
|
|
1 |
+
version = "1.23.1"
|