howard-hou commited on
Commit
f0c200a
·
1 Parent(s): c42ed6d

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -66
model.py DELETED
@@ -1,66 +0,0 @@
1
- from transformers import CLIPVisionModel
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from dataclasses import dataclass
6
-
7
- @dataclass
8
- class VisualEncoderConfig:
9
- n_embd: int = 2048
10
- vision_tower_name: str = 'openai/clip-vit-large-patch14-336'
11
- grid_size: int = -1 # -1: no grid pooling, 0: take cls token, 1: global avg pooling, 2, 3, 4, ...: grid pooling
12
-
13
- class VisualEncoder(nn.Module):
14
- def __init__(self, args):
15
- super().__init__()
16
- self.args = args
17
- self.vit = CLIPVisionModel.from_pretrained(args.vision_tower_name)
18
- self.proj = nn.Linear(self.vit.config.hidden_size, args.n_embd, bias=False)
19
-
20
- def encode_images(self, images):
21
- B, N, C, H, W = images.shape
22
- images = images.view(B*N, C, H, W)
23
- image_features = self.vit(images).last_hidden_state
24
- L, D = image_features.shape[1], image_features.shape[2]
25
- # rerange [B*N, L, D] -> [B, N, L, D]
26
- image_features = image_features.view(B, N, L, D)[:, 0, :, :]
27
- image_features = self.grid_pooling(image_features)
28
- return self.proj(image_features)
29
-
30
- def grid_pooling(self, image_features):
31
- if self.args.grid_size == -1: # no grid pooling
32
- return image_features
33
- if self.args.grid_size == 0: # take cls token
34
- return image_features[:, 0:1, :]
35
- if self.args.grid_size == 1: # global avg pooling
36
- return image_features.mean(dim=1, keepdim=True)
37
- cls_features = image_features[:, 0:1, :]
38
- image_features = image_features[:, 1:, :] #drop cls token
39
- B, L, D = image_features.shape
40
- H_or_W = int(L**0.5)
41
- image_features = image_features.view(B, H_or_W, H_or_W, D)
42
- grid_stride = H_or_W // self.args.grid_size
43
- image_features = F.avg_pool2d(image_features.permute(0, 3, 1, 2),
44
- padding=0,
45
- kernel_size=grid_stride,
46
- stride=grid_stride)
47
- image_features = image_features.permute(0, 2, 3, 1).view(B, -1, D)
48
- return torch.cat((cls_features, image_features), dim=1)
49
-
50
-
51
- class EmbeddingMixer(nn.Module):
52
- def __init__(self, original_embedding, num_image_embeddings=4096):
53
- super().__init__()
54
- image_embedding = torch.zeros(num_image_embeddings,
55
- original_embedding.shape[1],
56
- device=original_embedding.device,
57
- dtype=original_embedding.dtype)
58
- self.embedding = torch.cat((original_embedding, image_embedding), dim=0)
59
- self.image_start_index = len(original_embedding)
60
-
61
- def set_image_embeddings(self, image_embeddings):
62
- end_index = self.image_start_index + image_embeddings.shape[0]
63
- self.embedding[self.image_start_index:end_index] = image_embeddings
64
-
65
- def get_input_embeddings(self):
66
- return self.embedding