summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/streams.py78
-rw-r--r--synapse/replication/slave/storage/_base.py14
-rw-r--r--synapse/replication/slave/storage/devices.py36
-rw-r--r--synapse/replication/slave/storage/pushers.py3
-rw-r--r--synapse/replication/tcp/__init__.py30
-rw-r--r--synapse/replication/tcp/client.py178
-rw-r--r--synapse/replication/tcp/commands.py119
-rw-r--r--synapse/replication/tcp/handler.py469
-rw-r--r--synapse/replication/tcp/protocol.py393
-rw-r--r--synapse/replication/tcp/resource.py198
-rw-r--r--synapse/replication/tcp/streams/__init__.py71
-rw-r--r--synapse/replication/tcp/streams/_base.py506
-rw-r--r--synapse/replication/tcp/streams/events.py15
-rw-r--r--synapse/replication/tcp/streams/federation.py42
15 files changed, 1189 insertions, 965 deletions
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 28dbc6fcba..4613b2538c 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -21,6 +21,7 @@ from synapse.replication.http import (
     membership,
     register,
     send_event,
+    streams,
 )
 
 REPLICATION_PREFIX = "/_synapse/replication"
@@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
         login.register_servlets(hs, self)
         register.register_servlets(hs, self)
         devices.register_servlets(hs, self)
+        streams.register_servlets(hs, self)
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
new file mode 100644
index 0000000000..ffd4c61993
--- /dev/null
+++ b/synapse/replication/http/streams.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import parse_integer
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationGetStreamUpdates(ReplicationEndpoint):
+    """Fetches stream updates from a server. Used for streams not persisted to
+    the database, e.g. typing notifications.
+
+    The API looks like:
+
+        GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
+
+        200 OK
+
+        {
+            updates: [ ... ],
+            upto_token: 10,
+            limited: False,
+        }
+
+    """
+
+    NAME = "get_repl_stream_updates"
+    PATH_ARGS = ("stream_name",)
+    METHOD = "GET"
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        # We pull the streams from the replication steamer (if we try and make
+        # them ourselves we end up in an import loop).
+        self.streams = hs.get_replication_streamer().get_streams()
+
+    @staticmethod
+    def _serialize_payload(stream_name, from_token, upto_token, limit):
+        return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
+
+    async def _handle_request(self, request, stream_name):
+        stream = self.streams.get(stream_name)
+        if stream is None:
+            raise SynapseError(400, "Unknown stream")
+
+        from_token = parse_integer(request, "from_token", required=True)
+        upto_token = parse_integer(request, "upto_token", required=True)
+        limit = parse_integer(request, "limit", required=True)
+
+        updates, upto_token, limited = await stream.get_updates_since(
+            from_token, upto_token, limit
+        )
+
+        return (
+            200,
+            {"updates": updates, "upto_token": upto_token, "limited": limited},
+        )
+
+
+def register_servlets(hs, http_server):
+    ReplicationGetStreamUpdates(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f45cbd37a0..751c799d94 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,8 +18,10 @@ from typing import Dict, Optional
 
 import six
 
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.data_stores.main.cache import (
+    CURRENT_STATE_CACHE_NAME,
+    CacheInvalidationWorkerStore,
+)
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 
@@ -35,7 +37,7 @@ def __func__(inp):
         return inp.__func__
 
 
-class BaseSlavedStore(SQLBaseStore):
+class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: Database, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
@@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
             pos["caches"] = self._cache_id_gen.get_current_token()
         return pos
 
+    def get_cache_stream_token(self):
+        if self._cache_id_gen:
+            return self._cache_id_gen.get_current_token()
+        else:
+            return 0
+
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "caches":
             if self._cache_id_gen:
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 1c77687eea..23b1650e41 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
         self.hs = hs
 
         self._device_list_id_gen = SlavedIdTracker(
-            db_conn, "device_lists_stream", "stream_id"
+            db_conn,
+            "device_lists_stream",
+            "stream_id",
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
         )
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
@@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
-            for row in rows:
-                self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+            self._invalidate_caches_for_devices(token, rows)
         elif stream_name == UserSignatureStream.NAME:
+            self._device_list_id_gen.advance(token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super(SlavedDeviceStore, self).process_replication_rows(
             stream_name, token, rows
         )
 
-    def _invalidate_caches_for_devices(self, token, user_id, destination):
-        self._device_list_stream_cache.entity_has_changed(user_id, token)
-
-        if destination:
-            self._device_list_federation_stream_cache.entity_has_changed(
-                destination, token
-            )
+    def _invalidate_caches_for_devices(self, token, rows):
+        for row in rows:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            if row.entity.startswith("@"):
+                self._device_list_stream_cache.entity_has_changed(row.entity, token)
+                self.get_cached_devices_for_user.invalidate((row.entity,))
+                self._get_cached_user_device.invalidate_many((row.entity,))
+                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
-        self.get_cached_devices_for_user.invalidate((user_id,))
-        self._get_cached_user_device.invalidate_many((user_id,))
-        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+            else:
+                self._device_list_federation_stream_cache.entity_has_changed(
+                    row.entity, token
+                )
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index f22c2d44a3..bce8a3d115 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
         result["pushers"] = self._pushers_id_gen.get_current_token()
         return result
 
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
+
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "pushers":
             self._pushers_id_gen.advance(token)
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 81c2ea7ee9..523a1358d4 100644
--- a/synapse/replication/tcp/__init__.py
+++ b/synapse/replication/tcp/__init__.py
@@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst
 
 
 Structure of the module:
- * client.py   - the client classes used for workers to connect to master
+ * handler.py  - the classes used to handle sending/receiving commands to
+                 replication
  * command.py  - the definitions of all the valid commands
- * protocol.py - contains bot the client and server protocol implementations,
-                 these should not be used directly
- * resource.py - the server classes that accepts and handle client connections
- * streams.py  - the definitons of all the valid streams
+ * protocol.py - the TCP protocol classes
+ * resource.py - handles streaming stream updates to replications
+ * streams/    - the definitons of all the valid streams
 
+
+The general interaction of the classes are:
+
+        +---------------------+
+        | ReplicationStreamer |
+        +---------------------+
+                    |
+                    v
+        +---------------------------+     +----------------------+
+        | ReplicationCommandHandler |---->|ReplicationDataHandler|
+        +---------------------------+     +----------------------+
+                    | ^
+                    v |
+            +-------------+
+            | Protocols   |
+            | (TCP/redis) |
+            +-------------+
+
+Where the ReplicationDataHandler (or subclasses) handles incoming stream
+updates.
 """
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..700ae79158 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,26 +16,16 @@
 """
 
 import logging
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict
 
-from twisted.internet import defer
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.protocol import (
-    AbstractReplicationClientHandler,
-    ClientReplicationStreamProtocol,
-)
-
-from .commands import (
-    Command,
-    FederationAckCommand,
-    InvalidateCacheCommand,
-    RemoteServerUpCommand,
-    RemovePusherCommand,
-    UserIpCommand,
-    UserSyncCommand,
-)
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
 
 logger = logging.getLogger(__name__)
 
@@ -44,17 +34,22 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     """Factory for building connections to the master. Will reconnect if the
     connection is lost.
 
-    Accepts a handler that will be called when new data is available or data
-    is required.
+    Accepts a handler that is passed to `ClientReplicationStreamProtocol`.
     """
 
     initialDelay = 0.1
     maxDelay = 1  # Try at least once every N seconds
 
-    def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        client_name: str,
+        command_handler: "ReplicationCommandHandler",
+    ):
         self.client_name = client_name
-        self.handler = handler
+        self.command_handler = command_handler
         self.server_name = hs.config.server_name
+        self.hs = hs
         self._clock = hs.get_clock()  # As self.clock is defined in super class
 
         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +60,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     def buildProtocol(self, addr):
         logger.info("Connected to replication: %r", addr)
         return ClientReplicationStreamProtocol(
-            self.client_name, self.server_name, self._clock, self.handler
+            self.hs,
+            self.client_name,
+            self.server_name,
+            self._clock,
+            self.command_handler,
         )
 
     def clientConnectionLost(self, connector, reason):
@@ -77,41 +76,17 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
 
 
-class ReplicationClientHandler(AbstractReplicationClientHandler):
-    """A base handler that can be passed to the ReplicationClientFactory.
+class ReplicationDataHandler:
+    """Handles incoming stream updates from replication.
 
-    By default proxies incoming replication data to the SlaveStore.
+    This instance notifies the slave data store about updates. Can be subclassed
+    to handle updates in additional ways.
     """
 
     def __init__(self, store: BaseSlavedStore):
         self.store = store
 
-        # The current connection. None if we are currently (re)connecting
-        self.connection = None
-
-        # Any pending commands to be sent once a new connection has been
-        # established
-        self.pending_commands = []  # type: List[Command]
-
-        # Map from string -> deferred, to wake up when receiveing a SYNC with
-        # the given string.
-        # Used for tests.
-        self.awaiting_syncs = {}  # type: Dict[str, defer.Deferred]
-
-        # The factory used to create connections.
-        self.factory = None  # type: Optional[ReplicationClientFactory]
-
-    def start_replication(self, hs):
-        """Helper method to start a replication connection to the remote server
-        using TCP.
-        """
-        client_name = hs.config.worker_name
-        self.factory = ReplicationClientFactory(hs, client_name, self)
-        host = hs.config.worker_replication_host
-        port = hs.config.worker_replication_port
-        hs.get_reactor().connectTCP(host, port, self.factory)
-
-    async def on_rdata(self, stream_name, token, rows):
+    async def on_rdata(self, stream_name: str, token: int, rows: list):
         """Called to handle a batch of replication data with a given stream token.
 
         By default this just pokes the slave store. Can be overridden in subclasses to
@@ -123,30 +98,8 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
             rows (list): a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
         """
-        logger.debug("Received rdata %s -> %s", stream_name, token)
         self.store.process_replication_rows(stream_name, token, rows)
 
-    async def on_position(self, stream_name, token):
-        """Called when we get new position data. By default this just pokes
-        the slave store.
-
-        Can be overriden in subclasses to handle more.
-        """
-        self.store.process_replication_rows(stream_name, token, [])
-
-    def on_sync(self, data):
-        """When we received a SYNC we wake up any deferreds that were waiting
-        for the sync with the given data.
-
-        Used by tests.
-        """
-        d = self.awaiting_syncs.pop(data, None)
-        if d:
-            d.callback(data)
-
-    def on_remote_server_up(self, server: str):
-        """Called when get a new REMOTE_SERVER_UP command."""
-
     def get_streams_to_replicate(self) -> Dict[str, int]:
         """Called when a new connection has been established and we need to
         subscribe to streams.
@@ -162,83 +115,10 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
             args["account_data"] = user_account_data
         elif room_account_data:
             args["account_data"] = room_account_data
-
         return args
 
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users. (Overriden by the synchrotron's only)
-        """
-        return []
-
-    def send_command(self, cmd):
-        """Send a command to master (when we get establish a connection if we
-        don't have one already.)
-        """
-        if self.connection:
-            self.connection.send_command(cmd)
-        else:
-            logger.warning("Queuing command as not connected: %r", cmd.NAME)
-            self.pending_commands.append(cmd)
-
-    def send_federation_ack(self, token):
-        """Ack data for the federation stream. This allows the master to drop
-        data stored purely in memory.
-        """
-        self.send_command(FederationAckCommand(token))
-
-    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
-        """Poke the master that a user has started/stopped syncing.
-        """
-        self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
-
-    def send_remove_pusher(self, app_id, push_key, user_id):
-        """Poke the master to remove a pusher for a user
-        """
-        cmd = RemovePusherCommand(app_id, push_key, user_id)
-        self.send_command(cmd)
-
-    def send_invalidate_cache(self, cache_func, keys):
-        """Poke the master to invalidate a cache.
-        """
-        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
-        self.send_command(cmd)
-
-    def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
-        """Tell the master that the user made a request.
-        """
-        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
-        self.send_command(cmd)
-
-    def send_remote_server_up(self, server: str):
-        self.send_command(RemoteServerUpCommand(server))
-
-    def await_sync(self, data):
-        """Returns a deferred that is resolved when we receive a SYNC command
-        with given data.
-
-        [Not currently] used by tests.
-        """
-        return self.awaiting_syncs.setdefault(data, defer.Deferred())
-
-    def update_connection(self, connection):
-        """Called when a connection has been established (or lost with None).
-        """
-        self.connection = connection
-        if connection:
-            for cmd in self.pending_commands:
-                connection.send_command(cmd)
-            self.pending_commands = []
-
-    def finished_connecting(self):
-        """Called when we have successfully subscribed and caught up to all
-        streams we're interested in.
-        """
-        logger.info("Finished connecting to server")
+    async def on_position(self, stream_name: str, token: int):
+        self.store.process_replication_rows(stream_name, token, [])
 
-        # We don't reset the delay any earlier as otherwise if there is a
-        # problem during start up we'll end up tight looping connecting to the
-        # server.
-        if self.factory:
-            self.factory.resetDelay()
+    def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 451671412d..5ec89d0fb8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -17,7 +17,7 @@
 The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
 allowed to be sent by which side.
 """
-
+import abc
 import logging
 import platform
 from typing import Tuple, Type
@@ -34,34 +34,29 @@ else:
 logger = logging.getLogger(__name__)
 
 
-class Command(object):
+class Command(metaclass=abc.ABCMeta):
     """The base command class.
 
     All subclasses must set the NAME variable which equates to the name of the
     command on the wire.
 
     A full command line on the wire is constructed from `NAME + " " + to_line()`
-
-    The default implementation creates a command of form `<NAME> <data>`
     """
 
     NAME = None  # type: str
 
-    def __init__(self, data):
-        self.data = data
-
     @classmethod
+    @abc.abstractmethod
     def from_line(cls, line):
         """Deserialises a line from the wire into this command. `line` does not
         include the command.
         """
-        return cls(line)
 
-    def to_line(self):
+    @abc.abstractmethod
+    def to_line(self) -> str:
         """Serialises the comamnd for the wire. Does not include the command
         prefix.
         """
-        return self.data
 
     def get_logcontext_id(self):
         """Get a suitable string for the logcontext when processing this command"""
@@ -70,7 +65,21 @@ class Command(object):
         return self.NAME
 
 
-class ServerCommand(Command):
+class _SimpleCommand(Command):
+    """An implementation of Command whose argument is just a 'data' string."""
+
+    def __init__(self, data):
+        self.data = data
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(line)
+
+    def to_line(self) -> str:
+        return self.data
+
+
+class ServerCommand(_SimpleCommand):
     """Sent by the server on new connection and includes the server_name.
 
     Format::
@@ -136,8 +145,8 @@ class PositionCommand(Command):
     """Sent by the server to tell the client the stream postition without
     needing to send an RDATA.
 
-    Sent to the client after all missing updates for a stream have been sent
-    to the client and they're now up to date.
+    On receipt of a POSITION command clients should check if they have missed
+    any updates, and if so then fetch them out of band.
     """
 
     NAME = "POSITION"
@@ -155,7 +164,7 @@ class PositionCommand(Command):
         return " ".join((self.stream_name, str(self.token)))
 
 
-class ErrorCommand(Command):
+class ErrorCommand(_SimpleCommand):
     """Sent by either side if there was an ERROR. The data is a string describing
     the error.
     """
@@ -163,14 +172,14 @@ class ErrorCommand(Command):
     NAME = "ERROR"
 
 
-class PingCommand(Command):
+class PingCommand(_SimpleCommand):
     """Sent by either side as a keep alive. The data is arbitary (often timestamp)
     """
 
     NAME = "PING"
 
 
-class NameCommand(Command):
+class NameCommand(_SimpleCommand):
     """Sent by client to inform the server of the client's identity. The data
     is the name
     """
@@ -179,42 +188,24 @@ class NameCommand(Command):
 
 
 class ReplicateCommand(Command):
-    """Sent by the client to subscribe to the stream.
+    """Sent by the client to subscribe to streams.
 
     Format::
 
-        REPLICATE <stream_name> <token>
-
-    Where <token> may be either:
-        * a numeric stream_id to stream updates from
-        * "NOW" to stream all subsequent updates.
-
-    The <stream_name> can be "ALL" to subscribe to all known streams, in which
-    case the <token> must be set to "NOW", i.e.::
-
-        REPLICATE ALL NOW
+        REPLICATE
     """
 
     NAME = "REPLICATE"
 
-    def __init__(self, stream_name, token):
-        self.stream_name = stream_name
-        self.token = token
+    def __init__(self):
+        pass
 
     @classmethod
     def from_line(cls, line):
-        stream_name, token = line.split(" ", 1)
-        if token in ("NOW", "now"):
-            token = "NOW"
-        else:
-            token = int(token)
-        return cls(stream_name, token)
+        return cls()
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token)))
-
-    def get_logcontext_id(self):
-        return "REPLICATE-" + self.stream_name
+        return ""
 
 
 class UserSyncCommand(Command):
