Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import pickle | |
import textwrap | |
import unittest | |
from dataclasses import dataclass, field, is_dataclass | |
from enum import Enum | |
from typing import Any, Dict, List, Optional, Tuple | |
from unittest.mock import Mock | |
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError | |
from pytorch3d.implicitron.tools.config import ( | |
_get_type_to_process, | |
_is_actually_dataclass, | |
_ProcessType, | |
_Registry, | |
Configurable, | |
enable_get_default_args, | |
expand_args_fields, | |
get_default_args, | |
get_default_args_field, | |
registry, | |
remove_unused_components, | |
ReplaceableBase, | |
run_auto_creation, | |
) | |
class Animal(ReplaceableBase): | |
pass | |
class Fruit(ReplaceableBase): | |
pass | |
class Banana(Fruit): | |
pips: int | |
spots: int | |
bananame: str | |
class Pear(Fruit): | |
n_pips: int = 13 | |
class Pineapple(Fruit): | |
pass | |
class Orange(Fruit): | |
pass | |
class Kiwi(Fruit): | |
pass | |
class LargePear(Pear): | |
pass | |
class BoringConfigurable(Configurable): | |
pass | |
class MainTest(Configurable): | |
the_fruit: Fruit | |
n_ids: int | |
n_reps: int = 8 | |
the_second_fruit: Fruit | |
def create_the_second_fruit(self): | |
expand_args_fields(Pineapple) | |
self.the_second_fruit = Pineapple() | |
def __post_init__(self): | |
run_auto_creation(self) | |
class TestConfig(unittest.TestCase): | |
def test_is_actually_dataclass(self): | |
class A: | |
pass | |
self.assertTrue(_is_actually_dataclass(A)) | |
self.assertTrue(is_dataclass(A)) | |
class B(A): | |
a: int | |
self.assertFalse(_is_actually_dataclass(B)) | |
self.assertTrue(is_dataclass(B)) | |
def test_get_type_to_process(self): | |
gt = _get_type_to_process | |
self.assertIsNone(gt(int)) | |
self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE)) | |
self.assertEqual( | |
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE) | |
) | |
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE)) | |
self.assertEqual( | |
gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE) | |
) | |
self.assertIsNone(gt(Optional[int])) | |
self.assertIsNone(gt(Tuple[Fruit])) | |
self.assertIsNone(gt(Tuple[Fruit, Animal])) | |
self.assertIsNone(gt(Optional[List[int]])) | |
def test_simple_replacement(self): | |
struct = get_default_args(MainTest) | |
struct.n_ids = 9780 | |
struct.the_fruit_Pear_args.n_pips = 3 | |
struct.the_fruit_class_type = "Pear" | |
struct.the_second_fruit_class_type = "Pear" | |
main = MainTest(**struct) | |
self.assertIsInstance(main.the_fruit, Pear) | |
self.assertEqual(main.n_reps, 8) | |
self.assertEqual(main.n_ids, 9780) | |
self.assertEqual(main.the_fruit.n_pips, 3) | |
self.assertIsInstance(main.the_second_fruit, Pineapple) | |
struct2 = get_default_args(MainTest) | |
self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13) | |
self.assertEqual( | |
MainTest._creation_functions, | |
("create_the_fruit", "create_the_second_fruit"), | |
) | |
def test_detect_bases(self): | |
# testing the _base_class_from_class function | |
self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase)) | |
self.assertIsNone(_Registry._base_class_from_class(MainTest)) | |
self.assertIs(_Registry._base_class_from_class(Fruit), Fruit) | |
self.assertIs(_Registry._base_class_from_class(Pear), Fruit) | |
class PricklyPear(Pear): | |
pass | |
self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit) | |
def test_registry_entries(self): | |
self.assertIs(registry.get(Fruit, "Banana"), Banana) | |
with self.assertRaisesRegex(ValueError, "Banana has not been registered."): | |
registry.get(Animal, "Banana") | |
with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."): | |
registry.get(Fruit, "PricklyPear") | |
self.assertIs(registry.get(Pear, "Pear"), Pear) | |
self.assertIs(registry.get(Pear, "LargePear"), LargePear) | |
with self.assertRaisesRegex(ValueError, "Banana resolves to"): | |
registry.get(Pear, "Banana") | |
all_fruit = set(registry.get_all(Fruit)) | |
self.assertIn(Banana, all_fruit) | |
self.assertIn(Pear, all_fruit) | |
self.assertIn(LargePear, all_fruit) | |
self.assertEqual(registry.get_all(Pear), [LargePear]) | |
class Apple(Fruit): | |
pass | |
class CrabApple(Apple): | |
pass | |
self.assertEqual(registry.get_all(Apple), [CrabApple]) | |
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple) | |
with self.assertRaisesRegex(ValueError, "Cannot tell what it is."): | |
class NotAFruit: | |
pass | |
def test_recursion(self): | |
class Shape(ReplaceableBase): | |
pass | |
class Triangle(Shape): | |
a: float = 5.0 | |
class Square(Shape): | |
a: float = 3.0 | |
class LargeShape(Shape): | |
inner: Shape | |
def __post_init__(self): | |
run_auto_creation(self) | |
class ShapeContainer(Configurable): | |
shape: Shape | |
container = ShapeContainer(**get_default_args(ShapeContainer)) | |
# This is because ShapeContainer is missing __post_init__ | |
with self.assertRaises(AttributeError): | |
container.shape | |
class ShapeContainer2(Configurable): | |
x: Shape | |
x_class_type: str = "LargeShape" | |
def __post_init__(self): | |
self.x_LargeShape_args.inner_class_type = "Triangle" | |
run_auto_creation(self) | |
container2_args = get_default_args(ShapeContainer2) | |
container2_args.x_LargeShape_args.inner_Triangle_args.a += 10 | |
self.assertIn("inner_Square_args", container2_args.x_LargeShape_args) | |
# We do not perform expansion that would result in an infinite recursion, | |
# so this member is not present. | |
self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args) | |
container2_args.x_LargeShape_args.inner_Square_args.a += 100 | |
container2 = ShapeContainer2(**container2_args) | |
self.assertIsInstance(container2.x, LargeShape) | |
self.assertIsInstance(container2.x.inner, Triangle) | |
self.assertEqual(container2.x.inner.a, 15.0) | |
def test_simpleclass_member(self): | |
# Members which are not dataclasses are | |
# tolerated. But it would be nice to be able to | |
# configure them. | |
class Foo: | |
def __init__(self, a: Any = 1, b: Any = 2): | |
self.a, self.b = a, b | |
enable_get_default_args(Foo) | |
class Bar: | |
aa: int = 9 | |
bb: int = 9 | |
class Container(Configurable): | |
bar: Bar = Bar() | |
# TODO make this work? | |
# foo: Foo = Foo() | |
fruit: Fruit | |
fruit_class_type: str = "Orange" | |
def __post_init__(self): | |
run_auto_creation(self) | |
self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2}) | |
container_args = get_default_args(Container) | |
container = Container(**container_args) | |
self.assertIsInstance(container.fruit, Orange) | |
self.assertEqual(Container._processed_members, {"fruit": Fruit}) | |
self.assertEqual(container._processed_members, {"fruit": Fruit}) | |
container_defaulted = Container() | |
container_defaulted.fruit_Pear_args.n_pips += 4 | |
container_args2 = get_default_args(Container) | |
container = Container(**container_args2) | |
self.assertEqual(container.fruit_Pear_args.n_pips, 13) | |
def test_inheritance(self): | |
# Also exercises optional replaceables | |
class FruitBowl(ReplaceableBase): | |
main_fruit: Fruit | |
main_fruit_class_type: str = "Orange" | |
def __post_init__(self): | |
raise ValueError("This doesn't get called") | |
class LargeFruitBowl(FruitBowl): | |
extra_fruit: Optional[Fruit] | |
extra_fruit_class_type: str = "Kiwi" | |
no_fruit: Optional[Fruit] | |
no_fruit_class_type: Optional[str] = None | |
def __post_init__(self): | |
run_auto_creation(self) | |
large_args = get_default_args(LargeFruitBowl) | |
self.assertNotIn("extra_fruit", large_args) | |
self.assertNotIn("main_fruit", large_args) | |
large = LargeFruitBowl(**large_args) | |
self.assertIsInstance(large.main_fruit, Orange) | |
self.assertIsInstance(large.extra_fruit, Kiwi) | |
self.assertIsNone(large.no_fruit) | |
self.assertIn("no_fruit_Kiwi_args", large_args) | |
remove_unused_components(large_args) | |
large2 = LargeFruitBowl(**large_args) | |
self.assertIsInstance(large2.main_fruit, Orange) | |
self.assertIsInstance(large2.extra_fruit, Kiwi) | |
self.assertIsNone(large2.no_fruit) | |
needed_args = [ | |
"extra_fruit_Kiwi_args", | |
"extra_fruit_class_type", | |
"main_fruit_Orange_args", | |
"main_fruit_class_type", | |
"no_fruit_class_type", | |
] | |
self.assertEqual(sorted(large_args.keys()), needed_args) | |
with self.assertRaisesRegex(ValueError, "NotAFruit has not been registered."): | |
LargeFruitBowl(extra_fruit_class_type="NotAFruit") | |
def test_inheritance2(self): | |
# This is a case where a class could contain an instance | |
# of a subclass, which is ignored. | |
class Parent(ReplaceableBase): | |
pass | |
class Main(Configurable): | |
parent: Parent | |
# Note - no __post__init__ | |
class Derived(Parent, Main): | |
pass | |
args = get_default_args(Main) | |
# Derived has been ignored in processing Main. | |
self.assertCountEqual(args.keys(), ["parent_class_type"]) | |
main = Main(**args) | |
with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."): | |
run_auto_creation(main) | |
main.parent_class_type = "Derived" | |
# Illustrates that a dict works fine instead of a DictConfig. | |
main.parent_Derived_args = {} | |
with self.assertRaises(AttributeError): | |
main.parent | |
run_auto_creation(main) | |
self.assertIsInstance(main.parent, Derived) | |
def test_redefine(self): | |
class FruitBowl(ReplaceableBase): | |
main_fruit: Fruit | |
main_fruit_class_type: str = "Grape" | |
def __post_init__(self): | |
run_auto_creation(self) | |
class Grape(Fruit): | |
large: bool = False | |
def get_color(self): | |
return "red" | |
def __post_init__(self): | |
raise ValueError("This doesn't get called") | |
bowl_args = get_default_args(FruitBowl) | |
class Grape(Fruit): # noqa: F811 | |
large: bool = True | |
def get_color(self): | |
return "green" | |
with self.assertWarnsRegex( | |
UserWarning, "New implementation of Grape is being chosen." | |
): | |
defaulted_bowl = FruitBowl() | |
self.assertIsInstance(defaulted_bowl.main_fruit, Grape) | |
self.assertEqual(defaulted_bowl.main_fruit.large, True) | |
self.assertEqual(defaulted_bowl.main_fruit.get_color(), "green") | |
with self.assertWarnsRegex( | |
UserWarning, "New implementation of Grape is being chosen." | |
): | |
args_bowl = FruitBowl(**bowl_args) | |
self.assertIsInstance(args_bowl.main_fruit, Grape) | |
# Redefining the same class won't help with defaults because encoded in args | |
self.assertEqual(args_bowl.main_fruit.large, False) | |
# But the override worked. | |
self.assertEqual(args_bowl.main_fruit.get_color(), "green") | |
# 2. Try redefining without the dataclass modifier | |
# This relies on the fact that default creation processes the class. | |
# (otherwise incomprehensible messages) | |
class Grape(Fruit): # noqa: F811 | |
large: bool = True | |
with self.assertWarnsRegex( | |
UserWarning, "New implementation of Grape is being chosen." | |
): | |
FruitBowl(**bowl_args) | |
# 3. Adding a new class doesn't get picked up, because the first | |
# get_default_args call has frozen FruitBowl. This is intrinsic to | |
# the way dataclass and expand_args_fields work in-place but | |
# expand_args_fields is not pure - it depends on the registry. | |
class Fig(Fruit): | |
pass | |
bowl_args2 = get_default_args(FruitBowl) | |
self.assertIn("main_fruit_Grape_args", bowl_args2) | |
self.assertNotIn("main_fruit_Fig_args", bowl_args2) | |
# TODO Is it possible to make this work? | |
# bowl_args2["main_fruit_Fig_args"] = get_default_args(Fig) | |
# bowl_args2.main_fruit_class_type = "Fig" | |
# bowl2 = FruitBowl(**bowl_args2) <= unexpected argument | |
# Note that it is possible to use Fig if you can set | |
# bowl2.main_fruit_Fig_args explicitly (not in bowl_args2) | |
# before run_auto_creation happens. See test_inheritance2 | |
# for an example. | |
def test_no_replacement(self): | |
# Test of Configurables without ReplaceableBase | |
class A(Configurable): | |
n: int = 9 | |
class B(Configurable): | |
a: A | |
def __post_init__(self): | |
run_auto_creation(self) | |
class C(Configurable): | |
b1: B | |
b2: Optional[B] | |
b3: Optional[B] | |
b2_enabled: bool = True | |
b3_enabled: bool = False | |
def __post_init__(self): | |
run_auto_creation(self) | |
c_args = get_default_args(C) | |
c = C(**c_args) | |
self.assertIsInstance(c.b1.a, A) | |
self.assertEqual(c.b1.a.n, 9) | |
self.assertFalse(hasattr(c, "b1_enabled")) | |
self.assertIsInstance(c.b2.a, A) | |
self.assertEqual(c.b2.a.n, 9) | |
self.assertTrue(c.b2_enabled) | |
self.assertIsNone(c.b3) | |
self.assertFalse(c.b3_enabled) | |
def test_doc(self): | |
# The case in the docstring. | |
class A(ReplaceableBase): | |
k: int = 1 | |
class A1(A): | |
m: int = 3 | |
class A2(A): | |
n: str = "2" | |
class B(Configurable): | |
a: A | |
a_class_type: str = "A2" | |
def __post_init__(self): | |
run_auto_creation(self) | |
b_args = get_default_args(B) | |
self.assertNotIn("a", b_args) | |
b = B(**b_args) | |
self.assertEqual(b.a.n, "2") | |
def test_raw_types(self): | |
class MyDataclass: | |
int_field: int = 0 | |
none_field: Optional[int] = None | |
float_field: float = 9.3 | |
bool_field: bool = True | |
tuple_field: Tuple[int, ...] = (3,) | |
class SimpleClass: | |
def __init__( | |
self, | |
tuple_member_: Tuple[int, int] = (3, 4), | |
): | |
self.tuple_member = tuple_member_ | |
def get_tuple(self): | |
return self.tuple_member | |
enable_get_default_args(SimpleClass) | |
def f(*, a: int = 3, b: str = "kj"): | |
self.assertEqual(a, 3) | |
self.assertEqual(b, "kj") | |
enable_get_default_args(f) | |
class C(Configurable): | |
simple: DictConfig = get_default_args_field(SimpleClass) | |
# simple2: SimpleClass2 = SimpleClass2() | |
mydata: DictConfig = get_default_args_field(MyDataclass) | |
a_tuple: Tuple[float] = (4.0, 3.0) | |
f_args: DictConfig = get_default_args_field(f) | |
args = get_default_args(C) | |
c = C(**args) | |
self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"]) | |
mydata = MyDataclass(**c.mydata) | |
simple = SimpleClass(**c.simple) | |
# OmegaConf converts tuples to ListConfigs (which act like lists). | |
self.assertEqual(simple.get_tuple(), [3, 4]) | |
self.assertTrue(isinstance(simple.get_tuple(), ListConfig)) | |
# get_default_args converts sets to ListConfigs (which act like lists). | |
self.assertEqual(c.a_tuple, [4.0, 3.0]) | |
self.assertTrue(isinstance(c.a_tuple, ListConfig)) | |
self.assertEqual(mydata.tuple_field, (3,)) | |
self.assertTrue(isinstance(mydata.tuple_field, ListConfig)) | |
f(**c.f_args) | |
def test_irrelevant_bases(self): | |
class NotADataclass: | |
# Like torch.nn.Module, this class contains annotations | |
# but is not designed to be dataclass'd. | |
# This test ensures that such classes, when inherited fron, | |
# are not accidentally affected by expand_args_fields. | |
a: int = 9 | |
b: int | |
class LeftConfigured(Configurable, NotADataclass): | |
left: int = 1 | |
class RightConfigured(NotADataclass, Configurable): | |
right: int = 2 | |
class Outer(Configurable): | |
left: LeftConfigured | |
right: RightConfigured | |
def __post_init__(self): | |
run_auto_creation(self) | |
outer = Outer(**get_default_args(Outer)) | |
self.assertEqual(outer.left.left, 1) | |
self.assertEqual(outer.right.right, 2) | |
with self.assertRaisesRegex(TypeError, "non-default argument"): | |
dataclass(NotADataclass) | |
def test_unprocessed(self): | |
# behavior of Configurable classes which need processing in __new__, | |
class UnprocessedConfigurable(Configurable): | |
a: int = 9 | |
class UnprocessedReplaceable(ReplaceableBase): | |
a: int = 9 | |
for Unprocessed in [UnprocessedConfigurable, UnprocessedReplaceable]: | |
self.assertFalse(_is_actually_dataclass(Unprocessed)) | |
unprocessed = Unprocessed() | |
self.assertTrue(_is_actually_dataclass(Unprocessed)) | |
self.assertTrue(isinstance(unprocessed, Unprocessed)) | |
self.assertEqual(unprocessed.a, 9) | |
def test_enum(self): | |
# Test that enum values are kept, i.e. that OmegaConf's runtime checks | |
# are in use. | |
class A(Enum): | |
B1 = "b1" | |
B2 = "b2" | |
# Test for a Configurable class, a function, and a regular class. | |
class C(Configurable): | |
a: A = A.B1 | |
# Also test for a calllable with enum arguments. | |
def C_fn(a: A = A.B1): | |
pass | |
enable_get_default_args(C_fn) | |
class C_cl: | |
def __init__(self, a: A = A.B1) -> None: | |
pass | |
enable_get_default_args(C_cl) | |
for C_ in [C, C_fn, C_cl]: | |
base = get_default_args(C_) | |
self.assertEqual(OmegaConf.to_yaml(base), "a: B1\n") | |
self.assertEqual(base.a, A.B1) | |
replaced = OmegaConf.merge(base, {"a": "B2"}) | |
self.assertEqual(replaced.a, A.B2) | |
with self.assertRaises(ValidationError): | |
# You can't use a value which is not one of the | |
# choices, even if it is the str representation | |
# of one of the choices. | |
OmegaConf.merge(base, {"a": "b2"}) | |
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base))) | |
self.assertEqual(remerged.a, A.B1) | |
def test_pickle(self): | |
def func(a: int = 1, b: str = "3"): | |
pass | |
enable_get_default_args(func) | |
args = get_default_args(func) | |
args2 = pickle.loads(pickle.dumps(args)) | |
self.assertEqual(args2.a, 1) | |
self.assertEqual(args2.b, "3") | |
args_regenerated = get_default_args(func) | |
pickle.dumps(args_regenerated) | |
pickle.dumps(args) | |
def test_remove_unused_components(self): | |
struct = get_default_args(MainTest) | |
struct.n_ids = 32 | |
struct.the_fruit_class_type = "Pear" | |
struct.the_second_fruit_class_type = "Banana" | |
remove_unused_components(struct) | |
expected_keys = [ | |
"n_ids", | |
"n_reps", | |
"the_fruit_Pear_args", | |
"the_fruit_class_type", | |
"the_second_fruit_Banana_args", | |
"the_second_fruit_class_type", | |
] | |
expected_yaml = textwrap.dedent( | |
"""\ | |
n_ids: 32 | |
n_reps: 8 | |
the_fruit_class_type: Pear | |
the_fruit_Pear_args: | |
n_pips: 13 | |
the_second_fruit_class_type: Banana | |
the_second_fruit_Banana_args: | |
pips: ??? | |
spots: ??? | |
bananame: ??? | |
""" | |
) | |
self.assertEqual(sorted(struct.keys()), expected_keys) | |
# Check that struct is what we expect | |
expected = OmegaConf.create(expected_yaml) | |
self.assertEqual(struct, expected) | |
# Check that we get what we expect when writing to yaml. | |
self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml) | |
main = MainTest(**struct) | |
instance_data = OmegaConf.structured(main) | |
remove_unused_components(instance_data) | |
self.assertEqual(sorted(instance_data.keys()), expected_keys) | |
self.assertEqual(instance_data, expected) | |
def test_remove_unused_components_optional(self): | |
class MainTestWrapper(Configurable): | |
mt: Optional[MainTest] | |
mt_enabled: bool = False | |
args = get_default_args(MainTestWrapper) | |
self.assertEqual(list(args.keys()), ["mt_enabled", "mt_args"]) | |
remove_unused_components(args) | |
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n") | |
def test_get_instance_args(self): | |
mt1, mt2 = [ | |
MainTest( | |
n_ids=0, | |
n_reps=909, | |
the_fruit_class_type="Pear", | |
the_second_fruit_class_type="Pear", | |
the_fruit_Pear_args=DictConfig({}), | |
the_second_fruit_Pear_args={}, | |
) | |
for _ in range(2) | |
] | |
# Two equivalent ways to get the DictConfig back out of an instance. | |
cfg1 = OmegaConf.structured(mt1) | |
cfg2 = get_default_args(mt2) | |
self.assertEqual(cfg1, cfg2) | |
self.assertEqual(len(cfg1.the_second_fruit_Pear_args), 0) | |
self.assertEqual(len(mt2.the_second_fruit_Pear_args), 0) | |
from_cfg = MainTest(**cfg2) | |
self.assertEqual(len(from_cfg.the_second_fruit_Pear_args), 0) | |
# If you want the complete args, merge with the defaults. | |
merged_args = OmegaConf.merge(get_default_args(MainTest), cfg2) | |
from_merged = MainTest(**merged_args) | |
self.assertEqual(len(from_merged.the_second_fruit_Pear_args), 1) | |
self.assertEqual(from_merged.n_reps, 909) | |
def test_tweak_hook(self): | |
class A(Configurable): | |
n: int = 9 | |
class Wrapper(Configurable): | |
fruit: Fruit | |
fruit_class_type: str = "Pear" | |
fruit2: Fruit | |
fruit2_class_type: str = "Pear" | |
a: A | |
a2: A | |
a3: A | |
def a_tweak_args(cls, type, args): | |
assert type == A | |
args.n = 993 | |
def a3_tweak_args(cls, type, args): | |
del args["n"] | |
def fruit_tweak_args(cls, type, args): | |
assert issubclass(type, Fruit) | |
if type == Pear: | |
assert args.n_pips == 13 | |
args.n_pips = 19 | |
args = get_default_args(Wrapper) | |
self.assertEqual(args.a_args.n, 993) | |
self.assertEqual(args.a2_args.n, 9) | |
self.assertEqual(args.a3_args, {}) | |
self.assertEqual(args.fruit_Pear_args.n_pips, 19) | |
self.assertEqual(args.fruit2_Pear_args.n_pips, 13) | |
def test_impls(self): | |
# Check that create_x actually uses create_x_impl to do its work | |
# by using all the member types, both with a faked impl function | |
# and without. | |
# members with _0 are optional and absent, those with _o are | |
# optional and present. | |
control_args = [] | |
def fake_impl(self, control, args): | |
control_args.append(control) | |
for fake in [False, True]: | |
class MyClass(Configurable): | |
fruit: Fruit | |
fruit_class_type: str = "Orange" | |
fruit_o: Optional[Fruit] | |
fruit_o_class_type: str = "Orange" | |
fruit_0: Optional[Fruit] | |
fruit_0_class_type: Optional[str] = None | |
boring: BoringConfigurable | |
boring_o: Optional[BoringConfigurable] | |
boring_o_enabled: bool = True | |
boring_0: Optional[BoringConfigurable] | |
boring_0_enabled: bool = False | |
def __post_init__(self): | |
run_auto_creation(self) | |
if fake: | |
MyClass.create_fruit_impl = fake_impl | |
MyClass.create_fruit_o_impl = fake_impl | |
MyClass.create_boring_impl = fake_impl | |
MyClass.create_boring_o_impl = fake_impl | |
expand_args_fields(MyClass) | |
instance = MyClass() | |
for name in ["fruit", "fruit_o", "boring", "boring_o"]: | |
self.assertEqual( | |
hasattr(instance, name), not fake, msg=f"{name} {fake}" | |
) | |
self.assertIsNone(instance.fruit_0) | |
self.assertIsNone(instance.boring_0) | |
if not fake: | |
self.assertIsInstance(instance.fruit, Orange) | |
self.assertIsInstance(instance.fruit_o, Orange) | |
self.assertIsInstance(instance.boring, BoringConfigurable) | |
self.assertIsInstance(instance.boring_o, BoringConfigurable) | |
self.assertEqual(control_args, ["Orange", "Orange", True, True]) | |
def test_pre_expand(self): | |
# Check that the precreate method of a class is called once before | |
# when expand_args_fields is called on the class. | |
class A(Configurable): | |
n: int = 9 | |
def pre_expand(cls): | |
pass | |
A.pre_expand = Mock() | |
expand_args_fields(A) | |
A.pre_expand.assert_called() | |
def test_pre_expand_replaceable(self): | |
# Check that the precreate method of a class is called once before | |
# when expand_args_fields is called on the class. | |
class A(ReplaceableBase): | |
pass | |
def pre_expand(cls): | |
pass | |
class A1(A): | |
n: 9 | |
A.pre_expand = Mock() | |
expand_args_fields(A1) | |
A.pre_expand.assert_called() | |
class MockDataclass: | |
field_no_default: int | |
field_primitive_type: int = 42 | |
field_optional_none: Optional[int] = None | |
field_optional_dict_none: Optional[Dict] = None | |
field_optional_with_value: Optional[int] = 42 | |
field_list_type: List[int] = field(default_factory=lambda: []) | |
class RefObject: | |
pass | |
REF_OBJECT = RefObject() | |
class MockClassWithInit: # noqa: B903 | |
def __init__( | |
self, | |
field_no_nothing, | |
field_no_default: int, | |
field_primitive_type: int = 42, | |
field_optional_none: Optional[int] = None, | |
field_optional_dict_none: Optional[Dict] = None, | |
field_optional_with_value: Optional[int] = 42, | |
field_list_type: List[int] = [], # noqa: B006 | |
field_reference_type: RefObject = REF_OBJECT, | |
): | |
self.field_no_nothing = field_no_nothing | |
self.field_no_default = field_no_default | |
self.field_primitive_type = field_primitive_type | |
self.field_optional_none = field_optional_none | |
self.field_optional_dict_none = field_optional_dict_none | |
self.field_optional_with_value = field_optional_with_value | |
self.field_list_type = field_list_type | |
self.field_reference_type = field_reference_type | |
enable_get_default_args(MockClassWithInit) | |
class TestRawClasses(unittest.TestCase): | |
def setUp(self) -> None: | |
self._instances = { | |
MockDataclass: MockDataclass(field_no_default=0), | |
MockClassWithInit: MockClassWithInit( | |
field_no_nothing="tratata", field_no_default=0 | |
), | |
} | |
def test_get_default_args(self): | |
for cls in [MockDataclass, MockClassWithInit]: | |
dataclass_defaults = get_default_args(cls) | |
# DictConfig fields with missing values are `not in` | |
self.assertNotIn("field_no_default", dataclass_defaults) | |
self.assertNotIn("field_no_nothing", dataclass_defaults) | |
self.assertNotIn("field_reference_type", dataclass_defaults) | |
expected_defaults = [ | |
"field_primitive_type", | |
"field_optional_none", | |
"field_optional_dict_none", | |
"field_optional_with_value", | |
"field_list_type", | |
] | |
if cls == MockDataclass: # we don't remove undefaulted from dataclasses | |
dataclass_defaults.field_no_default = 0 | |
expected_defaults.insert(0, "field_no_default") | |
self.assertEqual(list(dataclass_defaults), expected_defaults) | |
for name, val in dataclass_defaults.items(): | |
self.assertTrue(hasattr(self._instances[cls], name)) | |
self.assertEqual(val, getattr(self._instances[cls], name)) | |
def test_get_default_args_readonly(self): | |
for cls in [MockDataclass, MockClassWithInit]: | |
dataclass_defaults = get_default_args(cls) | |
dataclass_defaults["field_list_type"].append(13) | |
self.assertEqual(self._instances[cls].field_list_type, []) | |