diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index c248fccc81..ea499ce0f8 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -22,6 +22,7 @@ from typing import (
Any,
Awaitable,
Callable,
+ ClassVar,
Collection,
Deque,
Dict,
@@ -33,6 +34,7 @@ from typing import (
Set,
Tuple,
TypeVar,
+ Union,
)
import attr
@@ -111,9 +113,43 @@ times_pruned_extremities = Counter(
@attr.s(auto_attribs=True, slots=True)
-class _EventPersistQueueItem:
+class _PersistEventsTask:
+ """A batch of events to persist."""
+
+ name: ClassVar[str] = "persist_event_batch" # used for opentracing
+
events_and_contexts: List[Tuple[EventBase, EventContext]]
backfilled: bool
+
+ def try_merge(self, task: "_EventPersistQueueTask") -> bool:
+ """Batches events with the same backfilled option together."""
+ if (
+ not isinstance(task, _PersistEventsTask)
+ or self.backfilled != task.backfilled
+ ):
+ return False
+
+ self.events_and_contexts.extend(task.events_and_contexts)
+ return True
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _UpdateCurrentStateTask:
+ """A room whose current state needs recalculating."""
+
+ name: ClassVar[str] = "update_current_state" # used for opentracing
+
+ def try_merge(self, task: "_EventPersistQueueTask") -> bool:
+ """Deduplicates consecutive recalculations of current state."""
+ return isinstance(task, _UpdateCurrentStateTask)
+
+
+_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _EventPersistQueueItem:
+ task: _EventPersistQueueTask
deferred: ObservableDeferred
parent_opentracing_span_contexts: List = attr.ib(factory=list)
@@ -127,14 +163,16 @@ _PersistResult = TypeVar("_PersistResult")
class _EventPeristenceQueue(Generic[_PersistResult]):
- """Queues up events so that they can be persisted in bulk with only one
- concurrent transaction per room.
+ """Queues up tasks so that they can be processed with only one concurrent
+ transaction per room.
+
+ Tasks can be bulk persistence of events or recalculation of a room's current state.
"""
def __init__(
self,
per_item_callback: Callable[
- [List[Tuple[EventBase, EventContext]], bool],
+ [str, _EventPersistQueueTask],
Awaitable[_PersistResult],
],
):
@@ -150,18 +188,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
async def add_to_queue(
self,
room_id: str,
- events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
- backfilled: bool,
+ task: _EventPersistQueueTask,
) -> _PersistResult:
- """Add events to the queue, with the given persist_event options.
+ """Add a task to the queue.
- If we are not already processing events in this room, starts off a background
+ If we are not already processing tasks in this room, starts off a background
process to to so, calling the per_item_callback for each item.
Args:
room_id (str):
- events_and_contexts (list[(EventBase, EventContext)]):
- backfilled (bool):
+ task (_EventPersistQueueTask): A _PersistEventsTask or
+ _UpdateCurrentStateTask to process.
Returns:
the result returned by the `_per_item_callback` passed to
@@ -169,26 +206,20 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
- # if the last item in the queue has the same `backfilled` setting,
- # we can just add these new events to that item.
- if queue and queue[-1].backfilled == backfilled:
+ if queue and queue[-1].task.try_merge(task):
+ # the new task has been merged into the last task in the queue
end_item = queue[-1]
else:
- # need to make a new queue item
deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
defer.Deferred(), consumeErrors=True
)
end_item = _EventPersistQueueItem(
- events_and_contexts=[],
- backfilled=backfilled,
+ task=task,
deferred=deferred,
)
queue.append(end_item)
- # add our events to the queue item
- end_item.events_and_contexts.extend(events_and_contexts)
-
# also add our active opentracing span to the item so that we get a link back
span = opentracing.active_span()
if span:
@@ -202,7 +233,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
# add another opentracing span which links to the persist trace.
with opentracing.start_active_span_follows_from(
- "persist_event_batch_complete", (end_item.opentracing_span_context,)
+ f"{task.name}_complete", (end_item.opentracing_span_context,)
):
pass
@@ -234,16 +265,14 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
for item in queue:
try:
with opentracing.start_active_span_follows_from(
- "persist_event_batch",
+ item.task.name,
item.parent_opentracing_span_contexts,
inherit_force_tracing=True,
) as scope:
if scope:
item.opentracing_span_context = scope.span.context
- ret = await self._per_item_callback(
- item.events_and_contexts, item.backfilled
- )
+ ret = await self._per_item_callback(room_id, item.task)
except Exception:
with PreserveLoggingContext():
item.deferred.errback()
@@ -292,9 +321,32 @@ class EventsPersistenceStorageController:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id
- self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch)
+ self._event_persist_queue = _EventPeristenceQueue(
+ self._process_event_persist_queue_task
+ )
self._state_resolution_handler = hs.get_state_resolution_handler()
+ async def _process_event_persist_queue_task(
+ self,
+ room_id: str,
+ task: _EventPersistQueueTask,
+ ) -> Dict[str, str]:
+ """Callback for the _event_persist_queue
+
+ Returns:
+ A dictionary of event ID to event ID we didn't persist as we already
+ had another event persisted with the same TXN ID.
+ """
+ if isinstance(task, _PersistEventsTask):
+ return await self._persist_event_batch(room_id, task)
+ elif isinstance(task, _UpdateCurrentStateTask):
+ await self._update_current_state(room_id, task)
+ return {}
+ else:
+ raise AssertionError(
+ f"Found an unexpected task type in event persistence queue: {task}"
+ )
+
@opentracing.trace
async def persist_events(
self,
@@ -329,7 +381,8 @@ class EventsPersistenceStorageController:
) -> Dict[str, str]:
room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue(
- room_id, evs_ctxs, backfilled=backfilled
+ room_id,
+ _PersistEventsTask(events_and_contexts=evs_ctxs, backfilled=backfilled),
)
ret_vals = await yieldable_gather_results(enqueue, partitioned.items())
@@ -376,7 +429,10 @@ class EventsPersistenceStorageController:
# event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events.)
replaced_events = await self._event_persist_queue.add_to_queue(
- event.room_id, [(event, context)], backfilled=backfilled
+ event.room_id,
+ _PersistEventsTask(
+ events_and_contexts=[(event, context)], backfilled=backfilled
+ ),
)
replaced_event = replaced_events.get(event.event_id)
if replaced_event:
@@ -391,20 +447,22 @@ class EventsPersistenceStorageController:
async def update_current_state(self, room_id: str) -> None:
"""Recalculate the current state for a room, and persist it"""
+ await self._event_persist_queue.add_to_queue(
+ room_id,
+ _UpdateCurrentStateTask(),
+ )
+
+ async def _update_current_state(
+ self, room_id: str, _task: _UpdateCurrentStateTask
+ ) -> None:
+ """Callback for the _event_persist_queue
+
+ Recalculates the current state for a room, and persists it.
+ """
state = await self._calculate_current_state(room_id)
delta = await self._calculate_state_delta(room_id, state)
- # TODO(faster_joins): get a real stream ordering, to make this work correctly
- # across workers.
- # https://github.com/matrix-org/synapse/issues/12994
- #
- # TODO(faster_joins): this can race against event persistence, in which case we
- # will end up with incorrect state. Perhaps we should make this a job we
- # farm out to the event persister thread, somehow.
- # https://github.com/matrix-org/synapse/issues/13007
- #
- stream_id = self.main_store.get_room_max_stream_ordering()
- await self.persist_events_store.update_current_state(room_id, delta, stream_id)
+ await self.persist_events_store.update_current_state(room_id, delta)
async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
"""Calculate the current state of a room, based on the forward extremities
@@ -449,9 +507,7 @@ class EventsPersistenceStorageController:
return res.state
async def _persist_event_batch(
- self,
- events_and_contexts: List[Tuple[EventBase, EventContext]],
- backfilled: bool = False,
+ self, _room_id: str, task: _PersistEventsTask
) -> Dict[str, str]:
"""Callback for the _event_persist_queue
@@ -466,6 +522,9 @@ class EventsPersistenceStorageController:
PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
+ events_and_contexts = task.events_and_contexts
+ backfilled = task.backfilled
+
replaced_events: Dict[str, str] = {}
if not events_and_contexts:
return replaced_events
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8a0e4e9589..2ff3d21305 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1007,16 +1007,16 @@ class PersistEventsStore:
self,
room_id: str,
state_delta: DeltaState,
- stream_id: int,
) -> None:
"""Update the current state stored in the datatabase for the given room"""
- await self.db_pool.runInteraction(
- "update_current_state",
- self._update_current_state_txn,
- state_delta_by_room={room_id: state_delta},
- stream_id=stream_id,
- )
+ async with self._stream_id_gen.get_next() as stream_ordering:
+ await self.db_pool.runInteraction(
+ "update_current_state",
+ self._update_current_state_txn,
+ state_delta_by_room={room_id: state_delta},
+ stream_id=stream_ordering,
+ )
def _update_current_state_txn(
self,
|