summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/__init__.py23
-rw-r--r--synapse/replication/http/_base.py57
-rw-r--r--synapse/replication/http/devices.py73
-rw-r--r--synapse/replication/http/federation.py85
-rw-r--r--synapse/replication/http/login.py7
-rw-r--r--synapse/replication/http/membership.py69
-rw-r--r--synapse/replication/http/presence.py116
-rw-r--r--synapse/replication/http/register.py14
-rw-r--r--synapse/replication/http/send_event.py24
-rw-r--r--synapse/replication/http/streams.py79
-rw-r--r--synapse/replication/slave/storage/_base.py58
-rw-r--r--synapse/replication/slave/storage/account_data.py31
-rw-r--r--synapse/replication/slave/storage/appservice.py2
-rw-r--r--synapse/replication/slave/storage/client_ips.py10
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py18
-rw-r--r--synapse/replication/slave/storage/devices.py63
-rw-r--r--synapse/replication/slave/storage/directory.py2
-rw-r--r--synapse/replication/slave/storage/events.py130
-rw-r--r--synapse/replication/slave/storage/filtering.py7
-rw-r--r--synapse/replication/slave/storage/groups.py30
-rw-r--r--synapse/replication/slave/storage/keys.py2
-rw-r--r--synapse/replication/slave/storage/presence.py32
-rw-r--r--synapse/replication/slave/storage/profile.py2
-rw-r--r--synapse/replication/slave/storage/push_rule.py20
-rw-r--r--synapse/replication/slave/storage/pushers.py19
-rw-r--r--synapse/replication/slave/storage/receipts.py18
-rw-r--r--synapse/replication/slave/storage/registration.py2
-rw-r--r--synapse/replication/slave/storage/room.py16
-rw-r--r--synapse/replication/slave/storage/transactions.py2
-rw-r--r--synapse/replication/tcp/__init__.py30
-rw-r--r--synapse/replication/tcp/client.py315
-rw-r--r--synapse/replication/tcp/commands.py267
-rw-r--r--synapse/replication/tcp/handler.py596
-rw-r--r--synapse/replication/tcp/protocol.py428
-rw-r--r--synapse/replication/tcp/redis.py215
-rw-r--r--synapse/replication/tcp/resource.py206
-rw-r--r--synapse/replication/tcp/streams/__init__.py70
-rw-r--r--synapse/replication/tcp/streams/_base.py629
-rw-r--r--synapse/replication/tcp/streams/events.py133
-rw-r--r--synapse/replication/tcp/streams/federation.py54
40 files changed, 2542 insertions, 1412 deletions
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 81b85352b1..19b69e0e11 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -14,7 +14,16 @@
 # limitations under the License.
 
 from synapse.http.server import JsonResource
-from synapse.replication.http import federation, login, membership, register, send_event
+from synapse.replication.http import (
+    devices,
+    federation,
+    login,
+    membership,
+    presence,
+    register,
+    send_event,
+    streams,
+)
 
 REPLICATION_PREFIX = "/_synapse/replication"
 
@@ -26,7 +35,13 @@ class ReplicationRestResource(JsonResource):
 
     def register_servlets(self, hs):
         send_event.register_servlets(hs, self)
-        membership.register_servlets(hs, self)
         federation.register_servlets(hs, self)
-        login.register_servlets(hs, self)
-        register.register_servlets(hs, self)
+        presence.register_servlets(hs, self)
+        membership.register_servlets(hs, self)
+
+        # The following can't currently be instantiated on workers.
+        if hs.config.worker.worker_app is None:
+            login.register_servlets(hs, self)
+            register.register_servlets(hs, self)
+            devices.register_servlets(hs, self)
+            streams.register_servlets(hs, self)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 03560c1f0e..793cef6c26 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,6 +16,8 @@
 import abc
 import logging
 import re
+from inspect import signature
+from typing import Dict, List, Tuple
 
 from six import raise_from
 from six.moves import urllib
@@ -43,7 +45,7 @@ class ReplicationEndpoint(object):
     """Helper base class for defining new replication HTTP endpoints.
 
     This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
-    (with an `/:txn_id` prefix for cached requests.), where NAME is a name,
+    (with a `/:txn_id` suffix for cached requests), where NAME is a name,
     PATH_ARGS are a tuple of parameters to be encoded in the URL.
 
     For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`,
@@ -59,6 +61,8 @@ class ReplicationEndpoint(object):
     must call `register` to register the path with the HTTP server.
 
     Requests can be sent by calling the client returned by `make_client`.
+    Requests are sent to master process by default, but can be sent to other
+    named processes by specifying an `instance_name` keyword argument.
 
     Attributes:
         NAME (str): A name for the endpoint, added to the path as well as used
@@ -78,9 +82,8 @@ class ReplicationEndpoint(object):
 
     __metaclass__ = abc.ABCMeta
 
-    NAME = abc.abstractproperty()
-    PATH_ARGS = abc.abstractproperty()
-
+    NAME = abc.abstractproperty()  # type: str  # type: ignore
+    PATH_ARGS = abc.abstractproperty()  # type: Tuple[str, ...]  # type: ignore
     METHOD = "POST"
     CACHE = True
     RETRY_ON_TIMEOUT = True
@@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
                 hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
             )
 
+        # We reserve `instance_name` as a parameter to sending requests, so we
+        # assert here that sub classes don't try and use the name.
+        assert (
+            "instance_name" not in self.PATH_ARGS
+        ), "`instance_name` is a reserved paramater name"
+        assert (
+            "instance_name"
+            not in signature(self.__class__._serialize_payload).parameters
+        ), "`instance_name` is a reserved paramater name"
+
         assert self.METHOD in ("PUT", "POST", "GET")
 
     @abc.abstractmethod
@@ -110,14 +123,14 @@ class ReplicationEndpoint(object):
         return {}
 
     @abc.abstractmethod
-    def _handle_request(self, request, **kwargs):
+    async def _handle_request(self, request, **kwargs):
         """Handle incoming request.
 
         This is called with the request object and PATH_ARGS.
 
         Returns:
-            Deferred[dict]: A JSON serialisable dict to be used as response
-            body of request.
+            tuple[int, dict]: HTTP status code and a JSON serialisable dict
+            to be used as response body of request.
         """
         pass
 
@@ -128,14 +141,30 @@ class ReplicationEndpoint(object):
         Returns a callable that accepts the same parameters as `_serialize_payload`.
         """
         clock = hs.get_clock()
-        host = hs.config.worker_replication_host
-        port = hs.config.worker_replication_http_port
-
         client = hs.get_simple_http_client()
+        local_instance_name = hs.get_instance_name()
+
+        master_host = hs.config.worker_replication_host
+        master_port = hs.config.worker_replication_http_port
+
+        instance_map = hs.config.worker.instance_map
 
         @trace(opname="outgoing_replication_request")
         @defer.inlineCallbacks
-        def send_request(**kwargs):
+        def send_request(instance_name="master", **kwargs):
+            if instance_name == local_instance_name:
+                raise Exception("Trying to send HTTP request to self")
+            if instance_name == "master":
+                host = master_host
+                port = master_port
+            elif instance_name in instance_map:
+                host = instance_map[instance_name].host
+                port = instance_map[instance_name].port
+            else:
+                raise Exception(
+                    "Instance %r not in 'instance_map' config" % (instance_name,)
+                )
+
             data = yield cls._serialize_payload(**kwargs)
 
             url_args = [
@@ -171,7 +200,7 @@ class ReplicationEndpoint(object):
                 # have a good idea that the request has either succeeded or failed on
                 # the master, and so whether we should clean up or not.
                 while True:
-                    headers = {}
+                    headers = {}  # type: Dict[bytes, List[bytes]]
                     inject_active_span_byte_dict(headers, None, check_destination=False)
                     try:
                         result = yield request_func(uri, data, headers=headers)
@@ -180,7 +209,7 @@ class ReplicationEndpoint(object):
                         if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
                             raise
 
-                    logger.warn("%s request timed out", cls.NAME)
+                    logger.warning("%s request timed out", cls.NAME)
 
                     # If we timed out we probably don't need to worry about backing
                     # off too much, but lets just wait a little anyway.
@@ -207,7 +236,7 @@ class ReplicationEndpoint(object):
         method = self.METHOD
 
         if self.CACHE:
-            handler = self._cached_handler
+            handler = self._cached_handler  # type: ignore
             url_args.append("txn_id")
 
         args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
new file mode 100644
index 0000000000..e32aac0a25
--- /dev/null
+++ b/synapse/replication/http/devices.py
@@ -0,0 +1,73 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
+    """Ask master to resync the device list for a user by contacting their
+    server.
+
+    This must happen on master so that the results can be correctly cached in
+    the database and streamed to workers.
+
+    Request format:
+
+        POST /_synapse/replication/user_device_resync/:user_id
+
+        {}
+
+    Response is equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
+    response, e.g.:
+
+        {
+            "user_id": "@alice:example.org",
+            "devices": [
+                {
+                    "device_id": "JLAFKJWSCS",
+                    "keys": { ... },
+                    "device_display_name": "Alice's Mobile Phone"
+                }
+            ]
+        }
+    """
+
+    NAME = "user_device_resync"
+    PATH_ARGS = ("user_id",)
+    CACHE = False
+
+    def __init__(self, hs):
+        super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs)
+
+        self.device_list_updater = hs.get_device_handler().device_list_updater
+        self.store = hs.get_datastore()
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    def _serialize_payload(user_id):
+        return {}
+
+    async def _handle_request(self, request, user_id):
+        user_devices = await self.device_list_updater.user_device_resync(user_id)
+
+        return 200, user_devices
+
+
+def register_servlets(hs, http_server):
+    ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 2f16955954..c287c4e269 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -17,7 +17,8 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.events import event_type_from_format_version
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
@@ -28,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
     """Handles events newly received from federation, including persisting and
-    notifying.
+    notifying. Returns the maximum stream ID of the persisted events.
 
     The API looks like:
 
@@ -37,11 +38,21 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         {
             "events": [{
                 "event": { .. serialized event .. },
+                "room_version": .., // "1", "2", "3", etc: the version of the room
+                                    // containing the event
+                "event_format_version": .., // 1,2,3 etc: the event format version
                 "internal_metadata": { .. serialized internal_metadata .. },
                 "rejected_reason": ..,   // The event.rejected_reason field
                 "context": { .. serialized event context .. },
             }],
             "backfilled": false
+        }
+
+        200 OK
+
+        {
+            "max_stream_id": 32443,
+        }
     """
 
     NAME = "fed_send_events"
@@ -51,6 +62,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
         super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
 
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.clock = hs.get_clock()
         self.federation_handler = hs.get_handlers().federation_handler
 
@@ -71,6 +83,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
             event_payloads.append(
                 {
                     "event": event.get_pdu_json(),
+                    "room_version": event.room_version.identifier,
                     "event_format_version": event.format_version,
                     "internal_metadata": event.internal_metadata.get_dict(),
                     "rejected_reason": event.rejected_reason,
@@ -82,8 +95,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
 
         return payload
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request):
+    async def _handle_request(self, request):
         with Measure(self.clock, "repl_fed_send_events_parse"):
             content = parse_json_object_from_request(request)
 
@@ -94,26 +106,27 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
             event_and_contexts = []
             for event_payload in event_payloads:
                 event_dict = event_payload["event"]
-                format_ver = event_payload["event_format_version"]
+                room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]]
                 internal_metadata = event_payload["internal_metadata"]
                 rejected_reason = event_payload["rejected_reason"]
 
-                EventType = event_type_from_format_version(format_ver)
-                event = EventType(event_dict, internal_metadata, rejected_reason)
+                event = make_event_from_dict(
+                    event_dict, room_ver, internal_metadata, rejected_reason
+                )
 
-                context = yield EventContext.deserialize(
-                    self.store, event_payload["context"]
+                context = EventContext.deserialize(
+                    self.storage, event_payload["context"]
                 )
 
                 event_and_contexts.append((event, context))
 
         logger.info("Got %d events from federation", len(event_and_contexts))
 
-        yield self.federation_handler.persist_events_and_notify(
+        max_stream_id = await self.federation_handler.persist_events_and_notify(
             event_and_contexts, backfilled
         )
 
-        return 200, {}
+        return 200, {"max_stream_id": max_stream_id}
 
 
 class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
@@ -144,8 +157,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
     def _serialize_payload(edu_type, origin, content):
         return {"origin": origin, "content": content}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, edu_type):
+    async def _handle_request(self, request, edu_type):
         with Measure(self.clock, "repl_fed_send_edu_parse"):
             content = parse_json_object_from_request(request)
 
@@ -154,7 +166,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
 
         logger.info("Got %r edu from %s", edu_type, origin)
 
-        result = yield self.registry.on_edu(edu_type, origin, edu_content)
+        result = await self.registry.on_edu(edu_type, origin, edu_content)
 
         return 200, result
 
@@ -193,8 +205,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
         """
         return {"args": args}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, query_type):
+    async def _handle_request(self, request, query_type):
         with Measure(self.clock, "repl_fed_query_parse"):
             content = parse_json_object_from_request(request)
 
@@ -202,7 +213,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
 
         logger.info("Got %r query", query_type)
 
-        result = yield self.registry.on_query(query_type, args)
+        result = await self.registry.on_query(query_type, args)
 
         return 200, result
 
@@ -213,7 +224,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
 
     Request format:
 
-        POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
+        POST /_synapse/replication/fed_cleanup_room/:room_id/:txn_id
 
         {}
     """
@@ -234,10 +245,41 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
         """
         return {}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, room_id):
-        yield self.store.clean_room_for_join(room_id)
+    async def _handle_request(self, request, room_id):
+        await self.store.clean_room_for_join(room_id)
+
+        return 200, {}
+
+
+class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
+    """Called to clean up any data in DB for a given room, ready for the
+    server to join the room.
+
+    Request format:
+
+        POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id
+
+        {
+            "room_version": "1",
+        }
+    """
+
+    NAME = "store_room_on_invite"
+    PATH_ARGS = ("room_id",)
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self.store = hs.get_datastore()
+
+    @staticmethod
+    def _serialize_payload(room_id, room_version):
+        return {"room_version": room_version.identifier}
 
+    async def _handle_request(self, request, room_id):
+        content = parse_json_object_from_request(request)
+        room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
+        await self.store.maybe_store_room_on_invite(room_id, room_version)
         return 200, {}
 
 
@@ -246,3 +288,4 @@ def register_servlets(hs, http_server):
     ReplicationFederationSendEduRestServlet(hs).register(http_server)
     ReplicationGetQueryRestServlet(hs).register(http_server)
     ReplicationCleanRoomRestServlet(hs).register(http_server)
+    ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 786f5232b2..798b9d3af5 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 
@@ -52,15 +50,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             "is_guest": is_guest,
         }
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, user_id):
+    async def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
         device_id = content["device_id"]
         initial_display_name = content["initial_display_name"]
         is_guest = content["is_guest"]
 
-        device_id, access_token = yield self.registration_handler.register_device(
+        device_id, access_token = await self.registration_handler.register_device(
             user_id, device_id, initial_display_name, is_guest
         )
 
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index b9ce3477ad..a7174c4a8f 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -14,14 +14,16 @@
 # limitations under the License.
 
 import logging
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING
 
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import Requester, UserID
 from synapse.util.distributor import user_joined_room, user_left_room
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -65,8 +67,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
             "content": content,
         }
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, room_id, user_id):
+    async def _handle_request(self, request, room_id, user_id):
         content = parse_json_object_from_request(request)
 
         remote_room_hosts = content["remote_room_hosts"]
@@ -79,11 +80,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
 
         logger.info("remote_join: %s into room: %s", user_id, room_id)
 
-        yield self.federation_handler.do_invite_join(
+        event_id, stream_id = await self.federation_handler.do_invite_join(
             remote_room_hosts, room_id, user_id, event_content
         )
 
-        return 200, {}
+        return 200, {"event_id": event_id, "stream_id": stream_id}
 
 
 class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@@ -96,6 +97,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         {
             "requester": ...,
             "remote_room_hosts": [...],
+            "content": { ... }
         }
     """
 
@@ -108,9 +110,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         self.federation_handler = hs.get_handlers().federation_handler
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
+        self.member_handler = hs.get_room_member_handler()
 
     @staticmethod
-    def _serialize_payload(requester, room_id, user_id, remote_room_hosts):
+    def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
         """
         Args:
             requester(Requester)
@@ -121,13 +124,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         return {
             "requester": requester.serialize(),
             "remote_room_hosts": remote_room_hosts,
+            "content": content,
         }
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, room_id, user_id):
+    async def _handle_request(self, request, room_id, user_id):
         content = parse_json_object_from_request(request)
 
         remote_room_hosts = content["remote_room_hosts"]
+        event_content = content["content"]
 
         requester = Requester.deserialize(self.store, content["requester"])
 
@@ -137,10 +141,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
 
         try:
-            event = yield self.federation_handler.do_remotely_reject_invite(
-                remote_room_hosts, room_id, user_id
+            event, stream_id = await self.federation_handler.do_remotely_reject_invite(
+                remote_room_hosts, room_id, user_id, event_content,
             )
-            ret = event.get_pdu_json()
+            event_id = event.event_id
         except Exception as e:
             # if we were unable to reject the exception, just mark
             # it as rejected on our end and plough ahead.
@@ -148,12 +152,44 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
             # The 'except' clause is very broad, but we need to
             # capture everything from DNS failures upwards
             #
-            logger.warn("Failed to reject invite: %s", e)
+            logger.warning("Failed to reject invite: %s", e)
+
+            stream_id = await self.member_handler.locally_reject_invite(
+                user_id, room_id
+            )
+            event_id = None
+
+        return 200, {"event_id": event_id, "stream_id": stream_id}
+
+
+class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
+    """Rejects the invite for the user and room locally.
+
+    Request format:
+
+        POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
+
+        {}
+    """
+
+    NAME = "locally_reject_invite"
+    PATH_ARGS = ("room_id", "user_id")
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self.member_handler = hs.get_room_member_handler()
+
+    @staticmethod
+    def _serialize_payload(room_id, user_id):
+        return {}
+
+    async def _handle_request(self, request, room_id, user_id):
+        logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
 
-            yield self.store.locally_reject_invite(user_id, room_id)
-            ret = {}
+        stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
 
-        return 200, ret
+        return 200, {"stream_id": stream_id}
 
 
 class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
@@ -209,3 +245,4 @@ def register_servlets(hs, http_server):
     ReplicationRemoteJoinRestServlet(hs).register(http_server)
     ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
     ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
+    ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/presence.py b/synapse/replication/http/presence.py
new file mode 100644
index 0000000000..ea1b33331b
--- /dev/null
+++ b/synapse/replication/http/presence.py
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+from synapse.types import UserID
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
+    """We've seen the user do something that indicates they're interacting
+    with the app.
+
+    The POST looks like:
+
+        POST /_synapse/replication/bump_presence_active_time/<user_id>
+
+        200 OK
+
+        {}
+    """
+
+    NAME = "bump_presence_active_time"
+    PATH_ARGS = ("user_id",)
+    METHOD = "POST"
+    CACHE = False
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self._presence_handler = hs.get_presence_handler()
+
+    @staticmethod
+    def _serialize_payload(user_id):
+        return {}
+
+    async def _handle_request(self, request, user_id):
+        await self._presence_handler.bump_presence_active_time(
+            UserID.from_string(user_id)
+        )
+
+        return (
+            200,
+            {},
+        )
+
+
+class ReplicationPresenceSetState(ReplicationEndpoint):
+    """Set the presence state for a user.
+
+    The POST looks like:
+
+        POST /_synapse/replication/presence_set_state/<user_id>
+
+        {
+            "state": { ... },
+            "ignore_status_msg": false,
+        }
+
+        200 OK
+
+        {}
+    """
+
+    NAME = "presence_set_state"
+    PATH_ARGS = ("user_id",)
+    METHOD = "POST"
+    CACHE = False
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        self._presence_handler = hs.get_presence_handler()
+
+    @staticmethod
+    def _serialize_payload(user_id, state, ignore_status_msg=False):
+        return {
+            "state": state,
+            "ignore_status_msg": ignore_status_msg,
+        }
+
+    async def _handle_request(self, request, user_id):
+        content = parse_json_object_from_request(request)
+
+        await self._presence_handler.set_state(
+            UserID.from_string(user_id), content["state"], content["ignore_status_msg"]
+        )
+
+        return (
+            200,
+            {},
+        )
+
+
+def register_servlets(hs, http_server):
+    ReplicationBumpPresenceActiveTime(hs).register(http_server)
+    ReplicationPresenceSetState(hs).register(http_server)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 38260256cf..0c4aca1291 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 
@@ -74,11 +72,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             "address": address,
         }
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, user_id):
+    async def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
-        yield self.registration_handler.register_with_store(
+        self.registration_handler.check_registration_ratelimit(content["address"])
+
+        await self.registration_handler.register_with_store(
             user_id=user_id,
             password_hash=content["password_hash"],
             was_guest=content["was_guest"],
@@ -117,14 +116,13 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
         """
         return {"auth_result": auth_result, "access_token": access_token}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, user_id):
