test
/
FoodSeg103
/Swin-Transformer-Semantic-Segmentation-main
/tests
/test_models
/test_segmentor.py
import numpy as np | |
import torch | |
from mmcv import ConfigDict | |
from torch import nn | |
from mmseg.models import BACKBONES, HEADS, build_segmentor | |
from mmseg.models.decode_heads.cascade_decode_head import BaseCascadeDecodeHead | |
from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10): | |
"""Create a superset of inputs needed to run test or train batches. | |
Args: | |
input_shape (tuple): | |
input batch dimensions | |
num_classes (int): | |
number of semantic classes | |
""" | |
(N, C, H, W) = input_shape | |
rng = np.random.RandomState(0) | |
imgs = rng.rand(*input_shape) | |
segs = rng.randint( | |
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) | |
img_metas = [{ | |
'img_shape': (H, W, C), | |
'ori_shape': (H, W, C), | |
'pad_shape': (H, W, C), | |
'filename': '<demo>.png', | |
'scale_factor': 1.0, | |
'flip': False, | |
'flip_direction': 'horizontal' | |
} for _ in range(N)] | |
mm_inputs = { | |
'imgs': torch.FloatTensor(imgs), | |
'img_metas': img_metas, | |
'gt_semantic_seg': torch.LongTensor(segs) | |
} | |
return mm_inputs | |
class ExampleBackbone(nn.Module): | |
def __init__(self): | |
super(ExampleBackbone, self).__init__() | |
self.conv = nn.Conv2d(3, 3, 3) | |
def init_weights(self, pretrained=None): | |
pass | |
def forward(self, x): | |
return [self.conv(x)] | |
class ExampleDecodeHead(BaseDecodeHead): | |
def __init__(self): | |
super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19) | |
def forward(self, inputs): | |
return self.cls_seg(inputs[0]) | |
class ExampleCascadeDecodeHead(BaseCascadeDecodeHead): | |
def __init__(self): | |
super(ExampleCascadeDecodeHead, self).__init__(3, 3, num_classes=19) | |
def forward(self, inputs, prev_out): | |
return self.cls_seg(inputs[0]) | |
def _segmentor_forward_train_test(segmentor): | |
if isinstance(segmentor.decode_head, nn.ModuleList): | |
num_classes = segmentor.decode_head[-1].num_classes | |
else: | |
num_classes = segmentor.decode_head.num_classes | |
# batch_size=2 for BatchNorm | |
mm_inputs = _demo_mm_inputs(num_classes=num_classes) | |
imgs = mm_inputs.pop('imgs') | |
img_metas = mm_inputs.pop('img_metas') | |
gt_semantic_seg = mm_inputs['gt_semantic_seg'] | |
# convert to cuda Tensor if applicable | |
if torch.cuda.is_available(): | |
segmentor = segmentor.cuda() | |
imgs = imgs.cuda() | |
gt_semantic_seg = gt_semantic_seg.cuda() | |
# Test forward train | |
losses = segmentor.forward( | |
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True) | |
assert isinstance(losses, dict) | |
# Test forward simple test | |
with torch.no_grad(): | |
segmentor.eval() | |
# pack into lists | |
img_list = [img[None, :] for img in imgs] | |
img_meta_list = [[img_meta] for img_meta in img_metas] | |
segmentor.forward(img_list, img_meta_list, return_loss=False) | |
# Test forward aug test | |
with torch.no_grad(): | |
segmentor.eval() | |
# pack into lists | |
img_list = [img[None, :] for img in imgs] | |
img_list = img_list + img_list | |
img_meta_list = [[img_meta] for img_meta in img_metas] | |
img_meta_list = img_meta_list + img_meta_list | |
segmentor.forward(img_list, img_meta_list, return_loss=False) | |
def test_encoder_decoder(): | |
# test 1 decode head, w.o. aux head | |
cfg = ConfigDict( | |
type='EncoderDecoder', | |
backbone=dict(type='ExampleBackbone'), | |
decode_head=dict(type='ExampleDecodeHead'), | |
train_cfg=None, | |
test_cfg=dict(mode='whole')) | |
segmentor = build_segmentor(cfg) | |
_segmentor_forward_train_test(segmentor) | |
# test slide mode | |
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) | |
segmentor = build_segmentor(cfg) | |
_segmentor_forward_train_test(segmentor) | |
# test 1 decode head, 1 aux head | |
cfg = ConfigDict( | |
type='EncoderDecoder', | |
backbone=dict(type='ExampleBackbone'), | |
decode_head=dict(type='ExampleDecodeHead'), | |
auxiliary_head=dict(type='ExampleDecodeHead')) | |
cfg.test_cfg = ConfigDict(mode='whole') | |
segmentor = build_segmentor(cfg) | |
_segmentor_forward_train_test(segmentor) | |
# test 1 decode head, 2 aux head | |
cfg = ConfigDict( | |
type='EncoderDecoder', | |
backbone=dict(type='ExampleBackbone'), | |
decode_head=dict(type='ExampleDecodeHead'), | |
auxiliary_head=[ | |
dict(type='ExampleDecodeHead'), | |
dict(type='ExampleDecodeHead') | |
]) | |
cfg.test_cfg = ConfigDict(mode='whole') | |
segmentor = build_segmentor(cfg) | |
_segmentor_forward_train_test(segmentor) | |
def test_cascade_encoder_decoder(): | |
# test 1 decode head, w.o. aux head | |
cfg = ConfigDict( | |
type='CascadeEncoderDecoder', | |
num_stages=2, | |
backbone=dict(type='ExampleBackbone'), | |
decode_head=[ | |
dict(type='ExampleDecodeHead'), | |
dict(type='ExampleCascadeDecodeHead') | |
]) | |
cfg.test_cfg = ConfigDict(mode='whole') | |
segmentor = build_segmentor(cfg) | |
_segmentor_forward_train_test(segmentor) | |
# test slide mode | |
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2)) | |
segmentor = build_segmentor(cfg) | |
_segmentor_forward_train_test(segmentor) | |
# test 1 decode head, 1 aux head | |
cfg = ConfigDict( | |
type='CascadeEncoderDecoder', | |
num_stages=2, | |
backbone=dict(type='ExampleBackbone'), | |
decode_head=[ | |
dict(type='ExampleDecodeHead'), | |
dict(type='ExampleCascadeDecodeHead') | |
], | |
auxiliary_head=dict(type='ExampleDecodeHead')) | |
cfg.test_cfg = ConfigDict(mode='whole') | |
segmentor = build_segmentor(cfg) | |
_segmentor_forward_train_test(segmentor) | |
# test 1 decode head, 2 aux head | |
cfg = ConfigDict( | |
type='CascadeEncoderDecoder', | |
num_stages=2, | |
backbone=dict(type='ExampleBackbone'), | |
decode_head=[ | |
dict(type='ExampleDecodeHead'), | |
dict(type='ExampleCascadeDecodeHead') | |
], | |
auxiliary_head=[ | |
dict(type='ExampleDecodeHead'), | |
dict(type='ExampleDecodeHead') | |
]) | |
cfg.test_cfg = ConfigDict(mode='whole') | |
segmentor = build_segmentor(cfg) | |
_segmentor_forward_train_test(segmentor) | |