davidberenstein1957 HF staff commited on
Commit
ffa2ee0
Β·
1 Parent(s): d15b1c7

update logic textcat for inferring labels

Browse files
src/synthetic_dataset_generator/apps/base.py CHANGED
@@ -4,6 +4,7 @@ from typing import Union
4
 
5
  import argilla as rg
6
  import gradio as gr
 
7
  from gradio import OAuthToken
8
  from huggingface_hub import HfApi, upload_file
9
 
@@ -75,6 +76,14 @@ def validate_push_to_hub(org_name, repo_name):
75
  return repo_id
76
 
77
 
 
 
 
 
 
 
 
 
78
  def show_success_message(org_name, repo_name) -> gr.Markdown:
79
  client = get_argilla_client()
80
  if client is None:
 
4
 
5
  import argilla as rg
6
  import gradio as gr
7
+ from datasets import Dataset, concatenate_datasets, load_dataset
8
  from gradio import OAuthToken
9
  from huggingface_hub import HfApi, upload_file
10
 
 
76
  return repo_id
77
 
78
 
79
+ def combine_datasets(repo_id: str, dataset: Dataset) -> Dataset:
80
+ try:
81
+ dataset = load_dataset(repo_id, split="train")
82
+ return concatenate_datasets([dataset, dataset])
83
+ except Exception:
84
+ return dataset
85
+
86
+
87
  def show_success_message(org_name, repo_name) -> gr.Markdown:
88
  client = get_argilla_client()
89
  if client is None:
src/synthetic_dataset_generator/apps/eval.py CHANGED
@@ -18,6 +18,7 @@ from gradio_huggingfacehub_search import HuggingfaceHubSearch
18
  from huggingface_hub import HfApi, repo_exists
19
 
