Spaces:
Running
Running
Upload 22 files
Browse files- sam_extension/distillation_models/__init__.py +4 -0
- sam_extension/distillation_models/__pycache__/__init__.cpython-38.pyc +0 -0
- sam_extension/distillation_models/__pycache__/dino.cpython-38.pyc +0 -0
- sam_extension/distillation_models/__pycache__/fastertinyvit.cpython-38.pyc +0 -0
- sam_extension/distillation_models/__pycache__/fastervit.cpython-38.pyc +0 -0
- sam_extension/distillation_models/__pycache__/sam.cpython-38.pyc +0 -0
- sam_extension/distillation_models/dino.py +122 -0
- sam_extension/distillation_models/fastertinyvit.py +233 -0
- sam_extension/distillation_models/fastervit.py +659 -0
- sam_extension/distillation_models/sam.py +369 -0
- sam_extension/pipeline/__init__.py +4 -0
- sam_extension/pipeline/__pycache__/__init__.cpython-38.pyc +0 -0
- sam_extension/pipeline/__pycache__/base.cpython-38.pyc +0 -0
- sam_extension/pipeline/__pycache__/groundingdino.cpython-38.pyc +0 -0
- sam_extension/pipeline/__pycache__/owlvit.cpython-38.pyc +0 -0
- sam_extension/pipeline/__pycache__/sam.cpython-38.pyc +0 -0
- sam_extension/pipeline/base.py +20 -0
- sam_extension/pipeline/groundingdino.py +97 -0
- sam_extension/pipeline/owlvit.py +372 -0
- sam_extension/pipeline/sam.py +722 -0
- sam_extension/utils/__init__.py +175 -0
- sam_extension/utils/__pycache__/__init__.cpython-38.pyc +0 -0
sam_extension/distillation_models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dino import DINO
|
2 |
+
from .sam import SAMEncoderViT, DINOSAMViT
|
3 |
+
from .fastertinyvit import FasterTinyViT
|
4 |
+
# from .flashvision_transformer import FlashVisionTransformer
|
sam_extension/distillation_models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (322 Bytes). View file
|
|
sam_extension/distillation_models/__pycache__/dino.cpython-38.pyc
ADDED
Binary file (4.72 kB). View file
|
|
sam_extension/distillation_models/__pycache__/fastertinyvit.cpython-38.pyc
ADDED
Binary file (6.26 kB). View file
|
|
sam_extension/distillation_models/__pycache__/fastervit.cpython-38.pyc
ADDED
Binary file (18 kB). View file
|
|
sam_extension/distillation_models/__pycache__/sam.cpython-38.pyc
ADDED
Binary file (10.7 kB). View file
|
|
sam_extension/distillation_models/dino.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
from PIL.Image import Image
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
from sklearn.decomposition import PCA
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torchvision import transforms as tfs
|
10 |
+
|
11 |
+
|
12 |
+
MEAN = [0.485, 0.456, 0.406]
|
13 |
+
STD = [0.229, 0.224, 0.225]
|
14 |
+
DINO_MODEL_HUB = 'facebookresearch/dino:main'
|
15 |
+
DINO_MODEL_TYPE = ['dino_vits16',
|
16 |
+
'dino_vits8',
|
17 |
+
'dino_vitb16',
|
18 |
+
'dino_vitb8',
|
19 |
+
'dino_xcit_small_12_p16',
|
20 |
+
'dino_xcit_small_12_p8',
|
21 |
+
'dino_xcit_medium_24_p16',
|
22 |
+
'dino_xcit_medium_24_p8',
|
23 |
+
'dino_resnet50']
|
24 |
+
|
25 |
+
DINOV2_MODEL_HUB = 'facebookresearch/dinov2:main'
|
26 |
+
DINOV2_MODEL_TYPE = ['dinov2_vits14',
|
27 |
+
'dinov2_vitb14',
|
28 |
+
'dinov2_vitl14',
|
29 |
+
'dinov2_vitg14']
|
30 |
+
|
31 |
+
class DINO(nn.Module):
|
32 |
+
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
|
33 |
+
super(DINO, self).__init__()
|
34 |
+
assert model_type in DINO_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
|
35 |
+
self.model = torch.hub.load(DINO_MODEL_HUB, model_type).to(device)
|
36 |
+
self.device = device
|
37 |
+
for param in self.model.parameters():
|
38 |
+
param.requires_grad = False
|
39 |
+
self.model.eval()
|
40 |
+
self.img_size = img_size
|
41 |
+
self.pca_dim = pca_dim
|
42 |
+
self.pca = self.set_pca(pca_dim) if pca_dim else None
|
43 |
+
def set_pca(self, dim=64):
|
44 |
+
return PCA(n_components=dim)
|
45 |
+
@torch.no_grad()
|
46 |
+
def extract_features(
|
47 |
+
self, img: Union[Image, torch.Tensor], transform=True, size=None
|
48 |
+
):
|
49 |
+
if transform and isinstance(img, Image):
|
50 |
+
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
|
51 |
+
with torch.no_grad():
|
52 |
+
out = self.model.get_intermediate_layers(img.to(self.device), n=1)[0]
|
53 |
+
out = out[:, 1:, :] # we discard the [CLS] token
|
54 |
+
h, w = int(img.shape[2] / self.model.patch_embed.patch_size), int(
|
55 |
+
img.shape[3] / self.model.patch_embed.patch_size
|
56 |
+
)
|
57 |
+
dim = out.shape[-1]
|
58 |
+
out = out.reshape(-1, h, w, dim)
|
59 |
+
dtype = out.dtype
|
60 |
+
if size is not None:
|
61 |
+
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
|
62 |
+
if self.pca:
|
63 |
+
B, H, W, C = out.shape
|
64 |
+
out = out.view(-1, C).cpu().numpy()
|
65 |
+
out = self.pca.fit_transform(out)
|
66 |
+
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
|
67 |
+
return out
|
68 |
+
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
|
69 |
+
return self.extract_features(img, transform, size)
|
70 |
+
@staticmethod
|
71 |
+
def transform(img, image_size):
|
72 |
+
transforms = tfs.Compose(
|
73 |
+
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
|
74 |
+
)
|
75 |
+
img = transforms(img)
|
76 |
+
return img
|
77 |
+
|
78 |
+
class DINOV2(nn.Module):
|
79 |
+
def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
|
80 |
+
super(DINOV2, self).__init__()
|
81 |
+
assert model_type in DINOV2_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
|
82 |
+
self.model = torch.hub.load(DINOV2_MODEL_HUB, model_type).to(device)
|
83 |
+
self.device = device
|
84 |
+
for param in self.model.parameters():
|
85 |
+
param.requires_grad = False
|
86 |
+
self.model.eval()
|
87 |
+
self.img_size = img_size
|
88 |
+
self.pca_dim = pca_dim
|
89 |
+
self.pca = self.set_pca(pca_dim) if pca_dim else None
|
90 |
+
def set_pca(self, dim=64):
|
91 |
+
return PCA(n_components=dim)
|
92 |
+
@torch.no_grad()
|
93 |
+
def extract_features(
|
94 |
+
self, img: Union[Image, torch.Tensor], transform=True, size=None
|
95 |
+
):
|
96 |
+
if transform and isinstance(img, Image):
|
97 |
+
img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
|
98 |
+
with torch.no_grad():
|
99 |
+
out = self.model.forward_features(img.to(self.device))['x_norm_patchtokens']
|
100 |
+
h, w = int(img.shape[2] / self.model.patch_size), int(
|
101 |
+
img.shape[3] / self.model.patch_size
|
102 |
+
)
|
103 |
+
dim = out.shape[-1]
|
104 |
+
out = out.reshape(-1, h, w, dim)
|
105 |
+
dtype = out.dtype
|
106 |
+
if size is not None:
|
107 |
+
out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
|
108 |
+
if self.pca:
|
109 |
+
B, H, W, C = out.shape
|
110 |
+
out = out.view(-1, C).cpu().numpy()
|
111 |
+
out = self.pca.fit_transform(out)
|
112 |
+
out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
|
113 |
+
return out
|
114 |
+
def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
|
115 |
+
return self.extract_features(img, transform, size)
|
116 |
+
@staticmethod
|
117 |
+
def transform(img, image_size):
|
118 |
+
transforms = tfs.Compose(
|
119 |
+
[tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
|
120 |
+
)
|
121 |
+
img = transforms(img)
|
122 |
+
return img
|
sam_extension/distillation_models/fastertinyvit.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, List, Union
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.utils.checkpoint import checkpoint
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from timm.models.layers import trunc_normal_
|
7 |
+
from sam_extension.distillation_models.fastervit import FasterViTLayer
|
8 |
+
from segment_anything.mobile_encoder.tiny_vit_sam import PatchEmbed, Conv2d_BN, LayerNorm2d, MBConv
|
9 |
+
class PatchMerging(nn.Module):
|
10 |
+
def __init__(self, input_resolution, dim, out_dim, activation):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
self.input_resolution = input_resolution
|
14 |
+
self.dim = dim
|
15 |
+
self.out_dim = out_dim
|
16 |
+
self.act = activation()
|
17 |
+
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
18 |
+
stride_c=2
|
19 |
+
if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576
|
20 |
+
stride_c=1
|
21 |
+
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
22 |
+
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
if x.ndim == 3:
|
26 |
+
H, W = self.input_resolution
|
27 |
+
B = len(x)
|
28 |
+
# (B, C, H, W)
|
29 |
+
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
30 |
+
|
31 |
+
x = self.conv1(x)
|
32 |
+
x = self.act(x)
|
33 |
+
|
34 |
+
x = self.conv2(x)
|
35 |
+
x = self.act(x)
|
36 |
+
x = self.conv3(x)
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class ConvLayer(nn.Module):
|
41 |
+
def __init__(self, dim, input_resolution, depth,
|
42 |
+
activation,
|
43 |
+
drop_path=0., downsample=None, use_checkpoint=False,
|
44 |
+
out_dim=None,
|
45 |
+
conv_expand_ratio=4.,
|
46 |
+
):
|
47 |
+
|
48 |
+
super().__init__()
|
49 |
+
self.dim = dim
|
50 |
+
self.input_resolution = input_resolution
|
51 |
+
self.depth = depth
|
52 |
+
self.use_checkpoint = use_checkpoint
|
53 |
+
|
54 |
+
# build blocks
|
55 |
+
self.blocks = nn.ModuleList([
|
56 |
+
MBConv(dim, dim, conv_expand_ratio, activation,
|
57 |
+
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
58 |
+
)
|
59 |
+
for i in range(depth)])
|
60 |
+
|
61 |
+
# patch merging layer
|
62 |
+
if downsample is not None:
|
63 |
+
self.downsample = downsample(
|
64 |
+
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
65 |
+
else:
|
66 |
+
self.downsample = None
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
for blk in self.blocks:
|
70 |
+
if self.use_checkpoint:
|
71 |
+
x = checkpoint.checkpoint(blk, x)
|
72 |
+
else:
|
73 |
+
x = blk(x)
|
74 |
+
if self.downsample is not None:
|
75 |
+
x = self.downsample(x)
|
76 |
+
return x
|
77 |
+
|
78 |
+
class FasterTinyViT(nn.Module):
|
79 |
+
def __init__(self, img_size=224,
|
80 |
+
in_chans=3,
|
81 |
+
out_chans=256,
|
82 |
+
embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
|
83 |
+
num_heads=[3, 6, 12, 24],
|
84 |
+
window_sizes=[7, 7, 14, 7],
|
85 |
+
mlp_ratio=4.,
|
86 |
+
drop_rate=0.,
|
87 |
+
drop_path_rate=0.1,
|
88 |
+
use_checkpoint=False,
|
89 |
+
mbconv_expand_ratio=4.0,
|
90 |
+
ct_size=2,
|
91 |
+
conv=False,
|
92 |
+
multi_scale=False,
|
93 |
+
output_shape=None,
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
self.img_size = img_size
|
97 |
+
self.depths = depths
|
98 |
+
self.num_layers = len(depths)
|
99 |
+
self.mlp_ratio = mlp_ratio
|
100 |
+
self.multi_scale = multi_scale
|
101 |
+
self.output_shape = tuple(output_shape) if output_shape else None
|
102 |
+
|
103 |
+
activation = nn.GELU
|
104 |
+
|
105 |
+
self.patch_embed = PatchEmbed(in_chans=in_chans,
|
106 |
+
embed_dim=embed_dims[0],
|
107 |
+
resolution=img_size,
|
108 |
+
activation=activation)
|
109 |
+
|
110 |
+
patches_resolution = self.patch_embed.patches_resolution
|
111 |
+
self.patches_resolution = patches_resolution
|
112 |
+
|
113 |
+
# stochastic depth
|
114 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
|
115 |
+
sum(depths))] # stochastic depth decay rule
|
116 |
+
|
117 |
+
# build layers
|
118 |
+
self.layers = nn.ModuleList()
|
119 |
+
for i_layer in range(self.num_layers):
|
120 |
+
kwargs_0 = dict(dim=embed_dims[i_layer],
|
121 |
+
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
122 |
+
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
|
123 |
+
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
124 |
+
# patches_resolution[1] // (2 ** i_layer)),
|
125 |
+
depth=depths[i_layer],
|
126 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
127 |
+
downsample=PatchMerging if (
|
128 |
+
i_layer < self.num_layers - 1) else None,
|
129 |
+
use_checkpoint=use_checkpoint,
|
130 |
+
out_dim=embed_dims[min(
|
131 |
+
i_layer + 1, len(embed_dims) - 1)],
|
132 |
+
activation=activation,
|
133 |
+
)
|
134 |
+
kwargs_1 = dict(dim=embed_dims[i_layer],
|
135 |
+
out_dim=embed_dims[i_layer+1] if (
|
136 |
+
i_layer < self.num_layers - 1) else embed_dims[i_layer],
|
137 |
+
input_resolution=patches_resolution[0] // (2 ** i_layer),
|
138 |
+
depth=depths[i_layer],
|
139 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
140 |
+
downsample=True if (i_layer < self.num_layers - 1) else False,
|
141 |
+
ct_size=ct_size,
|
142 |
+
conv=conv,
|
143 |
+
)
|
144 |
+
if i_layer == 0:
|
145 |
+
layer = ConvLayer(
|
146 |
+
conv_expand_ratio=mbconv_expand_ratio,
|
147 |
+
**kwargs_0,
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
layer = FasterViTLayer(
|
151 |
+
num_heads=num_heads[i_layer],
|
152 |
+
window_size=window_sizes[i_layer],
|
153 |
+
mlp_ratio=self.mlp_ratio,
|
154 |
+
drop=drop_rate,
|
155 |
+
**kwargs_1)
|
156 |
+
self.layers.append(layer)
|
157 |
+
|
158 |
+
# init weights
|
159 |
+
self.apply(self._init_weights)
|
160 |
+
|
161 |
+
self.neck = nn.Sequential(
|
162 |
+
nn.Conv2d(
|
163 |
+
sum(embed_dims)+embed_dims[-1] if self.multi_scale and self.output_shape else embed_dims[-1],
|
164 |
+
out_chans,
|
165 |
+
kernel_size=1,
|
166 |
+
bias=False,
|
167 |
+
),
|
168 |
+
LayerNorm2d(out_chans),
|
169 |
+
nn.Conv2d(
|
170 |
+
out_chans,
|
171 |
+
out_chans,
|
172 |
+
kernel_size=3,
|
173 |
+
padding=1,
|
174 |
+
bias=False,
|
175 |
+
),
|
176 |
+
LayerNorm2d(out_chans),
|
177 |
+
)
|
178 |
+
|
179 |
+
def _init_weights(self, m):
|
180 |
+
if isinstance(m, nn.Linear):
|
181 |
+
trunc_normal_(m.weight, std=.02)
|
182 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
183 |
+
nn.init.constant_(m.bias, 0)
|
184 |
+
elif isinstance(m, nn.LayerNorm):
|
185 |
+
nn.init.constant_(m.bias, 0)
|
186 |
+
nn.init.constant_(m.weight, 1.0)
|
187 |
+
|
188 |
+
@torch.jit.ignore
|
189 |
+
def no_weight_decay_keywords(self):
|
190 |
+
return {'attention_biases'}
|
191 |
+
|
192 |
+
def forward_features(self, x):
|
193 |
+
if self.multi_scale and self.output_shape:
|
194 |
+
output_list = []
|
195 |
+
# x: (N, C, H, W)
|
196 |
+
x = self.patch_embed(x)
|
197 |
+
output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear'))
|
198 |
+
for layer in self.layers:
|
199 |
+
x = layer(x)
|
200 |
+
output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear'))
|
201 |
+
x = self.neck(torch.cat(output_list, dim=1))
|
202 |
+
|
203 |
+
else:
|
204 |
+
x = self.patch_embed(x)
|
205 |
+
for layer in self.layers:
|
206 |
+
x = layer(x)
|
207 |
+
x = self.neck(x)
|
208 |
+
return x
|
209 |
+
|
210 |
+
|
211 |
+
def forward(self, x):
|
212 |
+
x = self.forward_features(x)
|
213 |
+
|
214 |
+
return x
|
215 |
+
|
216 |
+
if __name__ == '__main__':
|
217 |
+
from distillation.utils import get_parameter_number
|
218 |
+
x = torch.randn(1, 3, 1024, 1024).cuda()
|
219 |
+
fastertinyvit = FasterTinyViT(img_size=1024, in_chans=3,
|
220 |
+
embed_dims=[64, 128, 256],
|
221 |
+
depths=[1, 2, 1],
|
222 |
+
num_heads=[2, 4, 8],
|
223 |
+
window_sizes=[8, 8, 8],
|
224 |
+
mlp_ratio=4.,
|
225 |
+
drop_rate=0.,
|
226 |
+
drop_path_rate=0.0,
|
227 |
+
use_checkpoint=False,
|
228 |
+
mbconv_expand_ratio=4.0,
|
229 |
+
multi_scale=False,
|
230 |
+
output_shape='').cuda()
|
231 |
+
print(fastertinyvit(x).shape)
|
232 |
+
print(get_parameter_number(fastertinyvit))
|
233 |
+
# torch.save(fastertinyvit, 'fastertinyvit.pt')
|
sam_extension/distillation_models/fastervit.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
from timm.models.layers import DropPath, LayerNorm2d
|
5 |
+
def window_partition(x, window_size):
|
6 |
+
B, C, H, W = x.shape
|
7 |
+
x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
|
8 |
+
windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
|
9 |
+
return windows
|
10 |
+
|
11 |
+
|
12 |
+
def window_reverse(windows, window_size, H, W, B):
|
13 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
14 |
+
x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, windows.shape[2], H, W)
|
15 |
+
return x
|
16 |
+
|
17 |
+
|
18 |
+
def ct_dewindow(ct, W, H, window_size):
|
19 |
+
bs = ct.shape[0]
|
20 |
+
N=ct.shape[2]
|
21 |
+
ct2 = ct.view(-1, W//window_size, H//window_size, window_size, window_size, N).permute(0, 5, 1, 3, 2, 4)
|
22 |
+
ct2 = ct2.reshape(bs, N, W*H).transpose(1, 2)
|
23 |
+
return ct2
|
24 |
+
|
25 |
+
|
26 |
+
def ct_window(ct, W, H, window_size):
|
27 |
+
bs = ct.shape[0]
|
28 |
+
N = ct.shape[2]
|
29 |
+
ct = ct.view(bs, H // window_size, window_size, W // window_size, window_size, N)
|
30 |
+
ct = ct.permute(0, 1, 3, 2, 4, 5)
|
31 |
+
return ct
|
32 |
+
|
33 |
+
class PosEmbMLPSwinv2D(nn.Module):
|
34 |
+
def __init__(self,
|
35 |
+
window_size,
|
36 |
+
pretrained_window_size,
|
37 |
+
num_heads, seq_length,
|
38 |
+
ct_correct=False,
|
39 |
+
no_log=False):
|
40 |
+
super().__init__()
|
41 |
+
self.window_size = window_size
|
42 |
+
self.num_heads = num_heads
|
43 |
+
self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
|
44 |
+
nn.ReLU(inplace=True),
|
45 |
+
nn.Linear(512, num_heads, bias=False))
|
46 |
+
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
|
47 |
+
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
|
48 |
+
relative_coords_table = torch.stack(
|
49 |
+
torch.meshgrid([relative_coords_h,
|
50 |
+
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
51 |
+
if pretrained_window_size[0] > 0:
|
52 |
+
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
53 |
+
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
54 |
+
else:
|
55 |
+
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
56 |
+
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
57 |
+
|
58 |
+
if not no_log:
|
59 |
+
relative_coords_table *= 8 # normalize to -8, 8
|
60 |
+
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
61 |
+
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
62 |
+
|
63 |
+
self.register_buffer("relative_coords_table", relative_coords_table)
|
64 |
+
coords_h = torch.arange(self.window_size[0])
|
65 |
+
coords_w = torch.arange(self.window_size[1])
|
66 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
67 |
+
coords_flatten = torch.flatten(coords, 1)
|
68 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
69 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
70 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
71 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
72 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
73 |
+
relative_position_index = relative_coords.sum(-1)
|
74 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
75 |
+
self.grid_exists = False
|
76 |
+
self.pos_emb = None
|
77 |
+
self.deploy = False
|
78 |
+
relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
|
79 |
+
self.seq_length = seq_length
|
80 |
+
self.register_buffer("relative_bias", relative_bias)
|
81 |
+
self.ct_correct=ct_correct
|
82 |
+
|
83 |
+
def switch_to_deploy(self):
|
84 |
+
self.deploy = True
|
85 |
+
|
86 |
+
def forward(self, input_tensor, local_window_size):
|
87 |
+
if self.deploy:
|
88 |
+
input_tensor += self.relative_bias
|
89 |
+
return input_tensor
|
90 |
+
else:
|
91 |
+
self.grid_exists = False
|
92 |
+
|
93 |
+
if not self.grid_exists:
|
94 |
+
self.grid_exists = True
|
95 |
+
|
96 |
+
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
|
97 |
+
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
98 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
|
99 |
+
-1)
|
100 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
101 |
+
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
102 |
+
n_global_feature = input_tensor.shape[2] - local_window_size
|
103 |
+
if n_global_feature > 0 and self.ct_correct:
|
104 |
+
|
105 |
+
step_for_ct=self.window_size[0]/(n_global_feature**0.5+1)
|
106 |
+
seq_length = int(n_global_feature ** 0.5)
|
107 |
+
indices = []
|
108 |
+
for i in range(seq_length):
|
109 |
+
for j in range(seq_length):
|
110 |
+
ind = (i+1)*step_for_ct*self.window_size[0] + (j+1)*step_for_ct
|
111 |
+
indices.append(int(ind))
|
112 |
+
|
113 |
+
top_part = relative_position_bias[:, indices, :]
|
114 |
+
lefttop_part = relative_position_bias[:, indices, :][:, :, indices]
|
115 |
+
left_part = relative_position_bias[:, :, indices]
|
116 |
+
relative_position_bias = torch.nn.functional.pad(relative_position_bias, (n_global_feature,
|
117 |
+
0,
|
118 |
+
n_global_feature,
|
119 |
+
0)).contiguous()
|
120 |
+
if n_global_feature>0 and self.ct_correct:
|
121 |
+
relative_position_bias = relative_position_bias*0.0
|
122 |
+
relative_position_bias[:, :n_global_feature, :n_global_feature] = lefttop_part
|
123 |
+
relative_position_bias[:, :n_global_feature, n_global_feature:] = top_part
|
124 |
+
relative_position_bias[:, n_global_feature:, :n_global_feature] = left_part
|
125 |
+
|
126 |
+
self.pos_emb = relative_position_bias.unsqueeze(0)
|
127 |
+
self.relative_bias = self.pos_emb
|
128 |
+
|
129 |
+
input_tensor += self.pos_emb
|
130 |
+
return input_tensor
|
131 |
+
|
132 |
+
|
133 |
+
class PosEmbMLPSwinv1D(nn.Module):
|
134 |
+
def __init__(self,
|
135 |
+
dim,
|
136 |
+
rank=2,
|
137 |
+
seq_length=4,
|
138 |
+
conv=False):
|
139 |
+
super().__init__()
|
140 |
+
self.rank = rank
|
141 |
+
if not conv:
|
142 |
+
self.cpb_mlp = nn.Sequential(nn.Linear(self.rank, 512, bias=True),
|
143 |
+
nn.ReLU(),
|
144 |
+
nn.Linear(512, dim, bias=False))
|
145 |
+
else:
|
146 |
+
self.cpb_mlp = nn.Sequential(nn.Conv1d(self.rank, 512, 1,bias=True),
|
147 |
+
nn.ReLU(),
|
148 |
+
nn.Conv1d(512, dim, 1,bias=False))
|
149 |
+
self.grid_exists = False
|
150 |
+
self.pos_emb = None
|
151 |
+
self.deploy = False
|
152 |
+
relative_bias = torch.zeros(1,seq_length, dim)
|
153 |
+
self.register_buffer("relative_bias", relative_bias)
|
154 |
+
self.conv = conv
|
155 |
+
|
156 |
+
def switch_to_deploy(self):
|
157 |
+
self.deploy = True
|
158 |
+
|
159 |
+
def forward(self, input_tensor):
|
160 |
+
seq_length = input_tensor.shape[1] if not self.conv else input_tensor.shape[2]
|
161 |
+
if self.deploy:
|
162 |
+
return input_tensor + self.relative_bias
|
163 |
+
else:
|
164 |
+
self.grid_exists = False
|
165 |
+
if not self.grid_exists:
|
166 |
+
self.grid_exists = True
|
167 |
+
if self.rank == 1:
|
168 |
+
relative_coords_h = torch.arange(0, seq_length, device=input_tensor.device, dtype = input_tensor.dtype)
|
169 |
+
relative_coords_h -= seq_length//2
|
170 |
+
relative_coords_h /= (seq_length//2)
|
171 |
+
relative_coords_table = relative_coords_h
|
172 |
+
self.pos_emb = self.cpb_mlp(relative_coords_table.unsqueeze(0).unsqueeze(2))
|
173 |
+
self.relative_bias = self.pos_emb
|
174 |
+
else:
|
175 |
+
seq_length = int(seq_length**0.5)
|
176 |
+
relative_coords_h = torch.arange(0, seq_length, device=input_tensor.device, dtype = input_tensor.dtype)
|
177 |
+
relative_coords_w = torch.arange(0, seq_length, device=input_tensor.device, dtype = input_tensor.dtype)
|
178 |
+
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])).contiguous().unsqueeze(0)
|
179 |
+
relative_coords_table -= seq_length // 2
|
180 |
+
relative_coords_table /= (seq_length // 2)
|
181 |
+
if not self.conv:
|
182 |
+
self.pos_emb = self.cpb_mlp(relative_coords_table.flatten(2).transpose(1,2))
|
183 |
+
else:
|
184 |
+
self.pos_emb = self.cpb_mlp(relative_coords_table.flatten(2))
|
185 |
+
self.relative_bias = self.pos_emb
|
186 |
+
input_tensor = input_tensor + self.pos_emb
|
187 |
+
return input_tensor
|
188 |
+
|
189 |
+
|
190 |
+
class Mlp(nn.Module):
|
191 |
+
"""
|
192 |
+
Multi-Layer Perceptron (MLP) block
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(self,
|
196 |
+
in_features,
|
197 |
+
hidden_features=None,
|
198 |
+
out_features=None,
|
199 |
+
act_layer=nn.GELU,
|
200 |
+
drop=0.):
|
201 |
+
"""
|
202 |
+
Args:
|
203 |
+
in_features: input features dimension.
|
204 |
+
hidden_features: hidden features dimension.
|
205 |
+
out_features: output features dimension.
|
206 |
+
act_layer: activation function.
|
207 |
+
drop: dropout rate.
|
208 |
+
"""
|
209 |
+
|
210 |
+
super().__init__()
|
211 |
+
out_features = out_features or in_features
|
212 |
+
hidden_features = hidden_features or in_features
|
213 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
214 |
+
self.act = act_layer()
|
215 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
216 |
+
self.drop = nn.Dropout(drop)
|
217 |
+
|
218 |
+
def forward(self, x):
|
219 |
+
x_size = x.size()
|
220 |
+
x = x.view(-1, x_size[-1])
|
221 |
+
x = self.fc1(x)
|
222 |
+
x = self.act(x)
|
223 |
+
x = self.drop(x)
|
224 |
+
x = self.fc2(x)
|
225 |
+
x = self.drop(x)
|
226 |
+
x = x.view(x_size)
|
227 |
+
return x
|
228 |
+
|
229 |
+
class Downsample(nn.Module):
|
230 |
+
"""
|
231 |
+
Down-sampling block based on: "Hatamizadeh et al.,
|
232 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self,
|
236 |
+
dim,
|
237 |
+
out_dim,
|
238 |
+
keep_dim=False,
|
239 |
+
stride=2,
|
240 |
+
):
|
241 |
+
"""
|
242 |
+
Args:
|
243 |
+
dim: feature size dimension.
|
244 |
+
norm_layer: normalization layer.
|
245 |
+
keep_dim: bool argument for maintaining the resolution.
|
246 |
+
"""
|
247 |
+
|
248 |
+
super().__init__()
|
249 |
+
if keep_dim:
|
250 |
+
out_dim = dim
|
251 |
+
self.norm = LayerNorm2d(dim)
|
252 |
+
self.reduction = nn.Sequential(
|
253 |
+
nn.Conv2d(dim, out_dim, 3, stride, 1, bias=False),
|
254 |
+
)
|
255 |
+
|
256 |
+
def forward(self, x):
|
257 |
+
x = self.norm(x)
|
258 |
+
x = self.reduction(x)
|
259 |
+
return x
|
260 |
+
class PatchEmbed(nn.Module):
|
261 |
+
"""
|
262 |
+
Patch embedding block based on: "Hatamizadeh et al.,
|
263 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(self, in_chans=3, in_dim=64, dim=96):
|
267 |
+
"""
|
268 |
+
Args:
|
269 |
+
in_chans: number of input channels.
|
270 |
+
dim: feature size dimension.
|
271 |
+
"""
|
272 |
+
super().__init__()
|
273 |
+
self.proj = nn.Identity()
|
274 |
+
self.conv_down = nn.Sequential(
|
275 |
+
nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
|
276 |
+
nn.BatchNorm2d(in_dim, eps=1e-4),
|
277 |
+
nn.ReLU(),
|
278 |
+
nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
|
279 |
+
nn.BatchNorm2d(dim, eps=1e-4),
|
280 |
+
nn.ReLU()
|
281 |
+
)
|
282 |
+
|
283 |
+
def forward(self, x):
|
284 |
+
x = self.proj(x)
|
285 |
+
x = self.conv_down(x)
|
286 |
+
return x
|
287 |
+
|
288 |
+
|
289 |
+
class ConvBlock(nn.Module):
|
290 |
+
"""
|
291 |
+
Conv block based on: "Hatamizadeh et al.,
|
292 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, dim,
|
296 |
+
drop_path=0.,
|
297 |
+
layer_scale=None,
|
298 |
+
kernel_size=3):
|
299 |
+
super().__init__()
|
300 |
+
"""
|
301 |
+
Args:
|
302 |
+
drop_path: drop path.
|
303 |
+
layer_scale: layer scale coefficient.
|
304 |
+
kernel_size: kernel size.
|
305 |
+
"""
|
306 |
+
self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
|
307 |
+
self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
|
308 |
+
self.act1 = nn.GELU()
|
309 |
+
self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
|
310 |
+
self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
|
311 |
+
self.layer_scale = layer_scale
|
312 |
+
if layer_scale is not None and type(layer_scale) in [int, float]:
|
313 |
+
self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
|
314 |
+
self.layer_scale = True
|
315 |
+
else:
|
316 |
+
self.layer_scale = False
|
317 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
318 |
+
|
319 |
+
def forward(self, x, global_feature=None):
|
320 |
+
input = x
|
321 |
+
x = self.conv1(x)
|
322 |
+
x = self.norm1(x)
|
323 |
+
x = self.act1(x)
|
324 |
+
x = self.conv2(x)
|
325 |
+
x = self.norm2(x)
|
326 |
+
if self.layer_scale:
|
327 |
+
x = x * self.gamma.view(1, -1, 1, 1)
|
328 |
+
x = input + self.drop_path(x)
|
329 |
+
return x, global_feature
|
330 |
+
|
331 |
+
|
332 |
+
class WindowAttention(nn.Module):
|
333 |
+
"""
|
334 |
+
Window attention based on: "Hatamizadeh et al.,
|
335 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
336 |
+
"""
|
337 |
+
def __init__(self,
|
338 |
+
dim,
|
339 |
+
num_heads=8,
|
340 |
+
qkv_bias=False,
|
341 |
+
qk_scale=None,
|
342 |
+
attn_drop=0.,
|
343 |
+
proj_drop=0.,
|
344 |
+
resolution=0,
|
345 |
+
seq_length=0):
|
346 |
+
super().__init__()
|
347 |
+
"""
|
348 |
+
Args:
|
349 |
+
dim: feature size dimension.
|
350 |
+
num_heads: number of attention head.
|
351 |
+
qkv_bias: bool argument for query, key, value learnable bias.
|
352 |
+
qk_scale: bool argument to scaling query, key.
|
353 |
+
attn_drop: attention dropout rate.
|
354 |
+
proj_drop: output dropout rate.
|
355 |
+
resolution: feature resolution.
|
356 |
+
seq_length: sequence length.
|
357 |
+
"""
|
358 |
+
self.num_heads = num_heads
|
359 |
+
head_dim = dim // num_heads
|
360 |
+
self.head_dim = dim // num_heads
|
361 |
+
self.scale = qk_scale or head_dim ** -0.5
|
362 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
363 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
364 |
+
self.proj = nn.Linear(dim, dim)
|
365 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
366 |
+
# attention positional bias
|
367 |
+
self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],
|
368 |
+
pretrained_window_size=[resolution, resolution],
|
369 |
+
num_heads=num_heads,
|
370 |
+
seq_length=seq_length)
|
371 |
+
|
372 |
+
self.resolution = resolution
|
373 |
+
|
374 |
+
def forward(self, x):
|
375 |
+
B, N, C = x.shape
|
376 |
+
qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
377 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
378 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
379 |
+
attn = self.pos_emb_funct(attn, self.resolution ** 2)
|
380 |
+
attn = attn.softmax(dim=-1)
|
381 |
+
attn = self.attn_drop(attn)
|
382 |
+
x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
|
383 |
+
x = self.proj(x)
|
384 |
+
x = self.proj_drop(x)
|
385 |
+
return x
|
386 |
+
|
387 |
+
|
388 |
+
class HAT(nn.Module):
|
389 |
+
"""
|
390 |
+
Hierarchical attention (HAT) based on: "Hatamizadeh et al.,
|
391 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
392 |
+
"""
|
393 |
+
def __init__(self,
|
394 |
+
dim,
|
395 |
+
num_heads,
|
396 |
+
mlp_ratio=4.,
|
397 |
+
qkv_bias=False,
|
398 |
+
qk_scale=None,
|
399 |
+
drop=0.,
|
400 |
+
attn_drop=0.,
|
401 |
+
drop_path=0.,
|
402 |
+
act_layer=nn.GELU,
|
403 |
+
norm_layer=nn.LayerNorm,
|
404 |
+
sr_ratio=1.,
|
405 |
+
window_size=7,
|
406 |
+
last=False,
|
407 |
+
layer_scale=None,
|
408 |
+
ct_size=1,
|
409 |
+
do_propagation=False):
|
410 |
+
super().__init__()
|
411 |
+
"""
|
412 |
+
Args:
|
413 |
+
dim: feature size dimension.
|
414 |
+
num_heads: number of attention head.
|
415 |
+
mlp_ratio: MLP ratio.
|
416 |
+
qkv_bias: bool argument for query, key, value learnable bias.
|
417 |
+
qk_scale: bool argument to scaling query, key.
|
418 |
+
drop: dropout rate.
|
419 |
+
attn_drop: attention dropout rate.
|
420 |
+
proj_drop: output dropout rate.
|
421 |
+
act_layer: activation function.
|
422 |
+
norm_layer: normalization layer.
|
423 |
+
sr_ratio: input to window size ratio.
|
424 |
+
window_size: window size.
|
425 |
+
last: last layer flag.
|
426 |
+
layer_scale: layer scale coefficient.
|
427 |
+
ct_size: spatial dimension of carrier token local window.
|
428 |
+
do_propagation: enable carrier token propagation.
|
429 |
+
"""
|
430 |
+
# positional encoding for windowed attention tokens
|
431 |
+
self.pos_embed = PosEmbMLPSwinv1D(dim, rank=2, seq_length=window_size**2)
|
432 |
+
self.norm1 = norm_layer(dim)
|
433 |
+
# number of carrier tokens per every window
|
434 |
+
cr_tokens_per_window = ct_size**2 if sr_ratio > 1 else 0
|
435 |
+
# total number of carrier tokens
|
436 |
+
cr_tokens_total = cr_tokens_per_window*sr_ratio*sr_ratio
|
437 |
+
self.cr_window = ct_size
|
438 |
+
self.attn = WindowAttention(dim,
|
439 |
+
num_heads=num_heads,
|
440 |
+
qkv_bias=qkv_bias,
|
441 |
+
qk_scale=qk_scale,
|
442 |
+
attn_drop=attn_drop,
|
443 |
+
proj_drop=drop,
|
444 |
+
resolution=window_size,
|
445 |
+
seq_length=window_size**2 + cr_tokens_per_window)
|
446 |
+
|
447 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
448 |
+
self.norm2 = norm_layer(dim)
|
449 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
450 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
451 |
+
self.window_size = window_size
|
452 |
+
|
453 |
+
use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
|
454 |
+
self.gamma3 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
455 |
+
self.gamma4 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
456 |
+
|
457 |
+
self.sr_ratio = sr_ratio
|
458 |
+
if sr_ratio > 1:
|
459 |
+
# if do hierarchical attention, this part is for carrier tokens
|
460 |
+
self.hat_norm1 = norm_layer(dim)
|
461 |
+
self.hat_norm2 = norm_layer(dim)
|
462 |
+
self.hat_attn = WindowAttention(
|
463 |
+
dim,
|
464 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
465 |
+
attn_drop=attn_drop, proj_drop=drop, resolution=int(cr_tokens_total**0.5),
|
466 |
+
seq_length=cr_tokens_total)
|
467 |
+
|
468 |
+
self.hat_mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
469 |
+
self.hat_drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
470 |
+
self.hat_pos_embed = PosEmbMLPSwinv1D(dim, rank=2, seq_length=cr_tokens_total)
|
471 |
+
self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
472 |
+
self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
|
473 |
+
self.upsampler = nn.Upsample(size=window_size, mode='nearest')
|
474 |
+
|
475 |
+
# keep track for the last block to explicitly add carrier tokens to feature maps
|
476 |
+
self.last = last
|
477 |
+
self.do_propagation = do_propagation
|
478 |
+
|
479 |
+
def forward(self, x, carrier_tokens):
|
480 |
+
B, T, N = x.shape
|
481 |
+
ct = carrier_tokens
|
482 |
+
x = self.pos_embed(x)
|
483 |
+
|
484 |
+
if self.sr_ratio > 1:
|
485 |
+
# do hierarchical attention via carrier tokens
|
486 |
+
# first do attention for carrier tokens
|
487 |
+
Bg, Ng, Hg = ct.shape
|
488 |
+
|
489 |
+
# ct are located quite differently
|
490 |
+
ct = ct_dewindow(ct, self.cr_window*self.sr_ratio, self.cr_window*self.sr_ratio, self.cr_window)
|
491 |
+
|
492 |
+
# positional bias for carrier tokens
|
493 |
+
ct = self.hat_pos_embed(ct)
|
494 |
+
|
495 |
+
# attention plus mlp
|
496 |
+
ct = ct + self.hat_drop_path(self.gamma1*self.hat_attn(self.hat_norm1(ct)))
|
497 |
+
ct = ct + self.hat_drop_path(self.gamma2*self.hat_mlp(self.hat_norm2(ct)))
|
498 |
+
|
499 |
+
# ct are put back to windows
|
500 |
+
ct = ct_window(ct, self.cr_window * self.sr_ratio, self.cr_window * self.sr_ratio, self.cr_window)
|
501 |
+
|
502 |
+
ct = ct.reshape(x.shape[0], -1, N)
|
503 |
+
# concatenate carrier_tokens to the windowed tokens
|
504 |
+
x = torch.cat((ct, x), dim=1)
|
505 |
+
|
506 |
+
# window attention together with carrier tokens
|
507 |
+
x = x + self.drop_path(self.gamma3*self.attn(self.norm1(x)))
|
508 |
+
x = x + self.drop_path(self.gamma4*self.mlp(self.norm2(x)))
|
509 |
+
|
510 |
+
if self.sr_ratio > 1:
|
511 |
+
# for hierarchical attention we need to split carrier tokens and window tokens back
|
512 |
+
ctr, x = x.split([x.shape[1] - self.window_size*self.window_size, self.window_size*self.window_size], dim=1)
|
513 |
+
ct = ctr.reshape(Bg, Ng, Hg) # reshape carrier tokens.
|
514 |
+
if self.last and self.do_propagation:
|
515 |
+
# propagate carrier token information into the image
|
516 |
+
ctr_image_space = ctr.transpose(1, 2).reshape(B, N, self.cr_window, self.cr_window)
|
517 |
+
x = x + self.gamma1 * self.upsampler(ctr_image_space.to(dtype=torch.float32)).flatten(2).transpose(1, 2).to(dtype=x.dtype)
|
518 |
+
return x, ct
|
519 |
+
|
520 |
+
|
521 |
+
class TokenInitializer(nn.Module):
|
522 |
+
"""
|
523 |
+
Carrier token Initializer based on: "Hatamizadeh et al.,
|
524 |
+
FasterViT: Fast Vision Transformers with Hierarchical Attention
|
525 |
+
"""
|
526 |
+
def __init__(self,
|
527 |
+
dim,
|
528 |
+
input_resolution,
|
529 |
+
window_size,
|
530 |
+
ct_size=1):
|
531 |
+
"""
|
532 |
+
Args:
|
533 |
+
dim: feature size dimension.
|
534 |
+
input_resolution: input image resolution.
|
535 |
+
window_size: window size.
|
536 |
+
ct_size: spatial dimension of carrier token local window
|
537 |
+
"""
|
538 |
+
super().__init__()
|
539 |
+
|
540 |
+
output_size = int(ct_size * input_resolution/window_size)
|
541 |
+
stride_size = int(input_resolution/output_size)
|
542 |
+
kernel_size = input_resolution - (output_size - 1) * stride_size
|
543 |
+
self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
|
544 |
+
to_global_feature = nn.Sequential()
|
545 |
+
to_global_feature.add_module("pos", self.pos_embed)
|
546 |
+
to_global_feature.add_module("pool", nn.AvgPool2d(kernel_size=kernel_size, stride=stride_size))
|
547 |
+
self.to_global_feature = to_global_feature
|
548 |
+
self.window_size = ct_size
|
549 |
+
|
550 |
+
def forward(self, x):
|
551 |
+
x = self.to_global_feature(x)
|
552 |
+
B, C, H, W = x.shape
|
553 |
+
ct = x.view(B, C, H // self.window_size, self.window_size, W // self.window_size, self.window_size)
|
554 |
+
ct = ct.permute(0, 2, 4, 3, 5, 1).reshape(-1, H*W, C)
|
555 |
+
return ct
|
556 |
+
class FasterViTLayer(nn.Module):
|
557 |
+
"""
|
558 |
+
GCViT layer based on: "Hatamizadeh et al.,
|
559 |
+
Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
|
560 |
+
"""
|
561 |
+
|
562 |
+
def __init__(self,
|
563 |
+
dim,
|
564 |
+
out_dim,
|
565 |
+
depth,
|
566 |
+
input_resolution,
|
567 |
+
num_heads,
|
568 |
+
window_size,
|
569 |
+
ct_size=1,
|
570 |
+
conv=False,
|
571 |
+
downsample=True,
|
572 |
+
mlp_ratio=4.,
|
573 |
+
qkv_bias=True,
|
574 |
+
qk_scale=None,
|
575 |
+
drop=0.,
|
576 |
+
attn_drop=0.,
|
577 |
+
drop_path=0.,
|
578 |
+
layer_scale=None,
|
579 |
+
layer_scale_conv=None,
|
580 |
+
only_local=False,
|
581 |
+
hierarchy=True,
|
582 |
+
do_propagation=False
|
583 |
+
):
|
584 |
+
"""
|
585 |
+
Args:
|
586 |
+
dim: feature size dimension.
|
587 |
+
depth: layer depth.
|
588 |
+
input_resolution: input resolution.
|
589 |
+
num_heads: number of attention head.
|
590 |
+
window_size: window size.
|
591 |
+
ct_size: spatial dimension of carrier token local window.
|
592 |
+
conv: conv_based stage flag.
|
593 |
+
downsample: downsample flag.
|
594 |
+
mlp_ratio: MLP ratio.
|
595 |
+
qkv_bias: bool argument for query, key, value learnable bias.
|
596 |
+
qk_scale: bool argument to scaling query, key.
|
597 |
+
drop: dropout rate.
|
598 |
+
attn_drop: attention dropout rate.
|
599 |
+
drop_path: drop path rate.
|
600 |
+
layer_scale: layer scale coefficient.
|
601 |
+
layer_scale_conv: conv layer scale coefficient.
|
602 |
+
only_local: local attention flag.
|
603 |
+
hierarchy: hierarchical attention flag.
|
604 |
+
do_propagation: enable carrier token propagation.
|
605 |
+
"""
|
606 |
+
super().__init__()
|
607 |
+
self.conv = conv
|
608 |
+
self.transformer_block = False
|
609 |
+
if conv:
|
610 |
+
self.blocks = nn.ModuleList([
|
611 |
+
ConvBlock(dim=dim,
|
612 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
613 |
+
layer_scale=layer_scale_conv)
|
614 |
+
for i in range(depth)])
|
615 |
+
self.transformer_block = False
|
616 |
+
else:
|
617 |
+
sr_ratio = input_resolution // window_size if not only_local else 1
|
618 |
+
self.blocks = nn.ModuleList([
|
619 |
+
HAT(dim=dim,
|
620 |
+
num_heads=num_heads,
|
621 |
+
mlp_ratio=mlp_ratio,
|
622 |
+
qkv_bias=qkv_bias,
|
623 |
+
qk_scale=qk_scale,
|
624 |
+
drop=drop,
|
625 |
+
attn_drop=attn_drop,
|
626 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
627 |
+
sr_ratio=sr_ratio,
|
628 |
+
window_size=window_size,
|
629 |
+
last=(i == depth-1),
|
630 |
+
layer_scale=layer_scale,
|
631 |
+
ct_size=ct_size,
|
632 |
+
do_propagation=do_propagation,
|
633 |
+
)
|
634 |
+
for i in range(depth)])
|
635 |
+
self.transformer_block = True
|
636 |
+
self.downsample = Downsample(dim=dim, out_dim=out_dim, stride=1) if not downsample else Downsample(dim=dim, out_dim=out_dim, stride=2)
|
637 |
+
if len(self.blocks) and not only_local and input_resolution // window_size > 1 and hierarchy and not self.conv:
|
638 |
+
self.global_tokenizer = TokenInitializer(dim,
|
639 |
+
input_resolution,
|
640 |
+
window_size,
|
641 |
+
ct_size=ct_size)
|
642 |
+
self.do_gt = True
|
643 |
+
else:
|
644 |
+
self.do_gt = False
|
645 |
+
|
646 |
+
self.window_size = window_size
|
647 |
+
|
648 |
+
def forward(self, x):
|
649 |
+
ct = self.global_tokenizer(x) if self.do_gt else None
|
650 |
+
B, C, H, W = x.shape
|
651 |
+
if self.transformer_block:
|
652 |
+
x = window_partition(x, self.window_size)
|
653 |
+
for bn, blk in enumerate(self.blocks):
|
654 |
+
x, ct = blk(x, ct)
|
655 |
+
if self.transformer_block:
|
656 |
+
x = window_reverse(x, self.window_size, H, W, B)
|
657 |
+
if self.downsample is None:
|
658 |
+
return x
|
659 |
+
return self.downsample(x)
|
sam_extension/distillation_models/sam.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import functools
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
|
10 |
+
from typing import Optional, List, Union, Tuple, Type
|
11 |
+
|
12 |
+
from segment_anything import build_sam
|
13 |
+
from segment_anything.mobile_encoder.tiny_vit_sam import TinyViT
|
14 |
+
from segment_anything.modeling import PromptEncoder, MaskDecoder, TwoWayTransformer
|
15 |
+
from segment_anything.modeling.image_encoder import ImageEncoderViT, LayerNorm2d, PatchEmbed, Block, Attention
|
16 |
+
from segment_anything.mobile_encoder.setup_mobile_sam import load_mobile_sam
|
17 |
+
from segment_anything.modeling.sam import Sam
|
18 |
+
|
19 |
+
from sam_extension.distillation_models.fastertinyvit import FasterTinyViT
|
20 |
+
from sam_extension.distillation_models.dino import DINO
|
21 |
+
# from sam_extension.distillation_models.flashvision_transformer import FlashVisionTransformer
|
22 |
+
|
23 |
+
SAM_REPO_ID = 'YouLiXiya/YL-SAM'
|
24 |
+
hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir_use_symlinks=True)
|
25 |
+
|
26 |
+
|
27 |
+
class SAMImageEncoder(nn.Module):
|
28 |
+
def __init__(self,
|
29 |
+
sam_checkpoint_path,
|
30 |
+
device='cuda'):
|
31 |
+
super(SAMImageEncoder, self).__init__()
|
32 |
+
sam = build_sam(sam_checkpoint_path).to(device)
|
33 |
+
self.image_encoder = sam.image_encoder
|
34 |
+
del sam
|
35 |
+
torch.cuda.empty_cache()
|
36 |
+
def forward(self, x):
|
37 |
+
return self.image_encoder(x)
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
class MobileSAMImageEncoder(nn.Module):
|
42 |
+
def __init__(self,
|
43 |
+
sam_checkpoint_path,
|
44 |
+
device='cuda'):
|
45 |
+
super(MobileSAMImageEncoder, self).__init__()
|
46 |
+
sam = load_mobile_sam(sam_checkpoint_path, device)
|
47 |
+
self.image_encoder = sam.image_encoder
|
48 |
+
del sam
|
49 |
+
torch.cuda.empty_cache()
|
50 |
+
def forward(self, x):
|
51 |
+
return self.image_encoder(x)
|
52 |
+
|
53 |
+
class SAMEncoderViT(nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
img_size: int = 1024,
|
57 |
+
patch_size: int = 16,
|
58 |
+
in_chans: int = 3,
|
59 |
+
embed_dim: int = 768,
|
60 |
+
depth: int = 12,
|
61 |
+
num_heads: int = 12,
|
62 |
+
mlp_ratio: float = 4.0,
|
63 |
+
out_chans: int = 256,
|
64 |
+
qkv_bias: bool = True,
|
65 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
66 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
67 |
+
use_abs_pos: bool = True,
|
68 |
+
use_rel_pos: bool = False,
|
69 |
+
rel_pos_zero_init: bool = True,
|
70 |
+
window_size: int = 0,
|
71 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
72 |
+
multi_scale: bool = False,
|
73 |
+
output_shape: Union[Tuple, List] = None
|
74 |
+
) -> None:
|
75 |
+
"""
|
76 |
+
Args:
|
77 |
+
img_size (int): Input image size.
|
78 |
+
patch_size (int): Patch size.
|
79 |
+
in_chans (int): Number of input image channels.
|
80 |
+
embed_dim (int): Patch embedding dimension.
|
81 |
+
depth (int): Depth of ViT.
|
82 |
+
num_heads (int): Number of attention heads in each ViT block.
|
83 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
84 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
85 |
+
norm_layer (nn.Module): Normalization layer.
|
86 |
+
act_layer (nn.Module): Activation layer.
|
87 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
88 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
89 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
90 |
+
window_size (int): Window size for window attention blocks.
|
91 |
+
global_attn_indexes (list): Indexes for blocks using global attention.
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
self.img_size = img_size
|
95 |
+
self.multi_scale = multi_scale
|
96 |
+
self.output_shape = tuple(output_shape) if output_shape else None
|
97 |
+
|
98 |
+
|
99 |
+
self.patch_embed = PatchEmbed(
|
100 |
+
kernel_size=(patch_size, patch_size),
|
101 |
+
stride=(patch_size, patch_size),
|
102 |
+
in_chans=in_chans,
|
103 |
+
embed_dim=embed_dim,
|
104 |
+
)
|
105 |
+
|
106 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
107 |
+
if use_abs_pos:
|
108 |
+
# Initialize absolute positional embedding with pretrain image size.
|
109 |
+
self.pos_embed = nn.Parameter(
|
110 |
+
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
111 |
+
)
|
112 |
+
|
113 |
+
self.blocks = nn.ModuleList()
|
114 |
+
for i in range(depth):
|
115 |
+
block = Block(
|
116 |
+
dim=embed_dim,
|
117 |
+
num_heads=num_heads,
|
118 |
+
mlp_ratio=mlp_ratio,
|
119 |
+
qkv_bias=qkv_bias,
|
120 |
+
norm_layer=norm_layer,
|
121 |
+
act_layer=act_layer,
|
122 |
+
use_rel_pos=use_rel_pos,
|
123 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
124 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
125 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
126 |
+
)
|
127 |
+
self.blocks.append(block)
|
128 |
+
|
129 |
+
self.neck = nn.Sequential(
|
130 |
+
nn.Conv2d(
|
131 |
+
embed_dim*depth if self.multi_scale and self.output_shape else embed_dim,
|
132 |
+
out_chans,
|
133 |
+
kernel_size=1,
|
134 |
+
bias=False,
|
135 |
+
),
|
136 |
+
LayerNorm2d(out_chans),
|
137 |
+
nn.Conv2d(
|
138 |
+
out_chans,
|
139 |
+
out_chans,
|
140 |
+
kernel_size=3,
|
141 |
+
padding=1,
|
142 |
+
bias=False,
|
143 |
+
),
|
144 |
+
LayerNorm2d(out_chans),
|
145 |
+
)
|
146 |
+
|
147 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
148 |
+
x = self.patch_embed(x)
|
149 |
+
if self.pos_embed is not None:
|
150 |
+
x = x + self.pos_embed
|
151 |
+
|
152 |
+
if self.multi_scale and self.output_shape:
|
153 |
+
output_list = []
|
154 |
+
for blk in self.blocks:
|
155 |
+
x = blk(x)
|
156 |
+
output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear'))
|
157 |
+
|
158 |
+
x = self.neck(torch.cat(output_list, dim=1))
|
159 |
+
else:
|
160 |
+
for blk in self.blocks:
|
161 |
+
x = blk(x)
|
162 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
163 |
+
return x
|
164 |
+
|
165 |
+
class SAMEncoderAdaptor(nn.Module):
|
166 |
+
def __init__(self,
|
167 |
+
img_size: int,
|
168 |
+
input_size: Optional[Tuple[int, int]],
|
169 |
+
embed_dim: int = 768,
|
170 |
+
depth: int = 12,
|
171 |
+
num_heads: int = 12,
|
172 |
+
mlp_ratio: float = 4.0,
|
173 |
+
out_chans: int = 256,
|
174 |
+
qkv_bias: bool = True,
|
175 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
176 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
177 |
+
use_abs_pos: bool = True,
|
178 |
+
use_rel_pos: bool = False,
|
179 |
+
rel_pos_zero_init: bool = True,
|
180 |
+
window_size: int = 0,
|
181 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
182 |
+
multi_scale: bool = False,
|
183 |
+
output_shape: Union[Tuple, List] = None):
|
184 |
+
super(SAMEncoderAdaptor, self).__init__()
|
185 |
+
self.img_size = img_size
|
186 |
+
self.multi_scale = multi_scale
|
187 |
+
self.output_shape = tuple(output_shape) if output_shape else None
|
188 |
+
|
189 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
190 |
+
if use_abs_pos:
|
191 |
+
# Initialize absolute positional embedding with pretrain image size.
|
192 |
+
self.pos_embed = nn.Parameter(
|
193 |
+
torch.zeros(1, input_size[0], input_size[1], embed_dim)
|
194 |
+
)
|
195 |
+
self.blocks = nn.ModuleList()
|
196 |
+
for i in range(depth):
|
197 |
+
block = Block(
|
198 |
+
dim=embed_dim,
|
199 |
+
num_heads=num_heads,
|
200 |
+
mlp_ratio=mlp_ratio,
|
201 |
+
qkv_bias=qkv_bias,
|
202 |
+
norm_layer=norm_layer,
|
203 |
+
act_layer=act_layer,
|
204 |
+
use_rel_pos=use_rel_pos,
|
205 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
206 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
207 |
+
input_size=input_size,
|
208 |
+
)
|
209 |
+
self.blocks.append(block)
|
210 |
+
|
211 |
+
self.neck = nn.Sequential(
|
212 |
+
nn.Conv2d(
|
213 |
+
embed_dim * depth if self.multi_scale and self.output_shape else embed_dim,
|
214 |
+
out_chans,
|
215 |
+
kernel_size=1,
|
216 |
+
bias=False,
|
217 |
+
),
|
218 |
+
LayerNorm2d(out_chans),
|
219 |
+
nn.Conv2d(
|
220 |
+
out_chans,
|
221 |
+
out_chans,
|
222 |
+
kernel_size=3,
|
223 |
+
padding=1,
|
224 |
+
bias=False,
|
225 |
+
),
|
226 |
+
LayerNorm2d(out_chans),
|
227 |
+
)
|
228 |
+
|
229 |
+
def forward(self, x: torch.Tensor, original_size: Union[Tuple, List] = None) -> torch.Tensor:
|
230 |
+
if original_size:
|
231 |
+
original_size = torch.LongTensor(original_size)
|
232 |
+
output_shape = x.shape[-2:]
|
233 |
+
if original_size.ndim == 1:
|
234 |
+
original_size = original_size[None, ...]
|
235 |
+
adaptor_inputs = []
|
236 |
+
for i in range(original_size.shape[0]):
|
237 |
+
h, w = original_size[i]
|
238 |
+
if h > w:
|
239 |
+
new_h = output_shape[0]
|
240 |
+
new_w = int(w * new_h / h)
|
241 |
+
else:
|
242 |
+
new_w = output_shape[1]
|
243 |
+
new_h = int(h * new_w / w)
|
244 |
+
encoder_output = x[0].unsqueeze(0)
|
245 |
+
encoder_output = F.interpolate(encoder_output, size=(new_h, new_w), mode='bilinear')
|
246 |
+
pad_h = output_shape[0] - new_h
|
247 |
+
pad_w = output_shape[1] - new_w
|
248 |
+
encoder_output = F.pad(encoder_output, (0, pad_w, 0, pad_h))
|
249 |
+
adaptor_inputs.append(encoder_output)
|
250 |
+
adaptor_inputs = torch.cat(adaptor_inputs, dim=0)
|
251 |
+
x = adaptor_inputs.permute(0, 2, 3, 1)
|
252 |
+
if self.pos_embed is not None:
|
253 |
+
x = x + self.pos_embed
|
254 |
+
if self.multi_scale and self.output_shape:
|
255 |
+
output_list = []
|
256 |
+
for blk in self.blocks:
|
257 |
+
x = blk(x)
|
258 |
+
output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear'))
|
259 |
+
|
260 |
+
x = self.neck(torch.cat(output_list, dim=1))
|
261 |
+
else:
|
262 |
+
for blk in self.blocks:
|
263 |
+
x = blk(x)
|
264 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
265 |
+
return x
|
266 |
+
|
267 |
+
|
268 |
+
class DINOSAMViT(nn.Module):
|
269 |
+
def __init__(self,
|
270 |
+
dino_model_type,
|
271 |
+
device='cuda',
|
272 |
+
pca_dim=None,
|
273 |
+
**kwargs
|
274 |
+
):
|
275 |
+
super(DINOSAMViT, self).__init__()
|
276 |
+
self.img_size = kwargs['img_size']
|
277 |
+
if not pca_dim:
|
278 |
+
pca_dim = None
|
279 |
+
self.dino = DINO(dino_model_type, device, self.img_size, pca_dim)
|
280 |
+
self.input_size = tuple(kwargs['output_shape'])
|
281 |
+
# input_size = self.dino.model.patch_embed.img_size // self.dino.model.patch_embed.img_size
|
282 |
+
# self.input_size = (input_size, input_size)
|
283 |
+
embed_dim = pca_dim if pca_dim is not None else self.dino.model.embed_dim
|
284 |
+
kwargs.update({'input_size': self.input_size, 'embed_dim': embed_dim})
|
285 |
+
self.adaptor = SAMEncoderAdaptor(**kwargs).to(device)
|
286 |
+
def extract_dino_features(self, x, transform=False, size = None):
|
287 |
+
return self.dino.extract_features(x, transform, size)
|
288 |
+
def forward(self, x, transform=False, size = None):
|
289 |
+
dino_feature = F.normalize(self.extract_dino_features(x, transform, size), dim=3)
|
290 |
+
adaptor_input = F.interpolate(dino_feature.permute(0, 3, 1, 2), size=self.input_size, mode='bilinear').permute(0, 2, 3, 1)
|
291 |
+
return self.adaptor(adaptor_input)
|
292 |
+
def setup_model(model_config):
|
293 |
+
prompt_embed_dim = 256
|
294 |
+
image_size = 1024
|
295 |
+
vit_patch_size = 16
|
296 |
+
image_embedding_size = image_size // vit_patch_size
|
297 |
+
model = eval(model_config.pop('type'))(**model_config)
|
298 |
+
if model.__class__.__name__ == 'SAMEncoderAdaptor':
|
299 |
+
adaptor = model
|
300 |
+
image_encoder = load_sam('weights/sam/mobile_sam.pt', 'mobile_sam', 'cpu').image_encoder
|
301 |
+
else:
|
302 |
+
adaptor = None
|
303 |
+
image_encoder = model
|
304 |
+
sam = Sam(
|
305 |
+
image_encoder=image_encoder,
|
306 |
+
prompt_encoder=PromptEncoder(
|
307 |
+
embed_dim=prompt_embed_dim,
|
308 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
309 |
+
input_image_size=(image_size, image_size),
|
310 |
+
mask_in_chans=16,
|
311 |
+
),
|
312 |
+
mask_decoder=MaskDecoder(
|
313 |
+
num_multimask_outputs=3,
|
314 |
+
transformer=TwoWayTransformer(
|
315 |
+
depth=2,
|
316 |
+
embedding_dim=prompt_embed_dim,
|
317 |
+
mlp_dim=2048,
|
318 |
+
num_heads=8,
|
319 |
+
),
|
320 |
+
transformer_dim=prompt_embed_dim,
|
321 |
+
iou_head_depth=3,
|
322 |
+
iou_head_hidden_dim=256,
|
323 |
+
),
|
324 |
+
adaptor=adaptor,
|
325 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
326 |
+
pixel_std=[58.395, 57.12, 57.375],
|
327 |
+
)
|
328 |
+
return sam
|
329 |
+
|
330 |
+
def load_distillation_sam(distillation_sam_ckpt_path,
|
331 |
+
device='cuda'):
|
332 |
+
ckpt = torch.load(distillation_sam_ckpt_path)
|
333 |
+
sam = setup_model(ckpt['model_config'])
|
334 |
+
sam.load_state_dict(ckpt['model'])
|
335 |
+
return sam.to(device)
|
336 |
+
|
337 |
+
def load_sam(sam_ckpt_path, sam_version, device):
|
338 |
+
if not os.path.exists(sam_ckpt_path):
|
339 |
+
parent_dir = os.path.dirname(sam_ckpt_path)
|
340 |
+
os.makedirs(parent_dir, exist_ok=True)
|
341 |
+
hf_sam_download(filename=os.path.basename(sam_ckpt_path), local_dir=parent_dir)
|
342 |
+
if sam_version == 'sam':
|
343 |
+
sam = build_sam(sam_ckpt_path).to(device)
|
344 |
+
elif sam_version == 'mobile_sam':
|
345 |
+
sam = load_mobile_sam(sam_ckpt_path, device)
|
346 |
+
elif sam_version == 'distillation_sam':
|
347 |
+
sam = load_distillation_sam(sam_ckpt_path, device)
|
348 |
+
else:
|
349 |
+
raise ValueError('sam version error, please give sam version in [sam, mobile_sam, distillation_sam]')
|
350 |
+
return sam
|
351 |
+
|
352 |
+
if __name__ == '__main__':
|
353 |
+
from distillation.utils import get_parameter_number
|
354 |
+
vit = SAMEncoderViT(depth=3,
|
355 |
+
embed_dim=256,
|
356 |
+
img_size=512,
|
357 |
+
mlp_ratio=4,
|
358 |
+
num_heads=16,
|
359 |
+
patch_size=8,
|
360 |
+
qkv_bias=True,
|
361 |
+
use_rel_pos=True,
|
362 |
+
global_attn_indexes=[1],
|
363 |
+
window_size=16,
|
364 |
+
out_chans=256,
|
365 |
+
multi_scale=False,
|
366 |
+
output_shape='').cuda()
|
367 |
+
x = torch.randn((1, 3, 512, 512)).cuda()
|
368 |
+
print(vit(x).shape)
|
369 |
+
print(get_parameter_number(vit))
|
sam_extension/pipeline/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import Pipeline
|
2 |
+
from .sam import SAMEncoderPipeline, SAMDecoderPipeline
|
3 |
+
from .owlvit import OwlViTVisionEncoderPipeline, OwlViTDecoderPipeline
|
4 |
+
from .groundingdino import GroundingDinoPipeline
|
sam_extension/pipeline/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (421 Bytes). View file
|
|
sam_extension/pipeline/__pycache__/base.cpython-38.pyc
ADDED
Binary file (1.14 kB). View file
|
|
sam_extension/pipeline/__pycache__/groundingdino.cpython-38.pyc
ADDED
Binary file (3.28 kB). View file
|
|
sam_extension/pipeline/__pycache__/owlvit.cpython-38.pyc
ADDED
Binary file (10.8 kB). View file
|
|
sam_extension/pipeline/__pycache__/sam.cpython-38.pyc
ADDED
Binary file (19.6 kB). View file
|
|
sam_extension/pipeline/base.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from typing import Union, Dict
|
4 |
+
from dataclasses import dataclass
|
5 |
+
|
6 |
+
@dataclass(repr=True)
|
7 |
+
class Output:
|
8 |
+
pass
|
9 |
+
|
10 |
+
class Pipeline(nn.Module):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super(Pipeline, self).__init__()
|
13 |
+
self.args = args
|
14 |
+
self.kwargs = kwargs
|
15 |
+
@classmethod
|
16 |
+
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
|
17 |
+
pass
|
18 |
+
def forward(self, *args, **kwargs):
|
19 |
+
pass
|
20 |
+
|
sam_extension/pipeline/groundingdino.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import functools
|
3 |
+
import PIL
|
4 |
+
from PIL.Image import Image
|
5 |
+
import numpy as np
|
6 |
+
from typing import List, Union
|
7 |
+
import supervision as sv
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torchvision
|
11 |
+
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
from sam_extension.pipeline import Pipeline
|
14 |
+
from groundingdino.util.inference import Model
|
15 |
+
|
16 |
+
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
17 |
+
GROUNDING_DINO_CHECKPOINT_PATH = "groundingdino_swint_ogc.pth"
|
18 |
+
SAM_REPO_ID = 'YouLiXiya/YL-SAM'
|
19 |
+
LOCAL_DIR = "weights/groundingdino"
|
20 |
+
hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=True)
|
21 |
+
class GroundingDinoPipeline(Pipeline):
|
22 |
+
def __init__(self,
|
23 |
+
grounding_dino_config_path,
|
24 |
+
grounfing_dino_ckpt_path,
|
25 |
+
grounding_dino_model,
|
26 |
+
device,
|
27 |
+
*args,
|
28 |
+
**kwargs):
|
29 |
+
super(GroundingDinoPipeline, self).__init__(*args, **kwargs)
|
30 |
+
self.grounding_dino_config_path = grounding_dino_config_path
|
31 |
+
self.grounfing_dino_ckpt_path = grounfing_dino_ckpt_path
|
32 |
+
self.grounding_dino_model = grounding_dino_model
|
33 |
+
self.device = device
|
34 |
+
|
35 |
+
|
36 |
+
@classmethod
|
37 |
+
def from_pretrained(cls, grounding_dino_config_path, grounfing_dino_ckpt_path,device='cuda', *args, **kwargs):
|
38 |
+
if not os.path.exists(grounfing_dino_ckpt_path):
|
39 |
+
hf_sam_download(filename=os.path.basename(grounfing_dino_ckpt_path))
|
40 |
+
grounding_dino_model = Model(model_config_path=grounding_dino_config_path,
|
41 |
+
model_checkpoint_path=grounfing_dino_ckpt_path,
|
42 |
+
device=device)
|
43 |
+
return cls(grounding_dino_config_path,
|
44 |
+
grounfing_dino_ckpt_path,
|
45 |
+
grounding_dino_model,
|
46 |
+
device,
|
47 |
+
*args,
|
48 |
+
**kwargs)
|
49 |
+
|
50 |
+
def visualize_results(self,
|
51 |
+
img: Union[Image, np.ndarray],
|
52 |
+
class_list: [List],
|
53 |
+
box_threshold: float=0.25,
|
54 |
+
text_threshold: float=0.25,
|
55 |
+
nms_threshold: float=0.8,
|
56 |
+
pil: bool=True):
|
57 |
+
detections = self.forward(img, class_list, box_threshold, text_threshold)
|
58 |
+
box_annotator = sv.BoxAnnotator()
|
59 |
+
nms_idx = torchvision.ops.nms(
|
60 |
+
torch.from_numpy(detections.xyxy),
|
61 |
+
torch.from_numpy(detections.confidence),
|
62 |
+
nms_threshold
|
63 |
+
).numpy().tolist()
|
64 |
+
|
65 |
+
detections.xyxy = detections.xyxy[nms_idx]
|
66 |
+
detections.confidence = detections.confidence[nms_idx]
|
67 |
+
detections.class_id = detections.class_id[nms_idx]
|
68 |
+
labels = [
|
69 |
+
f"{class_list[class_id]} {confidence:0.2f}"
|
70 |
+
for _, _, confidence, class_id, _
|
71 |
+
in detections]
|
72 |
+
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
|
73 |
+
if pil:
|
74 |
+
return PIL.Image.fromarray(annotated_frame[:, :, ::-1]), detections
|
75 |
+
else:
|
76 |
+
return annotated_frame, detections
|
77 |
+
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def forward(self,
|
81 |
+
img: Union[Image, np.ndarray],
|
82 |
+
class_list: [List],
|
83 |
+
box_threshold: float=0.25,
|
84 |
+
text_threshold: float=0.25
|
85 |
+
)->sv.Detections:
|
86 |
+
if isinstance(img, Image):
|
87 |
+
img = np.uint8(img)[:, :, ::-1]
|
88 |
+
detections = self.grounding_dino_model.predict_with_classes(
|
89 |
+
image=img,
|
90 |
+
classes=class_list,
|
91 |
+
box_threshold=box_threshold,
|
92 |
+
text_threshold=text_threshold
|
93 |
+
)
|
94 |
+
return detections
|
95 |
+
|
96 |
+
|
97 |
+
|
sam_extension/pipeline/owlvit.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union, List
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
from PIL.Image import Image
|
5 |
+
import supervision as sv
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from transformers import OwlViTProcessor, OwlViTForObjectDetection, OwlViTVisionModel
|
11 |
+
from transformers.models.owlvit.modeling_owlvit import center_to_corners_format, box_iou, generalized_box_iou, OwlViTObjectDetectionOutput
|
12 |
+
|
13 |
+
from sam_extension.pipeline.base import Pipeline, Output
|
14 |
+
|
15 |
+
class OwlViTVisionEncoderPipeline(Pipeline):
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
vision_model,
|
19 |
+
layer_norm,
|
20 |
+
processor,
|
21 |
+
device='cuda',
|
22 |
+
*args,
|
23 |
+
**kwargs):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
self.vision_model = vision_model
|
26 |
+
self.layer_norm = layer_norm
|
27 |
+
self.processor = processor
|
28 |
+
self.device = device
|
29 |
+
torch.cuda.empty_cache()
|
30 |
+
@classmethod
|
31 |
+
def from_pretrained(cls, model_type, device='cuda', *args, **kwargs):
|
32 |
+
owlvit_for_object_detection = OwlViTForObjectDetection.from_pretrained(model_type).to(device)
|
33 |
+
processor = OwlViTProcessor.from_pretrained(model_type)
|
34 |
+
return cls(owlvit_for_object_detection.owlvit.vision_model,
|
35 |
+
owlvit_for_object_detection.layer_norm,
|
36 |
+
processor,
|
37 |
+
device,
|
38 |
+
*args,
|
39 |
+
**kwargs)
|
40 |
+
def process_image(self, image:Image):
|
41 |
+
image = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
|
42 |
+
return image
|
43 |
+
@torch.no_grad()
|
44 |
+
def forward(
|
45 |
+
self,
|
46 |
+
pixel_values: Union[torch.FloatTensor, Image] = None,
|
47 |
+
output_attentions: Optional[bool] = None,
|
48 |
+
output_hidden_states: Optional[bool] = None,
|
49 |
+
return_dict: Optional[bool] = None,
|
50 |
+
) -> torch.FloatTensor:
|
51 |
+
if isinstance(pixel_values, Image):
|
52 |
+
pixel_values = self.process_image(pixel_values)
|
53 |
+
pixel_values = pixel_values.to(self.device)
|
54 |
+
vision_outputs = self.vision_model(
|
55 |
+
pixel_values=pixel_values,
|
56 |
+
output_attentions=output_attentions,
|
57 |
+
output_hidden_states=output_hidden_states,
|
58 |
+
return_dict=return_dict,
|
59 |
+
)
|
60 |
+
# Get image embeddings
|
61 |
+
last_hidden_state = vision_outputs[0]
|
62 |
+
image_embeds = self.vision_model.post_layernorm(last_hidden_state)
|
63 |
+
new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
|
64 |
+
class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
|
65 |
+
|
66 |
+
# Merge image embedding with class tokens
|
67 |
+
image_embeds = image_embeds[:, 1:, :] * class_token_out
|
68 |
+
image_embeds = self.layer_norm(image_embeds)
|
69 |
+
|
70 |
+
# Resize to [batch_size, num_patches, num_patches, hidden_size]
|
71 |
+
new_size = (
|
72 |
+
image_embeds.shape[0],
|
73 |
+
int(np.sqrt(image_embeds.shape[1])),
|
74 |
+
int(np.sqrt(image_embeds.shape[1])),
|
75 |
+
image_embeds.shape[-1],
|
76 |
+
)
|
77 |
+
image_embeds = image_embeds.reshape(new_size)
|
78 |
+
return image_embeds
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
class OwlViTDecoderPipeline(Pipeline):
|
83 |
+
prompt_template: str = 'a photo of a '
|
84 |
+
def __init__(self,
|
85 |
+
owlvit_text,
|
86 |
+
text_projection,
|
87 |
+
class_head,
|
88 |
+
box_head,
|
89 |
+
processor,
|
90 |
+
device='cuda',
|
91 |
+
*args,
|
92 |
+
**kwargs):
|
93 |
+
super().__init__(*args, **kwargs)
|
94 |
+
|
95 |
+
self.owlvit_text = owlvit_text
|
96 |
+
self.text_projection = text_projection
|
97 |
+
self.class_head = class_head
|
98 |
+
self.box_head = box_head
|
99 |
+
|
100 |
+
self.sigmoid = nn.Sigmoid()
|
101 |
+
self.processor = processor
|
102 |
+
self.device = device
|
103 |
+
torch.cuda.empty_cache()
|
104 |
+
|
105 |
+
@classmethod
|
106 |
+
def from_pretrained(cls, model_type, device='cuda', *args, **kwargs):
|
107 |
+
owlvit_for_object_detection = OwlViTForObjectDetection.from_pretrained(model_type).to(device)
|
108 |
+
processor = OwlViTProcessor.from_pretrained(model_type)
|
109 |
+
return cls(owlvit_for_object_detection.owlvit.text_model,
|
110 |
+
owlvit_for_object_detection.owlvit.text_projection,
|
111 |
+
owlvit_for_object_detection.class_head,
|
112 |
+
owlvit_for_object_detection.box_head,
|
113 |
+
processor,
|
114 |
+
device,
|
115 |
+
*args,
|
116 |
+
**kwargs)
|
117 |
+
def set_template(self, template: str):
|
118 |
+
self.prompt_template = template
|
119 |
+
def process_text(self, text:List, use_template:bool = True):
|
120 |
+
if use_template:
|
121 |
+
text = [[self.prompt_template+i for i in text[0]]]
|
122 |
+
inputs = self.processor(text=text, return_tensors="pt")
|
123 |
+
return inputs
|
124 |
+
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
|
125 |
+
# Computes normalized xy corner coordinates from feature_map.
|
126 |
+
if not feature_map.ndim == 4:
|
127 |
+
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
|
128 |
+
|
129 |
+
device = feature_map.device
|
130 |
+
num_patches = feature_map.shape[1]
|
131 |
+
|
132 |
+
box_coordinates = np.stack(
|
133 |
+
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
|
134 |
+
).astype(np.float32)
|
135 |
+
box_coordinates /= np.array([num_patches, num_patches], np.float32)
|
136 |
+
|
137 |
+
# Flatten (h, w, 2) -> (h*w, 2)
|
138 |
+
box_coordinates = box_coordinates.reshape(
|
139 |
+
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
|
140 |
+
)
|
141 |
+
box_coordinates = torch.from_numpy(box_coordinates).to(device)
|
142 |
+
|
143 |
+
return box_coordinates
|
144 |
+
|
145 |
+
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
|
146 |
+
# The box center is biased to its position on the feature grid
|
147 |
+
box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
|
148 |
+
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
|
149 |
+
|
150 |
+
# Unnormalize xy
|
151 |
+
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
|
152 |
+
|
153 |
+
# The box size is biased to the patch size
|
154 |
+
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
|
155 |
+
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
|
156 |
+
|
157 |
+
# Compute box bias
|
158 |
+
box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
|
159 |
+
return box_bias
|
160 |
+
|
161 |
+
def box_predictor(
|
162 |
+
self,
|
163 |
+
image_feats: torch.FloatTensor,
|
164 |
+
feature_map: torch.FloatTensor,
|
165 |
+
) -> torch.FloatTensor:
|
166 |
+
"""
|
167 |
+
Args:
|
168 |
+
image_feats:
|
169 |
+
Features extracted from the image, returned by the `image_text_embedder` method.
|
170 |
+
feature_map:
|
171 |
+
A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
|
172 |
+
Returns:
|
173 |
+
pred_boxes:
|
174 |
+
List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
|
175 |
+
"""
|
176 |
+
# Bounding box detection head [batch_size, num_boxes, 4].
|
177 |
+
pred_boxes = self.box_head(image_feats)
|
178 |
+
|
179 |
+
# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
|
180 |
+
pred_boxes += self.compute_box_bias(feature_map)
|
181 |
+
pred_boxes = self.sigmoid(pred_boxes)
|
182 |
+
return pred_boxes
|
183 |
+
|
184 |
+
def class_predictor(
|
185 |
+
self,
|
186 |
+
image_feats: torch.FloatTensor,
|
187 |
+
query_embeds: Optional[torch.FloatTensor] = None,
|
188 |
+
query_mask: Optional[torch.Tensor] = None,
|
189 |
+
) -> Tuple[torch.FloatTensor]:
|
190 |
+
"""
|
191 |
+
Args:
|
192 |
+
image_feats:
|
193 |
+
Features extracted from the `image_text_embedder`.
|
194 |
+
query_embeds:
|
195 |
+
Text query embeddings.
|
196 |
+
query_mask:
|
197 |
+
Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
|
198 |
+
"""
|
199 |
+
(pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
|
200 |
+
|
201 |
+
return (pred_logits, image_class_embeds)
|
202 |
+
|
203 |
+
def image_text_embedder(
|
204 |
+
self,
|
205 |
+
input_ids: torch.Tensor,
|
206 |
+
image_embeds: torch.FloatTensor,
|
207 |
+
attention_mask: torch.Tensor,
|
208 |
+
output_attentions: Optional[bool] = None,
|
209 |
+
output_hidden_states: Optional[bool] = None,
|
210 |
+
) -> Tuple[torch.FloatTensor]:
|
211 |
+
|
212 |
+
# Encode text and image
|
213 |
+
text_outputs = self.owlvit_text(
|
214 |
+
input_ids=input_ids,
|
215 |
+
attention_mask=attention_mask,
|
216 |
+
output_attentions=output_attentions,
|
217 |
+
output_hidden_states=output_hidden_states,
|
218 |
+
return_dict=True,
|
219 |
+
)
|
220 |
+
text_embeds = text_outputs[1]
|
221 |
+
text_embeds = self.text_projection(text_embeds)
|
222 |
+
text_embeds = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
|
223 |
+
|
224 |
+
return (text_embeds, image_embeds, text_outputs)
|
225 |
+
|
226 |
+
def embed_image_query(
|
227 |
+
self, query_image_features: torch.FloatTensor, query_feature_map: torch.FloatTensor
|
228 |
+
) -> torch.FloatTensor:
|
229 |
+
|
230 |
+
_, class_embeds = self.class_predictor(query_image_features)
|
231 |
+
pred_boxes = self.box_predictor(query_image_features, query_feature_map)
|
232 |
+
pred_boxes_as_corners = center_to_corners_format(pred_boxes)
|
233 |
+
|
234 |
+
# Loop over query images
|
235 |
+
best_class_embeds = []
|
236 |
+
best_box_indices = []
|
237 |
+
pred_boxes_device = pred_boxes_as_corners.device
|
238 |
+
|
239 |
+
for i in range(query_image_features.shape[0]):
|
240 |
+
each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device)
|
241 |
+
each_query_pred_boxes = pred_boxes_as_corners[i]
|
242 |
+
ious, _ = box_iou(each_query_box, each_query_pred_boxes)
|
243 |
+
|
244 |
+
# If there are no overlapping boxes, fall back to generalized IoU
|
245 |
+
if torch.all(ious[0] == 0.0):
|
246 |
+
ious = generalized_box_iou(each_query_box, each_query_pred_boxes)
|
247 |
+
|
248 |
+
# Use an adaptive threshold to include all boxes within 80% of the best IoU
|
249 |
+
iou_threshold = torch.max(ious) * 0.8
|
250 |
+
|
251 |
+
selected_inds = (ious[0] >= iou_threshold).nonzero()
|
252 |
+
if selected_inds.numel():
|
253 |
+
selected_embeddings = class_embeds[i][selected_inds[0]]
|
254 |
+
mean_embeds = torch.mean(class_embeds[i], axis=0)
|
255 |
+
mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings)
|
256 |
+
best_box_ind = selected_inds[torch.argmin(mean_sim)]
|
257 |
+
best_class_embeds.append(class_embeds[i][best_box_ind])
|
258 |
+
best_box_indices.append(best_box_ind)
|
259 |
+
|
260 |
+
if best_class_embeds:
|
261 |
+
query_embeds = torch.stack(best_class_embeds)
|
262 |
+
box_indices = torch.stack(best_box_indices)
|
263 |
+
else:
|
264 |
+
query_embeds, box_indices = None, None
|
265 |
+
|
266 |
+
return query_embeds, box_indices, pred_boxes
|
267 |
+
|
268 |
+
@torch.no_grad()
|
269 |
+
def forward(
|
270 |
+
self,
|
271 |
+
image_embeds: torch.FloatTensor,
|
272 |
+
input_ids: Optional[torch.Tensor] = None,
|
273 |
+
text: Optional[List] = None,
|
274 |
+
attention_mask: Optional[torch.Tensor] = None,
|
275 |
+
output_attentions: Optional[bool] = None,
|
276 |
+
output_hidden_states: Optional[bool] = None,
|
277 |
+
return_dict: Optional[bool] = None,
|
278 |
+
) -> OwlViTObjectDetectionOutput:
|
279 |
+
if text is not None:
|
280 |
+
inputs = self.process_text(text)
|
281 |
+
input_ids = inputs.input_ids.to(self.device)
|
282 |
+
attention_mask = inputs.attention_mask.to(self.device)
|
283 |
+
input_ids = input_ids.to(self.device)
|
284 |
+
image_embeds = image_embeds.to(self.device)
|
285 |
+
attention_mask = attention_mask.to(self.device)
|
286 |
+
output_attentions = output_attentions if output_attentions is not None else False
|
287 |
+
output_hidden_states = (
|
288 |
+
output_hidden_states if output_hidden_states is not None else False
|
289 |
+
)
|
290 |
+
return_dict = return_dict if return_dict is not None else True
|
291 |
+
|
292 |
+
# Embed images and text queries
|
293 |
+
query_embeds, feature_map, text_outputs = self.image_text_embedder(
|
294 |
+
input_ids=input_ids,
|
295 |
+
image_embeds=image_embeds,
|
296 |
+
attention_mask=attention_mask,
|
297 |
+
output_attentions=output_attentions,
|
298 |
+
output_hidden_states=output_hidden_states,
|
299 |
+
)
|
300 |
+
|
301 |
+
# Text and vision model outputs
|
302 |
+
|
303 |
+
batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
|
304 |
+
image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
|
305 |
+
|
306 |
+
# Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
|
307 |
+
max_text_queries = input_ids.shape[0] // batch_size
|
308 |
+
query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])
|
309 |
+
|
310 |
+
# If first token is 0, then this is a padded query [batch_size, num_queries].
|
311 |
+
input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
|
312 |
+
query_mask = input_ids[..., 0] > 0
|
313 |
+
|
314 |
+
# Predict object classes [batch_size, num_patches, num_queries+1]
|
315 |
+
(pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask)
|
316 |
+
|
317 |
+
# Predict object boxes
|
318 |
+
pred_boxes = self.box_predictor(image_feats, feature_map)
|
319 |
+
|
320 |
+
if not return_dict:
|
321 |
+
output = (
|
322 |
+
pred_logits,
|
323 |
+
pred_boxes,
|
324 |
+
query_embeds,
|
325 |
+
feature_map,
|
326 |
+
class_embeds,
|
327 |
+
text_outputs.to_tuple(),
|
328 |
+
None,
|
329 |
+
)
|
330 |
+
output = tuple(x for x in output if x is not None)
|
331 |
+
return output
|
332 |
+
|
333 |
+
return OwlViTObjectDetectionOutput(
|
334 |
+
image_embeds=feature_map,
|
335 |
+
text_embeds=query_embeds,
|
336 |
+
pred_boxes=pred_boxes.cpu(),
|
337 |
+
logits=pred_logits.cpu(),
|
338 |
+
class_embeds=class_embeds,
|
339 |
+
text_model_output=text_outputs,
|
340 |
+
vision_model_output=None,
|
341 |
+
)
|
342 |
+
|
343 |
+
def owlvit_visualize(self,
|
344 |
+
image: Image,
|
345 |
+
texts: List,
|
346 |
+
owlvit_objectdetection_output: OwlViTObjectDetectionOutput,
|
347 |
+
score_threshold: float = 0.1,
|
348 |
+
pil=True):
|
349 |
+
target_sizes = torch.Tensor([image.size[::-1]])
|
350 |
+
# Convert outputs (bounding boxes and class logits) to COCO API
|
351 |
+
results = self.processor.post_process(outputs=owlvit_objectdetection_output, target_sizes=target_sizes)
|
352 |
+
|
353 |
+
text = texts[0]
|
354 |
+
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
|
355 |
+
boxes_np = []
|
356 |
+
labels_list = []
|
357 |
+
# Print detected objects and rescaled box coordinates
|
358 |
+
for box, score, label in zip(boxes, scores, labels):
|
359 |
+
box = [int(i) for i in box.tolist()]
|
360 |
+
if score >= score_threshold:
|
361 |
+
labels_list.append(f"{text[label]} {round(score.item(), 3)}")
|
362 |
+
boxes_np.append(box)
|
363 |
+
print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
|
364 |
+
boxes_np = np.array(boxes_np)
|
365 |
+
detections = sv.Detections(xyxy=boxes_np)
|
366 |
+
image_np = np.uint8(image)[:, :, ::-1]
|
367 |
+
box_annotator = sv.BoxAnnotator()
|
368 |
+
annotated_frame = box_annotator.annotate(scene=image_np.copy(), detections=detections, labels=labels_list)
|
369 |
+
if pil:
|
370 |
+
return PIL.Image.fromarray(annotated_frame[:, :, ::-1])
|
371 |
+
else:
|
372 |
+
return annotated_frame[:, :, ::-1]
|
sam_extension/pipeline/sam.py
ADDED
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from dataclasses import dataclass
|
3 |
+
import PIL
|
4 |
+
from PIL.Image import Image
|
5 |
+
import numpy as np
|
6 |
+
from typing import Union, Tuple, List, Optional, Callable
|
7 |
+
from sklearn.decomposition import PCA
|
8 |
+
import supervision as sv
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torchvision
|
14 |
+
import torchvision.transforms as T
|
15 |
+
|
16 |
+
from segment_anything.utils.transforms import ResizeLongestSide
|
17 |
+
from segment_anything.predictor import preprocess, postprocess_masks
|
18 |
+
from segment_anything import build_sam, load_mobile_sam
|
19 |
+
|
20 |
+
from sam_extension.utils import add_prompts_tag, get_empty_detections, transform_coords
|
21 |
+
from sam_extension.pipeline.base import Pipeline, Output
|
22 |
+
from sam_extension.pipeline.groundingdino import GroundingDinoPipeline
|
23 |
+
from sam_extension.distillation_models.sam import load_distillation_sam, load_sam
|
24 |
+
from sam_extension.distillation_models import *
|
25 |
+
|
26 |
+
ORIGINAL_SAM_IMG_SIZE: int = 1024
|
27 |
+
PIXEL_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
28 |
+
PIXEL_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
29 |
+
PREPROCESS = functools.partial(preprocess, ORIGINAL_SAM_IMG_SIZE, PIXEL_MEAN, PIXEL_STD)
|
30 |
+
POSTPROCESS_MASKS = functools.partial(postprocess_masks, ORIGINAL_SAM_IMG_SIZE)
|
31 |
+
|
32 |
+
@dataclass(repr=True)
|
33 |
+
class SAMEncoderOutput(Output):
|
34 |
+
features: torch.Tensor
|
35 |
+
interm_features: List[torch.Tensor]
|
36 |
+
original_size: Tuple
|
37 |
+
input_size: Tuple
|
38 |
+
|
39 |
+
@dataclass(repr=True)
|
40 |
+
class SAMEncoderProcesImgOutput(Output):
|
41 |
+
input_image: torch.Tensor
|
42 |
+
original_size: Tuple
|
43 |
+
input_size: Tuple
|
44 |
+
|
45 |
+
@dataclass(repr=True)
|
46 |
+
class SAMDecoderPredictOutput(Output):
|
47 |
+
masks_np: np.ndarray
|
48 |
+
iou_predictions_np: np.ndarray
|
49 |
+
low_res_masks_np: np.ndarray
|
50 |
+
|
51 |
+
@dataclass(repr=True)
|
52 |
+
class SAMDecoderPredictTorchOutput(Output):
|
53 |
+
masks: torch.Tensor
|
54 |
+
iou_predictions: torch.Tensor
|
55 |
+
low_res_masks: torch.Tensor
|
56 |
+
|
57 |
+
|
58 |
+
class SAMEncoderPipeline(Pipeline):
|
59 |
+
def __init__(self,
|
60 |
+
encoder: nn.Module,
|
61 |
+
input_img_size: Tuple,
|
62 |
+
multi_output: bool,
|
63 |
+
preprocess: Callable,
|
64 |
+
transform: ResizeLongestSide,
|
65 |
+
device: str,
|
66 |
+
*args,
|
67 |
+
**kwargs):
|
68 |
+
super(SAMEncoderPipeline, self).__init__(*args, **kwargs)
|
69 |
+
self.encoder = encoder
|
70 |
+
self.input_img_size = input_img_size
|
71 |
+
self.multi_output = multi_output
|
72 |
+
self.preprocess = preprocess
|
73 |
+
self.transform = transform
|
74 |
+
self.device = device
|
75 |
+
@classmethod
|
76 |
+
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
|
77 |
+
if 'sam_version' not in kwargs.keys():
|
78 |
+
sam_version = 'sam'
|
79 |
+
else:
|
80 |
+
sam_version = kwargs['sam_version']
|
81 |
+
sam = load_sam(ckpt_path, sam_version, device)
|
82 |
+
encoder = sam.image_encoder
|
83 |
+
encoder_type = encoder.__class__.__name__
|
84 |
+
if encoder_type in ['TinyViT', 'FasterTinyViT', 'SAMEncoderViT', 'DINOSAMViT', 'FlashVisionTransformer']:
|
85 |
+
multi_output = False
|
86 |
+
if encoder_type in ['FasterTinyViT', 'SAMEncoderViT', 'DINOSAMViT', 'FlashVisionTransformer']:
|
87 |
+
input_img_size = (encoder.img_size, encoder.img_size)
|
88 |
+
if encoder_type == 'DINOSAMViT':
|
89 |
+
encoder = encoder.dino
|
90 |
+
else:
|
91 |
+
input_img_size = (ORIGINAL_SAM_IMG_SIZE, ORIGINAL_SAM_IMG_SIZE)
|
92 |
+
else:
|
93 |
+
multi_output = True
|
94 |
+
input_img_size = (ORIGINAL_SAM_IMG_SIZE, ORIGINAL_SAM_IMG_SIZE)
|
95 |
+
if sam.adaptor is None:
|
96 |
+
transform = ResizeLongestSide(ORIGINAL_SAM_IMG_SIZE)
|
97 |
+
preprocess_ = functools.partial(preprocess, ORIGINAL_SAM_IMG_SIZE, PIXEL_MEAN.to(device), PIXEL_STD.to(device))
|
98 |
+
else:
|
99 |
+
transform = T.Compose([
|
100 |
+
T.Resize(input_img_size),
|
101 |
+
T.ToTensor(),
|
102 |
+
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
103 |
+
])
|
104 |
+
preprocess_ = None
|
105 |
+
pipeline = cls(encoder=encoder,
|
106 |
+
input_img_size=input_img_size,
|
107 |
+
multi_output=multi_output,
|
108 |
+
preprocess=preprocess_,
|
109 |
+
transform=transform,
|
110 |
+
device=device)
|
111 |
+
del sam, encoder
|
112 |
+
torch.cuda.empty_cache()
|
113 |
+
return pipeline
|
114 |
+
|
115 |
+
def process_img(self, img: Union[Image, np.ndarray]) -> SAMEncoderProcesImgOutput:
|
116 |
+
if self.preprocess is not None:
|
117 |
+
if isinstance(img, Image):
|
118 |
+
img = np.uint8(img)
|
119 |
+
input_image = self.transform.apply_image(img)
|
120 |
+
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
121 |
+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
122 |
+
original_size = tuple(img.shape[:2])
|
123 |
+
input_size = tuple(input_image_torch.shape[-2:])
|
124 |
+
input_image = F.interpolate(self.preprocess(input_image_torch), size=self.input_img_size, mode='bilinear')
|
125 |
+
else:
|
126 |
+
if isinstance(img, np.ndarray):
|
127 |
+
img = PIL.Image.fromarray(img)
|
128 |
+
original_size = (img.size[1], img.size[0])
|
129 |
+
if original_size[0] > original_size[1]:
|
130 |
+
input_h = 1024
|
131 |
+
input_w = int((1024 / original_size[0]) * original_size[1])
|
132 |
+
else:
|
133 |
+
input_w = 1024
|
134 |
+
input_h = int((1024 / original_size[1]) * original_size[0])
|
135 |
+
input_size = (input_h, input_w)
|
136 |
+
input_image = self.transform(img)[None, ...].to(self.device)
|
137 |
+
return SAMEncoderProcesImgOutput(input_image, original_size, input_size)
|
138 |
+
@torch.no_grad()
|
139 |
+
def get_visual_feature(self, x: Union[torch.Tensor, Image, np.ndarray]=None, **kwargs):
|
140 |
+
pca_rgb = PCA(n_components=3)
|
141 |
+
if 'sam_feature' in kwargs.keys() and 'original_size' in kwargs.keys():
|
142 |
+
sam_feature = kwargs['sam_feature']
|
143 |
+
original_size = kwargs['original_size']
|
144 |
+
else:
|
145 |
+
assert x is not None, 'please give x type Union[torch.Tensor, Image, np.ndarray] !'
|
146 |
+
sam_encoder_output = self.forward(x, **kwargs)
|
147 |
+
sam_feature = sam_encoder_output.features
|
148 |
+
original_size = sam_encoder_output.original_size
|
149 |
+
assert original_size is not None, 'please give original_size!'
|
150 |
+
sam_feature = F.interpolate(sam_feature, size=original_size, mode='bilinear').permute(0, 2, 3, 1)
|
151 |
+
b, h, w, c = sam_feature.shape
|
152 |
+
sam_feature = sam_feature.view(-1, c).cpu().numpy()
|
153 |
+
sam_feature = pca_rgb.fit_transform(sam_feature)
|
154 |
+
sam_feature = torch.Tensor(sam_feature.reshape(h, w, 3))
|
155 |
+
min_f, _ = sam_feature.min(-1)
|
156 |
+
max_f, _ = sam_feature.max(-1)
|
157 |
+
sam_feature = (sam_feature - min_f[..., None]) / (max_f[..., None] - min_f[..., None])
|
158 |
+
sam_feature = sam_feature.cpu().numpy()
|
159 |
+
sam_feature_image = PIL.Image.fromarray((sam_feature * 255).astype(np.uint8))
|
160 |
+
return sam_feature_image
|
161 |
+
def forward(self, x: Union[torch.Tensor, Image, np.ndarray], **kwargs) -> SAMEncoderOutput:
|
162 |
+
if isinstance(x, (Image, np.ndarray)):
|
163 |
+
process_img_output = self.process_img(x)
|
164 |
+
x = process_img_output.input_image
|
165 |
+
original_size = process_img_output.original_size
|
166 |
+
input_size = process_img_output.input_size
|
167 |
+
else:
|
168 |
+
original_size = kwargs.pop('original_size') if 'original_size' in kwargs.keys() else None
|
169 |
+
input_size = x.shape[-2:]
|
170 |
+
with torch.no_grad():
|
171 |
+
if self.multi_output:
|
172 |
+
features, interm_features = self.encoder(x, **kwargs)
|
173 |
+
else:
|
174 |
+
features = self.encoder(x, **kwargs)
|
175 |
+
if self.encoder.__class__.__name__ == 'DINO':
|
176 |
+
features = features.permute(0, 3, 1, 2)
|
177 |
+
interm_features = None
|
178 |
+
return SAMEncoderOutput(features, interm_features, original_size, input_size)
|
179 |
+
|
180 |
+
class SAMDecoderPipeline(Pipeline):
|
181 |
+
def __init__(self,
|
182 |
+
prompt_encoder: nn.Module,
|
183 |
+
mask_decoder: nn.Module,
|
184 |
+
adaptor: nn.Module,
|
185 |
+
mask_threshold: float,
|
186 |
+
transform: ResizeLongestSide,
|
187 |
+
postprocess_masks: Callable,
|
188 |
+
img_size: int,
|
189 |
+
device: str,
|
190 |
+
*args,
|
191 |
+
**kwargs):
|
192 |
+
super(SAMDecoderPipeline, self).__init__(*args, **kwargs)
|
193 |
+
self.prompt_encoder = prompt_encoder
|
194 |
+
self.mask_decoder = mask_decoder
|
195 |
+
self.adaptor = adaptor
|
196 |
+
self.mask_threshold = mask_threshold
|
197 |
+
self.transform = transform
|
198 |
+
self.postprocess_masks = postprocess_masks
|
199 |
+
self.img_size = img_size
|
200 |
+
self.device = device
|
201 |
+
@classmethod
|
202 |
+
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
|
203 |
+
if 'sam_version' not in kwargs.keys():
|
204 |
+
sam_version = 'sam'
|
205 |
+
else:
|
206 |
+
sam_version = kwargs['sam_version']
|
207 |
+
sam = load_sam(ckpt_path, sam_version, device)
|
208 |
+
if sam.image_encoder.__class__.__name__ == 'DINOSAMViT':
|
209 |
+
adaptor = sam.image_encoder.adaptor
|
210 |
+
elif sam.adaptor is not None:
|
211 |
+
adaptor = sam.adaptor
|
212 |
+
else:
|
213 |
+
adaptor = None
|
214 |
+
img_size = sam.image_encoder.img_size
|
215 |
+
prompt_encoder = sam.prompt_encoder
|
216 |
+
mask_decoder = sam.mask_decoder
|
217 |
+
transform = ResizeLongestSide(ORIGINAL_SAM_IMG_SIZE)
|
218 |
+
pipeline = cls(prompt_encoder=prompt_encoder,
|
219 |
+
mask_decoder=mask_decoder,
|
220 |
+
adaptor=adaptor,
|
221 |
+
mask_threshold=sam.mask_threshold,
|
222 |
+
transform=transform,
|
223 |
+
postprocess_masks=POSTPROCESS_MASKS,
|
224 |
+
img_size=img_size,
|
225 |
+
device=device)
|
226 |
+
del sam, prompt_encoder, mask_decoder
|
227 |
+
torch.cuda.empty_cache()
|
228 |
+
return pipeline
|
229 |
+
def visualize_prompt(self,
|
230 |
+
img: Union[Image, np.ndarray],
|
231 |
+
des_img: Union[Image, np.ndarray] = None,
|
232 |
+
point_labels: Union[List[int], np.ndarray] = None,
|
233 |
+
point_coords: Union[List[List[int]], np.ndarray] = None,
|
234 |
+
boxes: Union[List[List[int]], np.ndarray] = None,
|
235 |
+
pil: bool = False
|
236 |
+
) -> Union[Image, np.ndarray]:
|
237 |
+
if des_img is not None:
|
238 |
+
if isinstance(des_img, np.ndarray):
|
239 |
+
des_shape = tuple(des_img.shape[:2])
|
240 |
+
|
241 |
+
else:
|
242 |
+
des_shape = (des_img.size[1], des_img.size[0])
|
243 |
+
src_shape = (img.size[1], img.size[0])
|
244 |
+
point_coords, boxes = transform_coords(src_shape, des_shape, point_coords, boxes)
|
245 |
+
return add_prompts_tag(des_img, point_labels, point_coords, boxes, pil)
|
246 |
+
else:
|
247 |
+
return add_prompts_tag(img, point_labels, point_coords, boxes, pil)
|
248 |
+
|
249 |
+
def visualize_results(self,
|
250 |
+
img: Union[Image, np.ndarray],
|
251 |
+
des_img: Union[Image, np.ndarray] = None,
|
252 |
+
sam_encoder_output: Optional[SAMEncoderOutput] = None,
|
253 |
+
features: Optional[torch.Tensor] = None,
|
254 |
+
interm_features: Optional[List[torch.Tensor]] = None,
|
255 |
+
original_size: Optional[Tuple] = None,
|
256 |
+
input_size: Optional[Tuple] = None,
|
257 |
+
point_coords: Optional[np.ndarray] = None,
|
258 |
+
point_labels: Optional[np.ndarray] = None,
|
259 |
+
boxes: Optional[np.ndarray] = None,
|
260 |
+
texts: Optional[List] = None,
|
261 |
+
grounding_dino_pipeline: GroundingDinoPipeline = None,
|
262 |
+
box_threshold: float = 0.25,
|
263 |
+
text_threshold: float = 0.25,
|
264 |
+
nms_threshold: float = 0.8,
|
265 |
+
detections: Optional[sv.Detections] = None,
|
266 |
+
multimask_output: bool = True,
|
267 |
+
visualize_promts: bool = True,
|
268 |
+
pil: bool = False):
|
269 |
+
if isinstance(img, Image):
|
270 |
+
img = np.uint8(img)
|
271 |
+
if des_img is not None:
|
272 |
+
if isinstance(des_img, np.ndarray):
|
273 |
+
des_shape = tuple(des_img.shape[:2])
|
274 |
+
else:
|
275 |
+
des_shape = (des_img.size[1], des_img.size[0])
|
276 |
+
src_shape = img.shape[:2]
|
277 |
+
if point_coords is not None or boxes is not None:
|
278 |
+
des_point_coords, des_boxes = transform_coords(src_shape, des_shape, point_coords, boxes)
|
279 |
+
else:
|
280 |
+
des_point_coords = None
|
281 |
+
des_boxes = None
|
282 |
+
else:
|
283 |
+
des_point_coords = None
|
284 |
+
des_boxes = None
|
285 |
+
src_shape = None
|
286 |
+
des_shape = None
|
287 |
+
detections = get_empty_detections() if detections is None else detections
|
288 |
+
mask_annotator = sv.MaskAnnotator()
|
289 |
+
result_list = []
|
290 |
+
mask_result_list = []
|
291 |
+
mask_list = []
|
292 |
+
if boxes is None and point_coords is None and point_labels is None and texts is None or \
|
293 |
+
(point_coords is not None and point_labels is not None and point_coords.shape[0] != point_labels.shape[0]):
|
294 |
+
print('no prompt given!')
|
295 |
+
result_list.append(img)
|
296 |
+
return result_list
|
297 |
+
# if boxes is not None and point_coords is not None and point_labels is not None:
|
298 |
+
# multimask_output = False
|
299 |
+
def get_annotated_image(mask_annotator,
|
300 |
+
detections,
|
301 |
+
img,
|
302 |
+
point_labels=None,
|
303 |
+
point_coords=None,
|
304 |
+
boxes=None,
|
305 |
+
visualize_promts=True,
|
306 |
+
pil=False):
|
307 |
+
annotated_image = mask_annotator.annotate(scene=img.copy(), detections=detections)
|
308 |
+
if visualize_promts:
|
309 |
+
annotated_image = add_prompts_tag(annotated_image, point_labels, point_coords, boxes=boxes, pil=pil)
|
310 |
+
else:
|
311 |
+
if pil:
|
312 |
+
annotated_image = PIL.Image.fromarray(annotated_image)
|
313 |
+
return annotated_image
|
314 |
+
def get_masked_image(img,
|
315 |
+
masks,
|
316 |
+
pil=True):
|
317 |
+
masked_image_list = []
|
318 |
+
for i in range(masks.shape[0]):
|
319 |
+
object_rgb = img * (masks[i].reshape(img.shape[0], img.shape[1], 1))
|
320 |
+
object_rgb = object_rgb.astype(np.uint8)
|
321 |
+
bkgd_mask = np.where(object_rgb == 0, 1, 0)
|
322 |
+
bkgd_mask *= 255
|
323 |
+
bkgd_mask = bkgd_mask.astype(np.uint8)
|
324 |
+
object_rgb += bkgd_mask
|
325 |
+
if pil:
|
326 |
+
masked_image_list.append(PIL.Image.fromarray(object_rgb))
|
327 |
+
else:
|
328 |
+
masked_image_list.append(object_rgb)
|
329 |
+
return masked_image_list
|
330 |
+
def interpolate_mask(mask_np, des_shape):
|
331 |
+
mask_tensor = torch.tensor(mask_np, dtype=torch.float32).unsqueeze(0)
|
332 |
+
mask_interpolate = F.interpolate(mask_tensor, size=des_shape, mode='bilinear')
|
333 |
+
mask_interpolate = (mask_interpolate+0.5).long()
|
334 |
+
mask_np = mask_interpolate.squeeze(0).numpy().astype(bool)
|
335 |
+
return mask_np
|
336 |
+
|
337 |
+
if point_coords is not None and point_labels is not None:
|
338 |
+
|
339 |
+
if src_shape is not None:
|
340 |
+
point_result = self.forward(sam_encoder_output,
|
341 |
+
features,
|
342 |
+
interm_features,
|
343 |
+
original_size,
|
344 |
+
input_size,
|
345 |
+
des_point_coords,
|
346 |
+
point_labels)
|
347 |
+
masks_np = interpolate_mask(point_result.masks_np, src_shape)
|
348 |
+
else:
|
349 |
+
point_result = self.forward(sam_encoder_output,
|
350 |
+
features,
|
351 |
+
interm_features,
|
352 |
+
original_size,
|
353 |
+
input_size,
|
354 |
+
point_coords,
|
355 |
+
point_labels)
|
356 |
+
masks_np = point_result.masks_np
|
357 |
+
if multimask_output:
|
358 |
+
for i in range(masks_np.shape[0]):
|
359 |
+
detections.mask = masks_np[i][None, ...]
|
360 |
+
mask_list.append(masks_np[i])
|
361 |
+
result_list.append(get_annotated_image(mask_annotator,
|
362 |
+
detections,
|
363 |
+
img,
|
364 |
+
point_labels=point_labels,
|
365 |
+
point_coords=point_coords,
|
366 |
+
visualize_promts=visualize_promts,
|
367 |
+
pil=pil))
|
368 |
+
mask_result_list += get_masked_image(img,
|
369 |
+
detections.mask,
|
370 |
+
pil=pil)
|
371 |
+
else:
|
372 |
+
index = np.argmax(point_result.iou_predictions_np)
|
373 |
+
detections.mask = masks_np[index][None, ...]
|
374 |
+
mask_list.append(masks_np[index])
|
375 |
+
result_list.append(get_annotated_image(mask_annotator,
|
376 |
+
detections,
|
377 |
+
img,
|
378 |
+
point_labels=point_labels,
|
379 |
+
point_coords=point_coords,
|
380 |
+
visualize_promts=visualize_promts,
|
381 |
+
pil=pil))
|
382 |
+
mask_result_list += get_masked_image(img,
|
383 |
+
detections.mask,
|
384 |
+
pil=pil)
|
385 |
+
|
386 |
+
if boxes is not None:
|
387 |
+
result_masks = []
|
388 |
+
if src_shape is not None:
|
389 |
+
boxes_ = des_boxes
|
390 |
+
else:
|
391 |
+
boxes_ = boxes
|
392 |
+
if boxes_.shape[0] > 1:
|
393 |
+
for i in range(len(boxes)):
|
394 |
+
box_result = self.forward(sam_encoder_output,
|
395 |
+
features,
|
396 |
+
interm_features,
|
397 |
+
original_size,
|
398 |
+
input_size,
|
399 |
+
box=boxes_[i])
|
400 |
+
index = np.argmax(box_result.iou_predictions_np)
|
401 |
+
result_masks.append(box_result.masks_np[index])
|
402 |
+
mask = np.array(result_masks)
|
403 |
+
if src_shape is not None:
|
404 |
+
masks_np = interpolate_mask(mask, src_shape)
|
405 |
+
else:
|
406 |
+
masks_np = mask
|
407 |
+
mask_list.append(masks_np)
|
408 |
+
detections.mask = masks_np
|
409 |
+
result_list.append(get_annotated_image(mask_annotator,
|
410 |
+
detections,
|
411 |
+
img,
|
412 |
+
boxes=boxes,
|
413 |
+
visualize_promts=visualize_promts,
|
414 |
+
pil=pil))
|
415 |
+
mask_result_list += get_masked_image(img,
|
416 |
+
detections.mask,
|
417 |
+
pil=pil)
|
418 |
+
else:
|
419 |
+
box_result = self.forward(sam_encoder_output,
|
420 |
+
features,
|
421 |
+
interm_features,
|
422 |
+
original_size,
|
423 |
+
input_size,
|
424 |
+
box=boxes_)
|
425 |
+
if src_shape is not None:
|
426 |
+
masks_np = interpolate_mask(box_result.masks_np, src_shape)
|
427 |
+
else:
|
428 |
+
masks_np = box_result.masks_np
|
429 |
+
|
430 |
+
if multimask_output:
|
431 |
+
for i in range(masks_np.shape[0]):
|
432 |
+
detections.mask = masks_np[i][None, ...]
|
433 |
+
mask_list.append(masks_np[i])
|
434 |
+
result_list.append(get_annotated_image(mask_annotator,
|
435 |
+
detections,
|
436 |
+
img,
|
437 |
+
boxes=boxes,
|
438 |
+
visualize_promts=visualize_promts,
|
439 |
+
pil=pil))
|
440 |
+
mask_result_list += get_masked_image(img,
|
441 |
+
detections.mask,
|
442 |
+
pil=pil)
|
443 |
+
else:
|
444 |
+
index = np.argmax(box_result.iou_predictions_np)
|
445 |
+
detections.mask = masks_np[index][None, ...]
|
446 |
+
mask_list.append(masks_np[index])
|
447 |
+
result_list.append(get_annotated_image(mask_annotator, detections, img, boxes=boxes, pil=pil))
|
448 |
+
mask_result_list += get_masked_image(img,
|
449 |
+
detections.mask,
|
450 |
+
pil=pil)
|
451 |
+
|
452 |
+
if texts is not None and grounding_dino_pipeline is not None:
|
453 |
+
detections = grounding_dino_pipeline(img[:, :, ::-1], texts, box_threshold, text_threshold)
|
454 |
+
box_annotator = sv.BoxAnnotator()
|
455 |
+
nms_idx = torchvision.ops.nms(
|
456 |
+
torch.from_numpy(detections.xyxy),
|
457 |
+
torch.from_numpy(detections.confidence),
|
458 |
+
nms_threshold
|
459 |
+
).numpy().tolist()
|
460 |
+
|
461 |
+
detections.xyxy = detections.xyxy[nms_idx]
|
462 |
+
detections.confidence = detections.confidence[nms_idx]
|
463 |
+
detections.class_id = detections.class_id[nms_idx]
|
464 |
+
labels = [
|
465 |
+
f"{texts[class_id]} {confidence:0.2f}"
|
466 |
+
for _, _, confidence, class_id, _
|
467 |
+
in detections]
|
468 |
+
result_masks = []
|
469 |
+
if src_shape is not None:
|
470 |
+
_, boxes_ = transform_coords(src_shape, des_shape, boxes=detections.xyxy)
|
471 |
+
else:
|
472 |
+
boxes_ = detections.xyxy
|
473 |
+
for box in boxes_:
|
474 |
+
box_result = self.forward(sam_encoder_output,
|
475 |
+
features,
|
476 |
+
interm_features,
|
477 |
+
original_size,
|
478 |
+
input_size,
|
479 |
+
box=box)
|
480 |
+
index = np.argmax(box_result.iou_predictions_np)
|
481 |
+
result_masks.append(box_result.masks_np[index])
|
482 |
+
mask = np.array(result_masks)
|
483 |
+
if src_shape is not None:
|
484 |
+
detections.mask = interpolate_mask(mask, src_shape)
|
485 |
+
else:
|
486 |
+
detections.mask = mask
|
487 |
+
for i in range(detections.mask.shape[0]):
|
488 |
+
mask_list.append(detections.mask[i, ...])
|
489 |
+
if visualize_promts:
|
490 |
+
annotated_image = mask_annotator.annotate(scene=img[:, :, ::-1].copy(), detections=detections)
|
491 |
+
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
|
492 |
+
else:
|
493 |
+
annotated_image = mask_annotator.annotate(scene=img[:, :, ::-1].copy(), detections=detections)
|
494 |
+
|
495 |
+
if pil:
|
496 |
+
result_list.append(PIL.Image.fromarray(annotated_image[:, :, ::-1]))
|
497 |
+
else:
|
498 |
+
result_list.append(annotated_image[:, :, ::-1])
|
499 |
+
mask_result_list += get_masked_image(img,
|
500 |
+
detections.mask,
|
501 |
+
pil=pil)
|
502 |
+
|
503 |
+
return result_list, mask_result_list, mask_list
|
504 |
+
|
505 |
+
def predict(
|
506 |
+
self,
|
507 |
+
features: torch.Tensor,
|
508 |
+
interm_features: List[torch.Tensor],
|
509 |
+
original_size: Tuple,
|
510 |
+
input_size: Tuple,
|
511 |
+
point_coords: Optional[np.ndarray] = None,
|
512 |
+
point_labels: Optional[np.ndarray] = None,
|
513 |
+
box: Optional[np.ndarray] = None,
|
514 |
+
mask_input: Optional[np.ndarray] = None,
|
515 |
+
multimask_output: bool = True,
|
516 |
+
return_logits: bool = False,
|
517 |
+
hq_token_only: bool = False,
|
518 |
+
) -> SAMDecoderPredictOutput:
|
519 |
+
"""
|
520 |
+
Predict masks for the given input prompts, using the currently set image.
|
521 |
+
|
522 |
+
Arguments:
|
523 |
+
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
524 |
+
model. Each point is in (X,Y) in pixels.
|
525 |
+
point_labels (np.ndarray or None): A length N array of labels for the
|
526 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
527 |
+
background point.
|
528 |
+
box (np.ndarray or None): A length 4 array given a box prompt to the
|
529 |
+
model, in XYXY format.
|
530 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
531 |
+
coming from a previous prediction iteration. Has form 1xHxW, where
|
532 |
+
for SAM, H=W=256.
|
533 |
+
multimask_output (bool): If true, the model will return three masks.
|
534 |
+
For ambiguous input prompts (such as a single click), this will often
|
535 |
+
produce better masks than a single prediction. If only a single
|
536 |
+
mask is needed, the model's predicted quality score can be used
|
537 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
538 |
+
input prompts, multimask_output=False can give better results.
|
539 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
540 |
+
instead of a binary mask.
|
541 |
+
|
542 |
+
Returns:
|
543 |
+
(np.ndarray): The output masks in CxHxW format, where C is the
|
544 |
+
number of masks, and (H, W) is the original image size.
|
545 |
+
(np.ndarray): An array of length C containing the model's
|
546 |
+
predictions for the quality of each mask.
|
547 |
+
(np.ndarray): An array of shape CxHxW, where C is the number
|
548 |
+
of masks and H=W=256. These low resolution logits can be passed to
|
549 |
+
a subsequent iteration as mask input.
|
550 |
+
"""
|
551 |
+
# Transform input prompts
|
552 |
+
|
553 |
+
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
|
554 |
+
if point_coords is not None:
|
555 |
+
assert (
|
556 |
+
point_labels is not None
|
557 |
+
), "point_labels must be supplied if point_coords is supplied."
|
558 |
+
point_coords = self.transform.apply_coords(point_coords, original_size)
|
559 |
+
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
560 |
+
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
561 |
+
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
562 |
+
if box is not None:
|
563 |
+
box = self.transform.apply_boxes(box, original_size)
|
564 |
+
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
565 |
+
box_torch = box_torch[None, :]
|
566 |
+
if mask_input is not None:
|
567 |
+
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
568 |
+
mask_input_torch = mask_input_torch[None, :, :, :]
|
569 |
+
|
570 |
+
sam_decoder_predict_torch_output = self.predict_torch(
|
571 |
+
features,
|
572 |
+
interm_features,
|
573 |
+
original_size,
|
574 |
+
input_size,
|
575 |
+
coords_torch,
|
576 |
+
labels_torch,
|
577 |
+
box_torch,
|
578 |
+
mask_input_torch,
|
579 |
+
multimask_output,
|
580 |
+
return_logits=return_logits,
|
581 |
+
hq_token_only=hq_token_only,
|
582 |
+
)
|
583 |
+
|
584 |
+
masks_np = sam_decoder_predict_torch_output.masks[0].detach().cpu().numpy()
|
585 |
+
iou_predictions_np = sam_decoder_predict_torch_output.iou_predictions[0].detach().cpu().numpy()
|
586 |
+
low_res_masks_np = sam_decoder_predict_torch_output.low_res_masks[0].detach().cpu().numpy()
|
587 |
+
return SAMDecoderPredictOutput(masks_np, iou_predictions_np, low_res_masks_np)
|
588 |
+
|
589 |
+
@torch.no_grad()
|
590 |
+
def predict_torch(
|
591 |
+
self,
|
592 |
+
features: torch.Tensor,
|
593 |
+
interm_features: List[torch.Tensor],
|
594 |
+
original_size: Tuple,
|
595 |
+
input_size: Tuple,
|
596 |
+
point_coords: Optional[torch.Tensor],
|
597 |
+
point_labels: Optional[torch.Tensor],
|
598 |
+
boxes: Optional[torch.Tensor] = None,
|
599 |
+
mask_input: Optional[torch.Tensor] = None,
|
600 |
+
multimask_output: bool = True,
|
601 |
+
return_logits: bool = False,
|
602 |
+
hq_token_only: bool = False,
|
603 |
+
) -> SAMDecoderPredictTorchOutput:
|
604 |
+
"""
|
605 |
+
Predict masks for the given input prompts, using the currently set image.
|
606 |
+
Input prompts are batched torch tensors and are expected to already be
|
607 |
+
transformed to the input frame using ResizeLongestSide.
|
608 |
+
|
609 |
+
Arguments:
|
610 |
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
611 |
+
model. Each point is in (X,Y) in pixels.
|
612 |
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
613 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
614 |
+
background point.
|
615 |
+
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
616 |
+
model, in XYXY format.
|
617 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
618 |
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
619 |
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
620 |
+
predict method do not need further transformation.
|
621 |
+
multimask_output (bool): If true, the model will return three masks.
|
622 |
+
For ambiguous input prompts (such as a single click), this will often
|
623 |
+
produce better masks than a single prediction. If only a single
|
624 |
+
mask is needed, the model's predicted quality score can be used
|
625 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
626 |
+
input prompts, multimask_output=False can give better results.
|
627 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
628 |
+
instead of a binary mask.
|
629 |
+
|
630 |
+
Returns:
|
631 |
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
632 |
+
number of masks, and (H, W) is the original image size.
|
633 |
+
(torch.Tensor): An array of shape BxC containing the model's
|
634 |
+
predictions for the quality of each mask.
|
635 |
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
636 |
+
of masks and H=W=256. These low res logits can be passed to
|
637 |
+
a subsequent iteration as mask input.
|
638 |
+
"""
|
639 |
+
|
640 |
+
if point_coords is not None:
|
641 |
+
points = (point_coords, point_labels)
|
642 |
+
else:
|
643 |
+
points = None
|
644 |
+
|
645 |
+
# Embed prompts
|
646 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
647 |
+
points=points,
|
648 |
+
boxes=boxes,
|
649 |
+
masks=mask_input,
|
650 |
+
)
|
651 |
+
|
652 |
+
# Predict masks
|
653 |
+
low_res_masks, iou_predictions = self.mask_decoder(
|
654 |
+
image_embeddings=features,
|
655 |
+
image_pe=self.prompt_encoder.get_dense_pe(),
|
656 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
657 |
+
dense_prompt_embeddings=dense_embeddings,
|
658 |
+
multimask_output=multimask_output,
|
659 |
+
hq_token_only=hq_token_only,
|
660 |
+
interm_embeddings=interm_features,
|
661 |
+
)
|
662 |
+
|
663 |
+
# Upscale the masks to the original image resolution
|
664 |
+
# masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
|
665 |
+
masks = self.postprocess_masks(low_res_masks, input_size, original_size)
|
666 |
+
|
667 |
+
if not return_logits:
|
668 |
+
masks = masks > self.mask_threshold
|
669 |
+
|
670 |
+
return SAMDecoderPredictTorchOutput(masks, iou_predictions, low_res_masks)
|
671 |
+
def forward(self,
|
672 |
+
sam_encoder_output: Optional[SAMEncoderOutput]=None,
|
673 |
+
features: Optional[torch.Tensor]=None,
|
674 |
+
interm_features: Optional[List[torch.Tensor]]=None,
|
675 |
+
original_size: Optional[Tuple]=None,
|
676 |
+
input_size: Optional[Tuple]=None,
|
677 |
+
point_coords: Optional[np.ndarray] = None,
|
678 |
+
point_labels: Optional[np.ndarray] = None,
|
679 |
+
box: Optional[np.ndarray] = None,
|
680 |
+
mask_input: Optional[np.ndarray] = None,
|
681 |
+
multimask_output: bool = True,
|
682 |
+
return_logits: bool = False,
|
683 |
+
hq_token_only: bool = False,
|
684 |
+
dino: bool = False
|
685 |
+
) -> SAMDecoderPredictOutput:
|
686 |
+
assert sam_encoder_output or (features is not None and original_size is not None and input_size is not None), 'one of sam_encoder_output and four necessary inputs must be given!'
|
687 |
+
if sam_encoder_output:
|
688 |
+
features = sam_encoder_output.features
|
689 |
+
interm_features = sam_encoder_output.interm_features
|
690 |
+
original_size = sam_encoder_output.original_size
|
691 |
+
input_size = sam_encoder_output.input_size
|
692 |
+
if self.adaptor is not None:
|
693 |
+
if dino:
|
694 |
+
features = F.interpolate(F.normalize(features, dim=1), size=(64, 64), mode='bilinear').permute(0, 2, 3, 1)
|
695 |
+
features = self.adaptor(features)
|
696 |
+
#
|
697 |
+
# else:
|
698 |
+
# features = self.adaptor(features, original_size)
|
699 |
+
|
700 |
+
return self.predict(features,
|
701 |
+
interm_features,
|
702 |
+
original_size,
|
703 |
+
input_size,
|
704 |
+
point_coords,
|
705 |
+
point_labels,
|
706 |
+
box,
|
707 |
+
mask_input,
|
708 |
+
multimask_output,
|
709 |
+
return_logits,
|
710 |
+
hq_token_only)
|
711 |
+
|
712 |
+
'''
|
713 |
+
class SAMPipeline(Pipeline):
|
714 |
+
@classmethod
|
715 |
+
def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
|
716 |
+
sam_encoder_pipeline = SAMEncoderPipeline(ckpt_path, device, *args, **kwargs)
|
717 |
+
sam_decoder_pipeline = SAMDecoderPipeline(ckpt_path, device, *args, **kwargs)
|
718 |
+
pipeline = cls(**dict(sam_encoder_pipeline=sam_encoder_pipeline,
|
719 |
+
sam_decoder_pipeline=sam_decoder_pipeline,
|
720 |
+
device=device))
|
721 |
+
return pipeline
|
722 |
+
'''
|
sam_extension/utils/__init__.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import cv2
|
3 |
+
import PIL
|
4 |
+
import torch
|
5 |
+
from PIL.Image import Image
|
6 |
+
from typing import Union, Tuple, List, Optional
|
7 |
+
import numpy as np
|
8 |
+
import supervision as sv
|
9 |
+
from sklearn.decomposition import PCA
|
10 |
+
|
11 |
+
# def add_points_tag(img: Union[Image, np.ndarray],
|
12 |
+
# point_labels: Union[List[int], np.ndarray] = None,
|
13 |
+
# point_coords: Union[List[List[int]], np.ndarray] = None,
|
14 |
+
# pil: bool = False):
|
15 |
+
# if point_labels is None or point_coords is None or \
|
16 |
+
# not isinstance(point_labels, (List, np.ndarray)) or \
|
17 |
+
# not isinstance(point_coords, (List, np.ndarray)):
|
18 |
+
# return img
|
19 |
+
# if len(point_labels) != len(point_coords):
|
20 |
+
# print('length of point_label and point_coordinate must be same!')
|
21 |
+
# return img
|
22 |
+
# if isinstance(img, Image):
|
23 |
+
# img = np.uint8(img)
|
24 |
+
# start_angle = 40
|
25 |
+
# x = 8
|
26 |
+
# y = 2
|
27 |
+
# def get_point(angle, d, base):
|
28 |
+
# angle = angle / 180.0 * math.pi
|
29 |
+
# _x, _y = math.cos(angle) * d, math.sin(angle) * d
|
30 |
+
# return [base[0] + _x, base[1] - _y]
|
31 |
+
# # assert len(point_labels) == len(point_coords), ''
|
32 |
+
# for i in range(len(point_labels)):
|
33 |
+
# points = []
|
34 |
+
# for j in range(5):
|
35 |
+
# _x, _y = math.cos(start_angle), math.sin(start_angle)
|
36 |
+
# points.append(get_point(start_angle, x, point_coords[i]))
|
37 |
+
# start_angle -= 36
|
38 |
+
# points.append(get_point(start_angle, y, point_coords[i]))
|
39 |
+
# start_angle -= 36
|
40 |
+
# points = np.array([points], np.int32)
|
41 |
+
# color = (255, 0, 0) if point_labels[i] == 0 else (0, 255, 0)
|
42 |
+
# cv2.fillPoly(img, points, color, cv2.LINE_AA)
|
43 |
+
# if pil:
|
44 |
+
# img = PIL.Image.fromarray(img)
|
45 |
+
# return img
|
46 |
+
def add_points_tag(img: Union[Image, np.ndarray],
|
47 |
+
point_labels: Union[List[int], np.ndarray] = None,
|
48 |
+
point_coords: Union[List[List[int]], np.ndarray] = None,
|
49 |
+
pil: bool = False):
|
50 |
+
if point_labels is None or point_coords is None or \
|
51 |
+
not isinstance(point_labels, (List, np.ndarray)) or \
|
52 |
+
not isinstance(point_coords, (List, np.ndarray)):
|
53 |
+
return img
|
54 |
+
if len(point_labels) != len(point_coords):
|
55 |
+
print('length of point_label and point_coordinate must be same!')
|
56 |
+
return img
|
57 |
+
if isinstance(img, Image):
|
58 |
+
img = np.array(img)
|
59 |
+
# img.flags.writeable = True
|
60 |
+
h, w = img.shape[:2]
|
61 |
+
x_start_list, x_end_list = np.where((point_coords[:, 0] - 4) > 0, point_coords[:, 0] - 4, 0), np.where((point_coords[:, 0] + 4) < w, point_coords[:, 0] + 4, w)
|
62 |
+
y_start_list, y_end_list = np.where((point_coords[:, 1] - 4) > 0, point_coords[:, 1] - 4, 0), np.where((point_coords[:, 1] + 4) < h, point_coords[:, 1] + 4, h)
|
63 |
+
for i in range(len(point_labels)):
|
64 |
+
x_start, x_end = x_start_list[i], x_end_list[i]
|
65 |
+
y_start, y_end = y_start_list[i], y_end_list[i]
|
66 |
+
label = point_labels[i]
|
67 |
+
color = [0, 255, 0] if int(label) == 1 else [255, 0, 0]
|
68 |
+
for x in range(x_start, x_end):
|
69 |
+
for y in range(y_start, y_end):
|
70 |
+
img[y, x, :] = color
|
71 |
+
if pil:
|
72 |
+
img = PIL.Image.fromarray(img)
|
73 |
+
return img
|
74 |
+
def add_boxes_tag(img: Union[Image, np.ndarray],
|
75 |
+
boxes: Union[List[List[int]], np.ndarray] = None,
|
76 |
+
pil: bool = False):
|
77 |
+
if boxes is None or not isinstance(boxes, (List, np.ndarray)):
|
78 |
+
return img
|
79 |
+
# if isinstance(boxes, np.ndarray):
|
80 |
+
# if not boxes.all():
|
81 |
+
# return img
|
82 |
+
# else:
|
83 |
+
# if not boxes:
|
84 |
+
# return img
|
85 |
+
if isinstance(img, Image):
|
86 |
+
img = np.uint8(img)
|
87 |
+
thickness = 2
|
88 |
+
for i in range(len(boxes)):
|
89 |
+
color = (0, 255, 0)
|
90 |
+
img = cv2.rectangle(img, (boxes[i][0], boxes[i][1]), (boxes[i][2], boxes[i][3]), color, thickness)
|
91 |
+
if pil:
|
92 |
+
img = PIL.Image.fromarray(img)
|
93 |
+
return img
|
94 |
+
|
95 |
+
def add_prompts_tag(img: Union[Image, np.ndarray],
|
96 |
+
point_labels: Union[List[int], np.ndarray] = None,
|
97 |
+
point_coords: Union[List[List[int]], np.ndarray] = None,
|
98 |
+
boxes: Union[List[List[int]], np.ndarray] = None,
|
99 |
+
pil: bool = False):
|
100 |
+
img = add_points_tag(img, point_labels, point_coords, pil=pil)
|
101 |
+
img = add_boxes_tag(img, boxes, pil=pil)
|
102 |
+
return img
|
103 |
+
|
104 |
+
|
105 |
+
def get_empty_detections():
|
106 |
+
detections = sv.Detections(xyxy=np.array([0, 0, 0, 0]).reshape(1, 4))
|
107 |
+
detections.xyxy = None
|
108 |
+
return detections
|
109 |
+
|
110 |
+
|
111 |
+
def pca_feature(feature: torch.Tensor, dim: int = 3, return_np: bool = True):
|
112 |
+
pca = PCA(n_components=dim)
|
113 |
+
H, W, C = feature.shape
|
114 |
+
feature = feature.view(-1, C).cpu().numpy()
|
115 |
+
feature = pca.fit_transform(feature)
|
116 |
+
feature = torch.tensor(feature.reshape(H, W, dim))
|
117 |
+
if return_np:
|
118 |
+
return feature.numpy()
|
119 |
+
else:
|
120 |
+
return feature
|
121 |
+
|
122 |
+
def visual_feature_rgb(feature: torch.Tensor, pil:bool = True):
|
123 |
+
assert feature.ndim >= 3, 'the dim of feature must >= 3!'
|
124 |
+
if feature.ndim == 4:
|
125 |
+
feature = feature.squeeze(0)
|
126 |
+
if feature.shape[-1] != 3:
|
127 |
+
feature = pca_feature(feature, 3, False)
|
128 |
+
max_f, _ = feature.max(-1)
|
129 |
+
min_f, _ = feature.min(-1)
|
130 |
+
feature = (feature - min_f[..., None]) / (max_f[..., None] - min_f[..., None])
|
131 |
+
feature = np.uint8((feature*255).cpu().numpy())
|
132 |
+
if pil:
|
133 |
+
return PIL.Image.fromarray(feature)
|
134 |
+
else:
|
135 |
+
return feature
|
136 |
+
|
137 |
+
def transform_coords(src_shape, des_shape, points = None, boxes = None):
|
138 |
+
assert points is not None or boxes is not None, 'one of points and boxes must be given!'
|
139 |
+
scale_h = des_shape[0] / src_shape[0]
|
140 |
+
scale_w = des_shape[1] / src_shape[1]
|
141 |
+
if points is not None:
|
142 |
+
new_points = np.full_like(points, 0)
|
143 |
+
new_points[:, 0] = points[:, 0] * scale_w
|
144 |
+
new_points[:, 1] = points[:, 1] * scale_h
|
145 |
+
new_points.astype(np.int64)
|
146 |
+
else:
|
147 |
+
new_points = None
|
148 |
+
if boxes is not None:
|
149 |
+
new_boxes = np.full_like(boxes, 0)
|
150 |
+
new_boxes[:, 0] = boxes[:, 0] * scale_w
|
151 |
+
new_boxes[:, 1] = boxes[:, 1] * scale_h
|
152 |
+
new_boxes[:, 2] = boxes[:, 2] * scale_w
|
153 |
+
new_boxes[:, 3] = boxes[:, 3] * scale_h
|
154 |
+
new_boxes.astype(np.int64)
|
155 |
+
else:
|
156 |
+
new_boxes = None
|
157 |
+
return new_points, new_boxes
|
158 |
+
|
159 |
+
|
160 |
+
def mask2greyimg(mask_list, pil=True):
|
161 |
+
grey_img_list = []
|
162 |
+
for mask in mask_list:
|
163 |
+
if pil:
|
164 |
+
grey_img_list.append(PIL.Image.fromarray(np.uint8(mask*255)))
|
165 |
+
else:
|
166 |
+
grey_img_list.append(np.uint8(mask * 255))
|
167 |
+
return grey_img_list
|
168 |
+
if __name__ == '__main__':
|
169 |
+
src_shape = (100,100)
|
170 |
+
des_shape = (200,200)
|
171 |
+
points = np.array([[20,20],[40,40]])
|
172 |
+
boxes = np.array([[10,10,20,20]])
|
173 |
+
new_points, new_boxes = transform_coords(src_shape, des_shape, points, boxes)
|
174 |
+
print(new_points, new_boxes)
|
175 |
+
|
sam_extension/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (4.51 kB). View file
|
|