Zhiding commited on
Commit
c906038
·
1 Parent(s): 442fb1f

clean codes

Browse files
README.md CHANGED
@@ -75,7 +75,7 @@ We provide a [demo inference script](./demo.py) to help you quickly start using
75
  ### 0. Install the dependencies
76
 
77
  ```bash
78
- pip install transformers==4.37.2
79
  pip install flash-attn
80
  ```
81
  **Note**: Latest version of transformers is not compatible with the model.
 
75
  ### 0. Install the dependencies
76
 
77
  ```bash
78
+ pip install transformers
79
  pip install flash-attn
80
  ```
81
  **Note**: Latest version of transformers is not compatible with the model.
config.json CHANGED
@@ -200,6 +200,7 @@
200
  "transformers_version": "4.37.2",
201
  "typical_p": 1.0,
202
  "use_bfloat16": false,
203
- "vision_use_head": false
 
204
  }
205
  }
 
200
  "transformers_version": "4.37.2",
201
  "typical_p": 1.0,
202
  "use_bfloat16": false,
203
+ "vision_use_head": false,
204
+ "_attn_implementation": "flash_attention_2"
205
  }
206
  }
configuration_eagle_chat.py CHANGED
@@ -9,12 +9,10 @@ import copy
9
  from transformers import AutoConfig, LlamaConfig
10
  from transformers.configuration_utils import PretrainedConfig
11
  from transformers.utils import logging
12
- from .configuration_siglip import SiglipVisionConfig
13
- from .configuration_qwen2 import Qwen2Config
14
- from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig
15
  logger = logging.get_logger(__name__)
16
 
17
-
18
  class Eagle2ChatConfig(PretrainedConfig):
19
  model_type = 'eagle_chat'
20
  is_composition = True
@@ -36,6 +34,7 @@ class Eagle2ChatConfig(PretrainedConfig):
36
  mlp_checkpoint=True,
37
  pre_feature_reduction=False,
38
  keep_aspect_ratio=False,
 
39
  **kwargs):
40
  super().__init__(**kwargs)
41
 
@@ -49,8 +48,6 @@ class Eagle2ChatConfig(PretrainedConfig):
49
 
50
  if vision_config['model_type'] == 'siglip_vision_model':
51
  self.vision_config = SiglipVisionConfig(**vision_config)
52
- elif vision_config['model_type'].startswith("MOB"):
53
- self.vision_config = MultiBackboneChannelConcatenationVisionModelConfig(**vision_config)
54
  else:
55
  raise ValueError('Unsupported model_type: {}'.format(vision_config['model_type']))
56
 
@@ -73,6 +70,7 @@ class Eagle2ChatConfig(PretrainedConfig):
73
  self.mlp_checkpoint = mlp_checkpoint
74
  self.pre_feature_reduction = pre_feature_reduction
75
  self.keep_aspect_ratio = keep_aspect_ratio
 
76
  logger.info(f'keep_aspect_ratio: {self.keep_aspect_ratio}')
77
  logger.info(f'vision_select_layer: {self.select_layer}')
78
  logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
 
9
  from transformers import AutoConfig, LlamaConfig
10
  from transformers.configuration_utils import PretrainedConfig
11
  from transformers.utils import logging
12
+ from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
13
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
 
14
  logger = logging.get_logger(__name__)
15
 
 
16
  class Eagle2ChatConfig(PretrainedConfig):
17
  model_type = 'eagle_chat'
18
  is_composition = True
 
34
  mlp_checkpoint=True,
35
  pre_feature_reduction=False,
36
  keep_aspect_ratio=False,
37
+ vocab_size=-1,
38
  **kwargs):
39
  super().__init__(**kwargs)
40
 
 
48
 
49
  if vision_config['model_type'] == 'siglip_vision_model':
50
  self.vision_config = SiglipVisionConfig(**vision_config)
 
 
51
  else:
52
  raise ValueError('Unsupported model_type: {}'.format(vision_config['model_type']))
53
 
 
70
  self.mlp_checkpoint = mlp_checkpoint
71
  self.pre_feature_reduction = pre_feature_reduction
72
  self.keep_aspect_ratio = keep_aspect_ratio
73
+ self.vocab_size = self.llm_config.vocab_size
74
  logger.info(f'keep_aspect_ratio: {self.keep_aspect_ratio}')
75
  logger.info(f'vision_select_layer: {self.select_layer}')
76
  logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