+    async def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
         auth_result = content["auth_result"]
         access_token = content["access_token"]
 
-        yield self.registration_handler.post_registration_actions(
+        await self.registration_handler.post_registration_actions(
             user_id=user_id, auth_result=auth_result, access_token=access_token
         )
 
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index adb9b2f7f4..c981723c1a 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -17,7 +17,8 @@ import logging
 
 from twisted.internet import defer
 
-from synapse.events import event_type_from_format_version
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
@@ -37,6 +38,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
 
         {
             "event": { .. serialized event .. },
+            "room_version": .., // "1", "2", "3", etc: the version of the room
+                                // containing the event
+            "event_format_version": .., // 1,2,3 etc: the event format version
             "internal_metadata": { .. serialized internal_metadata .. },
             "rejected_reason": ..,   // The event.rejected_reason field
             "context": { .. serialized event context .. },
@@ -54,6 +58,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
 
         self.event_creation_handler = hs.get_event_creation_handler()
         self.store = hs.get_datastore()
+        self.storage = hs.get_storage()
         self.clock = hs.get_clock()
 
     @staticmethod
@@ -76,6 +81,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
 
         payload = {
             "event": event.get_pdu_json(),
+            "room_version": event.room_version.identifier,
             "event_format_version": event.format_version,
             "internal_metadata": event.internal_metadata.get_dict(),
             "rejected_reason": event.rejected_reason,
@@ -87,21 +93,21 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
 
         return payload
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, event_id):
+    async def _handle_request(self, request, event_id):
         with Measure(self.clock, "repl_send_event_parse"):
             content = parse_json_object_from_request(request)
 
             event_dict = content["event"]
-            format_ver = content["event_format_version"]
+            room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
             internal_metadata = content["internal_metadata"]
             rejected_reason = content["rejected_reason"]
 
-            EventType = event_type_from_format_version(format_ver)
-            event = EventType(event_dict, internal_metadata, rejected_reason)
+            event = make_event_from_dict(
+                event_dict, room_ver, internal_metadata, rejected_reason
+            )
 
             requester = Requester.deserialize(self.store, content["requester"])
-            context = yield EventContext.deserialize(self.store, content["context"])
+            context = EventContext.deserialize(self.storage, content["context"])
 
             ratelimit = content["ratelimit"]
             extra_users = [UserID.from_string(u) for u in content["extra_users"]]
@@ -113,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
         )
 
-        yield self.event_creation_handler.persist_and_notify_client_event(
+        stream_id = await self.event_creation_handler.persist_and_notify_client_event(
             requester, event, context, ratelimit=ratelimit, extra_users=extra_users
         )
 
-        return 200, {}
+        return 200, {"stream_id": stream_id}
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
new file mode 100644
index 0000000000..bde97eef32
--- /dev/null
+++ b/synapse/replication/http/streams.py
@@ -0,0 +1,79 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import parse_integer
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationGetStreamUpdates(ReplicationEndpoint):
+    """Fetches stream updates from a server. Used for streams not persisted to
+    the database, e.g. typing notifications.
+
+    The API looks like:
+
+        GET /_synapse/replication/get_repl_stream_updates/<stream name>?from_token=0&to_token=10
+
+        200 OK
+
+        {
+            updates: [ ... ],
+            upto_token: 10,
+            limited: False,
+        }
+
+    If there are more rows than can sensibly be returned in one lump, `limited` will be
+    set to true, and the caller should call again with a new `from_token`.
+
+    """
+
+    NAME = "get_repl_stream_updates"
+    PATH_ARGS = ("stream_name",)
+    METHOD = "GET"
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self._instance_name = hs.get_instance_name()
+        self.streams = hs.get_replication_streams()
+
+    @staticmethod
+    def _serialize_payload(stream_name, from_token, upto_token):
+        return {"from_token": from_token, "upto_token": upto_token}
+
+    async def _handle_request(self, request, stream_name):
+        stream = self.streams.get(stream_name)
+        if stream is None:
+            raise SynapseError(400, "Unknown stream")
+
+        from_token = parse_integer(request, "from_token", required=True)
+        upto_token = parse_integer(request, "upto_token", required=True)
+
+        updates, upto_token, limited = await stream.get_updates_since(
+            self._instance_name, from_token, upto_token
+        )
+
+        return (
+            200,
+            {"updates": updates, "upto_token": upto_token, "limited": limited},
+        )
+
+
+def register_servlets(hs, http_server):
+    ReplicationGetStreamUpdates(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 182cb2a1d8..f9e2533e96 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,56 +14,30 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
 
-import six
-
-from synapse.storage._base import _CURRENT_STATE_CACHE_NAME, SQLBaseStore
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
-
-from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 logger = logging.getLogger(__name__)
 
 
-def __func__(inp):
-    if six.PY3:
-        return inp
-    else:
-        return inp.__func__
-
-
-class BaseSlavedStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(BaseSlavedStore, self).__init__(db_conn, hs)
+class BaseSlavedStore(CacheInvalidationWorkerStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = SlavedIdTracker(
-                db_conn, "cache_invalidation_stream", "stream_id"
-            )
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                instance_name=hs.get_instance_name(),
+                table="cache_invalidation_stream_by_instance",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="cache_invalidation_stream_seq",
+            )  # type: Optional[MultiWriterIdGenerator]
         else:
             self._cache_id_gen = None
 
         self.hs = hs
-
-    def stream_positions(self):
-        pos = {}
-        if self._cache_id_gen:
-            pos["caches"] = self._cache_id_gen.get_current_token()
-        return pos
-
-    def process_replication_rows(self, stream_name, token, rows):
-        if stream_name == "caches":
-            self._cache_id_gen.advance(token)
-            for row in rows:
-                if row.cache_func == _CURRENT_STATE_CACHE_NAME:
-                    room_id = row.keys[0]
-                    members_changed = set(row.keys[1:])
-                    self._invalidate_state_caches(room_id, members_changed)
-                else:
-                    self._attempt_to_invalidate_cache(row.cache_func, tuple(row.keys))
-
-    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
-        txn.call_after(cache_func.invalidate, keys)
-        txn.call_after(self._send_invalidation_poke, cache_func, keys)
-
-    def _send_invalidation_poke(self, cache_func, keys):
-        self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 3c44d1d48d..9db6c62bc7 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -16,30 +16,29 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.account_data import AccountDataWorkerStore
-from synapse.storage.tags import TagsWorkerStore
+from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
+from synapse.storage.data_stores.main.tags import TagsWorkerStore
+from synapse.storage.database import Database
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._account_data_id_gen = SlavedIdTracker(
-            db_conn, "account_data_max_stream_id", "stream_id"
+            db_conn,
+            "account_data",
+            "stream_id",
+            extra_tables=[
+                ("room_account_data", "stream_id"),
+                ("room_tags_revisions", "stream_id"),
+            ],
         )
 
-        super(SlavedAccountDataStore, self).__init__(db_conn, hs)
+        super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedAccountDataStore, self).stream_positions()
-        position = self._account_data_id_gen.get_current_token()
-        result["user_account_data"] = position
-        result["room_account_data"] = position
-        result["tag_account_data"] = position
-        return result
-
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "tag_account_data":
             self._account_data_id_gen.advance(token)
             for row in rows:
@@ -58,6 +57,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
                     (row.user_id, row.room_id, row.data_type)
                 )
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedAccountDataStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index cda12ea70d..a67fbeffb7 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.appservice import (
+from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceTransactionWorkerStore,
     ApplicationServiceWorkerStore,
 )
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 14ced32333..1a38f53dfb 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,19 +13,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.client_ips import LAST_SEEN_GRANULARITY
-from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import Cache
 
 from ._base import BaseSlavedStore
 
 
 class SlavedClientIpStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedClientIpStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
 
         self.client_ip_last_seen = Cache(
-            name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+            name="client_ip_last_seen", keylen=4, max_entries=50000
         )
 
     def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 284fd30d89..6e7fd259d4 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,14 +15,15 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
         self._device_inbox_id_gen = SlavedIdTracker(
             db_conn, "device_max_stream_id", "stream_id"
         )
@@ -42,12 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
             expiry_ms=30 * 60 * 1000,
         )
 
-    def stream_positions(self):
-        result = super(SlavedDeviceInboxStore, self).stream_positions()
-        result["to_device"] = self._device_inbox_id_gen.get_current_token()
-        return result
-
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "to_device":
             self._device_inbox_id_gen.advance(token)
             for row in rows:
@@ -59,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
                     self._device_federation_outbox_stream_cache.entity_has_changed(
                         row.entity, token
                     )
-        return super(SlavedDeviceInboxStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index d9300fce33..9d8067342f 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -15,50 +15,61 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.devices import DeviceWorkerStore
-from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
+from synapse.storage.data_stores.main.devices import DeviceWorkerStore
+from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
 class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedDeviceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
 
         self._device_list_id_gen = SlavedIdTracker(
-            db_conn, "device_lists_stream", "stream_id"
+            db_conn,
+            "device_lists_stream",
+            "stream_id",
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
         )
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
             "DeviceListStreamChangeCache", device_list_max
         )
+        self._user_signature_stream_cache = StreamChangeCache(
+            "UserSignatureStreamChangeCache", device_list_max
+        )
         self._device_list_federation_stream_cache = StreamChangeCache(
             "DeviceListFederationStreamChangeCache", device_list_max
         )
 
-    def stream_positions(self):
-        result = super(SlavedDeviceStore, self).stream_positions()
-        result["device_lists"] = self._device_list_id_gen.get_current_token()
-        return result
-
-    def process_replication_rows(self, stream_name, token, rows):
-        if stream_name == "device_lists":
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == DeviceListsStream.NAME:
+            self._device_list_id_gen.advance(token)
+            self._invalidate_caches_for_devices(token, rows)
+        elif stream_name == UserSignatureStream.NAME:
             self._device_list_id_gen.advance(token)
             for row in rows:
-                self._invalidate_caches_for_devices(token, row.user_id, row.destination)
-        return super(SlavedDeviceStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
-
-    def _invalidate_caches_for_devices(self, token, user_id, destination):
-        self._device_list_stream_cache.entity_has_changed(user_id, token)
+                self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
-        if destination:
-            self._device_list_federation_stream_cache.entity_has_changed(
-                destination, token
-            )
+    def _invalidate_caches_for_devices(self, token, rows):
+        for row in rows:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            if row.entity.startswith("@"):
+                self._device_list_stream_cache.entity_has_changed(row.entity, token)
+                self.get_cached_devices_for_user.invalidate((row.entity,))
+                self._get_cached_user_device.invalidate_many((row.entity,))
+                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
-        self._get_cached_devices_for_user.invalidate((user_id,))
-        self._get_cached_user_device.invalidate_many((user_id,))
-        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+            else:
+                self._device_list_federation_stream_cache.entity_has_changed(
+                    row.entity, token
+                )
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 1d1d48709a..8b9717c46f 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.directory import DirectoryWorkerStore
+from synapse.storage.data_stores.main.directory import DirectoryWorkerStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index ab5937e638..1a1a50a24f 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -15,23 +15,21 @@
 # limitations under the License.
 import logging
 
-from synapse.api.constants import EventTypes
-from synapse.replication.tcp.streams.events import (
-    EventsStreamCurrentStateRow,
-    EventsStreamEventRow,
+from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
+from synapse.storage.data_stores.main.event_push_actions import (
+    EventPushActionsWorkerStore,
 )
-from synapse.storage.event_federation import EventFederationWorkerStore
-from synapse.storage.event_push_actions import EventPushActionsWorkerStore
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.relations import RelationsWorkerStore
-from synapse.storage.roommember import RoomMemberWorkerStore
-from synapse.storage.signatures import SignatureWorkerStore
-from synapse.storage.state import StateGroupWorkerStore
-from synapse.storage.stream import StreamWorkerStore
-from synapse.storage.user_erasure_store import UserErasureWorkerStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.relations import RelationsWorkerStore
+from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.data_stores.main.stream import StreamWorkerStore
+from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
+from synapse.storage.database import Database
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore
-from ._slaved_id_tracker import SlavedIdTracker
 
 logger = logging.getLogger(__name__)
 
@@ -57,13 +55,23 @@ class SlavedEventStore(
     RelationsWorkerStore,
     BaseSlavedStore,
 ):
-    def __init__(self, db_conn, hs):
-        self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
-        self._backfill_id_gen = SlavedIdTracker(
-            db_conn, "events", "stream_ordering", step=-1
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedEventStore, self).__init__(database, db_conn, hs)
+
+        events_max = self._stream_id_gen.get_current_token()
+        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+            db_conn,
+            "current_state_delta_stream",
+            entity_column="room_id",
+            stream_column="stream_id",
+            max_value=events_max,  # As we share the stream id with events token
+            limit=1000,
+        )
+        self._curr_state_delta_stream_cache = StreamChangeCache(
+            "_curr_state_delta_stream_cache",
+            min_curr_state_delta_id,
+            prefilled_cache=curr_state_delta_prefill,
         )
-
-        super(SlavedEventStore, self).__init__(db_conn, hs)
 
     # Cached functions can't be accessed through a class instance so we need
     # to reach inside the __dict__ to extract them.
@@ -73,85 +81,3 @@ class SlavedEventStore(
 
     def get_room_min_stream_ordering(self):
         return self._backfill_id_gen.get_current_token()
-
-    def stream_positions(self):
-        result = super(SlavedEventStore, self).stream_positions()
-        result["events"] = self._stream_id_gen.get_current_token()
-        result["backfill"] = -self._backfill_id_gen.get_current_token()
-        return result
-
-    def process_replication_rows(self, stream_name, token, rows):
-        if stream_name == "events":
-            self._stream_id_gen.advance(token)
-            for row in rows:
-                self._process_event_stream_row(token, row)
-        elif stream_name == "backfill":
-            self._backfill_id_gen.advance(-token)
-            for row in rows:
-                self.invalidate_caches_for_event(
-                    -token,
-                    row.event_id,
-                    row.room_id,
-                    row.type,
-                    row.state_key,
-                    row.redacts,
-                    row.relates_to,
-                    backfilled=True,
-                )
-        return super(SlavedEventStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
-
-    def _process_event_stream_row(self, token, row):
-        data = row.data
-
-        if row.type == EventsStreamEventRow.TypeId:
-            self.invalidate_caches_for_event(
-                token,
-                data.event_id,
-                data.room_id,
-                data.type,
-                data.state_key,
-                data.redacts,
-                data.relates_to,
-                backfilled=False,
-            )
-        elif row.type == EventsStreamCurrentStateRow.TypeId:
-            if data.type == EventTypes.Member:
-                self.get_rooms_for_user_with_stream_ordering.invalidate(
-                    (data.state_key,)
-                )
-        else:
-            raise Exception("Unknown events stream row type %s" % (row.type,))
-
-    def invalidate_caches_for_event(
-        self,
-        stream_ordering,
-        event_id,
-        room_id,
-        etype,
-        state_key,
-        redacts,
-        relates_to,
-        backfilled,
-    ):
-        self._invalidate_get_event_cache(event_id)
-
-        self.get_latest_event_ids_in_room.invalidate((room_id,))
-
-        self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
-
-        if not backfilled:
-            self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
-
-        if redacts:
-            self._invalidate_get_event_cache(redacts)
-
-        if etype == EventTypes.Member:
-            self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
-            self.get_invited_rooms_for_user.invalidate((state_key,))
-
-        if relates_to:
-            self.get_relations_for_event.invalidate_many((relates_to,))
-            self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
-            self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 456a14cd5c..bcb0688954 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -13,14 +13,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.filtering import FilteringStore
+from synapse.storage.data_stores.main.filtering import FilteringStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 
 
 class SlavedFilteringStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedFilteringStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
 
     # Filters are immutable so this cache doesn't need to be expired
     get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 28a46edd28..1851e7d525 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -13,16 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage import DataStore
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
-from ._base import BaseSlavedStore, __func__
-from ._slaved_id_tracker import SlavedIdTracker
 
-
-class SlavedGroupServerStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
 
         self.hs = hs
 
@@ -34,21 +34,13 @@ class SlavedGroupServerStore(BaseSlavedStore):
             self._group_updates_id_gen.get_current_token(),
         )
 
-    get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user)
-    get_group_stream_token = __func__(DataStore.get_group_stream_token)
-    get_all_groups_for_user = __func__(DataStore.get_all_groups_for_user)
-
-    def stream_positions(self):
-        result = super(SlavedGroupServerStore, self).stream_positions()
-        result["groups"] = self._group_updates_id_gen.get_current_token()
-        return result
+    def get_group_stream_token(self):
+        return self._group_updates_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "groups":
             self._group_updates_id_gen.advance(token)
             for row in rows:
                 self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
 
-        return super(SlavedGroupServerStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index cc6f7f009f..3def367ae9 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage import KeyStore
+from synapse.storage.data_stores.main.keys import KeyStore
 
 # KeyStore isn't really safe to use from a worker, but for now we do so and hope that
 # the races it creates aren't too bad.
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 82d808af4c..4e0124842d 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -14,47 +14,37 @@
 # limitations under the License.
 
 from synapse.storage import DataStore
-from synapse.storage.presence import PresenceStore
+from synapse.storage.data_stores.main.presence import PresenceStore
+from synapse.storage.database import Database
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
-from ._base import BaseSlavedStore, __func__
+from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPresenceStore(BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedPresenceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
         self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
 
-        self._presence_on_startup = self._get_active_presence(db_conn)
+        self._presence_on_startup = self._get_active_presence(db_conn)  # type: ignore
 
-        self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
+        self.presence_stream_cache = StreamChangeCache(
             "PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
         )
 
-    _get_active_presence = __func__(DataStore._get_active_presence)
-    take_presence_startup_info = __func__(DataStore.take_presence_startup_info)
+    _get_active_presence = DataStore._get_active_presence
+    take_presence_startup_info = DataStore.take_presence_startup_info
     _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
     get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
 
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedPresenceStore, self).stream_positions()
-
-        if self.hs.config.use_presence:
-            position = self._presence_id_gen.get_current_token()
-            result["presence"] = position
-
-        return result
-
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "presence":
             self._presence_id_gen.advance(token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
                 self._get_presence_for_user.invalidate((row.user_id,))
-        return super(SlavedPresenceStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
index 46c28d4171..28c508aad3 100644
--- a/synapse/replication/slave/storage/profile.py
+++ b/synapse/replication/slave/storage/profile.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.profile import ProfileWorkerStore
+from synapse.storage.data_stores.main.profile import ProfileWorkerStore
 
 
 class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index af7012702e..6adb19463a 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,19 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.push_rule import PushRulesWorkerStore
+from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
 
-from ._slaved_id_tracker import SlavedIdTracker
 from .events import SlavedEventStore
 
 
 class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
-    def __init__(self, db_conn, hs):
-        self._push_rules_stream_id_gen = SlavedIdTracker(
-            db_conn, "push_rules_stream", "stream_id"
-        )
-        super(SlavedPushRuleStore, self).__init__(db_conn, hs)
-
     def get_push_rules_stream_token(self):
         return (
             self._push_rules_stream_id_gen.get_current_token(),
@@ -36,18 +29,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedPushRuleStore, self).stream_positions()
-        result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
-        return result
-
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "push_rules":
             self._push_rules_stream_id_gen.advance(token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
                 self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedPushRuleStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 8eeb267d61..cb78b49acb 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,27 +14,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.pusher import PusherWorkerStore
+from synapse.storage.data_stores.main.pusher import PusherWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(SlavedPusherStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SlavedPusherStore, self).__init__(database, db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
 
-    def stream_positions(self):
-        result = super(SlavedPusherStore, self).stream_positions()
-        result["pushers"] = self._pushers_id_gen.get_current_token()
-        return result
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "pushers":
             self._pushers_id_gen.advance(token)
-        return super(SlavedPusherStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 91afa5a72b..be716cc558 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,7 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.receipts import ReceiptsWorkerStore
+from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
@@ -29,23 +30,18 @@ from ._slaved_id_tracker import SlavedIdTracker
 
 
 class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = SlavedIdTracker(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(SlavedReceiptsStore, self).__init__(db_conn, hs)
+        super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedReceiptsStore, self).stream_positions()
-        result["receipts"] = self._receipts_id_gen.get_current_token()
-        return result
-
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate_many((room_id,))
@@ -55,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
         self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "receipts":
             self._receipts_id_gen.advance(token)
             for row in rows:
@@ -64,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
                 )
                 self._receipts_stream_cache.entity_has_changed(row.room_id, token)
 
-        return super(SlavedReceiptsStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 408d91df1c..4b8553e250 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.registration import RegistrationWorkerStore
+from synapse.storage.data_stores.main.registration import RegistrationWorkerStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index f68b3378e3..8873bf37e5 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,15 +13,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.room import RoomWorkerStore
+from synapse.storage.data_stores.main.room import RoomWorkerStore
+from synapse.storage.database import Database
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
 
 class RoomStore(RoomWorkerStore, BaseSlavedStore):
-    def __init__(self, db_conn, hs):
-        super(RoomStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomStore, self).__init__(database, db_conn, hs)
         self._public_room_id_gen = SlavedIdTracker(
             db_conn, "public_room_list_stream", "stream_id"
         )
@@ -29,13 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(RoomStore, self).stream_positions()
-        result["public_rooms"] = self._public_room_id_gen.get_current_token()
-        return result
-
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "public_rooms":
             self._public_room_id_gen.advance(token)
 
-        return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index 3527beb3c9..ac88e6b8c3 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.transactions import TransactionStore
+from synapse.storage.data_stores.main.transactions import TransactionStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 81c2ea7ee9..523a1358d4 100644
--- a/synapse/replication/tcp/__init__.py
+++ b/synapse/replication/tcp/__init__.py
@@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst
 
 
 Structure of the module:
- * client.py   - the client classes used for workers to connect to master
+ * handler.py  - the classes used to handle sending/receiving commands to
+                 replication
  * command.py  - the definitions of all the valid commands
- * protocol.py - contains bot the client and server protocol implementations,
-                 these should not be used directly
- * resource.py - the server classes that accepts and handle client connections
- * streams.py  - the definitons of all the valid streams
+ * protocol.py - the TCP protocol classes
+ * resource.py - handles streaming stream updates to replications
+ * streams/    - the definitons of all the valid streams
 
+
+The general interaction of the classes are:
+
+        +---------------------+
+        | ReplicationStreamer |
+        +---------------------+
+                    |
+                    v
+        +---------------------------+     +----------------------+
+        | ReplicationCommandHandler |---->|ReplicationDataHandler|
+        +---------------------------+     +----------------------+
+                    | ^
+                    v |
+            +-------------+
+            | Protocols   |
+            | (TCP/redis) |
+            +-------------+
+
+Where the ReplicationDataHandler (or subclasses) handles incoming stream
+updates.
 """
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index a44ceb00e7..df29732f51 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -14,38 +14,55 @@
 # limitations under the License.
 """A replication client for use by synapse workers.
 """
-
+import heapq
 import logging
+from typing import TYPE_CHECKING, Dict, List, Tuple
 
-from twisted.internet import defer
+from twisted.internet.defer import Deferred
 from twisted.internet.protocol import ReconnectingClientFactory
 
-from .commands import (
-    FederationAckCommand,
-    InvalidateCacheCommand,
-    RemovePusherCommand,
-    UserIpCommand,
-    UserSyncCommand,
+from synapse.api.constants import EventTypes
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.streams.events import (
+    EventsStream,
+    EventsStreamEventRow,
+    EventsStreamRow,
 )
-from .protocol import ClientReplicationStreamProtocol
+from synapse.util.async_helpers import timeout_deferred
+from synapse.util.metrics import Measure
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
 
 logger = logging.getLogger(__name__)
 
 
-class ReplicationClientFactory(ReconnectingClientFactory):
+# How long we allow callers to wait for replication updates before timing out.
+_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
+
+
+class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
     """Factory for building connections to the master. Will reconnect if the
     connection is lost.
 
-    Accepts a handler that will be called when new data is available or data
-    is required.
+    Accepts a handler that is passed to `ClientReplicationStreamProtocol`.
     """
 
-    maxDelay = 30  # Try at least once every N seconds
+    initialDelay = 0.1
+    maxDelay = 1  # Try at least once every N seconds
 
-    def __init__(self, hs, client_name, handler):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        client_name: str,
+        command_handler: "ReplicationCommandHandler",
+    ):
         self.client_name = client_name
-        self.handler = handler
+        self.command_handler = command_handler
         self.server_name = hs.config.server_name
+        self.hs = hs
         self._clock = hs.get_clock()  # As self.clock is defined in super class
 
         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -56,7 +73,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     def buildProtocol(self, addr):
         logger.info("Connected to replication: %r", addr)
         return ClientReplicationStreamProtocol(
-            self.client_name, self.server_name, self._clock, self.handler
+            self.hs,
+            self.client_name,
+            self.server_name,
+            self._clock,
+            self.command_handler,
         )
 
     def clientConnectionLost(self, connector, reason):
@@ -68,162 +89,136 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
 
 
-class ReplicationClientHandler(object):
-    """A base handler that can be passed to the ReplicationClientFactory.
+class ReplicationDataHandler:
+    """Handles incoming stream updates from replication.
 
-    By default proxies incoming replication data to the SlaveStore.
+    This instance notifies the slave data store about updates. Can be subclassed
+    to handle updates in additional ways.
     """
 
-    def __init__(self, store):
-        self.store = store
-
-        # The current connection. None if we are currently (re)connecting
-        self.connection = None
-
-        # Any pending commands to be sent once a new connection has been
-        # established
-        self.pending_commands = []
-
-        # Map from string -> deferred, to wake up when receiveing a SYNC with
-        # the given string.
-        # Used for tests.
-        self.awaiting_syncs = {}
-
-        # The factory used to create connections.
-        self.factory = None
-
-    def start_replication(self, hs):
-        """Helper method to start a replication connection to the remote server
-        using TCP.
-        """
-        client_name = hs.config.worker_name
-        self.factory = ReplicationClientFactory(hs, client_name, self)
-        host = hs.config.worker_replication_host
-        port = hs.config.worker_replication_port
-        hs.get_reactor().connectTCP(host, port, self.factory)
-
-    def on_rdata(self, stream_name, token, rows):
+    def __init__(self, hs: "HomeServer"):
+        self.store = hs.get_datastore()
+        self.pusher_pool = hs.get_pusherpool()
+        self.notifier = hs.get_notifier()
+        self._reactor = hs.get_reactor()
+        self._clock = hs.get_clock()
+        self._streams = hs.get_replication_streams()
+        self._instance_name = hs.get_instance_name()
+
+        # Map from stream to list of deferreds waiting for the stream to
+        # arrive at a particular position. The lists are sorted by stream position.
+        self._streams_to_waiters = (
+            {}
+        )  # type: Dict[str, List[Tuple[int, Deferred[None]]]]
+
+    async def on_rdata(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
         """Called to handle a batch of replication data with a given stream token.
 
         By default this just pokes the slave store. Can be overridden in subclasses to
         handle more.
 
         Args:
-            stream_name (str): name of the replication stream for this batch of rows
-            token (int): stream token for this batch of rows
-            rows (list): a list of Stream.ROW_TYPE objects as returned by
-                Stream.parse_row.
-
-        Returns:
-            Deferred|None
-        """
-        logger.debug("Received rdata %s -> %s", stream_name, token)
-        return self.store.process_replication_rows(stream_name, token, rows)
-
-    def on_position(self, stream_name, token):
-        """Called when we get new position data. By default this just pokes
-        the slave store.
-
-        Can be overriden in subclasses to handle more.
-        """
-        return self.store.process_replication_rows(stream_name, token, [])
-
-    def on_sync(self, data):
-        """When we received a SYNC we wake up any deferreds that were waiting
-        for the sync with the given data.
-
-        Used by tests.
+            stream_name: name of the replication stream for this batch of rows
+            instance_name: the instance that wrote the rows.
+            token: stream token for this batch of rows
+            rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
-        d = self.awaiting_syncs.pop(data, None)
-        if d:
-            d.callback(data)
-
-    def get_streams_to_replicate(self):
-        """Called when a new connection has been established and we need to
-        subscribe to streams.
-
-        Returns a dictionary of stream name to token.
-        """
-        args = self.store.stream_positions()
-        user_account_data = args.pop("user_account_data", None)
-        room_account_data = args.pop("room_account_data", None)
-        if user_account_data:
-            args["account_data"] = user_account_data
-        elif room_account_data:
-            args["account_data"] = room_account_data
-
-        return args
-
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users. (Overriden by the synchrotron's only)
-        """
-        return []
-
-    def send_command(self, cmd):
-        """Send a command to master (when we get establish a connection if we
-        don't have one already.)
-        """
-        if self.connection:
-            self.connection.send_command(cmd)
-        else:
-            logger.warn("Queuing command as not connected: %r", cmd.NAME)
-            self.pending_commands.append(cmd)
-
-    def send_federation_ack(self, token):
-        """Ack data for the federation stream. This allows the master to drop
-        data stored purely in memory.
-        """
-        self.send_command(FederationAckCommand(token))
-
-    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
-        """Poke the master that a user has started/stopped syncing.
-        """
-        self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
-
-    def send_remove_pusher(self, app_id, push_key, user_id):
-        """Poke the master to remove a pusher for a user
-        """
-        cmd = RemovePusherCommand(app_id, push_key, user_id)
-        self.send_command(cmd)
-
-    def send_invalidate_cache(self, cache_func, keys):
-        """Poke the master to invalidate a cache.
-        """
-        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
-        self.send_command(cmd)
-
-    def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
-        """Tell the master that the user made a request.
+        self.store.process_replication_rows(stream_name, instance_name, token, rows)
+
+        if stream_name == EventsStream.NAME:
+            # We shouldn't get multiple rows per token for events stream, so
+            # we don't need to optimise this for multiple rows.
+            for row in rows:
+                if row.type != EventsStreamEventRow.TypeId:
+                    continue
+                assert isinstance(row, EventsStreamRow)
+
+                event = await self.store.get_event(
+                    row.data.event_id, allow_rejected=True
+                )
+                if event.rejected_reason:
+                    continue
+
+                extra_users = ()  # type: Tuple[str, ...]
+                if event.type == EventTypes.Member:
+                    extra_users = (event.state_key,)
+                max_token = self.store.get_room_max_stream_ordering()
+                self.notifier.on_new_room_event(event, token, max_token, extra_users)
+
+            await self.pusher_pool.on_new_notifications(token, token)
+
+        # Notify any waiting deferreds. The list is ordered by position so we
+        # just iterate through the list until we reach a position that is
+        # greater than the received row position.
+        waiting_list = self._streams_to_waiters.get(stream_name, [])
+
+        # Index of first item with a position after the current token, i.e we
+        # have called all deferreds before this index. If not overwritten by
+        # loop below means either a) no items in list so no-op or b) all items
+        # in list were called and so the list should be cleared. Setting it to
+        # `len(list)` works for both cases.
+        index_of_first_deferred_not_called = len(waiting_list)
+
+        for idx, (position, deferred) in enumerate(waiting_list):
+            if position <= token:
+                try:
+                    with PreserveLoggingContext():
+                        deferred.callback(None)
+                except Exception:
+                    # The deferred has been cancelled or timed out.
+                    pass
+            else:
+                # The list is sorted by position so we don't need to continue
+                # checking any further entries in the list.
+                index_of_first_deferred_not_called = idx
+                break
+
+        # Drop all entries in the waiting list that were called in the above
+        # loop. (This maintains the order so no need to resort)
+        waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+
+    async def on_position(self, stream_name: str, instance_name: str, token: int):
+        self.store.process_replication_rows(stream_name, instance_name, token, [])
+
+    def on_remote_server_up(self, server: str):
+        """Called when get a new REMOTE_SERVER_UP command."""
+
+    async def wait_for_stream_position(
+        self, instance_name: str, stream_name: str, position: int
+    ):
+        """Wait until this instance has received updates up to and including
+        the given stream position.
         """
-        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
-        self.send_command(cmd)
 
-    def await_sync(self, data):
-        """Returns a deferred that is resolved when we receive a SYNC command
-        with given data.
+        if instance_name == self._instance_name:
+            # We don't get told about updates written by this process, and
+            # anyway in that case we don't need to wait.
+            return
+
+        current_position = self._streams[stream_name].current_token(self._instance_name)
+        if position <= current_position:
+            # We're already past the position
+            return
+
+        # Create a new deferred that times out after N seconds, as we don't want
+        # to wedge here forever.
+        deferred = Deferred()
+        deferred = timeout_deferred(
+            deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
+        )
 
-        [Not currently] used by tests.
-        """
-        return self.awaiting_syncs.setdefault(data, defer.Deferred())
+        waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
 
-    def update_connection(self, connection):
-        """Called when a connection has been established (or lost with None).
-        """
-        self.connection = connection
-        if connection:
-            for cmd in self.pending_commands:
-                connection.send_command(cmd)
-            self.pending_commands = []
-
-    def finished_connecting(self):
-        """Called when we have successfully subscribed and caught up to all
-        streams we're interested in.
-        """
-        logger.info("Finished connecting to server")
+        # We insert into the list using heapq as it is more efficient than
+        # pushing then resorting each time.
+        heapq.heappush(waiting_list, (position, deferred))
 
-        # We don't reset the delay any earlier as otherwise if there is a
-        # problem during start up we'll end up tight looping connecting to the
-        # server.
-        self.factory.resetDelay()
+        # We measure here to get in flight counts and average waiting time.
+        with Measure(self._clock, "repl.wait_for_stream_position"):
+            logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
+            await make_deferred_yieldable(deferred)
+            logger.info(
+                "Finished waiting for repl stream %r to reach %s", stream_name, position
+            )
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 0ff2a7199f..c04f622816 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -17,50 +17,46 @@
 The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
 allowed to be sent by which side.
 """
-
+import abc
 import logging
 import platform
+from typing import Tuple, Type
 
 if platform.python_implementation() == "PyPy":
     import json
 
     _json_encoder = json.JSONEncoder()
 else:
-    import simplejson as json
+    import simplejson as json  # type: ignore[no-redef]  # noqa: F821
 
-    _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
+    _json_encoder = json.JSONEncoder(namedtuple_as_object=False)  # type: ignore[call-arg]  # noqa: F821
 
 logger = logging.getLogger(__name__)
 
 
-class Command(object):
+class Command(metaclass=abc.ABCMeta):
     """The base command class.
 
     All subclasses must set the NAME variable which equates to the name of the
     command on the wire.
 
     A full command line on the wire is constructed from `NAME + " " + to_line()`
-
-    The default implementation creates a command of form `<NAME> <data>`
     """
 
-    NAME = None
-
-    def __init__(self, data):
-        self.data = data
+    NAME = None  # type: str
 
     @classmethod
+    @abc.abstractmethod
     def from_line(cls, line):
         """Deserialises a line from the wire into this command. `line` does not
         include the command.
         """
-        return cls(line)
 
-    def to_line(self):
+    @abc.abstractmethod
+    def to_line(self) -> str:
         """Serialises the comamnd for the wire. Does not include the command
         prefix.
         """
-        return self.data
 
     def get_logcontext_id(self):
         """Get a suitable string for the logcontext when processing this command"""
@@ -69,7 +65,21 @@ class Command(object):
         return self.NAME
 
 
-class ServerCommand(Command):
+class _SimpleCommand(Command):
+    """An implementation of Command whose argument is just a 'data' string."""
+
+    def __init__(self, data):
+        self.data = data
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(line)
+
+    def to_line(self) -> str:
+        return self.data
+
+
+class ServerCommand(_SimpleCommand):
     """Sent by the server on new connection and includes the server_name.
 
     Format::
@@ -85,7 +95,7 @@ class RdataCommand(Command):
 
     Format::
 
-        RDATA <stream_name> <token> <row_json>
+        RDATA <stream_name> <instance_name> <token> <row_json>
 
     The `<token>` may either be a numeric stream id OR "batch". The latter case
     is used to support sending multiple updates with the same stream ID. This
@@ -95,33 +105,40 @@ class RdataCommand(Command):
     The client should batch all incoming RDATA with a token of "batch" (per
     stream_name) until it sees an RDATA with a numeric stream ID.
 
+    The `<instance_name>` is the source of the new data (usually "master").
+
     `<token>` of "batch" maps to the instance variable `token` being None.
 
     An example of a batched series of RDATA::
 
-        RDATA presence batch ["@foo:example.com", "online", ...]
-        RDATA presence batch ["@bar:example.com", "online", ...]
-        RDATA presence 59 ["@baz:example.com", "online", ...]
+        RDATA presence master batch ["@foo:example.com", "online", ...]
+        RDATA presence master batch ["@bar:example.com", "online", ...]
+        RDATA presence master 59 ["@baz:example.com", "online", ...]
     """
 
     NAME = "RDATA"
 
-    def __init__(self, stream_name, token, row):
+    def __init__(self, stream_name, instance_name, token, row):
         self.stream_name = stream_name
+        self.instance_name = instance_name
         self.token = token
         self.row = row
 
     @classmethod
     def from_line(cls, line):
-        stream_name, token, row_json = line.split(" ", 2)
+        stream_name, instance_name, token, row_json = line.split(" ", 3)
         return cls(
-            stream_name, None if token == "batch" else int(token), json.loads(row_json)
+            stream_name,
+            instance_name,
+            None if token == "batch" else int(token),
+            json.loads(row_json),
         )
 
     def to_line(self):
         return " ".join(
             (
                 self.stream_name,
+                self.instance_name,
                 str(self.token) if self.token is not None else "batch",
                 _json_encoder.encode(self.row),
             )
@@ -135,26 +152,34 @@ class PositionCommand(Command):
     """Sent by the server to tell the client the stream postition without
     needing to send an RDATA.
 
-    Sent to the client after all missing updates for a stream have been sent
-    to the client and they're now up to date.
+    Format::
+
+        POSITION <stream_name> <instance_name> <token>
+
+    On receipt of a POSITION command clients should check if they have missed
+    any updates, and if so then fetch them out of band.
+
+    The `<instance_name>` is the process that sent the command and is the source
+    of the stream.
     """
 
     NAME = "POSITION"
 
-    def __init__(self, stream_name, token):
+    def __init__(self, stream_name, instance_name, token):
         self.stream_name = stream_name
+        self.instance_name = instance_name
         self.token = token
 
     @classmethod
     def from_line(cls, line):
-        stream_name, token = line.split(" ", 1)
-        return cls(stream_name, int(token))
+        stream_name, instance_name, token = line.split(" ", 2)
+        return cls(stream_name, instance_name, int(token))
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token)))
+        return " ".join((self.stream_name, self.instance_name, str(self.token)))
 
 
-class ErrorCommand(Command):
+class ErrorCommand(_SimpleCommand):
     """Sent by either side if there was an ERROR. The data is a string describing
     the error.
     """
@@ -162,14 +187,14 @@ class ErrorCommand(Command):
     NAME = "ERROR"
 
 
-class PingCommand(Command):
+class PingCommand(_SimpleCommand):
     """Sent by either side as a keep alive. The data is arbitary (often timestamp)
     """
 
     NAME = "PING"
 
 
-class NameCommand(Command):
+class NameCommand(_SimpleCommand):
     """Sent by client to inform the server of the client's identity. The data
     is the name
     """
@@ -178,76 +203,63 @@ class NameCommand(Command):
 
 
 class ReplicateCommand(Command):
-    """Sent by the client to subscribe to the stream.
+    """Sent by the client to subscribe to streams.
 
     Format::
 
-        REPLICATE <stream_name> <token>
-
-    Where <token> may be either:
-        * a numeric stream_id to stream updates from
-        * "NOW" to stream all subsequent updates.
-
-    The <stream_name> can be "ALL" to subscribe to all known streams, in which
-    case the <token> must be set to "NOW", i.e.::
-
-        REPLICATE ALL NOW
+        REPLICATE
     """
 
     NAME = "REPLICATE"
 
-    def __init__(self, stream_name, token):
-        self.stream_name = stream_name
-        self.token = token
+    def __init__(self):
+        pass
 
     @classmethod
     def from_line(cls, line):
-        stream_name, token = line.split(" ", 1)
-        if token in ("NOW", "now"):
-            token = "NOW"
-        else:
-            token = int(token)
-        return cls(stream_name, token)
+        return cls()
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token)))
-
-    def get_logcontext_id(self):
-        return "REPLICATE-" + self.stream_name
+        return ""
 
 
 class UserSyncCommand(Command):
     """Sent by the client to inform the server that a user has started or
-    stopped syncing. Used to calculate presence on the master.
+    stopped syncing on this process.
+
+    This is used by the process handling presence (typically the master) to
+    calculate who is online and who is not.
 
     Includes a timestamp of when the last user sync was.
 
     Format::
 
-        USER_SYNC <user_id> <state> <last_sync_ms>
+        USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
 
-    Where <state> is either "start" or "stop"
+    Where <state> is either "start" or "end"
     """
 
     NAME = "USER_SYNC"
 
-    def __init__(self, user_id, is_syncing, last_sync_ms):
+    def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+        self.instance_id = instance_id
         self.user_id = user_id
         self.is_syncing = is_syncing
         self.last_sync_ms = last_sync_ms
 
     @classmethod
     def from_line(cls, line):
-        user_id, state, last_sync_ms = line.split(" ", 2)
+        instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
 
         if state not in ("start", "end"):
             raise Exception("Invalid USER_SYNC state %r" % (state,))
 
-        return cls(user_id, state == "start", int(last_sync_ms))
+        return cls(instance_id, user_id, state == "start", int(last_sync_ms))
 
     def to_line(self):
         return " ".join(
             (
+                self.instance_id,
                 self.user_id,
                 "start" if self.is_syncing else "end",
                 str(self.last_sync_ms),
@@ -255,6 +267,30 @@ class UserSyncCommand(Command):
         )
 
 
+class ClearUserSyncsCommand(Command):
+    """Sent by the client to inform the server that it should drop all
+    information about syncing users sent by the client.
+
+    Mainly used when client is about to shut down.
+
+    Format::
+
+        CLEAR_USER_SYNC <instance_id>
+    """
+
+    NAME = "CLEAR_USER_SYNC"
+
+    def __init__(self, instance_id):
+        self.instance_id = instance_id
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(line)
+
+    def to_line(self):
+        return self.instance_id
+
+
 class FederationAckCommand(Command):
     """Sent by the client when it has processed up to a given point in the
     federation stream. This allows the master to drop in-memory caches of the
@@ -280,14 +316,6 @@ class FederationAckCommand(Command):
         return str(self.token)
 
 
-class SyncCommand(Command):
-    """Used for testing. The client protocol implementation allows waiting
-    on a SYNC command with a specified data.
-    """
-
-    NAME = "SYNC"
-
-
 class RemovePusherCommand(Command):
     """Sent by the client to request the master remove the given pusher.
 
@@ -313,37 +341,6 @@ class RemovePusherCommand(Command):
         return " ".join((self.app_id, self.push_key, self.user_id))
 
 
-class InvalidateCacheCommand(Command):
-    """Sent by the client to invalidate an upstream cache.
-
-    THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
-    NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
-    Mainly used to invalidate destination retry timing caches.
-
-    Format::
-
-        INVALIDATE_CACHE <cache_func> <keys_json>
-
-    Where <keys_json> is a json list.
-    """
-
-    NAME = "INVALIDATE_CACHE"
-
-    def __init__(self, cache_func, keys):
-        self.cache_func = cache_func
-        self.keys = keys
-
-    @classmethod
-    def from_line(cls, line):
-        cache_func, keys_json = line.split(" ", 1)
-
-        return cls(cache_func, json.loads(keys_json))
-
-    def to_line(self):
-        return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
 class UserIpCommand(Command):
     """Sent periodically when a worker sees activity from a client.
 
@@ -386,25 +383,38 @@ class UserIpCommand(Command):
         )
 
 
+class RemoteServerUpCommand(_SimpleCommand):
+    """Sent when a worker has detected that a remote server is no longer
+    "down" and retry timings should be reset.
+
+    If sent from a client the server will relay to all other workers.
+
+    Format::
+
+        REMOTE_SERVER_UP <server>
+    """
+
+    NAME = "REMOTE_SERVER_UP"
+
+
+_COMMANDS = (
+    ServerCommand,
+    RdataCommand,
+    PositionCommand,
+    ErrorCommand,
+    PingCommand,
+    NameCommand,
+    ReplicateCommand,
+    UserSyncCommand,
+    FederationAckCommand,
+    RemovePusherCommand,
+    UserIpCommand,
+    RemoteServerUpCommand,
+    ClearUserSyncsCommand,
+)  # type: Tuple[Type[Command], ...]
+
 # Map of command name to command type.
-COMMAND_MAP = {
-    cmd.NAME: cmd
-    for cmd in (
-        ServerCommand,
-        RdataCommand,
-        PositionCommand,
-        ErrorCommand,
-        PingCommand,
-        NameCommand,
-        ReplicateCommand,
-        UserSyncCommand,
-        FederationAckCommand,
-        SyncCommand,
-        RemovePusherCommand,
-        InvalidateCacheCommand,
-        UserIpCommand,
-    )
-}
+COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
 
 # The commands the server is allowed to send
 VALID_SERVER_COMMANDS = (
@@ -413,7 +423,7 @@ VALID_SERVER_COMMANDS = (
     PositionCommand.NAME,
     ErrorCommand.NAME,
     PingCommand.NAME,
-    SyncCommand.NAME,
+    RemoteServerUpCommand.NAME,
 )
 
 # The commands the client is allowed to send
@@ -422,9 +432,28 @@ VALID_CLIENT_COMMANDS = (
     ReplicateCommand.NAME,
     PingCommand.NAME,
     UserSyncCommand.NAME,
+    ClearUserSyncsCommand.NAME,
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
-    InvalidateCacheCommand.NAME,
     UserIpCommand.NAME,
     ErrorCommand.NAME,
+    RemoteServerUpCommand.NAME,
 )
+
+
+def parse_command_from_line(line: str) -> Command:
+    """Parses a command from a received line.
+
+    Line should already be stripped of whitespace and be checked if blank.
+    """
+
+    idx = line.find(" ")
+    if idx >= 0:
+        cmd_name = line[:idx]
+        rest_of_line = line[idx + 1 :]
+    else:
+        cmd_name = line
+        rest_of_line = ""
+
+    cmd_cls = COMMAND_MAP[cmd_name]
+    return cmd_cls.from_line(rest_of_line)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
new file mode 100644
index 0000000000..cbcf46f3ae
--- /dev/null
+++ b/synapse/replication/tcp/handler.py
@@ -0,0 +1,596 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
+
+from prometheus_client import Counter
+
+from twisted.internet.protocol import ReconnectingClientFactory
+
+from synapse.metrics import LaterGauge
+from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
+from synapse.replication.tcp.commands import (
+    ClearUserSyncsCommand,
+    Command,
+    FederationAckCommand,
+    PositionCommand,
+    RdataCommand,
+    RemoteServerUpCommand,
+    RemovePusherCommand,
+    ReplicateCommand,
+    UserIpCommand,
+    UserSyncCommand,
+)
+from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.streams import (
+    STREAMS_MAP,
+    BackfillStream,
+    CachesStream,
+    EventsStream,
+    FederationStream,
+    Stream,
+)
+from synapse.util.async_helpers import Linearizer
+
+logger = logging.getLogger(__name__)
+
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = Counter(
+    "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
+)
+user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
+federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
+remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
+invalidate_cache_counter = Counter(
+    "synapse_replication_tcp_resource_invalidate_cache", ""
+)
+user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
+
+
+class ReplicationCommandHandler:
+    """Handles incoming commands from replication as well as sending commands
+    back out to connections.
+    """
+
+    def __init__(self, hs):
+        self._replication_data_handler = hs.get_replication_data_handler()
+        self._presence_handler = hs.get_presence_handler()
+        self._store = hs.get_datastore()
+        self._notifier = hs.get_notifier()
+        self._clock = hs.get_clock()
+        self._instance_id = hs.get_instance_id()
+        self._instance_name = hs.get_instance_name()
+
+        self._streams = {
+            stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+        }  # type: Dict[str, Stream]
+
+        # List of streams that this instance is the source of
+        self._streams_to_replicate = []  # type: List[Stream]
+
+        for stream in self._streams.values():
+            if stream.NAME == CachesStream.NAME:
+                # All workers can write to the cache invalidation stream.
+                self._streams_to_replicate.append(stream)
+                continue
+
+            if isinstance(stream, (EventsStream, BackfillStream)):
+                # Only add EventStream and BackfillStream as a source on the
+                # instance in charge of event persistence.
+                if hs.config.worker.writers.events == hs.get_instance_name():
+                    self._streams_to_replicate.append(stream)
+
+                continue
+
+            # Only add any other streams if we're on master.
+            if hs.config.worker_app is not None:
+                continue
+
+            if stream.NAME == FederationStream.NAME and hs.config.send_federation:
+                # We only support federation stream if federation sending
+                # has been disabled on the master.
+                continue
+
+            self._streams_to_replicate.append(stream)
+
+        self._position_linearizer = Linearizer(
+            "replication_position", clock=self._clock
+        )
+
+        # Map of stream to batched updates. See RdataCommand for info on how
+        # batching works.
+        self._pending_batches = {}  # type: Dict[str, List[Any]]
+
+        # The factory used to create connections.
+        self._factory = None  # type: Optional[ReconnectingClientFactory]
+
+        # The currently connected connections. (The list of places we need to send
+        # outgoing replication commands to.)
+        self._connections = []  # type: List[AbstractConnection]
+
+        # For each connection, the incoming streams that are coming from that connection
+        self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
+
+        LaterGauge(
+            "synapse_replication_tcp_resource_total_connections",
+            "",
+            [],
+            lambda: len(self._connections),
+        )
+
+        self._is_master = hs.config.worker_app is None
+
+        self._federation_sender = None
+        if self._is_master and not hs.config.send_federation:
+            self._federation_sender = hs.get_federation_sender()
+
+        self._server_notices_sender = None
+        if self._is_master:
+            self._server_notices_sender = hs.get_server_notices_sender()
+
+    def start_replication(self, hs):
+        """Helper method to start a replication connection to the remote server
+        using TCP.
+        """
+        if hs.config.redis.redis_enabled:
+            from synapse.replication.tcp.redis import (
+                RedisDirectTcpReplicationClientFactory,
+            )
+            import txredisapi
+
+            logger.info(
+                "Connecting to redis (host=%r port=%r)",
+                hs.config.redis_host,
+                hs.config.redis_port,
+            )
+
+            # First let's ensure that we have a ReplicationStreamer started.
+            hs.get_replication_streamer()
+
+            # We need two connections to redis, one for the subscription stream and
+            # one to send commands to (as you can't send further redis commands to a
+            # connection after SUBSCRIBE is called).
+
+            # First create the connection for sending commands.
+            outbound_redis_connection = txredisapi.lazyConnection(
+                host=hs.config.redis_host,
+                port=hs.config.redis_port,
+                password=hs.config.redis.redis_password,
+                reconnect=True,
+            )
+
+            # Now create the factory/connection for the subscription stream.
+            self._factory = RedisDirectTcpReplicationClientFactory(
+                hs, outbound_redis_connection
+            )
+            hs.get_reactor().connectTCP(
+                hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
+            )
+        else:
+            client_name = hs.get_instance_name()
+            self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
+            host = hs.config.worker_replication_host
+            port = hs.config.worker_replication_port
+            hs.get_reactor().connectTCP(host, port, self._factory)
+
+    def get_streams(self) -> Dict[str, Stream]:
+        """Get a map from stream name to all streams.
+        """
+        return self._streams
+
+    def get_streams_to_replicate(self) -> List[Stream]:
+        """Get a list of streams that this instances replicates.
+        """
+        return self._streams_to_replicate
+
+    async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+        self.send_positions_to_connection(conn)
+
+    def send_positions_to_connection(self, conn: AbstractConnection):
+        """Send current position of all streams this process is source of to
+        the connection.
+        """
+
+        # We respond with current position of all streams this instance
+        # replicates.
+        for stream in self.get_streams_to_replicate():
+            self.send_command(
+                PositionCommand(
+                    stream.NAME,
+                    self._instance_name,
+                    stream.current_token(self._instance_name),
+                )
+            )
+
+    async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
+        user_sync_counter.inc()
+
+        if self._is_master:
+            await self._presence_handler.update_external_syncs_row(
+                cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+            )
+
+    async def on_CLEAR_USER_SYNC(
+        self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+    ):
+        if self._is_master:
+            await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+
+    async def on_FEDERATION_ACK(
+        self, conn: AbstractConnection, cmd: FederationAckCommand
+    ):
+        federation_ack_counter.inc()
+
+        if self._federation_sender:
+            self._federation_sender.federation_ack(cmd.token)
+
+    async def on_REMOVE_PUSHER(
+        self, conn: AbstractConnection, cmd: RemovePusherCommand
+    ):
+        remove_pusher_counter.inc()
+
+        if self._is_master:
+            await self._store.delete_pusher_by_app_id_pushkey_user_id(
+                app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+            )
+
+            self._notifier.on_new_replication_data()
+
+    async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
+        user_ip_cache_counter.inc()
+
+        if self._is_master:
+            await self._store.insert_client_ip(
+                cmd.user_id,
+                cmd.access_token,
+                cmd.ip,
+                cmd.user_agent,
+                cmd.device_id,
+                cmd.last_seen,
+            )
+
+        if self._server_notices_sender:
+            await self._server_notices_sender.on_user_ip(cmd.user_id)
+
+    async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+        if cmd.instance_name == self._instance_name:
+            # Ignore RDATA that are just our own echoes
+            return
+
+        stream_name = cmd.stream_name
+        inbound_rdata_count.labels(stream_name).inc()
+
+        try:
+            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+        except Exception:
+            logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
+            raise
+
+        # We linearize here for two reasons:
+        #   1. so we don't try and concurrently handle multiple rows for the
+        #      same stream, and
+        #   2. so we don't race with getting a POSITION command and fetching
+        #      missing RDATA.
+        with await self._position_linearizer.queue(cmd.stream_name):
+            # make sure that we've processed a POSITION for this stream *on this
+            # connection*. (A POSITION on another connection is no good, as there
+            # is no guarantee that we have seen all the intermediate updates.)
+            sbc = self._streams_by_connection.get(conn)
+            if not sbc or stream_name not in sbc:
+                # Let's drop the row for now, on the assumption we'll receive a
+                # `POSITION` soon and we'll catch up correctly then.
+                logger.debug(
+                    "Discarding RDATA for unconnected stream %s -> %s",
+                    stream_name,
+                    cmd.token,
+                )
+                return
+
+            if cmd.token is None:
+                # I.e. this is part of a batch of updates for this stream (in
+                # which case batch until we get an update for the stream with a non
+                # None token).
+                self._pending_batches.setdefault(stream_name, []).append(row)
+            else:
+                # Check if this is the last of a batch of updates
+                rows = self._pending_batches.pop(stream_name, [])
+                rows.append(row)
+                await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+    async def on_rdata(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
+        """Called to handle a batch of replication data with a given stream token.
+
+        Args:
+            stream_name: name of the replication stream for this batch of rows
+            instance_name: the instance that wrote the rows.
+            token: stream token for this batch of rows
+            rows: a list of Stream.ROW_TYPE objects as returned by
+                Stream.parse_row.
+        """
+        logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
+        await self._replication_data_handler.on_rdata(
+            stream_name, instance_name, token, rows
+        )
+
+    async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+        if cmd.instance_name == self._instance_name:
+            # Ignore POSITION that are just our own echoes
+            return
+
+        logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
+
+        stream_name = cmd.stream_name
+        stream = self._streams.get(stream_name)
+        if not stream:
+            logger.error("Got POSITION for unknown stream: %s", stream_name)
+            return
+
+        # We protect catching up with a linearizer in case the replication
+        # connection reconnects under us.
+        with await self._position_linearizer.queue(stream_name):
+            # We're about to go and catch up with the stream, so remove from set
+            # of connected streams.
+            for streams in self._streams_by_connection.values():
+                streams.discard(stream_name)
+
+            # We clear the pending batches for the stream as the fetching of the
+            # missing updates below will fetch all rows in the batch.
+            self._pending_batches.pop(stream_name, [])
+
+            # Find where we previously streamed up to.
+            current_token = stream.current_token(cmd.instance_name)
+
+            # If the position token matches our current token then we're up to
+            # date and there's nothing to do. Otherwise, fetch all updates
+            # between then and now.
+            missing_updates = cmd.token != current_token
+            while missing_updates:
+                logger.info(
+                    "Fetching replication rows for '%s' between %i and %i",
+                    stream_name,
+                    current_token,
+                    cmd.token,
+                )
+                (
+                    updates,
+                    current_token,
+                    missing_updates,
+                ) = await stream.get_updates_since(
+                    cmd.instance_name, current_token, cmd.token
+                )
+
+                # TODO: add some tests for this
+
+                # Some streams return multiple rows with the same stream IDs,
+                # which need to be processed in batches.
+
+                for token, rows in _batch_updates(updates):
+                    await self.on_rdata(
+                        stream_name,
+                        cmd.instance_name,
+                        token,
+                        [stream.parse_row(row) for row in rows],
+                    )
+
+            logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+
+            # We've now caught up to position sent to us, notify handler.
+            await self._replication_data_handler.on_position(
+                cmd.stream_name, cmd.instance_name, cmd.token
+            )
+
+            self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+
+    async def on_REMOTE_SERVER_UP(
+        self, conn: AbstractConnection, cmd: RemoteServerUpCommand
+    ):
+        """"Called when get a new REMOTE_SERVER_UP command."""
+        self._replication_data_handler.on_remote_server_up(cmd.data)
+
+        self._notifier.notify_remote_server_up(cmd.data)
+
+        # We relay to all other connections to ensure every instance gets the
+        # notification.
+        #
+        # When configured to use redis we'll always only have one connection and
+        # so this is a no-op (all instances will have already received the same
+        # REMOTE_SERVER_UP command).
+        #
+        # For direct TCP connections this will relay to all other connections
+        # connected to us. When on master this will correctly fan out to all
+        # other direct TCP clients and on workers there'll only be the one
+        # connection to master.
+        #
+        # (The logic here should also be sound if we have a mix of Redis and
+        # direct TCP connections so long as there is only one traffic route
+        # between two instances, but that is not currently supported).
+        self.send_command(cmd, ignore_conn=conn)
+
+    def new_connection(self, connection: AbstractConnection):
+        """Called when we have a new connection.
+        """
+        self._connections.append(connection)
+
+        # If we are connected to replication as a client (rather than a server)
+        # we need to reset the reconnection delay on the client factory (which
+        # is used to do exponential back off when the connection drops).
+        #
+        # Ideally we would reset the delay when we've "fully established" the
+        # connection (for some definition thereof) to stop us from tightlooping
+        # on reconnection if something fails after this point and we drop the
+        # connection. Unfortunately, we don't really have a better definition of
+        # "fully established" than the connection being established.
+        if self._factory:
+            self._factory.resetDelay()
+
+        # Tell the other end if we have any users currently syncing.
+        currently_syncing = (
+            self._presence_handler.get_currently_syncing_users_for_replication()
+        )
+
+        now = self._clock.time_msec()
+        for user_id in currently_syncing:
+            connection.send_command(
+                UserSyncCommand(self._instance_id, user_id, True, now)
+            )
+
+    def lost_connection(self, connection: AbstractConnection):
+        """Called when a connection is closed/lost.
+        """
+        # we no longer need _streams_by_connection for this connection.
+        streams = self._streams_by_connection.pop(connection, None)
+        if streams:
+            logger.info(
+                "Lost replication connection; streams now disconnected: %s", streams
+            )
+        try:
+            self._connections.remove(connection)
+        except ValueError:
+            pass
+
+    def connected(self) -> bool:
+        """Do we have any replication connections open?
+
+        Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected.
+        """
+        return bool(self._connections)
+
+    def send_command(
+        self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+    ):
+        """Send a command to all connected connections.
+
+        Args:
+            cmd
+            ignore_conn: If set don't send command to the given connection.
+                Used when relaying commands from one connection to all others.
+        """
+        if self._connections:
+            for connection in self._connections:
+                if connection == ignore_conn:
+                    continue
+
+                try:
+                    connection.send_command(cmd)
+                except Exception:
+                    # We probably want to catch some types of exceptions here
+                    # and log them as warnings (e.g. connection gone), but I
+                    # can't find what those exception types they would be.
+                    logger.exception(
+                        "Failed to write command %s to connection %s",
+                        cmd.NAME,
+                        connection,
+                    )
+        else:
+            logger.warning("Dropping command as not connected: %r", cmd.NAME)
+
+    def send_federation_ack(self, token: int):
+        """Ack data for the federation stream. This allows the master to drop
+        data stored purely in memory.
+        """
+        self.send_command(FederationAckCommand(token))
+
+    def send_user_sync(
+        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+    ):
+        """Poke the master that a user has started/stopped syncing.
+        """
+        self.send_command(
+            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+        )
+
+    def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
+        """Poke the master to remove a pusher for a user
+        """
+        cmd = RemovePusherCommand(app_id, push_key, user_id)
+        self.send_command(cmd)
+
+    def send_user_ip(
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: str,
+        last_seen: int,
+    ):
+        """Tell the master that the user made a request.
+        """
+        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
+        self.send_command(cmd)
+
+    def send_remote_server_up(self, server: str):
+        self.send_command(RemoteServerUpCommand(server))
+
+    def stream_update(self, stream_name: str, token: str, data: Any):
+        """Called when a new update is available to stream to clients.
+
+        We need to check if the client is interested in the stream or not
+        """
+        self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
+
+
+UpdateToken = TypeVar("UpdateToken")
+UpdateRow = TypeVar("UpdateRow")
+
+
+def _batch_updates(
+    updates: Iterable[Tuple[UpdateToken, UpdateRow]]
+) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
+    """Collect stream updates with the same token together
+
+    Given a series of updates returned by Stream.get_updates_since(), collects
+    the updates which share the same stream_id together.
+
+    For example:
+
+        [(1, a), (1, b), (2, c), (3, d), (3, e)]
+
+    becomes:
+
+        [
+            (1, [a, b]),
+            (2, [c]),
+            (3, [d, e]),
+        ]
+    """
+
+    update_iter = iter(updates)
+
+    first_update = next(update_iter, None)
+    if first_update is None:
+        # empty input
+        return
+
+    current_batch_token = first_update[0]
+    current_batch = [first_update[1]]
+
+    for token, row in update_iter:
+        if token != current_batch_token:
+            # different token to the previous row: flush the previous
+            # batch and start anew
+            yield current_batch_token, current_batch
+            current_batch_token = token
+            current_batch = []
+
+        current_batch.append(row)
+
+    # flush the final batch
+    yield current_batch_token, current_batch
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 5ffdf2675d..4198eece71 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
     > PING 1490197665618
     < NAME synapse.app.appservice
     < PING 1490197665618
-    < REPLICATE events 1
-    < REPLICATE backfill 1
-    < REPLICATE caches 1
+    < REPLICATE
     > POSITION events 1
     > POSITION backfill 1
     > POSITION caches 1
@@ -48,45 +46,55 @@ indicate which side is sending, these are *not* included on the wire::
     > ERROR server stopping
     * connection closed by server *
 """
-
+import abc
 import fcntl
 import logging
 import struct
-from collections import defaultdict
-
-from six import iteritems, iterkeys
+from typing import TYPE_CHECKING, List
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.stringutils import random_string
-
-from .commands import (
-    COMMAND_MAP,
+from synapse.replication.tcp.commands import (
     VALID_CLIENT_COMMANDS,
     VALID_SERVER_COMMANDS,
+    Command,
     ErrorCommand,
     NameCommand,
     PingCommand,
-    PositionCommand,
-    RdataCommand,
     ReplicateCommand,
     ServerCommand,
-    SyncCommand,
-    UserSyncCommand,
+    parse_command_from_line,
 )
-from .streams import STREAMS_MAP
+from synapse.types import Collection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
+
+if TYPE_CHECKING:
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
+    from synapse.server import HomeServer
+
 
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
 )
 
+tcp_inbound_commands_counter = Counter(
+    "synapse_replication_tcp_protocol_inbound_commands",
+    "Number of commands received from replication, by command and name of process connected to",
+    ["command", "name"],
+)
+
+tcp_outbound_commands_counter = Counter(
+    "synapse_replication_tcp_protocol_outbound_commands",
+    "Number of commands sent to replication, by command and name of process connected to",
+    ["command", "name"],
+)
+
 # A list of all connected protocols. This allows us to send metrics about the
 # connections.
 connected_connections = []
@@ -115,7 +123,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     are only sent by the server.
 
     On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
-    command.
+    command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
 
     It also sends `PING` periodically, and correctly times out remote connections
     (if they send a `PING` command)
@@ -123,13 +131,17 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     delimiter = b"\n"
 
-    VALID_INBOUND_COMMANDS = []  # Valid commands we expect to receive
-    VALID_OUTBOUND_COMMANDS = []  # Valid commans we can send
+    # Valid commands we expect to receive
+    VALID_INBOUND_COMMANDS = []  # type: Collection[str]
+
+    # Valid commands we can send
+    VALID_OUTBOUND_COMMANDS = []  # type: Collection[str]
 
     max_line_buffer = 10000
 
-    def __init__(self, clock):
+    def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):
         self.clock = clock
+        self.command_handler = handler
 
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
@@ -143,14 +155,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.conn_id = random_string(5)  # To dedupe in case of name clashes.
 
         # List of pending commands to send once we've established the connection
-        self.pending_commands = []
+        self.pending_commands = []  # type: List[Command]
 
         # The LoopingCall for sending pings.
         self._send_ping_loop = None
 
-        self.inbound_commands_counter = defaultdict(int)
-        self.outbound_commands_counter = defaultdict(int)
-
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
 
@@ -169,6 +178,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # can time us out.
         self.send_command(PingCommand(self.clock.time_msec()))
 
+        self.command_handler.new_connection(self)
+
     def send_ping(self):
         """Periodically sends a ping and checks if we should close the connection
         due to the other side timing out.
@@ -196,60 +207,67 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 )
                 self.send_error("ping timeout")
 
-    def lineReceived(self, line):
+    def lineReceived(self, line: bytes):
         """Called when we've received a line
         """
         if line.strip() == "":
             # Ignore blank lines
             return
 
-        line = line.decode("utf-8")
-        cmd_name, rest_of_line = line.split(" ", 1)
+        linestr = line.decode("utf-8")
 
-        if cmd_name not in self.VALID_INBOUND_COMMANDS:
-            logger.error("[%s] invalid command %s", self.id(), cmd_name)
-            self.send_error("invalid command: %s", cmd_name)
+        try:
+            cmd = parse_command_from_line(linestr)
+        except Exception as e:
+            logger.exception("[%s] failed to parse line: %r", self.id(), linestr)
+            self.send_error("failed to parse line: %r (%r):" % (e, linestr))
             return
 
-        self.last_received_command = self.clock.time_msec()
+        if cmd.NAME not in self.VALID_INBOUND_COMMANDS:
+            logger.error("[%s] invalid command %s", self.id(), cmd.NAME)
+            self.send_error("invalid command: %s", cmd.NAME)
+            return
 
-        self.inbound_commands_counter[cmd_name] = (
-            self.inbound_commands_counter[cmd_name] + 1
-        )
+        self.last_received_command = self.clock.time_msec()
 
-        cmd_cls = COMMAND_MAP[cmd_name]
-        try:
-            cmd = cmd_cls.from_line(rest_of_line)
-        except Exception as e:
-            logger.exception(
-                "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
-            )
-            self.send_error(
-                "failed to parse line for  %r: %r (%r):" % (cmd_name, e, rest_of_line)
-            )
-            return
+        tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
 
         # Now lets try and call on_<CMD_NAME> function
         run_as_background_process(
             "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
         )
 
-    def handle_command(self, cmd):
+    async def handle_command(self, cmd: Command):
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>
+        First calls `self.on_<COMMAND>` if it exists, then calls
+        `self.command_handler.on_<COMMAND>` if it exists. This allows for
+        protocol level handling of commands (e.g. PINGs), before delegating to
+        the handler.
 
         Args:
-            cmd (synapse.replication.tcp.commands.Command): received command
-
-        Returns:
-            Deferred
+            cmd: received command
         """
-        handler = getattr(self, "on_%s" % (cmd.NAME,))
-        return handler(cmd)
+        handled = False
+
+        # First call any command handlers on this instance. These are for TCP
+        # specific handling.
+        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        # Then call out to the handler.
+        cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(self, cmd)
+            handled = True
+
+        if not handled:
+            logger.warning("Unhandled command: %r", cmd)
 
     def close(self):
-        logger.warn("[%s] Closing connection", self.id())
+        logger.warning("[%s] Closing connection", self.id())
         self.time_we_closed = self.clock.time_msec()
         self.transport.loseConnection()
         self.on_connection_closed()
@@ -278,9 +296,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             self._queue_command(cmd)
             return
 
-        self.outbound_commands_counter[cmd.NAME] = (
-            self.outbound_commands_counter[cmd.NAME] + 1
-        )
+        tcp_outbound_commands_counter.labels(cmd.NAME, self.name).inc()
+
         string = "%s %s" % (cmd.NAME, cmd.to_line())
         if "\n" in string:
             raise Exception("Unexpected newline in command: %r", string)
@@ -319,10 +336,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         for cmd in pending:
             self.send_command(cmd)
 
-    def on_PING(self, line):
+    async def on_PING(self, line):
         self.received_ping = True
 
-    def on_ERROR(self, cmd):
+    async def on_ERROR(self, cmd):
         logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
 
     def pauseProducing(self):
@@ -375,6 +392,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.state = ConnectionStates.CLOSED
         self.pending_commands = []
 
+        self.command_handler.lost_connection(self)
+
         if self.transport:
             self.transport.unregisterProducer()
 
@@ -401,264 +420,73 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
     VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
     VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
 
-    def __init__(self, server_name, clock, streamer):
-        BaseReplicationStreamProtocol.__init__(self, clock)  # Old style class
+    def __init__(
+        self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
+    ):
+        super().__init__(clock, handler)
 
         self.server_name = server_name
-        self.streamer = streamer
-
-        # The streams the client has subscribed to and is up to date with
-        self.replication_streams = set()
-
-        # The streams the client is currently subscribing to.
-        self.connecting_streams = set()
-
-        # Map from stream name to list of updates to send once we've finished
-        # subscribing the client to the stream.
-        self.pending_rdata = {}
 
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
-        BaseReplicationStreamProtocol.connectionMade(self)
-        self.streamer.new_connection(self)
+        super().connectionMade()
 
-    def on_NAME(self, cmd):
+    async def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
-    def on_USER_SYNC(self, cmd):
-        return self.streamer.on_user_sync(
-            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
-        )
-
-    def on_REPLICATE(self, cmd):
-        stream_name = cmd.stream_name
-        token = cmd.token
-
-        if stream_name == "ALL":
-            # Subscribe to all streams we're publishing to.
-            deferreds = [
-                run_in_background(self.subscribe_to_stream, stream, token)
-                for stream in iterkeys(self.streamer.streams_by_name)
-            ]
-
-            return make_deferred_yieldable(
-                defer.gatherResults(deferreds, consumeErrors=True)
-            )
-        else:
-            return self.subscribe_to_stream(stream_name, token)
-
-    def on_FEDERATION_ACK(self, cmd):
-        return self.streamer.federation_ack(cmd.token)
-
-    def on_REMOVE_PUSHER(self, cmd):
-        return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
-
-    def on_INVALIDATE_CACHE(self, cmd):
-        return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
-
-    def on_USER_IP(self, cmd):
-        return self.streamer.on_user_ip(
-            cmd.user_id,
-            cmd.access_token,
-            cmd.ip,
-            cmd.user_agent,
-            cmd.device_id,
-            cmd.last_seen,
-        )
-
-    @defer.inlineCallbacks
-    def subscribe_to_stream(self, stream_name, token):
-        """Subscribe the remote to a stream.
-
-        This invloves checking if they've missed anything and sending those
-        updates down if they have. During that time new updates for the stream
-        are queued and sent once we've sent down any missed updates.
-        """
-        self.replication_streams.discard(stream_name)
-        self.connecting_streams.add(stream_name)
-
-        try:
-            # Get missing updates
-            updates, current_token = yield self.streamer.get_stream_updates(
-                stream_name, token
-            )
-
-            # Send all the missing updates
-            for update in updates:
-                token, row = update[0], update[1]
-                self.send_command(RdataCommand(stream_name, token, row))
-
-            # We send a POSITION command to ensure that they have an up to
-            # date token (especially useful if we didn't send any updates
-            # above)
-            self.send_command(PositionCommand(stream_name, current_token))
-
-            # Now we can send any updates that came in while we were subscribing
-            pending_rdata = self.pending_rdata.pop(stream_name, [])
-            updates = []
-            for token, update in pending_rdata:
-                # If the token is null, it is part of a batch update. Batches
-                # are multiple updates that share a single token. To denote
-                # this, the token is set to None for all tokens in the batch
-                # except for the last. If we find a None token, we keep looking
-                # through tokens until we find one that is not None and then
-                # process all previous updates in the batch as if they had the
-                # final token.
-                if token is None:
-                    # Store this update as part of a batch
-                    updates.append(update)
-                    continue
-
-                if token <= current_token:
-                    # This update or batch of updates is older than
-                    # current_token, dismiss it
-                    updates = []
-                    continue
-
-                updates.append(update)
-
-                # Send all updates that are part of this batch with the
-                # found token
-                for update in updates:
-                    self.send_command(RdataCommand(stream_name, token, update))
-
-                # Clear stored updates
-                updates = []
-
-            # They're now fully subscribed
-            self.replication_streams.add(stream_name)
-        except Exception as e:
-            logger.exception("[%s] Failed to handle REPLICATE command", self.id())
-            self.send_error("failed to handle replicate: %r", e)
-        finally:
-            self.connecting_streams.discard(stream_name)
-
-    def stream_update(self, stream_name, token, data):
-        """Called when a new update is available to stream to clients.
-
-        We need to check if the client is interested in the stream or not
-        """
-        if stream_name in self.replication_streams:
-            # The client is subscribed to the stream
-            self.send_command(RdataCommand(stream_name, token, data))
-        elif stream_name in self.connecting_streams:
-            # The client is being subscribed to the stream
-            logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
-            self.pending_rdata.setdefault(stream_name, []).append((token, data))
-        else:
-            # The client isn't subscribed
-            logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
-
-    def send_sync(self, data):
-        self.send_command(SyncCommand(data))
-
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.streamer.lost_connection(self)
-
 
 class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
     VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
 
-    def __init__(self, client_name, server_name, clock, handler):
-        BaseReplicationStreamProtocol.__init__(self, clock)
+    def __init__(
+        self,
+        hs: "HomeServer",
+        client_name: str,
+        server_name: str,
+        clock: Clock,
+        command_handler: "ReplicationCommandHandler",
+    ):
+        super().__init__(clock, command_handler)
 
         self.client_name = client_name
         self.server_name = server_name
-        self.handler = handler
-
-        # Set of stream names that have been subscribe to, but haven't yet
-        # caught up with. This is used to track when the client has been fully
-        # connected to the remote.
-        self.streams_connecting = set()
-
-        # Map of stream to batched updates. See RdataCommand for info on how
-        # batching works.
-        self.pending_batches = {}
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
-        BaseReplicationStreamProtocol.connectionMade(self)
+        super().connectionMade()
 
         # Once we've connected subscribe to the necessary streams
-        for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
-            self.replicate(stream_name, token)
-
-        # Tell the server if we have any users currently syncing (should only
-        # happen on synchrotrons)
-        currently_syncing = self.handler.get_currently_syncing_users()
-        now = self.clock.time_msec()
-        for user_id in currently_syncing:
-            self.send_command(UserSyncCommand(user_id, True, now))
-
-        # We've now finished connecting to so inform the client handler
-        self.handler.update_connection(self)
+        self.replicate()
 
-        # This will happen if we don't actually subscribe to any streams
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
-
-    def on_SERVER(self, cmd):
+    async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
 
-    def on_RDATA(self, cmd):
-        stream_name = cmd.stream_name
-        inbound_rdata_count.labels(stream_name).inc()
-
-        try:
-            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
-        except Exception:
-            logger.exception(
-                "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
-            )
-            raise
-
-        if cmd.token is None:
-            # I.e. this is part of a batch of updates for this stream. Batch
-            # until we get an update for the stream with a non None token
-            self.pending_batches.setdefault(stream_name, []).append(row)
-        else:
-            # Check if this is the last of a batch of updates
-            rows = self.pending_batches.pop(stream_name, [])
-            rows.append(row)
-            return self.handler.on_rdata(stream_name, cmd.token, rows)
+    def replicate(self):
+        """Send the subscription request to the server
+        """
+        logger.info("[%s] Subscribing to replication streams", self.id())
 
-    def on_POSITION(self, cmd):
-        # When we get a `POSITION` command it means we've finished getting
-        # missing updates for the given stream, and are now up to date.
-        self.streams_connecting.discard(cmd.stream_name)
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
+        self.send_command(ReplicateCommand())
 
-        return self.handler.on_position(cmd.stream_name, cmd.token)
 
-    def on_SYNC(self, cmd):
-        return self.handler.on_sync(cmd.data)
+class AbstractConnection(abc.ABC):
+    """An interface for replication connections.
+    """
 
-    def replicate(self, stream_name, token):
-        """Send the subscription request to the server
+    @abc.abstractmethod
+    def send_command(self, cmd: Command):
+        """Send the command down the connection
         """
-        if stream_name not in STREAMS_MAP:
-            raise Exception("Invalid stream name %r" % (stream_name,))
-
-        logger.info(
-            "[%s] Subscribing to replication stream: %r from %r",
-            self.id(),
-            stream_name,
-            token,
-        )
-
-        self.streams_connecting.add(stream_name)
+        pass
 
-        self.send_command(ReplicateCommand(stream_name, token))
 
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.handler.update_connection(None)
+# This tells python that `BaseReplicationStreamProtocol` implements the
+# interface.
+AbstractConnection.register(BaseReplicationStreamProtocol)
 
 
 # The following simply registers metrics for the replication connections
@@ -696,7 +524,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
             op = SIOCINQ
         else:
             op = SIOCOUTQ
-        size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
+        size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
         return size
     return 0
 
@@ -721,31 +549,3 @@ tcp_transport_kernel_read_buffer = LaterGauge(
         for p in connected_connections
     },
 )
-
-
-tcp_inbound_commands = LaterGauge(
-    "synapse_replication_tcp_protocol_inbound_commands",
-    "",
-    ["command", "name"],
-    lambda: {
-        (k, p.name): count
-        for p in connected_connections
-        for k, count in iteritems(p.inbound_commands_counter)
-    },
-)
-
-tcp_outbound_commands = LaterGauge(
-    "synapse_replication_tcp_protocol_outbound_commands",
-    "",
-    ["command", "name"],
-    lambda: {
-        (k, p.name): count
-        for p in connected_connections
-        for k, count in iteritems(p.outbound_commands_counter)
-    },
-)
-
-# number of updates received for each RDATA stream
-inbound_rdata_count = Counter(
-    "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
-)
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
new file mode 100644
index 0000000000..e776b63183
--- /dev/null
+++ b/synapse/replication/tcp/redis.py
@@ -0,0 +1,215 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+import txredisapi
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import (
+    Command,
+    ReplicateCommand,
+    parse_command_from_line,
+)
+from synapse.replication.tcp.protocol import (
+    AbstractConnection,
+    tcp_inbound_commands_counter,
+    tcp_outbound_commands_counter,
+)
+
+if TYPE_CHECKING:
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+    """Connection to redis subscribed to replication stream.
+
+    This class fulfils two functions:
+
+    (a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis
+    connection, parsing *incoming* messages into replication commands, and passing them
+    to `ReplicationCommandHandler`
+
+    (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+    onto outbound_redis_connection.
+
+    Due to the vagaries of `txredisapi` we don't want to have a custom
+    constructor, so instead we expect the defined attributes below to be set
+    immediately after initialisation.
+
+    Attributes:
+        handler: The command handler to handle incoming commands.
+        stream_name: The *redis* stream name to subscribe to and publish from
+            (not anything to do with Synapse replication streams).
+        outbound_redis_connection: The connection to redis to use to send
+            commands.
+    """
+
+    handler = None  # type: ReplicationCommandHandler
+    stream_name = None  # type: str
+    outbound_redis_connection = None  # type: txredisapi.RedisProtocol
+
+    def connectionMade(self):
+        logger.info("Connected to redis")
+        super().connectionMade()
+        run_as_background_process("subscribe-replication", self._send_subscribe)
+
+    async def _send_subscribe(self):
+        # it's important to make sure that we only send the REPLICATE command once we
+        # have successfully subscribed to the stream - otherwise we might miss the
+        # POSITION response sent back by the other end.
+        logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
+        await make_deferred_yieldable(self.subscribe(self.stream_name))
+        logger.info(
+            "Successfully subscribed to redis stream, sending REPLICATE command"
+        )
+        self.handler.new_connection(self)
+        await self._async_send_command(ReplicateCommand())
+        logger.info("REPLICATE successfully sent")
+
+        # We send out our positions when there is a new connection in case the
+        # other side missed updates. We do this for Redis connections as the
+        # otherside won't know we've connected and so won't issue a REPLICATE.
+        self.handler.send_positions_to_connection(self)
+
+    def messageReceived(self, pattern: str, channel: str, message: str):
+        """Received a message from redis.
+        """
+
+        if message.strip() == "":
+            # Ignore blank lines
+            return
+
+        try:
+            cmd = parse_command_from_line(message)
+        except Exception:
+            logger.exception(
+                "Failed to parse replication line: %r", message,
+            )
+            return
+
+        # We use "redis" as the name here as we don't have 1:1 connections to
+        # remote instances.
+        tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
+
+        # Now lets try and call on_<CMD_NAME> function
+        run_as_background_process(
+            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
+        )
+
+    async def handle_command(self, cmd: Command):
+        """Handle a command we have received over the replication stream.
+
+        By default delegates to on_<COMMAND>, which should return an awaitable.
+
+        Args:
+            cmd: received command
+        """
+        handled = False
+
+        # First call any command handlers on this instance. These are for redis
+        # specific handling.
+        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        # Then call out to the handler.
+        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(self, cmd)
+            handled = True
+
+        if not handled:
+            logger.warning("Unhandled command: %r", cmd)
+
+    def connectionLost(self, reason):
+        logger.info("Lost connection to redis")
+        super().connectionLost(reason)
+        self.handler.lost_connection(self)
+
+    def send_command(self, cmd: Command):
+        """Send a command if connection has been established.
+
+        Args:
+            cmd (Command)
+        """
+        run_as_background_process("send-cmd", self._async_send_command, cmd)
+
+    async def _async_send_command(self, cmd: Command):
+        """Encode a replication command and send it over our outbound connection"""
+        string = "%s %s" % (cmd.NAME, cmd.to_line())
+        if "\n" in string:
+            raise Exception("Unexpected newline in command: %r", string)
+
+        encoded_string = string.encode("utf-8")
+
+        # We use "redis" as the name here as we don't have 1:1 connections to
+        # remote instances.
+        tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
+
+        await make_deferred_yieldable(
+            self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+        )
+
+
+class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+    """This is a reconnecting factory that connects to redis and immediately
+    subscribes to a stream.
+
+    Args:
+        hs
+        outbound_redis_connection: A connection to redis that will be used to
+            send outbound commands (this is seperate to the redis connection
+            used to subscribe).
+    """
+
+    maxDelay = 5
+    continueTrying = True
+    protocol = RedisSubscriber
+
+    def __init__(
+        self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
+    ):
+
+        super().__init__()
+
+        # This sets the password on the RedisFactory base class (as
+        # SubscriberFactory constructor doesn't pass it through).
+        self.password = hs.config.redis.redis_password
+
+        self.handler = hs.get_tcp_replication()
+        self.stream_name = hs.hostname
+
+        self.outbound_redis_connection = outbound_redis_connection
+
+    def buildProtocol(self, addr):
+        p = super().buildProtocol(addr)  # type: RedisSubscriber
+
+        # We do this here rather than add to the constructor of `RedisSubcriber`
+        # as to do so would involve overriding `buildProtocol` entirely, however
+        # the base method does some other things than just instantiating the
+        # protocol.
+        p.handler = self.handler
+        p.outbound_redis_connection = self.outbound_redis_connection
+        p.stream_name = self.stream_name
+        p.password = self.password
+
+        return p
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index d1e98428bc..41569305df 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -18,31 +18,17 @@
 import logging
 import random
 
-from six import itervalues
-
 from prometheus_client import Counter
 
-from twisted.internet import defer
 from twisted.internet.protocol import Factory
 
-from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.metrics import Measure, measure_func
-
-from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP
-from .streams.federation import FederationStream
+from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
     "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
 )
-user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
-federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
-remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter(
-    "synapse_replication_tcp_resource_invalidate_cache", ""
-)
-user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
 
 logger = logging.getLogger(__name__)
 
@@ -52,13 +38,23 @@ class ReplicationStreamProtocolFactory(Factory):
     """
 
     def __init__(self, hs):
-        self.streamer = ReplicationStreamer(hs)
+        self.command_handler = hs.get_tcp_replication()
         self.clock = hs.get_clock()
         self.server_name = hs.config.server_name
 
+        # If we've created a `ReplicationStreamProtocolFactory` then we're
+        # almost certainly registering a replication listener, so let's ensure
+        # that we've started a `ReplicationStreamer` instance to actually push
+        # data.
+        #
+        # (This is a bit of a weird place to do this, but the alternatives such
+        # as putting this in `HomeServer.setup()`, requires either passing the
+        # listener config again or always starting a `ReplicationStreamer`.)
+        hs.get_replication_streamer()
+
     def buildProtocol(self, addr):
         return ServerReplicationStreamProtocol(
-            self.server_name, self.clock, self.streamer
+            self.server_name, self.clock, self.command_handler
         )
 
 
@@ -71,66 +67,22 @@ class ReplicationStreamer(object):
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-        self.presence_handler = hs.get_presence_handler()
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
-        self._server_notices_sender = hs.get_server_notices_sender()
+        self._instance_name = hs.get_instance_name()
 
         self._replication_torture_level = hs.config.replication_torture_level
 
-        # Current connections.
-        self.connections = []
-
-        LaterGauge(
-            "synapse_replication_tcp_resource_total_connections",
-            "",
-            [],
-            lambda: len(self.connections),
-        )
-
-        # List of streams that clients can subscribe to.
-        # We only support federation stream if federation sending hase been
-        # disabled on the master.
-        self.streams = [
-            stream(hs)
-            for stream in itervalues(STREAMS_MAP)
-            if stream != FederationStream or not hs.config.send_federation
-        ]
-
-        self.streams_by_name = {stream.NAME: stream for stream in self.streams}
-
-        LaterGauge(
-            "synapse_replication_tcp_resource_connections_per_stream",
-            "",
-            ["stream_name"],
-            lambda: {
-                (stream_name,): len(
-                    [
-                        conn
-                        for conn in self.connections
-                        if stream_name in conn.replication_streams
-                    ]
-                )
-                for stream_name in self.streams_by_name
-            },
-        )
-
-        self.federation_sender = None
-        if not hs.config.send_federation:
-            self.federation_sender = hs.get_federation_sender()
-
         self.notifier.add_replication_callback(self.on_notifier_poke)
 
         # Keeps track of whether we are currently checking for updates
         self.is_looping = False
         self.pending_updates = False
 
-        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+        self.command_handler = hs.get_tcp_replication()
 
-    def on_shutdown(self):
-        # close all connections on shutdown
-        for conn in self.connections:
-            conn.send_error("server shutting down")
+        # Set of streams to replicate.
+        self.streams = self.command_handler.get_streams_to_replicate()
 
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
@@ -139,7 +91,7 @@ class ReplicationStreamer(object):
         This should get called each time new data is available, even if it
         is currently being executed, so that nothing gets missed
         """
-        if not self.connections:
+        if not self.command_handler.connected():
             # Don't bother if nothing is listening. We still need to advance
             # the stream tokens otherwise they'll fall beihind forever
             for stream in self.streams:
@@ -154,8 +106,7 @@ class ReplicationStreamer(object):
 
         run_as_background_process("replication_notifier", self._run_notifier_loop)
 
-    @defer.inlineCallbacks
-    def _run_notifier_loop(self):
+    async def _run_notifier_loop(self):
         self.is_looping = True
 
         try:
@@ -166,11 +117,6 @@ class ReplicationStreamer(object):
                 self.pending_updates = False
 
                 with Measure(self.clock, "repl.stream.get_updates"):
-                    # First we tell the streams that they should update their
-                    # current tokens.
-                    for stream in self.streams:
-                        stream.advance_current_token()
-
                     all_streams = self.streams
 
                     if self._replication_torture_level is not None:
@@ -180,11 +126,13 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.upto_token:
+                        if stream.last_token == stream.current_token(
+                            self._instance_name
+                        ):
                             continue
 
                         if self._replication_torture_level:
-                            yield self.clock.sleep(
+                            await self.clock.sleep(
                                 self._replication_torture_level / 1000.0
                             )
 
@@ -192,18 +140,17 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.upto_token,
+                            stream.current_token(self._instance_name),
                         )
                         try:
-                            updates, current_token = yield stream.get_updates()
+                            updates, current_token, limited = await stream.get_updates()
+                            self.pending_updates |= limited
                         except Exception:
                             logger.info("Failed to handle stream %s", stream.NAME)
                             raise
 
                         logger.debug(
-                            "Sending %d updates to %d connections",
-                            len(updates),
-                            len(self.connections),
+                            "Sending %d updates", len(updates),
                         )
 
                         if updates:
@@ -219,102 +166,19 @@ class ReplicationStreamer(object):
                         # token. See RdataCommand for more details.
                         batched_updates = _batch_updates(updates)
 
-                        for conn in self.connections:
-                            for token, row in batched_updates:
-                                try:
-                                    conn.stream_update(stream.NAME, token, row)
-                                except Exception:
-                                    logger.exception("Failed to replicate")
+                        for token, row in batched_updates:
+                            try:
+                                self.command_handler.stream_update(
+                                    stream.NAME, token, row
+                                )
+                            except Exception:
+                                logger.exception("Failed to replicate")
 
             logger.debug("No more pending updates, breaking poke loop")
         finally:
             self.pending_updates = False
             self.is_looping = False
 
-    @measure_func("repl.get_stream_updates")
-    def get_stream_updates(self, stream_name, token):
-        """For a given stream get all updates since token. This is called when
-        a client first subscribes to a stream.
-        """
-        stream = self.streams_by_name.get(stream_name, None)
-        if not stream:
-            raise Exception("unknown stream %s", stream_name)
-
-        return stream.get_updates_since(token)
-
-    @measure_func("repl.federation_ack")
-    def federation_ack(self, token):
-        """We've received an ack for federation stream from a client.
-        """
-        federation_ack_counter.inc()
-        if self.federation_sender:
-            self.federation_sender.federation_ack(token)
-
-    @measure_func("repl.on_user_sync")
-    @defer.inlineCallbacks
-    def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
-        """A client has started/stopped syncing on a worker.
-        """
-        user_sync_counter.inc()
-        yield self.presence_handler.update_external_syncs_row(
-            conn_id, user_id, is_syncing, last_sync_ms
-        )
-
-    @measure_func("repl.on_remove_pusher")
-    @defer.inlineCallbacks
-    def on_remove_pusher(self, app_id, push_key, user_id):
-        """A client has asked us to remove a pusher
-        """
-        remove_pusher_counter.inc()
-        yield self.store.delete_pusher_by_app_id_pushkey_user_id(
-            app_id=app_id, pushkey=push_key, user_id=user_id
-        )
-
-        self.notifier.on_new_replication_data()
-
-    @measure_func("repl.on_invalidate_cache")
-    def on_invalidate_cache(self, cache_func, keys):
-        """The client has asked us to invalidate a cache
-        """
-        invalidate_cache_counter.inc()
-        getattr(self.store, cache_func).invalidate(tuple(keys))
-
-    @measure_func("repl.on_user_ip")
-    @defer.inlineCallbacks
-    def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
-        """The client saw a user request
-        """
-        user_ip_cache_counter.inc()
-        yield self.store.insert_client_ip(
-            user_id, access_token, ip, user_agent, device_id, last_seen
-        )
-        yield self._server_notices_sender.on_user_ip(user_id)
-
-    def send_sync_to_all_connections(self, data):
-        """Sends a SYNC command to all clients.
-
-        Used in tests.
-        """
-        for conn in self.connections:
-            conn.send_sync(data)
-
-    def new_connection(self, connection):
-        """A new client connection has been established
-        """
-        self.connections.append(connection)
-
-    def lost_connection(self, connection):
-        """A client connection has been lost
-        """
-        try:
-            self.connections.remove(connection)
-        except ValueError:
-            pass
-
-        # We need to tell the presence handler that the connection has been
-        # lost so that it can handle any ongoing syncs on that connection.
-        self.presence_handler.update_external_syncs_clear(connection.conn_id)
-
 
 def _batch_updates(updates):
     """Takes a list of updates of form [(token, row)] and sets the token to
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 634f636dc9..d1a61c3314 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -25,25 +25,63 @@ Each stream is defined by the following information:
     update_function:    The function that returns a list of updates between two tokens
 """
 
-from . import _base, events, federation
+from synapse.replication.tcp.streams._base import (
+    AccountDataStream,
+    BackfillStream,
+    CachesStream,
+    DeviceListsStream,
+    GroupServerStream,
+    PresenceStream,
+    PublicRoomsStream,
+    PushersStream,
+    PushRulesStream,
+    ReceiptsStream,
+    Stream,
+    TagAccountDataStream,
+    ToDeviceStream,
+    TypingStream,
+    UserSignatureStream,
+)
+from synapse.replication.tcp.streams.events import EventsStream
+from synapse.replication.tcp.streams.federation import FederationStream
 
 STREAMS_MAP = {
     stream.NAME: stream
     for stream in (
-        events.EventsStream,
-        _base.BackfillStream,
-        _base.PresenceStream,
-        _base.TypingStream,
-        _base.ReceiptsStream,
-        _base.PushRulesStream,
-        _base.PushersStream,
-        _base.CachesStream,
-        _base.PublicRoomsStream,
-        _base.DeviceListsStream,
-        _base.ToDeviceStream,
-        federation.FederationStream,
-        _base.TagAccountDataStream,
-        _base.AccountDataStream,
-        _base.GroupServerStream,
+        EventsStream,
+        BackfillStream,
+        PresenceStream,
+        TypingStream,
+        ReceiptsStream,
+        PushRulesStream,
+        PushersStream,
+        CachesStream,
+        PublicRoomsStream,
+        DeviceListsStream,
+        ToDeviceStream,
+        FederationStream,
+        TagAccountDataStream,
+        AccountDataStream,
+        GroupServerStream,
+        UserSignatureStream,
     )
 }
+
+__all__ = [
+    "STREAMS_MAP",
+    "Stream",
+    "BackfillStream",
+    "PresenceStream",
+    "TypingStream",
+    "ReceiptsStream",
+    "PushRulesStream",
+    "PushersStream",
+    "CachesStream",
+    "PublicRoomsStream",
+    "DeviceListsStream",
+    "ToDeviceStream",
+    "TagAccountDataStream",
+    "AccountDataStream",
+    "GroupServerStream",
+    "UserSignatureStream",
+]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index f03111c259..4acefc8a96 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,102 +14,84 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
-import itertools
+import heapq
 import logging
 from collections import namedtuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+)
+
+import attr
+
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
 
-from twisted.internet import defer
+if TYPE_CHECKING:
+    import synapse.server
 
 logger = logging.getLogger(__name__)
 
+# the number of rows to request from an update_function.
+_STREAM_UPDATE_TARGET_ROW_COUNT = 100
 
-MAX_EVENTS_BEHIND = 10000
 
-BackfillStreamRow = namedtuple(
-    "BackfillStreamRow",
-    (
-        "event_id",  # str
-        "room_id",  # str
-        "type",  # str
-        "state_key",  # str, optional
-        "redacts",  # str, optional
-        "relates_to",  # str, optional
-    ),
-)
-PresenceStreamRow = namedtuple(
-    "PresenceStreamRow",
-    (
-        "user_id",  # str
-        "state",  # str
-        "last_active_ts",  # int
-        "last_federation_update_ts",  # int
-        "last_user_sync_ts",  # int
-        "status_msg",  # str
-        "currently_active",  # bool
-    ),
-)
-TypingStreamRow = namedtuple(
-    "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
-)
-ReceiptsStreamRow = namedtuple(
-    "ReceiptsStreamRow",
-    (
-        "room_id",  # str
-        "receipt_type",  # str
-        "user_id",  # str
-        "event_id",  # str
-        "data",  # dict
-    ),
-)
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
-PushersStreamRow = namedtuple(
-    "PushersStreamRow",
-    ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
-)
-CachesStreamRow = namedtuple(
-    "CachesStreamRow",
-    ("cache_func", "keys", "invalidation_ts"),  # str  # list(str)  # int
-)
-PublicRoomsStreamRow = namedtuple(
-    "PublicRoomsStreamRow",
-    (
-        "room_id",  # str
-        "visibility",  # str
-        "appservice_id",  # str, optional
-        "network_id",  # str, optional
-    ),
-)
-DeviceListsStreamRow = namedtuple(
-    "DeviceListsStreamRow", ("user_id", "destination")  # str  # str
-)
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
-TagAccountDataStreamRow = namedtuple(
-    "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
-)
-AccountDataStreamRow = namedtuple(
-    "AccountDataStream",
-    ("user_id", "room_id", "data_type", "data"),  # str  # str  # str  # dict
-)
-GroupsStreamRow = namedtuple(
-    "GroupsStreamRow",
-    ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
-)
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# The type of a stream update row, after JSON deserialisation, but before
+# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
+# just a row from a database query, though this is dependent on the stream in question.
+#
+StreamRow = TypeVar("StreamRow", bound=Tuple)
+
+# The type returned by the update_function of a stream, as well as get_updates(),
+# get_updates_since, etc.
+#
+# It consists of a triplet `(updates, new_last_token, limited)`, where:
+#   * `updates` is a list of `(token, row)` entries.
+#   * `new_last_token` is the new position in stream.
+#   * `limited` is whether there are more updates to fetch.
+#
+StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
+
+# The type of an update_function for a stream
+#
+# The arguments are:
+#
+#  * instance_name: the writer of the stream
+#  * from_token: the previous stream token: the starting point for fetching the
+#    updates
+#  * to_token: the new stream token: the point to get updates up to
+#  * target_row_count: a target for the number of rows to be returned.
+#
+# The update_function is expected to return up to _approximately_ target_row_count rows.
+# If there are more updates available, it should set `limited` in the result, and
+# it will be called again to get the next batch.
+#
+UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
 
 
 class Stream(object):
     """Base class for the streams.
 
     Provides a `get_updates()` function that returns new updates since the last
-    time it was called up until the point `advance_current_token` was called.
+    time it was called.
     """
 
-    NAME = None  # The name of the stream
-    ROW_TYPE = None  # The type of the row. Used by the default impl of parse_row.
-    _LIMITED = True  # Whether the update function takes a limit
+    NAME = None  # type: str  # The name of the stream
+    # The type of the row. Used by the default impl of parse_row.
+    ROW_TYPE = None  # type: Any
 
     @classmethod
-    def parse_row(cls, row):
+    def parse_row(cls, row: StreamRow):
         """Parse a row received over replication
 
         By default, assumes that the row data is an array object and passes its contents
@@ -123,102 +105,138 @@ class Stream(object):
         """
         return cls.ROW_TYPE(*row)
 
-    def __init__(self, hs):
-        # The token from which we last asked for updates
-        self.last_token = self.current_token()
-
-        # The token that we will get updates up to
-        self.upto_token = self.current_token()
+    def __init__(
+        self,
+        local_instance_name: str,
+        current_token_function: Callable[[str], Token],
+        update_function: UpdateFunction,
+    ):
+        """Instantiate a Stream
+
+        `current_token_function` and `update_function` are callbacks which
+        should be implemented by subclasses.
+
+        `current_token_function` takes an instance name, which is a writer to
+        the stream, and returns the position in the stream of the writer (as
+        viewed from the current process). On the writer process this is where
+        the writer has successfully written up to, whereas on other processes
+        this is the position which we have received updates up to over
+        replication. (Note that most streams have a single writer and so their
+        implementations ignore the instance name passed in).
+
+        `update_function` is called to get updates for this stream between a
+        pair of stream tokens. See the `UpdateFunction` type definition for more
+        info.
 
-    def advance_current_token(self):
-        """Updates `upto_token` to "now", which updates up until which point
-        get_updates[_since] will fetch rows till.
+        Args:
+            local_instance_name: The instance name of the current process
+            current_token_function: callback to get the current token, as above
+            update_function: callback go get stream updates, as above
         """
-        self.upto_token = self.current_token()
+        self.local_instance_name = local_instance_name
+        self.current_token = current_token_function
+        self.update_function = update_function
+
+        # The token from which we last asked for updates
+        self.last_token = self.current_token(self.local_instance_name)
 
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.upto_token = self.current_token()
-        self.last_token = self.upto_token
+        self.last_token = self.current_token(self.local_instance_name)
 
-    @defer.inlineCallbacks
-    def get_updates(self):
+    async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
-        since the stream was constructed if it hadn't been called before),
-        until the `upto_token`
+        since the stream was constructed if it hadn't been called before).
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        updates, current_token = yield self.get_updates_since(self.last_token)
+        current_token = self.current_token(self.local_instance_name)
+        updates, current_token, limited = await self.get_updates_since(
+            self.local_instance_name, self.last_token, current_token
+        )
         self.last_token = current_token
 
-        return updates, current_token
+        return updates, current_token, limited
 
-    @defer.inlineCallbacks
-    def get_updates_since(self, from_token):
+    async def get_updates_since(
+        self, instance_name: str, from_token: Token, upto_token: Token
+    ) -> StreamUpdateResult:
         """Like get_updates except allows specifying from when we should
         stream updates
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        if from_token in ("NOW", "now"):
-            return [], self.upto_token
-
-        current_token = self.upto_token
 
         from_token = int(from_token)
 
-        if from_token == current_token:
-            return [], current_token
+        if from_token == upto_token:
+            return [], upto_token, False
 
-        if self._LIMITED:
-            rows = yield self.update_function(
-                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
-            )
+        updates, upto_token, limited = await self.update_function(
+            instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+        )
+        return updates, upto_token, limited
 
-            # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
-            rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-        else:
-            rows = yield self.update_function(from_token, current_token)
 
+def current_token_without_instance(
+    current_token: Callable[[], int]
+) -> Callable[[str], int]:
+    """Takes a current token callback function for a single writer stream
+    that doesn't take an instance name parameter and wraps it in a function that
+    does accept an instance name parameter but ignores it.
+    """
+    return lambda instance_name: current_token()
+
+
+def db_query_to_update_function(
+    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> UpdateFunction:
+    """Wraps a db query function which returns a list of rows to make it
+    suitable for use as an `update_function` for the Stream class
+    """
+
+    async def update_function(instance_name, from_token, upto_token, limit):
+        rows = await query_function(from_token, upto_token, limit)
         updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) >= limit:
+            upto_token = updates[-1][0]
+            limited = True
 
-        # check we didn't get more rows than the limit.
-        # doing it like this allows the update_function to be a generator.
-        if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
-            raise Exception("stream %s has fallen behind" % (self.NAME))
+        return updates, upto_token, limited
 
-        return updates, current_token
+    return update_function
 
-    def current_token(self):
-        """Gets the current token of the underlying streams. Should be provided
-        by the sub classes
 
-        Returns:
-            int
-        """
-        raise NotImplementedError()
+def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
+    """Makes a suitable function for use as an `update_function` that queries
+    the master process for updates.
+    """
 
-    def update_function(self, from_token, current_token, limit=None):
-        """Get updates between from_token and to_token. If Stream._LIMITED is
-        True then limit is provided, otherwise it's not.
+    client = ReplicationGetStreamUpdates.make_client(hs)
 
-        Returns:
-            Deferred(list(tuple)): the first entry in the tuple is the token for
-                that update, and the rest of the tuple gets used to construct
-                a ``ROW_TYPE`` instance
-        """
-        raise NotImplementedError()
+    async def update_function(
+        instance_name: str, from_token: int, upto_token: int, limit: int
+    ) -> StreamUpdateResult:
+        result = await client(
+            instance_name=instance_name,
+            stream_name=stream_name,
+            from_token=from_token,
+            upto_token=upto_token,
+        )
+        return result["updates"], result["upto_token"], result["limited"]
+
+    return update_function
 
 
 class BackfillStream(Stream):
@@ -226,94 +244,170 @@ class BackfillStream(Stream):
     or it went from being an outlier to not.
     """
 
+    BackfillStreamRow = namedtuple(
+        "BackfillStreamRow",
+        (
+            "event_id",  # str
+            "room_id",  # str
+            "type",  # str
+            "state_key",  # str, optional
+            "redacts",  # str, optional
+            "relates_to",  # str, optional
+        ),
+    )
+
     NAME = "backfill"
     ROW_TYPE = BackfillStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        self.current_token = store.get_current_backfill_token
-        self.update_function = store.get_all_new_backfill_event_rows
-
-        super(BackfillStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_backfill_token),
+            db_query_to_update_function(store.get_all_new_backfill_event_rows),
+        )
 
 
 class PresenceStream(Stream):
+    PresenceStreamRow = namedtuple(
+        "PresenceStreamRow",
+        (
+            "user_id",  # str
+            "state",  # str
+            "last_active_ts",  # int
+            "last_federation_update_ts",  # int
+            "last_user_sync_ts",  # int
+            "status_msg",  # str
+            "currently_active",  # bool
+        ),
+    )
+
     NAME = "presence"
-    _LIMITED = False
     ROW_TYPE = PresenceStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        presence_handler = hs.get_presence_handler()
 
-        self.current_token = store.get_current_presence_token
-        self.update_function = presence_handler.get_all_presence_updates
+        if hs.config.worker_app is None:
+            # on the master, query the presence handler
+            presence_handler = hs.get_presence_handler()
+            update_function = db_query_to_update_function(
+                presence_handler.get_all_presence_updates
+            )
+        else:
+            # Query master process
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(PresenceStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_presence_token),
+            update_function,
+        )
 
 
 class TypingStream(Stream):
+    TypingStreamRow = namedtuple(
+        "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
+    )
+
     NAME = "typing"
-    _LIMITED = False
     ROW_TYPE = TypingStreamRow
 
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
 
-        self.current_token = typing_handler.get_current_token
-        self.update_function = typing_handler.get_all_typing_updates
+        if hs.config.worker_app is None:
+            # on the master, query the typing handler
+            update_function = db_query_to_update_function(
+                typing_handler.get_all_typing_updates
+            )
+        else:
+            # Query master process
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(TypingStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(typing_handler.get_current_token),
+            update_function,
+        )
 
 
 class ReceiptsStream(Stream):
+    ReceiptsStreamRow = namedtuple(
+        "ReceiptsStreamRow",
+        (
+            "room_id",  # str
+            "receipt_type",  # str
+            "user_id",  # str
+            "event_id",  # str
+            "data",  # dict
+        ),
+    )
+
     NAME = "receipts"
     ROW_TYPE = ReceiptsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_receipt_stream_id
-        self.update_function = store.get_all_updated_receipts
-
-        super(ReceiptsStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_max_receipt_stream_id),
+            db_query_to_update_function(store.get_all_updated_receipts),
+        )
 
 
 class PushRulesStream(Stream):
     """A user has changed their push rules
     """
 
+    PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
+
     NAME = "push_rules"
     ROW_TYPE = PushRulesStreamRow
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-        super(PushRulesStream, self).__init__(hs)
+        super(PushRulesStream, self).__init__(
+            hs.get_instance_name(), self._current_token, self._update_function
+        )
 
-    def current_token(self):
+    def _current_token(self, instance_name: str) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, to_token, limit):
-        rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
-        return [(row[0], row[2]) for row in rows]
+    async def _update_function(
+        self, instance_name: str, from_token: Token, to_token: Token, limit: int
+    ):
+        rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
+
+        limited = False
+        if len(rows) == limit:
+            to_token = rows[-1][0]
+            limited = True
+
+        return [(row[0], (row[2],)) for row in rows], to_token, limited
 
 
 class PushersStream(Stream):
     """A user has added/changed/removed a pusher
     """
 
+    PushersStreamRow = namedtuple(
+        "PushersStreamRow",
+        ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
+    )
+
     NAME = "pushers"
     ROW_TYPE = PushersStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_pushers_stream_token
-        self.update_function = store.get_all_updated_pushers_rows
-
-        super(PushersStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_pushers_stream_token),
+            db_query_to_update_function(store.get_all_updated_pushers_rows),
+        )
 
 
 class CachesStream(Stream):
@@ -321,120 +415,235 @@ class CachesStream(Stream):
     the cache on the workers
     """
 
+    @attr.s
+    class CachesStreamRow:
+        """Stream to inform workers they should invalidate their cache.
+
+        Attributes:
+            cache_func: Name of the cached function.
+            keys: The entry in the cache to invalidate. If None then will
+                invalidate all.
+            invalidation_ts: Timestamp of when the invalidation took place.
+        """
+
+        cache_func = attr.ib(type=str)
+        keys = attr.ib(type=Optional[List[Any]])
+        invalidation_ts = attr.ib(type=int)
+
     NAME = "caches"
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
-        store = hs.get_datastore()
+        self.store = hs.get_datastore()
+        super().__init__(
+            hs.get_instance_name(),
+            self.store.get_cache_stream_token,
+            self._update_function,
+        )
 
-        self.current_token = store.get_cache_stream_token
-        self.update_function = store.get_all_updated_caches
+    async def _update_function(
+        self, instance_name: str, from_token: int, upto_token: int, limit: int
+    ):
+        rows = await self.store.get_all_updated_caches(
+            instance_name, from_token, upto_token, limit
+        )
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) >= limit:
+            upto_token = updates[-1][0]
+            limited = True
 
-        super(CachesStream, self).__init__(hs)
+        return updates, upto_token, limited
 
 
 class PublicRoomsStream(Stream):
     """The public rooms list changed
     """
 
+    PublicRoomsStreamRow = namedtuple(
+        "PublicRoomsStreamRow",
+        (
+            "room_id",  # str
+            "visibility",  # str
+            "appservice_id",  # str, optional
+            "network_id",  # str, optional
+        ),
+    )
+
     NAME = "public_rooms"
     ROW_TYPE = PublicRoomsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_current_public_room_stream_id
-        self.update_function = store.get_all_new_public_rooms
-
-        super(PublicRoomsStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_public_room_stream_id),
+            db_query_to_update_function(store.get_all_new_public_rooms),
+        )
 
 
 class DeviceListsStream(Stream):
-    """Someone added/changed/removed a device
+    """Either a user has updated their devices or a remote server needs to be
+    told about a device update.
     """
 
+    @attr.s
+    class DeviceListsStreamRow:
+        entity = attr.ib(type=str)
+
     NAME = "device_lists"
-    _LIMITED = False
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_device_stream_token
-        self.update_function = store.get_all_device_list_changes_for_remotes
-
-        super(DeviceListsStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_device_stream_token),
+            db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+        )
 
 
 class ToDeviceStream(Stream):
     """New to_device messages for a client
     """
 
+    ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
+
     NAME = "to_device"
     ROW_TYPE = ToDeviceStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_to_device_stream_token
-        self.update_function = store.get_all_new_device_messages
-
-        super(ToDeviceStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_to_device_stream_token),
+            db_query_to_update_function(store.get_all_new_device_messages),
+        )
 
 
 class TagAccountDataStream(Stream):
     """Someone added/removed a tag for a room
     """
 
+    TagAccountDataStreamRow = namedtuple(
+        "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
+    )
+
     NAME = "tag_account_data"
     ROW_TYPE = TagAccountDataStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_account_data_stream_id
-        self.update_function = store.get_all_updated_tags
-
-        super(TagAccountDataStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_max_account_data_stream_id),
+            db_query_to_update_function(store.get_all_updated_tags),
+        )
 
 
 class AccountDataStream(Stream):
     """Global or per room account data was changed
     """
 
+    AccountDataStreamRow = namedtuple(
+        "AccountDataStream",
+        ("user_id", "room_id", "data_type"),  # str  # Optional[str]  # str
+    )
+
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
-    def __init__(self, hs):
+    def __init__(self, hs: "synapse.server.HomeServer"):
         self.store = hs.get_datastore()
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(self.store.get_max_account_data_stream_id),
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, to_token: int, limit: int
+    ) -> StreamUpdateResult:
+        limited = False
+        global_results = await self.store.get_updated_global_account_data(
+            from_token, to_token, limit
+        )
 
-        self.current_token = self.store.get_max_account_data_stream_id
+        # if the global results hit the limit, we'll need to limit the room results to
+        # the same stream token.
+        if len(global_results) >= limit:
+            to_token = global_results[-1][0]
+            limited = True
 
-        super(AccountDataStream, self).__init__(hs)
+        room_results = await self.store.get_updated_room_account_data(
+            from_token, to_token, limit
+        )
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, to_token, limit):
-        global_results, room_results = yield self.store.get_all_updated_account_data(
-            from_token, from_token, to_token, limit
+        # likewise, if the room results hit the limit, limit the global results to
+        # the same stream token.
+        if len(room_results) >= limit:
+            to_token = room_results[-1][0]
+            limited = True
+
+        # convert the global results to the right format, and limit them to the to_token
+        # at the same time
+        global_rows = (
+            (stream_id, (user_id, None, account_data_type))
+            for stream_id, user_id, account_data_type in global_results
+            if stream_id <= to_token
         )
 
-        results = list(room_results)
-        results.extend(
-            (stream_id, user_id, None, account_data_type, content)
-            for stream_id, user_id, account_data_type, content in global_results
+        # we know that the room_results are already limited to `to_token` so no need
+        # for a check on `stream_id` here.
+        room_rows = (
+            (stream_id, (user_id, room_id, account_data_type))
+            for stream_id, user_id, room_id, account_data_type in room_results
         )
 
-        return results
+        # We need to return a sorted list, so merge them together.
+        #
+        # Note: We order only by the stream ID to work around a bug where the
+        # same stream ID could appear in both `global_rows` and `room_rows`,
+        # leading to a comparison between the data tuples. The comparison could
+        # fail due to attempting to compare the `room_id` which results in a
+        # `TypeError` from comparing a `str` vs `None`.
+        updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0]))
+        return updates, to_token, limited
 
 
 class GroupServerStream(Stream):
+    GroupsStreamRow = namedtuple(
+        "GroupsStreamRow",
+        ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
+    )
+
     NAME = "groups"
     ROW_TYPE = GroupsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_group_stream_token),
+            db_query_to_update_function(store.get_all_groups_changes),
+        )
 
-        self.current_token = store.get_group_stream_token
-        self.update_function = store.get_all_groups_changes
 
-        super(GroupServerStream, self).__init__(hs)
+class UserSignatureStream(Stream):
+    """A user has signed their own device with their user-signing key
+    """
+
+    UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
+
+    NAME = "user_signature"
+    ROW_TYPE = UserSignatureStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_device_stream_token),
+            db_query_to_update_function(
+                store.get_all_user_signature_changes_for_remotes
+            ),
+        )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index d97669c886..f370390331 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,13 +13,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import heapq
+from collections import Iterable
+from typing import List, Tuple, Type
 
 import attr
 
-from twisted.internet import defer
-
-from ._base import Stream
+from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
 
 
 """Handling of the 'events' replication stream
@@ -63,7 +64,8 @@ class BaseEventsStreamRow(object):
     Specifies how to identify, serialize and deserialize the different types.
     """
 
-    TypeId = None  # Unique string that ids the type. Must be overriden in sub classes.
+    # Unique string that ids the type. Must be overriden in sub classes.
+    TypeId = None  # type: str
 
     @classmethod
     def from_data(cls, data):
@@ -99,9 +101,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
     event_id = attr.ib()  # str, optional
 
 
-TypeToRow = {
-    Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
-}
+_EventRows = (
+    EventsStreamEventRow,
+    EventsStreamCurrentStateRow,
+)  # type: Tuple[Type[BaseEventsStreamRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _EventRows}
 
 
 class EventsStream(Stream):
@@ -112,29 +117,107 @@ class EventsStream(Stream):
 
     def __init__(self, hs):
         self._store = hs.get_datastore()
-        self.current_token = self._store.get_current_events_token
-
-        super(EventsStream, self).__init__(hs)
-
-    @defer.inlineCallbacks
-    def update_function(self, from_token, current_token, limit=None):
-        event_rows = yield self._store.get_all_new_forward_event_rows(
-            from_token, current_token, limit
-        )
-        event_updates = (
-            (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
+        super().__init__(
+            hs.get_instance_name(),
+            current_token_without_instance(self._store.get_current_events_token),
+            self._update_function,
         )
 
-        state_rows = yield self._store.get_all_updated_current_state_deltas(
-            from_token, current_token, limit
-        )
-        state_updates = (
-            (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
+    async def _update_function(
+        self,
+        instance_name: str,
+        from_token: Token,
+        current_token: Token,
+        target_row_count: int,
+    ) -> StreamUpdateResult:
+
+        # the events stream merges together three separate sources:
+        #  * new events
+        #  * current_state changes
+        #  * events which were previously outliers, but have now been de-outliered.
+        #
+        # The merge operation is complicated by the fact that we only have a single
+        # "stream token" which is supposed to indicate how far we have got through
+        # all three streams. It's therefore no good to return rows 1-1000 from the
+        # "new events" table if the state_deltas are limited to rows 1-100 by the
+        # target_row_count.
+        #
+        # In other words: we must pick a new upper limit, and must return *all* rows
+        # up to that point for each of the three sources.
+        #
+        # Start by trying to split the target_row_count up. We expect to have a
+        # negligible number of ex-outliers, and a rough approximation based on recent
+        # traffic on sw1v.org shows that there are approximately the same number of
+        # event rows between a given pair of stream ids as there are state
+        # updates, so let's split our target_row_count among those two types. The target
+        # is only an approximation - it doesn't matter if we end up going a bit over it.
+
+        target_row_count //= 2
+
+        # now we fetch up to that many rows from the events table
+
+        event_rows = await self._store.get_all_new_forward_event_rows(
+            from_token, current_token, target_row_count
+        )  # type: List[Tuple]
+
+        # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
+        # that we know it is safe to just take upper_limit = event_rows[-1][0].
+        assert (
+            len(event_rows) <= target_row_count
+        ), "get_all_new_forward_event_rows did not honour row limit"
+
+        # if we hit the limit on event_updates, there's no point in going beyond the
+        # last stream_id in the batch for the other sources.
+
+        if len(event_rows) == target_row_count:
+            limited = True
+            upper_limit = event_rows[-1][0]  # type: int
+        else:
+            limited = False
+            upper_limit = current_token
+
+        # next up is the state delta table.
+        (
+            state_rows,
+            upper_limit,
+            state_rows_limited,
+        ) = await self._store.get_all_updated_current_state_deltas(
+            from_token, upper_limit, target_row_count
         )
 
-        all_updates = heapq.merge(event_updates, state_updates)
+        limited = limited or state_rows_limited
+
+        # finally, fetch the ex-outliers rows. We assume there are few enough of these
+        # not to bother with the limit.
+
+        ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
+            from_token, upper_limit
+        )  # type: List[Tuple]
+
+        # we now need to turn the raw database rows returned into tuples suitable
+        # for the replication protocol (basically, we add an identifier to
+        # distinguish the row type). At the same time, we can limit the event_rows
+        # to the max stream_id from state_rows.
 
-        return all_updates
+        event_updates = (
+            (stream_id, (EventsStreamEventRow.TypeId, rest))
+            for (stream_id, *rest) in event_rows
+            if stream_id <= upper_limit
+        )  # type: Iterable[Tuple[int, Tuple]]
+
+        state_updates = (
+            (stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
+            for (stream_id, *rest) in state_rows
+        )  # type: Iterable[Tuple[int, Tuple]]
+
+        ex_outliers_updates = (
+            (stream_id, (EventsStreamEventRow.TypeId, rest))
+            for (stream_id, *rest) in ex_outliers_rows
+        )  # type: Iterable[Tuple[int, Tuple]]
+
+        # we need to return a sorted list, so merge them together.
+        updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
+        return updates, upper_limit, limited
 
     @classmethod
     def parse_row(cls, row):
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index dc2484109d..9bcd13b009 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,14 +15,10 @@
 # limitations under the License.
 from collections import namedtuple
 
-from ._base import Stream
-
-FederationStreamRow = namedtuple(
-    "FederationStreamRow",
-    (
-        "type",  # str, the type of data as defined in the BaseFederationRows
-        "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
-    ),
+from synapse.replication.tcp.streams._base import (
+    Stream,
+    current_token_without_instance,
+    make_http_update_function,
 )
 
 
@@ -31,13 +27,47 @@ class FederationStream(Stream):
     sending disabled.
     """
 
+    FederationStreamRow = namedtuple(
+        "FederationStreamRow",
+        (
+            "type",  # str, the type of data as defined in the BaseFederationRows
+            "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
+        ),
+    )
+
     NAME = "federation"
     ROW_TYPE = FederationStreamRow
 
     def __init__(self, hs):
-        federation_sender = hs.get_federation_sender()
+        if hs.config.worker_app is None:
+            # master process: get updates from the FederationRemoteSendQueue.
+            # (if the master is configured to send federation itself, federation_sender
+            # will be a real FederationSender, which has stubs for current_token and
+            # get_replication_rows.)
+            federation_sender = hs.get_federation_sender()
+            current_token = current_token_without_instance(
+                federation_sender.get_current_token
+            )
+            update_function = federation_sender.get_replication_rows
+
+        elif hs.should_send_federation():
+            # federation sender: Query master process
+            update_function = make_http_update_function(hs, self.NAME)
+            current_token = self._stub_current_token
+
+        else:
+            # other worker: stub out the update function (we're not interested in
+            # any updates so when we get a POSITION we do nothing)
+            update_function = self._stub_update_function
+            current_token = self._stub_current_token
+
+        super().__init__(hs.get_instance_name(), current_token, update_function)
 
-        self.current_token = federation_sender.get_current_token
-        self.update_function = federation_sender.get_replication_rows
+    @staticmethod
+    def _stub_current_token(instance_name: str) -> int:
+        # dummy current-token method for use on workers
+        return 0
 
-        super(FederationStream, self).__init__(hs)
+    @staticmethod
+    async def _stub_update_function(instance_name, from_token, upto_token, limit):
+        return [], upto_token, False