GoodBaiBai88
commited on
Commit
•
1a597d0
1
Parent(s):
83c9d54
Upload M3DCLIP
Browse files- config.json +35 -0
- configuration_m3d_clip.py +42 -0
- model.safetensors +3 -0
- modeling_m3d_clip.py +225 -0
config.json
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"M3DCLIP"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_m3d_clip.M3DCLIPConfig",
|
7 |
+
"AutoModel": "modeling_m3d_clip.M3DCLIP"
|
8 |
+
},
|
9 |
+
"dropout_rate": 0,
|
10 |
+
"gather_loss": true,
|
11 |
+
"hidden_size": 768,
|
12 |
+
"img_size": [
|
13 |
+
32,
|
14 |
+
256,
|
15 |
+
256
|
16 |
+
],
|
17 |
+
"in_channels": 1,
|
18 |
+
"language_model_name_or_path": "bert-base-uncased",
|
19 |
+
"local_loss": false,
|
20 |
+
"max_text_len": 128,
|
21 |
+
"mlp_dim": 3072,
|
22 |
+
"model_type": "m3d_clip",
|
23 |
+
"num_heads": 12,
|
24 |
+
"num_layers": 12,
|
25 |
+
"patch_size": [
|
26 |
+
4,
|
27 |
+
16,
|
28 |
+
16
|
29 |
+
],
|
30 |
+
"pos_embed": "perceptron",
|
31 |
+
"spatial_dims": 3,
|
32 |
+
"torch_dtype": "float32",
|
33 |
+
"transformers_version": "4.40.1",
|
34 |
+
"vocab_size": 30522
|
35 |
+
}
|
configuration_m3d_clip.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
class M3DCLIPConfig(PretrainedConfig):
|
6 |
+
model_type = "m3d_clip"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
language_model_name_or_path: str = 'bert-base-uncased',
|
11 |
+
local_loss: bool = False,
|
12 |
+
gather_loss: bool = True,
|
13 |
+
in_channels: int = 1,
|
14 |
+
img_size: tuple = (32, 256, 256),
|
15 |
+
patch_size: tuple = (4, 16, 16),
|
16 |
+
hidden_size: int = 768,
|
17 |
+
mlp_dim: int = 3072,
|
18 |
+
num_layers: int = 12,
|
19 |
+
num_heads: int = 12,
|
20 |
+
pos_embed: str = "perceptron",
|
21 |
+
dropout_rate: float = 0,
|
22 |
+
spatial_dims: int = 3,
|
23 |
+
max_text_len: int = 128,
|
24 |
+
vocab_size: int = 30522,
|
25 |
+
**kwargs,
|
26 |
+
):
|
27 |
+
self.language_model_name_or_path = language_model_name_or_path
|
28 |
+
self.in_channels = in_channels
|
29 |
+
self.img_size = img_size
|
30 |
+
self.patch_size = patch_size
|
31 |
+
self.hidden_size = hidden_size
|
32 |
+
self.mlp_dim = mlp_dim
|
33 |
+
self.num_layers = num_layers
|
34 |
+
self.num_heads = num_heads
|
35 |
+
self.pos_embed = pos_embed
|
36 |
+
self.dropout_rate = dropout_rate
|
37 |
+
self.spatial_dims = spatial_dims
|
38 |
+
self.local_loss = local_loss
|
39 |
+
self.gather_loss = gather_loss
|
40 |
+
self.max_text_len = max_text_len
|
41 |
+
self.vocab_size = vocab_size
|
42 |
+
super().__init__(**kwargs)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c222448e87235ee2eaf6b903d6a9ab62e5d7e1300a5de823a60459eae2cfde32
|
3 |
+
size 792251956
|
modeling_m3d_clip.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers import PreTrainedModel
|
6 |
+
from collections.abc import Sequence
|
7 |
+
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
|
8 |
+
from monai.networks.blocks.transformerblock import TransformerBlock
|
9 |
+
try:
|
10 |
+
import torch.distributed.nn
|
11 |
+
from torch import distributed as dist
|
12 |
+
has_distributed = True
|
13 |
+
except ImportError:
|
14 |
+
has_distributed = False
|
15 |
+
from .configuration_m3d_clip import M3DCLIPConfig
|
16 |
+
from transformers import BertModel, BertConfig
|
17 |
+
|
18 |
+
|
19 |
+
def gather_features(
|
20 |
+
image_features,
|
21 |
+
text_features,
|
22 |
+
local_loss=False,
|
23 |
+
gather_with_grad=True,
|
24 |
+
rank=0,
|
25 |
+
world_size=1,
|
26 |
+
):
|
27 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
28 |
+
|
29 |
+
# We gather tensors from all gpus
|
30 |
+
if gather_with_grad:
|
31 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
32 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
33 |
+
else:
|
34 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
35 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
36 |
+
dist.all_gather(gathered_image_features, image_features)
|
37 |
+
dist.all_gather(gathered_text_features, text_features)
|
38 |
+
if not local_loss:
|
39 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
40 |
+
gathered_image_features[rank] = image_features
|
41 |
+
gathered_text_features[rank] = text_features
|
42 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
43 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
44 |
+
|
45 |
+
return all_image_features, all_text_features
|
46 |
+
|
47 |
+
class ViT(nn.Module):
|
48 |
+
"""
|
49 |
+
Vision Transformer (ViT), based on: "Dosovitskiy et al.,
|
50 |
+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
|
51 |
+
|
52 |
+
ViT supports Torchscript but only works for Pytorch after 1.8.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_channels: int,
|
58 |
+
img_size: Sequence[int] | int,
|
59 |
+
patch_size: Sequence[int] | int,
|
60 |
+
hidden_size: int = 768,
|
61 |
+
mlp_dim: int = 3072,
|
62 |
+
num_layers: int = 12,
|
63 |
+
num_heads: int = 12,
|
64 |
+
pos_embed: str = "conv",
|
65 |
+
classification: bool = False,
|
66 |
+
num_classes: int = 2,
|
67 |
+
dropout_rate: float = 0.0,
|
68 |
+
spatial_dims: int = 3,
|
69 |
+
post_activation="Tanh",
|
70 |
+
qkv_bias: bool = False,
|
71 |
+
save_attn: bool = False,
|
72 |
+
) -> None:
|
73 |
+
"""
|
74 |
+
Args:
|
75 |
+
in_channels (int): dimension of input channels.
|
76 |
+
img_size (Union[Sequence[int], int]): dimension of input image.
|
77 |
+
patch_size (Union[Sequence[int], int]): dimension of patch size.
|
78 |
+
hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
|
79 |
+
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
|
80 |
+
num_layers (int, optional): number of transformer blocks. Defaults to 12.
|
81 |
+
num_heads (int, optional): number of attention heads. Defaults to 12.
|
82 |
+
pos_embed (str, optional): position embedding layer type. Defaults to "conv".
|
83 |
+
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
|
84 |
+
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
|
85 |
+
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
|
86 |
+
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
|
87 |
+
post_activation (str, optional): add a final acivation function to the classification head
|
88 |
+
when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
|
89 |
+
Set to other values to remove this function.
|
90 |
+
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
|
91 |
+
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
|
92 |
+
|
93 |
+
Examples::
|
94 |
+
|
95 |
+
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
|
96 |
+
>>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
|
97 |
+
|
98 |
+
# for 3-channel with image size of (128,128,128), 24 layers and classification backbone
|
99 |
+
>>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
|
100 |
+
|
101 |
+
# for 3-channel with image size of (224,224), 12 layers and classification backbone
|
102 |
+
>>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
|
103 |
+
|
104 |
+
"""
|
105 |
+
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
if not (0 <= dropout_rate <= 1):
|
109 |
+
raise ValueError("dropout_rate should be between 0 and 1.")
|
110 |
+
|
111 |
+
if hidden_size % num_heads != 0:
|
112 |
+
raise ValueError("hidden_size should be divisible by num_heads.")
|
113 |
+
self.hidden_size = hidden_size
|
114 |
+
self.classification = classification
|
115 |
+
self.patch_embedding = PatchEmbeddingBlock(
|
116 |
+
in_channels=in_channels,
|
117 |
+
img_size=img_size,
|
118 |
+
patch_size=patch_size,
|
119 |
+
hidden_size=hidden_size,
|
120 |
+
num_heads=num_heads,
|
121 |
+
pos_embed=pos_embed,
|
122 |
+
dropout_rate=dropout_rate,
|
123 |
+
spatial_dims=spatial_dims,
|
124 |
+
)
|
125 |
+
self.blocks = nn.ModuleList(
|
126 |
+
[
|
127 |
+
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
|
128 |
+
for i in range(num_layers)
|
129 |
+
]
|
130 |
+
)
|
131 |
+
self.norm = nn.LayerNorm(hidden_size)
|
132 |
+
if self.classification:
|
133 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
134 |
+
# if post_activation == "Tanh":
|
135 |
+
# self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())
|
136 |
+
# else:
|
137 |
+
# self.classification_head = nn.Linear(hidden_size, num_classes) # type: ignore
|
138 |
+
|
139 |
+
def forward(self, x):
|
140 |
+
x = self.patch_embedding(x)
|
141 |
+
if hasattr(self, "cls_token"):
|
142 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
143 |
+
x = torch.cat((cls_token, x), dim=1)
|
144 |
+
hidden_states_out = []
|
145 |
+
for blk in self.blocks:
|
146 |
+
x = blk(x)
|
147 |
+
hidden_states_out.append(x)
|
148 |
+
x = self.norm(x)
|
149 |
+
# if hasattr(self, "classification_head"):
|
150 |
+
# x = self.classification_head(x[:, 0])
|
151 |
+
return x, hidden_states_out
|
152 |
+
|
153 |
+
|
154 |
+
class M3DCLIP(PreTrainedModel):
|
155 |
+
config_class = M3DCLIPConfig
|
156 |
+
|
157 |
+
def __init__(self, config):
|
158 |
+
super().__init__(config)
|
159 |
+
self.vision_encoder = ViT(
|
160 |
+
in_channels=config.in_channels,
|
161 |
+
img_size=config.img_size,
|
162 |
+
patch_size=config.patch_size,
|
163 |
+
hidden_size=config.hidden_size,
|
164 |
+
mlp_dim=config.mlp_dim,
|
165 |
+
num_layers=config.num_layers,
|
166 |
+
num_heads=config.num_heads,
|
167 |
+
pos_embed=config.pos_embed,
|
168 |
+
dropout_rate=config.dropout_rate,
|
169 |
+
spatial_dims=config.spatial_dims,
|
170 |
+
classification=True,
|
171 |
+
)
|
172 |
+
# configuration = BertConfig()
|
173 |
+
# self.language_encoder = BertModel(configuration)
|
174 |
+
self.language_encoder = BertModel.from_pretrained(config.language_model_name_or_path)
|
175 |
+
|
176 |
+
self.mm_vision_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
177 |
+
self.mm_language_proj = nn.Linear(config.hidden_size, config.hidden_size)
|
178 |
+
|
179 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
180 |
+
|
181 |
+
self.local_loss = config.local_loss
|
182 |
+
self.gather_loss = config.gather_loss
|
183 |
+
|
184 |
+
def encode_image(self, image):
|
185 |
+
image_feats, _ = self.vision_encoder(image)
|
186 |
+
image_feats = self.mm_vision_proj(image_feats)
|
187 |
+
image_feats = F.normalize(image_feats, dim=-1)
|
188 |
+
|
189 |
+
return image_feats
|
190 |
+
|
191 |
+
def encode_text(self, input_id, attention_mask):
|
192 |
+
text_feats = self.language_encoder(input_id, attention_mask=attention_mask)["last_hidden_state"]
|
193 |
+
text_feats = self.mm_language_proj(text_feats)
|
194 |
+
text_feats = F.normalize(text_feats, dim=-1)
|
195 |
+
|
196 |
+
return text_feats
|
197 |
+
|
198 |
+
|
199 |
+
def forward(self, images, input_ids, attention_mask, labels, **kwargs):
|
200 |
+
image_features = self.encode_image(images)[:, 0]
|
201 |
+
text_features = self.encode_text(input_ids, attention_mask)[:, 0]
|
202 |
+
|
203 |
+
if self.gather_loss:
|
204 |
+
all_image_features, all_text_features = gather_features(image_features, text_features)
|
205 |
+
if self.local_loss:
|
206 |
+
logits_per_image = self.logit_scale * image_features @ all_text_features.T
|
207 |
+
logits_per_text = self.logit_scale * text_features @ all_image_features.T
|
208 |
+
else:
|
209 |
+
logits_per_image = self.logit_scale * all_image_features @ all_text_features.T
|
210 |
+
logits_per_text = logits_per_image.T
|
211 |
+
else:
|
212 |
+
logits_per_image = self.logit_scale * image_features @ text_features.T
|
213 |
+
logits_per_text = self.logit_scale * text_features @ image_features.T
|
214 |
+
|
215 |
+
loss = (
|
216 |
+
F.cross_entropy(logits_per_image, labels) +
|
217 |
+
F.cross_entropy(logits_per_text, labels)
|
218 |
+
) / 2
|
219 |
+
|
220 |
+
ret = {
|
221 |
+
"loss": loss,
|
222 |
+
"logits": (logits_per_image + logits_per_text) / 2.0,
|
223 |
+
}
|
224 |
+
|
225 |
+
return ret
|