configuration_multi_backbone_channel_concatentation_model.py DELETED
@@ -1,86 +0,0 @@
1
- # --------------------------------------------------------
2
- # Eagle2
3
- # Copyright (c) 2025 NVIDIA
4
- # Licensed under The Apache License [see LICENSE for details]
5
- # --------------------------------------------------------
6
-
7
- import os
8
- from typing import Union
9
-
10
- from transformers.configuration_utils import PretrainedConfig
11
- from transformers.utils import logging
12
- from .configuration_siglip import SiglipVisionConfig
13
- logger = logging.get_logger(__name__)
14
-
15
-
16
- class MultiBackboneChannelConcatenationVisionModelConfig(PretrainedConfig):
17
- r"""
18
- This is the configuration class to store the configuration of a [`MultiBackboneChannelConcatenationVisionModelConfig`]. It is used to
19
- instantiate a vision encoder according to the specified arguments, defining the model architecture.
20
-
21
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
- documentation from [`PretrainedConfig`] for more information.
23
-
24
- Args:
25
- vision_path (str): Path to the vision model or its configuration.
26
- mm_vision_select_layer (int, optional): The layer to select from the vision model
27
- for multi-modal processing. Defaults to -2.
28
- grid_size (int, optional): The size of the grid for vision processing. Defaults to 32.
29
- **kwargs: Additional keyword arguments to be passed to the parent PretrainedConfig.
30
-
31
- """
32
-
33
- model_type = 'MOB'
34
-
35
- def __init__(
36
- self,
37
- vision_path,
38
- mm_vision_select_layer=-2,
39
- grid_size=32,
40
- input_image_size=1024,
41
- hidden_size='lazy_calculation',
42
- image_size=1024,
43
- freeze_backbones=None,
44
- moe_version_type=None,
45
- delay_load=False,
46
- convnext_img_size=1024,
47
- vision_tower_siglip_path=None,
48
- vision_tower_convnext_path='convnext_xxlarge.clip_laion2b_soup',
49
- normalize_type='siglip',
50
- **kwargs,
51
- ):
52
- super().__init__(**kwargs)
53
-
54
- self.normalize_type = normalize_type
55
- self.vision_path = vision_path
56
- self.mm_vision_select_layer = mm_vision_select_layer
57
- self.grid_size = grid_size
58
- self.input_image_size = input_image_size
59
- self.image_size = image_size
60
- self.hidden_size = hidden_size
61
- self.freeze_backbones = freeze_backbones
62
- self.moe_version_type = moe_version_type
63
- self.delay_load = delay_load
64
- self.convnext_img_size = convnext_img_size
65
- # other args. to make it compatable with eagle-next
66
- self.vision_tower_siglip_path = vision_tower_siglip_path
67
- self.vision_tower_convnext_path = vision_tower_convnext_path
68
- self.vision_tower = self.vision_path[4:] # remove `MOB:` prefix
69
-
70
- # asserts
71
- assert image_size == input_image_size, f"input_image_size ({input_image_size}) != image_size ({image_size})"
72
-
73
- @classmethod
74
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
75
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
76
-
77
- if 'vision_config' in config_dict:
78
- config_dict = config_dict['vision_config']
79
-
80
- if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
81
- logger.warning(
82
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
83
- f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
84
- )
85
-
86
- return cls.from_dict(config_dict, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configuration_qwen2.py DELETED
@@ -1,149 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Qwen2 model configuration"""
16
-
17
- from transformers.configuration_utils import PretrainedConfig
18
- from transformers.utils import logging
19
-
20
-
21
- logger = logging.get_logger(__name__)
22
-
23
- QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
- "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json",
25
- }
26
-
27
-
28
- class Qwen2Config(PretrainedConfig):
29
- r"""
30
- This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
31
- Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
32
- with the defaults will yield a similar configuration to that of
33
- Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
34
-
35
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
- documentation from [`PretrainedConfig`] for more information.
37
-
38
-
39
- Args:
40
- vocab_size (`int`, *optional*, defaults to 151936):
41
- Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
42
- `inputs_ids` passed when calling [`Qwen2Model`]
43
- hidden_size (`int`, *optional*, defaults to 4096):
44
- Dimension of the hidden representations.
45
- intermediate_size (`int`, *optional*, defaults to 22016):
46
- Dimension of the MLP representations.
47
- num_hidden_layers (`int`, *optional*, defaults to 32):
48
- Number of hidden layers in the Transformer encoder.
49
- num_attention_heads (`int`, *optional*, defaults to 32):
50
- Number of attention heads for each attention layer in the Transformer encoder.
51
- num_key_value_heads (`int`, *optional*, defaults to 32):
52
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54
- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56
- by meanpooling all the original heads within that group. For more details checkout [this
57
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
58
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59
- The non-linear activation function (function or string) in the decoder.
60
- max_position_embeddings (`int`, *optional*, defaults to 32768):
61
- The maximum sequence length that this model might ever be used with.
62
- initializer_range (`float`, *optional*, defaults to 0.02):
63
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65
- The epsilon used by the rms normalization layers.
66
- use_cache (`bool`, *optional*, defaults to `True`):
67
- Whether or not the model should return the last key/values attentions (not used by all models). Only
68
- relevant if `config.is_decoder=True`.
69
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
70
- Whether the model's input and output word embeddings should be tied.
71
- rope_theta (`float`, *optional*, defaults to 10000.0):
72
- The base period of the RoPE embeddings.
73
- use_sliding_window (`bool`, *optional*, defaults to `False`):
74
- Whether to use sliding window attention.
75
- sliding_window (`int`, *optional*, defaults to 4096):
76
- Sliding window attention (SWA) window size. If not specified, will default to `4096`.
77
- max_window_layers (`int`, *optional*, defaults to 28):
78
- The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
79
- attention_dropout (`float`, *optional*, defaults to 0.0):
80
- The dropout ratio for the attention probabilities.
81
-
82
- ```python
83
- >>> from transformers import Qwen2Model, Qwen2Config
84
-
85
- >>> # Initializing a Qwen2 style configuration
86
- >>> configuration = Qwen2Config()
87
-
88
- >>> # Initializing a model from the Qwen2-7B style configuration
89
- >>> model = Qwen2Model(configuration)
90
-
91
- >>> # Accessing the model configuration
92
- >>> configuration = model.config
93
- ```"""
94
-
95
- model_type = "qwen2"
96
- keys_to_ignore_at_inference = ["past_key_values"]
97
-
98
- def __init__(
99
- self,
100
- vocab_size=151936,
101
- hidden_size=4096,
102
- intermediate_size=22016,
103
- num_hidden_layers=32,
104
- num_attention_heads=32,
105
- num_key_value_heads=32,
106
- hidden_act="silu",
107
- max_position_embeddings=32768,
108
- initializer_range=0.02,
109
- rms_norm_eps=1e-6,
110
- use_cache=True,
111
- tie_word_embeddings=False,
112
- rope_theta=10000.0,
113
- use_sliding_window=False,
114
- sliding_window=4096,
115
- max_window_layers=28,
116
- attention_dropout=0.0,
117
- attn_implementation='flash_attention_2',
118
- **kwargs,
119
- ):
120
- self.vocab_size = vocab_size
121
- self.max_position_embeddings = max_position_embeddings
122
- self.hidden_size = hidden_size
123
- self.intermediate_size = intermediate_size
124
- self.num_hidden_layers = num_hidden_layers
125
- self.num_attention_heads = num_attention_heads
126
- self.use_sliding_window = use_sliding_window
127
- self.sliding_window = sliding_window
128
- self.max_window_layers = max_window_layers
129
-
130
- self.attn_implementation = attn_implementation
131
- if self.attn_implementation is None:
132
- self.attn_implementation = "flash_attention_2"
133
-
134
- # for backward compatibility
135
- if num_key_value_heads is None:
136
- num_key_value_heads = num_attention_heads
137
-
138
- self.num_key_value_heads = num_key_value_heads
139
- self.hidden_act = hidden_act
140
- self.initializer_range = initializer_range
141
- self.rms_norm_eps = rms_norm_eps
142
- self.use_cache = use_cache
143
- self.rope_theta = rope_theta
144
- self.attention_dropout = attention_dropout
145
-
146
- super().__init__(
147
- tie_word_embeddings=tie_word_embeddings,
148
- **kwargs,
149
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configuration_siglip.py DELETED
@@ -1,302 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ Siglip model configuration"""
16
-
17
- import os
18
- from typing import Union
19
-
20
- from transformers.configuration_utils import PretrainedConfig
21
- from transformers.utils import logging
22
-
23
-
24
- logger = logging.get_logger(__name__)
25
-
26
- SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
- "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json",
28
- }
29
-
30
-
31
- class SiglipTextConfig(PretrainedConfig):
32
- r"""
33
- This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
34
- Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
35
- configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
36
- [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
37
-
38
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
- documentation from [`PretrainedConfig`] for more information.
40
-
41
- Args:
42
- vocab_size (`int`, *optional*, defaults to 32000):
43
- Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
44
- the `inputs_ids` passed when calling [`SiglipModel`].
45
- hidden_size (`int`, *optional*, defaults to 768):
46
- Dimensionality of the encoder layers and the pooler layer.
47
- intermediate_size (`int`, *optional*, defaults to 3072):
48
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
49
- num_hidden_layers (`int`, *optional*, defaults to 12):
50
- Number of hidden layers in the Transformer encoder.
51
- num_attention_heads (`int`, *optional*, defaults to 12):
52
- Number of attention heads for each attention layer in the Transformer encoder.
53
- max_position_embeddings (`int`, *optional*, defaults to 64):
54
- The maximum sequence length that this model might ever be used with. Typically set this to something large
55
- just in case (e.g., 512 or 1024 or 2048).
56
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
57
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
58
- `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
59
- layer_norm_eps (`float`, *optional*, defaults to 1e-06):
60
- The epsilon used by the layer normalization layers.
61
- attention_dropout (`float`, *optional*, defaults to 0.0):
62
- The dropout ratio for the attention probabilities.
63
- pad_token_id (`int`, *optional*, defaults to 1):
64
- The id of the padding token in the vocabulary.
65
- bos_token_id (`int`, *optional*, defaults to 49406):
66
- The id of the beginning-of-sequence token in the vocabulary.
67
- eos_token_id (`int`, *optional*, defaults to 49407):
68
- The id of the end-of-sequence token in the vocabulary.
69
-
70
- Example:
71
-
72
- ```python
73
- >>> from transformers import SiglipTextConfig, SiglipTextModel
74
-
75
- >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
76
- >>> configuration = SiglipTextConfig()
77
-
78
- >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
79
- >>> model = SiglipTextModel(configuration)
80
-
81
- >>> # Accessing the model configuration
82
- >>> configuration = model.config
83
- ```"""
84
-
85
- model_type = "siglip_text_model"
86
-
87
- def __init__(
88
- self,
89
- vocab_size=32000,
90
- hidden_size=768,
91
- intermediate_size=3072,
92
- num_hidden_layers=12,
93
- num_attention_heads=12,
94
- max_position_embeddings=64,
95
- hidden_act="gelu_pytorch_tanh",
96
- layer_norm_eps=1e-6,
97
- attention_dropout=0.0,
98
- # This differs from `CLIPTokenizer`'s default and from openai/siglip
99
- # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
100
- pad_token_id=1,
101
- bos_token_id=49406,
102
- eos_token_id=49407,
103
- **kwargs,
104
- ):
105
- super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
106
-
107
- self.vocab_size = vocab_size
108
- self.hidden_size = hidden_size
109
- self.intermediate_size = intermediate_size
110
- self.num_hidden_layers = num_hidden_layers
111
- self.num_attention_heads = num_attention_heads
112
- self.max_position_embeddings = max_position_embeddings
113
- self.layer_norm_eps = layer_norm_eps
114
- self.hidden_act = hidden_act
115
- self.attention_dropout = attention_dropout
116
-
117
- @classmethod
118
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
119
- cls._set_token_in_kwargs(kwargs)
120
-
121
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
122
-
123
- # get the text config dict if we are loading from SiglipConfig
124
- if config_dict.get("model_type") == "siglip":
125
- config_dict = config_dict["text_config"]
126
-
127
- if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
128
- logger.warning(
129
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
130
- f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
131
- )
132
-
133
- return cls.from_dict(config_dict, **kwargs)
134
-
135
-
136
- class SiglipVisionConfig(PretrainedConfig):
137
- r"""
138
- This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
139
- Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
140
- configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
141
- [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
142
-
143
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
144
- documentation from [`PretrainedConfig`] for more information.
145
-
146
- Args:
147
- hidden_size (`int`, *optional*, defaults to 768):
148
- Dimensionality of the encoder layers and the pooler layer.
149
- intermediate_size (`int`, *optional*, defaults to 3072):
150
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
151
- num_hidden_layers (`int`, *optional*, defaults to 12):
152
- Number of hidden layers in the Transformer encoder.
153
- num_attention_heads (`int`, *optional*, defaults to 12):
154
- Number of attention heads for each attention layer in the Transformer encoder.
155
- num_channels (`int`, *optional*, defaults to 3):
156
- Number of channels in the input images.
157
- image_size (`int`, *optional*, defaults to 224):
158
- The size (resolution) of each image.
159
- patch_size (`int`, *optional*, defaults to 16):
160
- The size (resolution) of each patch.
161
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
162
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
163
- `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
164
- layer_norm_eps (`float`, *optional*, defaults to 1e-06):
165
- The epsilon used by the layer normalization layers.
166
- attention_dropout (`float`, *optional*, defaults to 0.0):
167
- The dropout ratio for the attention probabilities.
168
-
169
- Example:
170
-
171
- ```python
172
- >>> from transformers import SiglipVisionConfig, SiglipVisionModel
173
-
174
- >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
175
- >>> configuration = SiglipVisionConfig()
176
-
177
- >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
178
- >>> model = SiglipVisionModel(configuration)
179
-
180
- >>> # Accessing the model configuration
181
- >>> configuration = model.config
182
- ```"""
183
-
184
- model_type = "siglip_vision_model"
185
-
186
- def __init__(
187
- self,
188
- hidden_size=768,
189
- intermediate_size=3072,
190
- num_hidden_layers=12,
191
- num_attention_heads=12,
192
- num_channels=3,
193
- image_size=224,
194
- patch_size=16,
195
- hidden_act="gelu_pytorch_tanh",
196
- layer_norm_eps=1e-6,
197
- attention_dropout=0.0,
198
- **kwargs,
199
- ):
200
- super().__init__(**kwargs)
201
-
202
- self.hidden_size = hidden_size
203
- self.intermediate_size = intermediate_size
204
- self.num_hidden_layers = num_hidden_layers
205
- self.num_attention_heads = num_attention_heads
206
- self.num_channels = num_channels
207
- self.patch_size = patch_size
208
- self.image_size = image_size
209
- self.attention_dropout = attention_dropout
210
- self.layer_norm_eps = layer_norm_eps
211
- self.hidden_act = hidden_act
212
-
213
- @classmethod
214
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
215
- cls._set_token_in_kwargs(kwargs)
216
-
217
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
218
-
219
- # get the vision config dict if we are loading from SiglipConfig
220
- if config_dict.get("model_type") == "siglip":
221
- config_dict = config_dict["vision_config"]
222
-
223
- if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
224
- logger.warning(
225
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
226
- f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
227
- )
228
-
229
- return cls.from_dict(config_dict, **kwargs)
230
-
231
-
232
- class SiglipConfig(PretrainedConfig):
233
- r"""
234
- [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
235
- instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
236
- Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
237
- [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
238
-
239
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
240
- documentation from [`PretrainedConfig`] for more information.
241
-
242
- Args:
243
- text_config (`dict`, *optional*):
244
- Dictionary of configuration options used to initialize [`SiglipTextConfig`].
245
- vision_config (`dict`, *optional*):
246
- Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
247
- kwargs (*optional*):
248
- Dictionary of keyword arguments.
249
-
250
- Example:
251
-
252
- ```python
253
- >>> from transformers import SiglipConfig, SiglipModel
254
-
255
- >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
256
- >>> configuration = SiglipConfig()
257
-
258
- >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
259
- >>> model = SiglipModel(configuration)
260
-
261
- >>> # Accessing the model configuration
262
- >>> configuration = model.config
263
-
264
- >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
265
- >>> from transformers import SiglipTextConfig, SiglipVisionConfig
266
-
267
- >>> # Initializing a SiglipText and SiglipVision configuration
268
- >>> config_text = SiglipTextConfig()
269
- >>> config_vision = SiglipVisionConfig()
270
-
271
- >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
272
- ```"""
273
-
274
- model_type = "siglip"
275
-
276
- def __init__(self, text_config=None, vision_config=None, **kwargs):
277
- super().__init__(**kwargs)
278
-
279
- if text_config is None:
280
- text_config = {}
281
- logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
282
-
283
- if vision_config is None:
284
- vision_config = {}
285
- logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
286
-
287
- self.text_config = SiglipTextConfig(**text_config)
288
- self.vision_config = SiglipVisionConfig(**vision_config)
289
-
290
- self.initializer_factor = 1.0
291
-
292
- @classmethod
293
- def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
294
- r"""
295
- Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
296
- model configuration.
297
-
298
- Returns:
299
- [`SiglipConfig`]: An instance of a configuration object
300
- """
301
-
302
- return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
conversation.py DELETED
@@ -1,434 +0,0 @@
1
- """
2
- Conversation prompt templates.
3
-
4
- We kindly request that you import fastchat instead of copying this file if you wish to use it.
5
- If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
- """
7
-
8
- import dataclasses
9
- from enum import IntEnum, auto
10
- from typing import Any, Dict, List, Tuple, Union
11
-
12
-
13
- class SeparatorStyle(IntEnum):
14
- """Separator styles."""
15
-
16
- ADD_COLON_SINGLE = auto()
17
- ADD_COLON_TWO = auto()
18
- ADD_COLON_SPACE_SINGLE = auto()
19
- NO_COLON_SINGLE = auto()
20
- NO_COLON_TWO = auto()
21
- ADD_NEW_LINE_SINGLE = auto()
22
- LLAMA2 = auto()
23
- CHATGLM = auto()
24
- CHATML = auto()
25
- CHATINTERN = auto()
26
- DOLLY = auto()
27
- RWKV = auto()
28
- PHOENIX = auto()
29
- ROBIN = auto()
30
- FALCON_CHAT = auto()
31
- CHATGLM3 = auto()
32
- INTERNVL_ZH = auto()
33
- MPT = auto()
34
- LLAMA3 = auto()
35
-
36
-
37
- @dataclasses.dataclass
38
- class Conversation:
39
- """A class that manages prompt templates and keeps all conversation history."""
40
-
41
- # The name of this template
42
- name: str
43
- # The template of the system prompt
44
- system_template: str = '{system_message}'
45
- # The system message
46
- system_message: str = ''
47
- # The names of two roles
48
- roles: Tuple[str] = ('USER', 'ASSISTANT')
49
- # All messages. Each item is (role, message).
50
- messages: List[List[str]] = ()
51
- # The number of few shot examples
52
- offset: int = 0
53
- # The separator style and configurations
54
- sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
55
- sep: str = '\n'
56
- sep2: str = None
57
- # Stop criteria (the default one is EOS token)
58
- stop_str: Union[str, List[str]] = None
59
- # Stops generation if meeting any token in this list
60
- stop_token_ids: List[int] = None
61
-
62
- def get_prompt(self) -> str:
63
- """Get the prompt for generation."""
64
- system_prompt = self.system_template.format(system_message=self.system_message)
65
- if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
66
- ret = system_prompt + self.sep
67
- for role, message in self.messages:
68
- if message:
69
- ret += role + ': ' + message + self.sep
70
- else:
71
- ret += role + ':'
72
- return ret
73
- elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
74
- seps = [self.sep, self.sep2]
75
- ret = system_prompt + seps[0]
76
- for i, (role, message) in enumerate(self.messages):
77
- if message:
78
- ret += role + ': ' + message + seps[i % 2]
79
- else:
80
- ret += role + ':'
81
- return ret
82
- elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
83
- ret = system_prompt + self.sep
84
- for role, message in self.messages:
85
- if message:
86
- ret += role + ': ' + message + self.sep
87
- else:
88
- ret += role + ': ' # must be end with a space
89
- return ret
90
- elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
91
- ret = '' if system_prompt == '' else system_prompt + self.sep
92
- for role, message in self.messages:
93
- if message:
94
- ret += role + '\n' + message + self.sep
95
- else:
96
- ret += role + '\n'
97
- return ret
98
- elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
99
- ret = system_prompt
100
- for role, message in self.messages:
101
- if message:
102
- ret += role + message + self.sep
103
- else:
104
- ret += role
105
- return ret
106
- elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
107
- seps = [self.sep, self.sep2]
108
- ret = system_prompt
109
- for i, (role, message) in enumerate(self.messages):
110
- if message:
111
- ret += role + message + seps[i % 2]
112
- else:
113
- ret += role
114
- return ret
115
- elif self.sep_style == SeparatorStyle.RWKV:
116
- ret = system_prompt
117
- for i, (role, message) in enumerate(self.messages):
118
- if message:
119
- ret += (
120
- role
121
- + ': '
122
- + message.replace('\r\n', '\n').replace('\n\n', '\n')
123
- )
124
- ret += '\n\n'
125
- else:
126
- ret += role + ':'
127
- return ret
128
- elif self.sep_style == SeparatorStyle.LLAMA2:
129
- seps = [self.sep, self.sep2]
130
- if self.system_message:
131
- ret = system_prompt
132
- else:
133
- ret = '[INST] '
134
- for i, (role, message) in enumerate(self.messages):
135
- tag = self.roles[i % 2]
136
- if message:
137
- if i == 0:
138
- ret += message + ' '
139
- else:
140
- ret += tag + ' ' + message + seps[i % 2]
141
- else:
142
- ret += tag
143
- return ret
144
- elif self.sep_style == SeparatorStyle.CHATGLM:
145
- # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
146
- # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
147
- round_add_n = 1 if self.name == 'chatglm2' else 0
148
- if system_prompt:
149
- ret = system_prompt + self.sep
150
- else:
151
- ret = ''
152
-
153
- for i, (role, message) in enumerate(self.messages):
154
- if i % 2 == 0:
155
- ret += f'[Round {i//2 + round_add_n}]{self.sep}'
156
-
157
- if message:
158
- ret += f'{role}:{message}{self.sep}'
159
- else:
160
- ret += f'{role}:'
161
- return ret
162
- elif self.sep_style == SeparatorStyle.CHATML:
163
- ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
164
- for role, message in self.messages:
165
- if message:
166
- ret += role + '\n' + message + self.sep + '\n'
167
- else:
168
- ret += role + '\n'
169
- return ret
170
- elif self.sep_style == SeparatorStyle.CHATGLM3:
171
- ret = ''
172
- if self.system_message:
173
- ret += system_prompt
174
- for role, message in self.messages:
175
- if message:
176
- ret += role + '\n' + ' ' + message
177
- else:
178
- ret += role
179
- return ret
180
- elif self.sep_style == SeparatorStyle.CHATINTERN:
181
- # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
182
- seps = [self.sep, self.sep2]
183
- ret = system_prompt
184
- for i, (role, message) in enumerate(self.messages):
185
- # if i % 2 == 0:
186
- # ret += "<s>"
187
- if message:
188
- ret += role + ':' + message + seps[i % 2] + '\n'
189
- else:
190
- ret += role + ':'
191
- return ret
192
- elif self.sep_style == SeparatorStyle.DOLLY:
193
- seps = [self.sep, self.sep2]
194
- ret = system_prompt
195
- for i, (role, message) in enumerate(self.messages):
196
- if message:
197
- ret += role + ':\n' + message + seps[i % 2]
198
- if i % 2 == 1:
199
- ret += '\n\n'
200
- else:
201
- ret += role + ':\n'
202
- return ret
203
- elif self.sep_style == SeparatorStyle.PHOENIX:
204
- ret = system_prompt
205
- for role, message in self.messages:
206
- if message:
207
- ret += role + ': ' + '<s>' + message + '</s>'
208
- else:
209
- ret += role + ': ' + '<s>'
210
- return ret
211
- elif self.sep_style == SeparatorStyle.ROBIN:
212
- ret = system_prompt + self.sep
213
- for role, message in self.messages:
214
- if message:
215
- ret += role + ':\n' + message + self.sep
216
- else:
217
- ret += role + ':\n'
218
- return ret
219
- elif self.sep_style == SeparatorStyle.FALCON_CHAT:
220
- ret = ''
221
- if self.system_message:
222
- ret += system_prompt + self.sep
223
- for role, message in self.messages:
224
- if message:
225
- ret += role + ': ' + message + self.sep
226
- else:
227
- ret += role + ':'
228
-
229
- return ret
230
- elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
231
- seps = [self.sep, self.sep2]
232
- ret = self.system_message + seps[0]
233
- for i, (role, message) in enumerate(self.messages):
234
- if message:
235
- ret += role + ': ' + message + seps[i % 2]
236
- else:
237
- ret += role + ':'
238
- return ret
239
- elif self.sep_style == SeparatorStyle.MPT:
240
- ret = system_prompt + self.sep
241
- for role, message in self.messages:
242
- if message:
243
- if type(message) is tuple:
244
- message, _, _ = message
245
- ret += role + message + self.sep
246
- else:
247
- ret += role
248
- return ret
249
- elif self.sep_style == SeparatorStyle.LLAMA3:
250
- ret = system_prompt + self.sep
251
- for role, message in self.messages:
252
- if message:
253
- if type(message) is tuple:
254
- message, _, _ = message
255
- ret += role + message + self.sep
256
- else:
257
- ret += role
258
- return ret
259
- else:
260
- raise ValueError(f'Invalid style: {self.sep_style}')
261
-
262
- def set_system_message(self, system_message: str):
263
- """Set the system message."""
264
- self.system_message = system_message
265
-
266
- def append_message(self, role: str, message: str):
267
- """Append a new message."""
268
- self.messages.append([role, message])
269
-
270
- def update_last_message(self, message: str):
271
- """Update the last output.
272
-
273
- The last message is typically set to be None when constructing the prompt,
274
- so we need to update it in-place after getting the response from a model.
275
- """
276
- self.messages[-1][1] = message
277
-
278
- def to_gradio_chatbot(self):
279
- """Convert the conversation to gradio chatbot format."""
280
- ret = []
281
- for i, (role, msg) in enumerate(self.messages[self.offset :]):
282
- if i % 2 == 0:
283
- ret.append([msg, None])
284
- else:
285
- ret[-1][-1] = msg
286
- return ret
287
-
288
- def to_openai_api_messages(self):
289
- """Convert the conversation to OpenAI chat completion format."""
290
- ret = [{'role': 'system', 'content': self.system_message}]
291
-
292
- for i, (_, msg) in enumerate(self.messages[self.offset :]):
293
- if i % 2 == 0:
294
- ret.append({'role': 'user', 'content': msg})
295
- else:
296
- if msg is not None:
297
- ret.append({'role': 'assistant', 'content': msg})
298
- return ret
299
-
300
- def copy(self):
301
- return Conversation(
302
- name=self.name,
303
- system_template=self.system_template,
304
- system_message=self.system_message,
305
- roles=self.roles,
306
- messages=[[x, y] for x, y in self.messages],
307
- offset=self.offset,
308
- sep_style=self.sep_style,
309
- sep=self.sep,
310
- sep2=self.sep2,
311
- stop_str=self.stop_str,
312
- stop_token_ids=self.stop_token_ids,
313
- )
314
-
315
- def dict(self):
316
- return {
317
- 'template_name': self.name,
318
- 'system_message': self.system_message,
319
- 'roles': self.roles,
320
- 'messages': self.messages,
321
- 'offset': self.offset,
322
- }
323
-
324
-
325
- # A global registry for all conversation templates
326
- conv_templates: Dict[str, Conversation] = {}
327
-
328
-
329
- def register_conv_template(template: Conversation, override: bool = False):
330
- """Register a new conversation template."""
331
- if not override:
332
- assert (
333
- template.name not in conv_templates
334
- ), f'{template.name} has been registered.'
335
-
336
- conv_templates[template.name] = template
337
-
338
-
339
- def get_conv_template(name: str) -> Conversation:
340
- """Get a conversation template."""
341
- return conv_templates[name].copy()
342
-
343
-
344
- # Note that for inference, using the Hermes-2 and internlm2-chat templates is equivalent.
345
- register_conv_template(
346
- Conversation(
347
- name='Hermes-2',
348
- system_template='<|im_start|>system\n{system_message}',
349
- # note: The new system prompt was not used here to avoid changes in benchmark performance.
350
- # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室及多家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。',
351
- system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
352
- roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
353
- sep_style=SeparatorStyle.MPT,
354
- sep='<|im_end|>',
355
- stop_token_ids=[
356
- 2,
357
- 6,
358
- 7,
359
- 8,
360
- ],
361
- stop_str='<|endoftext|>',
362
- )
363
- )
364
-
365
-
366
- register_conv_template(
367
- Conversation(
368
- name='internlm2-chat',
369
- system_template='<|im_start|>system\n{system_message}',
370
- # note: The new system prompt was not used here to avoid changes in benchmark performance.
371
- # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室及多家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。',
372
- system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
373
- roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
374
- sep_style=SeparatorStyle.MPT,
375
- sep='<|im_end|>',
376
- stop_token_ids=[
377
- 2,
378
- 92543,
379
- 92542
380
- ]
381
- )
382
- )
383
-
384
-
385
- register_conv_template(
386
- Conversation(
387
- name='phi3-chat',
388
- system_template='<|system|>\n{system_message}',
389
- # note: The new system prompt was not used here to avoid changes in benchmark performance.
390
- # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室及��家合作单位联合开发的多模态大语言模型。人工智能实验室致力于原始技术创新,开源开放,共享共创,推动科技进步和产业发展。',
391
- system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
392
- roles=('<|user|>\n', '<|assistant|>\n'),
393
- sep_style=SeparatorStyle.MPT,
394
- sep='<|end|>',
395
- stop_token_ids=[
396
- 2,
397
- 32000,
398
- 32007
399
- ]
400
- )
401
- )
402
- register_conv_template(
403
- Conversation(
404
- name='llama3-chat',
405
- system_template='<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_message}',
406
- system_message='You are an AI assistant whose name is Eagle-Next.',
407
- roles=('<|start_header_id|>user<|end_header_id|>\n\n', '<|start_header_id|>assistant<|end_header_id|>\n\n'),
408
- sep_style=SeparatorStyle.LLAMA3,
409
- sep='<|eot_id|>',
410
- stop_token_ids=[
411
- 128259,
412
- 128001
413
- ]
414
- )
415
- )
416
-
417
- # Qwen-chat default template
418
- # source: https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L130
419
- register_conv_template(
420
- Conversation(
421
- name='qwen2-chat',
422
- system_template='<|im_start|>system\n{system_message}',
423
- system_message='You are a helpful assistant.',
424
- roles=('<|im_start|>user', '<|im_start|>assistant'),
425
- sep_style=SeparatorStyle.CHATML,
426
- sep='<|im_end|>',
427
- stop_token_ids=[
428
- 151643,
429
- 151644,
430
- 151645,
431
- ], # "<|endoftext|>", "<|im_start|>", "<|im_end|>"
432
- stop_str='<|endoftext|>',
433
- )
434
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
convnext.py DELETED
@@ -1,572 +0,0 @@
1
- """ ConvNeXt
2
-
3
- Papers:
4
- * `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
5
- @Article{liu2022convnet,
6
- author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
7
- title = {A ConvNet for the 2020s},
8
- journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
9
- year = {2022},
10
- }
11
-
12
- * `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
13
- @article{Woo2023ConvNeXtV2,
14
- title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
15
- author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
16
- year={2023},
17
- journal={arXiv preprint arXiv:2301.00808},
18
- }
19
-
20
- Original code and weights from:
21
- * https://github.com/facebookresearch/ConvNeXt, original copyright below
22
- * https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
23
-
24
- Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
25
-
26
- Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
27
- """
28
- # ConvNeXt
29
- # Copyright (c) Meta Platforms, Inc. and affiliates.
30
- # All rights reserved.
31
- # This source code is licensed under the MIT license
32
-
33
- # ConvNeXt-V2
34
- # Copyright (c) Meta Platforms, Inc. and affiliates.
35
- # All rights reserved.
36
- # This source code is licensed under the license found in the
37
- # LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
38
- # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
39
-
40
- from collections import OrderedDict
41
- from functools import partial
42
- from typing import Callable, Optional, Tuple, Union
43
-
44
- import torch
45
- import torch.nn as nn
46
-
47
- from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
48
- from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
49
- LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
50
- from timm.layers import NormMlpClassifierHead, ClassifierHead
51
- from timm.models._builder import build_model_with_cfg
52
- from timm.models._manipulate import named_apply, checkpoint_seq
53
- from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
54
-
55
- __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
56
-
57
-
58
- class Downsample(nn.Module):
59
-
60
- def __init__(self, in_chs, out_chs, stride=1, dilation=1):
61
- super().__init__()
62
- avg_stride = stride if dilation == 1 else 1
63
- if stride > 1 or dilation > 1:
64
- avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
65
- self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
66
- else:
67
- self.pool = nn.Identity()
68
-
69
- if in_chs != out_chs:
70
- self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
71
- else:
72
- self.conv = nn.Identity()
73
-
74
- def forward(self, x):
75
- x = self.pool(x)
76
- x = self.conv(x)
77
- return x
78
-
79
-
80
- class ConvNeXtBlock(nn.Module):
81
- """ ConvNeXt Block
82
- There are two equivalent implementations:
83
- (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
84
- (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
85
-
86
- Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
87
- choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
88
- is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
89
- """
90
-
91
- def __init__(
92
- self,
93
- in_chs: int,
94
- out_chs: Optional[int] = None,
95
- kernel_size: int = 7,
96
- stride: int = 1,
97
- dilation: Union[int, Tuple[int, int]] = (1, 1),
98
- mlp_ratio: float = 4,
99
- conv_mlp: bool = False,
100
- conv_bias: bool = True,
101
- use_grn: bool = False,
102
- ls_init_value: Optional[float] = 1e-6,
103
- act_layer: Union[str, Callable] = 'gelu',
104
- norm_layer: Optional[Callable] = None,
105
- drop_path: float = 0.,
106
- ):
107
- """
108
-
109
- Args:
110
- in_chs: Block input channels.
111
- out_chs: Block output channels (same as in_chs if None).
112
- kernel_size: Depthwise convolution kernel size.
113
- stride: Stride of depthwise convolution.
114
- dilation: Tuple specifying input and output dilation of block.
115
- mlp_ratio: MLP expansion ratio.
116
- conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
117
- conv_bias: Apply bias for all convolution (linear) layers.
118
- use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
119
- ls_init_value: Layer-scale init values, layer-scale applied if not None.
120
- act_layer: Activation layer.
121
- norm_layer: Normalization layer (defaults to LN if not specified).
122
- drop_path: Stochastic depth probability.
123
- """
124
- super().__init__()
125
- out_chs = out_chs or in_chs
126
- dilation = to_ntuple(2)(dilation)
127
- act_layer = get_act_layer(act_layer)
128
- if not norm_layer:
129
- norm_layer = LayerNorm2d if conv_mlp else LayerNorm
130
- mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
131
- self.use_conv_mlp = conv_mlp
132
- self.conv_dw = create_conv2d(
133
- in_chs,
134
- out_chs,
135
- kernel_size=kernel_size,
136
- stride=stride,
137
- dilation=dilation[0],
138
- depthwise=True,
139
- bias=conv_bias,
140
- )
141
- self.norm = norm_layer(out_chs)
142
- self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
143
- self.weight = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
144
- if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
145
- self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
146
- else:
147
- self.shortcut = nn.Identity()
148
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
149
-
150
- def forward(self, x):
151
- shortcut = x
152
- x = self.conv_dw(x)
153
- if self.use_conv_mlp:
154
- x = self.norm(x)
155
- x = self.mlp(x)
156
- else:
157
- x = x.permute(0, 2, 3, 1)
158
- x = self.norm(x)
159
- x = self.mlp(x)
160
- x = x.permute(0, 3, 1, 2)
161
- if self.weight is not None:
162
- x = x.mul(self.weight.reshape(1, -1, 1, 1))
163
-
164
- x = self.drop_path(x) + self.shortcut(shortcut)
165
- return x
166
-
167
-
168
- class ConvNeXtStage(nn.Module):
169
-
170
- def __init__(
171
- self,
172
- in_chs,
173
- out_chs,
174
- kernel_size=7,
175
- stride=2,
176
- depth=2,
177
- dilation=(1, 1),
178
- drop_path_rates=None,
179
- ls_init_value=1.0,
180
- conv_mlp=False,
181
- conv_bias=True,
182
- use_grn=False,
183
- act_layer='gelu',
184
- norm_layer=None,
185
- norm_layer_cl=None
186
- ):
187
- super().__init__()
188
- self.grad_checkpointing = False
189
-
190
- if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
191
- ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
192
- pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
193
- self.downsample = nn.Sequential(
194
- norm_layer(in_chs),
195
- create_conv2d(
196
- in_chs,
197
- out_chs,
198
- kernel_size=ds_ks,
199
- stride=stride,
200
- dilation=dilation[0],
201
- padding=pad,
202
- bias=conv_bias,
203
- ),
204
- )
205
- in_chs = out_chs
206
- else:
207
- self.downsample = nn.Identity()
208
-
209
- drop_path_rates = drop_path_rates or [0.] * depth
210
- stage_blocks = []
211
- for i in range(depth):
212
- stage_blocks.append(ConvNeXtBlock(
213
- in_chs=in_chs,
214
- out_chs=out_chs,
215
- kernel_size=kernel_size,
216
- dilation=dilation[1],
217
- drop_path=drop_path_rates[i],
218
- ls_init_value=ls_init_value,
219
- conv_mlp=conv_mlp,
220
- conv_bias=conv_bias,
221
- use_grn=use_grn,
222
- act_layer=act_layer,
223
- norm_layer=norm_layer if conv_mlp else norm_layer_cl,
224
- ))
225
- in_chs = out_chs
226
- self.blocks = nn.Sequential(*stage_blocks)
227
-
228
- def forward(self, x):
229
- x = self.downsample(x)
230
- if self.grad_checkpointing and not torch.jit.is_scripting():
231
- x = checkpoint_seq(self.blocks, x)
232
- else:
233
- x = self.blocks(x)
234
- return x
235
-
236
-
237
- class ConvNeXt(nn.Module):
238
- r""" ConvNeXt
239
- A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
240
- """
241
-
242
- def __init__(
243
- self,
244
- in_chans: int = 3,
245
- num_classes: int = 1000,
246
- global_pool: str = 'avg',
247
- output_stride: int = 32,
248
- depths: Tuple[int, ...] = (3, 3, 9, 3),
249
- dims: Tuple[int, ...] = (96, 192, 384, 768),
250
- kernel_sizes: Union[int, Tuple[int, ...]] = 7,
251
- ls_init_value: Optional[float] = 1e-6,
252
- stem_type: str = 'patch',
253
- patch_size: int = 4,
254
- head_init_scale: float = 1.,
255
- head_norm_first: bool = False,
256
- head_hidden_size: Optional[int] = None,
257
- conv_mlp: bool = False,
258
- conv_bias: bool = True,
259
- use_grn: bool = False,
260
- act_layer: Union[str, Callable] = 'gelu',
261
- norm_layer: Optional[Union[str, Callable]] = None,
262
- norm_eps: Optional[float] = None,
263
- drop_rate: float = 0.,
264
- drop_path_rate: float = 0.,
265
- ):
266
- """
267
- Args:
268
- in_chans: Number of input image channels.
269
- num_classes: Number of classes for classification head.
270
- global_pool: Global pooling type.
271
- output_stride: Output stride of network, one of (8, 16, 32).
272
- depths: Number of blocks at each stage.
273
- dims: Feature dimension at each stage.
274
- kernel_sizes: Depthwise convolution kernel-sizes for each stage.
275
- ls_init_value: Init value for Layer Scale, disabled if None.
276
- stem_type: Type of stem.
277
- patch_size: Stem patch size for patch stem.
278
- head_init_scale: Init scaling value for classifier weights and biases.
279
- head_norm_first: Apply normalization before global pool + head.
280
- head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
281
- conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
282
- conv_bias: Use bias layers w/ all convolutions.
283
- use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
284
- act_layer: Activation layer type.
285
- norm_layer: Normalization layer type.
286
- drop_rate: Head pre-classifier dropout rate.
287
- drop_path_rate: Stochastic depth drop rate.
288
- """
289
- super().__init__()
290
- assert output_stride in (8, 16, 32)
291
- kernel_sizes = to_ntuple(4)(kernel_sizes)
292
- if norm_layer is None:
293
- norm_layer = LayerNorm2d
294
- norm_layer_cl = norm_layer if conv_mlp else LayerNorm
295
- if norm_eps is not None:
296
- norm_layer = partial(norm_layer, eps=norm_eps)
297
- norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
298
- else:
299
- assert conv_mlp,\
300
- 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
301
- norm_layer_cl = norm_layer
302
- if norm_eps is not None:
303
- norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
304
-
305
- self.num_classes = num_classes
306
- self.drop_rate = drop_rate
307
- self.feature_info = []
308
-
309
- assert stem_type in ('patch', 'overlap', 'overlap_tiered')
310
- if stem_type == 'patch':
311
- # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
312
- self.stem = nn.Sequential(
313
- nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
314
- norm_layer(dims[0]),
315
- )
316
- stem_stride = patch_size
317
- else:
318
- mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
319
- self.stem = nn.Sequential(
320
- nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
321
- nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
322
- norm_layer(dims[0]),
323
- )
324
- stem_stride = 4
325
-
326
- self.stages = nn.Sequential()
327
- dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
328
- stages = []
329
- prev_chs = dims[0]
330
- curr_stride = stem_stride
331
- dilation = 1
332
- # 4 feature resolution stages, each consisting of multiple residual blocks
333
- for i in range(4):
334
- stride = 2 if curr_stride == 2 or i > 0 else 1
335
- if curr_stride >= output_stride and stride > 1:
336
- dilation *= stride
337
- stride = 1
338
- curr_stride *= stride
339
- first_dilation = 1 if dilation in (1, 2) else 2
340
- out_chs = dims[i]
341
- stages.append(ConvNeXtStage(
342
- prev_chs,
343
- out_chs,
344
- kernel_size=kernel_sizes[i],
345
- stride=stride,
346
- dilation=(first_dilation, dilation),
347
- depth=depths[i],
348
- drop_path_rates=dp_rates[i],
349
- ls_init_value=ls_init_value,
350
- conv_mlp=conv_mlp,
351
- conv_bias=conv_bias,
352
- use_grn=use_grn,
353
- act_layer=act_layer,
354
- norm_layer=norm_layer,
355
- norm_layer_cl=norm_layer_cl,
356
- ))
357
- prev_chs = out_chs
358
- # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
359
- self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
360
- self.stages = nn.Sequential(*stages)
361
- self.num_features = prev_chs
362
-
363
- # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
364
- # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
365
- if head_norm_first:
366
- assert not head_hidden_size
367
- self.norm_pre = norm_layer(self.num_features)
368
- self.head = ClassifierHead(
369
- self.num_features,
370
- num_classes,
371
- pool_type=global_pool,
372
- drop_rate=self.drop_rate,
373
- )
374
- else:
375
- self.norm_pre = nn.Identity()
376
- self.head = NormMlpClassifierHead(
377
- self.num_features,
378
- num_classes,
379
- hidden_size=head_hidden_size,
380
- pool_type=global_pool,
381
- drop_rate=self.drop_rate,
382
- norm_layer=norm_layer,
383
- act_layer='gelu',
384
- )
385
- named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
386
-
387
- @torch.jit.ignore
388
- def group_matcher(self, coarse=False):
389
- return dict(
390
- stem=r'^stem',
391
- blocks=r'^stages\.(\d+)' if coarse else [
392
- (r'^stages\.(\d+)\.downsample', (0,)), # blocks
393
- (r'^stages\.(\d+)\.blocks\.(\d+)', None),
394
- (r'^norm_pre', (99999,))
395
- ]
396
- )
397
-
398
- @torch.jit.ignore
399
- def set_grad_checkpointing(self, enable=True):
400
- for s in self.stages:
401
- s.grad_checkpointing = enable
402
-
403
- @torch.jit.ignore
404
- def get_classifier(self):
405
- return self.head.fc
406
-
407
- def reset_classifier(self, num_classes=0, global_pool=None):
408
- self.head.reset(num_classes, global_pool)
409
-
410
- def forward_features(self, x):
411
- x = self.stem(x)
412
- x = self.stages(x)
413
- x = self.norm_pre(x)
414
- return x
415
-
416
- def forward_head(self, x, pre_logits: bool = False):
417
- return self.head(x, pre_logits=True) if pre_logits else self.head(x)
418
-
419
- def forward(self, x):
420
- x = self.forward_features(x)
421
- x = self.forward_head(x)
422
- return x
423
-
424
-
425
- def _init_weights(module, name=None, head_init_scale=1.0):
426
- if isinstance(module, nn.Conv2d):
427
- trunc_normal_(module.weight, std=.02)
428
- if module.bias is not None:
429
- nn.init.zeros_(module.bias)
430
- elif isinstance(module, nn.Linear):
431
- trunc_normal_(module.weight, std=.02)
432
- nn.init.zeros_(module.bias)
433
- if name and 'head.' in name:
434
- module.weight.data.mul_(head_init_scale)
435
- module.bias.data.mul_(head_init_scale)
436
-
437
-
438
- def checkpoint_filter_fn(state_dict, model):
439
- """ Remap FB checkpoints -> timm """
440
- if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
441
- out_dict={}
442
- out_dict = {k.replace('gamma', 'weight'): v for k, v in state_dict.items()}
443
- return out_dict # non-FB checkpoint
444
- if 'model' in state_dict:
445
- state_dict = state_dict['model']
446
-
447
- out_dict = {}
448
- if 'visual.trunk.stem.0.weight' in state_dict:
449
- out_dict = {k.replace('visual.trunk.', '').replace('gamma', 'weight'): v for k, v in state_dict.items() if
450
- k.startswith('visual.trunk.')}
451
-
452
- if 'visual.head.proj.weight' in state_dict:
453
- out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
454
- out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
455
- elif 'visual.head.mlp.fc1.weight' in state_dict:
456
- out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
457
- out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
458
- out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
459
- out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
460
- return out_dict
461
-
462
- import re
463
- for k, v in state_dict.items():
464
- k = k.replace('downsample_layers.0.', 'stem.')
465
- k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
466
- k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
467
- k = k.replace('dwconv', 'conv_dw')
468
- k = k.replace('pwconv', 'mlp.fc')
469
- if 'grn' in k:
470
- k = k.replace('grn.beta', 'mlp.grn.bias')
471
- k = k.replace('grn.gamma', 'mlp.grn.weight')
472
- v = v.reshape(v.shape[-1])
473
- k = k.replace('head.', 'head.fc.')
474
- if k.startswith('norm.'):
475
- k = k.replace('norm', 'head.norm')
476
- if v.ndim == 2 and 'head' not in k:
477
- model_shape = model.state_dict()[k].shape
478
- v = v.reshape(model_shape)
479
- k=k.replace('gamma','weight')
480
- out_dict[k] = v
481
-
482
- return out_dict
483
-
484
-
485
- def _create_convnext(variant, pretrained=False, **kwargs):
486
- if kwargs.get('pretrained_cfg', '') == 'fcmae':
487
- # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
488
- # This is workaround loading with num_classes=0 w/o removing norm-layer.
489
- kwargs.setdefault('pretrained_strict', False)
490
-
491
- model = build_model_with_cfg(
492
- ConvNeXt, variant, pretrained,
493
- pretrained_filter_fn=checkpoint_filter_fn,
494
- feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
495
- **kwargs)
496
- return model
497
-
498
-
499
- def _cfg(url='', **kwargs):
500
- return {
501
- 'url': url,
502
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
503
- 'crop_pct': 0.875, 'interpolation': 'bicubic',
504
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
505
- 'first_conv': 'stem.0', 'classifier': 'head.fc',
506
- **kwargs
507
- }
508
-
509
-
510
- def _cfgv2(url='', **kwargs):
511
- return {
512
- 'url': url,
513
- 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
514
- 'crop_pct': 0.875, 'interpolation': 'bicubic',
515
- 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
516
- 'first_conv': 'stem.0', 'classifier': 'head.fc',
517
- 'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
518
- 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
519
- 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
520
- **kwargs
521
- }
522
-
523
-
524
- default_cfgs = generate_default_cfgs({
525
- 'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg(
526
- hf_hub_id='timm/',
527
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
528
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
529
-
530
- 'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg(
531
- hf_hub_id='timm/',
532
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
533
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
534
- 'convnext_xxlarge.clip_laion2b_soup': _cfg(
535
- hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
536
- hf_hub_filename='open_clip_pytorch_model.bin',
537
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
538
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
539
- 'convnext_xxlarge.clip_laion2b_rewind': _cfg(
540
- hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
541
- hf_hub_filename='open_clip_pytorch_model.bin',
542
- mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
543
- input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
544
- })
545
-
546
-
547
-
548
- @register_model
549
- def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt:
550
- model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
551
- model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
552
- return model
553
-
554
-
555
-
556
- # register_model_deprecations(__name__, {
557
- # 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
558
- # 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
559
- # 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
560
- # 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
561
- # 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
562
- # 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
563
- # 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
564
- # 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
565
- # 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
566
- # 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
567
- # 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
568
- # 'convnext_small_in22k': 'convnext_small.fb_in22k',
569
- # 'convnext_base_in22k': 'convnext_base.fb_in22k',
570
- # 'convnext_large_in22k': 'convnext_large.fb_in22k',
571
- # 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
572
- # })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
convnext_encoder.py DELETED
@@ -1,301 +0,0 @@
1
- import torch, os
2
- import torch.nn as nn
3
- from timm import create_model
4
- from transformers import CLIPImageProcessor
5
- from .convnext import convnext_xxlarge
6
- from torch.utils.checkpoint import checkpoint
7
- import torch
8
- from torchvision import transforms as T
9
- from PIL import Image
10
-
11
-
12
-
13
- cfg={
14
- "crop_size": 256,
15
- "do_center_crop": True,
16
- "do_normalize": True,
17
- "do_resize": True,
18
- "feature_extractor_type": "CLIPFeatureExtractor",
19
- "image_mean": [
20
- 0.48145466,
21
- 0.4578275,
22
- 0.40821073
23
- ],
24
- "image_std": [
25
- 0.26862954,
26
- 0.26130258,
27
- 0.27577711
28
- ],
29
- "resample": 3,
30
- "size": 256
31
- }
32
-
33
-
34
-
35
- MEAN_SLIP = [0.5, 0.5, 0.5]
36
- STD_SLIP = [0.5, 0.5, 0.5]
37
-
38
- MEAN_CLIP = [0.48145466, 0.4578275, 0.40821073]
39
- STD_CLIP = [0.26862954, 0.26130258, 0.27577711]
40
-
41
-
42
- a = [s_slip / s_clip for s_slip, s_clip in zip(STD_SLIP, STD_CLIP)]
43
- b = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SLIP, MEAN_CLIP, STD_CLIP)]
44
-
45
-
46
- class SlipToClipTransform:
47
- def __init__(self, a, b):
48
- self.a = torch.tensor(a).view(-1, 1, 1)
49
- self.b = torch.tensor(b).view(-1, 1, 1)
50
-
51
- def __call__(self, x_slip):
52
- return x_slip * self.a.to(x_slip.device) + self.b.to(x_slip.device)
53
- slip_to_clip = SlipToClipTransform(a, b)
54
-
55
- class ConvNextVisionTower(nn.Module):
56
- def __init__(self, vision_tower, args, delay_load=False, normalize_type=None):
57
- super().__init__()
58
-
59
- self.is_loaded = False
60
- self.freeze_vision=args.freeze_vision
61
- self.input_image_size=args.input_image_size
62
- self.vision_tower_name = vision_tower
63
- self.name = 'convnext'
64
- self.select_layer = args.mm_vision_select_layer
65
- self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
66
- self.pre_norm = normalize_type
67
-
68
- print('pre_norm: ', self.pre_norm)
69
- self.delay_load = delay_load
70
- self.load_model()
71
-
72
- def load_model(self):
73
- if 'xxlarge' in self.vision_tower_name:
74
- if self.delay_load:
75
- self.vision_tower = convnext_xxlarge(pretrained=False)
76
- else:
77
- self.vision_tower = convnext_xxlarge(self.vision_tower_name)
78
- setattr(self.vision_tower, 'hidden_size', 3072)
79
- elif os.path.exists(self.vision_tower_name):
80
- self.vision_tower = torch.load(self.vision_tower_name)
81
- else:
82
- assert False, 'Not implemented'
83
-
84
-
85
- self.vision_tower = self.vision_tower.to(torch.bfloat16)
86
-
87
- if self.freeze_vision:
88
- self.vision_tower.requires_grad_(False)
89
-
90
- # if self.vision_tower.grad_checkpointing:
91
- for s in self.vision_tower.stages:
92
- s.grad_checkpointing = True
93
-
94
- self.is_loaded = True
95
-
96
- def feature_select(self, image_forward_outs):
97
-
98
- if self.select_layer>100:
99
- image_features = image_forward_outs[-4:]
100
- else:
101
- image_features = image_forward_outs[-1]
102
- return image_features
103
-
104
- def forward_features(self, x):
105
- x = self.vision_tower.stem(x)
106
- image_forward_out=[]
107
- for blk in self.vision_tower.stages:
108
- x = blk(x)
109
- b,c,h,w=x.shape
110
- image_forward_out.append(x.view(b,c,-1).transpose(1,2))
111
- return image_forward_out
112
-
113
- def forward(self, images):
114
- if self.freeze_vision:
115
- with torch.no_grad():
116
- image_features = self._forward_images(images)
117
- else:
118
- image_features = self._forward_images(images)
119
-
120
- return image_features
121
-
122
- def _forward_images(self, images):
123
-
124
- if type(images) is list:
125
- image_features = []
126
- for image in images:
127
- if self.pre_norm == 'siglip':
128
- dtype = image.dtype
129
- image = slip_to_clip(image.to(torch.float32)).to(dtype)
130
- image_forward_out = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
131
- image_feature = self.feature_select(image_forward_out)
132
- image_features.append(image_feature)
133
- else:
134
- if self.pre_norm == 'siglip':
135
- dtype = images.dtype
136
- images = slip_to_clip(images.to(torch.float32)).to(dtype)
137
- image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype))
138
- image_features = self.feature_select(image_forward_outs)
139
-
140
- return image_features
141
-
142
- @property
143
- def dummy_feature(self):
144
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
145
-
146
- @property
147
- def dtype(self):
148
- return next(self.vision_tower.parameters()).dtype
149
-
150
- @property
151
- def device(self):
152
- return next(self.vision_tower.parameters()).device
153
-
154
- @property
155
- def config(self):
156
- assert NotImplementedError
157
- pass
158
-
159
- @property
160
- def num_attention_heads(self):
161
- # as constant
162
- return 16
163
- @property
164
- def num_layers(self):
165
- # as constant
166
- return 4
167
- @property
168
- def hidden_size(self):
169
- return self.vision_tower.hidden_size
170
-
171
- @property
172
- def num_patches(self):
173
- return (self.input_image_size // self.patch_embed.patch_size[0]) ** 2
174
-
175
-
176
- class ConvNextFPNVisionTower(nn.Module):
177
- def __init__(self,
178
- vision_tower,
179
- args,
180
- fpn_target_level=1,
181
- fpn_layer_idx=[1,2,3],
182
- fpn_input_dim=[768,1536,3072],
183
- delay_load=False):
184
-
185
- super().__init__()
186
-
187
- self.is_loaded = False
188
- self.vision_tower_name = vision_tower.replace('-fpn', 'fpn')
189
- self.freeze_vision = getattr(args, "frozen_backbone", True)
190
- # self.input_image_size = getattr(args, "vision_tower_input_size", 1024)
191
- self.input_image_size = 1024 # hardcode
192
- self.select_layer = args.mm_vision_select_layer # no effect
193
- self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
194
-
195
- self.need_fpn = True
196
- self.fpn_layer_idx = fpn_layer_idx # [1, 2, 3] # x8, x16, x32
197
- self.fpn_input_dim = [768, 1536, 3072]
198
- self.delay_load = delay_load
199
- self.load_model()
200
-
201
- def load_model(self):
202
- if self.is_loaded:
203
- return
204
-
205
- self.image_processor = CLIPImageProcessor(**cfg)
206
- if 'xxlarge' in self.vision_tower_name:
207
- self.vision_tower = convnext_xxlarge(self.vision_tower_name)
208
- setattr(self.vision_tower, 'hidden_size', self.fpn_input_dim)
209
- # setattr(self.vision_tower, 'hidden_size', 3072)
210
- else:
211
- self.vision_tower = convnext_large_mlp(self.vision_tower_name)
212
- setattr(self.vision_tower, 'hidden_size', 1536)
213
- if self.freeze_vision:
214
- self.vision_tower.requires_grad_(False)
215
-
216
- # if self.vision_tower.grad_checkpointing:
217
- for s in self.vision_tower.stages:
218
- s.grad_checkpointing = True
219
-
220
- if self.input_image_size is not None:
221
- self.image_processor.size=self.input_image_size
222
- self.image_processor.crop_size={
223
- 'height':self.input_image_size,
224
- 'width': self.input_image_size
225
- }
226
-
227
- self.is_loaded = True
228
-
229
- @torch.no_grad()
230
- def forward_features(self, x):
231
- x = self.vision_tower.stem(x)
232
- image_forward_out=[]
233
- for blk in self.vision_tower.stages:
234
- x = blk(x)
235
- image_forward_out.append(x)
236
- return image_forward_out
237
-
238
- @torch.no_grad()
239
- def forward(self, images):
240
- if type(images) is list:
241
- image_features = []
242
- for image in images:
243
- image_feature = self.forward_features(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
244
- image_features.append(image_feature)
245
- else:
246
- image_features = self.forward_features(images.to(device=self.device, dtype=self.dtype))
247
- image_features = [image_features[idx] for idx in self.fpn_layer_idx]
248
-
249
- return image_features
250
-
251
- @property
252
- def dummy_feature(self):
253
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
254
-
255
- @property
256
- def dtype(self):
257
- return next(self.vision_tower.parameters()).dtype
258
-
259
- @property
260
- def device(self):
261
- return next(self.vision_tower.parameters()).device
262
-
263
- @property
264
- def config(self):
265
- assert NotImplementedError
266
- pass
267
-
268
- @property
269
- def num_attention_heads(self):
270
- # as constant
271
- return 16
272
- @property
273
- def num_layers(self):
274
- # as constant
275
- return 4
276
- @property
277
- def hidden_size(self):
278
- return self.vision_tower.hidden_size
279
-
280
- @property
281
- def num_patches(self):
282
- return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2
283
-
284
- if __name__ == '__main__':
285
- COMBINED_STD = [s_slip / s_clip for s_slip, s_clip in zip(STD_SigLIP, STD_CLIP)]
286
- COMBINED_MEAN = [(m_slip - m_clip) / s_clip for m_slip, m_clip, s_clip in zip(MEAN_SigLIP, MEAN_CLIP, STD_CLIP)]
287
-
288
- # 定义合并的归一化变换
289
- combined_normalize = T.Normalize(mean=COMBINED_MEAN, std=COMBINED_STD)
290
- x = torch.randn(1, 3, 256, 256).cuda()
291
- a = normalize_clip(x).to(torch.bfloat16)
292
- b = normalize_siglip(x).to(torch.bfloat16)
293
- c = denormalize_siglip(b.to(torch.float32))
294
- c2 = normalize_clip(c).to(torch.bfloat16)
295
- c3 = combined_normalize(b)
296
- print((c-x).abs().max())
297
- print((c2-a).abs().max())
298
- print((c3-a).abs().max())
299
- from IPython import embed
300
- embed()
301
- exit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.py CHANGED
@@ -237,7 +237,7 @@ class ModelWorker:
237
  self.norm_type = 'siglip'
238
  else:
239
  self.norm_type = 'imagenet'
240
-
241
  if any(x in model_path.lower() for x in ['34b']):
242
  device_map = split_model(model_path, self.device)
243
  else:
@@ -261,7 +261,7 @@ class ModelWorker:
261
  self.image_size = self.model.config.force_image_size
262
  self.context_len = tokenizer.model_max_length
263
  self.per_tile_len = 256
264
-
265
  def reload_model(self):
266
  del self.model
267
  torch.cuda.empty_cache()
@@ -297,6 +297,7 @@ class ModelWorker:
297
 
298
  global_image_cnt = 0
299
  history, pil_images, max_input_tile_list = [], [], []
 
300
  for message in send_messages:
301
  if message['role'] == 'user':
302
  prefix = ''
@@ -341,6 +342,7 @@ class ModelWorker:
341
  max_input_tiles_limited_by_contect = params['max_input_tiles']
342
  while True:
343
  image_tiles = []
 
344
  for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
345
  if self.model.config.dynamic_image_size:
346
  tiles = dynamic_preprocess(
@@ -348,6 +350,7 @@ class ModelWorker:
348
  use_thumbnail=self.model.config.use_thumbnail)
349
  else:
350
  tiles = [pil_image]
 
351
  image_tiles += tiles
352
  if (len(image_tiles) * self.per_tile_len < self.context_len):
353
  break
@@ -358,6 +361,8 @@ class ModelWorker:
358
  break
359
 
360
  pixel_values = [transform(item) for item in image_tiles]
 
 
361
  pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
362
 
363
  else:
@@ -372,13 +377,14 @@ class ModelWorker:
372
  max_length=self.context_len,
373
  top_p=top_p,
374
  )
375
-
376
  response = self.model.chat(
377
  tokenizer=self.tokenizer,
378
  pixel_values=pixel_values,
379
  question=question,
380
  history=history,
381
  return_history=False,
 
382
  generation_config=generation_config,
383
  )
384
  self.model.system_message = old_system_message
@@ -390,8 +396,8 @@ class ModelWorker:
390
 
391
  if __name__ == '__main__':
392
  parser = argparse.ArgumentParser()
393
- parser.add_argument('--model-path', type=str, default='nvidia/Eagle2-2B')
394
- parser.add_argument('--model-name', type=str, default='Eagle2-2B')
395
  parser.add_argument('--device', type=str, default='cuda')
396
  parser.add_argument('--load-8bit', action='store_true')
397
  args = parser.parse_args()
@@ -404,9 +410,10 @@ if __name__ == '__main__':
404
  args.device)
