Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import json | |
from huggingface_hub import HfApi | |
import pandas as pd | |
import json | |
import spacy | |
import ast | |
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata | |
def analyze_dataset_metadata(repo_id: str): | |
try: | |
metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.0") | |
except Exception as e: | |
try: | |
metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.1") | |
except Exception as e: | |
print(f"Error loading metadata for {repo_id}: {str(e)}") | |
return None | |
# Check version | |
version_str = str(metadata._version).strip() | |
if version_str not in ["2.0", "2.1"]: | |
print(f"Skipping {repo_id}: version <{version_str}>") | |
return None | |
try: | |
info = { | |
"repo_id": repo_id, | |
"username": repo_id.split('/')[0], | |
"robot_type": metadata.robot_type, | |
"total_episodes": metadata.total_episodes, | |
"total_frames": metadata.total_frames, | |
"fps": metadata.fps, | |
"camera_keys": ','.join(metadata.camera_keys), # Convert list to string | |
"num_cameras": len(metadata.camera_keys), | |
"video_keys": ','.join(metadata.video_keys), | |
"has_video": len(metadata.video_keys) > 0, | |
"total_tasks": metadata.total_tasks, | |
"tasks": json.dumps(metadata.tasks), # Convert dict to JSON string | |
"is_sim": "sim_" in repo_id.lower(), | |
"is_eval": "eval_" in repo_id.lower(), | |
"features": ','.join(metadata.features.keys()), | |
"chunks_size": metadata.chunks_size, | |
"total_chunks": metadata.total_chunks, | |
"version": metadata._version | |
} | |
return info | |
except Exception as e: | |
print(f"Error extracting metadata for {repo_id}: {str(e)}") | |
return None | |
def extract_metadata_fn(tags, progress=gr.Progress()): | |
progress(0) | |
api = HfApi() | |
tags = tags.split(",") if tags else None | |
datasets = api.list_datasets(tags=tags) | |
repo_ids = [dataset.id for dataset in datasets] | |
gr.Info(f"Found {len(repo_ids)} datasets with provided tags. Extracting metadata...") | |
dataset_infos = [] | |
for i, repo_id in progress.tqdm(enumerate(repo_ids)): | |
progress(i) | |
info = analyze_dataset_metadata(repo_id) | |
if info is not None: | |
dataset_infos.append(info) | |
# Convert to DataFrame and save to CSV and pickle | |
df = pd.DataFrame(dataset_infos) | |
csv_filename = "lerobot_datasets.csv" | |
gr.Info(f"Dataset metadata extracted. Saving to {csv_filename}") | |
df.to_csv(csv_filename, index=False) | |
return df | |
def load_metadata_fn(file_explorer): | |
gr.Info(f"Metadata loaded from {file_explorer}.") | |
df = pd.read_csv(file_explorer) | |
return df | |
def filter_tasks(tasks_json): | |
"""Filter out tasks that are too short and contain weird names""" | |
try: | |
tasks = json.loads(tasks_json) | |
valid_tasks = [task for task in tasks.values() | |
if task and isinstance(task, str) and len(task.strip()) > 10 | |
and len(task.split("_")) < 3 and "test" not in task.lower()] | |
return len(valid_tasks) > 0 | |
except: | |
return False | |
def filtering_metadata( | |
df, | |
num_episodes, | |
num_frames, | |
include_sim, | |
robot_set, | |
include_eval, | |
filter_unlabeled_tasks | |
): | |
all_data_number = len(df) | |
filtered_datasets = df[ | |
(df['total_episodes'] >= num_episodes) & | |
(df['total_frames'] >= num_frames) & | |
(df['has_video'] == True) & | |
(df['is_sim'] == include_sim) & | |
(df['robot_type'].isin(robot_set)) & | |
('test' not in df['repo_id']) | |
] | |
if not include_eval: | |
filtered_datasets = filtered_datasets[filtered_datasets['is_eval'] == False] | |
if filter_unlabeled_tasks: | |
filtered_datasets['has_valid_tasks'] = filtered_datasets['tasks'].apply(filter_tasks) | |
filtered_datasets = filtered_datasets[filtered_datasets['has_valid_tasks']] | |
gr.Info(f"Filtering datasets from {all_data_number} to {len(filtered_datasets)}") | |
return len(filtered_datasets), filtered_datasets["repo_id"].to_list(), filtered_datasets | |
class LeRobotAnalysisApp(object): | |
def __init__(self, ui_obj): | |
self.name = "LeRobot Analysis App" | |
self.description = "Analyze LeRobot datasets" | |
self.ui_obj = ui_obj | |
# TODO | |
def create_app(self): | |
with self.ui_obj: | |
gr.Markdown("Application to filter & analyze LeRobot datasets") | |
filtered_data = gr.DataFrame(visible=False) | |
with gr.Tabs(): | |
with gr.TabItem("1) Extract/Load Data"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("# Extract metadata from HF API") | |
gr.Markdown("Choose a set of **tags** (separated by a coma) to select the datasets to extract **metadata** from.") | |
gr.Markdown("The final metadata will be saved to a **CSV file**.") | |
tags = gr.Textbox(label="Tags", value="LeRobot", | |
placeholder="Enter tags separated by comma", | |
info="Enter tags separated by comma", | |
lines=3) | |
btn_extract = gr.Button("Extract Data") | |
gr.Markdown("# OR Load from CSV") | |
gr.Markdown("If you already downloaded the metadata in CSV, you can directly load it here.") | |
file_explorer = gr.FileExplorer(label="Load CSV file", file_count="single") | |
btn_load = gr.Button("Load CSV Data") | |
with gr.Column(): | |
out_data = gr.DataFrame() | |
btn_extract.click(extract_metadata_fn, [tags], [out_data]) | |
btn_load.click(load_metadata_fn, [file_explorer], [out_data]) | |
with gr.TabItem("2) Filter Data"): | |
def filter_data(out_data): | |
if out_data.empty: | |
gr.Markdown("# Filtering data") | |
gr.Markdown("No data to display : please extract or load metadata first") | |
else: | |
df = out_data | |
min_eps = int(df['total_episodes'].min()) | |
min_frames = int(df['total_frames'].min()) | |
robot_types = list(set(df['robot_type'].to_list())) | |
robot_types.sort() | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("# Filtering data") | |
gr.Markdown("Filter the extracted datasets to your needs") | |
data = gr.DataFrame(label="Dataset Metadata", value=out_data) | |
is_sim = gr.Checkbox(label="Include simulation datasets", value=False) | |
eps = gr.Number(label="Min episodes ", value=min_eps) | |
frames = gr.Number(label="Min frames", value=min_frames) | |
robot_type = gr.CheckboxGroup(label="Robot types", choices=robot_types) | |
incl_eval = gr.Checkbox(label="Include evaluation datasets", value=False) | |
filter_task = gr.Checkbox(label="Filter unlabeled tasks", value=True) | |
btn_filter = gr.Button("Filter Data") | |
with gr.Column(): | |
out_num_d = gr.Number(label="Number of datasets", value=0) | |
out_text = gr.Text(label="Dataset repo IDs", value="") | |
btn_filter.click(filtering_metadata, | |
inputs=[data, eps, frames, is_sim, robot_type, incl_eval, filter_task], | |
outputs=[out_num_d, out_text, filtered_data]) | |
with gr.TabItem("3) Analyze Data"): | |
def analyze_data(out_data, filtered_data): | |
if out_data.empty: | |
gr.Markdown("# Analyzing data") | |
gr.Markdown("No data to display : please extract or load metadata first") | |
else: | |
with gr.Row(): | |
with gr.Column(): | |
if filtered_data.empty: | |
gr.BarPlot(out_data, x="robot_type", y="total_episodes", title="Episodes per robot type") | |
else: | |
actions_df = self.extract_actions_from_tasks(filtered_data['tasks']) | |
gr.BarPlot(filtered_data, x="robot_type", y="total_episodes", title="Episodes per robot type") | |
gr.BarPlot(actions_df, title="Counting of each actions", | |
x="actions", | |
y="count", | |
x_label="Actions", | |
y_label="Count of actions") | |
def extract_actions_from_tasks(self, tasks): | |
gr.Info("Extracting actions from tasks, it might take a while...") | |
nlp = spacy.load("en_core_web_sm") | |
actions = [] | |
for el in tasks: | |
dict_tasks = ast.literal_eval(el) | |
for id, task in dict_tasks.items(): | |
doc = nlp(task) | |
for token in doc: | |
if token.pos_ == "VERB": | |
actions.append(token.lemma_) | |
# Remove duplicates | |
actions_unique = list(set(actions)) | |
count_actions = [actions.count(action) for action in actions_unique] | |
return pd.DataFrame({"actions": actions_unique, "count": count_actions}) | |
def launch_ui(self): | |
self.ui_obj.launch() | |
if __name__ == "__main__": | |
app = gr.Blocks() | |
ui = LeRobotAnalysisApp(app) | |
ui.create_app() | |
ui.launch_ui() | |