summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13151.misc1
-rw-r--r--synapse/handlers/federation.py9
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/state.py75
-rw-r--r--synapse/state/__init__.py25
-rw-r--r--synapse/storage/controllers/persist_events.py141
-rw-r--r--synapse/storage/databases/main/events.py14
-rw-r--r--tests/test_state.py2
8 files changed, 214 insertions, 55 deletions
diff --git a/changelog.d/13151.misc b/changelog.d/13151.misc
new file mode 100644
index 0000000000..cfe3eed3a1
--- /dev/null
+++ b/changelog.d/13151.misc
@@ -0,0 +1 @@
+Faster room joins: fix race in recalculation of current room state.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3c44b4bf86..e2564e9340 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1559,14 +1559,9 @@ class FederationHandler:
                 # all the events are updated, so we can update current state and
                 # clear the lazy-loading flag.
                 logger.info("Updating current state for %s", room_id)
-                # TODO(faster_joins): support workers
+                # TODO(faster_joins): notify workers in notify_room_un_partial_stated
                 #   https://github.com/matrix-org/synapse/issues/12994
-                assert (
-                    self._storage_controllers.persistence is not None
-                ), "worker-mode deployments not currently supported here"
-                await self._storage_controllers.persistence.update_current_state(
-                    room_id
-                )
+                await self.state_handler.update_current_state(room_id)
 
                 logger.info("Clearing partial-state flag for %s", room_id)
                 success = await self.store.clear_partial_state_room(room_id)
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index aec040ee19..53aa7fa4c6 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -25,6 +25,7 @@ from synapse.replication.http import (
     push,
     register,
     send_event,
+    state,
     streams,
 )
 
@@ -48,6 +49,7 @@ class ReplicationRestResource(JsonResource):
         streams.register_servlets(hs, self)
         account_data.register_servlets(hs, self)
         push.register_servlets(hs, self)
+        state.register_servlets(hs, self)
 
         # The following can't currently be instantiated on workers.
         if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/state.py b/synapse/replication/http/state.py
