PhyscalX commited on
Commit
4cee877
·
1 Parent(s): 825a49c

TAP v1.1 models release

Browse files
app.py CHANGED
@@ -32,7 +32,7 @@ def parse_args():
32
  """Parse arguments."""
33
  parser = argparse.ArgumentParser(description="Launch gradio application")
34
  parser.add_argument("--model-type", type=str, default="tap_vit_l")
35
- parser.add_argument("--checkpoint", type=str, default="models/tap_vit_l_548184.pkl")
36
  parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl")
37
  parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices")
38
  return parser.parse_args()
 
32
  """Parse arguments."""
33
  parser = argparse.ArgumentParser(description="Launch gradio application")
34
  parser.add_argument("--model-type", type=str, default="tap_vit_l")
35
+ parser.add_argument("--checkpoint", type=str, default="models/tap_vit_l_v1_1.pkl")
36
  parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl")
37
  parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices")
38
  return parser.parse_args()
models/tap_vit_l_548184.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e1d3a11c572af8cb6bce8016d3a6c6948bba4959ea43811f0e984b9eafeee413
3
- size 811637521
 
 
 
 
models/{tap_vit_l_03f8ec.pkl → tap_vit_l_v1_1.pkl} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d63a5aba993c34bf29c0466026136e18e25d2bd4ac9e51b8fc407b76c431707d
3
- size 811637521
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7301f73267786bfafc7223931b3fd51baf111d1c24bcc1278ea6378139f1067
3
+ size 811637487
tokenize_anything/layers/utils.py CHANGED
@@ -26,7 +26,8 @@ def init_cross_conv(blocks):
26
  if isinstance(m, torch.nn.Conv2d):
27
  torch.nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
28
  for blk in blocks:
29
- torch.nn.init.constant_(blk.norm3.weight, 0)
 
30
 
31
 
32
  def set_dropout(module, dropout):
 
26
  if isinstance(m, torch.nn.Conv2d):
27
  torch.nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
28
  for blk in blocks:
29
+ if hasattr(blk, "norm3") and hasattr(blk.norm3, "weight"):
30
+ torch.nn.init.zeros_(blk.norm3.weight)
31
 
32
 
33
  def set_dropout(module, dropout):
tokenize_anything/modeling/concept_projector.py CHANGED
@@ -31,12 +31,12 @@ class ConceptProjector(nn.Module):
31
 
32
  def reset_weights(self, src_weights=None, tgt_weights=None):
33
  """Reset the normalized projection weights."""
34
- if src_weights is not None:
35
  with open(src_weights, "rb") as f:
36
  self.src_weights, self.concepts = pickle.load(f)
37
  self.src_weights = torch.from_numpy(self.src_weights)
38
  self.concepts = np.array(self.concepts)
39
- if tgt_weights is not None:
40
  with open(tgt_weights, "rb") as f:
41
  self.tgt_weights, self.concepts = pickle.load(f)
42
  self.tgt_weights = torch.from_numpy(self.tgt_weights)
 
31
 
32
  def reset_weights(self, src_weights=None, tgt_weights=None):
33
  """Reset the normalized projection weights."""
34
+ if src_weights:
35
  with open(src_weights, "rb") as f:
36
  self.src_weights, self.concepts = pickle.load(f)
37
  self.src_weights = torch.from_numpy(self.src_weights)
38
  self.concepts = np.array(self.concepts)
39
+ if tgt_weights:
40
  with open(tgt_weights, "rb") as f:
41
  self.tgt_weights, self.concepts = pickle.load(f)
42
  self.tgt_weights = torch.from_numpy(self.tgt_weights)
tokenize_anything/modeling/image_decoder.py CHANGED
@@ -50,13 +50,12 @@ class Attention(nn.Module):
50
 
51
  def __init__(self, dim=256, num_heads=8, attn_ratio=1):
52
  super(Attention, self).__init__()
