PhyscalX commited on
Commit
825a49c
Β·
1 Parent(s): ae507fe

Sync with main repo

Browse files
app.py CHANGED
@@ -23,18 +23,18 @@ import time
23
  import numpy as np
24
  import torch
25
 
26
- from tokenize_anything import test_engine
27
  from tokenize_anything.utils.image import im_rescale
28
  from tokenize_anything.utils.image import im_vstack
29
 
30
 
31
  def parse_args():
32
  """Parse arguments."""
33
- parser = argparse.ArgumentParser(description="Launch gradio app.")
34
  parser.add_argument("--model-type", type=str, default="tap_vit_l")
35
- parser.add_argument("--checkpoint", type=str, default="models/tap_vit_l_03f8ec.pkl")
36
  parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl")
37
- parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices.")
38
  return parser.parse_args()
39
 
40
 
@@ -94,7 +94,7 @@ class Predictor(object):
94
  # Generate captions.
95
  sem_tokens = outputs["sem_tokens"][mask_index].unsqueeze_(1)
96
  captions = self.model.generate_text(sem_tokens).reshape(batch_shape)
97
- # Postprecess results.
98
  results = []
99
  for i in range(batch_shape[0]):
100
  pred_h, pred_w = im_info[i, :2].astype("int")
@@ -227,7 +227,7 @@ if __name__ == "__main__":
227
  args = parse_args()
228
  queues = [mp.Queue(1024) for _ in range(len(args.device) + 1)]
