From 1531b214fc57714c14046a8f66c7b5fe5ec5dcdd Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 22 May 2020 14:21:54 +0100 Subject: Add ability to wait for replication streams (#7542) The idea here is that if an instance persists an event via the replication HTTP API it can return before we receive that event over replication, which can lead to races where code assumes that persisting an event immediately updates various caches (e.g. current state of the room). Most of Synapse doesn't hit such races, so we don't do the waiting automagically, instead we do so where necessary to avoid unnecessary delays. We may decide to change our minds here if it turns out there are a lot of subtle races going on. People probably want to look at this commit by commit. --- synapse/replication/http/federation.py | 13 +++-- synapse/replication/http/membership.py | 14 +++--- synapse/replication/http/send_event.py | 4 +- synapse/replication/http/streams.py | 5 +- synapse/replication/tcp/client.py | 90 +++++++++++++++++++++++++++++++++- 5 files changed, 108 insertions(+), 18 deletions(-) (limited to 'synapse/replication') diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 7e23b565b9..c287c4e269 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """Handles events newly received from federation, including persisting and - notifying. + notifying. Returns the maximum stream ID of the persisted events. The API looks like: @@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): "context": { .. serialized event context .. }, }], "backfilled": false + } + + 200 OK + + { + "max_stream_id": 32443, + } """ NAME = "fed_send_events" @@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) - await self.federation_handler.persist_events_and_notify( + max_stream_id = await self.federation_handler.persist_events_and_notify( event_and_contexts, backfilled ) - return 200, {} + return 200, {"max_stream_id": max_stream_id} class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 3577611fd7..050fd34562 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -76,11 +76,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): logger.info("remote_join: %s into room: %s", user_id, room_id) - await self.federation_handler.do_invite_join( + event_id, stream_id = await self.federation_handler.do_invite_join( remote_room_hosts, room_id, user_id, event_content ) - return 200, {} + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @@ -136,10 +136,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) try: - event = await self.federation_handler.do_remotely_reject_invite( + event, stream_id = await self.federation_handler.do_remotely_reject_invite( remote_room_hosts, room_id, user_id, event_content, ) - ret = event.get_pdu_json() + event_id = event.event_id except Exception as e: # if we were unable to reject the exception, just mark # it as rejected on our end and plough ahead. @@ -149,10 +149,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): # logger.warning("Failed to reject invite: %s", e) - await self.store.locally_reject_invite(user_id, room_id) - ret = {} + stream_id = await self.store.locally_reject_invite(user_id, room_id) + event_id = None - return 200, ret + return 200, {"event_id": event_id, "stream_id": stream_id} class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index b74b088ff4..c981723c1a 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - await self.event_creation_handler.persist_and_notify_client_event( + stream_id = await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return 200, {} + return 200, {"stream_id": stream_id} def register_servlets(hs, http_server): diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py index b705a8e16c..bde97eef32 100644 --- a/synapse/replication/http/streams.py +++ b/synapse/replication/http/streams.py @@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): super().__init__(hs) self._instance_name = hs.get_instance_name() - - # We pull the streams from the replication handler (if we try and make - # them ourselves we end up in an import loop). - self.streams = hs.get_tcp_replication().get_streams() + self.streams = hs.get_replication_streams() @staticmethod def _serialize_payload(stream_name, from_token, upto_token): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 28826302f5..508ad1b720 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -14,19 +14,23 @@ # limitations under the License. """A replication client for use by synapse workers. """ - +import heapq import logging -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple +from twisted.internet.defer import Deferred from twisted.internet.protocol import ReconnectingClientFactory from synapse.api.constants import EventTypes +from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, EventsStreamRow, ) +from synapse.util.async_helpers import timeout_deferred +from synapse.util.metrics import Measure if TYPE_CHECKING: from synapse.server import HomeServer @@ -35,6 +39,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# How long we allow callers to wait for replication updates before timing out. +_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30 + + class DirectTcpReplicationClientFactory(ReconnectingClientFactory): """Factory for building connections to the master. Will reconnect if the connection is lost. @@ -92,6 +100,16 @@ class ReplicationDataHandler: self.store = hs.get_datastore() self.pusher_pool = hs.get_pusherpool() self.notifier = hs.get_notifier() + self._reactor = hs.get_reactor() + self._clock = hs.get_clock() + self._streams = hs.get_replication_streams() + self._instance_name = hs.get_instance_name() + + # Map from stream to list of deferreds waiting for the stream to + # arrive at a particular position. The lists are sorted by stream position. + self._streams_to_waiters = ( + {} + ) # type: Dict[str, List[Tuple[int, Deferred[None]]]] async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -131,8 +149,76 @@ class ReplicationDataHandler: await self.pusher_pool.on_new_notifications(token, token) + # Notify any waiting deferreds. The list is ordered by position so we + # just iterate through the list until we reach a position that is + # greater than the received row position. + waiting_list = self._streams_to_waiters.get(stream_name, []) + + # Index of first item with a position after the current token, i.e we + # have called all deferreds before this index. If not overwritten by + # loop below means either a) no items in list so no-op or b) all items + # in list were called and so the list should be cleared. Setting it to + # `len(list)` works for both cases. + index_of_first_deferred_not_called = len(waiting_list) + + for idx, (position, deferred) in enumerate(waiting_list): + if position <= token: + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + # The deferred has been cancelled or timed out. + pass + else: + # The list is sorted by position so we don't need to continue + # checking any futher entries in the list. + index_of_first_deferred_not_called = idx + break + + # Drop all entries in the waiting list that were called in the above + # loop. (This maintains the order so no need to resort) + waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + async def on_position(self, stream_name: str, instance_name: str, token: int): self.store.process_replication_rows(stream_name, instance_name, token, []) def on_remote_server_up(self, server: str): """Called when get a new REMOTE_SERVER_UP command.""" + + async def wait_for_stream_position( + self, instance_name: str, stream_name: str, position: int + ): + """Wait until this instance has received updates up to and including + the given stream position. + """ + + if instance_name == self._instance_name: + # We don't get told about updates written by this process, and + # anyway in that case we don't need to wait. + return + + current_position = self._streams[stream_name].current_token(self._instance_name) + if position <= current_position: + # We're already past the position + return + + # Create a new deferred that times out after N seconds, as we don't want + # to wedge here forever. + deferred = Deferred() + deferred = timeout_deferred( + deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor + ) + + waiting_list = self._streams_to_waiters.setdefault(stream_name, []) + + # We insert into the list using heapq as it is more efficient than + # pushing then resorting each time. + heapq.heappush(waiting_list, (position, deferred)) + + # We measure here to get in flight counts and average waiting time. + with Measure(self._clock, "repl.wait_for_stream_position"): + logger.info("Waiting for repl stream %r to reach %s", stream_name, position) + await make_deferred_yieldable(deferred) + logger.info( + "Finished waiting for repl stream %r to reach %s", stream_name, position + ) -- cgit 1.4.1