From 1391a76cd2b287daebe61f7d8ea03b258ed522f5 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 7 Jul 2022 13:19:31 +0100 Subject: Faster room joins: fix race in recalculation of current room state (#13151) Bounce recalculation of current state to the correct event persister and move recalculation of current state into the event persistence queue, to avoid concurrent updates to a room's current state. Also give recalculation of a room's current state a real stream ordering. Signed-off-by: Sean Quah --- synapse/handlers/federation.py | 9 +- synapse/replication/http/__init__.py | 2 + synapse/replication/http/state.py | 75 ++++++++++++++ synapse/state/__init__.py | 25 +++++ synapse/storage/controllers/persist_events.py | 141 ++++++++++++++++++-------- synapse/storage/databases/main/events.py | 14 +-- 6 files changed, 211 insertions(+), 55 deletions(-) create mode 100644 synapse/replication/http/state.py (limited to 'synapse') diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3c44b4bf86..e2564e9340 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1559,14 +1559,9 @@ class FederationHandler: # all the events are updated, so we can update current state and # clear the lazy-loading flag. logger.info("Updating current state for %s", room_id) - # TODO(faster_joins): support workers + # TODO(faster_joins): notify workers in notify_room_un_partial_stated # https://github.com/matrix-org/synapse/issues/12994 - assert ( - self._storage_controllers.persistence is not None - ), "worker-mode deployments not currently supported here" - await self._storage_controllers.persistence.update_current_state( - room_id - ) + await self.state_handler.update_current_state(room_id) logger.info("Clearing partial-state flag for %s", room_id) success = await self.store.clear_partial_state_room(room_id) diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index aec040ee19..53aa7fa4c6 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -25,6 +25,7 @@ from synapse.replication.http import ( push, register, send_event, + state, streams, ) @@ -48,6 +49,7 @@ class ReplicationRestResource(JsonResource): streams.register_servlets(hs, self) account_data.register_servlets(hs, self) push.register_servlets(hs, self) + state.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: diff --git a/synapse/replication/http/state.py b/synapse/replication/http/state.py new file mode 100644 index 0000000000..838b7584e5 --- /dev/null +++ b/synapse/replication/http/state.py @@ -0,0 +1,75 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.server import Request + +from synapse.api.errors import SynapseError +from synapse.http.server import HttpServer +from synapse.replication.http._base import ReplicationEndpoint +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint): + """Recalculates the current state for a room, and persists it. + + The API looks like: + + POST /_synapse/replication/update_current_state/:room_id + + {} + + 200 OK + + {} + """ + + NAME = "update_current_state" + PATH_ARGS = ("room_id",) + + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self._state_handler = hs.get_state_handler() + self._events_shard_config = hs.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() + + @staticmethod + async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override] + return {} + + async def _handle_request( # type: ignore[override] + self, request: Request, room_id: str + ) -> Tuple[int, JsonDict]: + writer_instance = self._events_shard_config.get_instance(room_id) + if writer_instance != self._instance_name: + raise SynapseError( + 400, "/update_current_state request was routed to the wrong worker" + ) + + await self._state_handler.update_current_state(room_id) + + return 200, {} + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.get_instance_name() in hs.config.worker.writers.events: + ReplicationUpdateCurrentStateRestServlet(hs).register(http_server) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index d5cbdb3eef..781d9f06da 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -43,6 +43,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersio from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import ContextResourceUsage +from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo @@ -129,6 +130,12 @@ class StateHandler: self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() self._storage_controllers = hs.get_storage_controllers() + self._events_shard_config = hs.config.worker.events_shard_config + self._instance_name = hs.get_instance_name() + + self._update_current_state_client = ( + ReplicationUpdateCurrentStateRestServlet.make_client(hs) + ) async def get_current_state_ids( self, @@ -423,6 +430,24 @@ class StateHandler: return {key: state_map[ev_id] for key, ev_id in new_state.items()} + async def update_current_state(self, room_id: str) -> None: + """Recalculates the current state for a room, and persists it. + + Raises: + SynapseError(502): if all attempts to connect to the event persister worker + fail + """ + writer_instance = self._events_shard_config.get_instance(room_id) + if writer_instance != self._instance_name: + await self._update_current_state_client( + instance_name=writer_instance, + room_id=room_id, + ) + return + + assert self._storage_controllers.persistence is not None + await self._storage_controllers.persistence.update_current_state(room_id) + @attr.s(slots=True, auto_attribs=True) class _StateResMetrics: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index c248fccc81..ea499ce0f8 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -22,6 +22,7 @@ from typing import ( Any, Awaitable, Callable, + ClassVar, Collection, Deque, Dict, @@ -33,6 +34,7 @@ from typing import ( Set, Tuple, TypeVar, + Union, ) import attr @@ -111,9 +113,43 @@ times_pruned_extremities = Counter( @attr.s(auto_attribs=True, slots=True) -class _EventPersistQueueItem: +class _PersistEventsTask: + """A batch of events to persist.""" + + name: ClassVar[str] = "persist_event_batch" # used for opentracing + events_and_contexts: List[Tuple[EventBase, EventContext]] backfilled: bool + + def try_merge(self, task: "_EventPersistQueueTask") -> bool: + """Batches events with the same backfilled option together.""" + if ( + not isinstance(task, _PersistEventsTask) + or self.backfilled != task.backfilled + ): + return False + + self.events_and_contexts.extend(task.events_and_contexts) + return True + + +@attr.s(auto_attribs=True, slots=True) +class _UpdateCurrentStateTask: + """A room whose current state needs recalculating.""" + + name: ClassVar[str] = "update_current_state" # used for opentracing + + def try_merge(self, task: "_EventPersistQueueTask") -> bool: + """Deduplicates consecutive recalculations of current state.""" + return isinstance(task, _UpdateCurrentStateTask) + + +_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask] + + +@attr.s(auto_attribs=True, slots=True) +class _EventPersistQueueItem: + task: _EventPersistQueueTask deferred: ObservableDeferred parent_opentracing_span_contexts: List = attr.ib(factory=list) @@ -127,14 +163,16 @@ _PersistResult = TypeVar("_PersistResult") class _EventPeristenceQueue(Generic[_PersistResult]): - """Queues up events so that they can be persisted in bulk with only one - concurrent transaction per room. + """Queues up tasks so that they can be processed with only one concurrent + transaction per room. + + Tasks can be bulk persistence of events or recalculation of a room's current state. """ def __init__( self, per_item_callback: Callable[ - [List[Tuple[EventBase, EventContext]], bool], + [str, _EventPersistQueueTask], Awaitable[_PersistResult], ], ): @@ -150,18 +188,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]): async def add_to_queue( self, room_id: str, - events_and_contexts: Iterable[Tuple[EventBase, EventContext]], - backfilled: bool, + task: _EventPersistQueueTask, ) -> _PersistResult: - """Add events to the queue, with the given persist_event options. + """Add a task to the queue. - If we are not already processing events in this room, starts off a background + If we are not already processing tasks in this room, starts off a background process to to so, calling the per_item_callback for each item. Args: room_id (str): - events_and_contexts (list[(EventBase, EventContext)]): - backfilled (bool): + task (_EventPersistQueueTask): A _PersistEventsTask or + _UpdateCurrentStateTask to process. Returns: the result returned by the `_per_item_callback` passed to @@ -169,26 +206,20 @@ class _EventPeristenceQueue(Generic[_PersistResult]): """ queue = self._event_persist_queues.setdefault(room_id, deque()) - # if the last item in the queue has the same `backfilled` setting, - # we can just add these new events to that item. - if queue and queue[-1].backfilled == backfilled: + if queue and queue[-1].task.try_merge(task): + # the new task has been merged into the last task in the queue end_item = queue[-1] else: - # need to make a new queue item deferred: ObservableDeferred[_PersistResult] = ObservableDeferred( defer.Deferred(), consumeErrors=True ) end_item = _EventPersistQueueItem( - events_and_contexts=[], - backfilled=backfilled, + task=task, deferred=deferred, ) queue.append(end_item) - # add our events to the queue item - end_item.events_and_contexts.extend(events_and_contexts) - # also add our active opentracing span to the item so that we get a link back span = opentracing.active_span() if span: @@ -202,7 +233,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): # add another opentracing span which links to the persist trace. with opentracing.start_active_span_follows_from( - "persist_event_batch_complete", (end_item.opentracing_span_context,) + f"{task.name}_complete", (end_item.opentracing_span_context,) ): pass @@ -234,16 +265,14 @@ class _EventPeristenceQueue(Generic[_PersistResult]): for item in queue: try: with opentracing.start_active_span_follows_from( - "persist_event_batch", + item.task.name, item.parent_opentracing_span_contexts, inherit_force_tracing=True, ) as scope: if scope: item.opentracing_span_context = scope.span.context - ret = await self._per_item_callback( - item.events_and_contexts, item.backfilled - ) + ret = await self._per_item_callback(room_id, item.task) except Exception: with PreserveLoggingContext(): item.deferred.errback() @@ -292,9 +321,32 @@ class EventsPersistenceStorageController: self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() self.is_mine_id = hs.is_mine_id - self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch) + self._event_persist_queue = _EventPeristenceQueue( + self._process_event_persist_queue_task + ) self._state_resolution_handler = hs.get_state_resolution_handler() + async def _process_event_persist_queue_task( + self, + room_id: str, + task: _EventPersistQueueTask, + ) -> Dict[str, str]: + """Callback for the _event_persist_queue + + Returns: + A dictionary of event ID to event ID we didn't persist as we already + had another event persisted with the same TXN ID. + """ + if isinstance(task, _PersistEventsTask): + return await self._persist_event_batch(room_id, task) + elif isinstance(task, _UpdateCurrentStateTask): + await self._update_current_state(room_id, task) + return {} + else: + raise AssertionError( + f"Found an unexpected task type in event persistence queue: {task}" + ) + @opentracing.trace async def persist_events( self, @@ -329,7 +381,8 @@ class EventsPersistenceStorageController: ) -> Dict[str, str]: room_id, evs_ctxs = item return await self._event_persist_queue.add_to_queue( - room_id, evs_ctxs, backfilled=backfilled + room_id, + _PersistEventsTask(events_and_contexts=evs_ctxs, backfilled=backfilled), ) ret_vals = await yieldable_gather_results(enqueue, partitioned.items()) @@ -376,7 +429,10 @@ class EventsPersistenceStorageController: # event was deduplicated. (The dict may also include other entries if # the event was persisted in a batch with other events.) replaced_events = await self._event_persist_queue.add_to_queue( - event.room_id, [(event, context)], backfilled=backfilled + event.room_id, + _PersistEventsTask( + events_and_contexts=[(event, context)], backfilled=backfilled + ), ) replaced_event = replaced_events.get(event.event_id) if replaced_event: @@ -391,20 +447,22 @@ class EventsPersistenceStorageController: async def update_current_state(self, room_id: str) -> None: """Recalculate the current state for a room, and persist it""" + await self._event_persist_queue.add_to_queue( + room_id, + _UpdateCurrentStateTask(), + ) + + async def _update_current_state( + self, room_id: str, _task: _UpdateCurrentStateTask + ) -> None: + """Callback for the _event_persist_queue + + Recalculates the current state for a room, and persists it. + """ state = await self._calculate_current_state(room_id) delta = await self._calculate_state_delta(room_id, state) - # TODO(faster_joins): get a real stream ordering, to make this work correctly - # across workers. - # https://github.com/matrix-org/synapse/issues/12994 - # - # TODO(faster_joins): this can race against event persistence, in which case we - # will end up with incorrect state. Perhaps we should make this a job we - # farm out to the event persister thread, somehow. - # https://github.com/matrix-org/synapse/issues/13007 - # - stream_id = self.main_store.get_room_max_stream_ordering() - await self.persist_events_store.update_current_state(room_id, delta, stream_id) + await self.persist_events_store.update_current_state(room_id, delta) async def _calculate_current_state(self, room_id: str) -> StateMap[str]: """Calculate the current state of a room, based on the forward extremities @@ -449,9 +507,7 @@ class EventsPersistenceStorageController: return res.state async def _persist_event_batch( - self, - events_and_contexts: List[Tuple[EventBase, EventContext]], - backfilled: bool = False, + self, _room_id: str, task: _PersistEventsTask ) -> Dict[str, str]: """Callback for the _event_persist_queue @@ -466,6 +522,9 @@ class EventsPersistenceStorageController: PartialStateConflictError: if attempting to persist a partial state event in a room that has been un-partial stated. """ + events_and_contexts = task.events_and_contexts + backfilled = task.backfilled + replaced_events: Dict[str, str] = {} if not events_and_contexts: return replaced_events diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 8a0e4e9589..2ff3d21305 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1007,16 +1007,16 @@ class PersistEventsStore: self, room_id: str, state_delta: DeltaState, - stream_id: int, ) -> None: """Update the current state stored in the datatabase for the given room""" - await self.db_pool.runInteraction( - "update_current_state", - self._update_current_state_txn, - state_delta_by_room={room_id: state_delta}, - stream_id=stream_id, - ) + async with self._stream_id_gen.get_next() as stream_ordering: + await self.db_pool.runInteraction( + "update_current_state", + self._update_current_state_txn, + state_delta_by_room={room_id: state_delta}, + stream_id=stream_ordering, + ) def _update_current_state_txn( self, -- cgit 1.4.1