summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-05-07 13:51:08 +0100
committerGitHub <noreply@github.com>2020-05-07 13:51:08 +0100
commitd7983b63a6746d92225295f1e9d521f847cf8ba7 (patch)
tree13d581210d94c26bd75036592996f6f53f7d4bb2 /synapse
parentMerge pull request #7398 from Starbix/alpine-3.11 (diff)
downloadsynapse-d7983b63a6746d92225295f1e9d521f847cf8ba7.tar.xz
Support any process writing to cache invalidation stream. (#7436)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/slave/storage/_base.py50
-rw-r--r--synapse/replication/slave/storage/account_data.py6
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py6
-rw-r--r--synapse/replication/slave/storage/devices.py6
-rw-r--r--synapse/replication/slave/storage/events.py6
-rw-r--r--synapse/replication/slave/storage/groups.py6
-rw-r--r--synapse/replication/slave/storage/presence.py6
-rw-r--r--synapse/replication/slave/storage/push_rule.py6
-rw-r--r--synapse/replication/slave/storage/pushers.py6
-rw-r--r--synapse/replication/slave/storage/receipts.py6
-rw-r--r--synapse/replication/slave/storage/room.py4
-rw-r--r--synapse/replication/tcp/client.py6
-rw-r--r--synapse/replication/tcp/commands.py33
-rw-r--r--synapse/replication/tcp/handler.py42
-rw-r--r--synapse/replication/tcp/resource.py22
-rw-r--r--synapse/replication/tcp/streams/_base.py87
-rw-r--r--synapse/replication/tcp/streams/events.py4
-rw-r--r--synapse/replication/tcp/streams/federation.py12
-rw-r--r--synapse/storage/_base.py3
-rw-r--r--synapse/storage/data_stores/main/__init__.py15
-rw-r--r--synapse/storage/data_stores/main/cache.py84
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres30
-rw-r--r--synapse/storage/prepare_database.py2
23 files changed, 223 insertions, 225 deletions
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 5d7c8871a4..2904bd0235 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,14 +18,10 @@ from typing import Optional
 
 import six
 
-from synapse.storage.data_stores.main.cache import (
-    CURRENT_STATE_CACHE_NAME,
-    CacheInvalidationWorkerStore,
-)
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
-
-from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 logger = logging.getLogger(__name__)
 
@@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: Database, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = SlavedIdTracker(
-                db_conn, "cache_invalidation_stream", "stream_id"
-            )  # type: Optional[SlavedIdTracker]
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                instance_name=hs.get_instance_name(),
+                table="cache_invalidation_stream_by_instance",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="cache_invalidation_stream_seq",
+            )  # type: Optional[MultiWriterIdGenerator]
         else:
             self._cache_id_gen = None
 
         self.hs = hs
-
-    def get_cache_stream_token(self):
-        if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token()
-        else:
-            return 0
-
-    def process_replication_rows(self, stream_name, token, rows):
-        if stream_name == "caches":
-            if self._cache_id_gen:
-                self._cache_id_gen.advance(token)
-            for row in rows:
-                if row.cache_func == CURRENT_STATE_CACHE_NAME:
-                    if row.keys is None:
-                        raise Exception(
-                            "Can't send an 'invalidate all' for current state cache"
-                        )
-
-                    room_id = row.keys[0]
-                    members_changed = set(row.keys[1:])
-                    self._invalidate_state_caches(room_id, members_changed)
-                else:
-                    self._attempt_to_invalidate_cache(row.cache_func, row.keys)
-
-    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
-        txn.call_after(cache_func.invalidate, keys)
-        txn.call_after(self._send_invalidation_poke, cache_func, keys)
-
-    def _send_invalidation_poke(self, cache_func, keys):
-        self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 65e54b1c71..2a4f5c7cfd 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "tag_account_data":
             self._account_data_id_gen.advance(token)
             for row in rows:
@@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
                     (row.user_id, row.room_id, row.data_type)
                 )
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedAccountDataStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index c923751e50..6e7fd259d4 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
             expiry_ms=30 * 60 * 1000,
         )
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "to_device":
             self._device_inbox_id_gen.advance(token)
             for row in rows:
@@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
                     self._device_federation_outbox_stream_cache.entity_has_changed(
                         row.entity, token
                     )
