summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py26
-rw-r--r--synapse/replication/tcp/commands.py59
-rw-r--r--synapse/replication/tcp/protocol.py121
-rw-r--r--synapse/replication/tcp/resource.py43
-rw-r--r--synapse/replication/tcp/streams/_base.py104
-rw-r--r--synapse/replication/tcp/streams/events.py25
-rw-r--r--synapse/replication/tcp/streams/federation.py4
7 files changed, 218 insertions, 164 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index bbcb84646c..fc06a7b053 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,7 +16,7 @@
 """
 
 import logging
-from typing import Dict
+from typing import Dict, List, Optional
 
 from twisted.internet import defer
 from twisted.internet.protocol import ReconnectingClientFactory
@@ -28,6 +28,7 @@ from synapse.replication.tcp.protocol import (
 )
 
 from .commands import (
+    Command,
     FederationAckCommand,
     InvalidateCacheCommand,
     RemovePusherCommand,
@@ -89,15 +90,15 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
 
         # Any pending commands to be sent once a new connection has been
         # established
-        self.pending_commands = []
+        self.pending_commands = []  # type: List[Command]
 
         # Map from string -> deferred, to wake up when receiveing a SYNC with
         # the given string.
         # Used for tests.
-        self.awaiting_syncs = {}
+        self.awaiting_syncs = {}  # type: Dict[str, defer.Deferred]
 
         # The factory used to create connections.
-        self.factory = None
+        self.factory = None  # type: Optional[ReplicationClientFactory]
 
     def start_replication(self, hs):
         """Helper method to start a replication connection to the remote server
@@ -109,7 +110,7 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
         port = hs.config.worker_replication_port
         hs.get_reactor().connectTCP(host, port, self.factory)
 
-    def on_rdata(self, stream_name, token, rows):
+    async def on_rdata(self, stream_name, token, rows):
         """Called to handle a batch of replication data with a given stream token.
 
         By default this just pokes the slave store. Can be overridden in subclasses to
@@ -120,20 +121,17 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
             token (int): stream token for this batch of rows
             rows (list): a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
-
-        Returns:
-            Deferred|None
         """
         logger.debug("Received rdata %s -> %s", stream_name, token)
-        return self.store.process_replication_rows(stream_name, token, rows)
+        self.store.process_replication_rows(stream_name, token, rows)
 
-    def on_position(self, stream_name, token):
+    async def on_position(self, stream_name, token):
         """Called when we get new position data. By default this just pokes
         the slave store.
 
         Can be overriden in subclasses to handle more.
         """
-        return self.store.process_replication_rows(stream_name, token, [])
+        self.store.process_replication_rows(stream_name, token, [])
 
     def on_sync(self, data):
         """When we received a SYNC we wake up any deferreds that were waiting
@@ -145,6 +143,9 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
         if d:
             d.callback(data)
 
+    def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+
     def get_streams_to_replicate(self) -> Dict[str, int]:
         """Called when a new connection has been established and we need to
         subscribe to streams.
@@ -235,4 +236,5 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
         # We don't reset the delay any earlier as otherwise if there is a
         # problem during start up we'll end up tight looping connecting to the
         # server.
-        self.factory.resetDelay()
+        if self.factory:
+            self.factory.resetDelay()
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 0ff2a7199f..451671412d 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -20,15 +20,16 @@ allowed to be sent by which side.
 
 import logging
 import platform
+from typing import Tuple, Type
 
 if platform.python_implementation() == "PyPy":
     import json
 
     _json_encoder = json.JSONEncoder()
 else:
-    import simplejson as json
+    import simplejson as json  # type: ignore[no-redef]  # noqa: F821
 
-    _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
+    _json_encoder = json.JSONEncoder(namedtuple_as_object=False)  # type: ignore[call-arg]  # noqa: F821
 
 logger = logging.getLogger(__name__)
 
