David Quispe commited on
Commit
b2c01de
·
unverified ·
2 Parent(s): bf5ba21 3b7b628

Merge branch 'argilla-io:main' into main

Browse files
.DS_Store DELETED
Binary file (8.2 kB)
 
README.md CHANGED
@@ -86,12 +86,14 @@ You can set the following environment variables to customize the generation proc
86
  Optionally, you can use different API providers and models.
87
 
88
  - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
89
- - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the provided `HF_TOKEN` environment variable.
90
  - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
91
  - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
92
  - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
93
  - `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
94
 
 
 
95
  SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
96
 
97
  - `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
 
86
  Optionally, you can use different API providers and models.
87
 
88
  - `MODEL`: The model to use for generating the dataset, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`, `gpt-4o`, `llama3.1`.
89
+ - `API_KEY`: The API key to use for the generation API, e.g. `hf_...`, `sk-...`. If not provided, it will default to the `HF_TOKEN` environment variable.
90
  - `OPENAI_BASE_URL`: The base URL for any OpenAI compatible API, e.g. `https://api.openai.com/v1/`.
91
  - `OLLAMA_BASE_URL`: The base URL for any Ollama compatible API, e.g. `http://127.0.0.1:11434/`.
92
  - `HUGGINGFACE_BASE_URL`: The base URL for any Hugging Face compatible API, e.g. TGI server or Dedicated Inference Endpoints. If you want to use serverless inference, only set the `MODEL`.
93
  - `VLLM_BASE_URL`: The base URL for any VLLM compatible API, e.g. `http://localhost:8000/`.
94
 
95
+ To use a specific model exclusively for generating completions, set the corresponding environment variables by appending `_COMPLETION` to the ones mentioned earlier. For example, you can use `MODEL_COMPLETION` and `OPENAI_BASE_URL_COMPLETION`.
96
+
97
  SFT and Chat Data generation is not supported with OpenAI Endpoints. Additionally, you need to configure it per model family based on their prompt templates using the right `TOKENIZER_ID` and `MAGPIE_PRE_QUERY_TEMPLATE` environment variables.
98
 
99
  - `TOKENIZER_ID`: The tokenizer ID to use for the magpie pipeline, e.g. `meta-llama/Meta-Llama-3.1-8B-Instruct`.
examples/hf-serverless-deployment-deepseek.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ import os
8
+
9
+ from synthetic_dataset_generator import launch
10
+
11
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["MODEL"] = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" # use model for instructions
13
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "<|begin▁of▁sentence|>User: " # use the custom template for the model
14
+
15
+
16
+ launch()
examples/hf-serverless-deployment.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  from synthetic_dataset_generator import launch
10
 
11
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
- os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use instruct model
13
  os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
14
 
15
  launch()
 
9
  from synthetic_dataset_generator import launch
10
 
11
  os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for generation
13
  os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
14
 
15
  launch()
examples/hf-serverless-different-model-for-completion.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ import os
8
+
9
+ from synthetic_dataset_generator import launch
10
+
11
+ os.environ["HF_TOKEN"] = "hf_..." # push the data to huggingface
12
+ os.environ["MODEL"] = "meta-llama/Llama-3.1-8B-Instruct" # use model for instruction generation
13
+ os.environ["MODEL_COMPLETION"] = "meta-llama/Llama-3.1-70B-Instruct" # use model for completion generation
14
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # use the template for the model
15
+
16
+ launch()
examples/ollama-different-model-for-completion.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11,<3.12"
3
+ # dependencies = [
4
+ # "synthetic-dataset-generator",
5
+ # ]
6
+ # ///
7
+ # ollama serve
8
+ # ollama run llama3.2
9
+ # ollama run llama3.2:1b
10
+ import os
11
+
12
+ from synthetic_dataset_generator import launch
13
+
14
+ os.environ["OLLAMA_BASE_URL"] = (
15
+ "http://127.0.0.1:11434/" # in this case, the same base url for both models
16
+ )
17
+
18
+ os.environ["MODEL"] = "llama3.2" # model for instruction generation
19
+ os.environ["MODEL_COMPLETION"] = "llama3.2:1b" # model for completion generation
20
+
21
+ os.environ["TOKENIZER_ID"] = "meta-llama/Llama-3.2-1B-Instruct" # tokenizer for instruction generation
22
+ os.environ["TOKENIZER_ID_COMPLETION"] = "meta-llama/Llama-3.2-3B-Instruct" # tokenizer for completion generation
23
+
24
+ os.environ["MAGPIE_PRE_QUERY_TEMPLATE"] = "llama3" # magpie template required for instruction generation
25
+
26
+ launch()
src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -1,12 +1,16 @@
1
  import io
2
  import uuid
 
3
  from typing import Union
4
 
5
  import argilla as rg
6
  import gradio as gr
7
- from datasets import Dataset, concatenate_datasets, load_dataset
 
8
  from gradio import OAuthToken
9
  from huggingface_hub import HfApi, upload_file, repo_exists
 
 
10
 
11
  from synthetic_dataset_generator.constants import MAX_NUM_ROWS
12
  from synthetic_dataset_generator.utils import get_argilla_client
