File size: 679 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import unittest

import torch

from training.loss.stft import stft


class TestSTFT(unittest.TestCase):
    def test_stft(self):
        # Test the STFT function with a random input signal
        x = torch.randn(4, 16384)
        fft_size = 1024
        hop_size = 256
        win_length = 1024
        window = torch.hann_window(win_length)
        output = stft(x, fft_size, hop_size, win_length, window)
        self.assertEqual(output.shape[0], 4)
        self.assertEqual(output.shape[2], fft_size // 2 + 1)
        self.assertEqual(
            output.shape[1], (16384 - win_length) // hop_size + x.shape[0] + 1,
        )


if __name__ == "__main__":
    unittest.main()