summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2020-05-19 09:55:39 -0400
committerPatrick Cloke <patrickc@matrix.org>2020-05-19 09:55:39 -0400
commit13a82768ac4fdc1f5a24da9be5e51c5065382596 (patch)
treed14b2100876cde31e677d4f8a0db5036c882fcbe /synapse/replication
parentremove spurious changelog files (diff)
parentUpdate changelog based on feedback. (diff)
downloadsynapse-13a82768ac4fdc1f5a24da9be5e51c5065382596.tar.xz
Merge tag 'v1.13.0'
Synapse 1.13.0 (2020-05-19)
===========================

This release brings some potential changes necessary for certain
configurations of Synapse:

* If your Synapse is configured to use SSO and have a custom
  `sso_redirect_confirm_template_dir` configuration option set, you will need
  to duplicate the new `sso_auth_confirm.html`, `sso_auth_success.html` and
  `sso_account_deactivated.html` templates into that directory.
* Synapse plugins using the `complete_sso_login` method of
  `synapse.module_api.ModuleApi` should instead switch to the async/await
  version, `complete_sso_login_async`, which includes additional checks. The
  former version is now deprecated.
* A bug was introduced in Synapse 1.4.0 which could cause the room directory
  to be incomplete or empty if Synapse was upgraded directly from v1.2.1 or
  earlier, to versions between v1.4.0 and v1.12.x.

