trajectoryvla_new / prismatic_model.py
Sombit's picture
Upload TrajectoryVLA
bbe26d0 verified
"""
modeling_prismatic.py
Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
logic in `prismatic.models.vlms.prismatic.py`.
Note =>> for the time being, not adding the custom HF "docstring" formatting.
References [LLaVa, IDEFICS-2]:
=> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
=> https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
"""
import logging
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
from functools import cached_property
# from barrel.components.nn.layers.nerf_pos_embed import NeRFPositionalEmbedding
import numpy as np
import timm
import tokenizers
import torch
import torch.nn as nn
import transformers
from timm.models.vision_transformer import LayerScale
from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
import collections
import math
from barrel.pipes.vlams.extern.prismatic_config import OpenVLAConfig, PrismaticConfig , TrajectoryVLAConfig, WaypointTokenizer
# from barrel.pipes.vlams.models.control.token_proj import TokenProjector
from barrel.pipes.vlams.extern.datatypes import *
from barrel.pipes.vlams.extern.detr import *
from IPython import embed
import os
from PIL import Image
from pathlib import Path
from torch.amp.autocast_mode import autocast # Corrected import for latest PyTorch
from scipy.spatial.transform import Rotation as R
ht_token_path = Path(".hf_token")
HF_TOKEN = ht_token_path.read_text().strip() if isinstance(ht_token_path, Path) else hf_token_path
# Get Logger
logger = logging.getLogger(__name__)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# === PyTorch/HuggingFace Default IGNORE_INDEX (for CrossEntropyLoss labels)
IGNORE_INDEX = -100
# === Utility Functions for Monkey-Patching ===
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
result = fn(*args, **kwargs)
return result[0] if isinstance(result, tuple) else result
return wrapper
# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
def ls_apply_patch(ls_module: LayerScale):
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
del ls_module.gamma
# === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
class PrismaticVisionBackbone(nn.Module):
def __init__(
self,
use_fused_vision_backbone: bool,
image_sizes: List[int],
timm_model_ids: List[str],
timm_override_act_layers: List[Optional[str]],
) -> None:
super().__init__()
self.use_fused_vision_backbone = use_fused_vision_backbone
# [Contract] Validate number of (fused) vision backbones, create "alpha" featurizer and Instantiate
# =>> Note :: Monkey-Patch the `forward()` function of the backbone to ensure FSDP-compatibility
# Hardcodes `get_intermediate_layers` to return the **SECOND-TO-LAST** layer patches!
assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!"
self.dino_featurizer = timm.create_model(
timm_model_ids[0],
pretrained=True,
num_classes=0,
img_size=image_sizes[0],
act_layer=timm_override_act_layers[0],
)
self.dino_featurizer.eval()
self.embed_dim = self.dino_featurizer.embed_dim
# If `use_fused_vision_backbone` =>> create "beta" featurizer
# if self.use_fused_vision_backbone:
self.siglip_featurizer = timm.create_model(
timm_model_ids[1],
pretrained=True,
num_classes=0,
img_size=image_sizes[1],
act_layer=timm_override_act_layers[1],)
self.siglip_featurizer.eval()
self.dino_featurizer.forward = partial(
self.dino_featurizer.forward_intermediates,
indices=[len(self.dino_featurizer.blocks) - 2],
return_prefix_tokens=False,
norm=False,
stop_early=True,
output_fmt='NLC',
intermediates_only=True,
)
self.siglip_featurizer.forward = partial(
self.siglip_featurizer.forward_intermediates,
indices=[len(self.siglip_featurizer.blocks) - 2],
return_prefix_tokens=False,
norm=False,
stop_early=True,
output_fmt='NLC',
intermediates_only=True,
)
self.embed_dim += self.siglip_featurizer.embed_dim
def forward(self, pixel_values) -> torch.Tensor:
"""Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack."""
if not self.use_fused_vision_backbone:
return self.featurizer(pixel_values)
# Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
# img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
img = pixel_values['dino']
img_fused = pixel_values['siglip']
patches, patches_fused = self.dino_featurizer(img)[0], self.siglip_featurizer(img_fused)[0]
return torch.cat([patches, patches_fused], dim=2)
class PrismaticProjector(nn.Module):
def __init__(self, use_fused_vision_backbone, vision_dim: int, llm_dim: int) -> None:
super().__init__()
self.initial_projection_dim = vision_dim * 4
self.projector = torch.nn.Sequential(
torch.nn.Linear(vision_dim, self.initial_projection_dim, bias=True),
torch.nn.GELU(),
torch.nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
torch.nn.GELU(),
torch.nn.Linear(llm_dim, llm_dim, bias=True),
)
def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
return self.projector(fused_img_patches)
# === Main HF Class Definitions ===
@dataclass
class PrismaticCausalLMOutputWithPast(ModelOutput):
"""Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Additions for VLMs
projector_features: Optional[torch.FloatTensor] = None
class PrismaticPreTrainedModel(PreTrainedModel):
config_class: PrismaticConfig
base_model_prefix: str = "model"
supports_gradient_checkpointing: bool = True
_no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
_skip_keys_device_placement: str = "past_key_values"
_supports_flash_attn_2: bool = True
def _init_weights(self, module: nn.Module) -> None:
# Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
# => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
# https://github.com/TRI-ML/prismatic-vlms
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@property
def _supports_sdpa(self) -> bool:
"""Check LLM supports SDPA Attention"""
return self.language_model._supports_sdpa
class LLMBackbone(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.llm : AutoModelForCausalLM
self.tokenizer = self._create_tokenizer()
def _create_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
# Load (Fast) Tokenizer
print(f"Loading (Fast) Tokenizer via the AutoTokenizer API")
tokenizer = transformers.AutoTokenizer.from_pretrained(
self.config['hf_model_id'],
model_max_length=self.config['llm_max_length'],
token=HF_TOKEN,
padding_side="right",
)
# Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input
# starts with a <BOS> token unless `add_special_tokens = False`; for these models, we empirically
# find that adding image patches *after* the BOS leads to much better performance.
#
# As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this
# line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to
# override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py`
# and VLM `forward()` logic!
SPECIAL_CASES = {
# Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>"
# =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that
# this works well with base LLM generation.
# =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes.
"microsoft/phi-2",
}
if self.config['hf_model_id'] not in SPECIAL_CASES:
# Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral!
assert (
tokenizer("Test 123", add_special_tokens=True).input_ids[0] == tokenizer.bos_token_id
) and (
tokenizer("Test 123", add_special_tokens=False).input_ids[0] != tokenizer.bos_token_id
), f"Default Tokenizer of type `{type(tokenizer)}` does not automatically prefix inputs with BOS token!\n"
return tokenizer
class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
def __init__(self, config: PrismaticConfig) -> None:
super().__init__(config)
# [Validation] Lightweight Validate on `config` Fields + Dependency Versions
if config.use_fused_vision_backbone is None:
raise ValueError("Missing config field `use_fused_vision_backbone`")
# if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
# raise NotImplementedError(
# "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
# "if you urgently need support for latest TIMM versions."
# )
# if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
# logger.warning(
# f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
# f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
# f"there might be inference-time regressions due to dependency changes. If in doubt, please"
# f"use the above versions."
# )
# Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
self.vision_backbone = PrismaticVisionBackbone(
config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
)
# Create Multimodal Projector
self.projector = PrismaticProjector(
config.use_fused_vision_backbone,
vision_dim=self.vision_backbone.embed_dim,
llm_dim=config.text_config.hidden_size,
)
# Instantiate LLM Backbone
self.llm_backbone = LLMBackbone({'hf_model_id': config.hf_llm_id, 'llm_max_length': config.llm_max_length, "pad_token_id" :32000,
"pad_to_multiple_of" : 64,})
# self.llm_backbone.llm = AutoModelForCausalLM.from_config(
# config.text_config, attn_implementation="flash_attention_2"
# )
self.llm_backbone.llm = AutoModelForCausalLM.from_pretrained(
'meta-llama/Llama-2-7b-hf',
token=HF_TOKEN,
attn_implementation='flash_attention_2',
# The following parameters are set to prevent `UserWarnings` from HF; we want greedy decoding!
do_sample=False,
temperature=1.0,
use_cache=False,
top_p=1.0, )
self.llm_backbone.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
self.llm_backbone.llm.config.pad_token_id = self.llm_backbone.tokenizer.pad_token_id
self.llm_backbone.llm.resize_token_embeddings(len(self.llm_backbone.tokenizer), pad_to_multiple_of=64)
# self.llm_backbone.llm.config.pad_token_id = self.llm_backbone.tokenizer.pad_token_id
# self.llm_backbone.llm.resize_token_embeddings(len(self.llm_backbone.tokenizer), pad_to_multiple_of=64)
# self.resize_token_embeddings(32001,64)
self.vocab_size = config.text_config.vocab_size
self.pad_token_id = config.pad_token_id
# HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
self.post_init()
# === `PreTrainedModel` Boilerplate ===
def get_input_embeddings(self) -> nn.Module:
return self.llm_backbone.llm.get_input_embeddings()
def set_input_embeddings(self, value: nn.Module) -> None:
self.llm_backbone.llm.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.llm_backbone.llm.get_output_embeddings()
def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
self.llm_backbone.llm.set_output_embeddings(new_embeddings)
def get_decoder(self) -> nn.Module:
return self.llm_backbone.llm.get_decoder()
def set_decoder(self, decoder: nn.Module) -> None:
self.llm_backbone.llm.set_decoder(decoder)
def tie_weights(self) -> None:
self.llm_backbone.llm.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
# def resize_token_embeddings(
# self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
# ) -> nn.Embedding:
# updated_embeddings = self.llm_backbone.llm.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
# # Update config/instance variables
# self.config.text_config.vocab_size = updated_embeddings.num_embeddings
# self.vocab_size = updated_embeddings.num_embeddings
# return updated_embeddings
# === Core Prismatic VLM `forward()` Logic ===
def forward(
self,
input_ids: Optional[torch.LongTensor] ,
attention_mask: Optional[torch.Tensor],
# pixel_values: Optional[torch.FloatTensor] = None,
pixel_values: Dict[str, torch.Tensor] = {},
labels: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_projector_features: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Any,
) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
"""Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_projector_features = output_projector_features if output_projector_features is not None else False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
use_cache = use_cache and not self.training
# Instantiate Placeholder for Projector Features
projected_patch_embeddings = None
# Note :: We only support forward passes with the following cases:
# => Cached Generation :: (input_ids.shape[1] == 1) and (past_key_values is not None)
# => Unimodal Forward :: (pixel_values is None)
# => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
# === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
if input_ids.shape[1] == 1:
assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
assert labels is None, "Unexpected key `labels` provided during cached generation!"
language_model_output = self.llm_backbone.llm(
input_ids=input_ids,
attention_mask=None,
position_ids=None,
past_key_values=past_key_values,
inputs_embeds=None,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# === Handle Unimodal Forward ===
elif pixel_values is None:
assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
language_model_output = self.llm_backbone.llm(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# === Handle Multimodal Forward ===
elif (input_ids.shape[0] == pixel_values['dino'].shape[0]) or (inputs_embeds.shape[0] == pixel_values['dino'].shape[0]):
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
# Visual Feature Extraction
patch_features = self.vision_backbone(pixel_values)
projected_patch_embeddings = self.projector(patch_features) ## matches
projected_patch_attention_mask = None
if attention_mask is not None:
projected_patch_attention_mask = torch.full(
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
fill_value=True,
dtype=attention_mask.dtype,
device=attention_mask.device,
)
# Get Input Embeddings (from Language Model Embeddings)
input_embeddings = self.get_input_embeddings()(input_ids)
# Build Multimodal Embeddings & Attention Mask =>> Prismatic defaults to inserting after <BOS> token (1:)
multimodal_embeddings = torch.cat(
[input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
)
multimodal_attention_mask = None
if attention_mask is not None:
multimodal_attention_mask = torch.cat(
[attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
)
# Build Labels (if specified) =>> Ignore Labels for Patch Embeddings
multimodal_labels = None
if labels is not None:
projected_patch_labels = torch.full(
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
fill_value=IGNORE_INDEX,
dtype=labels.dtype,
device=labels.device,
)
multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
# Dispatch to Language Model
language_model_output = self.llm_backbone.llm(
input_ids=None,
attention_mask=multimodal_attention_mask,
position_ids=None,
past_key_values=None,
inputs_embeds=multimodal_embeddings,
labels=multimodal_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# === Otherwise =>> Assume Invalid! ===
elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
else:
raise ValueError(
"Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
f"=> `input_ids` = {input_ids is not None}\n"
f"=> `attention_mask` = {attention_mask is not None}\n"
f"=> `pixel_values` = {pixel_values is not None}\n"
f"=> `labels` = {labels is not None}\n"
f"=> `input_embeds` = {inputs_embeds is not None}\n"
f"=> `past_key_values` = {past_key_values is not None}\n"
f"=> `use_cache` = {use_cache}"
)
# Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
if not return_dict:
if output_projector_features and (projected_patch_embeddings is not None):
return *language_model_output, projected_patch_embeddings
return language_model_output
return (PrismaticCausalLMOutputWithPast(
loss=language_model_output.loss,
logits=language_model_output.logits,
past_key_values=language_model_output.past_key_values,
hidden_states=language_model_output.hidden_states,
attentions=language_model_output.attentions,
projector_features=projected_patch_embeddings,
),patch_features,multimodal_attention_mask)
# === GenerationMixin Methods ===
def prepare_inputs_for_generation(
self,
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: str,
) -> Dict[str, torch.Tensor]:
"""Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
(inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
):
raise ValueError("Generation with batch size > 1 is not currently supported!")
# Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# If `input_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"input_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# Make sure `pixel_values` are preserved in `model_inputs`
model_inputs.update(
{
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
}
)
return model_inputs
# Defer to Language Model (all handle this differently, with different return types)
def _reorder_cache(self, *args, **kwargs) -> Any:
return self.language_model._reorder_cache(*args, **kwargs)
class TokenProjectorConfig(PretrainedConfig):
vit_tokens_layers: List[int] = [] # If empty, torch.nn.Identity
llm_image_tokens_layers: List[int] = [] # If empty, torch.nn.Identity
control_tokens_layers: List[int] = [] # If empty, torch.nn.Identity
# image_tokens_mode:
# vit: use ViT tokens only
# llm: use LLM tokens only
# skip: skip connection between projector(ViT) and LLM with addition
# none: don't feed to TokenProjector
image_tokens_mode: str
def __post_init__(self):
super().__post_init__()
if self.image_tokens_mode == 'vit':
assert len(self.vit_tokens_layers) > 0 or len(self.control_tokens_layers) > 0
elif self.image_tokens_mode == 'llm':
assert len(self.vit_tokens_layers) > 0 or len(self.control_tokens_layers) > 0
elif self.image_tokens_mode == 'skip':
assert len(self.vit_tokens_layers) > 0 or len(self.llm_image_tokens_layers) > 0
elif self.image_tokens_mode == 'none':
assert len(self.vit_tokens_layers) == 0
assert len(self.llm_image_tokens_layers) == 0
else:
raise NotImplementedError(f"Unknown image tokens mode {self.image_tokens_mode}")
class TokenProjector(nn.Module):
"""Project and pack VLM output tokens"""
def __init__(self, config):
super().__init__()
self.config = TokenProjectorConfig()
self.config.vit_tokens_layers = config['vit_tokens_layers']
self.config.llm_image_tokens_layers = config['llm_image_tokens_layers']
self.config.control_tokens_layers = config['control_tokens_layers']
self.config.image_tokens_mode = config['image_tokens_mode']
self.vit_tokens_proj = self._make_token_proj_module(self.config.vit_tokens_layers)
self.llm_image_tokens_proj = self._make_token_proj_module(self.config.llm_image_tokens_layers)
self.control_tokens_proj = self._make_token_proj_module(self.config.control_tokens_layers)
def forward(self, inputs: WaypointerInput) -> torch.Tensor:
"""
Args:
inputs: Contains VLM outputs
Returns:
torch.Tensor of shape [B, num_tokens, token_size] that always contains the control tokens
and possibly the image tokens (prepended), depending on the configuration
"""
vit_tokens = self.vit_tokens_proj(inputs.vit_tokens)
control_tokens = self.control_tokens_proj(inputs.control_tokens)
llm_image_tokens = self.llm_image_tokens_proj(inputs.llm_image_tokens)
if self.config.image_tokens_mode == 'vit':
output = torch.cat([vit_tokens, control_tokens], dim=1) # [B, img + control, token_size]
elif self.config.image_tokens_mode == 'llm':
output = torch.cat([llm_image_tokens, control_tokens], dim=1) # [B, img + control, token_size]
elif self.config.image_tokens_mode == 'skip':
image_tokens = llm_image_tokens + vit_tokens
output = torch.cat([image_tokens, control_tokens], dim=1) # [B, img + control, token_size]
elif self.config.image_tokens_mode == 'none':
output = control_tokens
else:
raise NotImplementedError(f"Unknown image tokens mode {self.config.image_tokens_mode}")
return output
def _make_token_proj_module(self, layer_sizes: List[int]) -> torch.nn.Module:
if len(layer_sizes) == 0:
return torch.nn.Identity()
assert len(layer_sizes) > 1, "Need to provide input and output layer sizes at least"
module = torch.nn.Sequential(
*[
torch.nn.Sequential(
collections.OrderedDict(
{
'linear': torch.nn.Linear(layer_in_features, layer_out_features),
'act': torch.nn.ReLU(),
'norm': torch.nn.LayerNorm(layer_out_features),
}
)
)
for layer_in_features, layer_out_features in zip(layer_sizes[:-1], layer_sizes[1:])
]
)
return module
class NeRFPositionalEmbedding(torch.nn.Module):
def __init__(self, proj_scale: int):
"""
Args:
proj_scale: Dimension size, same as L parameter in the NeRF paper
"""
super().__init__()
self.proj_scale = proj_scale
freq = 2 ** torch.arange(self.proj_scale, dtype=torch.float32) * math.pi # size: [L]
self.register_buffer('freq', freq)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Maps values from R^N to a higher dimensional space R^(N2L)
Args:
inputs: torch.Tensor of shape [B, ..., N]; input values to be transformed
Returns: torch.Tensor of shape [B, ..., N2L]; encoded input values
"""
spectrum = self.freq.view(*[1] * inputs.ndim, -1) * inputs.unsqueeze(-1) # [B, ..., N, L]
encoding = torch.stack([torch.sin(spectrum), torch.cos(spectrum)], dim=-2) # [B, ..., N, 2, L]
encoding = encoding.view(inputs.shape[-1], -1) # [B, ..., N2L]
return encoding
class TimestepProjModuleConfig(PretrainedConfig):
pos_embed_scale: int # How much to scale timestep values when doing position embedding
proj_layers: List[int]
time_delta_sec: float = 0.25 # Time delta between two predictions
num_tokens: int = 3 # Number of tokens per timestep; Currently 3 - translation, rotation, gripper
class TimestepProjModule(nn.Module):
def __init__(self, config: TimestepProjModuleConfig, num_timesteps: int, token_size: int):
"""
Args:
num_timesteps: Number of control timesteps
token_size: Single token size
"""
super().__init__()
self.config = TimestepProjModuleConfig()
self.config.pos_embed_scale = config['pos_embed_scale']
self.config.proj_layers = config['proj_layers']
self.config.time_delta_sec = config['time_delta_sec']
self.config.num_tokens = config['num_tokens']
self.num_timesteps = num_timesteps
self.token_size = token_size
input_size = 2 * self.config.pos_embed_scale
self.pos_embed = NeRFPositionalEmbedding(self.config.pos_embed_scale)
# We output one token for translation, one for rotation and one for gripper state
feature_size = self.config.num_tokens * self.token_size
# Make MLP projection
self.timestep_proj = self._make_timestep_proj(in_features=int(input_size), out_features=int(feature_size))
def _make_timestep_proj(self, in_features: int, out_features: int) -> torch.nn.Module:
layer_sizes = [in_features] + list(self.config.proj_layers) + [out_features]
module = torch.nn.Sequential(
*[
torch.nn.Sequential(
collections.OrderedDict(
{
'linear': torch.nn.Linear(layer_in_features, layer_out_features),
'act': torch.nn.ReLU(),
'norm': torch.nn.LayerNorm(layer_out_features),
}
)
)
for layer_in_features, layer_out_features in zip(layer_sizes[:-1], layer_sizes[1:])
]
)
return module
def forward(self) -> torch.Tensor:
"""
Returns:
torch.Tensor of sequence of timestep tokens, shape [1, num_timesteps * num_tokens, token_size]
"""
device = self.timestep_proj[0].linear.weight.device # type: ignore[index]
# Position encode timesteps
time_deltas_norm = self.time_deltas_norm.view(1, self.num_timesteps) # [1, num_timesteps]
time_deltas_norm = time_deltas_norm.to(device=device)
# Embed timesteps to intermediate dimension
timesteps_embed = self.pos_embed(time_deltas_norm) # [1, num_timesteps * 2 * L]
timesteps_embed = timesteps_embed.view(self.num_timesteps, -1) # [num_timesteps, 2 * L]
# Project the timesteps via MLP to tokens
timesteps_tokens = self.timestep_proj(timesteps_embed) # [num_timesteps, token_size * 3]
# Reshape MLP outputs into tokens
timesteps_tokens = timesteps_tokens.view( # [1, num_timesteps * 3, token_size]
1, self.num_timesteps * self.config.num_tokens, self.token_size
)
return timesteps_tokens
@cached_property
def time_deltas_sec(self) -> torch.Tensor:
return torch.arange(0, self.num_timesteps, 1, dtype=torch.float32) * self.config.time_delta_sec
@cached_property
def time_deltas_norm(self) -> torch.Tensor:
# Normalize time deltas between [0, 1]. We are saving [-1, 0] interval for possible past supervision
if self.time_deltas_sec.shape[0] == 1:
# Can't divide by 0
time_deltas_norm = self.time_deltas_sec
else:
time_deltas_norm = self.time_deltas_sec / self.time_deltas_sec.max() # [num_timesteps]
return time_deltas_norm.detach()
# class Waypointer(nn.Module):
class TrajectoryVLA(PrismaticForConditionalGeneration):
# class TrajectoryVLA(nn.Module):
config_class: PretrainedConfig = TrajectoryVLAConfig
def __init__(self, config: TrajectoryVLAConfig) -> None:
super().__init__(config.prismatic_config)
self.control_tokenizer = WaypointTokenizer(self.llm_backbone.tokenizer)
self.timestep_proj = TimestepProjModule(
config.timestep_proj_config,
num_timesteps=config.num_timesteps,
token_size=config.token_size, )
self.num_timesteps = config.num_timesteps
self.token_proj = TokenProjector(config.token_proj_config)
self.transformer = DETR(config.transformer_config)
self.token_size = config.token_size
self.rotation_components = config.rotation_components
# if self.config.separate_control_proj:
# Project translation, rotation and gripper separately. Each timestep is projected separately
self.translation_proj = torch.nn.Sequential(
torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2),
torch.nn.ReLU(),
torch.nn.Linear(in_features=config.token_size // 2, out_features=3),
)
self.rotation_proj = torch.nn.Sequential(
torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2),
torch.nn.ReLU(),
torch.nn.Linear(
in_features=config.token_size // 2, out_features=config.rotation_components
),
)
self.gripper_proj = torch.nn.Sequential(
torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2),
torch.nn.ReLU(),
torch.nn.Linear(in_features=config.token_size // 2, out_features=1),
)
def _pack_waypointer_input(self, input_ids: torch.Tensor, vlm_output: PrismaticCausalLMOutputWithPast,vit_tokens,fused_attention_mask) -> WaypointerInput:
# Get the LLM output
# assert vlm_output.llm_output.hidden_states is not None
projected_tokens = vlm_output.hidden_states[-1]
control_tokens = self._extract_control_tokens(input_ids, projected_tokens) # type: ignore
num_image_tokens = vit_tokens.shape[1] # type: ignore[union-attr]
# TODO: This assumes a specific position of image tokens in the sequence. Make general
llm_image_tokens = projected_tokens[..., 1 : 1 + num_image_tokens, :]
return WaypointerInput(
vit_tokens=vit_tokens,
llm_image_tokens=llm_image_tokens,
control_tokens=control_tokens,
llm_tokens=projected_tokens,
attn_mask=fused_attention_mask,
)
def predict_tracks(self,inputs):
vlm_output,vit_tokens,fused_attention_mask = super().forward(**inputs,output_hidden_states=True,output_attentions=True,return_dict=True)
waypointer_input = self._pack_waypointer_input(inputs['input_ids'], vlm_output,vit_tokens,fused_attention_mask)
waypoint_output = self._waypointer_forward(waypointer_input)
translation, rotation, gripper = torch.split(
waypoint_output, [3, self.rotation_components, 1], dim=-1 )
translation, rotation, gripper = self.process_output(translation, rotation, gripper)
return translation, rotation, gripper
def process_output(self,translation,rotation,gripper):
## convert rotation from matrix to euler angles
euler_angles = []
for matrix in rotation[0]:
# Convert each rotation matrix to a Rotation object
rotation_obj = R.from_matrix(matrix.view(3, 3).detach().cpu().float().numpy().squeeze())
# Convert to Euler angles in radians with chosen convention, e.g., 'xyz'
euler_angle = rotation_obj.as_euler('xyz', degrees=False)
euler_angles.append(euler_angle)
translation = translation.detach().cpu().float().numpy().squeeze()
## sigmoid and clip from 0-1
gripper = np.round(torch.sigmoid(gripper).detach().cpu().float().numpy().squeeze())
return translation,euler_angles,gripper
def _extract_control_tokens(self, input_ids: torch.Tensor, output_tokens: torch.Tensor) -> torch.Tensor:
"""
Extract the action tokens from the LLM output sequence. Assumes the following order
[image_tokens, language_tokens, action_tokens, padding]
Args:
input_ids: IDs of the tokens in text input sequence; shape [B, S]
output_tokens: Token sequence output from LLM; shape [B, L, token_size]. Note the length is
different from input_ids as it also contains image tokens
Returns:
torch.Tensor of shape [B, 7, token_size] containing only action tokens
"""
assert input_ids.ndim == 2
assert output_tokens.ndim == 3
batch, in_seq_len, out_seq_len = *input_ids.shape, output_tokens.shape[1]
device = input_ids.device
num_control_tokens = self.control_tokenizer.num_control_tokens # type: ignore[attr-defined]
control_token_ids = torch.from_numpy( # type: ignore[attr-defined]
self.control_tokenizer.control_token_ids # type: ignore[attr-defined]
)
control_token_ids = control_token_ids.to(dtype=input_ids.dtype, device=input_ids.device)
is_control_token = torch.any( # shape: [B, S]
input_ids.unsqueeze(-1) == control_token_ids.view(1, 1, -1),
dim=-1,
)
if not torch.all(mask := is_control_token.sum(dim=-1) == num_control_tokens):
raise RuntimeError(
f"Can't properly detect control tokens with ids {control_token_ids} of len="
f"{len(control_token_ids)} in input_ids {input_ids}. Rows mask: {mask}"
)
# Pad is_control_tokens mask to the LLM output sequence size
tokens_mask = torch.cat( # shape: [B, L]
[
torch.zeros(batch, out_seq_len - in_seq_len, dtype=torch.bool, device=device),
is_control_token.to(torch.bool),
],
dim=1,
)
control_tokens = output_tokens[tokens_mask] # shape: 1D tensor
control_tokens = control_tokens.view( # [B, num_control_tokens, token_size]
batch, num_control_tokens, output_tokens.shape[-1]
)
return control_tokens
def _waypointer_forward(self, inputs:WaypointerInput):
timesteps_tokens = self.timestep_proj() # [1, num_timesteps * 3, token_size]
# Project and pack LLM tokens
llm_tokens = self.token_proj(inputs) # [B, num_tokens, token_size]
# TODO: Pass inputs.attn_mask if you start using the LLM tokens
output_tokens = self.transformer( # [B, num_timesteps * 3, token_size]
feature_tokens=llm_tokens, query_tokens=timesteps_tokens, attn_mask=None
)
output_tokens = output_tokens.view( # [B, num_timesteps, 3 * token_size]
-1, self.num_timesteps, 3 * self.token_size
)
# if self.config.separate_control_proj:
# [B, num_timesteps, token_size] each
translation_tokens, rotation_tokens, gripper_tokens = torch.split(
output_tokens, [self.token_size] * 3, dim=-1
)
translation = self.translation_proj(translation_tokens) # [B, num_timesteps, 3]
rotation = self.rotation_proj(rotation_tokens) # [B, num_timesteps, rotation_components]
gripper = self.gripper_proj(gripper_tokens) # [B, num_timesteps, 1]
output = torch.cat( # [B, num_timesteps, control_components]
[translation, rotation, gripper], dim=-1
)
return output
# def predict_waypoints(self,input_ids: Optional[torch.LongTensor] = None, **kwargs: str) -> np.ndarray:
# vlm_output = super().forward(
# inputs=input_ids,
# use_cache=use_cache,
# output_attentions=output_attentions,
# output_hidden_states=True,
# return_dict=return_dict,
# )
@staticmethod
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
if unnorm_key is None and len(norm_stats) != 1:
raise ValueError(
f"Your model was trained on more than one dataset. "
f"Please pass a `unnorm_key` from the following options to choose the statistics used for "
f"de-normalizing actions: {norm_stats.keys()}"
)
# If None, grab the (singular) dataset in `norm_stats` to use as `unnorm_key`
unnorm_key = unnorm_key if unnorm_key is not None else next(iter(norm_stats.keys()))
if unnorm_key not in norm_stats:
raise ValueError(
f"The `unnorm_key` you chose ({unnorm_key = }) is not in the available statistics. "
f"Please choose from: {norm_stats.keys()}"
)
return unnorm_key
def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
"""Get the dimensionality of the policy's action space."""
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
return len(self.norm_stats[unnorm_key]["action"]["q01"])
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
"""Get all the logged statistics for the given dataset."""
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
return self.norm_stats[unnorm_key]["action"]
def remove_waypointer_prefix(ckpt):
new_state_dict = {}
for key, value in ckpt.items():
# Remove the 'waypointer.' prefix if it exists
if key.startswith('waypointer.'):
new_key = key[len('waypointer.'):]
else:
new_key = key
new_state_dict[new_key] = value
return new_state_dict
def image_processor(image):
image_resolution = (3,224,224)
image = image.resize(image_resolution[1:], resample=Image.Resampling.LANCZOS)
def read_pt(pt_path):
data = torch.load(pt_path)
return data
# model_input = read_pt('/work/nikolay_nikolov/debug/inference/model_input.pt')
# vit_output = read_pt('/work/nikolay_nikolov/debug/inference/vit_output.pt')['vit_output']
# llm_output = read_pt('/work/nikolay_nikolov/debug/inference/llm_output.pt')['llm_output']
# projector_output = read_pt('/work/nikolay_nikolov/debug/inference/projector_output.pt')['projector_output']
# transformer_input = read_pt('/work/nikolay_nikolov/debug/inference/transformer_input.pt')
# feature_tokens = transformer_input['feature_tokens']
# timestep_tokens = transformer_input['timestep_tokens']
# # waypointer_input_nikolay = read_pt('/work/nikolay_nikolov/debug/inference/waypointer_input.pt')
# transformer_input = read_pt('/work/nikolay_nikolov/debug/inference/transformer_input.pt')
# control_target = read_pt('/work/nikolay_nikolov/debug/inference/control_target.pt')
if __name__ == "__main__":
prismatic_config_dict = {
"vision_backbone_id":"dinosiglip-vit-so-224px",
"llm_backbone_id":"llama2-7b-pure",
"arch_specifier": "no-align+gelu-mlp", ## TODO: check
"use_fused_vision_backbone" :True, ## TODO: check
"image_resize_strategy" : "letterbox",
"text_config" : None,
"llm_max_length" : 2048,
"pad_token_id" :32000,
"pad_to_multiple_of" : 64,
"output_projector_states" : False,
"return_dict": False,
}
token_proj_config = {
"vit_tokens_layers": [2176, 1024],
"control_tokens_layers": [4096, 2048, 1024],
"image_tokens_mode": 'vit',
'llm_image_tokens_layers': []
}
timestep_proj_config = {
"pos_embed_scale": 8,
"proj_layers": [128,512,1024],
"time_delta_sec": 0.1,
"num_tokens":3
}
pos_embed_config = {
"num_embeddings": 300,
"embedding_dim": 1024
}
encoder_block_config = {
"feature_size": 1024,
"head_dim": 64,
"num_heads": 16
}
decoder_block_config = {
"feature_size": 1024,
"head_dim": 64,
"num_heads": 16,
"dropout": 0.0
}
transformer_config = {
"pos_embed_config": pos_embed_config,
"encoder_block_config": encoder_block_config,
"decoder_block_config": decoder_block_config,
"num_blocks": 2
}
# transformer_config:
# autoclass: barrel.components.nn.layers.detr.DETR
# pos_embed_config:
# autoclass: barrel.components.nn.layers.positional_encodings.LearnedPosEmbed1D
# num_embeddings: 300 # Max number of input tokens
# embedding_dim: *token_size # token_size
# # num_embeddings: 256 # Number of image tokens
# # embedding_dim: 512 # token_size / 2
# encoder_block_config:
# autoclass: barrel.components.nn.layers.detr.TransformerEncoderBlock
# feature_size: *token_size
# # head_dim: 128
# # num_heads: 8
# head_dim: 64
# num_heads: 16
# decoder_block_config:
# autoclass: barrel.components.nn.layers.detr.TransformerDecoderBlock
# feature_size: *token_size
# # head_dim: 128
# # num_heads: 8
# head_dim: 64
# num_heads: 16
TrajectoryVlaConfig_config = {
"prismatic_config":prismatic_config_dict,
"token_size": 1024,
"cheat": False,
"num_timesteps": 6,
"rotation_components": 9,
"seperate_control_proj": True,
"timestep_proj_config": timestep_proj_config,
"token_proj_config": token_proj_config,
"transformer_config": transformer_config,
"num_timestep_tokens": 3,
}
# ckpt_path = '/work/nikolay_nikolov/debug/inference/model.ckpt'
# ckpt_params = torch.load(ckpt_path, map_location='cpu', mmap= True)
# ckpt_params = remove_waypointer_prefix(ckpt_params)
## Testing for prismatic
model_config = TrajectoryVLAConfig( **TrajectoryVlaConfig_config)
# model.load_state_dict(ckpt_params, strict=True)
model = TrajectoryVLA(model_config)
model = model.to(dtype=torch.bfloat16)
model = model.to('cuda')
model.eval()
# with autocast('cuda',dtype=torch.bfloat16):
# with torch.no_grad():
# output = model.predict_tracks(model_input)
# Get matched keys by finding keys that exist in both the model and checkpoint
# TrajectoryVLA.load_state_dict(ckpt_params, strict=False)
# model_keys = set(TrajectoryVLA.state_dict().keys())
# checkpoint_keys = set(ckpt_params.keys())
# matched_keys = model_keys.intersection(checkpoint_keys)
# print('Matched Keys:')
# for key in matched_keys:
# print(key)
# embed()
# hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path)
# hf_processor.push_to_hub(cfg.output_hf_model_hub_path)
# import code; code.interact(local=vars())