summary refs log tree commit diff
path: root/synapse/replication/tcp/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/client.py')
-rw-r--r--synapse/replication/tcp/client.py45
1 files changed, 26 insertions, 19 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e29ae1e375..d59ce7ccf9 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,10 +14,12 @@
 """A replication client for use by synapse workers.
 """
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IAddress, IConnector
 from twisted.internet.protocol import ReconnectingClientFactory
+from twisted.python.failure import Failure
 
 from synapse.api.constants import EventTypes
 from synapse.federation import send_queue
@@ -79,10 +81,10 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
 
         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
 
-    def startedConnecting(self, connector):
+    def startedConnecting(self, connector: IConnector) -> None:
         logger.info("Connecting to replication: %r", connector.getDestination())
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol:
         logger.info("Connected to replication: %r", addr)
         return ClientReplicationStreamProtocol(
             self.hs,
@@ -92,11 +94,11 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
             self.command_handler,
         )
 
-    def clientConnectionLost(self, connector, reason):
+    def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
         logger.error("Lost replication conn: %r", reason)
         ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
 
-    def clientConnectionFailed(self, connector, reason):
+    def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
         logger.error("Failed to connect to replication: %r", reason)
         ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
 
@@ -131,7 +133,7 @@ class ReplicationDataHandler:
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
-    ):
+    ) -> None:
         """Called to handle a batch of replication data with a given stream token.
 
         By default this just pokes the slave store. Can be overridden in subclasses to
@@ -252,14 +254,16 @@ class ReplicationDataHandler:
         # loop. (This maintains the order so no need to resort)
         waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
 
-    async def on_position(self, stream_name: str, instance_name: str, token: int):
+    async def on_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
         await self.on_rdata(stream_name, instance_name, token, [])
 
         # We poke the generic "replication" notifier to wake anything up that
         # may be streaming.
         self.notifier.notify_replication()
 
-    def on_remote_server_up(self, server: str):
+    def on_remote_server_up(self, server: str) -> None:
         """Called when get a new REMOTE_SERVER_UP command."""
 
         # Let's wake up the transaction queue for the server in case we have
@@ -269,7 +273,7 @@ class ReplicationDataHandler:
 
     async def wait_for_stream_position(
         self, instance_name: str, stream_name: str, position: int
-    ):
+    ) -> None:
         """Wait until this instance has received updates up to and including
         the given stream position.
         """
@@ -304,7 +308,7 @@ class ReplicationDataHandler:
                 "Finished waiting for repl stream %r to reach %s", stream_name, position
             )
 
-    def stop_pusher(self, user_id, app_id, pushkey):
+    def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
         if not self._notify_pushers:
             return
 
@@ -316,13 +320,13 @@ class ReplicationDataHandler:
         logger.info("Stopping pusher %r / %r", user_id, key)
         pusher.on_stop()
 
-    async def start_pusher(self, user_id, app_id, pushkey):
+    async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
         if not self._notify_pushers:
             return
 
         key = "%s:%s" % (app_id, pushkey)
         logger.info("Starting pusher %r / %r", user_id, key)
-        return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
+        await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
 
 
 class FederationSenderHandler:
@@ -353,10 +357,12 @@ class FederationSenderHandler:
 
         self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
 
-    def wake_destination(self, server: str):
+    def wake_destination(self, server: str) -> None:
         self.federation_sender.wake_destination(server)
 
-    async def process_replication_rows(self, stream_name, token, rows):
+    async def process_replication_rows(
+        self, stream_name: str, token: int, rows: list
+    ) -> None:
         # The federation stream contains things that we want to send out, e.g.
         # presence, typing, etc.
         if stream_name == "federation":
@@ -384,11 +390,12 @@ class FederationSenderHandler:
             for host in hosts:
                 self.federation_sender.send_device_messages(host)
 
-    async def _on_new_receipts(self, rows):
+    async def _on_new_receipts(
+        self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
+    ) -> None:
         """
         Args:
-            rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
-                new receipts to be processed
+            rows: new receipts to be processed
         """
         for receipt in rows:
             # we only want to send on receipts for our own users
@@ -408,7 +415,7 @@ class FederationSenderHandler:
             )
             await self.federation_sender.send_read_receipt(receipt_info)
 
-    async def update_token(self, token):
+    async def update_token(self, token: int) -> None:
         """Update the record of where we have processed to in the federation stream.
 
         Called after we have processed a an update received over replication. Sends
@@ -428,7 +435,7 @@ class FederationSenderHandler:
 
         run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
 
-    async def _save_and_send_ack(self):
+    async def _save_and_send_ack(self) -> None:
         """Save the current federation position in the database and send an ACK
         to master with where we're up to.
         """