-        return super(SlavedDeviceInboxStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 58fb0eaae3..9d8067342f 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             "DeviceListFederationStreamChangeCache", device_list_max
         )
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
             self._invalidate_caches_for_devices(token, rows)
@@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             self._device_list_id_gen.advance(token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedDeviceStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _invalidate_caches_for_devices(self, token, rows):
         for row in rows:
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 15011259df..b313720a4b 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -93,7 +93,7 @@ class SlavedEventStore(
     def get_room_min_stream_ordering(self):
         return self._backfill_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "events":
             self._stream_id_gen.advance(token)
             for row in rows:
@@ -111,9 +111,7 @@ class SlavedEventStore(
                     row.relates_to,
                     backfilled=True,
                 )
-        return super(SlavedEventStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _process_event_stream_row(self, token, row):
         data = row.data
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 01bcf0e882..1851e7d525 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
     def get_group_stream_token(self):
         return self._group_updates_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "groups":
             self._group_updates_id_gen.advance(token)
             for row in rows:
                 self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
 
-        return super(SlavedGroupServerStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index fae3125072..bd79ba99be 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore):
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "presence":
             self._presence_id_gen.advance(token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
                 self._get_presence_for_user.invalidate((row.user_id,))
-        return super(SlavedPresenceStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6138796da4..5d5816d7eb 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "push_rules":
             self._push_rules_stream_id_gen.advance(token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
                 self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedPushRuleStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 67be337945..cb78b49acb 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
     def get_pushers_stream_token(self):
         return self._pushers_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "pushers":
             self._pushers_id_gen.advance(token)
-        return super(SlavedPusherStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 993432edcb..be716cc558 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
         self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "receipts":
             self._receipts_id_gen.advance(token)
             for row in rows:
@@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
                 )
                 self._receipts_stream_cache.entity_has_changed(row.room_id, token)
 
-        return super(SlavedReceiptsStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 10dda8708f..8873bf37e5 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "public_rooms":
             self._public_room_id_gen.advance(token)
 
-        return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3bbf3c3569..20cb8a654f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -100,10 +100,10 @@ class ReplicationDataHandler:
             token: stream token for this batch of rows
             rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
-        self.store.process_replication_rows(stream_name, token, rows)
+        self.store.process_replication_rows(stream_name, instance_name, token, rows)
 
-    async def on_position(self, stream_name: str, token: int):
-        self.store.process_replication_rows(stream_name, token, [])
+    async def on_position(self, stream_name: str, instance_name: str, token: int):
+        self.store.process_replication_rows(stream_name, instance_name, token, [])
 
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f58e384d17..c04f622816 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
         return " ".join((self.app_id, self.push_key, self.user_id))
 
 
-class InvalidateCacheCommand(Command):
-    """Sent by the client to invalidate an upstream cache.
-
-    THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
-    NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
-    Mainly used to invalidate destination retry timing caches.
-
-    Format::
-
-        INVALIDATE_CACHE <cache_func> <keys_json>
-
-    Where <keys_json> is a json list.
-    """
-
-    NAME = "INVALIDATE_CACHE"
-
-    def __init__(self, cache_func, keys):
-        self.cache_func = cache_func
-        self.keys = keys
-
-    @classmethod
-    def from_line(cls, line):
-        cache_func, keys_json = line.split(" ", 1)
-
-        return cls(cache_func, json.loads(keys_json))
-
-    def to_line(self):
-        return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
 class UserIpCommand(Command):
     """Sent periodically when a worker sees activity from a client.
 
@@ -439,7 +408,6 @@ _COMMANDS = (
     UserSyncCommand,
     FederationAckCommand,
     RemovePusherCommand,
-    InvalidateCacheCommand,
     UserIpCommand,
     RemoteServerUpCommand,
     ClearUserSyncsCommand,
@@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
     ClearUserSyncsCommand.NAME,
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
-    InvalidateCacheCommand.NAME,
     UserIpCommand.NAME,
     ErrorCommand.NAME,
     RemoteServerUpCommand.NAME,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b14a3d9fca..7c5d6c76e7 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,18 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    Iterable,
-    Iterator,
-    List,
-    Optional,
-    Set,
-    Tuple,
-    TypeVar,
-)
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
 
 from prometheus_client import Counter
 
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
     Command,
     FederationAckCommand,
-    InvalidateCacheCommand,
     PositionCommand,
     RdataCommand,
     RemoteServerUpCommand,
@@ -171,7 +159,7 @@ class ReplicationCommandHandler:
             return
 
         for stream_name, stream in self._streams.items():
-            current_token = stream.current_token()
+            current_token = stream.current_token(self._instance_name)
             self.send_command(
                 PositionCommand(stream_name, self._instance_name, current_token)
             )
@@ -210,18 +198,6 @@ class ReplicationCommandHandler:
 
             self._notifier.on_new_replication_data()
 
-    async def on_INVALIDATE_CACHE(
-        self, conn: AbstractConnection, cmd: InvalidateCacheCommand
-    ):
-        invalidate_cache_counter.inc()
-
-        if self._is_master:
-            # We invalidate the cache locally, but then also stream that to other
-            # workers.
-            await self._store.invalidate_cache_and_stream(
-                cmd.cache_func, tuple(cmd.keys)
-            )
-
     async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
         user_ip_cache_counter.inc()
 
@@ -295,7 +271,7 @@ class ReplicationCommandHandler:
             rows: a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
         """
-        logger.debug("Received rdata %s -> %s", stream_name, token)
+        logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
         await self._replication_data_handler.on_rdata(
             stream_name, instance_name, token, rows
         )
@@ -326,7 +302,7 @@ class ReplicationCommandHandler:
             self._pending_batches.pop(stream_name, [])
 
             # Find where we previously streamed up to.
-            current_token = stream.current_token()
+            current_token = stream.current_token(cmd.instance_name)
 
             # If the position token matches our current token then we're up to
             # date and there's nothing to do. Otherwise, fetch all updates
@@ -363,7 +339,9 @@ class ReplicationCommandHandler:
             logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
 
             # We've now caught up to position sent to us, notify handler.
-            await self._replication_data_handler.on_position(stream_name, cmd.token)
+            await self._replication_data_handler.on_position(
+                cmd.stream_name, cmd.instance_name, cmd.token
+            )
 
             self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
@@ -491,12 +469,6 @@ class ReplicationCommandHandler:
         cmd = RemovePusherCommand(app_id, push_key, user_id)
         self.send_command(cmd)
 
-    def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
-        """Poke the master to invalidate a cache.
-        """
-        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
-        self.send_command(cmd)
-
     def send_user_ip(
         self,
         user_id: str,
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index b690abedad..002171ce7c 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
+from synapse.replication.tcp.streams import (
+    STREAMS_MAP,
+    CachesStream,
+    FederationStream,
+    Stream,
+)
 from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
@@ -71,11 +76,16 @@ class ReplicationStreamer(object):
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
+        self._instance_name = hs.get_instance_name()
 
         self._replication_torture_level = hs.config.replication_torture_level
 
         # Work out list of streams that this instance is the source of.
         self.streams = []  # type: List[Stream]
+
+        # All workers can write to the cache invalidation stream.
+        self.streams.append(CachesStream(hs))
+
         if hs.config.worker_app is None:
             for stream in STREAMS_MAP.values():
                 if stream == FederationStream and hs.config.send_federation:
@@ -83,6 +93,10 @@ class ReplicationStreamer(object):
                     # has been disabled on the master.
                     continue
 
+                if stream == CachesStream:
+                    # We've already added it above.
+                    continue
+
                 self.streams.append(stream(hs))
 
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
@@ -145,7 +159,9 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.current_token():
+                        if stream.last_token == stream.current_token(
+                            self._instance_name
+                        ):
                             continue
 
                         if self._replication_torture_level:
@@ -157,7 +173,7 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.current_token(),
+                            stream.current_token(self._instance_name),
                         )
                         try:
                             updates, current_token, limited = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 084604e8b0..b48a6a3e91 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,20 +95,25 @@ class Stream(object):
     def __init__(
         self,
         local_instance_name: str,
-        current_token_function: Callable[[], Token],
+        current_token_function: Callable[[str], Token],
         update_function: UpdateFunction,
     ):
         """Instantiate a Stream
 
-        current_token_function and update_function are callbacks which should be
-        implemented by subclasses.
+        `current_token_function` and `update_function` are callbacks which
+        should be implemented by subclasses.
 
-        current_token_function is called to get the current token of the underlying
-        stream. It is only meaningful on the process that is the source of the
-        replication stream (ie, usually the master).
+        `current_token_function` takes an instance name, which is a writer to
+        the stream, and returns the position in the stream of the writer (as
+        viewed from the current process). On the writer process this is where
+        the writer has successfully written up to, whereas on other processes
+        this is the position which we have received updates up to over
+        replication. (Note that most streams have a single writer and so their
+        implementations ignore the instance name passed in).
 
-        update_function is called to get updates for this stream between a pair of
-        stream tokens. See the UpdateFunction type definition for more info.
+        `update_function` is called to get updates for this stream between a
+        pair of stream tokens. See the `UpdateFunction` type definition for more
+        info.
 
         Args:
             local_instance_name: The instance name of the current process
@@ -120,13 +125,13 @@ class Stream(object):
         self.update_function = update_function
 
         # The token from which we last asked for updates
-        self.last_token = self.current_token()
+        self.last_token = self.current_token(self.local_instance_name)
 
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.last_token = self.current_token()
+        self.last_token = self.current_token(self.local_instance_name)
 
     async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
@@ -138,7 +143,7 @@ class Stream(object):
             position in stream, and `limited` is whether there are more updates
             to fetch.
         """
-        current_token = self.current_token()
+        current_token = self.current_token(self.local_instance_name)
         updates, current_token, limited = await self.get_updates_since(
             self.local_instance_name, self.last_token, current_token
         )
@@ -170,6 +175,16 @@ class Stream(object):
         return updates, upto_token, limited
 
 
+def current_token_without_instance(
+    current_token: Callable[[], int]
+) -> Callable[[str], int]:
+    """Takes a current token callback function for a single writer stream
+    that doesn't take an instance name parameter and wraps it in a function that
+    does accept an instance name parameter but ignores it.
+    """
+    return lambda instance_name: current_token()
+
+
 def db_query_to_update_function(
     query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
 ) -> UpdateFunction:
@@ -235,7 +250,7 @@ class BackfillStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_current_backfill_token,
+            current_token_without_instance(store.get_current_backfill_token),
             db_query_to_update_function(store.get_all_new_backfill_event_rows),
         )
 
@@ -271,7 +286,9 @@ class PresenceStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(), store.get_current_presence_token, update_function
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_presence_token),
+            update_function,
         )
 
 
@@ -296,7 +313,9 @@ class TypingStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(), typing_handler.get_current_token, update_function
+            hs.get_instance_name(),
+            current_token_without_instance(typing_handler.get_current_token),
+            update_function,
         )
 
 
@@ -319,7 +338,7 @@ class ReceiptsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_max_receipt_stream_id,
+            current_token_without_instance(store.get_max_receipt_stream_id),
             db_query_to_update_function(store.get_all_updated_receipts),
         )
 
@@ -339,7 +358,7 @@ class PushRulesStream(Stream):
             hs.get_instance_name(), self._current_token, self._update_function
         )
 
-    def _current_token(self) -> int:
+    def _current_token(self, instance_name: str) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
@@ -373,7 +392,7 @@ class PushersStream(Stream):
 
         super().__init__(
             hs.get_instance_name(),
-            store.get_pushers_stream_token,
+            current_token_without_instance(store.get_pushers_stream_token),
             db_query_to_update_function(store.get_all_updated_pushers_rows),
         )
 
@@ -402,12 +421,26 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
-        store = hs.get_datastore()
+        self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token,
-            db_query_to_update_function(store.get_all_updated_caches),
+            self.store.get_cache_stream_token,
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, upto_token: int, limit: int
+    ):
+        rows = await self.store.get_all_updated_caches(
+            instance_name, from_token, upto_token, limit
         )
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) >= limit:
+            upto_token = updates[-1][0]
+            limited = True
+
+        return updates, upto_token, limited
 
 
 class PublicRoomsStream(Stream):
