|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for official.benchmark.owner_utils.""" |
|
|
|
from absl.testing import absltest |
|
|
|
from official.benchmark import owner_utils |
|
|
|
|
|
@owner_utils.Owner('static_owner') |
|
def static_function(foo=5): |
|
return foo |
|
|
|
|
|
def static_function_without_owner(foo=5): |
|
return foo |
|
|
|
|
|
class BenchmarkClassWithoutOwner: |
|
|
|
def method_without_owner(self): |
|
return 100 |
|
|
|
@owner_utils.Owner('method_owner') |
|
def method_with_owner(self): |
|
return 200 |
|
|
|
|
|
@owner_utils.Owner('class_owner') |
|
class SomeBenchmarkClass: |
|
|
|
def method_inherited_owner(self): |
|
return 123 |
|
|
|
@owner_utils.Owner('method_owner') |
|
def method_override_owner(self): |
|
return 345 |
|
|
|
|
|
@owner_utils.Owner('new_class_owner') |
|
class InheritedClass(SomeBenchmarkClass): |
|
|
|
def method_inherited_owner(self): |
|
return 456 |
|
|
|
@owner_utils.Owner('new_method_owner') |
|
def method_override_owner(self): |
|
return 567 |
|
|
|
|
|
class OwnerUtilsTest(absltest.TestCase): |
|
"""Tests to assert for owner decorator functionality.""" |
|
|
|
def test_owner_tag_missing(self): |
|
self.assertEqual(None, owner_utils.GetOwner(static_function_without_owner)) |
|
|
|
benchmark_class = BenchmarkClassWithoutOwner() |
|
self.assertEqual(None, |
|
owner_utils.GetOwner(benchmark_class.method_without_owner)) |
|
self.assertEqual(100, benchmark_class.method_without_owner()) |
|
|
|
self.assertEqual('method_owner', |
|
owner_utils.GetOwner(benchmark_class.method_with_owner)) |
|
self.assertEqual(200, benchmark_class.method_with_owner()) |
|
|
|
def test_owner_attributes_static(self): |
|
self.assertEqual('static_owner', owner_utils.GetOwner(static_function)) |
|
self.assertEqual(5, static_function(5)) |
|
|
|
def test_owner_attributes_per_class(self): |
|
level1 = SomeBenchmarkClass() |
|
self.assertEqual('class_owner', |
|
owner_utils.GetOwner(level1.method_inherited_owner)) |
|
self.assertEqual(123, level1.method_inherited_owner()) |
|
|
|
self.assertEqual('method_owner', |
|
owner_utils.GetOwner(level1.method_override_owner)) |
|
self.assertEqual(345, level1.method_override_owner()) |
|
|
|
def test_owner_attributes_inherited_class(self): |
|
level2 = InheritedClass() |
|
self.assertEqual('new_class_owner', |
|
owner_utils.GetOwner(level2.method_inherited_owner)) |
|
self.assertEqual(456, level2.method_inherited_owner()) |
|
|
|
self.assertEqual('new_method_owner', |
|
owner_utils.GetOwner(level2.method_override_owner)) |
|
self.assertEqual(567, level2.method_override_owner()) |
|
|
|
|
|
if __name__ == '__main__': |
|
absltest.main() |
|
|