GoodBaiBai88
commited on
Commit
•
8c45550
1
Parent(s):
7a2bf1a
Upload M3DCLIP
Browse files- README.md +3 -3
- config.json +35 -35
- configuration_m3d_clip.py +42 -42
- model.safetensors +1 -1
- modeling_m3d_clip.py +225 -225
README.md
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
-
metrics:
|
4 |
-
- accuracy
|
5 |
-
pipeline_tag: image-feature-extraction
|
6 |
tags:
|
7 |
- 3D medical CLIP
|
8 |
- Image-text retrieval
|
|
|
|
|
|
|
9 |
---
|
10 |
|
11 |
M3D-CLIP is one of the works in the [M3D](https://github.com/BAAI-DCAI/M3D) series.
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
|
|
3 |
tags:
|
4 |
- 3D medical CLIP
|
5 |
- Image-text retrieval
|
6 |
+
metrics:
|
7 |
+
- accuracy
|
8 |
+
pipeline_tag: image-feature-extraction
|
9 |
---
|
10 |
|
11 |
M3D-CLIP is one of the works in the [M3D](https://github.com/BAAI-DCAI/M3D) series.
|
config.json
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
{
|
2 |
-
"architectures": [
|
3 |
-
"M3DCLIP"
|
4 |
-
],
|
5 |
-
"auto_map": {
|
6 |
-
"AutoConfig": "configuration_m3d_clip.M3DCLIPConfig",
|
7 |
-
"AutoModel": "modeling_m3d_clip.M3DCLIP"
|
8 |
-
},
|
9 |
-
"dropout_rate": 0,
|
10 |
-
"gather_loss": true,
|
11 |
-
"hidden_size": 768,
|
12 |
-
"img_size": [
|
13 |
-
32,
|
14 |
-
256,
|
15 |
-
256
|
16 |
-
],
|
17 |
-
"in_channels": 1,
|
18 |
-
"language_model_name_or_path": "bert-base-uncased",
|
19 |
-
"local_loss": false,
|
20 |
-
"max_text_len": 128,
|
21 |
-
"mlp_dim": 3072,
|
22 |
-
"model_type": "m3d_clip",
|
23 |
-
"num_heads": 12,
|
24 |
-
"num_layers": 12,
|
25 |
-
"patch_size": [
|
26 |
-
4,
|
27 |
-
16,
|
28 |
-
16
|
29 |
-
],
|
30 |
-
"pos_embed": "perceptron",
|
31 |
-
"spatial_dims": 3,
|
32 |
-
"torch_dtype": "float32",
|
33 |
-
"transformers_version": "4.40.1",
|
34 |
-
"vocab_size": 30522
|
35 |
-
}
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"M3DCLIP"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_m3d_clip.M3DCLIPConfig",
|
7 |
+
"AutoModel": "modeling_m3d_clip.M3DCLIP"
|
8 |
+
},
|
9 |
+
"dropout_rate": 0,
|
10 |
+
"gather_loss": true,
|
11 |
+
"hidden_size": 768,
|
12 |
+
"img_size": [
|
13 |
+
32,
|
14 |
+
256,
|
15 |
+
256
|
16 |
+
],
|
17 |
+
"in_channels": 1,
|
18 |
+
"language_model_name_or_path": "bert-base-uncased",
|
19 |
+
"local_loss": false,
|
20 |
+
"max_text_len": 128,
|
21 |
+
"mlp_dim": 3072,
|
22 |
+
"model_type": "m3d_clip",
|
23 |
+
"num_heads": 12,
|
24 |
+
"num_layers": 12,
|
25 |
+
"patch_size": [
|
26 |
+
4,
|
27 |
+
16,
|
28 |
+
16
|
29 |
+
],
|
30 |
+
"pos_embed": "perceptron",
|
31 |
+
"spatial_dims": 3,
|
32 |
+
"torch_dtype": "float32",
|
33 |
+
"transformers_version": "4.40.1",
|
34 |
+
"vocab_size": 30522
|
35 |
+
}
|
configuration_m3d_clip.py
CHANGED
@@ -1,42 +1,42 @@
|
|
1 |
-
from transformers import PretrainedConfig
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
class M3DCLIPConfig(PretrainedConfig):
|
6 |
-
model_type = "m3d_clip"
|
7 |
-
|
8 |
-
def __init__(
|
9 |
-
self,
|
10 |
-
language_model_name_or_path: str = 'bert-base-uncased',
|
11 |
-
local_loss: bool = False,
|
12 |
-
gather_loss: bool = True,
|
13 |
-
in_channels: int = 1,
|
14 |
-
img_size: tuple = (32, 256, 256),
|
15 |
-
patch_size: tuple = (4, 16, 16),
|
16 |
-
hidden_size: int = 768,
|
17 |
-
mlp_dim: int = 3072,
|
18 |
-
num_layers: int = 12,
|
19 |
-
num_heads: int = 12,
|
20 |
-
pos_embed: str = "perceptron",
|
21 |
-
dropout_rate: float = 0,
|
22 |
-
spatial_dims: int = 3,
|
23 |
-
max_text_len: int = 128,
|
24 |
-
vocab_size: int = 30522,
|
25 |
-
**kwargs,
|
26 |
-
):
|
27 |
-
self.language_model_name_or_path = language_model_name_or_path
|
28 |
-
self.in_channels = in_channels
|
29 |
-
self.img_size = img_size
|
30 |
-
self.patch_size = patch_size
|
31 |
-
self.hidden_size = hidden_size
|
32 |
-
self.mlp_dim = mlp_dim
|
33 |
-
self.num_layers = num_layers
|
34 |
-
self.num_heads = num_heads
|
35 |
-
self.pos_embed = pos_embed
|
36 |
-
self.dropout_rate = dropout_rate
|
37 |
-
self.spatial_dims = spatial_dims
|
38 |
-
self.local_loss = local_loss
|
39 |
-
self.gather_loss = gather_loss
|
40 |
-
self.max_text_len = max_text_len
|
41 |
-
self.vocab_size = vocab_size
|
42 |
-
super().__init__(**kwargs)
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
class M3DCLIPConfig(PretrainedConfig):
|
6 |
+
model_type = "m3d_clip"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
language_model_name_or_path: str = 'bert-base-uncased',
|
11 |
+
local_loss: bool = False,
|
12 |
+
gather_loss: bool = True,
|
13 |
+
in_channels: int = 1,
|
14 |
+
img_size: tuple = (32, 256, 256),
|
15 |
+
patch_size: tuple = (4, 16, 16),
|
16 |
+
hidden_size: int = 768,
|
17 |
+
mlp_dim: int = 3072,
|
18 |
+
num_layers: int = 12,
|
19 |
+
num_heads: int = 12,
|
20 |
+
pos_embed: str = "perceptron",
|
21 |
+
dropout_rate: float = 0,
|
22 |
+
spatial_dims: int = 3,
|
23 |
+
max_text_len: int = 128,
|
24 |
+
vocab_size: int = 30522,
|
25 |
+
**kwargs,
|
26 |
+
):
|
27 |
+
self.language_model_name_or_path = language_model_name_or_path
|
28 |
+
self.in_channels = in_channels
|
29 |
+
self.img_size = img_size
|
30 |
+
self.patch_size = patch_size
|
31 |
+
self.hidden_size = hidden_size
|
32 |
+
self.mlp_dim = mlp_dim
|
33 |
+
self.num_layers = num_layers
|
34 |
+
self.num_heads = num_heads
|
35 |
+
self.pos_embed = pos_embed
|
36 |
+
self.dropout_rate = dropout_rate
|
37 |
+
self.spatial_dims = spatial_dims
|
38 |
+
self.local_loss = local_loss
|
39 |
+
self.gather_loss = gather_loss
|
40 |
+
self.max_text_len = max_text_len
|
41 |
+
self.vocab_size = vocab_size
|
42 |
+
super().__init__(**kwargs)
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 792251956
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59abe6b2ad8b7cbab4d2026fcd2b7e0fe8f0d70d17ad5ede819b4100aa59b860
|
3 |
size 792251956
|
modeling_m3d_clip.py
CHANGED
@@ -1,225 +1,225 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
from transformers import PreTrainedModel
|
6 |
-
from collections.abc import Sequence
|
7 |
-
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
|
8 |
-
from monai.networks.blocks.transformerblock import TransformerBlock
|
9 |
-
try:
|
10 |
-
import torch.distributed.nn
|
11 |
-
from torch import distributed as dist
|
12 |
-
has_distributed = True
|
13 |
-
except ImportError:
|
14 |
-
has_distributed = False
|
15 |
-
from .configuration_m3d_clip import M3DCLIPConfig
|
16 |
-
from transformers import BertModel, BertConfig
|
17 |
-
|
18 |
-
|
19 |
-
def gather_features(
|
20 |
-
image_features,
|
21 |
-
text_features,
|
22 |
-
local_loss=False,
|
23 |
-
gather_with_grad=True,
|
24 |
-
rank=0,
|
25 |
-
world_size=1,
|
26 |
-
):
|
27 |
-
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
28 |
-
|
29 |
-
# We gather tensors from all gpus
|
30 |
-
if gather_with_grad:
|
31 |
-
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
32 |
-
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
33 |
-
else:
|
34 |
-
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
35 |
-
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
36 |
-
dist.all_gather(gathered_image_features, image_features)
|
37 |
-
dist.all_gather(gathered_text_features, text_features)
|
38 |
-
if not local_loss:
|
39 |
-
# ensure grads for local rank when all_* features don't have a gradient
|
40 |
-
gathered_image_features[rank] = image_features
|
41 |
-
gathered_text_features[rank] = text_features
|
42 |
-
all_image_features = torch.cat(gathered_image_features, dim=0)
|
43 |
-
all_text_features = torch.cat(gathered_text_features, dim=0)
|
44 |
-
|
45 |
-
return all_image_features, all_text_features
|
46 |
-
|
47 |
-
class ViT(nn.Module):
|
48 |
-
"""
|
49 |
-
Vision Transformer (ViT), based on: "Dosovitskiy et al.,
|
50 |
-
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
|
51 |
-
|
52 |
-
ViT supports Torchscript but only works for Pytorch after 1.8.
|
53 |
-
"""
|
54 |
-
|
55 |
-
def __init__(
|
56 |
-
self,
|
57 |
-
in_channels: int,
|
58 |
-
img_size: Sequence[int] | int,
|
59 |
-
patch_size: Sequence[int] | int,
|
60 |
-
hidden_size: int = 768,
|
61 |
-
mlp_dim: int = 3072,
|
62 |
-
num_layers: int = 12,
|
63 |
-
num_heads: int = 12,
|
64 |
-
pos_embed: str = "conv",
|
65 |
-
classification: bool = False,
|
66 |
-
num_classes: int = 2,
|
67 |
-
dropout_rate: float = 0.0,
|
68 |
-
spatial_dims: int = 3,
|
69 |
-
post_activation="Tanh",
|
70 |
-
qkv_bias: bool = False,
|
71 |
-
save_attn: bool = False,
|
72 |
-
) -> None:
|
73 |
-
"""
|
74 |
-
Args:
|
75 |
-
in_channels (int): dimension of input channels.
|
76 |
-
img_size (Union[Sequence[int], int]): dimension of input image.
|
77 |
-
patch_size (Union[Sequence[int], int]): dimension of patch size.
|
78 |
-
hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
|
79 |
-
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
|
80 |
-
num_layers (int, optional): number of transformer blocks. Defaults to 12.
|
81 |
-
num_heads (int, optional): number of attention heads. Defaults to 12.
|
82 |
-
pos_embed (str, optional): position embedding layer type. Defaults to "conv".
|
83 |
-
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
|
84 |
-
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
|
85 |
-
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
|
86 |
-
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
|
87 |
-
post_activation (str, optional): add a final acivation function to the classification head
|
88 |
-
when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
|
89 |
-
Set to other values to remove this function.
|
90 |
-
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
|
91 |
-
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
|
92 |
-
|
93 |
-
Examples::
|
94 |
-
|
95 |
-
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
|
96 |
-
>>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
|
97 |
-
|
98 |
-
# for 3-channel with image size of (128,128,128), 24 layers and classification backbone
|
99 |
-
>>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
|
100 |
-
|
101 |
-
# for 3-channel with image size of (224,224), 12 layers and classification backbone
|
102 |
-
>>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
|
103 |
-
|
104 |
-
"""
|
105 |
-
|
106 |
-
super().__init__()
|
107 |
-
|
108 |
-
if not (0 <= dropout_rate <= 1):
|
109 |
-
raise ValueError("dropout_rate should be between 0 and 1.")
|
110 |
-
|
111 |
-
if hidden_size % num_heads != 0:
|
112 |
-
raise ValueError("hidden_size should be divisible by num_heads.")
|
113 |
-
self.hidden_size = hidden_size
|
114 |
-
self.classification = classification
|
115 |
-
self.patch_embedding = PatchEmbeddingBlock(
|
116 |
-
in_channels=in_channels,
|
117 |
-
img_size=img_size,
|
118 |
-
patch_size=patch_size,
|
119 |
-
hidden_size=hidden_size,
|
120 |
-
num_heads=num_heads,
|
121 |
-
pos_embed=pos_embed,
|
122 |
-
dropout_rate=dropout_rate,
|
123 |
-
spatial_dims=spatial_dims,
|
124 |
-
)
|
125 |
-
self.blocks = nn.ModuleList(
|
126 |
-
[
|
127 |
-
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
|
128 |
-
for i in range(num_layers)
|
129 |
-
]
|
130 |
-
)
|
131 |
-
self.norm = nn.LayerNorm(hidden_size)
|
132 |
-
if self.classification:
|
133 |
-
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
134 |
-
# if post_activation == "Tanh":
|
135 |
-
# self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())
|
136 |
-
# else:
|
137 |
-
# self.classification_head = nn.Linear(hidden_size, num_classes) # type: ignore
|
138 |
-
|
139 |
-
def forward(self, x):
|
140 |
-
x = self.patch_embedding(x)
|
141 |
-
if hasattr(self, "cls_token"):
|
142 |
-
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
143 |
-
x = torch.cat((cls_token, x), dim=1)
|
144 |
-
hidden_states_out = []
|
145 |
-
for blk in self.blocks:
|
146 |
-
x = blk(x)
|
147 |
-
hidden_states_out.append(x)
|
148 |
-
x = self.norm(x)
|
149 |
-
# if hasattr(self, "classification_head"):
|
150 |
-
# x = self.classification_head(x[:, 0])
|
151 |
-
return x, hidden_states_out
|
152 |
-
|
153 |
-
|
154 |
-
class M3DCLIP(PreTrainedModel):
|
155 |
-
config_class = M3DCLIPConfig
|
156 |
-
|
157 |
-
def __init__(self, config):
|
158 |
-
super().__init__(config)
|
159 |
-
self.vision_encoder = ViT(
|
160 |
-
in_channels=config.in_channels,
|
161 |
-
img_size=config.img_size,
|
162 |
-
patch_size=config.patch_size,
|
163 |
-
hidden_size=config.hidden_size,
|
164 |
-
mlp_dim=config.mlp_dim,
|
165 |
-
num_layers=config.num_layers,
|
166 |
-
num_heads=config.num_heads,
|
167 |
-
pos_embed=config.pos_embed,
|
168 |
-
dropout_rate=config.dropout_rate,
|
169 |
-
spatial_dims=config.spatial_dims,
|
170 |
-
classification=True,
|
171 |
-
)
|
172 |
-
# configuration = BertConfig()
|
173 |
-
# self.language_encoder = BertModel(configuration)
|
174 |
-
self.language_encoder = BertModel.from_pretrained(config.language_model_name_or_path)
|
175 |
-
|
176 |
-
self.mm_vision_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
177 |
-
self.mm_language_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
178 |
-
|
179 |
-
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
180 |
-
|
181 |
-
self.local_loss = config.local_loss
|
182 |
-
self.gather_loss = config.gather_loss
|
183 |
-
|
184 |
-
def encode_image(self, image):
|
185 |
-
image_feats, _ = self.vision_encoder(image)
|
186 |
-
image_feats = self.mm_vision_proj(image_feats)
|
187 |
-
image_feats = F.normalize(image_feats, dim=-1)
|
188 |
-
|
189 |
-
return image_feats
|
190 |
-
|
191 |
-
def encode_text(self, input_id, attention_mask):
|
192 |
-
text_feats = self.language_encoder(input_id, attention_mask=attention_mask)["last_hidden_state"]
|
193 |
-
text_feats = self.mm_language_proj(text_feats)
|
194 |
-
text_feats = F.normalize(text_feats, dim=-1)
|
195 |
-
|
196 |
-
return text_feats
|
197 |
-
|
198 |
-
|
199 |
-
def forward(self, images, input_ids, attention_mask, labels, **kwargs):
|
200 |
-
image_features = self.encode_image(images)[:, 0]
|
201 |
-
text_features = self.encode_text(input_ids, attention_mask)[:, 0]
|
202 |
-
|
203 |
-
if self.gather_loss:
|
204 |
-
all_image_features, all_text_features = gather_features(image_features, text_features)
|
205 |
-
if self.local_loss:
|
206 |
-
logits_per_image = self.logit_scale * image_features @ all_text_features.T
|
207 |
-
logits_per_text = self.logit_scale * text_features @ all_image_features.T
|
208 |
-
else:
|
209 |
-
logits_per_image = self.logit_scale * all_image_features @ all_text_features.T
|
210 |
-
logits_per_text = logits_per_image.T
|
211 |
-
else:
|
212 |
-
logits_per_image = self.logit_scale * image_features @ text_features.T
|
213 |
-
logits_per_text = self.logit_scale * text_features @ image_features.T
|
214 |
-
|
215 |
-
loss = (
|
216 |
-
F.cross_entropy(logits_per_image, labels) +
|
217 |
-
F.cross_entropy(logits_per_text, labels)
|
218 |
-
) / 2
|
219 |
-
|
220 |
-
ret = {
|
221 |
-
"loss": loss,
|
222 |
-
"logits": (logits_per_image + logits_per_text) / 2.0,
|
223 |
-
}
|
224 |
-
|
225 |
-
return ret
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers import PreTrainedModel
|
6 |
+
from collections.abc import Sequence
|
7 |
+
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
|
8 |
+
from monai.networks.blocks.transformerblock import TransformerBlock
|
9 |
+
try:
|
10 |
+
import torch.distributed.nn
|
11 |
+
from torch import distributed as dist
|
12 |
+
has_distributed = True
|
13 |
+
except ImportError:
|
14 |
+
has_distributed = False
|
15 |
+
from .configuration_m3d_clip import M3DCLIPConfig
|
16 |
+
from transformers import BertModel, BertConfig
|
17 |
+
|
18 |
+
|
19 |
+
def gather_features(
|
20 |
+
image_features,
|
21 |
+
text_features,
|
22 |
+
local_loss=False,
|
23 |
+
gather_with_grad=True,
|
24 |
+
rank=0,
|
25 |
+
world_size=1,
|
26 |
+
):
|
27 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
28 |
+
|
29 |
+
# We gather tensors from all gpus
|
30 |
+
if gather_with_grad:
|
31 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
32 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
33 |
+
else:
|
34 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
35 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
36 |
+
dist.all_gather(gathered_image_features, image_features)
|
37 |
+
dist.all_gather(gathered_text_features, text_features)
|
38 |
+
if not local_loss:
|
39 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
40 |
+
gathered_image_features[rank] = image_features
|
41 |
+
gathered_text_features[rank] = text_features
|
42 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
43 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
44 |
+
|
45 |
+
return all_image_features, all_text_features
|
46 |
+
|
47 |
+
class ViT(nn.Module):
|
48 |
+
"""
|
49 |
+
Vision Transformer (ViT), based on: "Dosovitskiy et al.,
|
50 |
+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
|
51 |
+
|
52 |
+
ViT supports Torchscript but only works for Pytorch after 1.8.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_channels: int,
|
58 |
+
img_size: Sequence[int] | int,
|
59 |
+
patch_size: Sequence[int] | int,
|
60 |
+
hidden_size: int = 768,
|
61 |
+
mlp_dim: int = 3072,
|
62 |
+
num_layers: int = 12,
|
63 |
+
num_heads: int = 12,
|
64 |
+
pos_embed: str = "conv",
|
65 |
+
classification: bool = False,
|
66 |
+
num_classes: int = 2,
|
67 |
+
dropout_rate: float = 0.0,
|
68 |
+
spatial_dims: int = 3,
|
69 |
+
post_activation="Tanh",
|
70 |
+
qkv_bias: bool = False,
|
71 |
+
save_attn: bool = False,
|
72 |
+
) -> None:
|
73 |
+
"""
|
74 |
+
Args:
|
75 |
+
in_channels (int): dimension of input channels.
|
76 |
+
img_size (Union[Sequence[int], int]): dimension of input image.
|
77 |
+
patch_size (Union[Sequence[int], int]): dimension of patch size.
|
78 |
+
hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
|
79 |
+
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
|
80 |
+
num_layers (int, optional): number of transformer blocks. Defaults to 12.
|
81 |
+
num_heads (int, optional): number of attention heads. Defaults to 12.
|
82 |
+
pos_embed (str, optional): position embedding layer type. Defaults to "conv".
|
83 |
+
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
|
84 |
+
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
|
85 |
+
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
|
86 |
+
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
|
87 |
+
post_activation (str, optional): add a final acivation function to the classification head
|
88 |
+
when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
|
89 |
+
Set to other values to remove this function.
|
90 |
+
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
|
91 |
+
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
|
92 |
+
|
93 |
+
Examples::
|
94 |
+
|
95 |
+
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
|
96 |
+
>>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
|
97 |
+
|
98 |
+
# for 3-channel with image size of (128,128,128), 24 layers and classification backbone
|
99 |
+
>>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
|
100 |
+
|
101 |
+
# for 3-channel with image size of (224,224), 12 layers and classification backbone
|
102 |
+
>>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
|
103 |
+
|
104 |
+
"""
|
105 |
+
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
if not (0 <= dropout_rate <= 1):
|
109 |
+
raise ValueError("dropout_rate should be between 0 and 1.")
|
110 |
+
|
111 |
+
if hidden_size % num_heads != 0:
|
112 |
+
raise ValueError("hidden_size should be divisible by num_heads.")
|
113 |
+
self.hidden_size = hidden_size
|
114 |
+
self.classification = classification
|
115 |
+
self.patch_embedding = PatchEmbeddingBlock(
|
116 |
+
in_channels=in_channels,
|
117 |
+
img_size=img_size,
|
118 |
+
patch_size=patch_size,
|
119 |
+
hidden_size=hidden_size,
|
120 |
+
num_heads=num_heads,
|
121 |
+
pos_embed=pos_embed,
|
122 |
+
dropout_rate=dropout_rate,
|
123 |
+
spatial_dims=spatial_dims,
|
124 |
+
)
|
125 |
+
self.blocks = nn.ModuleList(
|
126 |
+
[
|
127 |
+
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
|
128 |
+
for i in range(num_layers)
|
129 |
+
]
|
130 |
+
)
|
131 |
+
self.norm = nn.LayerNorm(hidden_size)
|
132 |
+
if self.classification:
|
133 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
134 |
+
# if post_activation == "Tanh":
|
135 |
+
# self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())
|
136 |
+
# else:
|
137 |
+
# self.classification_head = nn.Linear(hidden_size, num_classes) # type: ignore
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
x = self.patch_embedding(x)
|
141 |
+
if hasattr(self, "cls_token"):
|
142 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
143 |
+
x = torch.cat((cls_token, x), dim=1)
|
144 |
+
hidden_states_out = []
|
145 |
+
for blk in self.blocks:
|
146 |
+
x = blk(x)
|
147 |
+
hidden_states_out.append(x)
|
148 |
+
x = self.norm(x)
|
149 |
+
# if hasattr(self, "classification_head"):
|
150 |
+
# x = self.classification_head(x[:, 0])
|
151 |
+
return x, hidden_states_out
|
152 |
+
|
153 |
+
|
154 |
+
class M3DCLIP(PreTrainedModel):
|
155 |
+
config_class = M3DCLIPConfig
|
156 |
+
|
157 |
+
def __init__(self, config):
|
158 |
+
super().__init__(config)
|
159 |
+
self.vision_encoder = ViT(
|
160 |
+
in_channels=config.in_channels,
|
161 |
+
img_size=config.img_size,
|
162 |
+
patch_size=config.patch_size,
|
163 |
+
hidden_size=config.hidden_size,
|
164 |
+
mlp_dim=config.mlp_dim,
|
165 |
+
num_layers=config.num_layers,
|
166 |
+
num_heads=config.num_heads,
|
167 |
+
pos_embed=config.pos_embed,
|
168 |
+
dropout_rate=config.dropout_rate,
|
169 |
+
spatial_dims=config.spatial_dims,
|
170 |
+
classification=True,
|
171 |
+
)
|
172 |
+
# configuration = BertConfig()
|
173 |
+
# self.language_encoder = BertModel(configuration)
|
174 |
+
self.language_encoder = BertModel.from_pretrained(config.language_model_name_or_path)
|
175 |
+
|
176 |
+
self.mm_vision_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
177 |
+
self.mm_language_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
178 |
+
|
179 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
180 |
+
|
181 |
+
self.local_loss = config.local_loss
|
182 |
+
self.gather_loss = config.gather_loss
|
183 |
+
|
184 |
+
def encode_image(self, image):
|
185 |
+
image_feats, _ = self.vision_encoder(image)
|
186 |
+
image_feats = self.mm_vision_proj(image_feats)
|
187 |
+
image_feats = F.normalize(image_feats, dim=-1)
|
188 |
+
|
189 |
+
return image_feats
|
190 |
+
|
191 |
+
def encode_text(self, input_id, attention_mask):
|
192 |
+
text_feats = self.language_encoder(input_id, attention_mask=attention_mask)["last_hidden_state"]
|
193 |
+
text_feats = self.mm_language_proj(text_feats)
|
194 |
+
text_feats = F.normalize(text_feats, dim=-1)
|
195 |
+
|
196 |
+
return text_feats
|
197 |
+
|
198 |
+
|
199 |
+
def forward(self, images, input_ids, attention_mask, labels, **kwargs):
|
200 |
+
image_features = self.encode_image(images)[:, 0]
|
201 |
+
text_features = self.encode_text(input_ids, attention_mask)[:, 0]
|
202 |
+
|
203 |
+
if self.gather_loss:
|
204 |
+
all_image_features, all_text_features = gather_features(image_features, text_features)
|
205 |
+
if self.local_loss:
|
206 |
+
logits_per_image = self.logit_scale * image_features @ all_text_features.T
|
207 |
+
logits_per_text = self.logit_scale * text_features @ all_image_features.T
|
208 |
+
else:
|
209 |
+
logits_per_image = self.logit_scale * all_image_features @ all_text_features.T
|
210 |
+
logits_per_text = logits_per_image.T
|
211 |
+
else:
|
212 |
+
logits_per_image = self.logit_scale * image_features @ text_features.T
|
213 |
+
logits_per_text = self.logit_scale * text_features @ image_features.T
|
214 |
+
|
215 |
+
loss = (
|
216 |
+
F.cross_entropy(logits_per_image, labels) +
|
217 |
+
F.cross_entropy(logits_per_text, labels)
|
218 |
+
) / 2
|
219 |
+
|
220 |
+
ret = {
|
221 |
+
"loss": loss,
|
222 |
+
"logits": (logits_per_image + logits_per_text) / 2.0,
|
223 |
+
}
|
224 |
+
|
225 |
+
return ret
|