File size: 405 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from rstor.architecture.nafnet import NAFNet


def test_nafnet():
    enc_blks = [1, 1]
    middle_blk_num = 1
    dec_blks = [1, 2]

    model = NAFNet(
        img_channel=3,
        width=2,
        middle_blk_num=middle_blk_num,
        enc_blk_nums=enc_blks,
        dec_blk_nums=dec_blks,
    )
    x = torch.rand(2, 3, 128, 128)
    y = model(x)
    assert y.shape == (2, 3, 128, 128)