Commit
·
7b7c1be
1
Parent(s):
dc56474
add MAX_NUM_ROWS
Browse files- README.md +6 -0
- src/synthetic_dataset_generator/apps/base.py +10 -0
- src/synthetic_dataset_generator/apps/eval.py +2 -0
- src/synthetic_dataset_generator/apps/sft.py +2 -0
- src/synthetic_dataset_generator/apps/textcat.py +2 -0
- src/synthetic_dataset_generator/constants.py +3 -1
- src/synthetic_dataset_generator/pipelines/eval.py +6 -6
- src/synthetic_dataset_generator/pipelines/sft.py +7 -6
- src/synthetic_dataset_generator/pipelines/textcat.py +6 -6
README.md
CHANGED
@@ -79,6 +79,12 @@ demo.launch()
|
|
79 |
|
80 |
Optionally, you can set the following environment variables to customize the generation process.
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
- `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`.
|
83 |
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`.
|
84 |
- `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`, `sk-...`.
|
|
|
79 |
|
80 |
Optionally, you can set the following environment variables to customize the generation process.
|
81 |
|
82 |
+
- `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
|
83 |
+
- `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
|
84 |
+
- `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
|
85 |
+
|
86 |
+
Optionally, you can use different models and APIs.
|
87 |
+
|
88 |
- `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`.
|
89 |
- `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`.
|
90 |
- `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`, `sk-...`.
|
src/synthetic_dataset_generator/apps/base.py
CHANGED
@@ -8,6 +8,7 @@ from datasets import Dataset, concatenate_datasets, load_dataset
|
|
8 |
from gradio import OAuthToken
|
9 |
from huggingface_hub import HfApi, upload_file
|
10 |
|
|
|
11 |
from synthetic_dataset_generator.utils import get_argilla_client
|
12 |
|
13 |
|
@@ -136,3 +137,12 @@ def show_success_message(org_name, repo_name) -> gr.Markdown:
|
|
136 |
|
137 |
def hide_success_message() -> gr.Markdown:
|
138 |
return gr.Markdown(value="")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from gradio import OAuthToken
|
9 |
from huggingface_hub import HfApi, upload_file
|
10 |
|
11 |
+
from synthetic_dataset_generator.constants import MAX_NUM_ROWS
|
12 |
from synthetic_dataset_generator.utils import get_argilla_client
|
13 |
|
14 |
|
|
|
137 |
|
138 |
def hide_success_message() -> gr.Markdown:
|
139 |
return gr.Markdown(value="")
|
140 |
+
|
141 |
+
|
142 |
+
def test_max_num_rows(num_rows: int) -> int:
|
143 |
+
if num_rows > MAX_NUM_ROWS:
|
144 |
+
num_rows = MAX_NUM_ROWS
|
145 |
+
gr.Info(
|
146 |
+
f"Number of rows is larger than the configured maximum. Setting number of rows to {MAX_NUM_ROWS}. Set environment variable `MAX_NUM_ROWS` to change this behavior."
|
147 |
+
)
|
148 |
+
return num_rows
|
src/synthetic_dataset_generator/apps/eval.py
CHANGED
@@ -22,6 +22,7 @@ from synthetic_dataset_generator.apps.base import (
|
|
22 |
hide_success_message,
|
23 |
push_pipeline_code_to_hub,
|
24 |
show_success_message,
|
|
|
25 |
validate_argilla_user_workspace_dataset,
|
26 |
validate_push_to_hub,
|
27 |
)
|
@@ -303,6 +304,7 @@ def _evaluate_dataset(
|
|
303 |
num_rows: int = 10,
|
304 |
is_sample: bool = False,
|
305 |
):
|
|
|
306 |
if eval_type == "chat-eval":
|
307 |
dataframe = evaluate_instruction_response(
|
308 |
dataframe=dataframe,
|
|
|
22 |
hide_success_message,
|
23 |
push_pipeline_code_to_hub,
|
24 |
show_success_message,
|
25 |
+
test_max_num_rows,
|
26 |
validate_argilla_user_workspace_dataset,
|
27 |
validate_push_to_hub,
|
28 |
)
|
|
|
304 |
num_rows: int = 10,
|
305 |
is_sample: bool = False,
|
306 |
):
|
307 |
+
num_rows = test_max_num_rows(num_rows)
|
308 |
if eval_type == "chat-eval":
|
309 |
dataframe = evaluate_instruction_response(
|
310 |
dataframe=dataframe,
|
src/synthetic_dataset_generator/apps/sft.py
CHANGED
@@ -14,6 +14,7 @@ from synthetic_dataset_generator.apps.base import (
|
|
14 |
hide_success_message,
|
15 |
push_pipeline_code_to_hub,
|
16 |
show_success_message,
|
|
|
17 |
validate_argilla_user_workspace_dataset,
|
18 |
validate_push_to_hub,
|
19 |
)
|
@@ -100,6 +101,7 @@ def generate_dataset(
|
|
100 |
is_sample: bool = False,
|
101 |
progress=gr.Progress(),
|
102 |
) -> pd.DataFrame:
|
|
|
103 |
progress(0.0, desc="(1/2) Generating instructions")
|
104 |
magpie_generator = get_magpie_generator(
|
105 |
system_prompt, num_turns, temperature, is_sample
|
|
|
14 |
hide_success_message,
|
15 |
push_pipeline_code_to_hub,
|
16 |
show_success_message,
|
17 |
+
test_max_num_rows,
|
18 |
validate_argilla_user_workspace_dataset,
|
19 |
validate_push_to_hub,
|
20 |
)
|
|
|
101 |
is_sample: bool = False,
|
102 |
progress=gr.Progress(),
|
103 |
) -> pd.DataFrame:
|
104 |
+
num_rows = test_max_num_rows(num_rows)
|
105 |
progress(0.0, desc="(1/2) Generating instructions")
|
106 |
magpie_generator = get_magpie_generator(
|
107 |
system_prompt, num_turns, temperature, is_sample
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -15,6 +15,7 @@ from src.synthetic_dataset_generator.apps.base import (
|
|
15 |
hide_success_message,
|
16 |
push_pipeline_code_to_hub,
|
17 |
show_success_message,
|
|
|
18 |
validate_argilla_user_workspace_dataset,
|
19 |
validate_push_to_hub,
|
20 |
)
|
@@ -94,6 +95,7 @@ def generate_dataset(
|
|
94 |
is_sample: bool = False,
|
95 |
progress=gr.Progress(),
|
96 |
) -> pd.DataFrame:
|
|
|
97 |
progress(0.0, desc="(1/2) Generating dataset")
|
98 |
labels = get_preprocess_labels(labels)
|
99 |
textcat_generator = get_textcat_generator(
|
|
|
15 |
hide_success_message,
|
16 |
push_pipeline_code_to_hub,
|
17 |
show_success_message,
|
18 |
+
test_max_num_rows,
|
19 |
validate_argilla_user_workspace_dataset,
|
20 |
validate_push_to_hub,
|
21 |
)
|
|
|
95 |
is_sample: bool = False,
|
96 |
progress=gr.Progress(),
|
97 |
) -> pd.DataFrame:
|
98 |
+
num_rows = test_max_num_rows(num_rows)
|
99 |
progress(0.0, desc="(1/2) Generating dataset")
|
100 |
labels = get_preprocess_labels(labels)
|
101 |
textcat_generator = get_textcat_generator(
|
src/synthetic_dataset_generator/constants.py
CHANGED
@@ -15,7 +15,9 @@ if HF_TOKEN is None:
|
|
15 |
)
|
16 |
|
17 |
# Inference
|
18 |
-
|
|
|
|
|
19 |
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
20 |
API_KEYS = (
|
21 |
[os.getenv("HF_TOKEN")]
|
|
|
15 |
)
|
16 |
|
17 |
# Inference
|
18 |
+
MAX_NUM_TOKENS = os.getenv("MAX_NUM_TOKENS", 2048)
|
19 |
+
MAX_NUM_ROWS: str | int = os.getenv("MAX_NUM_ROWS", 1000)
|
20 |
+
DEFAULT_BATCH_SIZE = os.getenv("DEFAULT_BATCH_SIZE", 5)
|
21 |
MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
|
22 |
API_KEYS = (
|
23 |
[os.getenv("HF_TOKEN")]
|
src/synthetic_dataset_generator/pipelines/eval.py
CHANGED
@@ -5,7 +5,7 @@ from distilabel.steps.tasks import (
|
|
5 |
UltraFeedback,
|
6 |
)
|
7 |
|
8 |
-
from synthetic_dataset_generator.constants import BASE_URL, MODEL
|
9 |
from synthetic_dataset_generator.pipelines.base import _get_next_api_key
|
10 |
from synthetic_dataset_generator.utils import extract_column_names
|
11 |
|
@@ -18,7 +18,7 @@ def get_ultrafeedback_evaluator(aspect, is_sample):
|
|
18 |
api_key=_get_next_api_key(),
|
19 |
generation_kwargs={
|
20 |
"temperature": 0.01,
|
21 |
-
"max_new_tokens":
|
22 |
},
|
23 |
),
|
24 |
aspect=aspect,
|
@@ -36,7 +36,7 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample)
|
|
36 |
structured_output={"format": "json", "schema": structured_output},
|
37 |
generation_kwargs={
|
38 |
"temperature": 0.01,
|
39 |
-
"max_new_tokens":
|
40 |
},
|
41 |
),
|
42 |
template=prompt_template,
|
@@ -79,7 +79,7 @@ with Pipeline(name="ultrafeedback") as pipeline:
|
|
79 |
api_key=os.environ["API_KEY"],
|
80 |
generation_kwargs={{
|
81 |
"temperature": 0.01,
|
82 |
-
"max_new_tokens":
|
83 |
}},
|
84 |
),
|
85 |
aspect=aspect,
|
@@ -123,7 +123,7 @@ with Pipeline(name="ultrafeedback") as pipeline:
|
|
123 |
api_key=os.environ["BASE_URL"],
|
124 |
generation_kwargs={{
|
125 |
"temperature": 0.01,
|
126 |
-
"max_new_tokens":
|
127 |
}},
|
128 |
output_mappings={{
|
129 |
"ratings": f"ratings_{{aspect}}",
|
@@ -177,7 +177,7 @@ with Pipeline(name="custom-evaluation") as pipeline:
|
|
177 |
structured_output={{"format": "json", "schema": {structured_output}}},
|
178 |
generation_kwargs={{
|
179 |
"temperature": 0.01,
|
180 |
-
"max_new_tokens":
|
181 |
}},
|
182 |
),
|
183 |
template=CUSTOM_TEMPLATE,
|
|
|
5 |
UltraFeedback,
|
6 |
)
|
7 |
|
8 |
+
from synthetic_dataset_generator.constants import BASE_URL, MAX_NUM_TOKENS, MODEL
|
9 |
from synthetic_dataset_generator.pipelines.base import _get_next_api_key
|
10 |
from synthetic_dataset_generator.utils import extract_column_names
|
11 |
|
|
|
18 |
api_key=_get_next_api_key(),
|
19 |
generation_kwargs={
|
20 |
"temperature": 0.01,
|
21 |
+
"max_new_tokens": MAX_NUM_TOKENS if not is_sample else 512,
|
22 |
},
|
23 |
),
|
24 |
aspect=aspect,
|
|
|
36 |
structured_output={"format": "json", "schema": structured_output},
|
37 |
generation_kwargs={
|
38 |
"temperature": 0.01,
|
39 |
+
"max_new_tokens": MAX_NUM_TOKENS if not is_sample else 512,
|
40 |
},
|
41 |
),
|
42 |
template=prompt_template,
|
|
|
79 |
api_key=os.environ["API_KEY"],
|
80 |
generation_kwargs={{
|
81 |
"temperature": 0.01,
|
82 |
+
"max_new_tokens": {MAX_NUM_TOKENS},
|
83 |
}},
|
84 |
),
|
85 |
aspect=aspect,
|
|
|
123 |
api_key=os.environ["BASE_URL"],
|
124 |
generation_kwargs={{
|
125 |
"temperature": 0.01,
|
126 |
+
"max_new_tokens": {MAX_NUM_TOKENS},
|
127 |
}},
|
128 |
output_mappings={{
|
129 |
"ratings": f"ratings_{{aspect}}",
|
|
|
177 |
structured_output={{"format": "json", "schema": {structured_output}}},
|
178 |
generation_kwargs={{
|
179 |
"temperature": 0.01,
|
180 |
+
"max_new_tokens": {MAX_NUM_TOKENS},
|
181 |
}},
|
182 |
),
|
183 |
template=CUSTOM_TEMPLATE,
|
src/synthetic_dataset_generator/pipelines/sft.py
CHANGED
@@ -4,6 +4,7 @@ from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
|
|
4 |
from synthetic_dataset_generator.constants import (
|
5 |
BASE_URL,
|
6 |
MAGPIE_PRE_QUERY_TEMPLATE,
|
|
|
7 |
MODEL,
|
8 |
)
|
9 |
from synthetic_dataset_generator.pipelines.base import _get_next_api_key
|
@@ -149,7 +150,7 @@ def get_prompt_generator():
|
|
149 |
base_url=BASE_URL,
|
150 |
generation_kwargs={
|
151 |
"temperature": 0.8,
|
152 |
-
"max_new_tokens":
|
153 |
"do_sample": True,
|
154 |
},
|
155 |
),
|
@@ -174,7 +175,7 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
174 |
generation_kwargs={
|
175 |
"temperature": temperature,
|
176 |
"do_sample": True,
|
177 |
-
"max_new_tokens": 256 if is_sample else
|
178 |
"stop_sequences": _STOP_SEQUENCES,
|
179 |
},
|
180 |
),
|
@@ -194,7 +195,7 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
|
|
194 |
generation_kwargs={
|
195 |
"temperature": temperature,
|
196 |
"do_sample": True,
|
197 |
-
"max_new_tokens": 256 if is_sample else
|
198 |
"stop_sequences": _STOP_SEQUENCES,
|
199 |
},
|
200 |
),
|
@@ -217,7 +218,7 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
|
217 |
api_key=_get_next_api_key(),
|
218 |
generation_kwargs={
|
219 |
"temperature": temperature,
|
220 |
-
"max_new_tokens": 256 if is_sample else
|
221 |
},
|
222 |
),
|
223 |
system_prompt=system_prompt,
|
@@ -233,7 +234,7 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
|
|
233 |
api_key=_get_next_api_key(),
|
234 |
generation_kwargs={
|
235 |
"temperature": temperature,
|
236 |
-
"max_new_tokens":
|
237 |
},
|
238 |
),
|
239 |
output_mappings={"generation": "completion"},
|
@@ -268,7 +269,7 @@ with Pipeline(name="sft") as pipeline:
|
|
268 |
generation_kwargs={{
|
269 |
"temperature": {temperature},
|
270 |
"do_sample": True,
|
271 |
-
"max_new_tokens":
|
272 |
"stop_sequences": {_STOP_SEQUENCES}
|
273 |
}},
|
274 |
api_key=os.environ["BASE_URL"],
|
|
|
4 |
from synthetic_dataset_generator.constants import (
|
5 |
BASE_URL,
|
6 |
MAGPIE_PRE_QUERY_TEMPLATE,
|
7 |
+
MAX_NUM_TOKENS,
|
8 |
MODEL,
|
9 |
)
|
10 |
from synthetic_dataset_generator.pipelines.base import _get_next_api_key
|
|
|
150 |
base_url=BASE_URL,
|
151 |
generation_kwargs={
|
152 |
"temperature": 0.8,
|
153 |
+
"max_new_tokens": MAX_NUM_TOKENS,
|
154 |
"do_sample": True,
|
155 |
},
|
156 |
),
|
|
|
175 |
generation_kwargs={
|
176 |
"temperature": temperature,
|
177 |
"do_sample": True,
|
178 |
+
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
179 |
"stop_sequences": _STOP_SEQUENCES,
|
180 |
},
|
181 |
),
|
|
|
195 |
generation_kwargs={
|
196 |
"temperature": temperature,
|
197 |
"do_sample": True,
|
198 |
+
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
199 |
"stop_sequences": _STOP_SEQUENCES,
|
200 |
},
|
201 |
),
|
|
|
218 |
api_key=_get_next_api_key(),
|
219 |
generation_kwargs={
|
220 |
"temperature": temperature,
|
221 |
+
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
222 |
},
|
223 |
),
|
224 |
system_prompt=system_prompt,
|
|
|
234 |
api_key=_get_next_api_key(),
|
235 |
generation_kwargs={
|
236 |
"temperature": temperature,
|
237 |
+
"max_new_tokens": MAX_NUM_TOKENS,
|
238 |
},
|
239 |
),
|
240 |
output_mappings={"generation": "completion"},
|
|
|
269 |
generation_kwargs={{
|
270 |
"temperature": {temperature},
|
271 |
"do_sample": True,
|
272 |
+
"max_new_tokens": {MAX_NUM_TOKENS},
|
273 |
"stop_sequences": {_STOP_SEQUENCES}
|
274 |
}},
|
275 |
api_key=os.environ["BASE_URL"],
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -9,7 +9,7 @@ from distilabel.steps.tasks import (
|
|
9 |
)
|
10 |
from pydantic import BaseModel, Field
|
11 |
|
12 |
-
from synthetic_dataset_generator.constants import BASE_URL, MODEL
|
13 |
from synthetic_dataset_generator.pipelines.base import _get_next_api_key
|
14 |
from synthetic_dataset_generator.utils import get_preprocess_labels
|
15 |
|
@@ -69,7 +69,7 @@ def get_prompt_generator():
|
|
69 |
structured_output={"format": "json", "schema": TextClassificationTask},
|
70 |
generation_kwargs={
|
71 |
"temperature": 0.8,
|
72 |
-
"max_new_tokens":
|
73 |
"do_sample": True,
|
74 |
},
|
75 |
),
|
@@ -88,7 +88,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
|
|
88 |
api_key=_get_next_api_key(),
|
89 |
generation_kwargs={
|
90 |
"temperature": temperature,
|
91 |
-
"max_new_tokens": 256 if is_sample else
|
92 |
"do_sample": True,
|
93 |
"top_k": 50,
|
94 |
"top_p": 0.95,
|
@@ -110,7 +110,7 @@ def get_labeller_generator(system_prompt, labels, num_labels):
|
|
110 |
api_key=_get_next_api_key(),
|
111 |
generation_kwargs={
|
112 |
"temperature": 0.7,
|
113 |
-
"max_new_tokens":
|
114 |
},
|
115 |
),
|
116 |
context=system_prompt,
|
@@ -159,7 +159,7 @@ with Pipeline(name="textcat") as pipeline:
|
|
159 |
api_key=os.environ["API_KEY"],
|
160 |
generation_kwargs={{
|
161 |
"temperature": {temperature},
|
162 |
-
"max_new_tokens":
|
163 |
"do_sample": True,
|
164 |
"top_k": 50,
|
165 |
"top_p": 0.95,
|
@@ -203,7 +203,7 @@ with Pipeline(name="textcat") as pipeline:
|
|
203 |
api_key=os.environ["API_KEY"],
|
204 |
generation_kwargs={{
|
205 |
"temperature": 0.8,
|
206 |
-
"max_new_tokens":
|
207 |
}},
|
208 |
),
|
209 |
n={num_labels},
|
|
|
9 |
)
|
10 |
from pydantic import BaseModel, Field
|
11 |
|
12 |
+
from synthetic_dataset_generator.constants import BASE_URL, MAX_NUM_TOKENS, MODEL
|
13 |
from synthetic_dataset_generator.pipelines.base import _get_next_api_key
|
14 |
from synthetic_dataset_generator.utils import get_preprocess_labels
|
15 |
|
|
|
69 |
structured_output={"format": "json", "schema": TextClassificationTask},
|
70 |
generation_kwargs={
|
71 |
"temperature": 0.8,
|
72 |
+
"max_new_tokens": MAX_NUM_TOKENS,
|
73 |
"do_sample": True,
|
74 |
},
|
75 |
),
|
|
|
88 |
api_key=_get_next_api_key(),
|
89 |
generation_kwargs={
|
90 |
"temperature": temperature,
|
91 |
+
"max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
|
92 |
"do_sample": True,
|
93 |
"top_k": 50,
|
94 |
"top_p": 0.95,
|
|
|
110 |
api_key=_get_next_api_key(),
|
111 |
generation_kwargs={
|
112 |
"temperature": 0.7,
|
113 |
+
"max_new_tokens": MAX_NUM_TOKENS,
|
114 |
},
|
115 |
),
|
116 |
context=system_prompt,
|
|
|
159 |
api_key=os.environ["API_KEY"],
|
160 |
generation_kwargs={{
|
161 |
"temperature": {temperature},
|
162 |
+
"max_new_tokens": {MAX_NUM_TOKENS},
|
163 |
"do_sample": True,
|
164 |
"top_k": 50,
|
165 |
"top_p": 0.95,
|
|
|
203 |
api_key=os.environ["API_KEY"],
|
204 |
generation_kwargs={{
|
205 |
"temperature": 0.8,
|
206 |
+
"max_new_tokens": {MAX_NUM_TOKENS},
|
207 |
}},
|
208 |
),
|
209 |
n={num_labels},
|