image-deblurring / test /test_unet.py
balthou's picture
initiate demo
cec5823
raw
history blame contribute delete
799 Bytes
import torch
from rstor.architecture.nafnet import UNet
from rstor.properties import LEAKY_RELU
def test_unet():
enc_blks = [1, 2]
middle_blk_num = 2
dec_blks = [2, 1]
model = UNet(
img_channel=3,
width=2,
activation=LEAKY_RELU,
# We need leaky relu ...
# otherwise it seems like ReLU may block propagation of NaN (with zeros!)
# NaN and ReLu do not work correctly for receptive field estimation technique
middle_blk_num=middle_blk_num,
enc_blk_nums=enc_blks,
dec_blk_nums=dec_blks,
)
rx, ry = model.receptive_field(channels=3)
assert rx == ry
assert rx == 44, "Receptive field should be {rx} x {ry}"
x = torch.rand(2, 3, 128, 128)
y = model(x)
assert y.shape == (2, 3, 128, 128)