File size: 906 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from rstor.architecture.stacked_convolutions import StackedConvolutions
from rstor.properties import RELU


def test_stacked_convolutions():
    # Test case 1: Default parameters
    model = StackedConvolutions()
    assert isinstance(model, torch.nn.Module)

    # Test case 2: Number of layers is not even
    try:
        model = StackedConvolutions(num_layers=7)
        assert False, "Expected AssertionError"
    except AssertionError:
        pass

    # Test case 3: Custom parameters
    n, c, h, w = 1, 3, 64, 64
    model = StackedConvolutions(ch_in=c, ch_out=2, h_dim=32, num_layers=4, k_size=5, activation=RELU, bias=False)
    assert isinstance(model, torch.nn.Module)

    # Test case 4: Forward pass
    input_tensor = torch.randn(n, c, h, w)
    output_tensor = model(input_tensor)
    assert model.receptive_field() == (25, 25)
    assert output_tensor.shape == (1, 2, h, w)