|
|
|
|
|
|
|
|
|
|
|
|
|
import pickle |
|
import textwrap |
|
import unittest |
|
from dataclasses import dataclass, field, is_dataclass |
|
from enum import Enum |
|
from typing import Any, Dict, List, Optional, Tuple |
|
from unittest.mock import Mock |
|
|
|
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError |
|
from pytorch3d.implicitron.tools.config import ( |
|
_get_type_to_process, |
|
_is_actually_dataclass, |
|
_ProcessType, |
|
_Registry, |
|
Configurable, |
|
enable_get_default_args, |
|
expand_args_fields, |
|
get_default_args, |
|
get_default_args_field, |
|
registry, |
|
remove_unused_components, |
|
ReplaceableBase, |
|
run_auto_creation, |
|
) |
|
|
|
|
|
@dataclass |
|
class Animal(ReplaceableBase): |
|
pass |
|
|
|
|
|
class Fruit(ReplaceableBase): |
|
pass |
|
|
|
|
|
@registry.register |
|
class Banana(Fruit): |
|
pips: int |
|
spots: int |
|
bananame: str |
|
|
|
|
|
@registry.register |
|
class Pear(Fruit): |
|
n_pips: int = 13 |
|
|
|
|
|
class Pineapple(Fruit): |
|
pass |
|
|
|
|
|
@registry.register |
|
class Orange(Fruit): |
|
pass |
|
|
|
|
|
@registry.register |
|
class Kiwi(Fruit): |
|
pass |
|
|
|
|
|
@registry.register |
|
class LargePear(Pear): |
|
pass |
|
|
|
|
|
class BoringConfigurable(Configurable): |
|
pass |
|
|
|
|
|
class MainTest(Configurable): |
|
the_fruit: Fruit |
|
n_ids: int |
|
n_reps: int = 8 |
|
the_second_fruit: Fruit |
|
|
|
def create_the_second_fruit(self): |
|
expand_args_fields(Pineapple) |
|
self.the_second_fruit = Pineapple() |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
|
|
class TestConfig(unittest.TestCase): |
|
def test_is_actually_dataclass(self): |
|
@dataclass |
|
class A: |
|
pass |
|
|
|
self.assertTrue(_is_actually_dataclass(A)) |
|
self.assertTrue(is_dataclass(A)) |
|
|
|
class B(A): |
|
a: int |
|
|
|
self.assertFalse(_is_actually_dataclass(B)) |
|
self.assertTrue(is_dataclass(B)) |
|
|
|
def test_get_type_to_process(self): |
|
gt = _get_type_to_process |
|
self.assertIsNone(gt(int)) |
|
self.assertEqual(gt(Fruit), (Fruit, _ProcessType.REPLACEABLE)) |
|
self.assertEqual( |
|
gt(Optional[Fruit]), (Fruit, _ProcessType.OPTIONAL_REPLACEABLE) |
|
) |
|
self.assertEqual(gt(MainTest), (MainTest, _ProcessType.CONFIGURABLE)) |
|
self.assertEqual( |
|
gt(Optional[MainTest]), (MainTest, _ProcessType.OPTIONAL_CONFIGURABLE) |
|
) |
|
self.assertIsNone(gt(Optional[int])) |
|
self.assertIsNone(gt(Tuple[Fruit])) |
|
self.assertIsNone(gt(Tuple[Fruit, Animal])) |
|
self.assertIsNone(gt(Optional[List[int]])) |
|
|
|
def test_simple_replacement(self): |
|
struct = get_default_args(MainTest) |
|
struct.n_ids = 9780 |
|
struct.the_fruit_Pear_args.n_pips = 3 |
|
struct.the_fruit_class_type = "Pear" |
|
struct.the_second_fruit_class_type = "Pear" |
|
|
|
main = MainTest(**struct) |
|
self.assertIsInstance(main.the_fruit, Pear) |
|
self.assertEqual(main.n_reps, 8) |
|
self.assertEqual(main.n_ids, 9780) |
|
self.assertEqual(main.the_fruit.n_pips, 3) |
|
self.assertIsInstance(main.the_second_fruit, Pineapple) |
|
|
|
struct2 = get_default_args(MainTest) |
|
self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13) |
|
|
|
self.assertEqual( |
|
MainTest._creation_functions, |
|
("create_the_fruit", "create_the_second_fruit"), |
|
) |
|
|
|
def test_detect_bases(self): |
|
|
|
self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase)) |
|
self.assertIsNone(_Registry._base_class_from_class(MainTest)) |
|
self.assertIs(_Registry._base_class_from_class(Fruit), Fruit) |
|
self.assertIs(_Registry._base_class_from_class(Pear), Fruit) |
|
|
|
class PricklyPear(Pear): |
|
pass |
|
|
|
self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit) |
|
|
|
def test_registry_entries(self): |
|
self.assertIs(registry.get(Fruit, "Banana"), Banana) |
|
with self.assertRaisesRegex(ValueError, "Banana has not been registered."): |
|
registry.get(Animal, "Banana") |
|
with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."): |
|
registry.get(Fruit, "PricklyPear") |
|
|
|
self.assertIs(registry.get(Pear, "Pear"), Pear) |
|
self.assertIs(registry.get(Pear, "LargePear"), LargePear) |
|
with self.assertRaisesRegex(ValueError, "Banana resolves to"): |
|
registry.get(Pear, "Banana") |
|
|
|
all_fruit = set(registry.get_all(Fruit)) |
|
self.assertIn(Banana, all_fruit) |
|
self.assertIn(Pear, all_fruit) |
|
self.assertIn(LargePear, all_fruit) |
|
self.assertEqual(registry.get_all(Pear), [LargePear]) |
|
|
|
@registry.register |
|
class Apple(Fruit): |
|
pass |
|
|
|
@registry.register |
|
class CrabApple(Apple): |
|
pass |
|
|
|
self.assertEqual(registry.get_all(Apple), [CrabApple]) |
|
|
|
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple) |
|
|
|
with self.assertRaisesRegex(ValueError, "Cannot tell what it is."): |
|
|
|
@registry.register |
|
class NotAFruit: |
|
pass |
|
|
|
def test_recursion(self): |
|
class Shape(ReplaceableBase): |
|
pass |
|
|
|
@registry.register |
|
class Triangle(Shape): |
|
a: float = 5.0 |
|
|
|
@registry.register |
|
class Square(Shape): |
|
a: float = 3.0 |
|
|
|
@registry.register |
|
class LargeShape(Shape): |
|
inner: Shape |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
class ShapeContainer(Configurable): |
|
shape: Shape |
|
|
|
container = ShapeContainer(**get_default_args(ShapeContainer)) |
|
|
|
with self.assertRaises(AttributeError): |
|
container.shape |
|
|
|
class ShapeContainer2(Configurable): |
|
x: Shape |
|
x_class_type: str = "LargeShape" |
|
|
|
def __post_init__(self): |
|
self.x_LargeShape_args.inner_class_type = "Triangle" |
|
run_auto_creation(self) |
|
|
|
container2_args = get_default_args(ShapeContainer2) |
|
container2_args.x_LargeShape_args.inner_Triangle_args.a += 10 |
|
self.assertIn("inner_Square_args", container2_args.x_LargeShape_args) |
|
|
|
|
|
self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args) |
|
container2_args.x_LargeShape_args.inner_Square_args.a += 100 |
|
container2 = ShapeContainer2(**container2_args) |
|
self.assertIsInstance(container2.x, LargeShape) |
|
self.assertIsInstance(container2.x.inner, Triangle) |
|
self.assertEqual(container2.x.inner.a, 15.0) |
|
|
|
def test_simpleclass_member(self): |
|
|
|
|
|
|
|
class Foo: |
|
def __init__(self, a: Any = 1, b: Any = 2): |
|
self.a, self.b = a, b |
|
|
|
enable_get_default_args(Foo) |
|
|
|
@dataclass() |
|
class Bar: |
|
aa: int = 9 |
|
bb: int = 9 |
|
|
|
class Container(Configurable): |
|
bar: Bar = Bar() |
|
|
|
|
|
fruit: Fruit |
|
fruit_class_type: str = "Orange" |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2}) |
|
container_args = get_default_args(Container) |
|
container = Container(**container_args) |
|
self.assertIsInstance(container.fruit, Orange) |
|
self.assertEqual(Container._processed_members, {"fruit": Fruit}) |
|
self.assertEqual(container._processed_members, {"fruit": Fruit}) |
|
|
|
container_defaulted = Container() |
|
container_defaulted.fruit_Pear_args.n_pips += 4 |
|
|
|
container_args2 = get_default_args(Container) |
|
container = Container(**container_args2) |
|
self.assertEqual(container.fruit_Pear_args.n_pips, 13) |
|
|
|
def test_inheritance(self): |
|
|
|
class FruitBowl(ReplaceableBase): |
|
main_fruit: Fruit |
|
main_fruit_class_type: str = "Orange" |
|
|
|
def __post_init__(self): |
|
raise ValueError("This doesn't get called") |
|
|
|
class LargeFruitBowl(FruitBowl): |
|
extra_fruit: Optional[Fruit] |
|
extra_fruit_class_type: str = "Kiwi" |
|
no_fruit: Optional[Fruit] |
|
no_fruit_class_type: Optional[str] = None |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
large_args = get_default_args(LargeFruitBowl) |
|
self.assertNotIn("extra_fruit", large_args) |
|
self.assertNotIn("main_fruit", large_args) |
|
large = LargeFruitBowl(**large_args) |
|
self.assertIsInstance(large.main_fruit, Orange) |
|
self.assertIsInstance(large.extra_fruit, Kiwi) |
|
self.assertIsNone(large.no_fruit) |
|
self.assertIn("no_fruit_Kiwi_args", large_args) |
|
|
|
remove_unused_components(large_args) |
|
large2 = LargeFruitBowl(**large_args) |
|
self.assertIsInstance(large2.main_fruit, Orange) |
|
self.assertIsInstance(large2.extra_fruit, Kiwi) |
|
self.assertIsNone(large2.no_fruit) |
|
needed_args = [ |
|
"extra_fruit_Kiwi_args", |
|
"extra_fruit_class_type", |
|
"main_fruit_Orange_args", |
|
"main_fruit_class_type", |
|
"no_fruit_class_type", |
|
] |
|
self.assertEqual(sorted(large_args.keys()), needed_args) |
|
|
|
with self.assertRaisesRegex(ValueError, "NotAFruit has not been registered."): |
|
LargeFruitBowl(extra_fruit_class_type="NotAFruit") |
|
|
|
def test_inheritance2(self): |
|
|
|
|
|
class Parent(ReplaceableBase): |
|
pass |
|
|
|
class Main(Configurable): |
|
parent: Parent |
|
|
|
|
|
@registry.register |
|
class Derived(Parent, Main): |
|
pass |
|
|
|
args = get_default_args(Main) |
|
|
|
self.assertCountEqual(args.keys(), ["parent_class_type"]) |
|
|
|
main = Main(**args) |
|
|
|
with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."): |
|
run_auto_creation(main) |
|
|
|
main.parent_class_type = "Derived" |
|
|
|
main.parent_Derived_args = {} |
|
with self.assertRaises(AttributeError): |
|
main.parent |
|
run_auto_creation(main) |
|
self.assertIsInstance(main.parent, Derived) |
|
|
|
def test_redefine(self): |
|
class FruitBowl(ReplaceableBase): |
|
main_fruit: Fruit |
|
main_fruit_class_type: str = "Grape" |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
@registry.register |
|
@dataclass |
|
class Grape(Fruit): |
|
large: bool = False |
|
|
|
def get_color(self): |
|
return "red" |
|
|
|
def __post_init__(self): |
|
raise ValueError("This doesn't get called") |
|
|
|
bowl_args = get_default_args(FruitBowl) |
|
|
|
@registry.register |
|
@dataclass |
|
class Grape(Fruit): |
|
large: bool = True |
|
|
|
def get_color(self): |
|
return "green" |
|
|
|
with self.assertWarnsRegex( |
|
UserWarning, "New implementation of Grape is being chosen." |
|
): |
|
defaulted_bowl = FruitBowl() |
|
self.assertIsInstance(defaulted_bowl.main_fruit, Grape) |
|
self.assertEqual(defaulted_bowl.main_fruit.large, True) |
|
self.assertEqual(defaulted_bowl.main_fruit.get_color(), "green") |
|
|
|
with self.assertWarnsRegex( |
|
UserWarning, "New implementation of Grape is being chosen." |
|
): |
|
args_bowl = FruitBowl(**bowl_args) |
|
self.assertIsInstance(args_bowl.main_fruit, Grape) |
|
|
|
self.assertEqual(args_bowl.main_fruit.large, False) |
|
|
|
self.assertEqual(args_bowl.main_fruit.get_color(), "green") |
|
|
|
|
|
|
|
|
|
@registry.register |
|
class Grape(Fruit): |
|
large: bool = True |
|
|
|
with self.assertWarnsRegex( |
|
UserWarning, "New implementation of Grape is being chosen." |
|
): |
|
FruitBowl(**bowl_args) |
|
|
|
|
|
|
|
|
|
|
|
@registry.register |
|
class Fig(Fruit): |
|
pass |
|
|
|
bowl_args2 = get_default_args(FruitBowl) |
|
self.assertIn("main_fruit_Grape_args", bowl_args2) |
|
self.assertNotIn("main_fruit_Fig_args", bowl_args2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_no_replacement(self): |
|
|
|
class A(Configurable): |
|
n: int = 9 |
|
|
|
class B(Configurable): |
|
a: A |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
class C(Configurable): |
|
b1: B |
|
b2: Optional[B] |
|
b3: Optional[B] |
|
b2_enabled: bool = True |
|
b3_enabled: bool = False |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
c_args = get_default_args(C) |
|
c = C(**c_args) |
|
self.assertIsInstance(c.b1.a, A) |
|
self.assertEqual(c.b1.a.n, 9) |
|
self.assertFalse(hasattr(c, "b1_enabled")) |
|
self.assertIsInstance(c.b2.a, A) |
|
self.assertEqual(c.b2.a.n, 9) |
|
self.assertTrue(c.b2_enabled) |
|
self.assertIsNone(c.b3) |
|
self.assertFalse(c.b3_enabled) |
|
|
|
def test_doc(self): |
|
|
|
class A(ReplaceableBase): |
|
k: int = 1 |
|
|
|
@registry.register |
|
class A1(A): |
|
m: int = 3 |
|
|
|
@registry.register |
|
class A2(A): |
|
n: str = "2" |
|
|
|
class B(Configurable): |
|
a: A |
|
a_class_type: str = "A2" |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
b_args = get_default_args(B) |
|
self.assertNotIn("a", b_args) |
|
b = B(**b_args) |
|
self.assertEqual(b.a.n, "2") |
|
|
|
def test_raw_types(self): |
|
@dataclass |
|
class MyDataclass: |
|
int_field: int = 0 |
|
none_field: Optional[int] = None |
|
float_field: float = 9.3 |
|
bool_field: bool = True |
|
tuple_field: Tuple[int, ...] = (3,) |
|
|
|
class SimpleClass: |
|
def __init__( |
|
self, |
|
tuple_member_: Tuple[int, int] = (3, 4), |
|
): |
|
self.tuple_member = tuple_member_ |
|
|
|
def get_tuple(self): |
|
return self.tuple_member |
|
|
|
enable_get_default_args(SimpleClass) |
|
|
|
def f(*, a: int = 3, b: str = "kj"): |
|
self.assertEqual(a, 3) |
|
self.assertEqual(b, "kj") |
|
|
|
enable_get_default_args(f) |
|
|
|
class C(Configurable): |
|
simple: DictConfig = get_default_args_field(SimpleClass) |
|
|
|
mydata: DictConfig = get_default_args_field(MyDataclass) |
|
a_tuple: Tuple[float] = (4.0, 3.0) |
|
f_args: DictConfig = get_default_args_field(f) |
|
|
|
args = get_default_args(C) |
|
c = C(**args) |
|
self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"]) |
|
|
|
mydata = MyDataclass(**c.mydata) |
|
simple = SimpleClass(**c.simple) |
|
|
|
|
|
self.assertEqual(simple.get_tuple(), [3, 4]) |
|
self.assertTrue(isinstance(simple.get_tuple(), ListConfig)) |
|
|
|
self.assertEqual(c.a_tuple, [4.0, 3.0]) |
|
self.assertTrue(isinstance(c.a_tuple, ListConfig)) |
|
self.assertEqual(mydata.tuple_field, (3,)) |
|
self.assertTrue(isinstance(mydata.tuple_field, ListConfig)) |
|
f(**c.f_args) |
|
|
|
def test_irrelevant_bases(self): |
|
class NotADataclass: |
|
|
|
|
|
|
|
|
|
a: int = 9 |
|
b: int |
|
|
|
class LeftConfigured(Configurable, NotADataclass): |
|
left: int = 1 |
|
|
|
class RightConfigured(NotADataclass, Configurable): |
|
right: int = 2 |
|
|
|
class Outer(Configurable): |
|
left: LeftConfigured |
|
right: RightConfigured |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
outer = Outer(**get_default_args(Outer)) |
|
self.assertEqual(outer.left.left, 1) |
|
self.assertEqual(outer.right.right, 2) |
|
with self.assertRaisesRegex(TypeError, "non-default argument"): |
|
dataclass(NotADataclass) |
|
|
|
def test_unprocessed(self): |
|
|
|
class UnprocessedConfigurable(Configurable): |
|
a: int = 9 |
|
|
|
class UnprocessedReplaceable(ReplaceableBase): |
|
a: int = 9 |
|
|
|
for Unprocessed in [UnprocessedConfigurable, UnprocessedReplaceable]: |
|
|
|
self.assertFalse(_is_actually_dataclass(Unprocessed)) |
|
unprocessed = Unprocessed() |
|
self.assertTrue(_is_actually_dataclass(Unprocessed)) |
|
self.assertTrue(isinstance(unprocessed, Unprocessed)) |
|
self.assertEqual(unprocessed.a, 9) |
|
|
|
def test_enum(self): |
|
|
|
|
|
|
|
class A(Enum): |
|
B1 = "b1" |
|
B2 = "b2" |
|
|
|
|
|
class C(Configurable): |
|
a: A = A.B1 |
|
|
|
|
|
def C_fn(a: A = A.B1): |
|
pass |
|
|
|
enable_get_default_args(C_fn) |
|
|
|
class C_cl: |
|
def __init__(self, a: A = A.B1) -> None: |
|
pass |
|
|
|
enable_get_default_args(C_cl) |
|
|
|
for C_ in [C, C_fn, C_cl]: |
|
base = get_default_args(C_) |
|
self.assertEqual(OmegaConf.to_yaml(base), "a: B1\n") |
|
self.assertEqual(base.a, A.B1) |
|
replaced = OmegaConf.merge(base, {"a": "B2"}) |
|
self.assertEqual(replaced.a, A.B2) |
|
with self.assertRaises(ValidationError): |
|
|
|
|
|
|
|
OmegaConf.merge(base, {"a": "b2"}) |
|
|
|
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base))) |
|
self.assertEqual(remerged.a, A.B1) |
|
|
|
def test_pickle(self): |
|
def func(a: int = 1, b: str = "3"): |
|
pass |
|
|
|
enable_get_default_args(func) |
|
|
|
args = get_default_args(func) |
|
args2 = pickle.loads(pickle.dumps(args)) |
|
self.assertEqual(args2.a, 1) |
|
self.assertEqual(args2.b, "3") |
|
|
|
args_regenerated = get_default_args(func) |
|
pickle.dumps(args_regenerated) |
|
pickle.dumps(args) |
|
|
|
def test_remove_unused_components(self): |
|
struct = get_default_args(MainTest) |
|
struct.n_ids = 32 |
|
struct.the_fruit_class_type = "Pear" |
|
struct.the_second_fruit_class_type = "Banana" |
|
remove_unused_components(struct) |
|
expected_keys = [ |
|
"n_ids", |
|
"n_reps", |
|
"the_fruit_Pear_args", |
|
"the_fruit_class_type", |
|
"the_second_fruit_Banana_args", |
|
"the_second_fruit_class_type", |
|
] |
|
expected_yaml = textwrap.dedent( |
|
"""\ |
|
n_ids: 32 |
|
n_reps: 8 |
|
the_fruit_class_type: Pear |
|
the_fruit_Pear_args: |
|
n_pips: 13 |
|
the_second_fruit_class_type: Banana |
|
the_second_fruit_Banana_args: |
|
pips: ??? |
|
spots: ??? |
|
bananame: ??? |
|
""" |
|
) |
|
self.assertEqual(sorted(struct.keys()), expected_keys) |
|
|
|
|
|
expected = OmegaConf.create(expected_yaml) |
|
self.assertEqual(struct, expected) |
|
|
|
|
|
self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml) |
|
|
|
main = MainTest(**struct) |
|
instance_data = OmegaConf.structured(main) |
|
remove_unused_components(instance_data) |
|
self.assertEqual(sorted(instance_data.keys()), expected_keys) |
|
self.assertEqual(instance_data, expected) |
|
|
|
def test_remove_unused_components_optional(self): |
|
class MainTestWrapper(Configurable): |
|
mt: Optional[MainTest] |
|
mt_enabled: bool = False |
|
|
|
args = get_default_args(MainTestWrapper) |
|
self.assertEqual(list(args.keys()), ["mt_enabled", "mt_args"]) |
|
remove_unused_components(args) |
|
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n") |
|
|
|
def test_get_instance_args(self): |
|
mt1, mt2 = [ |
|
MainTest( |
|
n_ids=0, |
|
n_reps=909, |
|
the_fruit_class_type="Pear", |
|
the_second_fruit_class_type="Pear", |
|
the_fruit_Pear_args=DictConfig({}), |
|
the_second_fruit_Pear_args={}, |
|
) |
|
for _ in range(2) |
|
] |
|
|
|
cfg1 = OmegaConf.structured(mt1) |
|
cfg2 = get_default_args(mt2) |
|
self.assertEqual(cfg1, cfg2) |
|
self.assertEqual(len(cfg1.the_second_fruit_Pear_args), 0) |
|
self.assertEqual(len(mt2.the_second_fruit_Pear_args), 0) |
|
|
|
from_cfg = MainTest(**cfg2) |
|
self.assertEqual(len(from_cfg.the_second_fruit_Pear_args), 0) |
|
|
|
|
|
merged_args = OmegaConf.merge(get_default_args(MainTest), cfg2) |
|
from_merged = MainTest(**merged_args) |
|
self.assertEqual(len(from_merged.the_second_fruit_Pear_args), 1) |
|
self.assertEqual(from_merged.n_reps, 909) |
|
|
|
def test_tweak_hook(self): |
|
class A(Configurable): |
|
n: int = 9 |
|
|
|
class Wrapper(Configurable): |
|
fruit: Fruit |
|
fruit_class_type: str = "Pear" |
|
fruit2: Fruit |
|
fruit2_class_type: str = "Pear" |
|
a: A |
|
a2: A |
|
a3: A |
|
|
|
@classmethod |
|
def a_tweak_args(cls, type, args): |
|
assert type == A |
|
args.n = 993 |
|
|
|
@classmethod |
|
def a3_tweak_args(cls, type, args): |
|
del args["n"] |
|
|
|
@classmethod |
|
def fruit_tweak_args(cls, type, args): |
|
assert issubclass(type, Fruit) |
|
if type == Pear: |
|
assert args.n_pips == 13 |
|
args.n_pips = 19 |
|
|
|
args = get_default_args(Wrapper) |
|
self.assertEqual(args.a_args.n, 993) |
|
self.assertEqual(args.a2_args.n, 9) |
|
self.assertEqual(args.a3_args, {}) |
|
self.assertEqual(args.fruit_Pear_args.n_pips, 19) |
|
self.assertEqual(args.fruit2_Pear_args.n_pips, 13) |
|
|
|
def test_impls(self): |
|
|
|
|
|
|
|
|
|
|
|
control_args = [] |
|
|
|
def fake_impl(self, control, args): |
|
control_args.append(control) |
|
|
|
for fake in [False, True]: |
|
|
|
class MyClass(Configurable): |
|
fruit: Fruit |
|
fruit_class_type: str = "Orange" |
|
fruit_o: Optional[Fruit] |
|
fruit_o_class_type: str = "Orange" |
|
fruit_0: Optional[Fruit] |
|
fruit_0_class_type: Optional[str] = None |
|
boring: BoringConfigurable |
|
boring_o: Optional[BoringConfigurable] |
|
boring_o_enabled: bool = True |
|
boring_0: Optional[BoringConfigurable] |
|
boring_0_enabled: bool = False |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
if fake: |
|
MyClass.create_fruit_impl = fake_impl |
|
MyClass.create_fruit_o_impl = fake_impl |
|
MyClass.create_boring_impl = fake_impl |
|
MyClass.create_boring_o_impl = fake_impl |
|
|
|
expand_args_fields(MyClass) |
|
instance = MyClass() |
|
for name in ["fruit", "fruit_o", "boring", "boring_o"]: |
|
self.assertEqual( |
|
hasattr(instance, name), not fake, msg=f"{name} {fake}" |
|
) |
|
|
|
self.assertIsNone(instance.fruit_0) |
|
self.assertIsNone(instance.boring_0) |
|
if not fake: |
|
self.assertIsInstance(instance.fruit, Orange) |
|
self.assertIsInstance(instance.fruit_o, Orange) |
|
self.assertIsInstance(instance.boring, BoringConfigurable) |
|
self.assertIsInstance(instance.boring_o, BoringConfigurable) |
|
|
|
self.assertEqual(control_args, ["Orange", "Orange", True, True]) |
|
|
|
def test_pre_expand(self): |
|
|
|
|
|
|
|
class A(Configurable): |
|
n: int = 9 |
|
|
|
@classmethod |
|
def pre_expand(cls): |
|
pass |
|
|
|
A.pre_expand = Mock() |
|
expand_args_fields(A) |
|
A.pre_expand.assert_called() |
|
|
|
def test_pre_expand_replaceable(self): |
|
|
|
|
|
|
|
class A(ReplaceableBase): |
|
pass |
|
|
|
@classmethod |
|
def pre_expand(cls): |
|
pass |
|
|
|
class A1(A): |
|
n: 9 |
|
|
|
A.pre_expand = Mock() |
|
expand_args_fields(A1) |
|
A.pre_expand.assert_called() |
|
|
|
|
|
@dataclass(eq=False) |
|
class MockDataclass: |
|
field_no_default: int |
|
field_primitive_type: int = 42 |
|
field_optional_none: Optional[int] = None |
|
field_optional_dict_none: Optional[Dict] = None |
|
field_optional_with_value: Optional[int] = 42 |
|
field_list_type: List[int] = field(default_factory=lambda: []) |
|
|
|
|
|
class RefObject: |
|
pass |
|
|
|
|
|
REF_OBJECT = RefObject() |
|
|
|
|
|
class MockClassWithInit: |
|
def __init__( |
|
self, |
|
field_no_nothing, |
|
field_no_default: int, |
|
field_primitive_type: int = 42, |
|
field_optional_none: Optional[int] = None, |
|
field_optional_dict_none: Optional[Dict] = None, |
|
field_optional_with_value: Optional[int] = 42, |
|
field_list_type: List[int] = [], |
|
field_reference_type: RefObject = REF_OBJECT, |
|
): |
|
self.field_no_nothing = field_no_nothing |
|
self.field_no_default = field_no_default |
|
self.field_primitive_type = field_primitive_type |
|
self.field_optional_none = field_optional_none |
|
self.field_optional_dict_none = field_optional_dict_none |
|
self.field_optional_with_value = field_optional_with_value |
|
self.field_list_type = field_list_type |
|
self.field_reference_type = field_reference_type |
|
|
|
|
|
enable_get_default_args(MockClassWithInit) |
|
|
|
|
|
class TestRawClasses(unittest.TestCase): |
|
def setUp(self) -> None: |
|
self._instances = { |
|
MockDataclass: MockDataclass(field_no_default=0), |
|
MockClassWithInit: MockClassWithInit( |
|
field_no_nothing="tratata", field_no_default=0 |
|
), |
|
} |
|
|
|
def test_get_default_args(self): |
|
for cls in [MockDataclass, MockClassWithInit]: |
|
dataclass_defaults = get_default_args(cls) |
|
|
|
self.assertNotIn("field_no_default", dataclass_defaults) |
|
self.assertNotIn("field_no_nothing", dataclass_defaults) |
|
self.assertNotIn("field_reference_type", dataclass_defaults) |
|
expected_defaults = [ |
|
"field_primitive_type", |
|
"field_optional_none", |
|
"field_optional_dict_none", |
|
"field_optional_with_value", |
|
"field_list_type", |
|
] |
|
|
|
if cls == MockDataclass: |
|
dataclass_defaults.field_no_default = 0 |
|
expected_defaults.insert(0, "field_no_default") |
|
self.assertEqual(list(dataclass_defaults), expected_defaults) |
|
for name, val in dataclass_defaults.items(): |
|
self.assertTrue(hasattr(self._instances[cls], name)) |
|
self.assertEqual(val, getattr(self._instances[cls], name)) |
|
|
|
def test_get_default_args_readonly(self): |
|
for cls in [MockDataclass, MockClassWithInit]: |
|
dataclass_defaults = get_default_args(cls) |
|
dataclass_defaults["field_list_type"].append(13) |
|
self.assertEqual(self._instances[cls].field_list_type, []) |
|
|