Upload ClipMDModel.py
Browse files- ClipMDModel.py +138 -0
ClipMDModel.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPModel
|
2 |
+
import torch
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
|
7 |
+
return torch.nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
|
8 |
+
|
9 |
+
|
10 |
+
def clip_loss(logits_per_text: torch.Tensor) -> torch.Tensor:
|
11 |
+
caption_loss = contrastive_loss(logits_per_text)
|
12 |
+
image_loss = contrastive_loss(logits_per_text.T)
|
13 |
+
return (caption_loss + image_loss) / 2.0
|
14 |
+
|
15 |
+
|
16 |
+
class ClipMDModel(CLIPModel):
|
17 |
+
|
18 |
+
def embed_text(self,
|
19 |
+
input_ids:torch.LongTensor,
|
20 |
+
attention_mask:torch.LongTensor,
|
21 |
+
output_attentions: Optional[bool] = None,
|
22 |
+
output_hidden_states: Optional[bool] = None,
|
23 |
+
position_ids: Optional[torch.LongTensor] = None,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
:param input_ids: tokenized text from CLIPProcessor.
|
27 |
+
:param attention_mask: attention mask from CLIPProcessor.
|
28 |
+
:return: text embeddings of input_ids (tokens longer then 77 tokens
|
29 |
+
is embeded using a sliding window and pooling).
|
30 |
+
"""
|
31 |
+
tokens = []
|
32 |
+
masks = []
|
33 |
+
pos = []
|
34 |
+
|
35 |
+
for i in range(input_ids.size()[0]):
|
36 |
+
ten = input_ids[i]
|
37 |
+
mask = attention_mask[i]
|
38 |
+
mask = mask[mask.nonzero().flatten()]
|
39 |
+
ten = ten[:mask.size()[0]]
|
40 |
+
|
41 |
+
if not pos:
|
42 |
+
pos.append([0, 0])
|
43 |
+
else:
|
44 |
+
pos.append([pos[-1][1], pos[-1][1]])
|
45 |
+
|
46 |
+
#spliting tokenized text into input sized chunks with an overlapping window.
|
47 |
+
if ten.size()[0]>77:
|
48 |
+
tokens.append(ten.unfold(dimension = 0,size = 77, step = 70))
|
49 |
+
masks.append(mask.unfold(dimension = 0,size = 77, step = 70))
|
50 |
+
|
51 |
+
pos[-1][1]+=tokens[-1].size()[0]
|
52 |
+
|
53 |
+
ten=ten[tokens[-1].size()[0]*70:]
|
54 |
+
mask=mask[tokens[-1].size()[0]*70:]
|
55 |
+
|
56 |
+
if ten.size()[0] > 0:
|
57 |
+
new_mask = torch.zeros((1, 77)).to(self.device)
|
58 |
+
new_mask[:, 0:mask.size()[0]] = mask
|
59 |
+
|
60 |
+
new_ten = torch.full((1, 77), 49407).to(self.device)
|
61 |
+
new_ten[:, 0:ten.size()[0]] = ten
|
62 |
+
|
63 |
+
tokens.append(new_ten)
|
64 |
+
masks.append(new_mask)
|
65 |
+
pos[-1][1] += 1
|
66 |
+
#encoding the tokenized text
|
67 |
+
embedded = self.get_text_features(input_ids=torch.cat(tokens, 0),
|
68 |
+
attention_mask=torch.cat(masks, 0),
|
69 |
+
output_attentions=output_attentions,
|
70 |
+
output_hidden_states=output_hidden_states,
|
71 |
+
position_ids=position_ids,
|
72 |
+
)
|
73 |
+
|
74 |
+
#pooling the embeddings of segments that came from the same original text
|
75 |
+
embeddings = []
|
76 |
+
for p in pos:
|
77 |
+
if p[1] - p[0] == 1:
|
78 |
+
embeddings.append(embedded[p[0]].unsqueeze(0))
|
79 |
+
else:
|
80 |
+
embeddings.append(torch.mean(embedded[p[0]:p[1]], dim=0).unsqueeze(0))
|
81 |
+
|
82 |
+
return torch.cat(embeddings, 0)
|
83 |
+
|
84 |
+
def forward(self,
|
85 |
+
input_ids: Optional[torch.LongTensor] = None,
|
86 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
87 |
+
attention_mask: Optional[torch.Tensor] = None,
|
88 |
+
position_ids: Optional[torch.LongTensor] = None,
|
89 |
+
return_loss: Optional[bool] = None,
|
90 |
+
output_attentions: Optional[bool] = None,
|
91 |
+
output_hidden_states: Optional[bool] = None,
|
92 |
+
return_dict: Optional[bool] = None,
|
93 |
+
) -> Tuple:
|
94 |
+
"""
|
95 |
+
:param input_ids: tokenized text from CLIPProcessor.
|
96 |
+
:param attention_mask: attention mask from CLIPProcessor.
|
97 |
+
:param pixel_values: pixel values from CLIPProcessor.
|
98 |
+
:param return_loss: boolean that indicates if loss should be returned
|
99 |
+
:return: image-caption cosine similarity as logits per image and per caption (also loss if return_loss is true)
|
100 |
+
"""
|
101 |
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
102 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
103 |
+
output_hidden_states = (
|
104 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
105 |
+
)
|
106 |
+
return_dict = self.config.use_return_dict
|
107 |
+
|
108 |
+
#encoding the images
|
109 |
+
vision_outputs = self.vision_model(
|
110 |
+
pixel_values=pixel_values,
|
111 |
+
output_attentions=output_attentions,
|
112 |
+
output_hidden_states=output_hidden_states,
|
113 |
+
)
|
114 |
+
image_embeds = vision_outputs[1]
|
115 |
+
image_embeds = self.visual_projection(image_embeds)
|
116 |
+
|
117 |
+
#encoding the text captions
|
118 |
+
text_embeds =self.embed_text(input_ids=input_ids,
|
119 |
+
attention_mask=attention_mask,
|
120 |
+
output_attentions=output_attentions,
|
121 |
+
output_hidden_states=output_hidden_states,
|
122 |
+
position_ids=position_ids
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
# normalized features
|
127 |
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
128 |
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
129 |
+
|
130 |
+
# cosine similarity as logits
|
131 |
+
logit_scale = self.logit_scale.exp()
|
132 |
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
133 |
+
logits_per_image = logits_per_text.T
|
134 |
+
|
135 |
+
if return_loss:
|
136 |
+
loss = clip_loss(logits_per_text)
|
137 |
+
return logits_per_image,logits_per_text,loss
|
138 |
+
return logits_per_image,logits_per_text
|