229
  commands = [
230
- test_engine.InferenceCommand(
231
  queues[i],
232
  queues[-1],
233
  kwargs={
 
23
  import numpy as np
24
  import torch
25
 
26
+ from tokenize_anything import engine
27
  from tokenize_anything.utils.image import im_rescale
28
  from tokenize_anything.utils.image import im_vstack
29
 
30
 
31
  def parse_args():
32
  """Parse arguments."""
33
+ parser = argparse.ArgumentParser(description="Launch gradio application")
34
  parser.add_argument("--model-type", type=str, default="tap_vit_l")
35
+ parser.add_argument("--checkpoint", type=str, default="models/tap_vit_l_548184.pkl")
36
  parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl")
37
+ parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices")
38
  return parser.parse_args()
39
 
40
 
 
94
  # Generate captions.
95
  sem_tokens = outputs["sem_tokens"][mask_index].unsqueeze_(1)
96
  captions = self.model.generate_text(sem_tokens).reshape(batch_shape)
97
+ # Postprocess results.
98
  results = []
99
  for i in range(batch_shape[0]):
100
  pred_h, pred_w = im_info[i, :2].astype("int")
 
227
  args = parse_args()
228
  queues = [mp.Queue(1024) for _ in range(len(args.device) + 1)]
229
  commands = [
230
+ engine.InferenceCommand(
231
  queues[i],
232
  queues[-1],
233
  kwargs={
models/tap_vit_l_548184.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1d3a11c572af8cb6bce8016d3a6c6948bba4959ea43811f0e984b9eafeee413
3
+ size 811637521
tokenize_anything/__init__.py CHANGED
@@ -15,5 +15,5 @@
15
  # ------------------------------------------------------------------------
16
  """Tokenize Anything via Prompting."""
17
 
18
- from tokenize_anything.build_model import model_registry
19
  from tokenize_anything.version import __version__
 
15
  # ------------------------------------------------------------------------
16
  """Tokenize Anything via Prompting."""
17
 
18
+ from tokenize_anything.models import model_registry
19
  from tokenize_anything.version import __version__
tokenize_anything/engine/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Engine components."""
17
+
18
+ from tokenize_anything.engine.build import build_tensorboard
19
+ from tokenize_anything.engine.test_engine import InferenceCommand
20
+ from tokenize_anything.engine.utils import apply_ddp_group
21
+ from tokenize_anything.engine.utils import count_params
22
+ from tokenize_anything.engine.utils import create_ddp_group
23
+ from tokenize_anything.engine.utils import freeze_module
24
+ from tokenize_anything.engine.utils import get_ddp_group
25
+ from tokenize_anything.engine.utils import get_ddp_rank
26
+ from tokenize_anything.engine.utils import get_device
27
+ from tokenize_anything.engine.utils import get_param_groups
28
+ from tokenize_anything.engine.utils import load_weights
29
+ from tokenize_anything.engine.utils import manual_seed
tokenize_anything/engine/build.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Build for engine."""
17
+
18
+
19
+ def build_tensorboard(log_dir):
20
+ """Build the tensorboard."""
21
+ from tokenize_anything.utils.tensorboard import TensorBoard
22
+
23
+ if TensorBoard.is_available():
24
+ return TensorBoard(log_dir)
25
+ return None
tokenize_anything/engine/lr_scheduler.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Learning rate schedulers."""
17
+
18
+ import math
19
+
20
+
21
+ class ConstantLR(object):
22
+ """Constant LR scheduler."""
23
+
24
+ def __init__(self, **kwargs):
25
+ self._lr_max = kwargs.pop("lr_max")
26
+ self._lr_min = kwargs.pop("lr_min", 0)
27
+ self._warmup_steps = kwargs.pop("warmup_steps", 0)
28
+ self._warmup_factor = kwargs.pop("warmup_factor", 0)
29
+ if kwargs:
30
+ raise ValueError("Unexpected arguments: " + ",".join(v for v in kwargs))
31
+ self._step_count = 0
32
+ self._last_decay = 1.0
33
+
34
+ def step(self):
35
+ self._step_count += 1
36
+
37
+ def get_lr(self):
38
+ if self._step_count < self._warmup_steps:
39
+ alpha = (self._step_count + 1.0) / self._warmup_steps
40
+ return self._lr_max * (alpha + (1.0 - alpha) * self._warmup_factor)
41
+ return self._lr_min + (self._lr_max - self._lr_min) * self.get_decay()
42
+
43
+ def get_decay(self):
44
+ return self._last_decay
45
+
46
+
47
+ class CosineLR(ConstantLR):
48
+ """LR scheduler with cosine decay."""
49
+
50
+ def __init__(self, lr_max, max_steps, lr_min=0, decay_step=1, **kwargs):
51
+ super(CosineLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs)
52
+ self._decay_step = decay_step
53
+ self._max_steps = max_steps
54
+
55
+ def get_decay(self):
56
+ t = self._step_count - self._warmup_steps
57
+ t_max = self._max_steps - self._warmup_steps
58
+ if t > 0 and t % self._decay_step == 0:
59
+ self._last_decay = 0.5 * (1.0 + math.cos(math.pi * t / t_max))
60
+ return self._last_decay
61
+
62
+
63
+ class LinearLR(ConstantLR):
64
+ """LR scheduler with linear decay."""
65
+
66
+ def __init__(self, lr_max, max_steps, lr_min=0, decay_step=1, **kwargs):
67
+ super(LinearLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs)
68
+ self._decay_step = decay_step
69
+ self._max_steps = max_steps
70
+
71
+ def get_decay(self):
72
+ t = self._step_count - self._warmup_steps
73
+ t_max = self._max_steps - self._warmup_steps
74
+ if t > 0 and t % self._decay_step == 0:
75
+ self._last_decay = 1.0 - float(t) / t_max
76
+ return self._last_decay
tokenize_anything/{test_engine.py β†’ engine/test_engine.py} RENAMED
@@ -17,7 +17,7 @@
17
 
18
  import time
19
 
20
- from tokenize_anything.build_model import model_registry
21
 
22
 
23
  class InferenceCommand(object):
 
17
 
18
  import time
19
 
20
+ from tokenize_anything.models.easy_build import model_registry
21
 
22
 
23
  class InferenceCommand(object):
tokenize_anything/engine/utils.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Engine utilities."""
17
+
18
+ import collections
19
+ import functools
20
+ import pickle
21
+
22
+ import torch
23
+ import numpy as np
24
+
25
+ from tokenize_anything.utils import logging
26
+
27
+ GLOBAL_DDP_GROUP = None
28
+
29
+
30
+ def count_params(module, trainable=True, unit="M"):
31
+ """Return the number of parameters."""
32
+ counts = [v.size().numel() for v in module.parameters() if v.requires_grad or (not trainable)]
33
+ return sum(counts) / {"M": 1e6, "B": 1e9}[unit]
34
+
35
+
36
+ def freeze_module(module):
37
+ """Freeze parameters of given module."""
38
+ module.eval()
39
+ for param in module.parameters():
40
+ param.requires_grad = False
41
+
42
+
43
+ def get_device(index):
44
+ """Create the available device object."""
45
+ if torch.cuda.is_available():
46
+ return torch.device("cuda", index)
47
+ for device_type in ("mps",):
48
+ try:
49
+ if getattr(torch.backends, device_type).is_available():
50
+ return torch.device(device_type, index)
51
+ except AttributeError:
52
+ pass
53
+ return torch.device("cpu")
54
+
55
+
56
+ def get_param_groups(module, layer_lr_decay=1.0):
57
+ """Separate parameters into groups."""
58
+ memo, groups, inner = {}, collections.OrderedDict(), module
59
+ if isinstance(module, torch.nn.parallel.DistributedDataParallel):
60
+ inner = module.module
61
+ lr_scale_getter = None
62
+ if layer_lr_decay < 1.0 and hasattr(inner.image_encoder, "get_lr_scale"):
63
+ lr_scale_getter = functools.partial(inner.image_encoder.get_lr_scale, decay=layer_lr_decay)
64
+ for name, param in module.named_parameters():
65
+ if not param.requires_grad:
66
+ continue
67
+ attrs = collections.OrderedDict()
68
+ if lr_scale_getter:
69
+ attrs["lr_scale"] = lr_scale_getter(name)
70
+ memo[name] = param.shape
71
+ no_weight_decay = not (name.endswith("weight") and param.dim() > 1)
72
+ no_weight_decay = getattr(param, "no_weight_decay", no_weight_decay)
73
+ if no_weight_decay:
74
+ attrs["weight_decay"] = 0
75
+ group_name = "/".join(["%s:%s" % (v[0], v[1]) for v in list(attrs.items())])
76
+ if group_name not in groups:
77
+ groups[group_name] = {"params": []}
78
+ groups[group_name].update(attrs)
79
+ groups[group_name]["params"].append(param)
80
+ return list(groups.values())
81
+
82
+
83
+ def load_weights(module, weights_file, prefix_removed="", strict=True):
84
+ """Load a weights file."""
85
+ if not weights_file:
86
+ return
87
+ if weights_file.endswith(".pkl"):
88
+ with open(weights_file, "rb") as f:
89
+ state_dict = pickle.load(f)
90
+ for k, v in state_dict.items():
91
+ state_dict[k] = torch.from_numpy(v) if isinstance(v, np.ndarray) else v
92
+ else:
93
+ state_dict = torch.load(weights_file)
94
+ if prefix_removed:
95
+ new_state_dict = type(state_dict)()
96
+ for k in list(state_dict.keys()):
97
+ new_state_dict[k.replace(prefix_removed, "")] = state_dict.pop(k)
98
+ state_dict = new_state_dict
99
+ module.load_state_dict(state_dict, strict=strict)
100
+
101
+
102
+ def manual_seed(seed, device_and_seed=None):
103
+ """Set the cpu and device random seed."""
104
+ torch.manual_seed(seed)
105
+ if device_and_seed is not None:
106
+ device_index, device_seed = device_and_seed
107
+ device_type = get_device(device_index).type
108
+ np.random.seed(device_seed)
109
+ if device_type in ("cuda", "mps"):
110
+ getattr(torch, device_type).manual_seed(device_seed)
111
+
112
+
113
+ def synchronize_device(device):
114
+ """Synchronize the computation of device."""
115
+ if device.type in ("cuda", "mps"):
116
+ getattr(torch, device.type).synchronize(device)
117
+
118
+
119
+ def create_ddp_group(cfg, ranks=None, devices=None, num_nodes=1):
120
+ """Create group for data parallelism."""
121
+ if not torch.distributed.is_initialized():
122
+ torch.distributed.init_process_group(backend="nccl")
123
+ world_rank = torch.distributed.get_rank()
124
+ ranks = ranks if ranks else [i for i in range(cfg.NUM_GPUS)]
125
+ logging.set_root(world_rank == ranks[0])
126
+ devices_per_node = len(ranks) // num_nodes
127
+ devices = devices if devices else [i % devices_per_node for i in range(len(ranks))]
128
+ cfg.GPU_ID = devices[world_rank]
129
+ torch.cuda.set_device(cfg.GPU_ID)
130
+ global GLOBAL_DDP_GROUP
131
+ GLOBAL_DDP_GROUP = torch.distributed.new_group(ranks)
132
+ return GLOBAL_DDP_GROUP
133
+
134
+
135
+ def get_ddp_group():
136
+ """Return the process group for data parallelism."""
137
+ return GLOBAL_DDP_GROUP
138
+
139
+
140
+ def get_ddp_rank():
141
+ """Return the rank in the data parallelism group."""
142
+ ddp_group = get_ddp_group()
143
+ if ddp_group is None:
144
+ return 0
145
+ return torch.distributed.get_rank(ddp_group)
146
+
147
+
148
+ def apply_ddp_group(module):
149
+ """Apply data parallelism group for given module."""
150
+ ddp_group = get_ddp_group()
151
+ if ddp_group is None:
152
+ return module
153
+ return torch.nn.parallel.DistributedDataParallel(module, process_group=ddp_group)
tokenize_anything/layers/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Layers."""
17
+
18
+ from tokenize_anything.layers.drop import DropPath
19
+ from tokenize_anything.layers.loss import BinaryDiceLoss
20
+ from tokenize_anything.layers.loss import BinaryFocalLoss
21
+ from tokenize_anything.layers.loss import CrossEntropyLoss
22
+ from tokenize_anything.layers.utils import init_cross_conv
23
+ from tokenize_anything.layers.utils import resize_pos_embed
24
+ from tokenize_anything.layers.utils import set_dropout
25
+ from tokenize_anything.layers.utils import set_drop_path
26
+ from tokenize_anything.layers.utils import set_sync_batch_norm
tokenize_anything/layers/drop.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Drop regularization layers."""
17
+
18
+ from torch import nn
19
+
20
+
21
+ class DropPath(nn.Module):
22
+ """Set examples to zero randomly."""
23
+
24
+ def __init__(self, p=0.1, inplace=False):
25
+ super(DropPath, self).__init__()
26
+ self.p = p
27
+ self.inplace = inplace
28
+
29
+ def forward(self, input):
30
+ if not self.training or self.p <= 0:
31
+ return input
32
+ keep_p = 1 - self.p
33
+ shape = (input.shape[0],) + (1,) * (input.dim() - 1)
34
+ scale = input.new_empty(shape).bernoulli_(keep_p).div_(keep_p)
35
+ return input.mul_(scale) if self.inplace else input.mul(scale)
36
+
37
+ def extra_repr(self):
38
+ inplace_str = ", inplace" if self.inplace else ""
39
+ return "p={}{}".format(self.p, inplace_str)
tokenize_anything/layers/loss.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Loss layers."""
17
+
18
+ from torch import nn
19
+
20
+
21
+ def reduce_loss(loss, reduction="mean"):
22
+ """Reduce the loss."""
23
+ if reduction == "mean" or reduction == "sum":
24
+ return getattr(loss, reduction)()
25
+ if reduction == "batch_mean":
26
+ return loss.sum().mul_(1.0 / loss.size(0))
27
+ return loss
28
+
29
+
30
+ class BinaryFocalLoss(nn.Module):
31
+ """Binary focal loss."""
32
+
33
+ def __init__(self, alpha=0.25, reduction="none"):
34
+ super(BinaryFocalLoss, self).__init__()
35
+ self.alpha = alpha
36
+ self.reduction = reduction
37
+
38
+ def forward(self, input, target):
39
+ alpha, p = self.alpha, input.sigmoid()
40
+ neg_alpha, neg_target = 1.0 - alpha, 1.0 - target
41
+ alpha_weight = target.mul(alpha).add_(neg_target.mul(neg_alpha))
42
+ focal_weight = (1.0 - p).mul_(target).add_(p.mul(neg_target)).square()
43
+ loss = nn.functional.binary_cross_entropy_with_logits(input, target, reduction="none")
44
+ return reduce_loss(loss * focal_weight.mul_(alpha_weight), self.reduction)
45
+
46
+
47
+ class BinaryDiceLoss(nn.Module):
48
+ """Binary dice loss."""
49
+
50
+ def __init__(self, eps=1.0, reduction="none"):
51
+ super(BinaryDiceLoss, self).__init__()
52
+ self.eps = eps
53
+ self.reduction = reduction
54
+
55
+ def forward(self, input, target):
56
+ input = input.sigmoid()
57
+ num = input.mul(target).sum(-1).mul_(2).add_(self.eps)
58
+ den = input.add(target).sum(-1).add_(self.eps)
59
+ return reduce_loss(1.0 - num / den, self.reduction)
60
+
61
+
62
+ class CrossEntropyLoss(nn.Module):
63
+ """Cross entropy loss with label smoothing."""
64
+
65
+ def __init__(self, epsilon=0, reduction="none"):
66
+ super(CrossEntropyLoss, self).__init__()
67
+ self.epsilon = epsilon
68
+ self.reduction = reduction
69
+
70
+ def forward_dense(self, input, target):
71
+ dim, target = input.shape[-1], target.squeeze_()
72
+ x = nn.functional.log_softmax(input, dim=-1)
73
+ y = nn.functional.one_hot(target, dim).float()
74
+ x = x.permute([0, x.dim() - 1] + list(range(x.dim()))[1:-1]) if x.dim() > 2 else x
75
+ y = y.permute([0, y.dim() - 1] + list(range(y.dim()))[1:-1]) if y.dim() > 2 else y
76
+ loss = nn.functional.cross_entropy(x, y, reduction="none", label_smoothing=self.epsilon)
77
+ return reduce_loss(loss, self.reduction)
78
+
79
+ def forward(self, input, target):
80
+ if self.epsilon > 0:
81
+ return self.forward_dense(input, target)
82
+ return nn.functional.cross_entropy(input, target, reduction=self.reduction)
tokenize_anything/layers/utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Layer utilities."""
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import torch
21
+
22
+
23
+ def init_cross_conv(blocks):
24
+ """Initialize convolutional cross attention."""
25
+ for m in blocks.modules():
26
+ if isinstance(m, torch.nn.Conv2d):
27
+ torch.nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
28
+ for blk in blocks:
29
+ torch.nn.init.constant_(blk.norm3.weight, 0)
30
+
31
+
32
+ def set_dropout(module, dropout):
33
+ """Initialize dropout."""
34
+ for m in [m for m in module.modules() if isinstance(m, torch.nn.Dropout)]:
35
+ m.p = dropout
36
+
37
+
38
+ def set_drop_path(blocks, drop_path):
39
+ """Initialize drop path."""
40
+ if not isinstance(blocks, torch.nn.ModuleList):
41
+ blocks = getattr(blocks, "blocks", getattr(blocks, "layers", None))
42
+ for i, blk in enumerate(blocks):
43
+ for m in [m for m in blk.modules() if type(m).__name__ == "DropPath"]:
44
+ m.p = i * drop_path / (len(blocks) - 1)
45
+
46
+
47
+ def set_sync_batch_norm(module, ddp_group):
48
+ """Set data parallelism group for sync batch norm."""
49
+ for m in module.modules():
50
+ if isinstance(m, torch.nn.SyncBatchNorm):
51
+ m.process_group = ddp_group
52
+
53
+
54
+ def resize_pos_embed(weight, out_len):
55
+ """Resize position embedding weights."""
56
+ out_h = out_w = int(out_len**0.5)
57
+ h = w = int(weight.shape[0] ** 0.5)
58
+ weight = weight.reshape((h, w, weight.shape[1]))
59
+ out_weight = [
60
+ cv2.resize(x, (out_w, out_h), interpolation=cv2.INTER_CUBIC)
61
+ for x in np.split(weight.astype("float32", copy=False), 4, axis=-1)
62
+ ]
63
+ out_weight = np.concatenate(out_weight, axis=-1)
64
+ return out_weight.reshape((-1, weight.shape[-1])).astype(weight.dtype, copy=False)
tokenize_anything/modeling/concept_projector.py CHANGED
@@ -51,11 +51,11 @@ class ConceptProjector(nn.Module):
51
  proj = proj.to(device=embeds.device)
52
  return embeds, proj
53
 
54
- def encode_src(self, src_embeds):
55
  """Encode source visual embedding via concept projection."""
56
  src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights)
57
  logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights
58
- return nn.functional.log_softmax(logits, dim=-1)
59
 
60
  def encode_tgt(self, tgt_embeds):
61
  """Encode target visual embedding via concept projection."""
 
51
  proj = proj.to(device=embeds.device)
52
  return embeds, proj
53
 
54
+ def encode_src(self, src_embeds, logpi=True):
55
  """Encode source visual embedding via concept projection."""
56
  src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights)
57
  logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights
58
+ return nn.functional.log_softmax(logits, dim=-1) if logpi else logits
59
 
60
  def encode_tgt(self, tgt_embeds):
61
  """Encode target visual embedding via concept projection."""
tokenize_anything/modeling/image_decoder.py CHANGED
@@ -76,7 +76,6 @@ class Block(nn.Module):
76
  num_heads=8,
77
  attn_ratio=0.5,
78
  mlp_dim=2048,
79
- dropout=0.1,
80
  activation_type="ReLU",
81
  skip_first_query_pos=False,
82
  ):
@@ -89,7 +88,7 @@ class Block(nn.Module):
89
  self.norm3 = nn.LayerNorm(dim)
90
  self.cross_attn_image_to_token = Attention(dim, num_heads, attn_ratio)
91
  self.norm4 = nn.LayerNorm(dim)
92
- self.dropout = nn.Dropout(dropout, inplace=True)
93
  self.skip_first_query_pos = skip_first_query_pos
94
 
95
  def forward(self, query, key, query_pos, key_pos):
@@ -115,7 +114,6 @@ class Transformer(nn.Module):
115
  num_heads=8,
116
  attn_ratio=0.5,
117
  mlp_dim=2048,
118
- dropout=0.1,
119
  activation_type="ReLU",
120
  depth=2,
121
  ):
@@ -126,7 +124,6 @@ class Transformer(nn.Module):
126
  num_heads,
127
  attn_ratio=attn_ratio,
128
  mlp_dim=mlp_dim,
129
- dropout=dropout,
130
  activation_type=activation_type,
131
  skip_first_query_pos=i == 0,
132
  )
@@ -134,7 +131,7 @@ class Transformer(nn.Module):
134
  )
135
  self.final_attn_token_to_image = Attention(embed_dim, num_heads, attn_ratio)
136
  self.norm = nn.LayerNorm(embed_dim)
137
- self.dropout = nn.Dropout(dropout, inplace=True)
138
 
139
  def forward(self, query, key, query_pos, key_pos):
140
  for blk in self.blocks:
@@ -202,7 +199,7 @@ class ImageDecoder(nn.Module):
202
  query, key = self.transformer(query, key, query, inputs["img_pos"])
203
  # Upscale key.
204
  key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
205
- output_masks = self.output_conv(key).flatten(2)
206
  # Unpack query.
207
  tokens = query[:, :num_tokens].unbind(dim=1)
208
  iou_tokens = tokens[num_tokens - self.num_mask_tokens - 1]
@@ -210,7 +207,7 @@ class ImageDecoder(nn.Module):
210
  sem_tokens = tokens[: self.num_mask_tokens]
211
  # Predict.
212
  mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
213
- mask_pred = torch.stack(mask_pred, dim=1) @ output_masks
214
  mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
215
  mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
216
  outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
 
76
  num_heads=8,
77
  attn_ratio=0.5,
78
  mlp_dim=2048,
 
79
  activation_type="ReLU",
80
  skip_first_query_pos=False,
81
  ):
 
88
  self.norm3 = nn.LayerNorm(dim)
89
  self.cross_attn_image_to_token = Attention(dim, num_heads, attn_ratio)
90
  self.norm4 = nn.LayerNorm(dim)
91
+ self.dropout = nn.Dropout(0.1, inplace=True)
92
  self.skip_first_query_pos = skip_first_query_pos
93
 
94
  def forward(self, query, key, query_pos, key_pos):
 
114
  num_heads=8,
115
  attn_ratio=0.5,
116
  mlp_dim=2048,
 
117
  activation_type="ReLU",
118
  depth=2,
119
  ):
 
124
  num_heads,
125
  attn_ratio=attn_ratio,
126
  mlp_dim=mlp_dim,
 
127
  activation_type=activation_type,
128
  skip_first_query_pos=i == 0,
129
  )
 
131
  )
132
  self.final_attn_token_to_image = Attention(embed_dim, num_heads, attn_ratio)
133
  self.norm = nn.LayerNorm(embed_dim)
134
+ self.dropout = nn.Dropout(0.1, inplace=True)
135
 
136
  def forward(self, query, key, query_pos, key_pos):
137
  for blk in self.blocks:
 
199
  query, key = self.transformer(query, key, query, inputs["img_pos"])
200
  # Upscale key.