@@ -431,7 +464,7 @@ class PublicRoomsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_current_public_room_stream_id,
+            current_token_without_instance(store.get_current_public_room_stream_id),
             db_query_to_update_function(store.get_all_new_public_rooms),
         )
 
@@ -452,7 +485,7 @@ class DeviceListsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_device_stream_token,
+            current_token_without_instance(store.get_device_stream_token),
             db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
         )
 
@@ -470,7 +503,7 @@ class ToDeviceStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_to_device_stream_token,
+            current_token_without_instance(store.get_to_device_stream_token),
             db_query_to_update_function(store.get_all_new_device_messages),
         )
 
@@ -490,7 +523,7 @@ class TagAccountDataStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_max_account_data_stream_id,
+            current_token_without_instance(store.get_max_account_data_stream_id),
             db_query_to_update_function(store.get_all_updated_tags),
         )
 
@@ -510,7 +543,7 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self.store.get_max_account_data_stream_id,
+            current_token_without_instance(self.store.get_max_account_data_stream_id),
             db_query_to_update_function(self._update_function),
         )
 
@@ -541,7 +574,7 @@ class GroupServerStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_group_stream_token,
+            current_token_without_instance(store.get_group_stream_token),
             db_query_to_update_function(store.get_all_groups_changes),
         )
 
