summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/http/_base.py19
-rw-r--r--synapse/replication/http/streams.py4
-rw-r--r--synapse/replication/slave/storage/_base.py15
-rw-r--r--synapse/replication/slave/storage/account_data.py8
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py5
-rw-r--r--synapse/replication/slave/storage/devices.py10
-rw-r--r--synapse/replication/slave/storage/events.py6
-rw-r--r--synapse/replication/slave/storage/groups.py5
-rw-r--r--synapse/replication/slave/storage/presence.py9
-rw-r--r--synapse/replication/slave/storage/push_rule.py5
-rw-r--r--synapse/replication/slave/storage/pushers.py5
-rw-r--r--synapse/replication/slave/storage/receipts.py5
-rw-r--r--synapse/replication/slave/storage/room.py5
-rw-r--r--synapse/replication/tcp/client.py31
-rw-r--r--synapse/replication/tcp/handler.py30
-rw-r--r--synapse/replication/tcp/redis.py3
-rw-r--r--synapse/replication/tcp/streams/_base.py50
-rw-r--r--synapse/replication/tcp/streams/events.py10
-rw-r--r--synapse/replication/tcp/streams/federation.py4
19 files changed, 96 insertions, 133 deletions
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 1be1ccbdf3..f88c80ae84 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,6 +16,7 @@
 import abc
 import logging
 import re
+from inspect import signature
 from typing import Dict, List, Tuple
 
 from six import raise_from
@@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
     must call `register` to register the path with the HTTP server.
 
     Requests can be sent by calling the client returned by `make_client`.
+    Requests are sent to master process by default, but can be sent to other
+    named processes by specifying an `instance_name` keyword argument.
 
     Attributes:
         NAME (str): A name for the endpoint, added to the path as well as used
@@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
                 hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
             )
 
+        # We reserve `instance_name` as a parameter to sending requests, so we
+        # assert here that sub classes don't try and use the name.
+        assert (
+            "instance_name" not in self.PATH_ARGS
+        ), "`instance_name` is a reserved paramater name"
+        assert (
+            "instance_name"
+            not in signature(self.__class__._serialize_payload).parameters
+        ), "`instance_name` is a reserved paramater name"
+
         assert self.METHOD in ("PUT", "POST", "GET")
 
     @abc.abstractmethod
@@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
 
         @trace(opname="outgoing_replication_request")
         @defer.inlineCallbacks
-        def send_request(**kwargs):
+        def send_request(instance_name="master", **kwargs):
+            # Currently we only support sending requests to master process.
+            if instance_name != "master":
+                raise Exception("Unknown instance")
+
             data = yield cls._serialize_payload(**kwargs)
 
             url_args = [
diff --git a/synapse/replication/http/streams.py b/synapse/replication/http/streams.py
index f35cebc710..0459f582bf 100644
--- a/synapse/replication/http/streams.py
+++ b/synapse/replication/http/streams.py
@@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
     def __init__(self, hs):
         super().__init__(hs)
 
+        self._instance_name = hs.get_instance_name()
+
         # We pull the streams from the replication steamer (if we try and make
         # them ourselves we end up in an import loop).
         self.streams = hs.get_replication_streamer().get_streams()
@@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
         upto_token = parse_integer(request, "upto_token", required=True)
 
         updates, upto_token, limited = await stream.get_updates_since(
-            from_token, upto_token
+            self._instance_name, from_token, upto_token
         )
 
         return (
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 751c799d94..5d7c8871a4 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Optional
+from typing import Optional
 
 import six
 
@@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
 
         self.hs = hs
 
-    def stream_positions(self) -> Dict[str, int]:
-        """
-        Get the current positions of all the streams this store wants to subscribe to
-
-        Returns:
-            map from stream name to the most recent update we have for
-            that stream (ie, the point we want to start replicating from)
-        """
-        pos = {}
-        if self._cache_id_gen:
-            pos["caches"] = self._cache_id_gen.get_current_token()
-        return pos
-
     def get_cache_stream_token(self):
         if self._cache_id_gen:
             return self._cache_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index ebe94909cb..65e54b1c71 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedAccountDataStore, self).stream_positions()
-        position = self._account_data_id_gen.get_current_token()
-        result["user_account_data"] = position
-        result["room_account_data"] = position
-        result["tag_account_data"] = position
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "tag_account_data":
             self._account_data_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 0c237c6e0f..c923751e50 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
             expiry_ms=30 * 60 * 1000,
         )
 
-    def stream_positions(self):
-        result = super(SlavedDeviceInboxStore, self).stream_positions()
-        result["to_device"] = self._device_inbox_id_gen.get_current_token()
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "to_device":
             self._device_inbox_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 23b1650e41..58fb0eaae3 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             "DeviceListFederationStreamChangeCache", device_list_max
         )
 
