Spaces:
Runtime error
Runtime error
import warnings | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from mmengine.model import BaseModule | |
from typing import Tuple, List | |
from torch import Tensor | |
from mmpl.registry import MODELS | |
from mmseg.models import build_loss | |
from mmseg.models.utils import resize | |
from mmseg.structures import build_pixel_sampler | |
from mmseg.utils import SampleList, ConfigType | |
class SamSemSegHead(BaseModule): | |
def __init__(self, | |
in_channels=2, | |
inner_channels=None, | |
num_classes=1, | |
ignore_index=255, | |
threshold=None, | |
out_channels=None, | |
loss_decode=None, | |
sampler=None, | |
align_corners=False, | |
norm_cfg=dict(type='BN', requires_grad=True), | |
act_cfg=dict(type='ReLU', inplace=True), | |
train_cfg=None, | |
test_cfg=None, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.ignore_index = ignore_index | |
self.align_corners = align_corners | |
self.train_cfg = train_cfg | |
self.test_cfg = test_cfg | |
if out_channels is None: | |
if num_classes == 2: | |
warnings.warn('For binary segmentation, we suggest using' | |
'`out_channels = 1` to define the output' | |
'channels of segmentor, and use `threshold`' | |
'to convert `seg_logits` into a prediction' | |
'applying a threshold') | |
out_channels = num_classes | |
if out_channels != num_classes and out_channels != 1: | |
raise ValueError( | |
'out_channels should be equal to num_classes,' | |
'except binary segmentation set out_channels == 1 and' | |
f'num_classes == 2, but got out_channels={out_channels}' | |
f'and num_classes={num_classes}') | |
if out_channels == 1 and threshold is None: | |
threshold = 0.3 | |
warnings.warn('threshold is not defined for binary, and defaults' | |
'to 0.3') | |
self.num_classes = num_classes | |
self.out_channels = num_classes + 1 | |
self.threshold = threshold | |
if isinstance(loss_decode, dict): | |
self.loss_decode = MODELS.build(loss_decode) | |
elif isinstance(loss_decode, (list, tuple)): | |
self.loss_decode = nn.ModuleList() | |
for loss in loss_decode: | |
self.loss_decode.append(MODELS.build(loss)) | |
else: | |
raise TypeError(f'loss_decode must be a dict or sequence of dict,\ | |
but got {type(loss_decode)}') | |
if sampler is not None: | |
self.sampler = build_pixel_sampler(sampler, context=self) | |
else: | |
self.sampler = None | |
if inner_channels is None: | |
self.down_conv = nn.ModuleList([nn.Identity(), nn.Identity()]) | |
else: | |
self.down_conv = nn.ModuleList([ | |
nn.Conv2d(in_channels, inner_channels, 1), | |
nn.Conv2d(in_channels, inner_channels, 1) | |
]) | |
in_channels = inner_channels | |
self.cls_seg = nn.Sequential( | |
nn.Linear(in_channels, in_channels // 2), | |
nn.ReLU(inplace=True), | |
nn.Linear(in_channels // 2, self.out_channels) | |
) | |
self.up_conv = nn.Sequential( | |
nn.UpsamplingBilinear2d(scale_factor=2), | |
ConvModule( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
), | |
nn.UpsamplingBilinear2d(scale_factor=2), | |
ConvModule( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
), | |
nn.UpsamplingBilinear2d(scale_factor=2), | |
ConvModule( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
) | |
) | |
def forward(self, inputs): | |
"""Forward function.""" | |
x0, x1 = inputs | |
x0 = self.down_conv[0](x0) | |
x1 = self.down_conv[1](x1) | |
gate_x0 = torch.sigmoid(x0) # B N H W | |
x1 = torch.einsum('bnhw,bchw->bnchw', gate_x0, x1) | |
x1 = torch.mean(x1, dim=(-2, -1)) | |
x1 = self.cls_seg(x1) # B N K | |
x0 = self.up_conv(x0) # B N H W | |
seg_logits = torch.einsum('bnhw,bnk->bkhw', x0, x1) | |
return seg_logits | |
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, | |
train_cfg: ConfigType=None) -> dict: | |
seg_logits = self.forward(inputs) | |
losses = self.loss_by_feat(seg_logits, batch_data_samples) | |
return losses | |
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], | |
test_cfg: ConfigType=None) -> Tensor: | |
"""Forward function for prediction. | |
Args: | |
inputs (Tuple[Tensor]): List of multi-level img features. | |
batch_img_metas (dict): List Image info where each dict may also | |
contain: 'img_shape', 'scale_factor', 'flip', 'img_path', | |
'ori_shape', and 'pad_shape'. | |
For details on the values of these keys see | |
`mmseg/datasets/pipelines/formatting.py:PackSegInputs`. | |
test_cfg (dict): The testing config. | |
Returns: | |
Tensor: Outputs segmentation logits map. | |
""" | |
seg_logits = self.forward(inputs) | |
return self.predict_by_feat(seg_logits, batch_img_metas) | |
def predict_by_feat(self, seg_logits: Tensor, | |
batch_img_metas: List[dict]) -> Tensor: | |
"""Transform a batch of output seg_logits to the input shape. | |
Args: | |
seg_logits (Tensor): The output from decode head forward function. | |
batch_img_metas (list[dict]): Meta information of each image, e.g., | |
image size, scaling factor, etc. | |
Returns: | |
Tensor: Outputs segmentation logits map. | |
""" | |
seg_logits = resize( | |
input=seg_logits, | |
size=batch_img_metas[0]['img_shape'], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
return seg_logits | |
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: | |
gt_semantic_segs = [ | |
data_sample.gt_sem_seg.data for data_sample in batch_data_samples | |
] | |
return torch.stack(gt_semantic_segs, dim=0) | |
def loss_by_feat(self, seg_logits: Tensor, | |
batch_data_samples: SampleList) -> dict: | |
seg_label = self._stack_batch_gt(batch_data_samples) | |
loss = dict() | |
seg_logits = resize( | |
input=seg_logits, | |
size=seg_label.shape[2:], | |
mode='bilinear', | |
align_corners=self.align_corners) | |
if self.sampler is not None: | |
seg_weight = self.sampler.sample(seg_logits, seg_label) | |
else: | |
seg_weight = None | |
seg_label = seg_label.squeeze(1) | |
if not isinstance(self.loss_decode, nn.ModuleList): | |
losses_decode = [self.loss_decode] | |
else: | |
losses_decode = self.loss_decode | |
for loss_decode in losses_decode: | |
if loss_decode.loss_name not in loss: | |
loss[loss_decode.loss_name] = loss_decode( | |
seg_logits, | |
seg_label, | |
weight=seg_weight, | |
ignore_index=self.ignore_index) | |
else: | |
loss[loss_decode.loss_name] += loss_decode( | |
seg_logits, | |
seg_label, | |
weight=seg_weight, | |
ignore_index=self.ignore_index) | |
return loss | |