@@ -559,7 +592,7 @@ class UserSignatureStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_device_stream_token,
+            current_token_without_instance(store.get_device_stream_token),
             db_query_to_update_function(
                 store.get_all_user_signature_changes_for_remotes
             ),
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 890e75d827..f370390331 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -20,7 +20,7 @@ from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token
+from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
 
 
 """Handling of the 'events' replication stream
@@ -119,7 +119,7 @@ class EventsStream(Stream):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self._store.get_current_events_token,
+            current_token_without_instance(self._store.get_current_events_token),
             self._update_function,
         )
 
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index b0505b8a2c..9bcd13b009 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,11 @@
 # limitations under the License.
 from collections import namedtuple
 
-from synapse.replication.tcp.streams._base import Stream, make_http_update_function
+from synapse.replication.tcp.streams._base import (
+    Stream,
+    current_token_without_instance,
+    make_http_update_function,
+)
 
 
 class FederationStream(Stream):
@@ -41,7 +45,9 @@ class FederationStream(Stream):
             # will be a real FederationSender, which has stubs for current_token and
             # get_replication_rows.)
             federation_sender = hs.get_federation_sender()
-            current_token = federation_sender.get_current_token
+            current_token = current_token_without_instance(
+                federation_sender.get_current_token
+            )
             update_function = federation_sender.get_replication_rows
 
         elif hs.should_send_federation():
@@ -58,7 +64,7 @@ class FederationStream(Stream):
         super().__init__(hs.get_instance_name(), current_token, update_function)
 
     @staticmethod
-    def _stub_current_token():
+    def _stub_current_token(instance_name: str) -> int:
         # dummy current-token method for use on workers
         return 0
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 13de5f1f62..59073c0a42 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -47,6 +47,9 @@ class SQLBaseStore(metaclass=ABCMeta):
         self.db = database
         self.rand = random.SystemRandom()
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        pass
+
     def _invalidate_state_caches(self, room_id, members_changed):
         """Invalidates caches that are based on the current state, but does
         not stream invalidations down replication.
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index ceba10882c..cd2a1f0461 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -26,13 +26,14 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     ChainedIdGenerator,
     IdGenerator,
