davidberenstein1957 HF staff commited on
Commit
19f20a1
·
1 Parent(s): 67fa2ba

feat: updated push to hub flow

Browse files
src/distilabel_dataset_generator/apps/faq.py CHANGED
@@ -27,6 +27,10 @@ with gr.Blocks() as app:
27
 
28
  <p>The current implementation is based on <a href="https://huggingface.co/docs/api-inference/index" target="_blank">Free Serverless Hugging Face Inference Endpoints</a>. They are rate limited but free to use for anyone on the Hugging Face Hub. You can re-use the underlying pipeline to generate data with other <a href="https://distilabel.argilla.io/dev/components-gallery/llms/" target="_blank">distilabel LLM integrations</a>.</p>
29
 
 
 
 
 
30
  <h4 style="text-align: center;">What is distilabel?</h4>
31
 
32
  <p>Distilabel is the framework for synthetic data and AI feedback for engineers who need fast, reliable and scalable pipelines based on verified research papers.</p>
 
27
 
28
  <p>The current implementation is based on <a href="https://huggingface.co/docs/api-inference/index" target="_blank">Free Serverless Hugging Face Inference Endpoints</a>. They are rate limited but free to use for anyone on the Hugging Face Hub. You can re-use the underlying pipeline to generate data with other <a href="https://distilabel.argilla.io/dev/components-gallery/llms/" target="_blank">distilabel LLM integrations</a>.</p>
29
 
30
+ <h4 style="text-align: center;">Can I run this locally?</h4>
31
+
32
+ <p>Yes, you can run this locally by <a href="https://huggingface.co/spaces/argilla/distilabel-datacraft?clone=true" target="_blank">cloning the Space</a> and installing the requirements with `pip install -r requirements.txt` and running `python app.py`. Alternatively, you can install the <a href="https://github.com/argilla-io/distilabel" target="_blank">distilabel library</a> with `pip install distilabel[hf-inference-endpoints]` and use the pipeline code at the bottom of each application tab. Distilabel also supports running the pipeline with <a href="https://distilabel.argilla.io/latest/components-gallery/llms/" target="_blank">other LLMs</a>.</p>
33
+
34
  <h4 style="text-align: center;">What is distilabel?</h4>
35
 
36
  <p>Distilabel is the framework for synthetic data and AI feedback for engineers who need fast, reliable and scalable pipelines based on verified research papers.</p>
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -4,6 +4,7 @@ import time
4
 
5
  import gradio as gr
6
  import pandas as pd
 
7
  from distilabel.distiset import Distiset
8
  from huggingface_hub import upload_file
9
 
@@ -69,17 +70,7 @@ def generate_sample_dataset(system_prompt, progress=gr.Progress()):
69
  return result
70
 
71
 
