summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py2
-rw-r--r--synapse/replication/slave/storage/client_ips.py4
-rw-r--r--synapse/replication/slave/storage/devices.py10
-rw-r--r--synapse/replication/slave/storage/groups.py8
-rw-r--r--synapse/replication/slave/storage/push_rule.py7
-rw-r--r--synapse/replication/slave/storage/pushers.py6
-rw-r--r--synapse/replication/tcp/client.py45
-rw-r--r--synapse/replication/tcp/commands.py74
-rw-r--r--synapse/replication/tcp/handler.py36
-rw-r--r--synapse/replication/tcp/protocol.py68
-rw-r--r--synapse/replication/tcp/redis.py32
-rw-r--r--synapse/replication/tcp/resource.py16
-rw-r--r--synapse/replication/tcp/streams/_base.py6
-rw-r--r--synapse/replication/tcp/streams/events.py23
14 files changed, 196 insertions, 141 deletions
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index fa132d10b4..8f3f953ed4 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -40,7 +40,7 @@ class SlavedIdTracker(AbstractStreamIdTracker):
             for table, column in extra_tables:
                 self.advance(None, _load_current_id(db_conn, table, column))
 
-    def advance(self, instance_name: Optional[str], new_id: int):
+    def advance(self, instance_name: Optional[str], new_id: int) -> None:
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
     def get_current_token(self) -> int:
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index bc888ce1a8..b5b84c09ae 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -37,7 +37,9 @@ class SlavedClientIpStore(BaseSlavedStore):
             cache_name="client_ip_last_seen", max_size=50000
         )
 
-    async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
+    async def insert_client_ip(
+        self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
+    ) -> None:
         now = int(self._clock.time_msec())
         key = (user_id, access_token, ip)
 
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index a2aff75b70..0ffd34f1da 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Iterable
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -60,7 +60,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+    ) -> None:
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(instance_name, token)
             self._invalidate_caches_for_devices(token, rows)
@@ -70,7 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def _invalidate_caches_for_devices(self, token, rows):
+    def _invalidate_caches_for_devices(
+        self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
+    ) -> None:
         for row in rows:
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 9d90e26375..d6f37d7479 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Iterable
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -44,10 +44,12 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
             self._group_updates_id_gen.get_current_token(),
         )
 
-    def get_group_stream_token(self):
+    def get_group_stream_token(self) -> int:
         return self._group_updates_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+    ) -> None:
         if stream_name == GroupServerStream.NAME:
             self._group_updates_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 7541e21de9..52ee3f7e58 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Iterable
 
 from synapse.replication.tcp.streams import PushRulesStream
 from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -20,10 +21,12 @@ from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def get_max_push_rules_stream_id(self):
+    def get_max_push_rules_stream_id(self) -> int:
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+    ) -> None:
         if stream_name == PushRulesStream.NAME:
             self._push_rules_stream_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index cea90c0f1b..de642bba71 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Iterable
 
 from synapse.replication.tcp.streams import PushersStream
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
@@ -41,8 +41,8 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
         return self._pushers_id_gen.get_current_token()
 
     def process_replication_rows(
-        self, stream_name: str, instance_name: str, token, rows
+        self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
     ) -> None:
         if stream_name == PushersStream.NAME:
