Slep commited on
Commit
11f6a98
1 Parent(s): 2fe5783

Upload CondViTForEmbedding

Browse files
Files changed (4) hide show
  1. config.json +20 -0
  2. hf_model.py +47 -0
  3. model.safetensors +3 -0
  4. module.py +171 -0
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "CondViT-B16-cat",
3
+ "architectures": [
4
+ "CondViTForEmbedding"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "hf_model.CondViTConfig",
8
+ "AutoModel": "hf_model.CondViTForEmbedding"
9
+ },
10
+ "heads": 12,
11
+ "input_resolution": 224,
12
+ "layers": 12,
13
+ "model_type": "condvit",
14
+ "n_categories": 10,
15
+ "output_dim": 512,
16
+ "patch_size": 16,
17
+ "torch_dtype": "float32",
18
+ "transformers_version": "4.37.1",
19
+ "width": 768
20
+ }
hf_model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ from .module import ConditionalViT
3
+
4
+
5
+ class CondViTConfig(PretrainedConfig):
6
+ model_type = "condvit"
7
+
8
+ def __init__(
9
+ self,
10
+ input_resolution: int = 224,
11
+ patch_size: int = 16,
12
+ width: int = 768,
13
+ layers: int = 12,
14
+ heads: int = 12,
15
+ output_dim: int = 512,
16
+ n_categories: int = 10,
17
+ **kwargs
18
+ ):
19
+ self.input_resolution = input_resolution
20
+ self.patch_size = patch_size
21
+ self.width = width
22
+ self.layers = layers
23
+ self.heads = heads
24
+ self.output_dim = output_dim
25
+ self.n_categories = n_categories
26
+
27
+ super().__init__(**kwargs)
28
+
29
+
30
+ class CondViTForEmbedding(PreTrainedModel):
31
+ config_class = CondViTConfig
32
+
33
+ def __init__(self, config):
34
+ super().__init__(config)
35
+
36
+ self.model = ConditionalViT(
37
+ input_resolution=config.input_resolution,
38
+ patch_size=config.patch_size,
39
+ width=config.width,
40
+ layers=config.layers,
41
+ heads=config.heads,
42
+ output_dim=config.output_dim,
43
+ n_categories=config.n_categories,
44
+ )
45
+
46
+ def forward(self, img_tensors, category_indices=None):
47
+ return self.model(img_tensors, category_indices)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:017be653458d515c3c44a35a0049660b69d6af79806f1a555e1e63685a251438
3
+ size 344822004
module.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from collections import OrderedDict
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class LayerNorm(nn.LayerNorm):
11
+ """Subclass torch's LayerNorm to handle fp16."""
12
+
13
+ def forward(self, x: torch.Tensor):
14
+ if self.weight.dtype != x.dtype:
15
+ orig_type = x.dtype
16
+ ret = super().forward(x.type(self.weight.dtype))
17
+ return ret.type(orig_type)
18
+ else:
19
+ return super().forward(x)
20
+
21
+
22
+ class QuickGELU(nn.Module):
23
+ def forward(self, x: torch.Tensor):
24
+ return x * torch.sigmoid(1.702 * x)
25
+
26
+
27
+ class ResidualAttentionBlock(nn.Module):
28
+ def __init__(
29
+ self,
30
+ d_model: int,
31
+ n_head: int,
32
+ attn_mask: torch.Tensor = None,
33
+ ):
34
+ super().__init__()
35
+
36
+ self.attn = nn.MultiheadAttention(d_model, n_head)
37
+ self.ln_1 = LayerNorm(d_model)
38
+ self.mlp = nn.Sequential(
39
+ OrderedDict(
40
+ [
41
+ (
42
+ "c_fc",
43
+ nn.Linear(d_model, d_model * 4),
44
+ ),
45
+ ("gelu", QuickGELU()),
46
+ (
47
+ "c_proj",
48
+ nn.Linear(d_model * 4, d_model),
49
+ ),
50
+ ]
51
+ )
52
+ )
53
+ self.ln_2 = LayerNorm(d_model)
54
+ self.attn_mask = attn_mask
55
+
56
+ def attention(self, x: torch.Tensor):
57
+ self.attn_mask = (
58
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
59
+ if self.attn_mask is not None
60
+ else None
61
+ )
62
+ return self.attn(
63
+ x,
64
+ x,
65
+ x,
66
+ need_weights=False,
67
+ attn_mask=self.attn_mask,
68
+ )[0]
69
+
70
+ def forward(self, x: torch.Tensor):
71
+ x = x + self.attention(self.ln_1(x))
72
+ x = x + self.mlp(self.ln_2(x))
73
+ return x
74
+
75
+
76
+ class Transformer(nn.Module):
77
+ def __init__(
78
+ self,
79
+ width: int,
80
+ layers: int,
81
+ heads: int,
82
+ attn_mask: torch.Tensor = None,
83
+ ):
84
+ super().__init__()
85
+ self.width = width
86
+ self.layers = layers
87
+ self.resblocks = nn.Sequential(
88
+ *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor):
92
+ return self.resblocks(x)
93
+
94
+
95
+ class ConditionalViT(nn.Module):
96
+ def __init__(
97
+ self,
98
+ input_resolution: int,
99
+ patch_size: int,
100
+ width: int,
101
+ layers: int,
102
+ heads: int,
103
+ output_dim: int,
104
+ n_categories: int,
105
+ ):
106
+ super().__init__()
107
+ self.input_resolution = input_resolution
108
+ self.output_dim = output_dim
109
+ self.conv1 = nn.Conv2d(
110
+ in_channels=3,
111
+ out_channels=width,
112
+ kernel_size=patch_size,
113
+ stride=patch_size,
114
+ bias=False,
115
+ )
116
+
117
+ scale = width**-0.5
118
+
119
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
120
+
121
+ self.n_categories = n_categories
122
+ self.c_embedding = nn.Embedding(self.n_categories, width)
123
+ self.c_pos_embedding = nn.Parameter(scale * torch.randn(1, width))
124
+
125
+ self.positional_embedding = nn.Parameter(
126
+ scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
127
+ )
128
+ self.ln_pre = LayerNorm(width)
129
+
130
+ self.transformer = Transformer(width, layers, heads)
131
+ self.ln_post = LayerNorm(width)
132
+ self.logit_scale = torch.nn.Parameter(torch.ones([]) * 4.6052)
133
+
134
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
135
+
136
+ def forward(self, imgs: torch.Tensor, c: torch.Tensor = None):
137
+ """
138
+ imgs : Batch of images
139
+ c : category indices.
140
+ """
141
+
142
+ x = self.conv1(imgs) # shape = [*, width, grid, grid]
143
+ # shape = [*, width, grid ** 2]
144
+ x = x.reshape(x.shape[0], x.shape[1], -1)
145
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
146
+
147
+ # [CLS, grid] + maybe Categories.
148
+ tokens = [self.class_embedding.tile(x.shape[0], 1, 1), x] # NLD
149
+ pos_embed = [self.positional_embedding] # LD
150
+
151
+ if c is not None: # If c is None, we don't add the token
152
+ tokens += [self.c_embedding(c).unsqueeze(1)] # ND -> N1D
153
+ pos_embed += [self.c_pos_embedding] # 1D
154
+
155
+ # shape = [*, grid ** 2 + 1|2, width] = N(L|L+1)D
156
+ x = torch.cat(tokens, dim=1)
157
+ pos_embed = torch.cat(pos_embed, dim=0).unsqueeze(0) # 1(L|L+1)D
158
+
159
+ x = x + pos_embed
160
+ x = self.ln_pre(x)
161
+
162
+ x = x.permute(1, 0, 2) # NLD -> LND
163
+
164
+ x = self.transformer(x)
165
+ x = x.permute(1, 0, 2) # LND -> NLD
166
+
167
+ x = self.ln_post(x[:, 0, :])
168
+
169
+ x = x @ self.proj
170
+
171
+ return x