michalk8 commited on
Commit
ae9d0c0
1 Parent(s): 8e1b368

Add LiT model

Browse files
README.md CHANGED
@@ -20,7 +20,39 @@ AIMv2 pre-training is simple and straightforward to train and to scale effective
20
  <img src="aimv2_overview_light.png" alt="AIMv2 Overview"/>
21
 
22
  ## Usage
23
- Under construction. Please consider using the models in the [ml-aim](https://github.com/apple/ml-aim) repository.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  ## Citation
26
  If you find our work useful, please consider citing us as:
 
20
  <img src="aimv2_overview_light.png" alt="AIMv2 Overview"/>
21
 
22
  ## Usage
23
+
24
+ ### PyTorch
25
+ ```python
26
+ import requests
27
+ from PIL import Image
28
+ from transformers import AutoProcessor, AutoModel
29
+
30
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
31
+ image = Image.open(requests.get(url, stream=True).raw)
32
+ text = ["Picture of a dog.", "Picture of a cat.", "Picture of a horse."]
33
+
34
+ processor = AutoProcessor.from_pretrained(
35
+ "apple/aimv2-large-patch14-224-lit",
36
+ )
37
+ model = AutoModel.from_pretrained(
38
+ "apple/aimv2-large-patch14-224-lit",
39
+ trust_remote_code=True,
40
+ )
41
+
42
+ inputs = processor(
43
+ images=image,
44
+ text=text,
45
+ add_special_tokens=True,
46
+ truncation=True,
47
+ padding=True,
48
+ return_tensors="pt",
49
+ )
50
+ outputs = model(**inputs)
51
+ probs = outputs.logits_per_image.softmax(dim=-1)
52
+ ```
53
+
54
+ ### JAX
55
+ Under construction.
56
 
57
  ## Citation
58
  If you find our work useful, please consider citing us as:
config.json ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "AIMv2Model"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_aimv2.AIMv2Config",
7
+ "AutoModel": "modeling_aimv2.AIMv2Model"
8
+ },
9
+ "init_temperature": 0.07,
10
+ "max_logit_scale": 100.0,
11
+ "model_type": "aimv2",
12
+ "projection_dim": 768,
13
+ "text_config": {
14
+ "_attn_implementation_autoset": true,
15
+ "_name_or_path": "",
16
+ "add_cross_attention": false,
17
+ "architectures": null,
18
+ "attention_dropout": 0.0,
19
+ "bad_words_ids": null,
20
+ "begin_suppress_tokens": null,
21
+ "bos_token_id": null,
22
+ "chunk_size_feed_forward": 0,
23
+ "cross_attention_hidden_size": null,
24
+ "decoder_start_token_id": null,
25
+ "diversity_penalty": 0.0,
26
+ "do_sample": false,
27
+ "early_stopping": false,
28
+ "encoder_no_repeat_ngram_size": 0,
29
+ "eos_token_id": 49407,
30
+ "exponential_decay_length_penalty": null,
31
+ "finetuning_task": null,
32
+ "forced_bos_token_id": null,
33
+ "forced_eos_token_id": null,
34
+ "hidden_size": 768,
35
+ "id2label": {
36
+ "0": "LABEL_0",
37
+ "1": "LABEL_1"
38
+ },
39
+ "intermediate_size": 2048,
40
+ "is_causal": true,
41
+ "is_decoder": false,
42
+ "is_encoder_decoder": false,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1
46
+ },
47
+ "length_penalty": 1.0,
48
+ "max_context_length": 77,
49
+ "max_length": 20,
50
+ "min_length": 0,
51
+ "model_type": "aimv2",
52
+ "no_repeat_ngram_size": 0,
53
+ "num_attention_heads": 6,
54
+ "num_beam_groups": 1,
55
+ "num_beams": 1,
56
+ "num_hidden_layers": 12,
57
+ "num_return_sequences": 1,
58
+ "output_attentions": false,
59
+ "output_hidden_states": false,
60
+ "output_scores": false,
61
+ "pad_token_id": null,
62
+ "prefix": null,
63
+ "problem_type": null,
64
+ "projection_dropout": 0.0,
65
+ "pruned_heads": {},
66
+ "qkv_bias": false,
67
+ "remove_invalid_values": false,
68
+ "repetition_penalty": 1.0,
69
+ "return_dict": true,
70
+ "return_dict_in_generate": false,
71
+ "rms_norm_eps": 1e-05,
72
+ "sep_token_id": null,
73
+ "suppress_tokens": null,
74
+ "task_specific_params": null,
75
+ "temperature": 1.0,
76
+ "tf_legacy_loss": false,
77
+ "tie_encoder_decoder": false,
78
+ "tie_word_embeddings": true,
79
+ "tokenizer_class": null,
80
+ "top_k": 50,
81
+ "top_p": 1.0,
82
+ "torch_dtype": null,
83
+ "torchscript": false,
84
+ "typical_p": 1.0,
85
+ "use_bfloat16": false,
86
+ "use_bias": false,
87
+ "vocab_size": 49408
88
+ },
89
+ "torch_dtype": "float32",
90
+ "transformers_version": "4.46.3",
91
+ "vision_config": {
92
+ "_attn_implementation_autoset": true,
93
+ "_name_or_path": "",
94
+ "add_cross_attention": false,
95
+ "architectures": null,
96
+ "attention_dropout": 0.0,
97
+ "bad_words_ids": null,
98
+ "begin_suppress_tokens": null,
99
+ "bos_token_id": null,
100
+ "chunk_size_feed_forward": 0,
101
+ "cross_attention_hidden_size": null,
102
+ "decoder_start_token_id": null,
103
+ "diversity_penalty": 0.0,
104
+ "do_sample": false,
105
+ "early_stopping": false,
106
+ "encoder_no_repeat_ngram_size": 0,
107
+ "eos_token_id": null,
108
+ "exponential_decay_length_penalty": null,
109
+ "finetuning_task": null,
110
+ "forced_bos_token_id": null,
111
+ "forced_eos_token_id": null,
112
+ "hidden_size": 1024,
113
+ "id2label": {
114
+ "0": "LABEL_0",
115
+ "1": "LABEL_1"
116
+ },
117
+ "image_size": 224,
118
+ "intermediate_size": 2816,
119
+ "is_causal": false,
120
+ "is_decoder": false,
121
+ "is_encoder_decoder": false,
122
+ "label2id": {
123
+ "LABEL_0": 0,
124
+ "LABEL_1": 1
125
+ },
126
+ "length_penalty": 1.0,
127
+ "max_length": 20,
128
+ "min_length": 0,
129
+ "model_type": "aimv2",
130
+ "no_repeat_ngram_size": 0,
131
+ "num_attention_heads": 8,
132
+ "num_beam_groups": 1,
133
+ "num_beams": 1,
134
+ "num_channels": 3,
135
+ "num_hidden_layers": 24,
136
+ "num_queries": 1,
137
+ "num_return_sequences": 1,
138
+ "output_attentions": false,
139
+ "output_hidden_states": false,
140
+ "output_scores": false,
141
+ "pad_token_id": null,
142
+ "patch_size": 14,
143
+ "prefix": null,
144
+ "problem_type": null,
145
+ "projection_dropout": 0.0,
146
+ "pruned_heads": {},
147
+ "qkv_bias": false,
148
+ "remove_invalid_values": false,
149
+ "repetition_penalty": 1.0,
150
+ "return_dict": true,
151
+ "return_dict_in_generate": false,
152
+ "rms_norm_eps": 1e-05,
153
+ "sep_token_id": null,
154
+ "suppress_tokens": null,
155
+ "task_specific_params": null,
156
+ "temperature": 1.0,
157
+ "tf_legacy_loss": false,
158
+ "tie_encoder_decoder": false,
159
+ "tie_word_embeddings": true,
160
+ "tokenizer_class": null,
161
+ "top_k": 50,
162
+ "top_p": 1.0,
163
+ "torch_dtype": null,
164
+ "torchscript": false,
165
+ "typical_p": 1.0,
166
+ "use_bfloat16": false,
167
+ "use_bias": false
168
+ }
169
+ }
configuration_aimv2.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+ __all__ = ["AIMv2VisionConfig", "AIMv2TextConfig", "AIMv2Config"]
6
+
7
+
8
+ class AIMv2VisionConfig(PretrainedConfig):
9
+ """This is the configuration class to store the configuration of an [`AIMv2VisionModel`].
10
+
11
+ Instantiating a configuration with the defaults will yield a similar configuration
12
+ to that of the [apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit).
13
+
14
+ Args:
15
+ hidden_size: Dimension of the hidden representations.
16
+ intermediate_size: Dimension of the SwiGLU representations.
17
+ num_hidden_layers: Number of hidden layers in the Transformer.
18
+ num_attention_heads: Number of attention heads for each attention layer
19
+ in the Transformer.
20
+ num_queries: Number of learnable queries for the attention-pooling head.
21
+ num_channels: Number of input channels.
22
+ image_size: Image size.
23
+ patch_size: Patch size.
24
+ rms_norm_eps: Epsilon value used for the RMS normalization layer.
25
+ attention_dropout: Dropout ratio for attention probabilities.
26
+ projection_dropout: Dropout ratio for the projection layer after the attention.
27
+ qkv_bias: Whether to add a bias to the queries, keys and values.
28
+ use_bias: Whether to add a bias in the feed-forward and projection layers.
29
+ kwargs: Keyword arguments for the [`PretrainedConfig`].
30
+ """
31
+
32
+ model_type: str = "aimv2"
33
+ base_config_key: str = "vision_config"
34
+
35
+ def __init__(
36
+ self,
37
+ hidden_size: int = 1024,
38
+ intermediate_size: int = 2816,
39
+ num_hidden_layers: int = 24,
40
+ num_attention_heads: int = 8,
41
+ num_queries: int = 1,
42
+ num_channels: int = 3,
43
+ image_size: int = 224,
44
+ patch_size: int = 14,
45
+ rms_norm_eps: float = 1e-5,
46
+ attention_dropout: float = 0.0,
47
+ projection_dropout: float = 0.0,
48
+ qkv_bias: bool = False,
49
+ use_bias: bool = False,
50
+ **kwargs: Any,
51
+ ):
52
+ super().__init__(**kwargs)
53
+ self.hidden_size = hidden_size
54
+ self.intermediate_size = intermediate_size
55
+ self.num_hidden_layers = num_hidden_layers
56
+ self.num_attention_heads = num_attention_heads
57
+ self.num_queries = num_queries
58
+ self.num_channels = num_channels
59
+ self.patch_size = patch_size
60
+ self.image_size = image_size
61
+ self.attention_dropout = attention_dropout
62
+ self.rms_norm_eps = rms_norm_eps
63
+
64
+ self.projection_dropout = projection_dropout
65
+ self.qkv_bias = qkv_bias
66
+ self.use_bias = use_bias
67
+ self.is_causal = False
68
+
69
+
70
+ class AIMv2TextConfig(PretrainedConfig):
71
+ """This is the configuration class to store the configuration of an [`AIMv2TextModel`].
72
+
73
+ Instantiating a configuration with the defaults will yield a similar configuration
74
+ to that of the [apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit).
75
+
76
+ Args:
77
+ vocab_size: Size of the vocabulary.
78
+ hidden_size: Dimension of the hidden representations.
79
+ intermediate_size: Dimension of the SwiGLU representations.
80
+ num_hidden_layers: Number of hidden layers in the Transformer.
81
+ num_attention_heads: Number of attention heads for each attention layer
82
+ in the Transformer.
83
+ rms_norm_eps: Epsilon value used for the RMS normalization layer.
84
+ attention_dropout: Dropout ratio for attention probabilities.
85
+ projection_dropout: Dropout ratio for the projection layer after the attention.
86
+ qkv_bias: Whether to add a bias to the queries, keys and values.
87
+ use_bias: Whether to add a bias in the feed-forward and projection layers.
88
+ eos_token_id: End-of-sequence token id.
89
+ max_context_length: Maximum number of tokens for the context.
90
+ kwargs: Keyword arguments for the [`PretrainedConfig`].
91
+ """
92
+
93
+ model_type: str = "aimv2"
94
+ base_config_key: str = "text_config"
95
+
96
+ def __init__(
97
+ self,
98
+ vocab_size: int = 49408,
99
+ hidden_size: int = 768,
100
+ intermediate_size: int = 2048,
101
+ num_hidden_layers: int = 12,
102
+ num_attention_heads: int = 6,
103
+ rms_norm_eps: float = 1e-5,
104
+ attention_dropout: float = 0.0,
105
+ projection_dropout: float = 0.0,
106
+ qkv_bias: bool = False,
107
+ use_bias: bool = False,
108
+ eos_token_id: int = 49407,
109
+ max_context_length: int = 77,
110
+ **kwargs: Any,
111
+ ):
112
+ super().__init__(**kwargs)
113
+ self.hidden_size = hidden_size
114
+ self.intermediate_size = intermediate_size
115
+ self.num_hidden_layers = num_hidden_layers
116
+ self.num_attention_heads = num_attention_heads
117
+ self.attention_dropout = attention_dropout
118
+ self.rms_norm_eps = rms_norm_eps
119
+
120
+ self.projection_dropout = projection_dropout
121
+ self.qkv_bias = qkv_bias
122
+ self.use_bias = use_bias
123
+ self.vocab_size = vocab_size
124
+ self.max_context_length = max_context_length
125
+ self.eos_token_id = eos_token_id
126
+ self.is_causal = True
127
+
128
+
129
+ class AIMv2Config(PretrainedConfig):
130
+ """This is the configuration class to store the configuration of an [`AIMv2Model`].
131
+
132
+ Instantiating a configuration with the defaults will yield a similar configuration
133
+ to that of the [apple/aimv2-large-patch14-224-lit](https://huggingface.co/apple/aimv2-large-patch14-224-lit).
134
+
135
+ Args:
136
+ vision_config: Vision config.
137
+ text_config: Text config.
138
+ projection_dim: Dimension of the image and text projection layers.
139
+ kwargs: Keyword arguments for the [`PretrainedConfig`].
140
+ """
141
+
142
+ model_type = "aimv2"
143
+ is_composition: bool = True
144
+ sub_configs: Dict[str, PretrainedConfig] = {
145
+ "vision_config": AIMv2VisionConfig,
146
+ "text_config": AIMv2TextConfig,
147
+ }
148
+
149
+ def __init__(
150
+ self,
151
+ vision_config: Optional[AIMv2VisionConfig] = None,
152
+ text_config: Optional[AIMv2TextConfig] = None,
153
+ projection_dim: int = 768,
154
+ init_temperature: float = 0.07,
155
+ max_logit_scale: float = 100.0,
156
+ **kwargs: Any,
157
+ ):
158
+ super().__init__(**kwargs)
159
+ if vision_config is None:
160
+ vision_config = AIMv2VisionConfig()
161
+ if text_config is None:
162
+ text_config = AIMv2TextConfig()
163
+
164
+ self.vision_config = vision_config
165
+ self.text_config = text_config
166
+ self.projection_dim = projection_dim
167
+
168
+ self.init_temperature = init_temperature
169
+ self.max_logit_scale = max_logit_scale
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c218dc23f746407cacb100506d83a1c5f169db11ff3cf141853ac28771af3222
3
- size 1746752308
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58b6d77170c9bdf42988d5b6d8fcec3c32cab90f4f853235706d4edbb05b0fb8
3
+ size 1135090512
modeling_aimv2.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import dataclasses
5
+ import math
6
+
7
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
8
+ from transformers.utils import ModelOutput
9
+
10
+ from configuration_aimv2 import AIMv2Config, AIMv2VisionConfig, AIMv2TextConfig
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention
14
+ from transformers.modeling_utils import PreTrainedModel
15
+
16
+ __all__ = ["AIMv2VisionModel", "AIMv2TextModel", "AIMv2Model"]
17
+
18
+ AIMv2VisionOrTextConfig = Union[AIMv2VisionConfig, AIMv2TextConfig]
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class AIMv2Output(ModelOutput):
23
+ logits_per_image: torch.Tensor
24
+ logits_per_text: Optional[torch.Tensor] = None
25
+ image_features: Optional[torch.Tensor] = None
26
+ text_features: Optional[torch.Tensor] = None
27
+ vision_output: Optional[BaseModelOutputWithNoAttention] = None
28
+ text_output: Optional[BaseModelOutputWithNoAttention] = None
29
+
30
+
31
+ class AIMv2TextPreprocessor(nn.Module):
32
+ def __init__(self, config: AIMv2TextConfig):
33
+ super().__init__()
34
+ self.max_context_length = config.max_context_length
35
+ self.eos_token_id = config.eos_token_id
36
+
37
+ self.text_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
38
+ self.positional_embedding = nn.Parameter(
39
+ torch.zeros(self.max_context_length, config.hidden_size)
40
+ )
41
+
42
+ def forward(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
43
+ _, N = input_ids.shape
44
+ max_len = min(N, self.max_context_length)
45
+ eos_token_mask = input_ids == self.eos_token_id
46
+ tokens = self.text_embedding(input_ids)
47
+ tokens = tokens[:, :max_len] + self.positional_embedding[:max_len].unsqueeze(0)
48
+ return tokens, eos_token_mask
49
+
50
+
51
+ class AIMv2ExtractEOS(nn.Module):
52
+ def forward(
53
+ self, tokens: torch.Tensor, eos_token_mask: torch.Tensor
54
+ ) -> torch.Tensor:
55
+ B, _, D = tokens.shape
56
+ eos_token_mask = torch.argmax(eos_token_mask.float(), dim=-1)
57
+ assert eos_token_mask.shape == (B,)
58
+ eos_token_mask = eos_token_mask.reshape(B, 1, 1).expand(B, 1, D)
59
+ eos_token = torch.gather(tokens, 1, eos_token_mask)
60
+ eos_token = eos_token.squeeze(1)
61
+ return eos_token
62
+
63
+
64
+ class RMSNorm(nn.Module):
65
+ def __init__(self, dim: int, eps: float = 1e-6):
66
+ super().__init__()
67
+ self.weight = nn.Parameter(torch.ones(dim))
68
+ self.eps = eps
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ output = self._norm(x.float()).type_as(x)
72
+ return output * self.weight
73
+
74
+ def extra_repr(self) -> str:
75
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
76
+
77
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
78
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
79
+
80
+
81
+ class AIMv2SwiGLUFFN(nn.Module):
82
+ def __init__(self, config: AIMv2VisionOrTextConfig):
83
+ super().__init__()
84
+ hidden_features = config.intermediate_size
85
+ in_features = config.hidden_size
86
+ bias = config.use_bias
87
+
88
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
89
+ self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
90
+ self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ x = F.silu(self.fc1(x)) * self.fc3(x)
94
+ x = self.fc2(x)
95
+ return x
96
+
97
+
98
+ class AIMv2PatchEmbed(nn.Module):
99
+ def __init__(self, config: AIMv2VisionOrTextConfig):
100
+ super().__init__()
101
+ self.proj = nn.Conv2d(
102
+ config.num_channels,
103
+ config.hidden_size,
104
+ kernel_size=(config.patch_size, config.patch_size),
105
+ stride=(config.patch_size, config.patch_size),
106
+ )
107
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
108
+
109
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
110
+ x = self.proj(x).flatten(2).transpose(1, 2)
111
+ x = self.norm(x)
112
+ return x
113
+
114
+
115
+ class AIMv2ViTPreprocessor(nn.Module):
116
+ def __init__(self, config: AIMv2VisionConfig):
117
+ super().__init__()
118
+ num_patches = (config.image_size // config.patch_size) ** 2
119
+
120
+ self.patchifier = AIMv2PatchEmbed(config)
121
+ self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.hidden_size)))
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ tokens = self.patchifier(x)
125
+ _, N, _ = tokens.shape
126
+ pos_embed = self.pos_embed.to(tokens.device)
127
+ tokens = tokens + pos_embed[:, :N]
128
+ return tokens
129
+
130
+
131
+ class AIMv2Attention(nn.Module):
132
+ def __init__(self, config: AIMv2VisionOrTextConfig):
133
+ super().__init__()
134
+ dim = config.hidden_size
135
+
136
+ self.num_heads = config.num_attention_heads
137
+ self.is_causal = config.is_causal
138
+ self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias)
139
+ self.attn_drop = nn.Dropout(config.attention_dropout)
140
+ self.proj = nn.Linear(dim, dim, bias=config.use_bias)
141
+ self.proj_drop = nn.Dropout(config.projection_dropout)
142
+
143
+ def forward(
144
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
145
+ ) -> torch.Tensor:
146
+ B, N, C = x.shape
147
+ qkv = (
148
+ self.qkv(x)
149
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
150
+ .permute(2, 0, 3, 1, 4)
151
+ )
152
+ q, k, v = qkv.unbind(0)
153
+
154
+ if mask is None:
155
+ x = F.scaled_dot_product_attention(q, k, v, is_causal=self.is_causal)
156
+ else:
157
+ mask_converter = AttentionMaskConverter(self.is_causal)
158
+ mask = mask_converter.to_4d(
159
+ mask, key_value_length=N, query_length=N, dtype=q.dtype
160
+ )
161
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
162
+
163
+ x = x.transpose(1, 2).contiguous().reshape(B, N, C)
164
+ x = self.proj(x)
165
+ x = self.proj_drop(x)
166
+ return x
167
+
168
+
169
+ class AIMv2Block(nn.Module):
170
+ def __init__(self, config: AIMv2VisionOrTextConfig):
171
+ super().__init__()
172
+ self.attn = AIMv2Attention(config)
173
+ self.norm_1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
174
+ self.mlp = AIMv2SwiGLUFFN(config)
175
+ self.norm_2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
176
+
177
+ def forward(
178
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
179
+ ) -> torch.Tensor:
180
+ x = x + self.attn(self.norm_1(x), mask)
181
+ x = x + self.mlp(self.norm_2(x))
182
+ return x
183
+
184
+
185
+ class AIMv2AttentionPoolingHead(nn.Module):
186
+ def __init__(self, config: AIMv2VisionConfig):
187
+ super().__init__()
188
+ dim = config.hidden_size
189
+ qkv_bias = config.qkv_bias
190
+
191
+ self.num_heads = config.num_attention_heads
192
+ self.num_queries = config.num_queries
193
+
194
+ self.k = nn.Linear(dim, dim, bias=qkv_bias)
195
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
196
+ self.cls_token = nn.Parameter(torch.randn(1, self.num_queries, dim) * 0.02)
197
+ self.linear = nn.Linear(dim, dim, bias=True)
198
+
199
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
200
+ B, N, C = x.shape
201
+ cls_token = self.cls_token.expand(B, -1, -1)
202
+
203
+ q = cls_token.reshape(
204
+ B, self.num_queries, self.num_heads, C // self.num_heads
205
+ ).permute(0, 2, 1, 3)
206
+ k = (
207
+ self.k(x)
208
+ .reshape(B, N, self.num_heads, C // self.num_heads)
209
+ .permute(0, 2, 1, 3)
210
+ )
211
+ v = (
212
+ self.v(x)
213
+ .reshape(B, N, self.num_heads, C // self.num_heads)
214
+ .permute(0, 2, 1, 3)
215
+ )
216
+
217
+ x_cls = F.scaled_dot_product_attention(q, k, v)
218
+ x_cls = x_cls.transpose(1, 2).reshape(B, self.num_queries, C)
219
+ x_cls = x_cls.mean(dim=1)
220
+
221
+ out = self.linear(x_cls)
222
+ return out
223
+
224
+
225
+ class AIMv2Transformer(nn.Module):
226
+ def __init__(self, config: AIMv2VisionOrTextConfig):
227
+ super().__init__()
228
+ self.blocks = nn.ModuleList(
229
+ [AIMv2Block(config) for _ in range(config.num_hidden_layers)]
230
+ )
231
+ self.post_trunk_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
232
+
233
+ def forward(
234
+ self,
235
+ tokens: torch.Tensor,
236
+ mask: Optional[torch.Tensor] = None,
237
+ output_hidden_states: bool = False,
238
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
239
+ hidden_states = () if output_hidden_states else None
240
+ for block in self.blocks:
241
+ tokens = block(tokens, mask)
242
+ if output_hidden_states:
243
+ hidden_states += (tokens,)
244
+ tokens = self.post_trunk_norm(tokens)
245
+ return tokens, hidden_states
246
+
247
+
248
+ class AIMv2PretrainedModel(PreTrainedModel):
249
+ base_model_prefix = "aimv2"
250
+ _supports_sdpa = True
251
+
252
+
253
+ class AIMv2VisionModel(AIMv2PretrainedModel):
254
+ config_class = AIMv2VisionConfig
255
+ main_input_name = "pixel_values"
256
+ _no_split_modules = ["AIMv2ViTPreprocessor", "AIMv2Block"]
257
+
258
+ def __init__(self, config: AIMv2VisionConfig):
259
+ super().__init__(config)
260
+ self.preprocessor = AIMv2ViTPreprocessor(config)
261
+ self.trunk = AIMv2Transformer(config)
262
+ self.head = AIMv2AttentionPoolingHead(config)
263
+
264
+ def forward(
265
+ self,
266
+ pixel_values: torch.Tensor,
267
+ mask: Optional[torch.Tensor] = None,
268
+ output_hidden_states: Optional[bool] = None,
269
+ return_dict: Optional[bool] = None,
270
+ ) -> Union[
271
+ Tuple[torch.Tensor],
272
+ Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
273
+ BaseModelOutputWithNoAttention,
274
+ ]:
275
+ if output_hidden_states is None:
276
+ output_hidden_states = self.config.output_hidden_states
277
+ if return_dict is None:
278
+ return_dict = self.config.use_return_dict
279
+
280
+ x = self.preprocessor(pixel_values)
281
+ x, hidden_states = self.trunk(
282
+ x, mask, output_hidden_states=output_hidden_states
283
+ )
284
+ x = self.head(x)
285
+
286
+ if not return_dict:
287
+ res = (x,)
288
+ res += (hidden_states,) if output_hidden_states else ()
289
+ return res
290
+
291
+ return BaseModelOutputWithNoAttention(
292
+ last_hidden_state=x,
293
+ hidden_states=hidden_states,
294
+ )
295
+
296
+
297
+ class AIMv2TextModel(AIMv2PretrainedModel):
298
+ config_class = AIMv2TextConfig
299
+ main_input_name = "input_ids"
300
+ _no_split_modules = ["AIMv2TextPreprocessor", "AIMv2Block"]
301
+
302
+ def __init__(self, config: AIMv2TextConfig):
303
+ super().__init__(config)
304
+ self.preprocessor = AIMv2TextPreprocessor(config)
305
+ self.trunk = AIMv2Transformer(config)
306
+ self.head = AIMv2ExtractEOS()
307
+
308
+ def forward(
309
+ self,
310
+ pixel_values: torch.Tensor,
311
+ mask: Optional[torch.Tensor] = None,
312
+ output_hidden_states: Optional[bool] = None,
313
+ return_dict: Optional[bool] = None,
314
+ ) -> Union[
315
+ Tuple[torch.Tensor],
316
+ Tuple[torch.Tensor, Tuple[torch.Tensor, ...]],
317
+ BaseModelOutputWithNoAttention,
318
+ ]:
319
+ if output_hidden_states is None:
320
+ output_hidden_states = self.config.output_hidden_states
321
+ if return_dict is None:
322
+ return_dict = self.config.use_return_dict
323
+
324
+ x, eos_token_mask = self.preprocessor(pixel_values)
325
+ x, hidden_states = self.trunk(
326
+ x, mask, output_hidden_states=output_hidden_states
327
+ )
328
+ x = self.head(x, eos_token_mask)
329
+
330
+ if not return_dict:
331
+ res = (x,)
332
+ res += (hidden_states,) if output_hidden_states else ()
333
+ return res
334
+
335
+ return BaseModelOutputWithNoAttention(
336
+ last_hidden_state=x,
337
+ hidden_states=hidden_states,
338
+ )
339
+
340
+
341
+ class AIMv2Model(AIMv2PretrainedModel):
342
+ config_class = AIMv2Config
343
+ main_input_name = ["input_ids", "pixel_values"]
344
+ _no_split_modules = ["AIMv2ViTPreprocessor", "AIMv2TextPreprocessor", "AIMv2Block"]
345
+
346
+ def __init__(self, config: AIMv2Config):
347
+ super().__init__(config)
348
+ self.image_encoder = AIMv2VisionModel(config.vision_config)
349
+ self.text_encoder = AIMv2TextModel(config.text_config)
350
+
351
+ self.image_projector = nn.Linear(
352
+ config.vision_config.hidden_size, config.projection_dim, bias=False
353
+ )
354
+ self.text_projector = nn.Linear(
355
+ config.text_config.hidden_size, config.projection_dim, bias=False
356
+ )
357
+
358
+ self.log_logit_scale = nn.Parameter(
359
+ torch.full([], fill_value=math.log(1.0 / config.init_temperature))
360
+ )
361
+ self.max_log_logit_scale = math.log(config.max_logit_scale)
362
+
363
+ def forward(
364
+ self,
365
+ input_ids: torch.Tensor,
366
+ pixel_values: torch.Tensor,
367
+ attention_mask: Optional[torch.Tensor] = None,
368
+ output_hidden_states: Optional[bool] = None,
369
+ return_dict: Optional[bool] = None,
370
+ ) -> Union[
371
+ Tuple[
372
+ torch.Tensor,
373
+ torch.Tensor,
374
+ torch.Tensor,
375
+ torch.Tensor,
376
+ Union[Tuple[torch.Tensor, ...], BaseModelOutputWithNoAttention],
377
+ Union[Tuple[torch.Tensor, ...], BaseModelOutputWithNoAttention],
378
+ ],
379
+ AIMv2Output,
380
+ ]:
381
+ if return_dict is None:
382
+ return_dict = self.config.use_return_dict
383
+
384
+ image_out = self.image_encoder(
385
+ pixel_values,
386
+ output_hidden_states=output_hidden_states,
387
+ return_dict=return_dict,
388
+ )
389
+ image_features = image_out.last_hidden_state if return_dict else image_out[0]
390
+ image_features = self.image_projector(image_features)
391
+ image_features = F.normalize(image_features, p=2, dim=-1)
392
+
393
+ text_out = self.text_encoder(
394
+ input_ids,
395
+ mask=attention_mask,
396
+ output_hidden_states=output_hidden_states,
397
+ return_dict=return_dict,
398
+ )
399
+ text_features = text_out.last_hidden_state if return_dict else text_out[0]
400
+ text_features = self.text_projector(text_features)
401
+ text_features = F.normalize(text_features, p=2, dim=-1)
402
+
403
+ logit_scale = self.log_logit_scale.clamp(0.0, self.max_log_logit_scale).exp()
404
+ logits_per_text = (logit_scale * text_features) @ image_features.t()
405
+ logits_per_image = logits_per_text.t()
406
+
407
+ if not return_dict:
408
+ return (
409
+ logits_per_image,
410
+ logits_per_text,
411
+ image_features,
412
+ text_features,
413
+ image_out,
414
+ text_out,
415
+ )
416
+
417
+ return AIMv2Output(
418
+ logits_per_image=logits_per_image,
419
+ logits_per_text=logits_per_text,
420
+ image_features=image_features,
421
+ text_features=text_features,
422
+ vision_output=image_out,
423
+ text_output=text_out,
424
+ )
425
+
426
+ def get_image_features(
427
+ self,
428
+ input_pixels: torch.Tensor,
429
+ attention_mask: Optional[torch.Tensor] = None,
430
+ ) -> torch.Tensor:
431
+ out = self.image_encoder(input_pixels, mask=attention_mask, return_dict=True)
432
+ image_features = self.image_projector(out.last_hidden_state)
433
+ return image_features
434
+
435
+ def get_text_features(
436
+ self,
437
+ input_ids: torch.Tensor,
438
+ attention_mask: Optional[torch.Tensor] = None,
439
+ ) -> torch.Tensor:
440
+ out = self.text_encoder(input_ids, mask=attention_mask, return_dict=True)
441
+ text_features = self.text_projector(out.last_hidden_state)
442
+ return text_features
preprocessor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPImageProcessor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "processor_class": "CLIPProcessor",
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ }
28
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<start_of_text>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<end_of_text>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<end_of_text>",
17
+ "unk_token": {
18
+ "content": "<end_of_text>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "49406": {
4
+ "content": "<start_of_text>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "49407": {
12
+ "content": "<end_of_text>",
13
+ "lstrip": false,
14
+ "normalized": true,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ }
19
+ },
20
+ "bos_token": "<start_of_text>",
21
+ "clean_up_tokenization_spaces": false,
22
+ "eos_token": "<end_of_text>",
23
+ "errors": "replace",
24
+ "model_max_length": 77,
25
+ "pad_token": "<end_of_text>",
26
+ "processor_class": "CLIPProcessor",
27
+ "tokenizer_class": "CLIPTokenizer",
28
+ "unk_token": "<end_of_text>"
29
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff