Beegbrain's picture
Add app.py
827715e
raw
history blame
10.7 kB
import gradio as gr
import pandas as pd
import json
from huggingface_hub import HfApi
import pandas as pd
import json
import spacy
import ast
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
def analyze_dataset_metadata(repo_id: str):
try:
metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.0")
except Exception as e:
try:
metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.1")
except Exception as e:
print(f"Error loading metadata for {repo_id}: {str(e)}")
return None
# Check version
version_str = str(metadata._version).strip()
if version_str not in ["2.0", "2.1"]:
print(f"Skipping {repo_id}: version <{version_str}>")
return None
try:
info = {
"repo_id": repo_id,
"username": repo_id.split('/')[0],
"robot_type": metadata.robot_type,
"total_episodes": metadata.total_episodes,
"total_frames": metadata.total_frames,
"fps": metadata.fps,
"camera_keys": ','.join(metadata.camera_keys), # Convert list to string
"num_cameras": len(metadata.camera_keys),
"video_keys": ','.join(metadata.video_keys),
"has_video": len(metadata.video_keys) > 0,
"total_tasks": metadata.total_tasks,
"tasks": json.dumps(metadata.tasks), # Convert dict to JSON string
"is_sim": "sim_" in repo_id.lower(),
"is_eval": "eval_" in repo_id.lower(),
"features": ','.join(metadata.features.keys()),
"chunks_size": metadata.chunks_size,
"total_chunks": metadata.total_chunks,
"version": metadata._version
}
return info
except Exception as e:
print(f"Error extracting metadata for {repo_id}: {str(e)}")
return None
def extract_metadata_fn(tags, progress=gr.Progress()):
progress(0)
api = HfApi()
tags = tags.split(",") if tags else None
datasets = api.list_datasets(tags=tags)
repo_ids = [dataset.id for dataset in datasets]
gr.Info(f"Found {len(repo_ids)} datasets with provided tags. Extracting metadata...")
dataset_infos = []
for i, repo_id in progress.tqdm(enumerate(repo_ids)):
progress(i)
info = analyze_dataset_metadata(repo_id)
if info is not None:
dataset_infos.append(info)
# Convert to DataFrame and save to CSV and pickle
df = pd.DataFrame(dataset_infos)
csv_filename = "lerobot_datasets.csv"
gr.Info(f"Dataset metadata extracted. Saving to {csv_filename}")
df.to_csv(csv_filename, index=False)
return df
def load_metadata_fn(file_explorer):
gr.Info(f"Metadata loaded from {file_explorer}.")
df = pd.read_csv(file_explorer)
return df
def filter_tasks(tasks_json):
"""Filter out tasks that are too short and contain weird names"""
try:
tasks = json.loads(tasks_json)
valid_tasks = [task for task in tasks.values()
if task and isinstance(task, str) and len(task.strip()) > 10
and len(task.split("_")) < 3 and "test" not in task.lower()]
return len(valid_tasks) > 0
except:
return False
def filtering_metadata(
df,
num_episodes,
num_frames,
include_sim,
robot_set,
include_eval,
filter_unlabeled_tasks
):
all_data_number = len(df)
filtered_datasets = df[
(df['total_episodes'] >= num_episodes) &
(df['total_frames'] >= num_frames) &
(df['has_video'] == True) &
(df['is_sim'] == include_sim) &
(df['robot_type'].isin(robot_set)) &
('test' not in df['repo_id'])
]
if not include_eval:
filtered_datasets = filtered_datasets[filtered_datasets['is_eval'] == False]
if filter_unlabeled_tasks:
filtered_datasets['has_valid_tasks'] = filtered_datasets['tasks'].apply(filter_tasks)
filtered_datasets = filtered_datasets[filtered_datasets['has_valid_tasks']]
gr.Info(f"Filtering datasets from {all_data_number} to {len(filtered_datasets)}")
return len(filtered_datasets), filtered_datasets["repo_id"].to_list(), filtered_datasets
class LeRobotAnalysisApp(object):
def __init__(self, ui_obj):
self.name = "LeRobot Analysis App"
self.description = "Analyze LeRobot datasets"
self.ui_obj = ui_obj
# TODO
def create_app(self):
with self.ui_obj:
gr.Markdown("Application to filter & analyze LeRobot datasets")
filtered_data = gr.DataFrame(visible=False)
with gr.Tabs():
with gr.TabItem("1) Extract/Load Data"):
with gr.Row():
with gr.Column():
gr.Markdown("# Extract metadata from HF API")
gr.Markdown("Choose a set of **tags** (separated by a coma) to select the datasets to extract **metadata** from.")
gr.Markdown("The final metadata will be saved to a **CSV file**.")
tags = gr.Textbox(label="Tags", value="LeRobot",
placeholder="Enter tags separated by comma",
info="Enter tags separated by comma",
lines=3)
btn_extract = gr.Button("Extract Data")
gr.Markdown("# OR Load from CSV")
gr.Markdown("If you already downloaded the metadata in CSV, you can directly load it here.")
file_explorer = gr.FileExplorer(label="Load CSV file", file_count="single")
btn_load = gr.Button("Load CSV Data")
with gr.Column():
out_data = gr.DataFrame()
btn_extract.click(extract_metadata_fn, [tags], [out_data])
btn_load.click(load_metadata_fn, [file_explorer], [out_data])
with gr.TabItem("2) Filter Data"):
@gr.render(inputs=[out_data])
def filter_data(out_data):
if out_data.empty:
gr.Markdown("# Filtering data")
gr.Markdown("No data to display : please extract or load metadata first")
else:
df = out_data
min_eps = int(df['total_episodes'].min())
min_frames = int(df['total_frames'].min())
robot_types = list(set(df['robot_type'].to_list()))
robot_types.sort()
with gr.Row():
with gr.Column():
gr.Markdown("# Filtering data")
gr.Markdown("Filter the extracted datasets to your needs")
data = gr.DataFrame(label="Dataset Metadata", value=out_data)
is_sim = gr.Checkbox(label="Include simulation datasets", value=False)
eps = gr.Number(label="Min episodes ", value=min_eps)
frames = gr.Number(label="Min frames", value=min_frames)
robot_type = gr.CheckboxGroup(label="Robot types", choices=robot_types)
incl_eval = gr.Checkbox(label="Include evaluation datasets", value=False)
filter_task = gr.Checkbox(label="Filter unlabeled tasks", value=True)
btn_filter = gr.Button("Filter Data")
with gr.Column():
out_num_d = gr.Number(label="Number of datasets", value=0)
out_text = gr.Text(label="Dataset repo IDs", value="")
btn_filter.click(filtering_metadata,
inputs=[data, eps, frames, is_sim, robot_type, incl_eval, filter_task],
outputs=[out_num_d, out_text, filtered_data])
with gr.TabItem("3) Analyze Data"):
@gr.render(inputs=[out_data, filtered_data])
def analyze_data(out_data, filtered_data):
if out_data.empty:
gr.Markdown("# Analyzing data")
gr.Markdown("No data to display : please extract or load metadata first")
else:
with gr.Row():
with gr.Column():
if filtered_data.empty:
gr.BarPlot(out_data, x="robot_type", y="total_episodes", title="Episodes per robot type")
else:
actions_df = self.extract_actions_from_tasks(filtered_data['tasks'])
gr.BarPlot(filtered_data, x="robot_type", y="total_episodes", title="Episodes per robot type")
gr.BarPlot(actions_df, title="Counting of each actions",
x="actions",
y="count",
x_label="Actions",
y_label="Count of actions")
def extract_actions_from_tasks(self, tasks):
gr.Info("Extracting actions from tasks, it might take a while...")
nlp = spacy.load("en_core_web_sm")
actions = []
for el in tasks:
dict_tasks = ast.literal_eval(el)
for id, task in dict_tasks.items():
doc = nlp(task)
for token in doc:
if token.pos_ == "VERB":
actions.append(token.lemma_)
# Remove duplicates
actions_unique = list(set(actions))
count_actions = [actions.count(action) for action in actions_unique]
return pd.DataFrame({"actions": actions_unique, "count": count_actions})
def launch_ui(self):
self.ui_obj.launch()
if __name__ == "__main__":
app = gr.Blocks()
ui = LeRobotAnalysisApp(app)
ui.create_app()
ui.launch_ui()