Spaces:
Running
on
L4
Running
on
L4
# -*- coding: utf-8 -*- | |
# File : unittest.py | |
# Author : Jiayuan Mao | |
# Email : [email protected] | |
# Date : 27/01/2018 | |
# | |
# This file is part of Synchronized-BatchNorm-PyTorch. | |
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch | |
# Distributed under MIT License. | |
import unittest | |
import numpy as np | |
from torch.autograd import Variable | |
def as_numpy(v): | |
if isinstance(v, Variable): | |
v = v.data | |
return v.cpu().numpy() | |
class TorchTestCase(unittest.TestCase): | |
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): | |
npa, npb = as_numpy(a), as_numpy(b) | |
self.assertTrue( | |
np.allclose(npa, npb, atol=atol), | |
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) | |
) | |