DatasetManager / app.py
Bradarr's picture
Update app.py
cc93bc7 verified
raw
history blame
5.72 kB
import gradio as gr
from datasets import load_dataset, Features, Value, Audio, Dataset
from huggingface_hub import HfApi, create_repo
import pandas as pd
def filter_dataset(dataset_name, split_name, keywords_text):
"""Filters a dataset based on keywords and returns a Pandas DataFrame."""
try:
# --- 1. Load the dataset in streaming mode ---
dataset = load_dataset(dataset_name, split=split_name, streaming=True)
# --- 2. Filter the dataset (streaming) ---
keywords = [keyword.strip().lower() for keyword in keywords_text.split(',') if keyword.strip()]
if not keywords:
return pd.DataFrame(), "Error: No keywords provided."
# Define a filtering function that handles potential KeyErrors
def filter_func(example):
prompt_value = example.get("prompt", "") # Get prompt, default to empty string
return any(keyword in prompt_value.lower() for keyword in keywords)
filtered_dataset = dataset.filter(filter_func)
# --- 3. Select Indices (Efficiently) ---
matching_indices = []
data_for_df = [] # Store data for DataFrame
for i, example in enumerate(filtered_dataset):
matching_indices.append(i)
# Extract data and append. Handle potential KeyErrors.
example_data = {
'prompt': example.get('prompt', None), # Use .get() for safety
'chosen': example.get('chosen', {}).get('array', None) if isinstance(example.get('chosen'), dict) else None, # Handle nested structure, check if it's a dict
'rejected': example.get('rejected', {}).get('array', None) if isinstance(example.get('rejected'), dict) else None, # Handle nested structure
}
data_for_df.append(example_data)
if not matching_indices:
return pd.DataFrame(), "No matching examples found."
# --- 4. Create Pandas DataFrame ---
df = pd.DataFrame(data_for_df)
return df, f"Found {len(matching_indices)} matching examples."
except Exception as e:
return pd.DataFrame(), f"An error occurred: {e}"
def push_to_hub(df_json, dataset_name, split_name, new_dataset_repo_id, hf_token):
"""Pushes a Pandas DataFrame (from JSON) to the Hugging Face Hub."""
if not hf_token:
return "Error: Hugging Face Token is required.", None
try:
# Convert JSON back to DataFrame
df = pd.read_json(df_json)
if df.empty:
return "Error: Cannot push an empty dataset", None
# Convert DataFrame to Hugging Face Dataset
dataset = Dataset.from_pandas(df)
# --- 5. Define features (for consistent schema) ---
features_dict = {
'prompt': Value(dtype='string', id=None),
'chosen': Audio(sampling_rate=16000), # Assuming 16kHz; adjust if needed
'rejected': Audio(sampling_rate=16000), # Assuming 16kHz
}
features = Features(features_dict)
try:
dataset = dataset.cast(features)
except Exception as e:
return f"An error occurred during casting: {e}", None
# --- 6. Upload to the Hugging Face Hub ---
api = HfApi(token=hf_token)
try:
create_repo(new_dataset_repo_id, token=hf_token, repo_type="dataset")
print(f"Repository '{new_dataset_repo_id}' created.")
except Exception as e:
if "Repo already exists" not in str(e):
return f"Error creating repository: {e}", None
dataset.push_to_hub(new_dataset_repo_id)
dataset_url = f"https://huggingface.co/datasets/{new_dataset_repo_id}"
return f"Subset dataset uploaded successfully!", dataset_url
except Exception as e:
return f"An error occurred during push: {e}", None
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# Dataset Filter and Push")
with gr.Row():
dataset_name_input = gr.Textbox(label="Source Dataset Name", value="ashraq/esc50") # Example with chosen/rejected
split_name_input = gr.Textbox(label="Split Name", value="train")
keywords_input = gr.Textbox(label="Keywords (comma-separated)", value="dog, cat")
filter_button = gr.Button("Filter Dataset")
# Display the filtered data. 'label' is important for presentation.
filtered_data_output = gr.Dataframe(label="Filtered Data")
filter_status_output = gr.Textbox(label="Filter Status")
with gr.Row():
new_dataset_repo_id_input = gr.Textbox(label="New Dataset Repo ID")
hf_token_input = gr.Textbox(label="Hugging Face Token", type="password")
push_button = gr.Button("Push to Hub")
push_status_output = gr.Textbox(label="Push Status")
dataset_url_output = gr.Textbox(label="Dataset URL") # Display the dataset URL
# Hidden component to store the filtered dataset (as JSON)
filtered_data_json = gr.JSON(visible=False)
# Connect the filter button
filter_button.click(
filter_dataset,
inputs=[dataset_name_input, split_name_input, keywords_input],
outputs=[filtered_data_output, filter_status_output]
).then( # Use .then() to chain actions
lambda df: df.to_json(), # Convert DataFrame to JSON
inputs=[filtered_data_output],
outputs=[filtered_data_json] # Store in the hidden JSON component
)
# Connect the push button
push_button.click(
push_to_hub,
inputs=[filtered_data_json, dataset_name_input, split_name_input, new_dataset_repo_id_input, hf_token_input],
outputs=[push_status_output, dataset_url_output]
)
if __name__ == "__main__":
demo.launch()