diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3c44b4bf86..e2564e9340 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1559,14 +1559,9 @@ class FederationHandler:
# all the events are updated, so we can update current state and
# clear the lazy-loading flag.
logger.info("Updating current state for %s", room_id)
- # TODO(faster_joins): support workers
+ # TODO(faster_joins): notify workers in notify_room_un_partial_stated
# https://github.com/matrix-org/synapse/issues/12994
- assert (
- self._storage_controllers.persistence is not None
- ), "worker-mode deployments not currently supported here"
- await self._storage_controllers.persistence.update_current_state(
- room_id
- )
+ await self.state_handler.update_current_state(room_id)
logger.info("Clearing partial-state flag for %s", room_id)
success = await self.store.clear_partial_state_room(room_id)
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index aec040ee19..53aa7fa4c6 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -25,6 +25,7 @@ from synapse.replication.http import (
push,
register,
send_event,
+ state,
streams,
)
@@ -48,6 +49,7 @@ class ReplicationRestResource(JsonResource):
streams.register_servlets(hs, self)
account_data.register_servlets(hs, self)
push.register_servlets(hs, self)
+ state.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/state.py b/synapse/replication/http/state.py
new file mode 100644
index 0000000000..838b7584e5
--- /dev/null
+++ b/synapse/replication/http/state.py
@@ -0,0 +1,75 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.server import Request
+
+from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint):
+ """Recalculates the current state for a room, and persists it.
+
+ The API looks like:
+
+ POST /_synapse/replication/update_current_state/:room_id
+
+ {}
+
+ 200 OK
+
+ {}
+ """
+
+ NAME = "update_current_state"
+ PATH_ARGS = ("room_id",)
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self._state_handler = hs.get_state_handler()
+ self._events_shard_config = hs.config.worker.events_shard_config
+ self._instance_name = hs.get_instance_name()
+
+ @staticmethod
+ async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override]
+ return {}
+
+ async def _handle_request( # type: ignore[override]
+ self, request: Request, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ writer_instance = self._events_shard_config.get_instance(room_id)
+ if writer_instance != self._instance_name:
+ raise SynapseError(
+ 400, "/update_current_state request was routed to the wrong worker"
+ )
+
+ await self._state_handler.update_current_state(room_id)
+
+ return 200, {}
+
+
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ ReplicationUpdateCurrentStateRestServlet(hs).register(http_server)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index d5cbdb3eef..781d9f06da 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -43,6 +43,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersio
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import ContextResourceUsage
+from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
@@ -129,6 +130,12 @@ class StateHandler:
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage_controllers = hs.get_storage_controllers()
+ self._events_shard_config = hs.config.worker.events_shard_config
+ self._instance_name = hs.get_instance_name()
+
+ self._update_current_state_client = (
+ ReplicationUpdateCurrentStateRestServlet.make_client(hs)
+ )
async def get_current_state_ids(
self,
@@ -423,6 +430,24 @@ class StateHandler:
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
+ async def update_current_state(self, room_id: str) -> None:
+ """Recalculates the current state for a room, and persists it.
+
+ Raises:
+ SynapseError(502): if all attempts to connect to the event persister worker
+ fail
+ """
+ writer_instance = self._events_shard_config.get_instance(room_id)
+ if writer_instance != self._instance_name:
+ await self._update_current_state_client(
+ instance_name=writer_instance,
+ room_id=room_id,
+ )
+ return
+
+ assert self._storage_controllers.persistence is not None
+ await self._storage_controllers.persistence.update_current_state(room_id)
+
@attr.s(slots=True, auto_attribs=True)
class _StateResMetrics:
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index c248fccc81..ea499ce0f8 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -22,6 +22,7 @@ from typing import (
Any,
Awaitable,
Callable,
+ ClassVar,
Collection,
Deque,
Dict,
@@ -33,6 +34,7 @@ from typing import (
Set,
Tuple,
TypeVar,
+ Union,
)
import attr
@@ -111,9 +113,43 @@ times_pruned_extremities = Counter(
@attr.s(auto_attribs=True, slots=True)
-class _EventPersistQueueItem:
+class _PersistEventsTask:
+ """A batch of events to persist."""
+
+ name: ClassVar[str] = "persist_event_batch" # used for opentracing
+
events_and_contexts: List[Tuple[EventBase, EventContext]]
backfilled: bool
+
+ def try_merge(self, task: "_EventPersistQueueTask") -> bool:
+ """Batches events with the same backfilled option together."""
+ if (
+ not isinstance(task, _PersistEventsTask)
+ or self.backfilled != task.backfilled
+ ):
+ return False
+
+ self.events_and_contexts.extend(task.events_and_contexts)
+ return True
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _UpdateCurrentStateTask:
+ """A room whose current state needs recalculating."""
+
+ name: ClassVar[str] = "update_current_state" # used for opentracing
+
+ def try_merge(self, task: "_EventPersistQueueTask") -> bool:
+ """Deduplicates consecutive recalculations of current state."""
+ return isinstance(task, _UpdateCurrentStateTask)
+
+
+_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
+
+
+@attr.s(auto_attribs=True, slots=True)
+class _EventPersistQueueItem:
+ task: _EventPersistQueueTask
deferred: ObservableDeferred
parent_opentracing_span_contexts: List = attr.ib(factory=list)
@@ -127,14 +163,16 @@ _PersistResult = TypeVar("_PersistResult")
class _EventPeristenceQueue(Generic[_PersistResult]):
- """Queues up events so that they can be persisted in bulk with only one
- concurrent transaction per room.
+ """Queues up tasks so that they can be processed with only one concurrent
+ transaction per room.
+
+ Tasks can be bulk persistence of events or recalculation of a room's current state.
"""
def __init__(
self,
per_item_callback: Callable[
- [List[Tuple[EventBase, EventContext]], bool],
+ [str, _EventPersistQueueTask],
Awaitable[_PersistResult],
],
):
@@ -150,18 +188,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
async def add_to_queue(
self,
room_id: str,
- events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
- backfilled: bool,
+ task: _EventPersistQueueTask,
) -> _PersistResult:
- """Add events to the queue, with the given persist_event options.
+ """Add a task to the queue.
- If we are not already processing events in this room, starts off a background
+ If we are not already processing tasks in this room, starts off a background
process to to so, calling the per_item_callback for each item.
Args:
room_id (str):
- events_and_contexts (list[(EventBase, EventContext)]):
- backfilled (bool):
+ task (_EventPersistQueueTask): A _PersistEventsTask or
+ _UpdateCurrentStateTask to process.
Returns:
the result returned by the `_per_item_callback` passed to
@@ -169,26 +206,20 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
- # if the last item in the queue has the same `backfilled` setting,
- # we can just add these new events to that item.
- if queue and queue[-1].backfilled == backfilled:
+ if queue and queue[-1].task.try_merge(task):
+ # the new task has been merged into the last task in the queue
end_item = queue[-1]
else:
- # need to make a new queue item
deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
defer.Deferred(), consumeErrors=True
)
end_item = _EventPersistQueueItem(
- events_and_contexts=[],
- backfilled=backfilled,
+ task=task,
deferred=deferred,
)
queue.append(end_item)
- # add our events to the queue item
- end_item.events_and_contexts.extend(events_and_contexts)
-
# also add our active opentracing span to the item so that we get a link back
span = opentracing.active_span()
if span:
@@ -202,7 +233,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
# add another opentracing span which links to the persist trace.
with opentracing.start_active_span_follows_from(
- "persist_event_batch_complete", (end_item.opentracing_span_context,)
+ f"{task.name}_complete", (end_item.opentracing_span_context,)
):
pass
@@ -234,16 +265,14 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
for item in queue:
try:
with opentracing.start_active_span_follows_from(
- "persist_event_batch",
+ item.task.name,
item.parent_opentracing_span_contexts,
inherit_force_tracing=True,
) as scope:
if scope:
item.opentracing_span_context = scope.span.context
- ret = await self._per_item_callback(
- item.events_and_contexts, item.backfilled
- )
+ ret = await self._per_item_callback(room_id, item.task)
except Exception:
with PreserveLoggingContext():
item.deferred.errback()
@@ -292,9 +321,32 @@ class EventsPersistenceStorageController:
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id
- self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch)
+ self._event_persist_queue = _EventPeristenceQueue(
+ self._process_event_persist_queue_task
+ )
self._state_resolution_handler = hs.get_state_resolution_handler()
+ async def _process_event_persist_queue_task(
+ self,
+ room_id: str,
+ task: _EventPersistQueueTask,
+ ) -> Dict[str, str]:
+ """Callback for the _event_persist_queue
+
+ Returns:
+ A dictionary of event ID to event ID we didn't persist as we already
+ had another event persisted with the same TXN ID.
+ """
+ if isinstance(task, _PersistEventsTask):
+ return await self._persist_event_batch(room_id, task)
+ elif isinstance(task, _UpdateCurrentStateTask):
+ await self._update_current_state(room_id, task)
+ return {}
+ else:
+ raise AssertionError(
+ f"Found an unexpected task type in event persistence queue: {task}"
+ )
+
@opentracing.trace
async def persist_events(
self,
@@ -329,7 +381,8 @@ class EventsPersistenceStorageController:
) -> Dict[str, str]:
room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue(
- room_id, evs_ctxs, backfilled=backfilled
+ room_id,
+ _PersistEventsTask(events_and_contexts=evs_ctxs, backfilled=backfilled),
)
ret_vals = await yieldable_gather_results(enqueue, partitioned.items())
@@ -376,7 +429,10 @@ class EventsPersistenceStorageController:
# event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events.)
replaced_events = await self._event_persist_queue.add_to_queue(
- event.room_id, [(event, context)], backfilled=backfilled
+ event.room_id,
+ _PersistEventsTask(
+ events_and_contexts=[(event, context)], backfilled=backfilled
+ ),
)
replaced_event = replaced_events.get(event.event_id)
if replaced_event:
@@ -391,20 +447,22 @@ class EventsPersistenceStorageController:
async def update_current_state(self, room_id: str) -> None:
"""Recalculate the current state for a room, and persist it"""
+ await self._event_persist_queue.add_to_queue(
+ room_id,
+ _UpdateCurrentStateTask(),
+ )
+
+ async def _update_current_state(
+ self, room_id: str, _task: _UpdateCurrentStateTask
+ ) -> None:
+ """Callback for the _event_persist_queue
+
+ Recalculates the current state for a room, and persists it.
+ """
state = await self._calculate_current_state(room_id)
delta = await self._calculate_state_delta(room_id, state)
- # TODO(faster_joins): get a real stream ordering, to make this work correctly
- # across workers.
- # https://github.com/matrix-org/synapse/issues/12994
- #
- # TODO(faster_joins): this can race against event persistence, in which case we
- # will end up with incorrect state. Perhaps we should make this a job we
- # farm out to the event persister thread, somehow.
- # https://github.com/matrix-org/synapse/issues/13007
- #
- stream_id = self.main_store.get_room_max_stream_ordering()
- await self.persist_events_store.update_current_state(room_id, delta, stream_id)
+ await self.persist_events_store.update_current_state(room_id, delta)
async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
"""Calculate the current state of a room, based on the forward extremities
@@ -449,9 +507,7 @@ class EventsPersistenceStorageController:
return res.state
async def _persist_event_batch(
- self,
- events_and_contexts: List[Tuple[EventBase, EventContext]],
- backfilled: bool = False,
+ self, _room_id: str, task: _PersistEventsTask
) -> Dict[str, str]:
"""Callback for the _event_persist_queue
@@ -466,6 +522,9 @@ class EventsPersistenceStorageController:
PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
+ events_and_contexts = task.events_and_contexts
+ backfilled = task.backfilled
+
replaced_events: Dict[str, str] = {}
if not events_and_contexts:
return replaced_events
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 8a0e4e9589..2ff3d21305 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1007,16 +1007,16 @@ class PersistEventsStore:
self,
room_id: str,
state_delta: DeltaState,
- stream_id: int,
) -> None:
"""Update the current state stored in the datatabase for the given room"""
- await self.db_pool.runInteraction(
- "update_current_state",
- self._update_current_state_txn,
- state_delta_by_room={room_id: state_delta},
- stream_id=stream_id,
- )
+ async with self._stream_id_gen.get_next() as stream_ordering:
+ await self.db_pool.runInteraction(
+ "update_current_state",
+ self._update_current_state_txn,
+ state_delta_by_room={room_id: state_delta},
+ stream_id=stream_ordering,
+ )
def _update_current_state_txn(
self,
|