201
  key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
202
+ mask_embeds = self.output_conv(key).flatten(2)
203
  # Unpack query.
204
  tokens = query[:, :num_tokens].unbind(dim=1)
205
  iou_tokens = tokens[num_tokens - self.num_mask_tokens - 1]
 
207
  sem_tokens = tokens[: self.num_mask_tokens]
208
  # Predict.
209
  mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
210
+ mask_pred = torch.stack(mask_pred, dim=1) @ mask_embeds
211
  mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
212
  mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
213
  outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
tokenize_anything/modeling/image_encoder.py CHANGED
@@ -17,6 +17,8 @@
17
  import torch
18
  from torch import nn
19
 
 
 
20
 
21
  def space_to_depth(input, block_size):
22
  """Rearrange blocks of spatial data into depth."""
@@ -84,10 +86,11 @@ class Block(nn.Module):
84
  self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
85
  self.norm2 = nn.LayerNorm(dim)
86
  self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
 
87
 
88
  def forward(self, x):
89
- x = self.attn(self.norm1(x)).add_(x)
90
- return self.mlp(self.norm2(x)).add_(x)
91
 
92
 
93
  class Bottleneck(nn.Module):
@@ -245,7 +248,7 @@ class ImageEncoderViT(nn.Module):
245
  if i in self.cross_indices or i == len(self.blocks) - 1:
246
  x = self.norm(x) if i == len(self.blocks) - 1 else x
247
  x = depth_to_space(x.reshape(wmsa_shape), self.window_size)
248
- x = x.permute(0, 3, 1, 2)
249
  if i in self.cross_indices:
250
  x = self.cross_conv[self.cross_indices.index(i)](x)
251
  if i in self.cross_indices and i < len(self.blocks) - 1:
 
17
  import torch
18
  from torch import nn
19
 
20
+ from tokenize_anything import layers
21
+
22
 
23
  def space_to_depth(input, block_size):
24
  """Rearrange blocks of spatial data into depth."""
 
86
  self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
87
  self.norm2 = nn.LayerNorm(dim)
88
  self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
89
+ self.drop_path = layers.DropPath(0.1, inplace=True)
90
 
91
  def forward(self, x):
92
+ x = self.drop_path(self.attn(self.norm1(x))).add_(x)
93
+ return self.drop_path(self.mlp(self.norm2(x))).add_(x)
94
 
95
 
96
  class Bottleneck(nn.Module):
 
248
  if i in self.cross_indices or i == len(self.blocks) - 1:
249
  x = self.norm(x) if i == len(self.blocks) - 1 else x
250
  x = depth_to_space(x.reshape(wmsa_shape), self.window_size)
251
+ x = x.permute(0, 3, 1, 2).contiguous()
252
  if i in self.cross_indices:
253
  x = self.cross_conv[self.cross_indices.index(i)](x)
254
  if i in self.cross_indices and i < len(self.blocks) - 1:
tokenize_anything/modeling/image_tokenizer.py CHANGED
@@ -45,13 +45,15 @@ class ImageTokenizer(nn.Module):
45
  self.register_buffer("pixel_mean", torch.Tensor(pixel_mean))
46
  self.register_buffer("pixel_rsig", torch.Tensor(pixel_std).reciprocal_())
47
 
48
- def get_inputs(self, inputs):
49
  """Return the model inputs.
50
 
51
  Parameters
52
  ----------
53
  inputs : dict
54
  The initial inputs.
 
 
55
 
56
  Returns
57
  -------
@@ -59,13 +61,10 @@ class ImageTokenizer(nn.Module):
59
  The model inputs.
60
 
61
  """
62
- if not isinstance(inputs["img"], torch.Tensor):
63
- inputs["img"] = torch.from_numpy(inputs["img"])
64
- if inputs["img"].device != self.pixel_mean.device:
65
- inputs["img"] = inputs["img"].to(device=self.pixel_mean.device)
66
- inputs["img"] = inputs["img"].to(dtype=self.pixel_mean.dtype)
67
- inputs["img"] = inputs["img"].sub(self.pixel_mean).mul_(self.pixel_rsig)
68
- inputs["img"] = inputs["img"].permute(0, 3, 1, 2)
69
  return inputs
70
 
71
  def get_features(self, inputs):
@@ -177,7 +176,7 @@ class ImageTokenizer(nn.Module):
177
  An array of generated texts.
178
 
179
  """
180
- max_gen_len = max_gen_len or self.text_decoder.max_seq_len
181
  prompts = self.text_decoder.get_prompts(visual_tokens)
182
  out_shape = (prompts.size(0), self.text_decoder.max_text_len)
183
  tokens = np.full(out_shape, self.text_tokenizer.pad_id, "int64")
 
45
  self.register_buffer("pixel_mean", torch.Tensor(pixel_mean))
46
  self.register_buffer("pixel_rsig", torch.Tensor(pixel_std).reciprocal_())
47
 
48
+ def get_inputs(self, inputs, dtype=None):
49
  """Return the model inputs.
50
 
51
  Parameters
52
  ----------
53
  inputs : dict
54
  The initial inputs.
55
+ dtype : torch.dtype, optional
56
+ The optional input dtype.
57
 
58
  Returns
59
  -------
 
61
  The model inputs.
62
 
63
  """
64
+ img_dtype, img_device = self.pixel_mean.dtype, self.pixel_mean.device
65
+ inputs["img"] = torch.as_tensor(inputs["img"], dtype=img_dtype, device=img_device)
66
+ inputs["img"] = inputs["img"].sub(self.pixel_mean).mul_(self.pixel_rsig).permute(0, 3, 1, 2)
67
+ inputs["img"] = inputs["img"].to(dtype=dtype) if dtype else inputs["img"]
 
 
 
68
  return inputs
69
 
70
  def get_features(self, inputs):
 
176
  An array of generated texts.
177
 
178
  """
179
+ max_gen_len = max_gen_len or self.text_decoder.max_text_len
180
  prompts = self.text_decoder.get_prompts(visual_tokens)
181
  out_shape = (prompts.size(0), self.text_decoder.max_text_len)
182
  tokens = np.full(out_shape, self.text_tokenizer.pad_id, "int64")
tokenize_anything/modeling/text_decoder.py CHANGED
@@ -79,6 +79,7 @@ class TransformerCache(nn.Module):
79
  cache_v = self.cache_dict.get(f"{id(mixer)}_v", None)