53
- qkv_dim = int(dim * attn_ratio)
54
- self.num_heads = num_heads
55
- self.head_dim = qkv_dim // num_heads
56
- self.q_proj = nn.Linear(dim, qkv_dim)
57
- self.k_proj = nn.Linear(dim, qkv_dim)
58
- self.v_proj = nn.Linear(dim, qkv_dim)
59
- self.proj = nn.Linear(qkv_dim, dim)
60
  self.scale = self.head_dim**-0.5
61
 
62
  def forward(self, q, k, v):
@@ -100,8 +99,7 @@ class Block(nn.Module):
100
  q, k = query + query_pos, key + key_pos
101
  query = self.norm2(self.dropout(self.cross_attn_token_to_image(q, k, key)).add_(query))
102
  query = self.norm3(self.dropout(self.mlp(query)).add_(query))
103
- q = query + query_pos
104
- key = self.norm4(self.cross_attn_image_to_token(k, q, query).add_(key))
105
  return query, key
106
 
107
 
@@ -137,8 +135,7 @@ class Transformer(nn.Module):
137
  for blk in self.blocks:
138
  query, key = blk(query, key, query_pos, key_pos)
139
  q, k = query + query_pos, key + key_pos
140
- query = self.dropout(self.final_attn_token_to_image(q, k, key)).add_(query)
141
- query = self.norm(query)
142
  return query, key
143
 
144
 
@@ -164,10 +161,10 @@ class ImageDecoder(nn.Module):
164
  super(ImageDecoder, self).__init__()
165
  self.embed_dim = embed_dim
166
  self.num_mask_tokens = num_mask_tokens
167
- self.transformer = Transformer(embed_dim, num_heads=num_heads, depth=depth)
168
  self.iou_token = nn.Embedding(1, embed_dim)