@@ -225,30 +216,32 @@ class UserSyncCommand(Command):
 
     Format::
 
-        USER_SYNC <user_id> <state> <last_sync_ms>
+        USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
 
     Where <state> is either "start" or "stop"
     """
 
     NAME = "USER_SYNC"
 
-    def __init__(self, user_id, is_syncing, last_sync_ms):
+    def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+        self.instance_id = instance_id
         self.user_id = user_id
         self.is_syncing = is_syncing
         self.last_sync_ms = last_sync_ms
 
     @classmethod
     def from_line(cls, line):
-        user_id, state, last_sync_ms = line.split(" ", 2)
+        instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
 
         if state not in ("start", "end"):
             raise Exception("Invalid USER_SYNC state %r" % (state,))
 
-        return cls(user_id, state == "start", int(last_sync_ms))
+        return cls(instance_id, user_id, state == "start", int(last_sync_ms))
 
     def to_line(self):
         return " ".join(
             (
+                self.instance_id,
                 self.user_id,
                 "start" if self.is_syncing else "end",
                 str(self.last_sync_ms),
@@ -256,6 +249,30 @@ class UserSyncCommand(Command):
         )
 
 
+class ClearUserSyncsCommand(Command):
+    """Sent by the client to inform the server that it should drop all
+    information about syncing users sent by the client.
+
+    Mainly used when client is about to shut down.
+
+    Format::
+
+        CLEAR_USER_SYNC <instance_id>
+    """
+
+    NAME = "CLEAR_USER_SYNC"
+
+    def __init__(self, instance_id):
+        self.instance_id = instance_id
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(line)
+
+    def to_line(self):
+        return self.instance_id
+
+
 class FederationAckCommand(Command):
     """Sent by the client when it has processed up to a given point in the
     federation stream. This allows the master to drop in-memory caches of the
@@ -281,14 +298,6 @@ class FederationAckCommand(Command):
         return str(self.token)
 
 
-class SyncCommand(Command):
-    """Used for testing. The client protocol implementation allows waiting
-    on a SYNC command with a specified data.
-    """
-
-    NAME = "SYNC"
-
-
 class RemovePusherCommand(Command):
     """Sent by the client to request the master remove the given pusher.
 
@@ -387,7 +396,7 @@ class UserIpCommand(Command):
         )
 
 
-class RemoteServerUpCommand(Command):
+class RemoteServerUpCommand(_SimpleCommand):
     """Sent when a worker has detected that a remote server is no longer
     "down" and retry timings should be reset.
 
