Spaces:
Runtime error
Runtime error
import abc | |
from collections import defaultdict | |
from queue import Queue | |
from typing import List, Optional, Set | |
import networkx as nx | |
from inference.enterprise.workflows.complier.utils import get_nodes_of_specific_kind | |
from inference.enterprise.workflows.constants import STEP_NODE_KIND | |
class StepExecutionCoordinator(metaclass=abc.ABCMeta): | |
def init(cls, execution_graph: nx.DiGraph) -> "StepExecutionCoordinator": | |
pass | |
def get_steps_to_execute_next( | |
self, steps_to_discard: Set[str] | |
) -> Optional[List[str]]: | |
pass | |
class SerialExecutionCoordinator(StepExecutionCoordinator): | |
def init(cls, execution_graph: nx.DiGraph) -> "StepExecutionCoordinator": | |
return cls(execution_graph=execution_graph) | |
def __init__(self, execution_graph: nx.DiGraph): | |
self._execution_graph = execution_graph.copy() | |
self._discarded_steps: Set[str] = set() | |
self.__order: Optional[List[str]] = None | |
self.__step_pointer = 0 | |
def get_steps_to_execute_next( | |
self, steps_to_discard: Set[str] | |
) -> Optional[List[str]]: | |
if self.__order is None: | |
self.__establish_execution_order() | |
self._discarded_steps.update(steps_to_discard) | |
next_step = None | |
while self.__step_pointer < len(self.__order): | |
candidate_step = self.__order[self.__step_pointer] | |
self.__step_pointer += 1 | |
if candidate_step in self._discarded_steps: | |
continue | |
return [candidate_step] | |
return next_step | |
def __establish_execution_order(self) -> None: | |
step_nodes = get_nodes_of_specific_kind( | |
execution_graph=self._execution_graph, kind=STEP_NODE_KIND | |
) | |
self.__order = [ | |
n for n in nx.topological_sort(self._execution_graph) if n in step_nodes | |
] | |
self.__step_pointer = 0 | |
class ParallelStepExecutionCoordinator(StepExecutionCoordinator): | |
def init(cls, execution_graph: nx.DiGraph) -> "StepExecutionCoordinator": | |
return cls(execution_graph=execution_graph) | |
def __init__(self, execution_graph: nx.DiGraph): | |
self._execution_graph = execution_graph.copy() | |
self._discarded_steps: Set[str] = set() | |
self.__execution_order: Optional[List[List[str]]] = None | |
self.__execution_pointer = 0 | |
def get_steps_to_execute_next( | |
self, steps_to_discard: Set[str] | |
) -> Optional[List[str]]: | |
if self.__execution_order is None: | |
self.__execution_order = establish_execution_order( | |
execution_graph=self._execution_graph | |
) | |
self.__execution_pointer = 0 | |
self._discarded_steps.update(steps_to_discard) | |
next_step = None | |
while self.__execution_pointer < len(self.__execution_order): | |
candidate_steps = [ | |
e | |
for e in self.__execution_order[self.__execution_pointer] | |
if e not in self._discarded_steps | |
] | |
self.__execution_pointer += 1 | |
if len(candidate_steps) == 0: | |
continue | |
return candidate_steps | |
return next_step | |
def establish_execution_order( | |
execution_graph: nx.DiGraph, | |
) -> List[List[str]]: | |
steps_flow_graph = construct_steps_flow_graph(execution_graph=execution_graph) | |
steps_flow_graph = assign_max_distances_from_start( | |
steps_flow_graph=steps_flow_graph | |
) | |
return get_groups_execution_order(steps_flow_graph=steps_flow_graph) | |
def construct_steps_flow_graph(execution_graph: nx.DiGraph) -> nx.DiGraph: | |
steps_flow_graph = nx.DiGraph() | |
steps_flow_graph.add_node("start") | |
steps_flow_graph.add_node("end") | |
step_nodes = get_nodes_of_specific_kind( | |
execution_graph=execution_graph, kind=STEP_NODE_KIND | |
) | |
for step_node in step_nodes: | |
for predecessor in execution_graph.predecessors(step_node): | |
start_node = predecessor if predecessor in step_nodes else "start" | |
steps_flow_graph.add_edge(start_node, step_node) | |
for successor in execution_graph.successors(step_node): | |
end_node = successor if successor in step_nodes else "end" | |
steps_flow_graph.add_edge(step_node, end_node) | |
return steps_flow_graph | |
def assign_max_distances_from_start(steps_flow_graph: nx.DiGraph) -> nx.DiGraph: | |
nodes_to_consider = Queue() | |
nodes_to_consider.put("start") | |
while nodes_to_consider.qsize() > 0: | |
node_to_consider = nodes_to_consider.get() | |
predecessors = list(steps_flow_graph.predecessors(node_to_consider)) | |
if not all( | |
steps_flow_graph.nodes[p].get("distance") is not None for p in predecessors | |
): | |
# we can proceed to establish distance, only if all parents have distances established | |
continue | |
if len(predecessors) == 0: | |
distance_from_start = 0 | |
else: | |
distance_from_start = ( | |
max(steps_flow_graph.nodes[p]["distance"] for p in predecessors) + 1 | |
) | |
steps_flow_graph.nodes[node_to_consider]["distance"] = distance_from_start | |
for neighbour in steps_flow_graph.successors(node_to_consider): | |
nodes_to_consider.put(neighbour) | |
return steps_flow_graph | |
def get_groups_execution_order(steps_flow_graph: nx.DiGraph) -> List[List[str]]: | |
distance2steps = defaultdict(list) | |
for node_name, node_data in steps_flow_graph.nodes(data=True): | |
if node_name in {"start", "end"}: | |
continue | |
distance2steps[node_data["distance"]].append(node_name) | |
sorted_distances = sorted(list(distance2steps.keys())) | |
return [distance2steps[d] for d in sorted_distances] | |
def get_next_steps_to_execute( | |
execution_order: List[List[str]], | |
execution_pointer: int, | |
discarded_steps: Set[str], | |
) -> List[str]: | |
return [e for e in execution_order[execution_pointer] if e not in discarded_steps] | |