| from collections import deque |
| from typing import List, Set |
| |
| |
| class DiGraph: |
| """Really simple unweighted directed graph data structure to track dependencies. |
| |
| The API is pretty much the same as networkx so if you add something just |
| copy their API. |
| """ |
| |
| def __init__(self): |
| # Dict of node -> dict of arbitrary attributes |
| self._node = {} |
| # Nested dict of node -> successor node -> nothing. |
| # (didn't implement edge data) |
| self._succ = {} |
| # Nested dict of node -> predecessor node -> nothing. |
| self._pred = {} |
| |
| # Keep track of the order in which nodes are added to |
| # the graph. |
| self._node_order = {} |
| self._insertion_idx = 0 |
| |
| def add_node(self, n, **kwargs): |
| """Add a node to the graph. |
| |
| Args: |
| n: the node. Can we any object that is a valid dict key. |
| **kwargs: any attributes you want to attach to the node. |
| """ |
| if n not in self._node: |
| self._node[n] = kwargs |
| self._succ[n] = {} |
| self._pred[n] = {} |
| self._node_order[n] = self._insertion_idx |
| self._insertion_idx += 1 |
| else: |
| self._node[n].update(kwargs) |
| |
| def add_edge(self, u, v): |
| """Add an edge to graph between nodes ``u`` and ``v`` |
| |
| ``u`` and ``v`` will be created if they do not already exist. |
| """ |
| # add nodes |
| self.add_node(u) |
| self.add_node(v) |
| |
| # add the edge |
| self._succ[u][v] = True |
| self._pred[v][u] = True |
| |
| def successors(self, n): |
| """Returns an iterator over successor nodes of n.""" |
| try: |
| return iter(self._succ[n]) |
| except KeyError as e: |
| raise ValueError(f"The node {n} is not in the digraph.") from e |
| |
| def predecessors(self, n): |
| """Returns an iterator over predecessors nodes of n.""" |
| try: |
| return iter(self._pred[n]) |
| except KeyError as e: |
| raise ValueError(f"The node {n} is not in the digraph.") from e |
| |
| @property |
| def edges(self): |
| """Returns an iterator over all edges (u, v) in the graph""" |
| for n, successors in self._succ.items(): |
| for succ in successors: |
| yield n, succ |
| |
| @property |
| def nodes(self): |
| """Returns a dictionary of all nodes to their attributes.""" |
| return self._node |
| |
| def __iter__(self): |
| """Iterate over the nodes.""" |
| return iter(self._node) |
| |
| def __contains__(self, n): |
| """Returns True if ``n`` is a node in the graph, False otherwise.""" |
| try: |
| return n in self._node |
| except TypeError: |
| return False |
| |
| def forward_transitive_closure(self, src: str) -> Set[str]: |
| """Returns a set of nodes that are reachable from src""" |
| |
| result = set(src) |
| working_set = deque(src) |
| while len(working_set) > 0: |
| cur = working_set.popleft() |
| for n in self.successors(cur): |
| if n not in result: |
| result.add(n) |
| working_set.append(n) |
| return result |
| |
| def backward_transitive_closure(self, src: str) -> Set[str]: |
| """Returns a set of nodes that are reachable from src in reverse direction""" |
| |
| result = set(src) |
| working_set = deque(src) |
| while len(working_set) > 0: |
| cur = working_set.popleft() |
| for n in self.predecessors(cur): |
| if n not in result: |
| result.add(n) |
| working_set.append(n) |
| return result |
| |
| def all_paths(self, src: str, dst: str): |
| """Returns a subgraph rooted at src that shows all the paths to dst.""" |
| |
| result_graph = DiGraph() |
| # First compute forward transitive closure of src (all things reachable from src). |
| forward_reachable_from_src = self.forward_transitive_closure(src) |
| |
| if dst not in forward_reachable_from_src: |
| return result_graph |
| |
| # Second walk the reverse dependencies of dst, adding each node to |
| # the output graph iff it is also present in forward_reachable_from_src. |
| # we don't use backward_transitive_closures for optimization purposes |
| working_set = deque(dst) |
| while len(working_set) > 0: |
| cur = working_set.popleft() |
| for n in self.predecessors(cur): |
| if n in forward_reachable_from_src: |
| result_graph.add_edge(n, cur) |
| # only explore further if its reachable from src |
| working_set.append(n) |
| |
| return result_graph.to_dot() |
| |
| def first_path(self, dst: str) -> List[str]: |
| """Returns a list of nodes that show the first path that resulted in dst being added to the graph.""" |
| path = [] |
| |
| while dst: |
| path.append(dst) |
| candidates = self._pred[dst].keys() |
| dst, min_idx = "", None |
| for candidate in candidates: |
| idx = self._node_order.get(candidate, None) |
| if idx is None: |
| break |
| if min_idx is None or idx < min_idx: |
| min_idx = idx |
| dst = candidate |
| |
| return list(reversed(path)) |
| |
| def to_dot(self) -> str: |
| """Returns the dot representation of the graph. |
| |
| Returns: |
| A dot representation of the graph. |
| """ |
| edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges) |
| return f"""\ |
| digraph G {{ |
| rankdir = LR; |
| node [shape=box]; |
| {edges} |
| }} |
| """ |