Image-Text-to-Text
Safetensors
openvla
custom_code
Emrys-Hong commited on
Commit
1b73161
·
1 Parent(s): 264f92b
Files changed (4) hide show
  1. config.json +5 -1
  2. configuration_prismatic.py +140 -0
  3. modeling_prismatic.py +619 -0
  4. solver.py +191 -0
config.json CHANGED
@@ -1,8 +1,12 @@
1
  {
2
  "arch_specifier": "no-align+fused-gelu-mlp",
3
  "architectures": [
4
- "OpenVLAForActionPrediction"
5
  ],
 
 
 
 
6
  "hf_llm_id": "meta-llama/Llama-2-7b-hf",
7
  "image_resize_strategy": "resize-naive",
8
  "image_sizes": [
 
1
  {
2
  "arch_specifier": "no-align+fused-gelu-mlp",
3
  "architectures": [
4
+ "Emma-X"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_prismatic.OpenVLAConfig",
8
+ "AutoModelForVision2Seq": "modeling_prismatic.EmmaxForActionPrediction"
9
+ },
10
  "hf_llm_id": "meta-llama/Llama-2-7b-hf",
11
  "image_resize_strategy": "resize-naive",
12
  "image_sizes": [
configuration_prismatic.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ configuration_prismatic.py
3
+
4
+ HuggingFace-style configuration definition for Prismatic VLMs, inheriting from `transformers.PretrainedConfig`.
5
+ Default configuration specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from transformers import PretrainedConfig
11
+ from transformers.models.auto import CONFIG_MAPPING
12
+
13
+ # === Utilities for Mapping Prismatic names to HF names ===
14
+ # fmt: off
15
+ VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = {
16
+ "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224],
17
+
18
+ "clip-vit-l-336px": [336],
19
+ "siglip-vit-so400m-384px": [384],
20
+
21
+ "dinoclip-vit-l-336px": [336, 336],
22
+ "dinosiglip-vit-so-224px": [224, 224],
23
+ "dinosiglip-vit-so-384px": [384, 384],
24
+ }
25
+ VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = {
26
+ "clip-vit-l": ["vit_large_patch14_clip_224.openai"],
27
+ "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"],
28
+
29
+ "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"],
30
+ "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"],
31
+
32
+ "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"],
33
+ "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"],
34
+
35
+ "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"],
36
+ "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"],
37
+ "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"],
38
+ }
39
+ TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = {
40
+ "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"],
41
+ "dinov2-vit-l": [None], "in1k-vit-l": [None],
42
+ "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None],
43
+ "dinoclip-vit-l-336px": [None, "quick_gelu"],
44
+ "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None]
45
+ }
46
+
47
+ LLM_BACKBONE_TO_HF_PATH = {
48
+ "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf",
49
+ "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf",
50
+
51
+ "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5",
52
+
53
+ "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1",
54
+ "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
55
+
56
+ "phi-2-3b": "microsoft/phi-2",
57
+ }
58
+ LLM_BACKBONE_TO_HF_METACLASS = {
59
+ "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama",
60
+ "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama",
61
+
62
+ "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral",
63
+
64
+ "phi-2-3b": "phi",
65
+ }
66
+
67
+ VALID_VISION_BACKBONES = set(VISION_BACKBONE_TO_RESOLUTION.keys())
68
+ VALID_LLM_BACKBONES = set(LLM_BACKBONE_TO_HF_PATH)
69
+ # fmt: on
70
+
71
+
72
+ class PrismaticConfig(PretrainedConfig):
73
+ model_type: str = "prismatic"
74
+ is_composition: bool = False
75
+
76
+ def __init__(
77
+ self,
78
+ vision_backbone_id: str = "siglip-vit-so400m",
79
+ llm_backbone_id: str = "vicuna-v15-7b",
80
+ arch_specifier: str = "no-align+gelu-mlp",
81
+ use_fused_vision_backbone: Optional[bool] = None,
82
+ image_resize_strategy: str = "letterbox",
83
+ text_config: Optional[Dict[str, Any]] = None,
84
+ llm_max_length: int = 2048,
85
+ pad_token_id: int = 32000,
86
+ pad_to_multiple_of: int = 64,
87
+ output_projector_states: bool = False,
88
+ **kwargs: str,
89
+ ) -> None:
90
+ if vision_backbone_id not in VALID_VISION_BACKBONES:
91
+ raise ValueError(f"Vision backbone `{vision_backbone_id}` not in {VALID_VISION_BACKBONES = }")
92
+
93
+ if llm_backbone_id not in VALID_LLM_BACKBONES:
94
+ raise ValueError(f"LLM backbone `{llm_backbone_id}` not in {VALID_LLM_BACKBONES = }")
95
+
96
+ # Set Prismatic Configuration Fields
97
+ self.vision_backbone_id = vision_backbone_id
98
+ self.llm_backbone_id = llm_backbone_id
99
+ self.arch_specifier = arch_specifier
100
+ self.output_projector_states = output_projector_states
101
+
102
+ # [Contract] All vision backbone parameters are lists =>> supports fused backbones with different preprocessing
103
+ self.use_fused_vision_backbone = (
104
+ use_fused_vision_backbone
105
+ if use_fused_vision_backbone is not None
106
+ else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"])
107
+ )
108
+
109
+ self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID[self.vision_backbone_id]
110
+ self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER[self.vision_backbone_id]
111
+ self.image_sizes = VISION_BACKBONE_TO_RESOLUTION[self.vision_backbone_id]
112
+ self.image_resize_strategy = image_resize_strategy
113
+
114
+ self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH[self.llm_backbone_id]
115
+ self.llm_max_length = llm_max_length
116
+ self.pad_token_id, self.pad_to_multiple_of = pad_token_id, pad_to_multiple_of
117
+
118
+ # [IMPORTANT] HF Utilities actually look for a `text_config` field... we need to use that specific naming!
119
+ self.text_config = (
120
+ CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]](**text_config)
121
+ if text_config is not None
122
+ else CONFIG_MAPPING[LLM_BACKBONE_TO_HF_METACLASS[self.llm_backbone_id]]()
123
+ )
124
+
125
+ # Dispatch **kwargs to super() =>> note that `pad_token_id` collides, so we pass it in here as well...
126
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
127
+
128
+
129
+ class OpenVLAConfig(PrismaticConfig):
130
+ model_type: str = "openvla"
131
+
132
+ def __init__(
133
+ self,
134
+ norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None,
135
+ n_action_bins: int = 256,
136
+ **kwargs: str,
137
+ ) -> None:
138
+ self.norm_stats, self.n_action_bins = norm_stats, n_action_bins
139
+
140
+ super().__init__(**kwargs)
modeling_prismatic.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
5
+ from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
6
+ logic in `prismatic.models.vlms.prismatic.py`.
7
+
8
+ Note =>> for the time being, not adding the custom HF "docstring" formatting.
9
+
10
+ References [LLaVa, IDEFICS-2]:
11
+ => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
12
+ => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
13
+ """
14
+
15
+ import logging
16
+ from dataclasses import dataclass
17
+ from functools import partial
18
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import timm
22
+ import tokenizers
23
+ import torch
24
+ import torch.nn as nn
25
+ import transformers
26
+ from timm.models.vision_transformer import LayerScale
27
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
28
+ from transformers.modeling_outputs import ModelOutput
29
+
30
+ from PIL import Image
31
+
32
+ from .configuration_prismatic import OpenVLAConfig, PrismaticConfig
33
+ from solver import solver
34
+
35
+ # Get Logger
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # === PyTorch/HuggingFace Default IGNORE_INDEX (for CrossEntropyLoss labels)
40
+ IGNORE_INDEX = -100
41
+
42
+
43
+ # === Utility Functions for Monkey-Patching ===
44
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
45
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
46
+ result = fn(*args, **kwargs)
47
+ return result[0] if isinstance(result, tuple) else result
48
+
49
+ return wrapper
50
+
51
+
52
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
53
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
54
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
55
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
57
+
58
+
59
+ def ls_apply_patch(ls_module: LayerScale):
60
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
61
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
62
+ del ls_module.gamma
63
+
64
+
65
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
66
+ class PrismaticVisionBackbone(nn.Module):
67
+ def __init__(
68
+ self,
69
+ use_fused_vision_backbone: bool,
70
+ image_sizes: List[int],
71
+ timm_model_ids: List[str],
72
+ timm_override_act_layers: List[Optional[str]],
73
+ ) -> None:
74
+ super().__init__()
75
+ self.use_fused_vision_backbone = use_fused_vision_backbone
76
+
77
+ # [Contract] Validate number of (fused) vision backbones, create "alpha" featurizer and Instantiate
78
+ # =>> Note :: Monkey-Patch the `forward()` function of the backbone to ensure FSDP-compatibility
79
+ # Hardcodes `get_intermediate_layers` to return the **SECOND-TO-LAST** layer patches!
80
+ assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!"
81
+ self.featurizer = timm.create_model(
82
+ timm_model_ids[0],
83
+ pretrained=False,
84
+ num_classes=0,
85
+ img_size=image_sizes[0],
86
+ act_layer=timm_override_act_layers[0],
87
+ )
88
+ self.featurizer.forward = unpack_tuple(
89
+ partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2})
90
+ )
91
+ self.embed_dim = self.featurizer.embed_dim
92
+
93
+ # If `use_fused_vision_backbone` =>> create "beta" featurizer
94
+ if self.use_fused_vision_backbone:
95
+ self.fused_featurizer = timm.create_model(
96
+ timm_model_ids[1],
97
+ pretrained=False,
98
+ num_classes=0,
99
+ img_size=image_sizes[1],
100
+ act_layer=timm_override_act_layers[1],
101
+ )
102
+ self.fused_featurizer.forward = unpack_tuple(
103
+ partial(self.fused_featurizer.get_intermediate_layers, n={len(self.fused_featurizer.blocks) - 2})
104
+ )
105
+ self.embed_dim += self.fused_featurizer.embed_dim
106
+
107
+ # Patch `vision_backbone.featurizer` and `vision_backbone.fused_featurizer` with HF-Compatible LayerScale
108
+ for module in self.featurizer.modules():
109
+ if isinstance(module, LayerScale):
110
+ ls_apply_patch(module)
111
+
112
+ if self.use_fused_vision_backbone:
113
+ for module in self.fused_featurizer.modules():
114
+ if isinstance(module, LayerScale):
115
+ ls_apply_patch(module)
116
+
117
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
118
+ """Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack."""
119
+ if not self.use_fused_vision_backbone:
120
+ return self.featurizer(pixel_values)
121
+
122
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
123
+ img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
124
+ patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused)
125
+
126
+ return torch.cat([patches, patches_fused], dim=2)
127
+
128
+
129
+ # === Prismatic Projector (nn.Module) Definitions ===
130
+ class PrismaticProjector(nn.Module):
131
+ def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None:
132
+ super().__init__()
133
+ self.use_fused_vision_backbone = use_fused_vision_backbone
134
+ self.vision_dim, self.llm_dim = vision_dim, llm_dim
135
+
136
+ # Switch on `use_fused_vision_backbone` =>> use slightly different MLPs and projection factors!
137
+ if not self.use_fused_vision_backbone:
138
+ self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True)
139
+ self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
140
+ self.act_fn1 = nn.GELU()
141
+ else:
142
+ initial_projection_dim = 4 * vision_dim
143
+ self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True)
144
+ self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True)
145
+ self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True)
146
+ self.act_fn1 = nn.GELU()
147
+ self.act_fn2 = nn.GELU()
148
+
149
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
150
+ if not self.use_fused_vision_backbone:
151
+ projected_features = self.fc1(img_patches)
152
+ projected_features = self.act_fn1(projected_features)
153
+ projected_features = self.fc2(projected_features)
154
+ else:
155
+ projected_features = self.fc1(img_patches)
156
+ projected_features = self.act_fn1(projected_features)
157
+ projected_features = self.fc2(projected_features)
158
+ projected_features = self.act_fn2(projected_features)
159
+ projected_features = self.fc3(projected_features)
160
+
161
+ return projected_features
162
+
163
+
164
+ # === Main HF Class Definitions ===
165
+ @dataclass
166
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
167
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
168
+
169
+ loss: Optional[torch.FloatTensor] = None
170
+ logits: torch.FloatTensor = None
171
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
172
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
173
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
174
+
175
+ # Additions for VLMs
176
+ projector_features: Optional[torch.FloatTensor] = None
177
+
178
+
179
+ class PrismaticPreTrainedModel(PreTrainedModel):
180
+ config_class: PretrainedConfig = PrismaticConfig
181
+ base_model_prefix: str = "model"
182
+ supports_gradient_checkpointing: bool = True
183
+
184
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
185
+ _skip_keys_device_placement: str = "past_key_values"
186
+ _supports_flash_attn_2: bool = True
187
+
188
+ def _init_weights(self, module: nn.Module) -> None:
189
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
190
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
191
+ # https://github.com/TRI-ML/prismatic-vlms
192
+ std = (
193
+ self.config.initializer_range
194
+ if hasattr(self.config, "initializer_range")
195
+ else self.config.text_config.initializer_range
196
+ )
197
+
198
+ if hasattr(module, "class_embedding"):
199
+ module.class_embedding.data.normal_(mean=0.0, std=std)
200
+
201
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
202
+ module.weight.data.normal_(mean=0.0, std=std)
203
+ if module.bias is not None:
204
+ module.bias.data.zero_()
205
+ elif isinstance(module, nn.Embedding):
206
+ module.weight.data.normal_(mean=0.0, std=std)
207
+ if module.padding_idx is not None:
208
+ module.weight.data[module.padding_idx].zero_()
209
+
210
+ @property
211
+ def _supports_sdpa(self) -> bool:
212
+ """Check LLM supports SDPA Attention"""
213
+ return self.language_model._supports_sdpa
214
+
215
+
216
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
217
+ def __init__(self, config: PrismaticConfig) -> None:
218
+ super().__init__(config)
219
+
220
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
221
+ if config.use_fused_vision_backbone is None:
222
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
223
+
224
+ if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
225
+ raise NotImplementedError(
226
+ "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
227
+ "if you urgently need support for latest TIMM versions."
228
+ )
229
+
230
+ if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
231
+ logger.warning(
232
+ f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
233
+ f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
234
+ f"there might be inference-time regressions due to dependency changes. If in doubt, please"
235
+ f"use the above versions."
236
+ )
237
+
238
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
239
+ self.vision_backbone = PrismaticVisionBackbone(
240
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
241
+ )
242
+
243
+ # Create Multimodal Projector
244
+ self.projector = PrismaticProjector(
245
+ config.use_fused_vision_backbone,
246
+ vision_dim=self.vision_backbone.embed_dim,
247
+ llm_dim=config.text_config.hidden_size,
248
+ )
249
+
250
+ # Instantiate LLM Backbone
251
+ self.language_model = AutoModelForCausalLM.from_config(
252
+ config.text_config, attn_implementation=config._attn_implementation
253
+ )
254
+ self.vocab_size = config.text_config.vocab_size
255
+ self.pad_token_id = config.pad_token_id
256
+
257
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
258
+ self.post_init()
259
+
260
+ # === `PreTrainedModel` Boilerplate ===
261
+ def get_input_embeddings(self) -> nn.Module:
262
+ return self.language_model.get_input_embeddings()
263
+
264
+ def set_input_embeddings(self, value: nn.Module) -> None:
265
+ self.language_model.set_input_embeddings(value)
266
+
267
+ def get_output_embeddings(self) -> nn.Module:
268
+ return self.language_model.get_output_embeddings()
269
+
270
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
271
+ self.language_model.set_output_embeddings(new_embeddings)
272
+
273
+ def get_decoder(self) -> nn.Module:
274
+ return self.language_model.get_decoder()
275
+
276
+ def set_decoder(self, decoder: nn.Module) -> None:
277
+ self.language_model.set_decoder(decoder)
278
+
279
+ def tie_weights(self) -> None:
280
+ self.language_model.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
281
+
282
+ def resize_token_embeddings(
283
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
284
+ ) -> nn.Embedding:
285
+ updated_embeddings = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
286
+
287
+ # Update config/instance variables
288
+ self.config.text_config.vocab_size = updated_embeddings.num_embeddings
289
+ self.vocab_size = updated_embeddings.num_embeddings
290
+
291
+ return updated_embeddings
292
+
293
+ # === Core Prismatic VLM `forward()` Logic ===
294
+ def forward(
295
+ self,
296
+ input_ids: Optional[torch.LongTensor] = None,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ pixel_values: Optional[torch.FloatTensor] = None,
299
+ labels: Optional[torch.LongTensor] = None,
300
+ inputs_embeds: Optional[torch.FloatTensor] = None,
301
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
302
+ use_cache: Optional[bool] = None,
303
+ output_attentions: Optional[bool] = None,
304
+ output_hidden_states: Optional[bool] = None,
305
+ output_projector_features: Optional[bool] = None,
306
+ return_dict: Optional[bool] = None,
307
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
308
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
309
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
310
+ output_hidden_states = (
311
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
312
+ )
313
+ output_projector_features = output_projector_features if output_projector_features is not None else False
314
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
315
+
316
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
317
+ use_cache = use_cache and not self.training
318
+
319
+ # Instantiate Placeholder for Projector Features
320
+ projected_patch_embeddings = None
321
+
322
+ # Note :: We only support forward passes with the following cases:
323
+ # => Cached Generation :: (input_ids.shape[1] == 1) and (past_key_values is not None)
324
+ # => Unimodal Forward :: (pixel_values is None)
325
+ # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
326
+
327
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
328
+ if input_ids.shape[1] == 1:
329
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
330
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
331
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
332
+
333
+ language_model_output = self.language_model(
334
+ input_ids=input_ids,
335
+ attention_mask=None,
336
+ position_ids=None,
337
+ past_key_values=past_key_values,
338
+ inputs_embeds=None,
339
+ labels=None,
340
+ use_cache=use_cache,
341
+ output_attentions=output_attentions,
342
+ output_hidden_states=output_hidden_states,
343
+ return_dict=return_dict,
344
+ )
345
+
346
+ # === Handle Unimodal Forward ===
347
+ elif pixel_values is None:
348
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
349
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
350
+
351
+ language_model_output = self.language_model(
352
+ input_ids=input_ids,
353
+ attention_mask=attention_mask,
354
+ position_ids=None,
355
+ past_key_values=None,
356
+ inputs_embeds=None,
357
+ labels=labels,
358
+ use_cache=use_cache,
359
+ output_attentions=output_attentions,
360
+ output_hidden_states=output_hidden_states,
361
+ return_dict=return_dict,
362
+ )
363
+
364
+ # === Handle Multimodal Forward ===
365
+ elif (input_ids.shape[0] == pixel_values.shape[0]) or (inputs_embeds.shape[0] == pixel_values.shape[0]):
366
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
367
+
368
+ # Visual Feature Extraction
369
+ patch_features = self.vision_backbone(pixel_values)
370
+
371
+ # Projection Logic =>> Update Attention Mask
372
+ projected_patch_embeddings = self.projector(patch_features)
373
+ projected_patch_attention_mask = None
374
+ if attention_mask is not None:
375
+ projected_patch_attention_mask = torch.full(
376
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
377
+ fill_value=True,
378
+ dtype=attention_mask.dtype,
379
+ device=attention_mask.device,
380
+ )
381
+
382
+ # Get Input Embeddings (from Language Model Embeddings)
383
+ input_embeddings = self.get_input_embeddings()(input_ids)
384
+
385
+ # Build Multimodal Embeddings & Attention Mask =>> Prismatic defaults to inserting after <BOS> token (1:)
386
+ multimodal_embeddings = torch.cat(
387
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
388
+ )
389
+ multimodal_attention_mask = None
390
+ if attention_mask is not None:
391
+ multimodal_attention_mask = torch.cat(
392
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
393
+ )
394
+
395
+ # Build Labels (if specified) =>> Ignore Labels for Patch Embeddings
396
+ multimodal_labels = None
397
+ if labels is not None:
398
+ projected_patch_labels = torch.full(
399
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
400
+ fill_value=IGNORE_INDEX,
401
+ dtype=labels.dtype,
402
+ device=labels.device,
403
+ )
404
+ multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
405
+
406
+ # Dispatch to Language Model
407
+ language_model_output = self.language_model(
408
+ input_ids=None,
409
+ attention_mask=multimodal_attention_mask,
410
+ position_ids=None,
411
+ past_key_values=None,
412
+ inputs_embeds=multimodal_embeddings,
413
+ labels=multimodal_labels,
414
+ use_cache=use_cache,
415
+ output_attentions=output_attentions,
416
+ output_hidden_states=output_hidden_states,
417
+ return_dict=return_dict,
418
+ )
419
+
420
+ # === Otherwise =>> Assume Invalid! ===
421
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
422
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
423
+
424
+ else:
425
+ raise ValueError(
426
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
427
+ f"=> `input_ids` = {input_ids is not None}\n"
428
+ f"=> `attention_mask` = {attention_mask is not None}\n"
429
+ f"=> `pixel_values` = {pixel_values is not None}\n"
430
+ f"=> `labels` = {labels is not None}\n"
431
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
432
+ f"=> `past_key_values` = {past_key_values is not None}\n"
433
+ f"=> `use_cache` = {use_cache}"
434
+ )
435
+
436
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
437
+ if not return_dict:
438
+ if output_projector_features and (projected_patch_embeddings is not None):
439
+ return *language_model_output, projected_patch_embeddings
440
+
441
+ return language_model_output
442
+
443
+ return PrismaticCausalLMOutputWithPast(
444
+ loss=language_model_output.loss,
445
+ logits=language_model_output.logits,
446
+ past_key_values=language_model_output.past_key_values,
447
+ hidden_states=language_model_output.hidden_states,
448
+ attentions=language_model_output.attentions,
449
+ projector_features=projected_patch_embeddings,
450
+ )
451
+
452
+ # === GenerationMixin Methods ===
453
+ def prepare_inputs_for_generation(
454
+ self,
455
+ input_ids: Optional[torch.Tensor] = None,
456
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
457
+ inputs_embeds: Optional[torch.FloatTensor] = None,
458
+ pixel_values: Optional[torch.FloatTensor] = None,
459
+ attention_mask: Optional[torch.Tensor] = None,
460
+ **kwargs: str,
461
+ ) -> Dict[str, torch.Tensor]:
462
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
463
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
464
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
465
+ ):
466
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
467
+
468
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
469
+ if past_key_values is not None:
470
+ input_ids = input_ids[:, -1:]
471
+
472
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
473
+ if inputs_embeds is not None and past_key_values is None:
474
+ model_inputs = {"input_embeds": inputs_embeds}
475
+ else:
476
+ model_inputs = {"input_ids": input_ids}
477
+
478
+ # Make sure `pixel_values` are preserved in `model_inputs`
479
+ model_inputs.update(
480
+ {
481
+ "attention_mask": attention_mask,
482
+ "pixel_values": pixel_values,
483
+ "past_key_values": past_key_values,
484
+ "use_cache": kwargs.get("use_cache"),
485
+ }
486
+ )
487
+
488
+ return model_inputs
489
+
490
+ # Defer to Language Model (all handle this differently, with different return types)
491
+ def _reorder_cache(self, *args, **kwargs) -> Any:
492
+ return self.language_model._reorder_cache(*args, **kwargs)
493
+
494
+
495
+ class EmmaxForActionPrediction(PrismaticForConditionalGeneration):
496
+ config_class: PretrainedConfig = OpenVLAConfig
497
+
498
+ def __init__(self, config: OpenVLAConfig) -> None:
499
+ super().__init__(config)
500
+ self.norm_stats = config.norm_stats
501
+
502
+ # Compute action bins
503
+ self.bins = np.linspace(-1, 1, config.n_action_bins)
504
+ self.bin_centers = (self.bins[:-1] + self.bins[1:]) / 2.0
505
+
506
+ # Compute vocab size for de-tokenization -- revert added "multiple of"
507
+ self.vocab_size = self.config.text_config.vocab_size - self.config.pad_to_multiple_of
508
+
509
+ def predict_action(
510
+ self, input_ids: Optional[torch.LongTensor] = None, unnorm_key: Optional[str] = None, **kwargs: str
511
+ ) -> np.ndarray:
512
+ """Thin wrapper around super().generate() that decodes predicted actions and de-normalizes them."""
513
+
514
+ # We need to add this special empty token ('') after the colon (':') token in "ASSISTANT:"
515
+ # in order for the predictions to match the training configuration and be accurate.
516
+ # NOTE: This is NOT needed for ECoT
517
+ # input_ids = torch.cat(
518
+ # (input_ids, torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(input_ids.device)), dim=1
519
+ # )
520
+
521
+ # Run VLA inference
522
+ generated_ids = self.generate(input_ids, **kwargs)
523
+
524
+ # Extract predicted action tokens and translate into (normalized) continuous actions
525
+ predicted_action_token_ids = generated_ids[0, -(self.get_action_dim(unnorm_key) + 1) : -1].cpu().numpy()
526
+ discretized_actions = self.vocab_size - predicted_action_token_ids
527
+ discretized_actions = np.clip(discretized_actions - 1, a_min=0, a_max=self.bin_centers.shape[0] - 1)
528
+ normalized_actions = self.bin_centers[discretized_actions]
529
+
530
+ # Unnormalize actions
531
+ action_norm_stats = self.get_action_stats(unnorm_key)
532
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
533
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
534
+ actions = np.where(
535
+ mask,
536
+ 0.5 * (normalized_actions + 1) * (action_high - action_low) + action_low,
537
+ normalized_actions,
538
+ )
539
+
540
+
541
+ return actions, generated_ids
542
+
543
+ @torch.inference_mode()
544
+ def generate_actions(self, image: Image, prompt_text: str, type: str, **kwargs: str) -> str:
545
+ # For now, only support generation with a batch size of 1 for simplicity
546
+ image_transform, tokenizer = self.vision_backbone.image_transform, self.llm_backbone.tokenizer
547
+
548
+ # Prepare Inputs
549
+ input_ids = tokenizer(prompt_text, truncation=True, return_tensors="pt").input_ids.to(self.device)
550
+
551
+ pixel_values = image_transform(image)
552
+ if isinstance(pixel_values, torch.Tensor):
553
+ pixel_values = pixel_values[None, ...].to(self.device)
554
+ elif isinstance(pixel_values, dict):
555
+ pixel_values = {k: v[None, ...].to(self.device) for k, v in pixel_values.items()}
556
+ else:
557
+ raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")
558
+
559
+ # Invoke super().generate --> taps into `GenerationMixin` which (redirects) to `forward()`
560
+ autocast_dtype = self.llm_backbone.half_precision_dtype
561
+ # with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
562
+ with torch.autocast("cuda", dtype=torch.float16):
563
+ # fmt: off
564
+ generated_ids = super().generate(
565
+ input_ids=input_ids, # Shape: [1, seq]
566
+ pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]]
567
+ **kwargs
568
+ )
569
+ # fmt: on
570
+
571
+ generated_text = tokenizer.decode(generated_ids[0, input_ids.shape[1] :], skip_special_tokens=True).strip()
572
+
573
+ s = solver
574
+ actions, reasoning = s.extract_action_policies(generated_text)
575
+ # unnorm_key = "bridge_orig"
576
+
577
+ # unnormalize
578
+ unnorm_key = None
579
+ action_norm_stats = self.get_action_stats(unnorm_key)
580
+ mask = action_norm_stats.get("mask", np.ones_like(action_norm_stats["q01"], dtype=bool))
581
+ action_high, action_low = np.array(action_norm_stats["q99"]), np.array(action_norm_stats["q01"])
582
+ _actions = []
583
+ for action in actions:
584
+ action_norm = np.where(
585
+ mask, 0.5 * (np.array(action) + 1) * (action_high - action_low) + action_low, action
586
+ )
587
+ _actions.append(action_norm)
588
+
589
+ return _actions, generated_text
590
+
591
+
592
+ @staticmethod
593
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
594
+ if unnorm_key is None and len(norm_stats) != 1:
595
+ raise ValueError(
596
+ f"Your model was trained on more than one dataset. "
597
+ f"Please pass a `unnorm_key` from the following options to choose the statistics used for "
598
+ f"de-normalizing actions: {norm_stats.keys()}"
599
+ )
600
+
601
+ # If None, grab the (singular) dataset in `norm_stats` to use as `unnorm_key`
602
+ unnorm_key = unnorm_key if unnorm_key is not None else next(iter(norm_stats.keys()))
603
+ if unnorm_key not in norm_stats:
604
+ raise ValueError(
605
+ f"The `unnorm_key` you chose ({unnorm_key = }) is not in the available statistics. "
606
+ f"Please choose from: {norm_stats.keys()}"
607
+ )
608
+
609
+ return unnorm_key
610
+
611
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
612
+ """Get the dimensionality of the policy's action space."""
613
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
614
+ return len(self.norm_stats[unnorm_key]["action"]["q01"])
615
+
616
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
617
+ """Get all the logged statistics for the given dataset."""
618
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
619
+ return self.norm_stats[unnorm_key]["action"]
solver.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
+ import numpy as np
4
+ from prismatic.vla.action_tokenizer import ActionTokenizer
5
+ from transformers import AutoTokenizer
6
+
7
+
8
+ class Solver:
9
+ def __init__(self, action_tokenizer=None, verbose=True) -> None:
10
+ self.verbose = verbose
11
+ self.action_tokenizer = action_tokenizer
12
+ self.coordinates_key = "NEXT GRIPPER:"
13
+ self.movement_key = "MOVEMENT:"
14
+ self.policy_key = "POLICIES:"
15
+
16
+ def compare_movement(self, pred_pos, label_pos):
17
+
18
+ dist = np.sum(np.abs(pred_pos - label_pos))
19
+ relative_dist = np.sum(np.abs(dist / label_pos))
20
+ return dist, relative_dist, dist == 0
21
+
22
+ def compare_policy(self, pred_pol, label_pol):
23
+ dist = 0
24
+ cnt = 0
25
+ for i in range(min(len(label_pol), len(pred_pol))):
26
+ for j in range(len(label_pol[0])):
27
+ dist += label_pol[i][j] == pred_pol[i][j]
28
+ cnt += 1
29
+ assert cnt % 7 == 0
30
+ return dist / cnt
31
+
32
+ def extract_2d_coordinates(self, text):
33
+ try:
34
+ coordinates_index = text.index(self.coordinates_key) + len(self.coordinates_key)
35
+ coord = text[coordinates_index:]
36
+ coord = [o for o in coord.split("\n") if len(o.strip()) != 0]
37
+ coord = eval(coord[0].strip())
38
+ except Exception:
39
+ coord = [0, 0]
40
+ return coord
41
+
42
+ def extract_movement_plan(self, text):
43
+ require_unorm = None
44
+ try:
45
+ # text after key word
46
+ movement_index = text.index(self.movement_key) + len(self.movement_key)
47
+ movement_level = text[movement_index:]
48
+ movement_level = [o for o in movement_level.split("\n") if len(o.strip()) != 0]
49
+ movement_level = movement_level[0].strip()
50
+
51
+ if "gripper" not in movement_level: # for normalized tokenized version
52
+ require_unorm = True
53
+ movement_token_ids = self.action_tokenizer.tokenizer(movement_level, add_special_tokens=False).input_ids
54
+ movement_norm = self.action_tokenizer.decode_token_ids_to_actions(np.array(movement_token_ids))
55
+ movement_norm = movement_norm[1:8]
56
+ assert len(movement_norm) == 7
57
+ else: # for unnormalized text version
58
+ require_unorm = False
59
+ movement_level = [o for o in movement_level.split(";") if len(o) > 0]
60
+ movement_level = movement_level[:7]
61
+
62
+ position = defaultdict(int)
63
+ movement_to_pos = dict(
64
+ move_backward=(-1, "y"),
65
+ move_forward=(1, "y"),
66
+ move_right=(-1, "x"),
67
+ move_left=(1, "x"),
68
+ move_downward=(-1, "z"),
69
+ move_upward=(1, "z"),
70
+ roll_downward=(-1, "ox"),
71
+ roll_upward=(1, "ox"),
72
+ swing_downward=(-1, "ox"),
73
+ swing_upward=(1, "ox"),
74
+ pitch_downward=(-1, "oy"),
75
+ pitch_upward=(1, "oy"),
76
+ yaw_downward=(-1, "oz"),
77
+ yaw_upward=(1, "oz"),
78
+ rotate_clockwise=(-1, "oz"),
79
+ rotate_counterclockwise=(1, "oz"),
80
+ close_gripper=(-1, "grip"),
81
+ open_gripper=(1, "grip"),
82
+ )
83
+
84
+ for ml in movement_level:
85
+ direction = "_".join(ml.split()[:2])
86
+ sign, axis = movement_to_pos[direction]
87
+ scale = 1
88
+ if "o" in axis: # for orientation
89
+ scale = scale * 1e-3
90
+ elif "grip" in axis: # for gripper
91
+ scale = scale
92
+ else: # for xyz
93
+ scale = scale / 180 * np.pi
94
+
95
+ if "grip" in axis:
96
+ level = round("open" in ml)
97
+ else:
98
+ level = int(ml.split()[2])
99
+
100
+ position[axis] += sign * scale * level
101
+ movement_norm = [position[idx] for idx in ["x", "y", "z", "ox", "oy", "oz", "grip"]]
102
+
103
+ except:
104
+ movement_norm = [-100] * 7
105
+
106
+ return require_unorm, np.array(movement_norm)
107
+
108
+ def extract_action_policies(self, text):
109
+ try:
110
+ if self.policy_key in text:
111
+
112
+ policy_index = text.index(self.policy_key) + len(self.policy_key)
113
+ policy = text[policy_index:]
114
+ remain_text = text[: text.index(self.policy_key)]
115
+ policies = [o for o in policy.split("\n") if len(o.strip()) != 0]
116
+ policies = policies[0].strip()
117
+ else:
118
+ policies = text.strip()
119
+ remain_text = ""
120
+
121
+ policies_num = []
122
+ for policy_text in policies.split(";"):
123
+ policy_token = self.action_tokenizer.tokenizer(policy_text, add_special_tokens=False).input_ids
124
+ action_policy = self.action_tokenizer.decode_token_ids_to_actions(np.array(policy_token))
125
+ # The first token is meaningless
126
+ action_policy = action_policy[1:]
127
+ action_policy = action_policy[:7]
128
+ # assert len(action_policy) == 7
129
+ if len(action_policy) != 7:
130
+ action_policy = [0] * 7
131
+ policies_num.append(action_policy.tolist())
132
+
133
+ except:
134
+ policies_num = [[0] * 7]
135
+ remain_text = text
136
+
137
+ return policies_num, remain_text
138
+
139
+ def evaluate_single(self, ground_truth, prediction, verbose=False):
140
+ gt_policies, ground_truth = self.extract_action_policies(ground_truth)
141
+ pred_policies, prediction = self.extract_action_policies(prediction)
142
+
143
+ _, pred_movement = self.extract_movement_plan(prediction)
144
+ _, gt_movement = self.extract_movement_plan(ground_truth)
145
+
146
+ dist, relative_dist, _ = self.compare_movement(label_pos=gt_movement, pred_pos=pred_movement)
147
+
148
+ # pred_2d = self.extract_2d_coordinates(prediction)
149
+ # gt_2d = self.extract_2d_coordinates(ground_truth)
150
+
151
+ next_state_score = 0
152
+
153
+ acc = self.compare_policy(label_pol=gt_policies, pred_pol=pred_policies)
154
+
155
+ return next_state_score, acc, dist, relative_dist, pred_policies, gt_policies
156
+
157
+ def evaluate_batch(self, batch_gt, batch_pred, verbose=False):
158
+ state_acc_ls = []
159
+ action_acc_ls = []
160
+ L1_loss_ls = []
161
+ relative_L1_loss_ls = []
162
+ pred_policies_ls = []
163
+ gt_policies_ls = []
164
+ for i in range(len(batch_gt)):
165
+ ground_truth = batch_gt[i]
166
+ prediction = batch_pred[i]
167
+ next_state_score, action_policy_score, L1_dist, relative_L1_dist, pred_policies, gt_policies = (
168
+ self.evaluate_single(ground_truth, prediction)
169
+ )
170
+ state_acc_ls.append(next_state_score)
171
+ action_acc_ls.append(action_policy_score)
172
+ L1_loss_ls.append(L1_dist)
173
+ relative_L1_loss_ls.append(relative_L1_dist)
174
+ pred_policies_ls.append(pred_policies)
175
+ gt_policies_ls.append(gt_policies)
176
+ if verbose:
177
+ print(f"Ground Truth:\n\n {ground_truth}")
178
+ print()
179
+ print(f"prediction:\n\n {prediction}")
180
+ print()
181
+ print(f"Ground Truth Policies:\n\n {gt_policies}")
182
+ print(f"prediction policies:\n\n {pred_policies}")
183
+ print("*" * 40)
184
+
185
+ return state_acc_ls, action_acc_ls, L1_loss_ls, relative_L1_loss_ls, pred_policies_ls, gt_policies_ls
186
+
187
+
188
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", model_max_length=2048, padding_side="right")
189
+ action_tokenizer = ActionTokenizer(tokenizer)
190
+ solver = Solver(action_tokenizer)
191
+