File size: 277 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import pytest
import torch
from s3prl.nn import BeamDecoder
@pytest.mark.extra_dependency
def test_beam_decoder():
decoder = BeamDecoder()
emissions = torch.randn((4, 100, 31))
emissions = torch.log_softmax(emissions, dim=2)
hyps = decoder.decode(emissions)
|