@@ -44,7 +45,7 @@ class Command(object):
     The default implementation creates a command of form `<NAME> <data>`
     """
 
-    NAME = None
+    NAME = None  # type: str
 
     def __init__(self, data):
         self.data = data
@@ -386,25 +387,39 @@ class UserIpCommand(Command):
         )
 
 
+class RemoteServerUpCommand(Command):
+    """Sent when a worker has detected that a remote server is no longer
+    "down" and retry timings should be reset.
+
+    If sent from a client the server will relay to all other workers.
+
+    Format::
+
+        REMOTE_SERVER_UP <server>
+    """
+
+    NAME = "REMOTE_SERVER_UP"
+
+
+_COMMANDS = (
+    ServerCommand,
+    RdataCommand,
+    PositionCommand,
+    ErrorCommand,
+    PingCommand,
+    NameCommand,
+    ReplicateCommand,
+    UserSyncCommand,
+    FederationAckCommand,
+    SyncCommand,
+    RemovePusherCommand,
+    InvalidateCacheCommand,
+    UserIpCommand,
+    RemoteServerUpCommand,
+)  # type: Tuple[Type[Command], ...]
+
 # Map of command name to command type.
-COMMAND_MAP = {
-    cmd.NAME: cmd
-    for cmd in (
-        ServerCommand,
-        RdataCommand,
-        PositionCommand,
-        ErrorCommand,
-        PingCommand,
-        NameCommand,
-        ReplicateCommand,
-        UserSyncCommand,
-        FederationAckCommand,
-        SyncCommand,
-        RemovePusherCommand,
-        InvalidateCacheCommand,
-        UserIpCommand,
-    )
-}
+COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
 
 # The commands the server is allowed to send
 VALID_SERVER_COMMANDS = (
@@ -414,6 +429,7 @@ VALID_SERVER_COMMANDS = (
     ErrorCommand.NAME,
     PingCommand.NAME,
     SyncCommand.NAME,
+    RemoteServerUpCommand.NAME,
 )
 
 # The commands the client is allowed to send
@@ -427,4 +443,5 @@ VALID_CLIENT_COMMANDS = (
     InvalidateCacheCommand.NAME,
     UserIpCommand.NAME,
     ErrorCommand.NAME,
+    RemoteServerUpCommand.NAME,
 )
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index afaf002fe6..131e5acb09 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -53,6 +53,7 @@ import fcntl
 import logging
 import struct
 from collections import defaultdict
+from typing import Any, DefaultDict, Dict, List, Set, Tuple
 
 from six import iteritems, iterkeys
 
@@ -65,24 +66,26 @@ from twisted.python.failure import Failure
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util import Clock
-from synapse.util.stringutils import random_string
-
-from .commands import (
+from synapse.replication.tcp.commands import (
     COMMAND_MAP,
     VALID_CLIENT_COMMANDS,
     VALID_SERVER_COMMANDS,
+    Command,
     ErrorCommand,
     NameCommand,
     PingCommand,
     PositionCommand,
     RdataCommand,
+    RemoteServerUpCommand,
     ReplicateCommand,
     ServerCommand,
     SyncCommand,
     UserSyncCommand,
 )
-from .streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.types import Collection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
 
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
@@ -124,8 +127,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     delimiter = b"\n"
 
-    VALID_INBOUND_COMMANDS = []  # Valid commands we expect to receive
-    VALID_OUTBOUND_COMMANDS = []  # Valid commans we can send
+    # Valid commands we expect to receive
+    VALID_INBOUND_COMMANDS = []  # type: Collection[str]
+
+    # Valid commands we can send
+    VALID_OUTBOUND_COMMANDS = []  # type: Collection[str]
 
     max_line_buffer = 10000
 
@@ -144,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.conn_id = random_string(5)  # To dedupe in case of name clashes.
 
         # List of pending commands to send once we've established the connection
-        self.pending_commands = []
+        self.pending_commands = []  # type: List[Command]
 
         # The LoopingCall for sending pings.
         self._send_ping_loop = None
 
-        self.inbound_commands_counter = defaultdict(int)
-        self.outbound_commands_counter = defaultdict(int)
+        self.inbound_commands_counter = defaultdict(int)  # type: DefaultDict[str, int]
+        self.outbound_commands_counter = defaultdict(int)  # type: DefaultDict[str, int]
 
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
@@ -235,19 +241,16 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
         )
 
-    def handle_command(self, cmd):
+    async def handle_command(self, cmd: Command):
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>
+        By default delegates to on_<COMMAND>, which should return an awaitable.
 
         Args:
-            cmd (synapse.replication.tcp.commands.Command): received command
-
-        Returns:
-            Deferred
+            cmd: received command
         """
         handler = getattr(self, "on_%s" % (cmd.NAME,))
