GoodBaiBai88 commited on
Commit
8c45550
1 Parent(s): 7a2bf1a

Upload M3DCLIP

Browse files
Files changed (5) hide show
  1. README.md +3 -3
  2. config.json +35 -35
  3. configuration_m3d_clip.py +42 -42
  4. model.safetensors +1 -1
  5. modeling_m3d_clip.py +225 -225
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
  license: apache-2.0
3
- metrics:
4
- - accuracy
5
- pipeline_tag: image-feature-extraction
6
  tags:
7
  - 3D medical CLIP
8
  - Image-text retrieval
 
 
 
9
  ---
10
 
11
  M3D-CLIP is one of the works in the [M3D](https://github.com/BAAI-DCAI/M3D) series.
 
1
  ---
2
  license: apache-2.0
 
 
 
3
  tags:
4
  - 3D medical CLIP
5
  - Image-text retrieval
6
+ metrics:
7
+ - accuracy
8
+ pipeline_tag: image-feature-extraction
9
  ---
10
 
11
  M3D-CLIP is one of the works in the [M3D](https://github.com/BAAI-DCAI/M3D) series.
config.json CHANGED
@@ -1,35 +1,35 @@
1
- {
2
- "architectures": [
3
- "M3DCLIP"
4
- ],
5
- "auto_map": {
6
- "AutoConfig": "configuration_m3d_clip.M3DCLIPConfig",
7
- "AutoModel": "modeling_m3d_clip.M3DCLIP"
8
- },
9
- "dropout_rate": 0,
10
- "gather_loss": true,
11
- "hidden_size": 768,
12
- "img_size": [
13
- 32,
14
- 256,
15
- 256
16
- ],
17
- "in_channels": 1,
18
- "language_model_name_or_path": "bert-base-uncased",
19
- "local_loss": false,
20
- "max_text_len": 128,
21
- "mlp_dim": 3072,
22
- "model_type": "m3d_clip",
23
- "num_heads": 12,
24
- "num_layers": 12,
25
- "patch_size": [
26
- 4,
27
- 16,
28
- 16
29
- ],
30
- "pos_embed": "perceptron",
31
- "spatial_dims": 3,
32
- "torch_dtype": "float32",
33
- "transformers_version": "4.40.1",
34
- "vocab_size": 30522
35
- }
 
1
+ {
2
+ "architectures": [
3
+ "M3DCLIP"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_m3d_clip.M3DCLIPConfig",
7
+ "AutoModel": "modeling_m3d_clip.M3DCLIP"
8
+ },
9
+ "dropout_rate": 0,
10
+ "gather_loss": true,
11
+ "hidden_size": 768,
12
+ "img_size": [
13
+ 32,
14
+ 256,
15
+ 256
16
+ ],
17
+ "in_channels": 1,
18
+ "language_model_name_or_path": "bert-base-uncased",
19
+ "local_loss": false,
20
+ "max_text_len": 128,
21
+ "mlp_dim": 3072,
22
+ "model_type": "m3d_clip",
23
+ "num_heads": 12,
24
+ "num_layers": 12,
25
+ "patch_size": [
26
+ 4,
27
+ 16,
28
+ 16
29
+ ],
30
+ "pos_embed": "perceptron",
31
+ "spatial_dims": 3,
32
+ "torch_dtype": "float32",
33
+ "transformers_version": "4.40.1",
34
+ "vocab_size": 30522
35
+ }
configuration_m3d_clip.py CHANGED
@@ -1,42 +1,42 @@
1
- from transformers import PretrainedConfig
2
-
3
-
4
-
5
- class M3DCLIPConfig(PretrainedConfig):
6
- model_type = "m3d_clip"
7
-
8
- def __init__(
9
- self,
10
- language_model_name_or_path: str = 'bert-base-uncased',
11
- local_loss: bool = False,
12
- gather_loss: bool = True,
13
- in_channels: int = 1,
14
- img_size: tuple = (32, 256, 256),
15
- patch_size: tuple = (4, 16, 16),
16
- hidden_size: int = 768,
17
- mlp_dim: int = 3072,
18
- num_layers: int = 12,
19
- num_heads: int = 12,
20
- pos_embed: str = "perceptron",
21
- dropout_rate: float = 0,
22
- spatial_dims: int = 3,
23
- max_text_len: int = 128,
24
- vocab_size: int = 30522,
25
- **kwargs,
26
- ):
27
- self.language_model_name_or_path = language_model_name_or_path
28
- self.in_channels = in_channels
29
- self.img_size = img_size
30
- self.patch_size = patch_size
31
- self.hidden_size = hidden_size
32
- self.mlp_dim = mlp_dim
33
- self.num_layers = num_layers
34
- self.num_heads = num_heads
35
- self.pos_embed = pos_embed
36
- self.dropout_rate = dropout_rate
37
- self.spatial_dims = spatial_dims
38
- self.local_loss = local_loss
39
- self.gather_loss = gather_loss
40
- self.max_text_len = max_text_len
41
- self.vocab_size = vocab_size
42
- super().__init__(**kwargs)
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+
5
+ class M3DCLIPConfig(PretrainedConfig):
6
+ model_type = "m3d_clip"
7
+
8
+ def __init__(
9
+ self,
10
+ language_model_name_or_path: str = 'bert-base-uncased',
11
+ local_loss: bool = False,
12
+ gather_loss: bool = True,
13
+ in_channels: int = 1,
14
+ img_size: tuple = (32, 256, 256),
15
+ patch_size: tuple = (4, 16, 16),
16
+ hidden_size: int = 768,
17
+ mlp_dim: int = 3072,
18
+ num_layers: int = 12,
19
+ num_heads: int = 12,
20
+ pos_embed: str = "perceptron",
21
+ dropout_rate: float = 0,
22
+ spatial_dims: int = 3,
23
+ max_text_len: int = 128,
24
+ vocab_size: int = 30522,
25
+ **kwargs,
26
+ ):
27
+ self.language_model_name_or_path = language_model_name_or_path
28
+ self.in_channels = in_channels
29
+ self.img_size = img_size
30
+ self.patch_size = patch_size
31
+ self.hidden_size = hidden_size
32
+ self.mlp_dim = mlp_dim
33
+ self.num_layers = num_layers
34
+ self.num_heads = num_heads
35
+ self.pos_embed = pos_embed
36
+ self.dropout_rate = dropout_rate
37
+ self.spatial_dims = spatial_dims
38
+ self.local_loss = local_loss
39
+ self.gather_loss = gather_loss
40
+ self.max_text_len = max_text_len
41
+ self.vocab_size = vocab_size
42
+ super().__init__(**kwargs)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6973db8a3018c033117b5f3fe9ff8ca3c6842532993e2c5723e892fcd18a235e
3
  size 792251956
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59abe6b2ad8b7cbab4d2026fcd2b7e0fe8f0d70d17ad5ede819b4100aa59b860
3
  size 792251956
modeling_m3d_clip.py CHANGED
@@ -1,225 +1,225 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from transformers import PreTrainedModel
6
- from collections.abc import Sequence
7
- from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
8
- from monai.networks.blocks.transformerblock import TransformerBlock
9
- try:
10
- import torch.distributed.nn
11
- from torch import distributed as dist
12
- has_distributed = True
13
- except ImportError:
14
- has_distributed = False
15
- from .configuration_m3d_clip import M3DCLIPConfig
16
- from transformers import BertModel, BertConfig
17
-
18
-
19
- def gather_features(
20
- image_features,
21
- text_features,
22
- local_loss=False,
23
- gather_with_grad=True,
24
- rank=0,
25
- world_size=1,
26
- ):
27
- assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
28
-
29
- # We gather tensors from all gpus
30
- if gather_with_grad:
31
- all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
32
- all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
33
- else:
34
- gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
35
- gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
36
- dist.all_gather(gathered_image_features, image_features)
37
- dist.all_gather(gathered_text_features, text_features)
38
- if not local_loss:
39
- # ensure grads for local rank when all_* features don't have a gradient
40
- gathered_image_features[rank] = image_features
41
- gathered_text_features[rank] = text_features
42
- all_image_features = torch.cat(gathered_image_features, dim=0)
43
- all_text_features = torch.cat(gathered_text_features, dim=0)
44
-
45
- return all_image_features, all_text_features
46
-
47
- class ViT(nn.Module):
48
- """
49
- Vision Transformer (ViT), based on: "Dosovitskiy et al.,
50
- An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
51
-
52
- ViT supports Torchscript but only works for Pytorch after 1.8.
53
- """
54
-
55
- def __init__(
56
- self,
57
- in_channels: int,
58
- img_size: Sequence[int] | int,
59
- patch_size: Sequence[int] | int,
60
- hidden_size: int = 768,
61
- mlp_dim: int = 3072,
62
- num_layers: int = 12,
63
- num_heads: int = 12,
64
- pos_embed: str = "conv",
65
- classification: bool = False,
66
- num_classes: int = 2,
67
- dropout_rate: float = 0.0,
68
- spatial_dims: int = 3,
69
- post_activation="Tanh",
70
- qkv_bias: bool = False,
71
- save_attn: bool = False,
72
- ) -> None:
73
- """
74
- Args:
75
- in_channels (int): dimension of input channels.
76
- img_size (Union[Sequence[int], int]): dimension of input image.
77
- patch_size (Union[Sequence[int], int]): dimension of patch size.
78
- hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
79
- mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
80
- num_layers (int, optional): number of transformer blocks. Defaults to 12.
81
- num_heads (int, optional): number of attention heads. Defaults to 12.
82
- pos_embed (str, optional): position embedding layer type. Defaults to "conv".
83
- classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
84
- num_classes (int, optional): number of classes if classification is used. Defaults to 2.
85
- dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
86
- spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
87
- post_activation (str, optional): add a final acivation function to the classification head
88
- when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
89
- Set to other values to remove this function.
90
- qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
91
- save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
92
-
93
- Examples::
94
-
95
- # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
96
- >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
97
-
98
- # for 3-channel with image size of (128,128,128), 24 layers and classification backbone
99
- >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
100
-
101
- # for 3-channel with image size of (224,224), 12 layers and classification backbone
102
- >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
103
-
104
- """
105
-
106
- super().__init__()
107
-
108
- if not (0 <= dropout_rate <= 1):
109
- raise ValueError("dropout_rate should be between 0 and 1.")
110
-
111
- if hidden_size % num_heads != 0:
112
- raise ValueError("hidden_size should be divisible by num_heads.")
113
- self.hidden_size = hidden_size
114
- self.classification = classification
115
- self.patch_embedding = PatchEmbeddingBlock(
116
- in_channels=in_channels,
117
- img_size=img_size,
118
- patch_size=patch_size,
119
- hidden_size=hidden_size,
120
- num_heads=num_heads,
121
- pos_embed=pos_embed,
122
- dropout_rate=dropout_rate,
123
- spatial_dims=spatial_dims,
124
- )
125
- self.blocks = nn.ModuleList(
126
- [
127
- TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
128
- for i in range(num_layers)
129
- ]
130
- )
131
- self.norm = nn.LayerNorm(hidden_size)
132
- if self.classification:
133
- self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
134
- # if post_activation == "Tanh":
135
- # self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())
136
- # else:
137
- # self.classification_head = nn.Linear(hidden_size, num_classes) # type: ignore
138
-
139
- def forward(self, x):
140
- x = self.patch_embedding(x)
141
- if hasattr(self, "cls_token"):
142
- cls_token = self.cls_token.expand(x.shape[0], -1, -1)
143
- x = torch.cat((cls_token, x), dim=1)
144
- hidden_states_out = []
145
- for blk in self.blocks:
146
- x = blk(x)
147
- hidden_states_out.append(x)
148
- x = self.norm(x)
149
- # if hasattr(self, "classification_head"):
150
- # x = self.classification_head(x[:, 0])
151
- return x, hidden_states_out
152
-
153
-
154
- class M3DCLIP(PreTrainedModel):
155
- config_class = M3DCLIPConfig
156
-
157
- def __init__(self, config):
158
- super().__init__(config)
159
- self.vision_encoder = ViT(
160
- in_channels=config.in_channels,
161
- img_size=config.img_size,
162
- patch_size=config.patch_size,
163
- hidden_size=config.hidden_size,
164
- mlp_dim=config.mlp_dim,
165
- num_layers=config.num_layers,
166
- num_heads=config.num_heads,
167
- pos_embed=config.pos_embed,
168
- dropout_rate=config.dropout_rate,
169
- spatial_dims=config.spatial_dims,
170
- classification=True,
171
- )
172
- # configuration = BertConfig()
173
- # self.language_encoder = BertModel(configuration)
174
- self.language_encoder = BertModel.from_pretrained(config.language_model_name_or_path)
175
-
176
- self.mm_vision_proj = nn.Linear(config.hidden_size, config.hidden_size)
177
- self.mm_language_proj = nn.Linear(config.hidden_size, config.hidden_size)
178
-
179
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
180
-
181
- self.local_loss = config.local_loss
182
- self.gather_loss = config.gather_loss
183
-
184
- def encode_image(self, image):
185
- image_feats, _ = self.vision_encoder(image)
186
- image_feats = self.mm_vision_proj(image_feats)
187
- image_feats = F.normalize(image_feats, dim=-1)
188
-
189
- return image_feats
190
-
191
- def encode_text(self, input_id, attention_mask):
192
- text_feats = self.language_encoder(input_id, attention_mask=attention_mask)["last_hidden_state"]
193
- text_feats = self.mm_language_proj(text_feats)
194
- text_feats = F.normalize(text_feats, dim=-1)
195
-
196
- return text_feats
197
-
198
-
199
- def forward(self, images, input_ids, attention_mask, labels, **kwargs):
200
- image_features = self.encode_image(images)[:, 0]
201
- text_features = self.encode_text(input_ids, attention_mask)[:, 0]
202
-
203
- if self.gather_loss:
204
- all_image_features, all_text_features = gather_features(image_features, text_features)
205
- if self.local_loss:
206
- logits_per_image = self.logit_scale * image_features @ all_text_features.T
207
- logits_per_text = self.logit_scale * text_features @ all_image_features.T
208
- else:
209
- logits_per_image = self.logit_scale * all_image_features @ all_text_features.T
210
- logits_per_text = logits_per_image.T
211
- else:
212
- logits_per_image = self.logit_scale * image_features @ text_features.T
213
- logits_per_text = self.logit_scale * text_features @ image_features.T
214
-
215
- loss = (
216
- F.cross_entropy(logits_per_image, labels) +
217
- F.cross_entropy(logits_per_text, labels)
218
- ) / 2
219
-
220
- ret = {
221
- "loss": loss,
222
- "logits": (logits_per_image + logits_per_text) / 2.0,
223
- }
224
-
225
- return ret
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel
6
+ from collections.abc import Sequence
7
+ from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
8
+ from monai.networks.blocks.transformerblock import TransformerBlock
9
+ try:
10
+ import torch.distributed.nn
11
+ from torch import distributed as dist
12
+ has_distributed = True
13
+ except ImportError:
14
+ has_distributed = False
15
+ from .configuration_m3d_clip import M3DCLIPConfig
16
+ from transformers import BertModel, BertConfig
17
+
18
+
19
+ def gather_features(
20
+ image_features,
21
+ text_features,
22
+ local_loss=False,
23
+ gather_with_grad=True,
24
+ rank=0,
25
+ world_size=1,
26
+ ):
27
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
28
+
29
+ # We gather tensors from all gpus
30
+ if gather_with_grad:
31
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
32
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
33
+ else:
34
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
35
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
36
+ dist.all_gather(gathered_image_features, image_features)
37
+ dist.all_gather(gathered_text_features, text_features)
38
+ if not local_loss:
39
+ # ensure grads for local rank when all_* features don't have a gradient
40
+ gathered_image_features[rank] = image_features
41
+ gathered_text_features[rank] = text_features
42
+ all_image_features = torch.cat(gathered_image_features, dim=0)
43
+ all_text_features = torch.cat(gathered_text_features, dim=0)
44
+
45
+ return all_image_features, all_text_features
46
+
47
+ class ViT(nn.Module):
48
+ """
49
+ Vision Transformer (ViT), based on: "Dosovitskiy et al.,
50
+ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
51
+
52
+ ViT supports Torchscript but only works for Pytorch after 1.8.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ in_channels: int,
58
+ img_size: Sequence[int] | int,
59
+ patch_size: Sequence[int] | int,
60
+ hidden_size: int = 768,
61
+ mlp_dim: int = 3072,
62
+ num_layers: int = 12,
63
+ num_heads: int = 12,
64
+ pos_embed: str = "conv",
65
+ classification: bool = False,
66
+ num_classes: int = 2,
67
+ dropout_rate: float = 0.0,
68
+ spatial_dims: int = 3,
69
+ post_activation="Tanh",
70
+ qkv_bias: bool = False,
71
+ save_attn: bool = False,
72
+ ) -> None:
73
+ """
74
+ Args:
75
+ in_channels (int): dimension of input channels.
76
+ img_size (Union[Sequence[int], int]): dimension of input image.
77
+ patch_size (Union[Sequence[int], int]): dimension of patch size.
78
+ hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
79
+ mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
80
+ num_layers (int, optional): number of transformer blocks. Defaults to 12.
81
+ num_heads (int, optional): number of attention heads. Defaults to 12.
82
+ pos_embed (str, optional): position embedding layer type. Defaults to "conv".
83
+ classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
84
+ num_classes (int, optional): number of classes if classification is used. Defaults to 2.
85
+ dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
86
+ spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
87
+ post_activation (str, optional): add a final acivation function to the classification head
88
+ when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
89
+ Set to other values to remove this function.
90
+ qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
91
+ save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
92
+
93
+ Examples::
94
+
95
+ # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
96
+ >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
97
+
98
+ # for 3-channel with image size of (128,128,128), 24 layers and classification backbone
99
+ >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
100
+
101
+ # for 3-channel with image size of (224,224), 12 layers and classification backbone
102
+ >>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
103
+
104
+ """
105
+
106
+ super().__init__()
107
+
108
+ if not (0 <= dropout_rate <= 1):
109
+ raise ValueError("dropout_rate should be between 0 and 1.")
110
+
111
+ if hidden_size % num_heads != 0:
112
+ raise ValueError("hidden_size should be divisible by num_heads.")
113
+ self.hidden_size = hidden_size
114
+ self.classification = classification
115
+ self.patch_embedding = PatchEmbeddingBlock(
116
+ in_channels=in_channels,
117
+ img_size=img_size,
118
+ patch_size=patch_size,
119
+ hidden_size=hidden_size,
120
+ num_heads=num_heads,
121
+ pos_embed=pos_embed,
122
+ dropout_rate=dropout_rate,
123
+ spatial_dims=spatial_dims,
124
+ )
125
+ self.blocks = nn.ModuleList(
126
+ [
127
+ TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
128
+ for i in range(num_layers)
129
+ ]
130
+ )
131
+ self.norm = nn.LayerNorm(hidden_size)
132
+ if self.classification:
133
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
134
+ # if post_activation == "Tanh":
135
+ # self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())
136
+ # else:
137
+ # self.classification_head = nn.Linear(hidden_size, num_classes) # type: ignore
138
+
139
+ def forward(self, x):
140
+ x = self.patch_embedding(x)
141
+ if hasattr(self, "cls_token"):
142
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
143
+ x = torch.cat((cls_token, x), dim=1)
144
+ hidden_states_out = []
145
+ for blk in self.blocks:
146
+ x = blk(x)
147
+ hidden_states_out.append(x)
148
+ x = self.norm(x)
149
+ # if hasattr(self, "classification_head"):
150
+ # x = self.classification_head(x[:, 0])
151
+ return x, hidden_states_out
152
+
153
+
154
+ class M3DCLIP(PreTrainedModel):
155
+ config_class = M3DCLIPConfig
156
+
157
+ def __init__(self, config):
158
+ super().__init__(config)
159
+ self.vision_encoder = ViT(
160
+ in_channels=config.in_channels,
161
+ img_size=config.img_size,
162
+ patch_size=config.patch_size,
163
+ hidden_size=config.hidden_size,
164
+ mlp_dim=config.mlp_dim,
165
+ num_layers=config.num_layers,
166
+ num_heads=config.num_heads,
167
+ pos_embed=config.pos_embed,
168
+ dropout_rate=config.dropout_rate,
169
+ spatial_dims=config.spatial_dims,
170
+ classification=True,
171
+ )
172
+ # configuration = BertConfig()
173
+ # self.language_encoder = BertModel(configuration)
174
+ self.language_encoder = BertModel.from_pretrained(config.language_model_name_or_path)
175
+
176
+ self.mm_vision_proj = nn.Linear(config.hidden_size, config.hidden_size)
177
+ self.mm_language_proj = nn.Linear(config.hidden_size, config.hidden_size)
178
+
179
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
180
+
181
+ self.local_loss = config.local_loss
182
+ self.gather_loss = config.gather_loss
183
+
184
+ def encode_image(self, image):
185
+ image_feats, _ = self.vision_encoder(image)
186
+ image_feats = self.mm_vision_proj(image_feats)
187
+ image_feats = F.normalize(image_feats, dim=-1)
188
+
189
+ return image_feats
190
+
191
+ def encode_text(self, input_id, attention_mask):
192
+ text_feats = self.language_encoder(input_id, attention_mask=attention_mask)["last_hidden_state"]
193
+ text_feats = self.mm_language_proj(text_feats)
194
+ text_feats = F.normalize(text_feats, dim=-1)
195
+
196
+ return text_feats
197
+
198
+
199
+ def forward(self, images, input_ids, attention_mask, labels, **kwargs):
200
+ image_features = self.encode_image(images)[:, 0]
201
+ text_features = self.encode_text(input_ids, attention_mask)[:, 0]
202
+
203
+ if self.gather_loss:
204
+ all_image_features, all_text_features = gather_features(image_features, text_features)
205
+ if self.local_loss:
206
+ logits_per_image = self.logit_scale * image_features @ all_text_features.T
207
+ logits_per_text = self.logit_scale * text_features @ all_image_features.T
208
+ else:
209
+ logits_per_image = self.logit_scale * all_image_features @ all_text_features.T
210
+ logits_per_text = logits_per_image.T
211
+ else:
212
+ logits_per_image = self.logit_scale * image_features @ text_features.T
213
+ logits_per_text = self.logit_scale * text_features @ image_features.T
214
+
215
+ loss = (
216
+ F.cross_entropy(logits_per_image, labels) +
217
+ F.cross_entropy(logits_per_text, labels)
218
+ ) / 2
219
+
220
+ ret = {
221
+ "loss": loss,
222
+ "logits": (logits_per_image + logits_per_text) / 2.0,
223
+ }
224
+
225
+ return ret