-            self._pushers_id_gen.advance(instance_name, token)  # type: ignore
+            self._pushers_id_gen.advance(instance_name, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e29ae1e375..d59ce7ccf9 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,10 +14,12 @@
 """A replication client for use by synapse workers.
 """
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IAddress, IConnector
 from twisted.internet.protocol import ReconnectingClientFactory
+from twisted.python.failure import Failure
 
 from synapse.api.constants import EventTypes
 from synapse.federation import send_queue
@@ -79,10 +81,10 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
 
         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
 
-    def startedConnecting(self, connector):
+    def startedConnecting(self, connector: IConnector) -> None:
         logger.info("Connecting to replication: %r", connector.getDestination())
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol:
         logger.info("Connected to replication: %r", addr)
         return ClientReplicationStreamProtocol(
             self.hs,
@@ -92,11 +94,11 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
             self.command_handler,
         )
 
-    def clientConnectionLost(self, connector, reason):
+    def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
         logger.error("Lost replication conn: %r", reason)
         ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
 
-    def clientConnectionFailed(self, connector, reason):
+    def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
         logger.error("Failed to connect to replication: %r", reason)
         ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
 
@@ -131,7 +133,7 @@ class ReplicationDataHandler:
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
-    ):
+    ) -> None:
         """Called to handle a batch of replication data with a given stream token.
 
         By default this just pokes the slave store. Can be overridden in subclasses to
@@ -252,14 +254,16 @@ class ReplicationDataHandler:
         # loop. (This maintains the order so no need to resort)
         waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
 
-    async def on_position(self, stream_name: str, instance_name: str, token: int):
+    async def on_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
         await self.on_rdata(stream_name, instance_name, token, [])
 
         # We poke the generic "replication" notifier to wake anything up that
         # may be streaming.
         self.notifier.notify_replication()
 
-    def on_remote_server_up(self, server: str):
+    def on_remote_server_up(self, server: str) -> None:
         """Called when get a new REMOTE_SERVER_UP command."""
 
         # Let's wake up the transaction queue for the server in case we have
@@ -269,7 +273,7 @@ class ReplicationDataHandler:
 
     async def wait_for_stream_position(
         self, instance_name: str, stream_name: str, position: int
-    ):
+    ) -> None:
         """Wait until this instance has received updates up to and including
         the given stream position.
         """
@@ -304,7 +308,7 @@ class ReplicationDataHandler:
                 "Finished waiting for repl stream %r to reach %s", stream_name, position
             )
 
-    def stop_pusher(self, user_id, app_id, pushkey):
+    def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
         if not self._notify_pushers:
             return
 
@@ -316,13 +320,13 @@ class ReplicationDataHandler:
         logger.info("Stopping pusher %r / %r", user_id, key)
         pusher.on_stop()
 
-    async def start_pusher(self, user_id, app_id, pushkey):
+    async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
         if not self._notify_pushers:
             return
 
         key = "%s:%s" % (app_id, pushkey)
         logger.info("Starting pusher %r / %r", user_id, key)
-        return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
+        await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
 
 
 class FederationSenderHandler:
@@ -353,10 +357,12 @@ class FederationSenderHandler:
 
         self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
 
-    def wake_destination(self, server: str):
+    def wake_destination(self, server: str) -> None:
         self.federation_sender.wake_destination(server)
 
-    async def process_replication_rows(self, stream_name, token, rows):
+    async def process_replication_rows(
+        self, stream_name: str, token: int, rows: list
+    ) -> None:
         # The federation stream contains things that we want to send out, e.g.
         # presence, typing, etc.
         if stream_name == "federation":
@@ -384,11 +390,12 @@ class FederationSenderHandler:
             for host in hosts:
                 self.federation_sender.send_device_messages(host)
 
-    async def _on_new_receipts(self, rows):
+    async def _on_new_receipts(
+        self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
+    ) -> None:
         """
         Args:
-            rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
-                new receipts to be processed
+            rows: new receipts to be processed
         """
         for receipt in rows:
             # we only want to send on receipts for our own users
@@ -408,7 +415,7 @@ class FederationSenderHandler:
             )
             await self.federation_sender.send_read_receipt(receipt_info)
 
-    async def update_token(self, token):
+    async def update_token(self, token: int) -> None:
         """Update the record of where we have processed to in the federation stream.
 
         Called after we have processed a an update received over replication. Sends
@@ -428,7 +435,7 @@ class FederationSenderHandler:
 
         run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
 
-    async def _save_and_send_ack(self):
+    async def _save_and_send_ack(self) -> None:
         """Save the current federation position in the database and send an ACK
         to master with where we're up to.
         """
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 1311b013da..3654f6c03c 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -18,12 +18,15 @@ allowed to be sent by which side.
 """
 import abc
 import logging
-from typing import Tuple, Type
+from typing import Optional, Tuple, Type, TypeVar
 
+from synapse.replication.tcp.streams._base import StreamRow
 from synapse.util import json_decoder, json_encoder
 
 logger = logging.getLogger(__name__)
 
+T = TypeVar("T", bound="Command")
+
 
 class Command(metaclass=abc.ABCMeta):
     """The base command class.
@@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta):
 
     @classmethod
     @abc.abstractmethod
-    def from_line(cls, line):
+    def from_line(cls: Type[T], line: str) -> T:
         """Deserialises a line from the wire into this command. `line` does not
         include the command.
         """
@@ -49,21 +52,24 @@ class Command(metaclass=abc.ABCMeta):
         prefix.
         """
 
-    def get_logcontext_id(self):
+    def get_logcontext_id(self) -> str:
         """Get a suitable string for the logcontext when processing this command"""
 
         # by default, we just use the command name.
         return self.NAME
 
 
+SC = TypeVar("SC", bound="_SimpleCommand")
+
+
 class _SimpleCommand(Command):
     """An implementation of Command whose argument is just a 'data' string."""
 
-    def __init__(self, data):
+    def __init__(self, data: str):
         self.data = data
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type[SC], line: str) -> SC:
         return cls(line)
 
     def to_line(self) -> str:
@@ -109,14 +115,16 @@ class RdataCommand(Command):
 
     NAME = "RDATA"
 
-    def __init__(self, stream_name, instance_name, token, row):
+    def __init__(
+        self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow
+    ):
         self.stream_name = stream_name
         self.instance_name = instance_name
         self.token = token
         self.row = row
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand":
         stream_name, instance_name, token, row_json = line.split(" ", 3)
         return cls(
             stream_name,
@@ -125,7 +133,7 @@ class RdataCommand(Command):
             json_decoder.decode(row_json),
         )
 
-    def to_line(self):
+    def to_line(self) -> str:
         return " ".join(
             (
                 self.stream_name,
@@ -135,7 +143,7 @@ class RdataCommand(Command):
             )
         )
 
-    def get_logcontext_id(self):
+    def get_logcontext_id(self) -> str:
         return "RDATA-" + self.stream_name
 
 
@@ -164,18 +172,20 @@ class PositionCommand(Command):
 
     NAME = "POSITION"
 
-    def __init__(self, stream_name, instance_name, prev_token, new_token):
+    def __init__(
+        self, stream_name: str, instance_name: str, prev_token: int, new_token: int
+    ):
         self.stream_name = stream_name
         self.instance_name = instance_name
         self.prev_token = prev_token
         self.new_token = new_token
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand":
         stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
         return cls(stream_name, instance_name, int(prev_token), int(new_token))
 
-    def to_line(self):
+    def to_line(self) -> str:
         return " ".join(
             (
                 self.stream_name,
@@ -218,14 +228,14 @@ class ReplicateCommand(Command):
 
     NAME = "REPLICATE"
 
-    def __init__(self):
+    def __init__(self) -> None:
         pass
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type[T], line: str) -> T:
         return cls()
 
-    def to_line(self):
+    def to_line(self) -> str:
         return ""
 
 
@@ -247,14 +257,16 @@ class UserSyncCommand(Command):
 
     NAME = "USER_SYNC"
 
-    def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+    def __init__(
+        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+    ):
         self.instance_id = instance_id
         self.user_id = user_id
         self.is_syncing = is_syncing
         self.last_sync_ms = last_sync_ms
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
         instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
 
         if state not in ("start", "end"):
@@ -262,7 +274,7 @@ class UserSyncCommand(Command):
 
         return cls(instance_id, user_id, state == "start", int(last_sync_ms))
 
-    def to_line(self):
+    def to_line(self) -> str:
         return " ".join(
             (
                 self.instance_id,
@@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command):
 
     NAME = "CLEAR_USER_SYNC"
 
-    def __init__(self, instance_id):
+    def __init__(self, instance_id: str):
         self.instance_id = instance_id
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(
+        cls: Type["ClearUserSyncsCommand"], line: str
+    ) -> "ClearUserSyncsCommand":
         return cls(line)
 
-    def to_line(self):
+    def to_line(self) -> str:
         return self.instance_id
 
 
@@ -316,7 +330,9 @@ class FederationAckCommand(Command):
         self.token = token
 
     @classmethod
-    def from_line(cls, line: str) -> "FederationAckCommand":
+    def from_line(
+        cls: Type["FederationAckCommand"], line: str
+    ) -> "FederationAckCommand":
         instance_name, token = line.split(" ")
         return cls(instance_name, int(token))
 
@@ -334,7 +350,15 @@ class UserIpCommand(Command):
 
     NAME = "USER_IP"
 
-    def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+    def __init__(
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: str,
+        last_seen: int,
+    ):
         self.user_id = user_id
         self.access_token = access_token
         self.ip = ip
@@ -343,14 +367,14 @@ class UserIpCommand(Command):
         self.last_seen = last_seen
 
     @classmethod
-    def from_line(cls, line):
+    def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand":
         user_id, jsn = line.split(" ", 1)
 
         access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
 
         return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
 
-    def to_line(self):
+    def to_line(self) -> str:
         return (
             self.user_id
             + " "
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index f7e6bc1e62..17e1572393 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -261,7 +261,7 @@ class ReplicationCommandHandler:
             "process-replication-data", self._unsafe_process_queue, stream_name
         )
 
-    async def _unsafe_process_queue(self, stream_name: str):
+    async def _unsafe_process_queue(self, stream_name: str) -> None:
         """Processes the command queue for the given stream, until it is empty
 
         Does not check if there is already a thread processing the queue, hence "unsafe"
@@ -294,7 +294,7 @@ class ReplicationCommandHandler:
             # This shouldn't be possible
             raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
 
-    def start_replication(self, hs: "HomeServer"):
+    def start_replication(self, hs: "HomeServer") -> None:
         """Helper method to start a replication connection to the remote server
         using TCP.
         """
@@ -345,10 +345,10 @@ class ReplicationCommandHandler:
         """Get a list of streams that this instances replicates."""
         return self._streams_to_replicate
 
-    def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
+    def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None:
         self.send_positions_to_connection(conn)
 
-    def send_positions_to_connection(self, conn: IReplicationConnection):
+    def send_positions_to_connection(self, conn: IReplicationConnection) -> None:
         """Send current position of all streams this process is source of to
         the connection.
         """
@@ -392,7 +392,7 @@ class ReplicationCommandHandler:
 
     def on_FEDERATION_ACK(
         self, conn: IReplicationConnection, cmd: FederationAckCommand
-    ):
+    ) -> None:
         federation_ack_counter.inc()
 
         if self._federation_sender:
@@ -408,7 +408,7 @@ class ReplicationCommandHandler:
         else:
             return None
 
-    async def _handle_user_ip(self, cmd: UserIpCommand):
+    async def _handle_user_ip(self, cmd: UserIpCommand) -> None:
         await self._store.insert_client_ip(
             cmd.user_id,
             cmd.access_token,
@@ -421,7 +421,7 @@ class ReplicationCommandHandler:
         assert self._server_notices_sender is not None
         await self._server_notices_sender.on_user_ip(cmd.user_id)
 
-    def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
+    def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None:
         if cmd.instance_name == self._instance_name:
             # Ignore RDATA that are just our own echoes
             return
@@ -497,7 +497,7 @@ class ReplicationCommandHandler:
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
-    ):
+    ) -> None:
         """Called to handle a batch of replication data with a given stream token.
 
         Args:
@@ -512,7 +512,7 @@ class ReplicationCommandHandler:
             stream_name, instance_name, token, rows
         )
 
-    def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
+    def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand) -> None:
         if cmd.instance_name == self._instance_name:
             # Ignore POSITION that are just our own echoes
             return
@@ -581,7 +581,7 @@ class ReplicationCommandHandler:
 
     def on_REMOTE_SERVER_UP(
         self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
-    ):
+    ) -> None:
         """Called when get a new REMOTE_SERVER_UP command."""
         self._replication_data_handler.on_remote_server_up(cmd.data)
 
@@ -604,7 +604,7 @@ class ReplicationCommandHandler:
         # between two instances, but that is not currently supported).
         self.send_command(cmd, ignore_conn=conn)
 
-    def new_connection(self, connection: IReplicationConnection):
+    def new_connection(self, connection: IReplicationConnection) -> None:
         """Called when we have a new connection."""
         self._connections.append(connection)
 
@@ -631,7 +631,7 @@ class ReplicationCommandHandler:
                 UserSyncCommand(self._instance_id, user_id, True, now)
             )
 
-    def lost_connection(self, connection: IReplicationConnection):
+    def lost_connection(self, connection: IReplicationConnection) -> None:
         """Called when a connection is closed/lost."""
         # we no longer need _streams_by_connection for this connection.
         streams = self._streams_by_connection.pop(connection, None)
@@ -653,7 +653,7 @@ class ReplicationCommandHandler:
 
     def send_command(
         self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
-    ):
+    ) -> None:
         """Send a command to all connected connections.
 
         Args:
@@ -680,7 +680,7 @@ class ReplicationCommandHandler:
         else:
             logger.warning("Dropping command as not connected: %r", cmd.NAME)
 
