sdiazlor's picture
refactor: add local save to README and improve layout
6a5179a
import ast
import json
import os
import random
import uuid
from typing import Dict, List, Union
import argilla as rg
import gradio as gr
import pandas as pd
from datasets import Dataset
from distilabel.distiset import Distiset
from gradio.oauth import OAuthToken
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from huggingface_hub import HfApi
from synthetic_dataset_generator.apps.base import (
combine_datasets,
hide_success_message,
load_dataset_from_hub,
preprocess_input_data,
push_pipeline_code_to_hub,
show_success_message,
test_max_num_rows,
validate_argilla_user_workspace_dataset,
validate_push_to_hub,
)
from synthetic_dataset_generator.constants import (
BASE_URL,
DEFAULT_BATCH_SIZE,
MODEL,
MODEL_COMPLETION,
SAVE_LOCAL_DIR,
SFT_AVAILABLE,
)
from synthetic_dataset_generator.pipelines.base import get_rewritten_prompts
from synthetic_dataset_generator.pipelines.chat import (
DEFAULT_DATASET_DESCRIPTIONS,
generate_pipeline_code,
get_follow_up_generator,
get_magpie_generator,
get_prompt_generator,
get_response_generator,
get_sentence_pair_generator,
)
from synthetic_dataset_generator.pipelines.embeddings import (
get_embeddings,
get_sentence_embedding_dimensions,
)
from synthetic_dataset_generator.utils import (
column_to_list,
get_argilla_client,
get_org_dropdown,
get_random_repo_name,
swap_visibility,
)
def _get_dataframe():
return gr.Dataframe(
headers=["prompt", "completion"],
wrap=True,
interactive=False,
)
def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame:
def convert_to_list_of_dicts(messages: str) -> List[Dict[str, str]]:
return ast.literal_eval(
messages.replace("'user'}", "'user'},")
.replace("'system'}", "'system'},")
.replace("'assistant'}", "'assistant'},")
)
if "messages" in dataframe.columns:
dataframe["messages"] = dataframe["messages"].apply(
lambda x: convert_to_list_of_dicts(x) if isinstance(x, str) else x
)
return dataframe
def generate_system_prompt(dataset_description: str, progress=gr.Progress()):
progress(0.1, desc="Initializing")
generate_description = get_prompt_generator()
progress(0.5, desc="Generating")
result = next(
generate_description.process(
[
{
"instruction": dataset_description,
}
]
)
)[0]["generation"]
progress(1.0, desc="Prompt generated")
return result
def load_dataset_file(
repo_id: str,
file_paths: list[str],
input_type: str,
num_rows: int = 10,
token: Union[OAuthToken, None] = None,
progress=gr.Progress(),
):
progress(0.1, desc="Loading the source data")
if input_type == "dataset-input":
return load_dataset_from_hub(repo_id=repo_id, num_rows=num_rows, token=token)
else:
return preprocess_input_data(file_paths=file_paths, num_rows=num_rows)
def generate_sample_dataset(
repo_id: str,
file_paths: list[str],
input_type: str,
system_prompt: str,
document_column: str,
num_turns: int,
num_rows: int,
oauth_token: Union[OAuthToken, None],
progress=gr.Progress(),
):
if input_type == "prompt-input":
dataframe = pd.DataFrame(columns=["prompt", "completion"])
else:
dataframe, _ = load_dataset_file(
repo_id=repo_id,
file_paths=file_paths,
input_type=input_type,
num_rows=num_rows,
token=oauth_token,
)
progress(0.5, desc="Generating sample dataset")
dataframe = generate_dataset(
input_type=input_type,
dataframe=dataframe,
system_prompt=system_prompt,
document_column=document_column,
num_turns=num_turns,
num_rows=num_rows,
is_sample=True,
)
progress(1.0, desc="Sample dataset generated")
return dataframe
def generate_dataset_from_prompt(
system_prompt: str,
num_turns: int = 1,
num_rows: int = 10,
temperature: float = 0.9,
temperature_completion: Union[float, None] = None,
is_sample: bool = False,
progress=gr.Progress(),
) -> pd.DataFrame:
num_rows = test_max_num_rows(num_rows)
progress(0.0, desc="(1/2) Generating instructions")
magpie_generator = get_magpie_generator(num_turns, temperature, is_sample)
response_generator = get_response_generator(
system_prompt=system_prompt,
num_turns=num_turns,
temperature=temperature or temperature_completion,
is_sample=is_sample,
)
total_steps: int = num_rows * 2
batch_size = DEFAULT_BATCH_SIZE
# create prompt rewrites
prompt_rewrites = get_rewritten_prompts(system_prompt, num_rows)
# create instructions
n_processed = 0
magpie_results = []
while n_processed < num_rows:
progress(
0.5 * n_processed / num_rows,
total=total_steps,
desc="(1/2) Generating instructions",
)
remaining_rows = num_rows - n_processed
batch_size = min(batch_size, remaining_rows)
rewritten_system_prompt = random.choice(prompt_rewrites)
inputs = [{"system_prompt": rewritten_system_prompt} for _ in range(batch_size)]
batch = list(magpie_generator.process(inputs=inputs))
magpie_results.extend(batch[0])
n_processed += batch_size
random.seed(a=random.randint(0, 2**32 - 1))
progress(0.5, desc="(1/2) Generating instructions")
# generate responses
n_processed = 0
response_results = []
if num_turns == 1:
while n_processed < num_rows:
progress(
0.5 + 0.5 * n_processed / num_rows,
total=total_steps,
desc="(2/2) Generating responses",
)
batch = magpie_results[n_processed : n_processed + batch_size]
responses = list(response_generator.process(inputs=batch))
response_results.extend(responses[0])
n_processed += batch_size
random.seed(a=random.randint(0, 2**32 - 1))
for result in response_results:
result["prompt"] = result["instruction"]
result["completion"] = result["generation"]
result["system_prompt"] = system_prompt
else:
for result in magpie_results:
result["conversation"].insert(
0, {"role": "system", "content": system_prompt}
)
result["messages"] = result["conversation"]
while n_processed < num_rows:
progress(
0.5 + 0.5 * n_processed / num_rows,
total=total_steps,
desc="(2/2) Generating responses",
)
batch = magpie_results[n_processed : n_processed + batch_size]
responses = list(response_generator.process(inputs=batch))
response_results.extend(responses[0])
n_processed += batch_size
random.seed(a=random.randint(0, 2**32 - 1))
for result in response_results:
result["messages"].append(
{"role": "assistant", "content": result["generation"]}
)
progress(
1,
total=total_steps,
desc="(2/2) Creating dataset",
)
# create distiset
distiset_results = []
for result in response_results:
record = {}
for relevant_keys in [
"messages",
"prompt",
"completion",
"model_name",
"system_prompt",
]:
if relevant_keys in result:
record[relevant_keys] = result[relevant_keys]
distiset_results.append(record)
distiset = Distiset(
{
"default": Dataset.from_list(distiset_results),
}
)
# If not pushing to hub generate the dataset directly
distiset = distiset["default"]
if num_turns == 1:
outputs = distiset.to_pandas()[["prompt", "completion", "system_prompt"]]
else:
outputs = distiset.to_pandas()[["messages"]]
dataframe = pd.DataFrame(outputs)
progress(1.0, desc="Dataset generation completed")
return dataframe
def generate_dataset_from_seed(
dataframe: pd.DataFrame,
document_column: str,
num_turns: int = 1,
num_rows: int = 10,
temperature: float = 0.9,
temperature_completion: Union[float, None] = None,
is_sample: bool = False,
progress=gr.Progress(),
) -> pd.DataFrame:
num_rows = test_max_num_rows(num_rows)
progress(0.0, desc="Initializing dataset generation")
document_data = column_to_list(dataframe, document_column)
if len(document_data) < num_rows:
document_data += random.choices(document_data, k=num_rows - len(document_data))
instruction_generator = get_sentence_pair_generator(
temperature=temperature, is_sample=is_sample
)
response_generator = get_response_generator(
system_prompt=None,
num_turns=1,
temperature=temperature or temperature_completion,
is_sample=is_sample,
)
follow_up_generator_instruction = get_follow_up_generator(
type="instruction", temperature=temperature, is_sample=is_sample
)
follow_up_generator_response = get_follow_up_generator(
type="response",
temperature=temperature or temperature_completion,
is_sample=is_sample,
)
steps = 2 * num_turns
total_steps: int = num_rows * steps
step_progress = round(1 / steps, 2)
batch_size = DEFAULT_BATCH_SIZE
# create instructions
n_processed = 0
instruction_results = []
while n_processed < num_rows:
progress(
step_progress * n_processed / num_rows,
total=total_steps,
desc="Generating instructions",
)
remaining_rows = num_rows - n_processed
batch_size = min(batch_size, remaining_rows)
batch = [
{"anchor": document}
for document in document_data[n_processed : n_processed + batch_size]
]
questions = list(instruction_generator.process(inputs=batch))
instruction_results.extend(questions[0])
n_processed += batch_size
for result in instruction_results:
result["instruction"] = result["positive"]
result["prompt"] = result.pop("positive")
progress(step_progress, desc="Generating instructions")
# generate responses
n_processed = 0
response_results = []
while n_processed < num_rows:
progress(
step_progress + step_progress * n_processed / num_rows,
total=total_steps,
desc="Generating responses",
)
batch = instruction_results[n_processed : n_processed + batch_size]
responses = list(response_generator.process(inputs=batch))
response_results.extend(responses[0])
n_processed += batch_size
for result in response_results:
result["completion"] = result.pop("generation")
# generate follow-ups
if num_turns > 1:
n_processed = 0
final_conversations = []
while n_processed < num_rows:
progress(
step_progress + step_progress * n_processed / num_rows,
total=total_steps,
desc="Generating follow-ups",
)
batch = response_results[n_processed : n_processed + batch_size]
conversations_batch = [
{
"messages": [
{"role": "user", "content": result["prompt"]},
{"role": "assistant", "content": result["completion"]},
]
}
for result in batch
]
for _ in range(num_turns - 1):
follow_up_instructions = list(
follow_up_generator_instruction.process(inputs=conversations_batch)
)
for conv, follow_up in zip(
conversations_batch, follow_up_instructions[0]
):
conv["messages"].append(
{"role": "user", "content": follow_up["generation"]}
)
follow_up_responses = list(
follow_up_generator_response.process(inputs=conversations_batch)
)
for conv, follow_up in zip(conversations_batch, follow_up_responses[0]):
conv["messages"].append(
{"role": "assistant", "content": follow_up["generation"]}
)
final_conversations.extend(
[{"messages": conv["messages"]} for conv in conversations_batch]
)
n_processed += batch_size
# create distiset
distiset_results = []
if num_turns == 1:
for result in response_results:
record = {}
for relevant_keys in ["prompt", "completion"]:
if relevant_keys in result:
record[relevant_keys] = result[relevant_keys]
distiset_results.append(record)
dataframe = pd.DataFrame(distiset_results)
else:
distiset_results = final_conversations
dataframe = pd.DataFrame(distiset_results)
dataframe["messages"] = dataframe["messages"].apply(lambda x: json.dumps(x))
progress(1.0, desc="Dataset generation completed")
return dataframe
def generate_dataset(
input_type: str,
dataframe: pd.DataFrame,
system_prompt: str,
document_column: str,
num_turns: int = 1,
num_rows: int = 10,
temperature: float = 0.9,
temperature_completion: Union[float, None] = None,
is_sample: bool = False,
progress=gr.Progress(),
) -> pd.DataFrame:
if input_type == "prompt-input":
dataframe = generate_dataset_from_prompt(
system_prompt=system_prompt,
num_turns=num_turns,
num_rows=num_rows,
temperature=temperature,
temperature_completion=temperature_completion,
is_sample=is_sample,
)
else:
dataframe = generate_dataset_from_seed(
dataframe=dataframe,
document_column=document_column,
num_turns=num_turns,
num_rows=num_rows,
temperature=temperature,
temperature_completion=temperature_completion,
is_sample=is_sample,
)
return dataframe
def push_dataset_to_hub(
dataframe: pd.DataFrame,
org_name: str,
repo_name: str,
oauth_token: Union[gr.OAuthToken, None],
private: bool,
pipeline_code: str,
progress=gr.Progress(),
):
progress(0.0, desc="Validating")
repo_id = validate_push_to_hub(org_name, repo_name)
progress(0.3, desc="Converting")
original_dataframe = dataframe.copy(deep=True)
dataframe = convert_dataframe_messages(dataframe)
progress(0.7, desc="Creating dataset")
dataset = Dataset.from_pandas(dataframe)
dataset = combine_datasets(repo_id, dataset, oauth_token)
progress(0.9, desc="Pushing dataset")
distiset = Distiset({"default": dataset})
distiset.push_to_hub(
repo_id=repo_id,
private=private,
include_script=False,
token=oauth_token.token,
create_pr=False,
)
push_pipeline_code_to_hub(pipeline_code, org_name, repo_name, oauth_token)
progress(1.0, desc="Dataset pushed")
return original_dataframe
def push_dataset(
org_name: str,
repo_name: str,
private: bool,
original_repo_id: str,
file_paths: list[str],
input_type: str,
system_prompt: str,
document_column: str,
num_turns: int = 1,
num_rows: int = 10,
temperature: float = 0.9,
temperature_completion: Union[float, None] = None,
pipeline_code: str = "",
oauth_token: Union[gr.OAuthToken, None] = None,
progress=gr.Progress(),
) -> pd.DataFrame:
if input_type == "prompt-input":
dataframe = _get_dataframe()
else:
dataframe, _ = load_dataset_file(
repo_id=original_repo_id,
file_paths=file_paths,
input_type=input_type,
num_rows=num_rows,
token=oauth_token,
)
progress(0.5, desc="Generating dataset")
dataframe = generate_dataset(
input_type=input_type,
dataframe=dataframe,
system_prompt=system_prompt,
document_column=document_column,
num_turns=num_turns,
num_rows=num_rows,
temperature=temperature,
temperature_completion=temperature_completion,
)
push_dataset_to_hub(
dataframe=dataframe,
org_name=org_name,
repo_name=repo_name,
oauth_token=oauth_token,
private=private,
pipeline_code=pipeline_code,
)
try:
progress(0.1, desc="Setting up user and workspace")
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
client = get_argilla_client()
if client is None:
return ""
progress(0.5, desc="Creating dataset in Argilla")
if "messages" in dataframe.columns:
settings = rg.Settings(
fields=[
rg.ChatField(
name="messages",
description="The messages in the conversation",
title="Messages",
),
],
questions=[
rg.RatingQuestion(
name="rating",
title="Rating",
description="The rating of the conversation",
values=list(range(1, 6)),
),
],
metadata=[
rg.IntegerMetadataProperty(
name="user_message_length", title="User Message Length"
),
rg.IntegerMetadataProperty(
name="assistant_message_length",
title="Assistant Message Length",
),
],
vectors=[
rg.VectorField(
name="messages_embeddings",
dimensions=get_sentence_embedding_dimensions(),
)
],
guidelines="Please review the conversation and provide a score for the assistant's response.",
)
dataframe["user_message_length"] = dataframe["messages"].apply(
lambda x: sum([len(y["content"]) for y in x if y["role"] == "user"])
)
dataframe["assistant_message_length"] = dataframe["messages"].apply(
lambda x: sum(
[len(y["content"]) for y in x if y["role"] == "assistant"]
)
)
dataframe["messages_embeddings"] = get_embeddings(
dataframe["messages"].apply(
lambda x: " ".join([y["content"] for y in x])
)
)
else:
settings = rg.Settings(
fields=[
rg.TextField(
name="system_prompt",
title="System Prompt",
description="The system prompt used for the conversation",
required=False,
),
rg.TextField(
name="prompt",
title="Prompt",
description="The prompt used for the conversation",
),
rg.TextField(
name="completion",
title="Completion",
description="The completion from the assistant",
),
],
questions=[
rg.RatingQuestion(
name="rating",
title="Rating",
description="The rating of the conversation",
values=list(range(1, 6)),
),
],
metadata=[
rg.IntegerMetadataProperty(
name="prompt_length", title="Prompt Length"
),
rg.IntegerMetadataProperty(
name="completion_length", title="Completion Length"
),
],
vectors=[
rg.VectorField(
name="prompt_embeddings",
dimensions=get_sentence_embedding_dimensions(),
)
],
guidelines="Please review the conversation and correct the prompt and completion where needed.",
)
dataframe["prompt_length"] = dataframe["prompt"].apply(len)
dataframe["completion_length"] = dataframe["completion"].apply(len)
dataframe["prompt_embeddings"] = get_embeddings(dataframe["prompt"])
rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
if rg_dataset is None:
rg_dataset = rg.Dataset(
name=repo_name,
workspace=hf_user,
settings=settings,
client=client,
)
rg_dataset = rg_dataset.create()
progress(0.7, desc="Pushing dataset to Argilla")
hf_dataset = Dataset.from_pandas(dataframe)
rg_dataset.records.log(records=hf_dataset)
progress(1.0, desc="Dataset pushed to Argilla")
except Exception as e:
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
return ""
def save_local(
repo_id: str,
file_paths: list[str],
input_type: str,
system_prompt: str,
document_column: str,
num_turns: int,
num_rows: int,
temperature: float,
repo_name: str,
temperature_completion: Union[float, None] = None,
) -> pd.DataFrame:
if input_type == "prompt-input":
dataframe = _get_dataframe()
else:
dataframe, _ = load_dataset_file(
repo_id=repo_id,
file_paths=file_paths,
input_type=input_type,
num_rows=num_rows,
)
dataframe = generate_dataset(
input_type=input_type,
dataframe=dataframe,
system_prompt=system_prompt,
document_column=document_column,
num_turns=num_turns,
num_rows=num_rows,
temperature=temperature,
temperature_completion=temperature_completion,
)
local_dataset = Dataset.from_pandas(dataframe)
output_csv = os.path.join(SAVE_LOCAL_DIR, repo_name + ".csv")
output_json = os.path.join(SAVE_LOCAL_DIR, repo_name + ".json")
local_dataset.to_csv(output_csv, index=False)
local_dataset.to_json(output_json, index=False)
return output_csv, output_json
def show_system_prompt_visibility():
return {system_prompt: gr.Textbox(visible=True)}
def hide_system_prompt_visibility():
return {system_prompt: gr.Textbox(visible=False)}
def show_document_column_visibility():
return {document_column: gr.Dropdown(visible=True)}
def hide_document_column_visibility():
return {
document_column: gr.Dropdown(
choices=["Load your data first in step 1."],
value="Load your data first in step 1.",
visible=False,
)
}
def show_pipeline_code_visibility():
return {pipeline_code_ui: gr.Accordion(visible=True)}
def hide_pipeline_code_visibility():
return {pipeline_code_ui: gr.Accordion(visible=False)}
def show_temperature_completion():
if MODEL != MODEL_COMPLETION:
return {temperature_completion: gr.Slider(value=0.9, visible=True)}
def show_save_local_button():
return {btn_save_local: gr.Button(visible=True)}
def hide_save_local_button():
return {btn_save_local: gr.Button(visible=False)}
def show_save_local():
gr.update(success_message, min_height=0)
return {
csv_file: gr.File(visible=True),
json_file: gr.File(visible=True),
success_message: success_message
}
def hide_save_local():
gr.update(success_message, min_height=100)
return {
csv_file: gr.File(visible=False),
json_file: gr.File(visible=False),
success_message: success_message,
}
######################
# Gradio UI
######################
with gr.Blocks() as app:
with gr.Column() as main_ui:
if not SFT_AVAILABLE:
gr.Markdown(
value="\n".join(
[
"## Supervised Fine-Tuning not available",
"",
f"This tool relies on the [Magpie](https://arxiv.org/abs/2406.08464) prequery template, which is not implemented for the {MODEL} with {BASE_URL}.",
"Use Llama3 or Qwen2 models with Hugging Face Inference Endpoints.",
]
)
)
else:
gr.Markdown("## 1. Select your input")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
input_type = gr.Dropdown(
label="Input type",
choices=["prompt-input", "dataset-input", "file-input"],
value="prompt-input",
multiselect=False,
visible=False,
)
with gr.Tab("Generate from prompt") as tab_prompt_input:
with gr.Row(equal_height=False):
with gr.Column(scale=2):
dataset_description = gr.Textbox(
label="Dataset description",
placeholder="Give a precise description of your desired dataset.",
)
with gr.Row():
clear_prompt_btn_part = gr.Button(
"Clear", variant="secondary"
)
load_prompt_btn = gr.Button(
"Create", variant="primary"
)
with gr.Column(scale=3):
examples = gr.Examples(
examples=DEFAULT_DATASET_DESCRIPTIONS,
inputs=[dataset_description],
cache_examples=False,
label="Examples",
)
with gr.Tab("Load from Hub") as tab_dataset_input:
with gr.Row(equal_height=False):
with gr.Column(scale=2):
search_in = HuggingfaceHubSearch(
label="Search",
placeholder="Search for a dataset",
search_type="dataset",
sumbit_on_select=True,
)
with gr.Row():
clear_dataset_btn_part = gr.Button(
"Clear", variant="secondary"
)
load_dataset_btn = gr.Button(
"Load", variant="primary"
)
with gr.Column(scale=3):
examples = gr.Examples(
examples=[
"charris/wikipedia_sample",
"plaguss/argilla_sdk_docs_raw_unstructured",
"BeIR/hotpotqa-generated-queries",
],
label="Example datasets",
fn=lambda x: x,
inputs=[search_in],
run_on_click=True,
)
search_out = gr.HTML(
label="Dataset preview", visible=False
)
with gr.Tab("Load your file") as tab_file_input:
with gr.Row(equal_height=False):
with gr.Column(scale=2):
file_in = gr.File(
label="Upload your file. Supported formats: .md, .txt, .docx, .pdf",
file_count="multiple",
file_types=[".md", ".txt", ".docx", ".pdf"],
)
with gr.Row():
clear_file_btn_part = gr.Button(
"Clear", variant="secondary"
)
load_file_btn = gr.Button("Load", variant="primary")
with gr.Column(scale=3):
file_out = gr.HTML(
label="Dataset preview", visible=False
)
gr.HTML(value="<hr>")
gr.Markdown(value="## 2. Configure your dataset")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
system_prompt = gr.Textbox(
label="System prompt",
placeholder="You are a helpful assistant.",
)
document_column = gr.Dropdown(
label="Document Column",
info="Select the document column to generate the chat data",
choices=["Load your data first in step 1."],
value="Load your data first in step 1.",
interactive=False,
multiselect=False,
allow_custom_value=False,
visible=False,
)
num_turns = gr.Number(
value=1,
label="Number of turns in the conversation",
minimum=1,
maximum=4,
step=1,
interactive=True,
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
)
with gr.Row():
clear_btn_full = gr.Button(
"Clear",
variant="secondary",
)
btn_apply_to_sample_dataset = gr.Button(
"Save", variant="primary"
)
with gr.Column(scale=3):
dataframe = _get_dataframe()
gr.HTML(value="<hr>")
gr.Markdown(value="## 3. Generate your dataset")
with gr.Row(equal_height=False):
with gr.Column(scale=2):
org_name = get_org_dropdown()
repo_name = gr.Textbox(
label="Repo name",
placeholder="dataset_name",
value=f"my-distiset-{str(uuid.uuid4())[:8]}",
interactive=True,
)
num_rows = gr.Number(
label="Number of rows",
value=10,
interactive=True,
scale=1,
)
temperature = gr.Slider(
label="Temperature",
minimum=0.1,
maximum=1.5,
value=0.9,
step=0.1,
interactive=True,
)
temperature_completion = gr.Slider(
label="Temperature for completion",
minimum=0.1,
maximum=1.5,
value=None,
step=0.1,
interactive=True,
visible=False,
)
private = gr.Checkbox(
label="Private dataset",
value=False,
interactive=True,
scale=1,
)
btn_push_to_hub = gr.Button(
"Push to Hub", variant="primary", scale=2
)
btn_save_local = gr.Button(
"Save locally", variant="primary", scale=2, visible=False
)
with gr.Column(scale=3):
csv_file = gr.File(
label="CSV",
elem_classes="datasets",
visible=False,
)
json_file = gr.File(
label="JSON",
elem_classes="datasets",
visible=False,
)
success_message = gr.Markdown(
visible=False,
min_height=0 # don't remove this otherwise progress is not visible
)
with gr.Accordion(
"Customize your pipeline with distilabel",
open=False,
visible=False,
) as pipeline_code_ui:
code = generate_pipeline_code(
repo_id=search_in.value,
input_type=input_type.value,
system_prompt=system_prompt.value,
document_column=document_column.value,
num_turns=num_turns.value,
num_rows=num_rows.value,
)
pipeline_code = gr.Code(
value=code,
language="python",
label="Distilabel Pipeline Code",
)
tab_prompt_input.select(
fn=lambda: "prompt-input",
inputs=[],
outputs=[input_type],
).then(fn=show_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
fn=hide_document_column_visibility, inputs=[], outputs=[document_column]
)
tab_dataset_input.select(
fn=lambda: "dataset-input",
inputs=[],
outputs=[input_type],
).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
fn=show_document_column_visibility, inputs=[], outputs=[document_column]
)
tab_file_input.select(
fn=lambda: "file-input",
inputs=[],
outputs=[input_type],
).then(fn=hide_system_prompt_visibility, inputs=[], outputs=[system_prompt]).then(
fn=show_document_column_visibility, inputs=[], outputs=[document_column]
)
search_in.submit(
fn=lambda df: pd.DataFrame(columns=df.columns),
inputs=[dataframe],
outputs=[dataframe],
)
load_prompt_btn.click(
fn=generate_system_prompt,
inputs=[dataset_description],
outputs=[system_prompt],
).success(
fn=generate_sample_dataset,
inputs=[
search_in,
file_in,
input_type,
system_prompt,
document_column,
num_turns,
num_rows,
],
outputs=dataframe,
)
gr.on(
triggers=[load_dataset_btn.click, load_file_btn.click],
fn=load_dataset_file,
inputs=[search_in, file_in, input_type],
outputs=[dataframe, document_column],
)
btn_apply_to_sample_dataset.click(
fn=generate_sample_dataset,
inputs=[
search_in,
file_in,
input_type,
system_prompt,
document_column,
num_turns,
num_rows,
],
outputs=dataframe,
)
btn_push_to_hub.click(
fn=validate_argilla_user_workspace_dataset,
inputs=[repo_name],
outputs=[success_message],
).then(
fn=validate_push_to_hub,
inputs=[org_name, repo_name],
outputs=[success_message],
).success(
fn=hide_save_local,
outputs=[csv_file, json_file, success_message],
).success(
fn=hide_success_message,
outputs=[success_message],
).success(
fn=hide_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
).success(
fn=push_dataset,
inputs=[
org_name,
repo_name,
private,
search_in,
file_in,
input_type,
system_prompt,
document_column,
num_turns,
num_rows,
temperature,
temperature_completion,
pipeline_code,
],
outputs=[success_message],
).success(
fn=show_success_message,
inputs=[org_name, repo_name],
outputs=[success_message],
).success(
fn=generate_pipeline_code,
inputs=[
search_in,
input_type,
system_prompt,
document_column,
num_turns,
num_rows,
],
outputs=[pipeline_code],
).success(
fn=show_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
)
btn_save_local.click(
fn=hide_success_message,
outputs=[success_message],
).success(
fn=hide_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
).success(
fn=show_save_local,
inputs=[],
outputs=[csv_file, json_file, success_message],
).success(
save_local,
inputs=[
search_in,
file_in,
input_type,
system_prompt,
document_column,
num_turns,
num_rows,
temperature,
repo_name,
temperature_completion,
],
outputs=[csv_file, json_file],
).success(
fn=generate_pipeline_code,
inputs=[
search_in,
input_type,
system_prompt,
document_column,
num_turns,
num_rows,
],
outputs=[pipeline_code],
).success(
fn=show_pipeline_code_visibility,
inputs=[],
outputs=[pipeline_code_ui],
)
clear_dataset_btn_part.click(fn=lambda: "", inputs=[], outputs=[search_in])
clear_file_btn_part.click(fn=lambda: None, inputs=[], outputs=[file_in])
clear_prompt_btn_part.click(fn=lambda: "", inputs=[], outputs=[dataset_description])
clear_btn_full.click(
fn=lambda df: ("", "", [], _get_dataframe()),
inputs=[dataframe],
outputs=[system_prompt, document_column, num_turns, dataframe],
)
app.load(fn=swap_visibility, outputs=main_ui)
app.load(fn=get_org_dropdown, outputs=[org_name])
app.load(fn=get_random_repo_name, outputs=[repo_name])
app.load(fn=show_temperature_completion, outputs=[temperature_completion])
if SAVE_LOCAL_DIR is not None:
app.load(fn=show_save_local_button, outputs=btn_save_local)