|
from unittest.mock import patch |
|
|
|
import pytest |
|
import torch |
|
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule |
|
from mmcv.utils import ConfigDict |
|
from mmcv.utils.parrots_wrapper import SyncBatchNorm |
|
|
|
from mmseg.models.decode_heads import (ANNHead, APCHead, ASPPHead, CCHead, |
|
DAHead, DepthwiseSeparableASPPHead, |
|
DepthwiseSeparableFCNHead, DMHead, |
|
DNLHead, EMAHead, EncHead, FCNHead, |
|
GCHead, LRASPPHead, NLHead, OCRHead, |
|
PointHead, PSAHead, PSPHead, UPerHead) |
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead |
|
|
|
|
|
def _conv_has_norm(module, sync_bn): |
|
for m in module.modules(): |
|
if isinstance(m, ConvModule): |
|
if not m.with_norm: |
|
return False |
|
if sync_bn: |
|
if not isinstance(m.bn, SyncBatchNorm): |
|
return False |
|
return True |
|
|
|
|
|
def to_cuda(module, data): |
|
module = module.cuda() |
|
if isinstance(data, list): |
|
for i in range(len(data)): |
|
data[i] = data[i].cuda() |
|
return module, data |
|
|
|
|
|
@patch.multiple(BaseDecodeHead, __abstractmethods__=set()) |
|
def test_decode_head(): |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
BaseDecodeHead([32, 16], 16, num_classes=19) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
BaseDecodeHead(32, 16, num_classes=19, in_index=[-1, -2]) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
BaseDecodeHead(32, 16, num_classes=19, input_transform='concat') |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat') |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
BaseDecodeHead([32], |
|
16, |
|
in_index=-1, |
|
num_classes=19, |
|
input_transform='resize_concat') |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
BaseDecodeHead([32, 16], |
|
16, |
|
num_classes=19, |
|
in_index=[-1], |
|
input_transform='resize_concat') |
|
|
|
|
|
head = BaseDecodeHead(32, 16, num_classes=19) |
|
assert hasattr(head, 'dropout') and head.dropout.p == 0.1 |
|
|
|
|
|
head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2) |
|
assert hasattr(head, 'dropout') and head.dropout.p == 0.2 |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = BaseDecodeHead(32, 16, num_classes=19) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert head.in_channels == 32 |
|
assert head.input_transform is None |
|
transformed_inputs = head._transform_inputs(inputs) |
|
assert transformed_inputs.shape == (1, 32, 45, 45) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)] |
|
head = BaseDecodeHead([32, 16], |
|
16, |
|
num_classes=19, |
|
in_index=[0, 1], |
|
input_transform='resize_concat') |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert head.in_channels == 48 |
|
assert head.input_transform == 'resize_concat' |
|
transformed_inputs = head._transform_inputs(inputs) |
|
assert transformed_inputs.shape == (1, 48, 45, 45) |
|
|
|
|
|
def test_fcn_head(): |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
FCNHead(num_classes=19, num_convs=-1) |
|
|
|
|
|
head = FCNHead(in_channels=32, channels=16, num_classes=19) |
|
for m in head.modules(): |
|
if isinstance(m, ConvModule): |
|
assert not m.with_norm |
|
|
|
|
|
head = FCNHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
norm_cfg=dict(type='SyncBN')) |
|
for m in head.modules(): |
|
if isinstance(m, ConvModule): |
|
assert m.with_norm and isinstance(m.bn, SyncBatchNorm) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = FCNHead( |
|
in_channels=32, channels=16, num_classes=19, concat_input=False) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert len(head.convs) == 2 |
|
assert not head.concat_input and not hasattr(head, 'conv_cat') |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = FCNHead( |
|
in_channels=32, channels=16, num_classes=19, concat_input=True) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert len(head.convs) == 2 |
|
assert head.concat_input |
|
assert head.conv_cat.in_channels == 48 |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = FCNHead(in_channels=32, channels=16, num_classes=19) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
for i in range(len(head.convs)): |
|
assert head.convs[i].kernel_size == (3, 3) |
|
assert head.convs[i].padding == 1 |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = FCNHead(in_channels=32, channels=16, num_classes=19, kernel_size=1) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
for i in range(len(head.convs)): |
|
assert head.convs[i].kernel_size == (1, 1) |
|
assert head.convs[i].padding == 0 |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = FCNHead(in_channels=32, channels=16, num_classes=19, num_convs=1) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert len(head.convs) == 1 |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = FCNHead( |
|
in_channels=32, |
|
channels=32, |
|
num_classes=19, |
|
num_convs=0, |
|
concat_input=False) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert isinstance(head.convs, torch.nn.Identity) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_psp_head(): |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
PSPHead(in_channels=32, channels=16, num_classes=19, pool_scales=1) |
|
|
|
|
|
head = PSPHead(in_channels=32, channels=16, num_classes=19) |
|
assert not _conv_has_norm(head, sync_bn=False) |
|
|
|
|
|
head = PSPHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
norm_cfg=dict(type='SyncBN')) |
|
assert _conv_has_norm(head, sync_bn=True) |
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = PSPHead( |
|
in_channels=32, channels=16, num_classes=19, pool_scales=(1, 2, 3)) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert head.psp_modules[0][0].output_size == 1 |
|
assert head.psp_modules[1][0].output_size == 2 |
|
assert head.psp_modules[2][0].output_size == 3 |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_apc_head(): |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
APCHead(in_channels=32, channels=16, num_classes=19, pool_scales=1) |
|
|
|
|
|
head = APCHead(in_channels=32, channels=16, num_classes=19) |
|
assert not _conv_has_norm(head, sync_bn=False) |
|
|
|
|
|
head = APCHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
norm_cfg=dict(type='SyncBN')) |
|
assert _conv_has_norm(head, sync_bn=True) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = APCHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
pool_scales=(1, 2, 3), |
|
fusion=True) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert head.fusion is True |
|
assert head.acm_modules[0].pool_scale == 1 |
|
assert head.acm_modules[1].pool_scale == 2 |
|
assert head.acm_modules[2].pool_scale == 3 |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = APCHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
pool_scales=(1, 2, 3), |
|
fusion=False) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert head.fusion is False |
|
assert head.acm_modules[0].pool_scale == 1 |
|
assert head.acm_modules[1].pool_scale == 2 |
|
assert head.acm_modules[2].pool_scale == 3 |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_dm_head(): |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
DMHead(in_channels=32, channels=16, num_classes=19, filter_sizes=1) |
|
|
|
|
|
head = DMHead(in_channels=32, channels=16, num_classes=19) |
|
assert not _conv_has_norm(head, sync_bn=False) |
|
|
|
|
|
head = DMHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
norm_cfg=dict(type='SyncBN')) |
|
assert _conv_has_norm(head, sync_bn=True) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = DMHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
filter_sizes=(1, 3, 5), |
|
fusion=True) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert head.fusion is True |
|
assert head.dcm_modules[0].filter_size == 1 |
|
assert head.dcm_modules[1].filter_size == 3 |
|
assert head.dcm_modules[2].filter_size == 5 |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = DMHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
filter_sizes=(1, 3, 5), |
|
fusion=False) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert head.fusion is False |
|
assert head.dcm_modules[0].filter_size == 1 |
|
assert head.dcm_modules[1].filter_size == 3 |
|
assert head.dcm_modules[2].filter_size == 5 |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_aspp_head(): |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
ASPPHead(in_channels=32, channels=16, num_classes=19, dilations=1) |
|
|
|
|
|
head = ASPPHead(in_channels=32, channels=16, num_classes=19) |
|
assert not _conv_has_norm(head, sync_bn=False) |
|
|
|
|
|
head = ASPPHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
norm_cfg=dict(type='SyncBN')) |
|
assert _conv_has_norm(head, sync_bn=True) |
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = ASPPHead( |
|
in_channels=32, channels=16, num_classes=19, dilations=(1, 12, 24)) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert head.aspp_modules[0].conv.dilation == (1, 1) |
|
assert head.aspp_modules[1].conv.dilation == (12, 12) |
|
assert head.aspp_modules[2].conv.dilation == (24, 24) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_psa_head(): |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
PSAHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
mask_size=(39, 39), |
|
psa_type='gather') |
|
|
|
|
|
head = PSAHead( |
|
in_channels=32, channels=16, num_classes=19, mask_size=(39, 39)) |
|
assert not _conv_has_norm(head, sync_bn=False) |
|
|
|
|
|
head = PSAHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
mask_size=(39, 39), |
|
norm_cfg=dict(type='SyncBN')) |
|
assert _conv_has_norm(head, sync_bn=True) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 39, 39)] |
|
head = PSAHead( |
|
in_channels=32, channels=16, num_classes=19, mask_size=(39, 39)) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 39, 39) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 39, 39)] |
|
head = PSAHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
mask_size=(39, 39), |
|
shrink_factor=1) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 39, 39) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 39, 39)] |
|
head = PSAHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
mask_size=(39, 39), |
|
psa_softmax=True) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 39, 39) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 39, 39)] |
|
head = PSAHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
mask_size=(39, 39), |
|
psa_type='collect') |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 39, 39) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 39, 39)] |
|
head = PSAHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
mask_size=(39, 39), |
|
shrink_factor=1, |
|
psa_type='collect') |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 39, 39) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 39, 39)] |
|
head = PSAHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
mask_size=(39, 39), |
|
psa_type='collect', |
|
shrink_factor=1, |
|
compact=True) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 39, 39) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 39, 39)] |
|
head = PSAHead( |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
mask_size=(39, 39), |
|
psa_type='distribute') |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 39, 39) |
|
|
|
|
|
def test_gc_head(): |
|
head = GCHead(in_channels=32, channels=16, num_classes=19) |
|
assert len(head.convs) == 2 |
|
assert hasattr(head, 'gc_block') |
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_nl_head(): |
|
head = NLHead(in_channels=32, channels=16, num_classes=19) |
|
assert len(head.convs) == 2 |
|
assert hasattr(head, 'nl_block') |
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_cc_head(): |
|
head = CCHead(in_channels=32, channels=16, num_classes=19) |
|
assert len(head.convs) == 2 |
|
assert hasattr(head, 'cca') |
|
if not torch.cuda.is_available(): |
|
pytest.skip('CCHead requires CUDA') |
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_uper_head(): |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
UPerHead(in_channels=32, channels=16, num_classes=19) |
|
|
|
|
|
head = UPerHead( |
|
in_channels=[32, 16], channels=16, num_classes=19, in_index=[-2, -1]) |
|
assert not _conv_has_norm(head, sync_bn=False) |
|
|
|
|
|
head = UPerHead( |
|
in_channels=[32, 16], |
|
channels=16, |
|
num_classes=19, |
|
norm_cfg=dict(type='SyncBN'), |
|
in_index=[-2, -1]) |
|
assert _conv_has_norm(head, sync_bn=True) |
|
|
|
inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)] |
|
head = UPerHead( |
|
in_channels=[32, 16], channels=16, num_classes=19, in_index=[-2, -1]) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_ann_head(): |
|
|
|
inputs = [torch.randn(1, 16, 45, 45), torch.randn(1, 32, 21, 21)] |
|
head = ANNHead( |
|
in_channels=[16, 32], |
|
channels=16, |
|
num_classes=19, |
|
in_index=[-2, -1], |
|
project_channels=8) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 21, 21) |
|
|
|
|
|
def test_da_head(): |
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = DAHead(in_channels=32, channels=16, num_classes=19, pam_channels=8) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert isinstance(outputs, tuple) and len(outputs) == 3 |
|
for output in outputs: |
|
assert output.shape == (1, head.num_classes, 45, 45) |
|
test_output = head.forward_test(inputs, None, None) |
|
assert test_output.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_ocr_head(): |
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
ocr_head = OCRHead( |
|
in_channels=32, channels=16, num_classes=19, ocr_channels=8) |
|
fcn_head = FCNHead(in_channels=32, channels=16, num_classes=19) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(ocr_head, inputs) |
|
head, inputs = to_cuda(fcn_head, inputs) |
|
prev_output = fcn_head(inputs) |
|
output = ocr_head(inputs, prev_output) |
|
assert output.shape == (1, ocr_head.num_classes, 45, 45) |
|
|
|
|
|
def test_enc_head(): |
|
|
|
inputs = [torch.randn(1, 32, 21, 21)] |
|
head = EncHead( |
|
in_channels=[32], channels=16, num_classes=19, in_index=[-1]) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert isinstance(outputs, tuple) and len(outputs) == 2 |
|
assert outputs[0].shape == (1, head.num_classes, 21, 21) |
|
assert outputs[1].shape == (1, head.num_classes) |
|
|
|
|
|
inputs = [torch.randn(1, 32, 21, 21)] |
|
head = EncHead( |
|
in_channels=[32], |
|
channels=16, |
|
use_se_loss=False, |
|
num_classes=19, |
|
in_index=[-1]) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 21, 21) |
|
|
|
|
|
inputs = [torch.randn(1, 16, 45, 45), torch.randn(1, 32, 21, 21)] |
|
head = EncHead( |
|
in_channels=[16, 32], |
|
channels=16, |
|
add_lateral=True, |
|
num_classes=19, |
|
in_index=[-2, -1]) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert isinstance(outputs, tuple) and len(outputs) == 2 |
|
assert outputs[0].shape == (1, head.num_classes, 21, 21) |
|
assert outputs[1].shape == (1, head.num_classes) |
|
test_output = head.forward_test(inputs, None, None) |
|
assert test_output.shape == (1, head.num_classes, 21, 21) |
|
|
|
|
|
def test_dw_aspp_head(): |
|
|
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
head = DepthwiseSeparableASPPHead( |
|
c1_in_channels=0, |
|
c1_channels=0, |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
dilations=(1, 12, 24)) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert head.c1_bottleneck is None |
|
assert head.aspp_modules[0].conv.dilation == (1, 1) |
|
assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12) |
|
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
inputs = [torch.randn(1, 8, 45, 45), torch.randn(1, 32, 21, 21)] |
|
head = DepthwiseSeparableASPPHead( |
|
c1_in_channels=8, |
|
c1_channels=4, |
|
in_channels=32, |
|
channels=16, |
|
num_classes=19, |
|
dilations=(1, 12, 24)) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
assert head.c1_bottleneck.in_channels == 8 |
|
assert head.c1_bottleneck.out_channels == 4 |
|
assert head.aspp_modules[0].conv.dilation == (1, 1) |
|
assert head.aspp_modules[1].depthwise_conv.dilation == (12, 12) |
|
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_sep_fcn_head(): |
|
|
|
head = DepthwiseSeparableFCNHead( |
|
in_channels=128, |
|
channels=128, |
|
concat_input=False, |
|
num_classes=19, |
|
in_index=-1, |
|
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) |
|
x = [torch.rand(2, 128, 32, 32)] |
|
output = head(x) |
|
assert output.shape == (2, head.num_classes, 32, 32) |
|
assert not head.concat_input |
|
assert isinstance(head.convs[0], DepthwiseSeparableConvModule) |
|
assert isinstance(head.convs[1], DepthwiseSeparableConvModule) |
|
assert head.conv_seg.kernel_size == (1, 1) |
|
|
|
head = DepthwiseSeparableFCNHead( |
|
in_channels=64, |
|
channels=64, |
|
concat_input=True, |
|
num_classes=19, |
|
in_index=-1, |
|
norm_cfg=dict(type='BN', requires_grad=True, momentum=0.01)) |
|
x = [torch.rand(3, 64, 32, 32)] |
|
output = head(x) |
|
assert output.shape == (3, head.num_classes, 32, 32) |
|
assert head.concat_input |
|
assert isinstance(head.convs[0], DepthwiseSeparableConvModule) |
|
assert isinstance(head.convs[1], DepthwiseSeparableConvModule) |
|
|
|
|
|
def test_dnl_head(): |
|
|
|
head = DNLHead(in_channels=32, channels=16, num_classes=19) |
|
assert len(head.convs) == 2 |
|
assert hasattr(head, 'dnl_block') |
|
assert head.dnl_block.temperature == 0.05 |
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
head = DNLHead( |
|
in_channels=32, channels=16, num_classes=19, mode='dot_product') |
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
head = DNLHead( |
|
in_channels=32, channels=16, num_classes=19, mode='gaussian') |
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
head = DNLHead( |
|
in_channels=32, channels=16, num_classes=19, mode='concatenation') |
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_emanet_head(): |
|
head = EMAHead( |
|
in_channels=32, |
|
ema_channels=24, |
|
channels=16, |
|
num_stages=3, |
|
num_bases=16, |
|
num_classes=19) |
|
for param in head.ema_mid_conv.parameters(): |
|
assert not param.requires_grad |
|
assert hasattr(head, 'ema_module') |
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(head, inputs) |
|
outputs = head(inputs) |
|
assert outputs.shape == (1, head.num_classes, 45, 45) |
|
|
|
|
|
def test_point_head(): |
|
|
|
inputs = [torch.randn(1, 32, 45, 45)] |
|
point_head = PointHead( |
|
in_channels=[32], in_index=[0], channels=16, num_classes=19) |
|
assert len(point_head.fcs) == 3 |
|
fcn_head = FCNHead(in_channels=32, channels=16, num_classes=19) |
|
if torch.cuda.is_available(): |
|
head, inputs = to_cuda(point_head, inputs) |
|
head, inputs = to_cuda(fcn_head, inputs) |
|
prev_output = fcn_head(inputs) |
|
test_cfg = ConfigDict( |
|
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2) |
|
output = point_head.forward_test(inputs, prev_output, None, test_cfg) |
|
assert output.shape == (1, point_head.num_classes, 180, 180) |
|
|
|
|
|
def test_lraspp_head(): |
|
with pytest.raises(ValueError): |
|
|
|
LRASPPHead( |
|
in_channels=(16, 16, 576), |
|
in_index=(0, 1, 2), |
|
channels=128, |
|
input_transform='resize_concat', |
|
dropout_ratio=0.1, |
|
num_classes=19, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
align_corners=False, |
|
loss_decode=dict( |
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
LRASPPHead( |
|
in_channels=(16, 16, 576), |
|
in_index=(0, 1, 2), |
|
channels=128, |
|
branch_channels=64, |
|
input_transform='multiple_select', |
|
dropout_ratio=0.1, |
|
num_classes=19, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
align_corners=False, |
|
loss_decode=dict( |
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) |
|
|
|
|
|
lraspp_head = LRASPPHead( |
|
in_channels=(16, 16, 576), |
|
in_index=(0, 1, 2), |
|
channels=128, |
|
input_transform='multiple_select', |
|
dropout_ratio=0.1, |
|
num_classes=19, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
align_corners=False, |
|
loss_decode=dict( |
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) |
|
inputs = [ |
|
torch.randn(2, 16, 45, 45), |
|
torch.randn(2, 16, 28, 28), |
|
torch.randn(2, 576, 14, 14) |
|
] |
|
with pytest.raises(RuntimeError): |
|
|
|
output = lraspp_head(inputs) |
|
|
|
inputs = [ |
|
torch.randn(2, 16, 111, 111), |
|
torch.randn(2, 16, 77, 77), |
|
torch.randn(2, 576, 55, 55) |
|
] |
|
output = lraspp_head(inputs) |
|
assert output.shape == (2, 19, 111, 111) |
|
|