Spaces:
Running
Running
File size: 10,670 Bytes
827715e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import gradio as gr
import pandas as pd
import json
from huggingface_hub import HfApi
import pandas as pd
import json
import spacy
import ast
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
def analyze_dataset_metadata(repo_id: str):
try:
metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.0")
except Exception as e:
try:
metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.1")
except Exception as e:
print(f"Error loading metadata for {repo_id}: {str(e)}")
return None
# Check version
version_str = str(metadata._version).strip()
if version_str not in ["2.0", "2.1"]:
print(f"Skipping {repo_id}: version <{version_str}>")
return None
try:
info = {
"repo_id": repo_id,
"username": repo_id.split('/')[0],
"robot_type": metadata.robot_type,
"total_episodes": metadata.total_episodes,
"total_frames": metadata.total_frames,
"fps": metadata.fps,
"camera_keys": ','.join(metadata.camera_keys), # Convert list to string
"num_cameras": len(metadata.camera_keys),
"video_keys": ','.join(metadata.video_keys),
"has_video": len(metadata.video_keys) > 0,
"total_tasks": metadata.total_tasks,
"tasks": json.dumps(metadata.tasks), # Convert dict to JSON string
"is_sim": "sim_" in repo_id.lower(),
"is_eval": "eval_" in repo_id.lower(),
"features": ','.join(metadata.features.keys()),
"chunks_size": metadata.chunks_size,
"total_chunks": metadata.total_chunks,
"version": metadata._version
}
return info
except Exception as e:
print(f"Error extracting metadata for {repo_id}: {str(e)}")
return None
def extract_metadata_fn(tags, progress=gr.Progress()):
progress(0)
api = HfApi()
tags = tags.split(",") if tags else None
datasets = api.list_datasets(tags=tags)
repo_ids = [dataset.id for dataset in datasets]
gr.Info(f"Found {len(repo_ids)} datasets with provided tags. Extracting metadata...")
dataset_infos = []
for i, repo_id in progress.tqdm(enumerate(repo_ids)):
progress(i)
info = analyze_dataset_metadata(repo_id)
if info is not None:
dataset_infos.append(info)
# Convert to DataFrame and save to CSV and pickle
df = pd.DataFrame(dataset_infos)
csv_filename = "lerobot_datasets.csv"
gr.Info(f"Dataset metadata extracted. Saving to {csv_filename}")
df.to_csv(csv_filename, index=False)
return df
def load_metadata_fn(file_explorer):
gr.Info(f"Metadata loaded from {file_explorer}.")
df = pd.read_csv(file_explorer)
return df
def filter_tasks(tasks_json):
"""Filter out tasks that are too short and contain weird names"""
try:
tasks = json.loads(tasks_json)
valid_tasks = [task for task in tasks.values()
if task and isinstance(task, str) and len(task.strip()) > 10
and len(task.split("_")) < 3 and "test" not in task.lower()]
return len(valid_tasks) > 0
except:
return False
def filtering_metadata(
df,
num_episodes,
num_frames,
include_sim,
robot_set,
include_eval,
filter_unlabeled_tasks
):
all_data_number = len(df)
filtered_datasets = df[
(df['total_episodes'] >= num_episodes) &
(df['total_frames'] >= num_frames) &
(df['has_video'] == True) &
(df['is_sim'] == include_sim) &
(df['robot_type'].isin(robot_set)) &
('test' not in df['repo_id'])
]
if not include_eval:
filtered_datasets = filtered_datasets[filtered_datasets['is_eval'] == False]
if filter_unlabeled_tasks:
filtered_datasets['has_valid_tasks'] = filtered_datasets['tasks'].apply(filter_tasks)
filtered_datasets = filtered_datasets[filtered_datasets['has_valid_tasks']]
gr.Info(f"Filtering datasets from {all_data_number} to {len(filtered_datasets)}")
return len(filtered_datasets), filtered_datasets["repo_id"].to_list(), filtered_datasets
class LeRobotAnalysisApp(object):
def __init__(self, ui_obj):
self.name = "LeRobot Analysis App"
self.description = "Analyze LeRobot datasets"
self.ui_obj = ui_obj
# TODO
def create_app(self):
with self.ui_obj:
gr.Markdown("Application to filter & analyze LeRobot datasets")
filtered_data = gr.DataFrame(visible=False)
with gr.Tabs():
with gr.TabItem("1) Extract/Load Data"):
with gr.Row():
with gr.Column():
gr.Markdown("# Extract metadata from HF API")
gr.Markdown("Choose a set of **tags** (separated by a coma) to select the datasets to extract **metadata** from.")
gr.Markdown("The final metadata will be saved to a **CSV file**.")
tags = gr.Textbox(label="Tags", value="LeRobot",
placeholder="Enter tags separated by comma",
info="Enter tags separated by comma",
lines=3)
btn_extract = gr.Button("Extract Data")
gr.Markdown("# OR Load from CSV")
gr.Markdown("If you already downloaded the metadata in CSV, you can directly load it here.")
file_explorer = gr.FileExplorer(label="Load CSV file", file_count="single")
btn_load = gr.Button("Load CSV Data")
with gr.Column():
out_data = gr.DataFrame()
btn_extract.click(extract_metadata_fn, [tags], [out_data])
btn_load.click(load_metadata_fn, [file_explorer], [out_data])
with gr.TabItem("2) Filter Data"):
@gr.render(inputs=[out_data])
def filter_data(out_data):
if out_data.empty:
gr.Markdown("# Filtering data")
gr.Markdown("No data to display : please extract or load metadata first")
else:
df = out_data
min_eps = int(df['total_episodes'].min())
min_frames = int(df['total_frames'].min())
robot_types = list(set(df['robot_type'].to_list()))
robot_types.sort()
with gr.Row():
with gr.Column():
gr.Markdown("# Filtering data")
gr.Markdown("Filter the extracted datasets to your needs")
data = gr.DataFrame(label="Dataset Metadata", value=out_data)
is_sim = gr.Checkbox(label="Include simulation datasets", value=False)
eps = gr.Number(label="Min episodes ", value=min_eps)
frames = gr.Number(label="Min frames", value=min_frames)
robot_type = gr.CheckboxGroup(label="Robot types", choices=robot_types)
incl_eval = gr.Checkbox(label="Include evaluation datasets", value=False)
filter_task = gr.Checkbox(label="Filter unlabeled tasks", value=True)
btn_filter = gr.Button("Filter Data")
with gr.Column():
out_num_d = gr.Number(label="Number of datasets", value=0)
out_text = gr.Text(label="Dataset repo IDs", value="")
btn_filter.click(filtering_metadata,
inputs=[data, eps, frames, is_sim, robot_type, incl_eval, filter_task],
outputs=[out_num_d, out_text, filtered_data])
with gr.TabItem("3) Analyze Data"):
@gr.render(inputs=[out_data, filtered_data])
def analyze_data(out_data, filtered_data):
if out_data.empty:
gr.Markdown("# Analyzing data")
gr.Markdown("No data to display : please extract or load metadata first")
else:
with gr.Row():
with gr.Column():
if filtered_data.empty:
gr.BarPlot(out_data, x="robot_type", y="total_episodes", title="Episodes per robot type")
else:
actions_df = self.extract_actions_from_tasks(filtered_data['tasks'])
gr.BarPlot(filtered_data, x="robot_type", y="total_episodes", title="Episodes per robot type")
gr.BarPlot(actions_df, title="Counting of each actions",
x="actions",
y="count",
x_label="Actions",
y_label="Count of actions")
def extract_actions_from_tasks(self, tasks):
gr.Info("Extracting actions from tasks, it might take a while...")
nlp = spacy.load("en_core_web_sm")
actions = []
for el in tasks:
dict_tasks = ast.literal_eval(el)
for id, task in dict_tasks.items():
doc = nlp(task)
for token in doc:
if token.pos_ == "VERB":
actions.append(token.lemma_)
# Remove duplicates
actions_unique = list(set(actions))
count_actions = [actions.count(action) for action in actions_unique]
return pd.DataFrame({"actions": actions_unique, "count": count_actions})
def launch_ui(self):
self.ui_obj.launch()
if __name__ == "__main__":
app = gr.Blocks()
ui = LeRobotAnalysisApp(app)
ui.create_app()
ui.launch_ui()
|