|
import torch |
|
|
|
from s3prl.nn import RNNEncoder |
|
|
|
|
|
def test_rnn(helpers): |
|
modules = [ |
|
RNNEncoder( |
|
input_size=8, |
|
output_size=6, |
|
module="LSTM", |
|
hidden_size=[10, 10, 10], |
|
dropout=[0.1, 0.1, 0.1], |
|
layer_norm=[True, True, True], |
|
proj=[True, True, True], |
|
sample_rate=[1, 2, 1], |
|
sample_style="drop", |
|
bidirectional=True, |
|
), |
|
RNNEncoder( |
|
input_size=8, |
|
output_size=6, |
|
module="LSTM", |
|
hidden_size=[10, 10, 10], |
|
dropout=[0.1, 0.1, 0.1], |
|
layer_norm=[True, True, True], |
|
proj=[True, True, True], |
|
sample_rate=[1, 2, 1], |
|
sample_style="concat", |
|
bidirectional=True, |
|
), |
|
] |
|
|
|
for module in modules: |
|
xs = torch.randn(32, 50, module.input_size) |
|
xs_len = torch.arange(32) + (50 - 32) + 1 |
|
|
|
out, out_len = module(xs, xs_len) |
|
assert out.shape[1] == 25 |
|
assert out.shape[2] == module.output_size |
|
assert out_len.max() == 25 |
|
|