80
  flash_args = {"softmax_scale": mixer.scale, "causal": True}
81
  if cache_k is None or cache_v is None:
 
82
  return flash_attn_func(q, k, v, **flash_args)
83
  flash_args["cache_seqlens"] = self.cache_dict["seq_lens"][: q.shape[0]]
84
  return flash_attn_with_kvcache(q, cache_k, cache_v, k, v, **flash_args)
@@ -94,6 +95,7 @@ class Attention(nn.Module):
94
  self.head_dim = dim // num_heads
95
  self.num_heads = num_heads
96
  self.scale = self.head_dim**-0.5
 
97
  self.cache = nn.Module()
98
 
99
  def forward(self, x):
@@ -126,10 +128,11 @@ class Block(nn.Module):
126
  self.mlp = MLP(dim, mlp_dim, bias=bias)
127
  self.norm1 = nn.LayerNorm(dim)
128
  self.norm2 = nn.LayerNorm(dim)
 
129
 
130
  def forward(self, x):
131
- x = self.attn(self.norm1(x)).add_(x)
132
- return self.mlp(self.norm2(x)).add_(x)
133
 
134
 
135
  class Transformer(nn.Module):
 
79
  cache_v = self.cache_dict.get(f"{id(mixer)}_v", None)
80
  flash_args = {"softmax_scale": mixer.scale, "causal": True}
81
  if cache_k is None or cache_v is None:
82
+ flash_args["dropout_p"] = mixer.dropout.p if mixer.training else 0
83
  return flash_attn_func(q, k, v, **flash_args)
84
  flash_args["cache_seqlens"] = self.cache_dict["seq_lens"][: q.shape[0]]
85
  return flash_attn_with_kvcache(q, cache_k, cache_v, k, v, **flash_args)
 
95
  self.head_dim = dim // num_heads
96
  self.num_heads = num_heads
97
  self.scale = self.head_dim**-0.5
98
+ self.dropout = nn.Dropout(0.1, inplace=False)
99
  self.cache = nn.Module()
100
 
101
  def forward(self, x):
 
128
  self.mlp = MLP(dim, mlp_dim, bias=bias)
129
  self.norm1 = nn.LayerNorm(dim)
130
  self.norm2 = nn.LayerNorm(dim)
131
+ self.dropout = nn.Dropout(0.1, inplace=True)
132
 
133
  def forward(self, x):
134
+ x = self.dropout(self.attn(self.norm1(x))).add_(x)
135
+ return self.dropout(self.mlp(self.norm2(x))).add_(x)
136
 
137
 
138
  class Transformer(nn.Module):
