Spaces:
Running
Running
File size: 4,516 Bytes
529ed6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
@PreTrainedConfig.register_subclass("pi0fast")
@dataclass
class PI0FASTConfig(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 10
n_action_steps: int = 5
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32 # 32
max_action_dim: int = 32 # 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)
interpolate_like_pi: bool = False
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Projector
proj_width: int = 1024
# Decoding
max_decoding_steps: int = 256
fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
max_input_seq_len: int = 256 # 512
# Utils
use_cache: bool = True
# Frozen parameters
freeze_vision_encoder: bool = True
freeze_lm_head: bool = True
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
checkpoint_path: str = None
padding_side: str = "right"
precision: str = "bfloat16"
grad_clip_norm: float = 1
# Allows padding/truncation of generated action tokens during detokenization to ensure decoding.
# In the original version, tensors of 0s were generated if shapes didn't match for stable decoding.
relaxed_action_decoding: bool = True
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
def validate_features(self) -> None:
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
|