@@ -411,11 +420,11 @@ _COMMANDS = (
     ReplicateCommand,
     UserSyncCommand,
     FederationAckCommand,
-    SyncCommand,
     RemovePusherCommand,
     InvalidateCacheCommand,
     UserIpCommand,
     RemoteServerUpCommand,
+    ClearUserSyncsCommand,
 )  # type: Tuple[Type[Command], ...]
 
 # Map of command name to command type.
@@ -428,7 +437,6 @@ VALID_SERVER_COMMANDS = (
     PositionCommand.NAME,
     ErrorCommand.NAME,
     PingCommand.NAME,
-    SyncCommand.NAME,
     RemoteServerUpCommand.NAME,
 )
 
@@ -438,6 +446,7 @@ VALID_CLIENT_COMMANDS = (
     ReplicateCommand.NAME,
     PingCommand.NAME,
     UserSyncCommand.NAME,
+    ClearUserSyncsCommand.NAME,
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
     InvalidateCacheCommand.NAME,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
new file mode 100644
index 0000000000..e32e68e8c4
--- /dev/null
+++ b/synapse/replication/tcp/handler.py
@@ -0,0 +1,469 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+)
+
+from prometheus_client import Counter
+
+from synapse.metrics import LaterGauge
+from synapse.replication.tcp.client import ReplicationClientFactory
+from synapse.replication.tcp.commands import (
+    ClearUserSyncsCommand,
+    Command,
+    FederationAckCommand,
+    InvalidateCacheCommand,
+    PositionCommand,
+    RdataCommand,
+    RemoteServerUpCommand,
+    RemovePusherCommand,
+    ReplicateCommand,
+    UserIpCommand,
+    UserSyncCommand,
+)
+from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.util.async_helpers import Linearizer
+
+logger = logging.getLogger(__name__)
+
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = Counter(
+    "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
+)
+user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
+federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
+remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
+invalidate_cache_counter = Counter(
+    "synapse_replication_tcp_resource_invalidate_cache", ""
+)
+user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
+
+
+class ReplicationCommandHandler:
+    """Handles incoming commands from replication as well as sending commands
+    back out to connections.
+    """
+
+    def __init__(self, hs):
+        self._replication_data_handler = hs.get_replication_data_handler()
+        self._presence_handler = hs.get_presence_handler()
+        self._store = hs.get_datastore()
+        self._notifier = hs.get_notifier()
+        self._clock = hs.get_clock()
+        self._instance_id = hs.get_instance_id()
+
+        # Set of streams that we've caught up with.
+        self._streams_connected = set()  # type: Set[str]
+
+        self._streams = {
+            stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+        }  # type: Dict[str, Stream]
+
+        self._position_linearizer = Linearizer("replication_position")
+
+        # Map of stream to batched updates. See RdataCommand for info on how
+        # batching works.
+        self._pending_batches = {}  # type: Dict[str, List[Any]]
+
+        # The factory used to create connections.
+        self._factory = None  # type: Optional[ReplicationClientFactory]
+
+        # The currently connected connections.
+        self._connections = []  # type: List[AbstractConnection]
+
+        LaterGauge(
+            "synapse_replication_tcp_resource_total_connections",
+            "",
+            [],
+            lambda: len(self._connections),
+        )
+
+        self._is_master = hs.config.worker_app is None
+
+        self._federation_sender = None
+        if self._is_master and not hs.config.send_federation:
+            self._federation_sender = hs.get_federation_sender()
+
+        self._server_notices_sender = None
+        if self._is_master:
+            self._server_notices_sender = hs.get_server_notices_sender()
+            self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
+
+    def start_replication(self, hs):
+        """Helper method to start a replication connection to the remote server
+        using TCP.
+        """
+        client_name = hs.config.worker_name
+        self._factory = ReplicationClientFactory(hs, client_name, self)
+        host = hs.config.worker_replication_host
+        port = hs.config.worker_replication_port
+        hs.get_reactor().connectTCP(host, port, self._factory)
+
+    async def on_REPLICATE(self, cmd: ReplicateCommand):
+        # We only want to announce positions by the writer of the streams.
+        # Currently this is just the master process.
+        if not self._is_master:
+            return
+
+        for stream_name, stream in self._streams.items():
+            current_token = stream.current_token()
+            self.send_command(PositionCommand(stream_name, current_token))
+
+    async def on_USER_SYNC(self, cmd: UserSyncCommand):
+        user_sync_counter.inc()
+
+        if self._is_master:
+            await self._presence_handler.update_external_syncs_row(
+                cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+            )
+
+    async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
+        if self._is_master:
+            await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+
+    async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
+        federation_ack_counter.inc()
+
+        if self._federation_sender:
+            self._federation_sender.federation_ack(cmd.token)
+
+    async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
+        remove_pusher_counter.inc()
+
+        if self._is_master:
+            await self._store.delete_pusher_by_app_id_pushkey_user_id(
+                app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+            )
+
+            self._notifier.on_new_replication_data()
+
+    async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
+        invalidate_cache_counter.inc()
+
+        if self._is_master:
+            # We invalidate the cache locally, but then also stream that to other
+            # workers.
+            await self._store.invalidate_cache_and_stream(
+                cmd.cache_func, tuple(cmd.keys)
+            )
+
+    async def on_USER_IP(self, cmd: UserIpCommand):
+        user_ip_cache_counter.inc()
+
+        if self._is_master:
+            await self._store.insert_client_ip(
+                cmd.user_id,
+                cmd.access_token,
+                cmd.ip,
+                cmd.user_agent,
+                cmd.device_id,
+                cmd.last_seen,
+            )
+
+        if self._server_notices_sender:
+            await self._server_notices_sender.on_user_ip(cmd.user_id)
+
+    async def on_RDATA(self, cmd: RdataCommand):
+        stream_name = cmd.stream_name
+        inbound_rdata_count.labels(stream_name).inc()
+
+        try:
+            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+        except Exception:
+            logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
+            raise
+
+        # We linearize here for two reasons:
+        #   1. so we don't try and concurrently handle multiple rows for the
+        #      same stream, and
+        #   2. so we don't race with getting a POSITION command and fetching
+        #      missing RDATA.
+        with await self._position_linearizer.queue(cmd.stream_name):
+            if stream_name not in self._streams_connected:
+                # If the stream isn't marked as connected then we haven't seen a
+                # `POSITION` command yet, and so we may have missed some rows.
+                # Let's drop the row for now, on the assumption we'll receive a
+                # `POSITION` soon and we'll catch up correctly then.
+                logger.warning(
+                    "Discarding RDATA for unconnected stream %s -> %s",
+                    stream_name,
+                    cmd.token,
+                )
+                return
+
+            if cmd.token is None:
+                # I.e. this is part of a batch of updates for this stream (in
+                # which case batch until we get an update for the stream with a non
+                # None token).
+                self._pending_batches.setdefault(stream_name, []).append(row)
+            else:
+                # Check if this is the last of a batch of updates
+                rows = self._pending_batches.pop(stream_name, [])
+                rows.append(row)
+                await self.on_rdata(stream_name, cmd.token, rows)
+
+    async def on_rdata(self, stream_name: str, token: int, rows: list):
+        """Called to handle a batch of replication data with a given stream token.
+
+        Args:
+            stream_name: name of the replication stream for this batch of rows
+            token: stream token for this batch of rows
+            rows: a list of Stream.ROW_TYPE objects as returned by
+                Stream.parse_row.
+        """
+        logger.debug("Received rdata %s -> %s", stream_name, token)
+        await self._replication_data_handler.on_rdata(stream_name, token, rows)
+
+    async def on_POSITION(self, cmd: PositionCommand):
+        stream = self._streams.get(cmd.stream_name)
+        if not stream:
+            logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+            return
+
+        # We protect catching up with a linearizer in case the replication
+        # connection reconnects under us.
+        with await self._position_linearizer.queue(cmd.stream_name):
+            # We're about to go and catch up with the stream, so remove from set
+            # of connected streams.
+            self._streams_connected.discard(cmd.stream_name)
+
+            # We clear the pending batches for the stream as the fetching of the
+            # missing updates below will fetch all rows in the batch.
+            self._pending_batches.pop(cmd.stream_name, [])
+
+            # Find where we previously streamed up to.
+            current_token = self._replication_data_handler.get_streams_to_replicate().get(
+                cmd.stream_name
+            )
+            if current_token is None:
+                logger.warning(
+                    "Got POSITION for stream we're not subscribed to: %s",
+                    cmd.stream_name,
+                )
+                return
+
+            # If the position token matches our current token then we're up to
+            # date and there's nothing to do. Otherwise, fetch all updates
+            # between then and now.
+            missing_updates = cmd.token != current_token
+            while missing_updates:
+                (
+                    updates,
+                    current_token,
+                    missing_updates,
+                ) = await stream.get_updates_since(current_token, cmd.token)
+
+                # TODO: add some tests for this
+
+                # Some streams return multiple rows with the same stream IDs,
+                # which need to be processed in batches.
+
+                for token, rows in _batch_updates(updates):
+                    await self.on_rdata(
+                        cmd.stream_name, token, [stream.parse_row(row) for row in rows],
+                    )
+
+            # We've now caught up to position sent to us, notify handler.
+            await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
+
+            self._streams_connected.add(cmd.stream_name)
+
+    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+        """"Called when get a new REMOTE_SERVER_UP command."""
+        self._replication_data_handler.on_remote_server_up(cmd.data)
+
+        if self._is_master:
+            self._notifier.notify_remote_server_up(cmd.data)
+
+    def get_currently_syncing_users(self):
+        """Get the list of currently syncing users (if any). This is called
+        when a connection has been established and we need to send the
+        currently syncing users.
+        """
+        return self._presence_handler.get_currently_syncing_users()
+
+    def new_connection(self, connection: AbstractConnection):
+        """Called when we have a new connection.
+        """
+        self._connections.append(connection)
+
+        # If we are connected to replication as a client (rather than a server)
+        # we need to reset the reconnection delay on the client factory (which
+        # is used to do exponential back off when the connection drops).
+        #
+        # Ideally we would reset the delay when we've "fully established" the
+        # connection (for some definition thereof) to stop us from tightlooping
+        # on reconnection if something fails after this point and we drop the
+        # connection. Unfortunately, we don't really have a better definition of
+        # "fully established" than the connection being established.
+        if self._factory:
+            self._factory.resetDelay()
+
+        # Tell the server if we have any users currently syncing (should only
+        # happen on synchrotrons)
+        currently_syncing = self.get_currently_syncing_users()
+        now = self._clock.time_msec()
+        for user_id in currently_syncing:
+            connection.send_command(
+                UserSyncCommand(self._instance_id, user_id, True, now)
+            )
+
+    def lost_connection(self, connection: AbstractConnection):
+        """Called when a connection is closed/lost.
+        """
+        try:
+            self._connections.remove(connection)
+        except ValueError:
+            pass
+
+    def connected(self) -> bool:
+        """Do we have any replication connections open?
+
+        Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected.
+        """
+        return bool(self._connections)
+
+    def send_command(self, cmd: Command):
+        """Send a command to all connected connections.
+        """
+        if self._connections:
+            for connection in self._connections:
+                try:
+                    connection.send_command(cmd)
+                except Exception:
+                    # We probably want to catch some types of exceptions here
+                    # and log them as warnings (e.g. connection gone), but I
+                    # can't find what those exception types they would be.
+                    logger.exception(
+                        "Failed to write command %s to connection %s",
+                        cmd.NAME,
+                        connection,
+                    )
+        else:
+            logger.warning("Dropping command as not connected: %r", cmd.NAME)
+
+    def send_federation_ack(self, token: int):
+        """Ack data for the federation stream. This allows the master to drop
+        data stored purely in memory.
+        """
+        self.send_command(FederationAckCommand(token))
+
+    def send_user_sync(
+        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+    ):
+        """Poke the master that a user has started/stopped syncing.
+        """
+        self.send_command(
+            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+        )
+
+    def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
+        """Poke the master to remove a pusher for a user
+        """
+        cmd = RemovePusherCommand(app_id, push_key, user_id)
+        self.send_command(cmd)
+
+    def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
+        """Poke the master to invalidate a cache.
+        """
+        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
+        self.send_command(cmd)
+
+    def send_user_ip(
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: str,
+        last_seen: int,
+    ):
+        """Tell the master that the user made a request.
+        """
+        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
+        self.send_command(cmd)
+
+    def send_remote_server_up(self, server: str):
+        self.send_command(RemoteServerUpCommand(server))
+
+    def stream_update(self, stream_name: str, token: str, data: Any):
+        """Called when a new update is available to stream to clients.
+
+        We need to check if the client is interested in the stream or not
+        """
+        self.send_command(RdataCommand(stream_name, token, data))
+
+
+UpdateToken = TypeVar("UpdateToken")
+UpdateRow = TypeVar("UpdateRow")
+
+
+def _batch_updates(
+    updates: Iterable[Tuple[UpdateToken, UpdateRow]]
+) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
+    """Collect stream updates with the same token together
+
+    Given a series of updates returned by Stream.get_updates_since(), collects
+    the updates which share the same stream_id together.
+
+    For example:
+
+        [(1, a), (1, b), (2, c), (3, d), (3, e)]
+
+    becomes:
+
+        [
+            (1, [a, b]),
+            (2, [c]),
+            (3, [d, e]),
+        ]
+    """
+
+    update_iter = iter(updates)
+
+    first_update = next(update_iter, None)
+    if first_update is None:
+        # empty input
+        return
+
+    current_batch_token = first_update[0]
+    current_batch = [first_update[1]]
+
+    for token, row in update_iter:
+        if token != current_batch_token:
+            # different token to the previous row: flush the previous
+            # batch and start anew
+            yield current_batch_token, current_batch
+            current_batch_token = token
+            current_batch = []
+
+        current_batch.append(row)
+
+    # flush the final batch
+    yield current_batch_token, current_batch
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index bc1482a9bb..9276ed2965 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
     > PING 1490197665618
     < NAME synapse.app.appservice
     < PING 1490197665618
