summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/federation.py13
-rw-r--r--synapse/replication/http/membership.py14
-rw-r--r--synapse/replication/http/send_event.py4
-rw-r--r--synapse/replication/http/streams.py5
-rw-r--r--synapse/replication/tcp/client.py90
5 files changed, 108 insertions, 18 deletions
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 7e23b565b9..c287c4e269 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
     """Handles events newly received from federation, including persisting and
-    notifying.
+    notifying. Returns the maximum stream ID of the persisted events.
 
     The API looks like:
 
@@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
                 "context": { .. serialized event context .. },
             }],
             "backfilled": false
+        }
+
+        200 OK
+
+        {
+            "max_stream_id": 32443,
+        }
     """
 
     NAME = "fed_send_events"
@@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
 
         logger.info("Got %d events from federation", len(event_and_contexts))
 
-        await self.federation_handler.persist_events_and_notify(
+        max_stream_id = await self.federation_handler.persist_events_and_notify(
             event_and_contexts, backfilled
         )
 
-        return 200, {}
+        return 200, {"max_stream_id": max_stream_id}
 
 
 class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 3577611fd7..050fd34562 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -76,11 +76,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
 
         logger.info("remote_join: %s into room: %s", user_id, room_id)
 
-        await self.federation_handler.do_invite_join(
+        event_id, stream_id = await self.federation_handler.do_invite_join(
             remote_room_hosts, room_id, user_id, event_content
         )
 
-        return 200, {}
+        return 200, {"event_id": event_id, "stream_id": stream_id}
 
 
 class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@@ -136,10 +136,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
 
         try:
-            event = await self.federation_handler.do_remotely_reject_invite(
+            event, stream_id = await self.federation_handler.do_remotely_reject_invite(
                 remote_room_hosts, room_id, user_id, event_content,
             )
-            ret = event.get_pdu_json()
+            event_id = event.event_id
         except Exception as e:
             # if we were unable to reject the exception, just mark
             # it as rejected on our end and plough ahead.
@@ -149,10 +149,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
             #
             logger.warning("Failed to reject invite: %s", e)
 
-            await self.store.locally_reject_invite(user_id, room_id)
-            ret = {}
+            stream_id = await self.store.locally_reject_invite(user_id, room_id)
+            event_id = None
 
-        return 200, ret
+        return 200, {"event_id": event_id, "stream_id": stream_id}
 
 
 class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index b74b088ff4..c981723c1a 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
         )
 
-        await self.event_creation_handler.persist_and_notify_client_event(
+        stream_id = await self.event_creation_handler.persist_and_notify_client_event(
             requester, event, context, ratelimit=ratelimit, extra_users=extra_users
         )
 
-        return 200, {}
+        return 200, {"stream_id": stream_id}
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index b705a8e16c..bde97eef32 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
         super().__init__(hs)
 
         self._instance_name = hs.get_instance_name()
-
-        # We pull the streams from the replication handler (if we try and make
-        # them ourselves we end up in an import loop).
-        self.streams = hs.get_tcp_replication().get_streams()
+        self.streams = hs.get_replication_streams()
 
     @staticmethod
     def _serialize_payload(stream_name, from_token, upto_token):
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 28826302f5..508ad1b720 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,19 +14,23 @@
 # limitations under the License.
 """A replication client for use by synapse workers.
 """
-
+import heapq
 import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple
 
+from twisted.internet.defer import Deferred
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.streams.events import (
     EventsStream,
     EventsStreamEventRow,
     EventsStreamRow,
 )
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -35,6 +39,10 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+# How long we allow callers to wait for replication updates before timing out.
+_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
+
+
 class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
     """Factory for building connections to the master. Will reconnect if the
     connection is lost.
@@ -92,6 +100,16 @@ class ReplicationDataHandler:
         self.store = hs.get_datastore()
         self.pusher_pool = hs.get_pusherpool()
         self.notifier = hs.get_notifier()
+        self._reactor = hs.get_reactor()
+        self._clock = hs.get_clock()
+        self._streams = hs.get_replication_streams()
+        self._instance_name = hs.get_instance_name()
+
+        # Map from stream to list of deferreds waiting for the stream to
+        # arrive at a particular position. The lists are sorted by stream position.
+        self._streams_to_waiters = (
+            {}
+        )  # type: Dict[str, List[Tuple[int, Deferred[None]]]]
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
@@ -131,8 +149,76 @@ class ReplicationDataHandler:
 
             await self.pusher_pool.on_new_notifications(token, token)
 
+        # Notify any waiting deferreds. The list is ordered by position so we
+        # just iterate through the list until we reach a position that is
+        # greater than the received row position.
+        waiting_list = self._streams_to_waiters.get(stream_name, [])
+
+        # Index of first item with a position after the current token, i.e we
+        # have called all deferreds before this index. If not overwritten by
+        # loop below means either a) no items in list so no-op or b) all items
+        # in list were called and so the list should be cleared. Setting it to
+        # `len(list)` works for both cases.
+        index_of_first_deferred_not_called = len(waiting_list)
+
+        for idx, (position, deferred) in enumerate(waiting_list):
+            if position <= token:
+                try:
+                    with PreserveLoggingContext():
+                        deferred.callback(None)
+                except Exception:
+                    # The deferred has been cancelled or timed out.
+                    pass
+            else:
+                # The list is sorted by position so we don't need to continue
+                # checking any futher entries in the list.
+                index_of_first_deferred_not_called = idx
+                break
+
+        # Drop all entries in the waiting list that were called in the above
+        # loop. (This maintains the order so no need to resort)
+        waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+
     async def on_position(self, stream_name: str, instance_name: str, token: int):
         self.store.process_replication_rows(stream_name, instance_name, token, [])
 
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
+
+    async def wait_for_stream_position(
+        self, instance_name: str, stream_name: str, position: int
+    ):
+        """Wait until this instance has received updates up to and including
+        the given stream position.
+        """
+
+        if instance_name == self._instance_name:
+            # We don't get told about updates written by this process, and
+            # anyway in that case we don't need to wait.
+            return
+
+        current_position = self._streams[stream_name].current_token(self._instance_name)
+        if position <= current_position:
+            # We're already past the position
+            return
+
+        # Create a new deferred that times out after N seconds, as we don't want
+        # to wedge here forever.
+        deferred = Deferred()
+        deferred = timeout_deferred(
+            deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
+        )
+
+        waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
+
+        # We insert into the list using heapq as it is more efficient than
+        # pushing then resorting each time.
+        heapq.heappush(waiting_list, (position, deferred))
+
+        # We measure here to get in flight counts and average waiting time.
+        with Measure(self._clock, "repl.wait_for_stream_position"):
+            logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
+            await make_deferred_yieldable(deferred)
+            logger.info(
+                "Finished waiting for repl stream %r to reach %s", stream_name, position
+            )