summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/tcp/handler.py115
-rw-r--r--synapse/replication/tcp/protocol.py45
-rw-r--r--synapse/replication/tcp/redis.py37
4 files changed, 112 insertions, 87 deletions
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 5ef1c6c1dc..a84a064c8d 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -39,10 +39,10 @@ class ReplicationRestResource(JsonResource):
         federation.register_servlets(hs, self)
         presence.register_servlets(hs, self)
         membership.register_servlets(hs, self)
+        streams.register_servlets(hs, self)
 
         # The following can't currently be instantiated on workers.
         if hs.config.worker.worker_app is None:
             login.register_servlets(hs, self)
             register.register_servlets(hs, self)
             devices.register_servlets(hs, self)
-            streams.register_servlets(hs, self)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1de590bba2..1c303f3a46 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -16,6 +16,7 @@
 import logging
 from typing import (
     Any,
+    Awaitable,
     Dict,
     Iterable,
     Iterator,
@@ -33,6 +34,7 @@ from typing_extensions import Deque
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
 from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
@@ -152,7 +154,7 @@ class ReplicationCommandHandler:
         # When POSITION or RDATA commands arrive, we stick them in a queue and process
         # them in order in a separate background process.
 
-        # the streams which are currently being processed by _unsafe_process_stream
+        # the streams which are currently being processed by _unsafe_process_queue
         self._processing_streams = set()  # type: Set[str]
 
         # for each stream, a queue of commands that are awaiting processing, and the
@@ -185,7 +187,7 @@ class ReplicationCommandHandler:
         if self._is_master:
             self._server_notices_sender = hs.get_server_notices_sender()
 
-    async def _add_command_to_stream_queue(
+    def _add_command_to_stream_queue(
         self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
     ) -> None:
         """Queue the given received command for processing
@@ -199,33 +201,34 @@ class ReplicationCommandHandler:
             logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
             return
 
-        # if we're already processing this stream, stick the new command in the
-        # queue, and we're done.
+        queue.append((cmd, conn))
+
+        # if we're already processing this stream, there's nothing more to do:
+        # the new entry on the queue will get picked up in due course
         if stream_name in self._processing_streams:
-            queue.append((cmd, conn))
             return
 
-        # otherwise, process the new command.
+        # fire off a background process to start processing the queue.
+        run_as_background_process(
+            "process-replication-data", self._unsafe_process_queue, stream_name
+        )
 
-        # arguably we should start off a new background process here, but nothing
-        # will be too upset if we don't return for ages, so let's save the overhead
-        # and use the existing logcontext.
+    async def _unsafe_process_queue(self, stream_name: str):
+        """Processes the command queue for the given stream, until it is empty
+
+        Does not check if there is already a thread processing the queue, hence "unsafe"
+        """
+        assert stream_name not in self._processing_streams
 
         self._processing_streams.add(stream_name)
         try:
-            # might as well skip the queue for this one, since it must be empty
-            assert not queue
-            await self._process_command(cmd, conn, stream_name)
-
-            # now process any other commands that have built up while we were
-            # dealing with that one.
+            queue = self._command_queues_by_stream.get(stream_name)
             while queue:
                 cmd, conn = queue.popleft()
                 try:
                     await self._process_command(cmd, conn, stream_name)
                 except Exception:
                     logger.exception("Failed to handle command %s", cmd)
-
         finally:
             self._processing_streams.discard(stream_name)
 
@@ -299,7 +302,7 @@ class ReplicationCommandHandler:
         """
         return self._streams_to_replicate
 
-    async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+    def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
         self.send_positions_to_connection(conn)
 
     def send_positions_to_connection(self, conn: AbstractConnection):
@@ -318,57 +321,73 @@ class ReplicationCommandHandler:
                 )
             )
 
-    async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
+    def on_USER_SYNC(
+        self, conn: AbstractConnection, cmd: UserSyncCommand
+    ) -> Optional[Awaitable[None]]:
         user_sync_counter.inc()
 
         if self._is_master:
-            await self._presence_handler.update_external_syncs_row(
+            return self._presence_handler.update_external_syncs_row(
                 cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
             )
+        else:
+            return None
 
-    async def on_CLEAR_USER_SYNC(
+    def on_CLEAR_USER_SYNC(
         self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
-    ):
+    ) -> Optional[Awaitable[None]]:
         if self._is_master:
-            await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+            return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+        else:
+            return None
 
-    async def on_FEDERATION_ACK(
-        self, conn: AbstractConnection, cmd: FederationAckCommand
-    ):
+    def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
         federation_ack_counter.inc()
 
         if self._federation_sender:
             self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
 
-    async def on_REMOVE_PUSHER(
+    def on_REMOVE_PUSHER(
         self, conn: AbstractConnection, cmd: RemovePusherCommand
-    ):
+    ) -> Optional[Awaitable[None]]:
         remove_pusher_counter.inc()
 
         if self._is_master:
-            await self._store.delete_pusher_by_app_id_pushkey_user_id(
-                app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
-            )
+            return self._handle_remove_pusher(cmd)
+        else:
+            return None
+
+    async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
+        await self._store.delete_pusher_by_app_id_pushkey_user_id(
+            app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+        )
 
-            self._notifier.on_new_replication_data()
+        self._notifier.on_new_replication_data()
 
-    async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
+    def on_USER_IP(
+        self, conn: AbstractConnection, cmd: UserIpCommand
+    ) -> Optional[Awaitable[None]]:
         user_ip_cache_counter.inc()
 
         if self._is_master:
-            await self._store.insert_client_ip(
-                cmd.user_id,
-                cmd.access_token,
-                cmd.ip,
-                cmd.user_agent,
-                cmd.device_id,
-                cmd.last_seen,
-            )
+            return self._handle_user_ip(cmd)
+        else:
+            return None
+
+    async def _handle_user_ip(self, cmd: UserIpCommand):
+        await self._store.insert_client_ip(
+            cmd.user_id,
+            cmd.access_token,
+            cmd.ip,
+            cmd.user_agent,
+            cmd.device_id,
+            cmd.last_seen,
+        )
 
-        if self._server_notices_sender:
-            await self._server_notices_sender.on_user_ip(cmd.user_id)
+        assert self._server_notices_sender is not None
+        await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-    async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+    def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore RDATA that are just our own echoes
             return
@@ -382,7 +401,7 @@ class ReplicationCommandHandler:
         #   2. so we don't race with getting a POSITION command and fetching
         #      missing RDATA.
 
-        await self._add_command_to_stream_queue(conn, cmd)
+        self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_rdata(
         self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
@@ -459,14 +478,14 @@ class ReplicationCommandHandler:
             stream_name, instance_name, token, rows
         )
 
-    async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+    def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore POSITION that are just our own echoes
             return
 
         logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
 
-        await self._add_command_to_stream_queue(conn, cmd)
+        self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_position(
         self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
@@ -526,9 +545,7 @@ class ReplicationCommandHandler:
 
         self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
-    async def on_REMOTE_SERVER_UP(
-        self, conn: AbstractConnection, cmd: RemoteServerUpCommand
-    ):
+    def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
         """"Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)
 
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 23191e3218..0350923898 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -50,6 +50,7 @@ import abc
 import fcntl
 import logging
 import struct
+from inspect import isawaitable
 from typing import TYPE_CHECKING, List
 
 from prometheus_client import Counter
@@ -128,6 +129,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
     command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
+    `ReplicationCommandHandler.on_<COMMAND_NAME>` can optionally return a coroutine;
+    if so, that will get run as a background process.
 
     It also sends `PING` periodically, and correctly times out remote connections
     (if they send a `PING` command)
@@ -166,9 +169,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # a logcontext which we use for processing incoming commands. We declare it as a
         # background process so that the CPU stats get reported to prometheus.
-        self._logging_context = BackgroundProcessLoggingContext(
-            "replication_command_handler-%s" % self.conn_id
-        )
+        ctx_name = "replication-conn-%s" % self.conn_id
+        self._logging_context = BackgroundProcessLoggingContext(ctx_name)
+        self._logging_context.request = ctx_name
 
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
@@ -246,18 +249,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
 
-        # Now lets try and call on_<CMD_NAME> function
-        run_as_background_process(
-            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
-        )
+        self.handle_command(cmd)
 
-    async def handle_command(self, cmd: Command):
+    def handle_command(self, cmd: Command) -> None:
         """Handle a command we have received over the replication stream.
 
         First calls `self.on_<COMMAND>` if it exists, then calls
-        `self.command_handler.on_<COMMAND>` if it exists. This allows for
-        protocol level handling of commands (e.g. PINGs), before delegating to
-        the handler.
+        `self.command_handler.on_<COMMAND>` if it exists (which can optionally
+        return an Awaitable).
+
+        This allows for protocol level handling of commands (e.g. PINGs), before
+        delegating to the handler.
 
         Args:
             cmd: received command
@@ -268,13 +270,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # specific handling.
         cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
         if cmd_func:
-            await cmd_func(cmd)
+            cmd_func(cmd)
             handled = True
 
         # Then call out to the handler.
         cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
         if cmd_func:
-            await cmd_func(self, cmd)
+            res = cmd_func(self, cmd)
+
+            # the handler might be a coroutine: fire it off as a background process
+            # if so.
+
+            if isawaitable(res):
+                run_as_background_process(
+                    "replication-" + cmd.get_logcontext_id(), lambda: res
+                )
+
             handled = True
 
         if not handled:
@@ -350,10 +361,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         for cmd in pending:
             self.send_command(cmd)
 
-    async def on_PING(self, line):
+    def on_PING(self, line):
         self.received_ping = True
 
-    async def on_ERROR(self, cmd):
+    def on_ERROR(self, cmd):
         logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
 
     def pauseProducing(self):
@@ -448,7 +459,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.send_command(ServerCommand(self.server_name))
         super().connectionMade()
 
-    async def on_NAME(self, cmd):
+    def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
@@ -477,7 +488,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         # Once we've connected subscribe to the necessary streams
         self.replicate()
 
-    async def on_SERVER(self, cmd):
+    def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index b5c533a607..f225e533de 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from inspect import isawaitable
 from typing import TYPE_CHECKING
 
 import txredisapi
@@ -124,36 +125,32 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
         # remote instances.
         tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
-        # Now lets try and call on_<CMD_NAME> function
-        run_as_background_process(
-            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
-        )
+        self.handle_command(cmd)
 
-    async def handle_command(self, cmd: Command):
+    def handle_command(self, cmd: Command) -> None:
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>, which should return an awaitable.
+        Delegates to `self.handler.on_<COMMAND>` (which can optionally return an
+        Awaitable).
 
         Args:
             cmd: received command
         """
-        handled = False
-
-        # First call any command handlers on this instance. These are for redis
-        # specific handling.
-        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
-        if cmd_func:
-            await cmd_func(cmd)
-            handled = True
 
-        # Then call out to the handler.
         cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
-        if cmd_func:
-            await cmd_func(self, cmd)
-            handled = True
-
-        if not handled:
+        if not cmd_func:
             logger.warning("Unhandled command: %r", cmd)
+            return
+
+        res = cmd_func(self, cmd)
+
+        # the handler might be a coroutine: fire it off as a background process
+        # if so.
+
+        if isawaitable(res):
+            run_as_background_process(
+                "replication-" + cmd.get_logcontext_id(), lambda: res
+            )
 
     def connectionLost(self, reason):
         logger.info("Lost connection to redis")