Spaces:
Runtime error
Runtime error
Commit
Β·
ffa2ee0
1
Parent(s):
d15b1c7
update logic textcat for inferring labels
Browse files
src/synthetic_dataset_generator/apps/base.py
CHANGED
@@ -4,6 +4,7 @@ from typing import Union
|
|
4 |
|
5 |
import argilla as rg
|
6 |
import gradio as gr
|
|
|
7 |
from gradio import OAuthToken
|
8 |
from huggingface_hub import HfApi, upload_file
|
9 |
|
@@ -75,6 +76,14 @@ def validate_push_to_hub(org_name, repo_name):
|
|
75 |
return repo_id
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
def show_success_message(org_name, repo_name) -> gr.Markdown:
|
79 |
client = get_argilla_client()
|
80 |
if client is None:
|
|
|
4 |
|
5 |
import argilla as rg
|
6 |
import gradio as gr
|
7 |
+
from datasets import Dataset, concatenate_datasets, load_dataset
|
8 |
from gradio import OAuthToken
|
9 |
from huggingface_hub import HfApi, upload_file
|
10 |
|
|
|
76 |
return repo_id
|
77 |
|
78 |
|
79 |
+
def combine_datasets(repo_id: str, dataset: Dataset) -> Dataset:
|
80 |
+
try:
|
81 |
+
dataset = load_dataset(repo_id, split="train")
|
82 |
+
return concatenate_datasets([dataset, dataset])
|
83 |
+
except Exception:
|
84 |
+
return dataset
|
85 |
+
|
86 |
+
|
87 |
def show_success_message(org_name, repo_name) -> gr.Markdown:
|
88 |
client = get_argilla_client()
|
89 |
if client is None:
|
src/synthetic_dataset_generator/apps/eval.py
CHANGED
@@ -18,6 +18,7 @@ from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
|
18 |
from huggingface_hub import HfApi, repo_exists
|
19 |
|
20 |
from synthetic_dataset_generator.apps.base import (
|
|
|
21 |
hide_success_message,
|
22 |
push_pipeline_code_to_hub,
|
23 |
show_success_message,
|
@@ -355,7 +356,9 @@ def push_dataset_to_hub(
|
|
355 |
pipeline_code: str,
|
356 |
):
|
357 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
358 |
-
|
|
|
|
|
359 |
distiset.push_to_hub(
|
360 |
repo_id=repo_id,
|
361 |
private=private,
|
|
|
18 |
from huggingface_hub import HfApi, repo_exists
|
19 |
|
20 |
from synthetic_dataset_generator.apps.base import (
|
21 |
+
combine_datasets,
|
22 |
hide_success_message,
|
23 |
push_pipeline_code_to_hub,
|
24 |
show_success_message,
|
|
|
356 |
pipeline_code: str,
|
357 |
):
|
358 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
359 |
+
dataset = Dataset.from_pandas(dataframe)
|
360 |
+
dataset = combine_datasets(repo_id, dataset)
|
361 |
+
distiset = Distiset({"default": dataset})
|
362 |
distiset.push_to_hub(
|
363 |
repo_id=repo_id,
|
364 |
private=private,
|
src/synthetic_dataset_generator/apps/sft.py
CHANGED
@@ -10,6 +10,7 @@ from distilabel.distiset import Distiset
|
|
10 |
from huggingface_hub import HfApi
|
11 |
|
12 |
from synthetic_dataset_generator.apps.base import (
|
|
|
13 |
hide_success_message,
|
14 |
push_pipeline_code_to_hub,
|
15 |
show_success_message,
|
@@ -209,11 +210,18 @@ def push_dataset_to_hub(
|
|
209 |
oauth_token: Union[gr.OAuthToken, None],
|
210 |
private: bool,
|
211 |
pipeline_code: str,
|
|
|
212 |
):
|
|
|
213 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
|
|
214 |
original_dataframe = dataframe.copy(deep=True)
|
215 |
dataframe = convert_dataframe_messages(dataframe)
|
216 |
-
|
|
|
|
|
|
|
|
|
217 |
distiset.push_to_hub(
|
218 |
repo_id=repo_id,
|
219 |
private=private,
|
@@ -222,6 +230,7 @@ def push_dataset_to_hub(
|
|
222 |
create_pr=False,
|
223 |
)
|
224 |
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
|
|
|
225 |
return original_dataframe
|
226 |
|
227 |
|
|
|
10 |
from huggingface_hub import HfApi
|
11 |
|
12 |
from synthetic_dataset_generator.apps.base import (
|
13 |
+
combine_datasets,
|
14 |
hide_success_message,
|
15 |
push_pipeline_code_to_hub,
|
16 |
show_success_message,
|
|
|
210 |
oauth_token: Union[gr.OAuthToken, None],
|
211 |
private: bool,
|
212 |
pipeline_code: str,
|
213 |
+
progress=gr.Progress(),
|
214 |
):
|
215 |
+
progress(0.0, desc="Validating")
|
216 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
217 |
+
progress(0.3, desc="Converting")
|
218 |
original_dataframe = dataframe.copy(deep=True)
|
219 |
dataframe = convert_dataframe_messages(dataframe)
|
220 |
+
progress(0.7, desc="Creating dataset")
|
221 |
+
dataset = Dataset.from_pandas(dataframe)
|
222 |
+
dataset = combine_datasets(repo_id, dataset)
|
223 |
+
progress(0.9, desc="Pushing dataset")
|
224 |
+
distiset = Distiset({"default": dataset})
|
225 |
distiset.push_to_hub(
|
226 |
repo_id=repo_id,
|
227 |
private=private,
|
|
|
230 |
create_pr=False,
|
231 |
)
|
232 |
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
|
233 |
+
progress(1.0, desc="Dataset pushed")
|
234 |
return original_dataframe
|
235 |
|
236 |
|
src/synthetic_dataset_generator/apps/textcat.py
CHANGED
@@ -11,6 +11,7 @@ from distilabel.distiset import Distiset
|
|
11 |
from huggingface_hub import HfApi
|
12 |
|
13 |
from src.synthetic_dataset_generator.apps.base import (
|
|
|
14 |
hide_success_message,
|
15 |
push_pipeline_code_to_hub,
|
16 |
show_success_message,
|
@@ -129,7 +130,9 @@ def generate_dataset(
|
|
129 |
sampled_labels = random.sample(labels, num_labels)
|
130 |
random.shuffle(sampled_labels)
|
131 |
inputs.append(
|
132 |
-
{
|
|
|
|
|
133 |
)
|
134 |
batch = list(textcat_generator.process(inputs=inputs))
|
135 |
textcat_results.extend(batch[0])
|
@@ -194,9 +197,13 @@ def push_dataset_to_hub(
|
|
194 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
195 |
private: bool = False,
|
196 |
pipeline_code: str = "",
|
|
|
197 |
):
|
|
|
198 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
|
|
199 |
labels = get_preprocess_labels(labels)
|
|
|
200 |
if num_labels == 1:
|
201 |
dataframe["label"] = dataframe["label"].replace("", None)
|
202 |
features = Features(
|
@@ -209,7 +216,10 @@ def push_dataset_to_hub(
|
|
209 |
"labels": Sequence(feature=ClassLabel(names=labels)),
|
210 |
}
|
211 |
)
|
212 |
-
|
|
|
|
|
|
|
213 |
distiset.push_to_hub(
|
214 |
repo_id=repo_id,
|
215 |
private=private,
|
@@ -218,6 +228,7 @@ def push_dataset_to_hub(
|
|
218 |
create_pr=False,
|
219 |
)
|
220 |
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
|
|
|
221 |
|
222 |
|
223 |
def push_dataset(
|
@@ -439,7 +450,7 @@ with gr.Blocks() as app:
|
|
439 |
("Ambiguous", "ambiguous"),
|
440 |
("Mixed", "mixed"),
|
441 |
],
|
442 |
-
value="
|
443 |
label="Clarity",
|
444 |
info="Set how easily the correct label or labels can be identified.",
|
445 |
interactive=True,
|
|
|
11 |
from huggingface_hub import HfApi
|
12 |
|
13 |
from src.synthetic_dataset_generator.apps.base import (
|
14 |
+
combine_datasets,
|
15 |
hide_success_message,
|
16 |
push_pipeline_code_to_hub,
|
17 |
show_success_message,
|
|
|
130 |
sampled_labels = random.sample(labels, num_labels)
|
131 |
random.shuffle(sampled_labels)
|
132 |
inputs.append(
|
133 |
+
{
|
134 |
+
"task": f"{system_prompt}. The text represents the following categories: {', '.join(sampled_labels)}"
|
135 |
+
}
|
136 |
)
|
137 |
batch = list(textcat_generator.process(inputs=inputs))
|
138 |
textcat_results.extend(batch[0])
|
|
|
197 |
oauth_token: Union[gr.OAuthToken, None] = None,
|
198 |
private: bool = False,
|
199 |
pipeline_code: str = "",
|
200 |
+
progress=gr.Progress(),
|
201 |
):
|
202 |
+
progress(0.0, desc="Validating")
|
203 |
repo_id = validate_push_to_hub(org_name, repo_name)
|
204 |
+
progress(0.3, desc="Preprocessing")
|
205 |
labels = get_preprocess_labels(labels)
|
206 |
+
progress(0.7, desc="Creating dataset")
|
207 |
if num_labels == 1:
|
208 |
dataframe["label"] = dataframe["label"].replace("", None)
|
209 |
features = Features(
|
|
|
216 |
"labels": Sequence(feature=ClassLabel(names=labels)),
|
217 |
}
|
218 |
)
|
219 |
+
dataset = Dataset.from_pandas(dataframe, features=features)
|
220 |
+
dataset = combine_datasets(repo_id, dataset)
|
221 |
+
distiset = Distiset({"default": dataset})
|
222 |
+
progress(0.9, desc="Pushing dataset")
|
223 |
distiset.push_to_hub(
|
224 |
repo_id=repo_id,
|
225 |
private=private,
|
|
|
228 |
create_pr=False,
|
229 |
)
|
230 |
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
|
231 |
+
progress(1.0, desc="Dataset pushed")
|
232 |
|
233 |
|
234 |
def push_dataset(
|
|
|
450 |
("Ambiguous", "ambiguous"),
|
451 |
("Mixed", "mixed"),
|
452 |
],
|
453 |
+
value="mixed",
|
454 |
label="Clarity",
|
455 |
info="Set how easily the correct label or labels can be identified.",
|
456 |
interactive=True,
|
src/synthetic_dataset_generator/pipelines/textcat.py
CHANGED
@@ -26,16 +26,16 @@ Don't include the labels in the classification_task but only provide a high leve
|
|
26 |
If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
|
27 |
|
28 |
Description: DavidMovieHouse is a cinema that has been in business for 10 years.
|
29 |
-
Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews. Classify the customer reviews as", "labels": ["positive", "negative"]}
|
30 |
|
31 |
Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
|
32 |
-
Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover. Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
|
33 |
|
34 |
Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
|
35 |
-
Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
|
36 |
|
37 |
Description: A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels "data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"
|
38 |
-
Output: {"classification_task": "A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels", "labels": ["data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"]}
|
39 |
|
40 |
Description:
|
41 |
"""
|
|
|
26 |
If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
|
27 |
|
28 |
Description: DavidMovieHouse is a cinema that has been in business for 10 years.
|
29 |
+
Output: {"classification_task": "The company DavidMovieHouse is a cinema that has been in business for 10 years and has had customers reviews of varying customer groups. Classify the customer reviews as", "labels": ["positive", "negative"]}
|
30 |
|
31 |
Description: A dataset that focuses on creating neo-ludite discussions about technologies within the AI space.
|
32 |
+
Output: {"classification_task": "Neo-ludiite discussions about technologies within the AI space cover from different speaking people . Categorize the discussions into one of the following categories", "labels": ["tech-support", "tech-opposition"]}
|
33 |
|
34 |
Description: A dataset that covers the articles of a niche sports website called TheSportBlogs that focuses on female sports within the ballsport domain for the US market.
|
35 |
+
Output: {"classification_task": "TechSportBlogs is a niche sports website that focuses on female sports within the ballsport domain for the US market. Written by different journalists. Determine the category of based on the article using the following categories", "labels": ["basketball", "volleyball", "tennis", "hockey", "baseball", "soccer"]}
|
36 |
|
37 |
Description: A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review with labels "data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"
|
38 |
+
Output: {"classification_task": "A dataset covering customer reviews for an e-commerce website called Argilla that sells technology datasets within the open source Natural Language Processing space and has review from various cusomer demographics with labels", "labels": ["data-quality", "data-accuracy", "customer-service", "price", "product-availability", "shipping-speed"]}
|
39 |
|
40 |
Description:
|
41 |
"""
|