-    def stream_positions(self):
-        result = super(SlavedDeviceStore, self).stream_positions()
-        # The user signature stream uses the same stream ID generator as the
-        # device list stream, so set them both to the device list ID
-        # generator's current token.
-        current_token = self._device_list_id_gen.get_current_token()
-        result[DeviceListsStream.NAME] = current_token
-        result[UserSignatureStream.NAME] = current_token
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index e73342c657..15011259df 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -93,12 +93,6 @@ class SlavedEventStore(
     def get_room_min_stream_ordering(self):
         return self._backfill_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedEventStore, self).stream_positions()
-        result["events"] = self._stream_id_gen.get_current_token()
-        result["backfill"] = -self._backfill_id_gen.get_current_token()
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "events":
             self._stream_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 2d4fd08cf5..01bcf0e882 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
     def get_group_stream_token(self):
         return self._group_updates_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedGroupServerStore, self).stream_positions()
-        result["groups"] = self._group_updates_id_gen.get_current_token()
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "groups":
             self._group_updates_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index ad8f0c15a9..fae3125072 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedPresenceStore, self).stream_positions()
-
-        if self.hs.config.use_presence:
-            position = self._presence_id_gen.get_current_token()
-            result["presence"] = position
-
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "presence":
             self._presence_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index eebd5a1fb6..6138796da4 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedPushRuleStore, self).stream_positions()
-        result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "push_rules":
             self._push_rules_stream_id_gen.advance(token)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index bce8a3d115..67be337945 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
 
-    def stream_positions(self):
-        result = super(SlavedPusherStore, self).stream_positions()
-        result["pushers"] = self._pushers_id_gen.get_current_token()
-        return result
-
     def get_pushers_stream_token(self):
         return self._pushers_id_gen.get_current_token()
 
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index d40dc6e1f5..993432edcb 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(SlavedReceiptsStore, self).stream_positions()
-        result["receipts"] = self._receipts_id_gen.get_current_token()
-        return result
-
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate_many((room_id,))
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 3a20f45316..10dda8708f 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def stream_positions(self):
-        result = super(RoomStore, self).stream_positions()
-        result["public_rooms"] = self._public_room_id_gen.get_current_token()
-        return result
-
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "public_rooms":
             self._public_room_id_gen.advance(token)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 2d07b8b2d0..3bbf3c3569 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,7 +16,7 @@
 """
 
 import logging
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING
 
 from twisted.internet.protocol import ReconnectingClientFactory
 
@@ -86,37 +86,22 @@ class ReplicationDataHandler:
     def __init__(self, store: BaseSlavedStore):
         self.store = store
 
-    async def on_rdata(self, stream_name: str, token: int, rows: list):
+    async def on_rdata(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
         """Called to handle a batch of replication data with a given stream token.
 
         By default this just pokes the slave store. Can be overridden in subclasses to
         handle more.
 
         Args:
-            stream_name (str): name of the replication stream for this batch of rows
-            token (int): stream token for this batch of rows
-            rows (list): a list of Stream.ROW_TYPE objects as returned by
-                Stream.parse_row.
+            stream_name: name of the replication stream for this batch of rows
+            instance_name: the instance that wrote the rows.
+            token: stream token for this batch of rows
+            rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
         self.store.process_replication_rows(stream_name, token, rows)
 
-    def get_streams_to_replicate(self) -> Dict[str, int]:
-        """Called when a new connection has been established and we need to
-        subscribe to streams.
-
-        Returns:
-            map from stream name to the most recent update we have for
-            that stream (ie, the point we want to start replicating from)
-        """
-        args = self.store.stream_positions()
-        user_account_data = args.pop("user_account_data", None)
-        room_account_data = args.pop("room_account_data", None)
-        if user_account_data:
-            args["account_data"] = user_account_data
-        elif room_account_data:
-            args["account_data"] = room_account_data
-        return args
-
     async def on_position(self, stream_name: str, token: int):
         self.store.process_replication_rows(stream_name, token, [])
 
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 6f7054d5af..2d1d119c7c 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -278,19 +278,24 @@ class ReplicationCommandHandler:
                 # Check if this is the last of a batch of updates
                 rows = self._pending_batches.pop(stream_name, [])
                 rows.append(row)
-                await self.on_rdata(stream_name, cmd.token, rows)
+                await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
 
-    async def on_rdata(self, stream_name: str, token: int, rows: list):
+    async def on_rdata(
+        self, stream_name: str, instance_name: str, token: int, rows: list
+    ):
         """Called to handle a batch of replication data with a given stream token.
 
         Args:
             stream_name: name of the replication stream for this batch of rows
+            instance_name: the instance that wrote the rows.
             token: stream token for this batch of rows
             rows: a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
         """
         logger.debug("Received rdata %s -> %s", stream_name, token)
