summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/controllers/persist_events.py141
-rw-r--r--synapse/storage/databases/main/events.py14
2 files changed, 107 insertions, 48 deletions
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index c248fccc81..ea499ce0f8 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -22,6 +22,7 @@ from typing import (
     Any,
     Awaitable,
     Callable,
+    ClassVar,
     Collection,
     Deque,
     Dict,
@@ -33,6 +34,7 @@ from typing import (
     Set,
     Tuple,
     TypeVar,
+    Union,
 )
 
 import attr
@@ -111,9 +113,43 @@ times_pruned_extremities = Counter(
 
 
 @attr.s(auto_attribs=True, slots=True)
-class _EventPersistQueueItem:
+class _PersistEventsTask:
+    """A batch of events to persist."""
+
+    name: ClassVar[str] = "persist_event_batch"  # used for opentracing
+
     events_and_contexts: List[Tuple[EventBase, EventContext]]
     backfilled: bool
+
+    def try_merge(self, task: "_EventPersistQueueTask") -> bool:
+        """Batches events with the same backfilled option together."""
+        if (
+            not isinstance(task, _PersistEventsTask)
+            or self.backfilled != task.backfilled
+        ):
+            return False
+
+        self.events_and_contexts.extend(task.events_and_contexts)
+        return True
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _UpdateCurrentStateTask:
+    """A room whose current state needs recalculating."""
+
+    name: ClassVar[str] = "update_current_state"  # used for opentracing
+
+    def try_merge(self, task: "_EventPersistQueueTask") -> bool:
+        """Deduplicates consecutive recalculations of current state."""
+        return isinstance(task, _UpdateCurrentStateTask)
+
+
+_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _EventPersistQueueItem:
+    task: _EventPersistQueueTask
     deferred: ObservableDeferred
 
     parent_opentracing_span_contexts: List = attr.ib(factory=list)
@@ -127,14 +163,16 @@ _PersistResult = TypeVar("_PersistResult")
 
 
 class _EventPeristenceQueue(Generic[_PersistResult]):
-    """Queues up events so that they can be persisted in bulk with only one
-    concurrent transaction per room.
+    """Queues up tasks so that they can be processed with only one concurrent
+    transaction per room.
+
+    Tasks can be bulk persistence of events or recalculation of a room's current state.
     """
 
     def __init__(
         self,
         per_item_callback: Callable[
-            [List[Tuple[EventBase, EventContext]], bool],
+            [str, _EventPersistQueueTask],
             Awaitable[_PersistResult],
         ],
     ):
@@ -150,18 +188,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
     async def add_to_queue(
         self,
         room_id: str,
-        events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
-        backfilled: bool,
+        task: _EventPersistQueueTask,
     ) -> _PersistResult:
-        """Add events to the queue, with the given persist_event options.
+        """Add a task to the queue.
 
-        If we are not already processing events in this room, starts off a background
+        If we are not already processing tasks in this room, starts off a background
         process to to so, calling the per_item_callback for each item.
 
         Args:
             room_id (str):
-            events_and_contexts (list[(EventBase, EventContext)]):
-            backfilled (bool):
+            task (_EventPersistQueueTask): A _PersistEventsTask or
+                _UpdateCurrentStateTask to process.
 
         Returns:
             the result returned by the `_per_item_callback` passed to
@@ -169,26 +206,20 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
         """
         queue = self._event_persist_queues.setdefault(room_id, deque())
 
-        # if the last item in the queue has the same `backfilled` setting,
-        # we can just add these new events to that item.
-        if queue and queue[-1].backfilled == backfilled:
+        if queue and queue[-1].task.try_merge(task):
+            # the new task has been merged into the last task in the queue
             end_item = queue[-1]
         else:
-            # need to make a new queue item
             deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
                 defer.Deferred(), consumeErrors=True
             )
 
             end_item = _EventPersistQueueItem(
-                events_and_contexts=[],
-                backfilled=backfilled,
+                task=task,
                 deferred=deferred,
             )
             queue.append(end_item)
 
-        # add our events to the queue item
-        end_item.events_and_contexts.extend(events_and_contexts)
-
         # also add our active opentracing span to the item so that we get a link back
         span = opentracing.active_span()
         if span:
@@ -202,7 +233,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
 
         # add another opentracing span which links to the persist trace.
         with opentracing.start_active_span_follows_from(
-            "persist_event_batch_complete", (end_item.opentracing_span_context,)
+            f"{task.name}_complete", (end_item.opentracing_span_context,)
         ):
             pass
 
@@ -234,16 +265,14 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                 for item in queue:
                     try:
                         with opentracing.start_active_span_follows_from(
-                            "persist_event_batch",
+                            item.task.name,
                             item.parent_opentracing_span_contexts,
                             inherit_force_tracing=True,
                         ) as scope:
                             if scope:
                                 item.opentracing_span_context = scope.span.context
 
-                            ret = await self._per_item_callback(
-                                item.events_and_contexts, item.backfilled
-                            )
+                            ret = await self._per_item_callback(room_id, item.task)
                     except Exception:
                         with PreserveLoggingContext():
                             item.deferred.errback()
@@ -292,9 +321,32 @@ class EventsPersistenceStorageController:
         self._clock = hs.get_clock()
         self._instance_name = hs.get_instance_name()
         self.is_mine_id = hs.is_mine_id
-        self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch)
+        self._event_persist_queue = _EventPeristenceQueue(
+            self._process_event_persist_queue_task
+        )
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
+    async def _process_event_persist_queue_task(
+        self,
+        room_id: str,
+        task: _EventPersistQueueTask,
+    ) -> Dict[str, str]:
+        """Callback for the _event_persist_queue
+
+        Returns:
+            A dictionary of event ID to event ID we didn't persist as we already
+            had another event persisted with the same TXN ID.
+        """
+        if isinstance(task, _PersistEventsTask):
+            return await self._persist_event_batch(room_id, task)
+        elif isinstance(task, _UpdateCurrentStateTask):
+            await self._update_current_state(room_id, task)
+            return {}
+        else:
+            raise AssertionError(
+                f"Found an unexpected task type in event persistence queue: {task}"
+            )
+
     @opentracing.trace
     async def persist_events(
         self,
@@ -329,7 +381,8 @@ class EventsPersistenceStorageController:
         ) -> Dict[str, str]:
             room_id, evs_ctxs = item
             return await self._event_persist_queue.add_to_queue(
-                room_id, evs_ctxs, backfilled=backfilled
+                room_id,
+                _PersistEventsTask(events_and_contexts=evs_ctxs, backfilled=backfilled),
             )
 
         ret_vals = await yieldable_gather_results(enqueue, partitioned.items())
