summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/__init__.py2
-rw-r--r--synapse/replication/http/streams.py78
-rw-r--r--synapse/replication/slave/storage/_base.py14
-rw-r--r--synapse/replication/slave/storage/pushers.py3
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--synapse/replication/tcp/commands.py34
-rw-r--r--synapse/replication/tcp/protocol.py206
-rw-r--r--synapse/replication/tcp/resource.py19
-rw-r--r--synapse/replication/tcp/streams/__init__.py8
-rw-r--r--synapse/replication/tcp/streams/_base.py160
-rw-r--r--synapse/replication/tcp/streams/events.py5
-rw-r--r--synapse/replication/tcp/streams/federation.py19
12 files changed, 319 insertions, 232 deletions
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 28dbc6fcba..4613b2538c 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -21,6 +21,7 @@ from synapse.replication.http import (
     membership,
     register,
     send_event,
+    streams,
 )
 
 REPLICATION_PREFIX = "/_synapse/replication"
@@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
         login.register_servlets(hs, self)
         register.register_servlets(hs, self)
         devices.register_servlets(hs, self)
+        streams.register_servlets(hs, self)
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
new file mode 100644
index 0000000000..ffd4c61993
--- /dev/null
+++ b/synapse/replication/http/streams.py
@@ -0,0 +1,78 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import parse_integer
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationGetStreamUpdates(ReplicationEndpoint):
+    """Fetches stream updates from a server. Used for streams not persisted to
+    the database, e.g. typing notifications.
+
+    The API looks like:
+
+        GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
+
+        200 OK
+
+        {
+            updates: [ ... ],
+            upto_token: 10,
+            limited: False,
+        }
+
+    """
+
+    NAME = "get_repl_stream_updates"
+    PATH_ARGS = ("stream_name",)
+    METHOD = "GET"
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
+        # We pull the streams from the replication steamer (if we try and make
+        # them ourselves we end up in an import loop).
+        self.streams = hs.get_replication_streamer().get_streams()
+
+    @staticmethod
+    def _serialize_payload(stream_name, from_token, upto_token, limit):
+        return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
+
+    async def _handle_request(self, request, stream_name):
+        stream = self.streams.get(stream_name)
+        if stream is None:
+            raise SynapseError(400, "Unknown stream")
+
+        from_token = parse_integer(request, "from_token", required=True)
+        upto_token = parse_integer(request, "upto_token", required=True)
+        limit = parse_integer(request, "limit", required=True)
+
+        updates, upto_token, limited = await stream.get_updates_since(
+            from_token, upto_token, limit
+        )
+
+        return (
+            200,
+            {"updates": updates, "upto_token": upto_token, "limited": limited},
+        )
+
+
+def register_servlets(hs, http_server):
+    ReplicationGetStreamUpdates(hs).register(http_server)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f45cbd37a0..751c799d94 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,8 +18,10 @@ from typing import Dict, Optional
 
 import six
 
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.data_stores.main.cache import (
+    CURRENT_STATE_CACHE_NAME,
+    CacheInvalidationWorkerStore,
+)
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 
@@ -35,7 +37,7 @@ def __func__(inp):
         return inp.__func__
 
 
-class BaseSlavedStore(SQLBaseStore):
+class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: Database, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
@@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
             pos["caches"] = self._cache_id_gen.get_current_token()
         return pos
 
+    def get_cache_stream_token(self):
+        if self._cache_id_gen:
+            return self._cache_id_gen.get_current_token()
+        else:
+            return 0
+
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "caches":
             if self._cache_id_gen:
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index f22c2d44a3..bce8a3d115 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
         result["pushers"] = self._pushers_id_gen.get_current_token()
         return result
 
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
+
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "pushers":
             self._pushers_id_gen.advance(token)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 02ab5b66ea..7e7ad0f798 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
         self.client_name = client_name
         self.handler = handler
         self.server_name = hs.config.server_name
+        self.hs = hs
         self._clock = hs.get_clock()  # As self.clock is defined in super class
 
         hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     def buildProtocol(self, addr):
         logger.info("Connected to replication: %r", addr)
         return ClientReplicationStreamProtocol(
-            self.client_name, self.server_name, self._clock, self.handler
+            self.hs, self.client_name, self.server_name, self._clock, self.handler,
         )
 
     def clientConnectionLost(self, connector, reason):
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 451671412d..5a6b734094 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -136,8 +136,8 @@ class PositionCommand(Command):
     """Sent by the server to tell the client the stream postition without
     needing to send an RDATA.
 
-    Sent to the client after all missing updates for a stream have been sent
-    to the client and they're now up to date.
+    On receipt of a POSITION command clients should check if they have missed
+    any updates, and if so then fetch them out of band.
     """
 
     NAME = "POSITION"
@@ -179,42 +179,24 @@ class NameCommand(Command):
 
 
 class ReplicateCommand(Command):
-    """Sent by the client to subscribe to the stream.
+    """Sent by the client to subscribe to streams.
 
     Format::
 
-        REPLICATE <stream_name> <token>
-
-    Where <token> may be either:
-        * a numeric stream_id to stream updates from
-        * "NOW" to stream all subsequent updates.
-
-    The <stream_name> can be "ALL" to subscribe to all known streams, in which
-    case the <token> must be set to "NOW", i.e.::
-
-        REPLICATE ALL NOW
+        REPLICATE
     """
 
     NAME = "REPLICATE"
 
-    def __init__(self, stream_name, token):
-        self.stream_name = stream_name
-        self.token = token
+    def __init__(self):
+        pass
 
     @classmethod
     def from_line(cls, line):
-        stream_name, token = line.split(" ", 1)
-        if token in ("NOW", "now"):
-            token = "NOW"
-        else:
-            token = int(token)
-        return cls(stream_name, token)
+        return cls()
 
     def to_line(self):
-        return " ".join((self.stream_name, str(self.token)))
-
-    def get_logcontext_id(self):
-        return "REPLICATE-" + self.stream_name
+        return ""
 
 
 class UserSyncCommand(Command):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index bc1482a9bb..f81d2e2442 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
     > PING 1490197665618
     < NAME synapse.app.appservice
     < PING 1490197665618
-    < REPLICATE events 1
-    < REPLICATE backfill 1
-    < REPLICATE caches 1
+    < REPLICATE
     > POSITION events 1
     > POSITION backfill 1
     > POSITION caches 1
@@ -53,17 +51,15 @@ import fcntl
 import logging
 import struct
 from collections import defaultdict
-from typing import Any, DefaultDict, Dict, List, Set, Tuple
+from typing import Any, DefaultDict, Dict, List, Set
 
-from six import iteritems, iterkeys
+from six import iteritems
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
 from twisted.python.failure import Failure
 
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.commands import (
@@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
     SyncCommand,
     UserSyncCommand,
 )
-from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
 from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
+MYPY = False
+if MYPY:
+    from synapse.server import HomeServer
+
+
 connection_close_counter = Counter(
     "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
 )
@@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.server_name = server_name
         self.streamer = streamer
 
-        # The streams the client has subscribed to and is up to date with
-        self.replication_streams = set()  # type: Set[str]
-
-        # The streams the client is currently subscribing to.
-        self.connecting_streams = set()  # type:  Set[str]
-
-        # Map from stream name to list of updates to send once we've finished
-        # subscribing the client to the stream.
-        self.pending_rdata = {}  # type: Dict[str, List[Tuple[int, Any]]]
-
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
         BaseReplicationStreamProtocol.connectionMade(self)
@@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
         )
 
     async def on_REPLICATE(self, cmd):
-        stream_name = cmd.stream_name
-        token = cmd.token
-
-        if stream_name == "ALL":
-            # Subscribe to all streams we're publishing to.
-            deferreds = [
-                run_in_background(self.subscribe_to_stream, stream, token)
-                for stream in iterkeys(self.streamer.streams_by_name)
-            ]
-
-            await make_deferred_yieldable(
-                defer.gatherResults(deferreds, consumeErrors=True)
-            )
-        else:
-            await self.subscribe_to_stream(stream_name, token)
+        # Subscribe to all streams we're publishing to.
+        for stream_name in self.streamer.streams_by_name:
+            current_token = self.streamer.get_stream_token(stream_name)
+            self.send_command(PositionCommand(stream_name, current_token))
 
     async def on_FEDERATION_ACK(self, cmd):
         self.streamer.federation_ack(cmd.token)
@@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
             cmd.last_seen,
         )
 
-    async def subscribe_to_stream(self, stream_name, token):
-        """Subscribe the remote to a stream.
-
-        This invloves checking if they've missed anything and sending those
-        updates down if they have. During that time new updates for the stream
-        are queued and sent once we've sent down any missed updates.
-        """
-        self.replication_streams.discard(stream_name)
-        self.connecting_streams.add(stream_name)
-
-        try:
-            # Get missing updates
-            updates, current_token = await self.streamer.get_stream_updates(
-                stream_name, token
-            )
-
-            # Send all the missing updates
-            for update in updates:
-                token, row = update[0], update[1]
-                self.send_command(RdataCommand(stream_name, token, row))
-
-            # We send a POSITION command to ensure that they have an up to
-            # date token (especially useful if we didn't send any updates
-            # above)
-            self.send_command(PositionCommand(stream_name, current_token))
-
-            # Now we can send any updates that came in while we were subscribing
-            pending_rdata = self.pending_rdata.pop(stream_name, [])
-            updates = []
-            for token, update in pending_rdata:
-                # If the token is null, it is part of a batch update. Batches
-                # are multiple updates that share a single token. To denote
-                # this, the token is set to None for all tokens in the batch
-                # except for the last. If we find a None token, we keep looking
-                # through tokens until we find one that is not None and then
-                # process all previous updates in the batch as if they had the
-                # final token.
-                if token is None:
-                    # Store this update as part of a batch
-                    updates.append(update)
-                    continue
-
-                if token <= current_token:
-                    # This update or batch of updates is older than
-                    # current_token, dismiss it
-                    updates = []
-                    continue
-
-                updates.append(update)
-
-                # Send all updates that are part of this batch with the
-                # found token
-                for update in updates:
-                    self.send_command(RdataCommand(stream_name, token, update))
-
-                # Clear stored updates
-                updates = []
-
-            # They're now fully subscribed
-            self.replication_streams.add(stream_name)
-        except Exception as e:
-            logger.exception("[%s] Failed to handle REPLICATE command", self.id())
-            self.send_error("failed to handle replicate: %r", e)
-        finally:
-            self.connecting_streams.discard(stream_name)
-
     def stream_update(self, stream_name, token, data):
         """Called when a new update is available to stream to clients.
 
         We need to check if the client is interested in the stream or not
         """
-        if stream_name in self.replication_streams:
-            # The client is subscribed to the stream
-            self.send_command(RdataCommand(stream_name, token, data))
-        elif stream_name in self.connecting_streams:
-            # The client is being subscribed to the stream
-            logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
-            self.pending_rdata.setdefault(stream_name, []).append((token, data))
-        else:
-            # The client isn't subscribed
-            logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
+        self.send_command(RdataCommand(stream_name, token, data))
 
     def send_sync(self, data):
         self.send_command(SyncCommand(data))
@@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
     def __init__(
         self,
+        hs: "HomeServer",
         client_name: str,
         server_name: str,
         clock: Clock,
@@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.server_name = server_name
         self.handler = handler
 
+        self.streams = {
+            stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+        }  # type: Dict[str, Stream]
+
         # Set of stream names that have been subscribe to, but haven't yet
         # caught up with. This is used to track when the client has been fully
         # connected to the remote.
-        self.streams_connecting = set()  # type: Set[str]
+        self.streams_connecting = set(STREAMS_MAP)  # type: Set[str]
 
         # Map of stream to batched updates. See RdataCommand for info on how
         # batching works.
-        self.pending_batches = {}  # type: Dict[str, Any]
+        self.pending_batches = {}  # type: Dict[str, List[Any]]
 
     def connectionMade(self):
         self.send_command(NameCommand(self.client_name))
         BaseReplicationStreamProtocol.connectionMade(self)
 
         # Once we've connected subscribe to the necessary streams
-        for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
-            self.replicate(stream_name, token)
+        self.replicate()
 
         # Tell the server if we have any users currently syncing (should only
         # happen on synchrotrons)
@@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         # We've now finished connecting to so inform the client handler
         self.handler.update_connection(self)
 
-        # This will happen if we don't actually subscribe to any streams
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
-
     async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             )
             raise
 
-        if cmd.token is None:
+        if cmd.token is None or stream_name in self.streams_connecting:
             # I.e. this is part of a batch of updates for this stream. Batch
             # until we get an update for the stream with a non None token
             self.pending_batches.setdefault(stream_name, []).append(row)
@@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
             rows.append(row)
             await self.handler.on_rdata(stream_name, cmd.token, rows)
 
-    async def on_POSITION(self, cmd):
-        # When we get a `POSITION` command it means we've finished getting
-        # missing updates for the given stream, and are now up to date.
+    async def on_POSITION(self, cmd: PositionCommand):
+        stream = self.streams.get(cmd.stream_name)
+        if not stream:
+            logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+            return
+
+        # Find where we previously streamed up to.
+        current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
+        if current_token is None:
+            logger.warning(
+                "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
+            )
+            return
+
+        # Fetch all updates between then and now.
+        limited = True
+        while limited:
+            updates, current_token, limited = await stream.get_updates_since(
+                current_token, cmd.token
+            )
+
+            # Check if the connection was closed underneath us, if so we bail
+            # rather than risk having concurrent catch ups going on.
+            if self.state == ConnectionStates.CLOSED:
+                return
+
+            if updates:
+                await self.handler.on_rdata(
+                    cmd.stream_name,
+                    current_token,
+                    [stream.parse_row(update[1]) for update in updates],
+                )
+
+        # We've now caught up to position sent to us, notify handler.
+        await self.handler.on_position(cmd.stream_name, cmd.token)
+
         self.streams_connecting.discard(cmd.stream_name)
         if not self.streams_connecting:
             self.handler.finished_connecting()
 
-        await self.handler.on_position(cmd.stream_name, cmd.token)
+        # Check if the connection was closed underneath us, if so we bail
+        # rather than risk having concurrent catch ups going on.
+        if self.state == ConnectionStates.CLOSED:
+            return
+
+        # Handle any RDATA that came in while we were catching up.
+        rows = self.pending_batches.pop(cmd.stream_name, [])
+        if rows:
+            await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
 
     async def on_SYNC(self, cmd):
         self.handler.on_sync(cmd.data)
@@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
         self.handler.on_remote_server_up(cmd.data)
 
-    def replicate(self, stream_name, token):
+    def replicate(self):
         """Send the subscription request to the server
         """
-        if stream_name not in STREAMS_MAP:
-            raise Exception("Invalid stream name %r" % (stream_name,))
-
-        logger.info(
-            "[%s] Subscribing to replication stream: %r from %r",
-            self.id(),
-            stream_name,
-            token,
-        )
-
-        self.streams_connecting.add(stream_name)
+        logger.info("[%s] Subscribing to replication streams", self.id())
 
-        self.send_command(ReplicateCommand(stream_name, token))
+        self.send_command(ReplicateCommand())
 
     def on_connection_closed(self):
         BaseReplicationStreamProtocol.on_connection_closed(self)
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 6e2ebaf614..4374e99e32 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,7 +17,7 @@
 
 import logging
 import random
-from typing import Any, List
+from typing import Any, Dict, List
 
 from six import itervalues
 
@@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util.metrics import Measure, measure_func
 
 from .protocol import ServerReplicationStreamProtocol
-from .streams import STREAMS_MAP
+from .streams import STREAMS_MAP, Stream
 from .streams.federation import FederationStream
 
 stream_updates_counter = Counter(
@@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory):
     """
 
     def __init__(self, hs):