-    def send_federation_ack(self, token: int):
+    def send_federation_ack(self, token: int) -> None:
         """Ack data for the federation stream. This allows the master to drop
         data stored purely in memory.
         """
@@ -688,7 +688,7 @@ class ReplicationCommandHandler:
 
     def send_user_sync(
         self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
-    ):
+    ) -> None:
         """Poke the master that a user has started/stopped syncing."""
         self.send_command(
             UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
@@ -702,15 +702,15 @@ class ReplicationCommandHandler:
         user_agent: str,
         device_id: str,
         last_seen: int,
-    ):
+    ) -> None:
         """Tell the master that the user made a request."""
         cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
         self.send_command(cmd)
 
-    def send_remote_server_up(self, server: str):
+    def send_remote_server_up(self, server: str) -> None:
         self.send_command(RemoteServerUpCommand(server))
 
-    def stream_update(self, stream_name: str, token: str, data: Any):
+    def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None:
         """Called when a new update is available to stream to clients.
 
         We need to check if the client is interested in the stream or not
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 7bae36db16..7763ffb2d0 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -49,7 +49,7 @@ import fcntl
 import logging
 import struct
 from inspect import isawaitable
-from typing import TYPE_CHECKING, Collection, List, Optional
+from typing import TYPE_CHECKING, Any, Collection, List, Optional
 
 from prometheus_client import Counter
 from zope.interface import Interface, implementer
@@ -123,7 +123,7 @@ class ConnectionStates:
 class IReplicationConnection(Interface):
     """An interface for replication connections."""
 
-    def send_command(cmd: Command):
+    def send_command(cmd: Command) -> None:
         """Send the command down the connection"""
 
 
@@ -190,7 +190,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 "replication-conn", self.conn_id
             )
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         logger.info("[%s] Connection established", self.id())
 
         self.state = ConnectionStates.ESTABLISHED
@@ -207,11 +207,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         # Always send the initial PING so that the other side knows that they
         # can time us out.
-        self.send_command(PingCommand(self.clock.time_msec()))
+        self.send_command(PingCommand(str(self.clock.time_msec())))
 
         self.command_handler.new_connection(self)
 
-    def send_ping(self):
+    def send_ping(self) -> None:
         """Periodically sends a ping and checks if we should close the connection
         due to the other side timing out.
         """
@@ -226,7 +226,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 self.transport.abortConnection()
         else:
             if now - self.last_sent_command >= PING_TIME:
-                self.send_command(PingCommand(now))
+                self.send_command(PingCommand(str(now)))
 
             if (
                 self.received_ping
@@ -239,12 +239,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 )
                 self.send_error("ping timeout")
 
-    def lineReceived(self, line: bytes):
+    def lineReceived(self, line: bytes) -> None:
         """Called when we've received a line"""
         with PreserveLoggingContext(self._logging_context):
             self._parse_and_dispatch_line(line)
 
-    def _parse_and_dispatch_line(self, line: bytes):
+    def _parse_and_dispatch_line(self, line: bytes) -> None:
         if line.strip() == "":
             # Ignore blank lines
             return
@@ -309,24 +309,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         if not handled:
             logger.warning("Unhandled command: %r", cmd)
 
-    def close(self):
+    def close(self) -> None:
         logger.warning("[%s] Closing connection", self.id())
         self.time_we_closed = self.clock.time_msec()
         assert self.transport is not None
         self.transport.loseConnection()
         self.on_connection_closed()
 
-    def send_error(self, error_string, *args):
+    def send_error(self, error_string: str, *args: Any) -> None:
         """Send an error to remote and close the connection."""
         self.send_command(ErrorCommand(error_string % args))
         self.close()
 