-    < REPLICATE events 1
-    < REPLICATE backfill 1
-    < REPLICATE caches 1
+    < REPLICATE
     > POSITION events 1
     > POSITION backfill 1
     > POSITION caches 1
@@ -53,17 +51,15 @@ import fcntl
 import logging
 import struct
 from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
+from typing import TYPE_CHECKING, DefaultDict, List
 
-from six import iteritems, iterkeys
+from six import iteritems
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.commands import (
@@ -74,19 +70,18 @@ from synapse.replication.tcp.commands import (
     ErrorCommand,
     NameCommand,
     PingCommand,
-    PositionCommand,
-    RdataCommand,
-    RemoteServerUpCommand,
     ReplicateCommand,
     ServerCommand,
-    SyncCommand,
-    UserSyncCommand,
 )
-from synapse.replication.tcp.streams import STREAMS_MAP
 from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
+if TYPE_CHECKING:
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
+    from synapse.server import HomeServer
+
+
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
 )
@@ -119,7 +114,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     are only sent by the server.
 
     On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
-    command.
+    command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
 
     It also sends `PING` periodically, and correctly times out remote connections
     (if they send a `PING` command)
@@ -135,8 +130,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     max_line_buffer = 10000
 
-    def __init__(self, clock):
+    def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):
         self.clock = clock
+        self.command_handler = handler
 
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
@@ -176,6 +172,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # can time us out.
         self.send_command(PingCommand(self.clock.time_msec()))
 
+        self.command_handler.new_connection(self)
+
     def send_ping(self):
         """Periodically sends a ping and checks if we should close the connection
         due to the other side timing out.
@@ -203,15 +201,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 )
                 self.send_error("ping timeout")
 
-    def lineReceived(self, line):
+    def lineReceived(self, line: bytes):
         """Called when we've received a line
         """
         if line.strip() == "":
             # Ignore blank lines
             return
 
-        line = line.decode("utf-8")
-        cmd_name, rest_of_line = line.split(" ", 1)
+        linestr = line.decode("utf-8")
+
+        # split at the first " ", handling one-word commands
+        idx = linestr.index(" ")
+        if idx >= 0:
+            cmd_name = linestr[:idx]
+            rest_of_line = linestr[idx + 1 :]
+        else:
+            cmd_name = linestr
+            rest_of_line = ""
 
         if cmd_name not in self.VALID_INBOUND_COMMANDS:
             logger.error("[%s] invalid command %s", self.id(), cmd_name)
@@ -244,13 +250,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     async def handle_command(self, cmd: Command):
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>, which should return an awaitable.
+        First calls `self.on_<COMMAND>` if it exists, then calls
+        `self.command_handler.on_<COMMAND>` if it exists. This allows for
+        protocol level handling of commands (e.g. PINGs), before delegating to
+        the handler.
 
         Args:
             cmd: received command
         """
-        handler = getattr(self, "on_%s" % (cmd.NAME,))
-        await handler(cmd)
+        handled = False
+
+        # First call any command handlers on this instance. These are for TCP
+        # specific handling.
+        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        # Then call out to the handler.
+        cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        if not handled:
+            logger.warning("Unhandled command: %r", cmd)
 
     def close(self):
         logger.warning("[%s] Closing connection", self.id())
@@ -379,6 +403,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.state = ConnectionStates.CLOSED
         self.pending_commands = []
 
+        self.command_handler.lost_connection(self)
+
         if self.transport:
             self.transport.unregisterProducer()
 
@@ -405,232 +431,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
     VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
     VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
 
-    def __init__(self, server_name, clock, streamer):
-        BaseReplicationStreamProtocol.__init__(self, clock)  # Old style class
+    def __init__(
+        self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
+    ):
+        super().__init__(clock, handler)
 
         self.server_name = server_name
-        self.streamer = streamer
-
-        # The streams the client has subscribed to and is up to date with
-        self.replication_streams = set()  # type: Set[str]
-
-        # The streams the client is currently subscribing to.
-        self.connecting_streams = set()  # type:  Set[str]
-
-        # Map from stream name to list of updates to send once we've finished
-        # subscribing the client to the stream.
-        self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]]
 
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
-        BaseReplicationStreamProtocol.connectionMade(self)
-        self.streamer.new_connection(self)
+        super().connectionMade()
 
     async def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
