summary refs log tree commit diff
path: root/synapse/replication/tcp/handler.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/replication/tcp/handler.py115
1 files changed, 66 insertions, 49 deletions
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 1de590bba2..1c303f3a46 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -16,6 +16,7 @@
 import logging
 from typing import (
     Any,
+    Awaitable,
     Dict,
     Iterable,
     Iterator,
@@ -33,6 +34,7 @@ from typing_extensions import Deque
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
 from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
@@ -152,7 +154,7 @@ class ReplicationCommandHandler:
         # When POSITION or RDATA commands arrive, we stick them in a queue and process
         # them in order in a separate background process.
 
-        # the streams which are currently being processed by _unsafe_process_stream
+        # the streams which are currently being processed by _unsafe_process_queue
         self._processing_streams = set()  # type: Set[str]
 
         # for each stream, a queue of commands that are awaiting processing, and the
@@ -185,7 +187,7 @@ class ReplicationCommandHandler:
         if self._is_master:
             self._server_notices_sender = hs.get_server_notices_sender()
 
-    async def _add_command_to_stream_queue(
+    def _add_command_to_stream_queue(
         self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
     ) -> None:
         """Queue the given received command for processing
@@ -199,33 +201,34 @@ class ReplicationCommandHandler:
             logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
             return
 
-        # if we're already processing this stream, stick the new command in the
-        # queue, and we're done.
+        queue.append((cmd, conn))
+
+        # if we're already processing this stream, there's nothing more to do:
+        # the new entry on the queue will get picked up in due course
         if stream_name in self._processing_streams:
-            queue.append((cmd, conn))
             return
 
-        # otherwise, process the new command.
+        # fire off a background process to start processing the queue.
+        run_as_background_process(
+            "process-replication-data", self._unsafe_process_queue, stream_name
+        )
 
-        # arguably we should start off a new background process here, but nothing
-        # will be too upset if we don't return for ages, so let's save the overhead
-        # and use the existing logcontext.
+    async def _unsafe_process_queue(self, stream_name: str):
+        """Processes the command queue for the given stream, until it is empty
+
+        Does not check if there is already a thread processing the queue, hence "unsafe"
+        """
+        assert stream_name not in self._processing_streams
 
         self._processing_streams.add(stream_name)
         try:
-            # might as well skip the queue for this one, since it must be empty
-            assert not queue
-            await self._process_command(cmd, conn, stream_name)
-
-            # now process any other commands that have built up while we were
-            # dealing with that one.
+            queue = self._command_queues_by_stream.get(stream_name)
             while queue:
                 cmd, conn = queue.popleft()
                 try:
                     await self._process_command(cmd, conn, stream_name)
                 except Exception:
                     logger.exception("Failed to handle command %s", cmd)
-
         finally:
             self._processing_streams.discard(stream_name)
 
@@ -299,7 +302,7 @@ class ReplicationCommandHandler:
         """
         return self._streams_to_replicate
 
-    async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+    def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
         self.send_positions_to_connection(conn)
 
     def send_positions_to_connection(self, conn: AbstractConnection):
@@ -318,57 +321,73 @@ class ReplicationCommandHandler:
                 )
             )
 
-    async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
+    def on_USER_SYNC(
+        self, conn: AbstractConnection, cmd: UserSyncCommand
+    ) -> Optional[Awaitable[None]]:
         user_sync_counter.inc()
 
         if self._is_master:
-            await self._presence_handler.update_external_syncs_row(
+            return self._presence_handler.update_external_syncs_row(
                 cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
             )
+        else:
+            return None
 
-    async def on_CLEAR_USER_SYNC(
+    def on_CLEAR_USER_SYNC(
         self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
-    ):
+    ) -> Optional[Awaitable[None]]:
         if self._is_master:
-            await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+            return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+        else:
+            return None
 
-    async def on_FEDERATION_ACK(
-        self, conn: AbstractConnection, cmd: FederationAckCommand
-    ):
+    def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
         federation_ack_counter.inc()
 
         if self._federation_sender:
             self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
 
-    async def on_REMOVE_PUSHER(
+    def on_REMOVE_PUSHER(
         self, conn: AbstractConnection, cmd: RemovePusherCommand
-    ):
+    ) -> Optional[Awaitable[None]]:
         remove_pusher_counter.inc()
 
         if self._is_master:
-            await self._store.delete_pusher_by_app_id_pushkey_user_id(
-                app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
-            )
+            return self._handle_remove_pusher(cmd)
+        else:
+            return None
+
+    async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
+        await self._store.delete_pusher_by_app_id_pushkey_user_id(
+            app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+        )
 
-            self._notifier.on_new_replication_data()
+        self._notifier.on_new_replication_data()
 
-    async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
+    def on_USER_IP(
+        self, conn: AbstractConnection, cmd: UserIpCommand
+    ) -> Optional[Awaitable[None]]:
         user_ip_cache_counter.inc()
 
         if self._is_master:
-            await self._store.insert_client_ip(
-                cmd.user_id,
-                cmd.access_token,
-                cmd.ip,
-                cmd.user_agent,
-                cmd.device_id,
-                cmd.last_seen,
-            )
+            return self._handle_user_ip(cmd)
+        else:
+            return None
+
+    async def _handle_user_ip(self, cmd: UserIpCommand):
+        await self._store.insert_client_ip(
+            cmd.user_id,
+            cmd.access_token,
+            cmd.ip,
+            cmd.user_agent,
+            cmd.device_id,
+            cmd.last_seen,
+        )
 
-        if self._server_notices_sender:
-            await self._server_notices_sender.on_user_ip(cmd.user_id)
+        assert self._server_notices_sender is not None
+        await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-    async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+    def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore RDATA that are just our own echoes
             return
@@ -382,7 +401,7 @@ class ReplicationCommandHandler:
         #   2. so we don't race with getting a POSITION command and fetching
         #      missing RDATA.
 
-        await self._add_command_to_stream_queue(conn, cmd)
+        self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_rdata(
         self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
@@ -459,14 +478,14 @@ class ReplicationCommandHandler:
             stream_name, instance_name, token, rows
         )
 
-    async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+    def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
         if cmd.instance_name == self._instance_name:
             # Ignore POSITION that are just our own echoes
             return
 
         logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
 
-        await self._add_command_to_stream_queue(conn, cmd)
+        self._add_command_to_stream_queue(conn, cmd)
 
     async def _process_position(
         self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
@@ -526,9 +545,7 @@ class ReplicationCommandHandler:
 
         self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
-    async def on_REMOTE_SERVER_UP(
-        self, conn: AbstractConnection, cmd: RemoteServerUpCommand
-    ):
+    def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
         """"Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)