-    def send_command(self, cmd, do_buffer=True):
+    def send_command(self, cmd: Command, do_buffer: bool = True) -> None:
         """Send a command if connection has been established.
 
         Args:
-            cmd (Command)
-            do_buffer (bool): Whether to buffer the message or always attempt
+            cmd
+            do_buffer: Whether to buffer the message or always attempt
                 to send the command. This is mostly used to send an error
                 message if we're about to close the connection due our buffers
                 becoming full.
@@ -357,7 +357,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         self.last_sent_command = self.clock.time_msec()
 
-    def _queue_command(self, cmd):
+    def _queue_command(self, cmd: Command) -> None:
         """Queue the command until the connection is ready to write to again."""
         logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
         self.pending_commands.append(cmd)
@@ -370,20 +370,20 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False)
             self.close()
 
-    def _send_pending_commands(self):
+    def _send_pending_commands(self) -> None:
         """Send any queued commandes"""
         pending = self.pending_commands
         self.pending_commands = []
         for cmd in pending:
             self.send_command(cmd)
 
-    def on_PING(self, line):
+    def on_PING(self, cmd: PingCommand) -> None:
         self.received_ping = True
 
-    def on_ERROR(self, cmd):
+    def on_ERROR(self, cmd: ErrorCommand) -> None:
         logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
 
-    def pauseProducing(self):
+    def pauseProducing(self) -> None:
         """This is called when both the kernel send buffer and the twisted
         tcp connection send buffers have become full.
 
@@ -394,26 +394,26 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         logger.info("[%s] Pause producing", self.id())
         self.state = ConnectionStates.PAUSED
 
-    def resumeProducing(self):
+    def resumeProducing(self) -> None:
         """The remote has caught up after we started buffering!"""
         logger.info("[%s] Resume producing", self.id())
         self.state = ConnectionStates.ESTABLISHED
         self._send_pending_commands()
 
-    def stopProducing(self):
+    def stopProducing(self) -> None:
         """We're never going to send any more data (normally because either
         we or the remote has closed the connection)
         """
         logger.info("[%s] Stop producing", self.id())
         self.on_connection_closed()
 
-    def connectionLost(self, reason):
+    def connectionLost(self, reason: Failure) -> None:  # type: ignore[override]
         logger.info("[%s] Replication connection closed: %r", self.id(), reason)
         if isinstance(reason, Failure):
             assert reason.type is not None
             connection_close_counter.labels(reason.type.__name__).inc()
         else:
-            connection_close_counter.labels(reason.__class__.__name__).inc()
+            connection_close_counter.labels(reason.__class__.__name__).inc()  # type: ignore[unreachable]
 
         try:
             # Remove us from list of connections to be monitored
@@ -427,7 +427,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         self.on_connection_closed()
 
-    def on_connection_closed(self):
+    def on_connection_closed(self) -> None:
         logger.info("[%s] Connection was closed", self.id())
 
         self.state = ConnectionStates.CLOSED
@@ -445,7 +445,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             # the sentinel context is now active, which may not be correct.
             # PreserveLoggingContext() will restore the correct logging context.
 
-    def __str__(self):
+    def __str__(self) -> str:
         addr = None
         if self.transport:
             addr = str(self.transport.getPeer())
@@ -455,10 +455,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             addr,
         )
 
-    def id(self):
+    def id(self) -> str:
         return "%s-%s" % (self.name, self.conn_id)
 
-    def lineLengthExceeded(self, line):
+    def lineLengthExceeded(self, line: str) -> None:
         """Called when we receive a line that is above the maximum line length"""
         self.send_error("Line length exceeded")
 
@@ -474,11 +474,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         self.server_name = server_name
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         self.send_command(ServerCommand(self.server_name))
         super().connectionMade()
 
-    def on_NAME(self, cmd):
+    def on_NAME(self, cmd: NameCommand) -> None:
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
@@ -500,19 +500,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.client_name = client_name
         self.server_name = server_name
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         self.send_command(NameCommand(self.client_name))
         super().connectionMade()
 
         # Once we've connected subscribe to the necessary streams
         self.replicate()
 
-    def on_SERVER(self, cmd):
+    def on_SERVER(self, cmd: ServerCommand) -> None:
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
 
-    def replicate(self):
+    def replicate(self) -> None:
         """Send the subscription request to the server"""
         logger.info("[%s] Subscribing to replication streams", self.id())
 
@@ -529,7 +529,7 @@ pending_commands = LaterGauge(
 )
 
 
-def transport_buffer_size(protocol):
+def transport_buffer_size(protocol: BaseReplicationStreamProtocol) -> int:
     if protocol.transport:
         size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen
         return size
@@ -544,7 +544,9 @@ transport_send_buffer = LaterGauge(
 )
 
 
-def transport_kernel_read_buffer_size(protocol, read=True):
+def transport_kernel_read_buffer_size(
+    protocol: BaseReplicationStreamProtocol, read: bool = True
+) -> int:
     SIOCINQ = 0x541B
     SIOCOUTQ = 0x5411
 
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 5b37f379d0..3170f7c59b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -14,7 +14,7 @@
 
 import logging
 from inspect import isawaitable
