File size: 799 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from rstor.architecture.nafnet import UNet
from rstor.properties import LEAKY_RELU


def test_unet():
    enc_blks = [1, 2]
    middle_blk_num = 2
    dec_blks = [2, 1]

    model = UNet(
        img_channel=3,
        width=2,
        activation=LEAKY_RELU,
        # We need leaky relu ...
        # otherwise it seems like ReLU may block propagation of NaN (with zeros!)
        # NaN and ReLu do not work correctly for receptive field estimation technique
        middle_blk_num=middle_blk_num,
        enc_blk_nums=enc_blks,
        dec_blk_nums=dec_blks,
    )
    rx, ry = model.receptive_field(channels=3)
    assert rx == ry
    assert rx == 44, "Receptive field should be {rx} x {ry}"
    x = torch.rand(2, 3, 128, 128)
    y = model(x)
    assert y.shape == (2, 3, 128, 128)