wavlm-large / s3prl_s3prl_main /test /test_beam_decoder.py
lmzjms's picture
Upload 1162 files
0b32ad6 verified
import pytest
import torch
from s3prl.nn import BeamDecoder
@pytest.mark.extra_dependency
def test_beam_decoder():
decoder = BeamDecoder()
emissions = torch.randn((4, 100, 31))
emissions = torch.log_softmax(emissions, dim=2)
hyps = decoder.decode(emissions)