File size: 6,853 Bytes
dd78229 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.layers import trunc_normal_
from segmenter_model.blocks import Block, FeedForward
from segmenter_model.utils import init_weights
class DecoderLinear(nn.Module):
def __init__(self, n_cls, patch_size, d_encoder):
super().__init__()
self.d_encoder = d_encoder
self.patch_size = patch_size
self.n_cls = n_cls
self.head = nn.Linear(self.d_encoder, n_cls)
self.apply(init_weights)
@torch.jit.ignore
def no_weight_decay(self):
return set()
def forward(self, x, im_size):
H, W = im_size
GS = H // self.patch_size
x = self.head(x)
x = rearrange(x, "b (h w) c -> b c h w", h=GS)
return x
class MaskTransformer(nn.Module):
def __init__(
self,
n_cls,
patch_size,
d_encoder,
n_layers,
n_heads,
d_model,
d_ff,
drop_path_rate,
dropout,
):
super().__init__()
self.d_encoder = d_encoder
self.patch_size = patch_size
self.n_layers = n_layers
self.n_cls = n_cls
self.d_model = d_model
self.d_ff = d_ff
self.scale = d_model ** -0.5
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
self.blocks = nn.ModuleList(
[Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]
)
self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model))
self.proj_dec = nn.Linear(d_encoder, d_model)
self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model))
self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model))
self.decoder_norm = nn.LayerNorm(d_model)
self.mask_norm = nn.LayerNorm(n_cls)
self.apply(init_weights)
trunc_normal_(self.cls_emb, std=0.02)
@torch.jit.ignore
def no_weight_decay(self):
return {"cls_emb"}
def forward(self, x, im_size, features_only=False, no_rearrange=False):
H, W = im_size
GS = H // self.patch_size
# project from the encoder dimensionality to the decoder dimensionality (usually the same)
x = self.proj_dec(x)
# reshape the class embedding token
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
# concatenate the class embedding token to the input
x = torch.cat((x, cls_emb), 1)
# forward the concatenated tokens through decoder blocks
for blk in self.blocks:
x = blk(x)
# perform normalization
x = self.decoder_norm(x)
# split to patch features and class-segmentation features
patches, cls_seg_feat = x[:, : -self.n_cls], x[:, -self.n_cls:]
# project the patch features
patches = patches @ self.proj_patch
if features_only:
if not no_rearrange:
features = rearrange(patches, "b (h w) n -> b n h w", h=int(GS))
else:
features = patches
return features
# project the class-segmentation features
cls_seg_feat = cls_seg_feat @ self.proj_classes
# scalar product between L2-normalized patch embeddings and class embeddings -> masks
patches = patches / patches.norm(dim=-1, keepdim=True)
cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)
masks = patches @ cls_seg_feat.transpose(1, 2)
masks = self.mask_norm(masks)
if not no_rearrange:
masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))
return masks
def get_attention_map(self, x, layer_id):
if layer_id >= self.n_layers or layer_id < 0:
raise ValueError(
f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}."
)
x = self.proj_dec(x)
cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
x = torch.cat((x, cls_emb), 1)
for i, blk in enumerate(self.blocks):
if i < layer_id:
x = blk(x)
else:
return blk(x, return_attention=True)
class DeepLabHead(nn.Sequential):
def __init__(self, in_channels, num_classes, patch_size=None):
super(DeepLabHead, self).__init__(
ASPP(in_channels, [12, 24, 36]),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, num_classes, 1)
)
self.patch_size = patch_size
def forward(self, x, im_size=None):
if len(x.shape) == 3:
# features from ViT
assert im_size is not None and self.patch_size is not None
H, W = im_size
GS = H // self.patch_size
x = rearrange(x, "b (h w) n -> b n h w", h=int(GS)).contiguous()
for module in self:
x = module(x)
return x
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU())
def forward(self, x):
size = x.shape[-2:]
for mod in self:
x = mod(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, out_channels=256):
super(ASPP, self).__init__()
modules = []
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU()))
rates = tuple(atrous_rates)
for rate in rates:
modules.append(ASPPConv(in_channels, out_channels, rate))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Dropout(0.5))
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
|