File size: 10,670 Bytes
827715e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import gradio as gr
import pandas as pd
import json
from huggingface_hub import HfApi
import pandas as pd
import json
import spacy
import ast
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata

def analyze_dataset_metadata(repo_id: str):
    try:
        metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.0")
    except Exception as e:
        try:
            metadata = LeRobotDatasetMetadata(repo_id=repo_id, revision="v2.1")
        except Exception as e:
            print(f"Error loading metadata for {repo_id}: {str(e)}")
            return None
    
    # Check version
    version_str = str(metadata._version).strip()
    if version_str not in ["2.0", "2.1"]:
        print(f"Skipping {repo_id}: version <{version_str}>")
        return None
        
    try:
        info = {
            "repo_id": repo_id,
            "username": repo_id.split('/')[0],
            "robot_type": metadata.robot_type,
            "total_episodes": metadata.total_episodes,
            "total_frames": metadata.total_frames,
            "fps": metadata.fps,
            "camera_keys": ','.join(metadata.camera_keys),  # Convert list to string
            "num_cameras": len(metadata.camera_keys),
            "video_keys": ','.join(metadata.video_keys),
            "has_video": len(metadata.video_keys) > 0,
            "total_tasks": metadata.total_tasks,
            "tasks": json.dumps(metadata.tasks),  # Convert dict to JSON string
            "is_sim": "sim_" in repo_id.lower(),
            "is_eval": "eval_" in repo_id.lower(),
            "features": ','.join(metadata.features.keys()),
            "chunks_size": metadata.chunks_size,
            "total_chunks": metadata.total_chunks,
            "version": metadata._version
        }
        return info
    except Exception as e:
        print(f"Error extracting metadata for {repo_id}: {str(e)}")
        return None

def extract_metadata_fn(tags, progress=gr.Progress()):
    progress(0) 
    api = HfApi()
    tags = tags.split(",") if tags else None
    datasets = api.list_datasets(tags=tags)
    repo_ids = [dataset.id for dataset in datasets]
    gr.Info(f"Found {len(repo_ids)} datasets with provided tags. Extracting metadata...")
    dataset_infos = []
    for i, repo_id in progress.tqdm(enumerate(repo_ids)):
        progress(i)
        info = analyze_dataset_metadata(repo_id)
        if info is not None:
            dataset_infos.append(info)
    
    # Convert to DataFrame and save to CSV and pickle
    df = pd.DataFrame(dataset_infos)
    csv_filename = "lerobot_datasets.csv"
    gr.Info(f"Dataset metadata extracted. Saving to {csv_filename}")
    df.to_csv(csv_filename, index=False)
    return df

def load_metadata_fn(file_explorer):
    gr.Info(f"Metadata loaded from {file_explorer}.")
    df = pd.read_csv(file_explorer)
    return df

def filter_tasks(tasks_json):
    """Filter out tasks that are too short and contain weird names"""
    try:
        tasks = json.loads(tasks_json)
        valid_tasks = [task for task in tasks.values() 
                    if task and isinstance(task, str) and len(task.strip()) > 10 
                    and len(task.split("_")) < 3 and "test" not in task.lower()]
        return len(valid_tasks) > 0
    except:
        return False

def filtering_metadata(
        df,
        num_episodes, 
        num_frames, 
        include_sim,
        robot_set,
        include_eval,
        filter_unlabeled_tasks
    ):
    all_data_number = len(df)
    filtered_datasets = df[
        (df['total_episodes'] >= num_episodes) &
        (df['total_frames'] >= num_frames) & 
        (df['has_video'] == True) &
        (df['is_sim'] == include_sim) &
        (df['robot_type'].isin(robot_set)) &
        ('test' not in df['repo_id'])
    ]
    if not include_eval:
        filtered_datasets = filtered_datasets[filtered_datasets['is_eval'] == False]
    if filter_unlabeled_tasks:   
        filtered_datasets['has_valid_tasks'] = filtered_datasets['tasks'].apply(filter_tasks)
        filtered_datasets = filtered_datasets[filtered_datasets['has_valid_tasks']]
    gr.Info(f"Filtering datasets from {all_data_number} to {len(filtered_datasets)}")
    return len(filtered_datasets), filtered_datasets["repo_id"].to_list(), filtered_datasets

