Spaces:
Running
Running
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
######################################################################################## | |
# Utilities | |
######################################################################################## | |
import logging | |
import time | |
import traceback | |
from contextlib import nullcontext | |
from copy import copy | |
from functools import cache | |
import rerun as rr | |
import torch | |
from deepdiff import DeepDiff | |
from termcolor import colored | |
from lerobot.common.datasets.image_writer import safe_stop_image_writer | |
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset | |
from lerobot.common.datasets.utils import get_features_from_robot | |
from lerobot.common.policies.pretrained import PreTrainedPolicy | |
from lerobot.common.robot_devices.robots.utils import Robot | |
from lerobot.common.robot_devices.utils import busy_wait | |
from lerobot.common.utils.utils import get_safe_torch_device, has_method | |
def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): | |
log_items = [] | |
if episode_index is not None: | |
log_items.append(f"ep:{episode_index}") | |
if frame_index is not None: | |
log_items.append(f"frame:{frame_index}") | |
def log_dt(shortname, dt_val_s): | |
nonlocal log_items, fps | |
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" | |
if fps is not None: | |
actual_fps = 1 / dt_val_s | |
if actual_fps < fps - 1: | |
info_str = colored(info_str, "yellow") | |
log_items.append(info_str) | |
# total step time displayed in milliseconds and its frequency | |
log_dt("dt", dt_s) | |
# TODO(aliberts): move robot-specific logs logic in robot.print_logs() | |
if not robot.robot_type.startswith("stretch"): | |
for name in robot.leader_arms: | |
key = f"read_leader_{name}_pos_dt_s" | |
if key in robot.logs: | |
log_dt("dtRlead", robot.logs[key]) | |
for name in robot.follower_arms: | |
key = f"write_follower_{name}_goal_pos_dt_s" | |
if key in robot.logs: | |
log_dt("dtWfoll", robot.logs[key]) | |
key = f"read_follower_{name}_pos_dt_s" | |
if key in robot.logs: | |
log_dt("dtRfoll", robot.logs[key]) | |
for name in robot.cameras: | |
key = f"read_camera_{name}_dt_s" | |
if key in robot.logs: | |
log_dt(f"dtR{name}", robot.logs[key]) | |
info_str = " ".join(log_items) | |
logging.info(info_str) | |
def is_headless(): | |
"""Detects if python is running without a monitor.""" | |
try: | |
import pynput # noqa | |
return False | |
except Exception: | |
print( | |
"Error trying to import pynput. Switching to headless mode. " | |
"As a result, the video stream from the cameras won't be shown, " | |
"and you won't be able to change the control flow with keyboards. " | |
"For more info, see traceback below.\n" | |
) | |
traceback.print_exc() | |
print() | |
return True | |
def predict_action(observation, policy, device, use_amp): | |
observation = copy(observation) | |
with ( | |
torch.inference_mode(), | |
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), | |
): | |
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension | |
for name in observation: | |
if "image" in name: | |
observation[name] = observation[name].type(torch.float32) / 255 | |
observation[name] = observation[name].permute(2, 0, 1).contiguous() | |
observation[name] = observation[name].unsqueeze(0) | |
observation[name] = observation[name].to(device) | |
# Compute the next action with the policy | |
# based on the current observation | |
action = policy.select_action(observation) | |
# Remove batch dimension | |
action = action.squeeze(0) | |
# Move to cpu, if not already the case | |
action = action.to("cpu") | |
return action | |
def init_keyboard_listener(): | |
# Allow to exit early while recording an episode or resetting the environment, | |
# by tapping the right arrow key '->'. This might require a sudo permission | |
# to allow your terminal to monitor keyboard events. | |
events = {} | |
events["exit_early"] = False | |
events["rerecord_episode"] = False | |
events["stop_recording"] = False | |
if is_headless(): | |
logging.warning( | |
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available." | |
) | |
listener = None | |
return listener, events | |
# Only import pynput if not in a headless environment | |
from pynput import keyboard | |
def on_press(key): | |
try: | |
if key == keyboard.Key.right: | |
print("Right arrow key pressed. Exiting loop...") | |
events["exit_early"] = True | |
elif key == keyboard.Key.left: | |
print("Left arrow key pressed. Exiting loop and rerecord the last episode...") | |
events["rerecord_episode"] = True | |
events["exit_early"] = True | |
elif key == keyboard.Key.esc: | |
print("Escape key pressed. Stopping data recording...") | |
events["stop_recording"] = True | |
events["exit_early"] = True | |
except Exception as e: | |
print(f"Error handling key press: {e}") | |
listener = keyboard.Listener(on_press=on_press) | |
listener.start() | |
return listener, events | |
def warmup_record( | |
robot, | |
events, | |
enable_teleoperation, | |
warmup_time_s, | |
display_data, | |
fps, | |
): | |
control_loop( | |
robot=robot, | |
control_time_s=warmup_time_s, | |
display_data=display_data, | |
events=events, | |
fps=fps, | |
teleoperate=enable_teleoperation, | |
) | |
def record_episode( | |
robot, | |
dataset, | |
events, | |
episode_time_s, | |
display_data, | |
policy, | |
fps, | |
single_task, | |
): | |
control_loop( | |
robot=robot, | |
control_time_s=episode_time_s, | |
display_data=display_data, | |
dataset=dataset, | |
events=events, | |
policy=policy, | |
fps=fps, | |
teleoperate=policy is None, | |
single_task=single_task, | |
) | |
def control_loop( | |
robot, | |
control_time_s=None, | |
teleoperate=False, | |
display_data=False, | |
dataset: LeRobotDataset | None = None, | |
events=None, | |
policy: PreTrainedPolicy = None, | |
fps: int | None = None, | |
single_task: str | None = None, | |
): | |
# TODO(rcadene): Add option to record logs | |
if not robot.is_connected: | |
robot.connect() | |
if events is None: | |
events = {"exit_early": False} | |
if control_time_s is None: | |
control_time_s = float("inf") | |
if teleoperate and policy is not None: | |
raise ValueError("When `teleoperate` is True, `policy` should be None.") | |
if dataset is not None and single_task is None: | |
raise ValueError("You need to provide a task as argument in `single_task`.") | |
if dataset is not None and fps is not None and dataset.fps != fps: | |
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") | |
timestamp = 0 | |
start_episode_t = time.perf_counter() | |
while timestamp < control_time_s: | |
start_loop_t = time.perf_counter() | |
if teleoperate: | |
observation, action = robot.teleop_step(record_data=True) | |
else: | |
observation = robot.capture_observation() | |
if policy is not None: | |
pred_action = predict_action( | |
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp | |
) | |
# Action can eventually be clipped using `max_relative_target`, | |
# so action actually sent is saved in the dataset. | |
action = robot.send_action(pred_action) | |
action = {"action": action} | |
if dataset is not None: | |
frame = {**observation, **action, "task": single_task} | |
dataset.add_frame(frame) | |
# TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon) | |
if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")): | |
for k, v in action.items(): | |
for i, vv in enumerate(v): | |
rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy())) | |
image_keys = [key for key in observation if "image" in key] | |
for key in image_keys: | |
rr.log(key, rr.Image(observation[key].numpy()), static=True) | |
if fps is not None: | |
dt_s = time.perf_counter() - start_loop_t | |
busy_wait(1 / fps - dt_s) | |
dt_s = time.perf_counter() - start_loop_t | |
log_control_info(robot, dt_s, fps=fps) | |
timestamp = time.perf_counter() - start_episode_t | |
if events["exit_early"]: | |
events["exit_early"] = False | |
break | |
def reset_environment(robot, events, reset_time_s, fps): | |
# TODO(rcadene): refactor warmup_record and reset_environment | |
if has_method(robot, "teleop_safety_stop"): | |
robot.teleop_safety_stop() | |
control_loop( | |
robot=robot, | |
control_time_s=reset_time_s, | |
events=events, | |
fps=fps, | |
teleoperate=True, | |
) | |
def stop_recording(robot, listener, display_data): | |
robot.disconnect() | |
if not is_headless() and listener is not None: | |
listener.stop() | |
def sanity_check_dataset_name(repo_id, policy_cfg): | |
_, dataset_name = repo_id.split("/") | |
# either repo_id doesnt start with "eval_" and there is no policy | |
# or repo_id starts with "eval_" and there is a policy | |
# Check if dataset_name starts with "eval_" but policy is missing | |
if dataset_name.startswith("eval_") and policy_cfg is None: | |
raise ValueError( | |
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})." | |
) | |
# Check if dataset_name does not start with "eval_" but policy is provided | |
if not dataset_name.startswith("eval_") and policy_cfg is not None: | |
raise ValueError( | |
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})." | |
) | |
def sanity_check_dataset_robot_compatibility( | |
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool | |
) -> None: | |
fields = [ | |
("robot_type", dataset.meta.robot_type, robot.robot_type), | |
("fps", dataset.fps, fps), | |
("features", dataset.features, get_features_from_robot(robot, use_videos)), | |
] | |
mismatches = [] | |
for field, dataset_value, present_value in fields: | |
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]) | |
if diff: | |
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") | |
if mismatches: | |
raise ValueError( | |
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) | |
) | |