-        await self._replication_data_handler.on_rdata(stream_name, token, rows)
+        await self._replication_data_handler.on_rdata(
+            stream_name, instance_name, token, rows
+        )
 
     async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
         if cmd.instance_name == self._instance_name:
@@ -314,15 +319,7 @@ class ReplicationCommandHandler:
             self._pending_batches.pop(cmd.stream_name, [])
 
             # Find where we previously streamed up to.
-            current_token = self._replication_data_handler.get_streams_to_replicate().get(
-                cmd.stream_name
-            )
-            if current_token is None:
-                logger.warning(
-                    "Got POSITION for stream we're not subscribed to: %s",
-                    cmd.stream_name,
-                )
-                return
+            current_token = stream.current_token()
 
             # If the position token matches our current token then we're up to
             # date and there's nothing to do. Otherwise, fetch all updates
@@ -333,7 +330,9 @@ class ReplicationCommandHandler:
                     updates,
                     current_token,
                     missing_updates,
-                ) = await stream.get_updates_since(current_token, cmd.token)
+                ) = await stream.get_updates_since(
+                    cmd.instance_name, current_token, cmd.token
+                )
 
                 # TODO: add some tests for this
 
@@ -342,7 +341,10 @@ class ReplicationCommandHandler:
 
                 for token, rows in _batch_updates(updates):
                     await self.on_rdata(
-                        cmd.stream_name, token, [stream.parse_row(row) for row in rows],
+                        cmd.stream_name,
+                        cmd.instance_name,
+                        token,
+                        [stream.parse_row(row) for row in rows],
                     )
 
             # We've now caught up to position sent to us, notify handler.
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 617e860f95..41c623d737 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -61,6 +61,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
     outbound_redis_connection = None  # type: txredisapi.RedisProtocol
 
     def connectionMade(self):
+        super().connectionMade()
         logger.info("Connected to redis instance")
         self.subscribe(self.stream_name)
         self.send_command(ReplicateCommand())
@@ -119,6 +120,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
             logger.warning("Unhandled command: %r", cmd)
 
     def connectionLost(self, reason):
+        super().connectionLost(reason)
         logger.info("Lost connection to redis instance")
         self.handler.lost_connection(self)
 
@@ -189,5 +191,6 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
         p.handler = self.handler
         p.outbound_redis_connection = self.outbound_redis_connection
         p.stream_name = self.stream_name
+        p.password = self.password
 
         return p
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 4af1afd119..b0f87c365b 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -16,7 +16,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
+from typing import Any, Awaitable, Callable, List, Optional, Tuple
 
 import attr
 
@@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
 #
 # The arguments are:
 #
+#  * instance_name: the writer of the stream
 #  * from_token: the previous stream token: the starting point for fetching the
 #    updates
 #  * to_token: the new stream token: the point to get updates up to
@@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
 # If there are more updates available, it should set `limited` in the result, and
 # it will be called again to get the next batch.
 #
-UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
+UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
 
 
 class Stream(object):
@@ -93,6 +94,7 @@ class Stream(object):
 
     def __init__(
         self,
+        local_instance_name: str,
         current_token_function: Callable[[], Token],
         update_function: UpdateFunction,
     ):
@@ -108,9 +110,11 @@ class Stream(object):
         stream tokens. See the UpdateFunction type definition for more info.
 
         Args:
