davidberenstein1957 HF staff commited on
Commit
7b7c1be
Β·
1 Parent(s): dc56474

add MAX_NUM_ROWS

Browse files
README.md CHANGED
@@ -79,6 +79,12 @@ demo.launch()
79
 
80
  Optionally, you can set the following environment variables to customize the generation process.
81
 
 
 
 
 
 
 
82
  - `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`.
83
  - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`.
84
  - `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`, `sk-...`.
 
79
 
80
  Optionally, you can set the following environment variables to customize the generation process.
81
 
82
+ - `MAX_NUM_TOKENS`: The maximum number of tokens to generate, defaults to `2048`.
83
+ - `MAX_NUM_ROWS`: The maximum number of rows to generate, defaults to `1000`.
84
+ - `DEFAULT_BATCH_SIZE`: The default batch size to use for generating the dataset, defaults to `5`.
85
+
86
+ Optionally, you can use different models and APIs.
87
+
88
  - `BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api-inference.huggingface.co/v1/`, `https://api.openai.com/v1/`.
89
  - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`.
90
  - `API_KEY`: The API key to use for the corresponding API, e.g. `hf_...`, `sk-...`.
src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -8,6 +8,7 @@ from datasets import Dataset, concatenate_datasets, load_dataset
8
  from gradio import OAuthToken
9
  from huggingface_hub import HfApi, upload_file
10
 
 
11
  from synthetic_dataset_generator.utils import get_argilla_client
12
 
13
 
@@ -136,3 +137,12 @@ def show_success_message(org_name, repo_name) -> gr.Markdown:
136
 
137
  def hide_success_message() -> gr.Markdown:
138
  return gr.Markdown(value="")
 
 
 
 
 
 
 
 
 
 
8
  from gradio import OAuthToken
9
  from huggingface_hub import HfApi, upload_file
10
 
11
+ from synthetic_dataset_generator.constants import MAX_NUM_ROWS
12
  from synthetic_dataset_generator.utils import get_argilla_client
13
 
14
 
 
137
 
138
  def hide_success_message() -> gr.Markdown:
139
  return gr.Markdown(value="")
140
+
141
+
142
+ def test_max_num_rows(num_rows: int) -> int:
143
+ if num_rows > MAX_NUM_ROWS:
144
+ num_rows = MAX_NUM_ROWS
145
+ gr.Info(
146
+ f"Number of rows is larger than the configured maximum. Setting number of rows to {MAX_NUM_ROWS}. Set environment variable `MAX_NUM_ROWS` to change this behavior."
147
+ )
148
+ return num_rows
src/synthetic_dataset_generator/apps/eval.py CHANGED
@@ -22,6 +22,7 @@ from synthetic_dataset_generator.apps.base import (
22
  hide_success_message,
23
  push_pipeline_code_to_hub,
24
  show_success_message,
 
25
  validate_argilla_user_workspace_dataset,
26
  validate_push_to_hub,
27
  )
@@ -303,6 +304,7 @@ def _evaluate_dataset(
303
  num_rows: int = 10,
304
  is_sample: bool = False,
305
  ):
 
306
  if eval_type == "chat-eval":
307
  dataframe = evaluate_instruction_response(
308
  dataframe=dataframe,
 
22
  hide_success_message,
23
  push_pipeline_code_to_hub,
24
  show_success_message,
25
+ test_max_num_rows,
26
  validate_argilla_user_workspace_dataset,
27
  validate_push_to_hub,
28
  )
 
304
  num_rows: int = 10,
305
  is_sample: bool = False,
306
  ):
307
+ num_rows = test_max_num_rows(num_rows)
308
  if eval_type == "chat-eval":
309
  dataframe = evaluate_instruction_response(
310
  dataframe=dataframe,
src/synthetic_dataset_generator/apps/sft.py CHANGED
@@ -14,6 +14,7 @@ from synthetic_dataset_generator.apps.base import (
14
  hide_success_message,
15
  push_pipeline_code_to_hub,
16
  show_success_message,
 
17
  validate_argilla_user_workspace_dataset,
18
  validate_push_to_hub,
19
  )
@@ -100,6 +101,7 @@ def generate_dataset(
100
  is_sample: bool = False,
101
  progress=gr.Progress(),
102
  ) -> pd.DataFrame:
 
103
  progress(0.0, desc="(1/2) Generating instructions")
104
  magpie_generator = get_magpie_generator(
105
  system_prompt, num_turns, temperature, is_sample
 
14
  hide_success_message,
15
  push_pipeline_code_to_hub,
16
  show_success_message,
17
+ test_max_num_rows,
18
  validate_argilla_user_workspace_dataset,
19
  validate_push_to_hub,
20
  )
 
101
  is_sample: bool = False,
102
  progress=gr.Progress(),
103
  ) -> pd.DataFrame:
104
+ num_rows = test_max_num_rows(num_rows)
105
  progress(0.0, desc="(1/2) Generating instructions")
106
  magpie_generator = get_magpie_generator(
107
  system_prompt, num_turns, temperature, is_sample
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -15,6 +15,7 @@ from src.synthetic_dataset_generator.apps.base import (
15
  hide_success_message,
16
  push_pipeline_code_to_hub,
17
  show_success_message,
 
18
  validate_argilla_user_workspace_dataset,
19
  validate_push_to_hub,
20
  )
@@ -94,6 +95,7 @@ def generate_dataset(
94
  is_sample: bool = False,
95
  progress=gr.Progress(),
96
  ) -> pd.DataFrame:
 
97
  progress(0.0, desc="(1/2) Generating dataset")
98
  labels = get_preprocess_labels(labels)
99
  textcat_generator = get_textcat_generator(
 
15
  hide_success_message,
16
  push_pipeline_code_to_hub,
17
  show_success_message,
18
+ test_max_num_rows,
19
  validate_argilla_user_workspace_dataset,
20
  validate_push_to_hub,
21
  )
 
95
  is_sample: bool = False,
96
  progress=gr.Progress(),
97
  ) -> pd.DataFrame:
98
+ num_rows = test_max_num_rows(num_rows)
99
  progress(0.0, desc="(1/2) Generating dataset")
100
  labels = get_preprocess_labels(labels)
101
  textcat_generator = get_textcat_generator(
src/synthetic_dataset_generator/constants.py CHANGED
@@ -15,7 +15,9 @@ if HF_TOKEN is None:
15
  )
16
 
17
  # Inference
18
- DEFAULT_BATCH_SIZE = 5
 
 
19
  MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
20
  API_KEYS = (
21
  [os.getenv("HF_TOKEN")]
 
15
  )
16
 
17
  # Inference
18
+ MAX_NUM_TOKENS = os.getenv("MAX_NUM_TOKENS", 2048)
19
+ MAX_NUM_ROWS: str | int = os.getenv("MAX_NUM_ROWS", 1000)
20
+ DEFAULT_BATCH_SIZE = os.getenv("DEFAULT_BATCH_SIZE", 5)
21
  MODEL = os.getenv("MODEL", "meta-llama/Meta-Llama-3.1-8B-Instruct")
22
  API_KEYS = (
23
  [os.getenv("HF_TOKEN")]
src/synthetic_dataset_generator/pipelines/eval.py CHANGED
@@ -5,7 +5,7 @@ from distilabel.steps.tasks import (
5
  UltraFeedback,
6
  )
7
 
8
- from synthetic_dataset_generator.constants import BASE_URL, MODEL
9
  from synthetic_dataset_generator.pipelines.base import _get_next_api_key
10
  from synthetic_dataset_generator.utils import extract_column_names
11
 
@@ -18,7 +18,7 @@ def get_ultrafeedback_evaluator(aspect, is_sample):
18
  api_key=_get_next_api_key(),
19
  generation_kwargs={
20
  "temperature": 0.01,
21
- "max_new_tokens": 2048 if not is_sample else 512,
22
  },
23
  ),
24
  aspect=aspect,
@@ -36,7 +36,7 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample)
36
  structured_output={"format": "json", "schema": structured_output},
37
  generation_kwargs={
38
  "temperature": 0.01,
39
- "max_new_tokens": 2048 if not is_sample else 512,
40
  },
41
  ),
42
  template=prompt_template,
@@ -79,7 +79,7 @@ with Pipeline(name="ultrafeedback") as pipeline:
79
  api_key=os.environ["API_KEY"],
80
  generation_kwargs={{
81
  "temperature": 0.01,
82
- "max_new_tokens": 2048,
83
  }},
84
  ),
85
  aspect=aspect,
@@ -123,7 +123,7 @@ with Pipeline(name="ultrafeedback") as pipeline:
123
  api_key=os.environ["BASE_URL"],
124
  generation_kwargs={{
125
  "temperature": 0.01,
126
- "max_new_tokens": 2048,
127
  }},
128
  output_mappings={{
129
  "ratings": f"ratings_{{aspect}}",
@@ -177,7 +177,7 @@ with Pipeline(name="custom-evaluation") as pipeline:
177
  structured_output={{"format": "json", "schema": {structured_output}}},
178
  generation_kwargs={{
179
  "temperature": 0.01,
180
- "max_new_tokens": 2048,
181
  }},
182
  ),
183
  template=CUSTOM_TEMPLATE,
 
5
  UltraFeedback,
6
  )
7
 
8
+ from synthetic_dataset_generator.constants import BASE_URL, MAX_NUM_TOKENS, MODEL
9
  from synthetic_dataset_generator.pipelines.base import _get_next_api_key
10
  from synthetic_dataset_generator.utils import extract_column_names
11
 
 
18
  api_key=_get_next_api_key(),
19
  generation_kwargs={
20
  "temperature": 0.01,
21
+ "max_new_tokens": MAX_NUM_TOKENS if not is_sample else 512,
22
  },
23
  ),
24
  aspect=aspect,
 
36
  structured_output={"format": "json", "schema": structured_output},
37
  generation_kwargs={
38
  "temperature": 0.01,
39
+ "max_new_tokens": MAX_NUM_TOKENS if not is_sample else 512,
40
  },
41
  ),
42
  template=prompt_template,
 
79
  api_key=os.environ["API_KEY"],
80
  generation_kwargs={{
81
  "temperature": 0.01,
82
+ "max_new_tokens": {MAX_NUM_TOKENS},
83
  }},
84
  ),
85
  aspect=aspect,
 
123
  api_key=os.environ["BASE_URL"],
124
  generation_kwargs={{
125
  "temperature": 0.01,
126
+ "max_new_tokens": {MAX_NUM_TOKENS},
127
  }},
128
  output_mappings={{
129
  "ratings": f"ratings_{{aspect}}",
 
177
  structured_output={{"format": "json", "schema": {structured_output}}},
178
  generation_kwargs={{
179
  "temperature": 0.01,
180
+ "max_new_tokens": {MAX_NUM_TOKENS},
181
  }},
182
  ),
183
  template=CUSTOM_TEMPLATE,
src/synthetic_dataset_generator/pipelines/sft.py CHANGED
@@ -4,6 +4,7 @@ from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
4
  from synthetic_dataset_generator.constants import (
5
  BASE_URL,
6
  MAGPIE_PRE_QUERY_TEMPLATE,
 
7
  MODEL,
8
  )
9
  from synthetic_dataset_generator.pipelines.base import _get_next_api_key
@@ -149,7 +150,7 @@ def get_prompt_generator():
149
  base_url=BASE_URL,
150
  generation_kwargs={
151
  "temperature": 0.8,
152
- "max_new_tokens": 2048,
153
  "do_sample": True,
154
  },
155
  ),
@@ -174,7 +175,7 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
174
  generation_kwargs={
175
  "temperature": temperature,
176
  "do_sample": True,
177
- "max_new_tokens": 256 if is_sample else 512,
178
  "stop_sequences": _STOP_SEQUENCES,
179
  },
180
  ),
@@ -194,7 +195,7 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
194
  generation_kwargs={
195
  "temperature": temperature,
196
  "do_sample": True,
197
- "max_new_tokens": 256 if is_sample else 1024,
198
  "stop_sequences": _STOP_SEQUENCES,
199
  },
200
  ),
@@ -217,7 +218,7 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
217
  api_key=_get_next_api_key(),
218
  generation_kwargs={
219
  "temperature": temperature,
220
- "max_new_tokens": 256 if is_sample else 1024,
221
  },
222
  ),
223
  system_prompt=system_prompt,
@@ -233,7 +234,7 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
233
  api_key=_get_next_api_key(),
234
  generation_kwargs={
235
  "temperature": temperature,
236
- "max_new_tokens": 2048,
237
  },
238
  ),
239
  output_mappings={"generation": "completion"},
@@ -268,7 +269,7 @@ with Pipeline(name="sft") as pipeline:
268
  generation_kwargs={{
269
  "temperature": {temperature},
270
  "do_sample": True,
271
- "max_new_tokens": 2048,
272
  "stop_sequences": {_STOP_SEQUENCES}
273
  }},
274
  api_key=os.environ["BASE_URL"],
 
4
  from synthetic_dataset_generator.constants import (
5
  BASE_URL,
6
  MAGPIE_PRE_QUERY_TEMPLATE,
7
+ MAX_NUM_TOKENS,
8
  MODEL,
9
  )
10
  from synthetic_dataset_generator.pipelines.base import _get_next_api_key
 
150
  base_url=BASE_URL,
151
  generation_kwargs={
152
  "temperature": 0.8,
153
+ "max_new_tokens": MAX_NUM_TOKENS,
154
  "do_sample": True,
155
  },
156
  ),
 
175
  generation_kwargs={
176
  "temperature": temperature,
177
  "do_sample": True,
178
+ "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
179
  "stop_sequences": _STOP_SEQUENCES,
180
  },
181
  ),
 
195
  generation_kwargs={
196
  "temperature": temperature,
197
  "do_sample": True,
198
+ "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
199
  "stop_sequences": _STOP_SEQUENCES,
200
  },
201
  ),
 
218
  api_key=_get_next_api_key(),
219
  generation_kwargs={
220
  "temperature": temperature,
221
+ "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
222
  },
223
  ),
224
  system_prompt=system_prompt,
 
234
  api_key=_get_next_api_key(),
235
  generation_kwargs={
236
  "temperature": temperature,
237
+ "max_new_tokens": MAX_NUM_TOKENS,
238
  },
239
  ),
240
  output_mappings={"generation": "completion"},
 
269
  generation_kwargs={{
270
  "temperature": {temperature},
271
  "do_sample": True,
272
+ "max_new_tokens": {MAX_NUM_TOKENS},
273
  "stop_sequences": {_STOP_SEQUENCES}
274
  }},
275
  api_key=os.environ["BASE_URL"],
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -9,7 +9,7 @@ from distilabel.steps.tasks import (
9
  )
10
  from pydantic import BaseModel, Field
11
 
12
- from synthetic_dataset_generator.constants import BASE_URL, MODEL
13
  from synthetic_dataset_generator.pipelines.base import _get_next_api_key
14
  from synthetic_dataset_generator.utils import get_preprocess_labels
15
 
@@ -69,7 +69,7 @@ def get_prompt_generator():
69
  structured_output={"format": "json", "schema": TextClassificationTask},
70
  generation_kwargs={
71
  "temperature": 0.8,
72
- "max_new_tokens": 2048,
73
  "do_sample": True,
74
  },
75
  ),
@@ -88,7 +88,7 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
88
  api_key=_get_next_api_key(),
89
  generation_kwargs={
90
  "temperature": temperature,
91
- "max_new_tokens": 256 if is_sample else 2048,
92
  "do_sample": True,
93
  "top_k": 50,
94
  "top_p": 0.95,
@@ -110,7 +110,7 @@ def get_labeller_generator(system_prompt, labels, num_labels):
110
  api_key=_get_next_api_key(),
111
  generation_kwargs={
112
  "temperature": 0.7,
113
- "max_new_tokens": 2048,
114
  },
115
  ),
116
  context=system_prompt,
@@ -159,7 +159,7 @@ with Pipeline(name="textcat") as pipeline:
159
  api_key=os.environ["API_KEY"],
160
  generation_kwargs={{
161
  "temperature": {temperature},
162
- "max_new_tokens": 2048,
163
  "do_sample": True,
164
  "top_k": 50,
165
  "top_p": 0.95,
@@ -203,7 +203,7 @@ with Pipeline(name="textcat") as pipeline:
203
  api_key=os.environ["API_KEY"],
204
  generation_kwargs={{
205
  "temperature": 0.8,
206
- "max_new_tokens": 2048,
207
  }},
208
  ),
209
  n={num_labels},
 
9
  )
10
  from pydantic import BaseModel, Field
11
 
12
+ from synthetic_dataset_generator.constants import BASE_URL, MAX_NUM_TOKENS, MODEL
13
  from synthetic_dataset_generator.pipelines.base import _get_next_api_key
14
  from synthetic_dataset_generator.utils import get_preprocess_labels
15
 
 
69
  structured_output={"format": "json", "schema": TextClassificationTask},
70
  generation_kwargs={
71
  "temperature": 0.8,
72
+ "max_new_tokens": MAX_NUM_TOKENS,
73
  "do_sample": True,
74
  },
75
  ),
 
88
  api_key=_get_next_api_key(),
89
  generation_kwargs={
90
  "temperature": temperature,
91
+ "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
92
  "do_sample": True,
93
  "top_k": 50,
94
  "top_p": 0.95,
 
110
  api_key=_get_next_api_key(),
111
  generation_kwargs={
112
  "temperature": 0.7,
113
+ "max_new_tokens": MAX_NUM_TOKENS,
114
  },
115
  ),
116
  context=system_prompt,
 
159
  api_key=os.environ["API_KEY"],
160
  generation_kwargs={{
161
  "temperature": {temperature},
162
+ "max_new_tokens": {MAX_NUM_TOKENS},
163
  "do_sample": True,
164
  "top_k": 50,
165
  "top_p": 0.95,
 
203
  api_key=os.environ["API_KEY"],
204
  generation_kwargs={{
205
  "temperature": 0.8,
206
+ "max_new_tokens": {MAX_NUM_TOKENS},
207
  }},
208
  ),
209
  n={num_labels},