Jiangxz01's picture
Upload 56 files
52f1bcb verified
raw
history blame
24.4 kB
"""
This is an implementation of efficientvit, with some modifications (decode head, etc).
Original paper at https://arxiv.org/abs/2205.14756
Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py
Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit
"""
from typing import Optional, Union, Tuple
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput
from surya.model.detection.config import EfficientViTConfig
from surya.model.detection.processor import SegformerImageProcessor
from surya.settings import settings
def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE):
config = EfficientViTConfig.from_pretrained(checkpoint)
model = EfficientViTForSemanticSegmentation.from_pretrained(checkpoint, torch_dtype=dtype, config=config, ignore_mismatched_sizes=True)
model = model.to(device)
model = model.eval()
print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}")
return model
def load_processor(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT):
processor = SegformerImageProcessor.from_pretrained(checkpoint)
return processor
def val2list(x: list or tuple or any, repeat_time=1):
if isinstance(x, (list, tuple)):
return list(x)
return [x for _ in range(repeat_time)]
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
# repeat elements if necessary
x = val2list(x)
if len(x) > 0:
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
return tuple(x)
def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
else:
assert kernel_size % 2 > 0, "kernel size should be odd number"
return kernel_size // 2
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
class ConvNormAct(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
dilation=1,
groups=1,
bias=False,
dropout=0.,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
super(ConvNormAct, self).__init__()
self.dropout = nn.Dropout(dropout, inplace=False)
padding = get_padding(kernel_size, stride, dilation)
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
padding=padding,
)
self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity()
self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
return x
class DSConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
use_bias=False,
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
act_layer=(nn.ReLU6, None),
):
super(DSConv, self).__init__()
use_bias = val2tuple(use_bias, 2)
norm_layer = val2tuple(norm_layer, 2)
act_layer = val2tuple(act_layer, 2)
self.depth_conv = ConvNormAct(
in_channels,
in_channels,
kernel_size,
stride,
groups=in_channels,
norm_layer=norm_layer[0],
act_layer=act_layer[0],
bias=use_bias[0],
)
self.point_conv = ConvNormAct(
in_channels,
out_channels,
1,
norm_layer=norm_layer[1],
act_layer=act_layer[1],
bias=use_bias[1],
)
def forward(self, x):
x = self.depth_conv(x)
x = self.point_conv(x)
return x
class ConvBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
mid_channels=None,
expand_ratio=1,
use_bias=False,
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
act_layer=(nn.ReLU6, None),
):
super(ConvBlock, self).__init__()
use_bias = val2tuple(use_bias, 2)
norm_layer = val2tuple(norm_layer, 2)
act_layer = val2tuple(act_layer, 2)
mid_channels = mid_channels or round(in_channels * expand_ratio)
self.conv1 = ConvNormAct(
in_channels,
mid_channels,
kernel_size,
stride,
norm_layer=norm_layer[0],
act_layer=act_layer[0],
bias=use_bias[0],
)
self.conv2 = ConvNormAct(
mid_channels,
out_channels,
kernel_size,
1,
norm_layer=norm_layer[1],
act_layer=act_layer[1],
bias=use_bias[1],
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
class MBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
mid_channels=None,
expand_ratio=6,
use_bias=False,
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d),
act_layer=(nn.ReLU6, nn.ReLU6, None),
):
super(MBConv, self).__init__()
use_bias = val2tuple(use_bias, 3)
norm_layer = val2tuple(norm_layer, 3)
act_layer = val2tuple(act_layer, 3)
mid_channels = mid_channels or round(in_channels * expand_ratio)
self.inverted_conv = ConvNormAct(
in_channels,
mid_channels,
1,
stride=1,
norm_layer=norm_layer[0],
act_layer=act_layer[0],
bias=use_bias[0],
)
self.depth_conv = ConvNormAct(
mid_channels,
mid_channels,
kernel_size,
stride=stride,
groups=mid_channels,
norm_layer=norm_layer[1],
act_layer=act_layer[1],
bias=use_bias[1],
)
self.point_conv = ConvNormAct(
mid_channels,
out_channels,
1,
norm_layer=norm_layer[2],
act_layer=act_layer[2],
bias=use_bias[2],
)
def forward(self, x):
x = self.inverted_conv(x)
x = self.depth_conv(x)
x = self.point_conv(x)
return x
class FusedMBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
stride=1,
mid_channels=None,
expand_ratio=6,
groups=1,
use_bias=False,
norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d),
act_layer=(nn.ReLU6, None),
):
super(FusedMBConv, self).__init__()
use_bias = val2tuple(use_bias, 2)
norm_layer = val2tuple(norm_layer, 2)
act_layer = val2tuple(act_layer, 2)
mid_channels = mid_channels or round(in_channels * expand_ratio)
self.spatial_conv = ConvNormAct(
in_channels,
mid_channels,
kernel_size,
stride=stride,
groups=groups,
norm_layer=norm_layer[0],
act_layer=act_layer[0],
bias=use_bias[0],
)
self.point_conv = ConvNormAct(
mid_channels,
out_channels,
1,
norm_layer=norm_layer[1],
act_layer=act_layer[1],
bias=use_bias[1],
)
def forward(self, x):
x = self.spatial_conv(x)
x = self.point_conv(x)
return x
class LiteMLA(nn.Module):
"""Lightweight multi-scale linear attention"""
def __init__(
self,
in_channels: int,
out_channels: int,
heads: int or None = None,
heads_ratio: float = 1.0,
dim=8,
use_bias=False,
norm_layer=(None, nn.BatchNorm2d),
act_layer=(None, None),
kernel_func=nn.ReLU,
scales=(5,),
eps=1e-5,
):
super(LiteMLA, self).__init__()
self.eps = eps
heads = heads or int(in_channels // dim * heads_ratio)
total_dim = heads * dim
use_bias = val2tuple(use_bias, 2)
norm_layer = val2tuple(norm_layer, 2)
act_layer = val2tuple(act_layer, 2)
self.dim = dim
self.qkv = ConvNormAct(
in_channels,
3 * total_dim,
1,
bias=use_bias[0],
norm_layer=norm_layer[0],
act_layer=act_layer[0],
)
self.aggreg = nn.ModuleList([
nn.Sequential(
nn.Conv2d(
3 * total_dim,
3 * total_dim,
scale,
padding=get_same_padding(scale),
groups=3 * total_dim,
bias=use_bias[0],
),
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
)
for scale in scales
])
self.kernel_func = kernel_func(inplace=False)
self.proj = ConvNormAct(
total_dim * (1 + len(scales)),
out_channels,
1,
bias=use_bias[1],
norm_layer=norm_layer[1],
act_layer=act_layer[1],
)
def _attn(self, q, k, v):
dtype = v.dtype
q, k, v = q.float(), k.float(), v.float()
kv = k.transpose(-1, -2) @ v
out = q @ kv
out = out[..., :-1] / (out[..., -1:] + self.eps)
return out.to(dtype)
def forward(self, x):
# Shape is B, C, H, W
B, _, H, W = x.shape
# generate multi-scale q, k, v
qkv = self.qkv(x)
multi_scale_qkv = [qkv]
for op in self.aggreg:
multi_scale_qkv.append(op(qkv))
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
# Shape for each is B, C, HW, head_dim
q, k, v = multi_scale_qkv.chunk(3, dim=-1)
# lightweight global attention
q = self.kernel_func(q)
k = self.kernel_func(k)
v = F.pad(v, (0, 1), mode="constant", value=1.)
out = self._attn(q, k, v)
# final projection
out = out.transpose(-1, -2).reshape(B, -1, H, W)
out = self.proj(out)
return out
class EfficientVitBlock(nn.Module):
def __init__(
self,
in_channels,
heads_ratio=1.0,
head_dim=32,
expand_ratio=4,
norm_layer=nn.BatchNorm2d,
act_layer=nn.Hardswish,
):
super(EfficientVitBlock, self).__init__()
self.context_module = ResidualBlock(
LiteMLA(
in_channels=in_channels,
out_channels=in_channels,
heads_ratio=heads_ratio,
dim=head_dim,
norm_layer=(None, norm_layer),
),
nn.Identity(),
)
self.local_module = ResidualBlock(
MBConv(
in_channels=in_channels,
out_channels=in_channels,
expand_ratio=expand_ratio,
use_bias=(True, True, False),
norm_layer=(None, None, norm_layer),
act_layer=(act_layer, act_layer, None),
),
nn.Identity(),
)
def forward(self, x):
x = self.context_module(x)
x = self.local_module(x)
return x
class ResidualBlock(nn.Module):
def __init__(
self,
main: Optional[nn.Module],
shortcut: Optional[nn.Module] = None,
pre_norm: Optional[nn.Module] = None,
):
super(ResidualBlock, self).__init__()
self.pre_norm = pre_norm if pre_norm is not None else nn.Identity()
self.main = main
self.shortcut = shortcut
def forward(self, x):
res = self.main(self.pre_norm(x))
if self.shortcut is not None:
res = res + self.shortcut(x)
return res
def build_local_block(
in_channels: int,
out_channels: int,
stride: int,
kernel_size: int,
expand_ratio: float,
norm_layer: str,
act_layer: str,
fewer_norm: bool = False,
block_type: str = "default",
):
assert block_type in ["default", "large", "fused"]
if expand_ratio == 1:
if block_type == "default":
block = DSConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
kernel_size=kernel_size,
use_bias=(True, False) if fewer_norm else False,
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
act_layer=(act_layer, None),
)
else:
block = ConvBlock(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
kernel_size=kernel_size,
use_bias=(True, False) if fewer_norm else False,
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
act_layer=(act_layer, None),
)
else:
if block_type == "default":
block = MBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
kernel_size=kernel_size,
expand_ratio=expand_ratio,
use_bias=(True, True, False) if fewer_norm else False,
norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
act_layer=(act_layer, act_layer, None),
)
else:
block = FusedMBConv(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
kernel_size=kernel_size,
expand_ratio=expand_ratio,
use_bias=(True, False) if fewer_norm else False,
norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
act_layer=(act_layer, None),
)
return block
class Stem(nn.Sequential):
def __init__(self, in_chs, out_chs, depth, stride, norm_layer, act_layer, block_type='default'):
super().__init__()
self.stride = stride
self.add_module(
'in_conv',
ConvNormAct(
in_chs, out_chs,
kernel_size=stride + 1, stride=stride, norm_layer=norm_layer, act_layer=act_layer,
)
)
stem_block = 0
for _ in range(depth):
self.add_module(f'res{stem_block}', ResidualBlock(
build_local_block(
in_channels=out_chs,
out_channels=out_chs,
stride=1,
kernel_size=3,
expand_ratio=1,
norm_layer=norm_layer,
act_layer=act_layer,
block_type=block_type,
),
nn.Identity(),
))
stem_block += 1
class EfficientVitLargeStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth,
stride,
norm_layer,
act_layer,
head_dim,
vit_stage=False,
fewer_norm=False,
):
super(EfficientVitLargeStage, self).__init__()
blocks = [ResidualBlock(
build_local_block(
in_channels=in_chs,
out_channels=out_chs,
stride=stride,
kernel_size=stride + 1,
expand_ratio=24 if vit_stage else 16,
norm_layer=norm_layer,
act_layer=act_layer,
fewer_norm=vit_stage or fewer_norm,
block_type='default' if fewer_norm else 'fused',
),
None,
)]
in_chs = out_chs
if vit_stage:
# for stage 4
for _ in range(depth):
blocks.append(
EfficientVitBlock(
in_channels=in_chs,
head_dim=head_dim,
expand_ratio=6,
norm_layer=norm_layer,
act_layer=act_layer,
)
)
else:
# for stage 1, 2, 3
for i in range(depth):
blocks.append(ResidualBlock(
build_local_block(
in_channels=in_chs,
out_channels=out_chs,
stride=1,
kernel_size=3,
expand_ratio=4,
norm_layer=norm_layer,
act_layer=act_layer,
fewer_norm=fewer_norm,
block_type='default' if fewer_norm else 'fused',
),
nn.Identity(),
))
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
return self.blocks(x)
class EfficientVitLarge(nn.Module):
def __init__(
self,
config: EfficientViTConfig,
norm_layer=nn.BatchNorm2d,
act_layer=nn.Hardswish,
):
super(EfficientVitLarge, self).__init__()
self.grad_checkpointing = False
self.num_classes = config.num_classes
self.norm_eps = config.layer_norm_eps
norm_layer = partial(norm_layer, eps=self.norm_eps)
# input stem
self.stem = Stem(config.num_channels, config.widths[0], config.depths[0], config.strides[0], norm_layer, act_layer, block_type='large')
stride = config.strides[0]
# stages
self.feature_info = []
self.stages = nn.Sequential()
in_channels = config.widths[0]
for i, (w, d, s) in enumerate(zip(config.widths[1:], config.depths[1:], config.strides[1:])):
self.stages.append(EfficientVitLargeStage(
in_channels,
w,
depth=d,
stride=s,
norm_layer=norm_layer,
act_layer=act_layer,
head_dim=config.head_dim,
vit_stage=i >= 3,
fewer_norm=i >= 2,
))
stride *= s
in_channels = w
self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
self.num_features = in_channels
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x):
x = self.stem(x)
encoder_hidden_states = []
for i, module in enumerate(self.stages):
x = module(x)
encoder_hidden_states.append(x)
return encoder_hidden_states
class EfficientViTPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = EfficientViTConfig
base_model_prefix = "efficientvit"
main_input_name = "pixel_values"
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class DecodeMLP(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.proj = nn.Linear(input_dim, output_dim)
def forward(self, hidden_states: torch.Tensor):
# Input is B, C, H, W
hidden_states = hidden_states.flatten(2).transpose(1, 2)
# Output is B, HW, C
hidden_states = self.proj(hidden_states)
return hidden_states
class DecodeHead(EfficientViTPreTrainedModel):
def __init__(self, config: EfficientViTConfig):
super().__init__(config)
# linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size
mlps = []
for width in config.widths[1:]:
mlp = DecodeMLP(input_dim=width, output_dim=config.decoder_layer_hidden_size)
mlps.append(mlp)
self.linear_c = nn.ModuleList(mlps)
# the following 3 layers implement the ConvModule of the original implementation
self.linear_fuse = nn.Conv2d(
in_channels=config.decoder_layer_hidden_size * config.num_stages,
out_channels=config.decoder_hidden_size,
kernel_size=1,
bias=False,
)
self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size)
self.activation = nn.ReLU()
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)
self.config = config
def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor:
batch_size = encoder_hidden_states[-1].shape[0]
all_hidden_states = ()
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
encoder_hidden_state = mlp(encoder_hidden_state) # Output is B, HW, C
# Permute to B, C, HW
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width)
# upsample
encoder_hidden_state = nn.functional.interpolate(
encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False
)
all_hidden_states += (encoder_hidden_state,)
hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
hidden_states = self.batch_norm(hidden_states)
hidden_states = self.activation(hidden_states)
# logits are of shape (batch_size, num_labels, height/4, width/4)
logits = self.classifier(hidden_states)
return logits
class EfficientViTForSemanticSegmentation(EfficientViTPreTrainedModel):
def __init__(self, config, **kwargs):
super().__init__(config)
self.vit = EfficientVitLarge(config)
self.decode_head = DecodeHead(config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
pixel_values: torch.FloatTensor
) -> Union[Tuple, SemanticSegmenterOutput]:
# Pixel values should be B,C,H,W
encoder_hidden_states = self.vit(
pixel_values,
)
logits = self.decode_head(encoder_hidden_states)
# Apply sigmoid to get 0-1 output
logits = torch.special.expit(logits)
return SemanticSegmenterOutput(
loss=None,
logits=logits,
hidden_states=encoder_hidden_states
)