+    MultiWriterIdGenerator,
     StreamIdGenerator,
 )
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from .account_data import AccountDataStore
 from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .cache import CacheInvalidationStore
+from .cache import CacheInvalidationWorkerStore
 from .client_ips import ClientIpStore
 from .deviceinbox import DeviceInboxStore
 from .devices import DeviceStore
@@ -112,8 +113,8 @@ class DataStore(
     MonthlyActiveUsersStore,
     StatsStore,
     RelationsStore,
-    CacheInvalidationStore,
     UIAuthStore,
+    CacheInvalidationWorkerStore,
 ):
     def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
@@ -170,8 +171,14 @@ class DataStore(
         )
 
         if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = StreamIdGenerator(
-                db_conn, "cache_invalidation_stream", "stream_id"
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                instance_name="master",
+                table="cache_invalidation_stream_by_instance",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="cache_invalidation_stream_seq",
             )
         else:
             self._cache_id_gen = None
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 4dc5da3fe8..342a87a46b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,11 +16,10 @@
 
 import itertools
 import logging
-from typing import Any, Iterable, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Any, Iterable, Optional
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
@@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
 class CacheInvalidationWorkerStore(SQLBaseStore):
-    def get_all_updated_caches(self, last_id, current_id, limit):
+    def __init__(self, database: Database, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._instance_name = hs.get_instance_name()
+
+    async def get_all_updated_caches(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ):
+        """Fetches cache invalidation rows between the two given IDs written
+        by the given instance. Returns at most `limit` rows.
+        """
+
         if last_id == current_id:
-            return defer.succeed([])
+            return []
 
         def get_all_updated_caches_txn(txn):
             # We purposefully don't bound by the current token, as we want to
             # send across cache invalidations as quickly as possible. Cache
             # invalidations are idempotent, so duplicates are fine.
-            sql = (
-                "SELECT stream_id, cache_func, keys, invalidation_ts"
-                " FROM cache_invalidation_stream"
-                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, limit))
+            sql = """
+                SELECT stream_id, cache_func, keys, invalidation_ts
+                FROM cache_invalidation_stream_by_instance
+                WHERE stream_id > ? AND instance_name = ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, instance_name, limit))
             return txn.fetchall()
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "get_all_updated_caches", get_all_updated_caches_txn
         )
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == "caches":
+            if self._cache_id_gen:
+                self._cache_id_gen.advance(instance_name, token)
 
-class CacheInvalidationStore(CacheInvalidationWorkerStore):
-    async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
-        """Invalidates the cache and adds it to the cache stream so slaves
-        will know to invalidate their caches.
+            for row in rows:
+                if row.cache_func == CURRENT_STATE_CACHE_NAME:
+                    if row.keys is None:
+                        raise Exception(
+                            "Can't send an 'invalidate all' for current state cache"
+                        )
 
