File size: 7,164 Bytes
367577f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# MIT License
# Copyright (c) 2025 IPEC at Shanghai AI Laboratory
# Permission is hereby granted, free of charge, to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND.
# Based on code licensed under the Apache License, Version 2.0 by Google Inc. and HuggingFace Inc. team (Copyright 2024).
# coding=utf-8
"""PaliGemmamodel configuration"""
import warnings
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from transformers import CONFIG_MAPPING, AutoConfig
logger = logging.get_logger(__name__)
class SpatialVLAConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`PaliGemmaForConditionalGeneration`]. It is used to instantiate an
PaliGemmamodel according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the PaliGemma-2B.
e.g. [paligemma-hf/paligemma-2b](https://huggingface.co/paligemma-hf/paligemma-2b)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`PaliGemmaVisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 256000):
The image token index to encode the image prompt.
vocab_size (`int`, *optional*, defaults to 257152):
Vocabulary size of the PaliGemmamodel. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~PaliGemmaForConditionalGeneration`]
projection_dim (`int`, *optional*, defaults to 2048):
Dimension of the multimodal projection space.
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden layer of the Language model.
Example:
```python
>>> from transformers import PaliGemmaForConditionalGeneration, PaliGemmaConfig, SiglipVisionConfig, GemmaConfig
>>> # Initializing a Siglip-like vision config
>>> vision_config = SiglipVisionConfig()
>>> # Initializing a PaliGemma config
>>> text_config = GemmaConfig()
>>> # Initializing a PaliGemma paligemma-3b-224 style configuration
>>> configuration = PaliGemmaConfig(vision_config, text_config)
>>> # Initializing a model from the paligemma-3b-224 style configuration
>>> model = PaliGemmaForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "spatialvla"
sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig, "vision_zoe_config": AutoConfig}
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=256000,
vocab_size=257152,
projection_dim=2048,
hidden_size=2048,
vision_zoe_config=None,
action_token_begin_idx=None,
spatial_token_num=259,
use_spatial_token=False,
ego3d_patch_reso=4,
n_freqs=8,
use_vision_zoe=True,
# wrap_lora=False,
**kwargs,
):
self._ignore_index = ignore_index
self.image_token_index = image_token_index
self._vocab_size = vocab_size
self.projection_dim = projection_dim
self.hidden_size = hidden_size
self.vision_config = vision_config
self.is_encoder_decoder = False
if isinstance(self.vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "siglip_vision_model"
)
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
self.vision_config = CONFIG_MAPPING["siglip_vision_model"](
intermediate_size=4096,
hidden_size=1152,
patch_size=14,
image_size=224,
num_hidden_layers=27,
num_attention_heads=16,
vocab_size=257152,
vision_use_head=False,
)
self.text_config = text_config
if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma2"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
self.text_config = CONFIG_MAPPING["gemma2"](
hidden_size=2048,
num_hidden_layers=18,
intermediate_size=16384,
num_attention_heads=8,
num_key_value_heads=1,
is_encoder_decoder=False,
vocab_size=vocab_size,
)
self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2
self.vision_config.projection_dim = projection_dim
# vision zoe config
self.vision_zoe_config = vision_zoe_config
if isinstance(self.vision_zoe_config, dict):
vision_zoe_config["model_type"] = vision_zoe_config["model_type"] if "model_type" in vision_zoe_config else "zoedepth"
self.vision_zoe_config = CONFIG_MAPPING[vision_zoe_config["model_type"]](**vision_zoe_config)
else:
print(f"🔥 init from default configurations ... {self.vision_zoe_config}")
# BUG: initializing zoe in default cause key error
# self.vision_zoe_config = CONFIG_MAPPING["zoedepth"]()
pass
# NOTE: additional attributes
self.action_token_begin_idx = action_token_begin_idx
self.spatial_token_num = spatial_token_num
self.use_spatial_token = use_spatial_token
self.ego3d_patch_reso = ego3d_patch_reso
self.n_freqs = n_freqs
self.use_vision_zoe = use_vision_zoe
# self.wrap_lora = wrap_lora
super().__init__(**kwargs)
@property
def ignore_index(self):
warnings.warn(
"The `ignore_index` attribute is deprecated and will be removed in v4.47.",
FutureWarning,
)
return self._ignore_index
@ignore_index.setter
def ignore_index(self, value):
self._ignore_index = value
def to_dict(self):
output = super().to_dict()
output.pop("_ignore_index", None)
return output |