@@ -64,7 +68,7 @@ def push_pipeline_code_to_hub(
64
  progress(1.0, desc="Pipeline code uploaded")
65
 
66
 
67
- def validate_push_to_hub(org_name, repo_name):
68
  repo_id = (
69
  f"{org_name}/{repo_name}"
70
  if repo_name is not None and org_name is not None
@@ -93,7 +97,7 @@ def combine_datasets(
93
  return dataset
94
 
95
 
96
- def show_success_message(org_name, repo_name) -> gr.Markdown:
97
  client = get_argilla_client()
98
  if client is None:
99
  return gr.Markdown(
@@ -179,3 +183,81 @@ def get_iframe(hub_repo_id: str) -> str:
179
  ></iframe>
180
  """
181
  return iframe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import io
2
  import uuid
3
+ from tqdm import tqdm
4
  from typing import Union
5
 
6
  import argilla as rg
7
  import gradio as gr
8
+ import pandas as pd
9
+ from datasets import Dataset, concatenate_datasets, get_dataset_config_names, get_dataset_split_names, load_dataset
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file, repo_exists
12
+ from unstructured.chunking.title import chunk_by_title
13
+ from unstructured.partition.auto import partition
14
 
15
  from synthetic_dataset_generator.constants import MAX_NUM_ROWS
16
  from synthetic_dataset_generator.utils import get_argilla_client
 
68
  progress(1.0, desc="Pipeline code uploaded")
69
 
70
 
71
+ def validate_push_to_hub(org_name: str, repo_name: str):
72
  repo_id = (
73
  f"{org_name}/{repo_name}"
74
  if repo_name is not None and org_name is not None
 
97
  return dataset
98
 
99
 
100
+ def show_success_message(org_name: str, repo_name: str) -> gr.Markdown:
101
  client = get_argilla_client()
102
  if client is None:
103
  return gr.Markdown(
 
183
  ></iframe>
184
  """
185
  return iframe
186
+
187
+
188
+ def _get_valid_columns(dataframe: pd.DataFrame):
189
+ doc_valid_columns = []
190
+
191
+ for col in dataframe.columns:
192
+ sample_val = dataframe[col].iloc[0]
193
+ if isinstance(sample_val, str):
194
+ doc_valid_columns.append(col)
195
+
196
+ return doc_valid_columns
197
+
198
+
199
+ def load_dataset_from_hub(
200
+ repo_id: str,
201
+ num_rows: int = 10,
202
+ token: Union[OAuthToken, None] = None,
203
+ progress=gr.Progress(track_tqdm=True),
204
+ ):
205
+ if not repo_id:
206
+ raise gr.Error("Please provide a Hub repo ID")
207
+ subsets = get_dataset_config_names(repo_id, token=token)
208
+ splits = get_dataset_split_names(repo_id, subsets[0], token=token)
209
+ ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
210
+ rows = []
211
+ for idx, row in enumerate(tqdm(ds, desc="Loading the dataset", total=num_rows)):
212
+ rows.append(row)
213
+ if idx == num_rows:
214
+ break
215
+ ds = Dataset.from_list(rows)
216
+ dataframe = ds.to_pandas()
217
+ doc_valid_columns = _get_valid_columns(dataframe)
218
+ col_doc = doc_valid_columns[0] if doc_valid_columns else ""
219
+ return (
220
+ dataframe,
221
+ gr.Dropdown(
222
+ choices=doc_valid_columns,
223
+ label="Documents column",
224
+ value=col_doc,
225
+ interactive=(False if col_doc == "" else True),
226
+ multiselect=False,
227
+ ),
228
+ )
229
+
230
+
231
+ def preprocess_input_data(
232
+ file_paths: list[str], num_rows: int, progress=gr.Progress(track_tqdm=True)
233
+ ):
234
+ if not file_paths:
235
+ raise gr.Error("Please provide an input file")
236
+
237
+ data = {}
238
+ total_chunks = 0
239
+
240
+ for file_path in tqdm(file_paths, desc="Processing files", total=len(file_paths)):
241
+ partitioned_file = partition(filename=file_path)
242
+ chunks = [str(chunk) for chunk in chunk_by_title(partitioned_file)]
243
+ data[file_path] = chunks
244
+ total_chunks += len(chunks)
245
+ if total_chunks >= num_rows:
246
+ break
247
+
248
+ dataframe = pd.DataFrame.from_records(
249
+ [(k, v) for k, values in data.items() for v in values],
250
+ columns=["filename", "chunks"],
251
+ )
252
+ col_doc = "chunks"
253
+
254
+ return (
255
+ dataframe,
256
+ gr.Dropdown(
257
+ choices=["chunks"],
258
+ label="Documents column",
259
+ value=col_doc,
260
+ interactive=(False if col_doc == "" else True),
261
+ multiselect=False,
262
+ ),
263
+ )
src/synthetic_dataset_generator/apps/chat.py CHANGED
@@ -1,4 +1,5 @@
1
  import ast
 
2
  import random
3
  import uuid
4
  from typing import Dict, List, Union
@@ -8,11 +9,15 @@ import gradio as gr
8
  import pandas as pd
9
  from datasets import Dataset
10
  from distilabel.distiset import Distiset
 
 
11
  from huggingface_hub import HfApi
12
 
13
  from synthetic_dataset_generator.apps.base import (
14
  combine_datasets,
15
  hide_success_message,
 
 
16
  push_pipeline_code_to_hub,
17
  show_success_message,
18
  test_max_num_rows,
@@ -23,21 +28,25 @@ from synthetic_dataset_generator.constants import (
23
  BASE_URL,
24
  DEFAULT_BATCH_SIZE,
25
  MODEL,
 
26
  SFT_AVAILABLE,
27
  )
28
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
29
  from synthetic_dataset_generator.pipelines.chat import (
30
  DEFAULT_DATASET_DESCRIPTIONS,
31
  generate_pipeline_code,
 
32
  get_magpie_generator,
33
  get_prompt_generator,
34
  get_response_generator,
 
35
  )
36
  from synthetic_dataset_generator.pipelines.embeddings import (
37
  get_embeddings,
38
  get_sentence_embedding_dimensions,
39
  )
40
  from synthetic_dataset_generator.utils import (
 
41
  get_argilla_client,
42
  get_org_dropdown,
43
  get_random_repo_name,
@@ -45,6 +54,14 @@ from synthetic_dataset_generator.utils import (
45
  )
46
 
47
 
 
 
 
 
 
 
 
 
48
  def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
49
  def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
50
  return ast.literal_eval(
@@ -60,7 +77,7 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
60
  return dataframe
61
 
62
 
63
- def generate_system_prompt(dataset_description, progress=gr.Progress()):
64
  progress(0.1, desc="Initializing")
65
  generate_description = get_prompt_generator()
66
  progress(0.5, desc="Generating")
@@ -77,42 +94,73 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
77
  return result
78
 
79
 
80
- def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()):
81
- progress(0.1, desc="Generating sample dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  dataframe = generate_dataset(
 
 
83
  system_prompt=system_prompt,
 
84
  num_turns=num_turns,
85
- num_rows=10,
86
- progress=progress,
87
  is_sample=True,
88
  )
89
  progress(1.0, desc="Sample dataset generated")
90
  return dataframe
91
 
92
 
93
- def _get_dataframe():
94
- return gr.Dataframe(
95
- headers=["prompt", "completion"],
96
- wrap=True,
97
- interactive=False,
98
- )
99
-
100
-
101
- def generate_dataset(
102
  system_prompt: str,
103
  num_turns: int = 1,
104
  num_rows: int = 10,
105
  temperature: float = 0.9,
 
106
  is_sample: bool = False,
107
  progress=gr.Progress(),
108
  ) -> pd.DataFrame:
109
  num_rows = test_max_num_rows(num_rows)
110
  progress(0.0, desc="(1/2) Generating instructions")
111
- magpie_generator = get_magpie_generator(
112
- system_prompt, num_turns, temperature, is_sample
113
- )
114
  response_generator = get_response_generator(
115
- system_prompt, num_turns, temperature, is_sample
 
 
 
116
  )
117
  total_steps: int = num_rows * 2
118
  batch_size = DEFAULT_BATCH_SIZE
@@ -217,6 +265,180 @@ def generate_dataset(
217
  return dataframe
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def push_dataset_to_hub(
221
  dataframe: pd.DataFrame,
222
  org_name: str,
@@ -251,23 +473,48 @@ def push_dataset_to_hub(
251
  def push_dataset(
252
  org_name: str,
253
  repo_name: str,
 
 
 
 
254
  system_prompt: str,
 
255
  num_turns: int = 1,
256
  num_rows: int = 10,
257
- private: bool = False,
258
  temperature: float = 0.9,
 
259
  pipeline_code: str = "",
260
  oauth_token: Union[gr.OAuthToken, None] = None,
261
  progress=gr.Progress(),
262
  ) -> pd.DataFrame:
 
 
 
 
 
 
 
 
 
 
 
263
  dataframe = generate_dataset(
 
 
264
  system_prompt=system_prompt,
 
265
  num_turns=num_turns,
266
  num_rows=num_rows,
267
  temperature=temperature,
 
268
  )
269
  push_dataset_to_hub(
270
- dataframe, org_name, repo_name, oauth_token, private, pipeline_code
 
 
 
 
 
271
  )
272
  try:
273
  progress(0.1, desc="Setting up user and workspace")
@@ -390,6 +637,28 @@ def push_dataset(
390
  return ""
391
 
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  def show_pipeline_code_visibility():
394
  return {pipeline_code_ui: gr.Accordion(visible=True)}
395
 
@@ -398,6 +667,11 @@ def hide_pipeline_code_visibility():
398
  return {pipeline_code_ui: gr.Accordion(visible=False)}
399
 
400
 
 
 
 
 
 
401
  ######################
402
  # Gradio UI
403
  ######################
@@ -417,29 +691,85 @@ with gr.Blocks() as app:
417
  )
418
  )
419
  else:
420
- gr.Markdown(value="## 1. Describe the dataset you want")
421
- with gr.Row():
422
  with gr.Column(scale=2):
423
- dataset_description = gr.Textbox(
424
- label="Dataset description",
425
- placeholder="Give a precise description of your desired dataset.",
426
- )
427
- with gr.Row():
428
- clear_btn_part = gr.Button(
429
- "Clear",
430
- variant="secondary",
431
- )
432
- load_btn = gr.Button(
433
- "Create",
434
- variant="primary",
435
- )
436
- with gr.Column(scale=3):
437
- examples = gr.Examples(
438
- examples=DEFAULT_DATASET_DESCRIPTIONS,
439
- inputs=[dataset_description],
440
- cache_examples=False,
441
- label="Examples",
442
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
  gr.HTML(value="<hr>")
445
  gr.Markdown(value="## 2. Configure your dataset")
@@ -449,6 +779,16 @@ with gr.Blocks() as app:
449
  label="System prompt",
450
  placeholder="You are a helpful assistant.",
451
  )
 
 
 
 
 
 
 
 
 
 
452
  num_turns = gr.Number(
453
  value=1,
454
  label="Number of turns in the conversation",
@@ -489,11 +829,20 @@ with gr.Blocks() as app:
489
  temperature = gr.Slider(
490
  label="Temperature",
491
  minimum=0.1,
492
- maximum=1,
493
  value=0.9,
494
  step=0.1,
495
  interactive=True,
496
  )
 
 
 
 
 
 
 
 
 
497
  private = gr.Checkbox(
498
  label="Private dataset",
499
  value=False,
@@ -514,7 +863,10 @@ with gr.Blocks() as app:
514
  visible=False,
515
  ) as pipeline_code_ui:
516
  code = generate_pipeline_code(
 
 
517
  system_prompt=system_prompt.value,
 
518
  num_turns=num_turns.value,
519
  num_rows=num_rows.value,
520
  )
@@ -524,77 +876,138 @@ with gr.Blocks() as app:
524
  label="Distilabel Pipeline Code",
525
  )
526
 
527
- load_btn.click(
528
- fn=generate_system_prompt,
529
- inputs=[dataset_description],
530
- outputs=[system_prompt],
531
- show_progress=True,
532
- ).then(
533
- fn=generate_sample_dataset,
534
- inputs=[system_prompt, num_turns],
535
- outputs=[dataframe],
536
- show_progress=True,
537
- )
538
 
539
- btn_apply_to_sample_dataset.click(
540
- fn=generate_sample_dataset,
541
- inputs=[system_prompt, num_turns],
542
- outputs=[dataframe],
543
- show_progress=True,
544
- )
 
545
 
546
- btn_push_to_hub.click(
547
- fn=validate_argilla_user_workspace_dataset,
548
- inputs=[repo_name],
549
- outputs=[success_message],
550
- show_progress=True,
551
- ).then(
552
- fn=validate_push_to_hub,
553
- inputs=[org_name, repo_name],
554
- outputs=[success_message],
555
- show_progress=True,
556
- ).success(
557
- fn=hide_success_message,
558
- outputs=[success_message],
559
- show_progress=True,
560
- ).success(
561
- fn=hide_pipeline_code_visibility,
562
- inputs=[],
563
- outputs=[pipeline_code_ui],
564
- show_progress=True,
565
- ).success(
566
- fn=push_dataset,
567
- inputs=[
568
- org_name,
569
- repo_name,
570
- system_prompt,
571
- num_turns,
572
- num_rows,
573
- private,
574
- temperature,
575
- pipeline_code,
576
- ],
577
- outputs=[success_message],
578
- show_progress=True,
579
- ).success(
580
- fn=show_success_message,
581
- inputs=[org_name, repo_name],
582
- outputs=[success_message],
583
- ).success(
584
- fn=generate_pipeline_code,
585
- inputs=[system_prompt, num_turns, num_rows],
586
- outputs=[pipeline_code],
587
- ).success(
588
- fn=show_pipeline_code_visibility,
589
- inputs=[],
590
- outputs=[pipeline_code_ui],
591
- )
592
- gr.on(
593
- triggers=[clear_btn_part.click, clear_btn_full.click],
594
- fn=lambda _: ("", "", 1, _get_dataframe()),
595
- inputs=[dataframe],
596
- outputs=[system_prompt, num_turns, dataframe],
597
- )
598
- app.load(fn=get_org_dropdown, outputs=[org_name])
599
- app.load(fn=get_random_repo_name, outputs=[repo_name])
600
- app.load(fn=swap_visibility, outputs=main_ui)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import ast
2
+ import json
3
  import random
4
  import uuid
5
  from typing import Dict, List, Union
 
9
  import pandas as pd
10
  from datasets import Dataset
11
  from distilabel.distiset import Distiset
12
+ from gradio.oauth import OAuthToken
13
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
14
  from huggingface_hub import HfApi
15
 
16
  from synthetic_dataset_generator.apps.base import (
17
  combine_datasets,
18
  hide_success_message,
19
+ load_dataset_from_hub,
20
+ preprocess_input_data,
21
  push_pipeline_code_to_hub,
22
  show_success_message,
23
  test_max_num_rows,
 
28
  BASE_URL,
29
  DEFAULT_BATCH_SIZE,
30
  MODEL,
31
+ MODEL_COMPLETION,
32
  SFT_AVAILABLE,
33
  )
34
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
35
  from synthetic_dataset_generator.pipelines.chat import (
36
  DEFAULT_DATASET_DESCRIPTIONS,
37
  generate_pipeline_code,
38
+ get_follow_up_generator,
39
  get_magpie_generator,
40
  get_prompt_generator,
41
  get_response_generator,
42
+ get_sentence_pair_generator,
43
  )
44
  from synthetic_dataset_generator.pipelines.embeddings import (
45
  get_embeddings,
46
  get_sentence_embedding_dimensions,
47
  )
48
  from synthetic_dataset_generator.utils import (
49
+ column_to_list,
50
  get_argilla_client,
51
  get_org_dropdown,
52
  get_random_repo_name,
 
54
  )
55
 
56
 
57
+ def _get_dataframe():
58
+ return gr.Dataframe(
59
+ headers=["prompt", "completion"],
60
+ wrap=True,
61
+ interactive=False,
62
+ )
63
+
64
+
65
  def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
66
  def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
67
  return ast.literal_eval(
 
77
  return dataframe
78
 
79
 
80
+ def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
81
  progress(0.1, desc="Initializing")
82
  generate_description = get_prompt_generator()
83
  progress(0.5, desc="Generating")
 
94
  return result
95
 
96
 
97
+ def load_dataset_file(
98
+ repo_id: str,
99
+ file_paths: list[str],
100
+ input_type: str,
101
+ num_rows: int = 10,
102
+ token: Union[OAuthToken, None] = None,
103
+ progress=gr.Progress(),
104
+ ):
105
+ progress(0.1, desc="Loading the source data")
106
+ if input_type == "dataset-input":
107
+ return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token)
108
+ else:
109
+ return preprocess_input_data(file_paths=file_paths, num_rows=num_rows)
110
+
111
+
112
+ def generate_sample_dataset(
113
+ repo_id: str,
114
+ file_paths: list[str],
115
+ input_type: str,
116
+ system_prompt: str,
117
+ document_column: str,
118
+ num_turns: int,
119
+ num_rows: int,
120
+ oauth_token: Union[OAuthToken, None],
121
+ progress=gr.Progress(),
122
+ ):
123
+ if input_type == "prompt-input":
124
+ dataframe = pd.DataFrame(columns=["prompt", "completion"])
125
+ else:
126
+ dataframe, _ = load_dataset_file(
127
+ repo_id=repo_id,
128
+ file_paths=file_paths,
129
+ input_type=input_type,
130
+ num_rows=num_rows,
131
+ token=oauth_token,
132
+ )
133
+ progress(0.5, desc="Generating sample dataset")
134
  dataframe = generate_dataset(
135
+ input_type=input_type,
136
+ dataframe=dataframe,
137
  system_prompt=system_prompt,
138
+ document_column=document_column,
139
  num_turns=num_turns,
140
+ num_rows=num_rows,
 
141
  is_sample=True,
142
  )
143
  progress(1.0, desc="Sample dataset generated")
144
  return dataframe
145
 
146
 
147
+ def generate_dataset_from_prompt(
 
 
 
 
 
 
 
 
148
  system_prompt: str,
149
  num_turns: int = 1,
150
  num_rows: int = 10,
151
  temperature: float = 0.9,
152
+ temperature_completion: Union[float, None] = None,
153
  is_sample: bool = False,
154
  progress=gr.Progress(),
155
  ) -> pd.DataFrame:
156
  num_rows = test_max_num_rows(num_rows)
157
  progress(0.0, desc="(1/2) Generating instructions")
158
+ magpie_generator = get_magpie_generator(num_turns, temperature, is_sample)
 
 
159
  response_generator = get_response_generator(
160
+ system_prompt=system_prompt,
161
+ num_turns=num_turns,
162
+ temperature=temperature or temperature_completion,
163
+ is_sample=is_sample,
164
  )
165
  total_steps: int = num_rows * 2
166
  batch_size = DEFAULT_BATCH_SIZE
 
265
  return dataframe
266
 
267
 
268
+ def generate_dataset_from_seed(
269
+ dataframe: pd.DataFrame,
270
+ document_column: str,
271
+ num_turns: int = 1,
272
+ num_rows: int = 10,
273
+ temperature: float = 0.9,
274
+ temperature_completion: Union[float, None] = None,
275
+ is_sample: bool = False,
276
+ progress=gr.Progress(),
277
+ ) -> pd.DataFrame:
278
+ num_rows = test_max_num_rows(num_rows)
279
+ progress(0.0, desc="Initializing dataset generation")
280
+ document_data = column_to_list(dataframe, document_column)
281
+ if len(document_data) < num_rows:
282
+ document_data += random.choices(document_data, k=num_rows - len(document_data))
283
+ instruction_generator = get_sentence_pair_generator(
284
+ temperature=temperature, is_sample=is_sample
285
+ )
286
+ response_generator = get_response_generator(
287
+ system_prompt=None,
288
+ num_turns=1,
289
+ temperature=temperature or temperature_completion,
290
+ is_sample=is_sample,
291
+ )
292
+ follow_up_generator_instruction = get_follow_up_generator(
293
+ type="instruction", temperature=temperature, is_sample=is_sample
294
+ )
295
+ follow_up_generator_response = get_follow_up_generator(
296
+ type="response",
297
+ temperature=temperature or temperature_completion,
298
+ is_sample=is_sample,
299
+ )
300
+ steps = 2 * num_turns
301
+ total_steps: int = num_rows * steps
302
+ step_progress = round(1 / steps, 2)
303
+ batch_size = DEFAULT_BATCH_SIZE
304
+
305
+ # create instructions
306
+ n_processed = 0
307
+ instruction_results = []
308
+ while n_processed < num_rows:
309
+ progress(
310
+ step_progress * n_processed / num_rows,
311
+ total=total_steps,
312
+ desc="Generating questions",
313
+ )
314
+ remaining_rows = num_rows - n_processed
315
+ batch_size = min(batch_size, remaining_rows)
316
+ batch = [
317
+ {"anchor": document}
318
+ for document in document_data[n_processed : n_processed + batch_size]
319
+ ]
320
+ questions = list(instruction_generator.process(inputs=batch))
321
+ instruction_results.extend(questions[0])
322
+ n_processed += batch_size
323
+ for result in instruction_results:
324
+ result["instruction"] = result["positive"]
325
+ result["prompt"] = result.pop("positive")
326
+
327
+ progress(step_progress, desc="Generating instructions")
328
+
329
+ # generate responses
330
+ n_processed = 0
331
+ response_results = []
332
+ while n_processed < num_rows:
333
+ progress(
334
+ step_progress + step_progress * n_processed / num_rows,
335
+ total=total_steps,
336
+ desc="Generating responses",
337
+ )
338
+ batch = instruction_results[n_processed : n_processed + batch_size]
339
+ responses = list(response_generator.process(inputs=batch))
340
+ response_results.extend(responses[0])
341
+ n_processed += batch_size
342
+ for result in response_results:
343
+ result["completion"] = result.pop("generation")
344
+
345
+ # generate follow-ups
346
+ if num_turns > 1:
347
+ n_processed = 0
348
+ final_conversations = []
349
+
350
+ while n_processed < num_rows:
351
+ progress(
352
+ step_progress + step_progress * n_processed / num_rows,
353
+ total=total_steps,
354
+ desc="Generating follow-ups",
355
+ )
356
+ batch = response_results[n_processed : n_processed + batch_size]
357
+ conversations_batch = [
358
+ {
359
+ "messages": [
360
+ {"role": "user", "content": result["prompt"]},
361
+ {"role": "assistant", "content": result["completion"]},
362
+ ]
363
+ }
364
+ for result in batch
365
+ ]
366
+
367
+ for _ in range(num_turns - 1):
368
+ follow_up_instructions = list(
369
+ follow_up_generator_instruction.process(inputs=conversations_batch)
370
+ )
371
+ for conv, follow_up in zip(conversations_batch, follow_up_instructions[0]):
372
+ conv["messages"].append(
373
+ {"role": "user", "content": follow_up["generation"]}
374
+ )
375
+
376
+ follow_up_responses = list(
377
+ follow_up_generator_response.process(inputs=conversations_batch)
378
+ )
379
+ for conv, follow_up in zip(conversations_batch, follow_up_responses[0]):
380
+ conv["messages"].append(
381
+ {"role": "assistant", "content": follow_up["generation"]}
382
+ )
383
+
384
+ final_conversations.extend(
385
+ [{"messages": conv["messages"]} for conv in conversations_batch]
386
+ )
387
+ n_processed += batch_size
388
+
389
+ # create distiset
390
+ distiset_results = []
391
+ if num_turns == 1:
392
+ for result in response_results:
393
+ record = {}
394
+ for relevant_keys in ["prompt", "completion"]:
395
+ if relevant_keys in result:
396
+ record[relevant_keys] = result[relevant_keys]
397
+ distiset_results.append(record)
398
+ dataframe = pd.DataFrame(distiset_results)
399
+ else:
400
+ distiset_results = final_conversations
401
+ dataframe = pd.DataFrame(distiset_results)
402
+ dataframe["messages"] = dataframe["messages"].apply(lambda x: json.dumps(x))
403
+
404
+ progress(1.0, desc="Dataset generation completed")
405
+ return dataframe
406
+
407
+
408
+ def generate_dataset(
409
+ input_type: str,
410
+ dataframe: pd.DataFrame,
411
+ system_prompt: str,
412
+ document_column: str,
413
+ num_turns: int = 1,
414
+ num_rows: int = 10,
415
+ temperature: float = 0.9,
416
+ temperature_completion: Union[float, None] = None,
417
+ is_sample: bool = False,
418
+ progress=gr.Progress(),
419
+ ) -> pd.DataFrame:
420
+ if input_type == "prompt-input":
421
+ dataframe = generate_dataset_from_prompt(
422
+ system_prompt=system_prompt,
423
+ num_turns=num_turns,
424
+ num_rows=num_rows,
425
+ temperature=temperature,
426
+ temperature_completion=temperature_completion,
427
+ is_sample=is_sample,
428
+ )
429
+ else:
430
+ dataframe = generate_dataset_from_seed(
431
+ dataframe=dataframe,
432
+ document_column=document_column,
433
+ num_turns=num_turns,
434
+ num_rows=num_rows,
435
+ temperature=temperature,
436
+ temperature_completion=temperature_completion,
437
+ is_sample=is_sample,
438
+ )
439
+ return dataframe
440
+
441
+
442
  def push_dataset_to_hub(
443
  dataframe: pd.DataFrame,
444
  org_name: str,
 
473
  def push_dataset(
474
  org_name: str,
475
  repo_name: str,
476
+ private: bool,
477
+ original_repo_id: str,
478
+ file_paths: list[str],
479
+ input_type: str,
480
  system_prompt: str,
481
+ document_column: str,
482
  num_turns: int = 1,
483
  num_rows: int = 10,
 
484
  temperature: float = 0.9,
485
+ temperature_completion: Union[float, None] = None,
486
  pipeline_code: str = "",
487
  oauth_token: Union[gr.OAuthToken, None] = None,
488
  progress=gr.Progress(),
489
  ) -> pd.DataFrame:
490
+ if input_type == "prompt-input":
491
+ dataframe = _get_dataframe()
492
+ else:
493
+ dataframe, _ = load_dataset_file(
494
+ repo_id=original_repo_id,
495
+ file_paths=file_paths,
496
+ input_type=input_type,
497
+ num_rows=num_rows,
498
+ token=oauth_token,
499
+ )
500
+ progress(0.5, desc="Generating dataset")
501
  dataframe = generate_dataset(
502
+ input_type=input_type,
503
+ dataframe=dataframe,
504
  system_prompt=system_prompt,
505
+ document_column=document_column,
506
  num_turns=num_turns,
507
  num_rows=num_rows,
508
  temperature=temperature,
509
+ temperature_completion=temperature_completion
510
  )
511
  push_dataset_to_hub(
512
+ dataframe=dataframe,
513
+ org_name=org_name,
514
+ repo_name=repo_name,
515
+ oauth_token=oauth_token,
516
+ private=private,
517
+ pipeline_code=pipeline_code,
518
  )
519
  try:
520
  progress(0.1, desc="Setting up user and workspace")
 
637
  return ""
638
 
639
 
640
+ def show_system_prompt_visibility():
641
+ return {system_prompt: gr.Textbox(visible=True)}
642
+
643
+
644
+ def hide_system_prompt_visibility():
645
+ return {system_prompt: gr.Textbox(visible=False)}
646
+
647
+
648
+ def show_document_column_visibility():
649
+ return {document_column: gr.Dropdown(visible=True)}
650
+
651
+
652
+ def hide_document_column_visibility():
653
+ return {
654
+ document_column: gr.Dropdown(
655
+ choices=["Load your data first in step 1."],
656
+ value="Load your data first in step 1.",
657
+ visible=False,
658
+ )
659
+ }
660
+
661
+
662
  def show_pipeline_code_visibility():
663
  return {pipeline_code_ui: gr.Accordion(visible=True)}
664
 
 
667
  return {pipeline_code_ui: gr.Accordion(visible=False)}
668
 
669
 
670
+ def show_temperature_completion():
671
+ if MODEL != MODEL_COMPLETION:
672
+ return {temperature_completion: gr.Slider(value=0.9, visible=True)}
673
+
674
+
675
  ######################
676
  # Gradio UI
677
  ######################
 
691
  )
692
  )
693
  else:
694
+ gr.Markdown("## 1. Select your input")
695
+ with gr.Row(equal_height=False):
696
  with gr.Column(scale=2):
697
+ input_type = gr.Dropdown(
698
+ label="Input type",
699
+ choices=["prompt-input", "dataset-input", "file-input"],
700
+ value="prompt-input",
701
+ multiselect=False,
702
+ visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
703
  )
704
+ with gr.Tab("Generate from prompt") as tab_prompt_input:
705
+ with gr.Row(equal_height=False):
706
+ with gr.Column(scale=2):
707
+ dataset_description = gr.Textbox(
708
+ label="Dataset description",
709
+ placeholder="Give a precise description of your desired dataset.",
710
+ )
711
+ with gr.Row():
712
+ clear_prompt_btn_part = gr.Button(
713
+ "Clear", variant="secondary"
714
+ )
715
+ load_prompt_btn = gr.Button(
716
+ "Create", variant="primary"
717
+ )
718
+ with gr.Column(scale=3):
719
+ examples = gr.Examples(
720
+ examples=DEFAULT_DATASET_DESCRIPTIONS,
721
+ inputs=[dataset_description],
722
+ cache_examples=False,
723
+ label="Examples",
724
+ )
725
+ with gr.Tab("Load from Hub") as tab_dataset_input:
726
+ with gr.Row(equal_height=False):
727
+ with gr.Column(scale=2):
728
+ search_in = HuggingfaceHubSearch(
729
+ label="Search",
730
+ placeholder="Search for a dataset",
731
+ search_type="dataset",
732
+ sumbit_on_select=True,
733
+ )
734
+ with gr.Row():
735
+ clear_dataset_btn_part = gr.Button(
736
+ "Clear", variant="secondary"
737
+ )
738
+ load_dataset_btn = gr.Button(
739
+ "Load", variant="primary"
740
+ )
741
+ with gr.Column(scale=3):
742
+ examples = gr.Examples(
743
+ examples=[
744
+ "charris/wikipedia_sample",
745
+ "plaguss/argilla_sdk_docs_raw_unstructured",
746
+ "BeIR/hotpotqa-generated-queries",
747
+ ],
748
+ label="Example datasets",
749
+ fn=lambda x: x,
750
+ inputs=[search_in],
751
+ run_on_click=True,
752
+ )
753
+ search_out = gr.HTML(
754
+ label="Dataset preview", visible=False
755
+ )
756
+ with gr.Tab("Load your file") as tab_file_input:
757
+ with gr.Row(equal_height=False):
758
+ with gr.Column(scale=2):
759
+ file_in = gr.File(
760
+ label="Upload your file. Supported formats: .md, .txt, .docx, .pdf",
761
+ file_count="multiple",
762
+ file_types=[".md", ".txt", ".docx", ".pdf"],
763
+ )
764
+ with gr.Row():
765
+ clear_file_btn_part = gr.Button(
766
+ "Clear", variant="secondary"
767
+ )
768
+ load_file_btn = gr.Button("Load", variant="primary")
769
+ with gr.Column(scale=3):
770
+ file_out = gr.HTML(
771
+ label="Dataset preview", visible=False
772
+ )
773
 
774
  gr.HTML(value="<hr>")
775
  gr.Markdown(value="## 2. Configure your dataset")
 
779
  label="System prompt",
780
  placeholder="You are a helpful assistant.",
781
  )
782
+ document_column = gr.Dropdown(
783
+ label="Document Column",
784
+ info="Select the document column to generate the RAG dataset",
785
+ choices=["Load your data first in step 1."],
786
+ value="Load your data first in step 1.",
787
+ interactive=False,
788
+ multiselect=False,
789
+ allow_custom_value=False,
790
+ visible=False,
791
+ )
792
  num_turns = gr.Number(
793
  value=1,
794
  label="Number of turns in the conversation",
 
829
  temperature = gr.Slider(
830
  label="Temperature",
831
  minimum=0.1,
832
+ maximum=1.5,
833
  value=0.9,
834
  step=0.1,
835
  interactive=True,
836
  )
837
+ temperature_completion = gr.Slider(
838
+ label="Temperature for completion",
839
+ minimum=0.1,
840
+ maximum=1.5,
841
+ value=None,
842
+ step=0.1,
843
+ interactive=True,
844
+ visible=False,
845
+ )
846
  private = gr.Checkbox(
847
  label="Private dataset",
848
  value=False,
 
863
  visible=False,
864
  ) as pipeline_code_ui:
865
  code = generate_pipeline_code(
866
+ repo_id=search_in.value,
867
+ input_type=input_type.value,
868
  system_prompt=system_prompt.value,
869
+ document_column=document_column.value,
870
  num_turns=num_turns.value,
871
  num_rows=num_rows.value,
872
  )
 
876
  label="Distilabel Pipeline Code",
877
  )
878
 
879
+ tab_prompt_input.select(
880
+ fn=lambda: "prompt-input",
881
+ inputs=[],
882
+ outputs=[input_type],
883
+ ).then(fn=show_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
884
+ fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
885
+ )
 
 
 
 
886
 
887
+ tab_dataset_input.select(
888
+ fn=lambda: "dataset-input",
889
+ inputs=[],
890
+ outputs=[input_type],
891
+ ).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
892
+ fn=show_document_column_visibility, inputs=[], outputs=[document_column]
893
+ )
894
 
895
+ tab_file_input.select(
896
+ fn=lambda: "file-input",
897
+ inputs=[],
898
+ outputs=[input_type],
899
+ ).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
900
+ fn=show_document_column_visibility, inputs=[], outputs=[document_column]
901
+ )
902
+
903
+ search_in.submit(
904
+ fn=lambda df: pd.DataFrame(columns=df.columns),
905
+ inputs=[dataframe],
906
+ outputs=[dataframe],
907
+ )
908
+
909
+ load_prompt_btn.click(
910
+ fn=generate_system_prompt,
911
+ inputs=[dataset_description],
912
+ outputs=[system_prompt],
913
+ ).success(
914
+ fn=generate_sample_dataset,
915
+ inputs=[
916
+ search_in,
917
+ file_in,
918
+ input_type,
919
+ system_prompt,
920
+ document_column,
921
+ num_turns,
922
+ num_rows,
923
+ ],
924
+ outputs=dataframe,
925
+ )
926
+
927
+ gr.on(
928
+ triggers=[load_dataset_btn.click, load_file_btn.click],
929
+ fn=load_dataset_file,
930
+ inputs=[search_in, file_in, input_type],
931
+ outputs=[dataframe, document_column],
932
+ )
933
+
934
+ btn_apply_to_sample_dataset.click(
935
+ fn=generate_sample_dataset,
936
+ inputs=[
937
+ search_in,
938
+ file_in,
939
+ input_type,
940
+ system_prompt,
941
+ document_column,
942
+ num_turns,
943
+ num_rows,
944
+ ],
945
+ outputs=dataframe,
946
+ )
947
+
948
+ btn_push_to_hub.click(
949
+ fn=validate_argilla_user_workspace_dataset,
950
+ inputs=[repo_name],
951
+ outputs=[success_message],
952
+ ).then(
953
+ fn=validate_push_to_hub,
954
+ inputs=[org_name, repo_name],
955
+ outputs=[success_message],
956
+ ).success(
957
+ fn=hide_success_message,
958
+ outputs=[success_message],
959
+ ).success(
960
+ fn=hide_pipeline_code_visibility,
961
+ inputs=[],
962
+ outputs=[pipeline_code_ui],
963
+ ).success(
964
+ fn=push_dataset,
965
+ inputs=[
966
+ org_name,
967
+ repo_name,
968
+ private,
969
+ search_in,
970
+ file_in,
971
+ input_type,
972
+ system_prompt,
973
+ document_column,
974
+ num_turns,
975
+ num_rows,
976
+ temperature,
977
+ temperature_completion,
978
+ pipeline_code,
979
+ ],
980
+ outputs=[success_message],
981
+ ).success(
982
+ fn=show_success_message,
983
+ inputs=[org_name, repo_name],
984
+ outputs=[success_message],
985
+ ).success(
986
+ fn=generate_pipeline_code,
987
+ inputs=[
988
+ search_in,
989
+ input_type,
990
+ system_prompt,
991
+ document_column,
992
+ num_turns,
993
+ num_rows,
994
+ ],
995
+ outputs=[pipeline_code],
996
+ ).success(
997
+ fn=show_pipeline_code_visibility,
998
+ inputs=[],
999
+ outputs=[pipeline_code_ui],
1000
+ )
1001
+
1002
+ clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
1003
+ clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
1004
+ clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
1005
+ clear_btn_full.click(
1006
+ fn=lambda df: ("", "", [], _get_dataframe()),
1007
+ inputs=[dataframe],
1008
+ outputs=[system_prompt, document_column, num_turns, dataframe],
1009
+ )
1010
+ app.load(fn=swap_visibility, outputs=main_ui)
1011
+ app.load(fn=get_org_dropdown, outputs=[org_name])
1012
+ app.load(fn=get_random_repo_name, outputs=[repo_name])
1013
+ app.load(fn=show_temperature_completion, outputs=[temperature_completion])
src/synthetic_dataset_generator/apps/eval.py CHANGED
@@ -15,7 +15,7 @@ from datasets import (
15
  from distilabel.distiset import Distiset
16
  from gradio.oauth import OAuthToken #
17
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
18
- from huggingface_hub import HfApi, repo_exists
19
 
20
  from synthetic_dataset_generator.apps.base import (
21
  combine_datasets,
@@ -130,9 +130,9 @@ def load_dataset_from_hub(
130
  choices=response_valid_columns,
131
  label="Response column",
132
  value=col_response,
133
- interactive=False
134
- if col_response == "No valid response columns found."
135
- else True,
136
  ),
137
  prompt_template,
138
  structured_output,
@@ -831,16 +831,13 @@ with gr.Blocks() as app:
831
  fn=validate_argilla_user_workspace_dataset,
832
  inputs=[repo_name],
833
  outputs=[success_message],
834
- show_progress=True,
835
  ).then(
836
  fn=validate_push_to_hub,
837
  inputs=[org_name, repo_name],
838
  outputs=[success_message],
839
- show_progress=True,
840
  ).success(
841
  fn=hide_success_message,
842
  outputs=[success_message],
843
- show_progress=True,
844
  ).success(
845
  fn=hide_pipeline_code_visibility,
846
  inputs=[],
@@ -862,7 +859,6 @@ with gr.Blocks() as app:
862
  pipeline_code,
863
  ],
864
  outputs=[success_message],
865
- show_progress=True,
866
  ).success(
867
  fn=show_success_message,
868
  inputs=[org_name, repo_name],
@@ -882,13 +878,14 @@ with gr.Blocks() as app:
882
  outputs=[pipeline_code_ui],
883
  )
884
 
885
- clear_btn_part.click(fn=lambda : "", inputs=[], outputs=[search_in])
886
  clear_btn_full.click(
887
  fn=lambda df: ("", "", pd.DataFrame(columns=df.columns)),
888
  inputs=[dataframe],
889
  outputs=[
890
  instruction_instruction_response,
891
  response_instruction_response,
 
892
  ],
893
  )
894
 
 
15
  from distilabel.distiset import Distiset
16
  from gradio.oauth import OAuthToken #
17
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
18
+ from huggingface_hub import HfApi
19
 
20
  from synthetic_dataset_generator.apps.base import (
21
  combine_datasets,
 
130
  choices=response_valid_columns,
131
  label="Response column",
132
  value=col_response,
133
+ interactive=(
134
+ False if col_response == "No valid response columns found." else True
135
+ ),
136
  ),
137
  prompt_template,
138
  structured_output,
 
831
  fn=validate_argilla_user_workspace_dataset,
832
  inputs=[repo_name],
833
  outputs=[success_message],
 
834
  ).then(
835
  fn=validate_push_to_hub,
836
  inputs=[org_name, repo_name],
837
  outputs=[success_message],
 
838
  ).success(
839
  fn=hide_success_message,
840
  outputs=[success_message],
 
841
  ).success(
842
  fn=hide_pipeline_code_visibility,
843
  inputs=[],
 
859
  pipeline_code,
860
  ],
861
  outputs=[success_message],
 
862
  ).success(
863
  fn=show_success_message,
864
  inputs=[org_name, repo_name],
 
878
  outputs=[pipeline_code_ui],
879
  )
880
 
881
+ clear_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
882
  clear_btn_full.click(
883
  fn=lambda df: ("", "", pd.DataFrame(columns=df.columns)),
884
  inputs=[dataframe],
885
  outputs=[
886
  instruction_instruction_response,
887
  response_instruction_response,
888
+ dataframe,
889
  ],
890
  )
891
 
src/synthetic_dataset_generator/apps/rag.py CHANGED
@@ -1,37 +1,30 @@
1
  import os
2
  import random
3
  import uuid
4
- from tqdm import tqdm
5
  from typing import Union
6
 
7
  import argilla as rg
8
  import gradio as gr
9
  import nltk
10
  import pandas as pd
11
- from datasets import (
12
- Dataset,
13
- get_dataset_config_names,
14
- get_dataset_split_names,
15
- load_dataset,
16
- )
17
  from distilabel.distiset import Distiset
18
  from gradio.oauth import OAuthToken
19
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
20
  from huggingface_hub import HfApi
21
- from unstructured.chunking.title import chunk_by_title
22
- from unstructured.partition.auto import partition
23
 
24
  from synthetic_dataset_generator.apps.base import (
25
  combine_datasets,
26
- get_iframe,
27
  hide_success_message,
 
 
28
  push_pipeline_code_to_hub,
29
  show_success_message,
30
  test_max_num_rows,
31
  validate_argilla_user_workspace_dataset,
32
  validate_push_to_hub,
33
  )
34
- from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
35
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
36
  from synthetic_dataset_generator.pipelines.embeddings import (
37
  get_embeddings,
@@ -39,11 +32,11 @@ from synthetic_dataset_generator.pipelines.embeddings import (
39
  )
40
  from synthetic_dataset_generator.pipelines.rag import (
41
  DEFAULT_DATASET_DESCRIPTIONS,
 
42
  get_chunks_generator,
43
  get_prompt_generator,
44
- generate_pipeline_code,
45
- get_sentence_pair_generator,
46
  get_response_generator,
 
47
  )
48
  from synthetic_dataset_generator.utils import (
49
  column_to_list,
@@ -58,80 +51,8 @@ nltk.data.path.append("./nltk_data")
58
  nltk.download("punkt_tab", download_dir="./nltk_data")
59
  nltk.download("averaged_perceptron_tagger_eng", download_dir="./nltk_data")
60
 
61
- def _get_valid_columns(dataframe: pd.DataFrame):
62
- doc_valid_columns = []
63
-
64
- for col in dataframe.columns:
65
- sample_val = dataframe[col].iloc[0]
66
- if isinstance(sample_val, str):
67
- doc_valid_columns.append(col)
68
-
69
- return doc_valid_columns
70
-
71
-
72
- def _load_dataset_from_hub(
73
- repo_id: str,
74
- num_rows: int = 10,
75
- token: Union[OAuthToken, None] = None,
76
- progress=gr.Progress(track_tqdm=True),
77
- ):
78
- if not repo_id:
79
- raise gr.Error("Hub repo id is required")
80
- subsets = get_dataset_config_names(repo_id, token=token)
81
- splits = get_dataset_split_names(repo_id, subsets[0], token=token)
82
- ds = load_dataset(repo_id, subsets[0], split=splits[0], token=token, streaming=True)
83
- rows = []
84
- for idx, row in enumerate(tqdm(ds, desc="Loading the dataset", total=num_rows)):
85
- rows.append(row)
86
- if idx == num_rows:
87
- break
88
- ds = Dataset.from_list(rows)
89
- dataframe = ds.to_pandas()
90
- doc_valid_columns = _get_valid_columns(dataframe)
91
- col_doc = doc_valid_columns[0] if doc_valid_columns else ""
92
- return (
93
- dataframe,
94
- gr.Dropdown(
95
- choices=doc_valid_columns,
96
- label="Documents column",
97
- value=col_doc,
98
- interactive=(False if col_doc == "" else True),
99
- multiselect=False,
100
- ),
101
- )
102
-
103
-
104
- def _preprocess_input_data(file_paths, num_rows, progress=gr.Progress(track_tqdm=True)):
105
- data = {}
106
- total_chunks = 0
107
 
108
- for file_path in tqdm(file_paths, desc="Processing files", total=len(file_paths)):
109
- partitioned_file = partition(filename=file_path)
110
- chunks = [str(chunk) for chunk in chunk_by_title(partitioned_file)]
111
- data[file_path] = chunks
112
- total_chunks += len(chunks)
113
- if total_chunks >= num_rows:
114
- break
115
-
116
- dataframe = pd.DataFrame.from_records(
117
- [(k, v) for k, values in data.items() for v in values],
118
- columns=["filename", "chunks"],
119
- )
120
- col_doc = "chunks"
121
-
122
- return (
123
- dataframe,
124
- gr.Dropdown(
125
- choices=["chunks"],
126
- label="Documents column",
127
- value=col_doc,
128
- interactive=(False if col_doc == "" else True),
129
- multiselect=False,
130
- ),
131
- )
132
-
133
-
134
- def generate_system_prompt(dataset_description, progress=gr.Progress()):
135
  progress(0.1, desc="Initializing")
136
  generate_description = get_prompt_generator()
137
  progress(0.5, desc="Generating")
@@ -158,9 +79,48 @@ def load_dataset_file(
158
  ):
159
  progress(0.1, desc="Loading the source data")
160
  if input_type == "dataset-input":
161
- return _load_dataset_from_hub(repo_id, num_rows, token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  else:
163
- return _preprocess_input_data(file_paths, num_rows)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
  def generate_dataset(
@@ -172,6 +132,7 @@ def generate_dataset(
172
  reranking: bool = False,
173
  num_rows: int = 10,
174
  temperature: float = 0.7,
 
175
  is_sample: bool = False,
176
  progress=gr.Progress(),
177
  ):
@@ -195,7 +156,7 @@ def generate_dataset(
195
  is_sample=is_sample,
196
  )
197
  response_generator = get_response_generator(
198
- temperature=temperature, is_sample=is_sample
199
  )
200
  if reranking:
201
  reranking_generator = get_sentence_pair_generator(
@@ -320,44 +281,6 @@ def generate_dataset(
320
  return dataframe
321
 
322
 
323
- def generate_sample_dataset(
324
- repo_id: str,
325
- file_paths: list[str],
326
- input_type: str,
327
- system_prompt: str,
328
- document_column: str,
329
- retrieval_reranking: list[str],
330
- num_rows: str,
331
- oauth_token: Union[OAuthToken, None],
332
- progress=gr.Progress(),
333
- ):
334
- retrieval = "Retrieval" in retrieval_reranking
335
- reranking = "Reranking" in retrieval_reranking
336
-
337
- if input_type == "prompt-input":
338
- dataframe = pd.DataFrame(columns=["context", "question", "response"])
339
- else:
340
- dataframe, _ = load_dataset_file(
341
- repo_id=repo_id,
342
- file_paths=file_paths,
343
- input_type=input_type,
344
- num_rows=num_rows,
345
- token=oauth_token,
346
- )
347
- progress(0.5, desc="Generating dataset")
348
- dataframe = generate_dataset(
349
- input_type=input_type,
350
- dataframe=dataframe,
351
- system_prompt=system_prompt,
352
- document_column=document_column,
353
- retrieval=retrieval,
354
- reranking=reranking,
355
- num_rows=10,
356
- is_sample=True,
357
- )
358
- return dataframe
359
-
360
-
361
  def push_dataset_to_hub(
362
  dataframe: pd.DataFrame,
363
  org_name: str,
@@ -398,6 +321,7 @@ def push_dataset(
398
  retrieval_reranking: list[str],
399
  num_rows: int,
400
  temperature: float,
 
401
  pipeline_code: str,
402
  oauth_token: Union[gr.OAuthToken, None] = None,
403
  progress=gr.Progress(),
@@ -425,15 +349,14 @@ def push_dataset(
425
  reranking=reranking,
426
  num_rows=num_rows,
427
  temperature=temperature,
 
428
  is_sample=True,
429
  )
430
  push_dataset_to_hub(
431
  dataframe, org_name, repo_name, oauth_token, private, pipeline_code
432
  )
433
  dataframe = dataframe[
434
- dataframe.applymap(
435
- lambda x: str(x).strip() if pd.notna(x) else x
436
- ).apply(
437
  lambda row: row.notna().all() and (row != "").all(), axis=1
438
  )
439
  ]
@@ -593,6 +516,11 @@ def hide_pipeline_code_visibility():
593
  return {pipeline_code_ui: gr.Accordion(visible=False)}
594
 
595
 
 
 
 
 
 
596
  ######################
597
  # Gradio UI
598
  ######################
@@ -674,40 +602,37 @@ with gr.Blocks() as app:
674
 
675
  gr.HTML(value="<hr>")
676
  gr.Markdown(value="## 2. Configure your task")
677
- with gr.Row(equal_height=True):
678
- with gr.Row(equal_height=False):
679
- with gr.Column(scale=2):
680
- system_prompt = gr.Textbox(
681
- label="System prompt",
682
- placeholder="You are a helpful assistant.",
683
- visible=False,
684
- )
685
- document_column = gr.Dropdown(
686
- label="Document Column",
687
- info="Select the document column to generate the RAG dataset",
688
- choices=["Load your data first in step 1."],
689
- value="Load your data first in step 1.",
690
- interactive=False,
691
- multiselect=False,
692
- allow_custom_value=False,
693
- )
694
- retrieval_reranking = gr.CheckboxGroup(
695
- choices=[("Retrieval", "Retrieval"), ("Reranking", "Reranking")],
696
- type="value",
697
- label="Data for RAG",
698
- info="Indicate the additional data you want to generate for RAG.",
699
- )
700
- with gr.Row():
701
- clear_btn_full = gr.Button("Clear", variant="secondary")
702
- btn_apply_to_sample_dataset = gr.Button(
703
- "Save", variant="primary"
704
- )
705
- with gr.Column(scale=3):
706
- dataframe = gr.Dataframe(
707
- headers=["context", "question", "response"],
708
- wrap=True,
709
- interactive=False,
710
- )
711
 
712
  gr.HTML(value="<hr>")
713
  gr.Markdown(value="## 3. Generate your dataset")
@@ -729,11 +654,20 @@ with gr.Blocks() as app:
729
  temperature = gr.Slider(
730
  label="Temperature",
731
  minimum=0.1,
732
- maximum=1,
733
  value=0.7,
734
  step=0.1,
735
  interactive=True,
736
  )
 
 
 
 
 
 
 
 
 
737
  private = gr.Checkbox(
738
  label="Private dataset",
739
  value=False,
@@ -753,7 +687,6 @@ with gr.Blocks() as app:
753
  ) as pipeline_code_ui:
754
  code = generate_pipeline_code(
755
  repo_id=search_in.value,
756
- file_paths=file_in.value,
757
  input_type=input_type.value,
758
  system_prompt=system_prompt.value,
759
  document_column=document_column.value,
@@ -790,35 +723,23 @@ with gr.Blocks() as app:
790
  fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
791
  )
792
 
793
- search_in.submit(fn=get_iframe, inputs=search_in, outputs=search_out).then(
794
  fn=lambda df: pd.DataFrame(columns=df.columns),
795
  inputs=[dataframe],
796
  outputs=[dataframe],
797
  )
798
 
799
- load_dataset_btn.click(
 
800
  fn=load_dataset_file,
801
  inputs=[search_in, file_in, input_type],
802
- outputs=[
803
- dataframe,
804
- document_column,
805
- ],
806
- )
807
-
808
- load_file_btn.click(
809
- fn=load_dataset_file,
810
- inputs=[search_in, file_in, input_type],
811
- outputs=[
812
- dataframe,
813
- document_column,
814
- ],
815
  )
816
 
817
  load_prompt_btn.click(
818
  fn=generate_system_prompt,
819
  inputs=[dataset_description],
820
  outputs=[system_prompt],
821
- show_progress=True,
822
  ).success(
823
  fn=generate_sample_dataset,
824
  inputs=[
@@ -851,16 +772,13 @@ with gr.Blocks() as app:
851
  fn=validate_argilla_user_workspace_dataset,
852
  inputs=[repo_name],
853
  outputs=[success_message],
854
- show_progress=True,
855
  ).then(
856
  fn=validate_push_to_hub,
857
  inputs=[org_name, repo_name],
858
  outputs=[success_message],
859
- show_progress=True,
860
  ).success(
861
  fn=hide_success_message,
862
  outputs=[success_message],
863
- show_progress=True,
864
  ).success(
865
  fn=hide_pipeline_code_visibility,
866
  inputs=[],
@@ -879,10 +797,10 @@ with gr.Blocks() as app:
879
  retrieval_reranking,
880
  num_rows,
881
  temperature,
 
882
  pipeline_code,
883
  ],
884
  outputs=[success_message],
885
- show_progress=True,
886
  ).success(
887
  fn=show_success_message,
888
  inputs=[org_name, repo_name],
@@ -891,7 +809,6 @@ with gr.Blocks() as app:
891
  fn=generate_pipeline_code,
892
  inputs=[
893
  search_in,
894
- file_in,
895
  input_type,
896
  system_prompt,
897
  document_column,
@@ -905,11 +822,9 @@ with gr.Blocks() as app:
905
  outputs=[pipeline_code_ui],
906
  )
907
 
908
- clear_dataset_btn_part.click(fn=lambda : "", inputs=[], outputs=[search_in])
909
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
910
- clear_prompt_btn_part.click(
911
- fn=lambda : "", inputs=[], outputs=[dataset_description]
912
- )
913
  clear_btn_full.click(
914
  fn=lambda df: ("", [], pd.DataFrame(columns=df.columns)),
915
  inputs=[dataframe],
@@ -919,3 +834,4 @@ with gr.Blocks() as app:
919
  app.load(fn=swap_visibility, outputs=main_ui)
920
  app.load(fn=get_org_dropdown, outputs=[org_name])
921
  app.load(fn=get_random_repo_name, outputs=[repo_name])
 
 
1
  import os
2
  import random
3
  import uuid
 
4
  from typing import Union
5
 
6
  import argilla as rg
7
  import gradio as gr
8
  import nltk
9
  import pandas as pd
10
+ from datasets import Dataset
 
 
 
 
 
11
  from distilabel.distiset import Distiset
12
  from gradio.oauth import OAuthToken
13
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
14
  from huggingface_hub import HfApi
 
 
15
 
16
  from synthetic_dataset_generator.apps.base import (
17
  combine_datasets,
 
18
  hide_success_message,
19
+ load_dataset_from_hub,
20
+ preprocess_input_data,
21
  push_pipeline_code_to_hub,
22
  show_success_message,
23
  test_max_num_rows,
24
  validate_argilla_user_workspace_dataset,
25
  validate_push_to_hub,
26
  )
27
+ from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE, MODEL, MODEL_COMPLETION
28
  from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
29
  from synthetic_dataset_generator.pipelines.embeddings import (
30
  get_embeddings,
 
32
  )
33
  from synthetic_dataset_generator.pipelines.rag import (
34
  DEFAULT_DATASET_DESCRIPTIONS,
35
+ generate_pipeline_code,
36
  get_chunks_generator,
37
  get_prompt_generator,
 
 
38
  get_response_generator,
39
+ get_sentence_pair_generator,
40
  )
41
  from synthetic_dataset_generator.utils import (
42
  column_to_list,
 
51
  nltk.download("punkt_tab", download_dir="./nltk_data")
52
  nltk.download("averaged_perceptron_tagger_eng", download_dir="./nltk_data")
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  progress(0.1, desc="Initializing")
57
  generate_description = get_prompt_generator()
58
  progress(0.5, desc="Generating")
 
79
  ):
80
  progress(0.1, desc="Loading the source data")
81
  if input_type == "dataset-input":
82
+ return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token)
83
+ else:
84
+ return preprocess_input_data(file_paths=file_paths, num_rows=num_rows)
85
+
86
+
87
+ def generate_sample_dataset(
88
+ repo_id: str,
89
+ file_paths: list[str],
90
+ input_type: str,
91
+ system_prompt: str,
92
+ document_column: str,
93
+ retrieval_reranking: list[str],
94
+ num_rows: str,
95
+ oauth_token: Union[OAuthToken, None],
96
+ progress=gr.Progress(),
97
+ ):
98
+ retrieval = "Retrieval" in retrieval_reranking
99
+ reranking = "Reranking" in retrieval_reranking
100
+
101
+ if input_type == "prompt-input":
102
+ dataframe = pd.DataFrame(columns=["context", "question", "response"])
103
  else:
104
+ dataframe, _ = load_dataset_file(
105
+ repo_id=repo_id,
106
+ file_paths=file_paths,
107
+ input_type=input_type,
108
+ num_rows=num_rows,
109
+ token=oauth_token,
110
+ )
111
+ progress(0.5, desc="Generating dataset")
112
+ dataframe = generate_dataset(
113
+ input_type=input_type,
114
+ dataframe=dataframe,
115
+ system_prompt=system_prompt,
116
+ document_column=document_column,
117
+ retrieval=retrieval,
118
+ reranking=reranking,
119
+ num_rows=10,
120
+ is_sample=True,
121
+ )
122
+ progress(1.0, desc="Sample dataset generated")
123
+ return dataframe
124
 
125
 
126
  def generate_dataset(
 
132
  reranking: bool = False,
133
  num_rows: int = 10,
134
  temperature: float = 0.7,
135
+ temperature_completion: Union[float, None] = None,
136
  is_sample: bool = False,
137
  progress=gr.Progress(),
138
  ):
 
156
  is_sample=is_sample,
157
  )
158
  response_generator = get_response_generator(
159
+ temperature = temperature_completion or temperature , is_sample=is_sample
160
  )
161
  if reranking:
162
  reranking_generator = get_sentence_pair_generator(
 
281
  return dataframe
282
 
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  def push_dataset_to_hub(
285
  dataframe: pd.DataFrame,
286
  org_name: str,
 
321
  retrieval_reranking: list[str],
322
  num_rows: int,
323
  temperature: float,
324
+ temperature_completion: float,
325
  pipeline_code: str,
326
  oauth_token: Union[gr.OAuthToken, None] = None,
327
  progress=gr.Progress(),
 
349
  reranking=reranking,
350
  num_rows=num_rows,
351
  temperature=temperature,
352
+ temperature_completion=temperature_completion,
353
  is_sample=True,
354
  )
355
  push_dataset_to_hub(
356
  dataframe, org_name, repo_name, oauth_token, private, pipeline_code
357
  )
358
  dataframe = dataframe[
359
+ dataframe.applymap(lambda x: str(x).strip() if pd.notna(x) else x).apply(
 
 
360
  lambda row: row.notna().all() and (row != "").all(), axis=1
361
  )
362
  ]
 
516
  return {pipeline_code_ui: gr.Accordion(visible=False)}
517
 
518
 
519
+ def show_temperature_completion():
520
+ if MODEL != MODEL_COMPLETION:
521
+ return {temperature_completion: gr.Slider(value=0.9, visible=True)}
522
+
523
+
524
  ######################
525
  # Gradio UI
526
  ######################
 
602
 
603
  gr.HTML(value="<hr>")
604
  gr.Markdown(value="## 2. Configure your task")
605
+ with gr.Row(equal_height=False):
606
+ with gr.Column(scale=2):
607
+ system_prompt = gr.Textbox(
608
+ label="System prompt",
609
+ placeholder="You are a helpful assistant.",
610
+ visible=False,
611
+ )
612
+ document_column = gr.Dropdown(
613
+ label="Document Column",
614
+ info="Select the document column to generate the RAG dataset",
615
+ choices=["Load your data first in step 1."],
616
+ value="Load your data first in step 1.",
617
+ interactive=False,
618
+ multiselect=False,
619
+ allow_custom_value=False,
620
+ )
621
+ retrieval_reranking = gr.CheckboxGroup(
622
+ choices=[("Retrieval", "Retrieval"), ("Reranking", "Reranking")],
623
+ type="value",
624
+ label="Data for RAG",
625
+ info="Indicate the additional data you want to generate for RAG.",
626
+ )
627
+ with gr.Row():
628
+ clear_btn_full = gr.Button("Clear", variant="secondary")
629
+ btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
630
+ with gr.Column(scale=3):
631
+ dataframe = gr.Dataframe(
632
+ headers=["context", "question", "response"],
633
+ wrap=True,
634
+ interactive=False,
635
+ )
 
 
 
636
 
637
  gr.HTML(value="<hr>")
638
  gr.Markdown(value="## 3. Generate your dataset")
 
654
  temperature = gr.Slider(
655
  label="Temperature",
656
  minimum=0.1,
657
+ maximum=1.5,
658
  value=0.7,
659
  step=0.1,
660
  interactive=True,
661
  )
662
+ temperature_completion = gr.Slider(
663
+ label="Temperature for completion",
664
+ minimum=0.1,
665
+ maximum=1.5,
666
+ value=None,
667
+ step=0.1,
668
+ interactive=True,
669
+ visible=False,
670
+ )
671
  private = gr.Checkbox(
672
  label="Private dataset",
673
  value=False,
 
687
  ) as pipeline_code_ui:
688
  code = generate_pipeline_code(
689
  repo_id=search_in.value,
 
690
  input_type=input_type.value,
691
  system_prompt=system_prompt.value,
692
  document_column=document_column.value,
 
723
  fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
724
  )
725
 
726
+ search_in.submit(
727
  fn=lambda df: pd.DataFrame(columns=df.columns),
728
  inputs=[dataframe],
729
  outputs=[dataframe],
730
  )
731
 
732
+ gr.on(
733
+ triggers=[load_dataset_btn.click, load_file_btn.click],
734
  fn=load_dataset_file,
735
  inputs=[search_in, file_in, input_type],
736
+ outputs=[dataframe, document_column],
 
 
 
 
 
 
 
 
 
 
 
 
737
  )
738
 
739
  load_prompt_btn.click(
740
  fn=generate_system_prompt,
741
  inputs=[dataset_description],
742
  outputs=[system_prompt],
 
743
  ).success(
744
  fn=generate_sample_dataset,
745
  inputs=[
 
772
  fn=validate_argilla_user_workspace_dataset,
773
  inputs=[repo_name],
774
  outputs=[success_message],
 
775
  ).then(
776
  fn=validate_push_to_hub,
777
  inputs=[org_name, repo_name],
778
  outputs=[success_message],
 
779
  ).success(
780
  fn=hide_success_message,
781
  outputs=[success_message],
 
782
  ).success(
783
  fn=hide_pipeline_code_visibility,
784
  inputs=[],
 
797
  retrieval_reranking,
798
  num_rows,
799
  temperature,
800
+ temperature_completion,
801
  pipeline_code,
802
  ],
803
  outputs=[success_message],
 
804
  ).success(
805
  fn=show_success_message,
806
  inputs=[org_name, repo_name],
 
809
  fn=generate_pipeline_code,
810
  inputs=[
811
  search_in,
 
812
  input_type,
813
  system_prompt,
814
  document_column,
 
822
  outputs=[pipeline_code_ui],
823
  )
824
 
825
+ clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
826
  clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
827
+ clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
 
 
828
  clear_btn_full.click(
829
  fn=lambda df: ("", [], pd.DataFrame(columns=df.columns)),
830
  inputs=[dataframe],
 
834
  app.load(fn=swap_visibility, outputs=main_ui)
835
  app.load(fn=get_org_dropdown, outputs=[org_name])
836
  app.load(fn=get_random_repo_name, outputs=[repo_name])
837
+ app.load(fn=show_temperature_completion, outputs=[temperature_completion])
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -49,7 +49,7 @@ def _get_dataframe():
49
  )
50
 
51
 
52
- def generate_system_prompt(dataset_description, progress=gr.Progress()):
53
  progress(0.0, desc="Starting")
54
  progress(0.3, desc="Initializing")
55
  generate_description = get_prompt_generator()
@@ -71,7 +71,12 @@ def generate_system_prompt(dataset_description, progress=gr.Progress()):
71
 
72
 
73
  def generate_sample_dataset(
74
- system_prompt, difficulty, clarity, labels, multi_label, progress=gr.Progress()
 
 
 
 
 
75
  ):
76
  dataframe = generate_dataset(
77
  system_prompt=system_prompt,
@@ -294,14 +299,14 @@ def push_dataset(
294
  temperature=temperature,
295
  )
296
  push_dataset_to_hub(
297
- dataframe,
298
- org_name,
299
- repo_name,
300
- multi_label,
301
- labels,
302
- oauth_token,
303
- private,
304
- pipeline_code,
305
  )
306
 
307
  dataframe = dataframe[
@@ -453,62 +458,59 @@ with gr.Blocks() as app:
453
 
454
  gr.HTML("<hr>")
455
  gr.Markdown("## 2. Configure your dataset")
456
- with gr.Row(equal_height=True):
457
- with gr.Row(equal_height=False):
458
- with gr.Column(scale=2):
459
- system_prompt = gr.Textbox(
460
- label="System prompt",
461
- placeholder="You are a helpful assistant.",
462
- visible=True,
463
- )
464
- labels = gr.Dropdown(
465
- choices=[],
466
- allow_custom_value=True,
467
- interactive=True,
468
- label="Labels",
469
- multiselect=True,
470
- info="Add the labels to classify the text.",
471
- )
472
- multi_label = gr.Checkbox(
473
- label="Multi-label",
474
- value=False,
475
- interactive=True,
476
- info="If checked, the text will be classified into multiple labels.",
477
- )
478
- clarity = gr.Dropdown(
479
- choices=[
480
- ("Clear", "clear"),
481
- (
482
- "Understandable",
483
- "understandable with some effort",
484
- ),
485
- ("Ambiguous", "ambiguous"),
486
- ("Mixed", "mixed"),
487
- ],
488
- value="mixed",
489
- label="Clarity",
490
- info="Set how easily the correct label or labels can be identified.",
491
- interactive=True,
492
- )
493
- difficulty = gr.Dropdown(
494
- choices=[
495
- ("High School", "high school"),
496
- ("College", "college"),
497
- ("PhD", "PhD"),
498
- ("Mixed", "mixed"),
499
- ],
500
- value="high school",
501
- label="Difficulty",
502
- info="Select the comprehension level for the text. Ensure it matches the task context.",
503
- interactive=True,
504
- )
505
- with gr.Row():
506
- clear_btn_full = gr.Button("Clear", variant="secondary")
507
- btn_apply_to_sample_dataset = gr.Button(
508
- "Save", variant="primary"
509
- )
510
- with gr.Column(scale=3):
511
- dataframe = _get_dataframe()
512
 
513
  gr.HTML("<hr>")
514
  gr.Markdown("## 3. Generate your dataset")
@@ -530,7 +532,7 @@ with gr.Blocks() as app:
530
  temperature = gr.Slider(
531
  label="Temperature",
532
  minimum=0.1,
533
- maximum=1,
534
  value=0.8,
535
  step=0.1,
536
  interactive=True,
@@ -570,45 +572,37 @@ with gr.Blocks() as app:
570
  fn=generate_system_prompt,
571
  inputs=[dataset_description],
572
  outputs=[system_prompt, labels],
573
- show_progress=True,
574
  ).then(
575
  fn=generate_sample_dataset,
576
  inputs=[system_prompt, difficulty, clarity, labels, multi_label],
577
  outputs=[dataframe],
578
- show_progress=True,
579
  )
580
 
581
  btn_apply_to_sample_dataset.click(
582
  fn=validate_input_labels,
583
  inputs=[labels],
584
  outputs=[labels],
585
- show_progress=True,
586
  ).success(
587
  fn=generate_sample_dataset,
588
  inputs=[system_prompt, difficulty, clarity, labels, multi_label],
589
  outputs=[dataframe],
590
- show_progress=True,
591
  )
592
 
593
  btn_push_to_hub.click(
594
  fn=validate_argilla_user_workspace_dataset,
595
  inputs=[repo_name],
596
  outputs=[success_message],
597
- show_progress=True,
598
  ).then(
599
  fn=validate_push_to_hub,
600
  inputs=[org_name, repo_name],
601
  outputs=[success_message],
602
- show_progress=True,
603
  ).success(
604
  fn=validate_input_labels,
605
  inputs=[labels],
606
  outputs=[labels],
607
- show_progress=True,
608
  ).success(
609
  fn=hide_success_message,
610
  outputs=[success_message],
611
- show_progress=True,
612
  ).success(
613
  fn=hide_pipeline_code_visibility,
614
  inputs=[],
@@ -629,7 +623,6 @@ with gr.Blocks() as app:
629
  pipeline_code,
630
  ],
631
  outputs=[success_message],
632
- show_progress=True,
633
  ).success(
634
  fn=show_success_message,
635
  inputs=[org_name, repo_name],
@@ -657,6 +650,7 @@ with gr.Blocks() as app:
657
  "",
658
  "",
659
  [],
 
660
  _get_dataframe(),
661
  ),
662
  inputs=[dataframe],
 
49
  )
50
 
51
 
52
+ def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
53
  progress(0.0, desc="Starting")
54
  progress(0.3, desc="Initializing")
55
  generate_description = get_prompt_generator()
 
71
 
72
 
73
  def generate_sample_dataset(
74
+ system_prompt: str,
75
+ difficulty: str,
76
+ clarity: str,
77
+ labels: List[str],
78
+ multi_label: bool,
79
+ progress=gr.Progress(),
80
  ):
81
  dataframe = generate_dataset(
82
  system_prompt=system_prompt,
 
299
  temperature=temperature,
300
  )
301
  push_dataset_to_hub(
302
+ dataframe=dataframe,
303
+ org_name=org_name,
304
+ repo_name=repo_name,
305
+ multi_label=multi_label,
306
+ labels=labels,
307
+ oauth_token=oauth_token,
308
+ private=private,
309
+ pipeline_code=pipeline_code,
310
  )
311
 
312
  dataframe = dataframe[
 
458
 
459
  gr.HTML("<hr>")
460
  gr.Markdown("## 2. Configure your dataset")
461
+ with gr.Row(equal_height=False):
462
+ with gr.Column(scale=2):
463
+ system_prompt = gr.Textbox(
464
+ label="System prompt",
465
+ placeholder="You are a helpful assistant.",
466
+ visible=True,
467
+ )
468
+ labels = gr.Dropdown(
469
+ choices=[],
470
+ allow_custom_value=True,
471
+ interactive=True,
472
+ label="Labels",
473
+ multiselect=True,
474
+ info="Add the labels to classify the text.",
475
+ )
476
+ multi_label = gr.Checkbox(
477
+ label="Multi-label",
478
+ value=False,
479
+ interactive=True,
480
+ info="If checked, the text will be classified into multiple labels.",
481
+ )
482
+ clarity = gr.Dropdown(
483
+ choices=[
484
+ ("Clear", "clear"),
485
+ (
486
+ "Understandable",
487
+ "understandable with some effort",
488
+ ),
489
+ ("Ambiguous", "ambiguous"),
490
+ ("Mixed", "mixed"),
491
+ ],
492
+ value="mixed",
493
+ label="Clarity",
494
+ info="Set how easily the correct label or labels can be identified.",
495
+ interactive=True,
496
+ )
497
+ difficulty = gr.Dropdown(
498
+ choices=[
499
+ ("High School", "high school"),
500
+ ("College", "college"),
501
+ ("PhD", "PhD"),
502
+ ("Mixed", "mixed"),
503
+ ],
504
+ value="high school",
505
+ label="Difficulty",
506
+ info="Select the comprehension level for the text. Ensure it matches the task context.",
507
+ interactive=True,
508
+ )
509
+ with gr.Row():
510
+ clear_btn_full = gr.Button("Clear", variant="secondary")
511
+ btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
512
+ with gr.Column(scale=3):
513
+ dataframe = _get_dataframe()
 
 
 
514
 
515
  gr.HTML("<hr>")
516
  gr.Markdown("## 3. Generate your dataset")
 
532
  temperature = gr.Slider(
533
  label="Temperature",
534
  minimum=0.1,
535
+ maximum=1.5,
536
  value=0.8,
537
  step=0.1,
538
  interactive=True,
 
572
  fn=generate_system_prompt,
573
  inputs=[dataset_description],
574
  outputs=[system_prompt, labels],
 
575
  ).then(
576
  fn=generate_sample_dataset,
577
  inputs=[system_prompt, difficulty, clarity, labels, multi_label],
578
  outputs=[dataframe],
 
579
  )
580
 
581
  btn_apply_to_sample_dataset.click(
582
  fn=validate_input_labels,
583
  inputs=[labels],
584
  outputs=[labels],
 
585
  ).success(
586
  fn=generate_sample_dataset,
587
  inputs=[system_prompt, difficulty, clarity, labels, multi_label],
588
  outputs=[dataframe],
 
589
  )
590
 
591
  btn_push_to_hub.click(
592
  fn=validate_argilla_user_workspace_dataset,
593
  inputs=[repo_name],
594
  outputs=[success_message],
 
595
  ).then(
596
  fn=validate_push_to_hub,
597
  inputs=[org_name, repo_name],
598
  outputs=[success_message],
 
599
  ).success(
600
  fn=validate_input_labels,
601
  inputs=[labels],
602
  outputs=[labels],
 
603
  ).success(
604
  fn=hide_success_message,
605
  outputs=[success_message],
 
606
  ).success(
607
  fn=hide_pipeline_code_visibility,
608
  inputs=[],
 
623
  pipeline_code,
624
  ],
625
  outputs=[success_message],
 
626
  ).success(
627
  fn=show_success_message,
628
  inputs=[org_name, repo_name],
 
650
  "",
651
  "",
652
  [],
653
+ "",
654
  _get_dataframe(),
655
  ),
656
  inputs=[dataframe],
src/synthetic_dataset_generator/constants.py CHANGED
@@ -3,10 +3,6 @@ import warnings
3
 
4
  import argilla as rg
5
 
6
- # Tasks
7
- TEXTCAT_TASK = "text_classification"
8
- SFT_TASK = "supervised_fine_tuning"
9
-
10
  # Inference
11
  MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
12
  MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
@@ -20,28 +16,56 @@ OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
20
  HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
21
  VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
22
 
23
- # check if model is set correctly
24
- if HUGGINGFACE_BASE_URL and MODEL:
25
- raise ValueError(
26
- "`HUGGINGFACE_BASE_URL` and `MODEL` cannot be set at the same time. Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
27
- )
28
- if not MODEL:
29
- if OPENAI_BASE_URL or OLLAMA_BASE_URL or VLLM_BASE_URL:
30
- raise ValueError("`MODEL` is not set. Please provide a model id for inference.")
31
-
32
- # Check if multiple base URLs are provided
33
- base_urls = [
34
- url
35
- for url in [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
36
- if url
 
 
37
  ]
38
- if len(base_urls) > 1:
39
- raise ValueError(
40
- f"Multiple base URLs provided: {', '.join(base_urls)}. Only one base URL can be set at a time."
41
- )
42
- BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  # API Keys
46
  HF_TOKEN = os.getenv("HF_TOKEN")
47
  if not HF_TOKEN:
 
3
 
4
  import argilla as rg
5
 
 
 
 
 
6
  # Inference
7
  MAX_NUM_TOKENS = int(os.getenv("MAX_NUM_TOKENS", 2048))
8
  MAX_NUM_ROWS = int(os.getenv("MAX_NUM_ROWS", 1000))
 
16
  HUGGINGFACE_BASE_URL = os.getenv("HUGGINGFACE_BASE_URL")
17
  VLLM_BASE_URL = os.getenv("VLLM_BASE_URL")
18
 
19
+ # Just used in case of selecting a different model for completions
20
+ MODEL_COMPLETION = os.getenv("MODEL_COMPLETION", MODEL)
21
+ TOKENIZER_ID_COMPLETION = os.getenv("TOKENIZER_ID_COMPLETION", TOKENIZER_ID)
22
+ OPENAI_BASE_URL_COMPLETION = os.getenv("OPENAI_BASE_URL_COMPLETION", OPENAI_BASE_URL)
23
+ OLLAMA_BASE_URL_COMPLETION = os.getenv("OLLAMA_BASE_URL_COMPLETION", OLLAMA_BASE_URL)
24
+ HUGGINGFACE_BASE_URL_COMPLETION = os.getenv(
25
+ "HUGGINGFACE_BASE_URL_COMPLETION", HUGGINGFACE_BASE_URL
26
+ )
27
+ VLLM_BASE_URL_COMPLETION = os.getenv("VLLM_BASE_URL_COMPLETION", VLLM_BASE_URL)
28
+
29
+ base_urls = [OPENAI_BASE_URL, OLLAMA_BASE_URL, HUGGINGFACE_BASE_URL, VLLM_BASE_URL]
30
+ base_urls_completion = [
31
+ OPENAI_BASE_URL_COMPLETION,
32
+ OLLAMA_BASE_URL_COMPLETION,
33
+ HUGGINGFACE_BASE_URL_COMPLETION,
34
+ VLLM_BASE_URL_COMPLETION,
35
  ]
 
 
 
 
 
36
 
37
 
38
+ # Validate the configuration of the model and base URLs.
39
+ def validate_configuration(base_urls, model, env_context=""):
40
+ huggingface_url = base_urls[2]
41
+ if huggingface_url and model:
42
+ raise ValueError(
43
+ f"`HUGGINGFACE_BASE_URL{env_context}` and `MODEL{env_context}` cannot be set at the same time. "
44
+ "Use a model id for serverless inference and a base URL dedicated to Hugging Face Inference Endpoints."
45
+ )
46
+
47
+ if not model and any(base_urls):
48
+ raise ValueError(
49
+ f"`MODEL{env_context}` is not set. Please provide a model id for inference."
50
+ )
51
+
52
+ active_urls = [url for url in base_urls if url]
53
+ if len(active_urls) > 1:
54
+ raise ValueError(
55
+ f"Multiple base URLs are provided: {', '.join(active_urls)}. "
56
+ "Only one base URL can be set at a time."
57
+ )
58
+ validate_configuration(base_urls, MODEL)
59
+ validate_configuration(base_urls_completion, MODEL_COMPLETION, "_COMPLETION")
60
+
61
+ BASE_URL = OPENAI_BASE_URL or OLLAMA_BASE_URL or HUGGINGFACE_BASE_URL or VLLM_BASE_URL
62
+ BASE_URL_COMPLETION = (
63
+ OPENAI_BASE_URL_COMPLETION
64
+ or OLLAMA_BASE_URL_COMPLETION
65
+ or HUGGINGFACE_BASE_URL_COMPLETION
66
+ or VLLM_BASE_URL_COMPLETION
67
+ )
68
+
69
  # API Keys
70
  HF_TOKEN = os.getenv("HF_TOKEN")
71
  if not HF_TOKEN:
src/synthetic_dataset_generator/pipelines/base.py CHANGED
@@ -8,11 +8,17 @@ from synthetic_dataset_generator.constants import (
8
  API_KEYS,
9
  DEFAULT_BATCH_SIZE,
10
  HUGGINGFACE_BASE_URL,
 
11
  MODEL,
 
12
  OLLAMA_BASE_URL,
 
13
  OPENAI_BASE_URL,
 
14
  TOKENIZER_ID,
 
15
  VLLM_BASE_URL,
 
16
  )
17
 
18
  TOKEN_INDEX = 0
@@ -73,12 +79,20 @@ def _get_llm_class() -> str:
73
  return "InferenceEndpointsLLM"
74
 
75
 
76
- def _get_llm(use_magpie_template=False, **kwargs):
 
 
 
 
 
 
 
77
  if OPENAI_BASE_URL:
78
  llm = OpenAILLM(
79
- model=MODEL,
80
- base_url=OPENAI_BASE_URL,
81
  api_key=_get_next_api_key(),
 
82
  **kwargs,
83
  )
84
  if "generation_kwargs" in kwargs:
@@ -108,19 +122,25 @@ def _get_llm(use_magpie_template=False, **kwargs):
108
  kwargs["generation_kwargs"] = {}
109
  kwargs["generation_kwargs"]["options"] = options
110
  llm = OllamaLLM(
111
- model=MODEL,
112
- host=OLLAMA_BASE_URL,
113
- tokenizer_id=TOKENIZER_ID or MODEL,
114
  use_magpie_template=use_magpie_template,
 
115
  **kwargs,
116
  )
117
  elif HUGGINGFACE_BASE_URL:
118
  kwargs["generation_kwargs"]["do_sample"] = True
119
  llm = InferenceEndpointsLLM(
120
  api_key=_get_next_api_key(),
121
- base_url=HUGGINGFACE_BASE_URL,
122
- tokenizer_id=TOKENIZER_ID or MODEL,
 
 
 
 
123
  use_magpie_template=use_magpie_template,
 
124
  **kwargs,
125
  )
126
  elif VLLM_BASE_URL:
@@ -128,19 +148,21 @@ def _get_llm(use_magpie_template=False, **kwargs):
128
  if "do_sample" in kwargs["generation_kwargs"]:
129
  del kwargs["generation_kwargs"]["do_sample"]
130
  llm = ClientvLLM(
131
- base_url=VLLM_BASE_URL,
132
- model=MODEL,
133
- tokenizer=TOKENIZER_ID or MODEL,
134
  api_key=_get_next_api_key(),
135
  use_magpie_template=use_magpie_template,
 
136
  **kwargs,
137
  )
138
  else:
139
  llm = InferenceEndpointsLLM(
140
  api_key=_get_next_api_key(),
141
- tokenizer_id=TOKENIZER_ID or MODEL,
142
- model_id=MODEL,
143
  use_magpie_template=use_magpie_template,
 
144
  **kwargs,
145
  )
146
 
 
8
  API_KEYS,
9
  DEFAULT_BATCH_SIZE,
10
  HUGGINGFACE_BASE_URL,
11
+ HUGGINGFACE_BASE_URL_COMPLETION,
12
  MODEL,
13
+ MODEL_COMPLETION,
14
  OLLAMA_BASE_URL,
15
+ OLLAMA_BASE_URL_COMPLETION,
16
  OPENAI_BASE_URL,
17
+ OPENAI_BASE_URL_COMPLETION,
18
  TOKENIZER_ID,
19
+ TOKENIZER_ID_COMPLETION,
20
  VLLM_BASE_URL,
21
+ VLLM_BASE_URL_COMPLETION,
22
  )
23
 
24
  TOKEN_INDEX = 0
 
79
  return "InferenceEndpointsLLM"
80
 
81
 
82
+ def _get_llm(
83
+ structured_output: dict = None,
84
+ use_magpie_template: str = False,
85
+ is_completion: bool = False,
86
+ **kwargs,
87
+ ):
88
+ model = MODEL_COMPLETION if is_completion else MODEL
89
+ tokenizer_id = TOKENIZER_ID_COMPLETION if is_completion else TOKENIZER_ID or model
90
  if OPENAI_BASE_URL:
91
  llm = OpenAILLM(
92
+ model=model,
93
+ base_url=OPENAI_BASE_URL_COMPLETION if is_completion else OPENAI_BASE_URL,
94
  api_key=_get_next_api_key(),
95
+ structured_output=structured_output,
96
  **kwargs,
97
  )
98
  if "generation_kwargs" in kwargs:
 
122
  kwargs["generation_kwargs"] = {}
123
  kwargs["generation_kwargs"]["options"] = options
124
  llm = OllamaLLM(
125
+ model=model,
126
+ host=OLLAMA_BASE_URL_COMPLETION if is_completion else OLLAMA_BASE_URL,
127
+ tokenizer_id=tokenizer_id,
128
  use_magpie_template=use_magpie_template,
129
+ structured_output=structured_output,
130
  **kwargs,
131
  )
132
  elif HUGGINGFACE_BASE_URL:
133
  kwargs["generation_kwargs"]["do_sample"] = True
134
  llm = InferenceEndpointsLLM(
135
  api_key=_get_next_api_key(),
136
+ base_url=(
137
+ HUGGINGFACE_BASE_URL_COMPLETION
138
+ if is_completion
139
+ else HUGGINGFACE_BASE_URL
140
+ ),
141
+ tokenizer_id=tokenizer_id,
142
  use_magpie_template=use_magpie_template,
143
+ structured_output=structured_output,
144
  **kwargs,
145
  )
146
  elif VLLM_BASE_URL:
 
148
  if "do_sample" in kwargs["generation_kwargs"]:
149
  del kwargs["generation_kwargs"]["do_sample"]
150
  llm = ClientvLLM(
151
+ base_url=VLLM_BASE_URL_COMPLETION if is_completion else VLLM_BASE_URL,
152
+ model=model,
153
+ tokenizer=tokenizer_id,
154
  api_key=_get_next_api_key(),
155
  use_magpie_template=use_magpie_template,
156
+ structured_output=structured_output,
157
  **kwargs,
158
  )
159
  else:
160
  llm = InferenceEndpointsLLM(
161
  api_key=_get_next_api_key(),
162
+ tokenizer_id=tokenizer_id,
163
+ model_id=model,
164
  use_magpie_template=use_magpie_template,
165
+ structured_output=structured_output,
166
  **kwargs,
167
  )
168
 
src/synthetic_dataset_generator/pipelines/chat.py CHANGED
@@ -1,4 +1,10 @@
1
- from distilabel.steps.tasks import ChatGeneration, Magpie, TextGeneration
 
 
 
 
 
 
2
 
3
  from synthetic_dataset_generator.constants import (
4
  MAGPIE_PRE_QUERY_TEMPLATE,
@@ -118,6 +124,18 @@ The prompt you write should follow the same style and structure as the following
118
  User dataset description:
119
  """
120
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  DEFAULT_DATASET_DESCRIPTIONS = [
122
  "rude customer assistant for a phone company",
123
  "assistant that solves math puzzles using python",
@@ -140,7 +158,7 @@ else:
140
  ]
141
 
142
 
143
- def _get_output_mappings(num_turns):
144
  if num_turns == 1:
145
  return {"instruction": "prompt", "response": "completion"}
146
  else:
@@ -162,7 +180,7 @@ def get_prompt_generator():
162
  return prompt_generator
163
 
164
 
165
- def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
166
  input_mappings = _get_output_mappings(num_turns)
167
  output_mappings = input_mappings.copy()
168
  if num_turns == 1:
@@ -203,14 +221,31 @@ def get_magpie_generator(system_prompt, num_turns, temperature, is_sample):
203
  return magpie_generator
204
 
205
 
206
- def get_response_generator(system_prompt, num_turns, temperature, is_sample):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  if num_turns == 1:
208
  generation_kwargs = {
209
  "temperature": temperature,
210
  "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
211
  }
212
  response_generator = TextGeneration(
213
- llm=_get_llm(generation_kwargs=generation_kwargs),
214
  system_prompt=system_prompt,
215
  output_mappings={"generation": "completion"},
216
  input_mappings={"instruction": "prompt"},
@@ -221,7 +256,7 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
221
  "max_new_tokens": MAX_NUM_TOKENS,
222
  }
223
  response_generator = ChatGeneration(
224
- llm=_get_llm(generation_kwargs=generation_kwargs),
225
  output_mappings={"generation": "completion"},
226
  input_mappings={"conversation": "messages"},
227
  )
@@ -229,36 +264,236 @@ def get_response_generator(system_prompt, num_turns, temperature, is_sample):
229
  return response_generator
230
 
231
 
232
- def generate_pipeline_code(system_prompt, num_turns, num_rows):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  input_mappings = _get_output_mappings(num_turns)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  code = f"""
236
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
237
- import os
238
  from distilabel.pipeline import Pipeline
239
- from distilabel.steps import KeepColumns
240
- from distilabel.steps.tasks import MagpieGenerator
241
- from distilabel.llms import {_get_llm_class()}
242
 
243
- SYSTEM_PROMPT = "{system_prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  with Pipeline(name="sft") as pipeline:
246
- magpie = MagpieGenerator(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  llm={_get_llm_class()}.from_dict(
248
  {_get_llm().dump()}
249
  ),
250
- n_turns={num_turns},
251
- num_rows={num_rows},
252
- batch_size=1,
253
- system_prompt=SYSTEM_PROMPT,
254
- output_mappings={input_mappings},
255
  )
256
- keep_columns = KeepColumns(
257
- columns={list(input_mappings.values())} + ["model_name"],
 
 
 
258
  )
259
- magpie.connect(keep_columns)
 
 
 
 
 
 
 
 
 
 
260
 
 
 
 
261
  if __name__ == "__main__":
262
  distiset = pipeline.run()
 
263
  """
264
  return code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import get_dataset_config_names, get_dataset_split_names
2
+ from distilabel.steps.tasks import (
3
+ ChatGeneration,
4
+ Magpie,
5
+ GenerateSentencePair,
6
+ TextGeneration,
7
+ )
8
 
9
  from synthetic_dataset_generator.constants import (
10
  MAGPIE_PRE_QUERY_TEMPLATE,
 
124
  User dataset description:
125
  """
126
 
127
+ FOLLOW_UP_TEMPLATE = """Conversation:
128
+ {% for message in messages %}
129
+ {% if message.role == "user" %}
130
+ User Question: {{ message.content }}
131
+ {% elif message.role == "assistant" %}
132
+ Assistant Response: {{ message.content }}
133
+ {% endif %}
134
+ {% endfor %}
135
+
136
+ Please generate the next logical user message in this conversation. Do not include any other information or 'User Question' in your response.
137
+ """.rstrip()
138
+
139
  DEFAULT_DATASET_DESCRIPTIONS = [
140
  "rude customer assistant for a phone company",
141
  "assistant that solves math puzzles using python",
 
158
  ]
159
 
160
 
161
+ def _get_output_mappings(num_turns: int):
162
  if num_turns == 1:
163
  return {"instruction": "prompt", "response": "completion"}
164
  else:
 
180
  return prompt_generator
181
 
182
 
183
+ def get_magpie_generator(num_turns: int, temperature: float, is_sample: bool):
184
  input_mappings = _get_output_mappings(num_turns)
185
  output_mappings = input_mappings.copy()
186
  if num_turns == 1:
 
221
  return magpie_generator
222
 
223
 
224
+ def get_sentence_pair_generator(temperature: float, is_sample: bool):
225
+ generation_kwargs = {
226
+ "temperature": temperature,
227
+ "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
228
+ }
229
+ sentence_pair_generator = GenerateSentencePair(
230
+ llm=_get_llm(generation_kwargs=generation_kwargs),
231
+ triplet=False,
232
+ action="query",
233
+ hard_negative=True,
234
+ )
235
+ sentence_pair_generator.load()
236
+ return sentence_pair_generator
237
+
238
+
239
+ def get_response_generator(
240
+ system_prompt: str, num_turns: int, temperature: float, is_sample: bool
241
+ ):
242
  if num_turns == 1:
243
  generation_kwargs = {
244
  "temperature": temperature,
245
  "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
246
  }
247
  response_generator = TextGeneration(
248
+ llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
249
  system_prompt=system_prompt,
250
  output_mappings={"generation": "completion"},
251
  input_mappings={"instruction": "prompt"},
 
256
  "max_new_tokens": MAX_NUM_TOKENS,
257
  }
258
  response_generator = ChatGeneration(
259
+ llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
260
  output_mappings={"generation": "completion"},
261
  input_mappings={"conversation": "messages"},
262
  )
 
264
  return response_generator
265
 
266
 
267
+ def get_follow_up_generator(type: str, temperature: float, is_sample: bool):
268
+ if type == "instruction":
269
+ generation_kwargs = {
270
+ "temperature": temperature,
271
+ "max_new_tokens": 256 if is_sample else int(MAX_NUM_TOKENS * 0.5),
272
+ }
273
+ follow_up_generator = TextGeneration(
274
+ llm=_get_llm(generation_kwargs=generation_kwargs),
275
+ template=FOLLOW_UP_TEMPLATE,
276
+ columns=["messages"],
277
+ )
278
+ else:
279
+ generation_kwargs = {
280
+ "temperature": temperature,
281
+ "max_new_tokens": MAX_NUM_TOKENS,
282
+ }
283
+ follow_up_generator = ChatGeneration(
284
+ llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
285
+ )
286
+ follow_up_generator.load()
287
+ return follow_up_generator
288
+
289
+ def generate_pipeline_code_system_prompt(
290
+ system_prompt: str,
291
+ num_turns: int,
292
+ num_rows: int,
293
+ ):
294
  input_mappings = _get_output_mappings(num_turns)
295
+ code = f"""
296
+ # Requirements: `pip install distilabel[hf-inference-endpoints]`
297
+ import os
298
+ from distilabel.pipeline import Pipeline
299
+ from distilabel.steps import KeepColumns
300
+ from distilabel.steps.tasks import MagpieGenerator
301
+ from distilabel.llms import {_get_llm_class()}
302
+
303
+ SYSTEM_PROMPT = "{system_prompt}"
304
+
305
+ with Pipeline(name="sft") as pipeline:
306
+ magpie = MagpieGenerator(
307
+ llm={_get_llm_class()}.from_dict(
308
+ {_get_llm().dump()}
309
+ ),
310
+ n_turns={num_turns},
311
+ num_rows={num_rows},
312
+ batch_size=1,
313
+ system_prompt=SYSTEM_PROMPT,
314
+ output_mappings={input_mappings},
315
+ )
316
+ keep_columns = KeepColumns(
317
+ columns={list(input_mappings.values())} + ["model_name"],
318
+ )
319
+ magpie.connect(keep_columns)
320
 
321
+ if __name__ == "__main__":
322
+ distiset = pipeline.run()
323
+ """
324
+ return code
325
+
326
+ def generate_pipeline_code_seed(
327
+ repo_id: str,
328
+ subset: str,
329
+ split: str,
330
+ input_type: str,
331
+ document_column: str,
332
+ num_turns: int,
333
+ num_rows: int,
334
+ ):
335
  code = f"""
336
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
337
+ from distilabel.models import {_get_llm_class()}
338
  from distilabel.pipeline import Pipeline
339
+ from distilabel.steps import KeepColumns{", LoadDataFromDicts" if input_type != "dataset-input" else ""}{", LoadDataFromHub" if input_type == "dataset-input" else ""}{", StepInput, step" if num_turns > 1 else ""}
340
+ from distilabel.steps.tasks import GenerateSentencePair, TextGeneration {", ChatGeneration" if num_turns > 1 else ""}
341
+ """
342
 
343
+ if num_turns > 1:
344
+ code += """
345
+ FOLLOW_UP_TEMPLATE = '''Conversation:
346
+ {{% for message in messages %}}
347
+ {{% if message.role == "user" %}}
348
+ User Question: {{{{ message.content }}}}
349
+ {{% elif message.role == "assistant" %}}
350
+ Assistant Response: {{{{ message.content }}}}
351
+ {{% endif %}}
352
+ {{% endfor %}}
353
+
354
+ Please generate the next logical user message in this conversation. Do not include any other information or 'User Question' in your response.
355
+ '''.rstrip()
356
+
357
+ @step(inputs=["prompt", "completion"], outputs=["messages"])
358
+ def PrepareMessages(*inputs: StepInput) -> StepOutput:
359
+ for input in inputs:
360
+ for item in input:
361
+ item["messages"] = [
362
+ {"role": "user", "content": item["prompt"]},
363
+ {"role": "assistant", "content": item["completion"]},
364
+ ]
365
+ yield input
366
+
367
+
368
+ @step(inputs=["messages", "generation"], outputs=["messages"])
369
+ def FormatMessagesInstruction(*inputs: StepInput) -> StepOutput:
370
+ for input in inputs:
371
+ for item in input:
372
+ item["messages"].append({"role": "user", "content": item["generation"]})
373
+ yield input
374
+
375
+
376
+ @step(inputs=["messages", "generation"], outputs=["messages"])
377
+ def FormatMessagesResponse(*inputs: StepInput) -> StepOutput:
378
+ for input in inputs:
379
+ for item in input:
380
+ item["messages"].append({"role": "assistant", "content": item["generation"]})
381
+ yield input
382
+ """
383
+
384
+ if input_type == "dataset-input":
385
+ code += f"""
386
+ with Pipeline(name="sft") as pipeline:
387
+ load_the_dataset = LoadDataFromHub(
388
+ repo_id='{repo_id}',
389
+ config='{subset}',
390
+ split='{split}',
391
+ num_examples={num_rows},
392
+ batch_size=2,
393
+ output_mappings={{'{document_column}':'anchor'}},
394
+ )
395
+ """
396
+
397
+ else:
398
+ code += """
399
+ data = process_and_chunk_files(files=[files])
400
 
401
  with Pipeline(name="sft") as pipeline:
402
+ load_the_dataset = LoadDataFromDicts(
403
+ data = data
404
+ )
405
+ """
406
+ code += f"""
407
+ instruction_generator = GenerateSentencePair(
408
+ name="instruction_generation",
409
+ triplet=False,
410
+ hard_negative=True,
411
+ action="query",
412
+ llm={_get_llm_class()}.from_dict(
413
+ {_get_llm().dump()}
414
+ ),
415
+ input_batch_size=10,
416
+ output_mappings={{"positive": "prompt"}},
417
+ )
418
+
419
+ response_generator = TextGeneration(
420
+ name="response_generation",
421
+ llm={_get_llm_class()}.from_dict(
422
+ {_get_llm().dump()}
423
+ ),
424
+ input_batch_size=10,
425
+ input_mappings={{"instruction": "prompt"}},
426
+ output_mappings={{"generation": "completion"}},
427
+ )
428
+ """
429
+
430
+ if num_turns > 1:
431
+ code += """
432
+ prepare_messages = PrepareMessages()
433
+ """
434
+
435
+ for i in range(num_turns - 1):
436
+ code += f"""
437
+ follow_up_instruction_{i} = TextGeneration(
438
  llm={_get_llm_class()}.from_dict(
439
  {_get_llm().dump()}
440
  ),
441
+ template=FOLLOW_UP_TEMPLATE,
442
+ columns=["messages"],
 
 
 
443
  )
444
+ format_instruction_{i} = FormatMessagesInstruction()
445
+ follow_up_response_{i} = ChatGeneration(
446
+ llm={_get_llm_class()}.from_dict(
447
+ {_get_llm().dump()}
448
+ ),
449
  )
450
+ format_response_{i} = FormatMessagesResponse()
451
+ """
452
+
453
+ if num_turns > 1:
454
+ code += """
455
+ keep_columns = KeepColumns(columns=["messages"])
456
+ """
457
+ code += "load_the_dataset >> instruction_generator >> response_generator >> prepare_messages"
458
+
459
+ for i in range(1, num_turns + 1):
460
+ code += f" >> follow_up_instruction_{i} >> format_instruction_{i} >> follow_up_response_{i} >> format_response_{i}"
461
 
462
+ code += " >> keep_columns"
463
+
464
+ code += """
465
  if __name__ == "__main__":
466
  distiset = pipeline.run()
467
+ )
468
  """
469
  return code
470
+
471
+ def generate_pipeline_code(
472
+ repo_id: str,
473
+ input_type: str,
474
+ system_prompt: str,
475
+ document_column: str,
476
+ num_turns: int,
477
+ num_rows: int,
478
+ ):
479
+ if input_type == "dataset-input" and repo_id is not None:
480
+ subset = get_dataset_config_names(repo_id)[0]
481
+ split = get_dataset_split_names(repo_id, subset)[0]
482
+ else:
483
+ subset = "default"
484
+ split = "train"
485
+ if input_type == "prompt-type":
486
+ return generate_pipeline_code_system_prompt(
487
+ system_prompt=system_prompt,
488
+ num_turns=num_turns,
489
+ num_rows=num_rows,
490
+ )
491
+ return generate_pipeline_code_seed(
492
+ repo_id=repo_id,
493
+ subset=subset,
494
+ split=split,
495
+ input_type=input_type,
496
+ document_column=document_column,
497
+ num_turns=num_turns,
498
+ num_rows=num_rows,
499
+ )
src/synthetic_dataset_generator/pipelines/eval.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from datasets import get_dataset_config_names, get_dataset_split_names
2
  from distilabel.models import InferenceEndpointsLLM
3
  from distilabel.steps.tasks import (
@@ -10,7 +12,7 @@ from synthetic_dataset_generator.pipelines.base import _get_next_api_key
10
  from synthetic_dataset_generator.utils import extract_column_names
11
 
12
 
13
- def get_ultrafeedback_evaluator(aspect, is_sample):
14
  ultrafeedback_evaluator = UltraFeedback(
15
  llm=InferenceEndpointsLLM(
16
  model_id=MODEL,
@@ -27,7 +29,9 @@ def get_ultrafeedback_evaluator(aspect, is_sample):
27
  return ultrafeedback_evaluator
28
 
29
 
30
- def get_custom_evaluator(prompt_template, structured_output, columns, is_sample):
 
 
31
  custom_evaluator = TextGeneration(
32
  llm=InferenceEndpointsLLM(
33
  model_id=MODEL,
@@ -47,7 +51,13 @@ def get_custom_evaluator(prompt_template, structured_output, columns, is_sample)
47
 
48
 
49
  def generate_ultrafeedback_pipeline_code(
50
- repo_id, subset, split, aspects, instruction_column, response_columns, num_rows
 
 
 
 
 
 
51
  ):
52
  if len(aspects) == 1:
53
  code = f"""
 
1
+ from typing import List
2
+
3
  from datasets import get_dataset_config_names, get_dataset_split_names
4
  from distilabel.models import InferenceEndpointsLLM
5
  from distilabel.steps.tasks import (
 
12
  from synthetic_dataset_generator.utils import extract_column_names
13
 
14
 
15
+ def get_ultrafeedback_evaluator(aspect: str, is_sample: bool):
16
  ultrafeedback_evaluator = UltraFeedback(
17
  llm=InferenceEndpointsLLM(
18
  model_id=MODEL,
 
29
  return ultrafeedback_evaluator
30
 
31
 
32
+ def get_custom_evaluator(
33
+ prompt_template: str, structured_output: dict, columns: List[str], is_sample: bool
34
+ ):
35
  custom_evaluator = TextGeneration(
36
  llm=InferenceEndpointsLLM(
37
  model_id=MODEL,
 
51
 
52
 
53
  def generate_ultrafeedback_pipeline_code(
54
+ repo_id: str,
55
+ subset: str,
56
+ split: str,
57
+ aspects: List[str],
58
+ instruction_column: str,
59
+ response_columns: str,
60
+ num_rows: int,
61
  ):
62
  if len(aspects) == 1:
63
  code = f"""
src/synthetic_dataset_generator/pipelines/rag.py CHANGED
@@ -1,7 +1,3 @@
1
- import os
2
-
3
- from typing import List
4
-
5
  from datasets import get_dataset_config_names, get_dataset_split_names
6
  from distilabel.steps.tasks import (
7
  GenerateSentencePair,
@@ -87,7 +83,7 @@ def get_prompt_generator():
87
  return text_generator
88
 
89
 
90
- def get_chunks_generator(temperature, is_sample):
91
  generation_kwargs = {
92
  "temperature": temperature,
93
  "max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
@@ -104,7 +100,7 @@ def get_chunks_generator(temperature, is_sample):
104
  return text_generator
105
 
106
 
107
- def get_sentence_pair_generator(action, triplet, temperature, is_sample):
108
  generation_kwargs = {
109
  "temperature": temperature,
110
  "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
@@ -119,13 +115,13 @@ def get_sentence_pair_generator(action, triplet, temperature, is_sample):
119
  return sentence_pair_generator
120
 
121
 
122
- def get_response_generator(temperature, is_sample):
123
  generation_kwargs = {
124
  "temperature": temperature,
125
  "max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
126
  }
127
  text_generator = TextGeneration(
128
- llm=_get_llm(generation_kwargs=generation_kwargs),
129
  system_prompt=SYSTEM_PROMPT_RAG,
130
  template=RAG_TEMPLATE,
131
  columns=["context", "question"],
@@ -138,7 +134,6 @@ def get_response_generator(temperature, is_sample):
138
 
139
  def generate_pipeline_code(
140
  repo_id: str,
141
- file_paths: List[str],
142
  input_type: str,
143
  system_prompt: str,
144
  document_column: str,
@@ -293,10 +288,7 @@ with Pipeline(name="rag") as pipeline:
293
 
294
  pipeline += """
295
  if __name__ == "__main__":
296
- distiset = pipeline.run(use_cache=False)
297
- print(distiset)
298
- if distiset:
299
- print(distiset["default"]["train"][0])
300
  """
301
 
302
  return base_code + pipeline
 
 
 
 
 
1
  from datasets import get_dataset_config_names, get_dataset_split_names
2
  from distilabel.steps.tasks import (
3
  GenerateSentencePair,
 
83
  return text_generator
84
 
85
 
86
+ def get_chunks_generator(temperature: float, is_sample: bool):
87
  generation_kwargs = {
88
  "temperature": temperature,
89
  "max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
 
100
  return text_generator
101
 
102
 
103
+ def get_sentence_pair_generator(action: str, triplet: bool, temperature: float, is_sample: bool):
104
  generation_kwargs = {
105
  "temperature": temperature,
106
  "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
 
115
  return sentence_pair_generator
116
 
117
 
118
+ def get_response_generator(temperature: float, is_sample: bool):
119
  generation_kwargs = {
120
  "temperature": temperature,
121
  "max_new_tokens": MAX_NUM_TOKENS if is_sample else 256,
122
  }
123
  text_generator = TextGeneration(
124
+ llm=_get_llm(is_completion=True, generation_kwargs=generation_kwargs),
125
  system_prompt=SYSTEM_PROMPT_RAG,
126
  template=RAG_TEMPLATE,
127
  columns=["context", "question"],
 
134
 
135
  def generate_pipeline_code(
136
  repo_id: str,
 
137
  input_type: str,
138
  system_prompt: str,
139
  document_column: str,
 
288
 
289
  pipeline += """
290
  if __name__ == "__main__":
291
+ distiset = pipeline.run()
 
 
 
292
  """
293
 
294
  return base_code + pipeline
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -85,7 +85,9 @@ def get_prompt_generator():
85
  return prompt_generator
86
 
87
 
88
- def get_textcat_generator(difficulty, clarity, temperature, is_sample):
 
 
89
  generation_kwargs = {
90
  "temperature": temperature,
91
  "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
@@ -102,12 +104,12 @@ def get_textcat_generator(difficulty, clarity, temperature, is_sample):
102
  return textcat_generator
103
 
104
 
105
- def get_labeller_generator(system_prompt, labels, multi_label):
106
  generation_kwargs = {
107
  "temperature": 0.01,
108
  "max_new_tokens": MAX_NUM_TOKENS,
109
  }
110
- llm = _get_llm(generation_kwargs=generation_kwargs)
111
  labeller_generator = TextClassification(
112
  llm=llm,
113
  context=system_prompt,
 
85
  return prompt_generator
86
 
87
 
88
+ def get_textcat_generator(
89
+ difficulty: str, clarity: str, temperature: float, is_sample: bool
90
+ ):
91
  generation_kwargs = {
92
  "temperature": temperature,
93
  "max_new_tokens": 256 if is_sample else MAX_NUM_TOKENS,
 
104
  return textcat_generator
105
 
106
 
107
+ def get_labeller_generator(system_prompt: str, labels: List[str], multi_label: bool):
108
  generation_kwargs = {
109
  "temperature": 0.01,
110
  "max_new_tokens": MAX_NUM_TOKENS,
111
  }
112
+ llm = _get_llm(is_completion=True, generation_kwargs=generation_kwargs)
113
  labeller_generator = TextClassification(
114
  llm=llm,
115
  context=system_prompt,