-    async def on_USER_SYNC(self, cmd):
-        await self.streamer.on_user_sync(
-            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
-        )
-
-    async def on_REPLICATE(self, cmd):
-        stream_name = cmd.stream_name
-        token = cmd.token
-
-        if stream_name == "ALL":
-            # Subscribe to all streams we're publishing to.
-            deferreds = [
-                run_in_background(self.subscribe_to_stream, stream, token)
-                for stream in iterkeys(self.streamer.streams_by_name)
-            ]
-
-            await make_deferred_yieldable(
-                defer.gatherResults(deferreds, consumeErrors=True)
-            )
-        else:
-            await self.subscribe_to_stream(stream_name, token)
-
-    async def on_FEDERATION_ACK(self, cmd):
-        self.streamer.federation_ack(cmd.token)
-
-    async def on_REMOVE_PUSHER(self, cmd):
-        await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
-
-    async def on_INVALIDATE_CACHE(self, cmd):
-        await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
-
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
-        self.streamer.on_remote_server_up(cmd.data)
-
-    async def on_USER_IP(self, cmd):
-        self.streamer.on_user_ip(
-            cmd.user_id,
-            cmd.access_token,
-            cmd.ip,
-            cmd.user_agent,
-            cmd.device_id,
-            cmd.last_seen,
-        )
-
-    async def subscribe_to_stream(self, stream_name, token):
-        """Subscribe the remote to a stream.
-
-        This invloves checking if they've missed anything and sending those
-        updates down if they have. During that time new updates for the stream
-        are queued and sent once we've sent down any missed updates.
-        """
-        self.replication_streams.discard(stream_name)
-        self.connecting_streams.add(stream_name)
-
-        try:
-            # Get missing updates
-            updates, current_token = await self.streamer.get_stream_updates(
-                stream_name, token
-            )
-
-            # Send all the missing updates
-            for update in updates:
-                token, row = update[0], update[1]
-                self.send_command(RdataCommand(stream_name, token, row))
-
-            # We send a POSITION command to ensure that they have an up to
-            # date token (especially useful if we didn't send any updates
-            # above)
-            self.send_command(PositionCommand(stream_name, current_token))
-
-            # Now we can send any updates that came in while we were subscribing
-            pending_rdata = self.pending_rdata.pop(stream_name, [])
-            updates = []
-            for token, update in pending_rdata:
-                # If the token is null, it is part of a batch update. Batches
-                # are multiple updates that share a single token. To denote
-                # this, the token is set to None for all tokens in the batch
-                # except for the last. If we find a None token, we keep looking
-                # through tokens until we find one that is not None and then
-                # process all previous updates in the batch as if they had the
-                # final token.
-                if token is None:
-                    # Store this update as part of a batch
-                    updates.append(update)
-                    continue
-
-                if token <= current_token:
-                    # This update or batch of updates is older than
-                    # current_token, dismiss it
-                    updates = []
-                    continue
-
-                updates.append(update)
-
-                # Send all updates that are part of this batch with the
-                # found token
-                for update in updates:
-                    self.send_command(RdataCommand(stream_name, token, update))
-
-                # Clear stored updates
-                updates = []
-
-            # They're now fully subscribed
-            self.replication_streams.add(stream_name)
-        except Exception as e:
-            logger.exception("[%s] Failed to handle REPLICATE command", self.id())
-            self.send_error("failed to handle replicate: %r", e)
-        finally:
-            self.connecting_streams.discard(stream_name)
-
-    def stream_update(self, stream_name, token, data):
-        """Called when a new update is available to stream to clients.
-
-        We need to check if the client is interested in the stream or not
-        """
-        if stream_name in self.replication_streams:
-            # The client is subscribed to the stream
-            self.send_command(RdataCommand(stream_name, token, data))
-        elif stream_name in self.connecting_streams:
-            # The client is being subscribed to the stream
-            logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
-            self.pending_rdata.setdefault(stream_name, []).append((token, data))
-        else:
-            # The client isn't subscribed
-            logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
-
-    def send_sync(self, data):
-        self.send_command(SyncCommand(data))
-
-    def send_remote_server_up(self, server: str):
-        self.send_command(RemoteServerUpCommand(server))
-
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.streamer.lost_connection(self)
-
-
-class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
-    """
-    The interface for the handler that should be passed to
-    ClientReplicationStreamProtocol
-    """
-
-    @abc.abstractmethod
-    async def on_rdata(self, stream_name, token, rows):
-        """Called to handle a batch of replication data with a given stream token.
-
-        Args:
-            stream_name (str): name of the replication stream for this batch of rows
-            token (int): stream token for this batch of rows
-            rows (list): a list of Stream.ROW_TYPE objects as returned by
-                Stream.parse_row.
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    async def on_position(self, stream_name, token):
-        """Called when we get new position data."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def on_sync(self, data):
-        """Called when get a new SYNC command."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    async def on_remote_server_up(self, server: str):
-        """Called when get a new REMOTE_SERVER_UP command."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def get_streams_to_replicate(self):
-        """Called when a new connection has been established and we need to
-        subscribe to streams.
-
-        Returns:
-            map from stream name to the most recent update we have for
-            that stream (ie, the point we want to start replicating from)
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def update_connection(self, connection):
-        """Called when a connection has been established (or lost with None).
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def finished_connecting(self):
-        """Called when we have successfully subscribed and caught up to all
-        streams we're interested in.
-        """
-        raise NotImplementedError()
-
 
 class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
@@ -638,110 +453,51 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     def __init__(
         self,
+        hs: "HomeServer",
         client_name: str,
         server_name: str,
         clock: Clock,
-        handler: AbstractReplicationClientHandler,
+        command_handler: "ReplicationCommandHandler",
     ):
-        BaseReplicationStreamProtocol.__init__(self, clock)
+        super().__init__(clock, command_handler)
 
         self.client_name = client_name
         self.server_name = server_name
-        self.handler = handler
-
-        # Set of stream names that have been subscribe to, but haven't yet
-        # caught up with. This is used to track when the client has been fully
-        # connected to the remote.
-        self.streams_connecting = set()  # type: Set[str]
-
-        # Map of stream to batched updates. See RdataCommand for info on how
-        # batching works.
-        self.pending_batches = {}  # type: Dict[str, Any]
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
-        BaseReplicationStreamProtocol.connectionMade(self)
+        super().connectionMade()
 
         # Once we've connected subscribe to the necessary streams
-        for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
-            self.replicate(stream_name, token)
-
-        # Tell the server if we have any users currently syncing (should only
-        # happen on synchrotrons)
-        currently_syncing = self.handler.get_currently_syncing_users()
-        now = self.clock.time_msec()
-        for user_id in currently_syncing:
-            self.send_command(UserSyncCommand(user_id, True, now))
-
-        # We've now finished connecting to so inform the client handler
-        self.handler.update_connection(self)
-
-        # This will happen if we don't actually subscribe to any streams
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
+        self.replicate()
 
     async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
 
-    async def on_RDATA(self, cmd):
-        stream_name = cmd.stream_name
-        inbound_rdata_count.labels(stream_name).inc()
-
-        try:
-            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
-        except Exception:
-            logger.exception(
-                "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
-            )
-            raise
-
-        if cmd.token is None:
-            # I.e. this is part of a batch of updates for this stream. Batch
-            # until we get an update for the stream with a non None token
-            self.pending_batches.setdefault(stream_name, []).append(row)
-        else:
-            # Check if this is the last of a batch of updates
-            rows = self.pending_batches.pop(stream_name, [])
-            rows.append(row)
-            await self.handler.on_rdata(stream_name, cmd.token, rows)
-
-    async def on_POSITION(self, cmd):
-        # When we get a `POSITION` command it means we've finished getting
-        # missing updates for the given stream, and are now up to date.
-        self.streams_connecting.discard(cmd.stream_name)
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
+    def replicate(self):
+        """Send the subscription request to the server
+        """
+        logger.info("[%s] Subscribing to replication streams", self.id())
 
-        await self.handler.on_position(cmd.stream_name, cmd.token)
+        self.send_command(ReplicateCommand())
 
-    async def on_SYNC(self, cmd):
-        self.handler.on_sync(cmd.data)
 
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
-        self.handler.on_remote_server_up(cmd.data)
+class AbstractConnection(abc.ABC):
+    """An interface for replication connections.
+    """
 
-    def replicate(self, stream_name, token):
-        """Send the subscription request to the server
+    @abc.abstractmethod
+    def send_command(self, cmd: Command):
+        """Send the command down the connection
         """
-        if stream_name not in STREAMS_MAP:
-            raise Exception("Invalid stream name %r" % (stream_name,))
-
-        logger.info(
-            "[%s] Subscribing to replication stream: %r from %r",
-            self.id(),
-            stream_name,
-            token,
-        )
-
-        self.streams_connecting.add(stream_name)
+        pass
 
-        self.send_command(ReplicateCommand(stream_name, token))
 
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.handler.update_connection(None)
+# This tells python that `BaseReplicationStreamProtocol` implements the
+# interface.
+AbstractConnection.register(BaseReplicationStreamProtocol)
 
 
 # The following simply registers metrics for the replication connections
@@ -827,8 +583,3 @@ tcp_outbound_commands = LaterGauge(
         for k, count in iteritems(p.outbound_commands_counter)
     },
 )
-
-# number of updates received for each RDATA stream
-inbound_rdata_count = Counter(
-    "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
-)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ce9d1fae12..b2d6baa2a2 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,7 @@
 
 import logging
 import random
-from typing import Any, List
+from typing import Dict
 
 from six import itervalues
 
@@ -25,24 +25,14 @@ from prometheus_client import Counter
 
 from twisted.internet.protocol import Factory
 
-from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.metrics import Measure, measure_func
-
-from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP
-from .streams.federation import FederationStream
+from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
+from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
     "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
 )
-user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
-federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
-remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter(
-    "synapse_replication_tcp_resource_invalidate_cache", ""
-)
-user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
 
 logger = logging.getLogger(__name__)
 
@@ -52,13 +42,23 @@ class ReplicationStreamProtocolFactory(Factory):
     """
 
     def __init__(self, hs):
-        self.streamer = ReplicationStreamer(hs)
+        self.command_handler = hs.get_tcp_replication()
         self.clock = hs.get_clock()
         self.server_name = hs.config.server_name
 
+        # If we've created a `ReplicationStreamProtocolFactory` then we're
+        # almost certainly registering a replication listener, so let's ensure
+        # that we've started a `ReplicationStreamer` instance to actually push
+        # data.
+        #
+        # (This is a bit of a weird place to do this, but the alternatives such
+        # as putting this in `HomeServer.setup()`, requires either passing the
+        # listener config again or always starting a `ReplicationStreamer`.)
+        hs.get_replication_streamer()
+
     def buildProtocol(self, addr):
         return ServerReplicationStreamProtocol(
-            self.server_name, self.clock, self.streamer
+            self.server_name, self.clock, self.command_handler
         )
 
 
