|
""" |
|
modeling_prismatic.py |
|
|
|
Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting |
|
from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the |
|
logic in `prismatic.models.vlms.prismatic.py`. |
|
|
|
Note =>> for the time being, not adding the custom HF "docstring" formatting. |
|
|
|
References [LLaVa, IDEFICS-2]: |
|
=> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py |
|
=> https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py |
|
""" |
|
|
|
import logging |
|
from dataclasses import dataclass |
|
from functools import partial |
|
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union |
|
from functools import cached_property |
|
|
|
|
|
import numpy as np |
|
import timm |
|
import tokenizers |
|
import torch |
|
import torch.nn as nn |
|
import transformers |
|
from timm.models.vision_transformer import LayerScale |
|
from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel |
|
from transformers.modeling_outputs import ModelOutput |
|
import collections |
|
import math |
|
from barrel.pipes.vlams.extern.prismatic_config import OpenVLAConfig, PrismaticConfig , TrajectoryVLAConfig, WaypointTokenizer |
|
|
|
from barrel.pipes.vlams.extern.datatypes import * |
|
from barrel.pipes.vlams.extern.detr import * |
|
from IPython import embed |
|
import os |
|
from PIL import Image |
|
from pathlib import Path |
|
from torch.amp.autocast_mode import autocast |
|
from scipy.spatial.transform import Rotation as R |
|
ht_token_path = Path(".hf_token") |
|
HF_TOKEN = ht_token_path.read_text().strip() if isinstance(ht_token_path, Path) else hf_token_path |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
IGNORE_INDEX = -100 |
|
|
|
|
|
|
|
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: |
|
def wrapper(*args: Any, **kwargs: Any) -> Any: |
|
result = fn(*args, **kwargs) |
|
return result[0] if isinstance(result, tuple) else result |
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor |
|
|
|
|
|
def ls_apply_patch(ls_module: LayerScale): |
|
ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) |
|
ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale) |
|
del ls_module.gamma |
|
|
|
|
|
|
|
class PrismaticVisionBackbone(nn.Module): |
|
def __init__( |
|
self, |
|
use_fused_vision_backbone: bool, |
|
image_sizes: List[int], |
|
timm_model_ids: List[str], |
|
timm_override_act_layers: List[Optional[str]], |
|
) -> None: |
|
super().__init__() |
|
self.use_fused_vision_backbone = use_fused_vision_backbone |
|
|
|
|
|
|
|
|
|
assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!" |
|
|
|
self.dino_featurizer = timm.create_model( |
|
timm_model_ids[0], |
|
pretrained=True, |
|
num_classes=0, |
|
img_size=image_sizes[0], |
|
act_layer=timm_override_act_layers[0], |
|
) |
|
self.dino_featurizer.eval() |
|
|
|
self.embed_dim = self.dino_featurizer.embed_dim |
|
|
|
|
|
|
|
self.siglip_featurizer = timm.create_model( |
|
timm_model_ids[1], |
|
pretrained=True, |
|
num_classes=0, |
|
img_size=image_sizes[1], |
|
act_layer=timm_override_act_layers[1],) |
|
|
|
self.siglip_featurizer.eval() |
|
|
|
self.dino_featurizer.forward = partial( |
|
self.dino_featurizer.forward_intermediates, |
|
indices=[len(self.dino_featurizer.blocks) - 2], |
|
return_prefix_tokens=False, |
|
norm=False, |
|
stop_early=True, |
|
output_fmt='NLC', |
|
intermediates_only=True, |
|
) |
|
self.siglip_featurizer.forward = partial( |
|
self.siglip_featurizer.forward_intermediates, |
|
indices=[len(self.siglip_featurizer.blocks) - 2], |
|
return_prefix_tokens=False, |
|
norm=False, |
|
stop_early=True, |
|
output_fmt='NLC', |
|
intermediates_only=True, |
|
) |
|
self.embed_dim += self.siglip_featurizer.embed_dim |
|
|
|
def forward(self, pixel_values) -> torch.Tensor: |
|
"""Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack.""" |
|
if not self.use_fused_vision_backbone: |
|
return self.featurizer(pixel_values) |
|
|
|
|
|
|
|
img = pixel_values['dino'] |
|
img_fused = pixel_values['siglip'] |
|
patches, patches_fused = self.dino_featurizer(img)[0], self.siglip_featurizer(img_fused)[0] |
|
|
|
return torch.cat([patches, patches_fused], dim=2) |
|
|
|
|
|
|
|
class PrismaticProjector(nn.Module): |
|
def __init__(self, use_fused_vision_backbone, vision_dim: int, llm_dim: int) -> None: |
|
super().__init__() |
|
self.initial_projection_dim = vision_dim * 4 |
|
self.projector = torch.nn.Sequential( |
|
torch.nn.Linear(vision_dim, self.initial_projection_dim, bias=True), |
|
torch.nn.GELU(), |
|
torch.nn.Linear(self.initial_projection_dim, llm_dim, bias=True), |
|
torch.nn.GELU(), |
|
torch.nn.Linear(llm_dim, llm_dim, bias=True), |
|
) |
|
|
|
def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor: |
|
return self.projector(fused_img_patches) |
|
|
|
|
|
@dataclass |
|
class PrismaticCausalLMOutputWithPast(ModelOutput): |
|
"""Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features.""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
projector_features: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class PrismaticPreTrainedModel(PreTrainedModel): |
|
config_class: PrismaticConfig |
|
base_model_prefix: str = "model" |
|
supports_gradient_checkpointing: bool = True |
|
|
|
_no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"] |
|
_skip_keys_device_placement: str = "past_key_values" |
|
_supports_flash_attn_2: bool = True |
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
|
|
|
|
|
|
std = ( |
|
self.config.initializer_range |
|
if hasattr(self.config, "initializer_range") |
|
else self.config.text_config.initializer_range |
|
) |
|
|
|
if hasattr(module, "class_embedding"): |
|
module.class_embedding.data.normal_(mean=0.0, std=std) |
|
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
@property |
|
def _supports_sdpa(self) -> bool: |
|
"""Check LLM supports SDPA Attention""" |
|
return self.language_model._supports_sdpa |
|
|
|
class LLMBackbone(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.llm : AutoModelForCausalLM |
|
self.tokenizer = self._create_tokenizer() |
|
|
|
def _create_tokenizer(self) -> transformers.PreTrainedTokenizerBase: |
|
|
|
print(f"Loading (Fast) Tokenizer via the AutoTokenizer API") |
|
tokenizer = transformers.AutoTokenizer.from_pretrained( |
|
self.config['hf_model_id'], |
|
model_max_length=self.config['llm_max_length'], |
|
token=HF_TOKEN, |
|
padding_side="right", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SPECIAL_CASES = { |
|
|
|
|
|
|
|
|
|
"microsoft/phi-2", |
|
} |
|
if self.config['hf_model_id'] not in SPECIAL_CASES: |
|
|
|
assert ( |
|
tokenizer("Test 123", add_special_tokens=True).input_ids[0] == tokenizer.bos_token_id |
|
) and ( |
|
tokenizer("Test 123", add_special_tokens=False).input_ids[0] != tokenizer.bos_token_id |
|
), f"Default Tokenizer of type `{type(tokenizer)}` does not automatically prefix inputs with BOS token!\n" |
|
|
|
return tokenizer |
|
|
|
class PrismaticForConditionalGeneration(PrismaticPreTrainedModel): |
|
def __init__(self, config: PrismaticConfig) -> None: |
|
super().__init__(config) |
|
|
|
if config.use_fused_vision_backbone is None: |
|
raise ValueError("Missing config field `use_fused_vision_backbone`") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.vision_backbone = PrismaticVisionBackbone( |
|
config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers |
|
) |
|
|
|
|
|
self.projector = PrismaticProjector( |
|
config.use_fused_vision_backbone, |
|
vision_dim=self.vision_backbone.embed_dim, |
|
llm_dim=config.text_config.hidden_size, |
|
) |
|
|
|
|
|
self.llm_backbone = LLMBackbone({'hf_model_id': config.hf_llm_id, 'llm_max_length': config.llm_max_length, "pad_token_id" :32000, |
|
"pad_to_multiple_of" : 64,}) |
|
|
|
|
|
|
|
|
|
self.llm_backbone.llm = AutoModelForCausalLM.from_pretrained( |
|
'meta-llama/Llama-2-7b-hf', |
|
token=HF_TOKEN, |
|
attn_implementation='flash_attention_2', |
|
|
|
do_sample=False, |
|
temperature=1.0, |
|
use_cache=False, |
|
top_p=1.0, ) |
|
|
|
self.llm_backbone.tokenizer.add_special_tokens({"pad_token": "<PAD>"}) |
|
self.llm_backbone.llm.config.pad_token_id = self.llm_backbone.tokenizer.pad_token_id |
|
self.llm_backbone.llm.resize_token_embeddings(len(self.llm_backbone.tokenizer), pad_to_multiple_of=64) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.vocab_size = config.text_config.vocab_size |
|
self.pad_token_id = config.pad_token_id |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
return self.llm_backbone.llm.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value: nn.Module) -> None: |
|
self.llm_backbone.llm.set_input_embeddings(value) |
|
|
|
def get_output_embeddings(self) -> nn.Module: |
|
return self.llm_backbone.llm.get_output_embeddings() |
|
|
|
def set_output_embeddings(self, new_embeddings: nn.Module) -> None: |
|
self.llm_backbone.llm.set_output_embeddings(new_embeddings) |
|
|
|
def get_decoder(self) -> nn.Module: |
|
return self.llm_backbone.llm.get_decoder() |
|
|
|
def set_decoder(self, decoder: nn.Module) -> None: |
|
self.llm_backbone.llm.set_decoder(decoder) |
|
|
|
def tie_weights(self) -> None: |
|
self.llm_backbone.llm.tie_weights() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] , |
|
attention_mask: Optional[torch.Tensor], |
|
|
|
pixel_values: Dict[str, torch.Tensor] = {}, |
|
labels: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_projector_features: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs: Any, |
|
) -> Union[Tuple, PrismaticCausalLMOutputWithPast]: |
|
"""Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance.""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
output_projector_features = output_projector_features if output_projector_features is not None else False |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
use_cache = use_cache and not self.training |
|
|
|
|
|
projected_patch_embeddings = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if input_ids.shape[1] == 1: |
|
assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!" |
|
assert past_key_values is not None, "You must provide `past_key_values` during cached generation!" |
|
assert labels is None, "Unexpected key `labels` provided during cached generation!" |
|
|
|
language_model_output = self.llm_backbone.llm( |
|
input_ids=input_ids, |
|
attention_mask=None, |
|
position_ids=None, |
|
past_key_values=past_key_values, |
|
inputs_embeds=None, |
|
labels=None, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
elif pixel_values is None: |
|
assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!" |
|
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" |
|
|
|
language_model_output = self.llm_backbone.llm( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
|
|
elif (input_ids.shape[0] == pixel_values['dino'].shape[0]) or (inputs_embeds.shape[0] == pixel_values['dino'].shape[0]): |
|
assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!" |
|
|
|
|
|
patch_features = self.vision_backbone(pixel_values) |
|
|
|
projected_patch_embeddings = self.projector(patch_features) |
|
projected_patch_attention_mask = None |
|
if attention_mask is not None: |
|
projected_patch_attention_mask = torch.full( |
|
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), |
|
fill_value=True, |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
) |
|
|
|
|
|
input_embeddings = self.get_input_embeddings()(input_ids) |
|
|
|
|
|
multimodal_embeddings = torch.cat( |
|
[input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1 |
|
) |
|
multimodal_attention_mask = None |
|
if attention_mask is not None: |
|
multimodal_attention_mask = torch.cat( |
|
[attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1 |
|
) |
|
|
|
|
|
multimodal_labels = None |
|
if labels is not None: |
|
projected_patch_labels = torch.full( |
|
(projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]), |
|
fill_value=IGNORE_INDEX, |
|
dtype=labels.dtype, |
|
device=labels.device, |
|
) |
|
multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1) |
|
|
|
|
|
language_model_output = self.llm_backbone.llm( |
|
input_ids=None, |
|
attention_mask=multimodal_attention_mask, |
|
position_ids=None, |
|
past_key_values=None, |
|
inputs_embeds=multimodal_embeddings, |
|
labels=multimodal_labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]): |
|
raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!") |
|
|
|
else: |
|
raise ValueError( |
|
"Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n" |
|
f"=> `input_ids` = {input_ids is not None}\n" |
|
f"=> `attention_mask` = {attention_mask is not None}\n" |
|
f"=> `pixel_values` = {pixel_values is not None}\n" |
|
f"=> `labels` = {labels is not None}\n" |
|
f"=> `input_embeds` = {inputs_embeds is not None}\n" |
|
f"=> `past_key_values` = {past_key_values is not None}\n" |
|
f"=> `use_cache` = {use_cache}" |
|
) |
|
|
|
|
|
if not return_dict: |
|
if output_projector_features and (projected_patch_embeddings is not None): |
|
return *language_model_output, projected_patch_embeddings |
|
|
|
return language_model_output |
|
|
|
|
|
return (PrismaticCausalLMOutputWithPast( |
|
loss=language_model_output.loss, |
|
logits=language_model_output.logits, |
|
past_key_values=language_model_output.past_key_values, |
|
hidden_states=language_model_output.hidden_states, |
|
attentions=language_model_output.attentions, |
|
projector_features=projected_patch_embeddings, |
|
),patch_features,multimodal_attention_mask) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
**kwargs: str, |
|
) -> Dict[str, torch.Tensor]: |
|
"""Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic.""" |
|
if ((input_ids is not None) and (input_ids.shape[0] > 1)) or ( |
|
(inputs_embeds is not None) and (inputs_embeds.shape[0] > 1) |
|
): |
|
raise ValueError("Generation with batch size > 1 is not currently supported!") |
|
|
|
|
|
if past_key_values is not None: |
|
input_ids = input_ids[:, -1:] |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"input_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
|
|
model_inputs.update( |
|
{ |
|
"attention_mask": attention_mask, |
|
"pixel_values": pixel_values, |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
} |
|
) |
|
|
|
return model_inputs |
|
|
|
|
|
def _reorder_cache(self, *args, **kwargs) -> Any: |
|
return self.language_model._reorder_cache(*args, **kwargs) |
|
|
|
|
|
class TokenProjectorConfig(PretrainedConfig): |
|
vit_tokens_layers: List[int] = [] |
|
llm_image_tokens_layers: List[int] = [] |
|
control_tokens_layers: List[int] = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
image_tokens_mode: str |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
if self.image_tokens_mode == 'vit': |
|
assert len(self.vit_tokens_layers) > 0 or len(self.control_tokens_layers) > 0 |
|
elif self.image_tokens_mode == 'llm': |
|
assert len(self.vit_tokens_layers) > 0 or len(self.control_tokens_layers) > 0 |
|
elif self.image_tokens_mode == 'skip': |
|
assert len(self.vit_tokens_layers) > 0 or len(self.llm_image_tokens_layers) > 0 |
|
elif self.image_tokens_mode == 'none': |
|
assert len(self.vit_tokens_layers) == 0 |
|
assert len(self.llm_image_tokens_layers) == 0 |
|
else: |
|
raise NotImplementedError(f"Unknown image tokens mode {self.image_tokens_mode}") |
|
|
|
class TokenProjector(nn.Module): |
|
"""Project and pack VLM output tokens""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = TokenProjectorConfig() |
|
self.config.vit_tokens_layers = config['vit_tokens_layers'] |
|
self.config.llm_image_tokens_layers = config['llm_image_tokens_layers'] |
|
self.config.control_tokens_layers = config['control_tokens_layers'] |
|
self.config.image_tokens_mode = config['image_tokens_mode'] |
|
|
|
self.vit_tokens_proj = self._make_token_proj_module(self.config.vit_tokens_layers) |
|
self.llm_image_tokens_proj = self._make_token_proj_module(self.config.llm_image_tokens_layers) |
|
self.control_tokens_proj = self._make_token_proj_module(self.config.control_tokens_layers) |
|
|
|
def forward(self, inputs: WaypointerInput) -> torch.Tensor: |
|
""" |
|
Args: |
|
inputs: Contains VLM outputs |
|
Returns: |
|
torch.Tensor of shape [B, num_tokens, token_size] that always contains the control tokens |
|
and possibly the image tokens (prepended), depending on the configuration |
|
""" |
|
|
|
vit_tokens = self.vit_tokens_proj(inputs.vit_tokens) |
|
control_tokens = self.control_tokens_proj(inputs.control_tokens) |
|
llm_image_tokens = self.llm_image_tokens_proj(inputs.llm_image_tokens) |
|
|
|
if self.config.image_tokens_mode == 'vit': |
|
output = torch.cat([vit_tokens, control_tokens], dim=1) |
|
elif self.config.image_tokens_mode == 'llm': |
|
output = torch.cat([llm_image_tokens, control_tokens], dim=1) |
|
elif self.config.image_tokens_mode == 'skip': |
|
image_tokens = llm_image_tokens + vit_tokens |
|
output = torch.cat([image_tokens, control_tokens], dim=1) |
|
elif self.config.image_tokens_mode == 'none': |
|
output = control_tokens |
|
else: |
|
raise NotImplementedError(f"Unknown image tokens mode {self.config.image_tokens_mode}") |
|
|
|
return output |
|
|
|
def _make_token_proj_module(self, layer_sizes: List[int]) -> torch.nn.Module: |
|
if len(layer_sizes) == 0: |
|
return torch.nn.Identity() |
|
|
|
assert len(layer_sizes) > 1, "Need to provide input and output layer sizes at least" |
|
|
|
module = torch.nn.Sequential( |
|
*[ |
|
torch.nn.Sequential( |
|
collections.OrderedDict( |
|
{ |
|
'linear': torch.nn.Linear(layer_in_features, layer_out_features), |
|
'act': torch.nn.ReLU(), |
|
'norm': torch.nn.LayerNorm(layer_out_features), |
|
} |
|
) |
|
) |
|
for layer_in_features, layer_out_features in zip(layer_sizes[:-1], layer_sizes[1:]) |
|
] |
|
) |
|
return module |
|
|
|
class NeRFPositionalEmbedding(torch.nn.Module): |
|
def __init__(self, proj_scale: int): |
|
""" |
|
Args: |
|
proj_scale: Dimension size, same as L parameter in the NeRF paper |
|
""" |
|
super().__init__() |
|
self.proj_scale = proj_scale |
|
|
|
freq = 2 ** torch.arange(self.proj_scale, dtype=torch.float32) * math.pi |
|
|
|
self.register_buffer('freq', freq) |
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Maps values from R^N to a higher dimensional space R^(N2L) |
|
Args: |
|
inputs: torch.Tensor of shape [B, ..., N]; input values to be transformed |
|
Returns: torch.Tensor of shape [B, ..., N2L]; encoded input values |
|
""" |
|
|
|
spectrum = self.freq.view(*[1] * inputs.ndim, -1) * inputs.unsqueeze(-1) |
|
encoding = torch.stack([torch.sin(spectrum), torch.cos(spectrum)], dim=-2) |
|
encoding = encoding.view(inputs.shape[-1], -1) |
|
|
|
return encoding |
|
|
|
class TimestepProjModuleConfig(PretrainedConfig): |
|
pos_embed_scale: int |
|
proj_layers: List[int] |
|
time_delta_sec: float = 0.25 |
|
num_tokens: int = 3 |
|
|
|
|
|
class TimestepProjModule(nn.Module): |
|
|
|
def __init__(self, config: TimestepProjModuleConfig, num_timesteps: int, token_size: int): |
|
""" |
|
Args: |
|
num_timesteps: Number of control timesteps |
|
token_size: Single token size |
|
""" |
|
super().__init__() |
|
self.config = TimestepProjModuleConfig() |
|
self.config.pos_embed_scale = config['pos_embed_scale'] |
|
self.config.proj_layers = config['proj_layers'] |
|
self.config.time_delta_sec = config['time_delta_sec'] |
|
self.config.num_tokens = config['num_tokens'] |
|
|
|
self.num_timesteps = num_timesteps |
|
self.token_size = token_size |
|
|
|
input_size = 2 * self.config.pos_embed_scale |
|
|
|
self.pos_embed = NeRFPositionalEmbedding(self.config.pos_embed_scale) |
|
|
|
|
|
feature_size = self.config.num_tokens * self.token_size |
|
|
|
|
|
|
|
self.timestep_proj = self._make_timestep_proj(in_features=int(input_size), out_features=int(feature_size)) |
|
|
|
def _make_timestep_proj(self, in_features: int, out_features: int) -> torch.nn.Module: |
|
layer_sizes = [in_features] + list(self.config.proj_layers) + [out_features] |
|
module = torch.nn.Sequential( |
|
*[ |
|
torch.nn.Sequential( |
|
collections.OrderedDict( |
|
{ |
|
'linear': torch.nn.Linear(layer_in_features, layer_out_features), |
|
'act': torch.nn.ReLU(), |
|
'norm': torch.nn.LayerNorm(layer_out_features), |
|
} |
|
) |
|
) |
|
for layer_in_features, layer_out_features in zip(layer_sizes[:-1], layer_sizes[1:]) |
|
] |
|
) |
|
return module |
|
|
|
def forward(self) -> torch.Tensor: |
|
""" |
|
Returns: |
|
torch.Tensor of sequence of timestep tokens, shape [1, num_timesteps * num_tokens, token_size] |
|
""" |
|
device = self.timestep_proj[0].linear.weight.device |
|
|
|
|
|
time_deltas_norm = self.time_deltas_norm.view(1, self.num_timesteps) |
|
time_deltas_norm = time_deltas_norm.to(device=device) |
|
|
|
|
|
timesteps_embed = self.pos_embed(time_deltas_norm) |
|
timesteps_embed = timesteps_embed.view(self.num_timesteps, -1) |
|
|
|
|
|
timesteps_tokens = self.timestep_proj(timesteps_embed) |
|
|
|
|
|
timesteps_tokens = timesteps_tokens.view( |
|
1, self.num_timesteps * self.config.num_tokens, self.token_size |
|
) |
|
|
|
return timesteps_tokens |
|
|
|
@cached_property |
|
def time_deltas_sec(self) -> torch.Tensor: |
|
return torch.arange(0, self.num_timesteps, 1, dtype=torch.float32) * self.config.time_delta_sec |
|
|
|
@cached_property |
|
def time_deltas_norm(self) -> torch.Tensor: |
|
|
|
if self.time_deltas_sec.shape[0] == 1: |
|
|
|
time_deltas_norm = self.time_deltas_sec |
|
else: |
|
time_deltas_norm = self.time_deltas_sec / self.time_deltas_sec.max() |
|
return time_deltas_norm.detach() |
|
|
|
|
|
|
|
|
|
class TrajectoryVLA(PrismaticForConditionalGeneration): |
|
|
|
config_class: PretrainedConfig = TrajectoryVLAConfig |
|
|
|
def __init__(self, config: TrajectoryVLAConfig) -> None: |
|
super().__init__(config.prismatic_config) |
|
self.control_tokenizer = WaypointTokenizer(self.llm_backbone.tokenizer) |
|
self.timestep_proj = TimestepProjModule( |
|
config.timestep_proj_config, |
|
num_timesteps=config.num_timesteps, |
|
token_size=config.token_size, ) |
|
self.num_timesteps = config.num_timesteps |
|
self.token_proj = TokenProjector(config.token_proj_config) |
|
self.transformer = DETR(config.transformer_config) |
|
self.token_size = config.token_size |
|
self.rotation_components = config.rotation_components |
|
|
|
|
|
self.translation_proj = torch.nn.Sequential( |
|
torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(in_features=config.token_size // 2, out_features=3), |
|
) |
|
self.rotation_proj = torch.nn.Sequential( |
|
torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear( |
|
in_features=config.token_size // 2, out_features=config.rotation_components |
|
), |
|
) |
|
|
|
self.gripper_proj = torch.nn.Sequential( |
|
torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(in_features=config.token_size // 2, out_features=1), |
|
) |
|
|
|
def _pack_waypointer_input(self, input_ids: torch.Tensor, vlm_output: PrismaticCausalLMOutputWithPast,vit_tokens,fused_attention_mask) -> WaypointerInput: |
|
|
|
|
|
projected_tokens = vlm_output.hidden_states[-1] |
|
|
|
control_tokens = self._extract_control_tokens(input_ids, projected_tokens) |
|
|
|
num_image_tokens = vit_tokens.shape[1] |
|
|
|
llm_image_tokens = projected_tokens[..., 1 : 1 + num_image_tokens, :] |
|
|
|
|
|
return WaypointerInput( |
|
vit_tokens=vit_tokens, |
|
llm_image_tokens=llm_image_tokens, |
|
control_tokens=control_tokens, |
|
llm_tokens=projected_tokens, |
|
attn_mask=fused_attention_mask, |
|
) |
|
|
|
def predict_tracks(self,inputs): |
|
|
|
vlm_output,vit_tokens,fused_attention_mask = super().forward(**inputs,output_hidden_states=True,output_attentions=True,return_dict=True) |
|
waypointer_input = self._pack_waypointer_input(inputs['input_ids'], vlm_output,vit_tokens,fused_attention_mask) |
|
waypoint_output = self._waypointer_forward(waypointer_input) |
|
translation, rotation, gripper = torch.split( |
|
waypoint_output, [3, self.rotation_components, 1], dim=-1 ) |
|
translation, rotation, gripper = self.process_output(translation, rotation, gripper) |
|
return translation, rotation, gripper |
|
def process_output(self,translation,rotation,gripper): |
|
|
|
euler_angles = [] |
|
for matrix in rotation[0]: |
|
|
|
rotation_obj = R.from_matrix(matrix.view(3, 3).detach().cpu().float().numpy().squeeze()) |
|
|
|
euler_angle = rotation_obj.as_euler('xyz', degrees=False) |
|
euler_angles.append(euler_angle) |
|
|
|
translation = translation.detach().cpu().float().numpy().squeeze() |
|
|
|
gripper = np.round(torch.sigmoid(gripper).detach().cpu().float().numpy().squeeze()) |
|
return translation,euler_angles,gripper |
|
|
|
def _extract_control_tokens(self, input_ids: torch.Tensor, output_tokens: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Extract the action tokens from the LLM output sequence. Assumes the following order |
|
[image_tokens, language_tokens, action_tokens, padding] |
|
|
|
Args: |
|
input_ids: IDs of the tokens in text input sequence; shape [B, S] |
|
output_tokens: Token sequence output from LLM; shape [B, L, token_size]. Note the length is |
|
different from input_ids as it also contains image tokens |
|
Returns: |
|
torch.Tensor of shape [B, 7, token_size] containing only action tokens |
|
""" |
|
|
|
assert input_ids.ndim == 2 |
|
assert output_tokens.ndim == 3 |
|
batch, in_seq_len, out_seq_len = *input_ids.shape, output_tokens.shape[1] |
|
|
|
device = input_ids.device |
|
|
|
num_control_tokens = self.control_tokenizer.num_control_tokens |
|
|
|
control_token_ids = torch.from_numpy( |
|
self.control_tokenizer.control_token_ids |
|
) |
|
control_token_ids = control_token_ids.to(dtype=input_ids.dtype, device=input_ids.device) |
|
is_control_token = torch.any( |
|
input_ids.unsqueeze(-1) == control_token_ids.view(1, 1, -1), |
|
dim=-1, |
|
) |
|
if not torch.all(mask := is_control_token.sum(dim=-1) == num_control_tokens): |
|
raise RuntimeError( |
|
f"Can't properly detect control tokens with ids {control_token_ids} of len=" |
|
f"{len(control_token_ids)} in input_ids {input_ids}. Rows mask: {mask}" |
|
) |
|
|
|
|
|
tokens_mask = torch.cat( |
|
[ |
|
torch.zeros(batch, out_seq_len - in_seq_len, dtype=torch.bool, device=device), |
|
is_control_token.to(torch.bool), |
|
], |
|
dim=1, |
|
) |
|
|
|
control_tokens = output_tokens[tokens_mask] |
|
control_tokens = control_tokens.view( |
|
batch, num_control_tokens, output_tokens.shape[-1] |
|
) |
|
|
|
return control_tokens |
|
|
|
def _waypointer_forward(self, inputs:WaypointerInput): |
|
|
|
timesteps_tokens = self.timestep_proj() |
|
|
|
|
|
llm_tokens = self.token_proj(inputs) |
|
|
|
|
|
output_tokens = self.transformer( |
|
feature_tokens=llm_tokens, query_tokens=timesteps_tokens, attn_mask=None |
|
) |
|
|
|
output_tokens = output_tokens.view( |
|
-1, self.num_timesteps, 3 * self.token_size |
|
) |
|
|
|
|
|
|
|
translation_tokens, rotation_tokens, gripper_tokens = torch.split( |
|
output_tokens, [self.token_size] * 3, dim=-1 |
|
) |
|
|
|
translation = self.translation_proj(translation_tokens) |
|
rotation = self.rotation_proj(rotation_tokens) |
|
gripper = self.gripper_proj(gripper_tokens) |
|
|
|
output = torch.cat( |
|
[translation, rotation, gripper], dim=-1 |
|
) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str: |
|
if unnorm_key is None and len(norm_stats) != 1: |
|
raise ValueError( |
|
f"Your model was trained on more than one dataset. " |
|
f"Please pass a `unnorm_key` from the following options to choose the statistics used for " |
|
f"de-normalizing actions: {norm_stats.keys()}" |
|
) |
|
|
|
|
|
unnorm_key = unnorm_key if unnorm_key is not None else next(iter(norm_stats.keys())) |
|
if unnorm_key not in norm_stats: |
|
raise ValueError( |
|
f"The `unnorm_key` you chose ({unnorm_key = }) is not in the available statistics. " |
|
f"Please choose from: {norm_stats.keys()}" |
|
) |
|
|
|
return unnorm_key |
|
|
|
def get_action_dim(self, unnorm_key: Optional[str] = None) -> int: |
|
"""Get the dimensionality of the policy's action space.""" |
|
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
|
return len(self.norm_stats[unnorm_key]["action"]["q01"]) |
|
|
|
def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]: |
|
"""Get all the logged statistics for the given dataset.""" |
|
unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key) |
|
return self.norm_stats[unnorm_key]["action"] |
|
|
|
def remove_waypointer_prefix(ckpt): |
|
new_state_dict = {} |
|
for key, value in ckpt.items(): |
|
|
|
if key.startswith('waypointer.'): |
|
new_key = key[len('waypointer.'):] |
|
else: |
|
new_key = key |
|
new_state_dict[new_key] = value |
|
return new_state_dict |
|
|
|
def image_processor(image): |
|
image_resolution = (3,224,224) |
|
image = image.resize(image_resolution[1:], resample=Image.Resampling.LANCZOS) |
|
|
|
def read_pt(pt_path): |
|
data = torch.load(pt_path) |
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
prismatic_config_dict = { |
|
"vision_backbone_id":"dinosiglip-vit-so-224px", |
|
"llm_backbone_id":"llama2-7b-pure", |
|
"arch_specifier": "no-align+gelu-mlp", |
|
"use_fused_vision_backbone" :True, |
|
"image_resize_strategy" : "letterbox", |
|
"text_config" : None, |
|
"llm_max_length" : 2048, |
|
"pad_token_id" :32000, |
|
"pad_to_multiple_of" : 64, |
|
"output_projector_states" : False, |
|
"return_dict": False, |
|
} |
|
|
|
token_proj_config = { |
|
"vit_tokens_layers": [2176, 1024], |
|
"control_tokens_layers": [4096, 2048, 1024], |
|
"image_tokens_mode": 'vit', |
|
'llm_image_tokens_layers': [] |
|
} |
|
timestep_proj_config = { |
|
"pos_embed_scale": 8, |
|
"proj_layers": [128,512,1024], |
|
"time_delta_sec": 0.1, |
|
"num_tokens":3 |
|
} |
|
pos_embed_config = { |
|
"num_embeddings": 300, |
|
"embedding_dim": 1024 |
|
} |
|
encoder_block_config = { |
|
"feature_size": 1024, |
|
"head_dim": 64, |
|
"num_heads": 16 |
|
} |
|
decoder_block_config = { |
|
"feature_size": 1024, |
|
"head_dim": 64, |
|
"num_heads": 16, |
|
"dropout": 0.0 |
|
} |
|
transformer_config = { |
|
"pos_embed_config": pos_embed_config, |
|
"encoder_block_config": encoder_block_config, |
|
"decoder_block_config": decoder_block_config, |
|
"num_blocks": 2 |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TrajectoryVlaConfig_config = { |
|
"prismatic_config":prismatic_config_dict, |
|
"token_size": 1024, |
|
"cheat": False, |
|
"num_timesteps": 6, |
|
"rotation_components": 9, |
|
"seperate_control_proj": True, |
|
"timestep_proj_config": timestep_proj_config, |
|
"token_proj_config": token_proj_config, |
|
"transformer_config": transformer_config, |
|
"num_timestep_tokens": 3, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_config = TrajectoryVLAConfig( **TrajectoryVlaConfig_config) |
|
|
|
|
|
model = TrajectoryVLA(model_config) |
|
model = model.to(dtype=torch.bfloat16) |
|
model = model.to('cuda') |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|