Spaces:
Runtime error
Runtime error
# Copyright 2023 The Orbit Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Provides a `ConditionalAction` abstraction.""" | |
from typing import Any, Callable, Sequence, Union | |
from orbit import controller | |
from orbit import runner | |
import tensorflow as tf, tf_keras | |
Condition = Callable[[runner.Output], Union[bool, tf.Tensor]] | |
def _as_sequence(maybe_sequence: Union[Any, Sequence[Any]]) -> Sequence[Any]: | |
if isinstance(maybe_sequence, Sequence): | |
return maybe_sequence | |
return [maybe_sequence] | |
class ConditionalAction: | |
"""Represents an action that is only taken when a given condition is met. | |
This class is itself an `Action` (a callable that can be applied to train or | |
eval outputs), but is intended to make it easier to write modular and reusable | |
conditions by decoupling "when" something whappens (the condition) from "what" | |
happens (the action). | |
""" | |
def __init__( | |
self, | |
condition: Condition, | |
action: Union[controller.Action, Sequence[controller.Action]], | |
): | |
"""Initializes the instance. | |
Args: | |
condition: A callable accepting train or eval outputs and returing a bool. | |
action: The action (or optionally sequence of actions) to perform when | |
`condition` is met. | |
""" | |
self.condition = condition | |
self.action = action | |
def __call__(self, output: runner.Output) -> None: | |
if self.condition(output): | |
for action in _as_sequence(self.action): | |
action(output) | |