|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
|
|
from timm.models.layers import trunc_normal_ |
|
|
|
from segmenter_model.blocks import Block, FeedForward |
|
from segmenter_model.utils import init_weights |
|
|
|
|
|
class DecoderLinear(nn.Module): |
|
def __init__(self, n_cls, patch_size, d_encoder): |
|
super().__init__() |
|
|
|
self.d_encoder = d_encoder |
|
self.patch_size = patch_size |
|
self.n_cls = n_cls |
|
|
|
self.head = nn.Linear(self.d_encoder, n_cls) |
|
self.apply(init_weights) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return set() |
|
|
|
def forward(self, x, im_size): |
|
H, W = im_size |
|
GS = H // self.patch_size |
|
x = self.head(x) |
|
x = rearrange(x, "b (h w) c -> b c h w", h=GS) |
|
|
|
return x |
|
|
|
|
|
class MaskTransformer(nn.Module): |
|
def __init__( |
|
self, |
|
n_cls, |
|
patch_size, |
|
d_encoder, |
|
n_layers, |
|
n_heads, |
|
d_model, |
|
d_ff, |
|
drop_path_rate, |
|
dropout, |
|
): |
|
super().__init__() |
|
self.d_encoder = d_encoder |
|
self.patch_size = patch_size |
|
self.n_layers = n_layers |
|
self.n_cls = n_cls |
|
self.d_model = d_model |
|
self.d_ff = d_ff |
|
self.scale = d_model ** -0.5 |
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] |
|
self.blocks = nn.ModuleList( |
|
[Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)] |
|
) |
|
|
|
self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model)) |
|
self.proj_dec = nn.Linear(d_encoder, d_model) |
|
|
|
self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model)) |
|
self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model)) |
|
|
|
self.decoder_norm = nn.LayerNorm(d_model) |
|
self.mask_norm = nn.LayerNorm(n_cls) |
|
|
|
self.apply(init_weights) |
|
trunc_normal_(self.cls_emb, std=0.02) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return {"cls_emb"} |
|
|
|
def forward(self, x, im_size, features_only=False, no_rearrange=False): |
|
H, W = im_size |
|
GS = H // self.patch_size |
|
|
|
|
|
x = self.proj_dec(x) |
|
|
|
cls_emb = self.cls_emb.expand(x.size(0), -1, -1) |
|
|
|
x = torch.cat((x, cls_emb), 1) |
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
|
|
x = self.decoder_norm(x) |
|
|
|
|
|
patches, cls_seg_feat = x[:, : -self.n_cls], x[:, -self.n_cls:] |
|
|
|
|
|
patches = patches @ self.proj_patch |
|
|
|
if features_only: |
|
if not no_rearrange: |
|
features = rearrange(patches, "b (h w) n -> b n h w", h=int(GS)) |
|
else: |
|
features = patches |
|
return features |
|
|
|
|
|
cls_seg_feat = cls_seg_feat @ self.proj_classes |
|
|
|
|
|
patches = patches / patches.norm(dim=-1, keepdim=True) |
|
cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True) |
|
masks = patches @ cls_seg_feat.transpose(1, 2) |
|
|
|
masks = self.mask_norm(masks) |
|
if not no_rearrange: |
|
masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS)) |
|
|
|
return masks |
|
|
|
def get_attention_map(self, x, layer_id): |
|
if layer_id >= self.n_layers or layer_id < 0: |
|
raise ValueError( |
|
f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}." |
|
) |
|
x = self.proj_dec(x) |
|
cls_emb = self.cls_emb.expand(x.size(0), -1, -1) |
|
x = torch.cat((x, cls_emb), 1) |
|
for i, blk in enumerate(self.blocks): |
|
if i < layer_id: |
|
x = blk(x) |
|
else: |
|
return blk(x, return_attention=True) |
|
|
|
|
|
class DeepLabHead(nn.Sequential): |
|
def __init__(self, in_channels, num_classes, patch_size=None): |
|
super(DeepLabHead, self).__init__( |
|
ASPP(in_channels, [12, 24, 36]), |
|
nn.Conv2d(256, 256, 3, padding=1, bias=False), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(), |
|
nn.Conv2d(256, num_classes, 1) |
|
) |
|
self.patch_size = patch_size |
|
|
|
def forward(self, x, im_size=None): |
|
if len(x.shape) == 3: |
|
|
|
assert im_size is not None and self.patch_size is not None |
|
H, W = im_size |
|
GS = H // self.patch_size |
|
x = rearrange(x, "b (h w) n -> b n h w", h=int(GS)).contiguous() |
|
for module in self: |
|
x = module(x) |
|
return x |
|
|
|
|
|
class ASPPConv(nn.Sequential): |
|
def __init__(self, in_channels, out_channels, dilation): |
|
modules = [ |
|
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU() |
|
] |
|
super(ASPPConv, self).__init__(*modules) |
|
|
|
|
|
class ASPPPooling(nn.Sequential): |
|
def __init__(self, in_channels, out_channels): |
|
super(ASPPPooling, self).__init__( |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Conv2d(in_channels, out_channels, 1, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU()) |
|
|
|
def forward(self, x): |
|
size = x.shape[-2:] |
|
for mod in self: |
|
x = mod(x) |
|
return F.interpolate(x, size=size, mode='bilinear', align_corners=False) |
|
|
|
|
|
class ASPP(nn.Module): |
|
def __init__(self, in_channels, atrous_rates, out_channels=256): |
|
super(ASPP, self).__init__() |
|
modules = [] |
|
modules.append(nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, 1, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU())) |
|
|
|
rates = tuple(atrous_rates) |
|
for rate in rates: |
|
modules.append(ASPPConv(in_channels, out_channels, rate)) |
|
|
|
modules.append(ASPPPooling(in_channels, out_channels)) |
|
|
|
self.convs = nn.ModuleList(modules) |
|
|
|
self.project = nn.Sequential( |
|
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), |
|
nn.BatchNorm2d(out_channels), |
|
nn.ReLU(), |
|
nn.Dropout(0.5)) |
|
|
|
def forward(self, x): |
|
res = [] |
|
for conv in self.convs: |
|
res.append(conv(x)) |
|
res = torch.cat(res, dim=1) |
|
return self.project(res) |
|
|