Sombit commited on
Commit
cdba976
·
verified ·
1 Parent(s): 36e08c0

Upload TrajectoryVLA

Browse files
config.json CHANGED
@@ -1,222 +1,37 @@
1
  {
 
 
 
 
2
  "auto_map": {
3
- "AutoConfig": "prismatic_config.TrajectoryVLAConfig"
4
  },
5
- "cheat": false,
6
- "model_type": "trajectoryvla",
7
- "num_timesteps": 6,
8
- "prismatic_config": {
9
- "_name_or_path": "",
10
- "add_cross_attention": false,
11
- "arch_specifier": "no-align+gelu-mlp",
12
- "architectures": [
13
- "TrajectoryVLA"
14
- ],
15
- "auto_map": {
16
- "AutoModelForVision2Seq": "prismatic_model.TrajectoryVLA"
17
- },
18
- "bad_words_ids": null,
19
- "begin_suppress_tokens": null,
20
- "bos_token_id": null,
21
- "chunk_size_feed_forward": 0,
22
- "cross_attention_hidden_size": null,
23
- "decoder_start_token_id": null,
24
- "diversity_penalty": 0.0,
25
- "do_sample": false,
26
- "early_stopping": false,
27
- "encoder_no_repeat_ngram_size": 0,
28
- "eos_token_id": null,
29
- "exponential_decay_length_penalty": null,
30
- "finetuning_task": null,
31
- "forced_bos_token_id": null,
32
- "forced_eos_token_id": null,
33
- "hf_llm_id": "meta-llama/Llama-2-7b-hf",
34
- "id2label": {
35
- "0": "LABEL_0",
36
- "1": "LABEL_1"
37
- },
38
- "image_resize_strategy": "letterbox",
39
- "image_sizes": [
40
- 224,
41
- 224
42
- ],
43
- "is_decoder": false,
44
- "is_encoder_decoder": false,
45
- "label2id": {
46
- "LABEL_0": 0,
47
- "LABEL_1": 1
48
- },
49
- "length_penalty": 1.0,
50
- "llm_backbone_id": "llama2-7b-pure",
51
- "llm_max_length": 2048,
52
- "max_length": 20,
53
- "min_length": 0,
54
- "model_type": "prismatic",
55
- "no_repeat_ngram_size": 0,
56
- "num_beam_groups": 1,
57
- "num_beams": 1,
58
- "num_return_sequences": 1,
59
- "output_attentions": false,
60
- "output_hidden_states": false,
61
- "output_projector_states": false,
62
- "output_scores": false,
63
- "pad_to_multiple_of": 64,
64
- "pad_token_id": 32000,
65
- "prefix": null,
66
- "problem_type": null,
67
- "pruned_heads": {},
68
- "remove_invalid_values": false,
69
- "repetition_penalty": 1.0,
70
- "return_dict": false,
71
- "return_dict_in_generate": false,
72
- "sep_token_id": null,
73
- "suppress_tokens": null,
74
- "task_specific_params": null,
75
- "temperature": 1.0,
76
- "text_config": {
77
- "_name_or_path": "",
78
- "add_cross_attention": false,
79
- "architectures": null,
80
- "attention_bias": false,
81
- "attention_dropout": 0.0,
82
- "bad_words_ids": null,
83
- "begin_suppress_tokens": null,
84
- "bos_token_id": 1,
85
- "chunk_size_feed_forward": 0,
86
- "cross_attention_hidden_size": null,
87
- "decoder_start_token_id": null,
88
- "diversity_penalty": 0.0,
89
- "do_sample": false,
90
- "early_stopping": false,
91
- "encoder_no_repeat_ngram_size": 0,
92
- "eos_token_id": 2,
93
- "exponential_decay_length_penalty": null,
94
- "finetuning_task": null,
95
- "forced_bos_token_id": null,
96
- "forced_eos_token_id": null,
97
- "hidden_act": "silu",
98
- "hidden_size": 4096,
99
- "id2label": {
100
- "0": "LABEL_0",
101
- "1": "LABEL_1"
102
- },
103
- "initializer_range": 0.02,
104
- "intermediate_size": 11008,
105
- "is_decoder": false,
106
- "is_encoder_decoder": false,
107
- "label2id": {
108
- "LABEL_0": 0,
109
- "LABEL_1": 1
110
- },
111
- "length_penalty": 1.0,
112
- "max_length": 20,
113
- "max_position_embeddings": 2048,
114
- "min_length": 0,
115
- "mlp_bias": false,
116
- "model_type": "llama",
117
- "no_repeat_ngram_size": 0,
118
- "num_attention_heads": 32,
119
- "num_beam_groups": 1,
120
- "num_beams": 1,
121
- "num_hidden_layers": 32,
122
- "num_key_value_heads": 32,
123
- "num_return_sequences": 1,
124
- "output_attentions": false,
125
- "output_hidden_states": false,
126
- "output_scores": false,
127
- "pad_token_id": null,
128
- "prefix": null,
129
- "pretraining_tp": 1,
130
- "problem_type": null,
131
- "pruned_heads": {},
132
- "remove_invalid_values": false,
133
- "repetition_penalty": 1.0,
134
- "return_dict": true,
135
- "return_dict_in_generate": false,
136
- "rms_norm_eps": 1e-06,
137
- "rope_scaling": null,
138
- "rope_theta": 10000.0,
139
- "sep_token_id": null,
140
- "suppress_tokens": null,
141
- "task_specific_params": null,
142
- "temperature": 1.0,
143
- "tf_legacy_loss": false,
144
- "tie_encoder_decoder": false,
145
- "tie_word_embeddings": false,
146
- "tokenizer_class": null,
147
- "top_k": 50,
148
- "top_p": 1.0,
149
- "torch_dtype": null,
150
- "torchscript": false,
151
- "typical_p": 1.0,
152
- "use_bfloat16": false,
153
- "use_cache": true,
154
- "vocab_size": 32000
155
- },
156
- "tf_legacy_loss": false,
157
- "tie_encoder_decoder": false,
158
- "tie_word_embeddings": true,
159
- "timm_model_ids": [
160
- "vit_large_patch14_reg4_dinov2.lvd142m",
161
- "vit_so400m_patch14_siglip_224"
162
- ],
163
- "timm_override_act_layers": [
164
- null,
165
- null
166
- ],
167
- "tokenizer_class": null,
168
- "top_k": 50,
169
- "top_p": 1.0,
170
- "torch_dtype": "bfloat16",
171
- "torchscript": false,
172
- "typical_p": 1.0,
173
- "use_bfloat16": false,
174
- "use_fused_vision_backbone": true,
175
- "vision_backbone_id": "dinosiglip-vit-so-224px"
176
  },
177
- "rotation_components": 9,
178
- "seperate_control_proj": true,
179
- "timestep_proj_config": {
180
- "num_tokens": 3,
181
- "pos_embed_scale": 8,
182
- "proj_layers": [
183
- 128,
184
- 512,
185
- 1024
186
- ],
187
- "time_delta_sec": 0.1
188
- },
189
- "token_proj_config": {
190
- "control_tokens_layers": [
191
- 4096,
192
- 2048,
193
- 1024
194
- ],
195
- "image_tokens_mode": "vit",
196
- "llm_image_tokens_layers": [],
197
- "vit_tokens_layers": [
198
- 2176,
199
- 1024
200
- ]
201
- },
202
- "token_size": 1024,
203
- "transformer_config": {
204
- "decoder_block_config": {
205
- "dropout": 0.0,
206
- "feature_size": 1024,
207
- "head_dim": 64,
208
- "num_heads": 16
209
- },
210
- "encoder_block_config": {
211
- "feature_size": 1024,
212
- "head_dim": 64,
213
- "num_heads": 16
214
- },
215
- "num_blocks": 2,
216
- "pos_embed_config": {
217
- "embedding_dim": 1024,
218
- "num_embeddings": 300
219
- }
220
- },
221
- "transformers_version": "4.44.2"
222
  }
 
