YouLiXiya's picture
Upload 22 files
7dbe662
raw
history blame
8.71 kB
from typing import Tuple, List, Union
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from sam_extension.distillation_models.fastervit import FasterViTLayer
from segment_anything.mobile_encoder.tiny_vit_sam import PatchEmbed, Conv2d_BN, LayerNorm2d, MBConv
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, out_dim, activation):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.out_dim = out_dim
self.act = activation()
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
stride_c=2
if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576
stride_c=1
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x):
if x.ndim == 3:
H, W = self.input_resolution
B = len(x)
# (B, C, H, W)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
return x
class ConvLayer(nn.Module):
def __init__(self, dim, input_resolution, depth,
activation,
drop_path=0., downsample=None, use_checkpoint=False,
out_dim=None,
conv_expand_ratio=4.,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
MBConv(dim, dim, conv_expand_ratio, activation,
drop_path[i] if isinstance(drop_path, list) else drop_path,
)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class FasterTinyViT(nn.Module):
def __init__(self, img_size=224,
in_chans=3,
out_chans=256,
embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.1,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
ct_size=2,
conv=False,
multi_scale=False,
output_shape=None,
):
super().__init__()
self.img_size = img_size
self.depths = depths
self.num_layers = len(depths)
self.mlp_ratio = mlp_ratio
self.multi_scale = multi_scale
self.output_shape = tuple(output_shape) if output_shape else None
activation = nn.GELU
self.patch_embed = PatchEmbed(in_chans=in_chans,
embed_dim=embed_dims[0],
resolution=img_size,
activation=activation)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
kwargs_0 = dict(dim=embed_dims[i_layer],
input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
# patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=PatchMerging if (
i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
out_dim=embed_dims[min(
i_layer + 1, len(embed_dims) - 1)],
activation=activation,
)
kwargs_1 = dict(dim=embed_dims[i_layer],
out_dim=embed_dims[i_layer+1] if (
i_layer < self.num_layers - 1) else embed_dims[i_layer],
input_resolution=patches_resolution[0] // (2 ** i_layer),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
downsample=True if (i_layer < self.num_layers - 1) else False,
ct_size=ct_size,
conv=conv,
)
if i_layer == 0:
layer = ConvLayer(
conv_expand_ratio=mbconv_expand_ratio,
**kwargs_0,
)
else:
layer = FasterViTLayer(
num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
**kwargs_1)
self.layers.append(layer)
# init weights
self.apply(self._init_weights)
self.neck = nn.Sequential(
nn.Conv2d(
sum(embed_dims)+embed_dims[-1] if self.multi_scale and self.output_shape else embed_dims[-1],
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'attention_biases'}
def forward_features(self, x):
if self.multi_scale and self.output_shape:
output_list = []
# x: (N, C, H, W)
x = self.patch_embed(x)
output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear'))
for layer in self.layers:
x = layer(x)
output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear'))
x = self.neck(torch.cat(output_list, dim=1))
else:
x = self.patch_embed(x)
for layer in self.layers:
x = layer(x)
x = self.neck(x)
return x
def forward(self, x):
x = self.forward_features(x)
return x
if __name__ == '__main__':
from distillation.utils import get_parameter_number
x = torch.randn(1, 3, 1024, 1024).cuda()
fastertinyvit = FasterTinyViT(img_size=1024, in_chans=3,
embed_dims=[64, 128, 256],
depths=[1, 2, 1],
num_heads=[2, 4, 8],
window_sizes=[8, 8, 8],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
multi_scale=False,
output_shape='').cuda()
print(fastertinyvit(x).shape)
print(get_parameter_number(fastertinyvit))
# torch.save(fastertinyvit, 'fastertinyvit.pt')