|
"""Light HamHead Decoder. |
|
|
|
Adapted from: |
|
https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/decode_heads/ham_head.py |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from siclib.models import BaseModel |
|
from siclib.models.utils.modules import ConvModule, FeatureFusionBlock |
|
|
|
|
|
|
|
|
|
|
|
class _MatrixDecomposition2DBase(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.spatial = True |
|
|
|
self.S = 1 |
|
self.D = 512 |
|
self.R = 64 |
|
|
|
self.train_steps = 6 |
|
self.eval_steps = 7 |
|
|
|
self.inv_t = 100 |
|
self.eta = 0.9 |
|
|
|
self.rand_init = True |
|
|
|
def _build_bases(self, B, S, D, R, device="cpu"): |
|
raise NotImplementedError |
|
|
|
def local_step(self, x, bases, coef): |
|
raise NotImplementedError |
|
|
|
|
|
def local_inference(self, x, bases): |
|
|
|
coef = torch.bmm(x.transpose(1, 2), bases) |
|
coef = F.softmax(self.inv_t * coef, dim=-1) |
|
|
|
steps = self.train_steps if self.training else self.eval_steps |
|
for _ in range(steps): |
|
bases, coef = self.local_step(x, bases, coef) |
|
|
|
return bases, coef |
|
|
|
def compute_coef(self, x, bases, coef): |
|
raise NotImplementedError |
|
|
|
def forward(self, x, return_bases=False): |
|
B, C, H, W = x.shape |
|
|
|
|
|
if self.spatial: |
|
D = C // self.S |
|
N = H * W |
|
x = x.view(B * self.S, D, N) |
|
else: |
|
D = H * W |
|
N = C // self.S |
|
x = x.view(B * self.S, N, D).transpose(1, 2) |
|
|
|
if not self.rand_init and not hasattr(self, "bases"): |
|
bases = self._build_bases(1, self.S, D, self.R, device=x.device) |
|
self.register_buffer("bases", bases) |
|
|
|
|
|
if self.rand_init: |
|
bases = self._build_bases(B, self.S, D, self.R, device=x.device) |
|
else: |
|
bases = self.bases.repeat(B, 1, 1) |
|
|
|
bases, coef = self.local_inference(x, bases) |
|
|
|
|
|
coef = self.compute_coef(x, bases, coef) |
|
|
|
|
|
x = torch.bmm(bases, coef.transpose(1, 2)) |
|
|
|
|
|
x = x.view(B, C, H, W) if self.spatial else x.transpose(1, 2).view(B, C, H, W) |
|
|
|
bases = bases.view(B, self.S, D, self.R) |
|
|
|
return x |
|
|
|
|
|
class NMF2D(_MatrixDecomposition2DBase): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.inv_t = 1 |
|
|
|
def _build_bases(self, B, S, D, R, device="cpu"): |
|
bases = torch.rand((B * S, D, R)).to(device) |
|
bases = F.normalize(bases, dim=1) |
|
|
|
return bases |
|
|
|
|
|
def local_step(self, x, bases, coef): |
|
|
|
numerator = torch.bmm(x.transpose(1, 2), bases) |
|
|
|
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) |
|
|
|
coef = coef * numerator / (denominator + 1e-6) |
|
|
|
|
|
numerator = torch.bmm(x, coef) |
|
|
|
denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) |
|
|
|
bases = bases * numerator / (denominator + 1e-6) |
|
|
|
return bases, coef |
|
|
|
def compute_coef(self, x, bases, coef): |
|
|
|
numerator = torch.bmm(x.transpose(1, 2), bases) |
|
|
|
denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) |
|
|
|
coef = coef * numerator / (denominator + 1e-6) |
|
|
|
return coef |
|
|
|
|
|
class Hamburger(nn.Module): |
|
def __init__(self, ham_channels=512, norm_cfg=None, **kwargs): |
|
super().__init__() |
|
|
|
self.ham_in = ConvModule(ham_channels, ham_channels, 1) |
|
|
|
self.ham = NMF2D() |
|
|
|
self.ham_out = ConvModule(ham_channels, ham_channels, 1) |
|
|
|
def forward(self, x): |
|
enjoy = self.ham_in(x) |
|
enjoy = F.relu(enjoy, inplace=False) |
|
enjoy = self.ham(enjoy) |
|
enjoy = self.ham_out(enjoy) |
|
ham = F.relu(x + enjoy, inplace=False) |
|
|
|
return ham |
|
|
|
|
|
class LightHamHead(BaseModel): |
|
"""Is Attention Better Than Matrix Decomposition? |
|
This head is the implementation of `HamNet |
|
<https://arxiv.org/abs/2109.04553>`_. |
|
|
|
Args: |
|
ham_channels (int): input channels for Hamburger. |
|
ham_kwargs (int): kwagrs for Ham. |
|
""" |
|
|
|
default_conf = { |
|
"predict_uncertainty": True, |
|
"out_channels": 64, |
|
"in_channels": [64, 128, 320, 512], |
|
"in_index": [0, 1, 2, 3], |
|
"ham_channels": 512, |
|
"with_low_level": True, |
|
} |
|
|
|
def _init(self, conf): |
|
self.in_index = conf.in_index |
|
self.in_channels = conf.in_channels |
|
self.out_channels = conf.out_channels |
|
self.ham_channels = conf.ham_channels |
|
self.align_corners = False |
|
self.predict_uncertainty = conf.predict_uncertainty |
|
|
|
self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1) |
|
|
|
self.hamburger = Hamburger(self.ham_channels) |
|
|
|
self.align = ConvModule(self.ham_channels, self.out_channels, 1) |
|
|
|
if self.predict_uncertainty: |
|
self.linear_pred_uncertainty = nn.Sequential( |
|
ConvModule( |
|
in_channels=self.out_channels, |
|
out_channels=self.out_channels, |
|
kernel_size=3, |
|
padding=1, |
|
bias=False, |
|
), |
|
nn.Conv2d(in_channels=self.out_channels, out_channels=1, kernel_size=1), |
|
) |
|
|
|
self.with_ll = conf.with_low_level |
|
if self.with_ll: |
|
self.out_conv = ConvModule( |
|
self.out_channels, self.out_channels, 3, padding=1, bias=False |
|
) |
|
self.ll_fusion = FeatureFusionBlock(self.out_channels, upsample=False) |
|
|
|
def _forward(self, features): |
|
"""Forward function.""" |
|
|
|
inputs = [features["hl"][i] for i in self.in_index] |
|
|
|
inputs = [ |
|
F.interpolate( |
|
level, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners |
|
) |
|
for level in inputs |
|
] |
|
|
|
inputs = torch.cat(inputs, dim=1) |
|
x = self.squeeze(inputs) |
|
|
|
x = self.hamburger(x) |
|
|
|
feats = self.align(x) |
|
|
|
if self.with_ll: |
|
assert "ll" in features, "Low-level features are required for this model" |
|
feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False) |
|
feats = self.out_conv(feats) |
|
feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False) |
|
feats_ll = features["ll"].clone() |
|
feats = self.ll_fusion(feats, feats_ll) |
|
|
|
uncertainty = ( |
|
self.linear_pred_uncertainty(feats).squeeze(1) if self.predict_uncertainty else None |
|
) |
|
|
|
return feats, uncertainty |
|
|
|
def loss(self, pred, data): |
|
raise NotImplementedError |
|
|