Spaces:
Runtime error
Runtime error
#! /usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# File : test_numeric_batchnorm_v2.py | |
# Author : Jiayuan Mao | |
# Email : [email protected] | |
# Date : 11/01/2018 | |
# | |
# Distributed under terms of the MIT license. | |
""" | |
Test the numerical implementation of batch normalization. | |
Author: acgtyrant. | |
See also: https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 | |
""" | |
import unittest | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from sync_batchnorm.unittest import TorchTestCase | |
from sync_batchnorm.batchnorm_reimpl import BatchNorm2dReimpl | |
class NumericTestCasev2(TorchTestCase): | |
def testNumericBatchNorm(self): | |
CHANNELS = 16 | |
batchnorm1 = nn.BatchNorm2d(CHANNELS, momentum=1) | |
optimizer1 = optim.SGD(batchnorm1.parameters(), lr=0.01) | |
batchnorm2 = BatchNorm2dReimpl(CHANNELS, momentum=1) | |
batchnorm2.weight.data.copy_(batchnorm1.weight.data) | |
batchnorm2.bias.data.copy_(batchnorm1.bias.data) | |
optimizer2 = optim.SGD(batchnorm2.parameters(), lr=0.01) | |
for _ in range(100): | |
input_ = torch.rand(16, CHANNELS, 16, 16) | |
input1 = input_.clone().requires_grad_(True) | |
output1 = batchnorm1(input1) | |
output1.sum().backward() | |
optimizer1.step() | |
input2 = input_.clone().requires_grad_(True) | |
output2 = batchnorm2(input2) | |
output2.sum().backward() | |
optimizer2.step() | |
self.assertTensorClose(input1, input2) | |
self.assertTensorClose(output1, output2) | |
self.assertTensorClose(input1.grad, input2.grad) | |
self.assertTensorClose(batchnorm1.weight.grad, batchnorm2.weight.grad) | |
self.assertTensorClose(batchnorm1.bias.grad, batchnorm2.bias.grad) | |
self.assertTensorClose(batchnorm1.running_mean, batchnorm2.running_mean) | |
self.assertTensorClose(batchnorm2.running_mean, batchnorm2.running_mean) | |
if __name__ == '__main__': | |
unittest.main() | |