Please review [UPGRADE.rst](https://github.com/matrix-org/synapse/blob/master/UPGRADE.rst)
for more details on these changes and for general upgrade guidance.

Notice of change to the default `git` branch for Synapse
--------------------------------------------------------

With the release of Synapse 1.13.0, the default `git` branch for Synapse has
changed to `develop`, which is the development tip. This is more consistent with
common practice and modern `git` usage.

The `master` branch, which tracks the latest release, is still available. It is
recommended that developers and distributors who have scripts which run builds
using the default branch of Synapse should therefore consider pinning their
scripts to `master`.

Features
--------

- Extend the `web_client_location` option to accept an absolute URL to use as a redirect. Adds a warning when running the web client on the same hostname as homeserver. Contributed by Martin Milata. ([\#7006](https://github.com/matrix-org/synapse/issues/7006))
- Set `Referrer-Policy` header to `no-referrer` on media downloads. ([\#7009](https://github.com/matrix-org/synapse/issues/7009))
- Add support for running replication over Redis when using workers. ([\#7040](https://github.com/matrix-org/synapse/issues/7040), [\#7325](https://github.com/matrix-org/synapse/issues/7325), [\#7352](https://github.com/matrix-org/synapse/issues/7352), [\#7401](https://github.com/matrix-org/synapse/issues/7401), [\#7427](https://github.com/matrix-org/synapse/issues/7427), [\#7439](https://github.com/matrix-org/synapse/issues/7439), [\#7446](https://github.com/matrix-org/synapse/issues/7446), [\#7450](https://github.com/matrix-org/synapse/issues/7450), [\#7454](https://github.com/matrix-org/synapse/issues/7454))
- Admin API `POST /_synapse/admin/v1/join/<roomIdOrAlias>` to join users to a room like `auto_join_rooms` for creation of users. ([\#7051](https://github.com/matrix-org/synapse/issues/7051))
- Add options to prevent users from changing their profile or associated 3PIDs. ([\#7096](https://github.com/matrix-org/synapse/issues/7096))
- Support SSO in the user interactive authentication workflow. ([\#7102](https://github.com/matrix-org/synapse/issues/7102), [\#7186](https://github.com/matrix-org/synapse/issues/7186), [\#7279](https://github.com/matrix-org/synapse/issues/7279), [\#7343](https://github.com/matrix-org/synapse/issues/7343))
- Allow server admins to define and enforce a password policy ([MSC2000](https://github.com/matrix-org/matrix-doc/issues/2000)). ([\#7118](https://github.com/matrix-org/synapse/issues/7118))
- Improve the support for SSO authentication on the login fallback page. ([\#7152](https://github.com/matrix-org/synapse/issues/7152), [\#7235](https://github.com/matrix-org/synapse/issues/7235))
- Always whitelist the login fallback in the SSO configuration if `public_baseurl` is set. ([\#7153](https://github.com/matrix-org/synapse/issues/7153))
- Admin users are no longer required to be in a room to create an alias for it. ([\#7191](https://github.com/matrix-org/synapse/issues/7191))
- Require admin privileges to enable room encryption by default. This does not affect existing rooms. ([\#7230](https://github.com/matrix-org/synapse/issues/7230))
- Add a config option for specifying the value of the Accept-Language HTTP header when generating URL previews. ([\#7265](https://github.com/matrix-org/synapse/issues/7265))
- Allow `/requestToken` endpoints to hide the existence (or lack thereof) of 3PID associations on the homeserver. ([\#7315](https://github.com/matrix-org/synapse/issues/7315))
- Add a configuration setting to tweak the threshold for dummy events. ([\#7422](https://github.com/matrix-org/synapse/issues/7422))

Bugfixes
--------

- Don't attempt to use an invalid sqlite config if no database configuration is provided. Contributed by @nekatak. ([\#6573](https://github.com/matrix-org/synapse/issues/6573))
- Fix single-sign on with CAS systems: pass the same service URL when requesting the CAS ticket and when calling the `proxyValidate` URL. Contributed by @Naugrimm. ([\#6634](https://github.com/matrix-org/synapse/issues/6634))
- Fix missing field `default` when fetching user-defined push rules. ([\#6639](https://github.com/matrix-org/synapse/issues/6639))
- Improve error responses when accessing remote public room lists. ([\#6899](https://github.com/matrix-org/synapse/issues/6899), [\#7368](https://github.com/matrix-org/synapse/issues/7368))
- Transfer alias mappings on room upgrade. ([\#6946](https://github.com/matrix-org/synapse/issues/6946))
- Ensure that a user interactive authentication session is tied to a single request. ([\#7068](https://github.com/matrix-org/synapse/issues/7068), [\#7455](https://github.com/matrix-org/synapse/issues/7455))
- Fix a bug in the federation API which could cause occasional "Failed to get PDU" errors. ([\#7089](https://github.com/matrix-org/synapse/issues/7089))
- Return the proper error (`M_BAD_ALIAS`) when a non-existant canonical alias is provided. ([\#7109](https://github.com/matrix-org/synapse/issues/7109))
- Fix a bug which meant that groups updates were not correctly replicated between workers. ([\#7117](https://github.com/matrix-org/synapse/issues/7117))
- Fix starting workers when federation sending not split out. ([\#7133](https://github.com/matrix-org/synapse/issues/7133))
- Ensure `is_verified` is a boolean in responses to `GET /_matrix/client/r0/room_keys/keys`. Also warn the user if they forgot the `version` query param. ([\#7150](https://github.com/matrix-org/synapse/issues/7150))
- Fix error page being shown when a custom SAML handler attempted to redirect when processing an auth response. ([\#7151](https://github.com/matrix-org/synapse/issues/7151))
- Avoid importing `sqlite3` when using the postgres backend. Contributed by David Vo. ([\#7155](https://github.com/matrix-org/synapse/issues/7155))
- Fix excessive CPU usage by `prune_old_outbound_device_pokes` job. ([\#7159](https://github.com/matrix-org/synapse/issues/7159))
- Fix a bug which could cause outbound federation traffic to stop working if a client uploaded an incorrect e2e device signature. ([\#7177](https://github.com/matrix-org/synapse/issues/7177))
- Fix a bug which could cause incorrect 'cyclic dependency' error. ([\#7178](https://github.com/matrix-org/synapse/issues/7178))
- Fix a bug that could cause a user to be invited to a server notices (aka System Alerts) room without any notice being sent. ([\#7199](https://github.com/matrix-org/synapse/issues/7199))
- Fix some worker-mode replication handling not being correctly recorded in CPU usage stats. ([\#7203](https://github.com/matrix-org/synapse/issues/7203))
- Do not allow a deactivated user to login via SSO. ([\#7240](https://github.com/matrix-org/synapse/issues/7240), [\#7259](https://github.com/matrix-org/synapse/issues/7259))
- Fix --help command-line argument. ([\#7249](https://github.com/matrix-org/synapse/issues/7249))
- Fix room publish permissions not being checked on room creation. ([\#7260](https://github.com/matrix-org/synapse/issues/7260))
- Reject unknown session IDs during user interactive authentication instead of silently creating a new session. ([\#7268](https://github.com/matrix-org/synapse/issues/7268))
- Fix a SQL query introduced in Synapse 1.12.0 which could cause large amounts of logging to the postgres slow-query log. ([\#7274](https://github.com/matrix-org/synapse/issues/7274))
- Persist user interactive authentication sessions across workers and Synapse restarts. ([\#7302](https://github.com/matrix-org/synapse/issues/7302))
- Fixed backwards compatibility logic of the first value of `trusted_third_party_id_servers` being used for `account_threepid_delegates.email`, which occurs when the former, deprecated option is set and the latter is not. ([\#7316](https://github.com/matrix-org/synapse/issues/7316))
- Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. ([\#7337](https://github.com/matrix-org/synapse/issues/7337), [\#7358](https://github.com/matrix-org/synapse/issues/7358))
- Fix bad error handling that would cause Synapse to crash if it's provided with a YAML configuration file that's either empty or doesn't parse into a key-value map. ([\#7341](https://github.com/matrix-org/synapse/issues/7341))
- Fix incorrect metrics reporting for `renew_attestations` background task. ([\#7344](https://github.com/matrix-org/synapse/issues/7344))
- Prevent non-federating rooms from appearing in responses to federated `POST /publicRoom` requests when a filter was included. ([\#7367](https://github.com/matrix-org/synapse/issues/7367))
- Fix a bug which would cause the room durectory to be incorrectly populated if Synapse was upgraded directly from v1.2.1 or earlier to v1.4.0 or later. Note that this fix does not apply retrospectively; see the [upgrade notes](UPGRADE.rst#upgrading-to-v1130) for more information. ([\#7387](https://github.com/matrix-org/synapse/issues/7387))
- Fix bug in `EventContext.deserialize`. ([\#7393](https://github.com/matrix-org/synapse/issues/7393))
- Fix a long-standing bug which could cause messages not to be sent over federation, when state events with state keys matching user IDs (such as custom user statuses) were received. ([\#7376](https://github.com/matrix-org/synapse/issues/7376))
- Restore compatibility with non-compliant clients during the user interactive authentication process, fixing a problem introduced in v1.13.0rc1. ([\#7483](https://github.com/matrix-org/synapse/issues/7483))
- Hash passwords as early as possible during registration. ([\#7523](https://github.com/matrix-org/synapse/issues/7523))

Improved Documentation
----------------------

- Update Debian installation instructions to recommend installing the `virtualenv` package instead of `python3-virtualenv`. ([\#6892](https://github.com/matrix-org/synapse/issues/6892))
- Improve the documentation for database configuration. ([\#6988](https://github.com/matrix-org/synapse/issues/6988))
- Improve the documentation of application service configuration files. ([\#7091](https://github.com/matrix-org/synapse/issues/7091))
- Update pre-built package name for FreeBSD. ([\#7107](https://github.com/matrix-org/synapse/issues/7107))
- Update postgres docs with login troubleshooting information. ([\#7119](https://github.com/matrix-org/synapse/issues/7119))
- Clean up INSTALL.md a bit. ([\#7141](https://github.com/matrix-org/synapse/issues/7141))
- Add documentation for running a local CAS server for testing. ([\#7147](https://github.com/matrix-org/synapse/issues/7147))
- Improve README.md by being explicit about public IP recommendation for TURN relaying. ([\#7167](https://github.com/matrix-org/synapse/issues/7167))
- Fix a small typo in the `metrics_flags` config option. ([\#7171](https://github.com/matrix-org/synapse/issues/7171))
- Update the contributed documentation on managing synapse workers with systemd, and bring it into the core distribution. ([\#7234](https://github.com/matrix-org/synapse/issues/7234))
- Add documentation to the `password_providers` config option. Add known password provider implementations to docs. ([\#7238](https://github.com/matrix-org/synapse/issues/7238), [\#7248](https://github.com/matrix-org/synapse/issues/7248))
- Modify suggested nginx reverse proxy configuration to match Synapse's default file upload size. Contributed by @ProCycleDev. ([\#7251](https://github.com/matrix-org/synapse/issues/7251))
- Documentation of media_storage_providers options updated to avoid misunderstandings. Contributed by Tristan Lins. ([\#7272](https://github.com/matrix-org/synapse/issues/7272))
- Add documentation on monitoring workers with Prometheus. ([\#7357](https://github.com/matrix-org/synapse/issues/7357))
- Clarify endpoint usage in the users admin api documentation. ([\#7361](https://github.com/matrix-org/synapse/issues/7361))

Deprecations and Removals
-------------------------

- Remove nonfunctional `captcha_bypass_secret` option from `homeserver.yaml`. ([\#7137](https://github.com/matrix-org/synapse/issues/7137))

Internal Changes
----------------

- Add benchmarks for LruCache. ([\#6446](https://github.com/matrix-org/synapse/issues/6446))
- Return total number of users and profile attributes in admin users endpoint. Contributed by Awesome Technologies Innovationslabor GmbH. ([\#6881](https://github.com/matrix-org/synapse/issues/6881))
- Change device list streams to have one row per ID. ([\#7010](https://github.com/matrix-org/synapse/issues/7010))
- Remove concept of a non-limited stream. ([\#7011](https://github.com/matrix-org/synapse/issues/7011))
- Move catchup of replication streams logic to worker. ([\#7024](https://github.com/matrix-org/synapse/issues/7024), [\#7195](https://github.com/matrix-org/synapse/issues/7195), [\#7226](https://github.com/matrix-org/synapse/issues/7226), [\#7239](https://github.com/matrix-org/synapse/issues/7239), [\#7286](https://github.com/matrix-org/synapse/issues/7286), [\#7290](https://github.com/matrix-org/synapse/issues/7290), [\#7318](https://github.com/matrix-org/synapse/issues/7318), [\#7326](https://github.com/matrix-org/synapse/issues/7326), [\#7378](https://github.com/matrix-org/synapse/issues/7378), [\#7421](https://github.com/matrix-org/synapse/issues/7421))
- Convert some of synapse.rest.media to async/await. ([\#7110](https://github.com/matrix-org/synapse/issues/7110), [\#7184](https://github.com/matrix-org/synapse/issues/7184), [\#7241](https://github.com/matrix-org/synapse/issues/7241))
- De-duplicate / remove unused REST code for login and auth. ([\#7115](https://github.com/matrix-org/synapse/issues/7115))
- Convert `*StreamRow` classes to inner classes. ([\#7116](https://github.com/matrix-org/synapse/issues/7116))
- Clean up some LoggingContext code. ([\#7120](https://github.com/matrix-org/synapse/issues/7120), [\#7181](https://github.com/matrix-org/synapse/issues/7181), [\#7183](https://github.com/matrix-org/synapse/issues/7183), [\#7408](https://github.com/matrix-org/synapse/issues/7408), [\#7426](https://github.com/matrix-org/synapse/issues/7426))
- Add explicit `instance_id` for USER_SYNC commands and remove implicit `conn_id` usage. ([\#7128](https://github.com/matrix-org/synapse/issues/7128))
- Refactored the CAS authentication logic to a separate class. ([\#7136](https://github.com/matrix-org/synapse/issues/7136))
- Run replication streamers on workers. ([\#7146](https://github.com/matrix-org/synapse/issues/7146))
- Add tests for outbound device pokes. ([\#7157](https://github.com/matrix-org/synapse/issues/7157))
- Fix device list update stream ids going backward. ([\#7158](https://github.com/matrix-org/synapse/issues/7158))
- Use `stream.current_token()` and remove `stream_positions()`. ([\#7172](https://github.com/matrix-org/synapse/issues/7172))
- Move client command handling out of TCP protocol. ([\#7185](https://github.com/matrix-org/synapse/issues/7185))
- Move server command handling out of TCP protocol. ([\#7187](https://github.com/matrix-org/synapse/issues/7187))
- Fix consistency of HTTP status codes reported in log lines. ([\#7188](https://github.com/matrix-org/synapse/issues/7188))
- Only run one background database update at a time. ([\#7190](https://github.com/matrix-org/synapse/issues/7190))
- Remove sent outbound device list pokes from the database. ([\#7192](https://github.com/matrix-org/synapse/issues/7192))
- Add a background database update job to clear out duplicate `device_lists_outbound_pokes`. ([\#7193](https://github.com/matrix-org/synapse/issues/7193))
- Remove some extraneous debugging log lines. ([\#7207](https://github.com/matrix-org/synapse/issues/7207))
- Add explicit Python build tooling as dependencies for the snapcraft build. ([\#7213](https://github.com/matrix-org/synapse/issues/7213))
- Add typing information to federation server code. ([\#7219](https://github.com/matrix-org/synapse/issues/7219))
- Extend room admin api (`GET /_synapse/admin/v1/rooms`) with additional attributes. ([\#7225](https://github.com/matrix-org/synapse/issues/7225))
- Unblacklist '/upgrade creates a new room' sytest for workers. ([\#7228](https://github.com/matrix-org/synapse/issues/7228))
- Remove redundant checks on `daemonize` from synctl. ([\#7233](https://github.com/matrix-org/synapse/issues/7233))
- Upgrade jQuery to v3.4.1 on fallback login/registration pages. ([\#7236](https://github.com/matrix-org/synapse/issues/7236))
- Change log line that told user to implement onLogin/onRegister fallback js functions to a warning, instead of an info, so it's more visible. ([\#7237](https://github.com/matrix-org/synapse/issues/7237))
- Correct the parameters of a test fixture. Contributed by Isaiah Singletary. ([\#7243](https://github.com/matrix-org/synapse/issues/7243))
- Convert auth handler to async/await. ([\#7261](https://github.com/matrix-org/synapse/issues/7261))
- Add some unit tests for replication. ([\#7278](https://github.com/matrix-org/synapse/issues/7278))
- Improve typing annotations in `synapse.replication.tcp.streams.Stream`. ([\#7291](https://github.com/matrix-org/synapse/issues/7291))
- Reduce log verbosity of url cache cleanup tasks. ([\#7295](https://github.com/matrix-org/synapse/issues/7295))
- Fix sample SAML Service Provider configuration. Contributed by @frcl. ([\#7300](https://github.com/matrix-org/synapse/issues/7300))
- Fix StreamChangeCache to work with multiple entities changing on the same stream id. ([\#7303](https://github.com/matrix-org/synapse/issues/7303))
- Fix an incorrect import in IdentityHandler. ([\#7319](https://github.com/matrix-org/synapse/issues/7319))
- Reduce logging verbosity for successful federation requests. ([\#7321](https://github.com/matrix-org/synapse/issues/7321))
- Convert some federation handler code to async/await. ([\#7338](https://github.com/matrix-org/synapse/issues/7338))
- Fix collation for postgres for unit tests. ([\#7359](https://github.com/matrix-org/synapse/issues/7359))
- Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await. ([\#7363](https://github.com/matrix-org/synapse/issues/7363))
- Add an `instance_name` to `RDATA` and `POSITION` replication commands. ([\#7364](https://github.com/matrix-org/synapse/issues/7364))
- Thread through instance name to replication client. ([\#7369](https://github.com/matrix-org/synapse/issues/7369))
- Convert synapse.server_notices to async/await. ([\#7394](https://github.com/matrix-org/synapse/issues/7394))
- Convert synapse.notifier to async/await. ([\#7395](https://github.com/matrix-org/synapse/issues/7395))
- Fix issues with the Python package manifest. ([\#7404](https://github.com/matrix-org/synapse/issues/7404))
- Prevent methods in `synapse.handlers.auth` from polling the homeserver config every request. ([\#7420](https://github.com/matrix-org/synapse/issues/7420))
- Speed up fetching device lists changes when handling `/sync` requests. ([\#7423](https://github.com/matrix-org/synapse/issues/7423))
- Run group attestation renewal in series rather than parallel for performance. ([\#7442](https://github.com/matrix-org/synapse/issues/7442))
- Fix linting errors in new version of Flake8. ([\#7470](https://github.com/matrix-org/synapse/issues/7470))
- Update the version of dh-virtualenv we use to build debs, and add focal to the list of target distributions. ([\#7526](https://github.com/matrix-org/synapse/issues/7526))
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/_base.py19
-rw-r--r--synapse/replication/http/streams.py82
-rw-r--r--synapse/replication/slave/storage/_base.py25
-rw-r--r--synapse/replication/slave/storage/account_data.py8
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py5
-rw-r--r--synapse/replication/slave/storage/devices.py46
-rw-r--r--synapse/replication/slave/storage/events.py6
-rw-r--r--synapse/replication/slave/storage/groups.py5
-rw-r--r--synapse/replication/slave/storage/presence.py9
-rw-r--r--synapse/replication/slave/storage/push_rule.py5
-rw-r--r--synapse/replication/slave/storage/pushers.py6
-rw-r--r--synapse/replication/slave/storage/receipts.py5
-rw-r--r--synapse/replication/slave/storage/room.py5
-rw-r--r--synapse/replication/tcp/__init__.py30
-rw-r--r--synapse/replication/tcp/client.py201
-rw-r--r--synapse/replication/tcp/commands.py181
-rw-r--r--synapse/replication/tcp/handler.py569
-rw-r--r--synapse/replication/tcp/protocol.py459
-rw-r--r--synapse/replication/tcp/redis.py210
-rw-r--r--synapse/replication/tcp/resource.py227
-rw-r--r--synapse/replication/tcp/streams/__init__.py71
-rw-r--r--synapse/replication/tcp/streams/_base.py534
-rw-r--r--synapse/replication/tcp/streams/events.py114
-rw-r--r--synapse/replication/tcp/streams/federation.py42
25 files changed, 1715 insertions, 1151 deletions
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 28dbc6fcba..4613b2538c 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -21,6 +21,7 @@ from synapse.replication.http import (
     membership,
     register,
     send_event,
+    streams,
 )
 
 REPLICATION_PREFIX = "/_synapse/replication"
@@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
         login.register_servlets(hs, self)
         register.register_servlets(hs, self)
         devices.register_servlets(hs, self)
+        streams.register_servlets(hs, self)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 1be1ccbdf3..f88c80ae84 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,6 +16,7 @@
 import abc
 import logging
 import re
+from inspect import signature
 from typing import Dict, List, Tuple
 
 from six import raise_from
@@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
     must call `register` to register the path with the HTTP server.
 
     Requests can be sent by calling the client returned by `make_client`.
+    Requests are sent to master process by default, but can be sent to other
+    named processes by specifying an `instance_name` keyword argument.
 
     Attributes:
         NAME (str): A name for the endpoint, added to the path as well as used
@@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
                 hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
             )
 
+        # We reserve `instance_name` as a parameter to sending requests, so we
+        # assert here that sub classes don't try and use the name.
+        assert (
+            "instance_name" not in self.PATH_ARGS
+        ), "`instance_name` is a reserved paramater name"
+        assert (
+            "instance_name"
+            not in signature(self.__class__._serialize_payload).parameters
+        ), "`instance_name` is a reserved paramater name"
+
         assert self.METHOD in ("PUT", "POST", "GET")
 
     @abc.abstractmethod
@@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
 
         @trace(opname="outgoing_replication_request")
         @defer.inlineCallbacks
-        def send_request(**kwargs):
+        def send_request(instance_name="master", **kwargs):
+            # Currently we only support sending requests to master process.
+            if instance_name != "master":
+                raise Exception("Unknown instance")
+
             data = yield cls._serialize_payload(**kwargs)
 
             url_args = [
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
new file mode 100644
index 0000000000..0459f582bf
--- /dev/null
+++ b/synapse/replication/http/streams.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import parse_integer
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationGetStreamUpdates(ReplicationEndpoint):
+    """Fetches stream updates from a server. Used for streams not persisted to
+    the database, e.g. typing notifications.
+
+    The API looks like:
+
+        GET /_synapse/replication/get_repl_stream_updates/<stream name>?from_token=0&to_token=10
+
+        200 OK
+
+        {
+            updates: [ ... ],
+            upto_token: 10,
+            limited: False,
+        }
+
+    If there are more rows than can sensibly be returned in one lump, `limited` will be
+    set to true, and the caller should call again with a new `from_token`.
+
+    """
+
+    NAME = "get_repl_stream_updates"
+    PATH_ARGS = ("stream_name",)
+    METHOD = "GET"
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        self._instance_name = hs.get_instance_name()
+
+        # We pull the streams from the replication steamer (if we try and make
+        # them ourselves we end up in an import loop).
+        self.streams = hs.get_replication_streamer().get_streams()
+
+    @staticmethod
+    def _serialize_payload(stream_name, from_token, upto_token):
+        return {"from_token": from_token, "upto_token": upto_token}
+
+    async def _handle_request(self, request, stream_name):
+        stream = self.streams.get(stream_name)
+        if stream is None:
+            raise SynapseError(400, "Unknown stream")
+
+        from_token = parse_integer(request, "from_token", required=True)
+        upto_token = parse_integer(request, "upto_token", required=True)
+
+        updates, upto_token, limited = await stream.get_updates_since(
+            self._instance_name, from_token, upto_token
+        )
+
+        return (
+            200,
+            {"updates": updates, "upto_token": upto_token, "limited": limited},
+        )
+
+
+def register_servlets(hs, http_server):
+    ReplicationGetStreamUpdates(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f45cbd37a0..5d7c8871a4 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,12 +14,14 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Optional
+from typing import Optional
 
 import six
 
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.data_stores.main.cache import (
+    CURRENT_STATE_CACHE_NAME,
+    CacheInvalidationWorkerStore,
+)
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 
@@ -35,7 +37,7 @@ def __func__(inp):
         return inp.__func__
 
 
-class BaseSlavedStore(SQLBaseStore):
+class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: Database, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
@@ -47,18 +49,11 @@ class BaseSlavedStore(SQLBaseStore):
 
         self.hs = hs
 
-    def stream_positions(self) -> Dict[str, int]:
-        """
-        Get the current positions of all the streams this store wants to subscribe to
-
-        Returns:
-            map from stream name to the most recent update we have for
-            that stream (ie, the point we want to start replicating from)
-        """
-        pos = {}
+    def get_cache_stream_token(self):
         if self._cache_id_gen:
-            pos["caches"] = self._cache_id_gen.get_current_token()
-        return pos
+            return self._cache_id_gen.get_current_token()
+        else:
+            return 0
 
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "caches":
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index ebe94909cb..65e54b1c71 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedAccountDataStore, self).stream_positions()
-        position = self._account_data_id_gen.get_current_token()
-        result["user_account_data"] = position
-        result["room_account_data"] = position
-        result["tag_account_data"] = position
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "tag_account_data":
             self._account_data_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 0c237c6e0f..c923751e50 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
             expiry_ms=30 * 60 * 1000,
         )
 
-    def stream_positions(self):
-        result = super(SlavedDeviceInboxStore, self).stream_positions()
-        result["to_device"] = self._device_inbox_id_gen.get_current_token()
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "to_device":
             self._device_inbox_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 1c77687eea..58fb0eaae3 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
         self.hs = hs
 
         self._device_list_id_gen = SlavedIdTracker(
-            db_conn, "device_lists_stream", "stream_id"
+            db_conn,
+            "device_lists_stream",
+            "stream_id",
+            extra_tables=[
+                ("user_signature_stream", "stream_id"),
+                ("device_lists_outbound_pokes", "stream_id"),
+            ],
         )
         device_list_max = self._device_list_id_gen.get_current_token()
         self._device_list_stream_cache = StreamChangeCache(
@@ -42,36 +48,30 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             "DeviceListFederationStreamChangeCache", device_list_max
         )
 
-    def stream_positions(self):
-        result = super(SlavedDeviceStore, self).stream_positions()
-        # The user signature stream uses the same stream ID generator as the
-        # device list stream, so set them both to the device list ID
-        # generator's current token.
-        current_token = self._device_list_id_gen.get_current_token()
-        result[DeviceListsStream.NAME] = current_token
-        result[UserSignatureStream.NAME] = current_token
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
-            for row in rows:
-                self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+            self._invalidate_caches_for_devices(token, rows)
         elif stream_name == UserSignatureStream.NAME:
+            self._device_list_id_gen.advance(token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super(SlavedDeviceStore, self).process_replication_rows(
             stream_name, token, rows
         )
 
-    def _invalidate_caches_for_devices(self, token, user_id, destination):
-        self._device_list_stream_cache.entity_has_changed(user_id, token)
-
-        if destination:
-            self._device_list_federation_stream_cache.entity_has_changed(
-                destination, token
-            )
+    def _invalidate_caches_for_devices(self, token, rows):
+        for row in rows:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
+            if row.entity.startswith("@"):
+                self._device_list_stream_cache.entity_has_changed(row.entity, token)
+                self.get_cached_devices_for_user.invalidate((row.entity,))
+                self._get_cached_user_device.invalidate_many((row.entity,))
+                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
 
-        self.get_cached_devices_for_user.invalidate((user_id,))
-        self._get_cached_user_device.invalidate_many((user_id,))
-        self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+            else:
+                self._device_list_federation_stream_cache.entity_has_changed(
+                    row.entity, token
+                )
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index e73342c657..15011259df 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -93,12 +93,6 @@ class SlavedEventStore(
     def get_room_min_stream_ordering(self):
         return self._backfill_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedEventStore, self).stream_positions()
-        result["events"] = self._stream_id_gen.get_current_token()
-        result["backfill"] = -self._backfill_id_gen.get_current_token()
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "events":
             self._stream_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 2d4fd08cf5..01bcf0e882 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
     def get_group_stream_token(self):
         return self._group_updates_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedGroupServerStore, self).stream_positions()
-        result["groups"] = self._group_updates_id_gen.get_current_token()
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "groups":
             self._group_updates_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index ad8f0c15a9..fae3125072 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedPresenceStore, self).stream_positions()
-
-        if self.hs.config.use_presence:
-            position = self._presence_id_gen.get_current_token()
-            result["presence"] = position
-
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "presence":
             self._presence_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index eebd5a1fb6..6138796da4 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedPushRuleStore, self).stream_positions()
-        result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "push_rules":
             self._push_rules_stream_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index f22c2d44a3..67be337945 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -28,10 +28,8 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
 
-    def stream_positions(self):
-        result = super(SlavedPusherStore, self).stream_positions()
-        result["pushers"] = self._pushers_id_gen.get_current_token()
-        return result
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
 
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "pushers":
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index d40dc6e1f5..993432edcb 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedReceiptsStore, self).stream_positions()
-        result["receipts"] = self._receipts_id_gen.get_current_token()
-        return result
-
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate_many((room_id,))
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 3a20f45316..10dda8708f 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(RoomStore, self).stream_positions()
-        result["public_rooms"] = self._public_room_id_gen.get_current_token()
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "public_rooms":
             self._public_room_id_gen.advance(token)
diff --git a/synapse/replication/tcp/__init__.py b/synapse/replication/tcp/__init__.py
index 81c2ea7ee9..523a1358d4 100644
--- a/synapse/replication/tcp/__init__.py
+++ b/synapse/replication/tcp/__init__.py
@@ -20,11 +20,31 @@ Further details can be found in docs/tcp_replication.rst
 
 
 Structure of the module:
- * client.py   - the client classes used for workers to connect to master
+ * handler.py  - the classes used to handle sending/receiving commands to
+                 replication
  * command.py  - the definitions of all the valid commands
- * protocol.py - contains bot the client and server protocol implementations,
-                 these should not be used directly
- * resource.py - the server classes that accepts and handle client connections
- * streams.py  - the definitons of all the valid streams
+ * protocol.py - the TCP protocol classes
+ * resource.py - handles streaming stream updates to replications
+ * streams/    - the definitons of all the valid streams
 
+
+The general interaction of the classes are:
+
+        +---------------------+
+        | ReplicationStreamer |
+        +---------------------+
+                    |
+                    v
+        +---------------------------+     +----------------------+
+        | ReplicationCommandHandler |---->|ReplicationDataHandler|
+        +---------------------------+     +----------------------+
+                    | ^
+                    v |
+            +-------------+
+            | Protocols   |
+            | (TCP/redis) |
+            +-------------+
+
+Where the ReplicationDataHandler (or subclasses) handles incoming stream
+updates.
 """
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..3bbf3c3569 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,45 +16,40 @@
 """
 
 import logging
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING
 
-from twisted.internet import defer
 from twisted.internet.protocol import ReconnectingClientFactory
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.protocol import (
-    AbstractReplicationClientHandler,
-    ClientReplicationStreamProtocol,
-)
-
-from .commands import (
-    Command,
-    FederationAckCommand,
-    InvalidateCacheCommand,
-    RemoteServerUpCommand,
-    RemovePusherCommand,
-    UserIpCommand,
-    UserSyncCommand,
-)
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
 
 logger = logging.getLogger(__name__)
 
 
-class ReplicationClientFactory(ReconnectingClientFactory):
+class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
     """Factory for building connections to the master. Will reconnect if the
     connection is lost.
 
-    Accepts a handler that will be called when new data is available or data
-    is required.
+    Accepts a handler that is passed to `ClientReplicationStreamProtocol`.
     """
 
     initialDelay = 0.1
     maxDelay = 1  # Try at least once every N seconds
 
-    def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        client_name: str,
+        command_handler: "ReplicationCommandHandler",
+    ):
         self.client_name = client_name
-        self.handler = handler
+        self.command_handler = command_handler
         self.server_name = hs.config.server_name
+        self.hs = hs
         self._clock = hs.get_clock()  # As self.clock is defined in super class
 
         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +60,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     def buildProtocol(self, addr):
         logger.info("Connected to replication: %r", addr)
         return ClientReplicationStreamProtocol(
-            self.client_name, self.server_name, self._clock, self.handler
+            self.hs,
+            self.client_name,
+            self.server_name,
+            self._clock,
+            self.command_handler,
         )
 
     def clientConnectionLost(self, connector, reason):
@@ -77,168 +76,34 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
 
 
-class ReplicationClientHandler(AbstractReplicationClientHandler):
-    """A base handler that can be passed to the ReplicationClientFactory.
+class ReplicationDataHandler:
+    """Handles incoming stream updates from replication.
 
-    By default proxies incoming replication data to the SlaveStore.
+    This instance notifies the slave data store about updates. Can be subclassed
+    to handle updates in additional ways.
     """
 
     def __init__(self, store: BaseSlavedStore):
         self.store = store
 
-        # The current connection. None if we are currently (re)connecting
-        self.connection = None
-
-        # Any pending commands to be sent once a new connection has been
-        # established
-        self.pending_commands = []  # type: List[Command]
-
-        # Map from string -> deferred, to wake up when receiveing a SYNC with
-        # the given string.
-        # Used for tests.
-        self.awaiting_syncs = {}  # type: Dict[str, defer.Deferred]
-
-        # The factory used to create connections.
-        self.factory = None  # type: Optional[ReplicationClientFactory]
-
-    def start_replication(self, hs):
-        """Helper method to start a replication connection to the remote server
-        using TCP.
-        """
-        client_name = hs.config.worker_name
-        self.factory = ReplicationClientFactory(hs, client_name, self)
-        host = hs.config.worker_replication_host
-        port = hs.config.worker_replication_port
-        hs.get_reactor().connectTCP(host, port, self.factory)
-
-    async def on_rdata(self, stream_name, token, rows):
+    async def on_rdata(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
         """Called to handle a batch of replication data with a given stream token.
 
         By default this just pokes the slave store. Can be overridden in subclasses to
         handle more.
 
         Args:
-            stream_name (str): name of the replication stream for this batch of rows
-            token (int): stream token for this batch of rows
-            rows (list): a list of Stream.ROW_TYPE objects as returned by
-                Stream.parse_row.
+            stream_name: name of the replication stream for this batch of rows
+            instance_name: the instance that wrote the rows.
+            token: stream token for this batch of rows
+            rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
-        logger.debug("Received rdata %s -> %s", stream_name, token)
         self.store.process_replication_rows(stream_name, token, rows)
 
-    async def on_position(self, stream_name, token):
-        """Called when we get new position data. By default this just pokes
-        the slave store.
-
-        Can be overriden in subclasses to handle more.
-        """
+    async def on_position(self, stream_name: str, token: int):
         self.store.process_replication_rows(stream_name, token, [])
 
-    def on_sync(self, data):
-        """When we received a SYNC we wake up any deferreds that were waiting
-        for the sync with the given data.
-
-        Used by tests.
-        """
-        d = self.awaiting_syncs.pop(data, None)
-        if d:
-            d.callback(data)
-
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
-
-    def get_streams_to_replicate(self) -> Dict[str, int]:
-        """Called when a new connection has been established and we need to
-        subscribe to streams.
-
-        Returns:
-            map from stream name to the most recent update we have for
-            that stream (ie, the point we want to start replicating from)
-        """
-        args = self.store.stream_positions()
-        user_account_data = args.pop("user_account_data", None)
-        room_account_data = args.pop("room_account_data", None)
-        if user_account_data:
-            args["account_data"] = user_account_data
-        elif room_account_data:
-            args["account_data"] = room_account_data
-
-        return args
-
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users. (Overriden by the synchrotron's only)
-        """
-        return []
-
-    def send_command(self, cmd):
-        """Send a command to master (when we get establish a connection if we
-        don't have one already.)
-        """
-        if self.connection:
-            self.connection.send_command(cmd)
-        else:
-            logger.warning("Queuing command as not connected: %r", cmd.NAME)
-            self.pending_commands.append(cmd)
-
-    def send_federation_ack(self, token):
-        """Ack data for the federation stream. This allows the master to drop
-        data stored purely in memory.
-        """
-        self.send_command(FederationAckCommand(token))
-
-    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
-        """Poke the master that a user has started/stopped syncing.
-        """
-        self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
-
-    def send_remove_pusher(self, app_id, push_key, user_id):
-        """Poke the master to remove a pusher for a user
-        """
-        cmd = RemovePusherCommand(app_id, push_key, user_id)
-        self.send_command(cmd)
-
-    def send_invalidate_cache(self, cache_func, keys):
-        """Poke the master to invalidate a cache.
-        """
-        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
-        self.send_command(cmd)
-
-    def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
-        """Tell the master that the user made a request.
-        """
-        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
-        self.send_command(cmd)
-
-    def send_remote_server_up(self, server: str):
-        self.send_command(RemoteServerUpCommand(server))
-
-    def await_sync(self, data):
-        """Returns a deferred that is resolved when we receive a SYNC command
-        with given data.
-
-        [Not currently] used by tests.
-        """
-        return self.awaiting_syncs.setdefault(data, defer.Deferred())
-
-    def update_connection(self, connection):
-        """Called when a connection has been established (or lost with None).
-        """
-        self.connection = connection
-        if connection:
-            for cmd in self.pending_commands:
-                connection.send_command(cmd)
-            self.pending_commands = []
-
-    def finished_connecting(self):
-        """Called when we have successfully subscribed and caught up to all
-        streams we're interested in.
-        """
-        logger.info("Finished connecting to server")
-
-        # We don't reset the delay any earlier as otherwise if there is a
-        # problem during start up we'll end up tight looping connecting to the
-        # server.
-        if self.factory:
-            self.factory.resetDelay()
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 451671412d..f58e384d17 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -17,7 +17,7 @@
 The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
 allowed to be sent by which side.
 """
-
+import abc
 import logging
 import platform
 from typing import Tuple, Type
@@ -34,34 +34,29 @@ else:
 logger = logging.getLogger(__name__)
 
 
-class Command(object):
+class Command(metaclass=abc.ABCMeta):
     """The base command class.
 
     All subclasses must set the NAME variable which equates to the name of the
     command on the wire.
 
     A full command line on the wire is constructed from `NAME + " " + to_line()`
-
-    The default implementation creates a command of form `<NAME> <data>`
     """
 
     NAME = None  # type: str
 
-    def __init__(self, data):
-        self.data = data
-
     @classmethod
+    @abc.abstractmethod
     def from_line(cls, line):
         """Deserialises a line from the wire into this command. `line` does not
         include the command.
         """
-        return cls(line)
 
-    def to_line(self):
+    @abc.abstractmethod
+    def to_line(self) -> str:
         """Serialises the comamnd for the wire. Does not include the command
         prefix.
         """
-        return self.data
 
     def get_logcontext_id(self):
         """Get a suitable string for the logcontext when processing this command"""
@@ -70,7 +65,21 @@ class Command(object):
         return self.NAME
 
 
-class ServerCommand(Command):
+class _SimpleCommand(Command):
+    """An implementation of Command whose argument is just a 'data' string."""
+
+    def __init__(self, data):
+        self.data = data
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(line)
+
+    def to_line(self) -> str:
+        return self.data
+
+
+class ServerCommand(_SimpleCommand):
     """Sent by the server on new connection and includes the server_name.
 
     Format::
@@ -86,7 +95,7 @@ class RdataCommand(Command):
 
     Format::
 
-        RDATA <stream_name> <token> <row_json>
+        RDATA <stream_name> <instance_name> <token> <row_json>
 
     The `<token>` may either be a numeric stream id OR "batch". The latter case
     is used to support sending multiple updates with the same stream ID. This
@@ -96,33 +105,40 @@ class RdataCommand(Command):
     The client should batch all incoming RDATA with a token of "batch" (per
     stream_name) until it sees an RDATA with a numeric stream ID.
 
+    The `<instance_name>` is the source of the new data (usually "master").
+
     `<token>` of "batch" maps to the instance variable `token` being None.
 
     An example of a batched series of RDATA::
 
-        RDATA presence batch ["@foo:example.com", "online", ...]
-        RDATA presence batch ["@bar:example.com", "online", ...]
-        RDATA presence 59 ["@baz:example.com", "online", ...]
+        RDATA presence master batch ["@foo:example.com", "online", ...]
+        RDATA presence master batch ["@bar:example.com", "online", ...]
+        RDATA presence master 59 ["@baz:example.com", "online", ...]
     """
 
     NAME = "RDATA"
 
-    def __init__(self, stream_name, token, row):
+    def __init__(self, stream_name, instance_name, token, row):
         self.stream_name = stream_name
+        self.instance_name = instance_name
         self.token = token
         self.row = row
 
     @classmethod
     def from_line(cls, line):
-        stream_name, token, row_json = line.split(" ", 2)
+        stream_name, instance_name, token, row_json = line.split(" ", 3)
         return cls(
-            stream_name, None if token == "batch" else int(token), json.loads(row_json)
+            stream_name,
+            instance_name,
+            None if token == "batch" else int(token),
+            json.loads(row_json),
         )
 
     def to_line(self):
         return " ".join(
             (
                 self.stream_name,
+                self.instance_name,
                 str(self.token) if self.token is not None else "batch",
                 _json_encoder.encode(self.row),
             )
@@ -136,26 +152,34 @@ class PositionCommand(Command):
     """Sent by the server to tell the client the stream postition without
     needing to send an RDATA.
 
-    Sent to the client after all missing updates for a stream have been sent
-    to the client and they're now up to date.
+    Format::
+
+        POSITION <stream_name> <instance_name> <token>
+
+    On receipt of a POSITION command clients should check if they have missed
+    any updates, and if so then fetch them out of band.
+
+    The `<instance_name>` is the process that sent the command and is the source
+    of the stream.
     """
 
     NAME = "POSITION"
 
-    def __init__(self, stream_name, token):
+    def __init__(self, stream_name, instance_name, token):
         self.stream_name = stream_name
+        self.instance_name = instance_name
         self.token = token
 
     @classmethod
     def from_line(cls, line):
-        stream_name, token = line.split(" ", 1)
-        return cls(stream_name, int(token))
+        stream_name, instance_name, token = line.split(" ", 2)
+        return cls(stream_name, instance_name, int(token))
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token)))
+        return " ".join((self.stream_name, self.instance_name, str(self.token)))
 
 
-class ErrorCommand(Command):
+class ErrorCommand(_SimpleCommand):
     """Sent by either side if there was an ERROR. The data is a string describing
     the error.
     """
@@ -163,14 +187,14 @@ class ErrorCommand(Command):
     NAME = "ERROR"
 
 
-class PingCommand(Command):
+class PingCommand(_SimpleCommand):
     """Sent by either side as a keep alive. The data is arbitary (often timestamp)
     """
 
     NAME = "PING"
 
 
-class NameCommand(Command):
+class NameCommand(_SimpleCommand):
     """Sent by client to inform the server of the client's identity. The data
     is the name
     """
@@ -179,76 +203,63 @@ class NameCommand(Command):
 
 
 class ReplicateCommand(Command):
-    """Sent by the client to subscribe to the stream.
+    """Sent by the client to subscribe to streams.
 
     Format::
 
-        REPLICATE <stream_name> <token>
-
-    Where <token> may be either:
-        * a numeric stream_id to stream updates from
-        * "NOW" to stream all subsequent updates.
-
-    The <stream_name> can be "ALL" to subscribe to all known streams, in which
-    case the <token> must be set to "NOW", i.e.::
-
-        REPLICATE ALL NOW
+        REPLICATE
     """
 
     NAME = "REPLICATE"
 
-    def __init__(self, stream_name, token):
-        self.stream_name = stream_name
-        self.token = token
+    def __init__(self):
+        pass
 
     @classmethod
     def from_line(cls, line):
-        stream_name, token = line.split(" ", 1)
-        if token in ("NOW", "now"):
-            token = "NOW"
-        else:
-            token = int(token)
-        return cls(stream_name, token)
+        return cls()
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token)))
-
-    def get_logcontext_id(self):
-        return "REPLICATE-" + self.stream_name
+        return ""
 
 
 class UserSyncCommand(Command):
     """Sent by the client to inform the server that a user has started or
-    stopped syncing. Used to calculate presence on the master.
+    stopped syncing on this process.
+
+    This is used by the process handling presence (typically the master) to
+    calculate who is online and who is not.
 
     Includes a timestamp of when the last user sync was.
 
     Format::
 
-        USER_SYNC <user_id> <state> <last_sync_ms>
+        USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
 
-    Where <state> is either "start" or "stop"
+    Where <state> is either "start" or "end"
     """
 
     NAME = "USER_SYNC"
 
-    def __init__(self, user_id, is_syncing, last_sync_ms):
+    def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
+        self.instance_id = instance_id
         self.user_id = user_id
         self.is_syncing = is_syncing
         self.last_sync_ms = last_sync_ms
 
     @classmethod
     def from_line(cls, line):
-        user_id, state, last_sync_ms = line.split(" ", 2)
+        instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
 
         if state not in ("start", "end"):
             raise Exception("Invalid USER_SYNC state %r" % (state,))
 
-        return cls(user_id, state == "start", int(last_sync_ms))
+        return cls(instance_id, user_id, state == "start", int(last_sync_ms))
 
     def to_line(self):
         return " ".join(
             (
+                self.instance_id,
                 self.user_id,
                 "start" if self.is_syncing else "end",
                 str(self.last_sync_ms),
@@ -256,6 +267,30 @@ class UserSyncCommand(Command):
         )
 
 
+class ClearUserSyncsCommand(Command):
+    """Sent by the client to inform the server that it should drop all
+    information about syncing users sent by the client.
+
+    Mainly used when client is about to shut down.
+
+    Format::
+
+        CLEAR_USER_SYNC <instance_id>
+    """
+
+    NAME = "CLEAR_USER_SYNC"
+
+    def __init__(self, instance_id):
+        self.instance_id = instance_id
+
+    @classmethod
+    def from_line(cls, line):
+        return cls(line)
+
+    def to_line(self):
+        return self.instance_id
+
+
 class FederationAckCommand(Command):
     """Sent by the client when it has processed up to a given point in the
     federation stream. This allows the master to drop in-memory caches of the
@@ -281,14 +316,6 @@ class FederationAckCommand(Command):
         return str(self.token)
 
 
-class SyncCommand(Command):
-    """Used for testing. The client protocol implementation allows waiting
-    on a SYNC command with a specified data.
-    """
-
-    NAME = "SYNC"
-
-
 class RemovePusherCommand(Command):
     """Sent by the client to request the master remove the given pusher.
 
@@ -387,7 +414,7 @@ class UserIpCommand(Command):
         )
 
 
-class RemoteServerUpCommand(Command):
+class RemoteServerUpCommand(_SimpleCommand):
     """Sent when a worker has detected that a remote server is no longer
     "down" and retry timings should be reset.
 
@@ -411,11 +438,11 @@ _COMMANDS = (
     ReplicateCommand,
     UserSyncCommand,
     FederationAckCommand,
-    SyncCommand,
     RemovePusherCommand,
     InvalidateCacheCommand,
     UserIpCommand,
     RemoteServerUpCommand,
+    ClearUserSyncsCommand,
 )  # type: Tuple[Type[Command], ...]
 
 # Map of command name to command type.
@@ -428,7 +455,6 @@ VALID_SERVER_COMMANDS = (
     PositionCommand.NAME,
     ErrorCommand.NAME,
     PingCommand.NAME,
-    SyncCommand.NAME,
     RemoteServerUpCommand.NAME,
 )
 
@@ -438,6 +464,7 @@ VALID_CLIENT_COMMANDS = (
     ReplicateCommand.NAME,
     PingCommand.NAME,
     UserSyncCommand.NAME,
+    ClearUserSyncsCommand.NAME,
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
     InvalidateCacheCommand.NAME,
@@ -445,3 +472,21 @@ VALID_CLIENT_COMMANDS = (
     ErrorCommand.NAME,
     RemoteServerUpCommand.NAME,
 )
+
+
+def parse_command_from_line(line: str) -> Command:
+    """Parses a command from a received line.
+
+    Line should already be stripped of whitespace and be checked if blank.
+    """
+
+    idx = line.find(" ")
+    if idx >= 0:
+        cmd_name = line[:idx]
+        rest_of_line = line[idx + 1 :]
+    else:
+        cmd_name = line
+        rest_of_line = ""
+
+    cmd_cls = COMMAND_MAP[cmd_name]
+    return cmd_cls.from_line(rest_of_line)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
new file mode 100644
index 0000000000..4328b38e9d
--- /dev/null
+++ b/synapse/replication/tcp/handler.py
@@ -0,0 +1,569 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+)
+
+from prometheus_client import Counter
+
+from twisted.internet.protocol import ReconnectingClientFactory
+
+from synapse.metrics import LaterGauge
+from synapse.replication.tcp.client import DirectTcpReplicationClientFactory
+from synapse.replication.tcp.commands import (
+    ClearUserSyncsCommand,
+    Command,
+    FederationAckCommand,
+    InvalidateCacheCommand,
+    PositionCommand,
+    RdataCommand,
+    RemoteServerUpCommand,
+    RemovePusherCommand,
+    ReplicateCommand,
+    UserIpCommand,
+    UserSyncCommand,
+)
+from synapse.replication.tcp.protocol import AbstractConnection
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.util.async_helpers import Linearizer
+
+logger = logging.getLogger(__name__)
+
+
+# number of updates received for each RDATA stream
+inbound_rdata_count = Counter(
+    "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
+)
+user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
+federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
+remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
+invalidate_cache_counter = Counter(
+    "synapse_replication_tcp_resource_invalidate_cache", ""
+)
+user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
+
+
+class ReplicationCommandHandler:
+    """Handles incoming commands from replication as well as sending commands
+    back out to connections.
+    """
+
+    def __init__(self, hs):
+        self._replication_data_handler = hs.get_replication_data_handler()
+        self._presence_handler = hs.get_presence_handler()
+        self._store = hs.get_datastore()
+        self._notifier = hs.get_notifier()
+        self._clock = hs.get_clock()
+        self._instance_id = hs.get_instance_id()
+        self._instance_name = hs.get_instance_name()
+
+        self._streams = {
+            stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+        }  # type: Dict[str, Stream]
+
+        self._position_linearizer = Linearizer(
+            "replication_position", clock=self._clock
+        )
+
+        # Map of stream to batched updates. See RdataCommand for info on how
+        # batching works.
+        self._pending_batches = {}  # type: Dict[str, List[Any]]
+
+        # The factory used to create connections.
+        self._factory = None  # type: Optional[ReconnectingClientFactory]
+
+        # The currently connected connections. (The list of places we need to send
+        # outgoing replication commands to.)
+        self._connections = []  # type: List[AbstractConnection]
+
+        # For each connection, the incoming streams that are coming from that connection
+        self._streams_by_connection = {}  # type: Dict[AbstractConnection, Set[str]]
+
+        LaterGauge(
+            "synapse_replication_tcp_resource_total_connections",
+            "",
+            [],
+            lambda: len(self._connections),
+        )
+
+        self._is_master = hs.config.worker_app is None
+
+        self._federation_sender = None
+        if self._is_master and not hs.config.send_federation:
+            self._federation_sender = hs.get_federation_sender()
+
+        self._server_notices_sender = None
+        if self._is_master:
+            self._server_notices_sender = hs.get_server_notices_sender()
+
+    def start_replication(self, hs):
+        """Helper method to start a replication connection to the remote server
+        using TCP.
+        """
+        if hs.config.redis.redis_enabled:
+            from synapse.replication.tcp.redis import (
+                RedisDirectTcpReplicationClientFactory,
+            )
+            import txredisapi
+
+            logger.info(
+                "Connecting to redis (host=%r port=%r)",
+                hs.config.redis_host,
+                hs.config.redis_port,
+            )
+
+            # We need two connections to redis, one for the subscription stream and
+            # one to send commands to (as you can't send further redis commands to a
+            # connection after SUBSCRIBE is called).
+
+            # First create the connection for sending commands.
+            outbound_redis_connection = txredisapi.lazyConnection(
+                host=hs.config.redis_host,
+                port=hs.config.redis_port,
+                password=hs.config.redis.redis_password,
+                reconnect=True,
+            )
+
+            # Now create the factory/connection for the subscription stream.
+            self._factory = RedisDirectTcpReplicationClientFactory(
+                hs, outbound_redis_connection
+            )
+            hs.get_reactor().connectTCP(
+                hs.config.redis.redis_host, hs.config.redis.redis_port, self._factory,
+            )
+        else:
+            client_name = hs.get_instance_name()
+            self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
+            host = hs.config.worker_replication_host
+            port = hs.config.worker_replication_port
+            hs.get_reactor().connectTCP(host, port, self._factory)
+
+    async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
+        # We only want to announce positions by the writer of the streams.
+        # Currently this is just the master process.
+        if not self._is_master:
+            return
+
+        for stream_name, stream in self._streams.items():
+            current_token = stream.current_token()
+            self.send_command(
+                PositionCommand(stream_name, self._instance_name, current_token)
+            )
+
+    async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
+        user_sync_counter.inc()
+
+        if self._is_master:
+            await self._presence_handler.update_external_syncs_row(
+                cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+            )
+
+    async def on_CLEAR_USER_SYNC(
+        self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
+    ):
+        if self._is_master:
+            await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
+
+    async def on_FEDERATION_ACK(
+        self, conn: AbstractConnection, cmd: FederationAckCommand
+    ):
+        federation_ack_counter.inc()
+
+        if self._federation_sender:
+            self._federation_sender.federation_ack(cmd.token)
+
+    async def on_REMOVE_PUSHER(
+        self, conn: AbstractConnection, cmd: RemovePusherCommand
+    ):
+        remove_pusher_counter.inc()
+
+        if self._is_master:
+            await self._store.delete_pusher_by_app_id_pushkey_user_id(
+                app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+            )
+
+            self._notifier.on_new_replication_data()
+
+    async def on_INVALIDATE_CACHE(
+        self, conn: AbstractConnection, cmd: InvalidateCacheCommand
+    ):
+        invalidate_cache_counter.inc()
+
+        if self._is_master:
+            # We invalidate the cache locally, but then also stream that to other
+            # workers.
+            await self._store.invalidate_cache_and_stream(
+                cmd.cache_func, tuple(cmd.keys)
+            )
+
+    async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
+        user_ip_cache_counter.inc()
+
+        if self._is_master:
+            await self._store.insert_client_ip(
+                cmd.user_id,
+                cmd.access_token,
+                cmd.ip,
+                cmd.user_agent,
+                cmd.device_id,
+                cmd.last_seen,
+            )
+
+        if self._server_notices_sender:
+            await self._server_notices_sender.on_user_ip(cmd.user_id)
+
+    async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
+        if cmd.instance_name == self._instance_name:
+            # Ignore RDATA that are just our own echoes
+            return
+
+        stream_name = cmd.stream_name
+        inbound_rdata_count.labels(stream_name).inc()
+
+        try:
+            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+        except Exception:
+            logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
+            raise
+
+        # We linearize here for two reasons:
+        #   1. so we don't try and concurrently handle multiple rows for the
+        #      same stream, and
+        #   2. so we don't race with getting a POSITION command and fetching
+        #      missing RDATA.
+        with await self._position_linearizer.queue(cmd.stream_name):
+            # make sure that we've processed a POSITION for this stream *on this
+            # connection*. (A POSITION on another connection is no good, as there
+            # is no guarantee that we have seen all the intermediate updates.)
+            sbc = self._streams_by_connection.get(conn)
+            if not sbc or stream_name not in sbc:
+                # Let's drop the row for now, on the assumption we'll receive a
+                # `POSITION` soon and we'll catch up correctly then.
+                logger.debug(
+                    "Discarding RDATA for unconnected stream %s -> %s",
+                    stream_name,
+                    cmd.token,
+                )
+                return
+
+            if cmd.token is None:
+                # I.e. this is part of a batch of updates for this stream (in
+                # which case batch until we get an update for the stream with a non
+                # None token).
+                self._pending_batches.setdefault(stream_name, []).append(row)
+            else:
+                # Check if this is the last of a batch of updates
+                rows = self._pending_batches.pop(stream_name, [])
+                rows.append(row)
+                await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
+
+    async def on_rdata(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
+        """Called to handle a batch of replication data with a given stream token.
+
+        Args:
+            stream_name: name of the replication stream for this batch of rows
+            instance_name: the instance that wrote the rows.
+            token: stream token for this batch of rows
+            rows: a list of Stream.ROW_TYPE objects as returned by
+                Stream.parse_row.
+        """
+        logger.debug("Received rdata %s -> %s", stream_name, token)
+        await self._replication_data_handler.on_rdata(
+            stream_name, instance_name, token, rows
+        )
+
+    async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
+        if cmd.instance_name == self._instance_name:
+            # Ignore POSITION that are just our own echoes
+            return
+
+        logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
+
+        stream_name = cmd.stream_name
+        stream = self._streams.get(stream_name)
+        if not stream:
+            logger.error("Got POSITION for unknown stream: %s", stream_name)
+            return
+
+        # We protect catching up with a linearizer in case the replication
+        # connection reconnects under us.
+        with await self._position_linearizer.queue(stream_name):
+            # We're about to go and catch up with the stream, so remove from set
+            # of connected streams.
+            for streams in self._streams_by_connection.values():
+                streams.discard(stream_name)
+
+            # We clear the pending batches for the stream as the fetching of the
+            # missing updates below will fetch all rows in the batch.
+            self._pending_batches.pop(stream_name, [])
+
+            # Find where we previously streamed up to.
+            current_token = stream.current_token()
+
+            # If the position token matches our current token then we're up to
+            # date and there's nothing to do. Otherwise, fetch all updates
+            # between then and now.
+            missing_updates = cmd.token != current_token
+            while missing_updates:
+                logger.info(
+                    "Fetching replication rows for '%s' between %i and %i",
+                    stream_name,
+                    current_token,
+                    cmd.token,
+                )
+                (
+                    updates,
+                    current_token,
+                    missing_updates,
+                ) = await stream.get_updates_since(
+                    cmd.instance_name, current_token, cmd.token
+                )
+
+                # TODO: add some tests for this
+
+                # Some streams return multiple rows with the same stream IDs,
+                # which need to be processed in batches.
+
+                for token, rows in _batch_updates(updates):
+                    await self.on_rdata(
+                        stream_name,
+                        cmd.instance_name,
+                        token,
+                        [stream.parse_row(row) for row in rows],
+                    )
+
+            logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
+
+            # We've now caught up to position sent to us, notify handler.
+            await self._replication_data_handler.on_position(stream_name, cmd.token)
+
+            self._streams_by_connection.setdefault(conn, set()).add(stream_name)
+
+    async def on_REMOTE_SERVER_UP(
+        self, conn: AbstractConnection, cmd: RemoteServerUpCommand
+    ):
+        """"Called when get a new REMOTE_SERVER_UP command."""
+        self._replication_data_handler.on_remote_server_up(cmd.data)
+
+        self._notifier.notify_remote_server_up(cmd.data)
+
+        # We relay to all other connections to ensure every instance gets the
+        # notification.
+        #
+        # When configured to use redis we'll always only have one connection and
+        # so this is a no-op (all instances will have already received the same
+        # REMOTE_SERVER_UP command).
+        #
+        # For direct TCP connections this will relay to all other connections
+        # connected to us. When on master this will correctly fan out to all
+        # other direct TCP clients and on workers there'll only be the one
+        # connection to master.
+        #
+        # (The logic here should also be sound if we have a mix of Redis and
+        # direct TCP connections so long as there is only one traffic route
+        # between two instances, but that is not currently supported).
+        self.send_command(cmd, ignore_conn=conn)
+
+    def new_connection(self, connection: AbstractConnection):
+        """Called when we have a new connection.
+        """
+        self._connections.append(connection)
+
+        # If we are connected to replication as a client (rather than a server)
+        # we need to reset the reconnection delay on the client factory (which
+        # is used to do exponential back off when the connection drops).
+        #
+        # Ideally we would reset the delay when we've "fully established" the
+        # connection (for some definition thereof) to stop us from tightlooping
+        # on reconnection if something fails after this point and we drop the
+        # connection. Unfortunately, we don't really have a better definition of
+        # "fully established" than the connection being established.
+        if self._factory:
+            self._factory.resetDelay()
+
+        # Tell the other end if we have any users currently syncing.
+        currently_syncing = (
+            self._presence_handler.get_currently_syncing_users_for_replication()
+        )
+
+        now = self._clock.time_msec()
+        for user_id in currently_syncing:
+            connection.send_command(
+                UserSyncCommand(self._instance_id, user_id, True, now)
+            )
+
+    def lost_connection(self, connection: AbstractConnection):
+        """Called when a connection is closed/lost.
+        """
+        # we no longer need _streams_by_connection for this connection.
+        streams = self._streams_by_connection.pop(connection, None)
+        if streams:
+            logger.info(
+                "Lost replication connection; streams now disconnected: %s", streams
+            )
+        try:
+            self._connections.remove(connection)
+        except ValueError:
+            pass
+
+    def connected(self) -> bool:
+        """Do we have any replication connections open?
+
+        Is used by e.g. `ReplicationStreamer` to no-op if nothing is connected.
+        """
+        return bool(self._connections)
+
+    def send_command(
+        self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
+    ):
+        """Send a command to all connected connections.
+
+        Args:
+            cmd
+            ignore_conn: If set don't send command to the given connection.
+                Used when relaying commands from one connection to all others.
+        """
+        if self._connections:
+            for connection in self._connections:
+                if connection == ignore_conn:
+                    continue
+
+                try:
+                    connection.send_command(cmd)
+                except Exception:
+                    # We probably want to catch some types of exceptions here
+                    # and log them as warnings (e.g. connection gone), but I
+                    # can't find what those exception types they would be.
+                    logger.exception(
+                        "Failed to write command %s to connection %s",
+                        cmd.NAME,
+                        connection,
+                    )
+        else:
+            logger.warning("Dropping command as not connected: %r", cmd.NAME)
+
+    def send_federation_ack(self, token: int):
+        """Ack data for the federation stream. This allows the master to drop
+        data stored purely in memory.
+        """
+        self.send_command(FederationAckCommand(token))
+
+    def send_user_sync(
+        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+    ):
+        """Poke the master that a user has started/stopped syncing.
+        """
+        self.send_command(
+            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+        )
+
+    def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
+        """Poke the master to remove a pusher for a user
+        """
+        cmd = RemovePusherCommand(app_id, push_key, user_id)
+        self.send_command(cmd)
+
+    def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
+        """Poke the master to invalidate a cache.
+        """
+        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
+        self.send_command(cmd)
+
+    def send_user_ip(
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: str,
+        last_seen: int,
+    ):
+        """Tell the master that the user made a request.
+        """
+        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
+        self.send_command(cmd)
+
+    def send_remote_server_up(self, server: str):
+        self.send_command(RemoteServerUpCommand(server))
+
+    def stream_update(self, stream_name: str, token: str, data: Any):
+        """Called when a new update is available to stream to clients.
+
+        We need to check if the client is interested in the stream or not
+        """
+        self.send_command(RdataCommand(stream_name, self._instance_name, token, data))
+
+
+UpdateToken = TypeVar("UpdateToken")
+UpdateRow = TypeVar("UpdateRow")
+
+
+def _batch_updates(
+    updates: Iterable[Tuple[UpdateToken, UpdateRow]]
+) -> Iterator[Tuple[UpdateToken, List[UpdateRow]]]:
+    """Collect stream updates with the same token together
+
+    Given a series of updates returned by Stream.get_updates_since(), collects
+    the updates which share the same stream_id together.
+
+    For example:
+
+        [(1, a), (1, b), (2, c), (3, d), (3, e)]
+
+    becomes:
+
+        [
+            (1, [a, b]),
+            (2, [c]),
+            (3, [d, e]),
+        ]
+    """
+
+    update_iter = iter(updates)
+
+    first_update = next(update_iter, None)
+    if first_update is None:
+        # empty input
+        return
+
+    current_batch_token = first_update[0]
+    current_batch = [first_update[1]]
+
+    for token, row in update_iter:
+        if token != current_batch_token:
+            # different token to the previous row: flush the previous
+            # batch and start anew
+            yield current_batch_token, current_batch
+            current_batch_token = token
+            current_batch = []
+
+        current_batch.append(row)
+
+    # flush the final batch
+    yield current_batch_token, current_batch
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index bc1482a9bb..4198eece71 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
     > PING 1490197665618
     < NAME synapse.app.appservice
     < PING 1490197665618
-    < REPLICATE events 1
-    < REPLICATE backfill 1
-    < REPLICATE caches 1
+    < REPLICATE
     > POSITION events 1
     > POSITION backfill 1
     > POSITION caches 1
@@ -52,45 +50,51 @@ import abc
 import fcntl
 import logging
 import struct
-from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
-
-from six import iteritems, iterkeys
+from typing import TYPE_CHECKING, List
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.commands import (
-    COMMAND_MAP,
     VALID_CLIENT_COMMANDS,
     VALID_SERVER_COMMANDS,
     Command,
     ErrorCommand,
     NameCommand,
     PingCommand,
-    PositionCommand,
-    RdataCommand,
-    RemoteServerUpCommand,
     ReplicateCommand,
     ServerCommand,
-    SyncCommand,
-    UserSyncCommand,
+    parse_command_from_line,
 )
-from synapse.replication.tcp.streams import STREAMS_MAP
 from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
+if TYPE_CHECKING:
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
+    from synapse.server import HomeServer
+
+
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
 )
 
+tcp_inbound_commands_counter = Counter(
+    "synapse_replication_tcp_protocol_inbound_commands",
+    "Number of commands received from replication, by command and name of process connected to",
+    ["command", "name"],
+)
+
+tcp_outbound_commands_counter = Counter(
+    "synapse_replication_tcp_protocol_outbound_commands",
+    "Number of commands sent to replication, by command and name of process connected to",
+    ["command", "name"],
+)
+
 # A list of all connected protocols. This allows us to send metrics about the
 # connections.
 connected_connections = []
@@ -119,7 +123,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     are only sent by the server.
 
     On receiving a new command it calls `on_<COMMAND_NAME>` with the parsed
-    command.
+    command before delegating to `ReplicationCommandHandler.on_<COMMAND_NAME>`.
 
     It also sends `PING` periodically, and correctly times out remote connections
     (if they send a `PING` command)
@@ -135,8 +139,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     max_line_buffer = 10000
 
-    def __init__(self, clock):
+    def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"):
         self.clock = clock
+        self.command_handler = handler
 
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
@@ -155,9 +160,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # The LoopingCall for sending pings.
         self._send_ping_loop = None
 
-        self.inbound_commands_counter = defaultdict(int)  # type: DefaultDict[str, int]
-        self.outbound_commands_counter = defaultdict(int)  # type: DefaultDict[str, int]
-
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
 
@@ -176,6 +178,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # can time us out.
         self.send_command(PingCommand(self.clock.time_msec()))
 
+        self.command_handler.new_connection(self)
+
     def send_ping(self):
         """Periodically sends a ping and checks if we should close the connection
         due to the other side timing out.
@@ -203,38 +207,30 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
                 )
                 self.send_error("ping timeout")
 
-    def lineReceived(self, line):
+    def lineReceived(self, line: bytes):
         """Called when we've received a line
         """
         if line.strip() == "":
             # Ignore blank lines
             return
 
-        line = line.decode("utf-8")
-        cmd_name, rest_of_line = line.split(" ", 1)
+        linestr = line.decode("utf-8")
 
-        if cmd_name not in self.VALID_INBOUND_COMMANDS:
-            logger.error("[%s] invalid command %s", self.id(), cmd_name)
-            self.send_error("invalid command: %s", cmd_name)
+        try:
+            cmd = parse_command_from_line(linestr)
+        except Exception as e:
+            logger.exception("[%s] failed to parse line: %r", self.id(), linestr)
+            self.send_error("failed to parse line: %r (%r):" % (e, linestr))
             return
 
-        self.last_received_command = self.clock.time_msec()
+        if cmd.NAME not in self.VALID_INBOUND_COMMANDS:
+            logger.error("[%s] invalid command %s", self.id(), cmd.NAME)
+            self.send_error("invalid command: %s", cmd.NAME)
+            return
 
-        self.inbound_commands_counter[cmd_name] = (
-            self.inbound_commands_counter[cmd_name] + 1
-        )
+        self.last_received_command = self.clock.time_msec()
 
-        cmd_cls = COMMAND_MAP[cmd_name]
-        try:
-            cmd = cmd_cls.from_line(rest_of_line)
-        except Exception as e:
-            logger.exception(
-                "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
-            )
-            self.send_error(
-                "failed to parse line for  %r: %r (%r):" % (cmd_name, e, rest_of_line)
-            )
-            return
+        tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc()
 
         # Now lets try and call on_<CMD_NAME> function
         run_as_background_process(
@@ -244,13 +240,31 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
     async def handle_command(self, cmd: Command):
         """Handle a command we have received over the replication stream.
 
-        By default delegates to on_<COMMAND>, which should return an awaitable.
+        First calls `self.on_<COMMAND>` if it exists, then calls
+        `self.command_handler.on_<COMMAND>` if it exists. This allows for
+        protocol level handling of commands (e.g. PINGs), before delegating to
+        the handler.
 
         Args:
             cmd: received command
         """
-        handler = getattr(self, "on_%s" % (cmd.NAME,))
-        await handler(cmd)
+        handled = False
+
+        # First call any command handlers on this instance. These are for TCP
+        # specific handling.
+        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        # Then call out to the handler.
+        cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(self, cmd)
+            handled = True
+
+        if not handled:
+            logger.warning("Unhandled command: %r", cmd)
 
     def close(self):
         logger.warning("[%s] Closing connection", self.id())
@@ -282,9 +296,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             self._queue_command(cmd)
             return
 
-        self.outbound_commands_counter[cmd.NAME] = (
-            self.outbound_commands_counter[cmd.NAME] + 1
-        )
+        tcp_outbound_commands_counter.labels(cmd.NAME, self.name).inc()
+
         string = "%s %s" % (cmd.NAME, cmd.to_line())
         if "\n" in string:
             raise Exception("Unexpected newline in command: %r", string)
@@ -379,6 +392,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.state = ConnectionStates.CLOSED
         self.pending_commands = []
 
+        self.command_handler.lost_connection(self)
+
         if self.transport:
             self.transport.unregisterProducer()
 
@@ -405,232 +420,21 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
     VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
     VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
 
-    def __init__(self, server_name, clock, streamer):
-        BaseReplicationStreamProtocol.__init__(self, clock)  # Old style class
+    def __init__(
+        self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler"
+    ):
+        super().__init__(clock, handler)
 
         self.server_name = server_name
-        self.streamer = streamer
-
-        # The streams the client has subscribed to and is up to date with
-        self.replication_streams = set()  # type: Set[str]
-
-        # The streams the client is currently subscribing to.
-        self.connecting_streams = set()  # type:  Set[str]
-
-        # Map from stream name to list of updates to send once we've finished
-        # subscribing the client to the stream.
-        self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]]
 
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
-        BaseReplicationStreamProtocol.connectionMade(self)
-        self.streamer.new_connection(self)
+        super().connectionMade()
 
     async def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
 
-    async def on_USER_SYNC(self, cmd):
-        await self.streamer.on_user_sync(
-            self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
-        )
-
-    async def on_REPLICATE(self, cmd):
-        stream_name = cmd.stream_name
-        token = cmd.token
-
-        if stream_name == "ALL":
-            # Subscribe to all streams we're publishing to.
-            deferreds = [
-                run_in_background(self.subscribe_to_stream, stream, token)
-                for stream in iterkeys(self.streamer.streams_by_name)
-            ]
-
-            await make_deferred_yieldable(
-                defer.gatherResults(deferreds, consumeErrors=True)
-            )
-        else:
-            await self.subscribe_to_stream(stream_name, token)
-
-    async def on_FEDERATION_ACK(self, cmd):
-        self.streamer.federation_ack(cmd.token)
-
-    async def on_REMOVE_PUSHER(self, cmd):
-        await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
-
-    async def on_INVALIDATE_CACHE(self, cmd):
-        await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
-
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
-        self.streamer.on_remote_server_up(cmd.data)
-
-    async def on_USER_IP(self, cmd):
-        self.streamer.on_user_ip(
-            cmd.user_id,
-            cmd.access_token,
-            cmd.ip,
-            cmd.user_agent,
-            cmd.device_id,
-            cmd.last_seen,
-        )
-
-    async def subscribe_to_stream(self, stream_name, token):
-        """Subscribe the remote to a stream.
-
-        This invloves checking if they've missed anything and sending those
-        updates down if they have. During that time new updates for the stream
-        are queued and sent once we've sent down any missed updates.
-        """
-        self.replication_streams.discard(stream_name)
-        self.connecting_streams.add(stream_name)
-
-        try:
-            # Get missing updates
-            updates, current_token = await self.streamer.get_stream_updates(
-                stream_name, token
-            )
-
-            # Send all the missing updates
-            for update in updates:
-                token, row = update[0], update[1]
-                self.send_command(RdataCommand(stream_name, token, row))
-
-            # We send a POSITION command to ensure that they have an up to
-            # date token (especially useful if we didn't send any updates
-            # above)
-            self.send_command(PositionCommand(stream_name, current_token))
-
-            # Now we can send any updates that came in while we were subscribing
-            pending_rdata = self.pending_rdata.pop(stream_name, [])
-            updates = []
-            for token, update in pending_rdata:
-                # If the token is null, it is part of a batch update. Batches
-                # are multiple updates that share a single token. To denote
-                # this, the token is set to None for all tokens in the batch
-                # except for the last. If we find a None token, we keep looking
-                # through tokens until we find one that is not None and then
-                # process all previous updates in the batch as if they had the
-                # final token.
-                if token is None:
-                    # Store this update as part of a batch
-                    updates.append(update)
-                    continue
-
-                if token <= current_token:
-                    # This update or batch of updates is older than
-                    # current_token, dismiss it
-                    updates = []
-                    continue
-
-                updates.append(update)
-
-                # Send all updates that are part of this batch with the
-                # found token
-                for update in updates:
-                    self.send_command(RdataCommand(stream_name, token, update))
-
-                # Clear stored updates
-                updates = []
-
-            # They're now fully subscribed
-            self.replication_streams.add(stream_name)
-        except Exception as e:
-            logger.exception("[%s] Failed to handle REPLICATE command", self.id())
-            self.send_error("failed to handle replicate: %r", e)
-        finally:
-            self.connecting_streams.discard(stream_name)
-
-    def stream_update(self, stream_name, token, data):
-        """Called when a new update is available to stream to clients.
-
-        We need to check if the client is interested in the stream or not
-        """
-        if stream_name in self.replication_streams:
-            # The client is subscribed to the stream
-            self.send_command(RdataCommand(stream_name, token, data))
-        elif stream_name in self.connecting_streams:
-            # The client is being subscribed to the stream
-            logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
-            self.pending_rdata.setdefault(stream_name, []).append((token, data))
-        else:
-            # The client isn't subscribed
-            logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
-
-    def send_sync(self, data):
-        self.send_command(SyncCommand(data))
-
-    def send_remote_server_up(self, server: str):
-        self.send_command(RemoteServerUpCommand(server))
-
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.streamer.lost_connection(self)
-
-
-class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
-    """
-    The interface for the handler that should be passed to
-    ClientReplicationStreamProtocol
-    """
-
-    @abc.abstractmethod
-    async def on_rdata(self, stream_name, token, rows):
-        """Called to handle a batch of replication data with a given stream token.
-
-        Args:
-            stream_name (str): name of the replication stream for this batch of rows
-            token (int): stream token for this batch of rows
-            rows (list): a list of Stream.ROW_TYPE objects as returned by
-                Stream.parse_row.
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    async def on_position(self, stream_name, token):
-        """Called when we get new position data."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def on_sync(self, data):
-        """Called when get a new SYNC command."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    async def on_remote_server_up(self, server: str):
-        """Called when get a new REMOTE_SERVER_UP command."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def get_streams_to_replicate(self):
-        """Called when a new connection has been established and we need to
-        subscribe to streams.
-
-        Returns:
-            map from stream name to the most recent update we have for
-            that stream (ie, the point we want to start replicating from)
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def update_connection(self, connection):
-        """Called when a connection has been established (or lost with None).
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def finished_connecting(self):
-        """Called when we have successfully subscribed and caught up to all
-        streams we're interested in.
-        """
-        raise NotImplementedError()
-
 
 class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
@@ -638,110 +442,51 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     def __init__(
         self,
+        hs: "HomeServer",
         client_name: str,
         server_name: str,
         clock: Clock,
-        handler: AbstractReplicationClientHandler,
+        command_handler: "ReplicationCommandHandler",
     ):
-        BaseReplicationStreamProtocol.__init__(self, clock)
+        super().__init__(clock, command_handler)
 
         self.client_name = client_name
         self.server_name = server_name
-        self.handler = handler
-
-        # Set of stream names that have been subscribe to, but haven't yet
-        # caught up with. This is used to track when the client has been fully
-        # connected to the remote.
-        self.streams_connecting = set()  # type: Set[str]
-
-        # Map of stream to batched updates. See RdataCommand for info on how
-        # batching works.
-        self.pending_batches = {}  # type: Dict[str, Any]
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
-        BaseReplicationStreamProtocol.connectionMade(self)
+        super().connectionMade()
 
         # Once we've connected subscribe to the necessary streams
-        for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
-            self.replicate(stream_name, token)
-
-        # Tell the server if we have any users currently syncing (should only
-        # happen on synchrotrons)
-        currently_syncing = self.handler.get_currently_syncing_users()
-        now = self.clock.time_msec()
-        for user_id in currently_syncing:
-            self.send_command(UserSyncCommand(user_id, True, now))
-
-        # We've now finished connecting to so inform the client handler
-        self.handler.update_connection(self)
-
-        # This will happen if we don't actually subscribe to any streams
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
+        self.replicate()
 
     async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
 
-    async def on_RDATA(self, cmd):
-        stream_name = cmd.stream_name
-        inbound_rdata_count.labels(stream_name).inc()
-
-        try:
-            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
-        except Exception:
-            logger.exception(
-                "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
-            )
-            raise
-
-        if cmd.token is None:
-            # I.e. this is part of a batch of updates for this stream. Batch
-            # until we get an update for the stream with a non None token
-            self.pending_batches.setdefault(stream_name, []).append(row)
-        else:
-            # Check if this is the last of a batch of updates
-            rows = self.pending_batches.pop(stream_name, [])
-            rows.append(row)
-            await self.handler.on_rdata(stream_name, cmd.token, rows)
-
-    async def on_POSITION(self, cmd):
-        # When we get a `POSITION` command it means we've finished getting
-        # missing updates for the given stream, and are now up to date.
-        self.streams_connecting.discard(cmd.stream_name)
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
+    def replicate(self):
+        """Send the subscription request to the server
+        """
+        logger.info("[%s] Subscribing to replication streams", self.id())
 
-        await self.handler.on_position(cmd.stream_name, cmd.token)
+        self.send_command(ReplicateCommand())
 
-    async def on_SYNC(self, cmd):
-        self.handler.on_sync(cmd.data)
 
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
-        self.handler.on_remote_server_up(cmd.data)
+class AbstractConnection(abc.ABC):
+    """An interface for replication connections.
+    """
 
-    def replicate(self, stream_name, token):
-        """Send the subscription request to the server
+    @abc.abstractmethod
+    def send_command(self, cmd: Command):
+        """Send the command down the connection
         """
-        if stream_name not in STREAMS_MAP:
-            raise Exception("Invalid stream name %r" % (stream_name,))
-
-        logger.info(
-            "[%s] Subscribing to replication stream: %r from %r",
-            self.id(),
-            stream_name,
-            token,
-        )
-
-        self.streams_connecting.add(stream_name)
+        pass
 
-        self.send_command(ReplicateCommand(stream_name, token))
 
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.handler.update_connection(None)
+# This tells python that `BaseReplicationStreamProtocol` implements the
+# interface.
+AbstractConnection.register(BaseReplicationStreamProtocol)
 
 
 # The following simply registers metrics for the replication connections
@@ -804,31 +549,3 @@ tcp_transport_kernel_read_buffer = LaterGauge(
         for p in connected_connections
     },
 )
-
-
-tcp_inbound_commands = LaterGauge(
-    "synapse_replication_tcp_protocol_inbound_commands",
-    "",
-    ["command", "name"],
-    lambda: {
-        (k, p.name): count
-        for p in connected_connections
-        for k, count in iteritems(p.inbound_commands_counter)
-    },
-)
-
-tcp_outbound_commands = LaterGauge(
-    "synapse_replication_tcp_protocol_outbound_commands",
-    "",
-    ["command", "name"],
-    lambda: {
-        (k, p.name): count
-        for p in connected_connections
-        for k, count in iteritems(p.outbound_commands_counter)
-    },
-)
-
-# number of updates received for each RDATA stream
-inbound_rdata_count = Counter(
-    "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"]
-)
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
new file mode 100644
index 0000000000..55bfa71dfd
--- /dev/null
+++ b/synapse/replication/tcp/redis.py
@@ -0,0 +1,210 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+import txredisapi
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.tcp.commands import (
+    Command,
+    ReplicateCommand,
+    parse_command_from_line,
+)
+from synapse.replication.tcp.protocol import (
+    AbstractConnection,
+    tcp_inbound_commands_counter,
+    tcp_outbound_commands_counter,
+)
+
+if TYPE_CHECKING:
+    from synapse.replication.tcp.handler import ReplicationCommandHandler
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
+    """Connection to redis subscribed to replication stream.
+
+    This class fulfils two functions:
+
+    (a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis
+    connection, parsing *incoming* messages into replication commands, and passing them
+    to `ReplicationCommandHandler`
+
+    (b) it implements the AbstractConnection API, where it sends *outgoing* commands
+    onto outbound_redis_connection.
+
+    Due to the vagaries of `txredisapi` we don't want to have a custom
+    constructor, so instead we expect the defined attributes below to be set
+    immediately after initialisation.
+
+    Attributes:
+        handler: The command handler to handle incoming commands.
+        stream_name: The *redis* stream name to subscribe to and publish from
+            (not anything to do with Synapse replication streams).
+        outbound_redis_connection: The connection to redis to use to send
+            commands.
+    """
+
+    handler = None  # type: ReplicationCommandHandler
+    stream_name = None  # type: str
+    outbound_redis_connection = None  # type: txredisapi.RedisProtocol
+
+    def connectionMade(self):
+        logger.info("Connected to redis")
+        super().connectionMade()
+        run_as_background_process("subscribe-replication", self._send_subscribe)
+        self.handler.new_connection(self)
+
+    async def _send_subscribe(self):
+        # it's important to make sure that we only send the REPLICATE command once we
+        # have successfully subscribed to the stream - otherwise we might miss the
+        # POSITION response sent back by the other end.
+        logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
+        await make_deferred_yieldable(self.subscribe(self.stream_name))
+        logger.info(
+            "Successfully subscribed to redis stream, sending REPLICATE command"
+        )
+        await self._async_send_command(ReplicateCommand())
+        logger.info("REPLICATE successfully sent")
+
+    def messageReceived(self, pattern: str, channel: str, message: str):
+        """Received a message from redis.
+        """
+
+        if message.strip() == "":
+            # Ignore blank lines
+            return
+
+        try:
+            cmd = parse_command_from_line(message)
+        except Exception:
+            logger.exception(
+                "Failed to parse replication line: %r", message,
+            )
+            return
+
+        # We use "redis" as the name here as we don't have 1:1 connections to
+        # remote instances.
+        tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc()
+
+        # Now lets try and call on_<CMD_NAME> function
+        run_as_background_process(
+            "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
+        )
+
+    async def handle_command(self, cmd: Command):
+        """Handle a command we have received over the replication stream.
+
+        By default delegates to on_<COMMAND>, which should return an awaitable.
+
+        Args:
+            cmd: received command
+        """
+        handled = False
+
+        # First call any command handlers on this instance. These are for redis
+        # specific handling.
+        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        # Then call out to the handler.
+        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(self, cmd)
+            handled = True
+
+        if not handled:
+            logger.warning("Unhandled command: %r", cmd)
+
+    def connectionLost(self, reason):
+        logger.info("Lost connection to redis")
+        super().connectionLost(reason)
+        self.handler.lost_connection(self)
+
+    def send_command(self, cmd: Command):
+        """Send a command if connection has been established.
+
+        Args:
+            cmd (Command)
+        """
+        run_as_background_process("send-cmd", self._async_send_command, cmd)
+
+    async def _async_send_command(self, cmd: Command):
+        """Encode a replication command and send it over our outbound connection"""
+        string = "%s %s" % (cmd.NAME, cmd.to_line())
+        if "\n" in string:
+            raise Exception("Unexpected newline in command: %r", string)
+
+        encoded_string = string.encode("utf-8")
+
+        # We use "redis" as the name here as we don't have 1:1 connections to
+        # remote instances.
+        tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
+
+        await make_deferred_yieldable(
+            self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+        )
+
+
+class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+    """This is a reconnecting factory that connects to redis and immediately
+    subscribes to a stream.
+
+    Args:
+        hs
+        outbound_redis_connection: A connection to redis that will be used to
+            send outbound commands (this is seperate to the redis connection
+            used to subscribe).
+    """
+
+    maxDelay = 5
+    continueTrying = True
+    protocol = RedisSubscriber
+
+    def __init__(
+        self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
+    ):
+
+        super().__init__()
+
+        # This sets the password on the RedisFactory base class (as
+        # SubscriberFactory constructor doesn't pass it through).
+        self.password = hs.config.redis.redis_password
+
+        self.handler = hs.get_tcp_replication()
+        self.stream_name = hs.hostname
+
+        self.outbound_redis_connection = outbound_redis_connection
+
+    def buildProtocol(self, addr):
+        p = super().buildProtocol(addr)  # type: RedisSubscriber
+
+        # We do this here rather than add to the constructor of `RedisSubcriber`
+        # as to do so would involve overriding `buildProtocol` entirely, however
+        # the base method does some other things than just instantiating the
+        # protocol.
+        p.handler = self.handler
+        p.outbound_redis_connection = self.outbound_redis_connection
+        p.stream_name = self.stream_name
+        p.password = self.password
+
+        return p
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index ce9d1fae12..33d2f589ac 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,32 +17,20 @@
 
 import logging
 import random
-from typing import Any, List
-
-from six import itervalues
+from typing import Dict, List
 
 from prometheus_client import Counter
 
 from twisted.internet.protocol import Factory
 
-from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.metrics import Measure, measure_func
-
-from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP
-from .streams.federation import FederationStream
+from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
+from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
+from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
     "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
 )
-user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
-federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
-remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter(
-    "synapse_replication_tcp_resource_invalidate_cache", ""
-)
-user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
 
 logger = logging.getLogger(__name__)
 
@@ -52,13 +40,23 @@ class ReplicationStreamProtocolFactory(Factory):
     """
 
     def __init__(self, hs):
-        self.streamer = ReplicationStreamer(hs)
+        self.command_handler = hs.get_tcp_replication()
         self.clock = hs.get_clock()
         self.server_name = hs.config.server_name
 
+        # If we've created a `ReplicationStreamProtocolFactory` then we're
+        # almost certainly registering a replication listener, so let's ensure
+        # that we've started a `ReplicationStreamer` instance to actually push
+        # data.
+        #
+        # (This is a bit of a weird place to do this, but the alternatives such
+        # as putting this in `HomeServer.setup()`, requires either passing the
+        # listener config again or always starting a `ReplicationStreamer`.)
+        hs.get_replication_streamer()
+
     def buildProtocol(self, addr):
         return ServerReplicationStreamProtocol(
-            self.server_name, self.clock, self.streamer
+            self.server_name, self.clock, self.command_handler
         )
 
 
@@ -71,67 +69,39 @@ class ReplicationStreamer(object):
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-        self.presence_handler = hs.get_presence_handler()
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
-        self._server_notices_sender = hs.get_server_notices_sender()
 
         self._replication_torture_level = hs.config.replication_torture_level
 
-        # Current connections.
-        self.connections = []  # type: List[ServerReplicationStreamProtocol]
-
-        LaterGauge(
-            "synapse_replication_tcp_resource_total_connections",
-            "",
-            [],
-            lambda: len(self.connections),
-        )
+        # Work out list of streams that this instance is the source of.
+        self.streams = []  # type: List[Stream]
+        if hs.config.worker_app is None:
+            for stream in STREAMS_MAP.values():
+                if stream == FederationStream and hs.config.send_federation:
+                    # We only support federation stream if federation sending
+                    # hase been disabled on the master.
+                    continue
 
-        # List of streams that clients can subscribe to.
-        # We only support federation stream if federation sending hase been
-        # disabled on the master.
-        self.streams = [
-            stream(hs)
-            for stream in itervalues(STREAMS_MAP)
-            if stream != FederationStream or not hs.config.send_federation
-        ]
+                self.streams.append(stream(hs))
 
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
 
-        LaterGauge(
-            "synapse_replication_tcp_resource_connections_per_stream",
-            "",
-            ["stream_name"],
-            lambda: {
-                (stream_name,): len(
-                    [
-                        conn
-                        for conn in self.connections
-                        if stream_name in conn.replication_streams
-                    ]
-                )
-                for stream_name in self.streams_by_name
-            },
-        )
-
-        self.federation_sender = None
-        if not hs.config.send_federation:
-            self.federation_sender = hs.get_federation_sender()
-
-        self.notifier.add_replication_callback(self.on_notifier_poke)
-        self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
+        # Only bother registering the notifier callback if we have streams to
+        # publish.
+        if self.streams:
+            self.notifier.add_replication_callback(self.on_notifier_poke)
 
         # Keeps track of whether we are currently checking for updates
         self.is_looping = False
         self.pending_updates = False
 
-        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
+        self.command_handler = hs.get_tcp_replication()
 
-    def on_shutdown(self):
-        # close all connections on shutdown
-        for conn in self.connections:
-            conn.send_error("server shutting down")
+    def get_streams(self) -> Dict[str, Stream]:
+        """Get a mapp from stream name to stream instance.
+        """
+        return self.streams_by_name
 
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
@@ -140,7 +110,7 @@ class ReplicationStreamer(object):
         This should get called each time new data is available, even if it
         is currently being executed, so that nothing gets missed
         """
-        if not self.connections:
+        if not self.command_handler.connected():
             # Don't bother if nothing is listening. We still need to advance
             # the stream tokens otherwise they'll fall beihind forever
             for stream in self.streams:
@@ -166,11 +136,6 @@ class ReplicationStreamer(object):
                 self.pending_updates = False
 
                 with Measure(self.clock, "repl.stream.get_updates"):
-                    # First we tell the streams that they should update their
-                    # current tokens.
-                    for stream in self.streams:
-                        stream.advance_current_token()
-
                     all_streams = self.streams
 
                     if self._replication_torture_level is not None:
@@ -180,7 +145,7 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.upto_token:
+                        if stream.last_token == stream.current_token():
                             continue
 
                         if self._replication_torture_level:
@@ -192,18 +157,17 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.upto_token,
+                            stream.current_token(),
                         )
                         try:
-                            updates, current_token = await stream.get_updates()
+                            updates, current_token, limited = await stream.get_updates()
+                            self.pending_updates |= limited
                         except Exception:
                             logger.info("Failed to handle stream %s", stream.NAME)
                             raise
 
                         logger.debug(
-                            "Sending %d updates to %d connections",
-                            len(updates),
-                            len(self.connections),
+                            "Sending %d updates", len(updates),
                         )
 
                         if updates:
@@ -219,116 +183,19 @@ class ReplicationStreamer(object):
                         # token. See RdataCommand for more details.
                         batched_updates = _batch_updates(updates)
 
-                        for conn in self.connections:
-                            for token, row in batched_updates:
-                                try:
-                                    conn.stream_update(stream.NAME, token, row)
-                                except Exception:
-                                    logger.exception("Failed to replicate")
+                        for token, row in batched_updates:
+                            try:
+                                self.command_handler.stream_update(
+                                    stream.NAME, token, row
+                                )
+                            except Exception:
+                                logger.exception("Failed to replicate")
 
             logger.debug("No more pending updates, breaking poke loop")
         finally:
             self.pending_updates = False
             self.is_looping = False
 
-    @measure_func("repl.get_stream_updates")
-    async def get_stream_updates(self, stream_name, token):
-        """For a given stream get all updates since token. This is called when
-        a client first subscribes to a stream.
-        """
-        stream = self.streams_by_name.get(stream_name, None)
-        if not stream:
-            raise Exception("unknown stream %s", stream_name)
-
-        return await stream.get_updates_since(token)
-
-    @measure_func("repl.federation_ack")
-    def federation_ack(self, token):
-        """We've received an ack for federation stream from a client.
-        """
-        federation_ack_counter.inc()
-        if self.federation_sender:
-            self.federation_sender.federation_ack(token)
-
-    @measure_func("repl.on_user_sync")
-    async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
-        """A client has started/stopped syncing on a worker.
-        """
-        user_sync_counter.inc()
-        await self.presence_handler.update_external_syncs_row(
-            conn_id, user_id, is_syncing, last_sync_ms
-        )
-
-    @measure_func("repl.on_remove_pusher")
-    async def on_remove_pusher(self, app_id, push_key, user_id):
-        """A client has asked us to remove a pusher
-        """
-        remove_pusher_counter.inc()
-        await self.store.delete_pusher_by_app_id_pushkey_user_id(
-            app_id=app_id, pushkey=push_key, user_id=user_id
-        )
-
-        self.notifier.on_new_replication_data()
-
-    @measure_func("repl.on_invalidate_cache")
-    async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
-        """The client has asked us to invalidate a cache
-        """
-        invalidate_cache_counter.inc()
-
-        # We invalidate the cache locally, but then also stream that to other
-        # workers.
-        await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
-
-    @measure_func("repl.on_user_ip")
-    async def on_user_ip(
-        self, user_id, access_token, ip, user_agent, device_id, last_seen
-    ):
-        """The client saw a user request
-        """
-        user_ip_cache_counter.inc()
-        await self.store.insert_client_ip(
-            user_id, access_token, ip, user_agent, device_id, last_seen
-        )
-        await self._server_notices_sender.on_user_ip(user_id)
-
-    @measure_func("repl.on_remote_server_up")
-    def on_remote_server_up(self, server: str):
-        self.notifier.notify_remote_server_up(server)
-
-    def send_remote_server_up(self, server: str):
-        for conn in self.connections:
-            conn.send_remote_server_up(server)
-
-    def send_sync_to_all_connections(self, data):
-        """Sends a SYNC command to all clients.
-
-        Used in tests.
-        """
-        for conn in self.connections:
-            conn.send_sync(data)
-
-    def new_connection(self, connection):
-        """A new client connection has been established
-        """
-        self.connections.append(connection)
-
-    def lost_connection(self, connection):
-        """A client connection has been lost
-        """
-        try:
-            self.connections.remove(connection)
-        except ValueError:
-            pass
-
-        # We need to tell the presence handler that the connection has been
-        # lost so that it can handle any ongoing syncs on that connection.
-        run_as_background_process(
-            "update_external_syncs_clear",
-            self.presence_handler.update_external_syncs_clear,
-            connection.conn_id,
-        )
-
 
 def _batch_updates(updates):
     """Takes a list of updates of form [(token, row)] and sets the token to
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 5f52264e84..d1a61c3314 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -25,26 +25,63 @@ Each stream is defined by the following information:
     update_function:    The function that returns a list of updates between two tokens
 """
 
-from . import _base, events, federation
+from synapse.replication.tcp.streams._base import (
+    AccountDataStream,
+    BackfillStream,
+    CachesStream,
+    DeviceListsStream,
+    GroupServerStream,
+    PresenceStream,
+    PublicRoomsStream,
+    PushersStream,
+    PushRulesStream,
+    ReceiptsStream,
+    Stream,
+    TagAccountDataStream,
+    ToDeviceStream,
+    TypingStream,
+    UserSignatureStream,
+)
+from synapse.replication.tcp.streams.events import EventsStream
+from synapse.replication.tcp.streams.federation import FederationStream
 
 STREAMS_MAP = {
     stream.NAME: stream
     for stream in (
-        events.EventsStream,
-        _base.BackfillStream,
-        _base.PresenceStream,
-        _base.TypingStream,
-        _base.ReceiptsStream,
-        _base.PushRulesStream,
-        _base.PushersStream,
-        _base.CachesStream,
-        _base.PublicRoomsStream,
-        _base.DeviceListsStream,
-        _base.ToDeviceStream,
-        federation.FederationStream,
-        _base.TagAccountDataStream,
-        _base.AccountDataStream,
-        _base.GroupServerStream,
-        _base.UserSignatureStream,
+        EventsStream,
+        BackfillStream,
+        PresenceStream,
+        TypingStream,
+        ReceiptsStream,
+        PushRulesStream,
+        PushersStream,
+        CachesStream,
+        PublicRoomsStream,
+        DeviceListsStream,
+        ToDeviceStream,
+        FederationStream,
+        TagAccountDataStream,
+        AccountDataStream,
+        GroupServerStream,
+        UserSignatureStream,
     )
 }
+
+__all__ = [
+    "STREAMS_MAP",
+    "Stream",
+    "BackfillStream",
+    "PresenceStream",
+    "TypingStream",
+    "ReceiptsStream",
+    "PushRulesStream",
+    "PushersStream",
+    "CachesStream",
+    "PublicRoomsStream",
+    "DeviceListsStream",
+    "ToDeviceStream",
+    "TagAccountDataStream",
+    "AccountDataStream",
+    "GroupServerStream",
+    "UserSignatureStream",
+]
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..b0f87c365b 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,117 +14,71 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import itertools
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
 
 import attr
 
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
+
 logger = logging.getLogger(__name__)
 
+# the number of rows to request from an update_function.
+_STREAM_UPDATE_TARGET_ROW_COUNT = 100
 
-MAX_EVENTS_BEHIND = 500000
-
-BackfillStreamRow = namedtuple(
-    "BackfillStreamRow",
-    (
-        "event_id",  # str
-        "room_id",  # str
-        "type",  # str
-        "state_key",  # str, optional
-        "redacts",  # str, optional
-        "relates_to",  # str, optional
-    ),
-)
-PresenceStreamRow = namedtuple(
-    "PresenceStreamRow",
-    (
-        "user_id",  # str
-        "state",  # str
-        "last_active_ts",  # int
-        "last_federation_update_ts",  # int
-        "last_user_sync_ts",  # int
-        "status_msg",  # str
-        "currently_active",  # bool
-    ),
-)
-TypingStreamRow = namedtuple(
-    "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
-)
-ReceiptsStreamRow = namedtuple(
-    "ReceiptsStreamRow",
-    (
-        "room_id",  # str
-        "receipt_type",  # str
-        "user_id",  # str
-        "event_id",  # str
-        "data",  # dict
-    ),
-)
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
-PushersStreamRow = namedtuple(
-    "PushersStreamRow",
-    ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
-)
-
-
-@attr.s
-class CachesStreamRow:
-    """Stream to inform workers they should invalidate their cache.
-
-    Attributes:
-        cache_func: Name of the cached function.
-        keys: The entry in the cache to invalidate. If None then will
-            invalidate all.
-        invalidation_ts: Timestamp of when the invalidation took place.
-    """
 
-    cache_func = attr.ib(type=str)
-    keys = attr.ib(type=Optional[List[Any]])
-    invalidation_ts = attr.ib(type=int)
-
-
-PublicRoomsStreamRow = namedtuple(
-    "PublicRoomsStreamRow",
-    (
-        "room_id",  # str
-        "visibility",  # str
-        "appservice_id",  # str, optional
-        "network_id",  # str, optional
-    ),
-)
-DeviceListsStreamRow = namedtuple(
-    "DeviceListsStreamRow", ("user_id", "destination")  # str  # str
-)
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
-TagAccountDataStreamRow = namedtuple(
-    "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
-)
-AccountDataStreamRow = namedtuple(
-    "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
-)
-GroupsStreamRow = namedtuple(
-    "GroupsStreamRow",
-    ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
-)
-UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# The type of a stream update row, after JSON deserialisation, but before
+# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
+# just a row from a database query, though this is dependent on the stream in question.
+#
+StreamRow = Tuple
+
+# The type returned by the update_function of a stream, as well as get_updates(),
+# get_updates_since, etc.
+#
+# It consists of a triplet `(updates, new_last_token, limited)`, where:
+#   * `updates` is a list of `(token, row)` entries.
+#   * `new_last_token` is the new position in stream.
+#   * `limited` is whether there are more updates to fetch.
+#
+StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
+
+# The type of an update_function for a stream
+#
+# The arguments are:
+#
+#  * instance_name: the writer of the stream
+#  * from_token: the previous stream token: the starting point for fetching the
+#    updates
+#  * to_token: the new stream token: the point to get updates up to
+#  * target_row_count: a target for the number of rows to be returned.
+#
+# The update_function is expected to return up to _approximately_ target_row_count rows.
+# If there are more updates available, it should set `limited` in the result, and
+# it will be called again to get the next batch.
+#
+UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
 
 
 class Stream(object):
     """Base class for the streams.
 
     Provides a `get_updates()` function that returns new updates since the last
-    time it was called up until the point `advance_current_token` was called.
+    time it was called.
     """
 
     NAME = None  # type: str  # The name of the stream
     # The type of the row. Used by the default impl of parse_row.
     ROW_TYPE = None  # type: Any
-    _LIMITED = True  # Whether the update function takes a limit
 
     @classmethod
-    def parse_row(cls, row):
+    def parse_row(cls, row: StreamRow):
         """Parse a row received over replication
 
         By default, assumes that the row data is an array object and passes its contents
@@ -138,101 +92,122 @@ class Stream(object):
         """
         return cls.ROW_TYPE(*row)
 
-    def __init__(self, hs):
-        # The token from which we last asked for updates
-        self.last_token = self.current_token()
+    def __init__(
+        self,
+        local_instance_name: str,
+        current_token_function: Callable[[], Token],
+        update_function: UpdateFunction,
+    ):
+        """Instantiate a Stream
+
+        current_token_function and update_function are callbacks which should be
+        implemented by subclasses.
+
+        current_token_function is called to get the current token of the underlying
+        stream.
 
-        # The token that we will get updates up to
-        self.upto_token = self.current_token()
+        update_function is called to get updates for this stream between a pair of
+        stream tokens. See the UpdateFunction type definition for more info.
 
-    def advance_current_token(self):
-        """Updates `upto_token` to "now", which updates up until which point
-        get_updates[_since] will fetch rows till.
+        Args:
+            local_instance_name: The instance name of the current process
+            current_token_function: callback to get the current token, as above
+            update_function: callback go get stream updates, as above
         """
-        self.upto_token = self.current_token()
+        self.local_instance_name = local_instance_name
+        self.current_token = current_token_function
+        self.update_function = update_function
+
+        # The token from which we last asked for updates
+        self.last_token = self.current_token()
 
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.upto_token = self.current_token()
-        self.last_token = self.upto_token
+        self.last_token = self.current_token()
 
-    async def get_updates(self):
+    async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
-        since the stream was constructed if it hadn't been called before),
-        until the `upto_token`
+        since the stream was constructed if it hadn't been called before).
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        updates, current_token = await self.get_updates_since(self.last_token)
+        current_token = self.current_token()
+        updates, current_token, limited = await self.get_updates_since(
+            self.local_instance_name, self.last_token, current_token
+        )
         self.last_token = current_token
 
-        return updates, current_token
+        return updates, current_token, limited
 
-    async def get_updates_since(self, from_token):
+    async def get_updates_since(
+        self, instance_name: str, from_token: Token, upto_token: Token
+    ) -> StreamUpdateResult:
         """Like get_updates except allows specifying from when we should
         stream updates
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        if from_token in ("NOW", "now"):
-            return [], self.upto_token
-
-        current_token = self.upto_token
 
         from_token = int(from_token)
 
-        if from_token == current_token:
-            return [], current_token
+        if from_token == upto_token:
+            return [], upto_token, False
 
-        logger.info("get_updates_since: %s", self.__class__)
-        if self._LIMITED:
-            rows = await self.update_function(
-                from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
-            )
+        updates, upto_token, limited = await self.update_function(
+            instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+        )
+        return updates, upto_token, limited
 
-            # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
-            rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-        else:
-            rows = await self.update_function(from_token, current_token)
 
+def db_query_to_update_function(
+    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> UpdateFunction:
+    """Wraps a db query function which returns a list of rows to make it
+    suitable for use as an `update_function` for the Stream class
+    """
+
+    async def update_function(instance_name, from_token, upto_token, limit):
+        rows = await query_function(from_token, upto_token, limit)
         updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) >= limit:
+            upto_token = updates[-1][0]
+            limited = True
 
-        # check we didn't get more rows than the limit.
-        # doing it like this allows the update_function to be a generator.
-        if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
-            raise Exception("stream %s has fallen behind" % (self.NAME))
+        return updates, upto_token, limited
 
-        return updates, current_token
+    return update_function
 
-    def current_token(self):
-        """Gets the current token of the underlying streams. Should be provided
-        by the sub classes
 
-        Returns:
-            int
-        """
-        raise NotImplementedError()
+def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
+    """Makes a suitable function for use as an `update_function` that queries
+    the master process for updates.
+    """
 
-    def update_function(self, from_token, current_token, limit=None):
-        """Get updates between from_token and to_token. If Stream._LIMITED is
-        True then limit is provided, otherwise it's not.
+    client = ReplicationGetStreamUpdates.make_client(hs)
 
-        Returns:
-            Deferred(list(tuple)): the first entry in the tuple is the token for
-                that update, and the rest of the tuple gets used to construct
-                a ``ROW_TYPE`` instance
-        """
-        raise NotImplementedError()
+    async def update_function(
+        instance_name: str, from_token: int, upto_token: int, limit: int
+    ) -> StreamUpdateResult:
+        result = await client(
+            instance_name=instance_name,
+            stream_name=stream_name,
+            from_token=from_token,
+            upto_token=upto_token,
+        )
+        return result["updates"], result["upto_token"], result["limited"]
+
+    return update_function
 
 
 class BackfillStream(Stream):
@@ -240,93 +215,166 @@ class BackfillStream(Stream):
     or it went from being an outlier to not.
     """
 
+    BackfillStreamRow = namedtuple(
+        "BackfillStreamRow",
+        (
+            "event_id",  # str
+            "room_id",  # str
+            "type",  # str
+            "state_key",  # str, optional
+            "redacts",  # str, optional
+            "relates_to",  # str, optional
+        ),
+    )
+
     NAME = "backfill"
     ROW_TYPE = BackfillStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        self.current_token = store.get_current_backfill_token  # type: ignore
-        self.update_function = store.get_all_new_backfill_event_rows  # type: ignore
-
-        super(BackfillStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            store.get_current_backfill_token,
+            db_query_to_update_function(store.get_all_new_backfill_event_rows),
+        )
 
 
 class PresenceStream(Stream):
+    PresenceStreamRow = namedtuple(
+        "PresenceStreamRow",
+        (
+            "user_id",  # str
+            "state",  # str
+            "last_active_ts",  # int
+            "last_federation_update_ts",  # int
+            "last_user_sync_ts",  # int
+            "status_msg",  # str
+            "currently_active",  # bool
+        ),
+    )
+
     NAME = "presence"
-    _LIMITED = False
     ROW_TYPE = PresenceStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        presence_handler = hs.get_presence_handler()
 
-        self.current_token = store.get_current_presence_token  # type: ignore
-        self.update_function = presence_handler.get_all_presence_updates  # type: ignore
+        if hs.config.worker_app is None:
+            # on the master, query the presence handler
+            presence_handler = hs.get_presence_handler()
+            update_function = db_query_to_update_function(
+                presence_handler.get_all_presence_updates
+            )
+        else:
+            # Query master process
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(PresenceStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(), store.get_current_presence_token, update_function
+        )
 
 
 class TypingStream(Stream):
+    TypingStreamRow = namedtuple(
+        "TypingStreamRow", ("room_id", "user_ids")  # str  # list(str)
+    )
+
     NAME = "typing"
-    _LIMITED = False
     ROW_TYPE = TypingStreamRow
 
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
 
-        self.current_token = typing_handler.get_current_token  # type: ignore
-        self.update_function = typing_handler.get_all_typing_updates  # type: ignore
+        if hs.config.worker_app is None:
+            # on the master, query the typing handler
+            update_function = db_query_to_update_function(
+                typing_handler.get_all_typing_updates
+            )
+        else:
+            # Query master process
+            update_function = make_http_update_function(hs, self.NAME)
 
-        super(TypingStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(), typing_handler.get_current_token, update_function
+        )
 
 
 class ReceiptsStream(Stream):
+    ReceiptsStreamRow = namedtuple(
+        "ReceiptsStreamRow",
+        (
+            "room_id",  # str
+            "receipt_type",  # str
+            "user_id",  # str
+            "event_id",  # str
+            "data",  # dict
+        ),
+    )
+
     NAME = "receipts"
     ROW_TYPE = ReceiptsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_receipt_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_receipts  # type: ignore
-
-        super(ReceiptsStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            store.get_max_receipt_stream_id,
+            db_query_to_update_function(store.get_all_updated_receipts),
+        )
 
 
 class PushRulesStream(Stream):
     """A user has changed their push rules
     """
 
+    PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))  # str
+
     NAME = "push_rules"
     ROW_TYPE = PushRulesStreamRow
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
-        super(PushRulesStream, self).__init__(hs)
+        super(PushRulesStream, self).__init__(
+            hs.get_instance_name(), self._current_token, self._update_function
+        )
 
-    def current_token(self):
+    def _current_token(self) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(
+        self, instance_name: str, from_token: Token, to_token: Token, limit: int
+    ):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
-        return [(row[0], row[2]) for row in rows]
+
+        limited = False
+        if len(rows) == limit:
+            to_token = rows[-1][0]
+            limited = True
+
+        return [(row[0], (row[2],)) for row in rows], to_token, limited
 
 
 class PushersStream(Stream):
     """A user has added/changed/removed a pusher
     """
 
+    PushersStreamRow = namedtuple(
+        "PushersStreamRow",
+        ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
+    )
+
     NAME = "pushers"
     ROW_TYPE = PushersStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_pushers_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_pushers_rows  # type: ignore
-
-        super(PushersStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            store.get_pushers_stream_token,
+            db_query_to_update_function(store.get_all_updated_pushers_rows),
+        )
 
 
 class CachesStream(Stream):
@@ -334,98 +382,138 @@ class CachesStream(Stream):
     the cache on the workers
     """
 
+    @attr.s
+    class CachesStreamRow:
+        """Stream to inform workers they should invalidate their cache.
+
+        Attributes:
+            cache_func: Name of the cached function.
+            keys: The entry in the cache to invalidate. If None then will
+                invalidate all.
+            invalidation_ts: Timestamp of when the invalidation took place.
+        """
+
+        cache_func = attr.ib(type=str)
+        keys = attr.ib(type=Optional[List[Any]])
+        invalidation_ts = attr.ib(type=int)
+
     NAME = "caches"
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_cache_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_caches  # type: ignore
-
-        super(CachesStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            store.get_cache_stream_token,
+            db_query_to_update_function(store.get_all_updated_caches),
+        )
 
 
 class PublicRoomsStream(Stream):
     """The public rooms list changed
     """
 
+    PublicRoomsStreamRow = namedtuple(
+        "PublicRoomsStreamRow",
+        (
+            "room_id",  # str
+            "visibility",  # str
+            "appservice_id",  # str, optional
+            "network_id",  # str, optional
+        ),
+    )
+
     NAME = "public_rooms"
     ROW_TYPE = PublicRoomsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_current_public_room_stream_id  # type: ignore
-        self.update_function = store.get_all_new_public_rooms  # type: ignore
-
-        super(PublicRoomsStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            store.get_current_public_room_stream_id,
+            db_query_to_update_function(store.get_all_new_public_rooms),
+        )
 
 
 class DeviceListsStream(Stream):
-    """Someone added/changed/removed a device
+    """Either a user has updated their devices or a remote server needs to be
+    told about a device update.
     """
 
+    @attr.s
+    class DeviceListsStreamRow:
+        entity = attr.ib(type=str)
+
     NAME = "device_lists"
-    _LIMITED = False
     ROW_TYPE = DeviceListsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_device_list_changes_for_remotes  # type: ignore
-
-        super(DeviceListsStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            store.get_device_stream_token,
+            db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
+        )
 
 
 class ToDeviceStream(Stream):
     """New to_device messages for a client
     """
 
+    ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))  # str
+
     NAME = "to_device"
     ROW_TYPE = ToDeviceStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_to_device_stream_token  # type: ignore
-        self.update_function = store.get_all_new_device_messages  # type: ignore
-
-        super(ToDeviceStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            store.get_to_device_stream_token,
+            db_query_to_update_function(store.get_all_new_device_messages),
+        )
 
 
 class TagAccountDataStream(Stream):
     """Someone added/removed a tag for a room
     """
 
+    TagAccountDataStreamRow = namedtuple(
+        "TagAccountDataStreamRow", ("user_id", "room_id", "data")  # str  # str  # dict
+    )
+
     NAME = "tag_account_data"
     ROW_TYPE = TagAccountDataStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_tags  # type: ignore
-
-        super(TagAccountDataStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            store.get_max_account_data_stream_id,
+            db_query_to_update_function(store.get_all_updated_tags),
+        )
 
 
 class AccountDataStream(Stream):
     """Global or per room account data was changed
     """
 
+    AccountDataStreamRow = namedtuple(
+        "AccountDataStream", ("user_id", "room_id", "data_type")  # str  # str  # str
+    )
+
     NAME = "account_data"
     ROW_TYPE = AccountDataStreamRow
 
     def __init__(self, hs):
         self.store = hs.get_datastore()
+        super().__init__(
+            hs.get_instance_name(),
+            self.store.get_max_account_data_stream_id,
+            db_query_to_update_function(self._update_function),
+        )
 
-        self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
-
-        super(AccountDataStream, self).__init__(hs)
-
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(self, from_token, to_token, limit):
         global_results, room_results = await self.store.get_all_updated_account_data(
             from_token, from_token, to_token, limit
         )
@@ -440,30 +528,38 @@ class AccountDataStream(Stream):
 
 
 class GroupServerStream(Stream):
+    GroupsStreamRow = namedtuple(
+        "GroupsStreamRow",
+        ("group_id", "user_id", "type", "content"),  # str  # str  # str  # dict
+    )
+
     NAME = "groups"
     ROW_TYPE = GroupsStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_group_stream_token  # type: ignore
-        self.update_function = store.get_all_groups_changes  # type: ignore
-
-        super(GroupServerStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            store.get_group_stream_token,
+            db_query_to_update_function(store.get_all_groups_changes),
+        )
 
 
 class UserSignatureStream(Stream):
     """A user has signed their own device with their user-signing key
     """
 
+    UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id"))  # str
+
     NAME = "user_signature"
-    _LIMITED = False
     ROW_TYPE = UserSignatureStreamRow
 
     def __init__(self, hs):
         store = hs.get_datastore()
-
-        self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_user_signature_changes_for_remotes  # type: ignore
-
-        super(UserSignatureStream, self).__init__(hs)
+        super().__init__(
+            hs.get_instance_name(),
+            store.get_device_stream_token,
+            db_query_to_update_function(
+                store.get_all_user_signature_changes_for_remotes
+            ),
+        )
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..890e75d827 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -15,11 +15,12 @@
 # limitations under the License.
 
 import heapq
-from typing import Tuple, Type
+from collections import Iterable
+from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream
+from ._base import Stream, StreamUpdateResult, Token
 
 
 """Handling of the 'events' replication stream
@@ -116,28 +117,107 @@ class EventsStream(Stream):
 
     def __init__(self, hs):
         self._store = hs.get_datastore()
-        self.current_token = self._store.get_current_events_token  # type: ignore
+        super().__init__(
+            hs.get_instance_name(),
+            self._store.get_current_events_token,
+            self._update_function,
+        )
 
-        super(EventsStream, self).__init__(hs)
+    async def _update_function(
+        self,
+        instance_name: str,
+        from_token: Token,
+        current_token: Token,
+        target_row_count: int,
+    ) -> StreamUpdateResult:
+
+        # the events stream merges together three separate sources:
+        #  * new events
+        #  * current_state changes
+        #  * events which were previously outliers, but have now been de-outliered.
+        #
+        # The merge operation is complicated by the fact that we only have a single
+        # "stream token" which is supposed to indicate how far we have got through
+        # all three streams. It's therefore no good to return rows 1-1000 from the
+        # "new events" table if the state_deltas are limited to rows 1-100 by the
+        # target_row_count.
+        #
+        # In other words: we must pick a new upper limit, and must return *all* rows
+        # up to that point for each of the three sources.
+        #
+        # Start by trying to split the target_row_count up. We expect to have a
+        # negligible number of ex-outliers, and a rough approximation based on recent
+        # traffic on sw1v.org shows that there are approximately the same number of
+        # event rows between a given pair of stream ids as there are state
+        # updates, so let's split our target_row_count among those two types. The target
+        # is only an approximation - it doesn't matter if we end up going a bit over it.
+
+        target_row_count //= 2
+
+        # now we fetch up to that many rows from the events table
 
-    async def update_function(self, from_token, current_token, limit=None):
         event_rows = await self._store.get_all_new_forward_event_rows(
-            from_token, current_token, limit
-        )
-        event_updates = (
-            (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
+            from_token, current_token, target_row_count
+        )  # type: List[Tuple]
+
+        # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
+        # that we know it is safe to just take upper_limit = event_rows[-1][0].
+        assert (
+            len(event_rows) <= target_row_count
+        ), "get_all_new_forward_event_rows did not honour row limit"
+
+        # if we hit the limit on event_updates, there's no point in going beyond the
+        # last stream_id in the batch for the other sources.
+
+        if len(event_rows) == target_row_count:
+            limited = True
+            upper_limit = event_rows[-1][0]  # type: int
+        else:
+            limited = False
+            upper_limit = current_token
+
+        # next up is the state delta table.
+        (
+            state_rows,
+            upper_limit,
+            state_rows_limited,
+        ) = await self._store.get_all_updated_current_state_deltas(
+            from_token, upper_limit, target_row_count
         )
 
-        state_rows = await self._store.get_all_updated_current_state_deltas(
-            from_token, current_token, limit
-        )
-        state_updates = (
-            (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
-        )
+        limited = limited or state_rows_limited
 
-        all_updates = heapq.merge(event_updates, state_updates)
+        # finally, fetch the ex-outliers rows. We assume there are few enough of these
+        # not to bother with the limit.
 
-        return all_updates
+        ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
+            from_token, upper_limit
+        )  # type: List[Tuple]
+
+        # we now need to turn the raw database rows returned into tuples suitable
+        # for the replication protocol (basically, we add an identifier to
+        # distinguish the row type). At the same time, we can limit the event_rows
+        # to the max stream_id from state_rows.
+
+        event_updates = (
+            (stream_id, (EventsStreamEventRow.TypeId, rest))
+            for (stream_id, *rest) in event_rows
+            if stream_id <= upper_limit
+        )  # type: Iterable[Tuple[int, Tuple]]
+
+        state_updates = (
+            (stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
+            for (stream_id, *rest) in state_rows
+        )  # type: Iterable[Tuple[int, Tuple]]
+
+        ex_outliers_updates = (
+            (stream_id, (EventsStreamEventRow.TypeId, rest))
+            for (stream_id, *rest) in ex_outliers_rows
+        )  # type: Iterable[Tuple[int, Tuple]]
+
+        # we need to return a sorted list, so merge them together.
+        updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
+        return updates, upper_limit, limited
 
     @classmethod
     def parse_row(cls, row):
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 615f3dc9ac..e8bd52e389 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,15 +15,7 @@
 # limitations under the License.
 from collections import namedtuple
 
-from ._base import Stream
-
-FederationStreamRow = namedtuple(
-    "FederationStreamRow",
-    (
-        "type",  # str, the type of data as defined in the BaseFederationRows
-        "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
-    ),
-)
+from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
 
 
 class FederationStream(Stream):
@@ -31,13 +23,33 @@ class FederationStream(Stream):
     sending disabled.
     """
 
+    FederationStreamRow = namedtuple(
+        "FederationStreamRow",
+        (
+            "type",  # str, the type of data as defined in the BaseFederationRows
+            "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
+        ),
+    )
+
     NAME = "federation"
     ROW_TYPE = FederationStreamRow
 
     def __init__(self, hs):
-        federation_sender = hs.get_federation_sender()
-
-        self.current_token = federation_sender.get_current_token  # type: ignore
-        self.update_function = federation_sender.get_replication_rows  # type: ignore
-
-        super(FederationStream, self).__init__(hs)
+        # Not all synapse instances will have a federation sender instance,
+        # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
+        # so we stub the stream out when that is the case.
+        if hs.config.worker_app is None or hs.should_send_federation():
+            federation_sender = hs.get_federation_sender()
+            current_token = federation_sender.get_current_token
+            update_function = db_query_to_update_function(
+                federation_sender.get_replication_rows
+            )
+        else:
+            current_token = lambda: 0
+            update_function = self._stub_update_function
+
+        super().__init__(hs.get_instance_name(), current_token, update_function)
+
+    @staticmethod
+    async def _stub_update_function(instance_name, from_token, upto_token, limit):
+        return [], upto_token, False