arpitg1304 commited on
Commit
59aa4e5
·
verified ·
1 Parent(s): 4152dfd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +363 -0
app.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi
3
+
4
+ def build_teleop_command(
5
+ robot_type,
6
+ robot_port,
7
+ robot_id,
8
+ cam_index,
9
+ cam_width,
10
+ cam_height,
11
+ cam_fps,
12
+ teleop_type,
13
+ teleop_port,
14
+ teleop_id,
15
+ fps,
16
+ teleop_duration,
17
+ display_data,
18
+ ):
19
+ cam_cfg = (
20
+ "{ front: {type: opencv, index_or_path: %d, width: %d, height: %d, fps: %d}}"
21
+ % (cam_index, cam_width, cam_height, cam_fps)
22
+ )
23
+
24
+ cmd = [
25
+ "python -m lerobot.teleoperate",
26
+ f"--robot.type={robot_type}",
27
+ f"--robot.port={robot_port}",
28
+ f"--robot.id={robot_id}",
29
+ f"--robot.cameras=\"{cam_cfg}\"",
30
+ f"--teleop.type={teleop_type}",
31
+ f"--teleop.port={teleop_port}",
32
+ f"--teleop.id={teleop_id}",
33
+ f"--fps={fps}",
34
+ ]
35
+ if teleop_duration:
36
+ cmd.append(f"--teleop_time_s={teleop_duration}")
37
+ cmd.append(f"--display_data={'true' if display_data else 'false'}")
38
+ return " \\\n ".join(cmd)
39
+
40
+
41
+ def build_record_command(
42
+ robot_type,
43
+ robot_port,
44
+ robot_id,
45
+ cam_index,
46
+ cam_width,
47
+ cam_height,
48
+ cam_fps,
49
+ teleop_type,
50
+ teleop_port,
51
+ teleop_id,
52
+ display_data,
53
+ dataset_repo,
54
+ num_episodes,
55
+ single_task,
56
+ resume,
57
+ push_to_hub,
58
+ use_existing,
59
+ existing_ds,
60
+ ):
61
+ # if using existing dataset, override dataset_repo
62
+ if use_existing and existing_ds:
63
+ dataset_repo = existing_ds
64
+
65
+ camera_cfg = (
66
+ "{ front: {type: opencv, index_or_path: %d, width: %d, height: %d, fps: %d}}"
67
+ % (cam_index, cam_width, cam_height, cam_fps)
68
+ )
69
+ cmd = [
70
+ "python -m lerobot.record",
71
+ f"--robot.type={robot_type}",
72
+ f"--robot.port={robot_port}",
73
+ f"--robot.id={robot_id}",
74
+ f"--robot.cameras=\"{camera_cfg}\"",
75
+ f"--teleop.type={teleop_type}",
76
+ f"--teleop.port={teleop_port}",
77
+ f"--teleop.id={teleop_id}",
78
+ f"--display_data={'true' if display_data else 'false'}",
79
+ f"--dataset.repo_id={dataset_repo}",
80
+ f"--dataset.num_episodes={num_episodes}",
81
+ f"--dataset.single_task=\"{single_task}\"",
82
+ ]
83
+ cmd.append(f"--dataset.push_to_hub={'true' if push_to_hub else 'false'}")
84
+ if resume:
85
+ cmd.append("--resume=True")
86
+ return " \\\n ".join(cmd)
87
+
88
+
89
+ def build_train_command(
90
+ policy_path,
91
+ dataset_repo,
92
+ batch_size,
93
+ steps,
94
+ output_dir,
95
+ job_name,
96
+ device,
97
+ wandb_enable,
98
+ policy_repo_id,
99
+ ):
100
+ cmd = [
101
+ "python -m lerobot.scripts.train",
102
+ f"--policy.path={policy_path}",
103
+ f"--dataset.repo_id={dataset_repo}",
104
+ f"--batch_size={batch_size}",
105
+ f"--steps={steps}",
106
+ f"--output_dir={output_dir}",
107
+ f"--job_name={job_name}",
108
+ f"--policy.device={device}",
109
+ f"--wandb.enable={'true' if wandb_enable else 'false'}",
110
+ f"--policy.repo_id={policy_repo_id}" if policy_repo_id else "",
111
+ ]
112
+ # filter empty strings
113
+ cmd = [c for c in cmd if c]
114
+ return " \\\n ".join(cmd)
115
+
116
+
117
+ def build_eval_command(
118
+ robot_type,
119
+ robot_port,
120
+ robot_id,
121
+ cam_index,
122
+ cam_width,
123
+ cam_height,
124
+ cam_fps,
125
+ display_data,
126
+ dataset_repo,
127
+ num_episodes,
128
+ single_task,
129
+ policy_path,
130
+ resume,
131
+ ):
132
+ camera_cfg = (
133
+ "{ front: {type: opencv, index_or_path: %d, width: %d, height: %d, fps: %d}}"
134
+ % (cam_index, cam_width, cam_height, cam_fps)
135
+ )
136
+
137
+ cmd = [
138
+ "python -m lerobot.record",
139
+ f"--robot.type={robot_type}",
140
+ f"--robot.port={robot_port}",
141
+ f"--robot.id={robot_id}",
142
+ f"--robot.cameras=\"{camera_cfg}\"",
143
+ f"--display_data={'true' if display_data else 'false'}",
144
+ f"--dataset.repo_id={dataset_repo}",
145
+ f"--dataset.num_episodes={num_episodes}",
146
+ f"--dataset.single_task=\"{single_task}\"",
147
+ f"--policy.path={policy_path}",
148
+ ]
149
+ if resume:
150
+ cmd.append("--resume=True")
151
+ return " \\\n ".join(cmd)
152
+
153
+
154
+ # Helper to list datasets on Hugging Face for given username
155
+ def _list_remote_datasets(username: str):
156
+ try:
157
+ api = HfApi()
158
+ datasets = api.list_datasets(author=username)
159
+ return sorted([d.id for d in datasets])
160
+ except Exception:
161
+ return []
162
+
163
+
164
+ def build_ui():
165
+ with gr.Blocks(title="Lerobot Scripts Controller (Generate Only)") as demo:
166
+ hf_username_tb = gr.Textbox(label="HF Username", value="arpitg1304")
167
+ with gr.Tabs():
168
+ # Teleoperate Tab
169
+ with gr.TabItem("Teleoperate Robot"):
170
+ gr.Markdown("### Teleoperate robot with camera")
171
+ with gr.Row():
172
+ robot_type = gr.Textbox(label="Robot Type", value="so101_follower")
173
+ robot_port = gr.Textbox(label="Robot Port", value="/dev/ttyACM0")
174
+ robot_id = gr.Textbox(label="Robot ID", value="follower")
175
+ with gr.Row():
176
+ cam_index = gr.Number(label="Cam Index", value=0, precision=0)
177
+ cam_width = gr.Number(label="Width", value=640, precision=0)
178
+ cam_height = gr.Number(label="Height", value=480, precision=0)
179
+ cam_fps = gr.Number(label="FPS", value=30, precision=0)
180
+ with gr.Row():
181
+ teleop_type = gr.Textbox(label="Teleop Type", value="so101_leader")
182
+ teleop_port = gr.Textbox(label="Teleop Port", value="/dev/ttyACM1")
183
+ teleop_id = gr.Textbox(label="Teleop ID", value="leader")
184
+ with gr.Row():
185
+ fps = gr.Number(label="Loop FPS", value=60, precision=0)
186
+ teleop_duration = gr.Number(label="Duration (s)", value=60, precision=0)
187
+ display_data = gr.Checkbox(label="Display Data", value=True)
188
+ teleop_cmd = gr.Textbox(label="Generated Command", interactive=False, lines=16)
189
+
190
+ inputs_teleop = [
191
+ robot_type,
192
+ robot_port,
193
+ robot_id,
194
+ cam_index,
195
+ cam_width,
196
+ cam_height,
197
+ cam_fps,
198
+ teleop_type,
199
+ teleop_port,
200
+ teleop_id,
201
+ fps,
202
+ teleop_duration,
203
+ display_data,
204
+ ]
205
+ gr.Button("Generate Command").click(build_teleop_command, inputs_teleop, outputs=teleop_cmd)
206
+
207
+ # Record Data Tab
208
+ with gr.TabItem("Record Data"):
209
+ gr.Markdown("### Record episodes with policy")
210
+ with gr.Row():
211
+ robot_type2 = gr.Textbox(label="Robot Type", value="so101_follower")
212
+ robot_port2 = gr.Textbox(label="Robot Port", value="/dev/ttyACM0")
213
+ robot_id2 = gr.Textbox(label="Robot ID", value="follower")
214
+ with gr.Row():
215
+ cam_index2 = gr.Number(label="Cam Index", value=0, precision=0)
216
+ cam_width2 = gr.Number(label="Width", value=640, precision=0)
217
+ cam_height2 = gr.Number(label="Height", value=480, precision=0)
218
+ cam_fps2 = gr.Number(label="FPS", value=30, precision=0)
219
+ with gr.Row():
220
+ teleop_type_r = gr.Textbox(label="Teleop Type", value="so101_leader")
221
+ teleop_port_r = gr.Textbox(label="Teleop Port", value="/dev/ttyACM1")
222
+ teleop_id_r = gr.Textbox(label="Teleop ID", value="leader")
223
+ with gr.Row():
224
+ display_data2 = gr.Checkbox(label="Display Data", value=True)
225
+ dataset_repo = gr.Textbox(label="Dataset Repo", value="")
226
+ num_episodes = gr.Number(label="Num Episodes", value=2, precision=0)
227
+ single_task = gr.Textbox(label="Single Task", value="Grab the cylinder")
228
+ resume_chk = gr.Checkbox(label="Resume", value=False)
229
+ push_hub_chk = gr.Checkbox(label="Push to Hub", value=False)
230
+ with gr.Row():
231
+ use_existing = gr.Checkbox(label="Use Existing Dataset", value=False)
232
+ existing_dd = gr.Dropdown(label="User Datasets", choices=_list_remote_datasets(hf_username_tb.value), visible=False)
233
+ # Toggle dropdown visibility
234
+ use_existing.change(lambda f: gr.update(visible=f), inputs=use_existing, outputs=existing_dd)
235
+
236
+ # Update dataset choices when username changes
237
+ def _update_ds_choices(username):
238
+ return gr.update(choices=_list_remote_datasets(username))
239
+
240
+ hf_username_tb.change(_update_ds_choices, inputs=hf_username_tb, outputs=existing_dd)
241
+
242
+ record_cmd = gr.Textbox(label="Generated Command", interactive=False, lines=16)
243
+ inputs_rec = [
244
+ robot_type2,
245
+ robot_port2,
246
+ robot_id2,
247
+ cam_index2,
248
+ cam_width2,
249
+ cam_height2,
250
+ cam_fps2,
251
+ teleop_type_r,
252
+ teleop_port_r,
253
+ teleop_id_r,
254
+ display_data2,
255
+ dataset_repo,
256
+ num_episodes,
257
+ single_task,
258
+ resume_chk,
259
+ push_hub_chk,
260
+ use_existing,
261
+ existing_dd,
262
+ ]
263
+ gr.Button("Generate Command").click(build_record_command, inputs_rec, record_cmd)
264
+
265
+ # Train Policy Tab
266
+ with gr.TabItem("Train Policy"):
267
+ gr.Markdown("### Train Policy")
268
+ # Row 1: Policy path (full width)
269
+ policy_path_t = gr.Textbox(label="Base Policy Path", value="lerobot/smolvla_base")
270
+
271
+ # Row 2: Dataset + Device + WandB enable
272
+ with gr.Row():
273
+ dataset_repo_t = gr.Dropdown(
274
+ label="User Dataset",
275
+ choices=_list_remote_datasets(hf_username_tb.value),
276
+ scale=4,
277
+ )
278
+ device_t = gr.Dropdown(label="Device", choices=["cpu", "cuda"], value="cuda", scale=1)
279
+ wandb_chk = gr.Checkbox(label="W&B", value=True, scale=1)
280
+
281
+ # Row 3: Batch size & Steps
282
+ with gr.Row():
283
+ batch_size_t = gr.Number(label="Batch Size", value=16, precision=0, scale=1)
284
+ steps_t = gr.Number(label="Steps", value=20000, precision=0, scale=1)
285
+
286
+ # Row 4: Output dir (full width)
287
+ output_dir_t = gr.Textbox(label="Output Dir", value="outputs/train/my_smolvla_1")
288
+
289
+ # Row 5: Job name & Policy repo id
290
+ with gr.Row():
291
+ job_name_t = gr.Textbox(label="Job Name", value="smolvla_place_cylinder", scale=1)
292
+ policy_repo_t = gr.Textbox(label="Policy Repo ID (optional)", value="", scale=1)
293
+
294
+ # Update train dataset dropdown when username changes
295
+ hf_username_tb.change(_update_ds_choices, inputs=hf_username_tb, outputs=dataset_repo_t)
296
+
297
+ train_cmd = gr.Textbox(label="Generated Command", interactive=False, lines=16)
298
+ gr.Button("Generate Command").click(
299
+ build_train_command,
300
+ [
301
+ policy_path_t,
302
+ dataset_repo_t,
303
+ batch_size_t,
304
+ steps_t,
305
+ output_dir_t,
306
+ job_name_t,
307
+ device_t,
308
+ wandb_chk,
309
+ policy_repo_t,
310
+ ],
311
+ train_cmd,
312
+ )
313
+
314
+ # Evaluate Policy Tab
315
+ with gr.TabItem("Evaluate Policy"):
316
+ gr.Markdown("### Evaluate Policy")
317
+ with gr.Row():
318
+ robot_type_e = gr.Textbox(label="Robot Type", value="so101_follower")
319
+ robot_port_e = gr.Textbox(label="Robot Port", value="/dev/ttyACM0")
320
+ robot_id_e = gr.Textbox(label="Robot ID", value="follower")
321
+ with gr.Row():
322
+ cam_index_e = gr.Number(label="Cam Index", value=0, precision=0)
323
+ cam_width_e = gr.Number(label="Width", value=640, precision=0)
324
+ cam_height_e = gr.Number(label="Height", value=480, precision=0)
325
+ cam_fps_e = gr.Number(label="FPS", value=30, precision=0)
326
+ with gr.Row():
327
+ display_data_e = gr.Checkbox(label="Display Data", value=True)
328
+ dataset_repo_e = gr.Dropdown(label="User Dataset", choices=_list_remote_datasets(hf_username_tb.value))
329
+ num_episodes_e = gr.Number(label="Num Episodes", value=2, precision=0)
330
+ single_task_e = gr.Textbox(label="Single Task", value="place cylinder")
331
+ with gr.Row():
332
+ policy_path_e = gr.Textbox(label="Policy Path", value="arpitg1304/smolvla_place_cylinder")
333
+ resume_e = gr.Checkbox(label="Resume", value=True)
334
+ eval_cmd = gr.Textbox(label="Generated Command", interactive=False, lines=16)
335
+ # update evaluate dataset dropdown when username changes
336
+ hf_username_tb.change(_update_ds_choices, inputs=hf_username_tb, outputs=dataset_repo_e)
337
+
338
+ inputs_eval = [
339
+ robot_type_e,
340
+ robot_port_e,
341
+ robot_id_e,
342
+ cam_index_e,
343
+ cam_width_e,
344
+ cam_height_e,
345
+ cam_fps_e,
346
+ display_data_e,
347
+ dataset_repo_e,
348
+ num_episodes_e,
349
+ single_task_e,
350
+ policy_path_e,
351
+ resume_e,
352
+ ]
353
+ gr.Button("Generate Command").click(
354
+ build_eval_command,
355
+ inputs_eval,
356
+ eval_cmd,
357
+ )
358
+
359
+ return demo
360
+
361
+
362
+ if __name__ == "__main__":
363
+ build_ui().launch()