RSPrompter / mmpl /models /heads /sam_semseg_head.py
KyanChen's picture
Upload 159 files
1c3eb47
import warnings
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from typing import Tuple, List
from torch import Tensor
from mmpl.registry import MODELS
from mmseg.models import build_loss
from mmseg.models.utils import resize
from mmseg.structures import build_pixel_sampler
from mmseg.utils import SampleList, ConfigType
@MODELS.register_module()
class SamSemSegHead(BaseModule):
def __init__(self,
in_channels=2,
inner_channels=None,
num_classes=1,
ignore_index=255,
threshold=None,
out_channels=None,
loss_decode=None,
sampler=None,
align_corners=False,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU', inplace=True),
train_cfg=None,
test_cfg=None,
):
super().__init__()
self.in_channels = in_channels
self.ignore_index = ignore_index
self.align_corners = align_corners
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if out_channels is None:
if num_classes == 2:
warnings.warn('For binary segmentation, we suggest using'
'`out_channels = 1` to define the output'
'channels of segmentor, and use `threshold`'
'to convert `seg_logits` into a prediction'
'applying a threshold')
out_channels = num_classes
if out_channels != num_classes and out_channels != 1:
raise ValueError(
'out_channels should be equal to num_classes,'
'except binary segmentation set out_channels == 1 and'
f'num_classes == 2, but got out_channels={out_channels}'
f'and num_classes={num_classes}')
if out_channels == 1 and threshold is None:
threshold = 0.3
warnings.warn('threshold is not defined for binary, and defaults'
'to 0.3')
self.num_classes = num_classes
self.out_channels = num_classes + 1
self.threshold = threshold
if isinstance(loss_decode, dict):
self.loss_decode = MODELS.build(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(MODELS.build(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')
if sampler is not None:
self.sampler = build_pixel_sampler(sampler, context=self)
else:
self.sampler = None
if inner_channels is None:
self.down_conv = nn.ModuleList([nn.Identity(), nn.Identity()])
else:
self.down_conv = nn.ModuleList([
nn.Conv2d(in_channels, inner_channels, 1),
nn.Conv2d(in_channels, inner_channels, 1)
])
in_channels = inner_channels
self.cls_seg = nn.Sequential(
nn.Linear(in_channels, in_channels // 2),
nn.ReLU(inplace=True),
nn.Linear(in_channels // 2, self.out_channels)
)
self.up_conv = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
),
nn.UpsamplingBilinear2d(scale_factor=2),
ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
),
nn.UpsamplingBilinear2d(scale_factor=2),
ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
)
def forward(self, inputs):
"""Forward function."""
x0, x1 = inputs
x0 = self.down_conv[0](x0)
x1 = self.down_conv[1](x1)
gate_x0 = torch.sigmoid(x0) # B N H W
x1 = torch.einsum('bnhw,bchw->bnchw', gate_x0, x1)
x1 = torch.mean(x1, dim=(-2, -1))
x1 = self.cls_seg(x1) # B N K
x0 = self.up_conv(x0) # B N H W
seg_logits = torch.einsum('bnhw,bnk->bkhw', x0, x1)
return seg_logits
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType=None) -> dict:
seg_logits = self.forward(inputs)
losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType=None) -> Tensor:
"""Forward function for prediction.
Args:
inputs (Tuple[Tensor]): List of multi-level img features.
batch_img_metas (dict): List Image info where each dict may also
contain: 'img_shape', 'scale_factor', 'flip', 'img_path',
'ori_shape', and 'pad_shape'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`.
test_cfg (dict): The testing config.
Returns:
Tensor: Outputs segmentation logits map.
"""
seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas)
def predict_by_feat(self, seg_logits: Tensor,
batch_img_metas: List[dict]) -> Tensor:
"""Transform a batch of output seg_logits to the input shape.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
Returns:
Tensor: Outputs segmentation logits map.
"""
seg_logits = resize(
input=seg_logits,
size=batch_img_metas[0]['img_shape'],
mode='bilinear',
align_corners=self.align_corners)
return seg_logits
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
gt_semantic_segs = [
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
]
return torch.stack(gt_semantic_segs, dim=0)
def loss_by_feat(self, seg_logits: Tensor,
batch_data_samples: SampleList) -> dict:
seg_label = self._stack_batch_gt(batch_data_samples)
loss = dict()
seg_logits = resize(
input=seg_logits,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logits, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
return loss