Spaces:
Running
Running
import torch | |
from rstor.architecture.nafnet import UNet | |
from rstor.properties import LEAKY_RELU | |
def test_unet(): | |
enc_blks = [1, 2] | |
middle_blk_num = 2 | |
dec_blks = [2, 1] | |
model = UNet( | |
img_channel=3, | |
width=2, | |
activation=LEAKY_RELU, | |
# We need leaky relu ... | |
# otherwise it seems like ReLU may block propagation of NaN (with zeros!) | |
# NaN and ReLu do not work correctly for receptive field estimation technique | |
middle_blk_num=middle_blk_num, | |
enc_blk_nums=enc_blks, | |
dec_blk_nums=dec_blks, | |
) | |
rx, ry = model.receptive_field(channels=3) | |
assert rx == ry | |
assert rx == 44, "Receptive field should be {rx} x {ry}" | |
x = torch.rand(2, 3, 128, 128) | |
y = model(x) | |
assert y.shape == (2, 3, 128, 128) | |