-        return handler(cmd)
+        await handler(cmd)
 
     def close(self):
         logger.warning("[%s] Closing connection", self.id())
@@ -320,10 +323,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         for cmd in pending:
             self.send_command(cmd)
 
-    def on_PING(self, line):
+    async def on_PING(self, line):
         self.received_ping = True
 
-    def on_ERROR(self, cmd):
+    async def on_ERROR(self, cmd):
         logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
 
     def pauseProducing(self):
@@ -409,30 +412,30 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.streamer = streamer
 
         # The streams the client has subscribed to and is up to date with
-        self.replication_streams = set()
+        self.replication_streams = set()  # type: Set[str]
 
         # The streams the client is currently subscribing to.
-        self.connecting_streams = set()
+        self.connecting_streams = set()  # type:  Set[str]
 
         # Map from stream name to list of updates to send once we've finished
         # subscribing the client to the stream.
-        self.pending_rdata = {}
+        self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]]
 
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
         BaseReplicationStreamProtocol.connectionMade(self)
         self.streamer.new_connection(self)
 
-    def on_NAME(self, cmd):
+    async def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
-    def on_USER_SYNC(self, cmd):
-        return self.streamer.on_user_sync(
+    async def on_USER_SYNC(self, cmd):
+        await self.streamer.on_user_sync(
             self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
         )
 
-    def on_REPLICATE(self, cmd):
+    async def on_REPLICATE(self, cmd):
         stream_name = cmd.stream_name
         token = cmd.token
 
@@ -443,23 +446,26 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
                 for stream in iterkeys(self.streamer.streams_by_name)
             ]
 
