import unittest import torch from models.helpers import ( stride_lens_downsampling, ) class TestStrideLens(unittest.TestCase): def test_stride_lens(self): # Define test case inputs input_lengths = torch.tensor([5, 7, 10, 12]) stride = 2 # Correct output for this would be ceil([5, 7, 10, 12] / 2) => [3, 4, 5, 6] expected_output = torch.tensor([3, 4, 5, 6]) # Call the function with the test cases output = stride_lens_downsampling(input_lengths, stride) # Check if the output is a tensor self.assertIsInstance(output, torch.Tensor) # Check if the output shape is as expected self.assertEqual(output.shape, expected_output.shape) # Check if the output values are as expected self.assertTrue(torch.all(output.eq(expected_output))) def test_stride_lens_default_stride(self): # Define test case inputs. Here, we do not provide the stride. input_lengths = torch.tensor([10, 20, 4, 11]) # Correct output for this would be ceil([10, 20, 4, 11] / 2) => [5, 10, 2, 6] expected_output = torch.tensor([5, 10, 2, 6]) # Call the function with the test cases output = stride_lens_downsampling(input_lengths) # Check if the output is a tensor self.assertIsInstance(output, torch.Tensor) # Check if the output shape is as expected self.assertEqual(output.shape, expected_output.shape) # Check if the output values are as expected self.assertTrue(torch.all(output.eq(expected_output))) if __name__ == "__main__": unittest.main()