+            local_instance_name: The instance name of the current process
             current_token_function: callback to get the current token, as above
             update_function: callback go get stream updates, as above
         """
+        self.local_instance_name = local_instance_name
         self.current_token = current_token_function
         self.update_function = update_function
 
@@ -135,14 +139,14 @@ class Stream(object):
         """
         current_token = self.current_token()
         updates, current_token, limited = await self.get_updates_since(
-            self.last_token, current_token
+            self.local_instance_name, self.last_token, current_token
         )
         self.last_token = current_token
 
         return updates, current_token, limited
 
     async def get_updates_since(
-        self, from_token: Token, upto_token: Token
+        self, instance_name: str, from_token: Token, upto_token: Token
     ) -> StreamUpdateResult:
         """Like get_updates except allows specifying from when we should
         stream updates
@@ -160,19 +164,19 @@ class Stream(object):
             return [], upto_token, False
 
         updates, upto_token, limited = await self.update_function(
-            from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
+            instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
         )
         return updates, upto_token, limited
 
 
 def db_query_to_update_function(
-    query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
+    query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
 ) -> UpdateFunction:
     """Wraps a db query function which returns a list of rows to make it
     suitable for use as an `update_function` for the Stream class
     """
 
-    async def update_function(from_token, upto_token, limit):
+    async def update_function(instance_name, from_token, upto_token, limit):
         rows = await query_function(from_token, upto_token, limit)
         updates = [(row[0], row[1:]) for row in rows]
         limited = False
@@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
     client = ReplicationGetStreamUpdates.make_client(hs)
 
     async def update_function(
-        from_token: int, upto_token: int, limit: int
+        instance_name: str, from_token: int, upto_token: int, limit: int
     ) -> StreamUpdateResult:
         result = await client(
-            stream_name=stream_name, from_token=from_token, upto_token=upto_token,
+            instance_name=instance_name,
+            stream_name=stream_name,
+            from_token=from_token,
+            upto_token=upto_token,
         )
         return result["updates"], result["upto_token"], result["limited"]
 
@@ -226,6 +233,7 @@ class BackfillStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_current_backfill_token,
             db_query_to_update_function(store.get_all_new_backfill_event_rows),
         )
@@ -261,7 +269,9 @@ class PresenceStream(Stream):
             # Query master process
             update_function = make_http_update_function(hs, self.NAME)
 
-        super().__init__(store.get_current_presence_token, update_function)
+        super().__init__(
+            hs.get_instance_name(), store.get_current_presence_token, update_function
+        )
 
 
 class TypingStream(Stream):
@@ -284,7 +294,9 @@ class TypingStream(Stream):
             # Query master process
             update_function = make_http_update_function(hs, self.NAME)
 
-        super().__init__(typing_handler.get_current_token, update_function)
+        super().__init__(
+            hs.get_instance_name(), typing_handler.get_current_token, update_function
+        )
 
 
 class ReceiptsStream(Stream):
@@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_max_receipt_stream_id,
             db_query_to_update_function(store.get_all_updated_receipts),
         )
@@ -322,14 +335,16 @@ class PushRulesStream(Stream):
     def __init__(self, hs):
         self.store = hs.get_datastore()
         super(PushRulesStream, self).__init__(
-            self._current_token, self._update_function
+            hs.get_instance_name(), self._current_token, self._update_function
         )
 
     def _current_token(self) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    async def _update_function(self, from_token: Token, to_token: Token, limit: int):
+    async def _update_function(
+        self, instance_name: str, from_token: Token, to_token: Token, limit: int
+    ):
         rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
 
         limited = False
@@ -356,6 +371,7 @@ class PushersStream(Stream):
         store = hs.get_datastore()
 
         super().__init__(
+            hs.get_instance_name(),
             store.get_pushers_stream_token,
             db_query_to_update_function(store.get_all_updated_pushers_rows),
         )
@@ -387,6 +403,7 @@ class CachesStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_cache_stream_token,
             db_query_to_update_function(store.get_all_updated_caches),
         )
@@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_current_public_room_stream_id,
             db_query_to_update_function(store.get_all_new_public_rooms),
         )
@@ -432,6 +450,7 @@ class DeviceListsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_device_stream_token,
             db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
         )
@@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_to_device_stream_token,
             db_query_to_update_function(store.get_all_new_device_messages),
         )
@@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_max_account_data_stream_id,
             db_query_to_update_function(store.get_all_updated_tags),
         )
@@ -487,6 +508,7 @@ class AccountDataStream(Stream):
     def __init__(self, hs):
         self.store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             self.store.get_max_account_data_stream_id,
             db_query_to_update_function(self._update_function),
         )
@@ -517,6 +539,7 @@ class GroupServerStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_group_stream_token,
             db_query_to_update_function(store.get_all_groups_changes),
         )
@@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
         super().__init__(
+            hs.get_instance_name(),
             store.get_device_stream_token,
             db_query_to_update_function(
                 store.get_all_user_signature_changes_for_remotes
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 52df81b1bd..890e75d827 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -118,11 +118,17 @@ class EventsStream(Stream):
     def __init__(self, hs):
         self._store = hs.get_datastore()
         super().__init__(
-            self._store.get_current_events_token, self._update_function,
+            hs.get_instance_name(),
+            self._store.get_current_events_token,
+            self._update_function,
         )
 
     async def _update_function(
-        self, from_token: Token, current_token: Token, target_row_count: int
+        self,
+        instance_name: str,
+        from_token: Token,
+        current_token: Token,
+        target_row_count: int,
     ) -> StreamUpdateResult:
 
         # the events stream merges together three separate sources:
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 75133d7e40..e8bd52e389 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -48,8 +48,8 @@ class FederationStream(Stream):
             current_token = lambda: 0
             update_function = self._stub_update_function
 
-        super().__init__(current_token, update_function)
+        super().__init__(hs.get_instance_name(), current_token, update_function)
 
     @staticmethod
-    async def _stub_update_function(from_token, upto_token, limit):
+    async def _stub_update_function(instance_name, from_token, upto_token, limit):
         return [], upto_token, False