-        self.streamer = ReplicationStreamer(hs)
+        self.streamer = hs.get_replication_streamer()
         self.clock = hs.get_clock()
         self.server_name = hs.config.server_name
 
@@ -133,6 +133,11 @@ class ReplicationStreamer(object):
         for conn in self.connections:
             conn.send_error("server shutting down")
 
+    def get_streams(self) -> Dict[str, Stream]:
+        """Get a mapp from stream name to stream instance.
+        """
+        return self.streams_by_name
+
     def on_notifier_poke(self):
         """Checks if there is actually any new data and sends it to the
         connections if there are.
@@ -190,7 +195,8 @@ class ReplicationStreamer(object):
                             stream.current_token(),
                         )
                         try:
-                            updates, current_token = await stream.get_updates()
+                            updates, current_token, limited = await stream.get_updates()
+                            self.pending_updates |= limited
                         except Exception:
                             logger.info("Failed to handle stream %s", stream.NAME)
                             raise
@@ -226,8 +232,7 @@ class ReplicationStreamer(object):
             self.pending_updates = False
             self.is_looping = False
 
-    @measure_func("repl.get_stream_updates")
-    async def get_stream_updates(self, stream_name, token):
+    def get_stream_token(self, stream_name):
         """For a given stream get all updates since token. This is called when
         a client first subscribes to a stream.
         """
@@ -235,7 +240,7 @@ class ReplicationStreamer(object):
         if not stream:
             raise Exception("unknown stream %s", stream_name)
 
-        return await stream.get_updates_since(token)
+        return stream.current_token()
 
     @measure_func("repl.federation_ack")
     def federation_ack(self, token):
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 29199f5b46..37bcd3de66 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -24,6 +24,9 @@ Each stream is defined by the following information:
     current_token:      The function that returns the current token for the stream
     update_function:    The function that returns a list of updates between two tokens
 """
