|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Global Registry for big_vision pp ops. |
|
|
|
Author: Joan Puigcerver (jpuigcerver@) |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import ast |
|
import contextlib |
|
import functools |
|
|
|
|
|
def parse_name(string_to_parse): |
|
"""Parses input to the registry's lookup function. |
|
|
|
Args: |
|
string_to_parse: can be either an arbitrary name or function call |
|
(optionally with positional and keyword arguments). |
|
e.g. "multiclass", "resnet50_v2(filters_factor=8)". |
|
|
|
Returns: |
|
A tuple of input name, argument tuple and a keyword argument dictionary. |
|
Examples: |
|
"multiclass" -> ("multiclass", (), {}) |
|
"resnet50_v2(9, filters_factor=4)" -> |
|
("resnet50_v2", (9,), {"filters_factor": 4}) |
|
|
|
Author: Joan Puigcerver (jpuigcerver@) |
|
""" |
|
expr = ast.parse(string_to_parse, mode="eval").body |
|
if not isinstance(expr, (ast.Attribute, ast.Call, ast.Name)): |
|
raise ValueError( |
|
"The given string should be a name or a call, but a {} was parsed from " |
|
"the string {!r}".format(type(expr), string_to_parse)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(expr, ast.Name): |
|
return string_to_parse, (), {} |
|
elif isinstance(expr, ast.Attribute): |
|
return string_to_parse, (), {} |
|
|
|
def _get_func_name(expr): |
|
if isinstance(expr, ast.Attribute): |
|
return _get_func_name(expr.value) + "." + expr.attr |
|
elif isinstance(expr, ast.Name): |
|
return expr.id |
|
else: |
|
raise ValueError( |
|
"Type {!r} is not supported in a function name, the string to parse " |
|
"was {!r}".format(type(expr), string_to_parse)) |
|
|
|
def _get_func_args_and_kwargs(call): |
|
args = tuple([ast.literal_eval(arg) for arg in call.args]) |
|
kwargs = { |
|
kwarg.arg: ast.literal_eval(kwarg.value) for kwarg in call.keywords |
|
} |
|
return args, kwargs |
|
|
|
func_name = _get_func_name(expr.func) |
|
func_args, func_kwargs = _get_func_args_and_kwargs(expr) |
|
|
|
return func_name, func_args, func_kwargs |
|
|
|
|
|
class Registry(object): |
|
"""Implements global Registry. |
|
|
|
Authors: Joan Puigcerver (jpuigcerver@), Alexander Kolesnikov (akolesnikov@) |
|
""" |
|
|
|
_GLOBAL_REGISTRY = {} |
|
|
|
@staticmethod |
|
def global_registry(): |
|
return Registry._GLOBAL_REGISTRY |
|
|
|
@staticmethod |
|
def register(name, replace=False): |
|
"""Creates a function that registers its input.""" |
|
|
|
def _register(item): |
|
if name in Registry.global_registry() and not replace: |
|
raise KeyError("The name {!r} was already registered.".format(name)) |
|
|
|
Registry.global_registry()[name] = item |
|
return item |
|
|
|
return _register |
|
|
|
@staticmethod |
|
def lookup(lookup_string, kwargs_extra=None): |
|
"""Lookup a name in the registry.""" |
|
|
|
try: |
|
name, args, kwargs = parse_name(lookup_string) |
|
except ValueError as e: |
|
raise ValueError(f"Error parsing:\n{lookup_string}") from e |
|
if kwargs_extra: |
|
kwargs.update(kwargs_extra) |
|
item = Registry.global_registry()[name] |
|
return functools.partial(item, *args, **kwargs) |
|
|
|
@staticmethod |
|
def knows(lookup_string): |
|
try: |
|
name, _, _ = parse_name(lookup_string) |
|
except ValueError as e: |
|
raise ValueError(f"Error parsing:\n{lookup_string}") from e |
|
return name in Registry.global_registry() |
|
|
|
|
|
@contextlib.contextmanager |
|
def temporary_ops(**kw): |
|
"""Registers specified pp ops for use in a `with` block. |
|
|
|
Example use: |
|
|
|
with pp_registry.remporary_ops( |
|
pow=lambda alpha: lambda d: {k: v**alpha for k, v in d.items()}): |
|
pp = pp_builder.get_preprocess_fn("pow(alpha=2.0)|pow(alpha=0.5)") |
|
features = pp(features) |
|
|
|
Args: |
|
**kw: Names are preprocess string function names to be used to specify the |
|
preprocess function. Values are functions that can be called with params |
|
(e.g. the `alpha` param in above example) and return functions to be used |
|
to transform features. |
|
|
|
Yields: |
|
A context manager to be used in a `with` statement. |
|
""" |
|
reg = Registry.global_registry() |
|
kw = {f"preprocess_ops.{k}": v for k, v in kw.items()} |
|
for k in kw: |
|
assert k not in reg |
|
for k, v in kw.items(): |
|
reg[k] = v |
|
try: |
|
yield |
|
finally: |
|
for k in kw: |
|
del reg[k] |
|
|