Upload CondViTForEmbedding
Browse files- config.json +20 -0
- hf_model.py +47 -0
- model.safetensors +3 -0
- module.py +171 -0
config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "CondViT-B16-cat",
|
3 |
+
"architectures": [
|
4 |
+
"CondViTForEmbedding"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "hf_model.CondViTConfig",
|
8 |
+
"AutoModel": "hf_model.CondViTForEmbedding"
|
9 |
+
},
|
10 |
+
"heads": 12,
|
11 |
+
"input_resolution": 224,
|
12 |
+
"layers": 12,
|
13 |
+
"model_type": "condvit",
|
14 |
+
"n_categories": 10,
|
15 |
+
"output_dim": 512,
|
16 |
+
"patch_size": 16,
|
17 |
+
"torch_dtype": "float32",
|
18 |
+
"transformers_version": "4.37.1",
|
19 |
+
"width": 768
|
20 |
+
}
|
hf_model.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
2 |
+
from .module import ConditionalViT
|
3 |
+
|
4 |
+
|
5 |
+
class CondViTConfig(PretrainedConfig):
|
6 |
+
model_type = "condvit"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
input_resolution: int = 224,
|
11 |
+
patch_size: int = 16,
|
12 |
+
width: int = 768,
|
13 |
+
layers: int = 12,
|
14 |
+
heads: int = 12,
|
15 |
+
output_dim: int = 512,
|
16 |
+
n_categories: int = 10,
|
17 |
+
**kwargs
|
18 |
+
):
|
19 |
+
self.input_resolution = input_resolution
|
20 |
+
self.patch_size = patch_size
|
21 |
+
self.width = width
|
22 |
+
self.layers = layers
|
23 |
+
self.heads = heads
|
24 |
+
self.output_dim = output_dim
|
25 |
+
self.n_categories = n_categories
|
26 |
+
|
27 |
+
super().__init__(**kwargs)
|
28 |
+
|
29 |
+
|
30 |
+
class CondViTForEmbedding(PreTrainedModel):
|
31 |
+
config_class = CondViTConfig
|
32 |
+
|
33 |
+
def __init__(self, config):
|
34 |
+
super().__init__(config)
|
35 |
+
|
36 |
+
self.model = ConditionalViT(
|
37 |
+
input_resolution=config.input_resolution,
|
38 |
+
patch_size=config.patch_size,
|
39 |
+
width=config.width,
|
40 |
+
layers=config.layers,
|
41 |
+
heads=config.heads,
|
42 |
+
output_dim=config.output_dim,
|
43 |
+
n_categories=config.n_categories,
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, img_tensors, category_indices=None):
|
47 |
+
return self.model(img_tensors, category_indices)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:017be653458d515c3c44a35a0049660b69d6af79806f1a555e1e63685a251438
|
3 |
+
size 344822004
|
module.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from collections import OrderedDict
|
5 |
+
import logging
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class LayerNorm(nn.LayerNorm):
|
11 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
12 |
+
|
13 |
+
def forward(self, x: torch.Tensor):
|
14 |
+
if self.weight.dtype != x.dtype:
|
15 |
+
orig_type = x.dtype
|
16 |
+
ret = super().forward(x.type(self.weight.dtype))
|
17 |
+
return ret.type(orig_type)
|
18 |
+
else:
|
19 |
+
return super().forward(x)
|
20 |
+
|
21 |
+
|
22 |
+
class QuickGELU(nn.Module):
|
23 |
+
def forward(self, x: torch.Tensor):
|
24 |
+
return x * torch.sigmoid(1.702 * x)
|
25 |
+
|
26 |
+
|
27 |
+
class ResidualAttentionBlock(nn.Module):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
d_model: int,
|
31 |
+
n_head: int,
|
32 |
+
attn_mask: torch.Tensor = None,
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
37 |
+
self.ln_1 = LayerNorm(d_model)
|
38 |
+
self.mlp = nn.Sequential(
|
39 |
+
OrderedDict(
|
40 |
+
[
|
41 |
+
(
|
42 |
+
"c_fc",
|
43 |
+
nn.Linear(d_model, d_model * 4),
|
44 |
+
),
|
45 |
+
("gelu", QuickGELU()),
|
46 |
+
(
|
47 |
+
"c_proj",
|
48 |
+
nn.Linear(d_model * 4, d_model),
|
49 |
+
),
|
50 |
+
]
|
51 |
+
)
|
52 |
+
)
|
53 |
+
self.ln_2 = LayerNorm(d_model)
|
54 |
+
self.attn_mask = attn_mask
|
55 |
+
|
56 |
+
def attention(self, x: torch.Tensor):
|
57 |
+
self.attn_mask = (
|
58 |
+
self.attn_mask.to(dtype=x.dtype, device=x.device)
|
59 |
+
if self.attn_mask is not None
|
60 |
+
else None
|
61 |
+
)
|
62 |
+
return self.attn(
|
63 |
+
x,
|
64 |
+
x,
|
65 |
+
x,
|
66 |
+
need_weights=False,
|
67 |
+
attn_mask=self.attn_mask,
|
68 |
+
)[0]
|
69 |
+
|
70 |
+
def forward(self, x: torch.Tensor):
|
71 |
+
x = x + self.attention(self.ln_1(x))
|
72 |
+
x = x + self.mlp(self.ln_2(x))
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class Transformer(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
width: int,
|
80 |
+
layers: int,
|
81 |
+
heads: int,
|
82 |
+
attn_mask: torch.Tensor = None,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
self.width = width
|
86 |
+
self.layers = layers
|
87 |
+
self.resblocks = nn.Sequential(
|
88 |
+
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x: torch.Tensor):
|
92 |
+
return self.resblocks(x)
|
93 |
+
|
94 |
+
|
95 |
+
class ConditionalViT(nn.Module):
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
input_resolution: int,
|
99 |
+
patch_size: int,
|
100 |
+
width: int,
|
101 |
+
layers: int,
|
102 |
+
heads: int,
|
103 |
+
output_dim: int,
|
104 |
+
n_categories: int,
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
self.input_resolution = input_resolution
|
108 |
+
self.output_dim = output_dim
|
109 |
+
self.conv1 = nn.Conv2d(
|
110 |
+
in_channels=3,
|
111 |
+
out_channels=width,
|
112 |
+
kernel_size=patch_size,
|
113 |
+
stride=patch_size,
|
114 |
+
bias=False,
|
115 |
+
)
|
116 |
+
|
117 |
+
scale = width**-0.5
|
118 |
+
|
119 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
120 |
+
|
121 |
+
self.n_categories = n_categories
|
122 |
+
self.c_embedding = nn.Embedding(self.n_categories, width)
|
123 |
+
self.c_pos_embedding = nn.Parameter(scale * torch.randn(1, width))
|
124 |
+
|
125 |
+
self.positional_embedding = nn.Parameter(
|
126 |
+
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
|
127 |
+
)
|
128 |
+
self.ln_pre = LayerNorm(width)
|
129 |
+
|
130 |
+
self.transformer = Transformer(width, layers, heads)
|
131 |
+
self.ln_post = LayerNorm(width)
|
132 |
+
self.logit_scale = torch.nn.Parameter(torch.ones([]) * 4.6052)
|
133 |
+
|
134 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
135 |
+
|
136 |
+
def forward(self, imgs: torch.Tensor, c: torch.Tensor = None):
|
137 |
+
"""
|
138 |
+
imgs : Batch of images
|
139 |
+
c : category indices.
|
140 |
+
"""
|
141 |
+
|
142 |
+
x = self.conv1(imgs) # shape = [*, width, grid, grid]
|
143 |
+
# shape = [*, width, grid ** 2]
|
144 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
145 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
146 |
+
|
147 |
+
# [CLS, grid] + maybe Categories.
|
148 |
+
tokens = [self.class_embedding.tile(x.shape[0], 1, 1), x] # NLD
|
149 |
+
pos_embed = [self.positional_embedding] # LD
|
150 |
+
|
151 |
+
if c is not None: # If c is None, we don't add the token
|
152 |
+
tokens += [self.c_embedding(c).unsqueeze(1)] # ND -> N1D
|
153 |
+
pos_embed += [self.c_pos_embedding] # 1D
|
154 |
+
|
155 |
+
# shape = [*, grid ** 2 + 1|2, width] = N(L|L+1)D
|
156 |
+
x = torch.cat(tokens, dim=1)
|
157 |
+
pos_embed = torch.cat(pos_embed, dim=0).unsqueeze(0) # 1(L|L+1)D
|
158 |
+
|
159 |
+
x = x + pos_embed
|
160 |
+
x = self.ln_pre(x)
|
161 |
+
|
162 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
163 |
+
|
164 |
+
x = self.transformer(x)
|
165 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
166 |
+
|
167 |
+
x = self.ln_post(x[:, 0, :])
|
168 |
+
|
169 |
+
x = x @ self.proj
|
170 |
+
|
171 |
+
return x
|