class LeRobotAnalysisApp(object):
    def __init__(self, ui_obj):
        self.name = "LeRobot Analysis App"
        self.description = "Analyze LeRobot datasets"
        self.ui_obj = ui_obj
        
    # TODO
    def create_app(self):
        with self.ui_obj:
            gr.Markdown("Application to filter & analyze LeRobot datasets")
            filtered_data = gr.DataFrame(visible=False)
            with gr.Tabs():
                with gr.TabItem("1) Extract/Load Data"):
                    with gr.Row():
                        with gr.Column():
                            gr.Markdown("# Extract metadata from HF API")
                            gr.Markdown("Choose a set of **tags** (separated by a coma) to select the datasets to extract **metadata** from.")
                            gr.Markdown("The final metadata will be saved to a **CSV file**.")
                            tags = gr.Textbox(label="Tags", value="LeRobot", 
                                            placeholder="Enter tags separated by comma", 
                                            info="Enter tags separated by comma",
                                            lines=3)
                            btn_extract = gr.Button("Extract Data")
                            gr.Markdown("# OR Load from CSV")
                            gr.Markdown("If you already downloaded the metadata in CSV, you can directly load it here.")
                            file_explorer = gr.FileExplorer(label="Load CSV file", file_count="single")
                            btn_load = gr.Button("Load CSV Data")
                        with gr.Column():
                            out_data = gr.DataFrame()
                        btn_extract.click(extract_metadata_fn, [tags], [out_data])
                        btn_load.click(load_metadata_fn, [file_explorer], [out_data])
                with gr.TabItem("2) Filter Data"):
                    @gr.render(inputs=[out_data])
                    def filter_data(out_data):
                        if out_data.empty:
                            gr.Markdown("# Filtering data")
                            gr.Markdown("No data to display : please extract or load metadata first")
                        else:
                            df = out_data
                            min_eps = int(df['total_episodes'].min())
                            min_frames = int(df['total_frames'].min())
                            robot_types = list(set(df['robot_type'].to_list()))
                            robot_types.sort()
                            with gr.Row():
                                with gr.Column():
                                    gr.Markdown("# Filtering data")
                                    gr.Markdown("Filter the extracted datasets to your needs")
                                    data = gr.DataFrame(label="Dataset Metadata", value=out_data)
                                    is_sim = gr.Checkbox(label="Include simulation datasets", value=False)
                                    eps = gr.Number(label="Min episodes ", value=min_eps)
                                    frames = gr.Number(label="Min frames", value=min_frames)
                                    robot_type = gr.CheckboxGroup(label="Robot types", choices=robot_types)
                                    incl_eval = gr.Checkbox(label="Include evaluation datasets", value=False)
                                    filter_task = gr.Checkbox(label="Filter unlabeled tasks", value=True)
                                    btn_filter = gr.Button("Filter Data")
                                with gr.Column():
                                    out_num_d = gr.Number(label="Number of datasets", value=0)
                                    out_text = gr.Text(label="Dataset repo IDs", value="")
                            btn_filter.click(filtering_metadata, 
                                    inputs=[data, eps, frames, is_sim, robot_type, incl_eval, filter_task], 
                                    outputs=[out_num_d, out_text, filtered_data])
                with gr.TabItem("3) Analyze Data"):
                    @gr.render(inputs=[out_data, filtered_data])
                    def analyze_data(out_data, filtered_data):
                        if out_data.empty:
                            gr.Markdown("# Analyzing data")
                            gr.Markdown("No data to display : please extract or load metadata first")
                        else:
                            with gr.Row():
                                with gr.Column():
                                    if filtered_data.empty:
                                        gr.BarPlot(out_data, x="robot_type", y="total_episodes", title="Episodes per robot type")
                                    else:
                                        actions_df = self.extract_actions_from_tasks(filtered_data['tasks'])
                                        gr.BarPlot(filtered_data, x="robot_type", y="total_episodes", title="Episodes per robot type")
                                        gr.BarPlot(actions_df, title="Counting of each actions", 
                                                    x="actions",
                                                    y="count",
                                                    x_label="Actions",
                                                    y_label="Count of actions")
                                        
    def extract_actions_from_tasks(self, tasks):
        gr.Info("Extracting actions from tasks, it might take a while...")
        nlp = spacy.load("en_core_web_sm")
        actions = []
        for el in tasks:
            dict_tasks = ast.literal_eval(el)
            for id, task in dict_tasks.items():
                doc = nlp(task)
                for token in doc:
                    if token.pos_ == "VERB":
                        actions.append(token.lemma_)
        # Remove duplicates
        actions_unique = list(set(actions))
        count_actions = [actions.count(action) for action in actions_unique]

        return pd.DataFrame({"actions": actions_unique, "count": count_actions})
    
    def launch_ui(self):
        self.ui_obj.launch()

if __name__ == "__main__":
    app = gr.Blocks()
    ui = LeRobotAnalysisApp(app)
    ui.create_app()
    ui.launch_ui()