Spaces:
Running
Running
#!/usr/bin/env python | |
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models | |
[Paper](https://arxiv.org/abs/2501.09747) | |
[Jax code](https://github.com/Physical-Intelligence/openpi) | |
Designed by Physical Intelligence. Ported from Jax by Hugging Face. | |
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): | |
```bash | |
python lerobot/scripts/train.py \ | |
--policy.path=lerobot/pi0fast_base \ | |
--dataset.repo_id=danaaubakirova/koch_test | |
``` | |
Example of training the pi0+FAST neural network with from scratch: | |
```bash | |
python lerobot/scripts/train.py \ | |
--policy.type=pi0fast \ | |
--dataset.repo_id=danaaubakirova/koch_test | |
``` | |
Example of using the pi0 pretrained model outside LeRobot training framework: | |
```python | |
policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base") | |
``` | |
""" | |
from collections import deque | |
from functools import partial | |
import numpy as np | |
import torch | |
import torch.nn.functional as F # noqa: N812 | |
from PIL import Image | |
from scipy.fft import idct | |
from torch import Tensor, nn | |
from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration | |
from transformers.cache_utils import HybridCache, StaticCache | |
from transformers.models.auto import CONFIG_MAPPING | |
from lerobot.common.constants import ACTION, OBS_ROBOT | |
from lerobot.common.policies.normalize import Normalize, Unnormalize | |
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig | |
from lerobot.common.policies.pretrained import PreTrainedPolicy | |
PRECISION = { | |
"float16": torch.float16, | |
"float32": torch.float32, | |
"bfloat16": torch.bfloat16, | |
} | |
def normalize(x, min_val, max_val): | |
return (x - min_val) / (max_val - min_val) | |
def unnormalize(x, min_val, max_val): | |
return x * (max_val - min_val) + min_val | |
def safe_arcsin(value): | |
# This ensures that the input stays within | |
# [−1,1] to avoid invalid values for arcsin | |
return torch.arcsin(torch.clamp(value, -1.0, 1.0)) | |
def aloha_gripper_to_angular(value): | |
# Aloha transforms the gripper positions into a linear space. The following code | |
# reverses this transformation to be consistent with pi0 which is pretrained in | |
# angular space. | |
# | |
# These values are coming from the Aloha code: | |
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED | |
value = unnormalize(value, min_val=0.01844, max_val=0.05800) | |
# This is the inverse of the angular to linear transformation inside the Interbotix code. | |
def linear_to_radian(linear_position, arm_length, horn_radius): | |
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) | |
return safe_arcsin(value) | |
# The constants are taken from the Interbotix code. | |
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) | |
# Normalize to [0, 1]. | |
# The values 0.4 and 1.5 were measured on an actual Trossen robot. | |
return normalize(value, min_val=0.4, max_val=1.5) | |
def aloha_gripper_from_angular(value): | |
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. | |
# Note that the units are still angular but the range is different. | |
# The values 0.4 and 1.5 were measured on an actual Trossen robot. | |
value = unnormalize(value, min_val=0.4, max_val=1.5) | |
# These values are coming from the Aloha code: | |
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE | |
return normalize(value, min_val=-0.6213, max_val=1.4910) | |
def aloha_gripper_from_angular_inv(value): | |
# Directly inverts the gripper_from_angular function. | |
value = unnormalize(value, min_val=-0.6213, max_val=1.4910) | |
return normalize(value, min_val=0.4, max_val=1.5) | |
class PI0FASTPolicy(PreTrainedPolicy): | |
"""Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot.""" | |
config_class = PI0FASTConfig | |
name = "pi0fast" | |
def __init__( | |
self, | |
config: PI0FASTConfig, | |
dataset_stats: dict[str, dict[str, Tensor]] | None = None, | |
): | |
""" | |
Args: | |
config: Policy configuration class instance or None, in which case the default instantiation of | |
the configuration class is used. | |
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected | |
that they will be passed with a call to `load_state_dict` before the policy is used. | |
""" | |
super().__init__(config) | |
config.validate_features() | |
self.config = config | |
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) | |
self.normalize_targets = Normalize( | |
config.output_features, config.normalization_mapping, dataset_stats | |
) | |
self.unnormalize_outputs = Unnormalize( | |
config.output_features, config.normalization_mapping, dataset_stats | |
) | |
self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") | |
self.model = PI0FAST(config) | |
self.reset() | |
def reset(self): | |
"""This should be called whenever the environment is reset.""" | |
self._action_queue = deque([], maxlen=self.config.n_action_steps) | |
def get_optim_params(self) -> dict: | |
return self.parameters() | |
def _pi_aloha_decode_state(self, state): | |
# Flip the joints. | |
for motor_idx in [1, 2, 8, 9]: | |
state[:, motor_idx] *= -1 | |
# Reverse the gripper transformation that is being applied by the Aloha runtime. | |
for motor_idx in [6, 13]: | |
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) | |
return state | |
def _pi_aloha_encode_actions(self, actions): | |
# Flip the joints. | |
for motor_idx in [1, 2, 8, 9]: | |
actions[:, :, motor_idx] *= -1 | |
# Reverse the gripper transformation that is being applied by the Aloha runtime. | |
for motor_idx in [6, 13]: | |
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) | |
return actions | |
def _pi_aloha_encode_actions_inv(self, actions): | |
# Flip the joints again. | |
for motor_idx in [1, 2, 8, 9]: | |
actions[:, :, motor_idx] *= -1 | |
# Reverse the gripper transformation that is being applied by the Aloha runtime. | |
for motor_idx in [6, 13]: | |
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) | |
return actions | |
def select_action(self, batch: dict[str, Tensor]) -> Tensor: | |
"""Select a single action given environment observations. | |
This method wraps `select_actions` in order to return one action at a time for execution in the | |
environment. It works by managing the actions in a queue and only calling `select_actions` when the | |
queue is empty. | |
""" | |
self.eval() | |
if self.config.adapt_to_pi_aloha: | |
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) | |
batch = self.normalize_inputs(batch) | |
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by | |
# querying the policy. | |
if len(self._action_queue) == 0: | |
actions = self.model.generate_actions(batch) | |
actions = actions[:, : self.config.n_action_steps] | |
original_action_dim = self.config.action_feature.shape[ | |
0 | |
] # self.config.max_action_dim # self.config.action_feature.shape[0] | |
actions = actions[:, :, :original_action_dim] | |
actions = self.unnormalize_outputs({"action": actions})["action"] | |
if self.config.adapt_to_pi_aloha: | |
actions = self._pi_aloha_encode_actions(actions) | |
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue | |
# effectively has shape (n_action_steps, batch_size, *), hence the transpose. | |
self._action_queue.extend(actions.transpose(0, 1)) | |
return self._action_queue.popleft() | |
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: | |
if self.config.adapt_to_pi_aloha: | |
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) | |
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) | |
batch = self.normalize_inputs(batch) | |
batch = self.normalize_targets(batch) | |
loss_dict = self.model.forward(batch) | |
return loss_dict["loss"], loss_dict | |
def block_causal_update_causal_mask( | |
attention_mask, | |
token_type_ids=None, | |
past_key_values=None, | |
cache_position=None, | |
input_tensor=None, | |
attn_implementation: str = "eager", | |
dtype: torch.dtype = "float32", | |
): | |
""" | |
Update the causal mask during training and generation. It can be customized to different attention masks. | |
""" | |
if attn_implementation == "flash_attention_2": | |
if attention_mask is not None and 0.0 in attention_mask: | |
return attention_mask | |
return None | |
using_static_cache = isinstance(past_key_values, StaticCache) | |
min_dtype = torch.finfo(dtype).min | |
if input_tensor is None: | |
input_tensor = attention_mask | |
inputs_lead_dim, sequence_length = input_tensor.shape[:2] | |
if using_static_cache or isinstance(past_key_values, HybridCache): | |
target_length = past_key_values.get_max_cache_shape() | |
else: | |
target_length = ( | |
attention_mask.shape[-1] | |
if isinstance(attention_mask, torch.Tensor) | |
else cache_position[0] + sequence_length + 1 | |
) | |
# Handle precomputed attention masks | |
if attention_mask is not None and attention_mask.dim() == 4: | |
return attention_mask | |
# Causal mask initialization | |
causal_mask = torch.full( | |
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device | |
) | |
# Standard causal masking (triu ensures tokens can only attend to past) | |
if sequence_length != 1: | |
causal_mask = torch.triu(causal_mask, diagonal=1) | |
# Apply block causal mask | |
if token_type_ids is not None: | |
token_type_ids = token_type_ids.to(causal_mask.device).bool() | |
cumsum = torch.cumsum(token_type_ids, dim=1) | |
block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None] | |
# Combine causal_mask with block-wise attention mask | |
causal_mask = torch.where(block_causal_mask, 0.0, causal_mask) | |
causal_mask = causal_mask[:, None, :, :] | |
else: | |
# Apply past cache position constraint | |
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( | |
-1, 1 | |
) | |
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) | |
else: | |
# Apply past cache position constraint | |
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( | |
-1, 1 | |
) | |
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) | |
if attention_mask is not None: | |
causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits | |
mask_length = attention_mask.shape[-1] | |
# Apply padding mask | |
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( | |
causal_mask.device | |
) | |
padding_mask = padding_mask == 0 | |
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( | |
padding_mask, min_dtype | |
) | |
return causal_mask | |
def prepare_inputs_for_generation( | |
# self, | |
input_ids, | |
past_key_values=None, | |
inputs_embeds=None, | |
cache_position=None, | |
position_ids=None, | |
pixel_values=None, | |
attention_mask=None, | |
token_type_ids=None, | |
use_cache=True, | |
num_logits_to_keep=None, | |
labels=None, | |
self=None, | |
**kwargs, | |
): | |
# create block causal attention | |
if cache_position[0] > 0 and input_ids.shape[1] > 0: | |
input_tensor = input_ids[:, -1:] | |
new_positions = ( | |
torch.ones( | |
(position_ids.shape[0], input_ids.shape[1]), | |
dtype=position_ids.dtype, | |
device=position_ids.device, | |
).cumsum(-1) | |
+ position_ids[:, -1:] | |
) | |
position_ids = torch.cat([position_ids, new_positions], dim=-1) | |
else: | |
input_tensor = inputs_embeds | |
attention_mask = block_causal_update_causal_mask( | |
attention_mask=attention_mask, | |
past_key_values=past_key_values, | |
cache_position=cache_position, | |
input_tensor=input_tensor, | |
token_type_ids=token_type_ids, | |
dtype=self.dtype, | |
attn_implementation=self.config.text_config._attn_implementation, | |
) | |
# Overwritten -- custom `position_ids` and `pixel_values` handling | |
model_inputs = self.language_model.prepare_inputs_for_generation( | |
input_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
cache_position=cache_position, | |
use_cache=use_cache, | |
num_logits_to_keep=num_logits_to_keep, | |
token_type_ids=token_type_ids, | |
**kwargs, | |
) | |
# Position_ids in Paligemma are 1-indexed | |
if model_inputs.get("position_ids") is not None: | |
model_inputs["position_ids"] += 1 | |
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore | |
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always | |
if cache_position[0] == 0: | |
model_inputs["pixel_values"] = pixel_values | |
is_training = token_type_ids is not None and labels is not None | |
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): | |
input_tensor = inputs_embeds if inputs_embeds is not None else input_ids | |
causal_mask = self._update_causal_mask( | |
attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training | |
) | |
model_inputs["attention_mask"] = causal_mask | |
return model_inputs | |
class PI0FAST(nn.Module): | |
def __init__(self, config: PI0FASTConfig): | |
super().__init__() | |
self.config = config | |
# TODO: move tokenizers in Policy | |
fast_tokenizer_path = "physical-intelligence/fast" | |
pi0_paligemma_path = "google/paligemma-3b-pt-224" | |
self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path) | |
self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path) | |
self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True) | |
self.fast_skip_tokens = self.config.fast_skip_tokens | |
self.max_input_seq_len = self.config.max_input_seq_len | |
self.action_horizon = self.config.chunk_size | |
self.action_dim = self.config.action_feature.shape[ | |
0 | |
] # self.config.max_action_dim # self.config.action_feature.shape[0] | |
precision = config.precision | |
torch_precision = PRECISION.get(precision, torch.float32) | |
self.pad_token_id = ( | |
self.paligemma_tokenizer.pad_token_id | |
if hasattr(self.paligemma_tokenizer, "pad_token_id") | |
else self.paligemma_tokenizer.eos_token_id | |
) | |
paligemma_config = CONFIG_MAPPING["paligemma"]( | |
transformers_version="4.48.1", | |
_vocab_size=257152, | |
bos_token_id=2, | |
eos_token_id=1, | |
hidden_size=2048, | |
image_token_index=257152, | |
model_type="paligemma", | |
pad_token_id=0, | |
projection_dim=2048, | |
text_config={ | |
"hidden_activation": "gelu_pytorch_tanh", | |
"hidden_size": 2048, | |
"intermediate_size": 16384, | |
"model_type": "gemma", | |
"num_attention_heads": 8, | |
"num_hidden_layers": 18, | |
"num_image_tokens": 256, | |
"num_key_value_heads": 1, | |
"torch_dtype": precision, | |
"vocab_size": 257152, | |
"_attn_implementation": "eager", | |
}, | |
vision_config={ | |
"hidden_size": 1152, | |
"intermediate_size": 4304, | |
"model_type": "siglip_vision_model", | |
"num_attention_heads": 16, | |
"num_hidden_layers": 27, | |
"num_image_tokens": 256, | |
"patch_size": 14, | |
"projection_dim": 2048, | |
"projector_hidden_act": "gelu_pytorch_tanh", | |
"torch_dtype": precision, | |
"vision_use_head": False, | |
}, | |
) | |
self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) | |
self.pi0_paligemma.prepare_inputs_for_generation = partial( | |
prepare_inputs_for_generation, self=self.pi0_paligemma | |
) | |
# change important stuff in bf16 | |
params_to_change_dtype = [ | |
"language_model", | |
"vision_tower", | |
"multi_modal", | |
] | |
for name, param in self.pi0_paligemma.named_parameters(): | |
if any(selector in name for selector in params_to_change_dtype): | |
param.data = param.data.to(dtype=torch_precision) | |
self.set_requires_grad() | |
self.image_keys = self.config.image_features.keys() | |
self.ignore_index = self.pi0_paligemma.config.ignore_index | |
self.padding_side = self.config.padding_side | |
def set_requires_grad(self): | |
if self.config.freeze_vision_encoder: | |
self.pi0_paligemma.vision_tower.eval() | |
for params in self.pi0_paligemma.vision_tower.parameters(): | |
params.requires_grad = False | |
# To avoid unused params issue with distributed training | |
if self.config.freeze_lm_head: | |
for name, params in self.pi0_paligemma.named_parameters(): | |
if "embed_tokens" in name: # lm heads and embedding layer are tied | |
params.requires_grad = False | |
def embed_tokens(self, tokens: torch.Tensor): | |
return self.pi0_paligemma.language_model.model.embed_tokens(tokens) | |
def prepare_inputs_for_generation(self, *args, **kwargs): | |
return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs) | |
def prepare_images(self, batch): | |
"""Preprocess LeRobot batch into Pi0 inputs""" | |
images = [] | |
img_masks = [] | |
present_img_keys = [key for key in self.image_keys if key in batch] | |
if len(present_img_keys) == 0: | |
raise ValueError( | |
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" | |
) | |
# Preprocess image features present in the batch | |
num_empty_cameras = 0 | |
for key in self.image_keys: | |
if key in present_img_keys: | |
img = batch[key] | |
if self.config.resize_imgs_with_padding is not None: | |
img = resize_with_pad( | |
img, | |
*self.config.resize_imgs_with_padding, | |
pad_value=0, | |
interpolate_like_pi=self.config.interpolate_like_pi, | |
) | |
# Normalize from range [0,1] to [-1,1] as expacted by siglip | |
img = img * 2.0 - 1.0 | |
bsize = img.shape[0] | |
device = img.device | |
mask = torch.ones(bsize, dtype=torch.bool, device=device) | |
else: | |
if num_empty_cameras >= self.config.empty_cameras: | |
continue | |
img = torch.ones_like(img) * -1 | |
bsize = img.shape[0] | |
device = img.device | |
mask = torch.ones(bsize, dtype=torch.bool, device=device) | |
num_empty_cameras += 1 | |
images.append(img) | |
img_masks.append(mask) | |
return images, img_masks | |
def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor: | |
mins = actions.amin(dim=(1, 2), keepdim=True) # [0] | |
maxs = actions.amax(dim=(1, 2), keepdim=True) # [0] | |
return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1 | |
def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: | |
out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens | |
return out | |
def fast_tokenizer_wrapper(self, actions_norm): | |
""" | |
A wrapper for self.fast_tokenizer that ensures batch processing, | |
conversion to PyTorch tensors, and returns a dictionary without padding. | |
""" | |
batch_tokens = self.fast_tokenizer(actions_norm) | |
fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt") | |
return fast_out | |
def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor: | |
token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) | |
# Compute cumulative sum mask | |
cumsum_mask = (padded_mask != 0).cumsum(dim=1) | |
# Suffix block (everything after prefix_len) | |
suffix_mask = cumsum_mask > prefix_len | |
token_type_ids = suffix_mask | |
return token_type_ids | |
def create_input_tokens(self, state, lang_text, actions=None): | |
bsize = state.shape[0] | |
device = state.device | |
bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1] | |
discretized = torch.bucketize(state, bins) - 1 | |
discretized = discretized[:, :32] | |
prefix_texts = [] | |
state_text = [] | |
for txt, disc in zip(lang_text, discretized, strict=False): | |
cleaned = txt.lower().strip().replace("_", " ") | |
state_str = " ".join(str(val.item()) for val in disc) | |
prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") | |
state_text.append(f"State: {state_str};\n") | |
prefix_out = self.paligemma_tokenizer( | |
prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False | |
) | |
prefix_ids = prefix_out["input_ids"].to(device) | |
prefix_mask = prefix_out["attention_mask"].to(device) | |
prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu() | |
if actions is not None: | |
actions_norm = self.normalize_actions(actions) | |
actions_pad = F.pad( | |
actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0 | |
)[:, :, : self.config.max_action_dim] | |
fast_out = self.fast_tokenizer_wrapper( | |
actions_pad.cpu(), | |
) | |
act_ids = fast_out["input_ids"] | |
act_mask = fast_out["attention_mask"].to(device) | |
act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device) | |
# Replace action with 0 to pad tokens | |
act_ids = torch.where( | |
act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens, | |
self.pad_token_id, | |
act_ids, | |
) | |
eos_token = torch.tensor( | |
[self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device | |
).expand(bsize, -1) | |
eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) | |
bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt") | |
bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device) | |
bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device) | |
act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) | |
act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) | |
act_mask = act_mask.to(device) | |
else: | |
act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device) | |
act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device) | |
final_ids = torch.cat([prefix_ids, act_ids], dim=1) | |
final_mask = torch.cat([prefix_mask, act_mask], dim=1) | |
batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()} | |
# Use tokenizer pad function | |
padded_output = self.paligemma_tokenizer.pad( | |
batch_inputs, padding="longest", max_length=180, return_tensors="pt" | |
) | |
padded_mask = padded_output["attention_mask"] | |
# define tensor of padding lengths | |
att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens | |
token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens) | |
padded_output["padded_mask"] = padded_output.pop("attention_mask") | |
padded_output["attention_mask"] = att_mask | |
# loss is computed not on prefix, and not on padding | |
padded_output["loss_mask"] = att_mask & padded_output["padded_mask"] | |
padded_output["token_type_ids"] = token_type_ids | |
return padded_output | |
def shift_padding_side( | |
self, | |
tokens: torch.Tensor, | |
ar_mask: torch.Tensor, | |
padding_mask: torch.Tensor, | |
loss_mask: torch.Tensor, | |
targets: torch.Tensor, | |
token_type_ids: torch.Tensor, | |
padding_side: str = "right", | |
) -> tuple[torch.Tensor]: | |
if padding_side not in ["right", "left"]: | |
return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids | |
new_tokens = torch.empty_like(tokens) | |
new_ar_masks = torch.empty_like(ar_mask) | |
new_padding_mask = torch.empty_like(padding_mask) | |
new_loss_mask = torch.empty_like(loss_mask) | |
new_targets = torch.empty_like(targets) | |
new_token_type_ids = torch.empty_like(token_type_ids) | |
batch_size = tokens.shape[0] | |
for i in range(batch_size): | |
padding_indices = torch.where(padding_mask[i] == 0)[0] | |
non_padding_indices = torch.where(padding_mask[i] == 1)[0] | |
if padding_side == "left": | |
new_indices = torch.cat((padding_indices, non_padding_indices), dim=0) | |
else: | |
new_indices = torch.cat((non_padding_indices, padding_indices), dim=0) | |
new_tokens[i] = tokens[i].index_select(0, new_indices) | |
new_ar_masks[i] = ar_mask[i].index_select(0, new_indices) | |
new_padding_mask[i] = padding_mask[i].index_select(0, new_indices) | |
new_loss_mask[i] = loss_mask[i].index_select(0, new_indices) | |
new_targets[i] = targets[i].index_select(0, new_indices) | |
new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices) | |
return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids | |
def forward(self, batch: dict[str, Tensor]): | |
device = batch[OBS_ROBOT].device | |
# TODO: keep like this or move to the policy .forward | |
images, img_masks = self.prepare_images(batch) | |
padded_outs = self.create_input_tokens( | |
state=batch[OBS_ROBOT], | |
lang_text=batch["task"], | |
actions=batch[ACTION], | |
) | |
embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs( | |
images, | |
img_masks, | |
padded_outs["input_ids"], | |
padded_outs["padded_mask"], | |
padded_outs["attention_mask"], | |
padded_outs["loss_mask"], | |
padded_outs["token_type_ids"], | |
padding_side=self.padding_side, | |
) | |
position_ids = torch.cumsum(pad_masks, dim=1) - 1 | |
token_type_ids = token_type_ids.to(dtype=torch.int64) | |
past_seen_tokens = 0 | |
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device) | |
pad_masks = block_causal_update_causal_mask( | |
attention_mask=pad_masks, | |
past_key_values=None, | |
cache_position=cache_position, | |
input_tensor=embs, | |
token_type_ids=token_type_ids, | |
dtype=self.pi0_paligemma.dtype, | |
attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation, | |
) | |
outputs = self.pi0_paligemma.forward( | |
input_ids=None, | |
token_type_ids=None, | |
attention_mask=pad_masks, | |
position_ids=position_ids, | |
past_key_values=None, | |
inputs_embeds=embs, | |
use_cache=False, | |
labels=None, | |
) | |
logits = outputs.logits | |
loss_fct = nn.CrossEntropyLoss(reduction="none") | |
# Shift left for next-step prediction | |
logits = logits[:, :-1, :] | |
targets = targets[:, 1:].to(device) # Shift targets | |
loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape | |
# Compute per-token loss | |
token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1)) | |
# Apply loss mask | |
token_loss = token_loss * loss_mask.reshape(-1) | |
# Compute final loss | |
loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1) | |
# Return loss dictionary | |
loss_dict = {"ce_loss": loss.item(), "loss": loss} | |
return loss_dict | |
def decode_actions_with_fast( | |
self, | |
tokens: list[list[int]], | |
*, | |
time_horizon: int | None = None, | |
action_dim: int | None = None, | |
relaxed_decoding: bool = True, | |
) -> np.array: | |
""" | |
Adapt original decoding in FAST to always return actions instead of zeros. | |
""" | |
self.time_horizon = ( | |
time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon | |
) | |
self.action_dim = ( | |
action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim | |
) | |
# Cache the time horizon and action dimension for the next call | |
self.called_time_horizon = self.time_horizon | |
self.called_action_dim = self.action_dim | |
assert self.time_horizon is not None and self.action_dim is not None, ( | |
"Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." | |
) | |
decoded_actions = [] | |
for token in tokens: | |
try: | |
decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token) | |
decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token | |
if relaxed_decoding: | |
# Expected sequence length | |
expected_seq_len = self.time_horizon * self.action_dim | |
diff = expected_seq_len - decoded_dct_coeff.shape[0] | |
# Apply truncation if too long | |
if diff < 0: | |
decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right | |
# Apply padding if too short | |
elif diff > 0: | |
decoded_dct_coeff = np.pad( | |
decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 | |
) | |
decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) | |
assert decoded_dct_coeff.shape == ( | |
self.time_horizon, | |
self.action_dim, | |
), ( | |
f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" | |
) | |
except Exception as e: | |
print(f"Error decoding tokens: {e}") | |
print(f"Tokens: {token}") | |
decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim)) | |
decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho")) | |
return np.stack(decoded_actions) | |
def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor: | |
""" | |
Extracts actions from predicted output tokens using the FAST model. | |
Args: | |
tokens (torch.Tensor): The input tensor of tokenized outputs. | |
action_horizon (int): The number of timesteps for actions. | |
action_dim (int): The dimensionality of each action. | |
Returns: | |
torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim). | |
""" | |
# Decode predicted output tokens | |
decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True) | |
cleaned_tokens = [ | |
tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip() | |
for tokens_sequence in decoded_tokens | |
] | |
raw_action_tokens = [ | |
self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False) | |
for sample_tokens in cleaned_tokens | |
] # something like this should be robust #looks good | |
action_tokens = [ | |
self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens | |
] | |
# returns the tensor of decoded actions per sample in a list | |
decoded_actions = [ | |
torch.tensor( | |
self.decode_actions_with_fast( | |
tok.tolist(), | |
time_horizon=action_horizon, | |
action_dim=action_dim, | |
relaxed_decoding=self.config.relaxed_action_decoding, | |
), | |
device=tokens.device, | |
).squeeze(0) | |
for tok in action_tokens | |
] | |
return torch.stack( | |
decoded_actions, | |
dim=0, | |
) | |
def generate_actions(self, batch: dict[str, Tensor]): | |
# TODO: keep like this or move to the policy .forward | |
images, img_masks = self.prepare_images(batch) | |
padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None) | |
embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs( | |
images, | |
img_masks, | |
padded_outs["input_ids"], | |
padded_outs["padded_mask"], | |
padded_outs["attention_mask"], | |
padded_outs["loss_mask"], | |
padded_outs["token_type_ids"], | |
padding_side="left", | |
) | |
token_type_ids = token_type_ids.to(dtype=torch.int64) | |
prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1 | |
output_tokens = self.pi0_paligemma.generate( | |
input_ids=None, | |
attention_mask=pad_masks, | |
position_ids=prefix_position_ids, | |
past_key_values=None, | |
inputs_embeds=embs, | |
use_cache=self.config.use_cache, | |
max_new_tokens=self.config.max_decoding_steps, | |
do_sample=False, | |
num_beams=1, | |
token_type_ids=token_type_ids, | |
) | |
actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim) | |
return actions | |
def embed_image(self, image: torch.Tensor): | |
return self.pi0_paligemma.get_image_features(image) | |
def embed_inputs( | |
self, | |
images, | |
img_masks, | |
tokens, | |
pad_mask, | |
ar_mask, | |
loss_mask, | |
token_type_ids, | |
padding_side: str = "right", | |
): | |
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty | |
# images are a list of same size | |
# vectorizing everything! | |
device = images[0].device | |
image_embedding_dim = images[0].shape[-1] # TODO should be from self.config | |
all_images = torch.stack(images, dim=1).to(device) | |
b, n, c, h, w = all_images.shape | |
all_images = all_images.view(b * n, c, h, w) | |
embedded = self.embed_image(all_images).to(device) | |
b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions | |
m = b_n // b # Compute the number of images per sample dynamically | |
# Reshape dynamically | |
embedded = embedded.view(b, m, p, image_embedding_dim) | |
tokens_embs = self.embed_tokens(tokens.to(device)) | |
img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device) | |
num_img_emb = embedded.shape[2] | |
img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1) | |
img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) | |
image_target_tokens = ( | |
torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id | |
).reshape(b, -1) | |
image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) | |
embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D) | |
embs = torch.cat([embedded, tokens_embs], dim=1).to(device) | |
pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1) | |
att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1) | |
loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1) | |
targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1) | |
token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1) | |
# Shift pad tokens to the left (.generate()) or right (.train()) | |
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side( | |
embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side | |
) | |
targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets) | |
return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids | |
def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True): | |
# assume no-op when width height fits already | |
if img.ndim != 4: | |
raise ValueError(f"(b,c,h,w) expected, but {img.shape}") | |
cur_height, cur_width = img.shape[2:] | |
ratio = max(cur_width / width, cur_height / height) | |
resized_height = int(cur_height / ratio) | |
resized_width = int(cur_width / ratio) | |
if interpolate_like_pi: | |
img = (img * 255.0).to(dtype=torch.uint8) | |
img = img.permute(0, 2, 3, 1) | |
original_device = img.device | |
img = img.to(device="cpu").numpy() | |
imgs = [] | |
for sub_img in img: | |
sub_img = Image.fromarray(sub_img) | |
resized_img = sub_img.resize((resized_width, resized_height), resample=2) | |
resized_img = torch.from_numpy(np.array(resized_img)) | |
imgs.append(resized_img) | |
img = torch.stack(imgs, dim=0) | |
img = img.permute(0, 3, 1, 2) | |
resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0 | |
else: | |
resized_img = F.interpolate( | |
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False | |
) | |
pad_height = max(0, int(height - resized_height)) | |
pad_width = max(0, int(width - resized_width)) | |
# pad on left and top of image | |
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) | |
return padded_img | |