import pytest import torch from s3prl.nn.common import UtteranceLevel from s3prl.nn.pooling import ( AttentiveStatisticsPooling, MeanPooling, SelfAttentivePooling, TemporalStatisticsPooling, ) @pytest.mark.parametrize( "pooling_type", [ "MeanPooling", "TemporalStatisticsPooling", "AttentiveStatisticsPooling", "SelfAttentivePooling", ], ) def test_utterance_level_with_pooling(pooling_type: str): model = UtteranceLevel(256, 64, [128], "ReLU", None, pooling_type, None) output = model(torch.randn(32, 100, 256), torch.arange(32) + 1) assert output.shape == (32, 64)