Spaces:
Sleeping
Sleeping
import gradio as gr | |
from datasets import load_dataset, Features, Value, Audio, Dataset | |
from huggingface_hub import HfApi, create_repo | |
import pandas as pd | |
def filter_dataset(dataset_name, split_name, keywords_text): | |
"""Filters a dataset based on keywords and returns a Pandas DataFrame.""" | |
try: | |
# --- 1. Load the dataset in streaming mode --- | |
dataset = load_dataset(dataset_name, split=split_name, streaming=True) | |
# --- 2. Filter the dataset (streaming) --- | |
keywords = [keyword.strip().lower() for keyword in keywords_text.split(',') if keyword.strip()] | |
if not keywords: | |
return pd.DataFrame(), "Error: No keywords provided." | |
# Define a filtering function that handles potential KeyErrors | |
def filter_func(example): | |
prompt_value = example.get("prompt", "") # Get prompt, default to empty string | |
return any(keyword in prompt_value.lower() for keyword in keywords) | |
filtered_dataset = dataset.filter(filter_func) | |
# --- 3. Select Indices (Efficiently) --- | |
matching_indices = [] | |
data_for_df = [] # Store data for DataFrame | |
for i, example in enumerate(filtered_dataset): | |
matching_indices.append(i) | |
# Extract data and append. Handle potential KeyErrors. | |
example_data = { | |
'prompt': example.get('prompt', None), # Use .get() for safety | |
'chosen': example.get('chosen', {}).get('array', None) if isinstance(example.get('chosen'), dict) else None, # Handle nested structure, check if it's a dict | |
'rejected': example.get('rejected', {}).get('array', None) if isinstance(example.get('rejected'), dict) else None, # Handle nested structure | |
} | |
data_for_df.append(example_data) | |
if not matching_indices: | |
return pd.DataFrame(), "No matching examples found." | |
# --- 4. Create Pandas DataFrame --- | |
df = pd.DataFrame(data_for_df) | |
return df, f"Found {len(matching_indices)} matching examples." | |
except Exception as e: | |
return pd.DataFrame(), f"An error occurred: {e}" | |
def push_to_hub(df_json, dataset_name, split_name, new_dataset_repo_id, hf_token): | |
"""Pushes a Pandas DataFrame (from JSON) to the Hugging Face Hub.""" | |
if not hf_token: | |
return "Error: Hugging Face Token is required.", None | |
try: | |
# Convert JSON back to DataFrame | |
df = pd.read_json(df_json) | |
if df.empty: | |
return "Error: Cannot push an empty dataset", None | |
# Convert DataFrame to Hugging Face Dataset | |
dataset = Dataset.from_pandas(df) | |
# --- 5. Define features (for consistent schema) --- | |
features_dict = { | |
'prompt': Value(dtype='string', id=None), | |
'chosen': Audio(sampling_rate=16000), # Assuming 16kHz; adjust if needed | |
'rejected': Audio(sampling_rate=16000), # Assuming 16kHz | |
} | |
features = Features(features_dict) | |
try: | |
dataset = dataset.cast(features) | |
except Exception as e: | |
return f"An error occurred during casting: {e}", None | |
# --- 6. Upload to the Hugging Face Hub --- | |
api = HfApi(token=hf_token) | |
try: | |
create_repo(new_dataset_repo_id, token=hf_token, repo_type="dataset") | |
print(f"Repository '{new_dataset_repo_id}' created.") | |
except Exception as e: | |
if "Repo already exists" not in str(e): | |
return f"Error creating repository: {e}", None | |
dataset.push_to_hub(new_dataset_repo_id) | |
dataset_url = f"https://huggingface.co/datasets/{new_dataset_repo_id}" | |
return f"Subset dataset uploaded successfully!", dataset_url | |
except Exception as e: | |
return f"An error occurred during push: {e}", None | |
# --- Gradio Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Dataset Filter and Push") | |
with gr.Row(): | |
dataset_name_input = gr.Textbox(label="Source Dataset Name", value="ashraq/esc50") # Example with chosen/rejected | |
split_name_input = gr.Textbox(label="Split Name", value="train") | |
keywords_input = gr.Textbox(label="Keywords (comma-separated)", value="dog, cat") | |
filter_button = gr.Button("Filter Dataset") | |
# Display the filtered data. 'label' is important for presentation. | |
filtered_data_output = gr.Dataframe(label="Filtered Data") | |
filter_status_output = gr.Textbox(label="Filter Status") | |
with gr.Row(): | |
new_dataset_repo_id_input = gr.Textbox(label="New Dataset Repo ID") | |
hf_token_input = gr.Textbox(label="Hugging Face Token", type="password") | |
push_button = gr.Button("Push to Hub") | |
push_status_output = gr.Textbox(label="Push Status") | |
dataset_url_output = gr.Textbox(label="Dataset URL") # Display the dataset URL | |
# Hidden component to store the filtered dataset (as JSON) | |
filtered_data_json = gr.JSON(visible=False) | |
# Connect the filter button | |
filter_button.click( | |
filter_dataset, | |
inputs=[dataset_name_input, split_name_input, keywords_input], | |
outputs=[filtered_data_output, filter_status_output] | |
).then( # Use .then() to chain actions | |
lambda df: df.to_json(), # Convert DataFrame to JSON | |
inputs=[filtered_data_output], | |
outputs=[filtered_data_json] # Store in the hidden JSON component | |
) | |
# Connect the push button | |
push_button.click( | |
push_to_hub, | |
inputs=[filtered_data_json, dataset_name_input, split_name_input, new_dataset_repo_id_input, hf_token_input], | |
outputs=[push_status_output, dataset_url_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |