# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field from lerobot.common.optim.optimizers import AdamWConfig from lerobot.common.optim.schedulers import ( CosineDecayWithWarmupSchedulerConfig, ) from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature @PreTrainedConfig.register_subclass("pi0") @dataclass class PI0Config(PreTrainedConfig): # Input / output structure. n_obs_steps: int = 1 chunk_size: int = 50 n_action_steps: int = 50 normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.IDENTITY, "STATE": NormalizationMode.MEAN_STD, "ACTION": NormalizationMode.MEAN_STD, } ) # Shorter state and action vectors will be padded max_state_dim: int = 32 max_action_dim: int = 32 # Image preprocessing resize_imgs_with_padding: tuple[int, int] = (224, 224) # Add empty images. Used by pi0_aloha_sim which adds the empty # left and right wrist cameras in addition to the top camera. empty_cameras: int = 0 # Converts the joint and gripper values from the standard Aloha space to # the space used by the pi internal runtime which was used to train the base model. adapt_to_pi_aloha: bool = False # Converts joint dimensions to deltas with respect to the current state before passing to the model. # Gripper dimensions will remain in absolute values. use_delta_joint_actions_aloha: bool = False # Tokenizer tokenizer_max_length: int = 48 # Projector proj_width: int = 1024 # Decoding num_steps: int = 10 # Attention utils use_cache: bool = True attention_implementation: str = "eager" # or fa2, flex # Finetuning settings freeze_vision_encoder: bool = True train_expert_only: bool = False train_state_proj: bool = True # Training presets optimizer_lr: float = 2.5e-5 optimizer_betas: tuple[float, float] = (0.9, 0.95) optimizer_eps: float = 1e-8 optimizer_weight_decay: float = 1e-10 scheduler_warmup_steps: int = 1_000 scheduler_decay_steps: int = 30_000 scheduler_decay_lr: float = 2.5e-6 # TODO: Add EMA def __post_init__(self): super().__post_init__() # TODO(Steven): Validate device and amp? in all policy configs? """Input validation (not exhaustive).""" if self.n_action_steps > self.chunk_size: raise ValueError( f"The chunk size is the upper bound for the number of action steps per model invocation. Got " f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." ) if self.n_obs_steps != 1: raise ValueError( f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" ) if self.use_delta_joint_actions_aloha: raise NotImplementedError( "`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot." ) def validate_features(self) -> None: # TODO: implement value error # if not self.image_features and not self.env_state_feature: # raise ValueError("You must provide at least one image or the environment state among the inputs.") for i in range(self.empty_cameras): key = f"observation.images.empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, 480, 640), ) self.input_features[key] = empty_camera def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig( lr=self.optimizer_lr, betas=self.optimizer_betas, eps=self.optimizer_eps, weight_decay=self.optimizer_weight_decay, ) def get_scheduler_preset(self): return CosineDecayWithWarmupSchedulerConfig( peak_lr=self.optimizer_lr, decay_lr=self.scheduler_decay_lr, num_warmup_steps=self.scheduler_warmup_steps, num_decay_steps=self.scheduler_decay_steps, ) @property def observation_delta_indices(self) -> None: return None @property def action_delta_indices(self) -> list: return list(range(self.chunk_size)) @property def reward_delta_indices(self) -> None: return None