405
  prompt = [
406
  {'role': 'system', 'content': 'You are a helpful assistant.'},
407
- {'role': 'user', 'content': 'Describe this image in details.',
408
  'image':[
409
- {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]'}
 
410
  ]
411
  }
412
  ]
 
237
  self.norm_type = 'siglip'
238
  else:
239
  self.norm_type = 'imagenet'
240
+ print('norm_type: ', self.norm_type)
241
  if any(x in model_path.lower() for x in ['34b']):
242
  device_map = split_model(model_path, self.device)
243
  else:
 
261
  self.image_size = self.model.config.force_image_size
262
  self.context_len = tokenizer.model_max_length
263
  self.per_tile_len = 256
264
+ print(self.model)
265
  def reload_model(self):
266
  del self.model
267
  torch.cuda.empty_cache()
 
297
 
298
  global_image_cnt = 0
299
  history, pil_images, max_input_tile_list = [], [], []
300
+
301
  for message in send_messages:
302
  if message['role'] == 'user':
303
  prefix = ''
 
342
  max_input_tiles_limited_by_contect = params['max_input_tiles']
343
  while True:
344
  image_tiles = []
345
+ num_patches_list = []
346
  for current_max_input_tiles, pil_image in zip(max_input_tile_list, pil_images):
347
  if self.model.config.dynamic_image_size:
348
  tiles = dynamic_preprocess(
 
350
  use_thumbnail=self.model.config.use_thumbnail)
351
  else:
352
  tiles = [pil_image]
353
+ num_patches_list.append(len(tiles))
354
  image_tiles += tiles
355
  if (len(image_tiles) * self.per_tile_len < self.context_len):
356
  break
 
361
  break
362
 
363
  pixel_values = [transform(item) for item in image_tiles]
364
+
365
+
366
  pixel_values = torch.stack(pixel_values).to(self.model.device, dtype=torch.bfloat16)
367
 
368
  else:
 
377
  max_length=self.context_len,
378
  top_p=top_p,
379
  )
380
+ print(f'pixel_values: {pixel_values.shape}')
381
  response = self.model.chat(
382
  tokenizer=self.tokenizer,
383
  pixel_values=pixel_values,
384
  question=question,
385
  history=history,
386
  return_history=False,
387
+ num_patches_list=num_patches_list,
388
  generation_config=generation_config,
389
  )
390
  self.model.system_message = old_system_message
 
396
 
397
  if __name__ == '__main__':
398
  parser = argparse.ArgumentParser()
399
+ parser.add_argument('--model-path', type=str, default='/home/zhidingy/workspace/eagle-next/internvl_chat/work_dirs/release/test/Eagle2-2B')
400
+ parser.add_argument('--model-name', type=str, default='Eagle2')
401
  parser.add_argument('--device', type=str, default='cuda')
402
  parser.add_argument('--load-8bit', action='store_true')
403
  args = parser.parse_args()
 
410
  args.device)
411
  prompt = [
412
  {'role': 'system', 'content': 'You are a helpful assistant.'},
413
+ {'role': 'user', 'content': 'Describe these two images in details respectively.',
414
  'image':[
415
+ {'url': 'https://www.nvidia.com/content/dam/en-zz/Solutions/about-nvidia/logo-and-brand/[email protected]'},
416
+ {'url': "https://www.google.com.hk/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png"}
417
  ]
418
  }
419
  ]
flash_attention.py DELETED
@@ -1,76 +0,0 @@
1
- # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
2
- import torch
3
- import torch.nn as nn
4
- from einops import rearrange
5
-
6
- try: # v1
7
- from flash_attn.flash_attn_interface import \
8
- flash_attn_unpadded_qkvpacked_func
9
- except: # v2
10
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
11
-
12
- from flash_attn.bert_padding import pad_input, unpad_input
13
-
14
-
15
- class FlashAttention(nn.Module):
16
- """Implement the scaled dot product attention with softmax.
17
- Arguments
18
- ---------
19
- softmax_scale: The temperature to use for the softmax attention.
20
- (default: 1/sqrt(d_keys) where d_keys is computed at
21
- runtime)
22
- attention_dropout: The dropout rate to apply to the attention
23
- (default: 0.0)
24
- """
25
-
26
- def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
27
- super().__init__()
28
- self.softmax_scale = softmax_scale
29
- self.dropout_p = attention_dropout
30
-
31
- def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
32
- max_s=None, need_weights=False):
33
- """Implements the multihead softmax attention.
34
- Arguments
35
- ---------
36
- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
37
- if unpadded: (nnz, 3, h, d)
38
- key_padding_mask: a bool tensor of shape (B, S)
39
- """
40
- assert not need_weights
41
- assert qkv.dtype in [torch.float16, torch.bfloat16]
42
- assert qkv.is_cuda
43
-
44
- if cu_seqlens is None:
45
- batch_size = qkv.shape[0]
46
- seqlen = qkv.shape[1]
47
- if key_padding_mask is None:
48
- qkv = rearrange(qkv, 'b s ... -> (b s) ...')
49
- max_s = seqlen
50
- cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
51
- device=qkv.device)
52
- output = flash_attn_unpadded_qkvpacked_func(
53
- qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
54
- softmax_scale=self.softmax_scale, causal=causal
55
- )
56
- output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
57
- else:
58
- nheads = qkv.shape[-2]
59
- x = rearrange(qkv, 'b s three h d -> b s (three h d)')
60
- x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
61
- x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
62
- output_unpad = flash_attn_unpadded_qkvpacked_func(
63
- x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
64
- softmax_scale=self.softmax_scale, causal=causal
65
- )
66
- output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
67
- indices, batch_size, seqlen),
68
- 'b s (h d) -> b s h d', h=nheads)
69
- else:
70
- assert max_s is not None
71
- output = flash_attn_unpadded_qkvpacked_func(
72
- qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
73
- softmax_scale=self.softmax_scale, causal=causal
74
- )
75
-
76
- return output, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_eagle_chat.py CHANGED
@@ -11,26 +11,18 @@ import torch.utils.checkpoint
11
  import transformers
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss
14
- from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
15
- LlamaTokenizer)
16
  from transformers.modeling_outputs import CausalLMOutputWithPast
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import ModelOutput, logging
19
  from peft import LoraConfig, get_peft_model
20
- from .configuration_eagle_chat import Eagle2ChatConfig
21
- from .conversation import get_conv_template
22
- from .modeling_siglip import SiglipVisionModel
23
- from .modeling_qwen2 import Qwen2ForCausalLM
24
- from .flash_attention import *
25
- from .multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModel
26
- from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower
27
- from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig
28
- from .siglip_vision_tower import SiglipVisionTower
29
- from .convnext_encoder import ConvNextVisionTower
30
- from .convnext import ConvNeXt
31
 
32
- logger = logging.get_logger(__name__)
33
 
 
 
34
 
35
  def version_cmp(v1, v2, op='eq'):
36
  import operator
@@ -44,25 +36,25 @@ class Eagle2ChatModel(PreTrainedModel):
44
  config_class = Eagle2ChatConfig
45
  main_input_name = 'pixel_values'
46
  _no_split_modules = ['LlamaDecoderLayer']
47
-
 
 
 
 
 
 
 
48
  def __init__(self, config: Eagle2ChatConfig, vision_model=None, language_model=None):
49
  super().__init__(config)
50
 
51
- assert version_cmp(transformers.__version__, '4.37.2', 'ge')
52
- assert version_cmp(transformers.__version__, '4.39.2', 'le')
53
  image_size = config.force_image_size or config.vision_config.image_size
54
- if hasattr(config.vision_config, 'grid_size'):
55
- grid_size = config.vision_config.grid_size
56
- self.patch_size = 14
57
- self.num_image_token = int((grid_size * config.downsample_ratio) ** 2)
58
- else:
59
- patch_size = config.vision_config.patch_size
60
- self.patch_size = patch_size
61
- self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
62
 
63
  self.select_layer = config.select_layer
64
  self.template = config.template
65
-
66
  self.downsample_ratio = config.downsample_ratio
67
 
68
  logger.info(f'num_image_token: {self.num_image_token}')
@@ -70,9 +62,9 @@ class Eagle2ChatModel(PreTrainedModel):
70
  self.vision_model = vision_model
71
  else:
72
  if config.vision_config.model_type == 'siglip_vision_model':
 
 
73
  self.vision_model = SiglipVisionModel(config.vision_config)
74
- elif config.vision_config.model_type.startswith("MOB"):
75
- self.vision_model = MultiBackboneChannelConcatenationVisionModel(config.vision_config, config)
76
 
77
  if language_model is not None:
78
  self.language_model = language_model
@@ -85,35 +77,17 @@ class Eagle2ChatModel(PreTrainedModel):
85
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
86
 
87
  vit_hidden_size = config.vision_config.hidden_size
88
- if vit_hidden_size == 'lazy_calculation':
89
- # a hack for Mixture of Backbones
90
- vit_hidden_size = self.vision_model.hidden_size
91
- print("The lazy calculated hidden_size: {} .. ".format(vit_hidden_size))
92
  llm_hidden_size = config.llm_config.hidden_size
93
- self.moe_version_type = getattr(config.vision_config, 'moe_version_type', None)
94
-
95
- if self.moe_version_type in ['seq_concat', 'feat_concat']:
96
- raise NotImplementedError
97
- elif self.moe_version_type == 'convnext_512_siglip_448':
98
- convnext_hidden_size = vit_hidden_size['convnext']
99
- siglip_hidden_size = vit_hidden_size['siglip']
100
- feature_concat_hidden_size = convnext_hidden_size + siglip_hidden_size * int(1 / self.downsample_ratio) ** 2
101
- self.mlp1 = nn.Sequential(
102
- nn.LayerNorm(feature_concat_hidden_size),
103
- nn.Linear(feature_concat_hidden_size, llm_hidden_size),
104
- nn.GELU(),
105
- nn.Linear(llm_hidden_size, llm_hidden_size)
106
- )
107
- else:
108
- self.mlp1 = nn.Sequential(
109
  nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
110
  nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
111
  nn.GELU(),
112
  nn.Linear(llm_hidden_size, llm_hidden_size)
113
  )
114
  self.img_context_token_id = None
115
- self.conv_template = get_conv_template(self.template)
116
- self.system_message = self.conv_template.system_message
117
 
118
  if config.use_backbone_lora:
119
  self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
@@ -165,19 +139,13 @@ class Eagle2ChatModel(PreTrainedModel):
165
  image_flags = image_flags.squeeze(-1)
166
  input_embeds = self.language_model.get_input_embeddings()(input_ids)
167
 
168
-
169
- if self.moe_version_type in ['seq_concat', 'feat_concat'] and not isinstance(pixel_values, dict):
170
- raise NotImplementedError
171
  vit_embeds = self.extract_feature(pixel_values)
172
 
173
  if not isinstance(image_flags, list):
174
  image_flags = image_flags.squeeze(-1)
175
  vit_embeds = vit_embeds[image_flags == 1]
176
- if isinstance(pixel_values, dict):
177
- # for MOE
178
- vit_batch_size = sum(pixel_values['num_patches'])
179
- else:
180
- vit_batch_size = pixel_values.shape[0]
181
 
182
  B, N, C = input_embeds.shape
183
  input_embeds = input_embeds.reshape(B * N, C)
@@ -206,7 +174,6 @@ class Eagle2ChatModel(PreTrainedModel):
206
  use_cache=use_cache,
207
  output_attentions=output_attentions,
208
  output_hidden_states=output_hidden_states,
209
- return_dict=return_dict,
210
  )
211
  logits = outputs.logits
212
 
@@ -248,7 +215,6 @@ class Eagle2ChatModel(PreTrainedModel):
248
  return x
249
 
250
  def extract_feature(self, pixel_values):
251
-
252
  """
253
  """
254
 
@@ -256,8 +222,10 @@ class Eagle2ChatModel(PreTrainedModel):
256
  vit_embeds = self.vision_model(
257
  pixel_values=pixel_values,
258
  output_hidden_states=False,
259
- return_dict=True).last_hidden_state # torch.Size([B, 1025, 1024])
260
-
 
 
261
  else:
262
  vit_embeds = self.vision_model(
263
  pixel_values=pixel_values,
@@ -265,35 +233,24 @@ class Eagle2ChatModel(PreTrainedModel):
265
  return_dict=True).hidden_states[self.select_layer]
266
  if type(self.vision_model) == SiglipVisionModel:
267
  pass
268
- elif type(self.vision_model) == MultiBackboneChannelConcatenationVisionModel:
269
- pass
270
  else:
271
  vit_embeds = vit_embeds[:, 1:, :] # torch.Size([B, 1024, 1024])
272
 
273
  if self.training and self.neftune_alpha is not None:
274
  vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha)
275
 
