File size: 1,634 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import unittest

import torch

from training.loss import STFTLoss


class TestSTFTLoss(unittest.TestCase):
    def test_stft_loss(self):
        torch.random.manual_seed(0)
        # Test the STFT loss function with random input tensors
        loss_fn = STFTLoss()

        x = torch.randn(
            4,
            16000,
        )
        y = torch.randn(
            4,
            16000,
        )

        sc_loss, mag_loss = loss_fn(x, y)

        self.assertIsInstance(sc_loss, torch.Tensor)
        self.assertEqual(sc_loss.shape, torch.Size([]))
        self.assertIsInstance(mag_loss, torch.Tensor)
        self.assertEqual(mag_loss.shape, torch.Size([]))

    def test_stft_loss_nonzero(self):
        # Test the STFT loss function with non-zero loss
        loss_fn = STFTLoss()

        # Reproducibility
        torch.manual_seed(0)

        x_mag = torch.randn(4, 16000, dtype=torch.float32)
        y_mag = torch.randn(4, 16000, dtype=torch.float32)

        sc_loss, mag_loss = loss_fn(x_mag, y_mag)

        self.assertIsInstance(sc_loss, torch.Tensor)
        self.assertEqual(sc_loss.shape, torch.Size([]))

        self.assertIsInstance(mag_loss, torch.Tensor)
        self.assertEqual(mag_loss.shape, torch.Size([]))

        self.assertGreater(sc_loss, 0.0)
        self.assertGreater(mag_loss, 0.0)

        expected_sc = torch.tensor(0.6559)
        self.assertTrue(torch.allclose(sc_loss, expected_sc, rtol=1e-4, atol=1e-4))

        expected_mag = torch.tensor(0.6977)
        self.assertTrue(torch.allclose(mag_loss, expected_mag, rtol=1e-4, atol=1e-4))


if __name__ == "__main__":
    unittest.main()