Commit
·
1fc08db
1
Parent(s):
371c76b
Add requirements.txt for synthetic dataset generator dependency, update import paths for InferenceEndpointsLLM in multiple files, and modify example deployment script to comment out HF_TOKEN.
Browse files- examples/hf-serverless-deployment.py +1 -1
- requirements.txt +1 -0
- src/synthetic_dataset_generator/_inference_endpoints.py +2 -2
- src/synthetic_dataset_generator/pipelines/base.py +1 -1
- src/synthetic_dataset_generator/pipelines/chat.py +2 -1
- src/synthetic_dataset_generator/pipelines/eval.py +4 -4
- src/synthetic_dataset_generator/pipelines/textcat.py +5 -3
examples/hf-serverless-deployment.py
CHANGED
@@ -8,7 +8,7 @@ import os
|
|
8 |
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
-
os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
12 |
os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model
|
13 |
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
|
14 |
|
|
|
8 |
|
9 |
from synthetic_dataset_generator import launch
|
10 |
|
11 |
+
# os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
|
12 |
os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model
|
13 |
os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
|
14 |
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
-e git+https://github.com/argilla-io/synthetic-data-generator.git#egg=synthetic-dataset-generator
|
src/synthetic_dataset_generator/_inference_endpoints.py
CHANGED
@@ -2,7 +2,7 @@ import warnings
|
|
2 |
|
3 |
import distilabel
|
4 |
import distilabel.distiset
|
5 |
-
from distilabel.
|
6 |
from pydantic import (
|
7 |
ValidationError,
|
8 |
model_validator,
|
@@ -55,4 +55,4 @@ class CustomInferenceEndpointsLLM(InferenceEndpointsLLM):
|
|
55 |
)
|
56 |
|
57 |
|
58 |
-
distilabel.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM
|
|
|
2 |
|
3 |
import distilabel
|
4 |
import distilabel.distiset
|
5 |
+
from distilabel.models import InferenceEndpointsLLM
|
6 |
from pydantic import (
|
7 |
ValidationError,
|
8 |
model_validator,
|
|
|
55 |
)
|
56 |
|
57 |
|
58 |
+
distilabel.models.llms.InferenceEndpointsLLM = CustomInferenceEndpointsLLM
|
src/synthetic_dataset_generator/pipelines/base.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import math
|
2 |
import random
|
3 |
|
4 |
-
from distilabel.
|
5 |
from distilabel.steps.tasks import TextGeneration
|
6 |
|
7 |
from synthetic_dataset_generator.constants import (
|
|
|
1 |
import math
|
2 |
import random
|
3 |
|
4 |
+
from distilabel.models import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM
|
5 |
from distilabel.steps.tasks import TextGeneration
|
6 |
|
7 |
from synthetic_dataset_generator.constants import (
|
src/synthetic_dataset_generator/pipelines/chat.py
CHANGED
@@ -229,6 +229,7 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
|
229 |
|
230 |
def generate_pipeline_code(system_prompt, num_turns, num_rows, temperature):
|
231 |
input_mappings = _get_output_mappings(num_turns)
|
|
|
232 |
code = f"""
|
233 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
234 |
import os
|
@@ -241,7 +242,7 @@ SYSTEM_PROMPT = "{system_prompt}"
|
|
241 |
|
242 |
with Pipeline(name="sft") as pipeline:
|
243 |
magpie = MagpieGenerator(
|
244 |
-
llm={_get_llm_class()}.
|
245 |
n_turns={num_turns},
|
246 |
num_rows={num_rows},
|
247 |
batch_size=1,
|
|
|
229 |
|
230 |
def generate_pipeline_code(system_prompt, num_turns, num_rows, temperature):
|
231 |
input_mappings = _get_output_mappings(num_turns)
|
232 |
+
|
233 |
code = f"""
|
234 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
235 |
import os
|
|
|
242 |
|
243 |
with Pipeline(name="sft") as pipeline:
|
244 |
magpie = MagpieGenerator(
|
245 |
+
llm={_get_llm_class()}.from_dict({_get_llm().model_dump()}),
|
246 |
n_turns={num_turns},
|
247 |
num_rows={num_rows},
|
248 |
batch_size=1,
|
src/synthetic_dataset_generator/pipelines/eval.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from datasets import get_dataset_config_names, get_dataset_split_names
|
2 |
-
from distilabel.
|
3 |
from distilabel.steps.tasks import (
|
4 |
TextGeneration,
|
5 |
UltraFeedback,
|
@@ -57,7 +57,7 @@ from datasets import load_dataset
|
|
57 |
from distilabel.pipeline import Pipeline
|
58 |
from distilabel.steps import LoadDataFromDicts
|
59 |
from distilabel.steps.tasks import UltraFeedback
|
60 |
-
from distilabel.
|
61 |
|
62 |
MODEL = "{MODEL}"
|
63 |
BASE_URL = "{BASE_URL}"
|
@@ -97,7 +97,7 @@ import os
|
|
97 |
from distilabel.pipeline import Pipeline
|
98 |
from distilabel.steps import LoadDataFromDicts, CombineOutputs
|
99 |
from distilabel.steps.tasks import UltraFeedback
|
100 |
-
from distilabel.
|
101 |
|
102 |
MODEL = "{MODEL}"
|
103 |
BASE_URL = "{BASE_URL}"
|
@@ -154,7 +154,7 @@ import os
|
|
154 |
from distilabel.pipeline import Pipeline
|
155 |
from distilabel.steps import LoadDataFromHub
|
156 |
from distilabel.steps.tasks import TextGeneration
|
157 |
-
from distilabel.
|
158 |
|
159 |
MODEL = "{MODEL}"
|
160 |
BASE_URL = "{BASE_URL}"
|
|
|
1 |
from datasets import get_dataset_config_names, get_dataset_split_names
|
2 |
+
from distilabel.models import InferenceEndpointsLLM
|
3 |
from distilabel.steps.tasks import (
|
4 |
TextGeneration,
|
5 |
UltraFeedback,
|
|
|
57 |
from distilabel.pipeline import Pipeline
|
58 |
from distilabel.steps import LoadDataFromDicts
|
59 |
from distilabel.steps.tasks import UltraFeedback
|
60 |
+
from distilabel.models import InferenceEndpointsLLM
|
61 |
|
62 |
MODEL = "{MODEL}"
|
63 |
BASE_URL = "{BASE_URL}"
|
|
|
97 |
from distilabel.pipeline import Pipeline
|
98 |
from distilabel.steps import LoadDataFromDicts, CombineOutputs
|
99 |
from distilabel.steps.tasks import UltraFeedback
|
100 |
+
from distilabel.models import InferenceEndpointsLLM
|
101 |
|
102 |
MODEL = "{MODEL}"
|
103 |
BASE_URL = "{BASE_URL}"
|
|
|
154 |
from distilabel.pipeline import Pipeline
|
155 |
from distilabel.steps import LoadDataFromHub
|
156 |
from distilabel.steps.tasks import TextGeneration
|
157 |
+
from distilabel.models import InferenceEndpointsLLM
|
158 |
|
159 |
MODEL = "{MODEL}"
|
160 |
BASE_URL = "{BASE_URL}"
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -133,17 +133,19 @@ def generate_pipeline_code(
|
|
133 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
134 |
import os
|
135 |
import random
|
136 |
-
from distilabel.
|
137 |
from distilabel.pipeline import Pipeline
|
138 |
from distilabel.steps import LoadDataFromDicts, KeepColumns
|
139 |
from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
|
140 |
|
|
|
|
|
141 |
with Pipeline(name="textcat") as pipeline:
|
142 |
|
143 |
task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
|
144 |
|
145 |
textcat_generation = GenerateTextClassificationData(
|
146 |
-
llm={_get_llm_class()}.
|
147 |
seed=random.randint(0, 2**32 - 1),
|
148 |
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
149 |
clarity={None if clarity == "mixed" else repr(clarity)},
|
@@ -176,7 +178,7 @@ with Pipeline(name="textcat") as pipeline:
|
|
176 |
)
|
177 |
|
178 |
textcat_labeller = TextClassification(
|
179 |
-
llm={_get_llm_class()}.
|
180 |
n={num_labels},
|
181 |
available_labels={labels},
|
182 |
context=TEXT_CLASSIFICATION_TASK,
|
|
|
133 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
134 |
import os
|
135 |
import random
|
136 |
+
from distilabel.models import {_get_llm_class()}
|
137 |
from distilabel.pipeline import Pipeline
|
138 |
from distilabel.steps import LoadDataFromDicts, KeepColumns
|
139 |
from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
|
140 |
|
141 |
+
SYSTEM_PROMPT = "{system_prompt}"
|
142 |
+
|
143 |
with Pipeline(name="textcat") as pipeline:
|
144 |
|
145 |
task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
|
146 |
|
147 |
textcat_generation = GenerateTextClassificationData(
|
148 |
+
llm={_get_llm_class()}.from_dict({_get_llm().model_dump()}),
|
149 |
seed=random.randint(0, 2**32 - 1),
|
150 |
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
151 |
clarity={None if clarity == "mixed" else repr(clarity)},
|
|
|
178 |
)
|
179 |
|
180 |
textcat_labeller = TextClassification(
|
181 |
+
llm={_get_llm_class()}.from_dict({_get_llm().model_dump()}),
|
182 |
n={num_labels},
|
183 |
available_labels={labels},
|
184 |
context=TEXT_CLASSIFICATION_TASK,
|