@@ -78,16 +78,6 @@ class ReplicationStreamer(object):
 
         self._replication_torture_level = hs.config.replication_torture_level
 
-        # Current connections.
-        self.connections = []  # type: List[ServerReplicationStreamProtocol]
-
-        LaterGauge(
-            "synapse_replication_tcp_resource_total_connections",
-            "",
-            [],
-            lambda: len(self.connections),
-        )
-
         # List of streams that clients can subscribe to.
         # We only support federation stream if federation sending hase been
         # disabled on the master.
@@ -99,39 +89,22 @@ class ReplicationStreamer(object):
 
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
 
-        LaterGauge(
-            "synapse_replication_tcp_resource_connections_per_stream",
-            "",
-            ["stream_name"],
-            lambda: {
-                (stream_name,): len(
-                    [
-                        conn
-                        for conn in self.connections
-                        if stream_name in conn.replication_streams
-                    ]
-                )
-                for stream_name in self.streams_by_name
-            },
-        )
-
         self.federation_sender = None
         if not hs.config.send_federation:
             self.federation_sender = hs.get_federation_sender()
 
         self.notifier.add_replication_callback(self.on_notifier_poke)
-        self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
 
         # Keeps track of whether we are currently checking for updates
         self.is_looping = False
         self.pending_updates = False
 
-        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+        self.command_handler = hs.get_tcp_replication()
 
-    def on_shutdown(self):
-        # close all connections on shutdown
-        for conn in self.connections:
-            conn.send_error("server shutting down")
+    def get_streams(self) -> Dict[str, Stream]:
+        """Get a mapp from stream name to stream instance.
+        """
+        return self.streams_by_name
 
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
@@ -140,7 +113,7 @@ class ReplicationStreamer(object):
         This should get called each time new data is available, even if it
         is currently being executed, so that nothing gets missed
         """
-        if not self.connections:
+        if not self.command_handler.connected():
             # Don't bother if nothing is listening. We still need to advance
             # the stream tokens otherwise they'll fall beihind forever
             for stream in self.streams:
@@ -166,11 +139,6 @@ class ReplicationStreamer(object):
                 self.pending_updates = False
 
                 with Measure(self.clock, "repl.stream.get_updates"):
-                    # First we tell the streams that they should update their
-                    # current tokens.
-                    for stream in self.streams:
-                        stream.advance_current_token()
-
                     all_streams = self.streams
 
                     if self._replication_torture_level is not None:
@@ -180,7 +148,7 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.upto_token:
+                        if stream.last_token == stream.current_token():
                             continue
 
                         if self._replication_torture_level:
@@ -192,18 +160,17 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.upto_token,
+                            stream.current_token(),
                         )
                         try:
-                            updates, current_token = await stream.get_updates()
+                            updates, current_token, limited = await stream.get_updates()
+                            self.pending_updates |= limited
                         except Exception:
                             logger.info("Failed to handle stream %s", stream.NAME)
                             raise
 
                         logger.debug(
-                            "Sending %d updates to %d connections",
-                            len(updates),
-                            len(self.connections),
+                            "Sending %d updates", len(updates),
                         )
 
                         if updates:
@@ -219,116 +186,19 @@ class ReplicationStreamer(object):
                         # token. See RdataCommand for more details.
                         batched_updates = _batch_updates(updates)
 
-                        for conn in self.connections:
-                            for token, row in batched_updates:
-                                try:
-                                    conn.stream_update(stream.NAME, token, row)
-                                except Exception:
-                                    logger.exception("Failed to replicate")
+                        for token, row in batched_updates:
+                            try:
+                                self.command_handler.stream_update(
+                                    stream.NAME, token, row
+                                )
+                            except Exception:
+                                logger.exception("Failed to replicate")
 
             logger.debug("No more pending updates, breaking poke loop")
         finally:
             self.pending_updates = False
             self.is_looping = False
 
-    @measure_func("repl.get_stream_updates")
-    async def get_stream_updates(self, stream_name, token):
-        """For a given stream get all updates since token. This is called when
-        a client first subscribes to a stream.
-        """
-        stream = self.streams_by_name.get(stream_name, None)
-        if not stream:
-            raise Exception("unknown stream %s", stream_name)
-
-        return await stream.get_updates_since(token)
-
-    @measure_func("repl.federation_ack")
-    def federation_ack(self, token):
-        """We've received an ack for federation stream from a client.
-        """
-        federation_ack_counter.inc()
-        if self.federation_sender:
-            self.federation_sender.federation_ack(token)
-
-    @measure_func("repl.on_user_sync")
-    async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
-        """A client has started/stopped syncing on a worker.
-        """
-        user_sync_counter.inc()
-        await self.presence_handler.update_external_syncs_row(
-            conn_id, user_id, is_syncing, last_sync_ms
-        )
-
-    @measure_func("repl.on_remove_pusher")
-    async def on_remove_pusher(self, app_id, push_key, user_id):
-        """A client has asked us to remove a pusher
-        """
-        remove_pusher_counter.inc()
-        await self.store.delete_pusher_by_app_id_pushkey_user_id(
-            app_id=app_id, pushkey=push_key, user_id=user_id
-        )
-
-        self.notifier.on_new_replication_data()
-
-    @measure_func("repl.on_invalidate_cache")
-    async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
-        """The client has asked us to invalidate a cache
-        """
-        invalidate_cache_counter.inc()
-
-        # We invalidate the cache locally, but then also stream that to other
-        # workers.
-        await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
-
-    @measure_func("repl.on_user_ip")
-    async def on_user_ip(
-        self, user_id, access_token, ip, user_agent, device_id, last_seen
-    ):
-        """The client saw a user request
-        """
-        user_ip_cache_counter.inc()
-        await self.store.insert_client_ip(
-            user_id, access_token, ip, user_agent, device_id, last_seen
-        )
-        await self._server_notices_sender.on_user_ip(user_id)
-
-    @measure_func("repl.on_remote_server_up")
-    def on_remote_server_up(self, server: str):
-        self.notifier.notify_remote_server_up(server)
-
-    def send_remote_server_up(self, server: str):
-        for conn in self.connections:
-            conn.send_remote_server_up(server)
-
-    def send_sync_to_all_connections(self, data):
-        """Sends a SYNC command to all clients.
-
-        Used in tests.
-        """
-        for conn in self.connections:
-            conn.send_sync(data)
-
-    def new_connection(self, connection):
-        """A new client connection has been established
-        """
-        self.connections.append(connection)
-
-    def lost_connection(self, connection):
-        """A client connection has been lost
-        """
-        try:
-            self.connections.remove(connection)
-        except ValueError:
-            pass
-
-        # We need to tell the presence handler that the connection has been
-        # lost so that it can handle any ongoing syncs on that connection.
-        run_as_background_process(
-            "update_external_syncs_clear",
-            self.presence_handler.update_external_syncs_clear,
-            connection.conn_id,
-        )
-
 
 def _batch_updates(updates):
     """Takes a list of updates of form [(token, row)] and sets the token to
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 5f52264e84..d1a61c3314 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -25,26 +25,63 @@ Each stream is defined by the following information:
     update_function:    The function that returns a list of updates between two tokens
 """
 
-from . import _base, events, federation
+from synapse.replication.tcp.streams._base import (
+    AccountDataStream,
+    BackfillStream,
+    CachesStream,
+    DeviceListsStream,
+    GroupServerStream,
+    PresenceStream,
+    PublicRoomsStream,
+    PushersStream,
+    PushRulesStream,
+    ReceiptsStream,
+    Stream,
+    TagAccountDataStream,
+    ToDeviceStream,
+    TypingStream,
+    UserSignatureStream,
+)
+from synapse.replication.tcp.streams.events import EventsStream
+from synapse.replication.tcp.streams.federation import FederationStream
 
 STREAMS_MAP = {
     stream.NAME: stream
     for stream in (
-        events.EventsStream,
-        _base.BackfillStream,
-        _base.PresenceStream,
-        _base.TypingStream,
-        _base.ReceiptsStream,
-        _base.PushRulesStream,
-        _base.PushersStream,
-        _base.CachesStream,
-        _base.PublicRoomsStream,
-        _base.DeviceListsStream,
-        _base.ToDeviceStream,
-        federation.FederationStream,
-        _base.TagAccountDataStream,
-        _base.AccountDataStream,
-        _base.GroupServerStream,
-        _base.UserSignatureStream,
+        EventsStream,
+        BackfillStream,
+        PresenceStream,
+        TypingStream,
+        ReceiptsStream,
+        PushRulesStream,
+        PushersStream,
+        CachesStream,
+        PublicRoomsStream,
+        DeviceListsStream,
+        ToDeviceStream,
+        FederationStream,
+        TagAccountDataStream,
+        AccountDataStream,
+        GroupServerStream,
+        UserSignatureStream,
     )
 }
+
+__all__ = [
+    "STREAMS_MAP",
+    "Stream",
+    "BackfillStream",
+    "PresenceStream",
+    "TypingStream",
+    "ReceiptsStream",
+    "PushRulesStream",
+    "PushersStream",
+    "CachesStream",
+    "PublicRoomsStream",
+    "DeviceListsStream",
+    "ToDeviceStream",
+    "TagAccountDataStream",
+    "AccountDataStream",
+    "GroupServerStream",
+    "UserSignatureStream",
+]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..a860072ccf 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,117 +14,66 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import itertools
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional
+from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
 
 import attr
 
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
+
 logger = logging.getLogger(__name__)
 
 
 MAX_EVENTS_BEHIND = 500000
 
-BackfillStreamRow = namedtuple(
-    "BackfillStreamRow",
-    (
-        "event_id",  # str
-        "room_id",  # str
-        "type",  # str
-        "state_key",  # str, optional
-        "redacts",  # str, optional
-        "relates_to",  # str, optional
-    ),
-)
-PresenceStreamRow = namedtuple(
-    "PresenceStreamRow",
-    (
-        "user_id",  # str
-        "state",  # str
-        "last_active_ts",  # int
-        "last_federation_update_ts",  # int
-        "last_user_sync_ts",  # int
-        "status_msg",  # str
-        "currently_active",  # bool
-    ),
-)
-TypingStreamRow = namedtuple(
-    "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
-)
-ReceiptsStreamRow = namedtuple(
-    "ReceiptsStreamRow",
-    (
-        "room_id",  # str
-        "receipt_type",  # str
-        "user_id",  # str
-        "event_id",  # str
-        "data",  # dict
-    ),
-)
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
-PushersStreamRow = namedtuple(
-    "PushersStreamRow",
-    ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
-)
-
-
-@attr.s
-class CachesStreamRow:
-    """Stream to inform workers they should invalidate their cache.
-
-    Attributes:
-        cache_func: Name of the cached function.
-        keys: The entry in the cache to invalidate. If None then will
-            invalidate all.
-        invalidation_ts: Timestamp of when the invalidation took place.
-    """
 
-    cache_func = attr.ib(type=str)
-    keys = attr.ib(type=Optional[List[Any]])
-    invalidation_ts = attr.ib(type=int)
-
-
-PublicRoomsStreamRow = namedtuple(
-    "PublicRoomsStreamRow",
-    (
-        "room_id",  # str
-        "visibility",  # str
-        "appservice_id",  # str, optional
-        "network_id",  # str, optional
-    ),
-)
-DeviceListsStreamRow = namedtuple(
-    "DeviceListsStreamRow", ("user_id", "destination")  # str  # str
-)
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
-TagAccountDataStreamRow = namedtuple(
-    "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
-)
-AccountDataStreamRow = namedtuple(
-    "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
-)
-GroupsStreamRow = namedtuple(
-    "GroupsStreamRow",
-    ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
-)
-UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# The type of a stream update row, after JSON deserialisation, but before
+# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
+# just a row from a database query, though this is dependent on the stream in question.
+#
+StreamRow = Tuple
+
+# The type returned by the update_function of a stream, as well as get_updates(),
+# get_updates_since, etc.
+#
+# It consists of a triplet `(updates, new_last_token, limited)`, where:
+#   * `updates` is a list of `(token, row)` entries.
+#   * `new_last_token` is the new position in stream.
+#   * `limited` is whether there are more updates to fetch.
+#
+StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
+
+# The type of an update_function for a stream
+#
+# The arguments are:
+#
+#  * from_token: the previous stream token: the starting point for fetching the
+#    updates
+#  * to_token: the new stream token: the point to get updates up to
+#  * limit: the maximum number of rows to return
+#
+UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
 
 
 class Stream(object):
     """Base class for the streams.
 
     Provides a `get_updates()` function that returns new updates since the last
