|
import pytest |
|
import torch |
|
from mmcv.cnn import ConvModule |
|
from mmcv.utils.parrots_wrapper import _BatchNorm |
|
from torch import nn |
|
|
|
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule, |
|
InterpConv, UNet, UpConvBlock) |
|
|
|
|
|
def check_norm_state(modules, train_state): |
|
"""Check if norm layer is in correct train state.""" |
|
for mod in modules: |
|
if isinstance(mod, _BatchNorm): |
|
if mod.training != train_state: |
|
return False |
|
return True |
|
|
|
|
|
def test_unet_basic_conv_block(): |
|
with pytest.raises(AssertionError): |
|
|
|
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) |
|
BasicConvBlock(64, 64, dcn=dcn) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
plugins = [ |
|
dict( |
|
cfg=dict(type='ContextBlock', ratio=1. / 16), |
|
position='after_conv3') |
|
] |
|
BasicConvBlock(64, 64, plugins=plugins) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
plugins = [ |
|
dict( |
|
cfg=dict( |
|
type='GeneralizedAttention', |
|
spatial_range=-1, |
|
num_heads=8, |
|
attention_type='0010', |
|
kv_stride=2), |
|
position='after_conv2') |
|
] |
|
BasicConvBlock(64, 64, plugins=plugins) |
|
|
|
|
|
block = BasicConvBlock(16, 16, with_cp=True) |
|
assert block.with_cp |
|
x = torch.randn(1, 16, 64, 64, requires_grad=True) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 16, 64, 64]) |
|
|
|
block = BasicConvBlock(16, 16, with_cp=False) |
|
assert not block.with_cp |
|
x = torch.randn(1, 16, 64, 64) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 16, 64, 64]) |
|
|
|
|
|
block = BasicConvBlock(16, 16, stride=2) |
|
x = torch.randn(1, 16, 64, 64) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 16, 32, 32]) |
|
|
|
|
|
block = BasicConvBlock(16, 64, num_convs=3, dilation=3) |
|
assert block.convs[0].conv.in_channels == 16 |
|
assert block.convs[0].conv.out_channels == 64 |
|
assert block.convs[0].conv.kernel_size == (3, 3) |
|
assert block.convs[0].conv.dilation == (1, 1) |
|
assert block.convs[0].conv.padding == (1, 1) |
|
|
|
assert block.convs[1].conv.in_channels == 64 |
|
assert block.convs[1].conv.out_channels == 64 |
|
assert block.convs[1].conv.kernel_size == (3, 3) |
|
assert block.convs[1].conv.dilation == (3, 3) |
|
assert block.convs[1].conv.padding == (3, 3) |
|
|
|
assert block.convs[2].conv.in_channels == 64 |
|
assert block.convs[2].conv.out_channels == 64 |
|
assert block.convs[2].conv.kernel_size == (3, 3) |
|
assert block.convs[2].conv.dilation == (3, 3) |
|
assert block.convs[2].conv.padding == (3, 3) |
|
|
|
|
|
def test_deconv_module(): |
|
with pytest.raises(AssertionError): |
|
|
|
|
|
DeconvModule(64, 32, kernel_size=1, scale_factor=2) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
|
|
DeconvModule(64, 32, kernel_size=3, scale_factor=2) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
|
|
DeconvModule(64, 32, kernel_size=5, scale_factor=4) |
|
|
|
|
|
block = DeconvModule(64, 32, with_cp=True) |
|
assert block.with_cp |
|
x = torch.randn(1, 64, 128, 128, requires_grad=True) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
block = DeconvModule(64, 32, with_cp=False) |
|
assert not block.with_cp |
|
x = torch.randn(1, 64, 128, 128) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
|
x = torch.randn(1, 64, 64, 64) |
|
block = DeconvModule(64, 32, kernel_size=2, scale_factor=2) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 32, 128, 128]) |
|
|
|
block = DeconvModule(64, 32, kernel_size=6, scale_factor=2) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 32, 128, 128]) |
|
|
|
|
|
x = torch.randn(1, 64, 64, 64) |
|
block = DeconvModule(64, 32, kernel_size=4, scale_factor=4) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
block = DeconvModule(64, 32, kernel_size=6, scale_factor=4) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
|
def test_interp_conv(): |
|
|
|
block = InterpConv(64, 32, with_cp=True) |
|
assert block.with_cp |
|
x = torch.randn(1, 64, 128, 128, requires_grad=True) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
block = InterpConv(64, 32, with_cp=False) |
|
assert not block.with_cp |
|
x = torch.randn(1, 64, 128, 128) |
|
x_out = block(x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
|
block = InterpConv(64, 32, conv_first=False) |
|
x = torch.randn(1, 64, 128, 128) |
|
x_out = block(x) |
|
assert isinstance(block.interp_upsample[0], nn.Upsample) |
|
assert isinstance(block.interp_upsample[1], ConvModule) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
|
block = InterpConv(64, 32, conv_first=True) |
|
x = torch.randn(1, 64, 128, 128) |
|
x_out = block(x) |
|
assert isinstance(block.interp_upsample[0], ConvModule) |
|
assert isinstance(block.interp_upsample[1], nn.Upsample) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
|
block = InterpConv( |
|
64, |
|
32, |
|
conv_first=False, |
|
upsampe_cfg=dict(scale_factor=2, mode='bilinear', align_corners=False)) |
|
x = torch.randn(1, 64, 128, 128) |
|
x_out = block(x) |
|
assert isinstance(block.interp_upsample[0], nn.Upsample) |
|
assert isinstance(block.interp_upsample[1], ConvModule) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
assert block.interp_upsample[0].mode == 'bilinear' |
|
|
|
|
|
block = InterpConv( |
|
64, |
|
32, |
|
conv_first=False, |
|
upsampe_cfg=dict(scale_factor=2, mode='nearest')) |
|
x = torch.randn(1, 64, 128, 128) |
|
x_out = block(x) |
|
assert isinstance(block.interp_upsample[0], nn.Upsample) |
|
assert isinstance(block.interp_upsample[1], ConvModule) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
assert block.interp_upsample[0].mode == 'nearest' |
|
|
|
|
|
def test_up_conv_block(): |
|
with pytest.raises(AssertionError): |
|
|
|
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) |
|
UpConvBlock(BasicConvBlock, 64, 32, 32, dcn=dcn) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
plugins = [ |
|
dict( |
|
cfg=dict(type='ContextBlock', ratio=1. / 16), |
|
position='after_conv3') |
|
] |
|
UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
plugins = [ |
|
dict( |
|
cfg=dict( |
|
type='GeneralizedAttention', |
|
spatial_range=-1, |
|
num_heads=8, |
|
attention_type='0010', |
|
kv_stride=2), |
|
position='after_conv2') |
|
] |
|
UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins) |
|
|
|
|
|
block = UpConvBlock(BasicConvBlock, 64, 32, 32, with_cp=True) |
|
skip_x = torch.randn(1, 32, 256, 256, requires_grad=True) |
|
x = torch.randn(1, 64, 128, 128, requires_grad=True) |
|
x_out = block(skip_x, x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
|
|
|
block = UpConvBlock( |
|
BasicConvBlock, 64, 32, 32, upsample_cfg=dict(type='InterpConv')) |
|
skip_x = torch.randn(1, 32, 256, 256) |
|
x = torch.randn(1, 64, 128, 128) |
|
x_out = block(skip_x, x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
|
|
|
block = UpConvBlock(BasicConvBlock, 64, 32, 32, upsample_cfg=None) |
|
skip_x = torch.randn(1, 32, 256, 256) |
|
x = torch.randn(1, 64, 256, 256) |
|
x_out = block(skip_x, x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
|
|
|
block = UpConvBlock( |
|
BasicConvBlock, |
|
64, |
|
32, |
|
32, |
|
upsample_cfg=dict( |
|
type='InterpConv', |
|
upsampe_cfg=dict( |
|
scale_factor=2, mode='bilinear', align_corners=False))) |
|
skip_x = torch.randn(1, 32, 256, 256) |
|
x = torch.randn(1, 64, 128, 128) |
|
x_out = block(skip_x, x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
|
|
|
block = UpConvBlock( |
|
BasicConvBlock, |
|
64, |
|
32, |
|
32, |
|
upsample_cfg=dict(type='DeconvModule', kernel_size=4, scale_factor=2)) |
|
skip_x = torch.randn(1, 32, 256, 256) |
|
x = torch.randn(1, 64, 128, 128) |
|
x_out = block(skip_x, x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
|
block = UpConvBlock( |
|
conv_block=BasicConvBlock, |
|
in_channels=64, |
|
skip_channels=32, |
|
out_channels=32, |
|
num_convs=3, |
|
dilation=3, |
|
upsample_cfg=dict( |
|
type='InterpConv', |
|
upsampe_cfg=dict( |
|
scale_factor=2, mode='bilinear', align_corners=False))) |
|
skip_x = torch.randn(1, 32, 256, 256) |
|
x = torch.randn(1, 64, 128, 128) |
|
x_out = block(skip_x, x) |
|
assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
assert block.conv_block.convs[0].conv.in_channels == 64 |
|
assert block.conv_block.convs[0].conv.out_channels == 32 |
|
assert block.conv_block.convs[0].conv.kernel_size == (3, 3) |
|
assert block.conv_block.convs[0].conv.dilation == (1, 1) |
|
assert block.conv_block.convs[0].conv.padding == (1, 1) |
|
|
|
assert block.conv_block.convs[1].conv.in_channels == 32 |
|
assert block.conv_block.convs[1].conv.out_channels == 32 |
|
assert block.conv_block.convs[1].conv.kernel_size == (3, 3) |
|
assert block.conv_block.convs[1].conv.dilation == (3, 3) |
|
assert block.conv_block.convs[1].conv.padding == (3, 3) |
|
|
|
assert block.conv_block.convs[2].conv.in_channels == 32 |
|
assert block.conv_block.convs[2].conv.out_channels == 32 |
|
assert block.conv_block.convs[2].conv.kernel_size == (3, 3) |
|
assert block.conv_block.convs[2].conv.dilation == (3, 3) |
|
assert block.conv_block.convs[2].conv.padding == (3, 3) |
|
|
|
assert block.upsample.interp_upsample[1].conv.in_channels == 64 |
|
assert block.upsample.interp_upsample[1].conv.out_channels == 32 |
|
assert block.upsample.interp_upsample[1].conv.kernel_size == (1, 1) |
|
assert block.upsample.interp_upsample[1].conv.dilation == (1, 1) |
|
assert block.upsample.interp_upsample[1].conv.padding == (0, 0) |
|
|
|
|
|
def test_unet(): |
|
with pytest.raises(AssertionError): |
|
|
|
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) |
|
UNet(3, 64, 5, dcn=dcn) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
plugins = [ |
|
dict( |
|
cfg=dict(type='ContextBlock', ratio=1. / 16), |
|
position='after_conv3') |
|
] |
|
UNet(3, 64, 5, plugins=plugins) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
plugins = [ |
|
dict( |
|
cfg=dict( |
|
type='GeneralizedAttention', |
|
spatial_range=-1, |
|
num_heads=8, |
|
attention_type='0010', |
|
kv_stride=2), |
|
position='after_conv2') |
|
] |
|
UNet(3, 64, 5, plugins=plugins) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=4, |
|
strides=(1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2), |
|
downsamples=(True, True, True), |
|
enc_dilations=(1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1)) |
|
x = torch.randn(2, 3, 65, 65) |
|
unet(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
x = torch.randn(2, 3, 65, 65) |
|
unet(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
x = torch.randn(2, 3, 65, 65) |
|
unet(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 2, 2, 2, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
x = torch.randn(2, 3, 65, 65) |
|
unet(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=6, |
|
strides=(1, 1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2, 2), |
|
downsamples=(True, True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1, 1)) |
|
x = torch.randn(2, 3, 65, 65) |
|
unet(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
x = torch.randn(2, 3, 64, 64) |
|
unet(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
x = torch.randn(2, 3, 64, 64) |
|
unet(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
x = torch.randn(2, 3, 64, 64) |
|
unet(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
x = torch.randn(2, 3, 64, 64) |
|
unet(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
x = torch.randn(2, 3, 64, 64) |
|
unet(x) |
|
|
|
with pytest.raises(AssertionError): |
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1, 1)) |
|
x = torch.randn(2, 3, 64, 64) |
|
unet(x) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1), |
|
norm_eval=True) |
|
unet.train() |
|
assert check_norm_state(unet.modules(), False) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1), |
|
norm_eval=False) |
|
unet.train() |
|
assert check_norm_state(unet.modules(), True) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
|
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
|
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 2, 2, 2, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
|
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, False, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
|
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 2, 2, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, False, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
|
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
|
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, False, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
|
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, False, False, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
|
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 64, 64]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 64, 64]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 64, 64]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 1, 1, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(False, False, False, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
|
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 128, 128]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 128, 128]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 128, 128]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 128, 128]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 2, 2, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, True), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
print(unet) |
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 2, 2, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
print(unet) |
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 2, 2, 2, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, True, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
print(unet) |
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 2, 2, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, False, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
print(unet) |
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
|
|
|
unet = UNet( |
|
in_channels=3, |
|
base_channels=64, |
|
num_stages=5, |
|
strides=(1, 2, 2, 1, 1), |
|
enc_num_convs=(2, 2, 2, 2, 2), |
|
dec_num_convs=(2, 2, 2, 2), |
|
downsamples=(True, True, False, False), |
|
enc_dilations=(1, 1, 1, 1, 1), |
|
dec_dilations=(1, 1, 1, 1)) |
|
unet.init_weights(pretrained=None) |
|
print(unet) |
|
x = torch.randn(2, 3, 128, 128) |
|
x_outs = unet(x) |
|
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
|
assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
|
assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
|
assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
|
assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|