-            return make_deferred_yieldable(
+            await make_deferred_yieldable(
                 defer.gatherResults(deferreds, consumeErrors=True)
             )
         else:
-            return self.subscribe_to_stream(stream_name, token)
+            await self.subscribe_to_stream(stream_name, token)
 
-    def on_FEDERATION_ACK(self, cmd):
-        return self.streamer.federation_ack(cmd.token)
+    async def on_FEDERATION_ACK(self, cmd):
+        self.streamer.federation_ack(cmd.token)
 
-    def on_REMOVE_PUSHER(self, cmd):
-        return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
+    async def on_REMOVE_PUSHER(self, cmd):
+        await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
 
-    def on_INVALIDATE_CACHE(self, cmd):
-        return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+    async def on_INVALIDATE_CACHE(self, cmd):
+        self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
 
-    def on_USER_IP(self, cmd):
-        return self.streamer.on_user_ip(
+    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+        self.streamer.on_remote_server_up(cmd.data)
+
+    async def on_USER_IP(self, cmd):
+        self.streamer.on_user_ip(
             cmd.user_id,
             cmd.access_token,
             cmd.ip,
@@ -468,8 +474,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
             cmd.last_seen,
         )
 
-    @defer.inlineCallbacks
-    def subscribe_to_stream(self, stream_name, token):
+    async def subscribe_to_stream(self, stream_name, token):
         """Subscribe the remote to a stream.
 
         This invloves checking if they've missed anything and sending those
@@ -481,7 +486,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
         try:
             # Get missing updates
-            updates, current_token = yield self.streamer.get_stream_updates(
+            updates, current_token = await self.streamer.get_stream_updates(
                 stream_name, token
             )
 
@@ -554,6 +559,9 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
     def send_sync(self, data):
         self.send_command(SyncCommand(data))
 
+    def send_remote_server_up(self, server: str):
+        self.send_command(RemoteServerUpCommand(server))
+
     def on_connection_closed(self):
         BaseReplicationStreamProtocol.on_connection_closed(self)
         self.streamer.lost_connection(self)
@@ -566,7 +574,7 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
     """
 
     @abc.abstractmethod
-    def on_rdata(self, stream_name, token, rows):
+    async def on_rdata(self, stream_name, token, rows):
         """Called to handle a batch of replication data with a given stream token.
 
         Args:
@@ -574,14 +582,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
             token (int): stream token for this batch of rows
             rows (list): a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
-
-        Returns:
-            Deferred|None
         """
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def on_position(self, stream_name, token):
+    async def on_position(self, stream_name, token):
         """Called when we get new position data."""
         raise NotImplementedError()
 
@@ -591,6 +596,11 @@ class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
+    async def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+        raise NotImplementedError()
+
+    @abc.abstractmethod
     def get_streams_to_replicate(self):
         """Called when a new connection has been established and we need to
         subscribe to streams.
@@ -642,11 +652,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         # Set of stream names that have been subscribe to, but haven't yet
         # caught up with. This is used to track when the client has been fully
         # connected to the remote.
-        self.streams_connecting = set()
+        self.streams_connecting = set()  # type: Set[str]
 
         # Map of stream to batched updates. See RdataCommand for info on how
         # batching works.
-        self.pending_batches = {}
+        self.pending_batches = {}  # type: Dict[str, Any]
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
@@ -670,12 +680,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         if not self.streams_connecting:
             self.handler.finished_connecting()
 
-    def on_SERVER(self, cmd):
+    async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
 
-    def on_RDATA(self, cmd):
+    async def on_RDATA(self, cmd):
         stream_name = cmd.stream_name
         inbound_rdata_count.labels(stream_name).inc()
 
@@ -695,19 +705,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             # Check if this is the last of a batch of updates
             rows = self.pending_batches.pop(stream_name, [])
             rows.append(row)
-            return self.handler.on_rdata(stream_name, cmd.token, rows)
+            await self.handler.on_rdata(stream_name, cmd.token, rows)
 
-    def on_POSITION(self, cmd):
+    async def on_POSITION(self, cmd):
         # When we get a `POSITION` command it means we've finished getting
         # missing updates for the given stream, and are now up to date.
         self.streams_connecting.discard(cmd.stream_name)
         if not self.streams_connecting:
             self.handler.finished_connecting()
 
-        return self.handler.on_position(cmd.stream_name, cmd.token)
+        await self.handler.on_position(cmd.stream_name, cmd.token)
+
+    async def on_SYNC(self, cmd):
+        self.handler.on_sync(cmd.data)
 
-    def on_SYNC(self, cmd):
-        return self.handler.on_sync(cmd.data)
+    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+        self.handler.on_remote_server_up(cmd.data)
 
     def replicate(self, stream_name, token):
         """Send the subscription request to the server
@@ -766,7 +779,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
             op = SIOCINQ
         else:
             op = SIOCOUTQ
-        size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
+        size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
         return size
     return 0
 
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index d1e98428bc..6ebf944f66 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,12 +17,12 @@
 
 import logging
 import random
+from typing import List
 
 from six import itervalues
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
 from twisted.internet.protocol import Factory
 
 from synapse.metrics import LaterGauge
@@ -79,7 +79,7 @@ class ReplicationStreamer(object):
         self._replication_torture_level = hs.config.replication_torture_level
 
         # Current connections.
-        self.connections = []
+        self.connections = []  # type: List[ServerReplicationStreamProtocol]
 
         LaterGauge(
             "synapse_replication_tcp_resource_total_connections",
@@ -120,6 +120,7 @@ class ReplicationStreamer(object):
             self.federation_sender = hs.get_federation_sender()
 
         self.notifier.add_replication_callback(self.on_notifier_poke)
+        self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
 
         # Keeps track of whether we are currently checking for updates
         self.is_looping = False
@@ -154,8 +155,7 @@ class ReplicationStreamer(object):
 
         run_as_background_process("replication_notifier", self._run_notifier_loop)
 
-    @defer.inlineCallbacks
-    def _run_notifier_loop(self):
+    async def _run_notifier_loop(self):
         self.is_looping = True
 
         try:
@@ -184,7 +184,7 @@ class ReplicationStreamer(object):
                             continue
 
                         if self._replication_torture_level:
-                            yield self.clock.sleep(
+                            await self.clock.sleep(
                                 self._replication_torture_level / 1000.0
                             )
 
@@ -195,7 +195,7 @@ class ReplicationStreamer(object):
                             stream.upto_token,
                         )
                         try:
-                            updates, current_token = yield stream.get_updates()
+                            updates, current_token = await stream.get_updates()
                         except Exception:
                             logger.info("Failed to handle stream %s", stream.NAME)
                             raise
@@ -232,7 +232,7 @@ class ReplicationStreamer(object):
             self.is_looping = False
 
     @measure_func("repl.get_stream_updates")
-    def get_stream_updates(self, stream_name, token):
+    async def get_stream_updates(self, stream_name, token):
         """For a given stream get all updates since token. This is called when
         a client first subscribes to a stream.
         """
@@ -240,7 +240,7 @@ class ReplicationStreamer(object):
         if not stream:
             raise Exception("unknown stream %s", stream_name)
 
-        return stream.get_updates_since(token)
+        return await stream.get_updates_since(token)
 
     @measure_func("repl.federation_ack")
     def federation_ack(self, token):
@@ -251,22 +251,20 @@ class ReplicationStreamer(object):
             self.federation_sender.federation_ack(token)
 
     @measure_func("repl.on_user_sync")
-    @defer.inlineCallbacks
-    def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+    async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
         """A client has started/stopped syncing on a worker.
         """
         user_sync_counter.inc()
-        yield self.presence_handler.update_external_syncs_row(
+        await self.presence_handler.update_external_syncs_row(
             conn_id, user_id, is_syncing, last_sync_ms
         )
 
     @measure_func("repl.on_remove_pusher")
-    @defer.inlineCallbacks
-    def on_remove_pusher(self, app_id, push_key, user_id):
+    async def on_remove_pusher(self, app_id, push_key, user_id):
         """A client has asked us to remove a pusher
         """
         remove_pusher_counter.inc()
-        yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+        await self.store.delete_pusher_by_app_id_pushkey_user_id(
             app_id=app_id, pushkey=push_key, user_id=user_id
         )
 
@@ -280,15 +278,24 @@ class ReplicationStreamer(object):
         getattr(self.store, cache_func).invalidate(tuple(keys))
 
     @measure_func("repl.on_user_ip")
-    @defer.inlineCallbacks
-    def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+    async def on_user_ip(
+        self, user_id, access_token, ip, user_agent, device_id, last_seen
+    ):
         """The client saw a user request
         """
         user_ip_cache_counter.inc()
-        yield self.store.insert_client_ip(
+        await self.store.insert_client_ip(
             user_id, access_token, ip, user_agent, device_id, last_seen
         )
-        yield self._server_notices_sender.on_user_ip(user_id)
+        await self._server_notices_sender.on_user_ip(user_id)
+
+    @measure_func("repl.on_remote_server_up")
+    def on_remote_server_up(self, server: str):
+        self.notifier.notify_remote_server_up(server)
+
+    def send_remote_server_up(self, server: str):
+        for conn in self.connections:
+            conn.send_remote_server_up(server)
 
     def send_sync_to_all_connections(self, data):
         """Sends a SYNC command to all clients.
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 8512923eae..a8d568b14a 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,12 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import itertools
 import logging
 from collections import namedtuple
+from typing import Any, List, Optional
 
-from twisted.internet import defer
+import attr
 
 logger = logging.getLogger(__name__)
 
@@ -67,10 +67,24 @@ PushersStreamRow = namedtuple(
     "PushersStreamRow",
     ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
 )
-CachesStreamRow = namedtuple(
-    "CachesStreamRow",
-    ("cache_func", "keys", "invalidation_ts"),  # str  # list(str)  # int
-)
+
+
+@attr.s
+class CachesStreamRow:
+    """Stream to inform workers they should invalidate their cache.
+
+    Attributes:
+        cache_func: Name of the cached function.
+        keys: The entry in the cache to invalidate. If None then will
+            invalidate all.
+        invalidation_ts: Timestamp of when the invalidation took place.
+    """
+
+    cache_func = attr.ib(type=str)
+    keys = attr.ib(type=Optional[List[Any]])
+    invalidation_ts = attr.ib(type=int)
+
+
 PublicRoomsStreamRow = namedtuple(
     "PublicRoomsStreamRow",
     (
@@ -104,8 +118,9 @@ class Stream(object):
     time it was called up until the point `advance_current_token` was called.
     """
 
-    NAME = None  # The name of the stream
-    ROW_TYPE = None  # The type of the row. Used by the default impl of parse_row.
+    NAME = None  # type: str  # The name of the stream
+    # The type of the row. Used by the default impl of parse_row.
+    ROW_TYPE = None  # type: Any
     _LIMITED = True  # Whether the update function takes a limit
 
     @classmethod
@@ -143,8 +158,7 @@ class Stream(object):
         self.upto_token = self.current_token()
         self.last_token = self.upto_token
 
-    @defer.inlineCallbacks
-    def get_updates(self):
+    async def get_updates(self):
         """Gets all updates since the last time this function was called (or
         since the stream was constructed if it hadn't been called before),
         until the `upto_token`
@@ -155,13 +169,12 @@ class Stream(object):
                 list of ``(token, row)`` entries. ``row`` will be json-serialised and
                 sent over the replication steam.
         """
-        updates, current_token = yield self.get_updates_since(self.last_token)
+        updates, current_token = await self.get_updates_since(self.last_token)
         self.last_token = current_token
 
         return updates, current_token
 
-    @defer.inlineCallbacks
-    def get_updates_since(self, from_token):
+    async def get_updates_since(self, from_token):
         """Like get_updates except allows specifying from when we should
         stream updates
 
@@ -181,15 +194,16 @@ class Stream(object):
         if from_token == current_token:
             return [], current_token
 
+        logger.info("get_updates_since: %s", self.__class__)
         if self._LIMITED:
-            rows = yield self.update_function(
+            rows = await self.update_function(
                 from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
             )
 
             # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
             rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
         else:
-            rows = yield self.update_function(from_token, current_token)
+            rows = await self.update_function(from_token, current_token)
 
         updates = [(row[0], row[1:]) for row in rows]
 
@@ -231,8 +245,8 @@ class BackfillStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        self.current_token = store.get_current_backfill_token
-        self.update_function = store.get_all_new_backfill_event_rows
+        self.current_token = store.get_current_backfill_token  # type: ignore
+        self.update_function = store.get_all_new_backfill_event_rows  # type: ignore
 
         super(BackfillStream, self).__init__(hs)
 
@@ -246,8 +260,8 @@ class PresenceStream(Stream):
         store = hs.get_datastore()
         presence_handler = hs.get_presence_handler()
 
-        self.current_token = store.get_current_presence_token
-        self.update_function = presence_handler.get_all_presence_updates
+        self.current_token = store.get_current_presence_token  # type: ignore
+        self.update_function = presence_handler.get_all_presence_updates  # type: ignore
 
         super(PresenceStream, self).__init__(hs)
 
@@ -260,8 +274,8 @@ class TypingStream(Stream):
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
 
-        self.current_token = typing_handler.get_current_token
-        self.update_function = typing_handler.get_all_typing_updates
+        self.current_token = typing_handler.get_current_token  # type: ignore
+        self.update_function = typing_handler.get_all_typing_updates  # type: ignore
 
         super(TypingStream, self).__init__(hs)
 
@@ -273,8 +287,8 @@ class ReceiptsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_max_receipt_stream_id
-        self.update_function = store.get_all_updated_receipts
+        self.current_token = store.get_max_receipt_stream_id  # type: ignore
+        self.update_function = store.get_all_updated_receipts  # type: ignore
 
         super(ReceiptsStream, self).__init__(hs)
 
@@ -294,9 +308,8 @@ class PushRulesStream(Stream):
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, to_token, limit):
-        rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
+    async def update_function(self, from_token, to_token, limit):
+        rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
         return [(row[0], row[2]) for row in rows]
 
 
@@ -310,8 +323,8 @@ class PushersStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_pushers_stream_token
-        self.update_function = store.get_all_updated_pushers_rows
+        self.current_token = store.get_pushers_stream_token  # type: ignore
+        self.update_function = store.get_all_updated_pushers_rows  # type: ignore
 
         super(PushersStream, self).__init__(hs)
 
@@ -327,8 +340,8 @@ class CachesStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_cache_stream_token
-        self.update_function = store.get_all_updated_caches
+        self.current_token = store.get_cache_stream_token  # type: ignore
+        self.update_function = store.get_all_updated_caches  # type: ignore
 
         super(CachesStream, self).__init__(hs)
 
@@ -343,8 +356,8 @@ class PublicRoomsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_current_public_room_stream_id
-        self.update_function = store.get_all_new_public_rooms
+        self.current_token = store.get_current_public_room_stream_id  # type: ignore
+        self.update_function = store.get_all_new_public_rooms  # type: ignore
 
         super(PublicRoomsStream, self).__init__(hs)
 
@@ -360,8 +373,8 @@ class DeviceListsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_device_stream_token
-        self.update_function = store.get_all_device_list_changes_for_remotes
+        self.current_token = store.get_device_stream_token  # type: ignore
+        self.update_function = store.get_all_device_list_changes_for_remotes  # type: ignore
 
         super(DeviceListsStream, self).__init__(hs)
 
@@ -376,8 +389,8 @@ class ToDeviceStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_to_device_stream_token
-        self.update_function = store.get_all_new_device_messages
+        self.current_token = store.get_to_device_stream_token  # type: ignore
+        self.update_function = store.get_all_new_device_messages  # type: ignore
 
         super(ToDeviceStream, self).__init__(hs)
 
@@ -392,8 +405,8 @@ class TagAccountDataStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_max_account_data_stream_id
-        self.update_function = store.get_all_updated_tags
+        self.current_token = store.get_max_account_data_stream_id  # type: ignore
+        self.update_function = store.get_all_updated_tags  # type: ignore
 
         super(TagAccountDataStream, self).__init__(hs)
 
@@ -408,13 +421,12 @@ class AccountDataStream(Stream):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-        self.current_token = self.store.get_max_account_data_stream_id
+        self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
 
         super(AccountDataStream, self).__init__(hs)
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, to_token, limit):
-        global_results, room_results = yield self.store.get_all_updated_account_data(
+    async def update_function(self, from_token, to_token, limit):
+        global_results, room_results = await self.store.get_all_updated_account_data(
             from_token, from_token, to_token, limit
         )
 
@@ -434,8 +446,8 @@ class GroupServerStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_group_stream_token
-        self.update_function = store.get_all_groups_changes
+        self.current_token = store.get_group_stream_token  # type: ignore
+        self.update_function = store.get_all_groups_changes  # type: ignore
 
         super(GroupServerStream, self).__init__(hs)
 
@@ -451,7 +463,7 @@ class UserSignatureStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_device_stream_token
-        self.update_function = store.get_all_user_signature_changes_for_remotes
+        self.current_token = store.get_device_stream_token  # type: ignore
+        self.update_function = store.get_all_user_signature_changes_for_remotes  # type: ignore
 
         super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index d97669c886..b3afabb8cd 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,12 +13,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import heapq
+from typing import Tuple, Type
 
 import attr
 
-from twisted.internet import defer
-
 from ._base import Stream
 
 
@@ -63,7 +63,8 @@ class BaseEventsStreamRow(object):
     Specifies how to identify, serialize and deserialize the different types.
     """
 
-    TypeId = None  # Unique string that ids the type. Must be overriden in sub classes.
+    # Unique string that ids the type. Must be overriden in sub classes.
+    TypeId = None  # type: str
 
     @classmethod
     def from_data(cls, data):
@@ -99,9 +100,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
     event_id = attr.ib()  # str, optional
 
 
-TypeToRow = {
-    Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
-}
+_EventRows = (
+    EventsStreamEventRow,
+    EventsStreamCurrentStateRow,
+)  # type: Tuple[Type[BaseEventsStreamRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _EventRows}
 
 
 class EventsStream(Stream):
@@ -112,20 +116,19 @@ class EventsStream(Stream):
 
     def __init__(self, hs):
         self._store = hs.get_datastore()
-        self.current_token = self._store.get_current_events_token
+        self.current_token = self._store.get_current_events_token  # type: ignore
 
         super(EventsStream, self).__init__(hs)
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, current_token, limit=None):
-        event_rows = yield self._store.get_all_new_forward_event_rows(
+    async def update_function(self, from_token, current_token, limit=None):
+        event_rows = await self._store.get_all_new_forward_event_rows(
             from_token, current_token, limit
         )
         event_updates = (
             (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
         )
 
-        state_rows = yield self._store.get_all_updated_current_state_deltas(
+        state_rows = await self._store.get_all_updated_current_state_deltas(
             from_token, current_token, limit
         )
         state_updates = (
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index dc2484109d..615f3dc9ac 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -37,7 +37,7 @@ class FederationStream(Stream):
     def __init__(self, hs):
         federation_sender = hs.get_federation_sender()
 
-        self.current_token = federation_sender.get_current_token
-        self.update_function = federation_sender.get_replication_rows
+        self.current_token = federation_sender.get_current_token  # type: ignore
+        self.update_function = federation_sender.get_replication_rows  # type: ignore
 
         super(FederationStream, self).__init__(hs)