276
- if self.moe_version_type in ['feat_concat', 'seq_concat']:
277
- raise NotImplementedError
278
- elif self.moe_version_type == 'convnext_512_siglip_448':
279
- siglip_embeds = vit_embeds['siglip']
280
- convnext_embeds = vit_embeds['convnext']
281
- h = w = int(siglip_embeds.shape[1] ** 0.5)
282
- siglip_embeds = siglip_embeds.reshape(siglip_embeds.shape[0], h, w, -1)
283
- siglip_embeds = self.pixel_shuffle(siglip_embeds, scale_factor=self.downsample_ratio)
284
- siglip_embeds = siglip_embeds.reshape(siglip_embeds.shape[0], -1, siglip_embeds.shape[-1])
285
- vit_embeds = self.mlp1(torch.cat([siglip_embeds, convnext_embeds], dim=-1))
286
- else:
287
- h = w = int(vit_embeds.shape[1] ** 0.5)
288
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
289
 
290
- vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
291
- vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
292
- vit_embeds = self.mlp1(vit_embeds)#.to(pixel_values.device)
 
 
 
293
 
294
  return vit_embeds
295
 
296
- def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
 
297
  history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
298
  IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
299
  if history is not None or return_history:
@@ -316,10 +273,11 @@ class Eagle2ChatModel(PreTrainedModel):
316
  question = questions[idx]
317
  if pixel_values is not None and '<image>' not in question:
318
  question = '<image>\n' + question
319
- template = get_conv_template(self.template)
320
- template.append_message(template.roles[0], question)
321
- template.append_message(template.roles[1], None)
322
- query = template.get_prompt()
 
323
 
324
  image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
325
  query = query.replace('<image>', image_tokens, 1)
@@ -329,7 +287,7 @@ class Eagle2ChatModel(PreTrainedModel):
329
  model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
330
  input_ids = model_inputs['input_ids'].cuda()
331
  attention_mask = model_inputs['attention_mask'].cuda()
332
- eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
333
  generation_config['eos_token_id'] = eos_token_id
334
  generation_output = self.generate(
335
  pixel_values=pixel_values,
@@ -338,7 +296,7 @@ class Eagle2ChatModel(PreTrainedModel):
338
  **generation_config
339
  )
340
  responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
341
- responses = [response.split(template.sep)[0].strip() for response in responses]
342
  return responses
343
 
344
  def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
@@ -355,17 +313,18 @@ class Eagle2ChatModel(PreTrainedModel):
355
  img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
356
  self.img_context_token_id = img_context_token_id
357
 
358
- template = get_conv_template(self.template)
359
- template.system_message = self.system_message
360
- eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
 
361
 
362
  history = [] if history is None else history
363
  for (old_question, old_answer) in history:
364
- template.append_message(template.roles[0], old_question)
365
- template.append_message(template.roles[1], old_answer)
366
- template.append_message(template.roles[0], question)
367
- template.append_message(template.roles[1], None)
368
- query = template.get_prompt()
369
 
370
  if verbose and pixel_values is not None:
371
  image_bs = pixel_values.shape[0]
@@ -382,11 +341,6 @@ class Eagle2ChatModel(PreTrainedModel):
382
  input_ids = model_inputs['input_ids'].cuda()
383
  attention_mask = model_inputs['attention_mask'].cuda()
384
  generation_config['eos_token_id'] = eos_token_id
385
- if self.moe_version_type is not None and self.moe_version_type != 'all_tiling' and self.moe_version_type != 'convnext_512_siglip_448':
386
- pixel_values = {
387
- 'pixel_values': pixel_values,
388
- 'num_patches': num_patches_list # num patch of each image.
389
- }
390
  generation_output = self.generate(
391
  pixel_values=pixel_values,
392
  input_ids=input_ids,
@@ -394,7 +348,7 @@ class Eagle2ChatModel(PreTrainedModel):
394
  **generation_config
395
  )
396
  response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
397
- response = response.split(template.sep)[0].strip()
398
  history.append((question, response))
399
  if return_history:
400
  return response, history
@@ -405,6 +359,17 @@ class Eagle2ChatModel(PreTrainedModel):
405
  print(query_to_print, response)
406
  return response
407
 
 
 
 
 
 
 
 
 
 
 
 
408
  @torch.no_grad()
409
  def generate(
410
  self,
@@ -443,7 +408,6 @@ class Eagle2ChatModel(PreTrainedModel):
443
  attention_mask=attention_mask,
444
  generation_config=generation_config,
445
  output_hidden_states=output_hidden_states,
446
- return_dict=return_dict,
447
  use_cache=True,
448
  **generate_kwargs,
449
  )
 
11
  import transformers
12
  from torch import nn
13
  from torch.nn import CrossEntropyLoss
14
+ from transformers import (AutoModel, GenerationConfig,
15
+ LlamaTokenizer, LlamaForCausalLM)
16
  from transformers.modeling_outputs import CausalLMOutputWithPast
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.utils import ModelOutput, logging
19
  from peft import LoraConfig, get_peft_model
20
+ from transformers.models.siglip.modeling_siglip import SiglipVisionModel
 
 
 
 
 
 
 
 
 
 
21
 
22
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
23
 
24
+ logger = logging.get_logger(__name__)
25
+ from .configuration_eagle_chat import Eagle2ChatConfig
26
 
27
  def version_cmp(v1, v2, op='eq'):
28
  import operator
 
36
  config_class = Eagle2ChatConfig
37
  main_input_name = 'pixel_values'
38
  _no_split_modules = ['LlamaDecoderLayer']
39
+ _supports_flash_attn_2 = True
40
+ _supports_sdpa = True
41
+ _supports_flex_attn = False
42
+ _supports_cache_class = False
43
+ _supports_quantized_cache = False
44
+ _supports_static_cache = False
45
+ _supports_attention_backend = False
46
+
47
  def __init__(self, config: Eagle2ChatConfig, vision_model=None, language_model=None):
48
  super().__init__(config)
49
 
 
 
50
  image_size = config.force_image_size or config.vision_config.image_size
51
+
52
+ patch_size = config.vision_config.patch_size
53
+ self.patch_size = patch_size
54
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
 
 
 
 
55
 
56
  self.select_layer = config.select_layer
57
  self.template = config.template
 
58
  self.downsample_ratio = config.downsample_ratio
59
 
60
  logger.info(f'num_image_token: {self.num_image_token}')
 
62
  self.vision_model = vision_model
63
  else:
64
  if config.vision_config.model_type == 'siglip_vision_model':
65
+ if version_cmp(transformers.__version__, '4.43.0', 'le'):
66
+ config.vision_config._attn_implementation = 'eager'
67
  self.vision_model = SiglipVisionModel(config.vision_config)
 
 
68
 
69
  if language_model is not None:
70
  self.language_model = language_model
 
77
  raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
78
 
79
  vit_hidden_size = config.vision_config.hidden_size
80
+
 
 
 
81
  llm_hidden_size = config.llm_config.hidden_size
82
+
83
+ self.mlp1 = nn.Sequential(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
85
  nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
86
  nn.GELU(),
87
  nn.Linear(llm_hidden_size, llm_hidden_size)
88
  )
89
  self.img_context_token_id = None
90
+ self.system_message = 'You are a helpful assistant.' # Default system message
 
91
 
92
  if config.use_backbone_lora:
93
  self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
 
139
  image_flags = image_flags.squeeze(-1)
140
  input_embeds = self.language_model.get_input_embeddings()(input_ids)
141
 
 
 
 
142
  vit_embeds = self.extract_feature(pixel_values)
143
 
144
  if not isinstance(image_flags, list):
145
  image_flags = image_flags.squeeze(-1)
146
  vit_embeds = vit_embeds[image_flags == 1]
147
+
148
+ vit_batch_size = pixel_values.shape[0]
 
 
 
149
 
150
  B, N, C = input_embeds.shape
151
  input_embeds = input_embeds.reshape(B * N, C)
 
174
  use_cache=use_cache,
175
  output_attentions=output_attentions,
176
  output_hidden_states=output_hidden_states,
 
177
  )
178
  logits = outputs.logits
179
 
 
215
  return x
216
 
217
  def extract_feature(self, pixel_values):
 
218
  """
219
  """
220
 
 
222
  vit_embeds = self.vision_model(
223
  pixel_values=pixel_values,
224
  output_hidden_states=False,
225
+ return_dict=True)
226
+ # if there is vit_embeds.last_hidden_state, use it.
227
+ if hasattr(vit_embeds, 'last_hidden_state'):
228
+ vit_embeds = vit_embeds.last_hidden_state
229
  else:
230
  vit_embeds = self.vision_model(
231
  pixel_values=pixel_values,
 
233
  return_dict=True).hidden_states[self.select_layer]
234
  if type(self.vision_model) == SiglipVisionModel:
235
  pass
 
 
236
  else:
237
  vit_embeds = vit_embeds[:, 1:, :] # torch.Size([B, 1024, 1024])
238
 
239
  if self.training and self.neftune_alpha is not None:
240
  vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha)
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ h = w = int(vit_embeds.shape[1] ** 0.5)
244
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
245
+
246
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) # torch.Size([B, 1024, 1024]) -> torch.Size([B, 16, 16, 4096])
247
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) # torch.Size([B, 16, 16, 4096]) -> torch.Size([B, 256, 4096])
248
+ vit_embeds = self.mlp1(vit_embeds)#.to(pixel_values.device)
249
 
250
  return vit_embeds
251
 
