summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/7735.bugfix1
-rw-r--r--synapse/handlers/federation.py1
-rw-r--r--synapse/state/__init__.py6
-rw-r--r--synapse/state/v2.py56
-rw-r--r--tests/state/test_v2.py9
5 files changed, 62 insertions, 11 deletions
diff --git a/changelog.d/7735.bugfix b/changelog.d/7735.bugfix
new file mode 100644
index 0000000000..86959a5ca4
--- /dev/null
+++ b/changelog.d/7735.bugfix
@@ -0,0 +1 @@
+Fix large state resolutions from stalling Synapse for seconds at a time.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 873f6bc39f..3828ff0ef0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -376,6 +376,7 @@ class FederationHandler(BaseHandler):
 
                     room_version = await self.store.get_room_version_id(room_id)
                     state_map = await resolve_events_with_store(
+                        self.clock,
                         room_id,
                         room_version,
                         state_maps,
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 50fd843f66..495d9f04c8 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -32,6 +32,7 @@ from synapse.logging.utils import log_function
 from synapse.state import v1, v2
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.types import StateMap
+from synapse.util import Clock
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.metrics import Measure, measure_func
@@ -414,6 +415,7 @@ class StateHandler(object):
 
         with Measure(self.clock, "state._resolve_events"):
             new_state = yield resolve_events_with_store(
+                self.clock,
                 event.room_id,
                 room_version,
                 state_set_ids,
@@ -516,6 +518,7 @@ class StateResolutionHandler(object):
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_store(
+                        self.clock,
                         room_id,
                         room_version,
                         list(state_groups_ids.values()),
@@ -589,6 +592,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
 
 
 def resolve_events_with_store(
+    clock: Clock,
     room_id: str,
     room_version: str,
     state_sets: List[StateMap[str]],
@@ -625,7 +629,7 @@ def resolve_events_with_store(
         )
     else:
         return v2.resolve_events_with_store(
-            room_id, room_version, state_sets, event_map, state_res_store
+            clock, room_id, room_version, state_sets, event_map, state_res_store
         )
 
 
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 57eadce4e6..7181ecda9a 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -27,12 +27,20 @@ from synapse.api.errors import AuthError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.types import StateMap
+from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
 
+# We want to yield to the reactor occasionally during state res when dealing
+# with large data sets, so that we don't exhaust the reactor. This is done by
+# yielding to reactor during loops every N iterations.
+_YIELD_AFTER_ITERATIONS = 100
+
+
 @defer.inlineCallbacks
 def resolve_events_with_store(
+    clock: Clock,
     room_id: str,
     room_version: str,
     state_sets: List[StateMap[str]],
@@ -42,13 +50,11 @@ def resolve_events_with_store(
     """Resolves the state using the v2 state resolution algorithm
 
     Args:
+        clock
         room_id: the room we are working in
-
         room_version: The room version
-
         state_sets: List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
-
         event_map:
             a dict from event_id to event, for any events that we happen to
             have in flight (eg, those currently being persisted). This will be
@@ -113,7 +119,7 @@ def resolve_events_with_store(
     )
 
     sorted_power_events = yield _reverse_topological_power_sort(
-        room_id, power_events, event_map, state_res_store, full_conflicted_set
+        clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
     )
 
     logger.debug("sorted %d power events", len(sorted_power_events))
@@ -142,7 +148,7 @@ def resolve_events_with_store(
 
     pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
     leftover_events = yield _mainline_sort(
-        room_id, leftover_events, pl, event_map, state_res_store
+        clock, room_id, leftover_events, pl, event_map, state_res_store
     )
 
     logger.debug("resolving remaining events")
@@ -317,12 +323,13 @@ def _add_event_and_auth_chain_to_graph(
 
 @defer.inlineCallbacks
 def _reverse_topological_power_sort(
-    room_id, event_ids, event_map, state_res_store, auth_diff
+    clock, room_id, event_ids, event_map, state_res_store, auth_diff
 ):
     """Returns a list of the event_ids sorted by reverse topological ordering,
     and then by power level and origin_server_ts
 
     Args:
+        clock (Clock)
         room_id (str): the room we are working in
         event_ids (list[str]): The events to sort
         event_map (dict[str,FrozenEvent])
@@ -334,18 +341,28 @@ def _reverse_topological_power_sort(
     """
 
     graph = {}
-    for event_id in event_ids:
+    for idx, event_id in enumerate(event_ids, start=1):
         yield _add_event_and_auth_chain_to_graph(
             graph, room_id, event_id, event_map, state_res_store, auth_diff
         )
 
+        # We yield occasionally when we're working with large data sets to
+        # ensure that we don't block the reactor loop for too long.
+        if idx % _YIELD_AFTER_ITERATIONS == 0:
+            yield clock.sleep(0)
+
     event_to_pl = {}
-    for event_id in graph:
+    for idx, event_id in enumerate(graph, start=1):
         pl = yield _get_power_level_for_sender(
             room_id, event_id, event_map, state_res_store
         )
         event_to_pl[event_id] = pl
 
+        # We yield occasionally when we're working with large data sets to
+        # ensure that we don't block the reactor loop for too long.
+        if idx % _YIELD_AFTER_ITERATIONS == 0:
+            yield clock.sleep(0)
+
     def _get_power_order(event_id):
         ev = event_map[event_id]
         pl = event_to_pl[event_id]
@@ -423,12 +440,13 @@ def _iterative_auth_checks(
 
 @defer.inlineCallbacks
 def _mainline_sort(
-    room_id, event_ids, resolved_power_event_id, event_map, state_res_store
+    clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
 ):
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
 
     Args:
+        clock (Clock)
         room_id (str): room we're working in
         event_ids (list[str]): Events to sort
         resolved_power_event_id (str): The final resolved power level event ID
@@ -438,8 +456,14 @@ def _mainline_sort(
     Returns:
         Deferred[list[str]]: The sorted list
     """
+    if not event_ids:
+        # It's possible for there to be no event IDs here to sort, so we can
+        # skip calculating the mainline in that case.
+        return []
+
     mainline = []
     pl = resolved_power_event_id
+    idx = 0
     while pl:
         mainline.append(pl)
         pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
@@ -453,17 +477,29 @@ def _mainline_sort(
                 pl = aid
                 break
 
+        # We yield occasionally when we're working with large data sets to
+        # ensure that we don't block the reactor loop for too long.
+        if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0:
+            yield clock.sleep(0)
+
+        idx += 1
+
     mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
 
     event_ids = list(event_ids)
 
     order_map = {}
-    for ev_id in event_ids:
+    for idx, ev_id in enumerate(event_ids, start=1):
         depth = yield _get_mainline_depth_for_event(
             event_map[ev_id], mainline_map, event_map, state_res_store
         )
         order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
 
+        # We yield occasionally when we're working with large data sets to
+        # ensure that we don't block the reactor loop for too long.
+        if idx % _YIELD_AFTER_ITERATIONS == 0:
+            yield clock.sleep(0)
+
     event_ids.sort(key=lambda ev_id: order_map[ev_id])
 
     return event_ids
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index cdc347bc53..38f9b423ef 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -17,6 +17,8 @@ import itertools
 
 import attr
 
+from twisted.internet import defer
+
 from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.event_auth import auth_types_for_event
@@ -41,6 +43,11 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
 ORIGIN_SERVER_TS = 0
 
 
+class FakeClock:
+    def sleep(self, msec):
+        return defer.succeed(None)
+
+
 class FakeEvent(object):
     """A fake event we use as a convenience.
 
@@ -417,6 +424,7 @@ class StateTestCase(unittest.TestCase):
                 state_before = dict(state_at_event[prev_events[0]])
             else:
                 state_d = resolve_events_with_store(
+                    FakeClock(),
                     ROOM_ID,
                     RoomVersions.V2.identifier,
                     [state_at_event[n] for n in prev_events],
@@ -565,6 +573,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
         # Test that we correctly handle passing `None` as the event_map
 
         state_d = resolve_events_with_store(
+            FakeClock(),
             ROOM_ID,
             RoomVersions.V2.identifier,
             [self.state_at_bob, self.state_at_charlie],