|
import abc |
|
from collections import defaultdict |
|
from queue import Queue |
|
from typing import List, Optional, Set |
|
|
|
import networkx as nx |
|
|
|
from inference.enterprise.workflows.complier.utils import get_nodes_of_specific_kind |
|
from inference.enterprise.workflows.constants import STEP_NODE_KIND |
|
|
|
|
|
class StepExecutionCoordinator(metaclass=abc.ABCMeta): |
|
|
|
@classmethod |
|
@abc.abstractmethod |
|
def init(cls, execution_graph: nx.DiGraph) -> "StepExecutionCoordinator": |
|
pass |
|
|
|
@abc.abstractmethod |
|
def get_steps_to_execute_next( |
|
self, steps_to_discard: Set[str] |
|
) -> Optional[List[str]]: |
|
pass |
|
|
|
|
|
class SerialExecutionCoordinator(StepExecutionCoordinator): |
|
|
|
@classmethod |
|
def init(cls, execution_graph: nx.DiGraph) -> "StepExecutionCoordinator": |
|
return cls(execution_graph=execution_graph) |
|
|
|
def __init__(self, execution_graph: nx.DiGraph): |
|
self._execution_graph = execution_graph.copy() |
|
self._discarded_steps: Set[str] = set() |
|
self.__order: Optional[List[str]] = None |
|
self.__step_pointer = 0 |
|
|
|
def get_steps_to_execute_next( |
|
self, steps_to_discard: Set[str] |
|
) -> Optional[List[str]]: |
|
if self.__order is None: |
|
self.__establish_execution_order() |
|
self._discarded_steps.update(steps_to_discard) |
|
next_step = None |
|
while self.__step_pointer < len(self.__order): |
|
candidate_step = self.__order[self.__step_pointer] |
|
self.__step_pointer += 1 |
|
if candidate_step in self._discarded_steps: |
|
continue |
|
return [candidate_step] |
|
return next_step |
|
|
|
def __establish_execution_order(self) -> None: |
|
step_nodes = get_nodes_of_specific_kind( |
|
execution_graph=self._execution_graph, kind=STEP_NODE_KIND |
|
) |
|
self.__order = [ |
|
n for n in nx.topological_sort(self._execution_graph) if n in step_nodes |
|
] |
|
self.__step_pointer = 0 |
|
|
|
|
|
class ParallelStepExecutionCoordinator(StepExecutionCoordinator): |
|
|
|
@classmethod |
|
def init(cls, execution_graph: nx.DiGraph) -> "StepExecutionCoordinator": |
|
return cls(execution_graph=execution_graph) |
|
|
|
def __init__(self, execution_graph: nx.DiGraph): |
|
self._execution_graph = execution_graph.copy() |
|
self._discarded_steps: Set[str] = set() |
|
self.__execution_order: Optional[List[List[str]]] = None |
|
self.__execution_pointer = 0 |
|
|
|
def get_steps_to_execute_next( |
|
self, steps_to_discard: Set[str] |
|
) -> Optional[List[str]]: |
|
if self.__execution_order is None: |
|
self.__execution_order = establish_execution_order( |
|
execution_graph=self._execution_graph |
|
) |
|
self.__execution_pointer = 0 |
|
self._discarded_steps.update(steps_to_discard) |
|
next_step = None |
|
while self.__execution_pointer < len(self.__execution_order): |
|
candidate_steps = [ |
|
e |
|
for e in self.__execution_order[self.__execution_pointer] |
|
if e not in self._discarded_steps |
|
] |
|
self.__execution_pointer += 1 |
|
if len(candidate_steps) == 0: |
|
continue |
|
return candidate_steps |
|
return next_step |
|
|
|
|
|
def establish_execution_order( |
|
execution_graph: nx.DiGraph, |
|
) -> List[List[str]]: |
|
steps_flow_graph = construct_steps_flow_graph(execution_graph=execution_graph) |
|
steps_flow_graph = assign_max_distances_from_start( |
|
steps_flow_graph=steps_flow_graph |
|
) |
|
return get_groups_execution_order(steps_flow_graph=steps_flow_graph) |
|
|
|
|
|
def construct_steps_flow_graph(execution_graph: nx.DiGraph) -> nx.DiGraph: |
|
steps_flow_graph = nx.DiGraph() |
|
steps_flow_graph.add_node("start") |
|
steps_flow_graph.add_node("end") |
|
step_nodes = get_nodes_of_specific_kind( |
|
execution_graph=execution_graph, kind=STEP_NODE_KIND |
|
) |
|
for step_node in step_nodes: |
|
for predecessor in execution_graph.predecessors(step_node): |
|
start_node = predecessor if predecessor in step_nodes else "start" |
|
steps_flow_graph.add_edge(start_node, step_node) |
|
for successor in execution_graph.successors(step_node): |
|
end_node = successor if successor in step_nodes else "end" |
|
steps_flow_graph.add_edge(step_node, end_node) |
|
return steps_flow_graph |
|
|
|
|
|
def assign_max_distances_from_start(steps_flow_graph: nx.DiGraph) -> nx.DiGraph: |
|
nodes_to_consider = Queue() |
|
nodes_to_consider.put("start") |
|
while nodes_to_consider.qsize() > 0: |
|
node_to_consider = nodes_to_consider.get() |
|
predecessors = list(steps_flow_graph.predecessors(node_to_consider)) |
|
if not all( |
|
steps_flow_graph.nodes[p].get("distance") is not None for p in predecessors |
|
): |
|
|
|
continue |
|
if len(predecessors) == 0: |
|
distance_from_start = 0 |
|
else: |
|
distance_from_start = ( |
|
max(steps_flow_graph.nodes[p]["distance"] for p in predecessors) + 1 |
|
) |
|
steps_flow_graph.nodes[node_to_consider]["distance"] = distance_from_start |
|
for neighbour in steps_flow_graph.successors(node_to_consider): |
|
nodes_to_consider.put(neighbour) |
|
return steps_flow_graph |
|
|
|
|
|
def get_groups_execution_order(steps_flow_graph: nx.DiGraph) -> List[List[str]]: |
|
distance2steps = defaultdict(list) |
|
for node_name, node_data in steps_flow_graph.nodes(data=True): |
|
if node_name in {"start", "end"}: |
|
continue |
|
distance2steps[node_data["distance"]].append(node_name) |
|
sorted_distances = sorted(list(distance2steps.keys())) |
|
return [distance2steps[d] for d in sorted_distances] |
|
|
|
|
|
def get_next_steps_to_execute( |
|
execution_order: List[List[str]], |
|
execution_pointer: int, |
|
discarded_steps: Set[str], |
|
) -> List[str]: |
|
return [e for e in execution_order[execution_pointer] if e not in discarded_steps] |
|
|