Spaces:
Running
Running
from typing import Tuple, List, Union | |
import torch | |
from torch import nn | |
from torch.utils.checkpoint import checkpoint | |
import torch.nn.functional as F | |
from timm.models.layers import trunc_normal_ | |
from sam_extension.distillation_models.fastervit import FasterViTLayer | |
from segment_anything.mobile_encoder.tiny_vit_sam import PatchEmbed, Conv2d_BN, LayerNorm2d, MBConv | |
class PatchMerging(nn.Module): | |
def __init__(self, input_resolution, dim, out_dim, activation): | |
super().__init__() | |
self.input_resolution = input_resolution | |
self.dim = dim | |
self.out_dim = out_dim | |
self.act = activation() | |
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) | |
stride_c=2 | |
if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576 | |
stride_c=1 | |
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) | |
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) | |
def forward(self, x): | |
if x.ndim == 3: | |
H, W = self.input_resolution | |
B = len(x) | |
# (B, C, H, W) | |
x = x.view(B, H, W, -1).permute(0, 3, 1, 2) | |
x = self.conv1(x) | |
x = self.act(x) | |
x = self.conv2(x) | |
x = self.act(x) | |
x = self.conv3(x) | |
return x | |
class ConvLayer(nn.Module): | |
def __init__(self, dim, input_resolution, depth, | |
activation, | |
drop_path=0., downsample=None, use_checkpoint=False, | |
out_dim=None, | |
conv_expand_ratio=4., | |
): | |
super().__init__() | |
self.dim = dim | |
self.input_resolution = input_resolution | |
self.depth = depth | |
self.use_checkpoint = use_checkpoint | |
# build blocks | |
self.blocks = nn.ModuleList([ | |
MBConv(dim, dim, conv_expand_ratio, activation, | |
drop_path[i] if isinstance(drop_path, list) else drop_path, | |
) | |
for i in range(depth)]) | |
# patch merging layer | |
if downsample is not None: | |
self.downsample = downsample( | |
input_resolution, dim=dim, out_dim=out_dim, activation=activation) | |
else: | |
self.downsample = None | |
def forward(self, x): | |
for blk in self.blocks: | |
if self.use_checkpoint: | |
x = checkpoint.checkpoint(blk, x) | |
else: | |
x = blk(x) | |
if self.downsample is not None: | |
x = self.downsample(x) | |
return x | |
class FasterTinyViT(nn.Module): | |
def __init__(self, img_size=224, | |
in_chans=3, | |
out_chans=256, | |
embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], | |
num_heads=[3, 6, 12, 24], | |
window_sizes=[7, 7, 14, 7], | |
mlp_ratio=4., | |
drop_rate=0., | |
drop_path_rate=0.1, | |
use_checkpoint=False, | |
mbconv_expand_ratio=4.0, | |
ct_size=2, | |
conv=False, | |
multi_scale=False, | |
output_shape=None, | |
): | |
super().__init__() | |
self.img_size = img_size | |
self.depths = depths | |
self.num_layers = len(depths) | |
self.mlp_ratio = mlp_ratio | |
self.multi_scale = multi_scale | |
self.output_shape = tuple(output_shape) if output_shape else None | |
activation = nn.GELU | |
self.patch_embed = PatchEmbed(in_chans=in_chans, | |
embed_dim=embed_dims[0], | |
resolution=img_size, | |
activation=activation) | |
patches_resolution = self.patch_embed.patches_resolution | |
self.patches_resolution = patches_resolution | |
# stochastic depth | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, | |
sum(depths))] # stochastic depth decay rule | |
# build layers | |
self.layers = nn.ModuleList() | |
for i_layer in range(self.num_layers): | |
kwargs_0 = dict(dim=embed_dims[i_layer], | |
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), | |
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))), | |
# input_resolution=(patches_resolution[0] // (2 ** i_layer), | |
# patches_resolution[1] // (2 ** i_layer)), | |
depth=depths[i_layer], | |
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |
downsample=PatchMerging if ( | |
i_layer < self.num_layers - 1) else None, | |
use_checkpoint=use_checkpoint, | |
out_dim=embed_dims[min( | |
i_layer + 1, len(embed_dims) - 1)], | |
activation=activation, | |
) | |
kwargs_1 = dict(dim=embed_dims[i_layer], | |
out_dim=embed_dims[i_layer+1] if ( | |
i_layer < self.num_layers - 1) else embed_dims[i_layer], | |
input_resolution=patches_resolution[0] // (2 ** i_layer), | |
depth=depths[i_layer], | |
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |
downsample=True if (i_layer < self.num_layers - 1) else False, | |
ct_size=ct_size, | |
conv=conv, | |
) | |
if i_layer == 0: | |
layer = ConvLayer( | |
conv_expand_ratio=mbconv_expand_ratio, | |
**kwargs_0, | |
) | |
else: | |
layer = FasterViTLayer( | |
num_heads=num_heads[i_layer], | |
window_size=window_sizes[i_layer], | |
mlp_ratio=self.mlp_ratio, | |
drop=drop_rate, | |
**kwargs_1) | |
self.layers.append(layer) | |
# init weights | |
self.apply(self._init_weights) | |
self.neck = nn.Sequential( | |
nn.Conv2d( | |
sum(embed_dims)+embed_dims[-1] if self.multi_scale and self.output_shape else embed_dims[-1], | |
out_chans, | |
kernel_size=1, | |
bias=False, | |
), | |
LayerNorm2d(out_chans), | |
nn.Conv2d( | |
out_chans, | |
out_chans, | |
kernel_size=3, | |
padding=1, | |
bias=False, | |
), | |
LayerNorm2d(out_chans), | |
) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, std=.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def no_weight_decay_keywords(self): | |
return {'attention_biases'} | |
def forward_features(self, x): | |
if self.multi_scale and self.output_shape: | |
output_list = [] | |
# x: (N, C, H, W) | |
x = self.patch_embed(x) | |
output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear')) | |
for layer in self.layers: | |
x = layer(x) | |
output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear')) | |
x = self.neck(torch.cat(output_list, dim=1)) | |
else: | |
x = self.patch_embed(x) | |
for layer in self.layers: | |
x = layer(x) | |
x = self.neck(x) | |
return x | |
def forward(self, x): | |
x = self.forward_features(x) | |
return x | |
if __name__ == '__main__': | |
from distillation.utils import get_parameter_number | |
x = torch.randn(1, 3, 1024, 1024).cuda() | |
fastertinyvit = FasterTinyViT(img_size=1024, in_chans=3, | |
embed_dims=[64, 128, 256], | |
depths=[1, 2, 1], | |
num_heads=[2, 4, 8], | |
window_sizes=[8, 8, 8], | |
mlp_ratio=4., | |
drop_rate=0., | |
drop_path_rate=0.0, | |
use_checkpoint=False, | |
mbconv_expand_ratio=4.0, | |
multi_scale=False, | |
output_shape='').cuda() | |
print(fastertinyvit(x).shape) | |
print(get_parameter_number(fastertinyvit)) | |
# torch.save(fastertinyvit, 'fastertinyvit.pt') |