-from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
+from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
 
 import attr
 import txredisapi
@@ -62,7 +62,7 @@ class ConstantProperty(Generic[T, V]):
     def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
         return self.constant
 
-    def __set__(self, obj: Optional[T], value: V):
+    def __set__(self, obj: Optional[T], value: V) -> None:
         pass
 
 
@@ -95,7 +95,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
     synapse_stream_name: str
     synapse_outbound_redis_connection: txredisapi.RedisProtocol
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args: Any, **kwargs: Any):
         super().__init__(*args, **kwargs)
 
         # a logcontext which we use for processing incoming commands. We declare it as a
@@ -108,12 +108,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
                 "replication_command_handler"
             )
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         logger.info("Connected to redis")
         super().connectionMade()
         run_as_background_process("subscribe-replication", self._send_subscribe)
 
-    async def _send_subscribe(self):
+    async def _send_subscribe(self) -> None:
         # it's important to make sure that we only send the REPLICATE command once we
         # have successfully subscribed to the stream - otherwise we might miss the
         # POSITION response sent back by the other end.
@@ -131,12 +131,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
         # otherside won't know we've connected and so won't issue a REPLICATE.
         self.synapse_handler.send_positions_to_connection(self)
 
-    def messageReceived(self, pattern: str, channel: str, message: str):
+    def messageReceived(self, pattern: str, channel: str, message: str) -> None:
         """Received a message from redis."""
         with PreserveLoggingContext(self._logging_context):
             self._parse_and_dispatch_message(message)
 
-    def _parse_and_dispatch_message(self, message: str):
+    def _parse_and_dispatch_message(self, message: str) -> None:
         if message.strip() == "":
             # Ignore blank lines
             return
@@ -181,7 +181,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
                 "replication-" + cmd.get_logcontext_id(), lambda: res
             )
 
-    def connectionLost(self, reason):
+    def connectionLost(self, reason: Failure) -> None:  # type: ignore[override]
         logger.info("Lost connection to redis")
         super().connectionLost(reason)
         self.synapse_handler.lost_connection(self)
@@ -193,17 +193,17 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
             # the sentinel context is now active, which may not be correct.
             # PreserveLoggingContext() will restore the correct logging context.
 
-    def send_command(self, cmd: Command):
+    def send_command(self, cmd: Command) -> None:
         """Send a command if connection has been established.
 
         Args:
-            cmd (Command)
+            cmd: The command to send
         """
         run_as_background_process(
             "send-cmd", self._async_send_command, cmd, bg_start_span=False
         )
 
-    async def _async_send_command(self, cmd: Command):
+    async def _async_send_command(self, cmd: Command) -> None:
         """Encode a replication command and send it over our outbound connection"""
         string = "%s %s" % (cmd.NAME, cmd.to_line())
         if "\n" in string:
@@ -259,7 +259,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
         hs.get_clock().looping_call(self._send_ping, 30 * 1000)
 
     @wrap_as_background_process("redis_ping")
-    async def _send_ping(self):
+    async def _send_ping(self) -> None:
         for connection in self.pool:
             try:
                 await make_deferred_yieldable(connection.ping())
@@ -269,13 +269,13 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
     # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
     # it's rubbish. We add our own here.
 
-    def startedConnecting(self, connector: IConnector):
+    def startedConnecting(self, connector: IConnector) -> None:
         logger.info(
             "Connecting to redis server %s", format_address(connector.getDestination())
         )
         super().startedConnecting(connector)
 
-    def clientConnectionFailed(self, connector: IConnector, reason: Failure):
+    def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
         logger.info(
             "Connection to redis server %s failed: %s",
             format_address(connector.getDestination()),
@@ -283,7 +283,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
         )
         super().clientConnectionFailed(connector, reason)
 
-    def clientConnectionLost(self, connector: IConnector, reason: Failure):
+    def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
         logger.info(
             "Connection to redis server %s lost: %s",
             format_address(connector.getDestination()),
@@ -330,7 +330,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
 
         self.synapse_outbound_redis_connection = outbound_redis_connection
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IAddress) -> RedisSubscriber:
         p = super().buildProtocol(addr)
         p = cast(RedisSubscriber, p)
 
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index a9d85f4f6c..ecd6190f5b 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -16,16 +16,18 @@
 
 import logging
 import random
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from prometheus_client import Counter
 
