Spaces:
Running
Running
from dataclasses import dataclass, field | |
from lerobot.common.optim.optimizers import AdamWConfig | |
from lerobot.common.optim.schedulers import ( | |
CosineDecayWithWarmupSchedulerConfig, | |
) | |
from lerobot.configs.policies import PreTrainedConfig | |
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature | |
class PI0FASTConfig(PreTrainedConfig): | |
# Input / output structure. | |
n_obs_steps: int = 1 | |
chunk_size: int = 10 | |
n_action_steps: int = 5 | |
normalization_mapping: dict[str, NormalizationMode] = field( | |
default_factory=lambda: { | |
"VISUAL": NormalizationMode.IDENTITY, | |
"STATE": NormalizationMode.MEAN_STD, | |
"ACTION": NormalizationMode.MEAN_STD, | |
} | |
) | |
# Shorter state and action vectors will be padded | |
max_state_dim: int = 32 # 32 | |
max_action_dim: int = 32 # 32 | |
# Image preprocessing | |
resize_imgs_with_padding: tuple[int, int] = (224, 224) | |
interpolate_like_pi: bool = False | |
# Add empty images. Used by pi0_aloha_sim which adds the empty | |
# left and right wrist cameras in addition to the top camera. | |
empty_cameras: int = 0 | |
# Converts the joint and gripper values from the standard Aloha space to | |
# the space used by the pi internal runtime which was used to train the base model. | |
adapt_to_pi_aloha: bool = False | |
# Converts joint dimensions to deltas with respect to the current state before passing to the model. | |
# Gripper dimensions will remain in absolute values. | |
use_delta_joint_actions_aloha: bool = False | |
# Tokenizer | |
tokenizer_max_length: int = 48 | |
# Projector | |
proj_width: int = 1024 | |
# Decoding | |
max_decoding_steps: int = 256 | |
fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens | |
max_input_seq_len: int = 256 # 512 | |
# Utils | |
use_cache: bool = True | |
# Frozen parameters | |
freeze_vision_encoder: bool = True | |
freeze_lm_head: bool = True | |
# Training presets | |
optimizer_lr: float = 1e-4 | |
optimizer_betas: tuple[float, float] = (0.9, 0.95) | |
optimizer_eps: float = 1e-8 | |
optimizer_weight_decay: float = 1e-5 | |
scheduler_warmup_steps: int = 1_000 | |
scheduler_decay_steps: int = 30_000 | |
scheduler_decay_lr: float = 2.5e-6 | |
checkpoint_path: str = None | |
padding_side: str = "right" | |
precision: str = "bfloat16" | |
grad_clip_norm: float = 1 | |
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding. | |
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding. | |
relaxed_action_decoding: bool = True | |
def __post_init__(self): | |
super().__post_init__() | |
"""Input validation (not exhaustive).""" | |
if self.n_action_steps > self.chunk_size: | |
raise ValueError( | |
f"The chunk size is the upper bound for the number of action steps per model invocation. Got " | |
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." | |
) | |
if self.n_obs_steps != 1: | |
raise ValueError( | |
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" | |
) | |
def validate_features(self) -> None: | |
for i in range(self.empty_cameras): | |
key = f"observation.images.empty_camera_{i}" | |
empty_camera = PolicyFeature( | |
type=FeatureType.VISUAL, | |
shape=(3, 480, 640), | |
) | |
self.input_features[key] = empty_camera | |
def get_optimizer_preset(self) -> AdamWConfig: | |
return AdamWConfig( | |
lr=self.optimizer_lr, | |
betas=self.optimizer_betas, | |
eps=self.optimizer_eps, | |
weight_decay=self.optimizer_weight_decay, | |
grad_clip_norm=self.grad_clip_norm, | |
) | |
def get_scheduler_preset(self): | |
return CosineDecayWithWarmupSchedulerConfig( | |
peak_lr=self.optimizer_lr, | |
decay_lr=self.scheduler_decay_lr, | |
num_warmup_steps=self.scheduler_warmup_steps, | |
num_decay_steps=self.scheduler_decay_steps, | |
) | |
def observation_delta_indices(self) -> None: | |
return None | |
def action_delta_indices(self) -> list: | |
return list(range(self.chunk_size)) | |
def reward_delta_indices(self) -> None: | |
return None | |