howard-hou commited on
Commit
51b71a9
1 Parent(s): ced196e

Upload modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +66 -0
modeling.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPVisionModel
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from dataclasses import dataclass
6
+
7
+ @dataclass
8
+ class VisualEncoderConfig:
9
+ n_embd: int = 2048
10
+ vision_tower_name: str = 'openai/clip-vit-large-patch14-336'
11
+ grid_size: int = -1 # -1: no grid pooling, 0: take cls token, 1: global avg pooling, 2, 3, 4, ...: grid pooling
12
+
13
+ class VisualEncoder(nn.Module):
14
+ def __init__(self, args):
15
+ super().__init__()
16
+ self.args = args
17
+ self.vit = CLIPVisionModel.from_pretrained(args.vision_tower_name)
18
+ self.proj = nn.Linear(self.vit.config.hidden_size, args.n_embd, bias=False)
19
+
20
+ def encode_images(self, images):
21
+ B, N, C, H, W = images.shape
22
+ images = images.view(B*N, C, H, W)
23
+ image_features = self.vit(images).last_hidden_state
24
+ L, D = image_features.shape[1], image_features.shape[2]
25
+ # rerange [B*N, L, D] -> [B, N, L, D]
26
+ image_features = image_features.view(B, N, L, D)[:, 0, :, :]
27
+ image_features = self.grid_pooling(image_features)
28
+ return self.proj(image_features)
29
+
30
+ def grid_pooling(self, image_features):
31
+ if self.args.grid_size == -1: # no grid pooling
32
+ return image_features
33
+ if self.args.grid_size == 0: # take cls token
34
+ return image_features[:, 0:1, :]
35
+ if self.args.grid_size == 1: # global avg pooling
36
+ return image_features.mean(dim=1, keepdim=True)
37
+ cls_features = image_features[:, 0:1, :]
38
+ image_features = image_features[:, 1:, :] #drop cls token
39
+ B, L, D = image_features.shape
40
+ H_or_W = int(L**0.5)
41
+ image_features = image_features.view(B, H_or_W, H_or_W, D)
42
+ grid_stride = H_or_W // self.args.grid_size
43
+ image_features = F.avg_pool2d(image_features.permute(0, 3, 1, 2),
44
+ padding=0,
45
+ kernel_size=grid_stride,
46
+ stride=grid_stride)
47
+ image_features = image_features.permute(0, 2, 3, 1).view(B, -1, D)
48
+ return torch.cat((cls_features, image_features), dim=1)
49
+
50
+
51
+ class EmbeddingMixer(nn.Module):
52
+ def __init__(self, original_embedding, num_image_embeddings=4096):
53
+ super().__init__()
54
+ image_embedding = torch.zeros(num_image_embeddings,
55
+ original_embedding.shape[1],
56
+ device=original_embedding.device,
57
+ dtype=original_embedding.dtype)
58
+ self.embedding = torch.cat((original_embedding, image_embedding), dim=0)
59
+ self.image_start_index = len(original_embedding)
60
+
61
+ def set_image_embeddings(self, image_embeddings):
62
+ end_index = self.image_start_index + image_embeddings.shape[0]
63
+ self.embedding[self.image_start_index:end_index] = image_embeddings
64
+
65
+ def get_input_embeddings(self):
66
+ return self.embedding