Spaces:
Sleeping
Sleeping
File size: 3,375 Bytes
ce7bf5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from unittest import TestCase
import torch
from chroma.layers.norm import MaskedBatchNorm1d
class TestBatchNorm(TestCase):
def test_norm(self):
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
B, C, L = (3, 5, 7)
x1 = torch.randn(B, C, L).to(device)
mean1 = x1.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (B * L)
var1 = ((x1 - mean1) ** 2).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (
B * L
)
x2 = torch.randn(B, C, L).to(device)
mean2 = x2.sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (B * L)
var2 = ((x2 - mean2) ** 2).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / (
B * L
)
mbn = MaskedBatchNorm1d(C)
mbn = mbn.to(device)
# Test without mask in train
mbn.train()
out = mbn(x1)
self.assertTrue(mean1.allclose(mbn.running_mean))
self.assertTrue(var1.allclose(mbn.running_var))
normed = (x1 - mean1) / torch.sqrt(var1 + mbn.eps) * mbn.weight + mbn.bias
self.assertTrue(normed.allclose(out))
out = mbn(x2)
normed = (x2 - mean2) / torch.sqrt(var2 + mbn.eps) * mbn.weight + mbn.bias
self.assertTrue(normed.allclose(out))
self.assertTrue(
mbn.running_mean.allclose((1 - mbn.momentum) * mean1 + mbn.momentum * mean2)
)
self.assertTrue(
mbn.running_var.allclose((1 - mbn.momentum) * var1 + mbn.momentum * var2)
)
# Without mask in eval
mbn.eval()
out = mbn(x1)
self.assertTrue(
mbn.running_mean.allclose((1 - mbn.momentum) * mean1 + mbn.momentum * mean2)
)
self.assertTrue(
mbn.running_var.allclose((1 - mbn.momentum) * var1 + mbn.momentum * var2)
)
normed = (x1 - mbn.running_mean) / torch.sqrt(
mbn.running_var + mbn.eps
) * mbn.weight + mbn.bias
self.assertTrue(normed.allclose(out))
# Check that masking with all ones doesn't change values
mask = x1.new_ones((B, 1, L))
outm = mbn(x1, input_mask=mask)
self.assertTrue(outm.allclose(out))
mbn.eval()
out = mbn(x2)
outm = mbn(x2, input_mask=mask)
self.assertTrue(outm.allclose(out))
# With mask in train
mask = torch.randn(B, 1, L)
mask = mask > 0.0
mask = mask.to(device)
n = mask.sum()
mean1 = (x1 * mask).sum(dim=0, keepdim=True).sum(dim=2, keepdim=True) / n
var1 = (((x1 * mask) - mean1) ** 2).sum(dim=0, keepdim=True).sum(
dim=2, keepdim=True
) / n
mbn = MaskedBatchNorm1d(C)
mbn = mbn.to(device)
mbn.train()
out = mbn(x1, input_mask=mask)
self.assertTrue(mean1.allclose(mbn.running_mean))
self.assertTrue(var1.allclose(mbn.running_var))
normed = (x1 * mask - mean1) / torch.sqrt(
var1 + mbn.eps
) * mbn.weight + mbn.bias
self.assertTrue(normed.allclose(out))
# With mask in eval
mbn.eval()
out = mbn(x1, input_mask=mask)
normed = (x1 * mask - mbn.running_mean) / torch.sqrt(
mbn.running_var + mbn.eps
) * mbn.weight + mbn.bias
self.assertTrue(normed.allclose(out))
|