+
+from typing import Dict, Type
+
 from synapse.replication.tcp.streams._base import (
     AccountDataStream,
     BackfillStream,
@@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import (
     PushersStream,
     PushRulesStream,
     ReceiptsStream,
+    Stream,
     TagAccountDataStream,
     ToDeviceStream,
     TypingStream,
@@ -63,10 +67,12 @@ STREAMS_MAP = {
         GroupServerStream,
         UserSignatureStream,
     )
-}
+}  # type: Dict[str, Type[Stream]]
+
 
 __all__ = [
     "STREAMS_MAP",
+    "Stream",
     "BackfillStream",
     "PresenceStream",
     "TypingStream",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 32d9514883..c14dff6c64 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,13 +14,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import itertools
 import logging
 from collections import namedtuple
-from typing import Any, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
 
 import attr
 
+from synapse.replication.http.streams import ReplicationGetStreamUpdates
 from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
@@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
 MAX_EVENTS_BEHIND = 500000
 
 
+# Some type aliases to make things a bit easier.
+
+# A stream position token
+Token = int
+
+# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
+StreamRow = Tuple[Token, tuple]
+
+
 class Stream(object):
     """Base class for the streams.
 
@@ -56,6 +65,7 @@ class Stream(object):
         return cls.ROW_TYPE(*row)
 
     def __init__(self, hs):
+
         # The token from which we last asked for updates
         self.last_token = self.current_token()
 
@@ -65,61 +75,46 @@ class Stream(object):
         """
         self.last_token = self.current_token()
 
-    async def get_updates(self):
+    async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
         """Gets all updates since the last time this function was called (or
         since the stream was constructed if it hadn't been called before).
 
         Returns:
-            Deferred[Tuple[List[Tuple[int, Any]], int]:
-                Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
-                list of ``(token, row)`` entries. ``row`` will be json-serialised and
-                sent over the replication steam.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
-        updates, current_token = await self.get_updates_since(self.last_token)
+        current_token = self.current_token()
+        updates, current_token, limited = await self.get_updates_since(
+            self.last_token, current_token
+        )
         self.last_token = current_token
 
-        return updates, current_token
+        return updates, current_token, limited
 
     async def get_updates_since(
-        self, from_token: int
-    ) -> Tuple[List[Tuple[int, JsonDict]], int]:
+        self, from_token: Token, upto_token: Token, limit: int = 100
+    ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
         """Like get_updates except allows specifying from when we should
         stream updates
 
         Returns:
-            Resolves to a pair `(updates, new_last_token)`, where `updates` is
-            a list of `(token, row)` entries and `new_last_token` is the new
-            position in stream.
+            A triplet `(updates, new_last_token, limited)`, where `updates` is
+            a list of `(token, row)` entries, `new_last_token` is the new
+            position in stream, and `limited` is whether there are more updates
+            to fetch.
         """
 
-        if from_token in ("NOW", "now"):
-            return [], self.current_token()
-
-        current_token = self.current_token()
-
         from_token = int(from_token)
 
-        if from_token == current_token:
-            return [], current_token
+        if from_token == upto_token:
+            return [], upto_token, False
 
-        rows = await self.update_function(
-            from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
+        updates, upto_token, limited = await self.update_function(
+            from_token, upto_token, limit=limit,
         )
-
-        # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
-        rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
-
-        updates = [(row[0], row[1:]) for row in rows]
-
-        # check we didn't get more rows than the limit.
-        # doing it like this allows the update_function to be a generator.
-        if len(updates) >= MAX_EVENTS_BEHIND:
-            raise Exception("stream %s has fallen behind" % (self.NAME))
-
-        # The update function didn't hit the limit, so we must have got all
-        # the updates to `current_token`, and can return that as our new
-        # stream position.
-        return updates, current_token
+        return updates, upto_token, limited
 
     def current_token(self):
         """Gets the current token of the underlying streams. Should be provided
@@ -141,6 +136,48 @@ class Stream(object):
         raise NotImplementedError()
 
 
+def db_query_to_update_function(
+    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
+) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+    """Wraps a db query function which returns a list of rows to make it
+    suitable for use as an `update_function` for the Stream class
+    """
+
+    async def update_function(from_token, upto_token, limit):
+        rows = await query_function(from_token, upto_token, limit)
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) == limit:
+            upto_token = rows[-1][0]
+            limited = True
+
+        return updates, upto_token, limited
+
+    return update_function
+
+
+def make_http_update_function(
+    hs, stream_name: str
+) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
+    """Makes a suitable function for use as an `update_function` that queries
+    the master process for updates.
+    """
+
+    client = ReplicationGetStreamUpdates.make_client(hs)
+
+    async def update_function(
+        from_token: int, upto_token: int, limit: int
+    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+        return await client(
+            stream_name=stream_name,
+            from_token=from_token,
+            upto_token=upto_token,
+            limit=limit,
+        )
+
+    return update_function
+
+
 class BackfillStream(Stream):
     """We fetched some old events and either we had never seen that event before
     or it went from being an outlier to not.
@@ -164,7 +201,7 @@ class BackfillStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         self.current_token = store.get_current_backfill_token  # type: ignore
-        self.update_function = store.get_all_new_backfill_event_rows  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows)  # type: ignore
 
         super(BackfillStream, self).__init__(hs)
 
@@ -190,8 +227,15 @@ class PresenceStream(Stream):
         store = hs.get_datastore()
         presence_handler = hs.get_presence_handler()
 
+        self._is_worker = hs.config.worker_app is not None
+
         self.current_token = store.get_current_presence_token  # type: ignore
-        self.update_function = presence_handler.get_all_presence_updates  # type: ignore
+
+        if hs.config.worker_app is None:
+            self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates)  # type: ignore
+        else:
+            # Query master process
+            self.update_function = make_http_update_function(hs, self.NAME)  # type: ignore
 
         super(PresenceStream, self).__init__(hs)
 
@@ -208,7 +252,12 @@ class TypingStream(Stream):
         typing_handler = hs.get_typing_handler()
 
         self.current_token = typing_handler.get_current_token  # type: ignore
-        self.update_function = typing_handler.get_all_typing_updates  # type: ignore
+
+        if hs.config.worker_app is None:
+            self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates)  # type: ignore
+        else:
+            # Query master process
+            self.update_function = make_http_update_function(hs, self.NAME)  # type: ignore
 
         super(TypingStream, self).__init__(hs)
 
@@ -232,7 +281,7 @@ class ReceiptsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_max_receipt_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_receipts  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_receipts)  # type: ignore
 
         super(ReceiptsStream, self).__init__(hs)
 
@@ -256,7 +305,13 @@ class PushRulesStream(Stream):
 
     async def update_function(self, from_token, to_token, limit):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
-        return [(row[0], row[2]) for row in rows]
+
+        limited = False
+        if len(rows) == limit:
+            to_token = rows[-1][0]
+            limited = True
+
+        return [(row[0], (row[2],)) for row in rows], to_token, limited
 
 
 class PushersStream(Stream):
@@ -275,7 +330,7 @@ class PushersStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_pushers_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_pushers_rows  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows)  # type: ignore
 
         super(PushersStream, self).__init__(hs)
 
@@ -307,7 +362,7 @@ class CachesStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_cache_stream_token  # type: ignore
-        self.update_function = store.get_all_updated_caches  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_caches)  # type: ignore
 
         super(CachesStream, self).__init__(hs)
 
@@ -333,7 +388,7 @@ class PublicRoomsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_current_public_room_stream_id  # type: ignore
-        self.update_function = store.get_all_new_public_rooms  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_public_rooms)  # type: ignore
 
         super(PublicRoomsStream, self).__init__(hs)
 
@@ -354,7 +409,7 @@ class DeviceListsStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_device_list_changes_for_remotes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes)  # type: ignore
 
         super(DeviceListsStream, self).__init__(hs)
 
@@ -372,7 +427,7 @@ class ToDeviceStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_to_device_stream_token  # type: ignore
-        self.update_function = store.get_all_new_device_messages  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_new_device_messages)  # type: ignore
 
         super(ToDeviceStream, self).__init__(hs)
 
@@ -392,7 +447,7 @@ class TagAccountDataStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_max_account_data_stream_id  # type: ignore
-        self.update_function = store.get_all_updated_tags  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_updated_tags)  # type: ignore
 
         super(TagAccountDataStream, self).__init__(hs)
 
@@ -412,10 +467,11 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastore()
 
         self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
+        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
 
         super(AccountDataStream, self).__init__(hs)
 
-    async def update_function(self, from_token, to_token, limit):
+    async def _update_function(self, from_token, to_token, limit):
         global_results, room_results = await self.store.get_all_updated_account_data(
             from_token, from_token, to_token, limit
         )
@@ -442,7 +498,7 @@ class GroupServerStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_group_stream_token  # type: ignore
-        self.update_function = store.get_all_groups_changes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_groups_changes)  # type: ignore
 
         super(GroupServerStream, self).__init__(hs)
 
@@ -460,6 +516,6 @@ class UserSignatureStream(Stream):
         store = hs.get_datastore()
 
         self.current_token = store.get_device_stream_token  # type: ignore
-        self.update_function = store.get_all_user_signature_changes_for_remotes  # type: ignore
+        self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes)  # type: ignore
 
         super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index b3afabb8cd..c6a595629f 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -19,7 +19,7 @@ from typing import Tuple, Type
 
 import attr
 
-from ._base import Stream
+from ._base import Stream, db_query_to_update_function
 
 
 """Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
     def __init__(self, hs):
         self._store = hs.get_datastore()
         self.current_token = self._store.get_current_events_token  # type: ignore
+        self.update_function = db_query_to_update_function(self._update_function)  # type: ignore
 
         super(EventsStream, self).__init__(hs)
 
-    async def update_function(self, from_token, current_token, limit=None):
+    async def _update_function(self, from_token, current_token, limit=None):
         event_rows = await self._store.get_all_new_forward_event_rows(
             from_token, current_token, limit
         )
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index f5f9336430..48c1d45718 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,9 @@
 # limitations under the License.
 from collections import namedtuple
 
-from ._base import Stream
+from twisted.internet import defer
+
+from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
 
 
 class FederationStream(Stream):
@@ -33,11 +35,18 @@ class FederationStream(Stream):
 
     NAME = "federation"
     ROW_TYPE = FederationStreamRow
+    _QUERY_MASTER = True
 
     def __init__(self, hs):
-        federation_sender = hs.get_federation_sender()
-
-        self.current_token = federation_sender.get_current_token  # type: ignore
-        self.update_function = federation_sender.get_replication_rows  # type: ignore
+        # Not all synapse instances will have a federation sender instance,
+        # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
+        # so we stub the stream out when that is the case.
+        if hs.config.worker_app is None or hs.should_send_federation():
+            federation_sender = hs.get_federation_sender()
+            self.current_token = federation_sender.get_current_token  # type: ignore
+            self.update_function = db_query_to_update_function(federation_sender.get_replication_rows)  # type: ignore
+        else:
+            self.current_token = lambda: 0  # type: ignore
+            self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool))  # type: ignore
 
         super(FederationStream, self).__init__(hs)