|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Tests for distribution util functions.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import tensorflow.compat.v2 as tf |
|
|
|
from official.utils.misc import distribution_utils |
|
|
|
|
|
class GetDistributionStrategyTest(tf.test.TestCase): |
|
"""Tests for get_distribution_strategy.""" |
|
def test_one_device_strategy_cpu(self): |
|
ds = distribution_utils.get_distribution_strategy(num_gpus=0) |
|
self.assertEquals(ds.num_replicas_in_sync, 1) |
|
self.assertEquals(len(ds.extended.worker_devices), 1) |
|
self.assertIn('CPU', ds.extended.worker_devices[0]) |
|
|
|
def test_one_device_strategy_gpu(self): |
|
ds = distribution_utils.get_distribution_strategy(num_gpus=1) |
|
self.assertEquals(ds.num_replicas_in_sync, 1) |
|
self.assertEquals(len(ds.extended.worker_devices), 1) |
|
self.assertIn('GPU', ds.extended.worker_devices[0]) |
|
|
|
def test_mirrored_strategy(self): |
|
ds = distribution_utils.get_distribution_strategy(num_gpus=5) |
|
self.assertEquals(ds.num_replicas_in_sync, 5) |
|
self.assertEquals(len(ds.extended.worker_devices), 5) |
|
for device in ds.extended.worker_devices: |
|
self.assertIn('GPU', device) |
|
|
|
|
|
if __name__ == "__main__": |
|
tf.test.main() |
|
|