Spaces:
Sleeping
Sleeping
TAP v1.1 models release
Browse files- app.py +1 -1
- models/tap_vit_l_548184.pkl +0 -3
- models/{tap_vit_l_03f8ec.pkl → tap_vit_l_v1_1.pkl} +2 -2
- tokenize_anything/layers/utils.py +2 -1
- tokenize_anything/modeling/concept_projector.py +2 -2
- tokenize_anything/modeling/image_decoder.py +18 -22
- tokenize_anything/modeling/prompt_encoder.py +9 -16
- tokenize_anything/models/easy_build.py +10 -11
- tokenize_anything/version.py +1 -1
app.py
CHANGED
@@ -32,7 +32,7 @@ def parse_args():
|
|
32 |
"""Parse arguments."""
|
33 |
parser = argparse.ArgumentParser(description="Launch gradio application")
|
34 |
parser.add_argument("--model-type", type=str, default="tap_vit_l")
|
35 |
-
parser.add_argument("--checkpoint", type=str, default="models/
|
36 |
parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl")
|
37 |
parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices")
|
38 |
return parser.parse_args()
|
|
|
32 |
"""Parse arguments."""
|
33 |
parser = argparse.ArgumentParser(description="Launch gradio application")
|
34 |
parser.add_argument("--model-type", type=str, default="tap_vit_l")
|
35 |
+
parser.add_argument("--checkpoint", type=str, default="models/tap_vit_l_v1_1.pkl")
|
36 |
parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl")
|
37 |
parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices")
|
38 |
return parser.parse_args()
|
models/tap_vit_l_548184.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e1d3a11c572af8cb6bce8016d3a6c6948bba4959ea43811f0e984b9eafeee413
|
3 |
-
size 811637521
|
|
|
|
|
|
|
|
models/{tap_vit_l_03f8ec.pkl → tap_vit_l_v1_1.pkl}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7301f73267786bfafc7223931b3fd51baf111d1c24bcc1278ea6378139f1067
|
3 |
+
size 811637487
|
tokenize_anything/layers/utils.py
CHANGED
@@ -26,7 +26,8 @@ def init_cross_conv(blocks):
|
|
26 |
if isinstance(m, torch.nn.Conv2d):
|
27 |
torch.nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
28 |
for blk in blocks:
|
29 |
-
|
|
|
30 |
|
31 |
|
32 |
def set_dropout(module, dropout):
|
|
|
26 |
if isinstance(m, torch.nn.Conv2d):
|
27 |
torch.nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
28 |
for blk in blocks:
|
29 |
+
if hasattr(blk, "norm3") and hasattr(blk.norm3, "weight"):
|
30 |
+
torch.nn.init.zeros_(blk.norm3.weight)
|
31 |
|
32 |
|
33 |
def set_dropout(module, dropout):
|
tokenize_anything/modeling/concept_projector.py
CHANGED
@@ -31,12 +31,12 @@ class ConceptProjector(nn.Module):
|
|
31 |
|
32 |
def reset_weights(self, src_weights=None, tgt_weights=None):
|
33 |
"""Reset the normalized projection weights."""
|
34 |
-
if src_weights
|
35 |
with open(src_weights, "rb") as f:
|
36 |
self.src_weights, self.concepts = pickle.load(f)
|
37 |
self.src_weights = torch.from_numpy(self.src_weights)
|
38 |
self.concepts = np.array(self.concepts)
|
39 |
-
if tgt_weights
|
40 |
with open(tgt_weights, "rb") as f:
|
41 |
self.tgt_weights, self.concepts = pickle.load(f)
|
42 |
self.tgt_weights = torch.from_numpy(self.tgt_weights)
|
|
|
31 |
|
32 |
def reset_weights(self, src_weights=None, tgt_weights=None):
|
33 |
"""Reset the normalized projection weights."""
|
34 |
+
if src_weights:
|
35 |
with open(src_weights, "rb") as f:
|
36 |
self.src_weights, self.concepts = pickle.load(f)
|
37 |
self.src_weights = torch.from_numpy(self.src_weights)
|
38 |
self.concepts = np.array(self.concepts)
|
39 |
+
if tgt_weights:
|
40 |
with open(tgt_weights, "rb") as f:
|
41 |
self.tgt_weights, self.concepts = pickle.load(f)
|
42 |
self.tgt_weights = torch.from_numpy(self.tgt_weights)
|
tokenize_anything/modeling/image_decoder.py
CHANGED
@@ -50,13 +50,12 @@ class Attention(nn.Module):
|
|
50 |
|
51 |
def __init__(self, dim=256, num_heads=8, attn_ratio=1):
|
52 |
super(Attention, self).__init__()
|
53 |
-
|
54 |
-
self.
|
55 |
-
self.
|
56 |
-
self.
|
57 |
-
self.
|
58 |
-
self.
|
59 |
-
self.proj = nn.Linear(qkv_dim, dim)
|
60 |
self.scale = self.head_dim**-0.5
|
61 |
|
62 |
def forward(self, q, k, v):
|
@@ -100,8 +99,7 @@ class Block(nn.Module):
|
|
100 |
q, k = query + query_pos, key + key_pos
|
101 |
query = self.norm2(self.dropout(self.cross_attn_token_to_image(q, k, key)).add_(query))
|
102 |
query = self.norm3(self.dropout(self.mlp(query)).add_(query))
|
103 |
-
|
104 |
-
key = self.norm4(self.cross_attn_image_to_token(k, q, query).add_(key))
|
105 |
return query, key
|
106 |
|
107 |
|
@@ -137,8 +135,7 @@ class Transformer(nn.Module):
|
|
137 |
for blk in self.blocks:
|
138 |
query, key = blk(query, key, query_pos, key_pos)
|
139 |
q, k = query + query_pos, key + key_pos
|
140 |
-
query = self.dropout(self.final_attn_token_to_image(q, k, key)).add_(query)
|
141 |
-
query = self.norm(query)
|
142 |
return query, key
|
143 |
|
144 |
|
@@ -164,10 +161,10 @@ class ImageDecoder(nn.Module):
|
|
164 |
super(ImageDecoder, self).__init__()
|
165 |
self.embed_dim = embed_dim
|
166 |
self.num_mask_tokens = num_mask_tokens
|
167 |
-
self.transformer = Transformer(embed_dim, num_heads
|
168 |
self.iou_token = nn.Embedding(1, embed_dim)
|
169 |
-
self.sem_tokens = nn.Embedding(
|
170 |
-
self.mask_tokens = nn.Embedding(
|
171 |
self.output_conv = nn.Sequential(
|
172 |
nn.ConvTranspose2d(embed_dim, embed_dim // 4, 2, 2),
|
173 |
TransposedLayerNorm(embed_dim // 4),
|
@@ -178,8 +175,8 @@ class ImageDecoder(nn.Module):
|
|
178 |
self.mask_pred = nn.ModuleList(
|
179 |
Predictor(embed_dim, embed_dim // 8) for _ in range(num_mask_tokens)
|
180 |
)
|
181 |
-
self.iou_pred = Predictor(embed_dim,
|
182 |
-
self.sem_pred = Predictor(embed_dim, sem_embed_dim,
|
183 |
|
184 |
def get_outputs(self, inputs):
|
185 |
img_embeds = inputs["img_embeds"]
|
@@ -201,18 +198,17 @@ class ImageDecoder(nn.Module):
|
|
201 |
key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
|
202 |
mask_embeds = self.output_conv(key).flatten(2)
|
203 |
# Unpack query.
|
204 |
-
|
205 |
-
|
206 |
-
mask_tokens =
|
207 |
-
sem_tokens = tokens[: self.num_mask_tokens]
|
208 |
# Predict.
|
209 |
mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
|
210 |
mask_pred = torch.stack(mask_pred, dim=1) @ mask_embeds
|
211 |
mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
|
212 |
mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
|
213 |
outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
|
214 |
-
outputs["sem_tokens"] =
|
215 |
-
outputs["sem_embeds"] = self.sem_pred(outputs["sem_tokens"])
|
216 |
return outputs
|
217 |
|
218 |
def forward(self, inputs):
|
|
|
50 |
|
51 |
def __init__(self, dim=256, num_heads=8, attn_ratio=1):
|
52 |
super(Attention, self).__init__()
|
53 |
+
self.num_heads = num_heads or dim // 64
|
54 |
+
self.head_dim = int(dim * attn_ratio) // self.num_heads
|
55 |
+
self.q_proj = nn.Linear(dim, self.num_heads * self.head_dim)
|
56 |
+
self.k_proj = nn.Linear(dim, self.num_heads * self.head_dim)
|
57 |
+
self.v_proj = nn.Linear(dim, self.num_heads * self.head_dim)
|
58 |
+
self.proj = nn.Linear(self.num_heads * self.head_dim, dim)
|
|
|
59 |
self.scale = self.head_dim**-0.5
|
60 |
|
61 |
def forward(self, q, k, v):
|
|
|
99 |
q, k = query + query_pos, key + key_pos
|
100 |
query = self.norm2(self.dropout(self.cross_attn_token_to_image(q, k, key)).add_(query))
|
101 |
query = self.norm3(self.dropout(self.mlp(query)).add_(query))
|
102 |
+
key = self.norm4(self.cross_attn_image_to_token(k, query + query_pos, query).add_(key))
|
|
|
103 |
return query, key
|
104 |
|
105 |
|
|
|
135 |
for blk in self.blocks:
|
136 |
query, key = blk(query, key, query_pos, key_pos)
|
137 |
q, k = query + query_pos, key + key_pos
|
138 |
+
query = self.norm(self.dropout(self.final_attn_token_to_image(q, k, key)).add_(query))
|
|
|
139 |
return query, key
|
140 |
|
141 |
|
|
|
161 |
super(ImageDecoder, self).__init__()
|
162 |
self.embed_dim = embed_dim
|
163 |
self.num_mask_tokens = num_mask_tokens
|
164 |
+
self.transformer = Transformer(embed_dim, num_heads, depth=depth)
|
165 |
self.iou_token = nn.Embedding(1, embed_dim)
|
166 |
+
self.sem_tokens = nn.Embedding(num_mask_tokens, embed_dim)
|
167 |
+
self.mask_tokens = nn.Embedding(num_mask_tokens, embed_dim)
|
168 |
self.output_conv = nn.Sequential(
|
169 |
nn.ConvTranspose2d(embed_dim, embed_dim // 4, 2, 2),
|
170 |
TransposedLayerNorm(embed_dim // 4),
|
|
|
175 |
self.mask_pred = nn.ModuleList(
|
176 |
Predictor(embed_dim, embed_dim // 8) for _ in range(num_mask_tokens)
|
177 |
)
|
178 |
+
self.iou_pred = Predictor(embed_dim, num_mask_tokens)
|
179 |
+
self.sem_pred = Predictor(embed_dim, sem_embed_dim, sem_embed_dim)
|
180 |
|
181 |
def get_outputs(self, inputs):
|
182 |
img_embeds = inputs["img_embeds"]
|
|
|
198 |
key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
|
199 |
mask_embeds = self.output_conv(key).flatten(2)
|
200 |
# Unpack query.
|
201 |
+
sem_tokens = query[:, : self.num_mask_tokens]
|
202 |
+
sam_tokens = query[:, self.num_mask_tokens : num_tokens].unbind(1)
|
203 |
+
iou_tokens, mask_tokens = sam_tokens[0], sam_tokens[1:]
|
|
|
204 |
# Predict.
|
205 |
mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
|
206 |
mask_pred = torch.stack(mask_pred, dim=1) @ mask_embeds
|
207 |
mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
|
208 |
mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
|
209 |
outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
|
210 |
+
outputs["sem_tokens"] = sem_tokens.unsqueeze_(2)
|
211 |
+
outputs["sem_embeds"] = self.sem_pred(outputs["sem_tokens"].flatten(2))
|
212 |
return outputs
|
213 |
|
214 |
def forward(self, inputs):
|
tokenize_anything/modeling/prompt_encoder.py
CHANGED
@@ -24,21 +24,14 @@ class PromptEncoder(nn.Module):
|
|
24 |
|
25 |
def __init__(self, embed_dim, image_size):
|
26 |
super(PromptEncoder, self).__init__()
|
27 |
-
self.img_size = [image_size] * 2
|
28 |
self.point_embed = nn.Embedding(5, embed_dim) # [bg, fg, lt, rb, pad]
|
29 |
self.corner_labels = torch.tensor([[2, 3]], dtype=torch.int64)
|
30 |
self.register_buffer("coord_matrix", torch.randn((2, embed_dim // 2)))
|
31 |
-
self.img_pos = None
|
32 |
|
33 |
-
def
|
34 |
-
"""Convert input
|
35 |
-
|
36 |
-
return input
|
37 |
-
if not isinstance(input, torch.Tensor):
|
38 |
-
input = torch.from_numpy(input)
|
39 |
-
if input.device != self.coord_matrix.device:
|
40 |
-
input = input.to(device=self.coord_matrix.device)
|
41 |
-
return input
|
42 |
|
43 |
def to_points(self, points=None, boxes=None):
|
44 |
"""Convert points or boxes to point prompts."""
|
@@ -48,14 +41,14 @@ class PromptEncoder(nn.Module):
|
|
48 |
else:
|
49 |
coords, labels = points[:, :, :2], points[:, :, 2]
|
50 |
coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
|
51 |
-
coords = self.
|
52 |
-
labels = self.
|
53 |
return coords, labels
|
54 |
if boxes is not None:
|
55 |
coords = boxes.reshape((-1, 2, 2))
|
56 |
coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
|
57 |
-
coords = self.
|
58 |
-
labels = self.
|
59 |
return coords, labels
|
60 |
return None
|
61 |
|
@@ -79,7 +72,7 @@ class PromptEncoder(nn.Module):
|
|
79 |
grid = torch.ones(*grid_size, dtype=torch.float32)
|
80 |
y = grid.cumsum(dim=0).sub_(0.5).div_(grid_size[0])
|
81 |
x = grid.cumsum(dim=1).sub_(0.5).div_(grid_size[1])
|
82 |
-
coords = self.
|
83 |
return self.encode_coords(coords)
|
84 |
|
85 |
def forward(self, inputs):
|
|
|
24 |
|
25 |
def __init__(self, embed_dim, image_size):
|
26 |
super(PromptEncoder, self).__init__()
|
|
|
27 |
self.point_embed = nn.Embedding(5, embed_dim) # [bg, fg, lt, rb, pad]
|
28 |
self.corner_labels = torch.tensor([[2, 3]], dtype=torch.int64)
|
29 |
self.register_buffer("coord_matrix", torch.randn((2, embed_dim // 2)))
|
30 |
+
self.img_pos, self.img_size = None, [image_size] * 2
|
31 |
|
32 |
+
def as_tensor(self, input):
|
33 |
+
"""Convert input into a tensor."""
|
34 |
+
return torch.as_tensor(input, device=self.coord_matrix.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def to_points(self, points=None, boxes=None):
|
37 |
"""Convert points or boxes to point prompts."""
|
|
|
41 |
else:
|
42 |
coords, labels = points[:, :, :2], points[:, :, 2]
|
43 |
coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
|
44 |
+
coords = self.as_tensor(coords.clip(0, 1).astype("float32"))
|
45 |
+
labels = self.as_tensor(labels.astype("int64"))
|
46 |
return coords, labels
|
47 |
if boxes is not None:
|
48 |
coords = boxes.reshape((-1, 2, 2))
|
49 |
coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
|
50 |
+
coords = self.as_tensor(coords.clip(0, 1).astype("float32"))
|
51 |
+
labels = self.as_tensor(self.corner_labels)
|
52 |
return coords, labels
|
53 |
return None
|
54 |
|
|
|
72 |
grid = torch.ones(*grid_size, dtype=torch.float32)
|
73 |
y = grid.cumsum(dim=0).sub_(0.5).div_(grid_size[0])
|
74 |
x = grid.cumsum(dim=1).sub_(0.5).div_(grid_size[1])
|
75 |
+
coords = self.as_tensor(torch.stack([x, y], dim=-1))
|
76 |
return self.encode_coords(coords)
|
77 |
|
78 |
def forward(self, inputs):
|
tokenize_anything/models/easy_build.py
CHANGED
@@ -18,7 +18,6 @@
|
|
18 |
from functools import partial
|
19 |
import pickle
|
20 |
|
21 |
-
import numpy as np
|
22 |
import torch
|
23 |
|
24 |
from tokenize_anything.modeling import ConceptProjector
|
@@ -45,9 +44,9 @@ def load_weights(module, weights_file, strict=True):
|
|
45 |
with open(weights_file, "rb") as f:
|
46 |
state_dict = pickle.load(f)
|
47 |
for k, v in state_dict.items():
|
48 |
-
state_dict[k] = torch.
|
49 |
else:
|
50 |
-
state_dict = torch.load(weights_file)
|
51 |
module.load_state_dict(state_dict, strict=strict)
|
52 |
|
53 |
|
@@ -68,21 +67,21 @@ def vit_encoder(depth, embed_dim, num_heads, out_dim, image_size):
|
|
68 |
def image_tokenizer(image_encoder, checkpoint=None, device=0, dtype="float16", **kwargs):
|
69 |
"""Build an image tokenizer."""
|
70 |
image_size = kwargs.get("image_size", 1024)
|
71 |
-
|
72 |
sem_embed_dim = kwargs.get("sem_embed_dim", 1024)
|
73 |
text_embed_dim = kwargs.get("text_embed_dim", 512)
|
74 |
text_decoder_depth = kwargs.get("text_decoder_depth", 12)
|
75 |
text_seq_len = kwargs.get("text_seq_len", 40)
|
76 |
text_tokenizer = TextTokenizer()
|
77 |
model = ImageTokenizer(
|
78 |
-
image_encoder=image_encoder(out_dim=
|
79 |
-
prompt_encoder=PromptEncoder(embed_dim=
|
80 |
image_decoder=ImageDecoder(
|
|
|
|
|
|
|
81 |
depth=2,
|
82 |
-
embed_dim=prompt_embed_dim,
|
83 |
-
num_heads=prompt_embed_dim // 32,
|
84 |
num_mask_tokens=4,
|
85 |
-
sem_embed_dim=sem_embed_dim,
|
86 |
),
|
87 |
text_tokenizer=text_tokenizer,
|
88 |
concept_projector=ConceptProjector(),
|
@@ -90,10 +89,10 @@ def image_tokenizer(image_encoder, checkpoint=None, device=0, dtype="float16", *
|
|
90 |
depth=text_decoder_depth,
|
91 |
embed_dim=text_embed_dim,
|
92 |
num_heads=text_embed_dim // 64,
|
93 |
-
|
94 |
-
prompt_embed_dim=prompt_embed_dim,
|
95 |
max_seq_len=text_seq_len,
|
96 |
vocab_size=text_tokenizer.n_words,
|
|
|
97 |
),
|
98 |
)
|
99 |
load_weights(model, checkpoint)
|
|
|
18 |
from functools import partial
|
19 |
import pickle
|
20 |
|
|
|
21 |
import torch
|
22 |
|
23 |
from tokenize_anything.modeling import ConceptProjector
|
|
|
44 |
with open(weights_file, "rb") as f:
|
45 |
state_dict = pickle.load(f)
|
46 |
for k, v in state_dict.items():
|
47 |
+
state_dict[k] = torch.as_tensor(v)
|
48 |
else:
|
49 |
+
state_dict = torch.load(weights_file, map_location="cpu")
|
50 |
module.load_state_dict(state_dict, strict=strict)
|
51 |
|
52 |
|
|
|
67 |
def image_tokenizer(image_encoder, checkpoint=None, device=0, dtype="float16", **kwargs):
|
68 |
"""Build an image tokenizer."""
|
69 |
image_size = kwargs.get("image_size", 1024)
|
70 |
+
image_embed_dim = kwargs.get("image_embed_dim", 256)
|
71 |
sem_embed_dim = kwargs.get("sem_embed_dim", 1024)
|
72 |
text_embed_dim = kwargs.get("text_embed_dim", 512)
|
73 |
text_decoder_depth = kwargs.get("text_decoder_depth", 12)
|
74 |
text_seq_len = kwargs.get("text_seq_len", 40)
|
75 |
text_tokenizer = TextTokenizer()
|
76 |
model = ImageTokenizer(
|
77 |
+
image_encoder=image_encoder(out_dim=image_embed_dim, image_size=image_size),
|
78 |
+
prompt_encoder=PromptEncoder(embed_dim=image_embed_dim, image_size=image_size),
|
79 |
image_decoder=ImageDecoder(
|
80 |
+
embed_dim=image_embed_dim,
|
81 |
+
num_heads=image_embed_dim // 32,
|
82 |
+
sem_embed_dim=sem_embed_dim,
|
83 |
depth=2,
|
|
|
|
|
84 |
num_mask_tokens=4,
|
|
|
85 |
),
|
86 |
text_tokenizer=text_tokenizer,
|
87 |
concept_projector=ConceptProjector(),
|
|
|
89 |
depth=text_decoder_depth,
|
90 |
embed_dim=text_embed_dim,
|
91 |
num_heads=text_embed_dim // 64,
|
92 |
+
prompt_embed_dim=image_embed_dim,
|
|
|
93 |
max_seq_len=text_seq_len,
|
94 |
vocab_size=text_tokenizer.n_words,
|
95 |
+
mlp_ratio=4,
|
96 |
),
|
97 |
)
|
98 |
load_weights(model, checkpoint)
|
tokenize_anything/version.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
version = "
|
2 |
git_version = "None"
|
3 |
__version__ = version
|
|
|
1 |
+
version = "1.1.0a0"
|
2 |
git_version = "None"
|
3 |
__version__ = version
|