from unittest import TestCase import torch from chroma.layers.norm import MaskedBatchNorm1d class TestBatchNorm(TestCase): def test_norm(self): device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) B, C, L = (3, 5, 7) x1 = torch.randn(B, C, L).to(device) mean1 = x1.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (B * L) var1 = ((x1 - mean1) ** 2).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / ( B * L ) x2 = torch.randn(B, C, L).to(device) mean2 = x2.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (B * L) var2 = ((x2 - mean2) ** 2).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / ( B * L ) mbn = MaskedBatchNorm1d(C) mbn = mbn.to(device) # Test without mask in train mbn.train() out = mbn(x1) self.assertTrue(mean1.allclose(mbn.running_mean)) self.assertTrue(var1.allclose(mbn.running_var)) normed = (x1 - mean1) / torch.sqrt(var1 + mbn.eps) * mbn.weight + mbn.bias self.assertTrue(normed.allclose(out)) out = mbn(x2) normed = (x2 - mean2) / torch.sqrt(var2 + mbn.eps) * mbn.weight + mbn.bias self.assertTrue(normed.allclose(out)) self.assertTrue( mbn.running_mean.allclose((1 - mbn.momentum) * mean1 + mbn.momentum * mean2) ) self.assertTrue( mbn.running_var.allclose((1 - mbn.momentum) * var1 + mbn.momentum * var2) ) # Without mask in eval mbn.eval() out = mbn(x1) self.assertTrue( mbn.running_mean.allclose((1 - mbn.momentum) * mean1 + mbn.momentum * mean2) ) self.assertTrue( mbn.running_var.allclose((1 - mbn.momentum) * var1 + mbn.momentum * var2) ) normed = (x1 - mbn.running_mean) / torch.sqrt( mbn.running_var + mbn.eps ) * mbn.weight + mbn.bias self.assertTrue(normed.allclose(out)) # Check that masking with all ones doesn't change values mask = x1.new_ones((B, 1, L)) outm = mbn(x1, input_mask=mask) self.assertTrue(outm.allclose(out)) mbn.eval() out = mbn(x2) outm = mbn(x2, input_mask=mask) self.assertTrue(outm.allclose(out)) # With mask in train mask = torch.randn(B, 1, L) mask = mask > 0.0 mask = mask.to(device) n = mask.sum() mean1 = (x1 * mask).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / n var1 = (((x1 * mask) - mean1) ** 2).sum(dim=0, keepdim=True).sum( dim=2, keepdim=True ) / n mbn = MaskedBatchNorm1d(C) mbn = mbn.to(device) mbn.train() out = mbn(x1, input_mask=mask) self.assertTrue(mean1.allclose(mbn.running_mean)) self.assertTrue(var1.allclose(mbn.running_var)) normed = (x1 * mask - mean1) / torch.sqrt( var1 + mbn.eps ) * mbn.weight + mbn.bias self.assertTrue(normed.allclose(out)) # With mask in eval mbn.eval() out = mbn(x1, input_mask=mask) normed = (x1 * mask - mbn.running_mean) / torch.sqrt( mbn.running_var + mbn.eps ) * mbn.weight + mbn.bias self.assertTrue(normed.allclose(out))