summary refs log tree commit diff
path: root/synapse/replication/tcp/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/client.py')
-rw-r--r--synapse/replication/tcp/client.py315
1 files changed, 155 insertions, 160 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index a44ceb00e7..df29732f51 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,38 +14,55 @@
 # limitations under the License.
 """A replication client for use by synapse workers.
 """
-
+import heapq
 import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple
 
-from twisted.internet import defer
+from twisted.internet.defer import Deferred
 from twisted.internet.protocol import ReconnectingClientFactory
 
-from .commands import (
-    FederationAckCommand,
-    InvalidateCacheCommand,
-    RemovePusherCommand,
-    UserIpCommand,
-    UserSyncCommand,
+from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.streams.events import (
+    EventsStream,
+    EventsStreamEventRow,
+    EventsStreamRow,
 )
-from .protocol import ClientReplicationStreamProtocol
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.metrics import Measure
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
 
 logger = logging.getLogger(__name__)
 
 
-class ReplicationClientFactory(ReconnectingClientFactory):
+# How long we allow callers to wait for replication updates before timing out.
+_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
+
+
+class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
     """Factory for building connections to the master. Will reconnect if the
     connection is lost.
 
-    Accepts a handler that will be called when new data is available or data
-    is required.
+    Accepts a handler that is passed to `ClientReplicationStreamProtocol`.
     """
 
-    maxDelay = 30  # Try at least once every N seconds
+    initialDelay = 0.1
+    maxDelay = 1  # Try at least once every N seconds
 
-    def __init__(self, hs, client_name, handler):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        client_name: str,
+        command_handler: "ReplicationCommandHandler",
+    ):
         self.client_name = client_name
-        self.handler = handler
+        self.command_handler = command_handler
         self.server_name = hs.config.server_name
+        self.hs = hs
         self._clock = hs.get_clock()  # As self.clock is defined in super class
 
         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -56,7 +73,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     def buildProtocol(self, addr):
         logger.info("Connected to replication: %r", addr)
         return ClientReplicationStreamProtocol(
-            self.client_name, self.server_name, self._clock, self.handler
+            self.hs,
+            self.client_name,
+            self.server_name,
+            self._clock,
+            self.command_handler,
         )
 
     def clientConnectionLost(self, connector, reason):
@@ -68,162 +89,136 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
 
 
-class ReplicationClientHandler(object):
-    """A base handler that can be passed to the ReplicationClientFactory.
+class ReplicationDataHandler:
+    """Handles incoming stream updates from replication.
 
-    By default proxies incoming replication data to the SlaveStore.
+    This instance notifies the slave data store about updates. Can be subclassed
+    to handle updates in additional ways.
     """
 
-    def __init__(self, store):
-        self.store = store
-
-        # The current connection. None if we are currently (re)connecting
-        self.connection = None
-
-        # Any pending commands to be sent once a new connection has been
-        # established
-        self.pending_commands = []
-
-        # Map from string -> deferred, to wake up when receiveing a SYNC with
-        # the given string.
-        # Used for tests.
-        self.awaiting_syncs = {}
-
-        # The factory used to create connections.
-        self.factory = None
-
-    def start_replication(self, hs):
-        """Helper method to start a replication connection to the remote server
-        using TCP.
-        """
-        client_name = hs.config.worker_name
-        self.factory = ReplicationClientFactory(hs, client_name, self)
-        host = hs.config.worker_replication_host
-        port = hs.config.worker_replication_port
-        hs.get_reactor().connectTCP(host, port, self.factory)
-
-    def on_rdata(self, stream_name, token, rows):
+    def __init__(self, hs: "HomeServer"):
+        self.store = hs.get_datastore()
+        self.pusher_pool = hs.get_pusherpool()
+        self.notifier = hs.get_notifier()
+        self._reactor = hs.get_reactor()
+        self._clock = hs.get_clock()
+        self._streams = hs.get_replication_streams()
+        self._instance_name = hs.get_instance_name()
+
+        # Map from stream to list of deferreds waiting for the stream to
+        # arrive at a particular position. The lists are sorted by stream position.
+        self._streams_to_waiters = (
+            {}
+        )  # type: Dict[str, List[Tuple[int, Deferred[None]]]]
+
+    async def on_rdata(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
         """Called to handle a batch of replication data with a given stream token.
 
         By default this just pokes the slave store. Can be overridden in subclasses to
         handle more.
 
         Args:
