import json import uuid from typing import List, Union import argilla as rg import gradio as gr import pandas as pd from datasets import ClassLabel, Dataset, Features, Sequence, Value from distilabel.distiset import Distiset from huggingface_hub import HfApi from distilabel_dataset_generator.constants import DEFAULT_BATCH_SIZE from src.distilabel_dataset_generator.apps.base import ( hide_success_message, show_success_message, validate_argilla_user_workspace_dataset, validate_push_to_hub, ) from src.distilabel_dataset_generator.pipelines.embeddings import ( get_embeddings, get_sentence_embedding_dimensions, ) from src.distilabel_dataset_generator.pipelines.textcat import ( DEFAULT_DATASET_DESCRIPTIONS, generate_pipeline_code, get_labeller_generator, get_prompt_generator, get_textcat_generator, ) from src.distilabel_dataset_generator.utils import ( get_argilla_client, get_org_dropdown, get_preprocess_labels, swap_visibility, ) def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()): progress(0.0, desc="Generating text classification task") progress(0.3, desc="Initializing text generation") generate_description = get_prompt_generator(temperature) progress(0.7, desc="Generating text classification task") result = next( generate_description.process( [ { "instruction": dataset_description, } ] ) )[0]["generation"] progress(1.0, desc="Text classification task generated") data = json.loads(result) system_prompt = data["classification_task"] labels = data["labels"] return system_prompt, labels def generate_sample_dataset( system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress() ): dataframe = generate_dataset( system_prompt=system_prompt, difficulty=difficulty, clarity=clarity, labels=labels, num_labels=num_labels, num_rows=10, progress=progress, is_sample=True, ) return dataframe def generate_dataset( system_prompt: str, difficulty: str, clarity: str, labels: List[str] = None, num_labels: int = 1, num_rows: int = 10, is_sample: bool = False, progress=gr.Progress(), ) -> pd.DataFrame: progress(0.0, desc="(1/2) Generating text classification data") labels = get_preprocess_labels(labels) textcat_generator = get_textcat_generator( difficulty=difficulty, clarity=clarity, is_sample=is_sample ) labeller_generator = get_labeller_generator( system_prompt=f"{system_prompt} {', '.join(labels)}", labels=labels, num_labels=num_labels, ) total_steps: int = num_rows * 2 batch_size = DEFAULT_BATCH_SIZE # create text classification data n_processed = 0 textcat_results = [] while n_processed < num_rows: progress( 2 * 0.5 * n_processed / num_rows, total=total_steps, desc="(1/2) Generating text classification data", ) remaining_rows = num_rows - n_processed batch_size = min(batch_size, remaining_rows) inputs = [ {"task": f"{system_prompt} {', '.join(labels)}"} for _ in range(batch_size) ] batch = list(textcat_generator.process(inputs=inputs)) textcat_results.extend(batch[0]) n_processed += batch_size for result in textcat_results: result["text"] = result["input_text"] # label text classification data progress(2 * 0.5, desc="(1/2) Generating text classification data") n_processed = 0 labeller_results = [] while n_processed < num_rows: progress( 0.5 + 0.5 * n_processed / num_rows, total=total_steps, desc="(1/2) Labeling text classification data", ) batch = textcat_results[n_processed : n_processed + batch_size] labels_batch = list(labeller_generator.process(inputs=batch)) labeller_results.extend(labels_batch[0]) n_processed += batch_size progress( 1, total=total_steps, desc="(2/2) Creating dataset", ) # create final dataset distiset_results = [] for result in labeller_results: record = {key: result[key] for key in ["labels", "text"] if key in result} distiset_results.append(record) dataframe = pd.DataFrame(distiset_results) if num_labels == 1: dataframe = dataframe.rename(columns={"labels": "label"}) dataframe["label"] = dataframe["label"].apply( lambda x: x.lower().strip() if x.lower().strip() in labels else None ) progress(1.0, desc="Dataset generation completed") return dataframe def push_dataset_to_hub( dataframe: pd.DataFrame, org_name: str, repo_name: str, num_labels: int = 1, labels: List[str] = None, oauth_token: Union[gr.OAuthToken, None] = None, private: bool = False, ): repo_id = validate_push_to_hub(org_name, repo_name) labels = get_preprocess_labels(labels) if num_labels == 1: dataframe["label"] = dataframe["label"].replace("", None) features = Features( {"text": Value("string"), "label": ClassLabel(names=labels)} ) else: features = Features( { "text": Value("string"), "labels": Sequence(feature=ClassLabel(names=labels)), } ) distiset = Distiset({"default": Dataset.from_pandas(dataframe, features=features)}) distiset.push_to_hub( repo_id=repo_id, private=private, include_script=False, token=oauth_token.token, create_pr=False, ) def push_dataset( org_name: str, repo_name: str, system_prompt: str, difficulty: str, clarity: str, num_labels: int = 1, num_rows: int = 10, labels: List[str] = None, private: bool = False, oauth_token: Union[gr.OAuthToken, None] = None, progress=gr.Progress(), ) -> pd.DataFrame: dataframe = generate_dataset( system_prompt=system_prompt, difficulty=difficulty, clarity=clarity, num_labels=num_labels, labels=labels, num_rows=num_rows, ) push_dataset_to_hub( dataframe, org_name, repo_name, num_labels, labels, oauth_token, private ) dataframe = dataframe[ (dataframe["text"].str.strip() != "") & (dataframe["text"].notna()) ] try: progress(0.1, desc="Setting up user and workspace") hf_user = HfApi().whoami(token=oauth_token.token)["name"] client = get_argilla_client() if client is None: return "" labels = get_preprocess_labels(labels) settings = rg.Settings( fields=[ rg.TextField( name="text", description="The text classification data", title="Text", ), ], questions=[ ( rg.LabelQuestion( name="label", title="Label", description="The label of the text", labels=labels, ) if num_labels == 1 else rg.MultiLabelQuestion( name="labels", title="Labels", description="The labels of the conversation", labels=labels, ) ), ], metadata=[ rg.IntegerMetadataProperty(name="text_length", title="Text Length"), ], vectors=[ rg.VectorField( name="text_embeddings", dimensions=get_sentence_embedding_dimensions(), ) ], guidelines="Please review the text and provide or correct the label where needed.", ) dataframe["text_length"] = dataframe["text"].apply(len) dataframe["text_embeddings"] = get_embeddings(dataframe["text"].to_list()) progress(0.5, desc="Creating dataset") rg_dataset = client.datasets(name=repo_name, workspace=hf_user) if rg_dataset is None: rg_dataset = rg.Dataset( name=repo_name, workspace=hf_user, settings=settings, client=client, ) rg_dataset = rg_dataset.create() progress(0.7, desc="Pushing dataset to Argilla") hf_dataset = Dataset.from_pandas(dataframe) records = [ rg.Record( fields={ "text": sample["text"], }, metadata={"text_length": sample["text_length"]}, vectors={"text_embeddings": sample["text_embeddings"]}, suggestions=( [ rg.Suggestion( question_name="label" if num_labels == 1 else "labels", value=( sample["label"] if num_labels == 1 else sample["labels"] ), ) ] if ( (num_labels == 1 and sample["label"] in labels) or ( num_labels > 1 and all(label in labels for label in sample["labels"]) ) ) else [] ), ) for sample in hf_dataset ] rg_dataset.records.log(records=records) progress(1.0, desc="Dataset pushed to Argilla") except Exception as e: raise gr.Error(f"Error pushing dataset to Argilla: {e}") return "" def validate_input_labels(labels): if not labels or len(labels) < 2: raise gr.Error( f"Please select at least 2 labels to classify your text. You selected {len(labels) if labels else 0}." ) return labels def update_max_num_labels(labels): return gr.update(maximum=len(labels) if labels else 1) def show_pipeline_code_visibility(): return {pipeline_code_ui: gr.Accordion(visible=True)} def hide_pipeline_code_visibility(): return {pipeline_code_ui: gr.Accordion(visible=False)} ###################### # Gradio UI ###################### with gr.Blocks() as app: with gr.Column() as main_ui: gr.Markdown("## 1. Describe the dataset you want") with gr.Row(): with gr.Column(scale=2): dataset_description = gr.Textbox( label="Dataset description", placeholder="Give a precise description of your desired dataset.", ) with gr.Accordion("Temperature", open=False): temperature = gr.Slider( minimum=0.1, maximum=1, value=0.8, step=0.1, interactive=True, show_label=False, ) load_btn = gr.Button( "Create dataset", variant="primary", ) with gr.Column(scale=2): examples = gr.Examples( examples=DEFAULT_DATASET_DESCRIPTIONS, inputs=[dataset_description], cache_examples=False, label="Examples", ) with gr.Column(scale=1): pass gr.HTML("
") gr.Markdown("## 2. Configure your dataset") with gr.Row(equal_height=False): with gr.Column(scale=2): system_prompt = gr.Textbox( label="System prompt", placeholder="You are a helpful assistant.", visible=True, ) labels = gr.Dropdown( choices=[], allow_custom_value=True, interactive=True, label="Labels", multiselect=True, info="Add the labels to classify the text.", ) num_labels = gr.Number( label="Number of labels per text", value=1, minimum=1, maximum=10, info="Select 1 for single-label and >1 for multi-label.", interactive=True, ) clarity = gr.Dropdown( choices=[ ("Clear", "clear"), ( "Understandable", "understandable with some effort", ), ("Ambiguous", "ambiguous"), ("Mixed", "mixed"), ], value="mixed", label="Clarity", info="Set how easily the correct label or labels can be identified.", interactive=True, ) difficulty = gr.Dropdown( choices=[ ("High School", "high school"), ("College", "college"), ("PhD", "PhD"), ("Mixed", "mixed"), ], value="mixed", label="Difficulty", info="Select the comprehension level for the text. Ensure it matches the task context.", interactive=True, ) btn_apply_to_sample_dataset = gr.Button( "Refresh dataset", variant="secondary" ) with gr.Column(scale=3): dataframe = gr.Dataframe( headers=["labels", "text"], wrap=True, height=500, interactive=False ) gr.HTML("
") gr.Markdown("## 3. Generate your dataset") with gr.Row(equal_height=False): with gr.Column(scale=2): org_name = get_org_dropdown() repo_name = gr.Textbox( label="Repo name", placeholder="dataset_name", value=f"my-distiset-{str(uuid.uuid4())[:8]}", interactive=True, ) num_rows = gr.Number( label="Number of rows", value=10, interactive=True, scale=1, ) private = gr.Checkbox( label="Private dataset", value=False, interactive=True, scale=1, ) btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2) with gr.Column(scale=3): success_message = gr.Markdown(visible=True) with gr.Accordion( "Do you want to go further? Customize and run with Distilabel", open=False, visible=False, ) as pipeline_code_ui: code = generate_pipeline_code( system_prompt.value, difficulty=difficulty.value, clarity=clarity.value, labels=labels.value, num_labels=num_labels.value, num_rows=num_rows.value, ) pipeline_code = gr.Code( value=code, language="python", label="Distilabel Pipeline Code", ) load_btn.click( fn=generate_system_prompt, inputs=[dataset_description, temperature], outputs=[system_prompt, labels], show_progress=True, ).then( fn=generate_sample_dataset, inputs=[system_prompt, difficulty, clarity, labels, num_labels], outputs=[dataframe], show_progress=True, ).then( fn=update_max_num_labels, inputs=[labels], outputs=[num_labels], ) labels.input( fn=update_max_num_labels, inputs=[labels], outputs=[num_labels], ) btn_apply_to_sample_dataset.click( fn=generate_sample_dataset, inputs=[system_prompt, difficulty, clarity, labels, num_labels], outputs=[dataframe], show_progress=True, ) btn_push_to_hub.click( fn=validate_argilla_user_workspace_dataset, inputs=[repo_name], outputs=[success_message], show_progress=True, ).then( fn=validate_push_to_hub, inputs=[org_name, repo_name], outputs=[success_message], show_progress=True, ).success( fn=hide_success_message, outputs=[success_message], show_progress=True, ).success( fn=hide_pipeline_code_visibility, inputs=[], outputs=[pipeline_code_ui], ).success( fn=push_dataset, inputs=[ org_name, repo_name, system_prompt, difficulty, clarity, num_labels, num_rows, labels, private, ], outputs=[success_message], show_progress=True, ).success( fn=show_success_message, inputs=[org_name, repo_name], outputs=[success_message], ).success( fn=generate_pipeline_code, inputs=[ system_prompt, difficulty, clarity, labels, num_labels, num_rows, ], outputs=[pipeline_code], ).success( fn=show_pipeline_code_visibility, inputs=[], outputs=[pipeline_code_ui], ) app.load(fn=swap_visibility, outputs=main_ui) app.load(fn=get_org_dropdown, outputs=[org_name])