-        This should only be used to invalidate caches where slaves won't
-        otherwise know from other replication streams that the cache should
-        be invalidated.
-        """
-        cache_func = getattr(self, cache_name, None)
-        if not cache_func:
-            return
-
-        cache_func.invalidate(keys)
-        await self.runInteraction(
-            "invalidate_cache_and_stream",
-            self._send_invalidation_to_replication,
-            cache_func.__name__,
-            keys,
-        )
+                    room_id = row.keys[0]
+                    members_changed = set(row.keys[1:])
+                    self._invalidate_state_caches(room_id, members_changed)
+                else:
+                    self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _invalidate_cache_and_stream(self, txn, cache_func, keys):
         """Invalidates the cache and adds it to the cache stream so slaves
@@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
             # the transaction. However, we want to only get an ID when we want
             # to use it, here, so we need to call __enter__ manually, and have
             # __exit__ called after the transaction finishes.
-            ctx = self._cache_id_gen.get_next()
-            stream_id = ctx.__enter__()
-            txn.call_on_exception(ctx.__exit__, None, None, None)
-            txn.call_after(ctx.__exit__, None, None, None)
+            stream_id = self._cache_id_gen.get_next_txn(txn)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
             if keys is not None:
@@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
 
             self.db.simple_insert_txn(
                 txn,
-                table="cache_invalidation_stream",
+                table="cache_invalidation_stream_by_instance",
                 values={
                     "stream_id": stream_id,
+                    "instance_name": self._instance_name,
                     "cache_func": cache_name,
                     "keys": keys,
                     "invalidation_ts": self.clock.time_msec(),
                 },
             )
 
-    def get_cache_stream_token(self):
+    def get_cache_stream_token(self, instance_name):
         if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token()
+            return self._cache_id_gen.get_current_token(instance_name)
         else:
             return 0
diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
new file mode 100644
index 0000000000..aa46eb0e10
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We keep the old table here to enable us to roll back. It doesn't matter
+-- that we have dropped all the data here.
+TRUNCATE cache_invalidation_stream;
+
+CREATE TABLE cache_invalidation_stream_by_instance (
+    stream_id       BIGINT NOT NULL,
+    instance_name   TEXT NOT NULL,
+    cache_func      TEXT NOT NULL,
+    keys            TEXT[],
+    invalidation_ts BIGINT
+);
+
+CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
+
+CREATE SEQUENCE cache_invalidation_stream_seq;
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1712932f31..640f242584 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -29,6 +29,8 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
+# XXX: If you're about to bump this to 59 (or higher) please create an update
+# that drops the unused `cache_invalidation_stream` table, as per #7436!
 SCHEMA_VERSION = 58
 
 dir_path = os.path.abspath(os.path.dirname(__file__))