169
- self.sem_tokens = nn.Embedding(self.num_mask_tokens, embed_dim)
170
- self.mask_tokens = nn.Embedding(self.num_mask_tokens, embed_dim)
171
  self.output_conv = nn.Sequential(
172
  nn.ConvTranspose2d(embed_dim, embed_dim // 4, 2, 2),
173
  TransposedLayerNorm(embed_dim // 4),
@@ -178,8 +175,8 @@ class ImageDecoder(nn.Module):
178
  self.mask_pred = nn.ModuleList(
179
  Predictor(embed_dim, embed_dim // 8) for _ in range(num_mask_tokens)
180
  )
181
- self.iou_pred = Predictor(embed_dim, self.num_mask_tokens)
182
- self.sem_pred = Predictor(embed_dim, sem_embed_dim, 1024)
183
 
184
  def get_outputs(self, inputs):
185
  img_embeds = inputs["img_embeds"]
@@ -201,18 +198,17 @@ class ImageDecoder(nn.Module):
201
  key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
202
  mask_embeds = self.output_conv(key).flatten(2)
203
  # Unpack query.
204
- tokens = query[:, :num_tokens].unbind(dim=1)
205
- iou_tokens = tokens[num_tokens - self.num_mask_tokens - 1]
206
- mask_tokens = tokens[num_tokens - self.num_mask_tokens :]
207
- sem_tokens = tokens[: self.num_mask_tokens]
208
  # Predict.
209
  mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
210
  mask_pred = torch.stack(mask_pred, dim=1) @ mask_embeds
211
  mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
212
  mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
213
  outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
214
- outputs["sem_tokens"] = torch.stack(sem_tokens, dim=1)
215
- outputs["sem_embeds"] = self.sem_pred(outputs["sem_tokens"])
216
  return outputs
217
 
218
  def forward(self, inputs):
 
50
 
51
  def __init__(self, dim=256, num_heads=8, attn_ratio=1):
52
  super(Attention, self).__init__()
53
+ self.num_heads = num_heads or dim // 64
54
+ self.head_dim = int(dim * attn_ratio) // self.num_heads
55
+ self.q_proj = nn.Linear(dim, self.num_heads * self.head_dim)
56
+ self.k_proj = nn.Linear(dim, self.num_heads * self.head_dim)
57
+ self.v_proj = nn.Linear(dim, self.num_heads * self.head_dim)
58
+ self.proj = nn.Linear(self.num_heads * self.head_dim, dim)
 
59
  self.scale = self.head_dim**-0.5
60
 
61
  def forward(self, q, k, v):
 
99
  q, k = query + query_pos, key + key_pos
100
  query = self.norm2(self.dropout(self.cross_attn_token_to_image(q, k, key)).add_(query))
101
  query = self.norm3(self.dropout(self.mlp(query)).add_(query))
102
+ key = self.norm4(self.cross_attn_image_to_token(k, query + query_pos, query).add_(key))
 
103
  return query, key
104
 
105
 
 
135
  for blk in self.blocks:
136
  query, key = blk(query, key, query_pos, key_pos)
137
  q, k = query + query_pos, key + key_pos
138
+ query = self.norm(self.dropout(self.final_attn_token_to_image(q, k, key)).add_(query))
 
139
  return query, key
140
 
141
 
 
161
  super(ImageDecoder, self).__init__()
162
  self.embed_dim = embed_dim
163
  self.num_mask_tokens = num_mask_tokens
164
+ self.transformer = Transformer(embed_dim, num_heads, depth=depth)
165
  self.iou_token = nn.Embedding(1, embed_dim)
166
+ self.sem_tokens = nn.Embedding(num_mask_tokens, embed_dim)
167
+ self.mask_tokens = nn.Embedding(num_mask_tokens, embed_dim)
168
  self.output_conv = nn.Sequential(
169
  nn.ConvTranspose2d(embed_dim, embed_dim // 4, 2, 2),
170
  TransposedLayerNorm(embed_dim // 4),
 
175
  self.mask_pred = nn.ModuleList(
176
  Predictor(embed_dim, embed_dim // 8) for _ in range(num_mask_tokens)
177
  )
178
+ self.iou_pred = Predictor(embed_dim, num_mask_tokens)
179
+ self.sem_pred = Predictor(embed_dim, sem_embed_dim, sem_embed_dim)
180
 
181
  def get_outputs(self, inputs):
182
  img_embeds = inputs["img_embeds"]
 
198
  key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
199
  mask_embeds = self.output_conv(key).flatten(2)
200
  # Unpack query.
201
+ sem_tokens = query[:, : self.num_mask_tokens]
202
+ sam_tokens = query[:, self.num_mask_tokens : num_tokens].unbind(1)
203
+ iou_tokens, mask_tokens = sam_tokens[0], sam_tokens[1:]
 
204
  # Predict.
205
  mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
206
  mask_pred = torch.stack(mask_pred, dim=1) @ mask_embeds
207
  mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
208
  mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
209
  outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
210
+ outputs["sem_tokens"] = sem_tokens.unsqueeze_(2)
211
+ outputs["sem_embeds"] = self.sem_pred(outputs["sem_tokens"].flatten(2))
212
  return outputs
213
 
214
  def forward(self, inputs):
tokenize_anything/modeling/prompt_encoder.py CHANGED
@@ -24,21 +24,14 @@ class PromptEncoder(nn.Module):
24
 
25
  def __init__(self, embed_dim, image_size):
26
  super(PromptEncoder, self).__init__()
27
- self.img_size = [image_size] * 2
28
  self.point_embed = nn.Embedding(5, embed_dim) # [bg, fg, lt, rb, pad]
29
  self.corner_labels = torch.tensor([[2, 3]], dtype=torch.int64)
30
  self.register_buffer("coord_matrix", torch.randn((2, embed_dim // 2)))
31
- self.img_pos = None
32
 
33
- def to_tensor(self, input):
34
- """Convert input to tensor."""
35
- if input is None:
36
- return input
37
- if not isinstance(input, torch.Tensor):
38
- input = torch.from_numpy(input)
39
- if input.device != self.coord_matrix.device:
40
- input = input.to(device=self.coord_matrix.device)
41
- return input
42
 
43
  def to_points(self, points=None, boxes=None):
44
  """Convert points or boxes to point prompts."""
@@ -48,14 +41,14 @@ class PromptEncoder(nn.Module):
48
  else:
49
  coords, labels = points[:, :, :2], points[:, :, 2]
50
  coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
51
- coords = self.to_tensor(coords.clip(0, 1).astype("float32"))
52
- labels = self.to_tensor(labels.astype("int64"))
53
  return coords, labels
54
  if boxes is not None:
55
  coords = boxes.reshape((-1, 2, 2))
56
  coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
57
- coords = self.to_tensor(coords.clip(0, 1).astype("float32"))
58
- labels = self.to_tensor(self.corner_labels)
59
  return coords, labels
60
  return None
61
 
@@ -79,7 +72,7 @@ class PromptEncoder(nn.Module):
79
  grid = torch.ones(*grid_size, dtype=torch.float32)
80
  y = grid.cumsum(dim=0).sub_(0.5).div_(grid_size[0])
81
  x = grid.cumsum(dim=1).sub_(0.5).div_(grid_size[1])
82
- coords = self.to_tensor(torch.stack([x, y], dim=-1))
83
  return self.encode_coords(coords)
84
 
85
  def forward(self, inputs):
 
24
 
25
  def __init__(self, embed_dim, image_size):
26
  super(PromptEncoder, self).__init__()
 
27
  self.point_embed = nn.Embedding(5, embed_dim) # [bg, fg, lt, rb, pad]
28
  self.corner_labels = torch.tensor([[2, 3]], dtype=torch.int64)
29
  self.register_buffer("coord_matrix", torch.randn((2, embed_dim // 2)))
30
+ self.img_pos, self.img_size = None, [image_size] * 2
31
 
32
+ def as_tensor(self, input):
33
+ """Convert input into a tensor."""
34
+ return torch.as_tensor(input, device=self.coord_matrix.device)
 
 
 
 
 
 
35
 
36
  def to_points(self, points=None, boxes=None):
37
  """Convert points or boxes to point prompts."""
 
41
  else:
42
  coords, labels = points[:, :, :2], points[:, :, 2]
43
  coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
44
+ coords = self.as_tensor(coords.clip(0, 1).astype("float32"))
45
+ labels = self.as_tensor(labels.astype("int64"))
46
  return coords, labels
47
  if boxes is not None:
48
  coords = boxes.reshape((-1, 2, 2))
49
  coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
50
+ coords = self.as_tensor(coords.clip(0, 1).astype("float32"))
51
+ labels = self.as_tensor(self.corner_labels)
52
  return coords, labels
53
  return None
54
 
 
72
  grid = torch.ones(*grid_size, dtype=torch.float32)
73
  y = grid.cumsum(dim=0).sub_(0.5).div_(grid_size[0])
74
  x = grid.cumsum(dim=1).sub_(0.5).div_(grid_size[1])
75
+ coords = self.as_tensor(torch.stack([x, y], dim=-1))
76
  return self.encode_coords(coords)
77
 
78
  def forward(self, inputs):
tokenize_anything/models/easy_build.py CHANGED
@@ -18,7 +18,6 @@
18
  from functools import partial
19
  import pickle
20
 
21
- import numpy as np
22
  import torch
23
 
24
  from tokenize_anything.modeling import ConceptProjector
@@ -45,9 +44,9 @@ def load_weights(module, weights_file, strict=True):
45
  with open(weights_file, "rb") as f:
46
  state_dict = pickle.load(f)
47
  for k, v in state_dict.items():
48
- state_dict[k] = torch.from_numpy(v) if isinstance(v, np.ndarray) else v
49
  else:
50
- state_dict = torch.load(weights_file)
51
  module.load_state_dict(state_dict, strict=strict)
52
 
53
 
@@ -68,21 +67,21 @@ def vit_encoder(depth, embed_dim, num_heads, out_dim, image_size):
68
  def image_tokenizer(image_encoder, checkpoint=None, device=0, dtype="float16", **kwargs):
69
  """Build an image tokenizer."""
70
  image_size = kwargs.get("image_size", 1024)
71
- prompt_embed_dim = kwargs.get("prompt_embed_dim", 256)
72
  sem_embed_dim = kwargs.get("sem_embed_dim", 1024)
73
  text_embed_dim = kwargs.get("text_embed_dim", 512)
74
  text_decoder_depth = kwargs.get("text_decoder_depth", 12)
75
  text_seq_len = kwargs.get("text_seq_len", 40)
76
  text_tokenizer = TextTokenizer()
77
  model = ImageTokenizer(
78
- image_encoder=image_encoder(out_dim=prompt_embed_dim, image_size=image_size),
79
- prompt_encoder=PromptEncoder(embed_dim=prompt_embed_dim, image_size=image_size),
80
  image_decoder=ImageDecoder(
 
 
 
81
  depth=2,
82
- embed_dim=prompt_embed_dim,
83
- num_heads=prompt_embed_dim // 32,
84
  num_mask_tokens=4,
85
- sem_embed_dim=sem_embed_dim,
86
  ),
87
  text_tokenizer=text_tokenizer,
88
  concept_projector=ConceptProjector(),
@@ -90,10 +89,10 @@ def image_tokenizer(image_encoder, checkpoint=None, device=0, dtype="float16", *
90
  depth=text_decoder_depth,
91
  embed_dim=text_embed_dim,
92
  num_heads=text_embed_dim // 64,
93
- mlp_ratio=4,
94
- prompt_embed_dim=prompt_embed_dim,
95
  max_seq_len=text_seq_len,
96
  vocab_size=text_tokenizer.n_words,
 
97
  ),
98
  )
99
  load_weights(model, checkpoint)
 
18
  from functools import partial
19
  import pickle
20
 
 
21
  import torch
22
 
23
  from tokenize_anything.modeling import ConceptProjector
 
44
  with open(weights_file, "rb") as f:
45
  state_dict = pickle.load(f)
46
  for k, v in state_dict.items():
47
+ state_dict[k] = torch.as_tensor(v)
48
  else:
49
+ state_dict = torch.load(weights_file, map_location="cpu")
50
  module.load_state_dict(state_dict, strict=strict)
51
 
52
 
 
67
  def image_tokenizer(image_encoder, checkpoint=None, device=0, dtype="float16", **kwargs):
68
  """Build an image tokenizer."""
69
  image_size = kwargs.get("image_size", 1024)
70
+ image_embed_dim = kwargs.get("image_embed_dim", 256)
71
  sem_embed_dim = kwargs.get("sem_embed_dim", 1024)
72
  text_embed_dim = kwargs.get("text_embed_dim", 512)
73
  text_decoder_depth = kwargs.get("text_decoder_depth", 12)
74
  text_seq_len = kwargs.get("text_seq_len", 40)
75
  text_tokenizer = TextTokenizer()
76
  model = ImageTokenizer(
77
+ image_encoder=image_encoder(out_dim=image_embed_dim, image_size=image_size),
78
+ prompt_encoder=PromptEncoder(embed_dim=image_embed_dim, image_size=image_size),
79
  image_decoder=ImageDecoder(
80
+ embed_dim=image_embed_dim,
81
+ num_heads=image_embed_dim // 32,
82
+ sem_embed_dim=sem_embed_dim,
83
  depth=2,
 
 
84
  num_mask_tokens=4,
 
85
  ),
86
  text_tokenizer=text_tokenizer,
87
  concept_projector=ConceptProjector(),
 
89
  depth=text_decoder_depth,
90
  embed_dim=text_embed_dim,
91
  num_heads=text_embed_dim // 64,
92
+ prompt_embed_dim=image_embed_dim,
 
93
  max_seq_len=text_seq_len,
94
  vocab_size=text_tokenizer.n_words,
95
+ mlp_ratio=4,
96
  ),
97
  )
98
  load_weights(model, checkpoint)
tokenize_anything/version.py CHANGED
@@ -1,3 +1,3 @@
1
- version = "0.1.0a0"
2
  git_version = "None"
3
  __version__ = version
 
1
+ version = "1.1.0a0"
2
  git_version = "None"
3
  __version__ = version