-    time it was called up until the point `advance_current_token` was called.
+    time it was called.
     """
 
     NAME = None  # type: str  # The name of the stream
     # The type of the row. Used by the default impl of parse_row.
     ROW_TYPE = None  # type: Any
-    _LIMITED = True  # Whether the update function takes a limit
 
     @classmethod
-    def parse_row(cls, row):
+    def parse_row(cls, row: StreamRow):
         """Parse a row received over replication
 
         By default, assumes that the row data is an array object and passes its contents
@@ -138,101 +87,120 @@ class Stream(object):
         """
         return cls.ROW_TYPE(*row)
 
-    def __init__(self, hs):
-        # The token from which we last asked for updates
-        self.last_token = self.current_token()
+    def __init__(
+        self,
+        current_token_function: Callable[[], Token],
+        update_function: UpdateFunction,
+    ):
+        """Instantiate a Stream
 
-        # The token that we will get updates up to
-        self.upto_token = self.current_token()
+        current_token_function and update_function are callbacks which should be
+        implemented by subclasses.
 
-    def advance_current_token(self):
-        """Updates `upto_token` to "now", which updates up until which point
-        get_updates[_since] will fetch rows till.
+        current_token_function is called to get the current token of the underlying
+        stream.
+
+        update_function is called to get updates for this stream between a pair of
+        stream tokens. See the UpdateFunction type definition for more info.
+
+        Args:
+            current_token_function: callback to get the current token, as above
+            update_function: callback go get stream updates, as above
         """
-        self.upto_token = self.current_token()
+        self.current_token = current_token_function
+        self.update_function = update_function
+
+        # The token from which we last asked for updates
+        self.last_token = self.current_token()
 
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.upto_token = self.current_token()
-        self.last_token = self.upto_token
+        self.last_token = self.current_token()
 
-    async def get_updates(self):
+    async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
-        since the stream was constructed if it hadn't been called before),
-        until the `upto_token`
+        since the stream was constructed if it hadn't been called before).
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        updates, current_token = await self.get_updates_since(self.last_token)
+        current_token = self.current_token()
+        updates, current_token, limited = await self.get_updates_since(
+            self.last_token, current_token
+        )
         self.last_token = current_token
 
-        return updates, current_token
+        return updates, current_token, limited
 
-    async def get_updates_since(self, from_token):
+    async def get_updates_since(
+        self, from_token: Token, upto_token: Token, limit: int = 100
+    ) -> StreamUpdateResult:
         """Like get_updates except allows specifying from when we should
         stream updates
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        if from_token in ("NOW", "now"):
-            return [], self.upto_token
-
-        current_token = self.upto_token
 
         from_token = int(from_token)
 
-        if from_token == current_token:
-            return [], current_token
+        if from_token == upto_token:
+            return [], upto_token, False
 
-        logger.info("get_updates_since: %s", self.__class__)
-        if self._LIMITED:
-            rows = await self.update_function(
-                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
-            )
+        updates, upto_token, limited = await self.update_function(
+            from_token, upto_token, limit,
+        )
+        return updates, upto_token, limited
 
-            # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
-            rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-        else:
-            rows = await self.update_function(from_token, current_token)
 
+def db_query_to_update_function(
+    query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
+) -> UpdateFunction:
+    """Wraps a db query function which returns a list of rows to make it
+    suitable for use as an `update_function` for the Stream class
+    """
+
+    async def update_function(from_token, upto_token, limit):
+        rows = await query_function(from_token, upto_token, limit)
         updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) == limit:
+            upto_token = updates[-1][0]
+            limited = True
+        assert len(updates) <= limit
 
-        # check we didn't get more rows than the limit.
-        # doing it like this allows the update_function to be a generator.
-        if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
-            raise Exception("stream %s has fallen behind" % (self.NAME))
+        return updates, upto_token, limited
 
-        return updates, current_token
+    return update_function
 
-    def current_token(self):
-        """Gets the current token of the underlying streams. Should be provided
-        by the sub classes
 
-        Returns:
-            int
-        """
-        raise NotImplementedError()
+def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
+    """Makes a suitable function for use as an `update_function` that queries
+    the master process for updates.
+    """
 
-    def update_function(self, from_token, current_token, limit=None):
-        """Get updates between from_token and to_token. If Stream._LIMITED is
-        True then limit is provided, otherwise it's not.
+    client = ReplicationGetStreamUpdates.make_client(hs)
 
-        Returns:
-            Deferred(list(tuple)): the first entry in the tuple is the token for
-                that update, and the rest of the tuple gets used to construct
-                a ``ROW_TYPE`` instance
-        """
-        raise NotImplementedError()
+    async def update_function(
+        from_token: int, upto_token: int, limit: int
+    ) -> StreamUpdateResult:
+        result = await client(
+            stream_name=stream_name,
+            from_token=from_token,
+            upto_token=upto_token,
+            limit=limit,
+        )
+        return result["updates"], result["upto_token"], result["limited"]
+
+    return update_function
 
 
 class BackfillStream(Stream):
@@ -240,93 +208,157 @@ class BackfillStream(Stream):
     or it went from being an outlier to not.
     """
 
+    BackfillStreamRow = namedtuple(
+        "BackfillStreamRow",
+        (
+            "event_id",  # str
+            "room_id",  # str
+            "type",  # str
+            "state_key",  # str, optional
+            "redacts",  # str, optional
+            "relates_to",  # str, optional
+        ),
+    )
+
     NAME = "backfill"
     ROW_TYPE = BackfillStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        self.current_token = store.get_current_backfill_token  # type: ignore
-        self.update_function = store.get_all_new_backfill_event_rows  # type: ignore
-
-        super(BackfillStream, self).__init__(hs)
+        super().__init__(
+            store.get_current_backfill_token,
+            db_query_to_update_function(store.get_all_new_backfill_event_rows),
+        )
 
 
 class PresenceStream(Stream):
+    PresenceStreamRow = namedtuple(
+        "PresenceStreamRow",
+        (
+            "user_id",  # str
+            "state",  # str
+            "last_active_ts",  # int
+            "last_federation_update_ts",  # int
+            "last_user_sync_ts",  # int
+            "status_msg",  # str
+            "currently_active",  # bool
+        ),
+    )
+
     NAME = "presence"
-    _LIMITED = False
     ROW_TYPE = PresenceStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        presence_handler = hs.get_presence_handler()
 
-        self.current_token = store.get_current_presence_token  # type: ignore
-        self.update_function = presence_handler.get_all_presence_updates  # type: ignore
+        if hs.config.worker_app is None:
+            # on the master, query the presence handler
+            presence_handler = hs.get_presence_handler()
+            update_function = db_query_to_update_function(
+                presence_handler.get_all_presence_updates
+            )
+        else:
+            # Query master process
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(PresenceStream, self).__init__(hs)
+        super().__init__(store.get_current_presence_token, update_function)
 
 
 class TypingStream(Stream):
+    TypingStreamRow = namedtuple(
+        "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
+    )
+
     NAME = "typing"
