import unittest import torch from training.loss import STFTLoss class TestSTFTLoss(unittest.TestCase): def test_stft_loss(self): torch.random.manual_seed(0) # Test the STFT loss function with random input tensors loss_fn = STFTLoss() x = torch.randn( 4, 16000, ) y = torch.randn( 4, 16000, ) sc_loss, mag_loss = loss_fn(x, y) self.assertIsInstance(sc_loss, torch.Tensor) self.assertEqual(sc_loss.shape, torch.Size([])) self.assertIsInstance(mag_loss, torch.Tensor) self.assertEqual(mag_loss.shape, torch.Size([])) def test_stft_loss_nonzero(self): # Test the STFT loss function with non-zero loss loss_fn = STFTLoss() # Reproducibility torch.manual_seed(0) x_mag = torch.randn(4, 16000, dtype=torch.float32) y_mag = torch.randn(4, 16000, dtype=torch.float32) sc_loss, mag_loss = loss_fn(x_mag, y_mag) self.assertIsInstance(sc_loss, torch.Tensor) self.assertEqual(sc_loss.shape, torch.Size([])) self.assertIsInstance(mag_loss, torch.Tensor) self.assertEqual(mag_loss.shape, torch.Size([])) self.assertGreater(sc_loss, 0.0) self.assertGreater(mag_loss, 0.0) expected_sc = torch.tensor(0.6559) self.assertTrue(torch.allclose(sc_loss, expected_sc, rtol=1e-4, atol=1e-4)) expected_mag = torch.tensor(0.6977) self.assertTrue(torch.allclose(mag_loss, expected_mag, rtol=1e-4, atol=1e-4)) if __name__ == "__main__": unittest.main()