+from twisted.internet.interfaces import IAddress
 from twisted.internet.protocol import ServerFactory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.commands import PositionCommand
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
 from synapse.replication.tcp.streams import EventsStream
+from synapse.replication.tcp.streams._base import StreamRow, Token
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -56,7 +58,7 @@ class ReplicationStreamProtocolFactory(ServerFactory):
         # listener config again or always starting a `ReplicationStreamer`.)
         hs.get_replication_streamer()
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol:
         return ServerReplicationStreamProtocol(
             self.server_name, self.clock, self.command_handler
         )
@@ -105,7 +107,7 @@ class ReplicationStreamer:
         if any(EventsStream.NAME == s.NAME for s in self.streams):
             self.clock.looping_call(self.on_notifier_poke, 1000)
 
-    def on_notifier_poke(self):
+    def on_notifier_poke(self) -> None:
         """Checks if there is actually any new data and sends it to the
         connections if there are.
 
@@ -137,7 +139,7 @@ class ReplicationStreamer:
 
         run_as_background_process("replication_notifier", self._run_notifier_loop)
 
-    async def _run_notifier_loop(self):
+    async def _run_notifier_loop(self) -> None:
         self.is_looping = True
 
         try:
@@ -238,7 +240,9 @@ class ReplicationStreamer:
             self.is_looping = False
 
 
-def _batch_updates(updates):
+def _batch_updates(
+    updates: List[Tuple[Token, StreamRow]]
+) -> List[Tuple[Optional[Token], StreamRow]]:
     """Takes a list of updates of form [(token, row)] and sets the token to
     None for all rows where the next row has the same token. This is used to
     implement batching.
@@ -254,7 +258,7 @@ def _batch_updates(updates):
     if not updates:
         return []
 
-    new_updates = []
+    new_updates: List[Tuple[Optional[Token], StreamRow]] = []
     for i, update in enumerate(updates[:-1]):
         if update[0] == updates[i + 1][0]:
             new_updates.append((None, update[1]))
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 5a2d90c530..914b9eae84 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -90,7 +90,7 @@ class Stream:
     ROW_TYPE: Any = None
 
     @classmethod
-    def parse_row(cls, row: StreamRow):
+    def parse_row(cls, row: StreamRow) -> Any:
         """Parse a row received over replication
 
         By default, assumes that the row data is an array object and passes its contents
@@ -139,7 +139,7 @@ class Stream:
         # The token from which we last asked for updates
         self.last_token = self.current_token(self.local_instance_name)
 
-    def discard_updates_and_advance(self):
+    def discard_updates_and_advance(self) -> None:
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
@@ -200,7 +200,7 @@ def current_token_without_instance(
     return lambda instance_name: current_token()
 
 
-def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
+def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction:
     """Makes a suitable function for use as an `update_function` that queries
     the master process for updates.
     """
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 4f4f1ad453..50c4a5ba03 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,12 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import heapq
-from collections.abc import Iterable
-from typing import TYPE_CHECKING, Optional, Tuple, Type
+from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token
+from synapse.replication.tcp.streams._base import (
+    Stream,
+    StreamRow,
+    StreamUpdateResult,
+    Token,
+)
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -58,6 +62,9 @@ class EventsStreamRow:
     data: "BaseEventsStreamRow"
 
 
+T = TypeVar("T", bound="BaseEventsStreamRow")
+
+
 class BaseEventsStreamRow:
     """Base class for rows to be sent in the events stream.
 
@@ -68,7 +75,7 @@ class BaseEventsStreamRow:
     TypeId: str
 
     @classmethod
-    def from_data(cls, data):
+    def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T:
         """Parse the data from the replication stream into a row.
 
         By default we just call the constructor with the data list as arguments
@@ -221,7 +228,7 @@ class EventsStream(Stream):
         return updates, upper_limit, limited
 
     @classmethod
-    def parse_row(cls, row):
-        (typ, data) = row
-        data = TypeToRow[typ].from_data(data)
-        return EventsStreamRow(typ, data)
+    def parse_row(cls, row: StreamRow) -> "EventsStreamRow":
+        (typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row)
+        event_stream_row_data = TypeToRow[typ].from_data(data)
+        return EventsStreamRow(typ, event_stream_row_data)