tokenize_anything/models/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Models."""
17
+
18
+ from tokenize_anything.models.easy_build import model_registry
tokenize_anything/{build_model.py β†’ models/easy_build.py} RENAMED
@@ -13,7 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  # ------------------------------------------------------------------------
16
- """Build model."""
17
 
18
  from functools import partial
19
  import pickle
@@ -40,7 +40,7 @@ def get_device(device_index):
40
  def load_weights(module, weights_file, strict=True):
41
  """Load a weights file."""
42
  if not weights_file:
43
- return module._IncompatibleKeys([], [])
44
  if weights_file.endswith(".pkl"):
45
  with open(weights_file, "rb") as f:
46
  state_dict = pickle.load(f)
@@ -48,7 +48,7 @@ def load_weights(module, weights_file, strict=True):
48
  state_dict[k] = torch.from_numpy(v) if isinstance(v, np.ndarray) else v
49
  else:
50
  state_dict = torch.load(weights_file)
51
- return module.load_state_dict(state_dict, strict=strict)
52
 
53
 
54
  def vit_encoder(depth, embed_dim, num_heads, out_dim, image_size):
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  # ------------------------------------------------------------------------
16
+ """Easy model builder."""
17
 
18
  from functools import partial
19
  import pickle
 
40
  def load_weights(module, weights_file, strict=True):
41
  """Load a weights file."""
42
  if not weights_file:
43
+ return
44
  if weights_file.endswith(".pkl"):
45
  with open(weights_file, "rb") as f:
46
  state_dict = pickle.load(f)
 
48
  state_dict[k] = torch.from_numpy(v) if isinstance(v, np.ndarray) else v
49
  else:
50
  state_dict = torch.load(weights_file)
51
+ module.load_state_dict(state_dict, strict=strict)
52
 
53
 
54
  def vit_encoder(depth, embed_dim, num_heads, out_dim, image_size):
tokenize_anything/prompters/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Prompters."""
17
+
18
+ from tokenize_anything.prompters.visual_prompter import VisualPrompter
tokenize_anything/prompters/visual_prompter.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Generate visual prompts."""
17
+
18
+ import collections
19
+
20
+ import numpy as np
21
+ import numpy.random as npr
22
+
23
+
24
+ class VisualPrompter(object):
25
+ """Generate visual prompts."""
26
+
27
+ def __init__(self, image_size=1024, max_points=9, num_experts=4, padding_index=4):
28
+ super(VisualPrompter, self).__init__()
29
+ self.num_stages = 2
30
+ self.max_points = max_points
31
+ self.point_weight = [1000] + [0] * (num_experts - 1)
32
+ self.image_size = image_size if isinstance(image_size, (tuple, list)) else [image_size] * 2
33
+ self.padding_index = padding_index
34
+ self.coord_count = collections.defaultdict(int)
35
+ self.coords = self.labels = self.boxes_turn = None
36
+ self.stage_count = 0
37
+ self.box_prob = 0.5
38
+
39
+ @property
40
+ def is_last_stage(self):
41
+ return self.stage_count == self.num_stages - 1
42
+
43
+ def add_point(self, index, gt_masks, error_masks=None, num=1):
44
+ def sample(mask):
45
+ ys, xs = np.nonzero(mask)
46
+ if ys.shape[0] > 0:
47
+ idx = npr.choice(ys.shape[0], size=(num,), replace=num > ys.shape[0])
48
+ return xs[idx], ys[idx]
49
+ return [-0.5] * num, [-0.5] * num
50
+
51
+ labels = [self.padding_index] * num
52
+ if error_masks is not None: # FP or FN point.
53
+ xs, ys = sample(error_masks[index])
54
+ labels = gt_masks[index, ys, xs] if ys[0] >= 0 else labels
55
+ if labels[0] == self.padding_index: # GT point.
56
+ xs, ys = sample(gt_masks[index])
57
+ labels = [1] * num if ys[0] >= 0 else labels
58
+ xs = (np.array(xs, "float32") + 0.5) * (self.image_size[1] / gt_masks.shape[2]) - 0.5
59
+ ys = (np.array(ys, "float32") + 0.5) * (self.image_size[0] / gt_masks.shape[1]) - 0.5
60
+ slice_index = slice(self.coord_count[index], self.coord_count[index] + num)
61
+ self.coords[index, slice_index] = np.vstack([xs, ys]).T
62
+ self.labels[index, slice_index] = labels
63
+ self.coord_count[index] += num
64
+
65
+ def add_box(self, index, gt_boxes):
66
+ x1, y1, x2, y2 = gt_boxes[index, :4]
67
+ dx1, dx2 = np.clip(npr.normal(0.0, 0.1 * (x2 - x1), (2,)), -20, 20)
68
+ dy1, dy2 = np.clip(npr.normal(0.0, 0.1 * (y2 - y1), (2,)), -20, 20)
69
+ x1, y1 = x1 + np.minimum(dx1, 0), y1 + np.minimum(dy1, 0)
70
+ x2, y2 = x2 + np.maximum(dx2, 0), y2 + np.maximum(dy2, 0)
71
+ self.coords[index, self.coord_count[index]] = (x1, y1)
72
+ self.coords[index, self.coord_count[index] + 1] = (x2, y2)
73
+ self.labels[index, self.coord_count[index]] = 2
74
+ self.labels[index, self.coord_count[index] + 1] = 3
75
+ self.coord_count[index] += 2
76
+
77
+ def reset(self, num):
78
+ self.stage_count = 0
79
+ self.coord_count.clear()
80
+ self.coords = np.full((num, self.max_points + 1, 2), -0.5, "float32")
81
+ self.labels = np.full((num, self.max_points + 1), self.padding_index, "int64")
82
+ self.boxes_turn = npr.rand(num) < self.box_prob
83
+
84
+ def get_prompts(self, gt_boxes, gt_masks=None, masks=None):
85
+ num = gt_boxes.shape[0]
86
+ if self.stage_count == 0:
87
+ self.reset(num)
88
+ coords = labels = error_masks = None
89
+ if masks is not None:
90
+ masks = masks.reshape(gt_masks.shape)
91
+ error_masks = (masks | gt_masks) ^ (masks & gt_masks)
92
+ num_points = 1
93
+ if self.stage_count > 0:
94
+ num_points = npr.randint(1, self.max_points + 1 - self.stage_count)
95
+ if self.stage_count == 0 and self.box_prob == 0:
96
+ num_points = npr.randint(2, self.max_points + 1)
97
+ for index in range(num):
98
+ is_box = self.stage_count == 0 and self.boxes_turn[index]
99
+ if gt_masks is None or is_box:
100
+ self.add_box(index, gt_boxes)
101
+ else:
102
+ self.add_point(index, gt_masks, error_masks, num_points)
103
+ coords = self.coords[:, : 1 + self.stage_count + num_points]
104
+ labels = self.labels[:, : 1 + self.stage_count + num_points]
105
+ scores = (self.boxes_turn[:, None] - 0.5) * self.point_weight
106
+ return {"points": (coords, labels), "point_score": scores}
tokenize_anything/utils/logging.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Logging utilities."""
17
+
18
+ import inspect
19
+ import logging as _logging
20
+ import os
21
+ import sys as _sys
22
+ import threading
23
+
24
+
25
+ _logger = None
26
+ _logger_lock = threading.Lock()
27
+
28
+
29
+ def get_logger():
30
+ global _logger
31
+ # Use double-checked locking to avoid taking lock unnecessarily.
32
+ if _logger:
33
+ return _logger
34
+ _logger_lock.acquire()
35
+ try:
36
+ if _logger:
37
+ return _logger
38
+ logger = _logging.getLogger("tokenize-anything")
39
+ logger.setLevel("INFO")
40
+ logger.propagate = False
41
+ logger._is_root = True
42
+ if True:
43
+ # Determine whether we are in an interactive environment.
44
+ _interactive = False
45
+ try:
46
+ # This is only defined in interactive shells.
47
+ if _sys.ps1:
48
+ _interactive = True
49
+ except AttributeError:
50
+ # Even now, we may be in an interactive shell with `python -i`.
51
+ _interactive = _sys.flags.interactive
52
+ # If we are in an interactive environment (like Jupyter), set loglevel
53
+ # to INFO and pipe the output to stdout.
54
+ if _interactive:
55
+ logger.setLevel("INFO")
56
+ _logging_target = _sys.stdout
57
+ else:
58
+ _logging_target = _sys.stderr
59
+ # Add the output handler.
60
+ _handler = _logging.StreamHandler(_logging_target)
61
+ _handler.setFormatter(_logging.Formatter("%(levelname)s %(message)s"))
62
+ logger.addHandler(_handler)
63
+ _logger = logger
64
+ return _logger
65
+ finally:
66
+ _logger_lock.release()
67
+
68
+
69
+ def _detailed_msg(msg):
70
+ file, lineno = inspect.stack()[:3][2][1:3]
71
+ return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg)
72
+
73
+
74
+ def log(level, msg, *args, **kwargs):
75
+ get_logger().log(level, _detailed_msg(msg), *args, **kwargs)
76
+
77
+
78
+ def debug(msg, *args, **kwargs):
79
+ if is_root():
80
+ get_logger().debug(_detailed_msg(msg), *args, **kwargs)
81
+
82
+
83
+ def error(msg, *args, **kwargs):
84
+ get_logger().error(_detailed_msg(msg), *args, **kwargs)
85
+ assert 0
86
+
87
+
88
+ def fatal(msg, *args, **kwargs):
89
+ get_logger().fatal(_detailed_msg(msg), *args, **kwargs)
90
+ assert 0
91
+
92
+
93
+ def info(msg, *args, **kwargs):
94
+ if is_root():
95
+ get_logger().info(_detailed_msg(msg), *args, **kwargs)
96
+
97
+
98
+ def warning(msg, *args, **kwargs):
99
+ if is_root():
100
+ get_logger().warning(_detailed_msg(msg), *args, **kwargs)
101
+
102
+
103
+ def get_verbosity():
104
+ """Return how much logging output will be produced."""
105
+ return get_logger().getEffectiveLevel()
106
+
107
+
108
+ def set_verbosity(v):
109
+ """Set the threshold for what messages will be logged."""
110
+ get_logger().setLevel(v)
111
+
112
+
113
+ def set_formatter(fmt=None, datefmt=None):
114
+ """Set the formatter."""
115
+ handler = _logging.StreamHandler(_sys.stderr)
116
+ handler.setFormatter(_logging.Formatter(fmt, datefmt))
117
+ logger = get_logger()
118
+ logger.removeHandler(logger.handlers[0])
119
+ logger.addHandler(handler)
120
+
121
+
122
+ def set_root(is_root=True):
123
+ """Set logger to the root."""
124
+ get_logger()._is_root = is_root
125
+
126
+
127
+ def is_root():
128
+ """Return logger is the root."""
129
+ return get_logger()._is_root
tokenize_anything/utils/profiler/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Profiler utilities."""
17
+
18
+ from tokenize_anything.utils.profiler.stats import SmoothedValue
19
+ from tokenize_anything.utils.profiler.timer import Timer
20
+ from tokenize_anything.utils.profiler.timer import get_progress
tokenize_anything/utils/profiler/stats.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Trackable statistics."""
17
+
18
+ import collections
19
+ import numpy as np
20
+
21
+
22
+ class SmoothedValue(object):
23
+ """Track values and provide smoothed report."""
24
+
25
+ def __init__(self, window_size=None):
26
+ self.deque = collections.deque(maxlen=window_size)
27
+ self.total = 0.0
28
+ self.count = 0
29
+
30
+ def update(self, value):
31
+ self.deque.append(value)
32
+ self.count += 1
33
+ self.total += value
34
+
35
+ def mean(self):
36
+ return np.mean(self.deque)
37
+
38
+ def median(self):
39
+ return np.median(self.deque)
40
+
41
+ def average(self):
42
+ return self.total / self.count
tokenize_anything/utils/{timer.py β†’ profiler/timer.py} RENAMED
@@ -9,13 +9,14 @@
9
  #
