Upload folder using huggingface_hub
Browse files- README.md +81 -43
- evaluate_cli.py +4 -2
- hf_utils.py +22 -4
- inference.py +38 -262
- llm_as_judge_constants.py +7 -5
- loaders.py +246 -27
- metrics.py +165 -14
- operator.py +1 -2
- operators.py +38 -1
- settings_utils.py +1 -0
- struct_data_operators.py +67 -0
- utils.py +47 -0
- version.py +1 -1
README.md
CHANGED
@@ -8,23 +8,9 @@ app_file: README.md
|
|
8 |
pinned: false
|
9 |
---
|
10 |
<div align="center">
|
11 |
-
<img src="https://
|
12 |
</div>
|
13 |
|
14 |
-
[](https://unitxt.readthedocs.io/en/latest/_static/video.mov)
|
15 |
-
[](https://unitxt.readthedocs.io/en/latest/docs/introduction.html)
|
16 |
-
[](https://unitxt.readthedocs.io/en/latest/docs/demo.html)
|
17 |
-
[](https://unitxt.readthedocs.io/en/latest/docs/adding_dataset.html)
|
18 |
-
[](https://arxiv.org/abs/2401.14019)
|
19 |
-
[](https://unitxt.readthedocs.io/en/latest/catalog/catalog.__dir__.html)
|
20 |
-
[](https://github.com/IBM/unitxt/blob/main/CONTRIBUTING.md)
|
21 |
-
[](https://pypi.org/project/unitxt/)
|
22 |
-
|
23 |
-
|
24 |
-
In the dynamic landscape of generative NLP, traditional text processing pipelines limit research flexibility and reproducibility, as they are tailored to specific dataset, task, and model combinations. The escalating complexity, involving system prompts, model-specific formats, instructions, and more, calls for a shift to a structured, modular, and customizable solution.
|
25 |
-
|
26 |
-
Addressing this need, we present Unitxt, an innovative library for customizable textual data preparation and evaluation tailored to generative language models. Unitxt natively integrates with common libraries like HuggingFace and LM-eval-harness and deconstructs processing flows into modular components, enabling easy customization and sharing between practitioners. These components encompass model-specific formats, task prompts, and many other comprehensive dataset processing definitions. The Unitxt-Catalog centralizes these components, fostering collaboration and exploration in modern textual data workflows. Beyond being a tool, Unitxt is a community-driven platform, empowering users to build, share, and advance their pipelines collaboratively.
|
27 |
-
|
28 |
#
|
29 |
[](https://pypi.org/project/unitxt/)
|
30 |

|
@@ -34,34 +20,93 @@ In the dynamic landscape of generative NLP, traditional text processing pipeline
|
|
34 |

|
35 |
[](https://pepy.tech/project/unitxt)
|
36 |
|
|
|
|
|
37 |
#
|
38 |
|
39 |
-
|
40 |
|
41 |
-
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-

|
48 |
|
49 |
-
|
50 |
|
51 |
-
|
|
|
52 |
```
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
```
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
unitxt-explore
|
58 |
```
|
59 |
|
60 |
-
#
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
See more examples in examples subdirectory.
|
65 |
|
66 |
```python
|
67 |
# Import required components
|
@@ -114,17 +159,13 @@ print("Global Results:\n", results.global_scores.summary)
|
|
114 |
print("Instance Results:\n", results.instance_scores.summary)
|
115 |
```
|
116 |
|
117 |
-
#
|
118 |
|
119 |
-
|
120 |
-
```bash
|
121 |
-
git clone [email protected]:IBM/unitxt.git
|
122 |
-
cd unitxt
|
123 |
-
pip install -e ".[dev]"
|
124 |
-
pre-commit install
|
125 |
-
```
|
126 |
|
127 |
-
#
|
|
|
|
|
128 |
|
129 |
If you use Unitxt in your research, please cite our paper:
|
130 |
|
@@ -153,8 +194,5 @@ If you use Unitxt in your research, please cite our paper:
|
|
153 |
publisher = "Association for Computational Linguistics",
|
154 |
url = "https://aclanthology.org/2024.naacl-demo.21",
|
155 |
pages = "207--215",
|
156 |
-
abstract = "In the dynamic landscape of generative NLP, traditional text processing pipelines limit research flexibility and reproducibility, as they are tailored to specific dataset, task, and model combinations. The escalating complexity, involving system prompts, model-specific formats, instructions, and more, calls for a shift to a structured, modular, and customizable solution.Addressing this need, we present Unitxt, an innovative library for customizable textual data preparation and evaluation tailored to generative language models. Unitxt natively integrates with common libraries like HuggingFace and LM-eval-harness and deconstructs processing flows into modular components, enabling easy customization and sharing between practitioners. These components encompass model-specific formats, task prompts, and many other comprehensive dataset processing definitions. The Unitxt Catalog centralizes these components, fostering collaboration and exploration in modern textual data workflows. Beyond being a tool, Unitxt is a community-driven platform, empowering users to build, share, and advance their pipelines collaboratively. Join the Unitxt community at https://github.com/IBM/unitxt",
|
157 |
}
|
158 |
-
```
|
159 |
-
|
160 |
-
Unitxt emoji designed by [OpenMoji](https://openmoji.org/#) - the open-source emoji and icon project. License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/#)
|
|
|
8 |
pinned: false
|
9 |
---
|
10 |
<div align="center">
|
11 |
+
<img src="https://www.unitxt.ai/en/latest/_static/banner.png" alt="Image Description" width="100%" />
|
12 |
</div>
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
#
|
15 |
[](https://pypi.org/project/unitxt/)
|
16 |

|
|
|
20 |

|
21 |
[](https://pepy.tech/project/unitxt)
|
22 |
|
23 |
+
### 🦄 Unitxt is a Python library for enterprise-grade evaluation of AI performance, offering the world's largest catalog of tools and data for end-to-end AI benchmarking
|
24 |
+
|
25 |
#
|
26 |
|
27 |
+
## Why Unitxt?
|
28 |
|
29 |
+
- 🌐 **Comprehensive**: Evaluate text, tables, vision, speech, and code in one unified framework
|
30 |
+
- 💼 **Enterprise-Ready**: Battle-tested components with extensive catalog of benchmarks
|
31 |
+
- 🧠 **Model Agnostic**: Works with HuggingFace, OpenAI, WatsonX, and custom models
|
32 |
+
- 🔒 **Reproducible**: Shareable, modular components ensure consistent results
|
33 |
|
34 |
+
## Quick Links
|
35 |
+
- 📖 [Documentation](https://www.unitxt.ai)
|
36 |
+
- 🚀 [Getting Started](https://www.unitxt.ai)
|
37 |
+
- 📁 [Browse Catalog](https://www.unitxt.ai/en/latest/catalog/catalog.__dir__.html)
|
|
|
38 |
|
39 |
+
# Installation
|
40 |
|
41 |
+
```bash
|
42 |
+
pip install unitxt
|
43 |
```
|
44 |
+
|
45 |
+
# Quick Start
|
46 |
+
|
47 |
+
## Command Line Evaluation
|
48 |
+
```bash
|
49 |
+
# Simple evaluation
|
50 |
+
unitxt-evaluate \
|
51 |
+
--tasks "card=cards.mmlu_pro.engineering" \
|
52 |
+
--model cross_provider \
|
53 |
+
--model_args "model_name=llama-3-1-8b-instruct" \
|
54 |
+
--limit 10
|
55 |
+
|
56 |
+
# Multi-task evaluation
|
57 |
+
unitxt-evaluate \
|
58 |
+
--tasks "card=cards.text2sql.bird+card=cards.mmlu_pro.engineering" \
|
59 |
+
--model cross_provider \
|
60 |
+
--model_args "model_name=llama-3-1-8b-instruct,max_tokens=256" \
|
61 |
+
--split test \
|
62 |
+
--limit 10 \
|
63 |
+
--output_path ./results/evaluate_cli \
|
64 |
+
--log_samples \
|
65 |
+
--apply_chat_template
|
66 |
+
|
67 |
+
# Benchmark evaluation
|
68 |
+
unitxt-evaluate \
|
69 |
+
--tasks "benchmarks.tool_calling" \
|
70 |
+
--model cross_provider \
|
71 |
+
--model_args "model_name=llama-3-1-8b-instruct,max_tokens=256" \
|
72 |
+
--split test \
|
73 |
+
--limit 10 \
|
74 |
+
--output_path ./results/evaluate_cli \
|
75 |
+
--log_samples \
|
76 |
+
--apply_chat_template
|
77 |
```
|
78 |
+
|
79 |
+
## Loading as Dataset
|
80 |
+
Load thousands of datasets in chat API format, ready for any model:
|
81 |
+
```python
|
82 |
+
from unitxt import load_dataset
|
83 |
+
|
84 |
+
dataset = load_dataset(
|
85 |
+
card="cards.gpqa.diamond",
|
86 |
+
split="test",
|
87 |
+
format="formats.chat_api",
|
88 |
+
)
|
89 |
```
|
90 |
+
|
91 |
+
## 📊 Available on The Catalog
|
92 |
+
|
93 |
+

|
94 |
+

|
95 |
+

|
96 |
+

|
97 |
+

|
98 |
+
|
99 |
+
## 🚀 Interactive Dashboard
|
100 |
+
|
101 |
+
Launch the graphical user interface to explore datasets and benchmarks:
|
102 |
+
```
|
103 |
+
pip install unitxt[ui]
|
104 |
unitxt-explore
|
105 |
```
|
106 |
|
107 |
+
# Complete Python Example
|
108 |
|
109 |
+
Evaluate your own data with any model:
|
|
|
|
|
110 |
|
111 |
```python
|
112 |
# Import required components
|
|
|
159 |
print("Instance Results:\n", results.instance_scores.summary)
|
160 |
```
|
161 |
|
162 |
+
# Contributing
|
163 |
|
164 |
+
Read the [contributing guide](./CONTRIBUTING.md) for details on how to contribute to Unitxt.
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
+
#
|
167 |
+
|
168 |
+
# Citation
|
169 |
|
170 |
If you use Unitxt in your research, please cite our paper:
|
171 |
|
|
|
194 |
publisher = "Association for Computational Linguistics",
|
195 |
url = "https://aclanthology.org/2024.naacl-demo.21",
|
196 |
pages = "207--215",
|
|
|
197 |
}
|
198 |
+
```
|
|
|
|
evaluate_cli.py
CHANGED
@@ -299,7 +299,9 @@ def cli_load_dataset(args: argparse.Namespace) -> HFDataset:
|
|
299 |
)
|
300 |
|
301 |
# this hack circumvents an issue with multi-level benchmarks (such Bluebench's translation subset) that fail when wrapped with an additional Benchmark() object.
|
302 |
-
if len(benchmark_subsets) == 1
|
|
|
|
|
303 |
source = next(iter(benchmark_subsets.values()))
|
304 |
else:
|
305 |
source = Benchmark(subsets=benchmark_subsets)
|
@@ -452,7 +454,7 @@ def initialize_inference_engine(
|
|
452 |
)
|
453 |
|
454 |
# Keep the actual model name for the results
|
455 |
-
args.model = inference_model.
|
456 |
else:
|
457 |
# This case should not be reached due to argparse choices
|
458 |
logger.error(
|
|
|
299 |
)
|
300 |
|
301 |
# this hack circumvents an issue with multi-level benchmarks (such Bluebench's translation subset) that fail when wrapped with an additional Benchmark() object.
|
302 |
+
if len(benchmark_subsets) == 1 and isinstance(
|
303 |
+
next(iter(benchmark_subsets.values())), Benchmark
|
304 |
+
):
|
305 |
source = next(iter(benchmark_subsets.values()))
|
306 |
else:
|
307 |
source = Benchmark(subsets=benchmark_subsets)
|
|
|
454 |
)
|
455 |
|
456 |
# Keep the actual model name for the results
|
457 |
+
args.model = inference_model.get_engine_id()
|
458 |
else:
|
459 |
# This case should not be reached due to argparse choices
|
460 |
logger.error(
|
hf_utils.py
CHANGED
@@ -1,11 +1,30 @@
|
|
|
|
1 |
from pathlib import Path
|
2 |
-
|
3 |
-
from datasets.utils.py_utils import get_imports
|
4 |
|
5 |
from .deprecation_utils import compare_versions
|
6 |
from .file_utils import get_all_files_in_dir
|
7 |
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
def get_missing_imports(file, exclude=None):
|
10 |
if exclude is None:
|
11 |
exclude = []
|
@@ -13,8 +32,7 @@ def get_missing_imports(file, exclude=None):
|
|
13 |
python_files = get_all_files_in_dir(src_dir, file_extension=".py")
|
14 |
# get only the file without the path and extension
|
15 |
required_modules = [Path(p).stem for p in python_files]
|
16 |
-
|
17 |
-
imported_modules = [i[1] for i in imports if i[0] == "internal"]
|
18 |
return [
|
19 |
i for i in required_modules if i not in imported_modules and i not in exclude
|
20 |
]
|
|
|
1 |
+
import re
|
2 |
from pathlib import Path
|
3 |
+
from typing import List
|
|
|
4 |
|
5 |
from .deprecation_utils import compare_versions
|
6 |
from .file_utils import get_all_files_in_dir
|
7 |
|
8 |
|
9 |
+
def get_internal_imports(file_path: str) -> List[str]:
|
10 |
+
"""Return a list of local (relative) modules directly imported in the given Python file."""
|
11 |
+
internal_imports = []
|
12 |
+
is_in_docstring = False
|
13 |
+
with open(file_path, encoding="utf-8") as f:
|
14 |
+
for line in f:
|
15 |
+
if line.count('"""') == 1 or line.count("'''") == 1:
|
16 |
+
is_in_docstring = not is_in_docstring
|
17 |
+
if is_in_docstring:
|
18 |
+
continue
|
19 |
+
# Match "import .module" or "from .module import ..."
|
20 |
+
match = re.match(r"^(?:import|from)\s+\.(\w+)", line)
|
21 |
+
if match:
|
22 |
+
module = match.group(1)
|
23 |
+
if module not in internal_imports:
|
24 |
+
internal_imports.append(module)
|
25 |
+
return internal_imports
|
26 |
+
|
27 |
+
|
28 |
def get_missing_imports(file, exclude=None):
|
29 |
if exclude is None:
|
30 |
exclude = []
|
|
|
32 |
python_files = get_all_files_in_dir(src_dir, file_extension=".py")
|
33 |
# get only the file without the path and extension
|
34 |
required_modules = [Path(p).stem for p in python_files]
|
35 |
+
imported_modules = get_internal_imports(file)
|
|
|
36 |
return [
|
37 |
i for i in required_modules if i not in imported_modules and i not in exclude
|
38 |
]
|
inference.py
CHANGED
@@ -1378,45 +1378,6 @@ class MockModeMixin(Artifact):
|
|
1378 |
mock_mode: bool = False
|
1379 |
|
1380 |
|
1381 |
-
class IbmGenAiInferenceEngineParamsMixin(Artifact):
|
1382 |
-
beam_width: Optional[int] = None
|
1383 |
-
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
1384 |
-
include_stop_sequence: Optional[bool] = None
|
1385 |
-
length_penalty: Any = None
|
1386 |
-
max_new_tokens: Optional[int] = None
|
1387 |
-
min_new_tokens: Optional[int] = None
|
1388 |
-
random_seed: Optional[int] = None
|
1389 |
-
repetition_penalty: Optional[float] = None
|
1390 |
-
return_options: Any = None
|
1391 |
-
stop_sequences: Optional[List[str]] = None
|
1392 |
-
temperature: Optional[float] = None
|
1393 |
-
time_limit: Optional[int] = None
|
1394 |
-
top_k: Optional[int] = None
|
1395 |
-
top_p: Optional[float] = None
|
1396 |
-
truncate_input_tokens: Optional[int] = None
|
1397 |
-
typical_p: Optional[float] = None
|
1398 |
-
|
1399 |
-
|
1400 |
-
@deprecation(version="2.0.0", alternative=IbmGenAiInferenceEngineParamsMixin)
|
1401 |
-
class IbmGenAiInferenceEngineParams(Artifact):
|
1402 |
-
beam_width: Optional[int] = None
|
1403 |
-
decoding_method: Optional[Literal["greedy", "sample"]] = None
|
1404 |
-
include_stop_sequence: Optional[bool] = None
|
1405 |
-
length_penalty: Any = None
|
1406 |
-
max_new_tokens: Optional[int] = None
|
1407 |
-
min_new_tokens: Optional[int] = None
|
1408 |
-
random_seed: Optional[int] = None
|
1409 |
-
repetition_penalty: Optional[float] = None
|
1410 |
-
return_options: Any = None
|
1411 |
-
stop_sequences: Optional[List[str]] = None
|
1412 |
-
temperature: Optional[float] = None
|
1413 |
-
time_limit: Optional[int] = None
|
1414 |
-
top_k: Optional[int] = None
|
1415 |
-
top_p: Optional[float] = None
|
1416 |
-
truncate_input_tokens: Optional[int] = None
|
1417 |
-
typical_p: Optional[float] = None
|
1418 |
-
|
1419 |
-
|
1420 |
class GenericInferenceEngine(
|
1421 |
InferenceEngine, ArtifactFetcherMixin, LogProbInferenceEngine
|
1422 |
):
|
@@ -1430,7 +1391,7 @@ class GenericInferenceEngine(
|
|
1430 |
"GenericInferenceEngine could not be initialized"
|
1431 |
'\nThis is since both the "UNITXT_INFERENCE_ENGINE" environmental variable is not set and no default engine was not inputted.'
|
1432 |
"\nFor example, you can fix it by setting"
|
1433 |
-
"\nexport UNITXT_INFERENCE_ENGINE=engines.
|
1434 |
"\nto your ~/.bashrc"
|
1435 |
"\nor passing a similar required engine in the default argument"
|
1436 |
)
|
@@ -1601,214 +1562,6 @@ class OptionSelectingByLogProbsInferenceEngine:
|
|
1601 |
return dataset
|
1602 |
|
1603 |
|
1604 |
-
class IbmGenAiInferenceEngine(
|
1605 |
-
InferenceEngine,
|
1606 |
-
IbmGenAiInferenceEngineParamsMixin,
|
1607 |
-
PackageRequirementsMixin,
|
1608 |
-
LogProbInferenceEngine,
|
1609 |
-
OptionSelectingByLogProbsInferenceEngine,
|
1610 |
-
):
|
1611 |
-
label: str = "ibm_genai"
|
1612 |
-
model_name: str
|
1613 |
-
_requirements_list = {
|
1614 |
-
"ibm-generative-ai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"
|
1615 |
-
}
|
1616 |
-
data_classification_policy = ["public", "proprietary"]
|
1617 |
-
parameters: Optional[IbmGenAiInferenceEngineParams] = None
|
1618 |
-
rate_limit: int = 10
|
1619 |
-
|
1620 |
-
def get_engine_id(self):
|
1621 |
-
return get_model_and_label_id(self.model_name, self.label)
|
1622 |
-
|
1623 |
-
@staticmethod
|
1624 |
-
def _get_credentials():
|
1625 |
-
from genai import Credentials
|
1626 |
-
|
1627 |
-
api_key_env_var_name = "GENAI_KEY" # pragma: allowlist secret
|
1628 |
-
api_key = os.environ.get(api_key_env_var_name)
|
1629 |
-
|
1630 |
-
assert api_key is not None, (
|
1631 |
-
f"Error while trying to run IbmGenAiInferenceEngine."
|
1632 |
-
f" Please set the environment param '{api_key_env_var_name}'."
|
1633 |
-
)
|
1634 |
-
|
1635 |
-
return Credentials(api_key=api_key)
|
1636 |
-
|
1637 |
-
def prepare_engine(self):
|
1638 |
-
self.check_missing_requirements()
|
1639 |
-
|
1640 |
-
from genai import Client
|
1641 |
-
from genai.text.generation import CreateExecutionOptions
|
1642 |
-
|
1643 |
-
credentials = self._get_credentials()
|
1644 |
-
self.client = Client(credentials=credentials)
|
1645 |
-
|
1646 |
-
self.execution_options = CreateExecutionOptions(
|
1647 |
-
concurrency_limit=self.rate_limit
|
1648 |
-
)
|
1649 |
-
|
1650 |
-
self._set_inference_parameters()
|
1651 |
-
|
1652 |
-
def _infer(
|
1653 |
-
self,
|
1654 |
-
dataset: Union[List[Dict[str, Any]], Dataset],
|
1655 |
-
return_meta_data: bool = False,
|
1656 |
-
) -> Union[List[str], List[TextGenerationInferenceOutput]]:
|
1657 |
-
from genai.schema import TextGenerationParameters, TextGenerationResult
|
1658 |
-
|
1659 |
-
self.verify_not_chat_api(dataset)
|
1660 |
-
|
1661 |
-
genai_params = TextGenerationParameters(
|
1662 |
-
**self.to_dict([IbmGenAiInferenceEngineParamsMixin])
|
1663 |
-
)
|
1664 |
-
|
1665 |
-
responses = self.client.text.generation.create(
|
1666 |
-
model_id=self.model_name,
|
1667 |
-
inputs=[instance["source"] for instance in dataset],
|
1668 |
-
parameters=genai_params,
|
1669 |
-
execution_options=self.execution_options,
|
1670 |
-
)
|
1671 |
-
|
1672 |
-
results = []
|
1673 |
-
for response in responses:
|
1674 |
-
generation_result: TextGenerationResult = response.results[0]
|
1675 |
-
result = self.get_return_object(
|
1676 |
-
generation_result.generated_text, generation_result, return_meta_data
|
1677 |
-
)
|
1678 |
-
results.append(result)
|
1679 |
-
return results
|
1680 |
-
|
1681 |
-
def _infer_log_probs(
|
1682 |
-
self,
|
1683 |
-
dataset: Union[List[Dict[str, Any]], Dataset],
|
1684 |
-
return_meta_data: bool = False,
|
1685 |
-
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
1686 |
-
from genai.schema import TextGenerationParameters, TextGenerationResult
|
1687 |
-
|
1688 |
-
self.verify_not_chat_api(dataset)
|
1689 |
-
|
1690 |
-
logprobs_return_options = {
|
1691 |
-
"generated_tokens": True,
|
1692 |
-
"input_text": False,
|
1693 |
-
"input_tokens": False,
|
1694 |
-
"token_logprobs": True,
|
1695 |
-
"token_ranks": True,
|
1696 |
-
"top_n_tokens": 5,
|
1697 |
-
}
|
1698 |
-
genai_params = self.to_dict(
|
1699 |
-
[IbmGenAiInferenceEngineParamsMixin], keep_empty=False
|
1700 |
-
)
|
1701 |
-
genai_params = {**genai_params, "return_options": logprobs_return_options}
|
1702 |
-
genai_params = TextGenerationParameters(**genai_params)
|
1703 |
-
predictions = self.client.text.generation.create(
|
1704 |
-
model_id=self.model_name,
|
1705 |
-
inputs=[instance["source"] for instance in dataset],
|
1706 |
-
parameters=genai_params,
|
1707 |
-
execution_options=self.execution_options,
|
1708 |
-
)
|
1709 |
-
|
1710 |
-
predict_results = []
|
1711 |
-
for prediction in predictions:
|
1712 |
-
result: TextGenerationResult = prediction.results[0]
|
1713 |
-
assert isinstance(
|
1714 |
-
result.generated_tokens, list
|
1715 |
-
), "result.generated_tokens should be a list"
|
1716 |
-
|
1717 |
-
predict_result = []
|
1718 |
-
for base_token in result.generated_tokens:
|
1719 |
-
res = {**base_token.__dict__, **base_token.model_extra}
|
1720 |
-
res["top_tokens"] = [
|
1721 |
-
{"logprob": top_token.logprob, "text": top_token.text}
|
1722 |
-
for top_token in res["top_tokens"]
|
1723 |
-
]
|
1724 |
-
predict_result.append(res)
|
1725 |
-
final_results = self.get_return_object(
|
1726 |
-
predict_result, result, return_meta_data
|
1727 |
-
)
|
1728 |
-
predict_results.append(final_results)
|
1729 |
-
return predict_results
|
1730 |
-
|
1731 |
-
def get_return_object(self, predict_result, result, return_meta_data):
|
1732 |
-
if return_meta_data:
|
1733 |
-
return TextGenerationInferenceOutput(
|
1734 |
-
prediction=predict_result,
|
1735 |
-
input_tokens=result.input_token_count,
|
1736 |
-
output_tokens=result.generated_token_count,
|
1737 |
-
model_name=self.model_name,
|
1738 |
-
inference_type=self.label,
|
1739 |
-
input_text=result.input_text,
|
1740 |
-
seed=self.random_seed,
|
1741 |
-
stop_reason=result.stop_reason,
|
1742 |
-
)
|
1743 |
-
return predict_result
|
1744 |
-
|
1745 |
-
def get_model_details(self) -> Dict:
|
1746 |
-
from genai import ApiClient
|
1747 |
-
from genai.model import ModelService
|
1748 |
-
|
1749 |
-
api_client = ApiClient(credentials=self._get_credentials())
|
1750 |
-
model_info = (
|
1751 |
-
ModelService(api_client=api_client).retrieve(id=self.model_name).result
|
1752 |
-
)
|
1753 |
-
return model_info.dict()
|
1754 |
-
|
1755 |
-
def get_token_count(self, dataset):
|
1756 |
-
texts = [instance["source"] for instance in dataset]
|
1757 |
-
token_counts = list(
|
1758 |
-
tqdm(
|
1759 |
-
[
|
1760 |
-
result.token_count
|
1761 |
-
for response in self.client.text.tokenization.create(
|
1762 |
-
model_id=self.model_name,
|
1763 |
-
input=texts,
|
1764 |
-
execution_options={"ordered": True},
|
1765 |
-
)
|
1766 |
-
for result in response.results
|
1767 |
-
],
|
1768 |
-
desc="Tokenizing",
|
1769 |
-
total=len(texts),
|
1770 |
-
)
|
1771 |
-
)
|
1772 |
-
for i, token_count in enumerate(token_counts):
|
1773 |
-
dataset[i]["token_count"] = token_count
|
1774 |
-
return dataset
|
1775 |
-
|
1776 |
-
def get_options_log_probs(self, dataset):
|
1777 |
-
"""Add to each instance in the data a "options_log_prob" field, which is a dict with str as key and a list of {text: str, logprob:float}."""
|
1778 |
-
from genai.schema import TextGenerationParameters, TextGenerationReturnOptions
|
1779 |
-
|
1780 |
-
texts = [x["source"] for x in dataset]
|
1781 |
-
|
1782 |
-
responses = tqdm(
|
1783 |
-
self.client.text.generation.create(
|
1784 |
-
model_id=self.model_name,
|
1785 |
-
inputs=texts,
|
1786 |
-
execution_options={"ordered": True},
|
1787 |
-
parameters=TextGenerationParameters(
|
1788 |
-
max_new_tokens=1,
|
1789 |
-
return_options=TextGenerationReturnOptions(
|
1790 |
-
input_tokens=True, token_logprobs=True
|
1791 |
-
),
|
1792 |
-
# random_seed=self.random_state
|
1793 |
-
),
|
1794 |
-
),
|
1795 |
-
total=len(texts),
|
1796 |
-
desc="Completions",
|
1797 |
-
)
|
1798 |
-
|
1799 |
-
scores = [
|
1800 |
-
[
|
1801 |
-
{"text": token.text, "logprob": token.logprob}
|
1802 |
-
for token in response.results[0].input_tokens
|
1803 |
-
]
|
1804 |
-
for response in responses
|
1805 |
-
]
|
1806 |
-
|
1807 |
-
for instance, score in zip(dataset, scores):
|
1808 |
-
instance["prediction"] = score[instance["task_data"]["token_count"] - 1 :]
|
1809 |
-
return dataset
|
1810 |
-
|
1811 |
-
|
1812 |
class CredentialsOpenAi(TypedDict, total=False):
|
1813 |
api_key: str
|
1814 |
api_url: str
|
@@ -2099,6 +1852,11 @@ class RITSInferenceEngine(
|
|
2099 |
"meta-llama/Llama-3.1-8B-Instruct": "llama-3-1-8b-instruct",
|
2100 |
"meta-llama/Llama-4-Scout-17B-16E-Instruct": "llama-4-scout-17b-16e-instruct",
|
2101 |
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": "mistral-small-3-1-24b-2503",
|
|
|
|
|
|
|
|
|
|
|
2102 |
}
|
2103 |
|
2104 |
def get_default_headers(self):
|
@@ -3467,16 +3225,18 @@ _supported_apis = Literal[
|
|
3467 |
"open-ai",
|
3468 |
"aws",
|
3469 |
"ollama",
|
3470 |
-
"bam",
|
3471 |
"watsonx-sdk",
|
3472 |
"rits",
|
3473 |
"azure",
|
3474 |
"vertex-ai",
|
3475 |
"replicate",
|
|
|
3476 |
]
|
3477 |
|
3478 |
|
3479 |
-
class CrossProviderInferenceEngine(
|
|
|
|
|
3480 |
"""Inference engine capable of dynamically switching between multiple providers APIs.
|
3481 |
|
3482 |
This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin
|
@@ -3516,7 +3276,11 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3516 |
"granite-3-3-2b-instruct": "ibm/granite-3-3-2b-instruct",
|
3517 |
"granite-3-3-8b-instruct": "ibm/granite-3-3-8b-instruct",
|
3518 |
"granite-34b-code-instruct": "ibm/granite-34b-code-instruct",
|
|
|
3519 |
"granite-guardian-3-8b": "ibm/granite-guardian-3-8b",
|
|
|
|
|
|
|
3520 |
"granite-vision-3-2-2b": "ibm/granite-vision-3-2-2b",
|
3521 |
"llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
|
3522 |
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
@@ -3570,17 +3334,14 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3570 |
"granite-3-3-2b-instruct": "granite3.3:2b",
|
3571 |
"granite-3-3-8b-instruct": "granite3.3:8b",
|
3572 |
},
|
3573 |
-
"bam": {
|
3574 |
-
"granite-3-8b-instruct": "ibm/granite-8b-instruct-preview-4k",
|
3575 |
-
"llama-3-8b-instruct": "meta-llama/llama-3-8b-instruct",
|
3576 |
-
"llama-3-2-1b-instruct": "meta-llama/llama-3-2-1b-instruct",
|
3577 |
-
"flan-t5-xxl": "google/flan-t5-xxl",
|
3578 |
-
},
|
3579 |
"rits": {
|
3580 |
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
|
3581 |
"granite-3-1-8b-instruct": "ibm-granite/granite-3.1-8b-instruct",
|
3582 |
"granite-3-2-8b-instruct": "ibm-granite/granite-3.2-8b-instruct",
|
3583 |
"granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
|
|
|
|
|
|
|
3584 |
"llama-3-1-8b-instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
3585 |
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
3586 |
"llama-3-1-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
|
@@ -3595,9 +3356,9 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3595 |
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3596 |
"mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3597 |
"deepseek-v3": "deepseek-ai/DeepSeek-V3",
|
3598 |
-
"granite-guardian-3-2-3b-a800m": "ibm-granite/granite-guardian-3.2-3b-a800m",
|
3599 |
-
"granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
|
3600 |
"phi-4": "microsoft/phi-4",
|
|
|
|
|
3601 |
},
|
3602 |
"open-ai": {
|
3603 |
"o1-mini": "o1-mini",
|
@@ -3699,9 +3460,16 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3699 |
"gpt-4-1": "replicate/openai/gpt-4.1",
|
3700 |
},
|
3701 |
"hf-local": {
|
3702 |
-
"granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
|
3703 |
"llama-3-3-8b-instruct": "meta-llama/Llama-3.3-8B-Instruct",
|
3704 |
"SmolLM2-1.7B-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3705 |
},
|
3706 |
}
|
3707 |
provider_model_map["watsonx"] = {
|
@@ -3714,7 +3482,6 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3714 |
"together-ai": LiteLLMInferenceEngine,
|
3715 |
"aws": LiteLLMInferenceEngine,
|
3716 |
"ollama": OllamaInferenceEngine,
|
3717 |
-
"bam": IbmGenAiInferenceEngine,
|
3718 |
"watsonx-sdk": WMLInferenceEngineChat,
|
3719 |
"rits": RITSInferenceEngine,
|
3720 |
"azure": LiteLLMInferenceEngine,
|
@@ -3724,7 +3491,6 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3724 |
}
|
3725 |
|
3726 |
_provider_param_renaming = {
|
3727 |
-
"bam": {"max_tokens": "max_new_tokens", "model": "model_name"},
|
3728 |
"watsonx-sdk": {"model": "model_name"},
|
3729 |
"rits": {"model": "model_name"},
|
3730 |
"hf-local": {"model": "model_name", "max_tokens": "max_new_tokens"},
|
@@ -3737,7 +3503,6 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3737 |
return self.provider if self.provider is not None else settings.default_provider
|
3738 |
|
3739 |
def prepare_engine(self):
|
3740 |
-
# print("provider", self.provider)
|
3741 |
provider = self.get_provider_name()
|
3742 |
if provider not in self._provider_to_base_class:
|
3743 |
raise UnitxtError(
|
@@ -3783,6 +3548,17 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
|
|
3783 |
return get_model_and_label_id(self.provider_model_map[api][self.model], api)
|
3784 |
return get_model_and_label_id(self.model, api)
|
3785 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3786 |
|
3787 |
class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
|
3788 |
"""HuggingFace based class for inference engines that calculate log probabilities.
|
|
|
1378 |
mock_mode: bool = False
|
1379 |
|
1380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1381 |
class GenericInferenceEngine(
|
1382 |
InferenceEngine, ArtifactFetcherMixin, LogProbInferenceEngine
|
1383 |
):
|
|
|
1391 |
"GenericInferenceEngine could not be initialized"
|
1392 |
'\nThis is since both the "UNITXT_INFERENCE_ENGINE" environmental variable is not set and no default engine was not inputted.'
|
1393 |
"\nFor example, you can fix it by setting"
|
1394 |
+
"\nexport UNITXT_INFERENCE_ENGINE=engines.ibm_wml.llama_3_70b_instruct"
|
1395 |
"\nto your ~/.bashrc"
|
1396 |
"\nor passing a similar required engine in the default argument"
|
1397 |
)
|
|
|
1562 |
return dataset
|
1563 |
|
1564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1565 |
class CredentialsOpenAi(TypedDict, total=False):
|
1566 |
api_key: str
|
1567 |
api_url: str
|
|
|
1852 |
"meta-llama/Llama-3.1-8B-Instruct": "llama-3-1-8b-instruct",
|
1853 |
"meta-llama/Llama-4-Scout-17B-16E-Instruct": "llama-4-scout-17b-16e-instruct",
|
1854 |
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": "mistral-small-3-1-24b-2503",
|
1855 |
+
"ibm-granite/granite-guardian-3.2-3b-a800m": "granite-guardian-3-2-3b-a800m",
|
1856 |
+
"ibm-granite/granite-guardian-3.2-5b": "granite-guardian-3-2-5b-ris",
|
1857 |
+
"granite-guardian-3-2-5b-ris": "granite-guardian-3-3-8b",
|
1858 |
+
"openai/gpt-oss-20b": "gpt-oss-20b",
|
1859 |
+
"openai/gpt-oss-120b": "gpt-oss-120b",
|
1860 |
}
|
1861 |
|
1862 |
def get_default_headers(self):
|
|
|
3225 |
"open-ai",
|
3226 |
"aws",
|
3227 |
"ollama",
|
|
|
3228 |
"watsonx-sdk",
|
3229 |
"rits",
|
3230 |
"azure",
|
3231 |
"vertex-ai",
|
3232 |
"replicate",
|
3233 |
+
"hf-local",
|
3234 |
]
|
3235 |
|
3236 |
|
3237 |
+
class CrossProviderInferenceEngine(
|
3238 |
+
InferenceEngine, StandardAPIParamsMixin, LogProbInferenceEngine
|
3239 |
+
):
|
3240 |
"""Inference engine capable of dynamically switching between multiple providers APIs.
|
3241 |
|
3242 |
This class extends the InferenceEngine and OpenAiInferenceEngineParamsMixin
|
|
|
3276 |
"granite-3-3-2b-instruct": "ibm/granite-3-3-2b-instruct",
|
3277 |
"granite-3-3-8b-instruct": "ibm/granite-3-3-8b-instruct",
|
3278 |
"granite-34b-code-instruct": "ibm/granite-34b-code-instruct",
|
3279 |
+
"granite-guardian-3-2b": "ibm/granite-guardian-3-2b",
|
3280 |
"granite-guardian-3-8b": "ibm/granite-guardian-3-8b",
|
3281 |
+
"granite-guardian-3-1-2b": "ibm/granite-guardian-3-2b", # LifecycleWarning: Model 'ibm/granite-guardian-3-2b' is in deprecated state from 2025-07-09 until 2025-10-08. IDs of alternative models: ibm/granite-guardian-3-2-5b.
|
3282 |
+
"granite-guardian-3-1-8b": "ibm/granite-guardian-3-8b",
|
3283 |
+
"granite-guardian-3-2-5b": "ibm/granite-guardian-3-2-5b",
|
3284 |
"granite-vision-3-2-2b": "ibm/granite-vision-3-2-2b",
|
3285 |
"llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
|
3286 |
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
|
|
3334 |
"granite-3-3-2b-instruct": "granite3.3:2b",
|
3335 |
"granite-3-3-8b-instruct": "granite3.3:8b",
|
3336 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
3337 |
"rits": {
|
3338 |
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
|
3339 |
"granite-3-1-8b-instruct": "ibm-granite/granite-3.1-8b-instruct",
|
3340 |
"granite-3-2-8b-instruct": "ibm-granite/granite-3.2-8b-instruct",
|
3341 |
"granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
|
3342 |
+
"granite-guardian-3-2-3b": "ibm-granite/granite-guardian-3.2-3b-a800m",
|
3343 |
+
"granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
|
3344 |
+
"granite-guardian-3-3-8b": "ibm-granite/granite-guardian-3.3-8b",
|
3345 |
"llama-3-1-8b-instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
3346 |
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
|
3347 |
"llama-3-1-405b-instruct": "meta-llama/llama-3-1-405b-instruct-fp8",
|
|
|
3356 |
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3357 |
"mixtral-8x7b-instruct-v01": "mistralai/mixtral-8x7B-instruct-v0.1",
|
3358 |
"deepseek-v3": "deepseek-ai/DeepSeek-V3",
|
|
|
|
|
3359 |
"phi-4": "microsoft/phi-4",
|
3360 |
+
"gpt-oss-20b": "openai/gpt-oss-20b",
|
3361 |
+
"gpt-oss-120b": "openai/gpt-oss-120b",
|
3362 |
},
|
3363 |
"open-ai": {
|
3364 |
"o1-mini": "o1-mini",
|
|
|
3460 |
"gpt-4-1": "replicate/openai/gpt-4.1",
|
3461 |
},
|
3462 |
"hf-local": {
|
|
|
3463 |
"llama-3-3-8b-instruct": "meta-llama/Llama-3.3-8B-Instruct",
|
3464 |
"SmolLM2-1.7B-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
3465 |
+
"granite-guardian-3-1-2b": "ibm-granite/granite-guardian-3.1-2b",
|
3466 |
+
"granite-guardian-3-1-8b": "ibm-granite/granite-guardian-3.1-8b",
|
3467 |
+
"granite-guardian-3-2-3b": "ibm-granite/granite-guardian-3.2-3b-a800m",
|
3468 |
+
"granite-guardian-3-2-5b": "ibm-granite/granite-guardian-3.2-5b",
|
3469 |
+
"granite-guardian-3-3-8b": "ibm-granite/granite-guardian-3.3-8b",
|
3470 |
+
"granite-3-3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
|
3471 |
+
"granite-3-3-8b-instruct": "ibm-granite/granite-3.3-8b-instruct",
|
3472 |
+
"granite-4-0-tiny-preview": "ibm-granite/granite-4.0-tiny-preview",
|
3473 |
},
|
3474 |
}
|
3475 |
provider_model_map["watsonx"] = {
|
|
|
3482 |
"together-ai": LiteLLMInferenceEngine,
|
3483 |
"aws": LiteLLMInferenceEngine,
|
3484 |
"ollama": OllamaInferenceEngine,
|
|
|
3485 |
"watsonx-sdk": WMLInferenceEngineChat,
|
3486 |
"rits": RITSInferenceEngine,
|
3487 |
"azure": LiteLLMInferenceEngine,
|
|
|
3491 |
}
|
3492 |
|
3493 |
_provider_param_renaming = {
|
|
|
3494 |
"watsonx-sdk": {"model": "model_name"},
|
3495 |
"rits": {"model": "model_name"},
|
3496 |
"hf-local": {"model": "model_name", "max_tokens": "max_new_tokens"},
|
|
|
3503 |
return self.provider if self.provider is not None else settings.default_provider
|
3504 |
|
3505 |
def prepare_engine(self):
|
|
|
3506 |
provider = self.get_provider_name()
|
3507 |
if provider not in self._provider_to_base_class:
|
3508 |
raise UnitxtError(
|
|
|
3548 |
return get_model_and_label_id(self.provider_model_map[api][self.model], api)
|
3549 |
return get_model_and_label_id(self.model, api)
|
3550 |
|
3551 |
+
def _infer_log_probs(
|
3552 |
+
self,
|
3553 |
+
dataset: Union[List[Dict[str, Any]], Dataset],
|
3554 |
+
return_meta_data: bool = False,
|
3555 |
+
) -> Union[List[Dict], List[TextGenerationInferenceOutput]]:
|
3556 |
+
if not isinstance(self.engine, LogProbInferenceEngine):
|
3557 |
+
raise UnitxtError(
|
3558 |
+
f"The underlying inference engine of this instance of CrossProviderInferenceEngine ({self.engine.get_engine_id()}) must inherit from LogProbInferenceEngine and implement _infer_log_probs"
|
3559 |
+
)
|
3560 |
+
return self.engine._infer_log_probs(dataset, return_meta_data)
|
3561 |
+
|
3562 |
|
3563 |
class HFOptionSelectingInferenceEngine(InferenceEngine, TorchDeviceMixin):
|
3564 |
"""HuggingFace based class for inference engines that calculate log probabilities.
|
llm_as_judge_constants.py
CHANGED
@@ -32,7 +32,7 @@ class Criteria(Artifact):
|
|
32 |
prediction_field: Optional[str] = None
|
33 |
"""The prediction field name this criteria expects and refers to, e.g. answer/model response/summary"""
|
34 |
|
35 |
-
context_fields: Union[str, List[str], Dict[str, str]] = None
|
36 |
"""The context field names this criteria expects, i.e. [context]/[source article, user questions]"""
|
37 |
|
38 |
@staticmethod
|
@@ -370,7 +370,7 @@ class DirectCriteriaCatalogEnum(Enum):
|
|
370 |
name="conciseness",
|
371 |
description="Is the response concise and to the point?",
|
372 |
prediction_field="response",
|
373 |
-
context_fields=[],
|
374 |
options=[
|
375 |
CriteriaOption(
|
376 |
name="Yes",
|
@@ -1603,7 +1603,9 @@ Errors: Are there any errors in grammar, vocabulary, punctuation, or formatting
|
|
1603 |
)
|
1604 |
|
1605 |
|
1606 |
-
DIRECT_CRITERIA = [
|
|
|
|
|
1607 |
|
1608 |
|
1609 |
class PairwiseCriteriaCatalogEnum(Enum):
|
@@ -1625,7 +1627,7 @@ class PairwiseCriteriaCatalogEnum(Enum):
|
|
1625 |
name="factually_consistent",
|
1626 |
description="A factually consistent response contains only statements that are entailed by the source document.",
|
1627 |
prediction_field="response",
|
1628 |
-
context_fields=[],
|
1629 |
)
|
1630 |
|
1631 |
INCLUSIVITY = Criteria(
|
@@ -1658,4 +1660,4 @@ class PairwiseCriteriaCatalogEnum(Enum):
|
|
1658 |
)
|
1659 |
|
1660 |
|
1661 |
-
PAIRWISE_CRITERIA = [c.value for c in PairwiseCriteriaCatalogEnum]
|
|
|
32 |
prediction_field: Optional[str] = None
|
33 |
"""The prediction field name this criteria expects and refers to, e.g. answer/model response/summary"""
|
34 |
|
35 |
+
context_fields: Optional[Union[str, List[str], Dict[str, str]]] = None
|
36 |
"""The context field names this criteria expects, i.e. [context]/[source article, user questions]"""
|
37 |
|
38 |
@staticmethod
|
|
|
370 |
name="conciseness",
|
371 |
description="Is the response concise and to the point?",
|
372 |
prediction_field="response",
|
373 |
+
context_fields=["question"],
|
374 |
options=[
|
375 |
CriteriaOption(
|
376 |
name="Yes",
|
|
|
1603 |
)
|
1604 |
|
1605 |
|
1606 |
+
DIRECT_CRITERIA: List[CriteriaWithOptions] = [
|
1607 |
+
c.value for c in DirectCriteriaCatalogEnum
|
1608 |
+
]
|
1609 |
|
1610 |
|
1611 |
class PairwiseCriteriaCatalogEnum(Enum):
|
|
|
1627 |
name="factually_consistent",
|
1628 |
description="A factually consistent response contains only statements that are entailed by the source document.",
|
1629 |
prediction_field="response",
|
1630 |
+
context_fields=["source document"],
|
1631 |
)
|
1632 |
|
1633 |
INCLUSIVITY = Criteria(
|
|
|
1660 |
)
|
1661 |
|
1662 |
|
1663 |
+
PAIRWISE_CRITERIA: List[Criteria] = [c.value for c in PairwiseCriteriaCatalogEnum]
|
loaders.py
CHANGED
@@ -24,6 +24,7 @@ Available Loaders Overview:
|
|
24 |
- :class:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
|
25 |
- :class:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
|
26 |
- :class:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from HuggingFace Spaces.
|
|
|
27 |
|
28 |
|
29 |
|
@@ -52,6 +53,7 @@ from typing import (
|
|
52 |
Union,
|
53 |
)
|
54 |
|
|
|
55 |
import pandas as pd
|
56 |
import requests
|
57 |
from datasets import (
|
@@ -62,6 +64,7 @@ from datasets import (
|
|
62 |
)
|
63 |
from datasets import load_dataset as _hf_load_dataset
|
64 |
from huggingface_hub import HfApi
|
|
|
65 |
from tqdm import tqdm
|
66 |
|
67 |
from .dataclass import NonPositionalField
|
@@ -96,21 +99,19 @@ def hf_load_dataset(path: str, *args, **kwargs):
|
|
96 |
):
|
97 |
if settings.hf_offline_datasets_path is not None:
|
98 |
path = os.path.join(settings.hf_offline_datasets_path, path)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
raise UnitxtUnverifiedCodeError(path) from e
|
113 |
-
raise e # Re raise
|
114 |
|
115 |
|
116 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
@@ -119,13 +120,9 @@ def hf_get_dataset_splits(path: str, name: str, revision=None):
|
|
119 |
return get_dataset_split_names(
|
120 |
path=path,
|
121 |
config_name=name,
|
122 |
-
trust_remote_code=settings.allow_unverified_code,
|
123 |
revision=revision,
|
124 |
)
|
125 |
except Exception as e:
|
126 |
-
if "trust_remote_code" in str(e):
|
127 |
-
raise UnitxtUnverifiedCodeError(path) from e
|
128 |
-
|
129 |
if "Couldn't find cache" in str(e):
|
130 |
raise FileNotFoundError(
|
131 |
f"Dataset cache path={path}, name={name} was not found."
|
@@ -354,7 +351,7 @@ class LoadHF(LazyLoader):
|
|
354 |
raise NotImplementedError() from None
|
355 |
|
356 |
if not disable_memory_caching:
|
357 |
-
self.__class__._loader_cache.
|
358 |
self.__class__._loader_cache[dataset_id] = dataset
|
359 |
self._already_logged_limited_loading = True
|
360 |
|
@@ -476,7 +473,7 @@ class LoadWithPandas(LazyLoader):
|
|
476 |
|
477 |
dataset = dataframe.to_dict("records")
|
478 |
|
479 |
-
self.__class__._loader_cache.
|
480 |
self.__class__._loader_cache[dataset_id] = dataset
|
481 |
|
482 |
for instance in self.__class__._loader_cache[dataset_id]:
|
@@ -499,7 +496,7 @@ class LoadWithPandas(LazyLoader):
|
|
499 |
|
500 |
|
501 |
class LoadCSV(LoadWithPandas):
|
502 |
-
"""Loads data from CSV files.
|
503 |
|
504 |
Supports streaming and can handle large files by loading them in chunks.
|
505 |
|
@@ -510,6 +507,7 @@ class LoadCSV(LoadWithPandas):
|
|
510 |
streaming: Bool indicating if streaming should be used.
|
511 |
sep: String specifying the separator used in the CSV files.
|
512 |
indirect_read: Bool indicating if to open a remote file with urllib first
|
|
|
513 |
|
514 |
Example:
|
515 |
Loading csv
|
@@ -517,15 +515,30 @@ class LoadCSV(LoadWithPandas):
|
|
517 |
.. code-block:: python
|
518 |
|
519 |
load_csv = LoadCSV(files={'train': 'path/to/train.csv'}, chunksize=100)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
"""
|
521 |
|
522 |
sep: str = ","
|
|
|
523 |
|
524 |
def read_dataframe(self, file) -> pd.DataFrame:
|
525 |
with error_context(
|
526 |
stage="Raw Dataset Loading",
|
527 |
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
528 |
):
|
|
|
|
|
|
|
|
|
529 |
if self.indirect_read:
|
530 |
# Open the URL with urllib first to mitigate HTTP errors that sometime happen with the internal pandas implementation
|
531 |
from urllib import request
|
@@ -535,12 +548,10 @@ class LoadCSV(LoadWithPandas):
|
|
535 |
response,
|
536 |
sep=self.sep,
|
537 |
low_memory=self.streaming,
|
538 |
-
**
|
539 |
)
|
540 |
|
541 |
-
return pd.read_csv(
|
542 |
-
file, sep=self.sep, low_memory=self.streaming, **self.get_args()
|
543 |
-
)
|
544 |
|
545 |
|
546 |
def read_file(source) -> bytes:
|
@@ -668,7 +679,7 @@ class LoadFromSklearn(LazyLoader):
|
|
668 |
df = pd.DataFrame([split_data["data"], targets]).T
|
669 |
df.columns = ["data", "target"]
|
670 |
dataset = df.to_dict("records")
|
671 |
-
self.__class__._loader_cache.
|
672 |
self.__class__._loader_cache[dataset_id] = dataset
|
673 |
for instance in self.__class__._loader_cache[dataset_id]:
|
674 |
yield recursive_copy(instance)
|
@@ -1247,3 +1258,211 @@ class LoadFromAPI(Loader):
|
|
1247 |
self.__class__._loader_cache.max_size = settings.loader_cache_size
|
1248 |
self.__class__._loader_cache[str(self)] = iterables
|
1249 |
return MultiStream.from_iterables(iterables, copying=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
- :class:`MultipleSourceLoader <unitxt.loaders.MultipleSourceLoader>` - Combines data from multiple different sources.
|
25 |
- :class:`LoadFromDictionary <unitxt.loaders.LoadFromDictionary>` - Loads data from a user-defined Python dictionary.
|
26 |
- :class:`LoadFromHFSpace <unitxt.loaders.LoadFromHFSpace>` - Downloads and loads data from HuggingFace Spaces.
|
27 |
+
- :class:`LoadIOB <unitxt.loaders.LoadIOB>` - Loads data from IOB format files for named entity recognition tasks.
|
28 |
|
29 |
|
30 |
|
|
|
53 |
Union,
|
54 |
)
|
55 |
|
56 |
+
import datasets
|
57 |
import pandas as pd
|
58 |
import requests
|
59 |
from datasets import (
|
|
|
64 |
)
|
65 |
from datasets import load_dataset as _hf_load_dataset
|
66 |
from huggingface_hub import HfApi
|
67 |
+
from packaging.version import Version
|
68 |
from tqdm import tqdm
|
69 |
|
70 |
from .dataclass import NonPositionalField
|
|
|
99 |
):
|
100 |
if settings.hf_offline_datasets_path is not None:
|
101 |
path = os.path.join(settings.hf_offline_datasets_path, path)
|
102 |
+
|
103 |
+
if settings.disable_hf_datasets_cache:
|
104 |
+
kwargs["download_mode"] = "force_redownload"
|
105 |
+
|
106 |
+
if Version(datasets.__version__) < Version("4.0.0"):
|
107 |
+
kwargs["trust_remote_code"] = True
|
108 |
+
|
109 |
+
return _hf_load_dataset(
|
110 |
+
path,
|
111 |
+
*args,
|
112 |
+
**kwargs,
|
113 |
+
verification_mode="no_checks",
|
114 |
+
)
|
|
|
|
|
115 |
|
116 |
|
117 |
@retry_connection_with_exponential_backoff(backoff_factor=2)
|
|
|
120 |
return get_dataset_split_names(
|
121 |
path=path,
|
122 |
config_name=name,
|
|
|
123 |
revision=revision,
|
124 |
)
|
125 |
except Exception as e:
|
|
|
|
|
|
|
126 |
if "Couldn't find cache" in str(e):
|
127 |
raise FileNotFoundError(
|
128 |
f"Dataset cache path={path}, name={name} was not found."
|
|
|
351 |
raise NotImplementedError() from None
|
352 |
|
353 |
if not disable_memory_caching:
|
354 |
+
self.__class__._loader_cache._max_size = settings.loader_cache_size
|
355 |
self.__class__._loader_cache[dataset_id] = dataset
|
356 |
self._already_logged_limited_loading = True
|
357 |
|
|
|
473 |
|
474 |
dataset = dataframe.to_dict("records")
|
475 |
|
476 |
+
self.__class__._loader_cache._max_size = settings.loader_cache_size
|
477 |
self.__class__._loader_cache[dataset_id] = dataset
|
478 |
|
479 |
for instance in self.__class__._loader_cache[dataset_id]:
|
|
|
496 |
|
497 |
|
498 |
class LoadCSV(LoadWithPandas):
|
499 |
+
r"""Loads data from CSV files.
|
500 |
|
501 |
Supports streaming and can handle large files by loading them in chunks.
|
502 |
|
|
|
507 |
streaming: Bool indicating if streaming should be used.
|
508 |
sep: String specifying the separator used in the CSV files.
|
509 |
indirect_read: Bool indicating if to open a remote file with urllib first
|
510 |
+
column_names: Optional list of column names to use instead of header row.
|
511 |
|
512 |
Example:
|
513 |
Loading csv
|
|
|
515 |
.. code-block:: python
|
516 |
|
517 |
load_csv = LoadCSV(files={'train': 'path/to/train.csv'}, chunksize=100)
|
518 |
+
|
519 |
+
Loading TSV with custom column names
|
520 |
+
|
521 |
+
.. code-block:: python
|
522 |
+
|
523 |
+
load_csv = LoadCSV(
|
524 |
+
files={'train': 'path/to/train.tsv'},
|
525 |
+
sep='\t',
|
526 |
+
column_names=['id', 'question', 'table_name', 'answer']
|
527 |
+
)
|
528 |
"""
|
529 |
|
530 |
sep: str = ","
|
531 |
+
column_names: Optional[List[str]] = None
|
532 |
|
533 |
def read_dataframe(self, file) -> pd.DataFrame:
|
534 |
with error_context(
|
535 |
stage="Raw Dataset Loading",
|
536 |
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
537 |
):
|
538 |
+
args = self.get_args()
|
539 |
+
if self.column_names is not None:
|
540 |
+
args["names"] = self.column_names
|
541 |
+
args["header"] = None # Don't use first row as header
|
542 |
if self.indirect_read:
|
543 |
# Open the URL with urllib first to mitigate HTTP errors that sometime happen with the internal pandas implementation
|
544 |
from urllib import request
|
|
|
548 |
response,
|
549 |
sep=self.sep,
|
550 |
low_memory=self.streaming,
|
551 |
+
**args,
|
552 |
)
|
553 |
|
554 |
+
return pd.read_csv(file, sep=self.sep, low_memory=self.streaming, **args)
|
|
|
|
|
555 |
|
556 |
|
557 |
def read_file(source) -> bytes:
|
|
|
679 |
df = pd.DataFrame([split_data["data"], targets]).T
|
680 |
df.columns = ["data", "target"]
|
681 |
dataset = df.to_dict("records")
|
682 |
+
self.__class__._loader_cache._max_size = settings.loader_cache_size
|
683 |
self.__class__._loader_cache[dataset_id] = dataset
|
684 |
for instance in self.__class__._loader_cache[dataset_id]:
|
685 |
yield recursive_copy(instance)
|
|
|
1258 |
self.__class__._loader_cache.max_size = settings.loader_cache_size
|
1259 |
self.__class__._loader_cache[str(self)] = iterables
|
1260 |
return MultiStream.from_iterables(iterables, copying=True)
|
1261 |
+
|
1262 |
+
|
1263 |
+
class LoadIOB(LazyLoader):
|
1264 |
+
"""Loads data from IOB format files.
|
1265 |
+
|
1266 |
+
This loader can parse IOB (Inside-Outside-Begin) format files commonly used for
|
1267 |
+
named entity recognition tasks. It supports both local files and remote URLs,
|
1268 |
+
and can handle various IOB formats including CoNLL-U style files.
|
1269 |
+
|
1270 |
+
Args:
|
1271 |
+
files (Dict[str, str]):
|
1272 |
+
A dictionary mapping split names to file paths or URLs.
|
1273 |
+
column_names (tuple, optional):
|
1274 |
+
Column names for the IOB format. Defaults to ('id', 'token', 'tag', 'misc', 'annotator').
|
1275 |
+
fix_tags (bool, optional):
|
1276 |
+
Whether to apply tag fixing for OTH and B-O tags. Defaults to True.
|
1277 |
+
encoding (str, optional):
|
1278 |
+
File encoding. Defaults to 'utf-8'.
|
1279 |
+
|
1280 |
+
Example:
|
1281 |
+
Loading IOB files
|
1282 |
+
|
1283 |
+
.. code-block:: python
|
1284 |
+
|
1285 |
+
load_iob = LoadIOB(files={'train': 'path/to/train.iob2', 'test': 'path/to/test.iob2'})
|
1286 |
+
"""
|
1287 |
+
|
1288 |
+
files: Dict[str, str]
|
1289 |
+
column_names: tuple = ("id", "token", "tag", "misc", "annotator")
|
1290 |
+
fix_tags: bool = True
|
1291 |
+
encoding: str = "utf-8"
|
1292 |
+
|
1293 |
+
_requirements_list: List[str] = ["conllu"]
|
1294 |
+
|
1295 |
+
def _maybe_set_classification_policy(self):
|
1296 |
+
self.set_default_data_classification(
|
1297 |
+
["proprietary"], "when loading from local files"
|
1298 |
+
)
|
1299 |
+
|
1300 |
+
def get_splits(self) -> List[str]:
|
1301 |
+
return list(self.files.keys())
|
1302 |
+
|
1303 |
+
def split_generator(self, split: str) -> Generator:
|
1304 |
+
import conllu
|
1305 |
+
|
1306 |
+
dataset_id = str(self) + "_" + split
|
1307 |
+
dataset = self.__class__._loader_cache.get(dataset_id, None)
|
1308 |
+
|
1309 |
+
if dataset is None:
|
1310 |
+
if self.get_limit() is not None:
|
1311 |
+
self.log_limited_loading()
|
1312 |
+
|
1313 |
+
file_path = self.files[split]
|
1314 |
+
dataset = []
|
1315 |
+
id_counter = 0
|
1316 |
+
|
1317 |
+
try:
|
1318 |
+
# Handle remote URLs
|
1319 |
+
if file_path.startswith(("http://", "https://")):
|
1320 |
+
import io
|
1321 |
+
import urllib.request
|
1322 |
+
|
1323 |
+
with urllib.request.urlopen(file_path) as response:
|
1324 |
+
content = response.read().decode(self.encoding)
|
1325 |
+
# Use StringIO to create a file-like object
|
1326 |
+
content_file = io.StringIO(content)
|
1327 |
+
sentences = list(
|
1328 |
+
conllu.parse_incr(content_file, fields=self.column_names)
|
1329 |
+
)
|
1330 |
+
else:
|
1331 |
+
# Handle local files
|
1332 |
+
with open(file_path, encoding=self.encoding) as data_file:
|
1333 |
+
sentences = list(
|
1334 |
+
conllu.parse_incr(data_file, fields=self.column_names)
|
1335 |
+
)
|
1336 |
+
|
1337 |
+
limit = self.get_limit()
|
1338 |
+
processed_count = 0
|
1339 |
+
|
1340 |
+
for sent in sentences:
|
1341 |
+
if limit is not None and processed_count >= limit:
|
1342 |
+
break
|
1343 |
+
|
1344 |
+
# Get sentence ID
|
1345 |
+
if "sent_id" in sent.metadata:
|
1346 |
+
idx = sent.metadata["sent_id"]
|
1347 |
+
else:
|
1348 |
+
idx = id_counter
|
1349 |
+
|
1350 |
+
# Extract tokens and tags
|
1351 |
+
tokens = [token["token"] for token in sent]
|
1352 |
+
actual_tags = [token["tag"] for token in sent]
|
1353 |
+
|
1354 |
+
# Apply tag fixing if enabled
|
1355 |
+
if self.fix_tags:
|
1356 |
+
fixed_tags = []
|
1357 |
+
for actual_tag in actual_tags:
|
1358 |
+
if "OTH" in actual_tag or actual_tag == "B-O":
|
1359 |
+
actual_tag = "O"
|
1360 |
+
fixed_tags.append(actual_tag)
|
1361 |
+
else:
|
1362 |
+
fixed_tags = actual_tags
|
1363 |
+
|
1364 |
+
# Extract annotator info if available
|
1365 |
+
annotator = []
|
1366 |
+
for token in sent:
|
1367 |
+
if "annotator" in token and token["annotator"] is not None:
|
1368 |
+
annotator.append(token["annotator"])
|
1369 |
+
else:
|
1370 |
+
annotator.append("")
|
1371 |
+
|
1372 |
+
# Get text from metadata or reconstruct from tokens
|
1373 |
+
if "text" in sent.metadata:
|
1374 |
+
text = sent.metadata["text"]
|
1375 |
+
else:
|
1376 |
+
text = " ".join(tokens)
|
1377 |
+
|
1378 |
+
instance = {
|
1379 |
+
"idx": str(idx),
|
1380 |
+
"text": text,
|
1381 |
+
"tokens": tokens,
|
1382 |
+
"ner_tags": fixed_tags,
|
1383 |
+
"annotator": annotator,
|
1384 |
+
}
|
1385 |
+
|
1386 |
+
dataset.append(instance)
|
1387 |
+
processed_count += 1
|
1388 |
+
id_counter += 1
|
1389 |
+
|
1390 |
+
except Exception as e:
|
1391 |
+
with error_context(
|
1392 |
+
stage="Raw Dataset Loading",
|
1393 |
+
help="https://www.unitxt.ai/en/latest/unitxt.loaders.html#module-unitxt.loaders",
|
1394 |
+
):
|
1395 |
+
raise UnitxtError(
|
1396 |
+
f"Failed to load IOB file {file_path}: {e!s}"
|
1397 |
+
) from e
|
1398 |
+
|
1399 |
+
# Cache the dataset
|
1400 |
+
self.__class__._loader_cache.max_size = settings.loader_cache_size
|
1401 |
+
self.__class__._loader_cache[dataset_id] = dataset
|
1402 |
+
|
1403 |
+
# Yield instances from cached dataset
|
1404 |
+
for instance in dataset:
|
1405 |
+
yield recursive_copy(instance)
|
1406 |
+
|
1407 |
+
|
1408 |
+
class TURLColumnTypeAnnotationLoader(LazyLoader):
|
1409 |
+
data_classification_policy = ["public"]
|
1410 |
+
_requirements_list = ["huggingface_hub"]
|
1411 |
+
|
1412 |
+
def prepare(self):
|
1413 |
+
super().prepare()
|
1414 |
+
from huggingface_hub import hf_hub_download
|
1415 |
+
|
1416 |
+
self._download = hf_hub_download
|
1417 |
+
|
1418 |
+
def get_splits(self) -> List[str]:
|
1419 |
+
return ["train", "validation", "test"]
|
1420 |
+
|
1421 |
+
@staticmethod
|
1422 |
+
def _load_table(table_data):
|
1423 |
+
headers = table_data[5]
|
1424 |
+
cols = table_data[6]
|
1425 |
+
if not cols:
|
1426 |
+
return {"header": headers, "rows": []}
|
1427 |
+
row_count = max(x[-1][0][0] for x in cols)
|
1428 |
+
rows = []
|
1429 |
+
for i in range(row_count):
|
1430 |
+
row = []
|
1431 |
+
for col in cols:
|
1432 |
+
cell = next((c[1][1] for c in col if c[0][0] == i), "")
|
1433 |
+
row.append(cell)
|
1434 |
+
if any(row):
|
1435 |
+
rows.append(row)
|
1436 |
+
return {"header": headers, "rows": rows}
|
1437 |
+
|
1438 |
+
def split_generator(self, split: str) -> Generator[Dict[str, Any], None, None]:
|
1439 |
+
dataset_id = str(self) + "_" + split
|
1440 |
+
dataset = self.__class__._loader_cache.get(dataset_id, None)
|
1441 |
+
if split == "validation":
|
1442 |
+
split = "dev"
|
1443 |
+
if dataset is None:
|
1444 |
+
file_path = self._download(
|
1445 |
+
"stanford-crfm/helm-scenarios",
|
1446 |
+
filename=f"turl-column-type-annotation/{split}.table_col_type.json",
|
1447 |
+
repo_type="dataset",
|
1448 |
+
revision="main",
|
1449 |
+
)
|
1450 |
+
with open(file_path, encoding="utf-8") as f:
|
1451 |
+
data = json.load(f)
|
1452 |
+
dataset = []
|
1453 |
+
for table_data in data:
|
1454 |
+
table_content = self._load_table(table_data)
|
1455 |
+
for idx, colname in enumerate(table_data[5]):
|
1456 |
+
instance = {
|
1457 |
+
"page_title": table_data[1],
|
1458 |
+
"section_title": table_data[3],
|
1459 |
+
"table_caption": table_data[4],
|
1460 |
+
"table": table_content,
|
1461 |
+
"colname": colname,
|
1462 |
+
"annotations": table_data[7][idx],
|
1463 |
+
}
|
1464 |
+
dataset.append(instance)
|
1465 |
+
self.__class__._loader_cache[dataset_id] = dataset
|
1466 |
+
|
1467 |
+
for instance in self.__class__._loader_cache[dataset_id]:
|
1468 |
+
yield instance
|
metrics.py
CHANGED
@@ -8,7 +8,7 @@ import uuid
|
|
8 |
import warnings
|
9 |
from abc import ABC, abstractmethod
|
10 |
from collections import Counter, defaultdict
|
11 |
-
from dataclasses import asdict, field
|
12 |
from dataclasses import fields as dataclasses_fields
|
13 |
from enum import Enum
|
14 |
from functools import lru_cache
|
@@ -21,6 +21,7 @@ from typing import (
|
|
21 |
Literal,
|
22 |
Optional,
|
23 |
Tuple,
|
|
|
24 |
TypeVar,
|
25 |
Union,
|
26 |
)
|
@@ -6160,23 +6161,174 @@ For MacOS: If error on 'mecab-config' show up during installation ], one should
|
|
6160 |
"""
|
6161 |
|
6162 |
|
6163 |
-
|
6164 |
-
|
|
|
|
|
|
|
|
|
6165 |
|
6166 |
-
Range: [0, 1] (higher is better)
|
6167 |
-
Character-level tokenization of BLEU score for improved cross-lingual evaluation.
|
6168 |
|
6169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6170 |
"""
|
6171 |
|
6172 |
-
hf_metric_name = "sacrebleu"
|
6173 |
-
hf_main_score = "score"
|
6174 |
-
prediction_type = str
|
6175 |
main_score = "sacrebleu"
|
6176 |
-
|
6177 |
-
|
6178 |
-
hf_additional_input_fields_pass_one_value = ["tokenize"]
|
6179 |
_requirements_list = ["sacrebleu"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6180 |
|
6181 |
|
6182 |
class CustomF1Fuzzy(CustomF1):
|
@@ -6599,7 +6751,6 @@ class GraniteGuardianBase(InstanceMetric):
|
|
6599 |
"""Return metric for different kinds of "risk" from the Granite-3.0 Guardian model."""
|
6600 |
|
6601 |
reduction_map: Dict[str, List[str]] = None
|
6602 |
-
prediction_type = float
|
6603 |
main_score = None
|
6604 |
reduction_map = {}
|
6605 |
wml_model_name: str = "ibm/granite-guardian-3-8b"
|
@@ -6936,7 +7087,7 @@ class GraniteGuardianCustomRisk(GraniteGuardianBase):
|
|
6936 |
return messages
|
6937 |
|
6938 |
|
6939 |
-
RISK_TYPE_TO_CLASS: Dict[RiskType, GraniteGuardianBase] = {
|
6940 |
RiskType.USER_MESSAGE: GraniteGuardianUserRisk,
|
6941 |
RiskType.ASSISTANT_MESSAGE: GraniteGuardianAssistantRisk,
|
6942 |
RiskType.RAG: GraniteGuardianRagRisk,
|
|
|
8 |
import warnings
|
9 |
from abc import ABC, abstractmethod
|
10 |
from collections import Counter, defaultdict
|
11 |
+
from dataclasses import asdict, dataclass, field
|
12 |
from dataclasses import fields as dataclasses_fields
|
13 |
from enum import Enum
|
14 |
from functools import lru_cache
|
|
|
21 |
Literal,
|
22 |
Optional,
|
23 |
Tuple,
|
24 |
+
Type,
|
25 |
TypeVar,
|
26 |
Union,
|
27 |
)
|
|
|
6161 |
"""
|
6162 |
|
6163 |
|
6164 |
+
@dataclass
|
6165 |
+
class SacreBleuStats:
|
6166 |
+
counts: List[int]
|
6167 |
+
totals: List[int]
|
6168 |
+
sys_len: int
|
6169 |
+
ref_len: int
|
6170 |
|
|
|
|
|
6171 |
|
6172 |
+
class NormalizedSacrebleu(
|
6173 |
+
MapReduceMetric[str, SacreBleuStats], PackageRequirementsMixin
|
6174 |
+
):
|
6175 |
+
"""SacreBLEU metric implementation using MapReduceMetric pattern.
|
6176 |
+
|
6177 |
+
This implementation uses the official sacrebleu library for tokenization
|
6178 |
+
and BLEU computation, while supporting the map-reduce pattern for proper
|
6179 |
+
corpus-level evaluation that matches the behavior of the HuggingFace version.
|
6180 |
+
|
6181 |
+
Range: [0, 1] (higher is better)
|
6182 |
+
Reference: Post, M. 2018. A Call for Clarity in Reporting BLEU Scores.
|
6183 |
"""
|
6184 |
|
|
|
|
|
|
|
6185 |
main_score = "sacrebleu"
|
6186 |
+
ci_score_names = ["sacrebleu"]
|
6187 |
+
prediction_type = str
|
|
|
6188 |
_requirements_list = ["sacrebleu"]
|
6189 |
+
language_to_tokenizer: Optional[Dict[str, str]] = None
|
6190 |
+
# Configuration parameters matching sacrebleu API
|
6191 |
+
tokenize: str = None
|
6192 |
+
lowercase: bool = False
|
6193 |
+
force: bool = False
|
6194 |
+
smooth_method: str = "exp"
|
6195 |
+
smooth_value: Optional[float] = None
|
6196 |
+
use_effective_order: bool = True # Recommended by sacrebleu for sentence-level BLEU
|
6197 |
+
max_ngram_order: int = 4
|
6198 |
+
|
6199 |
+
def prepare(self):
|
6200 |
+
super().prepare()
|
6201 |
+
from sacrebleu.metrics.bleu import BLEU
|
6202 |
+
|
6203 |
+
self.bleu_metric = BLEU(
|
6204 |
+
lowercase=self.lowercase,
|
6205 |
+
force=self.force,
|
6206 |
+
tokenize=self.tokenize,
|
6207 |
+
smooth_method=self.smooth_method,
|
6208 |
+
smooth_value=self.smooth_value,
|
6209 |
+
max_ngram_order=self.max_ngram_order,
|
6210 |
+
effective_order=self.use_effective_order,
|
6211 |
+
)
|
6212 |
+
|
6213 |
+
def _get_tokenizer_for_language(self, language: str) -> str:
|
6214 |
+
"""Get appropriate tokenizer for a given language."""
|
6215 |
+
if self.language_to_tokenizer is None:
|
6216 |
+
raise ValueError("Please set language_to_tokenizer.")
|
6217 |
+
if language.lower() not in self.language_to_tokenizer:
|
6218 |
+
raise ValueError(
|
6219 |
+
f"Language {language} is not in language_to_tokenizer please add it."
|
6220 |
+
)
|
6221 |
+
|
6222 |
+
return self.language_to_tokenizer.get(language.lower())
|
6223 |
+
|
6224 |
+
@staticmethod
|
6225 |
+
@lru_cache(maxsize=10000)
|
6226 |
+
def get_bleu_metric(
|
6227 |
+
lowercase: bool = False,
|
6228 |
+
force: bool = False,
|
6229 |
+
tokenize: Optional[str] = None,
|
6230 |
+
smooth_method: str = "exp",
|
6231 |
+
smooth_value: Optional[float] = None,
|
6232 |
+
max_ngram_order: int = 4,
|
6233 |
+
effective_order: bool = False,
|
6234 |
+
):
|
6235 |
+
from sacrebleu.metrics.bleu import BLEU
|
6236 |
+
|
6237 |
+
return BLEU(
|
6238 |
+
lowercase=lowercase,
|
6239 |
+
force=force,
|
6240 |
+
tokenize=tokenize,
|
6241 |
+
smooth_method=smooth_method,
|
6242 |
+
smooth_value=smooth_value,
|
6243 |
+
max_ngram_order=max_ngram_order,
|
6244 |
+
effective_order=effective_order,
|
6245 |
+
)
|
6246 |
+
|
6247 |
+
def map(
|
6248 |
+
self,
|
6249 |
+
prediction: str,
|
6250 |
+
references: List[str],
|
6251 |
+
task_data: Dict[str, Any],
|
6252 |
+
) -> SacreBleuStats:
|
6253 |
+
"""Map function: compute BLEU statistics for a single instance using sacrebleu."""
|
6254 |
+
if self.tokenize is None and "target_language" in task_data:
|
6255 |
+
target_lang = task_data["target_language"]
|
6256 |
+
tokenize_method = self._get_tokenizer_for_language(target_lang)
|
6257 |
+
else:
|
6258 |
+
tokenize_method = self.tokenize
|
6259 |
+
|
6260 |
+
instance_bleu_metric = self.get_bleu_metric(
|
6261 |
+
lowercase=self.lowercase,
|
6262 |
+
force=self.force,
|
6263 |
+
tokenize=tokenize_method,
|
6264 |
+
smooth_method=self.smooth_method,
|
6265 |
+
smooth_value=self.smooth_value,
|
6266 |
+
max_ngram_order=self.max_ngram_order,
|
6267 |
+
effective_order=self.use_effective_order,
|
6268 |
+
)
|
6269 |
+
|
6270 |
+
# Use the instance-specific metric to get per-instance statistics
|
6271 |
+
bleu_result = instance_bleu_metric.sentence_score(prediction, references)
|
6272 |
+
|
6273 |
+
return SacreBleuStats(
|
6274 |
+
counts=bleu_result.counts,
|
6275 |
+
totals=bleu_result.totals,
|
6276 |
+
sys_len=bleu_result.sys_len,
|
6277 |
+
ref_len=bleu_result.ref_len,
|
6278 |
+
)
|
6279 |
+
|
6280 |
+
def reduce(self, intermediates: List[SacreBleuStats]) -> Dict[str, Any]:
|
6281 |
+
"""Reduce function: aggregate statistics and compute corpus BLEU using sacrebleu."""
|
6282 |
+
if not intermediates:
|
6283 |
+
return {
|
6284 |
+
"sacrebleu": 0.0,
|
6285 |
+
"counts": [0, 0, 0, 0],
|
6286 |
+
"totals": [0, 0, 0, 0],
|
6287 |
+
"precisions": [0.0, 0.0, 0.0, 0.0],
|
6288 |
+
"bp": 0.0,
|
6289 |
+
"sys_len": 0,
|
6290 |
+
"ref_len": 0,
|
6291 |
+
}
|
6292 |
+
|
6293 |
+
# Aggregate all the statistics across instances
|
6294 |
+
total_counts = [0] * self.max_ngram_order
|
6295 |
+
total_totals = [0] * self.max_ngram_order
|
6296 |
+
total_sys_len = 0
|
6297 |
+
total_ref_len = 0
|
6298 |
+
|
6299 |
+
for stats in intermediates:
|
6300 |
+
for i in range(min(len(stats.counts), self.max_ngram_order)):
|
6301 |
+
total_counts[i] += stats.counts[i]
|
6302 |
+
total_totals[i] += stats.totals[i]
|
6303 |
+
total_sys_len += stats.sys_len
|
6304 |
+
total_ref_len += stats.ref_len
|
6305 |
+
|
6306 |
+
# Use sacrebleu's compute_bleu static method to compute the final score from aggregated stats
|
6307 |
+
# This is the proper way to get corpus-level BLEU from individual statistics
|
6308 |
+
bleu_result = self.bleu_metric.compute_bleu(
|
6309 |
+
correct=total_counts,
|
6310 |
+
total=total_totals,
|
6311 |
+
sys_len=total_sys_len,
|
6312 |
+
ref_len=total_ref_len,
|
6313 |
+
smooth_method=self.smooth_method,
|
6314 |
+
smooth_value=self.smooth_value,
|
6315 |
+
effective_order=self.use_effective_order,
|
6316 |
+
max_ngram_order=self.max_ngram_order,
|
6317 |
+
)
|
6318 |
+
|
6319 |
+
return {
|
6320 |
+
"sacrebleu": round(
|
6321 |
+
bleu_result.score / 100.0, 2
|
6322 |
+
), # Convert from 0-100 to 0-1 scale
|
6323 |
+
"counts": total_counts,
|
6324 |
+
"totals": total_totals,
|
6325 |
+
"precisions": [
|
6326 |
+
round(p / 100.0, 2) for p in bleu_result.precisions
|
6327 |
+
], # Convert from 0-100 to 0-1 scale
|
6328 |
+
"bp": round(bleu_result.bp, 2),
|
6329 |
+
"sys_len": total_sys_len,
|
6330 |
+
"ref_len": total_ref_len,
|
6331 |
+
}
|
6332 |
|
6333 |
|
6334 |
class CustomF1Fuzzy(CustomF1):
|
|
|
6751 |
"""Return metric for different kinds of "risk" from the Granite-3.0 Guardian model."""
|
6752 |
|
6753 |
reduction_map: Dict[str, List[str]] = None
|
|
|
6754 |
main_score = None
|
6755 |
reduction_map = {}
|
6756 |
wml_model_name: str = "ibm/granite-guardian-3-8b"
|
|
|
7087 |
return messages
|
7088 |
|
7089 |
|
7090 |
+
RISK_TYPE_TO_CLASS: Dict[RiskType, Type[GraniteGuardianBase]] = {
|
7091 |
RiskType.USER_MESSAGE: GraniteGuardianUserRisk,
|
7092 |
RiskType.ASSISTANT_MESSAGE: GraniteGuardianAssistantRisk,
|
7093 |
RiskType.RAG: GraniteGuardianRagRisk,
|
operator.py
CHANGED
@@ -2,13 +2,12 @@ from abc import abstractmethod
|
|
2 |
from dataclasses import field
|
3 |
from typing import Any, Dict, Generator, List, Optional, Union
|
4 |
|
5 |
-
from pkg_resources import DistributionNotFound, VersionConflict, require
|
6 |
-
|
7 |
from .artifact import Artifact
|
8 |
from .dataclass import FinalField, InternalField, NonPositionalField
|
9 |
from .error_utils import error_context
|
10 |
from .settings_utils import get_constants
|
11 |
from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
|
|
|
12 |
|
13 |
constants = get_constants()
|
14 |
|
|
|
2 |
from dataclasses import field
|
3 |
from typing import Any, Dict, Generator, List, Optional, Union
|
4 |
|
|
|
|
|
5 |
from .artifact import Artifact
|
6 |
from .dataclass import FinalField, InternalField, NonPositionalField
|
7 |
from .error_utils import error_context
|
8 |
from .settings_utils import get_constants
|
9 |
from .stream import DynamicStream, EmptyStreamError, MultiStream, Stream
|
10 |
+
from .utils import DistributionNotFound, VersionConflict, require
|
11 |
|
12 |
constants = get_constants()
|
13 |
|
operators.py
CHANGED
@@ -218,7 +218,7 @@ class MapInstanceValues(InstanceOperator):
|
|
218 |
if val_as_str in mapper:
|
219 |
return recursive_copy(mapper[val_as_str])
|
220 |
if self.strict:
|
221 |
-
raise
|
222 |
f"value '{val_as_str}', the string representation of the value in field '{key}', is not found in mapper '{mapper}'"
|
223 |
)
|
224 |
return val
|
@@ -2574,3 +2574,40 @@ class Fillna(FieldOperator):
|
|
2574 |
except TypeError:
|
2575 |
return value
|
2576 |
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
if val_as_str in mapper:
|
219 |
return recursive_copy(mapper[val_as_str])
|
220 |
if self.strict:
|
221 |
+
raise ValueError(
|
222 |
f"value '{val_as_str}', the string representation of the value in field '{key}', is not found in mapper '{mapper}'"
|
223 |
)
|
224 |
return val
|
|
|
2574 |
except TypeError:
|
2575 |
return value
|
2576 |
return value
|
2577 |
+
|
2578 |
+
|
2579 |
+
class ReadFile(FieldOperator):
|
2580 |
+
"""Reads file content from local path or URL.
|
2581 |
+
|
2582 |
+
This operator can read files from local filesystem paths or remote URLs.
|
2583 |
+
The content is returned as a string.
|
2584 |
+
|
2585 |
+
Args:
|
2586 |
+
encoding (str): Text encoding to use when reading the file. Defaults to 'utf-8'.
|
2587 |
+
|
2588 |
+
Example:
|
2589 |
+
Reading a local file
|
2590 |
+
|
2591 |
+
.. code-block:: python
|
2592 |
+
|
2593 |
+
ReadFile(field="file_path", to_field="content")
|
2594 |
+
|
2595 |
+
Reading from URL
|
2596 |
+
|
2597 |
+
.. code-block:: python
|
2598 |
+
|
2599 |
+
ReadFile(field="url", to_field="content")
|
2600 |
+
"""
|
2601 |
+
|
2602 |
+
encoding: str = "utf-8"
|
2603 |
+
|
2604 |
+
def process_value(self, value: str) -> str:
|
2605 |
+
"""Read file content from local path or URL."""
|
2606 |
+
if value.startswith(("http://", "https://")):
|
2607 |
+
# Read from URL
|
2608 |
+
response = requests.get(value)
|
2609 |
+
response.raise_for_status()
|
2610 |
+
return response.content.decode(self.encoding, errors="replace")
|
2611 |
+
# Read from local file
|
2612 |
+
with open(value, encoding=self.encoding) as f:
|
2613 |
+
return f.read()
|
settings_utils.py
CHANGED
@@ -224,6 +224,7 @@ if Settings.is_uninitilized():
|
|
224 |
settings.hf_offline_models_path = None
|
225 |
settings.inference_engine_cache_path = "./inference_engine_cache/"
|
226 |
settings.max_connection_retries = 3
|
|
|
227 |
settings.dataset_cache_default = (bool, False)
|
228 |
|
229 |
if Constants.is_uninitilized():
|
|
|
224 |
settings.hf_offline_models_path = None
|
225 |
settings.inference_engine_cache_path = "./inference_engine_cache/"
|
226 |
settings.max_connection_retries = 3
|
227 |
+
settings.max_templates_tests_for_card_test = 10
|
228 |
settings.dataset_cache_default = (bool, False)
|
229 |
|
230 |
if Constants.is_uninitilized():
|
struct_data_operators.py
CHANGED
@@ -24,6 +24,8 @@ For key-value pairs, expected input format is:
|
|
24 |
"""
|
25 |
|
26 |
import ast
|
|
|
|
|
27 |
import json
|
28 |
import random
|
29 |
from abc import ABC, abstractmethod
|
@@ -31,6 +33,7 @@ from typing import (
|
|
31 |
Any,
|
32 |
Dict,
|
33 |
List,
|
|
|
34 |
Optional,
|
35 |
Tuple,
|
36 |
)
|
@@ -1118,3 +1121,67 @@ class JsonStrToDict(FieldOperator):
|
|
1118 |
)
|
1119 |
dict_value = {}
|
1120 |
return {str(k): str(v) for k, v in dict_value.items() if v is not None}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
"""
|
25 |
|
26 |
import ast
|
27 |
+
import csv
|
28 |
+
import io
|
29 |
import json
|
30 |
import random
|
31 |
from abc import ABC, abstractmethod
|
|
|
33 |
Any,
|
34 |
Dict,
|
35 |
List,
|
36 |
+
Literal,
|
37 |
Optional,
|
38 |
Tuple,
|
39 |
)
|
|
|
1121 |
)
|
1122 |
dict_value = {}
|
1123 |
return {str(k): str(v) for k, v in dict_value.items() if v is not None}
|
1124 |
+
|
1125 |
+
|
1126 |
+
class ParseCSV(FieldOperator):
|
1127 |
+
r"""Parse CSV/TSV text content into table format.
|
1128 |
+
|
1129 |
+
This operator converts CSV or TSV text content into the standard table format
|
1130 |
+
used by Unitxt with header and rows fields.
|
1131 |
+
|
1132 |
+
Args:
|
1133 |
+
separator (str): Field separator character. Defaults to ','.
|
1134 |
+
has_header (bool): Whether the first row contains column headers. Defaults to True.
|
1135 |
+
skip_header (bool): Whether to skip the first row entirely. Defaults to False.
|
1136 |
+
|
1137 |
+
Example:
|
1138 |
+
Parsing CSV content
|
1139 |
+
|
1140 |
+
.. code-block:: python
|
1141 |
+
|
1142 |
+
ParseCSV(field="csv_content", to_field="table", separator=",")
|
1143 |
+
|
1144 |
+
Parsing TSV content
|
1145 |
+
|
1146 |
+
.. code-block:: python
|
1147 |
+
|
1148 |
+
ParseCSV(field="tsv_content", to_field="table", separator="\t")
|
1149 |
+
"""
|
1150 |
+
|
1151 |
+
separator: str = ","
|
1152 |
+
has_header: bool = True
|
1153 |
+
skip_header: bool = False
|
1154 |
+
dtype: Optional[Literal["str"]] = None
|
1155 |
+
strip_cells: bool = False
|
1156 |
+
|
1157 |
+
def process_value(self, value: str) -> Dict[str, Any]:
|
1158 |
+
csv_reader = csv.reader(
|
1159 |
+
io.StringIO(value), delimiter=self.separator, quotechar='"'
|
1160 |
+
)
|
1161 |
+
rows = []
|
1162 |
+
header = []
|
1163 |
+
for idx, row in enumerate(csv_reader):
|
1164 |
+
if idx == 0 and self.has_header:
|
1165 |
+
header = row
|
1166 |
+
if self.skip_header:
|
1167 |
+
continue
|
1168 |
+
else:
|
1169 |
+
rows.append(row)
|
1170 |
+
|
1171 |
+
if not self.has_header or self.skip_header:
|
1172 |
+
header = [f"col_{i}" for i in range(len(rows[0]))]
|
1173 |
+
|
1174 |
+
if self.strip_cells:
|
1175 |
+
|
1176 |
+
def clean_cell(x):
|
1177 |
+
if isinstance(x, str):
|
1178 |
+
return x.replace("\n", " ").strip()
|
1179 |
+
return x
|
1180 |
+
|
1181 |
+
rows = [[clean_cell(cell) for cell in row] for row in rows]
|
1182 |
+
header = [clean_cell(h) for h in header]
|
1183 |
+
|
1184 |
+
return {
|
1185 |
+
"header": header,
|
1186 |
+
"rows": rows,
|
1187 |
+
}
|
utils.py
CHANGED
@@ -9,9 +9,13 @@ import time
|
|
9 |
from collections import OrderedDict
|
10 |
from contextvars import ContextVar
|
11 |
from functools import wraps
|
|
|
|
|
12 |
from typing import Any, Dict, Optional
|
13 |
from urllib.error import HTTPError as UrllibHTTPError
|
14 |
|
|
|
|
|
15 |
from requests.exceptions import ConnectionError, HTTPError
|
16 |
from requests.exceptions import Timeout as TimeoutError
|
17 |
|
@@ -422,3 +426,46 @@ class LongString(str):
|
|
422 |
if self._repr_str is not None:
|
423 |
return self._repr_str
|
424 |
return super().__repr__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from collections import OrderedDict
|
10 |
from contextvars import ContextVar
|
11 |
from functools import wraps
|
12 |
+
from importlib.metadata import PackageNotFoundError
|
13 |
+
from importlib.metadata import version as get_installed_version
|
14 |
from typing import Any, Dict, Optional
|
15 |
from urllib.error import HTTPError as UrllibHTTPError
|
16 |
|
17 |
+
from packaging.requirements import Requirement
|
18 |
+
from packaging.version import Version
|
19 |
from requests.exceptions import ConnectionError, HTTPError
|
20 |
from requests.exceptions import Timeout as TimeoutError
|
21 |
|
|
|
426 |
if self._repr_str is not None:
|
427 |
return self._repr_str
|
428 |
return super().__repr__()
|
429 |
+
|
430 |
+
|
431 |
+
class DistributionNotFound(Exception):
|
432 |
+
def __init__(self, requirement):
|
433 |
+
self.requirement = requirement
|
434 |
+
super().__init__(f"Distribution not found for requirement: {requirement}")
|
435 |
+
|
436 |
+
|
437 |
+
class VersionConflict(Exception):
|
438 |
+
def __init__(self, dist, req):
|
439 |
+
self.dist = dist # Distribution object, just emulate enough for your needs
|
440 |
+
self.req = req
|
441 |
+
super().__init__(f"Version conflict: {dist} does not satisfy {req}")
|
442 |
+
|
443 |
+
|
444 |
+
class DistStub:
|
445 |
+
# Minimal stub to mimic pkg_resources.Distribution
|
446 |
+
def __init__(self, project_name, version):
|
447 |
+
self.project_name = project_name
|
448 |
+
self.version = version
|
449 |
+
|
450 |
+
|
451 |
+
def require(requirements):
|
452 |
+
"""Minimal drop-in replacement for pkg_resources.require.
|
453 |
+
|
454 |
+
Accepts a single requirement string or a list of them.
|
455 |
+
Raises DistributionNotFound or VersionConflict.
|
456 |
+
Returns nothing (side-effect only).
|
457 |
+
"""
|
458 |
+
if isinstance(requirements, str):
|
459 |
+
requirements = [requirements]
|
460 |
+
for req_str in requirements:
|
461 |
+
req = Requirement(req_str)
|
462 |
+
if req.marker and not req.marker.evaluate():
|
463 |
+
continue # skip not needed for this environment
|
464 |
+
name = req.name
|
465 |
+
try:
|
466 |
+
ver = get_installed_version(name)
|
467 |
+
except PackageNotFoundError as e:
|
468 |
+
raise DistributionNotFound(req_str) from e
|
469 |
+
if req.specifier and not req.specifier.contains(Version(ver), prereleases=True):
|
470 |
+
dist = DistStub(name, ver)
|
471 |
+
raise VersionConflict(dist, req_str)
|
version.py
CHANGED
@@ -1 +1 @@
|
|
1 |
-
version = "1.26.
|
|
|
1 |
+
version = "1.26.6"
|