diff options
Diffstat (limited to 'synapse/util/iterutils.py')
-rw-r--r-- | synapse/util/iterutils.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py index a0efb96d3b..f4c0194af0 100644 --- a/synapse/util/iterutils.py +++ b/synapse/util/iterutils.py @@ -135,3 +135,54 @@ def sorted_topologically( degree_map[edge] -= 1 if degree_map[edge] == 0: heapq.heappush(zero_degree, edge) + + +def sorted_topologically_batched( + nodes: Iterable[T], + graph: Mapping[T, Collection[T]], +) -> Generator[Collection[T], None, None]: + r"""Walk the graph topologically, returning batches of nodes where all nodes + that references it have been previously returned. + + For example, given the following graph: + + A + / \ + B C + \ / + D + + This function will return: `[[A], [B, C], [D]]`. + + This function is useful for e.g. batch persisting events in an auth chain, + where we can only persist an event if all its auth events have already been + persisted. + """ + + degree_map = {node: 0 for node in nodes} + reverse_graph: Dict[T, Set[T]] = {} + + for node, edges in graph.items(): + if node not in degree_map: + continue + + for edge in set(edges): + if edge in degree_map: + degree_map[node] += 1 + + reverse_graph.setdefault(edge, set()).add(node) + reverse_graph.setdefault(node, set()) + + zero_degree = [node for node, degree in degree_map.items() if degree == 0] + + while zero_degree: + new_zero_degree = [] + for node in zero_degree: + for edge in reverse_graph.get(node, []): + if edge in degree_map: + degree_map[edge] -= 1 + if degree_map[edge] == 0: + new_zero_degree.append(edge) + + yield zero_degree + zero_degree = new_zero_degree |