10
  # Unless required by applicable law or agreed to in writing, software
11
  # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  # ------------------------------------------------------------------------
16
  """Timing functions."""
17
 
18
  import contextlib
 
19
  import time
20
 
21
 
@@ -49,3 +50,13 @@ class Timer(object):
49
  def toc(self, n=1, average=True):
50
  self.diff = time.time() - self.start_time
51
  return self.add_diff(self.diff, n, average)
 
 
 
 
 
 
 
 
 
 
 
9
  #
10
  # Unless required by applicable law or agreed to in writing, software
11
  # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  # ------------------------------------------------------------------------
16
  """Timing functions."""
17
 
18
  import contextlib
19
+ import datetime
20
  import time
21
 
22
 
 
50
  def toc(self, n=1, average=True):
51
  self.diff = time.time() - self.start_time
52
  return self.add_diff(self.diff, n, average)
53
+
54
+
55
+ def get_progress(timer, step, max_steps):
56
+ """Return the progress information."""
57
+ eta_seconds = timer.average_time * (max_steps - step)
58
+ eta = str(datetime.timedelta(seconds=int(eta_seconds)))
59
+ progress = (step + 1.0) / max_steps
60
+ return "< PROGRESS: {:.2%} | SPEED: {:.3f}s / iter | ETA: {} >".format(
61
+ progress, timer.average_time, eta
62
+ )
tokenize_anything/utils/registry.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Registry utilities."""
17
+
18
+ import collections
19
+ import functools
20
+
21
+
22
+ class Registry(object):
23
+ """Registry class."""
24
+
25
+ def __init__(self, name):
26
+ self.name = name
27
+ self.registry = collections.OrderedDict()
28
+
29
+ def has(self, key):
30
+ return key in self.registry
31
+
32
+ def register(self, name, func=None, **kwargs):
33
+ def decorated(inner_function):
34
+ for key in name if isinstance(name, (tuple, list)) else [name]:
35
+ self.registry[key] = functools.partial(inner_function, **kwargs)
36
+ return inner_function
37
+
38
+ if func is not None:
39
+ return decorated(func)
40
+ return decorated
41
+
42
+ def get(self, name, default=None):
43
+ if name is None:
44
+ return None
45
+ if not self.has(name):
46
+ if default is not None:
47
+ return default
48
+ raise KeyError("`%s` is not registered in <%s>." % (name, self.name))
49
+ return self.registry[name]
50
+
51
+ def try_get(self, name):
52
+ if self.has(name):
53
+ return self.get(name)
54
+ return None
tokenize_anything/utils/tensorboard.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Tensorboard application."""
17
+
18
+ import time
19
+
20
+ import numpy as np
21
+
22
+ try:
23
+ import tensorflow as tf
24
+ except ImportError:
25
+ tf = None
26
+
27
+
28
+ class TensorBoard(object):
29
+ """TensorBoard application."""
30
+
31
+ def __init__(self, log_dir=None):
32
+ """Create a summary writer logging to log_dir."""
33
+ if tf is None:
34
+ raise ImportError("Failed to import ``tensorflow`` package.")
35
+ tf.config.set_visible_devices([], "GPU")
36
+ if log_dir is None:
37
+ log_dir = "./logs/" + time.strftime("%Y%m%d_%H%M%S", time.localtime(time.time()))
38
+ self.writer = tf.summary.create_file_writer(log_dir)
39
+
40
+ @staticmethod
41
+ def is_available():
42
+ """Return if tensor board is available."""
43
+ return tf is not None
44
+
45
+ def close(self):
46
+ """Close board and apply all cached summaries."""
47
+ self.writer.close()
48
+
49
+ def histogram_summary(self, tag, values, step, buckets=10):
50
+ """Write a histogram of values."""
51
+ with self.writer.as_default():
52
+ tf.summary.histogram(tag, values, step, buckets=buckets)
53
+
54
+ def image_summary(self, tag, images, step, order="BGR"):
55
+ """Write a list of images."""
56
+ if isinstance(images, (tuple, list)):
57
+ images = np.stack(images)
58
+ if len(images.shape) != 4:
59
+ raise ValueError("Images can not be packed to (N, H, W, C).")
60
+ if order == "BGR":
61
+ images = images[:, :, :, ::-1]
62
+ with self.writer.as_default():
63
+ tf.summary.image(tag, images, step, max_outputs=images.shape[0])
64
+
65
+ def scalar_summary(self, tag, value, step):
66
+ """Write a scalar."""
67
+ with self.writer.as_default():
68
+ tf.summary.scalar(tag, value, step)