20
  from synthetic_dataset_generator.apps.base import (
 
21
  hide_success_message,
22
  push_pipeline_code_to_hub,
23
  show_success_message,
@@ -355,7 +356,9 @@ def push_dataset_to_hub(
355
  pipeline_code: str,
356
  ):
357
  repo_id = validate_push_to_hub(org_name, repo_name)
358
- distiset = Distiset({"default": Dataset.from_pandas(dataframe)})
 
 
359
  distiset.push_to_hub(
360
  repo_id=repo_id,
361
  private=private,
 
18
  from huggingface_hub import HfApi, repo_exists
19
 
20
  from synthetic_dataset_generator.apps.base import (
21
+ combine_datasets,
22
  hide_success_message,
23
  push_pipeline_code_to_hub,
24
  show_success_message,
 
356
  pipeline_code: str,
357
  ):
358
  repo_id = validate_push_to_hub(org_name, repo_name)
359
+ dataset = Dataset.from_pandas(dataframe)
360
+ dataset = combine_datasets(repo_id, dataset)
361
+ distiset = Distiset({"default": dataset})
362
  distiset.push_to_hub(
363
  repo_id=repo_id,
364
  private=private,
src/synthetic_dataset_generator/apps/sft.py CHANGED
@@ -10,6 +10,7 @@ from distilabel.distiset import Distiset
10
  from huggingface_hub import HfApi
11
 
12
  from synthetic_dataset_generator.apps.base import (
 
13
  hide_success_message,
14
  push_pipeline_code_to_hub,
15
  show_success_message,
@@ -209,11 +210,18 @@ def push_dataset_to_hub(
209
  oauth_token: Union[gr.OAuthToken, None],
210
  private: bool,
211
  pipeline_code: str,
 
212
  ):
 
213
  repo_id = validate_push_to_hub(org_name, repo_name)
 
214
  original_dataframe = dataframe.copy(deep=True)
215
  dataframe = convert_dataframe_messages(dataframe)
216
- distiset = Distiset({"default": Dataset.from_pandas(dataframe)})
 
 
 
 
217
  distiset.push_to_hub(
218
  repo_id=repo_id,
219
  private=private,
@@ -222,6 +230,7 @@ def push_dataset_to_hub(
222
  create_pr=False,
223
  )
224
  push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
 
225
  return original_dataframe
226
 
227
 
 
10
  from huggingface_hub import HfApi
11
 
12
  from synthetic_dataset_generator.apps.base import (
13
+ combine_datasets,
14
  hide_success_message,
15
  push_pipeline_code_to_hub,
16
  show_success_message,
 
210
  oauth_token: Union[gr.OAuthToken, None],
211
  private: bool,
212
  pipeline_code: str,
213
+ progress=gr.Progress(),
214
  ):
215
+ progress(0.0, desc="Validating")
216
  repo_id = validate_push_to_hub(org_name, repo_name)
217
+ progress(0.3, desc="Converting")
218
  original_dataframe = dataframe.copy(deep=True)
219
  dataframe = convert_dataframe_messages(dataframe)
220
+ progress(0.7, desc="Creating dataset")
221
+ dataset = Dataset.from_pandas(dataframe)
222
+ dataset = combine_datasets(repo_id, dataset)
223
+ progress(0.9, desc="Pushing dataset")
224
+ distiset = Distiset({"default": dataset})
225
  distiset.push_to_hub(
226
  repo_id=repo_id,
227
  private=private,
 
230
  create_pr=False,
231
  )
232
  push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
233
+ progress(1.0, desc="Dataset pushed")
234
  return original_dataframe
235
 
236
 
src/synthetic_dataset_generator/apps/textcat.py CHANGED
@@ -11,6 +11,7 @@ from distilabel.distiset import Distiset
11
  from huggingface_hub import HfApi
12
 
13
  from src.synthetic_dataset_generator.apps.base import (
 
14
  hide_success_message,
15
  push_pipeline_code_to_hub,
16
  show_success_message,
@@ -129,7 +130,9 @@ def generate_dataset(
129
  sampled_labels = random.sample(labels, num_labels)
130
  random.shuffle(sampled_labels)
131
  inputs.append(
132
- {"task": f"{system_prompt}. Labels: {', '.join(sampled_labels)}"}
 
 
133
  )
134
  batch = list(textcat_generator.process(inputs=inputs))
135
  textcat_results.extend(batch[0])
@@ -194,9 +197,13 @@ def push_dataset_to_hub(
194
  oauth_token: Union[gr.OAuthToken, None] = None,
195
  private: bool = False,
196
  pipeline_code: str = "",
 
197
  ):
 
198
  repo_id = validate_push_to_hub(org_name, repo_name)
 
199
  labels = get_preprocess_labels(labels)
 
200
  if num_labels == 1:
201
  dataframe["label"] = dataframe["label"].replace("", None)
202
  features = Features(
@@ -209,7 +216,10 @@ def push_dataset_to_hub(
209
  "labels": Sequence(feature=ClassLabel(names=labels)),
210
  }
211
  )
212
- distiset = Distiset({"default": Dataset.from_pandas(dataframe, features=features)})
 
 
 
213
  distiset.push_to_hub(
214
  repo_id=repo_id,
215
  private=private,
@@ -218,6 +228,7 @@ def push_dataset_to_hub(
218
  create_pr=False,
219
  )
220
  push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
 
221
 
222
 
223
  def push_dataset(
@@ -439,7 +450,7 @@ with gr.Blocks() as app:
439
  ("Ambiguous", "ambiguous"),
440
  ("Mixed", "mixed"),
441
  ],
442
- value="understandable with some effort",
443
  label="Clarity",
444
  info="Set how easily the correct label or labels can be identified.",
445
  interactive=True,
 
11
  from huggingface_hub import HfApi
12
 
13
  from src.synthetic_dataset_generator.apps.base import (
14
+ combine_datasets,
15
  hide_success_message,
16
  push_pipeline_code_to_hub,
17
  show_success_message,
 
130
  sampled_labels = random.sample(labels, num_labels)
131
  random.shuffle(sampled_labels)
132
  inputs.append(
133
+ {
134
+ "task": f"{system_prompt}. The text represents the following categories: {', '.join(sampled_labels)}"
135
+ }
136
  )
137
  batch = list(textcat_generator.process(inputs=inputs))
138
  textcat_results.extend(batch[0])
 
197
  oauth_token: Union[gr.OAuthToken, None] = None,
198
  private: bool = False,
199
  pipeline_code: str = "",
200
+ progress=gr.Progress(),
201
  ):
202
+ progress(0.0, desc="Validating")
203
  repo_id = validate_push_to_hub(org_name, repo_name)
204
+ progress(0.3, desc="Preprocessing")
205
  labels = get_preprocess_labels(labels)
206
+ progress(0.7, desc="Creating dataset")
207
  if num_labels == 1:
208
  dataframe["label"] = dataframe["label"].replace("", None)
209
  features = Features(
 
216
  "labels": Sequence(feature=ClassLabel(names=labels)),
217
  }
218
  )
219
+ dataset = Dataset.from_pandas(dataframe, features=features)
220
+ dataset = combine_datasets(repo_id, dataset)
221
+ distiset = Distiset({"default": dataset})
222
+ progress(0.9, desc="Pushing dataset")
223
  distiset.push_to_hub(
224
  repo_id=repo_id,
225
  private=private,
 
228
  create_pr=False,
229
  )
230
  push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
231
+ progress(1.0, desc="Dataset pushed")
232
 
233
 
234
  def push_dataset(
 
450
  ("Ambiguous", "ambiguous"),
451
  ("Mixed", "mixed"),
452
  ],
453
+ value="mixed",
454
  label="Clarity",
455
  info="Set how easily the correct label or labels can be identified.",
456
  interactive=True,
src/synthetic_dataset_generator/pipelines/textcat.py CHANGED
@@ -26,16 +26,16 @@ Don't include the labels in the classification_task but only provide a high leve
26
  If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
27
 
28
  Description: DavidMovieHouse is a cinema that has been in business for 10 years.
29
- Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews. Classify the customer reviews as", "labels": ["positive", "negative"]}
30
 
31
  Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
32
- Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover. Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
33
 
34
  Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
35
- Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
36
 
37
  Description: A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels "data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"
38
- Output: {"classification_task": "A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels", "labels": ["data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"]}
39
 
40
  Description:
41
  """
 
26
  If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
27
 
28
  Description: DavidMovieHouse is a cinema that has been in business for 10 years.
29
+ Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews of varying customer groups. Classify the customer reviews as", "labels": ["positive", "negative"]}
30
 
31
  Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
32
+ Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover from different speaking people . Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
33
 
34
  Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
35
+ Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Written by different journalists. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
36
 
37
  Description: A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels "data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"
38
+ Output: {"classification_task": "A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review from various cusomer demographics with labels", "labels": ["data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"]}
39
 
40
  Description:
41
  """