new file mode 100644
index 0000000000..838b7584e5
--- /dev/null
+++ b/synapse/replication/http/state.py
@@ -0,0 +1,75 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
+
+from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint):
+    """Recalculates the current state for a room, and persists it.
+
+    The API looks like:
+
+        POST /_synapse/replication/update_current_state/:room_id
+
+        {}
+
+        200 OK
+
+        {}
+    """
+
+    NAME = "update_current_state"
+    PATH_ARGS = ("room_id",)
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self._state_handler = hs.get_state_handler()
+        self._events_shard_config = hs.config.worker.events_shard_config
+        self._instance_name = hs.get_instance_name()
+
+    @staticmethod
+    async def _serialize_payload(room_id: str) -> JsonDict:  # type: ignore[override]
+        return {}
+
+    async def _handle_request(  # type: ignore[override]
+        self, request: Request, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        writer_instance = self._events_shard_config.get_instance(room_id)
+        if writer_instance != self._instance_name:
+            raise SynapseError(
+                400, "/update_current_state request was routed to the wrong worker"
+            )
+
+        await self._state_handler.update_current_state(room_id)
+
+        return 200, {}
+
+
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
+    if hs.get_instance_name() in hs.config.worker.writers.events:
+        ReplicationUpdateCurrentStateRestServlet(hs).register(http_server)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index d5cbdb3eef..781d9f06da 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -43,6 +43,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersio
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.logging.context import ContextResourceUsage
+from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
@@ -129,6 +130,12 @@ class StateHandler:
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
         self._storage_controllers = hs.get_storage_controllers()
+        self._events_shard_config = hs.config.worker.events_shard_config
+        self._instance_name = hs.get_instance_name()
+
+        self._update_current_state_client = (
+            ReplicationUpdateCurrentStateRestServlet.make_client(hs)
+        )
 
     async def get_current_state_ids(
         self,
@@ -423,6 +430,24 @@ class StateHandler:
 
         return {key: state_map[ev_id] for key, ev_id in new_state.items()}
 
+    async def update_current_state(self, room_id: str) -> None:
+        """Recalculates the current state for a room, and persists it.
+
+        Raises:
+            SynapseError(502): if all attempts to connect to the event persister worker
+                fail
+        """
+        writer_instance = self._events_shard_config.get_instance(room_id)
+        if writer_instance != self._instance_name:
+            await self._update_current_state_client(
+                instance_name=writer_instance,
+                room_id=room_id,
+            )
+            return
+
+        assert self._storage_controllers.persistence is not None
+        await self._storage_controllers.persistence.update_current_state(room_id)
+
 
 @attr.s(slots=True, auto_attribs=True)
 class _StateResMetrics:
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index c248fccc81..ea499ce0f8 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -22,6 +22,7 @@ from typing import (
     Any,
     Awaitable,
     Callable,
+    ClassVar,
     Collection,
     Deque,
     Dict,
@@ -33,6 +34,7 @@ from typing import (
     Set,
     Tuple,
     TypeVar,
+    Union,
 )
 
 import attr
@@ -111,9 +113,43 @@ times_pruned_extremities = Counter(
 
 
 @attr.s(auto_attribs=True, slots=True)
-class _EventPersistQueueItem:
+class _PersistEventsTask:
+    """A batch of events to persist."""
+
+    name: ClassVar[str] = "persist_event_batch"  # used for opentracing
+
     events_and_contexts: List[Tuple[EventBase, EventContext]]
     backfilled: bool
+
+    def try_merge(self, task: "_EventPersistQueueTask") -> bool:
+        """Batches events with the same backfilled option together."""
+        if (
+            not isinstance(task, _PersistEventsTask)
+            or self.backfilled != task.backfilled
+        ):
+            return False
+
+        self.events_and_contexts.extend(task.events_and_contexts)
+        return True
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _UpdateCurrentStateTask:
+    """A room whose current state needs recalculating."""
+
+    name: ClassVar[str] = "update_current_state"  # used for opentracing
+
+    def try_merge(self, task: "_EventPersistQueueTask") -> bool:
+        """Deduplicates consecutive recalculations of current state."""
+        return isinstance(task, _UpdateCurrentStateTask)
+
+
+_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _EventPersistQueueItem:
+    task: _EventPersistQueueTask
     deferred: ObservableDeferred
 
     parent_opentracing_span_contexts: List = attr.ib(factory=list)
@@ -127,14 +163,16 @@ _PersistResult = TypeVar("_PersistResult")
 
 
 class _EventPeristenceQueue(Generic[_PersistResult]):
-    """Queues up events so that they can be persisted in bulk with only one
-    concurrent transaction per room.
+    """Queues up tasks so that they can be processed with only one concurrent
+    transaction per room.
+
+    Tasks can be bulk persistence of events or recalculation of a room's current state.
     """
 
     def __init__(
         self,
         per_item_callback: Callable[
-            [List[Tuple[EventBase, EventContext]], bool],
+            [str, _EventPersistQueueTask],
             Awaitable[_PersistResult],
         ],
     ):
@@ -150,18 +188,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
     async def add_to_queue(
         self,
         room_id: str,
-        events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
-        backfilled: bool,
+        task: _EventPersistQueueTask,
     ) -> _PersistResult:
-        """Add events to the queue, with the given persist_event options.
+        """Add a task to the queue.
 
-        If we are not already processing events in this room, starts off a background
+        If we are not already processing tasks in this room, starts off a background
         process to to so, calling the per_item_callback for each item.
 
         Args:
             room_id (str):
-            events_and_contexts (list[(EventBase, EventContext)]):
-            backfilled (bool):
+            task (_EventPersistQueueTask): A _PersistEventsTask or
+                _UpdateCurrentStateTask to process.
 
         Returns:
             the result returned by the `_per_item_callback` passed to
@@ -169,26 +206,20 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
 
-        # if the last item in the queue has the same `backfilled` setting,
-        # we can just add these new events to that item.
-        if queue and queue[-1].backfilled == backfilled:
+        if queue and queue[-1].task.try_merge(task):
+            # the new task has been merged into the last task in the queue
             end_item = queue[-1]
         else:
-            # need to make a new queue item
             deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
                 defer.Deferred(), consumeErrors=True
             )
 
             end_item = _EventPersistQueueItem(
-                events_and_contexts=[],
-                backfilled=backfilled,
+                task=task,
                 deferred=deferred,
             )
             queue.append(end_item)
 
-        # add our events to the queue item
-        end_item.events_and_contexts.extend(events_and_contexts)
-
         # also add our active opentracing span to the item so that we get a link back
         span = opentracing.active_span()
         if span:
@@ -202,7 +233,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
 
         # add another opentracing span which links to the persist trace.
         with opentracing.start_active_span_follows_from(
-            "persist_event_batch_complete", (end_item.opentracing_span_context,)
+            f"{task.name}_complete", (end_item.opentracing_span_context,)
         ):
             pass
 
@@ -234,16 +265,14 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                 for item in queue:
                     try:
                         with opentracing.start_active_span_follows_from(
-                            "persist_event_batch",
+                            item.task.name,
                             item.parent_opentracing_span_contexts,
                             inherit_force_tracing=True,
                         ) as scope:
                             if scope:
                                 item.opentracing_span_context = scope.span.context
 
-                            ret = await self._per_item_callback(
-                                item.events_and_contexts, item.backfilled
-                            )
+                            ret = await self._per_item_callback(room_id, item.task)
                     except Exception:
                         with PreserveLoggingContext():
                             item.deferred.errback()
@@ -292,9 +321,32 @@ class EventsPersistenceStorageController:
         self._clock = hs.get_clock()
         self._instance_name = hs.get_instance_name()
         self.is_mine_id = hs.is_mine_id
-        self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch)
+        self._event_persist_queue = _EventPeristenceQueue(
+            self._process_event_persist_queue_task
+        )
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
+    async def _process_event_persist_queue_task(
+        self,
+        room_id: str,
+        task: _EventPersistQueueTask,
+    ) -> Dict[str, str]:
+        """Callback for the _event_persist_queue
+
+        Returns:
+            A dictionary of event ID to event ID we didn't persist as we already
+            had another event persisted with the same TXN ID.
+        """
+        if isinstance(task, _PersistEventsTask):
+            return await self._persist_event_batch(room_id, task)
+        elif isinstance(task, _UpdateCurrentStateTask):
+            await self._update_current_state(room_id, task)
+            return {}
+        else:
+            raise AssertionError(
+                f"Found an unexpected task type in event persistence queue: {task}"
+            )
+
     @opentracing.trace
     async def persist_events(
         self,
@@ -329,7 +381,8 @@ class EventsPersistenceStorageController:
         ) -> Dict[str, str]:
             room_id, evs_ctxs = item
             return await self._event_persist_queue.add_to_queue(
-                room_id, evs_ctxs, backfilled=backfilled
+                room_id,
+                _PersistEventsTask(events_and_contexts=evs_ctxs, backfilled=backfilled),
             )
 
         ret_vals = await yieldable_gather_results(enqueue, partitioned.items())
@@ -376,7 +429,10 @@ class EventsPersistenceStorageController:
         # event was deduplicated. (The dict may also include other entries if
         # the event was persisted in a batch with other events.)
         replaced_events = await self._event_persist_queue.add_to_queue(
-            event.room_id, [(event, context)], backfilled=backfilled
+            event.room_id,
+            _PersistEventsTask(
+                events_and_contexts=[(event, context)], backfilled=backfilled
+            ),
         )
         replaced_event = replaced_events.get(event.event_id)
         if replaced_event:
@@ -391,20 +447,22 @@ class EventsPersistenceStorageController:
 
     async def update_current_state(self, room_id: str) -> None:
         """Recalculate the current state for a room, and persist it"""
+        await self._event_persist_queue.add_to_queue(
+            room_id,
+            _UpdateCurrentStateTask(),
+        )
+
+    async def _update_current_state(
+        self, room_id: str, _task: _UpdateCurrentStateTask
+    ) -> None:
+        """Callback for the _event_persist_queue
+
+        Recalculates the current state for a room, and persists it.
+        """
         state = await self._calculate_current_state(room_id)
         delta = await self._calculate_state_delta(room_id, state)
 
-        # TODO(faster_joins): get a real stream ordering, to make this work correctly
-        #    across workers.
-        #    https://github.com/matrix-org/synapse/issues/12994
-        #
-        # TODO(faster_joins): this can race against event persistence, in which case we
-        #    will end up with incorrect state. Perhaps we should make this a job we
-        #    farm out to the event persister thread, somehow.
-        #    https://github.com/matrix-org/synapse/issues/13007
-        #
-        stream_id = self.main_store.get_room_max_stream_ordering()
-        await self.persist_events_store.update_current_state(room_id, delta, stream_id)
+        await self.persist_events_store.update_current_state(room_id, delta)
 
     async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
         """Calculate the current state of a room, based on the forward extremities
@@ -449,9 +507,7 @@ class EventsPersistenceStorageController:
         return res.state
 
     async def _persist_event_batch(
-        self,
-        events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool = False,
+        self, _room_id: str, task: _PersistEventsTask
     ) -> Dict[str, str]:
         """Callback for the _event_persist_queue
 
@@ -466,6 +522,9 @@ class EventsPersistenceStorageController:
             PartialStateConflictError: if attempting to persist a partial state event in
                 a room that has been un-partial stated.
         """
+        events_and_contexts = task.events_and_contexts
+        backfilled = task.backfilled
+
         replaced_events: Dict[str, str] = {}
         if not events_and_contexts:
             return replaced_events
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8a0e4e9589..2ff3d21305 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1007,16 +1007,16 @@ class PersistEventsStore:
         self,
         room_id: str,
         state_delta: DeltaState,
-        stream_id: int,
     ) -> None:
         """Update the current state stored in the datatabase for the given room"""
 
-        await self.db_pool.runInteraction(
-            "update_current_state",
-            self._update_current_state_txn,
-            state_delta_by_room={room_id: state_delta},
-            stream_id=stream_id,
-        )
+        async with self._stream_id_gen.get_next() as stream_ordering:
+            await self.db_pool.runInteraction(
+                "update_current_state",
+                self._update_current_state_txn,
+                state_delta_by_room={room_id: state_delta},
+                stream_id=stream_ordering,
+            )
 
     def _update_current_state_txn(
         self,
diff --git a/tests/test_state.py b/tests/test_state.py
index 7b3f52f68e..6ca8d8f21d 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -195,6 +195,8 @@ class StateTestCase(unittest.TestCase):
                 "get_state_resolution_handler",
                 "get_account_validity_handler",
                 "get_macaroon_generator",
+                "get_instance_name",
+                "get_simple_http_client",
                 "hostname",
             ]
         )