Spaces:
Sleeping
Sleeping
File size: 2,265 Bytes
375a1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
"""Utilities for union (sum type) disambiguation."""
from collections import OrderedDict
from functools import reduce
from operator import or_
from typing import ( # noqa: F401, imported for Mypy.
Callable,
Dict,
Mapping,
Optional,
Type,
)
from attr import fields, NOTHING
from cattr._compat import get_origin
def create_uniq_field_dis_func(*classes: Type) -> Callable:
"""Given attr classes, generate a disambiguation function.
The function is based on unique fields."""
if len(classes) < 2:
raise ValueError("At least two classes required.")
cls_and_attrs = [
(cl, set(at.name for at in fields(get_origin(cl) or cl)))
for cl in classes
]
if len([attrs for _, attrs in cls_and_attrs if len(attrs) == 0]) > 1:
raise ValueError("At least two classes have no attributes.")
# TODO: Deal with a single class having no required attrs.
# For each class, attempt to generate a single unique required field.
uniq_attrs_dict = OrderedDict() # type: Dict[str, Type]
cls_and_attrs.sort(key=lambda c_a: -len(c_a[1]))
fallback = None # If none match, try this.
for i, (cl, cl_reqs) in enumerate(cls_and_attrs):
other_classes = cls_and_attrs[i + 1 :]
if other_classes:
other_reqs = reduce(or_, (c_a[1] for c_a in other_classes))
uniq = cl_reqs - other_reqs
if not uniq:
m = "{} has no usable unique attributes.".format(cl)
raise ValueError(m)
# We need a unique attribute with no default.
cl_fields = fields(get_origin(cl) or cl)
for attr_name in uniq:
if getattr(cl_fields, attr_name).default is NOTHING:
break
else:
raise ValueError(f"{cl} has no usable non-default attributes.")
uniq_attrs_dict[attr_name] = cl
else:
fallback = cl
def dis_func(data):
# type: (Mapping) -> Optional[Type]
if not isinstance(data, Mapping):
raise ValueError("Only input mappings are supported.")
for k, v in uniq_attrs_dict.items():
if k in data:
return v
return fallback
return dis_func
|