summary refs log tree commit diff
path: root/synapse/state/v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/v2.py')
-rw-r--r--synapse/state/v2.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 1b9d7d8457..1752f95db8 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -19,7 +19,6 @@ from typing import (
     Any,
     Awaitable,
     Callable,
-    Collection,
     Dict,
     Generator,
     Iterable,
@@ -39,7 +38,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateMap, StrCollection
 
 logger = logging.getLogger(__name__)
 
@@ -56,7 +55,7 @@ class StateResolutionStore(Protocol):
     # This is usually synapse.state.StateResolutionStore, but it's replaced with a
     # TestStateResolutionStore in tests.
     def get_events(
-        self, event_ids: Collection[str], allow_rejected: bool = False
+        self, event_ids: StrCollection, allow_rejected: bool = False
     ) -> Awaitable[Dict[str, EventBase]]:
         ...
 
@@ -366,7 +365,7 @@ async def _get_auth_chain_difference(
         union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
         intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
 
-        auth_difference_unpersisted_part: Collection[str] = union - intersection
+        auth_difference_unpersisted_part: StrCollection = union - intersection
     else:
         auth_difference_unpersisted_part = ()
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]
@@ -667,7 +666,7 @@ async def _mainline_sort(
     order_map = {}
     for idx, ev_id in enumerate(event_ids, start=1):
         depth = await _get_mainline_depth_for_event(
-            event_map[ev_id], mainline_map, event_map, state_res_store
+            clock, event_map[ev_id], mainline_map, event_map, state_res_store
         )
         order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
 
@@ -682,6 +681,7 @@ async def _mainline_sort(
 
 
 async def _get_mainline_depth_for_event(
+    clock: Clock,
     event: EventBase,
     mainline_map: Dict[str, int],
     event_map: Dict[str, EventBase],
@@ -704,6 +704,7 @@ async def _get_mainline_depth_for_event(
 
     # We do an iterative search, replacing `event with the power level in its
     # auth events (if any)
+    idx = 0
     while tmp_event:
         depth = mainline_map.get(tmp_event.event_id)
         if depth is not None:
@@ -720,6 +721,11 @@ async def _get_mainline_depth_for_event(
                 tmp_event = aev
                 break
 
+        idx += 1
+
+        if idx % _AWAIT_AFTER_ITERATIONS == 0:
+            await clock.sleep(0)
+
     # Didn't find a power level auth event, so we just return 0
     return 0