@@ -376,7 +429,10 @@ class EventsPersistenceStorageController:
         # event was deduplicated. (The dict may also include other entries if
         # the event was persisted in a batch with other events.)
         replaced_events = await self._event_persist_queue.add_to_queue(
-            event.room_id, [(event, context)], backfilled=backfilled
+            event.room_id,
+            _PersistEventsTask(
+                events_and_contexts=[(event, context)], backfilled=backfilled
+            ),
         )
         replaced_event = replaced_events.get(event.event_id)
         if replaced_event:
@@ -391,20 +447,22 @@ class EventsPersistenceStorageController:
 
     async def update_current_state(self, room_id: str) -> None:
         """Recalculate the current state for a room, and persist it"""
+        await self._event_persist_queue.add_to_queue(
+            room_id,
+            _UpdateCurrentStateTask(),
+        )
+
+    async def _update_current_state(
+        self, room_id: str, _task: _UpdateCurrentStateTask
+    ) -> None:
+        """Callback for the _event_persist_queue
+
+        Recalculates the current state for a room, and persists it.
+        """
         state = await self._calculate_current_state(room_id)
         delta = await self._calculate_state_delta(room_id, state)
 
-        # TODO(faster_joins): get a real stream ordering, to make this work correctly
-        #    across workers.
-        #    https://github.com/matrix-org/synapse/issues/12994
-        #
-        # TODO(faster_joins): this can race against event persistence, in which case we
-        #    will end up with incorrect state. Perhaps we should make this a job we
-        #    farm out to the event persister thread, somehow.
-        #    https://github.com/matrix-org/synapse/issues/13007
-        #
-        stream_id = self.main_store.get_room_max_stream_ordering()
-        await self.persist_events_store.update_current_state(room_id, delta, stream_id)
+        await self.persist_events_store.update_current_state(room_id, delta)
 
     async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
         """Calculate the current state of a room, based on the forward extremities
@@ -449,9 +507,7 @@ class EventsPersistenceStorageController:
         return res.state
 
     async def _persist_event_batch(
-        self,
-        events_and_contexts: List[Tuple[EventBase, EventContext]],
-        backfilled: bool = False,
+        self, _room_id: str, task: _PersistEventsTask
     ) -> Dict[str, str]:
         """Callback for the _event_persist_queue
 
@@ -466,6 +522,9 @@ class EventsPersistenceStorageController:
             PartialStateConflictError: if attempting to persist a partial state event in
                 a room that has been un-partial stated.
         """
+        events_and_contexts = task.events_and_contexts
+        backfilled = task.backfilled
+
         replaced_events: Dict[str, str] = {}
         if not events_and_contexts:
             return replaced_events
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8a0e4e9589..2ff3d21305 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1007,16 +1007,16 @@ class PersistEventsStore:
         self,
         room_id: str,
         state_delta: DeltaState,
-        stream_id: int,
     ) -> None:
         """Update the current state stored in the datatabase for the given room"""
 
-        await self.db_pool.runInteraction(
-            "update_current_state",
-            self._update_current_state_txn,
-            state_delta_by_room={room_id: state_delta},
-            stream_id=stream_id,
-        )
+        async with self._stream_id_gen.get_next() as stream_ordering:
+            await self.db_pool.runInteraction(
+                "update_current_state",
+                self._update_current_state_txn,
+                state_delta_by_room={room_id: state_delta},
+                stream_id=stream_ordering,
+            )
 
     def _update_current_state_txn(
         self,