1
  {
2
+ "arch_specifier": "no-align+gelu-mlp",
3
+ "architectures": [
4
+ "TrajectoryVLA"
5
+ ],
6
  "auto_map": {
7
+ "AutoModelForVision2Seq": "prismatic_model.TrajectoryVLA"
8
  },
9
+ "hf_llm_id": "meta-llama/Llama-2-7b-hf",
10
+ "image_resize_strategy": "letterbox",
11
+ "image_sizes": [
12
+ 224,
13
+ 224
14
+ ],
15
+ "llm_backbone_id": "llama2-7b-pure",
16
+ "llm_max_length": 2048,
17
+ "model_type": "prismatic",
18
+ "output_projector_states": false,
19
+ "pad_to_multiple_of": 64,
20
+ "pad_token_id": 32000,
21
+ "return_dict": false,
22
+ "text_config": {
23
+ "model_type": "llama"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  },
25
+ "timm_model_ids": [
26
+ "vit_large_patch14_reg4_dinov2.lvd142m",
27
+ "vit_so400m_patch14_siglip_224"
28
+ ],
29
+ "timm_override_act_layers": [
30
+ null,
31
+ null
32
+ ],
33
+ "torch_dtype": "bfloat16",
34
+ "transformers_version": "4.44.2",
35
+ "use_fused_vision_backbone": true,
36
+ "vision_backbone_id": "dinosiglip-vit-so-224px"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 32000,
6
+ "transformers_version": "4.44.2"
7
+ }
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cab95ea8a69faf885ec29dce3dba829617f86bf9fc8fdd730dbf28804ad7bf1
3
+ size 6948963952
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbb646e9b5155db78dfeb12260d2e3171f0ae53bed32d9a5a7c488e08c3372ee
3
+ size 6971232352
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf01b558f01cfa15b7d7112150d134d1b633d6ab56db901759885f59503371ff
3
+ size 1266349562
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
prismatic_model.py ADDED
@@ -0,0 +1,1129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_prismatic.py
3
+
4
+ Core HuggingFace-style PrismaticPreTrainedModel and PrismaticForConditionalGeneration class definitions, inheriting
5
+ from the default `transformers.PretrainedModel`. Meant to be standalone and self-contained, but exactly replicate the
6
+ logic in `prismatic.models.vlms.prismatic.py`.
7
+
8
+ Note =>> for the time being, not adding the custom HF "docstring" formatting.
9
+
10
+ References [LLaVa, IDEFICS-2]:
11
+ => https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/modeling_llava.py
12
+ => https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
13
+ """
14
+
15
+ import logging
16
+ from dataclasses import dataclass
17
+ from functools import partial
18
+ from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union
19
+ from functools import cached_property
20
+ # from barrel.components.nn.layers.nerf_pos_embed import NeRFPositionalEmbedding
21
+
22
+ import numpy as np
23
+ import timm
24
+ import tokenizers
25
+ import torch
26
+ import torch.nn as nn
27
+ import transformers
28
+ from timm.models.vision_transformer import LayerScale
29
+ from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
30
+ from transformers.modeling_outputs import ModelOutput
31
+ import collections
32
+ import math
33
+ from barrel.pipes.vlams.extern.prismatic_config import OpenVLAConfig, PrismaticConfig , TrajectoryVLAConfig, WaypointTokenizer
34
+ # from barrel.pipes.vlams.models.control.token_proj import TokenProjector
35
+ from barrel.pipes.vlams.extern.datatypes import *
36
+ from barrel.pipes.vlams.extern.detr import *
37
+ from IPython import embed
38
+ import os
39
+ from PIL import Image
40
+ from pathlib import Path
41
+ from torch.amp.autocast_mode import autocast # Corrected import for latest PyTorch
42
+ from scipy.spatial.transform import Rotation as R
43
+ ht_token_path = Path(".hf_token")
44
+ HF_TOKEN = ht_token_path.read_text().strip() if isinstance(ht_token_path, Path) else hf_token_path
45
+
46
+ # Get Logger
47
+ logger = logging.getLogger(__name__)
48
+ torch.backends.cudnn.benchmark = False
49
+ torch.backends.cudnn.deterministic = True
50
+
51
+ # === PyTorch/HuggingFace Default IGNORE_INDEX (for CrossEntropyLoss labels)
52
+ IGNORE_INDEX = -100
53
+
54
+
55
+ # === Utility Functions for Monkey-Patching ===
56
+ def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
57
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
58
+ result = fn(*args, **kwargs)
59
+ return result[0] if isinstance(result, tuple) else result
60
+
61
+ return wrapper
62
+
63
+
64
+ # HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale.
65
+ # =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109
66
+ # =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960
67
+ def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor:
68
+ return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor
69
+
70
+
71
+ def ls_apply_patch(ls_module: LayerScale):
72
+ ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone())
73
+ ls_module.forward = _ls_new_forward.__get__(ls_module, LayerScale)
74
+ del ls_module.gamma
75
+
76
+
77
+ # === Prismatic Vision Backbone (nn.Module) Definitions (w/ Fused Backbone Support) ===
78
+ class PrismaticVisionBackbone(nn.Module):
79
+ def __init__(
80
+ self,
81
+ use_fused_vision_backbone: bool,
82
+ image_sizes: List[int],
83
+ timm_model_ids: List[str],
84
+ timm_override_act_layers: List[Optional[str]],
85
+ ) -> None:
86
+ super().__init__()
87
+ self.use_fused_vision_backbone = use_fused_vision_backbone
88
+
89
+ # [Contract] Validate number of (fused) vision backbones, create "alpha" featurizer and Instantiate
90
+ # =>> Note :: Monkey-Patch the `forward()` function of the backbone to ensure FSDP-compatibility
91
+ # Hardcodes `get_intermediate_layers` to return the **SECOND-TO-LAST** layer patches!
92
+ assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!"
93
+
94
+ self.dino_featurizer = timm.create_model(
95
+ timm_model_ids[0],
96
+ pretrained=True,
97
+ num_classes=0,
98
+ img_size=image_sizes[0],
99
+ act_layer=timm_override_act_layers[0],
100
+ )
101
+ self.dino_featurizer.eval()
102
+
103
+ self.embed_dim = self.dino_featurizer.embed_dim
104
+
105
+ # If `use_fused_vision_backbone` =>> create "beta" featurizer
106
+ # if self.use_fused_vision_backbone:
107
+ self.siglip_featurizer = timm.create_model(
108
+ timm_model_ids[1],
109
+ pretrained=True,
110
+ num_classes=0,
111
+ img_size=image_sizes[1],
112
+ act_layer=timm_override_act_layers[1],)
113
+
114
+ self.siglip_featurizer.eval()
115
+
116
+ self.dino_featurizer.forward = partial(
117
+ self.dino_featurizer.forward_intermediates,
118
+ indices=[len(self.dino_featurizer.blocks) - 2],
119
+ return_prefix_tokens=False,
120
+ norm=False,
121
+ stop_early=True,
122
+ output_fmt='NLC',
123
+ intermediates_only=True,
124
+ )
125
+ self.siglip_featurizer.forward = partial(
126
+ self.siglip_featurizer.forward_intermediates,
127
+ indices=[len(self.siglip_featurizer.blocks) - 2],
128
+ return_prefix_tokens=False,
129
+ norm=False,
130
+ stop_early=True,
131
+ output_fmt='NLC',
132
+ intermediates_only=True,
133
+ )
134
+ self.embed_dim += self.siglip_featurizer.embed_dim
135
+
136
+ def forward(self, pixel_values) -> torch.Tensor:
137
+ """Run image (`pixel_values`) through featurizer; if channel-stacked, then dispatch and sequence stack."""
138
+ if not self.use_fused_vision_backbone:
139
+ return self.featurizer(pixel_values)
140
+
141
+ # Split `pixel_values :: [bsz, 2 * 3, resolution, resolution]` =>> featurize =>> channel stack
142
+ # img, img_fused = torch.split(pixel_values, [3, 3], dim=1)
143
+ img = pixel_values['dino']
144
+ img_fused = pixel_values['siglip']
145
+ patches, patches_fused = self.dino_featurizer(img)[0], self.siglip_featurizer(img_fused)[0]
146
+
147
+ return torch.cat([patches, patches_fused], dim=2)
148
+
149
+
150
+
151
+ class PrismaticProjector(nn.Module):
152
+ def __init__(self, use_fused_vision_backbone, vision_dim: int, llm_dim: int) -> None:
153
+ super().__init__()
154
+ self.initial_projection_dim = vision_dim * 4
155
+ self.projector = torch.nn.Sequential(
156
+ torch.nn.Linear(vision_dim, self.initial_projection_dim, bias=True),
157
+ torch.nn.GELU(),
158
+ torch.nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
159
+ torch.nn.GELU(),
160
+ torch.nn.Linear(llm_dim, llm_dim, bias=True),
161
+ )
162
+
163
+ def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
164
+ return self.projector(fused_img_patches)
165
+
166
+ # === Main HF Class Definitions ===
167
+ @dataclass
168
+ class PrismaticCausalLMOutputWithPast(ModelOutput):
169
+ """Base class for Prismatic casual (visually-conditioned) language model outputs; also exposes visual features."""
170
+
171
+ loss: Optional[torch.FloatTensor] = None
172
+ logits: torch.FloatTensor = None
173
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
174
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
175
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
176
+
177
+ # Additions for VLMs
178
+ projector_features: Optional[torch.FloatTensor] = None
179
+
180
+
181
+ class PrismaticPreTrainedModel(PreTrainedModel):
182
+ config_class: PrismaticConfig
183
+ base_model_prefix: str = "model"
184
+ supports_gradient_checkpointing: bool = True
185
+
186
+ _no_split_modules: ClassVar[List[str]] = ["PrismaticProjector"]
187
+ _skip_keys_device_placement: str = "past_key_values"
188
+ _supports_flash_attn_2: bool = True
189
+
190
+ def _init_weights(self, module: nn.Module) -> None:
191
+ # Important :: this HF ported version is *not* meant for training from scratch; only inference and fine-tuning!
192
+ # => As such, this init_weights code is not correct; if training VLMs from scratch, use the main codebase at
193
+ # https://github.com/TRI-ML/prismatic-vlms
194
+ std = (
195
+ self.config.initializer_range
196
+ if hasattr(self.config, "initializer_range")
197
+ else self.config.text_config.initializer_range
198
+ )
199
+
200
+ if hasattr(module, "class_embedding"):
201
+ module.class_embedding.data.normal_(mean=0.0, std=std)
202
+
203
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
204
+ module.weight.data.normal_(mean=0.0, std=std)
205
+ if module.bias is not None:
206
+ module.bias.data.zero_()
207
+ elif isinstance(module, nn.Embedding):
208
+ module.weight.data.normal_(mean=0.0, std=std)
209
+ if module.padding_idx is not None:
210
+ module.weight.data[module.padding_idx].zero_()
211
+
212
+ @property
213
+ def _supports_sdpa(self) -> bool:
214
+ """Check LLM supports SDPA Attention"""
215
+ return self.language_model._supports_sdpa
216
+
217
+ class LLMBackbone(nn.Module):
218
+ def __init__(self, config):
219
+ super().__init__()
220
+ self.config = config
221
+ self.llm : AutoModelForCausalLM
222
+ self.tokenizer = self._create_tokenizer()
223
+
224
+ def _create_tokenizer(self) -> transformers.PreTrainedTokenizerBase:
225
+ # Load (Fast) Tokenizer
226
+ print(f"Loading (Fast) Tokenizer via the AutoTokenizer API")
227
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
228
+ self.config['hf_model_id'],
229
+ model_max_length=self.config['llm_max_length'],
230
+ token=HF_TOKEN,
231
+ padding_side="right",
232
+ )
233
+
234
+ # Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input
235
+ # starts with a <BOS> token unless `add_special_tokens = False`; for these models, we empirically
236
+ # find that adding image patches *after* the BOS leads to much better performance.
237
+ #
238
+ # As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this
239
+ # line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to
240
+ # override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py`
241
+ # and VLM `forward()` logic!
242
+ SPECIAL_CASES = {
243
+ # Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>"
244
+ # =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that
245
+ # this works well with base LLM generation.
246
+ # =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes.
247
+ "microsoft/phi-2",
248
+ }
249
+ if self.config['hf_model_id'] not in SPECIAL_CASES:
250
+ # Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral!
251
+ assert (
252
+ tokenizer("Test 123", add_special_tokens=True).input_ids[0] == tokenizer.bos_token_id
253
+ ) and (
254
+ tokenizer("Test 123", add_special_tokens=False).input_ids[0] != tokenizer.bos_token_id
255
+ ), f"Default Tokenizer of type `{type(tokenizer)}` does not automatically prefix inputs with BOS token!\n"
256
+
257
+ return tokenizer
258
+
259
+ class PrismaticForConditionalGeneration(PrismaticPreTrainedModel):
260
+ def __init__(self, config: PrismaticConfig) -> None:
261
+ super().__init__(config)
262
+ # [Validation] Lightweight Validate on `config` Fields + Dependency Versions
263
+ if config.use_fused_vision_backbone is None:
264
+ raise ValueError("Missing config field `use_fused_vision_backbone`")
265
+
266
+ # if timm.__version__ not in {"0.9.10", "0.9.11", "0.9.12", "0.9.16"}:
267
+ # raise NotImplementedError(
268
+ # "TIMM Version must be >= 0.9.10 and < 1.0.0 (breaking); please raise a GitHub Issue "
269
+ # "if you urgently need support for latest TIMM versions."
270
+ # )
271
+
272
+ # if (transformers.__version__ != "4.40.1") or (tokenizers.__version__ != "0.19.1"):
273
+ # logger.warning(
274
+ # f"Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got "
275
+ # f"`transformers=={transformers.__version__}` and `tokenizers=={tokenizers.__version__}`; "
276
+ # f"there might be inference-time regressions due to dependency changes. If in doubt, please"
277
+ # f"use the above versions."
278
+ # )
279
+
280
+ # Instantiate PrismaticVisionBackbone (w/ Potential Fused Backbone)
281
+ self.vision_backbone = PrismaticVisionBackbone(
282
+ config.use_fused_vision_backbone, config.image_sizes, config.timm_model_ids, config.timm_override_act_layers
283
+ )
284
+
285
+ # Create Multimodal Projector
286
+ self.projector = PrismaticProjector(
287
+ config.use_fused_vision_backbone,
288
+ vision_dim=self.vision_backbone.embed_dim,
289
+ llm_dim=config.text_config.hidden_size,
290
+ )
291
+
292
+ # Instantiate LLM Backbone
293
+ self.llm_backbone = LLMBackbone({'hf_model_id': config.hf_llm_id, 'llm_max_length': config.llm_max_length, "pad_token_id" :32000,
294
+ "pad_to_multiple_of" : 64,})
295
+
296
+ # self.llm_backbone.llm = AutoModelForCausalLM.from_config(
297
+ # config.text_config, attn_implementation="flash_attention_2"
298
+ # )
299
+ self.llm_backbone.llm = AutoModelForCausalLM.from_pretrained(
300
+ 'meta-llama/Llama-2-7b-hf',
301
+ token=HF_TOKEN,
302
+ attn_implementation='flash_attention_2',
303
+ # The following parameters are set to prevent `UserWarnings` from HF; we want greedy decoding!
304
+ do_sample=False,
305
+ temperature=1.0,
306
+ use_cache=False,
307
+ top_p=1.0, )
308
+
309
+ self.llm_backbone.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
310
+ self.llm_backbone.llm.config.pad_token_id = self.llm_backbone.tokenizer.pad_token_id
311
+ self.llm_backbone.llm.resize_token_embeddings(len(self.llm_backbone.tokenizer), pad_to_multiple_of=64)
312
+
313
+
314
+
315
+ # self.llm_backbone.llm.config.pad_token_id = self.llm_backbone.tokenizer.pad_token_id
316
+ # self.llm_backbone.llm.resize_token_embeddings(len(self.llm_backbone.tokenizer), pad_to_multiple_of=64)
317
+ # self.resize_token_embeddings(32001,64)
318
+
319
+ self.vocab_size = config.text_config.vocab_size
320
+ self.pad_token_id = config.pad_token_id
321
+
322
+ # HF Boilerplate =>> initializes weights via `_init_weights()` and sets gradient checkpointing
323
+ self.post_init()
324
+
325
+ # === `PreTrainedModel` Boilerplate ===
326
+ def get_input_embeddings(self) -> nn.Module:
327
+ return self.llm_backbone.llm.get_input_embeddings()
328
+
329
+ def set_input_embeddings(self, value: nn.Module) -> None:
330
+ self.llm_backbone.llm.set_input_embeddings(value)
331
+
332
+ def get_output_embeddings(self) -> nn.Module:
333
+ return self.llm_backbone.llm.get_output_embeddings()
334
+
335
+ def set_output_embeddings(self, new_embeddings: nn.Module) -> None:
336
+ self.llm_backbone.llm.set_output_embeddings(new_embeddings)
337
+
338
+ def get_decoder(self) -> nn.Module:
339
+ return self.llm_backbone.llm.get_decoder()
340
+
341
+ def set_decoder(self, decoder: nn.Module) -> None:
342
+ self.llm_backbone.llm.set_decoder(decoder)
343
+
344
+ def tie_weights(self) -> None:
345
+ self.llm_backbone.llm.tie_weights() # Note: `Llama-2` and `Mistral` don't tie weights (no-op)
346
+
347
+ # def resize_token_embeddings(
348
+ # self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
349
+ # ) -> nn.Embedding:
350
+ # updated_embeddings = self.llm_backbone.llm.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
351
+
352
+ # # Update config/instance variables
353
+ # self.config.text_config.vocab_size = updated_embeddings.num_embeddings
354
+ # self.vocab_size = updated_embeddings.num_embeddings
355
+
356
+ # return updated_embeddings
357
+
358
+ # === Core Prismatic VLM `forward()` Logic ===
359
+ def forward(
360
+ self,
361
+ input_ids: Optional[torch.LongTensor] ,
362
+ attention_mask: Optional[torch.Tensor],
363
+ # pixel_values: Optional[torch.FloatTensor] = None,
364
+ pixel_values: Dict[str, torch.Tensor] = {},
365
+ labels: Optional[torch.LongTensor] = None,
366
+ inputs_embeds: Optional[torch.FloatTensor] = None,
367
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
368
+ use_cache: Optional[bool] = None,
369
+ output_attentions: Optional[bool] = None,
370
+ output_hidden_states: Optional[bool] = None,
371
+ output_projector_features: Optional[bool] = None,
372
+ return_dict: Optional[bool] = None,
373
+ **kwargs: Any,
374
+ ) -> Union[Tuple, PrismaticCausalLMOutputWithPast]:
375
+ """Run a forward pass through the VLM, returning a PrismaticCausalLMOutputWithPast instance."""
376
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
377
+ output_hidden_states = (
378
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
379
+ )
380
+ output_projector_features = output_projector_features if output_projector_features is not None else False
381
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
382
+
383
+ # Respect `use_cache` only if not training (even if `gradient_checkpointing` is off)
384
+ use_cache = use_cache and not self.training
385
+
386
+ # Instantiate Placeholder for Projector Features
387
+ projected_patch_embeddings = None
388
+
389
+ # Note :: We only support forward passes with the following cases:
390
+ # => Cached Generation :: (input_ids.shape[1] == 1) and (past_key_values is not None)
391
+ # => Unimodal Forward :: (pixel_values is None)
392
+ # => Multimodal Forward :: (pixel_values is not None) and (input_ids/embeds.shape[0] == pixel_values.shape[0])
393
+
394
+ # === Handle Generation with Cache (`input_ids.shape[1] == 1`) =>> requires `past_keys_values` ===
395
+ if input_ids.shape[1] == 1:
396
+ assert input_ids.shape[0] == 1, "Generation is only currently supported for batch size of 1!"
397
+ assert past_key_values is not None, "You must provide `past_key_values` during cached generation!"
398
+ assert labels is None, "Unexpected key `labels` provided during cached generation!"
399
+
400
+ language_model_output = self.llm_backbone.llm(
401
+ input_ids=input_ids,
402
+ attention_mask=None,
403
+ position_ids=None,
404
+ past_key_values=past_key_values,
405
+ inputs_embeds=None,
406
+ labels=None,
407
+ use_cache=use_cache,
408
+ output_attentions=output_attentions,
409
+ output_hidden_states=output_hidden_states,
410
+ return_dict=return_dict,
411
+ )
412
+
413
+ # === Handle Unimodal Forward ===
414
+ elif pixel_values is None:
415
+ assert (input_ids is not None) and (inputs_embeds is None), "Missing `input_ids` in language-only forward!"
416
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
417
+
418
+ language_model_output = self.llm_backbone.llm(
419
+ input_ids=input_ids,
420
+ attention_mask=attention_mask,
421
+ position_ids=None,
422
+ past_key_values=None,
423
+ inputs_embeds=None,
424
+ labels=labels,
425
+ use_cache=use_cache,
426
+ output_attentions=output_attentions,
427
+ output_hidden_states=output_hidden_states,
428
+ return_dict=return_dict,
429
+ )
430
+
431
+ # === Handle Multimodal Forward ===
432
+
433
+ elif (input_ids.shape[0] == pixel_values['dino'].shape[0]) or (inputs_embeds.shape[0] == pixel_values['dino'].shape[0]):
434
+ assert past_key_values is None, "Unexpected key `past_key_values` provided during language-only forward!"
435
+
436
+ # Visual Feature Extraction
437
+ patch_features = self.vision_backbone(pixel_values)
438
+
439
+ projected_patch_embeddings = self.projector(patch_features) ## matches
440
+ projected_patch_attention_mask = None
441
+ if attention_mask is not None:
442
+ projected_patch_attention_mask = torch.full(
443
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
444
+ fill_value=True,
445
+ dtype=attention_mask.dtype,
446
+ device=attention_mask.device,
447
+ )
448
+
449
+ # Get Input Embeddings (from Language Model Embeddings)
450
+ input_embeddings = self.get_input_embeddings()(input_ids)
451
+
452
+ # Build Multimodal Embeddings & Attention Mask =>> Prismatic defaults to inserting after <BOS> token (1:)
453
+ multimodal_embeddings = torch.cat(
454
+ [input_embeddings[:, :1, :], projected_patch_embeddings, input_embeddings[:, 1:, :]], dim=1
455
+ )
456
+ multimodal_attention_mask = None
457
+ if attention_mask is not None:
458
+ multimodal_attention_mask = torch.cat(
459
+ [attention_mask[:, :1], projected_patch_attention_mask, attention_mask[:, 1:]], dim=1
460
+ )
461
+
462
+ # Build Labels (if specified) =>> Ignore Labels for Patch Embeddings
463
+ multimodal_labels = None
464
+ if labels is not None:
465
+ projected_patch_labels = torch.full(
466
+ (projected_patch_embeddings.shape[0], projected_patch_embeddings.shape[1]),
467
+ fill_value=IGNORE_INDEX,
468
+ dtype=labels.dtype,
469
+ device=labels.device,
470
+ )
471
+ multimodal_labels = torch.cat([labels[:, :1], projected_patch_labels, labels[:, 1:]], dim=1)
472
+
473
+ # Dispatch to Language Model
474
+ language_model_output = self.llm_backbone.llm(
475
+ input_ids=None,
476
+ attention_mask=multimodal_attention_mask,
477
+ position_ids=None,
478
+ past_key_values=None,
479
+ inputs_embeds=multimodal_embeddings,
480
+ labels=multimodal_labels,
481
+ use_cache=use_cache,
482
+ output_attentions=output_attentions,
483
+ output_hidden_states=output_hidden_states,
484
+ return_dict=return_dict,
485
+ )
486
+
487
+ # === Otherwise =>> Assume Invalid! ===
488
+ elif (input_ids.shape[0] != pixel_values.shape[0]) or (inputs_embeds.shape[0] != pixel_values.shape[0]):
489
+ raise ValueError("Non-homogenous batch of (text, image) input -- forward() does not support mixed batches!")
490
+
491
+ else:
492
+ raise ValueError(
493
+ "Invalid PrismaticForConditionalGeneration `forward()` call with provided arguments:\n"
494
+ f"=> `input_ids` = {input_ids is not None}\n"
495
+ f"=> `attention_mask` = {attention_mask is not None}\n"
496
+ f"=> `pixel_values` = {pixel_values is not None}\n"
497
+ f"=> `labels` = {labels is not None}\n"
498
+ f"=> `input_embeds` = {inputs_embeds is not None}\n"
499
+ f"=> `past_key_values` = {past_key_values is not None}\n"
500
+ f"=> `use_cache` = {use_cache}"
501
+ )
502
+
503
+ # Unpack `language_model_output` and return PrismaticCausalLMOutputWithPast (or tuple if not `return_dict`)
504
+ if not return_dict:
505
+ if output_projector_features and (projected_patch_embeddings is not None):
506
+ return *language_model_output, projected_patch_embeddings
507
+
508
+ return language_model_output
509
+
510
+
511
+ return (PrismaticCausalLMOutputWithPast(
512
+ loss=language_model_output.loss,
513
+ logits=language_model_output.logits,
514
+ past_key_values=language_model_output.past_key_values,
515
+ hidden_states=language_model_output.hidden_states,
516
+ attentions=language_model_output.attentions,
517
+ projector_features=projected_patch_embeddings,
518
+ ),patch_features,multimodal_attention_mask)
519
+
520
+ # === GenerationMixin Methods ===
521
+ def prepare_inputs_for_generation(
522
+ self,
523
+ input_ids: Optional[torch.Tensor] = None,
524
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
525
+ inputs_embeds: Optional[torch.FloatTensor] = None,
526
+ pixel_values: Optional[torch.FloatTensor] = None,
527
+ attention_mask: Optional[torch.Tensor] = None,
528
+ **kwargs: str,
529
+ ) -> Dict[str, torch.Tensor]:
530
+ """Borrowed from `LlamaForCausalLM` and simplified for batch size = 1; mirrors original PrismaticVLM logic."""
531
+ if ((input_ids is not None) and (input_ids.shape[0] > 1)) or (
532
+ (inputs_embeds is not None) and (inputs_embeds.shape[0] > 1)
533
+ ):
534
+ raise ValueError("Generation with batch size > 1 is not currently supported!")
535
+
536
+ # Handle `past_key_values` (cache) =>> assume `input_ids` just has unprocessed tokens
537
+ if past_key_values is not None:
538
+ input_ids = input_ids[:, -1:]
539
+
540
+ # If `input_embeds` are passed, we only want to use them in the 1st generation step
541
+ if inputs_embeds is not None and past_key_values is None:
542
+ model_inputs = {"input_embeds": inputs_embeds}
543
+ else:
544
+ model_inputs = {"input_ids": input_ids}
545
+
546
+ # Make sure `pixel_values` are preserved in `model_inputs`
547
+ model_inputs.update(
548
+ {
549
+ "attention_mask": attention_mask,
550
+ "pixel_values": pixel_values,
551
+ "past_key_values": past_key_values,
552
+ "use_cache": kwargs.get("use_cache"),
553
+ }
554
+ )
555
+
556
+ return model_inputs
557
+
558
+ # Defer to Language Model (all handle this differently, with different return types)
559
+ def _reorder_cache(self, *args, **kwargs) -> Any:
560
+ return self.language_model._reorder_cache(*args, **kwargs)
561
+
562
+
563
+ class TokenProjectorConfig(PretrainedConfig):
564
+ vit_tokens_layers: List[int] = [] # If empty, torch.nn.Identity
565
+ llm_image_tokens_layers: List[int] = [] # If empty, torch.nn.Identity
566
+ control_tokens_layers: List[int] = [] # If empty, torch.nn.Identity
567
+
568
+ # image_tokens_mode:
569
+ # vit: use ViT tokens only
570
+ # llm: use LLM tokens only
571
+ # skip: skip connection between projector(ViT) and LLM with addition
572
+ # none: don't feed to TokenProjector
573
+ image_tokens_mode: str
574
+
575
+ def __post_init__(self):
576
+ super().__post_init__()
577
+
578
+ if self.image_tokens_mode == 'vit':
579
+ assert len(self.vit_tokens_layers) > 0 or len(self.control_tokens_layers) > 0
580
+ elif self.image_tokens_mode == 'llm':
581
+ assert len(self.vit_tokens_layers) > 0 or len(self.control_tokens_layers) > 0
582
+ elif self.image_tokens_mode == 'skip':
583
+ assert len(self.vit_tokens_layers) > 0 or len(self.llm_image_tokens_layers) > 0
584
+ elif self.image_tokens_mode == 'none':
585
+ assert len(self.vit_tokens_layers) == 0
586
+ assert len(self.llm_image_tokens_layers) == 0
587
+ else:
588
+ raise NotImplementedError(f"Unknown image tokens mode {self.image_tokens_mode}")
589
+
590
+ class TokenProjector(nn.Module):
591
+ """Project and pack VLM output tokens"""
592
+
593
+ def __init__(self, config):
594
+ super().__init__()
595
+ self.config = TokenProjectorConfig()
596
+ self.config.vit_tokens_layers = config['vit_tokens_layers']
597
+ self.config.llm_image_tokens_layers = config['llm_image_tokens_layers']
598
+ self.config.control_tokens_layers = config['control_tokens_layers']
599
+ self.config.image_tokens_mode = config['image_tokens_mode']
600
+
601
+ self.vit_tokens_proj = self._make_token_proj_module(self.config.vit_tokens_layers)
602
+ self.llm_image_tokens_proj = self._make_token_proj_module(self.config.llm_image_tokens_layers)
603
+ self.control_tokens_proj = self._make_token_proj_module(self.config.control_tokens_layers)
604
+
605
+ def forward(self, inputs: WaypointerInput) -> torch.Tensor:
606
+ """
607
+ Args:
608
+ inputs: Contains VLM outputs
609
+ Returns:
610
+ torch.Tensor of shape [B, num_tokens, token_size] that always contains the control tokens
611
+ and possibly the image tokens (prepended), depending on the configuration
612
+ """
613
+
614
+ vit_tokens = self.vit_tokens_proj(inputs.vit_tokens)
615
+ control_tokens = self.control_tokens_proj(inputs.control_tokens)
616
+ llm_image_tokens = self.llm_image_tokens_proj(inputs.llm_image_tokens)
617
+
618
+ if self.config.image_tokens_mode == 'vit':
619
+ output = torch.cat([vit_tokens, control_tokens], dim=1) # [B, img + control, token_size]
620
+ elif self.config.image_tokens_mode == 'llm':
621
+ output = torch.cat([llm_image_tokens, control_tokens], dim=1) # [B, img + control, token_size]
622
+ elif self.config.image_tokens_mode == 'skip':
623
+ image_tokens = llm_image_tokens + vit_tokens
624
+ output = torch.cat([image_tokens, control_tokens], dim=1) # [B, img + control, token_size]
625
+ elif self.config.image_tokens_mode == 'none':
626
+ output = control_tokens
627
+ else:
628
+ raise NotImplementedError(f"Unknown image tokens mode {self.config.image_tokens_mode}")
629
+
630
+ return output
631
+
632
+ def _make_token_proj_module(self, layer_sizes: List[int]) -> torch.nn.Module:
633
+ if len(layer_sizes) == 0:
634
+ return torch.nn.Identity()
635
+
636
+ assert len(layer_sizes) > 1, "Need to provide input and output layer sizes at least"
637
+
638
+ module = torch.nn.Sequential(
639
+ *[
640
+ torch.nn.Sequential(
641
+ collections.OrderedDict(
642
+ {
643
+ 'linear': torch.nn.Linear(layer_in_features, layer_out_features),
644
+ 'act': torch.nn.ReLU(),
645
+ 'norm': torch.nn.LayerNorm(layer_out_features),
646
+ }
647
+ )
648
+ )
649
+ for layer_in_features, layer_out_features in zip(layer_sizes[:-1], layer_sizes[1:])
650
+ ]
651
+ )
652
+ return module
653
+
654
+ class NeRFPositionalEmbedding(torch.nn.Module):
655
+ def __init__(self, proj_scale: int):
656
+ """
657
+ Args:
658
+ proj_scale: Dimension size, same as L parameter in the NeRF paper
659
+ """
660
+ super().__init__()
661
+ self.proj_scale = proj_scale
662
+
663
+ freq = 2 ** torch.arange(self.proj_scale, dtype=torch.float32) * math.pi # size: [L]
664
+
665
+ self.register_buffer('freq', freq)
666
+
667
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
668
+ """
669
+ Maps values from R^N to a higher dimensional space R^(N2L)
670
+ Args:
671
+ inputs: torch.Tensor of shape [B, ..., N]; input values to be transformed
672
+ Returns: torch.Tensor of shape [B, ..., N2L]; encoded input values
673
+ """
674
+
675
+ spectrum = self.freq.view(*[1] * inputs.ndim, -1) * inputs.unsqueeze(-1) # [B, ..., N, L]
676
+ encoding = torch.stack([torch.sin(spectrum), torch.cos(spectrum)], dim=-2) # [B, ..., N, 2, L]
677
+ encoding = encoding.view(inputs.shape[-1], -1) # [B, ..., N2L]
678
+
679
+ return encoding
680
+
681
+ class TimestepProjModuleConfig(PretrainedConfig):
682
+ pos_embed_scale: int # How much to scale timestep values when doing position embedding
683
+ proj_layers: List[int]
684
+ time_delta_sec: float = 0.25 # Time delta between two predictions
685
+ num_tokens: int = 3 # Number of tokens per timestep; Currently 3 - translation, rotation, gripper
686
+
687
+
688
+ class TimestepProjModule(nn.Module):
689
+
690
+ def __init__(self, config: TimestepProjModuleConfig, num_timesteps: int, token_size: int):
691
+ """
692
+ Args:
693
+ num_timesteps: Number of control timesteps
694
+ token_size: Single token size
695
+ """
696
+ super().__init__()
697
+ self.config = TimestepProjModuleConfig()
698
+ self.config.pos_embed_scale = config['pos_embed_scale']
699
+ self.config.proj_layers = config['proj_layers']
700
+ self.config.time_delta_sec = config['time_delta_sec']
701
+ self.config.num_tokens = config['num_tokens']
702
+
703
+ self.num_timesteps = num_timesteps
704
+ self.token_size = token_size
705
+
706
+ input_size = 2 * self.config.pos_embed_scale
707
+
708
+ self.pos_embed = NeRFPositionalEmbedding(self.config.pos_embed_scale)
709
+
710
+ # We output one token for translation, one for rotation and one for gripper state
711
+ feature_size = self.config.num_tokens * self.token_size
712
+
713
+ # Make MLP projection
714
+
715
+ self.timestep_proj = self._make_timestep_proj(in_features=int(input_size), out_features=int(feature_size))
716
+
717
+ def _make_timestep_proj(self, in_features: int, out_features: int) -> torch.nn.Module:
718
+ layer_sizes = [in_features] + list(self.config.proj_layers) + [out_features]
719
+ module = torch.nn.Sequential(
720
+ *[
721
+ torch.nn.Sequential(
722
+ collections.OrderedDict(
723
+ {
724
+ 'linear': torch.nn.Linear(layer_in_features, layer_out_features),
725
+ 'act': torch.nn.ReLU(),
726
+ 'norm': torch.nn.LayerNorm(layer_out_features),
727
+ }
728
+ )
729
+ )
730
+ for layer_in_features, layer_out_features in zip(layer_sizes[:-1], layer_sizes[1:])
731
+ ]
732
+ )
733
+ return module
734
+
735
+ def forward(self) -> torch.Tensor:
736
+ """
737
+ Returns:
738
+ torch.Tensor of sequence of timestep tokens, shape [1, num_timesteps * num_tokens, token_size]
739
+ """
740
+ device = self.timestep_proj[0].linear.weight.device # type: ignore[index]
741
+
742
+ # Position encode timesteps
743
+ time_deltas_norm = self.time_deltas_norm.view(1, self.num_timesteps) # [1, num_timesteps]
744
+ time_deltas_norm = time_deltas_norm.to(device=device)
745
+
746
+ # Embed timesteps to intermediate dimension
747
+ timesteps_embed = self.pos_embed(time_deltas_norm) # [1, num_timesteps * 2 * L]
748
+ timesteps_embed = timesteps_embed.view(self.num_timesteps, -1) # [num_timesteps, 2 * L]
749
+
750
+ # Project the timesteps via MLP to tokens
751
+ timesteps_tokens = self.timestep_proj(timesteps_embed) # [num_timesteps, token_size * 3]
752
+
753
+ # Reshape MLP outputs into tokens
754
+ timesteps_tokens = timesteps_tokens.view( # [1, num_timesteps * 3, token_size]
755
+ 1, self.num_timesteps * self.config.num_tokens, self.token_size
756
+ )
757
+
758
+ return timesteps_tokens
759
+
760
+ @cached_property
761
+ def time_deltas_sec(self) -> torch.Tensor:
762
+ return torch.arange(0, self.num_timesteps, 1, dtype=torch.float32) * self.config.time_delta_sec
763
+
764
+ @cached_property
765
+ def time_deltas_norm(self) -> torch.Tensor:
766
+ # Normalize time deltas between [0, 1]. We are saving [-1, 0] interval for possible past supervision
767
+ if self.time_deltas_sec.shape[0] == 1:
768
+ # Can't divide by 0
769
+ time_deltas_norm = self.time_deltas_sec
770
+ else:
771
+ time_deltas_norm = self.time_deltas_sec / self.time_deltas_sec.max() # [num_timesteps]
772
+ return time_deltas_norm.detach()
773
+
774
+
775
+ # class Waypointer(nn.Module):
776
+
777
+ class TrajectoryVLA(PrismaticForConditionalGeneration):
778
+
779
+
780
+ config_class: PretrainedConfig = TrajectoryVLAConfig
781
+
782
+ def __init__(self, config: TrajectoryVLAConfig) -> None:
783
+ super().__init__(config.prismatic_config)
784
+ self.control_tokenizer = WaypointTokenizer(self.llm_backbone.tokenizer)
785
+ self.timestep_proj = TimestepProjModule(
786
+ config.timestep_proj_config,
787
+ num_timesteps=config.num_timesteps,
788
+ token_size=config.token_size, )
789
+ self.num_timesteps = config.num_timesteps
790
+ self.token_proj = TokenProjector(config.token_proj_config)
791
+ self.transformer = DETR(config.transformer_config)
792
+ self.token_size = config.token_size
793
+ self.rotation_components = config.rotation_components
794
+ # if self.config.separate_control_proj:
795
+ # Project translation, rotation and gripper separately. Each timestep is projected separately
796
+ self.translation_proj = torch.nn.Sequential(
797
+ torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2),
798
+ torch.nn.ReLU(),
799
+ torch.nn.Linear(in_features=config.token_size // 2, out_features=3),
800
+ )
801
+ self.rotation_proj = torch.nn.Sequential(
802
+ torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2),
803
+ torch.nn.ReLU(),
804
+ torch.nn.Linear(
805
+ in_features=config.token_size // 2, out_features=config.rotation_components
806
+ ),
807
+ )
808
+
809
+ self.gripper_proj = torch.nn.Sequential(
810
+ torch.nn.Linear(in_features=config.token_size, out_features=config.token_size // 2),
811
+ torch.nn.ReLU(),
812
+ torch.nn.Linear(in_features=config.token_size // 2, out_features=1),
813
+ )
814
+
815
+ def _pack_waypointer_input(self, input_ids: torch.Tensor, vlm_output: PrismaticCausalLMOutputWithPast,vit_tokens,fused_attention_mask) -> WaypointerInput:
816
+ # Get the LLM output
817
+ # assert vlm_output.llm_output.hidden_states is not None
818
+ projected_tokens = vlm_output.hidden_states[-1]
819
+
820
+ control_tokens = self._extract_control_tokens(input_ids, projected_tokens) # type: ignore
821
+
822
+ num_image_tokens = vit_tokens.shape[1] # type: ignore[union-attr]
823
+ # TODO: This assumes a specific position of image tokens in the sequence. Make general
824
+ llm_image_tokens = projected_tokens[..., 1 : 1 + num_image_tokens, :]
825
+
826
+
827
+ return WaypointerInput(
828
+ vit_tokens=vit_tokens,
829
+ llm_image_tokens=llm_image_tokens,
830
+ control_tokens=control_tokens,
831
+ llm_tokens=projected_tokens,
832
+ attn_mask=fused_attention_mask,
833
+ )
834
+
835
+ def predict_tracks(self,inputs):
836
+
837
+ vlm_output,vit_tokens,fused_attention_mask = super().forward(**inputs,output_hidden_states=True,output_attentions=True,return_dict=True)
838
+ waypointer_input = self._pack_waypointer_input(inputs['input_ids'], vlm_output,vit_tokens,fused_attention_mask)
839
+ waypoint_output = self._waypointer_forward(waypointer_input)
840
+ translation, rotation, gripper = torch.split(
841
+ waypoint_output, [3, self.rotation_components, 1], dim=-1 )
842
+ translation, rotation, gripper = self.process_output(translation, rotation, gripper)
843
+ return translation, rotation, gripper
844
+ def process_output(self,translation,rotation,gripper):
845
+ ## convert rotation from matrix to euler angles
846
+ euler_angles = []
847
+ for matrix in rotation[0]:
848
+ # Convert each rotation matrix to a Rotation object
849
+ rotation_obj = R.from_matrix(matrix.view(3, 3).detach().cpu().float().numpy().squeeze())
850
+ # Convert to Euler angles in radians with chosen convention, e.g., 'xyz'
851
+ euler_angle = rotation_obj.as_euler('xyz', degrees=False)
852
+ euler_angles.append(euler_angle)
853
+
854
+ translation = translation.detach().cpu().float().numpy().squeeze()
855
+ ## sigmoid and clip from 0-1
856
+ gripper = np.round(torch.sigmoid(gripper).detach().cpu().float().numpy().squeeze())
857
+ return translation,euler_angles,gripper
858
+
859
+ def _extract_control_tokens(self, input_ids: torch.Tensor, output_tokens: torch.Tensor) -> torch.Tensor:
860
+ """
861
+ Extract the action tokens from the LLM output sequence. Assumes the following order
862
+ [image_tokens, language_tokens, action_tokens, padding]
863
+
864
+ Args:
865
+ input_ids: IDs of the tokens in text input sequence; shape [B, S]
866
+ output_tokens: Token sequence output from LLM; shape [B, L, token_size]. Note the length is
867
+ different from input_ids as it also contains image tokens
868
+ Returns:
869
+ torch.Tensor of shape [B, 7, token_size] containing only action tokens
870
+ """
871
+
872
+ assert input_ids.ndim == 2
873
+ assert output_tokens.ndim == 3
874
+ batch, in_seq_len, out_seq_len = *input_ids.shape, output_tokens.shape[1]
875
+
876
+ device = input_ids.device
877
+
878
+ num_control_tokens = self.control_tokenizer.num_control_tokens # type: ignore[attr-defined]
879
+
880
+ control_token_ids = torch.from_numpy( # type: ignore[attr-defined]
881
+ self.control_tokenizer.control_token_ids # type: ignore[attr-defined]
882
+ )
883
+ control_token_ids = control_token_ids.to(dtype=input_ids.dtype, device=input_ids.device)
884
+ is_control_token = torch.any( # shape: [B, S]
885
+ input_ids.unsqueeze(-1) == control_token_ids.view(1, 1, -1),
886
+ dim=-1,
887
+ )
888
+ if not torch.all(mask := is_control_token.sum(dim=-1) == num_control_tokens):
889
+ raise RuntimeError(
890
+ f"Can't properly detect control tokens with ids {control_token_ids} of len="
891
+ f"{len(control_token_ids)} in input_ids {input_ids}. Rows mask: {mask}"
892
+ )
893
+
894
+ # Pad is_control_tokens mask to the LLM output sequence size
895
+ tokens_mask = torch.cat( # shape: [B, L]
896
+ [
897
+ torch.zeros(batch, out_seq_len - in_seq_len, dtype=torch.bool, device=device),
898
+ is_control_token.to(torch.bool),
899
+ ],
900
+ dim=1,
901
+ )
902
+
903
+ control_tokens = output_tokens[tokens_mask] # shape: 1D tensor
904
+ control_tokens = control_tokens.view( # [B, num_control_tokens, token_size]
905
+ batch, num_control_tokens, output_tokens.shape[-1]
906
+ )
907
+
908
+ return control_tokens
909
+
910
+ def _waypointer_forward(self, inputs:WaypointerInput):
911
+
912
+ timesteps_tokens = self.timestep_proj() # [1, num_timesteps * 3, token_size]
913
+
914
+ # Project and pack LLM tokens
915
+ llm_tokens = self.token_proj(inputs) # [B, num_tokens, token_size]
916
+
917
+ # TODO: Pass inputs.attn_mask if you start using the LLM tokens
918
+ output_tokens = self.transformer( # [B, num_timesteps * 3, token_size]
919
+ feature_tokens=llm_tokens, query_tokens=timesteps_tokens, attn_mask=None
920
+ )
921
+
922
+ output_tokens = output_tokens.view( # [B, num_timesteps, 3 * token_size]
923
+ -1, self.num_timesteps, 3 * self.token_size
924
+ )
925
+
926
+ # if self.config.separate_control_proj:
927
+ # [B, num_timesteps, token_size] each
928
+ translation_tokens, rotation_tokens, gripper_tokens = torch.split(
929
+ output_tokens, [self.token_size] * 3, dim=-1
930
+ )
931
+
932
+ translation = self.translation_proj(translation_tokens) # [B, num_timesteps, 3]
933
+ rotation = self.rotation_proj(rotation_tokens) # [B, num_timesteps, rotation_components]
934
+ gripper = self.gripper_proj(gripper_tokens) # [B, num_timesteps, 1]
935
+
936
+ output = torch.cat( # [B, num_timesteps, control_components]
937
+ [translation, rotation, gripper], dim=-1
938
+ )
939
+
940
+ return output
941
+ # def predict_waypoints(self,input_ids: Optional[torch.LongTensor] = None, **kwargs: str) -> np.ndarray:
942
+ # vlm_output = super().forward(
943
+ # inputs=input_ids,
944
+ # use_cache=use_cache,
945
+ # output_attentions=output_attentions,
946
+ # output_hidden_states=True,
947
+ # return_dict=return_dict,
948
+ # )
949
+
950
+
951
+ @staticmethod
952
+ def _check_unnorm_key(norm_stats: Dict[str, Dict[str, Any]], unnorm_key: Optional[str]) -> str:
953
+ if unnorm_key is None and len(norm_stats) != 1:
954
+ raise ValueError(
955
+ f"Your model was trained on more than one dataset. "
956
+ f"Please pass a `unnorm_key` from the following options to choose the statistics used for "
957
+ f"de-normalizing actions: {norm_stats.keys()}"
958
+ )
959
+
960
+ # If None, grab the (singular) dataset in `norm_stats` to use as `unnorm_key`
961
+ unnorm_key = unnorm_key if unnorm_key is not None else next(iter(norm_stats.keys()))
962
+ if unnorm_key not in norm_stats:
963
+ raise ValueError(
964
+ f"The `unnorm_key` you chose ({unnorm_key = }) is not in the available statistics. "
965
+ f"Please choose from: {norm_stats.keys()}"
966
+ )
967
+
968
+ return unnorm_key
969
+
970
+ def get_action_dim(self, unnorm_key: Optional[str] = None) -> int:
971
+ """Get the dimensionality of the policy's action space."""
972
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
973
+ return len(self.norm_stats[unnorm_key]["action"]["q01"])
974
+
975
+ def get_action_stats(self, unnorm_key: Optional[str] = None) -> Dict[str, Any]:
976
+ """Get all the logged statistics for the given dataset."""
977
+ unnorm_key = self._check_unnorm_key(self.norm_stats, unnorm_key)
978
+ return self.norm_stats[unnorm_key]["action"]
979
+
980
+ def remove_waypointer_prefix(ckpt):
981
+ new_state_dict = {}
982
+ for key, value in ckpt.items():
983
+ # Remove the 'waypointer.' prefix if it exists
984
+ if key.startswith('waypointer.'):
985
+ new_key = key[len('waypointer.'):]
986
+ else:
987
+ new_key = key
988
+ new_state_dict[new_key] = value
989
+ return new_state_dict
990
+
991
+ def image_processor(image):
992
+ image_resolution = (3,224,224)
993
+ image = image.resize(image_resolution[1:], resample=Image.Resampling.LANCZOS)
994
+
995
+ def read_pt(pt_path):
996
+ data = torch.load(pt_path)
997
+ return data
998
+
999
+ # model_input = read_pt('/work/nikolay_nikolov/debug/inference/model_input.pt')
1000
+ # vit_output = read_pt('/work/nikolay_nikolov/debug/inference/vit_output.pt')['vit_output']
1001
+ # llm_output = read_pt('/work/nikolay_nikolov/debug/inference/llm_output.pt')['llm_output']
1002
+ # projector_output = read_pt('/work/nikolay_nikolov/debug/inference/projector_output.pt')['projector_output']
1003
+ # transformer_input = read_pt('/work/nikolay_nikolov/debug/inference/transformer_input.pt')
1004
+ # feature_tokens = transformer_input['feature_tokens']
1005
+ # timestep_tokens = transformer_input['timestep_tokens']
1006
+ # # waypointer_input_nikolay = read_pt('/work/nikolay_nikolov/debug/inference/waypointer_input.pt')
1007
+ # transformer_input = read_pt('/work/nikolay_nikolov/debug/inference/transformer_input.pt')
1008
+ # control_target = read_pt('/work/nikolay_nikolov/debug/inference/control_target.pt')
1009
+
1010
+ if __name__ == "__main__":
1011
+
1012
+ prismatic_config_dict = {
1013
+ "vision_backbone_id":"dinosiglip-vit-so-224px",
1014
+ "llm_backbone_id":"llama2-7b-pure",
1015
+ "arch_specifier": "no-align+gelu-mlp", ## TODO: check
1016
+ "use_fused_vision_backbone" :True, ## TODO: check
1017
+ "image_resize_strategy" : "letterbox",
1018
+ "text_config" : None,
1019
+ "llm_max_length" : 2048,
1020
+ "pad_token_id" :32000,
1021
+ "pad_to_multiple_of" : 64,
1022
+ "output_projector_states" : False,
1023
+ "return_dict": False,
1024
+ }
1025
+
1026
+ token_proj_config = {
1027
+ "vit_tokens_layers": [2176, 1024],
1028
+ "control_tokens_layers": [4096, 2048, 1024],
1029
+ "image_tokens_mode": 'vit',
1030
+ 'llm_image_tokens_layers': []
1031
+ }
1032
+ timestep_proj_config = {
1033
+ "pos_embed_scale": 8,
1034
+ "proj_layers": [128,512,1024],
1035
+ "time_delta_sec": 0.1,
1036
+ "num_tokens":3
1037
+ }
1038
+ pos_embed_config = {
1039
+ "num_embeddings": 300,
1040
+ "embedding_dim": 1024
1041
+ }
1042
+ encoder_block_config = {
1043
+ "feature_size": 1024,
1044
+ "head_dim": 64,
1045
+ "num_heads": 16
1046
+ }
1047
+ decoder_block_config = {
1048
+ "feature_size": 1024,
1049
+ "head_dim": 64,
1050
+ "num_heads": 16,
1051
+ "dropout": 0.0
1052
+ }
1053
+ transformer_config = {
1054
+ "pos_embed_config": pos_embed_config,
1055
+ "encoder_block_config": encoder_block_config,
1056
+ "decoder_block_config": decoder_block_config,
1057
+ "num_blocks": 2
1058
+ }
1059
+
1060
+ # transformer_config:
1061
+ # autoclass: barrel.components.nn.layers.detr.DETR
1062
+ # pos_embed_config:
1063
+ # autoclass: barrel.components.nn.layers.positional_encodings.LearnedPosEmbed1D
1064
+ # num_embeddings: 300 # Max number of input tokens
1065
+ # embedding_dim: *token_size # token_size
1066
+ # # num_embeddings: 256 # Number of image tokens
1067
+ # # embedding_dim: 512 # token_size / 2
1068
+ # encoder_block_config:
1069
+ # autoclass: barrel.components.nn.layers.detr.TransformerEncoderBlock
1070
+ # feature_size: *token_size
1071
+ # # head_dim: 128
1072
+ # # num_heads: 8
1073
+ # head_dim: 64
1074
+ # num_heads: 16
1075
+ # decoder_block_config:
1076
+ # autoclass: barrel.components.nn.layers.detr.TransformerDecoderBlock
1077
+ # feature_size: *token_size
1078
+ # # head_dim: 128
1079
+ # # num_heads: 8
1080
+ # head_dim: 64
1081
+ # num_heads: 16
1082
+
1083
+ TrajectoryVlaConfig_config = {
1084
+ "prismatic_config":prismatic_config_dict,
1085
+ "token_size": 1024,
1086
+ "cheat": False,
1087
+ "num_timesteps": 6,
1088
+ "rotation_components": 9,
1089
+ "seperate_control_proj": True,
1090
+ "timestep_proj_config": timestep_proj_config,
1091
+ "token_proj_config": token_proj_config,
1092
+ "transformer_config": transformer_config,
1093
+ "num_timestep_tokens": 3,
1094
+ }
1095
+
1096
+ # ckpt_path = '/work/nikolay_nikolov/debug/inference/model.ckpt'
1097
+ # ckpt_params = torch.load(ckpt_path, map_location='cpu', mmap= True)
1098
+ # ckpt_params = remove_waypointer_prefix(ckpt_params)
1099
+
1100
+ ## Testing for prismatic
1101
+ model_config = TrajectoryVLAConfig( **TrajectoryVlaConfig_config)
1102
+ # model.load_state_dict(ckpt_params, strict=True)
1103
+
1104
+ model = TrajectoryVLA(model_config)
1105
+ model = model.to(dtype=torch.bfloat16)
1106
+ model = model.to('cuda')
1107
+ model.eval()
1108
+
1109
+ # with autocast('cuda',dtype=torch.bfloat16):
1110
+ # with torch.no_grad():
1111
+ # output = model.predict_tracks(model_input)
1112
+
1113
+
1114
+ # Get matched keys by finding keys that exist in both the model and checkpoint
1115
+ # TrajectoryVLA.load_state_dict(ckpt_params, strict=False)
1116
+
1117
+ # model_keys = set(TrajectoryVLA.state_dict().keys())
1118
+ # checkpoint_keys = set(ckpt_params.keys())
1119
+ # matched_keys = model_keys.intersection(checkpoint_keys)
1120
+ # print('Matched Keys:')
1121
+ # for key in matched_keys:
1122
+ # print(key)
1123
+ # embed()
1124
+
1125
+ # hf_image_processor.push_to_hub(cfg.output_hf_model_hub_path)
1126
+ # hf_processor.push_to_hub(cfg.output_hf_model_hub_path)
1127
+
1128
+ # import code; code.interact(local=vars())
1129
+