72
- def generate_dataset(
73
- system_prompt: str,
74
- num_turns: int = 1,
75
- num_rows: int = 5,
76
- private: bool = True,
77
- org_name: str = None,
78
- repo_name: str = None,
79
- oauth_token: OAuthToken = None,
80
- progress=gr.Progress(),
81
- is_sample: bool = False,
82
- ):
83
  repo_id = (
84
  f"{org_name}/{repo_name}"
85
  if repo_name is not None and org_name is not None
@@ -90,15 +81,16 @@ def generate_dataset(
90
  raise gr.Error(
91
  "Please provide a `repo_name` and `org_name` to push the dataset to."
92
  )
 
93
 
94
- if num_turns > 4:
95
- num_turns = 4
96
- gr.Info("You can only generate a dataset with 4 or fewer turns. Setting to 4.")
97
- if num_rows > 5000:
98
- num_rows = 1000
99
- gr.Info(
100
- "You can only generate a dataset with 1000 or fewer rows. Setting to 1000."
101
- )
102
  if num_rows < 5:
103
  duration = 25
104
  elif num_rows < 10:
@@ -137,24 +129,37 @@ def generate_dataset(
137
 
138
  distiset = result_queue.get()
139
 
140
- if repo_id is not None:
141
- progress(0.95, desc="Pushing dataset to Hugging Face Hub.")
142
- distiset.push_to_hub(
143
- repo_id=repo_id,
144
- private=private,
145
- include_script=True,
146
- token=oauth_token,
147
- )
148
-
149
  # If not pushing to hub generate the dataset directly
150
  distiset = distiset["default"]["train"]
151
  if num_turns == 1:
152
  outputs = distiset.to_pandas()[["prompt", "completion"]]
153
  else:
154
  outputs = distiset.to_pandas()[["messages"]]
 
155
 
156
  progress(1.0, desc="Dataset generation completed")
157
- return pd.DataFrame(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
  def upload_pipeline_code(
@@ -182,7 +187,7 @@ with gr.Blocks(
182
  ) as app:
183
  with gr.Row():
184
  gr.Markdown(
185
- "To push the dataset to the Hugging Face Hub you need to sign in. This will only be used for pushing the dataset not for data generation."
186
  )
187
  with gr.Row():
188
  gr.Column()
@@ -269,22 +274,30 @@ with gr.Blocks(
269
  maximum=500,
270
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
271
  )
272
-
273
  with gr.Row(variant="panel"):
274
  org_name = get_org_dropdown()
275
  repo_name = gr.Textbox(
276
  label="Repo name", placeholder="dataset_name", value="my-distiset"
277
  )
278
  private = gr.Checkbox(
279
- label="Private dataset", value=True, interactive=True, scale=0.5
 
 
 
280
  )
281
  with gr.Row() as regenerate_row:
282
  gr.Column(scale=1)
283
  btn_generate_full_dataset = gr.Button(
284
- value="Generate Full Dataset", variant="primary"
 
 
 
 
 
 
285
  )
286
- gr.Column(scale=1)
287
 
 
288
  with gr.Row():
289
  final_dataset = gr.DataFrame(
290
  value=DEFAULT_DATASETS[0],
@@ -292,6 +305,7 @@ with gr.Blocks(
292
  interactive=False,
293
  wrap=True,
294
  )
 
295
  with gr.Row():
296
  success_message = gr.Markdown(visible=False)
297
 
@@ -340,16 +354,37 @@ with gr.Blocks(
340
  outputs=[success_message],
341
  ).then(
342
  fn=generate_dataset,
343
- inputs=[
344
- system_prompt,
345
- num_turns,
346
- num_rows,
347
- private,
348
- org_name,
349
- repo_name,
350
- ],
351
  outputs=[final_dataset],
352
  show_progress=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  ).then(
354
  fn=upload_pipeline_code,
355
  inputs=[pipeline_code, org_name, repo_name],
 
4
 
5
  import gradio as gr
6
  import pandas as pd
7
+ from datasets import Dataset
8
  from distilabel.distiset import Distiset
9
  from huggingface_hub import upload_file
10
 
 
70
  return result
71
 
72
 
73
+ def _check_push_to_hub(org_name, repo_name):
 
 
 
 
 
 
 
 
 
 
74
  repo_id = (
75
  f"{org_name}/{repo_name}"
76
  if repo_name is not None and org_name is not None
 
81
  raise gr.Error(
82
  "Please provide a `repo_name` and `org_name` to push the dataset to."
83
  )
84
+ return repo_id
85
 
86
+
87
+ def generate_dataset(
88
+ system_prompt: str,
89
+ num_turns: int = 1,
90
+ num_rows: int = 5,
91
+ is_sample: bool = False,
92
+ progress=gr.Progress(),
93
+ ):
94
  if num_rows < 5:
95
  duration = 25
96
  elif num_rows < 10:
 
129
 
130
  distiset = result_queue.get()
131
 
 
 
 
 
 
 
 
 
 
132
  # If not pushing to hub generate the dataset directly
133
  distiset = distiset["default"]["train"]
134
  if num_turns == 1:
135
  outputs = distiset.to_pandas()[["prompt", "completion"]]
136
  else:
137
  outputs = distiset.to_pandas()[["messages"]]
138
+ dataframe = pd.DataFrame(outputs)
139
 
140
  progress(1.0, desc="Dataset generation completed")
141
+ return dataframe
142
+
143
+
144
+ def push_to_hub(
145
+ dataframe,
146
+ private: bool = True,
147
+ org_name: str = None,
148
+ repo_name: str = None,
149
+ oauth_token: OAuthToken = None,
150
+ ):
151
+ distiset = Distiset(
152
+ {
153
+ "default": Dataset.from_pandas(dataframe),
154
+ }
155
+ )
156
+ distiset.push_to_hub(
157
+ repo_id=f"{org_name}/{repo_name}",
158
+ private=private,
159
+ include_script=True,
160
+ token=oauth_token,
161
+ )
162
+ return dataframe
163
 
164
 
165
  def upload_pipeline_code(
 
187
  ) as app:
188
  with gr.Row():
189
  gr.Markdown(
190
+ "Want to run this locally or with other LLMs? Take a look at the FAQ tab. DataCraft is free, we use the authentication token to push the dataset to the Hugging Face Hub and not for data generation."
191
  )
192
  with gr.Row():
193
  gr.Column()
 
274
  maximum=500,
275
  info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
276
  )
 
277
  with gr.Row(variant="panel"):
278
  org_name = get_org_dropdown()
279
  repo_name = gr.Textbox(
280
  label="Repo name", placeholder="dataset_name", value="my-distiset"
281
  )
282
  private = gr.Checkbox(
283
+ label="Private dataset",
284
+ value=True,
285
+ interactive=True,
286
+ scale=0.5,
287
  )
288
  with gr.Row() as regenerate_row:
289
  gr.Column(scale=1)
290
  btn_generate_full_dataset = gr.Button(
291
+ value="Generate", variant="primary", scale=2
292
+ )
293
+ btn_generate_and_push_to_hub = gr.Button(
294
+ value="Generate and Push to Hub", variant="primary", scale=2
295
+ )
296
+ btn_push_to_hub = gr.Button(
297
+ value="Push to Hub", variant="primary", scale=2
298
  )
 
299
 
300
+ gr.Column(scale=1)
301
  with gr.Row():
302
  final_dataset = gr.DataFrame(
303
  value=DEFAULT_DATASETS[0],
 
305
  interactive=False,
306
  wrap=True,
307
  )
308
+
309
  with gr.Row():
310
  success_message = gr.Markdown(visible=False)
311
 
 
354
  outputs=[success_message],
355
  ).then(
356
  fn=generate_dataset,
357
+ inputs=[system_prompt, num_turns, num_rows],
 
 
 
 
 
 
 
358
  outputs=[final_dataset],
359
  show_progress=True,
360
+ )
361
+ btn_generate_and_push_to_hub.click(
362
+ fn=hide_success_message,
363
+ outputs=[success_message],
364
+ ).then(
365
+ fn=generate_dataset,
366
+ inputs=[system_prompt, num_turns, num_rows],
367
+ outputs=[final_dataset],
368
+ show_progress=True,
369
+ ).then(
370
+ fn=push_to_hub,
371
+ inputs=[final_dataset, private, org_name, repo_name],
372
+ outputs=[final_dataset],
373
+ show_progress=True,
374
+ ).then(
375
+ fn=upload_pipeline_code,
376
+ inputs=[pipeline_code, org_name, repo_name],
377
+ outputs=[],
378
+ ).success(
379
+ fn=show_success_message,
380
+ inputs=[org_name, repo_name],
381
+ outputs=[success_message],
382
+ )
383
+
384
+ btn_push_to_hub.click(
385
+ fn=push_to_hub,
386
+ inputs=[final_dataset, private, org_name, repo_name],
387
+ outputs=[final_dataset],
388
  ).then(
389
  fn=upload_pipeline_code,
390
  inputs=[pipeline_code, org_name, repo_name],