252
+ def batch_chat(self,
253
+ tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
254
  history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
255
  IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
256
  if history is not None or return_history:
 
273
  question = questions[idx]
274
  if pixel_values is not None and '<image>' not in question:
275
  question = '<image>\n' + question
276
+ template_messages = []
277
+ sep = tokenizer.eos_token
278
+ template_messages.append(('<|im_start|>user', question))
279
+ template_messages.append(('<|im_end|>assistant', None))
280
+ query = self.get_prompt(self.system_message, template_messages, sep)
281
 
282
  image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
283
  query = query.replace('<image>', image_tokens, 1)
 
287
  model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
288
  input_ids = model_inputs['input_ids'].cuda()
289
  attention_mask = model_inputs['attention_mask'].cuda()
290
+ eos_token_id = tokenizer.convert_tokens_to_ids(sep)
291
  generation_config['eos_token_id'] = eos_token_id
292
  generation_output = self.generate(
293
  pixel_values=pixel_values,
 
296
  **generation_config
297
  )
298
  responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
299
+ responses = [response.split(sep)[0].strip() for response in responses]
300
  return responses
301
 
302
  def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
 
313
  img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
314
  self.img_context_token_id = img_context_token_id
315
 
316
+ template_messages = []
317
+ system_message = f'<|im_start|>system\n{self.system_message}'
318
+ sep = tokenizer.eos_token
319
+ eos_token_id = tokenizer.convert_tokens_to_ids(sep)
320
 
321
  history = [] if history is None else history
322
  for (old_question, old_answer) in history:
323
+ template_messages.append(('<|im_start|>user', old_question))
324
+ template_messages.append(('<|im_start|>assistant', old_answer))
325
+ template_messages.append(('<|im_start|>user', question))
326
+ template_messages.append(('<|im_end|>assistant', None))
327
+ query = self.get_prompt(system_message, template_messages, sep)
328
 
329
  if verbose and pixel_values is not None:
330
  image_bs = pixel_values.shape[0]
 
341
  input_ids = model_inputs['input_ids'].cuda()
342
  attention_mask = model_inputs['attention_mask'].cuda()
343
  generation_config['eos_token_id'] = eos_token_id
 
 
 
 
 
344
  generation_output = self.generate(
345
  pixel_values=pixel_values,
346
  input_ids=input_ids,
 
348
  **generation_config
349
  )
350
  response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
351
+ response = response.split(sep)[0].strip()
352
  history.append((question, response))
353
  if return_history:
354
  return response, history
 
359
  print(query_to_print, response)
360
  return response
361
 
362
+ def get_prompt(self, system_prompt, messages, sep) -> str:
363
+ """Get the prompt for generation."""
364
+
365
+ ret = '' if system_prompt == '' else system_prompt + sep + '\n'
366
+ for role, message in messages:
367
+ if message:
368
+ ret += role + '\n' + message + sep + '\n'
369
+ else:
370
+ ret += role + '\n'
371
+ return ret
372
+
373
  @torch.no_grad()
374
  def generate(
375
  self,
 
408
  attention_mask=attention_mask,
409
  generation_config=generation_config,
410
  output_hidden_states=output_hidden_states,
 
411
  use_cache=True,
412
  **generate_kwargs,
413
  )
modeling_qwen2.py DELETED
@@ -1,1744 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- """ PyTorch Qwen2 model."""
21
- import inspect
22
- import math
23
- import warnings
24
- from typing import List, Optional, Tuple, Union
25
-
26
- import torch
27
- import torch.nn.functional as F
28
- import torch.utils.checkpoint
29
- from torch import nn
30
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
-
32
- from transformers.activations import ACT2FN
33
- from transformers.cache_utils import Cache, DynamicCache
34
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
35
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
36
- from transformers.modeling_utils import PreTrainedModel
37
- from transformers.utils import (
38
- add_start_docstrings,
39
- add_start_docstrings_to_model_forward,
40
- is_flash_attn_2_available,
41
- is_flash_attn_greater_or_equal_2_10,
42
- logging,
43
- replace_return_docstrings,
44
- )
45
- from .configuration_qwen2 import Qwen2Config
46
-
47
-
48
- if is_flash_attn_2_available():
49
- from flash_attn import flash_attn_func, flash_attn_varlen_func
50
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51
-
52
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
53
-
54
-
55
- logger = logging.get_logger(__name__)
56
-
57
-
58
- _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
59
- _CONFIG_FOR_DOC = "Qwen2Config"
60
-
61
- QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [
62
- "Qwen/Qwen2-7B-beta",
63
- # See all Qwen2 models at https://huggingface.co/models?filter=qwen2
64
- ]
65
-
66
-
67
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
- def _get_unpad_data(attention_mask):
69
- seqlens_in_batch = (attention_mask>0).sum(dim=-1, dtype=torch.int32)
70
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
- max_seqlen_in_batch = seqlens_in_batch.max().item()
72
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
73
- return (
74
- indices,
75
- cu_seqlens,
76
- max_seqlen_in_batch,
77
- )
78
-
79
- def _get_unpad_data_packing(attention_mask, sub_sample_lengths):
80
- seqlens_in_batch = []
81
- for i, per_sub_sample_lengths in enumerate(sub_sample_lengths):
82
- if (attention_mask[i]==0).sum() == per_sub_sample_lengths[-1]:
83
- per_sub_sample_lengths = per_sub_sample_lengths[:-1]
84
- seqlens_in_batch.extend(per_sub_sample_lengths)
85
- seqlens_in_batch = torch.tensor(seqlens_in_batch, device=attention_mask.device, dtype=torch.int32)
86
-
87
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
88
- max_seqlen_in_batch = seqlens_in_batch.max().item()
89
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
90
- return (
91
- indices,
92
- cu_seqlens,
93
- max_seqlen_in_batch,
94
- )
95
-
96
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
97
- class Qwen2RMSNorm(nn.Module):
98
- def __init__(self, hidden_size, eps=1e-6):
99
- """
100
- Qwen2RMSNorm is equivalent to T5LayerNorm
101
- """
102
- super().__init__()
103
- self.weight = nn.Parameter(torch.ones(hidden_size))
104
- self.variance_epsilon = eps
105
-
106
- def forward(self, hidden_states):
107
- input_dtype = hidden_states.dtype
108
- hidden_states = hidden_states.to(torch.float32)
109
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
110
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
111
- return self.weight * hidden_states.to(input_dtype)
112
-
113
-
114
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
115
- class Qwen2RotaryEmbedding(nn.Module):
116
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
117
- super().__init__()
118
-
119
- self.dim = dim
120
- self.max_position_embeddings = max_position_embeddings
121
- self.base = base
122
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
123
- self.register_buffer("inv_freq", inv_freq, persistent=False)
124
-
125
- # Build here to make `torch.jit.trace` work.
126
- self._set_cos_sin_cache(
127
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
128
- )
129
-
130
- def _set_cos_sin_cache(self, seq_len, device, dtype):
131
- self.max_seq_len_cached = seq_len
132
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
133
-
134
- freqs = torch.outer(t, self.inv_freq)
135
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
136
- emb = torch.cat((freqs, freqs), dim=-1)
137
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
138
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
139
-
140
- def forward(self, x, seq_len=None):
141
- # x: [bs, num_attention_heads, seq_len, head_size]
142
- if seq_len > self.max_seq_len_cached:
143
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
144
-
145
- return (
146
- self.cos_cached[:seq_len].to(dtype=x.dtype),
147
- self.sin_cached[:seq_len].to(dtype=x.dtype),
148
- )
149
-
150
-
151
- # Copied from transformers.models.llama.modeling_llama.rotate_half
152
- def rotate_half(x):
153
- """Rotates half the hidden dims of the input."""
154
- x1 = x[..., : x.shape[-1] // 2]
155
- x2 = x[..., x.shape[-1] // 2 :]
156
- return torch.cat((-x2, x1), dim=-1)
157
-
158
-
159
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
160
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
161
- """Applies Rotary Position Embedding to the query and key tensors.
162
-
163
- Args:
164
- q (`torch.Tensor`): The query tensor.
165
- k (`torch.Tensor`): The key tensor.
166
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
167
- sin (`torch.Tensor`): The sine part of the rotary embedding.
168
- position_ids (`torch.Tensor`):
169
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
170
- used to pass offsetted position ids when working with a KV-cache.
171
- unsqueeze_dim (`int`, *optional*, defaults to 1):
172
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
173
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
174
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
175
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
176
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
177
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
178
- Returns:
179
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
180
- """
181
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
182
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
183
- q_embed = (q * cos) + (rotate_half(q) * sin)
184
- k_embed = (k * cos) + (rotate_half(k) * sin)
185
- return q_embed, k_embed
186
-
187
-
188
- # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
189
- class Qwen2MLP(nn.Module):
190
- def __init__(self, config):
191
- super().__init__()
192
- self.config = config
193
- self.hidden_size = config.hidden_size
194
- self.intermediate_size = config.intermediate_size
195
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
196
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
197
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
198
- self.act_fn = ACT2FN[config.hidden_act]
199
-
200
- def forward(self, x):
201
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
202
-
203
-
204
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
205
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
206
- """
207
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
208
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
209
- """
210
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
211
- if n_rep == 1:
212
- return hidden_states
213
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
214
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
215
-
216
-
217
- class Qwen2Attention(nn.Module):
218
- """
219
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
220
- and "Generating Long Sequences with Sparse Transformers".
221
- """
222
-
223
- def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
224
- super().__init__()
225
- self.config = config
226
- self.layer_idx = layer_idx
227
- if layer_idx is None:
228
- logger.warning_once(
229
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
230
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
231
- "when creating this class."
232
- )
233
-
234
- self.hidden_size = config.hidden_size
235
- self.num_heads = config.num_attention_heads
236
- self.head_dim = self.hidden_size // self.num_heads
237
- self.num_key_value_heads = config.num_key_value_heads
238
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
239
- self.max_position_embeddings = config.max_position_embeddings
240
- self.rope_theta = config.rope_theta
241
- self.is_causal = True
242
- self.attention_dropout = config.attention_dropout
243
-
244
- if (self.head_dim * self.num_heads) != self.hidden_size:
245
- raise ValueError(
246
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
247
- f" and `num_heads`: {self.num_heads})."
248
- )
249
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
250
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
251
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
252
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
253
-
254
- self.rotary_emb = Qwen2RotaryEmbedding(
255
- self.head_dim,
256
- max_position_embeddings=self.max_position_embeddings,
257
- base=self.rope_theta,
258
- )
259
-
260
- def forward(
261
- self,
262
- hidden_states: torch.Tensor,
263
- attention_mask: Optional[torch.Tensor] = None,
264
- position_ids: Optional[torch.LongTensor] = None,
265
- past_key_value: Optional[Cache] = None,
266
- output_attentions: bool = False,
267
- use_cache: bool = False,
268
- **kwargs,
269
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
270
- if "padding_mask" in kwargs:
271
- warnings.warn(
272
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
273
- )
274
- bsz, q_len, _ = hidden_states.size()
275
-
276
- query_states = self.q_proj(hidden_states)
277
- key_states = self.k_proj(hidden_states)
278
- value_states = self.v_proj(hidden_states)
279
-
280
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
281
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
282
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
283
-
284
- kv_seq_len = key_states.shape[-2]
285
- if past_key_value is not None:
286
- if self.layer_idx is None:
287
- raise ValueError(
288
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
289
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
290
- "with a layer index."
291
- )
292
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
293
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
294
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
295
-
296
- if past_key_value is not None:
297
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
298
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
299
-
300
- # repeat k/v heads if n_kv_heads < n_heads
301
- key_states = repeat_kv(key_states, self.num_key_value_groups)
302
- value_states = repeat_kv(value_states, self.num_key_value_groups)
303
-
304
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
305
-
306
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
307
- raise ValueError(
308
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
309
- f" {attn_weights.size()}"
310
- )
311
-
312
- if attention_mask is not None:
313
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
314
- raise ValueError(
315
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
316
- )
317
-
318
- attn_weights = attn_weights + attention_mask
319
-
320
- # upcast attention to fp32
321
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
322
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
323
- attn_output = torch.matmul(attn_weights, value_states)
324
-
325
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
326
- raise ValueError(
327
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
328
- f" {attn_output.size()}"
329
- )
330
-
331
- attn_output = attn_output.transpose(1, 2).contiguous()
332
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
333
-
334
- attn_output = self.o_proj(attn_output)
335
-
336
- if not output_attentions:
337
- attn_weights = None
338
-
339
- return attn_output, attn_weights, past_key_value
340
-
341
-
342
- class Qwen2FlashAttention2(Qwen2Attention):
343
- """
344
- Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
345
- as the weights of the module stays untouched. The only required change would be on the forward pass
346
- where it needs to correctly call the public API of flash attention and deal with padding tokens
347
- in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
348
- config.max_window_layers layers.
349
- """
350
-
351
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
352
- def __init__(self, *args, **kwargs):
353
- super().__init__(*args, **kwargs)
354
-
355
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
356
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
357
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
358
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
359
-
360
- def forward(
361
- self,
362
- hidden_states: torch.Tensor,
363
- attention_mask: Optional[torch.Tensor] = None,
364
- position_ids: Optional[torch.LongTensor] = None,
365
- past_key_value: Optional[Cache] = None,
366
- output_attentions: bool = False,
367
- use_cache: bool = False,
368
- **kwargs,
369
- ):
370
- if "padding_mask" in kwargs:
371
- warnings.warn(
372
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
373
- )
374
-
375
- # overwrite attention_mask with padding_mask
376
- attention_mask = kwargs.pop("padding_mask")
377
- bsz, q_len, _ = hidden_states.size()
378
-
379
- query_states = self.q_proj(hidden_states)
380
- key_states = self.k_proj(hidden_states)
381
- value_states = self.v_proj(hidden_states)
382
-
383
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
384
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
385
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
386
-
387
- kv_seq_len = key_states.shape[-2]
388
- if past_key_value is not None:
389
- if self.layer_idx is None:
390
- raise ValueError(
391
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
392
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
393
- "with a layer index."
394
- )
395
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
396
-
397
- # Because the input can be padded, the absolute sequence length depends on the max position id.
398
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
399
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
400
-
401
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
402
-
403
- use_sliding_windows = (
404
- _flash_supports_window_size
405
- and getattr(self.config, "sliding_window", None) is not None
406
- and kv_seq_len > self.config.sliding_window
407
- and self.config.use_sliding_window
408
- )
409
-
410
- if not _flash_supports_window_size:
411
- logger.warning_once(
412
- "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
413
- " make sure to upgrade flash-attn library."
414
- )
415
-
416
- if past_key_value is not None:
417
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
418
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
419
- if (
420
- getattr(self.config, "sliding_window", None) is not None
421
- and kv_seq_len > self.config.sliding_window
422
- and cache_has_contents
423
- ):
424
- slicing_tokens = 1 - self.config.sliding_window
425
-
426
- past_key = past_key_value[self.layer_idx][0]
427
- past_value = past_key_value[self.layer_idx][1]
428
-
429
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
430
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
431
-
432
- if past_key.shape[-2] != self.config.sliding_window - 1:
433
- raise ValueError(
434
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
435
- f" {past_key.shape}"
436
- )
437
-
438
- if attention_mask is not None:
439
- attention_mask = attention_mask[:, slicing_tokens:]
440
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
441
-
442
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
443
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
444
-
445
- # repeat k/v heads if n_kv_heads < n_heads
446
- key_states = repeat_kv(key_states, self.num_key_value_groups)
447
- value_states = repeat_kv(value_states, self.num_key_value_groups)
448
- dropout_rate = 0.0 if not self.training else self.attention_dropout
449
-
450
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
451
- # therefore the input hidden states gets silently casted in float32. Hence, we need
452
- # cast them back in float16 just to be sure everything works as expected.
453
- input_dtype = query_states.dtype
454
- if input_dtype == torch.float32:
455
- if torch.is_autocast_enabled():
456
- target_dtype = torch.get_autocast_gpu_dtype()
457
- # Handle the case where the model is quantized
458
- elif hasattr(self.config, "_pre_quantization_dtype"):
459
- target_dtype = self.config._pre_quantization_dtype
460
- else:
461
- target_dtype = self.q_proj.weight.dtype
462
-
463
- logger.warning_once(
464
- f"The input hidden states seems to be silently casted in float32, this might be related to"
465
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
466
- f" {target_dtype}."
467
- )
468
-
469
- query_states = query_states.to(target_dtype)
470
- key_states = key_states.to(target_dtype)
471
- value_states = value_states.to(target_dtype)
472
-
473
- # Reashape to the expected shape for Flash Attention
474
- query_states = query_states.transpose(1, 2)
475
- key_states = key_states.transpose(1, 2)
476
- value_states = value_states.transpose(1, 2)
477
-
478
- attn_output = self._flash_attention_forward(
479
- query_states,
480
- key_states,
481
- value_states,
482
- attention_mask,
483
- q_len,
484
- dropout=dropout_rate,
485
- use_sliding_windows=use_sliding_windows,
486
- )
487
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
488
- attn_output = self.o_proj(attn_output)
489
-
490
- if not output_attentions:
491
- attn_weights = None
492
-
493
- return attn_output, attn_weights, past_key_value
494
-
495
- def _flash_attention_forward(
496
- self,
497
- query_states,
498
- key_states,
499
- value_states,
500
- attention_mask,
501
- query_length,
502
- dropout=0.0,
503
- softmax_scale=None,
504
- use_sliding_windows=False,
505
- ):
506
- """
507
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
508
- first unpad the input, then computes the attention scores and pad the final attention scores.
509
-
510
- Args:
511
- query_states (`torch.Tensor`):
512
- Input query states to be passed to Flash Attention API
513
- key_states (`torch.Tensor`):
514
- Input key states to be passed to Flash Attention API
515
- value_states (`torch.Tensor`):
516
- Input value states to be passed to Flash Attention API
517
- attention_mask (`torch.Tensor`):
518
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
519
- position of padding tokens and 1 for the position of non-padding tokens.
520
- dropout (`int`, *optional*):
521
- Attention dropout
522
- softmax_scale (`float`, *optional*):
523
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
524
- use_sliding_windows (`bool`, *optional*):
525
- Whether to activate sliding window attention.
526
- """
527
- if not self._flash_attn_uses_top_left_mask:
528
- causal = self.is_causal
529
- else:
530
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
531
- causal = self.is_causal and query_length != 1
532
-
533
- # Decide whether to use SWA or not by layer index.
534
- if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
535
- use_sliding_windows = False
536
-
537
- # Contains at least one padding token in the sequence
538
- if attention_mask is not None:
539
- batch_size = query_states.shape[0]
540
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
541
- query_states, key_states, value_states, attention_mask, query_length
542
- )
543
-
544
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
545
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
546
-
547
- if not use_sliding_windows:
548
- attn_output_unpad = flash_attn_varlen_func(
549
- query_states,
550
- key_states,
551
- value_states,
552
- cu_seqlens_q=cu_seqlens_q,
553
- cu_seqlens_k=cu_seqlens_k,
554
- max_seqlen_q=max_seqlen_in_batch_q,
555
- max_seqlen_k=max_seqlen_in_batch_k,
556
- dropout_p=dropout,
557
- softmax_scale=softmax_scale,
558
- causal=causal,
559
- )
560
- else:
561
- attn_output_unpad = flash_attn_varlen_func(
562
- query_states,
563
- key_states,
564
- value_states,
565
- cu_seqlens_q=cu_seqlens_q,
566
- cu_seqlens_k=cu_seqlens_k,
567
- max_seqlen_q=max_seqlen_in_batch_q,
568
- max_seqlen_k=max_seqlen_in_batch_k,
569
- dropout_p=dropout,
570
- softmax_scale=softmax_scale,
571
- causal=causal,
572
- window_size=(self.config.sliding_window, self.config.sliding_window),
573
- )
574
-
575
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
576
- else:
577
- if not use_sliding_windows:
578
- attn_output = flash_attn_func(
579
- query_states,
580
- key_states,
581
- value_states,
582
- dropout,
583
- softmax_scale=softmax_scale,
584
- causal=causal,
585
- )
586
- else:
587
- attn_output = flash_attn_func(
588
- query_states,
589
- key_states,
590
- value_states,
591
- dropout,
592
- softmax_scale=softmax_scale,
593
- causal=causal,
594
- window_size=(self.config.sliding_window, self.config.sliding_window),
595
- )
596
-
597
- return attn_output
598
-
599
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
600
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
601
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
602
-
603
- # On the first iteration we need to properly re-create the padding mask
604
- # by slicing it on the proper place
605
- if kv_seq_len != attention_mask.shape[-1]:
606
- attention_mask_num_tokens = attention_mask.shape[-1]
607
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
608
-
609
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
610
-
611
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
612
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
613
-
614
- if query_length == kv_seq_len:
615
- query_layer = index_first_axis(
616
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
617
- )
618
- cu_seqlens_q = cu_seqlens_k
619
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
620
- indices_q = indices_k
621
- elif query_length == 1:
622
- max_seqlen_in_batch_q = 1
623
- cu_seqlens_q = torch.arange(
624
- batch_size + 1, dtype=torch.int32, device=query_layer.device
625
- ) # There is a memcpy here, that is very bad.
626
- indices_q = cu_seqlens_q[:-1]
627
- query_layer = query_layer.squeeze(1)
628
- else:
629
- # The -q_len: slice assumes left padding.
630
- attention_mask = attention_mask[:, -query_length:]
631
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
632
-
633
- return (
634
- query_layer,
635
- key_layer,
636
- value_layer,
637
- indices_q,
638
- (cu_seqlens_q, cu_seqlens_k),
639
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
640
- )
641
- class Qwen2FlashAttention2_packing(Qwen2Attention):
642
- """
643
- Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
644
- as the weights of the module stays untouched. The only required change would be on the forward pass
645
- where it needs to correctly call the public API of flash attention and deal with padding tokens
646
- in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
647
- config.max_window_layers layers.
648
- """
649
-
650
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
651
- def __init__(self, *args, **kwargs):
652
- super().__init__(*args, **kwargs)
653
-
654
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
655
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
656
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
657
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
658
-
659
- def forward(
660
- self,
661
- hidden_states: torch.Tensor,
662
- attention_mask: Optional[torch.Tensor] = None,
663
- position_ids: Optional[torch.LongTensor] = None,
664
- past_key_value: Optional[Cache] = None,
665
- output_attentions: bool = False,
666
- use_cache: bool = False,
667
- sub_sample_lengths = None,
668
- **kwargs,
669
- ):
670
- if "padding_mask" in kwargs:
671
- warnings.warn(
672
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
673
- )
674
-
675
- # overwrite attention_mask with padding_mask
676
- attention_mask = kwargs.pop("padding_mask")
677
- bsz, q_len, _ = hidden_states.size()
678
-
679
- query_states = self.q_proj(hidden_states)
680
- key_states = self.k_proj(hidden_states)
681
- value_states = self.v_proj(hidden_states)
682
-
683
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
684
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
685
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
686
-
687
- kv_seq_len = key_states.shape[-2]
688
- if past_key_value is not None:
689
- if self.layer_idx is None:
690
- raise ValueError(
691
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
692
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
693
- "with a layer index."
694
- )
695
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
696
-
697
- # Because the input can be padded, the absolute sequence length depends on the max position id.
698
- rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
699
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
700
-
701
- if sub_sample_lengths is not None:
702
- packing_position_ids = []
703
- for b in range(bsz):
704
- each_sum_sample_lengths = sub_sample_lengths[b]
705
- packing_position_ids.append(torch.cat([torch.arange(each) for each in each_sum_sample_lengths]))
706
- packing_position_ids = torch.stack(packing_position_ids)
707
- packing_position_ids.to(query_states.device)
708
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, packing_position_ids)
709
- else:
710
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
711
-
712
- use_sliding_windows = (
713
- _flash_supports_window_size
714
- and getattr(self.config, "sliding_window", None) is not None
715
- and kv_seq_len > self.config.sliding_window
716
- and self.config.use_sliding_window
717
- )
718
-
719
- if not _flash_supports_window_size:
720
- logger.warning_once(
721
- "The current flash attention version does not support sliding window attention, for a more memory efficient implementation"
722
- " make sure to upgrade flash-attn library."
723
- )
724
-
725
- if past_key_value is not None:
726
- # Activate slicing cache only if the config has a value `sliding_windows` attribute
727
- cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
728
- if (
729
- getattr(self.config, "sliding_window", None) is not None
730
- and kv_seq_len > self.config.sliding_window
731
- and cache_has_contents
732
- ):
733
- slicing_tokens = 1 - self.config.sliding_window
734
-
735
- past_key = past_key_value[self.layer_idx][0]
736
- past_value = past_key_value[self.layer_idx][1]
737
-
738
- past_key = past_key[:, :, slicing_tokens:, :].contiguous()
739
- past_value = past_value[:, :, slicing_tokens:, :].contiguous()
740
-
741
- if past_key.shape[-2] != self.config.sliding_window - 1:
742
- raise ValueError(
743
- f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
744
- f" {past_key.shape}"
745
- )
746
-
747
- if attention_mask is not None:
748
- attention_mask = attention_mask[:, slicing_tokens:]
749
- attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
750
-
751
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
752
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
753
-
754
- # repeat k/v heads if n_kv_heads < n_heads
755
- key_states = repeat_kv(key_states, self.num_key_value_groups)
756
- value_states = repeat_kv(value_states, self.num_key_value_groups)
757
- dropout_rate = 0.0 if not self.training else self.attention_dropout
758
-
759
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
760
- # therefore the input hidden states gets silently casted in float32. Hence, we need
761
- # cast them back in float16 just to be sure everything works as expected.
762
- input_dtype = query_states.dtype
763
- if input_dtype == torch.float32:
764
- if torch.is_autocast_enabled():
765
- target_dtype = torch.get_autocast_gpu_dtype()
766
- # Handle the case where the model is quantized
767
- elif hasattr(self.config, "_pre_quantization_dtype"):
768
- target_dtype = self.config._pre_quantization_dtype
769
- else:
770
- target_dtype = self.q_proj.weight.dtype
771
-
772
- logger.warning_once(
773
- f"The input hidden states seems to be silently casted in float32, this might be related to"
774
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
775
- f" {target_dtype}."
776
- )
777
-
778
- query_states = query_states.to(target_dtype)
779
- key_states = key_states.to(target_dtype)
780
- value_states = value_states.to(target_dtype)
781
-
782
- # Reashape to the expected shape for Flash Attention
783
- query_states = query_states.transpose(1, 2)
784
- key_states = key_states.transpose(1, 2)
785
- value_states = value_states.transpose(1, 2)
786
-
787
- attn_output = self._flash_attention_forward(
788
- query_states,
789
- key_states,
790
- value_states,
791
- attention_mask,
792
- q_len,
793
- dropout=dropout_rate,
794
- use_sliding_windows=use_sliding_windows,
795
- sub_sample_lengths=sub_sample_lengths
796
- )
797
-
798
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
799
- attn_output = self.o_proj(attn_output)
800
-
801
- if not output_attentions:
802
- attn_weights = None
803
-
804
- return attn_output, attn_weights, past_key_value
805
-
806
- def _flash_attention_forward(
807
- self,
808
- query_states,
809
- key_states,
810
- value_states,
811
- attention_mask,
812
- query_length,
813
- dropout=0.0,
814
- softmax_scale=None,
815
- use_sliding_windows=False,
816
- sub_sample_lengths=None,
817
- ):
818
- """
819
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
820
- first unpad the input, then computes the attention scores and pad the final attention scores.
821
-
822
- Args:
823
- query_states (`torch.Tensor`):
824
- Input query states to be passed to Flash Attention API
825
- key_states (`torch.Tensor`):
826
- Input key states to be passed to Flash Attention API
827
- value_states (`torch.Tensor`):
828
- Input value states to be passed to Flash Attention API
829
- attention_mask (`torch.Tensor`):
830
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
831
- position of padding tokens and 1 for the position of non-padding tokens.
832
- dropout (`int`, *optional*):
833
- Attention dropout
834
- softmax_scale (`float`, *optional*):
835
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
836
- use_sliding_windows (`bool`, *optional*):
837
- Whether to activate sliding window attention.
838
- """
839
- if not self._flash_attn_uses_top_left_mask:
840
- causal = self.is_causal
841
- else:
842
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
843
- causal = self.is_causal and query_length != 1
844
-
845
- # Decide whether to use SWA or not by layer index.
846
- if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
847
- use_sliding_windows = False
848
-
849
- # Contains at least one padding token in the sequence
850
-
851
- if attention_mask is not None:
852
- batch_size = query_states.shape[0]
853
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input_packing(
854
- query_states, key_states, value_states, attention_mask, query_length, sub_sample_lengths
855
- )
856
-
857
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
858
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
859
-
860
- if not use_sliding_windows:
861
- attn_output_unpad = flash_attn_varlen_func(
862
- query_states,
863
- key_states,
864
- value_states,
865
- cu_seqlens_q=cu_seqlens_q,
866
- cu_seqlens_k=cu_seqlens_k,
867
- max_seqlen_q=max_seqlen_in_batch_q,
868
- max_seqlen_k=max_seqlen_in_batch_k,
869
- dropout_p=dropout,
870
- softmax_scale=softmax_scale,
871
- causal=causal,
872
- )
873
- else:
874
- attn_output_unpad = flash_attn_varlen_func(
875
- query_states,
876
- key_states,
877
- value_states,
878
- cu_seqlens_q=cu_seqlens_q,
879
- cu_seqlens_k=cu_seqlens_k,
880
- max_seqlen_q=max_seqlen_in_batch_q,
881
- max_seqlen_k=max_seqlen_in_batch_k,
882
- dropout_p=dropout,
883
- softmax_scale=softmax_scale,
884
- causal=causal,
885
- window_size=(self.config.sliding_window, self.config.sliding_window),
886
- )
887
-
888
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
889
- else:
890
- if not use_sliding_windows:
891
- attn_output = flash_attn_func(
892
- query_states,
893
- key_states,
894
- value_states,
895
- dropout,
896
- softmax_scale=softmax_scale,
897
- causal=causal,
898
- )
899
- else:
900
- attn_output = flash_attn_func(
901
- query_states,
902
- key_states,
903
- value_states,
904
- dropout,
905
- softmax_scale=softmax_scale,
906
- causal=causal,
907
- window_size=(self.config.sliding_window, self.config.sliding_window),
908
- )
909
-
910
- return attn_output
911
-
912
- # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
913
- def _unpad_input_packing(self, query_layer, key_layer, value_layer, attention_mask, query_length, sub_sample_lengths):
914
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
915
-
916
- # On the first iteration we need to properly re-create the padding mask
917
- # by slicing it on the proper place
918
- if kv_seq_len != attention_mask.shape[-1]:
919
- attention_mask_num_tokens = attention_mask.shape[-1]
920
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
921
-
922
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data_packing(attention_mask, sub_sample_lengths)
923
-
924
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
925
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
926
-
927
- if query_length == kv_seq_len:
928
- query_layer = index_first_axis(
929
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
930
- )
931
- cu_seqlens_q = cu_seqlens_k
932
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
933
- indices_q = indices_k
934
- elif query_length == 1:
935
- max_seqlen_in_batch_q = 1
936
- cu_seqlens_q = torch.arange(
937
- batch_size + 1, dtype=torch.int32, device=query_layer.device
938
- ) # There is a memcpy here, that is very bad.
939
- indices_q = cu_seqlens_q[:-1]
940
- query_layer = query_layer.squeeze(1)
941
- else:
942
- # The -q_len: slice assumes left padding.
943
- attention_mask = attention_mask[:, -query_length:]
944
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
945
-
946
- return (
947
- query_layer,
948
- key_layer,
949
- value_layer,
950
- indices_q,
951
- (cu_seqlens_q, cu_seqlens_k),
952
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
953
- )
954
-
955
-
956
- # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2
957
- class Qwen2SdpaAttention(Qwen2Attention):
958
- """
959
- Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
960
- `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
961
- SDPA API.
962
- """
963
-
964
- # Adapted from Qwen2Attention.forward
965
- def forward(
966
- self,
967
- hidden_states: torch.Tensor,
968
- attention_mask: Optional[torch.Tensor] = None,
969
- position_ids: Optional[torch.LongTensor] = None,
970
- past_key_value: Optional[Cache] = None,
971
- output_attentions: bool = False,
972
- use_cache: bool = False,
973
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
974
- if output_attentions:
975
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
976
- logger.warning_once(
977
- "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
978
- 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
979
- )
980
- return super().forward(
981
- hidden_states=hidden_states,
982
- attention_mask=attention_mask,
983
- position_ids=position_ids,
984
- past_key_value=past_key_value,
985
- output_attentions=output_attentions,
986
- use_cache=use_cache,
987
- )
988
-
989
- bsz, q_len, _ = hidden_states.size()
990
-
991
- query_states = self.q_proj(hidden_states)
992
- key_states = self.k_proj(hidden_states)
993
- value_states = self.v_proj(hidden_states)
994
-
995
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
996
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
997
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
998
-
999
- kv_seq_len = key_states.shape[-2]
1000
- if past_key_value is not None:
1001
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1002
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1003
-
1004
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1005
-
1006
- if past_key_value is not None:
1007
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
1008
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1009
-
1010
- key_states = repeat_kv(key_states, self.num_key_value_groups)
1011
- value_states = repeat_kv(value_states, self.num_key_value_groups)
1012
-
1013
- if attention_mask is not None:
1014
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
1015
- raise ValueError(
1016
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
1017
- )
1018
-
1019
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1020
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
1021
- if query_states.device.type == "cuda" and attention_mask is not None:
1022
- query_states = query_states.contiguous()
1023
- key_states = key_states.contiguous()
1024
- value_states = value_states.contiguous()
1025
-
1026
- attn_output = torch.nn.functional.scaled_dot_product_attention(
1027
- query_states,
1028
- key_states,
1029
- value_states,
1030
- attn_mask=attention_mask,
1031
- dropout_p=self.attention_dropout if self.training else 0.0,
1032
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1033
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
1034
- )
1035
-
1036
- attn_output = attn_output.transpose(1, 2).contiguous()
1037
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1038
-
1039
- attn_output = self.o_proj(attn_output)
1040
-
1041
- return attn_output, None, past_key_value
1042
-
1043
-
1044
- QWEN2_ATTENTION_CLASSES = {
1045
- "eager": Qwen2Attention,
1046
- "flash_attention_2": Qwen2FlashAttention2,
1047
- "sdpa": Qwen2SdpaAttention,
1048
- 'flash_attention_2_packing':Qwen2FlashAttention2_packing
1049
- }
1050
-
1051
-
1052
- class Qwen2DecoderLayer(nn.Module):
1053
- def __init__(self, config: Qwen2Config, layer_idx: int):
1054
- super().__init__()
1055
- self.hidden_size = config.hidden_size
1056
-
1057
- if config.use_sliding_window and config.attn_implementation != "flash_attention_2":
1058
- logger.warning_once(
1059
- f"Sliding Window Attention is enabled but not implemented for `{config.attn_implementation}`; "
1060
- "unexpected results may be encountered."
1061
- )
1062
-
1063
- self.self_attn = QWEN2_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx)
1064
-
1065
- self.mlp = Qwen2MLP(config)
1066
- self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1067
- self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1068
-
1069
- def forward(
1070
- self,
1071
- hidden_states: torch.Tensor,
1072
- attention_mask: Optional[torch.Tensor] = None,
1073
- position_ids: Optional[torch.LongTensor] = None,
1074
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
1075
- sub_sample_lengths=None,
1076
- output_attentions: Optional[bool] = False,
1077
- use_cache: Optional[bool] = False,
1078
- **kwargs,
1079
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1080
- if "padding_mask" in kwargs:
1081
- warnings.warn(
1082
- "Passing `padding_mask` is deprecated and will be removed in v4.37. "
1083
- "Please make sure use `attention_mask` instead.`"
1084
- )
1085
- """
1086
- Args:
1087
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1088
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1089
- `(batch, sequence_length)` where padding elements are indicated by 0.
1090
- output_attentions (`bool`, *optional*):
1091
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1092
- returned tensors for more detail.
1093
- use_cache (`bool`, *optional*):
1094
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1095
- (see `past_key_values`).
1096
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1097
- """
1098
-
1099
- residual = hidden_states
1100
-
1101
- hidden_states = self.input_layernorm(hidden_states)
1102
-
1103
- # Self Attention
1104
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
1105
- hidden_states=hidden_states,
1106
- attention_mask=attention_mask,
1107
- position_ids=position_ids,
1108
- past_key_value=past_key_value,
1109
- output_attentions=output_attentions,
1110
- use_cache=use_cache,
1111
- sub_sample_lengths=sub_sample_lengths,
1112
- )
1113
- hidden_states = residual + hidden_states
1114
-
1115
- # Fully Connected
1116
- residual = hidden_states
1117
- hidden_states = self.post_attention_layernorm(hidden_states)
1118
- hidden_states = self.mlp(hidden_states)
1119
- hidden_states = residual + hidden_states
1120
-
1121
- outputs = (hidden_states,)
1122
-
1123
- if output_attentions:
1124
- outputs += (self_attn_weights,)
1125
-
1126
- if use_cache:
1127
- outputs += (present_key_value,)
1128
-
1129
- return outputs
1130
-
1131
-
1132
- QWEN2_START_DOCSTRING = r"""
1133
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1134
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1135
- etc.)
1136
-
1137
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1138
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1139
- and behavior.
1140
-
1141
- Parameters:
1142
- config ([`Qwen2Config`]):
1143
- Model configuration class with all the parameters of the model. Initializing with a config file does not
1144
- load the weights associated with the model, only the configuration. Check out the
1145
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1146
- """
1147
-
1148
-
1149
- @add_start_docstrings(
1150
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
1151
- QWEN2_START_DOCSTRING,
1152
- )
1153
- class Qwen2PreTrainedModel(PreTrainedModel):
1154
- config_class = Qwen2Config
1155
- base_model_prefix = "model"
1156
- supports_gradient_checkpointing = True
1157
- _no_split_modules = ["Qwen2DecoderLayer"]
1158
- _skip_keys_device_placement = "past_key_values"
1159
- _supports_flash_attn_2 = True
1160
- _supports_sdpa = True
1161
- _supports_cache_class = True
1162
-
1163
- def _init_weights(self, module):
1164
- std = self.config.initializer_range
1165
- if isinstance(module, nn.Linear):
1166
- module.weight.data.normal_(mean=0.0, std=std)
1167
- if module.bias is not None:
1168
- module.bias.data.zero_()
1169
- elif isinstance(module, nn.Embedding):
1170
- module.weight.data.normal_(mean=0.0, std=std)
1171
- if module.padding_idx is not None:
1172
- module.weight.data[module.padding_idx].zero_()
1173
-
1174
-
1175
- QWEN2_INPUTS_DOCSTRING = r"""
1176
- Args:
1177
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1178
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1179
- it.
1180
-
1181
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1182
- [`PreTrainedTokenizer.__call__`] for details.
1183
-
1184
- [What are input IDs?](../glossary#input-ids)
1185
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1186
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1187
-
1188
- - 1 for tokens that are **not masked**,
1189
- - 0 for tokens that are **masked**.
1190
-
1191
- [What are attention masks?](../glossary#attention-mask)
1192
-
1193
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1194
- [`PreTrainedTokenizer.__call__`] for details.
1195
-
1196
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1197
- `past_key_values`).
1198
-
1199
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1200
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1201
- information on the default strategy.
1202
-
1203
- - 1 indicates the head is **not masked**,
1204
- - 0 indicates the head is **masked**.
1205
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1206
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1207
- config.n_positions - 1]`.
1208
-
1209
- [What are position IDs?](../glossary#position-ids)
1210
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1211
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1212
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1213
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1214
-
1215
- Two formats are allowed:
1216
- - a [`~cache_utils.Cache`] instance;
1217
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1218
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1219
- cache format.
1220
-
1221
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1222
- legacy cache format will be returned.
1223
-
1224
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1225
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1226
- of shape `(batch_size, sequence_length)`.
1227
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1228
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1229
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1230
- model's internal embedding lookup matrix.
1231
- use_cache (`bool`, *optional*):
1232
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1233
- `past_key_values`).
1234
- output_attentions (`bool`, *optional*):
1235
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1236
- tensors for more detail.
1237
- output_hidden_states (`bool`, *optional*):
1238
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1239
- more detail.
1240
- return_dict (`bool`, *optional*):
1241
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1242
- """
1243
-
1244
-
1245
- @add_start_docstrings(
1246
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
1247
- QWEN2_START_DOCSTRING,
1248
- )
1249
- class Qwen2Model(Qwen2PreTrainedModel):
1250
- """
1251
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
1252
-
1253
- Args:
1254
- config: Qwen2Config
1255
- """
1256
-
1257
- def __init__(self, config: Qwen2Config):
1258
- super().__init__(config)
1259
- self.padding_idx = config.pad_token_id
1260
- self.vocab_size = config.vocab_size
1261
-
1262
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1263
- self.layers = nn.ModuleList(
1264
- [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1265
- )
1266
- self.attn_implementation = config.attn_implementation
1267
- self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1268
-
1269
- self.gradient_checkpointing = False
1270
- # Initialize weights and apply final processing
1271
- self.post_init()
1272
-
1273
- def get_input_embeddings(self):
1274
- return self.embed_tokens
1275
-
1276
- def set_input_embeddings(self, value):
1277
- self.embed_tokens = value
1278
-
1279
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1280
- def forward(
1281
- self,
1282
- input_ids: torch.LongTensor = None,
1283
- attention_mask: Optional[torch.Tensor] = None,
1284
- position_ids: Optional[torch.LongTensor] = None,
1285
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1286
- inputs_embeds: Optional[torch.FloatTensor] = None,
1287
- use_cache: Optional[bool] = None,
1288
- output_attentions: Optional[bool] = None,
1289
- output_hidden_states: Optional[bool] = None,
1290
- return_dict: Optional[bool] = None,
1291
- sub_sample_lengths=None,
1292
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1293
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1294
- output_hidden_states = (
1295
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1296
- )
1297
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1298
-
1299
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1300
-
1301
- # retrieve input_ids and inputs_embeds
1302
- if input_ids is not None and inputs_embeds is not None:
1303
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1304
- elif input_ids is not None:
1305
- batch_size, seq_length = input_ids.shape
1306
- elif inputs_embeds is not None:
1307
- batch_size, seq_length, _ = inputs_embeds.shape
1308
- else:
1309
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1310
-
1311
- if self.gradient_checkpointing and self.training:
1312
- if use_cache:
1313
- logger.warning_once(
1314
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1315
- )
1316
- use_cache = False
1317
-
1318
- past_key_values_length = 0
1319
-
1320
- if use_cache:
1321
- use_legacy_cache = not isinstance(past_key_values, Cache)
1322
- if use_legacy_cache:
1323
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1324
- past_key_values_length = past_key_values.get_usable_length(seq_length)
1325
-
1326
- if position_ids is None:
1327
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1328
- position_ids = torch.arange(
1329
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1330
- )
1331
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
1332
- else:
1333
- position_ids = position_ids.view(-1, seq_length).long()
1334
-
1335
- if inputs_embeds is None:
1336
- inputs_embeds = self.embed_tokens(input_ids)
1337
-
1338
- if attention_mask is not None and self.attn_implementation == "flash_attention_2" and use_cache:
1339
- is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1340
- if is_padding_right:
1341
- raise ValueError(
1342
- "You are attempting to perform batched generation with padding_side='right'"
1343
- " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
1344
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1345
- )
1346
-
1347
- if self.attn_implementation == "flash_attention_2" or self.config.attn_implementation =='flash_attention_2_packing':
1348
- # 2d mask is passed through the layers
1349
- if attention_mask is not None:
1350
- if attention_mask.dtype == torch.long:
1351
- pass
1352
- # attention_mask = attention_mask
1353
- else:
1354
- attention_mask = attention_mask if (0 in attention_mask) else None
1355
-
1356
- elif self.attn_implementation == "sdpa" and not output_attentions:
1357
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1358
- # the manual implementation that requires a 4D causal mask in all cases.
1359
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1360
- attention_mask,
1361
- (batch_size, seq_length),
1362
- inputs_embeds,
1363
- past_key_values_length,
1364
- )
1365
- else:
1366
- # 4d mask is passed through the layers
1367
- attention_mask = _prepare_4d_causal_attention_mask(
1368
- attention_mask,
1369
- (batch_size, seq_length),
1370
- inputs_embeds,
1371
- past_key_values_length,
1372
- sliding_window=self.config.sliding_window,
1373
- )
1374
-
1375
- hidden_states = inputs_embeds
1376
-
1377
- # decoder layers
1378
- all_hidden_states = () if output_hidden_states else None
1379
- all_self_attns = () if output_attentions else None
1380
- next_decoder_cache = None
1381
-
1382
- for decoder_layer in self.layers:
1383
- if output_hidden_states:
1384
- all_hidden_states += (hidden_states,)
1385
- if self.gradient_checkpointing and self.training:
1386
- layer_outputs = self._gradient_checkpointing_func(
1387
- decoder_layer.__call__,
1388
- hidden_states,
1389
- attention_mask,
1390
- position_ids,
1391
- past_key_values,
1392
- sub_sample_lengths,
1393
- output_attentions,
1394
- use_cache,
1395
- )
1396
- else:
1397
- layer_outputs = decoder_layer(
1398
- hidden_states,
1399
- attention_mask=attention_mask,
1400
- position_ids=position_ids,
1401
- past_key_value=past_key_values,
1402
- sub_sample_lengths=sub_sample_lengths,
1403
- output_attentions=output_attentions,
1404
- use_cache=use_cache,
1405
- )
1406
-
1407
- hidden_states = layer_outputs[0]
1408
-
1409
- if use_cache:
1410
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1411
-
1412
- if output_attentions:
1413
- all_self_attns += (layer_outputs[1],)
1414
-
1415
- hidden_states = self.norm(hidden_states)
1416
-
1417
- # add hidden states from the last decoder layer
1418
- if output_hidden_states:
1419
- all_hidden_states += (hidden_states,)
1420
-
1421
- next_cache = None
1422
- if use_cache:
1423
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1424
-
1425
- if not return_dict:
1426
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1427
- return BaseModelOutputWithPast(
1428
- last_hidden_state=hidden_states,
1429
- past_key_values=next_cache,
1430
- hidden_states=all_hidden_states,
1431
- attentions=all_self_attns,
1432
- )
1433
-
1434
-
1435
- class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1436
- _tied_weights_keys = ["lm_head.weight"]
1437
-
1438
- def __init__(self, config):
1439
- super().__init__(config)
1440
- self.model = Qwen2Model(config)
1441
- self.vocab_size = config.vocab_size
1442
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1443
-
1444
- # Initialize weights and apply final processing
1445
- self.post_init()
1446
- self.support_packing = True
1447
-
1448
- def get_input_embeddings(self):
1449
- return self.model.embed_tokens
1450
-
1451
- def set_input_embeddings(self, value):
1452
- self.model.embed_tokens = value
1453
-
1454
- def get_output_embeddings(self):
1455
- return self.lm_head
1456
-
1457
- def set_output_embeddings(self, new_embeddings):
1458
- self.lm_head = new_embeddings
1459
-
1460
- def set_decoder(self, decoder):
1461
- self.model = decoder
1462
-
1463
- def get_decoder(self):
1464
- return self.model
1465
-
1466
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1467
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1468
- def forward(
1469
- self,
1470
- input_ids: torch.LongTensor = None,
1471
- attention_mask: Optional[torch.Tensor] = None,
1472
- position_ids: Optional[torch.LongTensor] = None,
1473
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1474
- inputs_embeds: Optional[torch.FloatTensor] = None,
1475
- labels: Optional[torch.LongTensor] = None,
1476
- use_cache: Optional[bool] = None,
1477
- output_attentions: Optional[bool] = None,
1478
- output_hidden_states: Optional[bool] = None,
1479
- return_dict: Optional[bool] = None,
1480
- sub_sample_lengths=None,
1481
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1482
- r"""
1483
- Args:
1484
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1485
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1486
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1487
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1488
-
1489
- Returns:
1490
-
1491
- Example:
1492
-
1493
- ```python
1494
- >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
1495
-
1496
- >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1497
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1498
-
1499
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
1500
- >>> inputs = tokenizer(prompt, return_tensors="pt")
1501
-
1502
- >>> # Generate
1503
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1504
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1505
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1506
- ```"""
1507
-
1508
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1509
- output_hidden_states = (
1510
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1511
- )
1512
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1513
-
1514
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1515
- outputs = self.model(
1516
- input_ids=input_ids,
1517
- attention_mask=attention_mask,
1518
- position_ids=position_ids,
1519
- past_key_values=past_key_values,
1520
- inputs_embeds=inputs_embeds,
1521
- use_cache=use_cache,
1522
- output_attentions=output_attentions,
1523
- output_hidden_states=output_hidden_states,
1524
- return_dict=return_dict,
1525
- sub_sample_lengths=sub_sample_lengths
1526
- )
1527
-
1528
- hidden_states = outputs[0]
1529
- logits = self.lm_head(hidden_states)
1530
- logits = logits.float()
1531
-
1532
- loss = None
1533
- if labels is not None:
1534
- # Shift so that tokens < n predict n
1535
- shift_logits = logits[..., :-1, :].contiguous()
1536
- shift_labels = labels[..., 1:].contiguous()
1537
- # Flatten the tokens
1538
- loss_fct = CrossEntropyLoss()
1539
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1540
- shift_labels = shift_labels.view(-1)
1541
- # Enable model parallelism
1542
- shift_labels = shift_labels.to(shift_logits.device)
1543
- loss = loss_fct(shift_logits, shift_labels)
1544
-
1545
- if not return_dict:
1546
- output = (logits,) + outputs[1:]
1547
- return (loss,) + output if loss is not None else output
1548
-
1549
- return CausalLMOutputWithPast(
1550
- loss=loss,
1551
- logits=logits,
1552
- past_key_values=outputs.past_key_values,
1553
- hidden_states=outputs.hidden_states,
1554
- attentions=outputs.attentions,
1555
- )
1556
-
1557
- def prepare_inputs_for_generation(
1558
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1559
- ):
1560
- # Omit tokens covered by past_key_values
1561
- if past_key_values is not None:
1562
- if isinstance(past_key_values, Cache):
1563
- cache_length = past_key_values.get_seq_length()
1564
- past_length = past_key_values.seen_tokens
1565
- max_cache_length = past_key_values.get_max_length()
1566
- else:
1567
- cache_length = past_length = past_key_values[0][0].shape[2]
1568
- max_cache_length = None
1569
-
1570
- # Keep only the unprocessed tokens:
1571
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1572
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1573
- # input)
1574
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1575
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1576
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1577
- # input_ids based on the past_length.
1578
- elif past_length < input_ids.shape[1]:
1579
- input_ids = input_ids[:, past_length:]
1580
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1581
-
1582
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1583
- if (
1584
- max_cache_length is not None
1585
- and attention_mask is not None
1586
- and cache_length + input_ids.shape[1] > max_cache_length
1587
- ):
1588
- attention_mask = attention_mask[:, -max_cache_length:]
1589
-
1590
- position_ids = kwargs.get("position_ids", None)
1591
- if attention_mask is not None and position_ids is None:
1592
- # create position_ids on the fly for batch generation
1593
- position_ids = attention_mask.long().cumsum(-1) - 1
1594
- position_ids.masked_fill_(attention_mask == 0, 1)
1595
- if past_key_values:
1596
- position_ids = position_ids[:, -input_ids.shape[1] :]
1597
-
1598
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1599
- if inputs_embeds is not None and past_key_values is None:
1600
- model_inputs = {"inputs_embeds": inputs_embeds}
1601
- else:
1602
- model_inputs = {"input_ids": input_ids}
1603
-
1604
- model_inputs.update(
1605
- {
1606
- "position_ids": position_ids,
1607
- "past_key_values": past_key_values,
1608
- "use_cache": kwargs.get("use_cache"),
1609
- "attention_mask": attention_mask,
1610
- }
1611
- )
1612
- return model_inputs
1613
-
1614
- @staticmethod
1615
- def _reorder_cache(past_key_values, beam_idx):
1616
- reordered_past = ()
1617
- for layer_past in past_key_values:
1618
- reordered_past += (
1619
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1620
- )
1621
- return reordered_past
1622
-
1623
-
1624
- @add_start_docstrings(
1625
- """
1626
- The Qwen2 Model transformer with a sequence classification head on top (linear layer).
1627
-
1628
- [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1629
- (e.g. GPT-2) do.
1630
-
1631
- Since it does classification on the last token, it requires to know the position of the last token. If a
1632
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1633
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1634
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1635
- each row of the batch).
1636
- """,
1637
- QWEN2_START_DOCSTRING,
1638
- )
1639
- class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1640
- def __init__(self, config):
1641
- super().__init__(config)
1642
- self.num_labels = config.num_labels
1643
- self.model = Qwen2Model(config)
1644
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1645
-
1646
- # Initialize weights and apply final processing
1647
- self.post_init()
1648
-
1649
- def get_input_embeddings(self):
1650
- return self.model.embed_tokens
1651
-
1652
- def set_input_embeddings(self, value):
1653
- self.model.embed_tokens = value
1654
-
1655
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1656
- def forward(
1657
- self,
1658
- input_ids: torch.LongTensor = None,
1659
- attention_mask: Optional[torch.Tensor] = None,
1660
- position_ids: Optional[torch.LongTensor] = None,
1661
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1662
- inputs_embeds: Optional[torch.FloatTensor] = None,
1663
- labels: Optional[torch.LongTensor] = None,
1664
- use_cache: Optional[bool] = None,
1665
- output_attentions: Optional[bool] = None,
1666
- output_hidden_states: Optional[bool] = None,
1667
- return_dict: Optional[bool] = None,
1668
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1669
- r"""
1670
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1671
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1672
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1673
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1674
- """
1675
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1676
-
1677
- transformer_outputs = self.model(
1678
- input_ids,
1679
- attention_mask=attention_mask,
1680
- position_ids=position_ids,
1681
- past_key_values=past_key_values,
1682
- inputs_embeds=inputs_embeds,
1683
- use_cache=use_cache,
1684
- output_attentions=output_attentions,
1685
- output_hidden_states=output_hidden_states,
1686
- return_dict=return_dict,
1687
- )
1688
- hidden_states = transformer_outputs[0]
1689
- logits = self.score(hidden_states)
1690
-
1691
- if input_ids is not None:
1692
- batch_size = input_ids.shape[0]
1693
- else:
1694
- batch_size = inputs_embeds.shape[0]
1695
-
1696
- if self.config.pad_token_id is None and batch_size != 1:
1697
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1698
- if self.config.pad_token_id is None:
1699
- sequence_lengths = -1
1700
- else:
1701
- if input_ids is not None:
1702
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1703
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1704
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
1705
- sequence_lengths = sequence_lengths.to(logits.device)
1706
- else:
1707
- sequence_lengths = -1
1708
-
1709
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1710
-
1711
- loss = None
1712
- if labels is not None:
1713
- labels = labels.to(logits.device)
1714
- if self.config.problem_type is None:
1715
- if self.num_labels == 1:
1716
- self.config.problem_type = "regression"
1717
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1718
- self.config.problem_type = "single_label_classification"
1719
- else:
1720
- self.config.problem_type = "multi_label_classification"
1721
-
1722
- if self.config.problem_type == "regression":
1723
- loss_fct = MSELoss()
1724
- if self.num_labels == 1:
1725
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1726
- else:
1727
- loss = loss_fct(pooled_logits, labels)
1728
- elif self.config.problem_type == "single_label_classification":
1729
- loss_fct = CrossEntropyLoss()
1730
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1731
- elif self.config.problem_type == "multi_label_classification":
1732
- loss_fct = BCEWithLogitsLoss()
1733
- loss = loss_fct(pooled_logits, labels)
1734
- if not return_dict:
1735
- output = (pooled_logits,) + transformer_outputs[1:]
1736
- return ((loss,) + output) if loss is not None else output
1737
-
1738
- return SequenceClassifierOutputWithPast(
1739
- loss=loss,
1740
- logits=pooled_logits,
1741
- past_key_values=transformer_outputs.past_key_values,
1742
- hidden_states=transformer_outputs.hidden_states,
1743
- attentions=transformer_outputs.attentions,
1744
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_siglip.py DELETED
@@ -1,1241 +0,0 @@
1
- # --------------------------------------------------------
2
- # Eagle2
3
- # Copyright (c) 2025 NVIDIA
4
- # Licensed under The MIT License [see LICENSE for details]
5
- # Support flash-attention in SigLIP
6
- # --------------------------------------------------------
7
-
8
-
9
- # coding=utf-8
10
- # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
11
- #
12
- # Licensed under the Apache License, Version 2.0 (the "License");
13
- # you may not use this file except in compliance with the License.
14
- # You may obtain a copy of the License at
15
- #
16
- # http://www.apache.org/licenses/LICENSE-2.0
17
- #
18
- # Unless required by applicable law or agreed to in writing, software
19
- # distributed under the License is distributed on an "AS IS" BASIS,
20
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
- # See the License for the specific language governing permissions and
22
- # limitations under the License.
23
- """ PyTorch Siglip model."""
24
-
25
-
26
- import math
27
- import warnings
28
- from dataclasses import dataclass
29
- from typing import Any, Optional, Tuple, Union
30
- from einops import rearrange
31
- import numpy as np
32
- import torch
33
- import torch.utils.checkpoint
34
- from torch import nn
35
- from torch.nn.init import _calculate_fan_in_and_fan_out
36
-
37
- from transformers.activations import ACT2FN
38
- from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
39
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
40
- from transformers.modeling_utils import PreTrainedModel
41
- from transformers.utils import (
42
- ModelOutput,
43
- add_start_docstrings,
44
- add_start_docstrings_to_model_forward,
45
- logging,
46
- replace_return_docstrings,
47
- )
48
- from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
49
-
50
- try:
51
- from .flash_attention import FlashAttention
52
- has_flash_attn = True
53
- except:
54
- print('FlashAttention is not installed.')
55
- has_flash_attn = False
56
-
57
- logger = logging.get_logger(__name__)
58
-
59
- _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
60
-
61
- SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
62
- "google/siglip-base-patch16-224",
63
- # See all SigLIP models at https://huggingface.co/models?filter=siglip
64
- ]
65
-
66
-
67
- def _trunc_normal_(tensor, mean, std, a, b):
68
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
69
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
70
- def norm_cdf(x):
71
- # Computes standard normal cumulative distribution function
72
- return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
73
-
74
- if (mean < a - 2 * std) or (mean > b + 2 * std):
75
- warnings.warn(
76
- "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
77
- "The distribution of values may be incorrect.",
78
- stacklevel=2,
79
- )
80
-
81
- # Values are generated by using a truncated uniform distribution and
82
- # then using the inverse CDF for the normal distribution.
83
- # Get upper and lower cdf values
84
- l = norm_cdf((a - mean) / std)
85
- u = norm_cdf((b - mean) / std)
86
-
87
- # Uniformly fill tensor with values from [l, u], then translate to
88
- # [2l-1, 2u-1].
89
- tensor.uniform_(2 * l - 1, 2 * u - 1)
90
-
91
- # Use inverse cdf transform for normal distribution to get truncated
92
- # standard normal
93
- tensor.erfinv_()
94
-
95
- # Transform to proper mean, std
96
- tensor.mul_(std * math.sqrt(2.0))
97
- tensor.add_(mean)
98
-
99
- # Clamp to ensure it's in the proper range
100
- tensor.clamp_(min=a, max=b)
101
-
102
-
103
- def trunc_normal_tf_(
104
- tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
105
- ) -> torch.Tensor:
106
- """Fills the input Tensor with values drawn from a truncated
107
- normal distribution. The values are effectively drawn from the
108
- normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
109
- with values outside :math:`[a, b]` redrawn until they are within
110
- the bounds. The method used for generating the random values works
111
- best when :math:`a \\leq \text{mean} \\leq b`.
112
-
113
- NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
114
- bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
115
- and the result is subsquently scaled and shifted by the mean and std args.
116
-
117
- Args:
118
- tensor: an n-dimensional `torch.Tensor`
119
- mean: the mean of the normal distribution
120
- std: the standard deviation of the normal distribution
121
- a: the minimum cutoff value
122
- b: the maximum cutoff value
123
- """
124
- with torch.no_grad():
125
- _trunc_normal_(tensor, 0, 1.0, a, b)
126
- tensor.mul_(std).add_(mean)
127
-
128
-
129
- def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
130
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
131
- if mode == "fan_in":
132
- denom = fan_in
133
- elif mode == "fan_out":
134
- denom = fan_out
135
- elif mode == "fan_avg":
136
- denom = (fan_in + fan_out) / 2
137
-
138
- variance = scale / denom
139
-
140
- if distribution == "truncated_normal":
141
- # constant is stddev of standard normal truncated to (-2, 2)
142
- trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
143
- elif distribution == "normal":
144
- with torch.no_grad():
145
- tensor.normal_(std=math.sqrt(variance))
146
- elif distribution == "uniform":
147
- bound = math.sqrt(3 * variance)
148
- with torch.no_grad():
149
- tensor.uniform_(-bound, bound)
150
- else:
151
- raise ValueError(f"invalid distribution {distribution}")
152
-
153
-
154
- def lecun_normal_(tensor):
155
- variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
156
-
157
-
158
- def default_flax_embed_init(tensor):
159
- variance_scaling_(tensor, mode="fan_in", distribution="normal")
160
-
161
-
162
- @dataclass
163
- # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
164
- class SiglipVisionModelOutput(ModelOutput):
165
- """
166
- Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
167
-
168
- Args:
169
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
170
- The image embeddings obtained by applying the projection layer to the pooler_output.
171
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
172
- Sequence of hidden-states at the output of the last layer of the model.
173
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
174
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
175
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
176
-
177
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
178
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
179
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
180
- sequence_length)`.
181
-
182
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
183
- heads.
184
- """
185
-
186
- image_embeds: Optional[torch.FloatTensor] = None
187
- last_hidden_state: torch.FloatTensor = None
188
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
189
- attentions: Optional[Tuple[torch.FloatTensor]] = None
190
-
191
-
192
- @dataclass
193
- # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
194
- class SiglipTextModelOutput(ModelOutput):
195
- """
196
- Base class for text model's outputs that also contains a pooling of the last hidden states.
197
-
198
- Args:
199
- text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
200
- The text embeddings obtained by applying the projection layer to the pooler_output.
201
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
202
- Sequence of hidden-states at the output of the last layer of the model.
203
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
204
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
205
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
206
-
207
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
208
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
209
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
210
- sequence_length)`.
211
-
212
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
213
- heads.
214
- """
215
-
216
- text_embeds: Optional[torch.FloatTensor] = None
217
- last_hidden_state: torch.FloatTensor = None
218
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
219
- attentions: Optional[Tuple[torch.FloatTensor]] = None
220
-
221
-
222
- @dataclass
223
- # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
224
- class SiglipOutput(ModelOutput):
225
- """
226
- Args:
227
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
228
- Contrastive loss for image-text similarity.
229
- logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
230
- The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
231
- similarity scores.
232
- logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
233
- The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
234
- similarity scores.
235
- text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
236
- The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
237
- image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
238
- The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
239
- text_model_output(`BaseModelOutputWithPooling`):
240
- The output of the [`SiglipTextModel`].
241
- vision_model_output(`BaseModelOutputWithPooling`):
242
- The output of the [`SiglipVisionModel`].
243
- """
244
-
245
- loss: Optional[torch.FloatTensor] = None
246
- logits_per_image: torch.FloatTensor = None
247
- logits_per_text: torch.FloatTensor = None
248
- text_embeds: torch.FloatTensor = None
249
- image_embeds: torch.FloatTensor = None
250
- text_model_output: BaseModelOutputWithPooling = None
251
- vision_model_output: BaseModelOutputWithPooling = None
252
-
253
- def to_tuple(self) -> Tuple[Any]:
254
- return tuple(
255
- self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
256
- for k in self.keys()
257
- )
258
-
259
-
260
- class SiglipVisionEmbeddings(nn.Module):
261
- def __init__(self, config: SiglipVisionConfig):
262
- super().__init__()
263
- self.config = config
264
- self.embed_dim = config.hidden_size
265
- self.image_size = config.image_size
266
- self.patch_size = config.patch_size
267
-
268
- self.patch_embedding = nn.Conv2d(
269
- in_channels=config.num_channels,
270
- out_channels=self.embed_dim,
271
- kernel_size=self.patch_size,
272
- stride=self.patch_size,
273
- padding="valid",
274
- )
275
-
276
- self.num_patches = (self.image_size // self.patch_size) ** 2
277
- self.num_positions = self.num_patches
278
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
279
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
280
-
281
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
282
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
283
- embeddings = patch_embeds.flatten(2).transpose(1, 2)
284
-
285
- embeddings = embeddings + self.position_embedding(self.position_ids)
286
- return embeddings
287
-
288
-
289
- # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
290
- class SiglipTextEmbeddings(nn.Module):
291
- def __init__(self, config: SiglipTextConfig):
292
- super().__init__()
293
- embed_dim = config.hidden_size
294
-
295
- self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
296
- self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
297
-
298
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
299
- self.register_buffer(
300
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
301
- )
302
-
303
- def forward(
304
- self,
305
- input_ids: Optional[torch.LongTensor] = None,
306
- position_ids: Optional[torch.LongTensor] = None,
307
- inputs_embeds: Optional[torch.FloatTensor] = None,
308
- ) -> torch.Tensor:
309
- seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
310
-
311
- if position_ids is None:
312
- position_ids = self.position_ids[:, :seq_length]
313
-
314
- if inputs_embeds is None:
315
- inputs_embeds = self.token_embedding(input_ids)
316
-
317
- position_embeddings = self.position_embedding(position_ids)
318
- embeddings = inputs_embeds + position_embeddings
319
-
320
- return embeddings
321
-
322
-
323
- class SiglipAttention(nn.Module):
324
- """Multi-headed attention from 'Attention Is All You Need' paper"""
325
-
326
- # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
327
- def __init__(self, config):
328
- super().__init__()
329
- self.config = config
330
- self.embed_dim = config.hidden_size
331
- self.num_heads = config.num_attention_heads
332
- self.head_dim = self.embed_dim // self.num_heads
333
- if self.head_dim * self.num_heads != self.embed_dim:
334
- raise ValueError(
335
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
336
- f" {self.num_heads})."
337
- )
338
- self.scale = self.head_dim**-0.5
339
- self.dropout = config.attention_dropout
340
-
341
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
342
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
343
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
344
- # self.use_flash_attn = config.use_flash_attn and has_flash_attn
345
- self.use_flash_attn = True if has_flash_attn else False
346
- if self.use_flash_attn:
347
- self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
348
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
349
-
350
- def _flash_attn(self,
351
- hidden_states: torch.Tensor,
352
- attention_mask: Optional[torch.Tensor] = None,
353
- output_attentions: Optional[bool] = False,
354
- key_padding_mask=None,
355
- need_weights=False
356
- ):
357
-
358
- batch_size, q_len, _ = hidden_states.size()
359
-
360
- query_states = self.q_proj(hidden_states)
361
- key_states = self.k_proj(hidden_states)
362
- value_states = self.v_proj(hidden_states)
363
-
364
- query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
365
- key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
366
- value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
367
-
368
- qkv = torch.stack([query_states, key_states, value_states], dim=2)
369
- context, attn_weights = self.inner_attn(
370
- qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
371
- )
372
- attn_output = self.out_proj(rearrange(context, 'b s h d -> b s (h d)'))
373
-
374
- return attn_output, attn_weights
375
-
376
- def forward(
377
- self,
378
- hidden_states: torch.Tensor,
379
- attention_mask: Optional[torch.Tensor] = None,
380
- output_attentions: Optional[bool] = False,
381
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
382
- """Input shape: Batch x Time x Channel"""
383
- if self.use_flash_attn:
384
- return self._flash_attn(hidden_states)
385
- else:
386
- return self._vanilla_attn(hidden_states, attention_mask, output_attentions)
387
-
388
- def _vanilla_attn(self, hidden_states, attention_mask=None, output_attentions=False):
389
- batch_size, q_len, _ = hidden_states.size()
390
-
391
- query_states = self.q_proj(hidden_states)
392
- key_states = self.k_proj(hidden_states)
393
- value_states = self.v_proj(hidden_states)
394
-
395
- query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
396
- key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
397
- value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
398
-
399
- k_v_seq_len = key_states.shape[-2]
400
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
401
-
402
- if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
403
- raise ValueError(
404
- f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
405
- f" {attn_weights.size()}"
406
- )
407
-
408
- if attention_mask is not None:
409
- if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
410
- raise ValueError(
411
- f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
412
- )
413
- attn_weights = attn_weights + attention_mask
414
-
415
- # upcast attention to fp32
416
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
417
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
418
- attn_output = torch.matmul(attn_weights, value_states)
419
-
420
- if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
421
- raise ValueError(
422
- f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
423
- f" {attn_output.size()}"
424
- )
425
-
426
- attn_output = attn_output.transpose(1, 2).contiguous()
427
- attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
428
-
429
- attn_output = self.out_proj(attn_output)
430
-
431
- return attn_output, attn_weights
432
-
433
-
434
- # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
435
- class SiglipMLP(nn.Module):
436
- def __init__(self, config):
437
- super().__init__()
438
- self.config = config
439
- self.activation_fn = ACT2FN[config.hidden_act]
440
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
441
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
442
-
443
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
444
- hidden_states = self.fc1(hidden_states)
445
- hidden_states = self.activation_fn(hidden_states)
446
- hidden_states = self.fc2(hidden_states)
447
- return hidden_states
448
-
449
-
450
- # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
451
- class SiglipEncoderLayer(nn.Module):
452
- def __init__(self, config: SiglipConfig):
453
- super().__init__()
454
- self.embed_dim = config.hidden_size
455
- self.self_attn = SiglipAttention(config)
456
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
457
- self.mlp = SiglipMLP(config)
458
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
459
-
460
- # Ignore copy
461
- def forward(
462
- self,
463
- hidden_states: torch.Tensor,
464
- attention_mask: torch.Tensor,
465
- output_attentions: Optional[bool] = False,
466
- ) -> Tuple[torch.FloatTensor]:
467
- """
468
- Args:
469
- hidden_states (`torch.FloatTensor`):
470
- Input to the layer of shape `(batch, seq_len, embed_dim)`.
471
- attention_mask (`torch.FloatTensor`):
472
- Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
473
- output_attentions (`bool`, *optional*, defaults to `False`):
474
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
475
- returned tensors for more detail.
476
- """
477
- residual = hidden_states
478
-
479
- hidden_states = self.layer_norm1(hidden_states)
480
- hidden_states, attn_weights = self.self_attn(
481
- hidden_states=hidden_states,
482
- attention_mask=attention_mask,
483
- output_attentions=output_attentions,
484
- )
485
- hidden_states = residual + hidden_states
486
-
487
- residual = hidden_states
488
- hidden_states = self.layer_norm2(hidden_states)
489
- hidden_states = self.mlp(hidden_states)
490
- hidden_states = residual + hidden_states
491
-
492
- outputs = (hidden_states,)
493
-
494
- if output_attentions:
495
- outputs += (attn_weights,)
496
-
497
- return outputs
498
-
499
-
500
- class SiglipPreTrainedModel(PreTrainedModel):
501
- """
502
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
503
- models.
504
- """
505
-
506
- config_class = SiglipConfig
507
- base_model_prefix = "siglip"
508
- supports_gradient_checkpointing = True
509
-
510
- def _init_weights(self, module):
511
- """Initialize the weights"""
512
- if isinstance(module, SiglipVisionEmbeddings):
513
- width = (
514
- self.config.vision_config.hidden_size
515
- if isinstance(self.config, SiglipConfig)
516
- else self.config.hidden_size
517
- )
518
- nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
519
- elif isinstance(module, nn.Embedding):
520
- default_flax_embed_init(module.weight)
521
- elif isinstance(module, SiglipAttention):
522
- nn.init.xavier_uniform_(module.q_proj.weight)
523
- nn.init.xavier_uniform_(module.k_proj.weight)
524
- nn.init.xavier_uniform_(module.v_proj.weight)
525
- nn.init.xavier_uniform_(module.out_proj.weight)
526
- nn.init.zeros_(module.q_proj.bias)
527
- nn.init.zeros_(module.k_proj.bias)
528
- nn.init.zeros_(module.v_proj.bias)
529
- nn.init.zeros_(module.out_proj.bias)
530
- elif isinstance(module, SiglipMLP):
531
- nn.init.xavier_uniform_(module.fc1.weight)
532
- nn.init.xavier_uniform_(module.fc2.weight)
533
- nn.init.normal_(module.fc1.bias, std=1e-6)
534
- nn.init.normal_(module.fc2.bias, std=1e-6)
535
- elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
536
- nn.init.xavier_uniform_(module.probe.data)
537
- nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
538
- nn.init.zeros_(module.attention.in_proj_bias.data)
539
- elif isinstance(module, SiglipModel):
540
- logit_scale_init = torch.log(torch.tensor(1.0))
541
- module.logit_scale.data.fill_(logit_scale_init)
542
- module.logit_bias.data.zero_()
543
- elif isinstance(module, (nn.Linear, nn.Conv2d)):
544
- lecun_normal_(module.weight)
545
- if module.bias is not None:
546
- nn.init.zeros_(module.bias)
547
- elif isinstance(module, nn.LayerNorm):
548
- module.bias.data.zero_()
549
- module.weight.data.fill_(1.0)
550
-
551
-
552
- SIGLIP_START_DOCSTRING = r"""
553
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
554
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
555
- etc.)
556
-
557
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
558
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
559
- and behavior.
560
-
561
- Parameters:
562
- config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
563
- Initializing with a config file does not load the weights associated with the model, only the
564
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
565
- """
566
-
567
- SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
568
- Args:
569
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
570
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
571
- it.
572
-
573
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
574
- [`PreTrainedTokenizer.__call__`] for details.
575
-
576
- [What are input IDs?](../glossary#input-ids)
577
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
578
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
579
-
580
- - 1 for tokens that are **not masked**,
581
- - 0 for tokens that are **masked**.
582
-
583
- [What are attention masks?](../glossary#attention-mask)
584
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
585
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
586
- config.max_position_embeddings - 1]`.
587
-
588
- [What are position IDs?](../glossary#position-ids)
589
- output_attentions (`bool`, *optional*):
590
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
591
- tensors for more detail.
592
- output_hidden_states (`bool`, *optional*):
593
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
594
- more detail.
595
- return_dict (`bool`, *optional*):
596
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
597
- """
598
-
599
- SIGLIP_VISION_INPUTS_DOCSTRING = r"""
600
- Args:
601
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
602
- Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
603
- [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
604
- output_attentions (`bool`, *optional*):
605
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
606
- tensors for more detail.
607
- output_hidden_states (`bool`, *optional*):
608
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
609
- more detail.
610
- return_dict (`bool`, *optional*):
611
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
612
- """
613
-
614
- SIGLIP_INPUTS_DOCSTRING = r"""
615
- Args:
616
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
617
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
618
- it.
619
-
620
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
621
- [`PreTrainedTokenizer.__call__`] for details.
622
-
623
- [What are input IDs?](../glossary#input-ids)
624
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
625
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
626
-
627
- - 1 for tokens that are **not masked**,
628
- - 0 for tokens that are **masked**.
629
-
630
- [What are attention masks?](../glossary#attention-mask)
631
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
632
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
633
- config.max_position_embeddings - 1]`.
634
-
635
- [What are position IDs?](../glossary#position-ids)
636
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
637
- Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
638
- [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
639
- return_loss (`bool`, *optional*):
640
- Whether or not to return the contrastive loss.
641
- output_attentions (`bool`, *optional*):
642
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
643
- tensors for more detail.
644
- output_hidden_states (`bool`, *optional*):
645
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
646
- more detail.
647
- return_dict (`bool`, *optional*):
648
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
649
- """
650
-
651
-
652
- # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
653
- class SiglipEncoder(nn.Module):
654
- """
655
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
656
- [`SiglipEncoderLayer`].
657
-
658
- Args:
659
- config: SiglipConfig
660
- """
661
-
662
- def __init__(self, config: SiglipConfig):
663
- super().__init__()
664
- self.config = config
665
- self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
666
- self.gradient_checkpointing = False
667
-
668
- # Ignore copy
669
- def forward(
670
- self,
671
- inputs_embeds,
672
- attention_mask: Optional[torch.Tensor] = None,
673
- output_attentions: Optional[bool] = None,
674
- output_hidden_states: Optional[bool] = None,
675
- return_dict: Optional[bool] = None,
676
- ) -> Union[Tuple, BaseModelOutput]:
677
- r"""
678
- Args:
679
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
680
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
681
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
682
- than the model's internal embedding lookup matrix.
683
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
684
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
685
-
686
- - 1 for tokens that are **not masked**,
687
- - 0 for tokens that are **masked**.
688
-
689
- [What are attention masks?](../glossary#attention-mask)
690
- output_attentions (`bool`, *optional*):
691
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
692
- returned tensors for more detail.
693
- output_hidden_states (`bool`, *optional*):
694
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
695
- for more detail.
696
- return_dict (`bool`, *optional*):
697
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
698
- """
699
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
700
- output_hidden_states = (
701
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
702
- )
703
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
704
-
705
- encoder_states = () if output_hidden_states else None
706
- all_attentions = () if output_attentions else None
707
-
708
- hidden_states = inputs_embeds
709
- for encoder_layer in self.layers:
710
- if output_hidden_states:
711
- encoder_states = encoder_states + (hidden_states,)
712
- if self.gradient_checkpointing and self.training:
713
- layer_outputs = self._gradient_checkpointing_func(
714
- encoder_layer.__call__,
715
- hidden_states,
716
- attention_mask,
717
- output_attentions,
718
- )
719
- else:
720
- layer_outputs = encoder_layer(
721
- hidden_states,
722
- attention_mask,
723
- output_attentions=output_attentions,
724
- )
725
-
726
- hidden_states = layer_outputs[0]
727
-
728
- if output_attentions:
729
- all_attentions = all_attentions + (layer_outputs[1],)
730
-
731
- if output_hidden_states:
732
- encoder_states = encoder_states + (hidden_states,)
733
-
734
- if not return_dict:
735
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
736
- return BaseModelOutput(
737
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
738
- )
739
-
740
-
741
- class SiglipTextTransformer(nn.Module):
742
- def __init__(self, config: SiglipTextConfig):
743
- super().__init__()
744
- self.config = config
745
- embed_dim = config.hidden_size
746
- self.embeddings = SiglipTextEmbeddings(config)
747
- self.encoder = SiglipEncoder(config)
748
- self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
749
-
750
- self.head = nn.Linear(embed_dim, embed_dim)
751
-
752
- @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
753
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
754
- def forward(
755
- self,
756
- input_ids: Optional[torch.Tensor] = None,
757
- attention_mask: Optional[torch.Tensor] = None,
758
- position_ids: Optional[torch.Tensor] = None,
759
- output_attentions: Optional[bool] = None,
760
- output_hidden_states: Optional[bool] = None,
761
- return_dict: Optional[bool] = None,
762
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
763
- r"""
764
- Returns:
765
-
766
- """
767
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
768
- output_hidden_states = (
769
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
770
- )
771
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
772
-
773
- if input_ids is None:
774
- raise ValueError("You have to specify input_ids")
775
-
776
- input_shape = input_ids.size()
777
- input_ids = input_ids.view(-1, input_shape[-1])
778
-
779
- hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
780
-
781
- # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
782
- # expand attention_mask
783
- if attention_mask is not None:
784
- # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
785
- attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
786
-
787
- encoder_outputs = self.encoder(
788
- inputs_embeds=hidden_states,
789
- attention_mask=attention_mask,
790
- output_attentions=output_attentions,
791
- output_hidden_states=output_hidden_states,
792
- return_dict=return_dict,
793
- )
794
-
795
- last_hidden_state = encoder_outputs[0]
796
- last_hidden_state = self.final_layer_norm(last_hidden_state)
797
-
798
- # Assuming "sticky" EOS tokenization, last token is always EOS.
799
- pooled_output = last_hidden_state[:, -1, :]
800
- pooled_output = self.head(pooled_output)
801
-
802
- if not return_dict:
803
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
804
-
805
- return BaseModelOutputWithPooling(
806
- last_hidden_state=last_hidden_state,
807
- pooler_output=pooled_output,
808
- hidden_states=encoder_outputs.hidden_states,
809
- attentions=encoder_outputs.attentions,
810
- )
811
-
812
-
813
- @add_start_docstrings(
814
- """The text model from SigLIP without any head or projection on top.""",
815
- SIGLIP_START_DOCSTRING,
816
- )
817
- class SiglipTextModel(SiglipPreTrainedModel):
818
- config_class = SiglipTextConfig
819
-
820
- _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
821
-
822
- def __init__(self, config: SiglipTextConfig):
823
- super().__init__(config)
824
- self.text_model = SiglipTextTransformer(config)
825
- # Initialize weights and apply final processing
826
- self.post_init()
827
-
828
- def get_input_embeddings(self) -> nn.Module:
829
- return self.text_model.embeddings.token_embedding
830
-
831
- def set_input_embeddings(self, value):
832
- self.text_model.embeddings.token_embedding = value
833
-
834
- @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
835
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
836
- def forward(
837
- self,
838
- input_ids: Optional[torch.Tensor] = None,
839
- attention_mask: Optional[torch.Tensor] = None,
840
- position_ids: Optional[torch.Tensor] = None,
841
- output_attentions: Optional[bool] = None,
842
- output_hidden_states: Optional[bool] = None,
843
- return_dict: Optional[bool] = None,
844
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
845
- r"""
846
- Returns:
847
-
848
- Examples:
849
-
850
- ```python
851
- >>> from transformers import AutoTokenizer, SiglipTextModel
852
-
853
- >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
854
- >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
855
-
856
- >>> # important: make sure to set padding="max_length" as that's how the model was trained
857
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
858
-
859
- >>> outputs = model(**inputs)
860
- >>> last_hidden_state = outputs.last_hidden_state
861
- >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
862
- ```"""
863
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
864
-
865
- return self.text_model(
866
- input_ids=input_ids,
867
- attention_mask=attention_mask,
868
- position_ids=position_ids,
869
- output_attentions=output_attentions,
870
- output_hidden_states=output_hidden_states,
871
- return_dict=return_dict,
872
- )
873
-
874
-
875
- class SiglipVisionTransformer(nn.Module):
876
- def __init__(self, config: SiglipVisionConfig):
877
- super().__init__()
878
- self.config = config
879
- embed_dim = config.hidden_size
880
-
881
- self.embeddings = SiglipVisionEmbeddings(config)
882
- self.encoder = SiglipEncoder(config)
883
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
884
- self.head = SiglipMultiheadAttentionPoolingHead(config)
885
-
886
- @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
887
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
888
- def forward(
889
- self,
890
- pixel_values,
891
- output_attentions: Optional[bool] = None,
892
- output_hidden_states: Optional[bool] = None,
893
- return_dict: Optional[bool] = None,
894
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
895
- r"""
896
- Returns:
897
-
898
- """
899
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
900
- output_hidden_states = (
901
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
902
- )
903
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
904
-
905
- hidden_states = self.embeddings(pixel_values)
906
-
907
- encoder_outputs = self.encoder(
908
- inputs_embeds=hidden_states,
909
- output_attentions=output_attentions,
910
- output_hidden_states=output_hidden_states,
911
- return_dict=return_dict,
912
- )
913
-
914
- last_hidden_state = encoder_outputs[0]
915
- last_hidden_state = self.post_layernorm(last_hidden_state)
916
-
917
- pooled_output = self.head(last_hidden_state)
918
-
919
- if not return_dict:
920
- return (last_hidden_state, pooled_output) + encoder_outputs[1:]
921
-
922
- return BaseModelOutputWithPooling(
923
- last_hidden_state=last_hidden_state,
924
- pooler_output=pooled_output,
925
- hidden_states=encoder_outputs.hidden_states,
926
- attentions=encoder_outputs.attentions,
927
- )
928
-
929
-
930
- class SiglipMultiheadAttentionPoolingHead(nn.Module):
931
- """Multihead Attention Pooling."""
932
-
933
- def __init__(self, config: SiglipVisionConfig):
934
- super().__init__()
935
-
936
- self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
937
- self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
938
- self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
939
- self.mlp = SiglipMLP(config)
940
-
941
- def forward(self, hidden_state):
942
- batch_size = hidden_state.shape[0]
943
- probe = self.probe.repeat(batch_size, 1, 1)
944
-
945
- hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
946
-
947
- residual = hidden_state
948
- hidden_state = self.layernorm(hidden_state)
949
- hidden_state = residual + self.mlp(hidden_state)
950
-
951
- return hidden_state[:, 0]
952
-
953
-
954
- @add_start_docstrings(
955
- """The vision model from SigLIP without any head or projection on top.""",
956
- SIGLIP_START_DOCSTRING,
957
- )
958
- class SiglipVisionModel(SiglipPreTrainedModel):
959
- config_class = SiglipVisionConfig
960
- main_input_name = "pixel_values"
961
- _no_split_modules = [
962
- "SiglipEncoderLayer",
963
- "SiglipVisionEmbeddings",
964
- "SiglipMultiheadAttentionPoolingHead",
965
- ]
966
-
967
- def __init__(self, config: SiglipVisionConfig):
968
- super().__init__(config)
969
-
970
- self.vision_model = SiglipVisionTransformer(config)
971
-
972
- # Initialize weights and apply final processing
973
- self.post_init()
974
-
975
- def get_input_embeddings(self) -> nn.Module:
976
- return self.vision_model.embeddings.patch_embedding
977
-
978
- @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
979
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
980
- def forward(
981
- self,
982
- pixel_values,
983
- output_attentions: Optional[bool] = None,
984
- output_hidden_states: Optional[bool] = None,
985
- return_dict: Optional[bool] = None,
986
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
987
- r"""
988
- Returns:
989
-
990
- Examples:
991
-
992
- ```python
993
- >>> from PIL import Image
994
- >>> import requests
995
- >>> from transformers import AutoProcessor, SiglipVisionModel
996
-
997
- >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
998
- >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
999
-
1000
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1001
- >>> image = Image.open(requests.get(url, stream=True).raw)
1002
-
1003
- >>> inputs = processor(images=image, return_tensors="pt")
1004
-
1005
- >>> outputs = model(**inputs)
1006
- >>> last_hidden_state = outputs.last_hidden_state
1007
- >>> pooled_output = outputs.pooler_output # pooled features
1008
- ```"""
1009
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1010
-
1011
- return self.vision_model(
1012
- pixel_values=pixel_values,
1013
- output_attentions=output_attentions,
1014
- output_hidden_states=output_hidden_states,
1015
- return_dict=return_dict,
1016
- )
1017
-
1018
-
1019
- @add_start_docstrings(SIGLIP_START_DOCSTRING)
1020
- class SiglipModel(SiglipPreTrainedModel):
1021
- config_class = SiglipConfig
1022
-
1023
- def __init__(self, config: SiglipConfig):
1024
- super().__init__(config)
1025
-
1026
- if not isinstance(config.text_config, SiglipTextConfig):
1027
- raise ValueError(
1028
- "config.text_config is expected to be of type SiglipTextConfig but is of type"
1029
- f" {type(config.text_config)}."
1030
- )
1031
-
1032
- if not isinstance(config.vision_config, SiglipVisionConfig):
1033
- raise ValueError(
1034
- "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1035
- f" {type(config.vision_config)}."
1036
- )
1037
-
1038
- text_config = config.text_config
1039
- vision_config = config.vision_config
1040
-
1041
- self.text_model = SiglipTextTransformer(text_config)
1042
- self.vision_model = SiglipVisionTransformer(vision_config)
1043
-
1044
- self.logit_scale = nn.Parameter(torch.randn(1))
1045
- self.logit_bias = nn.Parameter(torch.randn(1))
1046
-
1047
- # Initialize weights and apply final processing
1048
- self.post_init()
1049
-
1050
- @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1051
- def get_text_features(
1052
- self,
1053
- input_ids: Optional[torch.Tensor] = None,
1054
- attention_mask: Optional[torch.Tensor] = None,
1055
- position_ids: Optional[torch.Tensor] = None,
1056
- output_attentions: Optional[bool] = None,
1057
- output_hidden_states: Optional[bool] = None,
1058
- return_dict: Optional[bool] = None,
1059
- ) -> torch.FloatTensor:
1060
- r"""
1061
- Returns:
1062
- text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1063
- applying the projection layer to the pooled output of [`SiglipTextModel`].
1064
-
1065
- Examples:
1066
-
1067
- ```python
1068
- >>> from transformers import AutoTokenizer, AutoModel
1069
- >>> import torch
1070
-
1071
- >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1072
- >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1073
-
1074
- >>> # important: make sure to set padding="max_length" as that's how the model was trained
1075
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1076
- >>> with torch.no_grad():
1077
- transformers. text_features = model.get_text_features(**inputs)
1078
- ```"""
1079
- # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1080
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1081
- output_hidden_states = (
1082
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1083
- )
1084
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1085
-
1086
- text_outputs = self.text_model(
1087
- input_ids=input_ids,
1088
- attention_mask=attention_mask,
1089
- position_ids=position_ids,
1090
- output_attentions=output_attentions,
1091
- output_hidden_states=output_hidden_states,
1092
- return_dict=return_dict,
1093
- )
1094
-
1095
- pooled_output = text_outputs[1]
1096
-
1097
- return pooled_output
1098
-
1099
- @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1100
- def get_image_features(
1101
- self,
1102
- pixel_values: Optional[torch.FloatTensor] = None,
1103
- output_attentions: Optional[bool] = None,
1104
- output_hidden_states: Optional[bool] = None,
1105
- return_dict: Optional[bool] = None,
1106
- ) -> torch.FloatTensor:
1107
- r"""
1108
- Returns:
1109
- image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1110
- applying the projection layer to the pooled output of [`SiglipVisionModel`].
1111
-
1112
- Examples:
1113
-
1114
- ```python
1115
- >>> from PIL import Image
1116
- >>> import requests
1117
- >>> from transformers import AutoProcessor, AutoModel
1118
- >>> import torch
1119
-
1120
- >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1121
- >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1122
-
1123
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1124
- >>> image = Image.open(requests.get(url, stream=True).raw)
1125
-
1126
- >>> inputs = processor(images=image, return_tensors="pt")
1127
-
1128
- >>> with torch.no_grad():
1129
- transformers. image_features = model.get_image_features(**inputs)
1130
- ```"""
1131
- # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1132
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1133
- output_hidden_states = (
1134
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1135
- )
1136
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1137
-
1138
- vision_outputs = self.vision_model(
1139
- pixel_values=pixel_values,
1140
- output_attentions=output_attentions,
1141
- output_hidden_states=output_hidden_states,
1142
- return_dict=return_dict,
1143
- )
1144
-
1145
- pooled_output = vision_outputs[1]
1146
-
1147
- return pooled_output
1148
-
1149
- @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1150
- @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1151
- def forward(
1152
- self,
1153
- input_ids: Optional[torch.LongTensor] = None,
1154
- pixel_values: Optional[torch.FloatTensor] = None,
1155
- attention_mask: Optional[torch.Tensor] = None,
1156
- position_ids: Optional[torch.LongTensor] = None,
1157
- return_loss: Optional[bool] = None,
1158
- output_attentions: Optional[bool] = None,
1159
- output_hidden_states: Optional[bool] = None,
1160
- return_dict: Optional[bool] = None,
1161
- ) -> Union[Tuple, SiglipOutput]:
1162
- r"""
1163
- Returns:
1164
-
1165
- Examples:
1166
-
1167
- ```python
1168
- >>> from PIL import Image
1169
- >>> import requests
1170
- >>> from transformers import AutoProcessor, AutoModel
1171
- >>> import torch
1172
-
1173
- >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1174
- >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1175
-
1176
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1177
- >>> image = Image.open(requests.get(url, stream=True).raw)
1178
-
1179
- >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1180
- >>> # important: we pass `padding=max_length` since the model was trained with this
1181
- >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1182
-
1183
- >>> with torch.no_grad():
1184
- transformers. outputs = model(**inputs)
1185
-
1186
- >>> logits_per_image = outputs.logits_per_image
1187
- >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1188
- >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1189
- 31.9% that image 0 is 'a photo of 2 cats'
1190
- ```"""
1191
- # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1192
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1193
- output_hidden_states = (
1194
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1195
- )
1196
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1197
-
1198
- vision_outputs = self.vision_model(
1199
- pixel_values=pixel_values,
1200
- output_attentions=output_attentions,
1201
- output_hidden_states=output_hidden_states,
1202
- return_dict=return_dict,
1203
- )
1204
-
1205
- text_outputs = self.text_model(
1206
- input_ids=input_ids,
1207
- attention_mask=attention_mask,
1208
- position_ids=position_ids,
1209
- output_attentions=output_attentions,
1210
- output_hidden_states=output_hidden_states,
1211
- return_dict=return_dict,
1212
- )
1213
-
1214
- image_embeds = vision_outputs[1]
1215
- text_embeds = text_outputs[1]
1216
-
1217
- # normalized features
1218
- image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1219
- text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1220
-
1221
- # cosine similarity as logits
1222
- logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1223
- logits_per_image = logits_per_text.t()
1224
-
1225
- loss = None
1226
- if return_loss:
1227
- raise NotImplementedError("SigLIP loss to be implemented")
1228
-
1229
- if not return_dict:
1230
- output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1231
- return ((loss,) + output) if loss is not None else output
1232
-
1233
- return SiglipOutput(
1234
- loss=loss,
1235
- logits_per_image=logits_per_image,
1236
- logits_per_text=logits_per_text,
1237
- text_embeds=text_embeds,
1238
- image_embeds=image_embeds,
1239
- text_model_output=text_outputs,
1240
- vision_model_output=vision_outputs,
1241
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multi_backbone_channel_concatenation_encoder.py DELETED
@@ -1,266 +0,0 @@
1
- # --------------------------------------------------------
2
- # Eagle2
3
- # Copyright (c) 2025 NVIDIA
4
- # Licensed under The Apache License [see LICENSE for details]
5
- # --------------------------------------------------------
6
-
7
- import torch, os
8
- import torch.nn as nn
9
- from torch.utils.checkpoint import checkpoint
10
-
11
- from .siglip_vision_tower import SiglipVisionTower
12
-
13
- import torch.nn.functional as F
14
- from torch.nn.init import trunc_normal_
15
- from copy import deepcopy
16
- import random
17
- import math
18
-
19
- class MultiBackboneChannelConcatenationVisionTower(nn.Module):
20
- def __init__(self,
21
- vision_tower,
22
- args,
23
- grid_size=32,
24
- convnext_img_size=1024,
25
- normalize_type=None, raw_config=None):
26
-
27
- super().__init__()
28
-
29
- self.is_loaded = False
30
- self.grid_size = grid_size
31
- self.num_tokens = self.grid_size ** 2
32
- self.normalize_type = args.normalize_type
33
- self.moe_version_type = args.moe_version_type
34
- self.raw_config = raw_config
35
- print("moe_version_type: ", self.moe_version_type)
36
- assert self.moe_version_type in [None, 'all_tiling', 'seq_concat', 'feat_concat', 'convnext_512_siglip_448'], f"Unknown self.moe_version_type: {self.moe_version_type}"
37
-
38
- vision_tower_name_list = vision_tower.split(";")
39
- self.input_image_size = 1024
40
- self.convnext_img_size = convnext_img_size
41
- self.load_vision_towers(vision_tower_name_list, args)
42
-
43
-
44
- def load_vision_towers(self, vision_tower_name_list, args):
45
- self.vision_towers = nn.ModuleList()
46
-
47
- freeze_backbone_list = args.freeze_backbones # note this is a str
48
- if freeze_backbone_list is not None and len(freeze_backbone_list) > 0:
49
- print("The frozen backbones: ", freeze_backbone_list)
50
- else:
51
- # make it a blank str
52
- freeze_backbone_list = ""
53
-
54
- for name in vision_tower_name_list:
55
-
56
- ## ConvNeXt
57
- if name == 'convnext-1024':
58
- convnext_args = deepcopy(args)
59
-
60
- convnext_args.freeze_vision = False
61
- if 'convnext-1024' in freeze_backbone_list:
62
- convnext_args.freeze_vision = True
63
-
64
- from .convnext_encoder import ConvNextVisionTower
65
- convnext_args.input_image_size = self.convnext_img_size
66
- convnext_vision_tower = args.vision_tower_convnext_path
67
- convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower,
68
- convnext_args, delay_load=args.delay_load, normalize_type=self.normalize_type)
69
- convnext_vision_tower.load_model()
70
- self.vision_towers.append(convnext_vision_tower)
71
-
72
- ## PaliSigLIP
73
- elif name == 'palisiglip':
74
- palisiglip_args = deepcopy(args)
75
- palisiglip_args.input_image_size = 448
76
-
77
- palisiglip_args.freeze_vision = False
78
- if 'palisiglip' in freeze_backbone_list:
79
- palisiglip_args.freeze_vision = True
80
-
81
- palisiglip_vision_tower = SiglipVisionTower(args.vision_tower_siglip_path, palisiglip_args, delay_load=args.delay_load, raw_config=self.raw_config)
82
-
83
- palisiglip_vision_tower.load_model()
84
- self.vision_towers.append(palisiglip_vision_tower)
85
-
86
- # Set the image processor
87
- self.image_processor = None
88
- self.is_loaded = True
89
-
90
- def load_model(self):
91
- assert self.is_loaded, "All the vision encoders should be loaded during initialization!"
92
-
93
- def forward(self, x):
94
- # x is a Tensor if moe_version_type is None or 'all_tiling'
95
- # else is a tuple(Tensor, Tensor)
96
- if self.moe_version_type in [None, 'all_tiling']:
97
- # The default pipeline
98
- features = []
99
- image_input_size = x.shape[2]
100
- assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})"
101
- for vision_tower in self.vision_towers:
102
-
103
- if vision_tower.input_image_size != image_input_size:
104
- resized_x = F.interpolate(x.float(),
105
- size=(vision_tower.input_image_size, vision_tower.input_image_size),
106
- mode='bilinear',
107
- align_corners=True).to(dtype=x.dtype)
108
- else:
109
- resized_x = x
110
-
111
- feature = vision_tower(resized_x)
112
-
113
- if len(feature.shape) == 3: # b, n, c
114
- b, n, c = feature.shape
115
- if n == self.num_tokens:
116
- features.append(feature)
117
- continue
118
- w = h = int(n**0.5)
119
- feature = feature.transpose(1,2).reshape(b, c, h, w)
120
- else:
121
- b, c, h, w = feature.shape
122
-
123
- if w != self.grid_size:
124
- feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
125
- features.append(feature.flatten(2,3).transpose(1,2))
126
-
127
- features = torch.cat(features, dim=-1)
128
- elif self.moe_version_type == 'convnext_512_siglip_448':
129
- features = {}
130
- image_input_size = x.shape[2]
131
- assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})"
132
- for vision_tower in self.vision_towers:
133
-
134
- if vision_tower.input_image_size != image_input_size:
135
- resized_x = F.interpolate(x.float(),
136
- size=(vision_tower.input_image_size, vision_tower.input_image_size),
137
- mode='bilinear',
138
- align_corners=True).to(dtype=x.dtype)
139
- else:
140
- resized_x = x
141
-
142
- feature = vision_tower(resized_x)
143
-
144
- # if len(feature.shape) == 3: # b, n, c
145
- # b, n, c = feature.shape
146
- # if n == self.num_tokens:
147
- # features.append(feature)
148
- # continue
149
- # w = h = int(n**0.5)
150
- # feature = feature.transpose(1,2).reshape(b, c, h, w)
151
- # else:
152
- # b, c, h, w = feature.shape
153
- features[vision_tower.name] = feature
154
-
155
- else:
156
- assert isinstance(x, dict), "x is expected to be a dict but {}".format(type(x))
157
- pixel_values = x['pixel_values']
158
- num_patches = x['num_patches'] # num patch of paddings token in texts
159
-
160
- # calculated the real image patches
161
- if self.moe_version_type == 'seq_concat':
162
- image_in_num_patches = [i-1 for i in num_patches]
163
- else:
164
- image_in_num_patches = [i for i in num_patches]
165
-
166
-
167
- assert sum(image_in_num_patches) == pixel_values.size(0), "sum(image_in_num_patches) ({}) != pixel_values.size(0) ({})".format(sum(image_in_num_patches), pixel_values.size(0))
168
-
169
- # find the thubnail image id
170
- thumbnail_image_id = torch.cumsum(torch.tensor(image_in_num_patches).to(pixel_values.device), 0) - 1
171
- image_no_tiling = pixel_values[thumbnail_image_id]
172
-
173
- # By default, we use the 1st vision_tower for x, others for x_nt
174
- features = []
175
- for layer_id, vision_tower in enumerate(self.vision_towers):
176
- if layer_id == 0:
177
- x = pixel_values
178
- else:
179
- x = image_no_tiling
180
-
181
- if vision_tower.input_image_size != self.input_image_size:
182
- resized_x = F.interpolate(x.float(),
183
- size=(vision_tower.input_image_size, vision_tower.input_image_size),
184
- mode='bilinear',
185
- align_corners=True).to(dtype=x.dtype)
186
- else:
187
- resized_x = x
188
-
189
- feature = vision_tower(resized_x)
190
- if len(feature.shape) == 3: # b, n, c
191
- b, n, c = feature.shape
192
- if n == self.num_tokens:
193
- features.append(feature)
194
- continue
195
-
196
- w = h = int(n**0.5)
197
- feature = feature.transpose(1,2).reshape(b, c, h, w)
198
- else:
199
- b, c, h, w = feature.shape
200
-
201
- if w != self.grid_size:
202
- feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
203
- features.append(feature.flatten(2,3).transpose(1,2))
204
-
205
- clip_embeds = features[0]
206
- if len(features) <= 1:
207
- no_tiling_embeds = None
208
- else:
209
- no_tiling_embeds = torch.cat(features[1:], dim=-1)
210
-
211
- if self.moe_version_type == 'feat_concat':
212
- # concat thumbnail images features together
213
- clip_thumbnail_embeds = clip_embeds[thumbnail_image_id]
214
- if no_tiling_embeds is not None:
215
- no_tiling_embeds = torch.cat([clip_thumbnail_embeds, no_tiling_embeds], dim=-1)
216
- else:
217
- no_tiling_embeds = clip_thumbnail_embeds
218
-
219
- # extra patch featureas
220
- clip_embeds_mask = ~torch.isin(torch.arange(clip_embeds.shape[0]).to(clip_embeds.device), thumbnail_image_id)
221
- clip_embeds = clip_embeds[clip_embeds_mask]
222
-
223
-
224
- features = {
225
- 'clip_embeds': clip_embeds,
226
- 'no_tiling_embeds': no_tiling_embeds,
227
- 'num_patches': num_patches
228
- }
229
-
230
- # features is a Tensor if not clip_tiling_only
231
-
232
- return features
233
-
234
- @property
235
- def dummy_feature(self):
236
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
237
-
238
- @property
239
- def dtype(self):
240
- return next(self.clip_vision_tower.parameters()).dtype
241
-
242
- @property
243
- def device(self):
244
- return next(self.clip_vision_tower.parameters()).device
245
-
246
- @property
247
- def config(self):
248
- assert NotImplementedError
249
- pass
250
-
251
- @property
252
- def hidden_size(self):
253
- if self.moe_version_type == 'convnext_512_siglip_448':
254
- res = {}
255
- for vision_tower in self.vision_towers:
256
- res[vision_tower.name] = vision_tower.hidden_size
257
- return res
258
- else:
259
- return sum([_.hidden_size for _ in self.vision_towers])
260
-
261
- @property
262
- def num_patches(self):
263
- return self.num_tokens
264
-
265
-
266
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
multi_backbone_channel_concatentation_model.py DELETED
@@ -1,95 +0,0 @@
1
- # --------------------------------------------------------
2
- # Eagle2
3
- # Copyright (c) 2025 NVIDIA
4
- # Licensed under The Apache License [see LICENSE for details]
5
- # --------------------------------------------------------
6
-
7
- import torch.nn as nn
8
-
9
- from transformers.modeling_outputs import BaseModelOutputWithPooling
10
- from typing import Optional, Tuple, Union
11
-
12
- from .multi_backbone_channel_concatenation_encoder import MultiBackboneChannelConcatenationVisionTower
13
- from .configuration_multi_backbone_channel_concatentation_model import MultiBackboneChannelConcatenationVisionModelConfig
14
-
15
-
16
- class MultiBackboneChannelConcatenationVisionModel(nn.Module):
17
-
18
- """
19
- A vision model wrapper that concatenates channels from multiple backbones.
20
-
21
- Args:
22
- config (MultiBackboneChannelConcatenationVisionModelConfig): The configuration for the model.
23
-
24
- Attributes:
25
- vision_model (MultiBackboneChannelConcatenationVisionTower): The vision tower that performs the channel concatenation.
26
-
27
- Notes:
28
- **The class is not inherited from the PreTrainedModel in transformers**
29
-
30
- """
31
-
32
- config_class = MultiBackboneChannelConcatenationVisionModelConfig
33
- main_input_name = "pixel_values"
34
-
35
- def __init__(self, config: MultiBackboneChannelConcatenationVisionModelConfig, raw_config):
36
- super().__init__()
37
-
38
- self.vision_model = MultiBackboneChannelConcatenationVisionTower(
39
- vision_tower=config.vision_tower,
40
- args=config,
41
- grid_size=config.grid_size,
42
- convnext_img_size=config.convnext_img_size,
43
- normalize_type=config.normalize_type,
44
- raw_config=raw_config
45
- )
46
-
47
-
48
- def get_input_embeddings(self):
49
- # You might need to adjust this depending on how you want to handle input embeddings
50
- return self.vision_model.vision_towers[0].get_input_embeddings()
51
-
52
- def forward(
53
- self,
54
- pixel_values,
55
- return_dict: Optional[bool] = True,
56
- output_hidden_states: Optional[bool] = False,
57
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
58
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
59
-
60
- assert return_dict is True, "We only support return_dict"
61
- assert output_hidden_states is False, "We do not support output_hidden_states"
62
-
63
- features = self.vision_model(pixel_values)
64
-
65
- # We only supports features as model outputs
66
- return BaseModelOutputWithPooling(
67
- last_hidden_state=features,
68
- pooler_output=None,
69
- hidden_states=None,
70
- attentions=None,
71
- )
72
-
73
- @property
74
- def dummy_feature(self):
75
- return self.vision_model.dummy_feature
76
-
77
- @property
78
- def dtype(self):
79
- return self.vision_model.dtype
80
-
81
- @property
82
- def device(self):
83
- return self.vision_model.device
84
-
85
- @property
86
- def config(self):
87
- return self.vision_model.config
88
-
89
- @property
90
- def hidden_size(self):
91
- return self.vision_model.hidden_size
92
-
93
- @property
94
- def num_patches(self):
95
- return self.vision_model.num_patches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
siglip_vision_tower.py DELETED
@@ -1,93 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.checkpoint import checkpoint
4
-
5
- from .modeling_siglip import SiglipVisionModel
6
- from .configuration_siglip import SiglipVisionConfig
7
-
8
- import math
9
- import torch
10
- import torch.nn.functional as F
11
- from typing import List, Optional
12
- import os
13
-
14
- class SiglipVisionTower(nn.Module):
15
- # We use the same wrapper as the default clip encoder.
16
- # See `clip_encoder.py` in the same folder
17
- def __init__(self, vision_tower, args, delay_load=False, raw_config=None):
18
- super().__init__()
19
-
20
- self.is_loaded = False
21
- self.freeze_vision=args.freeze_vision
22
- self.input_image_size=args.input_image_size
23
- self.vision_tower_name = vision_tower
24
- self.select_layer = args.mm_vision_select_layer
25
- self.name = 'siglip'
26
- self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
27
- self.delay_load = delay_load
28
- self.raw_config = raw_config
29
- if not delay_load:
30
- self.load_model()
31
- else:
32
- if os.path.isfile(self.vision_tower_name):
33
- self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name, local_files_only=True)
34
- else:
35
- self.cfg_only = SiglipVisionConfig(**self.raw_config.vision_config.siglip_vision_config)
36
-
37
-
38
- def load_model(self):
39
- if self.is_loaded:
40
- print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
41
- return
42
-
43
- # self.image_processor = SiglipImageProcessor(size=1024)
44
- # self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, local_files_only=True, torch_dtype=torch.bfloat16)
45
- if self.delay_load:
46
- # cfg = SiglipVisionConfig.from_pretrained(self.vision_tower_name, local_files_only=True)
47
- self.vision_tower = SiglipVisionModel(self.cfg_only)
48
- else:
49
- self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, local_files_only=True)
50
-
51
- if self.freeze_vision:
52
- self.vision_tower.requires_grad_(False)
53
-
54
- self.vision_tower.vision_model.encoder.gradient_checkpointing = True
55
- self.is_loaded = True
56
-
57
- def forward(self, images):
58
- return self.vision_tower(
59
- pixel_values=images,
60
- output_hidden_states=False,
61
- return_dict=True).last_hidden_state
62
-
63
-
64
- @property
65
- def dummy_feature(self):
66
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
67
-
68
- @property
69
- def dtype(self):
70
- return self.vision_tower.dtype
71
-
72
- @property
73
- def device(self):
74
- return self.vision_tower.device
75
-
76
- @property
77
- def config(self):
78
- if self.is_loaded:
79
- return self.vision_tower.config
80
- else:
81
- return self.cfg_only
82
-
83
- @property
84
- def hidden_size(self):
85
- return self.config.hidden_size
86
-
87
- @property
88
- def num_patches_per_side(self):
89
- return self.config.image_size // self.config.patch_size
90
-
91
- @property
92
- def num_patches(self):
93
- return (self.config.image_size // self.config.patch_size) ** 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenization_qwen2.py DELETED
@@ -1,345 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Tokenization classes for Qwen2."""
16
-
17
- import json
18
- import os
19
- import unicodedata
20
- from functools import lru_cache
21
- from typing import Optional, Tuple
22
-
23
- import regex as re
24
-
25
- from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
26
- from transformers.utils import logging
27
-
28
-
29
- logger = logging.get_logger(__name__)
30
-
31
- VOCAB_FILES_NAMES = {
32
- "vocab_file": "vocab.json",
33
- "merges_file": "merges.txt",
34
- }
35
-
36
- PRETRAINED_VOCAB_FILES_MAP = {
37
- "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
38
- "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
39
- }
40
-
41
- MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
42
-
43
- PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
44
-
45
-
46
- @lru_cache()
47
- # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
48
- def bytes_to_unicode():
49
- """
50
- Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
51
- characters the bpe code barfs on.
52
-
53
- The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
54
- if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
55
- decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
56
- tables between utf-8 bytes and unicode strings.
57
- """
58
- bs = (
59
- list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
60
- )
61
- cs = bs[:]
62
- n = 0
63
- for b in range(2**8):
64
- if b not in bs:
65
- bs.append(b)
66
- cs.append(2**8 + n)
67
- n += 1
68
- cs = [chr(n) for n in cs]
69
- return dict(zip(bs, cs))
70
-
71
-
72
- # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
73
- def get_pairs(word):
74
- """
75
- Return set of symbol pairs in a word.
76
-
77
- Word is represented as tuple of symbols (symbols being variable-length strings).
78
- """
79
- pairs = set()
80
- prev_char = word[0]
81
- for char in word[1:]:
82
- pairs.add((prev_char, char))
83
- prev_char = char
84
- return pairs
85
-
86
-
87
- class Qwen2Tokenizer(PreTrainedTokenizer):
88
- """
89
- Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
90
-
91
- Same with GPT2Tokenzier, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
92
- be encoded differently whether it is at the beginning of the sentence (without space) or not:
93
-
94
- ```python
95
- >>> from transformers import Qwen2Tokenizer
96
-
97
- >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
98
- >>> tokenizer("Hello world")["input_ids"]
99
- [9707, 1879]
100
-
101
- >>> tokenizer(" Hello world")["input_ids"]
102
- [21927, 1879]
103
- ```
104
- This is expected.
105
-
106
- You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
107
-
108
- This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
109
- this superclass for more information regarding those methods.
110
-
111
- Args:
112
- vocab_file (`str`):
113
- Path to the vocabulary file.
114
- merges_file (`str`):
115
- Path to the merges file.
116
- errors (`str`, *optional*, defaults to `"replace"`):
117
- Paradigm to follow when decoding bytes to UTF-8. See
118
- [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
119
- unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
120
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
121
- token instead.
122
- bos_token (`str`, *optional*):
123
- The beginning of sequence token. Not applicable for this tokenizer.
124
- eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
125
- The end of sequence token.
126
- pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
127
- The token used for padding, for example when batching sequences of different lengths.
128
- clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
129
- Whether or not the model should cleanup the spaces that were added when splitting the input text during the
130
- tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
131
- split_special_tokens (`bool`, *optional*, defaults to `False`):
132
- Whether or not the special tokens should be split during the tokenization process. The default behavior is
133
- to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
134
- ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
135
- '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
136
- """
137
-
138
- vocab_files_names = VOCAB_FILES_NAMES
139
- pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
140
- max_model_input_sizes = MAX_MODEL_INPUT_SIZES
141
- model_input_names = ["input_ids", "attention_mask"]
142
-
143
- def __init__(
144
- self,
145
- vocab_file,
146
- merges_file,
147
- errors="replace",
148
- unk_token="<|endoftext|>",
149
- bos_token=None,
150
- eos_token="<|endoftext|>",
151
- pad_token="<|endoftext|>",
152
- clean_up_tokenization_spaces=False,
153
- split_special_tokens=False,
154
- **kwargs,
155
- ):
156
- # Qwen vocab does not contain control tokens; added tokens need to be special
157
- bos_token = (
158
- AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
159
- if isinstance(bos_token, str)
160
- else bos_token
161
- )
162
- eos_token = (
163
- AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
164
- if isinstance(eos_token, str)
165
- else eos_token
166
- )
167
- unk_token = (
168
- AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
169
- if isinstance(unk_token, str)
170
- else unk_token
171
- )
172
- pad_token = (
173
- AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
174
- if isinstance(pad_token, str)
175
- else pad_token
176
- )
177
-
178
- with open(vocab_file, encoding="utf-8") as vocab_handle:
179
- self.encoder = json.load(vocab_handle)
180
- self.decoder = {v: k for k, v in self.encoder.items()}
181
- self.errors = errors # how to handle errors in decoding
182
- self.byte_encoder = bytes_to_unicode()
183
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
184
- bpe_merges = []
185
- with open(merges_file, encoding="utf-8") as merges_handle:
186
- for line in merges_handle:
187
- line = line.strip()
188
- if not line or line.startswith("#"):
189
- continue
190
- bpe_merges.append(tuple(line.split()))
191
- self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
192
- # NOTE: the cache can grow without bound and will get really large for long running processes
193
- # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
194
- # not a memory leak but appears as one.
195
- # GPT2Tokenizer has the same problem, so let's be consistent.
196
- self.cache = {}
197
-
198
- self.pat = re.compile(PRETOKENIZE_REGEX)
199
-
200
- if kwargs.get("add_prefix_space", False):
201
- logger.warning_once(
202
- f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
203
- )
204
-
205
- super().__init__(
206
- errors=errors,
207
- bos_token=bos_token,
208
- eos_token=eos_token,
209
- pad_token=pad_token,
210
- unk_token=unk_token,
211
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
212
- split_special_tokens=split_special_tokens,
213
- **kwargs,
214
- )
215
-
216
- @property
217
- def vocab_size(self) -> int:
218
- return len(self.encoder)
219
-
220
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
221
- def get_vocab(self):
222
- return dict(self.encoder, **self.added_tokens_encoder)
223
-
224
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
225
- def bpe(self, token):
226
- if token in self.cache:
227
- return self.cache[token]
228
- word = tuple(token)
229
- pairs = get_pairs(word)
230
-
231
- if not pairs:
232
- return token
233
-
234
- while True:
235
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
236
- if bigram not in self.bpe_ranks:
237
- break
238
- first, second = bigram
239
- new_word = []
240
- i = 0
241
- while i < len(word):
242
- try:
243
- j = word.index(first, i)
244
- except ValueError:
245
- new_word.extend(word[i:])
246
- break
247
- else:
248
- new_word.extend(word[i:j])
249
- i = j
250
-
251
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
252
- new_word.append(first + second)
253
- i += 2
254
- else:
255
- new_word.append(word[i])
256
- i += 1
257
- new_word = tuple(new_word)
258
- word = new_word
259
- if len(word) == 1:
260
- break
261
- else:
262
- pairs = get_pairs(word)
263
- word = " ".join(word)
264
- self.cache[token] = word
265
- return word
266
-
267
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
268
- def _tokenize(self, text):
269
- """Tokenize a string."""
270
- bpe_tokens = []
271
- for token in re.findall(self.pat, text):
272
- token = "".join(
273
- self.byte_encoder[b] for b in token.encode("utf-8")
274
- ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
275
- bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
276
- return bpe_tokens
277
-
278
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
279
- def _convert_token_to_id(self, token):
280
- """Converts a token (str) in an id using the vocab."""
281
- return self.encoder.get(token, self.encoder.get(self.unk_token))
282
-
283
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
284
- def _convert_id_to_token(self, index):
285
- """Converts an index (integer) in a token (str) using the vocab."""
286
- return self.decoder.get(index)
287
-
288
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
289
- def convert_tokens_to_string(self, tokens):
290
- """Converts a sequence of tokens (string) in a single string."""
291
- text = "".join(tokens)
292
- text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
293
- return text
294
-
295
- def decode(
296
- self,
297
- token_ids,
298
- skip_special_tokens: bool = False,
299
- clean_up_tokenization_spaces: Optional[bool] = False,
300
- spaces_between_special_tokens: bool = False,
301
- **kwargs,
302
- ) -> str:
303
- # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
304
- # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
305
- return super().decode(
306
- token_ids,
307
- skip_special_tokens=skip_special_tokens,
308
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
309
- spaces_between_special_tokens=spaces_between_special_tokens,
310
- **kwargs,
311
- )
312
-
313
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
314
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
315
- if not os.path.isdir(save_directory):
316
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
317
- return
318
- vocab_file = os.path.join(
319
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
320
- )
321
- merge_file = os.path.join(
322
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
323
- )
324
-
325
- with open(vocab_file, "w", encoding="utf-8") as f:
326
- f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
327
-
328
- index = 0
329
- with open(merge_file, "w", encoding="utf-8") as writer:
330
- writer.write("#version: 0.2\n")
331
- for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
332
- if index != token_index:
333
- logger.warning(
334
- f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
335
- " Please check that the tokenizer is not corrupted!"
336
- )
337
- index = token_index
338
- writer.write(" ".join(bpe_tokens) + "\n")
339
- index += 1
340
-
341
- return vocab_file, merge_file
342
-
343
- def prepare_for_tokenization(self, text, **kwargs):
344
- text = unicodedata.normalize("NFC", text)
345
- return (text, kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenization_qwen2_fast.py DELETED
@@ -1,143 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Tokenization classes for Qwen2."""
16
-
17
- from typing import Optional, Tuple
18
-
19
- from transformers.tokenization_utils import AddedToken
20
- from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
- from transformers.utils import logging
22
- from .tokenization_qwen2 import Qwen2Tokenizer
23
-
24
-
25
- logger = logging.get_logger(__name__)
26
-
27
- VOCAB_FILES_NAMES = {
28
- "vocab_file": "vocab.json",
29
- "merges_file": "merges.txt",
30
- "tokenizer_file": "tokenizer.json",
31
- }
32
-
33
- PRETRAINED_VOCAB_FILES_MAP = {
34
- "vocab_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/vocab.json"},
35
- "merges_file": {"qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/merges.txt"},
36
- "tokenizer_file": {
37
- "qwen/qwen-tokenizer": "https://huggingface.co/qwen/qwen-tokenizer/resolve/main/tokenizer.json"
38
- },
39
- }
40
-
41
- MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
42
-
43
-
44
- class Qwen2TokenizerFast(PreTrainedTokenizerFast):
45
- """
46
- Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
47
- Byte-Pair-Encoding.
48
-
49
- Same with GPT2Tokenzier, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
50
- be encoded differently whether it is at the beginning of the sentence (without space) or not:
51
-
52
- ```python
53
- >>> from transformers import Qwen2TokenizerFast
54
-
55
- >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
56
- >>> tokenizer("Hello world")["input_ids"]
57
- [9707, 1879]
58
-
59
- >>> tokenizer(" Hello world")["input_ids"]
60
- [21927, 1879]
61
- ```
62
- This is expected.
63
-
64
- This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
65
- refer to this superclass for more information regarding those methods.
66
-
67
- Args:
68
- vocab_file (`str`, *optional*):
69
- Path to the vocabulary file.
70
- merges_file (`str`, *optional*):
71
- Path to the merges file.
72
- tokenizer_file (`str`, *optional*):
73
- Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
74
- contains everything needed to load the tokenizer.
75
- unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
76
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
77
- token instead. Not applicable to this tokenizer.
78
- bos_token (`str`, *optional*):
79
- The beginning of sequence token. Not applicable for this tokenizer.
80
- eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
81
- The end of sequence token.
82
- pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
83
- The token used for padding, for example when batching sequences of different lengths.
84
- """
85
-
86
- vocab_files_names = VOCAB_FILES_NAMES
87
- pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
88
- max_model_input_sizes = MAX_MODEL_INPUT_SIZES
89
- model_input_names = ["input_ids", "attention_mask"]
90
- slow_tokenizer_class = Qwen2Tokenizer
91
-
92
- def __init__(
93
- self,
94
- vocab_file=None,
95
- merges_file=None,
96
- tokenizer_file=None,
97
- unk_token="<|endoftext|>",
98
- bos_token=None,
99
- eos_token="<|endoftext|>",
100
- pad_token="<|endoftext|>",
101
- **kwargs,
102
- ):
103
- # We need to at least pass vocab_file and merges_file to base class
104
- # in case a slow tokenizer needs to be initialized; other can be
105
- # configured through files.
106
- # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
107
-
108
- bos_token = (
109
- AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
110
- if isinstance(bos_token, str)
111
- else bos_token
112
- )
113
- eos_token = (
114
- AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
115
- if isinstance(eos_token, str)
116
- else eos_token
117
- )
118
- unk_token = (
119
- AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
120
- if isinstance(unk_token, str)
121
- else unk_token
122
- )
123
- pad_token = (
124
- AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
125
- if isinstance(pad_token, str)
126
- else pad_token
127
- )
128
-
129
- super().__init__(
130
- vocab_file,
131
- merges_file,
132
- tokenizer_file=tokenizer_file,
133
- unk_token=unk_token,
134
- bos_token=bos_token,
135
- eos_token=eos_token,
136
- pad_token=pad_token,
137
- **kwargs,
138
- )
139
-
140
- # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
141
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
142
- files = self._tokenizer.model.save(save_directory, name=filename_prefix)
143
- return tuple(files)