File size: 1,810 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import unittest

import torch

from training.loss import MultiResolutionSTFTLoss


class TestMultiResolutionSTFTLoss(unittest.TestCase):
    def setUp(self):
        torch.random.manual_seed(0)
        self.loss_fn = MultiResolutionSTFTLoss([(1024, 120, 600), (2048, 240, 1200)])

        self.x = torch.randn(
            4,
            16000,
        )
        self.y = torch.randn(
            4,
            16000,
        )

    def test_multi_resolution_stft_loss(self):
        # Test the MultiResolutionSTFTLoss class with random input tensors

        sc_loss, mag_loss = self.loss_fn(self.x, self.y)

        self.assertIsInstance(sc_loss, torch.Tensor)
        self.assertEqual(sc_loss.shape, torch.Size([]))
        self.assertIsInstance(mag_loss, torch.Tensor)
        self.assertEqual(mag_loss.shape, torch.Size([]))

    def test_multi_resolution_stft_loss_nonzero(self):
        # Test the MultiResolutionSTFTLoss class with input tensors that have a non-zero loss value
        torch.manual_seed(0)

        x = torch.randn(
            4,
            16000,
        )
        y = torch.randn(
            4,
            16000,
        )

        sc_loss, mag_loss = self.loss_fn(x, y)

        self.assertIsInstance(sc_loss, torch.Tensor)
        self.assertEqual(sc_loss.shape, torch.Size([]))
        self.assertIsInstance(mag_loss, torch.Tensor)
        self.assertEqual(mag_loss.shape, torch.Size([]))

        expected_sc_loss = torch.tensor(
            0.6571,
        )
        self.assertTrue(torch.allclose(sc_loss, expected_sc_loss, atol=1e-4))

        expected_mag_loss = torch.tensor(
            0.7007,
        )
        self.assertTrue(
            torch.allclose(mag_loss, expected_mag_loss, rtol=1e-4, atol=1e-4),
        )


if __name__ == "__main__":
    unittest.main()