-    _LIMITED = False
     ROW_TYPE = TypingStreamRow
 
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
 
-        self.current_token = typing_handler.get_current_token  # type: ignore
-        self.update_function = typing_handler.get_all_typing_updates  # type: ignore
+        if hs.config.worker_app is None:
+            # on the master, query the typing handler
+            update_function = db_query_to_update_function(
+                typing_handler.get_all_typing_updates
+            )
+        else:
+            # Query master process
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(TypingStream, self).__init__(hs)
+        super().__init__(typing_handler.get_current_token, update_function)
 
 
 class ReceiptsStream(Stream):
+    ReceiptsStreamRow = namedtuple(
+        "ReceiptsStreamRow",
+        (
+            "room_id",  # str
+            "receipt_type",  # str
+            "user_id",  # str
+            "event_id",  # str
+            "data",  # dict
+        ),
+    )
+
     NAME = "receipts"
     ROW_TYPE = ReceiptsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_receipt_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_receipts  # type: ignore
-
-        super(ReceiptsStream, self).__init__(hs)
+        super().__init__(
+            store.get_max_receipt_stream_id,
+            db_query_to_update_function(store.get_all_updated_receipts),
+        )
 
 
 class PushRulesStream(Stream):
     """A user has changed their push rules
     """
 
+    PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
+
     NAME = "push_rules"
     ROW_TYPE = PushRulesStreamRow
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-        super(PushRulesStream, self).__init__(hs)
+        super(PushRulesStream, self).__init__(
+            self._current_token, self._update_function
+        )
 
-    def current_token(self):
+    def _current_token(self) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(self, from_token: Token, to_token: Token, limit: int):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
-        return [(row[0], row[2]) for row in rows]
+
+        limited = False
+        if len(rows) == limit:
+            to_token = rows[-1][0]
+            limited = True
+
+        return [(row[0], (row[2],)) for row in rows], to_token, limited
 
 
 class PushersStream(Stream):
     """A user has added/changed/removed a pusher
     """
 
+    PushersStreamRow = namedtuple(
+        "PushersStreamRow",
+        ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
+    )
+
     NAME = "pushers"
     ROW_TYPE = PushersStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_pushers_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_pushers_rows  # type: ignore
-
-        super(PushersStream, self).__init__(hs)
+        super().__init__(
+            store.get_pushers_stream_token,
+            db_query_to_update_function(store.get_all_updated_pushers_rows),
+        )
 
 
 class CachesStream(Stream):
@@ -334,98 +366,132 @@ class CachesStream(Stream):
     the cache on the workers
     """
 
+    @attr.s
+    class CachesStreamRow:
+        """Stream to inform workers they should invalidate their cache.
+
+        Attributes:
+            cache_func: Name of the cached function.
+            keys: The entry in the cache to invalidate. If None then will
+                invalidate all.
+            invalidation_ts: Timestamp of when the invalidation took place.
+        """
+
+        cache_func = attr.ib(type=str)
+        keys = attr.ib(type=Optional[List[Any]])
+        invalidation_ts = attr.ib(type=int)
+
     NAME = "caches"
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_cache_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_caches  # type: ignore
-
-        super(CachesStream, self).__init__(hs)
+        super().__init__(
+            store.get_cache_stream_token,
+            db_query_to_update_function(store.get_all_updated_caches),
+        )
 
 
 class PublicRoomsStream(Stream):
     """The public rooms list changed
     """
 
+    PublicRoomsStreamRow = namedtuple(
+        "PublicRoomsStreamRow",
+        (
+            "room_id",  # str
+            "visibility",  # str
+            "appservice_id",  # str, optional
+            "network_id",  # str, optional
+        ),
+    )
+
     NAME = "public_rooms"
     ROW_TYPE = PublicRoomsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_current_public_room_stream_id  # type: ignore
-        self.update_function = store.get_all_new_public_rooms  # type: ignore
-
-        super(PublicRoomsStream, self).__init__(hs)
+        super().__init__(
+            store.get_current_public_room_stream_id,
+            db_query_to_update_function(store.get_all_new_public_rooms),
+        )
 
 
 class DeviceListsStream(Stream):
-    """Someone added/changed/removed a device
+    """Either a user has updated their devices or a remote server needs to be
+    told about a device update.
     """
 
+    @attr.s
+    class DeviceListsStreamRow:
+        entity = attr.ib(type=str)
+
     NAME = "device_lists"
-    _LIMITED = False
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_device_list_changes_for_remotes  # type: ignore
-
-        super(DeviceListsStream, self).__init__(hs)
+        super().__init__(
+            store.get_device_stream_token,
+            db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+        )
 
 
 class ToDeviceStream(Stream):
     """New to_device messages for a client
     """
 
+    ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
+
     NAME = "to_device"
     ROW_TYPE = ToDeviceStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_to_device_stream_token  # type: ignore
-        self.update_function = store.get_all_new_device_messages  # type: ignore
-
-        super(ToDeviceStream, self).__init__(hs)
+        super().__init__(
+            store.get_to_device_stream_token,
+            db_query_to_update_function(store.get_all_new_device_messages),
+        )
 
 
 class TagAccountDataStream(Stream):
     """Someone added/removed a tag for a room
     """
 
+    TagAccountDataStreamRow = namedtuple(
+        "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
+    )
+
     NAME = "tag_account_data"
     ROW_TYPE = TagAccountDataStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_tags  # type: ignore
-
-        super(TagAccountDataStream, self).__init__(hs)
+        super().__init__(
+            store.get_max_account_data_stream_id,
+            db_query_to_update_function(store.get_all_updated_tags),
+        )
 
 
 class AccountDataStream(Stream):
     """Global or per room account data was changed
     """
 
+    AccountDataStreamRow = namedtuple(
+        "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
+    )
+
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
+        super().__init__(
+            self.store.get_max_account_data_stream_id,
+            db_query_to_update_function(self._update_function),
+        )
 
-        self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
-
-        super(AccountDataStream, self).__init__(hs)
-
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(self, from_token, to_token, limit):
         global_results, room_results = await self.store.get_all_updated_account_data(
             from_token, from_token, to_token, limit
         )
@@ -440,30 +506,36 @@ class AccountDataStream(Stream):
 
 
 class GroupServerStream(Stream):
+    GroupsStreamRow = namedtuple(
+        "GroupsStreamRow",
+        ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
+    )
+
     NAME = "groups"
     ROW_TYPE = GroupsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_group_stream_token  # type: ignore
-        self.update_function = store.get_all_groups_changes  # type: ignore
-
-        super(GroupServerStream, self).__init__(hs)
+        super().__init__(
+            store.get_group_stream_token,
+            db_query_to_update_function(store.get_all_groups_changes),
+        )
 
 
 class UserSignatureStream(Stream):
     """A user has signed their own device with their user-signing key
     """
 
+    UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
+
     NAME = "user_signature"
-    _LIMITED = False
     ROW_TYPE = UserSignatureStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_user_signature_changes_for_remotes  # type: ignore
-
-        super(UserSignatureStream, self).__init__(hs)
+        super().__init__(
+            store.get_device_stream_token,
+            db_query_to_update_function(
+                store.get_all_user_signature_changes_for_remotes
+            ),
+        )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..051114596b 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,11 +15,11 @@
 # limitations under the License.
 
 import heapq
-from typing import Tuple, Type
+from typing import Iterable, Tuple, Type
 
 import attr
 
-from ._base import Stream
+from ._base import Stream, Token, db_query_to_update_function
 
 
 """Handling of the 'events' replication stream
@@ -116,11 +116,14 @@ class EventsStream(Stream):
 
     def __init__(self, hs):
         self._store = hs.get_datastore()
-        self.current_token = self._store.get_current_events_token  # type: ignore
-
-        super(EventsStream, self).__init__(hs)
+        super().__init__(
+            self._store.get_current_events_token,
+            db_query_to_update_function(self._update_function),
+        )
 
-    async def update_function(self, from_token, current_token, limit=None):
+    async def _update_function(
+        self, from_token: Token, current_token: Token, limit: int
+    ) -> Iterable[tuple]:
         event_rows = await self._store.get_all_new_forward_event_rows(
             from_token, current_token, limit
         )
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 615f3dc9ac..75133d7e40 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,15 +15,7 @@
 # limitations under the License.
 from collections import namedtuple
 
-from ._base import Stream
-
-FederationStreamRow = namedtuple(
-    "FederationStreamRow",
-    (
-        "type",  # str, the type of data as defined in the BaseFederationRows
-        "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
-    ),
-)
+from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
 
 
 class FederationStream(Stream):
@@ -31,13 +23,33 @@ class FederationStream(Stream):
     sending disabled.
     """
 
+    FederationStreamRow = namedtuple(
+        "FederationStreamRow",
+        (
+            "type",  # str, the type of data as defined in the BaseFederationRows
+            "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
+        ),
+    )
+
     NAME = "federation"
     ROW_TYPE = FederationStreamRow
 
     def __init__(self, hs):
-        federation_sender = hs.get_federation_sender()
-
-        self.current_token = federation_sender.get_current_token  # type: ignore
-        self.update_function = federation_sender.get_replication_rows  # type: ignore
-
-        super(FederationStream, self).__init__(hs)
+        # Not all synapse instances will have a federation sender instance,
+        # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
+        # so we stub the stream out when that is the case.
+        if hs.config.worker_app is None or hs.should_send_federation():
+            federation_sender = hs.get_federation_sender()
+            current_token = federation_sender.get_current_token
+            update_function = db_query_to_update_function(
+                federation_sender.get_replication_rows
+            )
+        else:
+            current_token = lambda: 0
+            update_function = self._stub_update_function
+
+        super().__init__(current_token, update_function)
+
+    @staticmethod
+    async def _stub_update_function(from_token, upto_token, limit):
+        return [], upto_token, False