|
from collections import defaultdict |
|
from typing import Any, Dict, List, Optional, Set, Tuple |
|
|
|
import networkx as nx |
|
from networkx import DiGraph |
|
|
|
from inference.enterprise.workflows.complier.utils import ( |
|
construct_input_selector, |
|
construct_output_name, |
|
construct_step_selector, |
|
get_nodes_of_specific_kind, |
|
get_step_input_selectors, |
|
get_step_selector_from_its_output, |
|
is_condition_step, |
|
is_step_output_selector, |
|
) |
|
from inference.enterprise.workflows.constants import ( |
|
INPUT_NODE_KIND, |
|
OUTPUT_NODE_KIND, |
|
STEP_NODE_KIND, |
|
) |
|
from inference.enterprise.workflows.entities.outputs import JsonField |
|
from inference.enterprise.workflows.entities.steps import StepInterface |
|
from inference.enterprise.workflows.entities.validators import is_selector |
|
from inference.enterprise.workflows.entities.workflows_specification import ( |
|
InputType, |
|
StepType, |
|
WorkflowSpecificationV1, |
|
) |
|
from inference.enterprise.workflows.errors import ( |
|
AmbiguousPathDetected, |
|
NodesNotReachingOutputError, |
|
NotAcyclicGraphError, |
|
SelectorToUndefinedNodeError, |
|
) |
|
|
|
|
|
def prepare_execution_graph(workflow_specification: WorkflowSpecificationV1) -> DiGraph: |
|
execution_graph = construct_graph(workflow_specification=workflow_specification) |
|
if not nx.is_directed_acyclic_graph(execution_graph): |
|
raise NotAcyclicGraphError(f"Detected cycle in execution graph.") |
|
verify_each_node_reach_at_least_one_output(execution_graph=execution_graph) |
|
verify_each_node_step_has_parent_in_the_same_branch(execution_graph=execution_graph) |
|
verify_that_steps_are_connected_with_compatible_inputs( |
|
execution_graph=execution_graph |
|
) |
|
return execution_graph |
|
|
|
|
|
def construct_graph(workflow_specification: WorkflowSpecificationV1) -> DiGraph: |
|
execution_graph = nx.DiGraph() |
|
execution_graph = add_input_nodes_for_graph( |
|
inputs=workflow_specification.inputs, execution_graph=execution_graph |
|
) |
|
execution_graph = add_steps_nodes_for_graph( |
|
steps=workflow_specification.steps, execution_graph=execution_graph |
|
) |
|
execution_graph = add_output_nodes_for_graph( |
|
outputs=workflow_specification.outputs, execution_graph=execution_graph |
|
) |
|
execution_graph = add_steps_edges( |
|
workflow_specification=workflow_specification, execution_graph=execution_graph |
|
) |
|
return add_edges_for_outputs( |
|
workflow_specification=workflow_specification, execution_graph=execution_graph |
|
) |
|
|
|
|
|
def add_input_nodes_for_graph( |
|
inputs: List[InputType], |
|
execution_graph: DiGraph, |
|
) -> DiGraph: |
|
for input_spec in inputs: |
|
input_selector = construct_input_selector(input_name=input_spec.name) |
|
execution_graph.add_node( |
|
input_selector, |
|
kind=INPUT_NODE_KIND, |
|
definition=input_spec, |
|
) |
|
return execution_graph |
|
|
|
|
|
def add_steps_nodes_for_graph( |
|
steps: List[StepType], |
|
execution_graph: DiGraph, |
|
) -> DiGraph: |
|
for step in steps: |
|
step_selector = construct_step_selector(step_name=step.name) |
|
execution_graph.add_node( |
|
step_selector, |
|
kind=STEP_NODE_KIND, |
|
definition=step, |
|
) |
|
return execution_graph |
|
|
|
|
|
def add_output_nodes_for_graph( |
|
outputs: List[JsonField], |
|
execution_graph: DiGraph, |
|
) -> DiGraph: |
|
for output_spec in outputs: |
|
execution_graph.add_node( |
|
construct_output_name(name=output_spec.name), |
|
kind=OUTPUT_NODE_KIND, |
|
definition=output_spec, |
|
) |
|
return execution_graph |
|
|
|
|
|
def add_steps_edges( |
|
workflow_specification: WorkflowSpecificationV1, |
|
execution_graph: DiGraph, |
|
) -> DiGraph: |
|
for step in workflow_specification.steps: |
|
input_selectors = get_step_input_selectors(step=step) |
|
step_selector = construct_step_selector(step_name=step.name) |
|
execution_graph = add_edges_for_step_inputs( |
|
execution_graph=execution_graph, |
|
input_selectors=input_selectors, |
|
step_selector=step_selector, |
|
) |
|
if step.type == "Condition": |
|
verify_edge_is_created_between_existing_nodes( |
|
execution_graph=execution_graph, |
|
start=step_selector, |
|
end=step.step_if_true, |
|
) |
|
execution_graph.add_edge(step_selector, step.step_if_true) |
|
verify_edge_is_created_between_existing_nodes( |
|
execution_graph=execution_graph, |
|
start=step_selector, |
|
end=step.step_if_false, |
|
) |
|
execution_graph.add_edge(step_selector, step.step_if_false) |
|
return execution_graph |
|
|
|
|
|
def add_edges_for_step_inputs( |
|
execution_graph: DiGraph, |
|
input_selectors: Set[str], |
|
step_selector: str, |
|
) -> DiGraph: |
|
for input_selector in input_selectors: |
|
if is_step_output_selector(selector_or_value=input_selector): |
|
input_selector = get_step_selector_from_its_output( |
|
step_output_selector=input_selector |
|
) |
|
verify_edge_is_created_between_existing_nodes( |
|
execution_graph=execution_graph, |
|
start=input_selector, |
|
end=step_selector, |
|
) |
|
execution_graph.add_edge(input_selector, step_selector) |
|
return execution_graph |
|
|
|
|
|
def add_edges_for_outputs( |
|
workflow_specification: WorkflowSpecificationV1, |
|
execution_graph: DiGraph, |
|
) -> DiGraph: |
|
for output in workflow_specification.outputs: |
|
output_selector = output.selector |
|
if is_step_output_selector(selector_or_value=output_selector): |
|
output_selector = get_step_selector_from_its_output( |
|
step_output_selector=output_selector |
|
) |
|
output_name = construct_output_name(name=output.name) |
|
verify_edge_is_created_between_existing_nodes( |
|
execution_graph=execution_graph, |
|
start=output_selector, |
|
end=output_name, |
|
) |
|
execution_graph.add_edge(output_selector, output_name) |
|
return execution_graph |
|
|
|
|
|
def verify_edge_is_created_between_existing_nodes( |
|
execution_graph: DiGraph, |
|
start: str, |
|
end: str, |
|
) -> None: |
|
if not execution_graph.has_node(start): |
|
raise SelectorToUndefinedNodeError( |
|
f"Graph definition contains selector {start} that points to not defined element." |
|
) |
|
if not execution_graph.has_node(end): |
|
raise SelectorToUndefinedNodeError( |
|
f"Graph definition contains selector {end} that points to not defined element." |
|
) |
|
|
|
|
|
def verify_each_node_reach_at_least_one_output( |
|
execution_graph: DiGraph, |
|
) -> None: |
|
all_nodes = set(execution_graph.nodes()) |
|
output_nodes = get_nodes_of_specific_kind( |
|
execution_graph=execution_graph, kind=OUTPUT_NODE_KIND |
|
) |
|
nodes_without_outputs = get_nodes_that_do_not_produce_outputs( |
|
execution_graph=execution_graph |
|
) |
|
nodes_that_must_be_reached = output_nodes.union(nodes_without_outputs) |
|
nodes_reaching_output = ( |
|
get_nodes_that_are_reachable_from_pointed_ones_in_reversed_graph( |
|
execution_graph=execution_graph, |
|
pointed_nodes=nodes_that_must_be_reached, |
|
) |
|
) |
|
nodes_not_reaching_output = all_nodes.difference(nodes_reaching_output) |
|
if len(nodes_not_reaching_output) > 0: |
|
raise NodesNotReachingOutputError( |
|
f"Detected {len(nodes_not_reaching_output)} nodes not reaching any of output node:" |
|
f"{nodes_not_reaching_output}." |
|
) |
|
|
|
|
|
def get_nodes_that_do_not_produce_outputs(execution_graph: DiGraph) -> Set[str]: |
|
|
|
|
|
step_nodes = get_nodes_of_specific_kind( |
|
execution_graph=execution_graph, kind=STEP_NODE_KIND |
|
) |
|
return { |
|
step_node |
|
for step_node in step_nodes |
|
if len(execution_graph.nodes[step_node]["definition"].get_output_names()) == 0 |
|
} |
|
|
|
|
|
def get_nodes_that_are_reachable_from_pointed_ones_in_reversed_graph( |
|
execution_graph: DiGraph, |
|
pointed_nodes: Set[str], |
|
) -> Set[str]: |
|
result = set() |
|
reversed_graph = execution_graph.reverse(copy=True) |
|
for pointed_node in pointed_nodes: |
|
nodes_reaching_pointed_one = list( |
|
nx.dfs_postorder_nodes(reversed_graph, pointed_node) |
|
) |
|
result.update(nodes_reaching_pointed_one) |
|
return result |
|
|
|
|
|
def verify_each_node_step_has_parent_in_the_same_branch( |
|
execution_graph: DiGraph, |
|
) -> None: |
|
""" |
|
Conditional branching creates a bit of mess, in terms of determining which |
|
steps to execute. |
|
Let's imagine graph: |
|
/ -> B -> C -> D |
|
A -> IF < \ |
|
\ -> E -> F -> G -> H |
|
where node G requires node C even though IF branched the execution. In other |
|
words - the problem emerges if a node of kind STEP has a parent (node from which |
|
it can be achieved) of kind STEP and this parent is in a different branch (point out that |
|
we allow for a single step to have multiple steps as input, but they must be at the same |
|
execution path - for instance if D requires an output from C and B - this is allowed). |
|
Additionally, we must prevent situation when outcomes of branches started by two or more |
|
condition steps merge with each other, as condition eval may result in contradictory |
|
execution (2). |
|
|
|
|
|
We need to detect that situation upfront, such that we can raise error of ambiguous execution path |
|
rather than run time-consuming computations that will end up in error. |
|
|
|
To detect problem, first we detect steps with more than one parent step. |
|
From those steps we trace what sequence of steps would lead to execution of problematic one. |
|
For each problematic node we take its parent nodes. Then, we analyse paths from |
|
those parent nodes in reversed topological order (from those nodes towards entry nodes |
|
of execution graph). While our analysis, on each path we denote `Condition` steps and |
|
result of condition evaluation that must have been observed in runtime, to reach |
|
the problematic node while graph execution in normal direction. If we detect that |
|
for any `Condition` step we would need to output both True and False (more than one registered |
|
next step of `Condition` step) - we raise error. |
|
To detect problem (2) - we only let number of different condition steps considered be the number of |
|
max condition steps in a single path from origin to parent of problematic step. |
|
|
|
Beware that the latter part of algorithm has quite bad time complexity in general case. |
|
Worst part of algorithm runs at O(V^4) - at least taking coarse, worst-case estimations. |
|
In fact, there is not so bad: |
|
* The number of step nodes with multiple parents that we loop over in main loop, reduces the number of |
|
steps we iterate through in inner loops, as we are dealing with DAG (with quite limited amount of edges) |
|
and for each multi-parent node takes at least two other nodes (to construct a suspicious group) - |
|
so expected number of iterations in main loop is low - let's say 1-3 for a real graph. |
|
* for any reasonable execution graph, the complexity should be acceptable. |
|
""" |
|
steps_with_more_than_one_parent = detect_steps_with_more_than_one_parent_step( |
|
execution_graph=execution_graph |
|
) |
|
if len(steps_with_more_than_one_parent) == 0: |
|
return None |
|
reversed_steps_graph = construct_reversed_steps_graph( |
|
execution_graph=execution_graph |
|
) |
|
reversed_topological_order = list( |
|
nx.topological_sort(reversed_steps_graph) |
|
) |
|
for step in steps_with_more_than_one_parent: |
|
verify_multi_parent_step_execution_paths( |
|
reversed_steps_graph=reversed_steps_graph, |
|
reversed_topological_order=reversed_topological_order, |
|
step=step, |
|
) |
|
|
|
|
|
def detect_steps_with_more_than_one_parent_step(execution_graph: DiGraph) -> Set[str]: |
|
steps_nodes = get_nodes_of_specific_kind( |
|
execution_graph=execution_graph, |
|
kind=STEP_NODE_KIND, |
|
) |
|
edges_of_steps_nodes = [edge for edge in execution_graph.edges()] |
|
steps_parents = defaultdict(set) |
|
for edge in edges_of_steps_nodes: |
|
parent, child = edge |
|
if parent not in steps_nodes or child not in steps_nodes: |
|
continue |
|
steps_parents[child].add(parent) |
|
return {key for key, value in steps_parents.items() if len(value) > 1} |
|
|
|
|
|
def construct_reversed_steps_graph(execution_graph: DiGraph) -> DiGraph: |
|
reversed_steps_graph = execution_graph.reverse() |
|
for node, node_data in list(reversed_steps_graph.nodes(data=True)): |
|
if node_data.get("kind") != STEP_NODE_KIND: |
|
reversed_steps_graph.remove_node(node) |
|
return reversed_steps_graph |
|
|
|
|
|
def verify_multi_parent_step_execution_paths( |
|
reversed_steps_graph: nx.DiGraph, |
|
reversed_topological_order: List[str], |
|
step: str, |
|
) -> None: |
|
condition_steps_successors = defaultdict(set) |
|
max_conditions_steps = 0 |
|
for normal_flow_predecessor in reversed_steps_graph.successors(step): |
|
reversed_flow_path = ( |
|
construct_reversed_path_to_multi_parent_step_parent( |
|
reversed_steps_graph=reversed_steps_graph, |
|
reversed_topological_order=reversed_topological_order, |
|
parent_step=normal_flow_predecessor, |
|
step=step, |
|
) |
|
) |
|
( |
|
condition_steps_successors, |
|
condition_steps, |
|
) = denote_condition_steps_successors_in_normal_flow( |
|
reversed_steps_graph=reversed_steps_graph, |
|
reversed_flow_path=reversed_flow_path, |
|
condition_steps_successors=condition_steps_successors, |
|
) |
|
max_conditions_steps = max(condition_steps, max_conditions_steps) |
|
if len(condition_steps_successors) > max_conditions_steps: |
|
raise AmbiguousPathDetected( |
|
f"In execution graph, detected collision of branches that originate in different condition steps." |
|
) |
|
for condition_step, potential_next_steps in condition_steps_successors.items(): |
|
if len(potential_next_steps) > 1: |
|
raise AmbiguousPathDetected( |
|
f"In execution graph, condition step: {condition_step} creates ambiguous execution paths." |
|
) |
|
|
|
|
|
def construct_reversed_path_to_multi_parent_step_parent( |
|
reversed_steps_graph: nx.DiGraph, |
|
reversed_topological_order: List[str], |
|
parent_step: str, |
|
step: str, |
|
) -> List[str]: |
|
normal_flow_path_nodes = nx.descendants(reversed_steps_graph, parent_step) |
|
normal_flow_path_nodes.add(parent_step) |
|
normal_flow_path_nodes.add(step) |
|
return [n for n in reversed_topological_order if n in normal_flow_path_nodes] |
|
|
|
|
|
def denote_condition_steps_successors_in_normal_flow( |
|
reversed_steps_graph: nx.DiGraph, |
|
reversed_flow_path: List[str], |
|
condition_steps_successors: Dict[str, Set[str]], |
|
) -> Tuple[Dict[str, Set[str]], int]: |
|
conditions_steps = 0 |
|
if len(reversed_flow_path) == 0: |
|
return condition_steps_successors, conditions_steps |
|
previous_node = reversed_flow_path[0] |
|
for node in reversed_flow_path[1:]: |
|
if is_condition_step(execution_graph=reversed_steps_graph, node=node): |
|
condition_steps_successors[node].add(previous_node) |
|
conditions_steps += 1 |
|
previous_node = node |
|
return condition_steps_successors, conditions_steps |
|
|
|
|
|
def verify_that_steps_are_connected_with_compatible_inputs( |
|
execution_graph: nx.DiGraph, |
|
) -> None: |
|
steps_nodes = get_nodes_of_specific_kind( |
|
execution_graph=execution_graph, |
|
kind=STEP_NODE_KIND, |
|
) |
|
for step in steps_nodes: |
|
verify_step_inputs_selectors(step=step, execution_graph=execution_graph) |
|
|
|
|
|
def verify_step_inputs_selectors(step: str, execution_graph: nx.DiGraph) -> None: |
|
step_definition = execution_graph.nodes[step]["definition"] |
|
all_inputs = step_definition.get_input_names() |
|
for input_step in all_inputs: |
|
input_selector_or_value = getattr(step_definition, input_step) |
|
if issubclass(type(input_selector_or_value), list): |
|
for idx, single_selector_or_value in enumerate(input_selector_or_value): |
|
validate_step_definition_input( |
|
step_definition=step_definition, |
|
input_name=input_step, |
|
execution_graph=execution_graph, |
|
input_selector_or_value=single_selector_or_value, |
|
index=idx, |
|
) |
|
else: |
|
validate_step_definition_input( |
|
step_definition=step_definition, |
|
input_name=input_step, |
|
execution_graph=execution_graph, |
|
input_selector_or_value=input_selector_or_value, |
|
) |
|
|
|
|
|
def validate_step_definition_input( |
|
step_definition: StepInterface, |
|
input_name: str, |
|
execution_graph: nx.DiGraph, |
|
input_selector_or_value: Any, |
|
index: Optional[int] = None, |
|
) -> None: |
|
if not is_selector(selector_or_value=input_selector_or_value): |
|
return None |
|
if is_step_output_selector(selector_or_value=input_selector_or_value): |
|
input_selector_or_value = get_step_selector_from_its_output( |
|
step_output_selector=input_selector_or_value |
|
) |
|
input_node_definition = execution_graph.nodes[input_selector_or_value]["definition"] |
|
step_definition.validate_field_selector( |
|
field_name=input_name, |
|
input_step=input_node_definition, |
|
index=index, |
|
) |
|
|