-            stream_name (str): name of the replication stream for this batch of rows
-            token (int): stream token for this batch of rows
-            rows (list): a list of Stream.ROW_TYPE objects as returned by
-                Stream.parse_row.
-
-        Returns:
-            Deferred|None
-        """
-        logger.debug("Received rdata %s -> %s", stream_name, token)
-        return self.store.process_replication_rows(stream_name, token, rows)
-
-    def on_position(self, stream_name, token):
-        """Called when we get new position data. By default this just pokes
-        the slave store.
-
-        Can be overriden in subclasses to handle more.
-        """
-        return self.store.process_replication_rows(stream_name, token, [])
-
-    def on_sync(self, data):
-        """When we received a SYNC we wake up any deferreds that were waiting
-        for the sync with the given data.
-
-        Used by tests.
+            stream_name: name of the replication stream for this batch of rows
+            instance_name: the instance that wrote the rows.
+            token: stream token for this batch of rows
+            rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
-        d = self.awaiting_syncs.pop(data, None)
-        if d:
-            d.callback(data)
-
-    def get_streams_to_replicate(self):
-        """Called when a new connection has been established and we need to
-        subscribe to streams.
-
-        Returns a dictionary of stream name to token.
-        """
-        args = self.store.stream_positions()
-        user_account_data = args.pop("user_account_data", None)
-        room_account_data = args.pop("room_account_data", None)
-        if user_account_data:
-            args["account_data"] = user_account_data
-        elif room_account_data:
-            args["account_data"] = room_account_data
-
-        return args
-
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users. (Overriden by the synchrotron's only)
-        """
-        return []
-
-    def send_command(self, cmd):
-        """Send a command to master (when we get establish a connection if we
-        don't have one already.)
-        """
-        if self.connection:
-            self.connection.send_command(cmd)
-        else:
-            logger.warn("Queuing command as not connected: %r", cmd.NAME)
-            self.pending_commands.append(cmd)
-
-    def send_federation_ack(self, token):
-        """Ack data for the federation stream. This allows the master to drop
-        data stored purely in memory.
-        """
-        self.send_command(FederationAckCommand(token))
-
-    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
-        """Poke the master that a user has started/stopped syncing.
-        """
-        self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
-
-    def send_remove_pusher(self, app_id, push_key, user_id):
-        """Poke the master to remove a pusher for a user
-        """
-        cmd = RemovePusherCommand(app_id, push_key, user_id)
-        self.send_command(cmd)
-
-    def send_invalidate_cache(self, cache_func, keys):
-        """Poke the master to invalidate a cache.
-        """
-        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
-        self.send_command(cmd)
-
-    def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
-        """Tell the master that the user made a request.
+        self.store.process_replication_rows(stream_name, instance_name, token, rows)
+
+        if stream_name == EventsStream.NAME:
+            # We shouldn't get multiple rows per token for events stream, so
+            # we don't need to optimise this for multiple rows.
+            for row in rows:
+                if row.type != EventsStreamEventRow.TypeId:
+                    continue
+                assert isinstance(row, EventsStreamRow)
+
+                event = await self.store.get_event(
+                    row.data.event_id, allow_rejected=True
+                )
+                if event.rejected_reason:
+                    continue
+
+                extra_users = ()  # type: Tuple[str, ...]
+                if event.type == EventTypes.Member:
+                    extra_users = (event.state_key,)
+                max_token = self.store.get_room_max_stream_ordering()
+                self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+            await self.pusher_pool.on_new_notifications(token, token)
+
+        # Notify any waiting deferreds. The list is ordered by position so we
+        # just iterate through the list until we reach a position that is
+        # greater than the received row position.
+        waiting_list = self._streams_to_waiters.get(stream_name, [])
+
+        # Index of first item with a position after the current token, i.e we
+        # have called all deferreds before this index. If not overwritten by
+        # loop below means either a) no items in list so no-op or b) all items
+        # in list were called and so the list should be cleared. Setting it to
+        # `len(list)` works for both cases.
+        index_of_first_deferred_not_called = len(waiting_list)
+
+        for idx, (position, deferred) in enumerate(waiting_list):
+            if position <= token:
+                try:
+                    with PreserveLoggingContext():
+                        deferred.callback(None)
+                except Exception:
+                    # The deferred has been cancelled or timed out.
+                    pass
+            else:
+                # The list is sorted by position so we don't need to continue
+                # checking any further entries in the list.
+                index_of_first_deferred_not_called = idx
+                break
+
+        # Drop all entries in the waiting list that were called in the above
+        # loop. (This maintains the order so no need to resort)
+        waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+
+    async def on_position(self, stream_name: str, instance_name: str, token: int):
+        self.store.process_replication_rows(stream_name, instance_name, token, [])
+
+    def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+
+    async def wait_for_stream_position(
+        self, instance_name: str, stream_name: str, position: int
+    ):
+        """Wait until this instance has received updates up to and including
+        the given stream position.
         """
-        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
-        self.send_command(cmd)
 
-    def await_sync(self, data):
-        """Returns a deferred that is resolved when we receive a SYNC command
-        with given data.
+        if instance_name == self._instance_name:
+            # We don't get told about updates written by this process, and
+            # anyway in that case we don't need to wait.
+            return
+
+        current_position = self._streams[stream_name].current_token(self._instance_name)
+        if position <= current_position:
+            # We're already past the position
+            return
+
+        # Create a new deferred that times out after N seconds, as we don't want
+        # to wedge here forever.
+        deferred = Deferred()
+        deferred = timeout_deferred(
+            deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
+        )
 
-        [Not currently] used by tests.
-        """
-        return self.awaiting_syncs.setdefault(data, defer.Deferred())
+        waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
 
-    def update_connection(self, connection):
-        """Called when a connection has been established (or lost with None).
-        """
-        self.connection = connection
-        if connection:
-            for cmd in self.pending_commands:
-                connection.send_command(cmd)
-            self.pending_commands = []
-
-    def finished_connecting(self):
-        """Called when we have successfully subscribed and caught up to all
-        streams we're interested in.
-        """
-        logger.info("Finished connecting to server")
+        # We insert into the list using heapq as it is more efficient than
+        # pushing then resorting each time.
+        heapq.heappush(waiting_list, (position, deferred))
 
-        # We don't reset the delay any earlier as otherwise if there is a
-        # problem during start up we'll end up tight looping connecting to the
-        # server.
-        self.factory.resetDelay()
+        # We measure here to get in flight counts and average waiting time.
+        with Measure(self._clock, "repl.wait_for_stream_position"):
+            logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
+            await make_deferred_yieldable(deferred)
+            logger.info(
+                "Finished waiting for repl stream %r to reach %s", stream_name, position
+            )