summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/send_queue.py84
-rw-r--r--synapse/federation/sender/__init__.py12
-rw-r--r--synapse/federation/sender/per_destination_queue.py6
-rw-r--r--synapse/federation/sender/transaction_manager.py6
-rw-r--r--synapse/handlers/room.py42
-rw-r--r--synapse/handlers/saml_handler.py28
-rw-r--r--synapse/replication/slave/storage/_base.py50
-rw-r--r--synapse/replication/slave/storage/account_data.py6
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py6
-rw-r--r--synapse/replication/slave/storage/devices.py6
-rw-r--r--synapse/replication/slave/storage/events.py6
-rw-r--r--synapse/replication/slave/storage/groups.py6
-rw-r--r--synapse/replication/slave/storage/presence.py6
-rw-r--r--synapse/replication/slave/storage/push_rule.py6
-rw-r--r--synapse/replication/slave/storage/pushers.py6
-rw-r--r--synapse/replication/slave/storage/receipts.py6
-rw-r--r--synapse/replication/slave/storage/room.py4
-rw-r--r--synapse/replication/tcp/client.py6
-rw-r--r--synapse/replication/tcp/commands.py33
-rw-r--r--synapse/replication/tcp/handler.py42
-rw-r--r--synapse/replication/tcp/resource.py24
-rw-r--r--synapse/replication/tcp/streams/_base.py86
-rw-r--r--synapse/replication/tcp/streams/events.py4
-rw-r--r--synapse/replication/tcp/streams/federation.py36
-rw-r--r--synapse/res/templates/notice_expiry.html2
-rw-r--r--synapse/res/templates/notice_expiry.txt2
-rw-r--r--synapse/storage/_base.py3
-rw-r--r--synapse/storage/data_stores/main/__init__.py15
-rw-r--r--synapse/storage/data_stores/main/cache.py84
-rw-r--r--synapse/storage/data_stores/main/devices.py60
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py98
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql28
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres30
-rw-r--r--synapse/storage/database.py74
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/util/id_generators.py169
36 files changed, 649 insertions, 435 deletions
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index e1700ca8aa..52f4f54215 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
 
 import logging
 from collections import namedtuple
+from typing import Dict, List, Tuple, Type
 
 from six import iteritems
 
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
         self.notifier = hs.get_notifier()
         self.is_mine_id = hs.is_mine_id
 
-        self.presence_map = {}  # Pending presence map user_id -> UserPresenceState
-        self.presence_changed = SortedDict()  # Stream position -> list[user_id]
+        # Pending presence map user_id -> UserPresenceState
+        self.presence_map = {}  # type: Dict[str, UserPresenceState]
+
+        # Stream position -> list[user_id]
+        self.presence_changed = SortedDict()  # type: SortedDict[int, List[str]]
 
         # Stores the destinations we need to explicitly send presence to about a
         # given user.
         # Stream position -> (user_id, destinations)
-        self.presence_destinations = SortedDict()
+        self.presence_destinations = (
+            SortedDict()
+        )  # type: SortedDict[int, Tuple[str, List[str]]]
+
+        # (destination, key) -> EDU
+        self.keyed_edu = {}  # type: Dict[Tuple[str, tuple], Edu]
 
-        self.keyed_edu = {}  # (destination, key) -> EDU
-        self.keyed_edu_changed = SortedDict()  # stream position -> (destination, key)
+        # stream position -> (destination, key)
+        self.keyed_edu_changed = (
+            SortedDict()
+        )  # type: SortedDict[int, Tuple[str, tuple]]
 
-        self.edus = SortedDict()  # stream position -> Edu
+        self.edus = SortedDict()  # type: SortedDict[int, Edu]
 
+        # stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
         self.pos = 1
-        self.pos_time = SortedDict()
+
+        # map from stream ID to the time that stream entry was generated, so that we
+        # can clear out entries after a while
+        self.pos_time = SortedDict()  # type: SortedDict[int, int]
 
         # EVERYTHING IS SAD. In particular, python only makes new scopes when
         # we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
             for edu_key in self.keyed_edu_changed.values():
                 live_keys.add(edu_key)
 
-            to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
-            for edu_key in to_del:
+            keys_to_del = [
+                edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
+            ]
+            for edu_key in keys_to_del:
                 del self.keyed_edu[edu_key]
 
             # Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
         self._clear_queue_before_pos(token)
 
     async def get_replication_rows(
-        self, from_token, to_token, limit, federation_ack=None
-    ):
+        self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+    ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
         """Get rows to be sent over federation between the two tokens
 
         Args:
-            from_token (int)
-            to_token(int)
-            limit (int)
-            federation_ack (int): Optional. The position where the worker is
-                explicitly acknowledged it has handled. Allows us to drop
-                data from before that point
+            instance_name: the name of the current process
+            from_token: the previous stream token: the starting point for fetching the
+                updates
+            to_token: the new stream token: the point to get updates up to
+            target_row_count: a target for the number of rows to be returned.
+
+        Returns: a triplet `(updates, new_last_token, limited)`, where:
+           * `updates` is a list of `(token, row)` entries.
+           * `new_last_token` is the new position in stream.
+           * `limited` is whether there are more updates to fetch.
         """
-        # TODO: Handle limit.
+        # TODO: Handle target_row_count.
 
         # To handle restarts where we wrap around
         if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
 
         # list of tuple(int, BaseFederationRow), where the first is the position
         # of the federation stream.
-        rows = []
-
-        # There should be only one reader, so lets delete everything its
-        # acknowledged its seen.
-        if federation_ack:
-            self._clear_queue_before_pos(federation_ack)
+        rows = []  # type: List[Tuple[int, BaseFederationRow]]
 
         # Fetch changed presence
         i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
         # Sort rows based on pos
         rows.sort()
 
-        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+        return (
+            [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
+            to_token,
+            False,
+        )
 
 
 class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
     Specifies how to identify, serialize and deserialize the different types.
     """
 
-    TypeId = None  # Unique string that ids the type. Must be overriden in sub classes.
+    TypeId = ""  # Unique string that ids the type. Must be overriden in sub classes.
 
     @staticmethod
     def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))):  # Edu
         buff.edus.setdefault(self.edu.destination, []).append(self.edu)
 
 
-TypeToRow = {
-    Row.TypeId: Row
-    for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
-}
+_rowtypes = (
+    PresenceRow,
+    PresenceDestinationsRow,
+    KeyedEduRow,
+    EduRow,
+)  # type: Tuple[Type[BaseFederationRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
 
 
 ParsedFederationStreamData = namedtuple(
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a477578e44..d473576902 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set
+from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
 
 from six import itervalues
 
@@ -498,14 +498,16 @@ class FederationSender(object):
 
         self._get_per_destination_queue(destination).attempt_new_transaction()
 
-    def get_current_token(self) -> int:
+    @staticmethod
+    def get_current_token() -> int:
         # Dummy implementation for case where federation sender isn't offloaded
         # to a worker.
         return 0
 
+    @staticmethod
     async def get_replication_rows(
-        self, from_token, to_token, limit, federation_ack=None
-    ):
+        instance_name: str, from_token: int, to_token: int, target_row_count: int
+    ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
         # Dummy implementation for case where federation sender isn't offloaded
         # to a worker.
-        return []
+        return [], 0, False
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index e13cd20ffa..276a2b596f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,10 @@
 # limitations under the License.
 import datetime
 import logging
-from typing import Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
 
 from prometheus_client import Counter
 
-import synapse.server
 from synapse.api.errors import (
     FederationDeniedError,
     HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
 from synapse.types import ReadReceipt
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
+if TYPE_CHECKING:
+    import synapse.server
+
 # This is defined in the Matrix spec and enforced by the receiver.
 MAX_EDUS_PER_TRANSACTION = 100
 
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3c2a02a3b3..a2752a54a5 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,11 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List
+from typing import TYPE_CHECKING, List
 
 from canonicaljson import json
 
-import synapse.server
 from synapse.api.errors import HttpResponseException
 from synapse.events import EventBase
 from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
 )
 from synapse.util.metrics import measure_func
 
+if TYPE_CHECKING:
+    import synapse.server
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index da12df7f53..73f9eeb399 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,8 +25,6 @@ from collections import OrderedDict
 
 from six import iteritems, string_types
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
 from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
-    @defer.inlineCallbacks
-    def upgrade_room(
+    async def upgrade_room(
         self, requester: Requester, old_room_id: str, new_version: RoomVersion
     ):
         """Replace a room with a new room with a different version
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
         Returns:
             Deferred[unicode]: the new room id
         """
-        yield self.ratelimit(requester)
+        await self.ratelimit(requester)
 
         user_id = requester.user.to_string()
 
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
         # If this user has sent multiple upgrade requests for the same room
         # and one of them is not complete yet, cache the response and
         # return it to all subsequent requests
-        ret = yield self._upgrade_response_cache.wrap(
+        ret = await self._upgrade_response_cache.wrap(
             (old_room_id, user_id),
             self._upgrade_room,
             requester,
@@ -856,8 +853,7 @@ class RoomCreationHandler(BaseHandler):
         for (etype, state_key), content in initial_state.items():
             await send(etype=etype, state_key=state_key, content=content)
 
-    @defer.inlineCallbacks
-    def _generate_room_id(
+    async def _generate_room_id(
         self, creator_id: str, is_public: str, room_version: RoomVersion,
     ):
         # autogen room IDs and try to create it. We may clash, so just
@@ -869,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
                 gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
                 if isinstance(gen_room_id, bytes):
                     gen_room_id = gen_room_id.decode("utf-8")
-                yield self.store.store_room(
+                await self.store.store_room(
                     room_id=gen_room_id,
                     room_creator_user_id=creator_id,
                     is_public=is_public,
@@ -888,8 +884,7 @@ class RoomContextHandler(object):
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    @defer.inlineCallbacks
-    def get_event_context(self, user, room_id, event_id, limit, event_filter):
+    async def get_event_context(self, user, room_id, event_id, limit, event_filter):
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
@@ -908,7 +903,7 @@ class RoomContextHandler(object):
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
-        users = yield self.store.get_users_in_room(room_id)
+        users = await self.store.get_users_in_room(room_id)
         is_peeking = user.to_string() not in users
 
         def filter_evts(events):
@@ -916,17 +911,17 @@ class RoomContextHandler(object):
                 self.storage, user.to_string(), events, is_peeking=is_peeking
             )
 
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, get_prev_content=True, allow_none=True
         )
         if not event:
             return None
 
-        filtered = yield (filter_evts([event]))
+        filtered = await filter_evts([event])
         if not filtered:
             raise AuthError(403, "You don't have permission to access that event.")
 
-        results = yield self.store.get_events_around(
+        results = await self.store.get_events_around(
             room_id, event_id, before_limit, after_limit, event_filter
         )
 
@@ -934,8 +929,8 @@ class RoomContextHandler(object):
             results["events_before"] = event_filter.filter(results["events_before"])
             results["events_after"] = event_filter.filter(results["events_after"])
 
-        results["events_before"] = yield filter_evts(results["events_before"])
-        results["events_after"] = yield filter_evts(results["events_after"])
+        results["events_before"] = await filter_evts(results["events_before"])
+        results["events_after"] = await filter_evts(results["events_after"])
         # filter_evts can return a pruned event in case the user is allowed to see that
         # there's something there but not see the content, so use the event that's in
         # `filtered` rather than the event we retrieved from the datastore.
@@ -962,7 +957,7 @@ class RoomContextHandler(object):
         # first? Shouldn't we be consistent with /sync?
         # https://github.com/matrix-org/matrix-doc/issues/687
 
-        state = yield self.state_store.get_state_for_events(
+        state = await self.state_store.get_state_for_events(
             [last_event_id], state_filter=state_filter
         )
 
@@ -970,7 +965,7 @@ class RoomContextHandler(object):
         if event_filter:
             state_events = event_filter.filter(state_events)
 
-        results["state"] = yield filter_evts(state_events)
+        results["state"] = await filter_evts(state_events)
 
         # We use a dummy token here as we only care about the room portion of
         # the token, which we replace.
@@ -989,13 +984,12 @@ class RoomEventSource(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
-    def get_new_events(
+    async def get_new_events(
         self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
     ):
         # We just ignore the key for now.
 
-        to_key = yield self.get_current_key()
+        to_key = await self.get_current_key()
 
         from_token = RoomStreamToken.parse(from_key)
         if from_token.topological:
@@ -1008,11 +1002,11 @@ class RoomEventSource(object):
             # See https://github.com/matrix-org/matrix-doc/issues/1144
             raise NotImplementedError()
         else:
-            room_events = yield self.store.get_membership_changes_for_user(
+            room_events = await self.store.get_membership_changes_for_user(
                 user.to_string(), from_key, to_key
             )
 
-            room_to_events = yield self.store.get_room_events_stream_for_rooms(
+            room_to_events = await self.store.get_room_events_stream_for_rooms(
                 room_ids=room_ids,
                 from_key=from_key,
                 to_key=to_key,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 96f2dd36ad..e7015c704f 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import re
-from typing import Optional, Tuple
+from typing import Callable, Dict, Optional, Set, Tuple
 
 import attr
 import saml2
@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.http.server import finish_request
 from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
 from synapse.module_api.errors import RedirectException
 from synapse.types import (
@@ -81,17 +82,19 @@ class SamlHandler:
 
         self._error_html_content = hs.config.saml2_error_html_content
 
-    def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
+    def handle_redirect_request(
+        self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
+    ) -> bytes:
         """Handle an incoming request to /login/sso/redirect
 
         Args:
-            client_redirect_url (bytes): the URL that we should redirect the
+            client_redirect_url: the URL that we should redirect the
                 client to when everything is done
-            ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
         Returns:
-            bytes: URL to redirect to
+            URL to redirect to
         """
         reqid, info = self._saml_client.prepare_for_authenticate(
             relay_state=client_redirect_url
@@ -109,15 +112,15 @@ class SamlHandler:
         # this shouldn't happen!
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
-    async def handle_saml_response(self, request):
+    async def handle_saml_response(self, request: SynapseRequest) -> None:
         """Handle an incoming request to /_matrix/saml2/authn_response
 
         Args:
-            request (SynapseRequest): the incoming request from the browser. We'll
+            request: the incoming request from the browser. We'll
                 respond to it with a redirect.
 
         Returns:
-            Deferred[none]: Completes once we have handled the request.
+            Completes once we have handled the request.
         """
         resp_bytes = parse_string(request, "SAMLResponse", required=True)
         relay_state = parse_string(request, "RelayState", required=True)
@@ -310,6 +313,7 @@ DOT_REPLACE_PATTERN = re.compile(
 
 
 def dot_replace_for_mxid(username: str) -> str:
+    """Replace any characters which are not allowed in Matrix IDs with a dot."""
     username = username.lower()
     username = DOT_REPLACE_PATTERN.sub(".", username)
 
@@ -321,7 +325,7 @@ def dot_replace_for_mxid(username: str) -> str:
 MXID_MAPPER_MAP = {
     "hexencode": map_username_to_mxid_localpart,
     "dotreplace": dot_replace_for_mxid,
-}
+}  # type: Dict[str, Callable[[str], str]]
 
 
 @attr.s
@@ -349,7 +353,7 @@ class DefaultSamlMappingProvider(object):
 
     def get_remote_user_id(
         self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
-    ):
+    ) -> str:
         """Extracts the remote user id from the SAML response"""
         try:
             return saml_response.ava["uid"][0]
@@ -428,14 +432,14 @@ class DefaultSamlMappingProvider(object):
         return SamlConfig(mxid_source_attribute, mxid_mapper)
 
     @staticmethod
-    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+    def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
         """Returns the required attributes of a SAML
 
         Args:
             config: A SamlConfig object containing configuration params for this provider
 
         Returns:
-            tuple[set,set]: The first set equates to the saml auth response
+            The first set equates to the saml auth response
                 attributes that are required for the module to function, whereas the
                 second set consists of those attributes which can be used if
                 available, but are not necessary
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 5d7c8871a4..2904bd0235 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,14 +18,10 @@ from typing import Optional
 
 import six
 
-from synapse.storage.data_stores.main.cache import (
-    CURRENT_STATE_CACHE_NAME,
-    CacheInvalidationWorkerStore,
-)
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
-
-from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 logger = logging.getLogger(__name__)
 
@@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: Database, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = SlavedIdTracker(
-                db_conn, "cache_invalidation_stream", "stream_id"
-            )  # type: Optional[SlavedIdTracker]
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                instance_name=hs.get_instance_name(),
+                table="cache_invalidation_stream_by_instance",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="cache_invalidation_stream_seq",
+            )  # type: Optional[MultiWriterIdGenerator]
         else:
             self._cache_id_gen = None
 
         self.hs = hs
-
-    def get_cache_stream_token(self):
-        if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token()
-        else:
-            return 0
-
-    def process_replication_rows(self, stream_name, token, rows):
-        if stream_name == "caches":
-            if self._cache_id_gen:
-                self._cache_id_gen.advance(token)
-            for row in rows:
-                if row.cache_func == CURRENT_STATE_CACHE_NAME:
-                    if row.keys is None:
-                        raise Exception(
-                            "Can't send an 'invalidate all' for current state cache"
-                        )
-
-                    room_id = row.keys[0]
-                    members_changed = set(row.keys[1:])
-                    self._invalidate_state_caches(room_id, members_changed)
-                else:
-                    self._attempt_to_invalidate_cache(row.cache_func, row.keys)
-
-    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
-        txn.call_after(cache_func.invalidate, keys)
-        txn.call_after(self._send_invalidation_poke, cache_func, keys)
-
-    def _send_invalidation_poke(self, cache_func, keys):
-        self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 65e54b1c71..2a4f5c7cfd 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "tag_account_data":
             self._account_data_id_gen.advance(token)
             for row in rows:
@@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
                     (row.user_id, row.room_id, row.data_type)
                 )
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedAccountDataStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index c923751e50..6e7fd259d4 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
             expiry_ms=30 * 60 * 1000,
         )
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "to_device":
             self._device_inbox_id_gen.advance(token)
             for row in rows:
@@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
                     self._device_federation_outbox_stream_cache.entity_has_changed(
                         row.entity, token
                     )
-        return super(SlavedDeviceInboxStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 58fb0eaae3..9d8067342f 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             "DeviceListFederationStreamChangeCache", device_list_max
         )
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
             self._invalidate_caches_for_devices(token, rows)
@@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             self._device_list_id_gen.advance(token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedDeviceStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _invalidate_caches_for_devices(self, token, rows):
         for row in rows:
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 15011259df..b313720a4b 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -93,7 +93,7 @@ class SlavedEventStore(
     def get_room_min_stream_ordering(self):
         return self._backfill_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "events":
             self._stream_id_gen.advance(token)
             for row in rows:
@@ -111,9 +111,7 @@ class SlavedEventStore(
                     row.relates_to,
                     backfilled=True,
                 )
-        return super(SlavedEventStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _process_event_stream_row(self, token, row):
         data = row.data
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 01bcf0e882..1851e7d525 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
     def get_group_stream_token(self):
         return self._group_updates_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "groups":
             self._group_updates_id_gen.advance(token)
             for row in rows:
                 self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
 
-        return super(SlavedGroupServerStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index fae3125072..bd79ba99be 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore):
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "presence":
             self._presence_id_gen.advance(token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
                 self._get_presence_for_user.invalidate((row.user_id,))
-        return super(SlavedPresenceStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6138796da4..5d5816d7eb 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "push_rules":
             self._push_rules_stream_id_gen.advance(token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
                 self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedPushRuleStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 67be337945..cb78b49acb 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
     def get_pushers_stream_token(self):
         return self._pushers_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "pushers":
             self._pushers_id_gen.advance(token)
-        return super(SlavedPusherStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 993432edcb..be716cc558 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
         self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "receipts":
             self._receipts_id_gen.advance(token)
             for row in rows:
@@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
                 )
                 self._receipts_stream_cache.entity_has_changed(row.room_id, token)
 
-        return super(SlavedReceiptsStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 10dda8708f..8873bf37e5 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "public_rooms":
             self._public_room_id_gen.advance(token)
 
-        return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3bbf3c3569..20cb8a654f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -100,10 +100,10 @@ class ReplicationDataHandler:
             token: stream token for this batch of rows
             rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
-        self.store.process_replication_rows(stream_name, token, rows)
+        self.store.process_replication_rows(stream_name, instance_name, token, rows)
 
-    async def on_position(self, stream_name: str, token: int):
-        self.store.process_replication_rows(stream_name, token, [])
+    async def on_position(self, stream_name: str, instance_name: str, token: int):
+        self.store.process_replication_rows(stream_name, instance_name, token, [])
 
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f58e384d17..c04f622816 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
         return " ".join((self.app_id, self.push_key, self.user_id))
 
 
-class InvalidateCacheCommand(Command):
-    """Sent by the client to invalidate an upstream cache.
-
-    THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
-    NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
-    Mainly used to invalidate destination retry timing caches.
-
-    Format::
-
-        INVALIDATE_CACHE <cache_func> <keys_json>
-
-    Where <keys_json> is a json list.
-    """
-
-    NAME = "INVALIDATE_CACHE"
-
-    def __init__(self, cache_func, keys):
-        self.cache_func = cache_func
-        self.keys = keys
-
-    @classmethod
-    def from_line(cls, line):
-        cache_func, keys_json = line.split(" ", 1)
-
-        return cls(cache_func, json.loads(keys_json))
-
-    def to_line(self):
-        return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
 class UserIpCommand(Command):
     """Sent periodically when a worker sees activity from a client.
 
@@ -439,7 +408,6 @@ _COMMANDS = (
     UserSyncCommand,
     FederationAckCommand,
     RemovePusherCommand,
-    InvalidateCacheCommand,
     UserIpCommand,
     RemoteServerUpCommand,
     ClearUserSyncsCommand,
@@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
     ClearUserSyncsCommand.NAME,
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
-    InvalidateCacheCommand.NAME,
     UserIpCommand.NAME,
     ErrorCommand.NAME,
     RemoteServerUpCommand.NAME,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b14a3d9fca..7c5d6c76e7 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,18 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    Iterable,
-    Iterator,
-    List,
-    Optional,
-    Set,
-    Tuple,
-    TypeVar,
-)
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
 
 from prometheus_client import Counter
 
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
     Command,
     FederationAckCommand,
-    InvalidateCacheCommand,
     PositionCommand,
     RdataCommand,
     RemoteServerUpCommand,
@@ -171,7 +159,7 @@ class ReplicationCommandHandler:
             return
 
         for stream_name, stream in self._streams.items():
-            current_token = stream.current_token()
+            current_token = stream.current_token(self._instance_name)
             self.send_command(
                 PositionCommand(stream_name, self._instance_name, current_token)
             )
@@ -210,18 +198,6 @@ class ReplicationCommandHandler:
 
             self._notifier.on_new_replication_data()
 
-    async def on_INVALIDATE_CACHE(
-        self, conn: AbstractConnection, cmd: InvalidateCacheCommand
-    ):
-        invalidate_cache_counter.inc()
-
-        if self._is_master:
-            # We invalidate the cache locally, but then also stream that to other
-            # workers.
-            await self._store.invalidate_cache_and_stream(
-                cmd.cache_func, tuple(cmd.keys)
-            )
-
     async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
         user_ip_cache_counter.inc()
 
@@ -295,7 +271,7 @@ class ReplicationCommandHandler:
             rows: a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
         """
-        logger.debug("Received rdata %s -> %s", stream_name, token)
+        logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
         await self._replication_data_handler.on_rdata(
             stream_name, instance_name, token, rows
         )
@@ -326,7 +302,7 @@ class ReplicationCommandHandler:
             self._pending_batches.pop(stream_name, [])
 
             # Find where we previously streamed up to.
-            current_token = stream.current_token()
+            current_token = stream.current_token(cmd.instance_name)
 
             # If the position token matches our current token then we're up to
             # date and there's nothing to do. Otherwise, fetch all updates
@@ -363,7 +339,9 @@ class ReplicationCommandHandler:
             logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
 
             # We've now caught up to position sent to us, notify handler.
-            await self._replication_data_handler.on_position(stream_name, cmd.token)
+            await self._replication_data_handler.on_position(
+                cmd.stream_name, cmd.instance_name, cmd.token
+            )
 
             self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
@@ -491,12 +469,6 @@ class ReplicationCommandHandler:
         cmd = RemovePusherCommand(app_id, push_key, user_id)
         self.send_command(cmd)
 
-    def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
-        """Poke the master to invalidate a cache.
-        """
-        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
-        self.send_command(cmd)
-
     def send_user_ip(
         self,
         user_id: str,
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 33d2f589ac..002171ce7c 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
+from synapse.replication.tcp.streams import (
+    STREAMS_MAP,
+    CachesStream,
+    FederationStream,
+    Stream,
+)
 from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
@@ -71,16 +76,25 @@ class ReplicationStreamer(object):
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
+        self._instance_name = hs.get_instance_name()
 
         self._replication_torture_level = hs.config.replication_torture_level
 
         # Work out list of streams that this instance is the source of.
         self.streams = []  # type: List[Stream]
+
+        # All workers can write to the cache invalidation stream.
+        self.streams.append(CachesStream(hs))
+
         if hs.config.worker_app is None:
             for stream in STREAMS_MAP.values():
                 if stream == FederationStream and hs.config.send_federation:
                     # We only support federation stream if federation sending
-                    # hase been disabled on the master.
+                    # has been disabled on the master.
+                    continue
+
+                if stream == CachesStream:
+                    # We've already added it above.
                     continue
 
                 self.streams.append(stream(hs))
@@ -145,7 +159,9 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.current_token():
+                        if stream.last_token == stream.current_token(
+                            self._instance_name
+                        ):
                             continue
 
                         if self._replication_torture_level:
@@ -157,7 +173,7 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.current_token(),
+                            stream.current_token(self._instance_name),
                         )
                         try:
                             updates, current_token, limited = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b0f87c365b..b48a6a3e91 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,19 +95,25 @@ class Stream(object):
     def __init__(
         self,
         local_instance_name: str,
-        current_token_function: Callable[[], Token],
+        current_token_function: Callable[[str], Token],
         update_function: UpdateFunction,
     ):
         """Instantiate a Stream
 
-        current_token_function and update_function are callbacks which should be
-        implemented by subclasses.
+        `current_token_function` and `update_function` are callbacks which
+        should be implemented by subclasses.
 
-        current_token_function is called to get the current token of the underlying
-        stream.
+        `current_token_function` takes an instance name, which is a writer to
+        the stream, and returns the position in the stream of the writer (as
+        viewed from the current process). On the writer process this is where
+        the writer has successfully written up to, whereas on other processes
+        this is the position which we have received updates up to over
+        replication. (Note that most streams have a single writer and so their
+        implementations ignore the instance name passed in).
 
-        update_function is called to get updates for this stream between a pair of
-        stream tokens. See the UpdateFunction type definition for more info.
+        `update_function` is called to get updates for this stream between a
+        pair of stream tokens. See the `UpdateFunction` type definition for more
+        info.
 
         Args:
             local_instance_name: The instance name of the current process
@@ -119,13 +125,13 @@ class Stream(object):
         self.update_function = update_function
 
         # The token from which we last asked for updates
-        self.last_token = self.current_token()
+        self.last_token = self.current_token(self.local_instance_name)
 
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.last_token = self.current_token()
+        self.last_token = self.current_token(self.local_instance_name)
 
     async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
@@ -137,7 +143,7 @@ class Stream(object):
             position in stream, and `limited` is whether there are more updates
             to fetch.
         """
-        current_token = self.current_token()
+        current_token = self.current_token(self.local_instance_name)
         updates, current_token, limited = await self.get_updates_since(
             self.local_instance_name, self.last_token, current_token
         )
@@ -169,6 +175,16 @@ class Stream(object):
         return updates, upto_token, limited
 
 
+def current_token_without_instance(
+    current_token: Callable[[], int]
+) -> Callable[[str], int]:
+    """Takes a current token callback function for a single writer stream
+    that doesn't take an instance name parameter and wraps it in a function that
+    does accept an instance name parameter but ignores it.
+    """
+    return lambda instance_name: current_token()
+
+
 def db_query_to_update_function(
     query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
 ) -> UpdateFunction:
@@ -234,7 +250,7 @@ class BackfillStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_current_backfill_token,
+            current_token_without_instance(store.get_current_backfill_token),
             db_query_to_update_function(store.get_all_new_backfill_event_rows),
         )
 
@@ -270,7 +286,9 @@ class PresenceStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(), store.get_current_presence_token, update_function
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_presence_token),
+            update_function,
         )
 
 
@@ -295,7 +313,9 @@ class TypingStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(), typing_handler.get_current_token, update_function
+            hs.get_instance_name(),
+            current_token_without_instance(typing_handler.get_current_token),
+            update_function,
         )
 
 
@@ -318,7 +338,7 @@ class ReceiptsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_max_receipt_stream_id,
+            current_token_without_instance(store.get_max_receipt_stream_id),
             db_query_to_update_function(store.get_all_updated_receipts),
         )
 
@@ -338,7 +358,7 @@ class PushRulesStream(Stream):
             hs.get_instance_name(), self._current_token, self._update_function
         )
 
-    def _current_token(self) -> int:
+    def _current_token(self, instance_name: str) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
@@ -372,7 +392,7 @@ class PushersStream(Stream):
 
         super().__init__(
             hs.get_instance_name(),
-            store.get_pushers_stream_token,
+            current_token_without_instance(store.get_pushers_stream_token),
             db_query_to_update_function(store.get_all_updated_pushers_rows),
         )
 
@@ -401,12 +421,26 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
-        store = hs.get_datastore()
+        self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token,
-            db_query_to_update_function(store.get_all_updated_caches),
+            self.store.get_cache_stream_token,
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, upto_token: int, limit: int
+    ):
+        rows = await self.store.get_all_updated_caches(
+            instance_name, from_token, upto_token, limit
         )
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) >= limit:
+            upto_token = updates[-1][0]
+            limited = True
+
+        return updates, upto_token, limited
 
 
 class PublicRoomsStream(Stream):
@@ -430,7 +464,7 @@ class PublicRoomsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_current_public_room_stream_id,
+            current_token_without_instance(store.get_current_public_room_stream_id),
             db_query_to_update_function(store.get_all_new_public_rooms),
         )
 
@@ -451,7 +485,7 @@ class DeviceListsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_device_stream_token,
+            current_token_without_instance(store.get_device_stream_token),
             db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
         )
 
@@ -469,7 +503,7 @@ class ToDeviceStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_to_device_stream_token,
+            current_token_without_instance(store.get_to_device_stream_token),
             db_query_to_update_function(store.get_all_new_device_messages),
         )
 
@@ -489,7 +523,7 @@ class TagAccountDataStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_max_account_data_stream_id,
+            current_token_without_instance(store.get_max_account_data_stream_id),
             db_query_to_update_function(store.get_all_updated_tags),
         )
 
@@ -509,7 +543,7 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self.store.get_max_account_data_stream_id,
+            current_token_without_instance(self.store.get_max_account_data_stream_id),
             db_query_to_update_function(self._update_function),
         )
 
@@ -540,7 +574,7 @@ class GroupServerStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_group_stream_token,
+            current_token_without_instance(store.get_group_stream_token),
             db_query_to_update_function(store.get_all_groups_changes),
         )
 
@@ -558,7 +592,7 @@ class UserSignatureStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_device_stream_token,
+            current_token_without_instance(store.get_device_stream_token),
             db_query_to_update_function(
                 store.get_all_user_signature_changes_for_remotes
             ),
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 890e75d827..f370390331 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -20,7 +20,7 @@ from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token
+from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
 
 
 """Handling of the 'events' replication stream
@@ -119,7 +119,7 @@ class EventsStream(Stream):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self._store.get_current_events_token,
+            current_token_without_instance(self._store.get_current_events_token),
             self._update_function,
         )
 
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index e8bd52e389..9bcd13b009 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,11 @@
 # limitations under the License.
 from collections import namedtuple
 
-from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
+from synapse.replication.tcp.streams._base import (
+    Stream,
+    current_token_without_instance,
+    make_http_update_function,
+)
 
 
 class FederationStream(Stream):
@@ -35,21 +39,35 @@ class FederationStream(Stream):
     ROW_TYPE = FederationStreamRow
 
     def __init__(self, hs):
-        # Not all synapse instances will have a federation sender instance,
-        # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
-        # so we stub the stream out when that is the case.
-        if hs.config.worker_app is None or hs.should_send_federation():
+        if hs.config.worker_app is None:
+            # master process: get updates from the FederationRemoteSendQueue.
+            # (if the master is configured to send federation itself, federation_sender
+            # will be a real FederationSender, which has stubs for current_token and
+            # get_replication_rows.)
             federation_sender = hs.get_federation_sender()
-            current_token = federation_sender.get_current_token
-            update_function = db_query_to_update_function(
-                federation_sender.get_replication_rows
+            current_token = current_token_without_instance(
+                federation_sender.get_current_token
             )
+            update_function = federation_sender.get_replication_rows
+
+        elif hs.should_send_federation():
+            # federation sender: Query master process
+            update_function = make_http_update_function(hs, self.NAME)
+            current_token = self._stub_current_token
+
         else:
-            current_token = lambda: 0
+            # other worker: stub out the update function (we're not interested in
+            # any updates so when we get a POSITION we do nothing)
             update_function = self._stub_update_function
+            current_token = self._stub_current_token
 
         super().__init__(hs.get_instance_name(), current_token, update_function)
 
     @staticmethod
+    def _stub_current_token(instance_name: str) -> int:
+        # dummy current-token method for use on workers
+        return 0
+
+    @staticmethod
     async def _stub_update_function(instance_name, from_token, upto_token, limit):
         return [], upto_token, False
diff --git a/synapse/res/templates/notice_expiry.html b/synapse/res/templates/notice_expiry.html
index f0d7c66e1b..6b94d8c367 100644
--- a/synapse/res/templates/notice_expiry.html
+++ b/synapse/res/templates/notice_expiry.html
@@ -30,7 +30,7 @@
                         <tr>
                           <td colspan="2">
                             <div class="noticetext">Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.</div>
-                            <div class="noticetext">To extend the validity of your account, please click on the link bellow (or copy and paste it into a new browser tab):</div>
+                            <div class="noticetext">To extend the validity of your account, please click on the link below (or copy and paste it into a new browser tab):</div>
                             <div class="noticetext"><a href="{{ url }}">{{ url }}</a></div>
                           </td>
                         </tr>
diff --git a/synapse/res/templates/notice_expiry.txt b/synapse/res/templates/notice_expiry.txt
index 41f1c4279c..4ec27e8831 100644
--- a/synapse/res/templates/notice_expiry.txt
+++ b/synapse/res/templates/notice_expiry.txt
@@ -2,6 +2,6 @@ Hi {{ display_name }},
 
 Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
 
-To extend the validity of your account, please click on the link bellow (or copy and paste it to a new browser tab):
+To extend the validity of your account, please click on the link below (or copy and paste it to a new browser tab):
 
 {{ url }}
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 13de5f1f62..59073c0a42 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -47,6 +47,9 @@ class SQLBaseStore(metaclass=ABCMeta):
         self.db = database
         self.rand = random.SystemRandom()
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        pass
+
     def _invalidate_state_caches(self, room_id, members_changed):
         """Invalidates caches that are based on the current state, but does
         not stream invalidations down replication.
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index ceba10882c..cd2a1f0461 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -26,13 +26,14 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     ChainedIdGenerator,
     IdGenerator,
+    MultiWriterIdGenerator,
     StreamIdGenerator,
 )
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from .account_data import AccountDataStore
 from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .cache import CacheInvalidationStore
+from .cache import CacheInvalidationWorkerStore
 from .client_ips import ClientIpStore
 from .deviceinbox import DeviceInboxStore
 from .devices import DeviceStore
@@ -112,8 +113,8 @@ class DataStore(
     MonthlyActiveUsersStore,
     StatsStore,
     RelationsStore,
-    CacheInvalidationStore,
     UIAuthStore,
+    CacheInvalidationWorkerStore,
 ):
     def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
@@ -170,8 +171,14 @@ class DataStore(
         )
 
         if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = StreamIdGenerator(
-                db_conn, "cache_invalidation_stream", "stream_id"
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                instance_name="master",
+                table="cache_invalidation_stream_by_instance",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="cache_invalidation_stream_seq",
             )
         else:
             self._cache_id_gen = None
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 4dc5da3fe8..342a87a46b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,11 +16,10 @@
 
 import itertools
 import logging
-from typing import Any, Iterable, Optional, Tuple
-
-from twisted.internet import defer
+from typing import Any, Iterable, Optional
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
@@ -33,47 +32,58 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
 class CacheInvalidationWorkerStore(SQLBaseStore):
-    def get_all_updated_caches(self, last_id, current_id, limit):
+    def __init__(self, database: Database, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._instance_name = hs.get_instance_name()
+
+    async def get_all_updated_caches(
+        self, instance_name: str, last_id: int, current_id: int, limit: int
+    ):
+        """Fetches cache invalidation rows between the two given IDs written
+        by the given instance. Returns at most `limit` rows.
+        """
+
         if last_id == current_id:
-            return defer.succeed([])
+            return []
 
         def get_all_updated_caches_txn(txn):
             # We purposefully don't bound by the current token, as we want to
             # send across cache invalidations as quickly as possible. Cache
             # invalidations are idempotent, so duplicates are fine.
-            sql = (
-                "SELECT stream_id, cache_func, keys, invalidation_ts"
-                " FROM cache_invalidation_stream"
-                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, limit))
+            sql = """
+                SELECT stream_id, cache_func, keys, invalidation_ts
+                FROM cache_invalidation_stream_by_instance
+                WHERE stream_id > ? AND instance_name = ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+            txn.execute(sql, (last_id, instance_name, limit))
             return txn.fetchall()
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "get_all_updated_caches", get_all_updated_caches_txn
         )
 
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
+        if stream_name == "caches":
+            if self._cache_id_gen:
+                self._cache_id_gen.advance(instance_name, token)
 
-class CacheInvalidationStore(CacheInvalidationWorkerStore):
-    async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
-        """Invalidates the cache and adds it to the cache stream so slaves
-        will know to invalidate their caches.
+            for row in rows:
+                if row.cache_func == CURRENT_STATE_CACHE_NAME:
+                    if row.keys is None:
+                        raise Exception(
+                            "Can't send an 'invalidate all' for current state cache"
+                        )
 
-        This should only be used to invalidate caches where slaves won't
-        otherwise know from other replication streams that the cache should
-        be invalidated.
-        """
-        cache_func = getattr(self, cache_name, None)
-        if not cache_func:
-            return
-
-        cache_func.invalidate(keys)
-        await self.runInteraction(
-            "invalidate_cache_and_stream",
-            self._send_invalidation_to_replication,
-            cache_func.__name__,
-            keys,
-        )
+                    room_id = row.keys[0]
+                    members_changed = set(row.keys[1:])
+                    self._invalidate_state_caches(room_id, members_changed)
+                else:
+                    self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _invalidate_cache_and_stream(self, txn, cache_func, keys):
         """Invalidates the cache and adds it to the cache stream so slaves
@@ -147,10 +157,7 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
             # the transaction. However, we want to only get an ID when we want
             # to use it, here, so we need to call __enter__ manually, and have
             # __exit__ called after the transaction finishes.
-            ctx = self._cache_id_gen.get_next()
-            stream_id = ctx.__enter__()
-            txn.call_on_exception(ctx.__exit__, None, None, None)
-            txn.call_after(ctx.__exit__, None, None, None)
+            stream_id = self._cache_id_gen.get_next_txn(txn)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
             if keys is not None:
@@ -158,17 +165,18 @@ class CacheInvalidationStore(CacheInvalidationWorkerStore):
 
             self.db.simple_insert_txn(
                 txn,
-                table="cache_invalidation_stream",
+                table="cache_invalidation_stream_by_instance",
                 values={
                     "stream_id": stream_id,
+                    "instance_name": self._instance_name,
                     "cache_func": cache_name,
                     "keys": keys,
                     "invalidation_ts": self.clock.time_msec(),
                 },
             )
 
-    def get_cache_stream_token(self):
+    def get_cache_stream_token(self, instance_name):
         if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token()
+            return self._cache_id_gen.get_current_token(instance_name)
         else:
             return 0
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 03f5141e6c..fe6d6ecfe0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -55,6 +55,10 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 
 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
+BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = (
+    "drop_device_lists_outbound_last_success_non_unique_idx"
+)
+
 
 class DeviceWorkerStore(SQLBaseStore):
     def get_device(self, user_id, device_id):
@@ -342,32 +346,23 @@ class DeviceWorkerStore(SQLBaseStore):
 
     def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
         # We update the device_lists_outbound_last_success with the successfully
-        # poked users. We do the join to see which users need to be inserted and
-        # which updated.
+        # poked users.
         sql = """
-            SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+            SELECT user_id, coalesce(max(o.stream_id), 0)
             FROM device_lists_outbound_pokes as o
-            LEFT JOIN device_lists_outbound_last_success as s
-                USING (destination, user_id)
             WHERE destination = ? AND o.stream_id <= ?
             GROUP BY user_id
         """
         txn.execute(sql, (destination, stream_id))
         rows = txn.fetchall()
 
-        sql = """
-            UPDATE device_lists_outbound_last_success
-            SET stream_id = ?
-            WHERE destination = ? AND user_id = ?
-        """
-        txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
-
-        sql = """
-            INSERT INTO device_lists_outbound_last_success
-            (destination, user_id, stream_id) VALUES (?, ?, ?)
-        """
-        txn.executemany(
-            sql, ((destination, row[0], row[1]) for row in rows if not row[2])
+        self.db.simple_upsert_many_txn(
+            txn=txn,
+            table="device_lists_outbound_last_success",
+            key_names=("destination", "user_id"),
+            key_values=((destination, user_id) for user_id, _ in rows),
+            value_names=("stream_id",),
+            value_values=((stream_id,) for _, stream_id in rows),
         )
 
         # Delete all sent outbound pokes
@@ -725,6 +720,21 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
         )
 
+        # create a unique index on device_lists_outbound_last_success
+        self.db.updates.register_background_index_update(
+            "device_lists_outbound_last_success_unique_idx",
+            index_name="device_lists_outbound_last_success_unique_idx",
+            table="device_lists_outbound_last_success",
+            columns=["destination", "user_id"],
+            unique=True,
+        )
+
+        # once that completes, we can remove the old non-unique index.
+        self.db.updates.register_background_update_handler(
+            BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX,
+            self._drop_device_lists_outbound_last_success_non_unique_idx,
+        )
+
     @defer.inlineCallbacks
     def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
         def f(conn):
@@ -799,6 +809,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
         return rows
 
+    async def _drop_device_lists_outbound_last_success_non_unique_idx(
+        self, progress, batch_size
+    ):
+        def f(txn):
+            txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx")
+
+        await self.db.runInteraction(
+            "drop_device_lists_outbound_last_success_non_unique_idx", f,
+        )
+        await self.db.updates._end_background_update(
+            BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX
+        )
+        return 1
+
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index bcf746b7ef..20698bfd16 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -25,7 +25,9 @@ from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
 
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -268,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
-    def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
-        """Returns a user's cross-signing key.
-
-        Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            user_id (str): the user whose key is being requested
-            key_type (str): the type of key that is being requested: either 'master'
-                for a master key, 'self_signing' for a self-signing key, or
-                'user_signing' for a user-signing key
-            from_user_id (str): if specified, signatures made by this user on
-                the key will be included in the result
-
-        Returns:
-            dict of the key data or None if not found
-        """
-        sql = (
-            "SELECT keydata "
-            "  FROM e2e_cross_signing_keys "
-            " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
-        )
-        txn.execute(sql, (user_id, key_type))
-        row = txn.fetchone()
-        if not row:
-            return None
-        key = json.loads(row[0])
-
-        device_id = None
-        for k in key["keys"].values():
-            device_id = k
-
-        if from_user_id is not None:
-            sql = (
-                "SELECT key_id, signature "
-                "  FROM e2e_cross_signing_signatures "
-                " WHERE user_id = ? "
-                "   AND target_user_id = ? "
-                "   AND target_device_id = ? "
-            )
-            txn.execute(sql, (from_user_id, user_id, device_id))
-            row = txn.fetchone()
-            if row:
-                key.setdefault("signatures", {}).setdefault(from_user_id, {})[
-                    row[0]
-                ] = row[1]
-
-        return key
-
+    @defer.inlineCallbacks
     def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
         """Returns a user's cross-signing key.
 
@@ -329,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         Returns:
             dict of the key data or None if not found
         """
-        return self.db.runInteraction(
-            "get_e2e_cross_signing_key",
-            self._get_e2e_cross_signing_key_txn,
-            user_id,
-            key_type,
-            from_user_id,
-        )
+        res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+        user_keys = res.get(user_id)
+        if not user_keys:
+            return None
+        return user_keys.get(key_type)
 
     @cached(num_args=1)
     def _get_bare_e2e_cross_signing_keys(self, user_id):
@@ -391,26 +345,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         """
         result = {}
 
-        batch_size = 100
-        chunks = [
-            user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
-        ]
-        for user_chunk in chunks:
-            sql = """
+        for user_chunk in batch_iter(user_ids, 100):
+            clause, params = make_in_list_sql_clause(
+                txn.database_engine, "k.user_id", user_chunk
+            )
+            sql = (
+                """
                 SELECT k.user_id, k.keytype, k.keydata, k.stream_id
                   FROM e2e_cross_signing_keys k
                   INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
                                 FROM e2e_cross_signing_keys
                                GROUP BY user_id, keytype) s
                  USING (user_id, stream_id, keytype)
-                 WHERE k.user_id IN (%s)
-            """ % (
-                ",".join("?" for u in user_chunk),
+                 WHERE
+            """
+                + clause
             )
-            query_params = []
-            query_params.extend(user_chunk)
 
-            txn.execute(sql, query_params)
+            txn.execute(sql, params)
             rows = self.db.cursor_to_dict(txn)
 
             for row in rows:
@@ -453,15 +405,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                     device_id = k
                 devices[(user_id, device_id)] = key_type
 
-        device_list = list(devices)
-
-        # split into batches
-        batch_size = 100
-        chunks = [
-            device_list[i : i + batch_size]
-            for i in range(0, len(device_list), batch_size)
-        ]
-        for user_chunk in chunks:
+        for batch in batch_iter(devices.keys(), size=100):
             sql = """
                 SELECT target_user_id, target_device_id, key_id, signature
                   FROM e2e_cross_signing_signatures
@@ -469,11 +413,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                    AND (%s)
             """ % (
                 " OR ".join(
-                    "(target_user_id = ? AND target_device_id = ?)" for d in devices
+                    "(target_user_id = ? AND target_device_id = ?)" for _ in batch
                 )
             )
             query_params = [from_user_id]
-            for item in devices:
+            for item in batch:
                 # item is a (user_id, device_id) tuple
                 query_params.extend(item)
 
diff --git a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
new file mode 100644
index 0000000000..d5e6deb878
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will create a unique index on
+-- device_lists_outbound_last_success
+INSERT into background_updates (ordering, update_name, progress_json)
+    VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}');
+
+-- once that completes, we can drop the old index.
+INSERT into background_updates (ordering, update_name, progress_json, depends_on)
+    VALUES (
+        5804,
+        'drop_device_lists_outbound_last_success_non_unique_idx',
+        '{}',
+        'device_lists_outbound_last_success_unique_idx'
+    );
diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
new file mode 100644
index 0000000000..aa46eb0e10
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We keep the old table here to enable us to roll back. It doesn't matter
+-- that we have dropped all the data here.
+TRUNCATE cache_invalidation_stream;
+
+CREATE TABLE cache_invalidation_stream_by_instance (
+    stream_id       BIGINT NOT NULL,
+    instance_name   TEXT NOT NULL,
+    cache_func      TEXT NOT NULL,
+    keys            TEXT[],
+    invalidation_ts BIGINT
+);
+
+CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
+
+CREATE SEQUENCE cache_invalidation_stream_seq;
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a7cd97b0b0..2b635d6ca0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
 from synapse.util.stringutils import exception_to_unicode
 
 logger = logging.getLogger(__name__)
@@ -78,6 +79,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
     "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
     "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
     "event_search": "event_search_event_id_idx",
+    "device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx",
 }
 
 
@@ -889,20 +891,24 @@ class Database(object):
         txn.execute(sql, list(allvalues.values()))
 
     def simple_upsert_many_txn(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[str]],
+    ) -> None:
         """
         Upsert, many times.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_many_txn_native_upsert(
@@ -914,20 +920,24 @@ class Database(object):
             )
 
     def simple_upsert_many_txn_emulated(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Iterable[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[str]],
+    ) -> None:
         """
         Upsert, many times, but without native UPSERT support or batching.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         # No value columns, therefore make a blank list so that the following
         # zip() works correctly.
@@ -941,20 +951,24 @@ class Database(object):
             self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
 
     def simple_upsert_many_txn_native_upsert(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[Any]],
+    ) -> None:
         """
         Upsert, many times, using batching where possible.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         allnames = []  # type: List[str]
         allnames.extend(key_names)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1712932f31..640f242584 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -29,6 +29,8 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
+# XXX: If you're about to bump this to 59 (or higher) please create an update
+# that drops the unused `cache_invalidation_stream` table, as per #7436!
 SCHEMA_VERSION = 58
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d851beaa5..86d04ea9ac 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
 import contextlib
 import threading
 from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
 
 
 class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
             )
-        self._unfinished_ids = deque()
+        self._unfinished_ids = deque()  # type: Deque[int]
 
     def get_next(self):
         """
@@ -163,7 +168,7 @@ class ChainedIdGenerator(object):
         self.chained_generator = chained_generator
         self._lock = threading.Lock()
         self._current_max = _load_current_id(db_conn, table, column)
-        self._unfinished_ids = deque()
+        self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
 
     def get_next(self):
         """
@@ -198,3 +203,163 @@ class ChainedIdGenerator(object):
                 return stream_id - 1, chained_id
 
             return self._current_max, self.chained_generator.get_current_token()
+
+
+class MultiWriterIdGenerator:
+    """An ID generator that tracks a stream that can have multiple writers.
+
+    Uses a Postgres sequence to coordinate ID assignment, but positions of other
+    writers will only get updated when `advance` is called (by replication).
+
+    Note: Only works with Postgres.
+
+    Args:
+        db_conn
+        db
+        instance_name: The name of this instance.
+        table: Database table associated with stream.
+        instance_column: Column that stores the row's writer's instance name
+        id_column: Column that stores the stream ID.
+        sequence_name: The name of the postgres sequence used to generate new
+            IDs.
+    """
+
+    def __init__(
+        self,
+        db_conn,
+        db: Database,
+        instance_name: str,
+        table: str,
+        instance_column: str,
+        id_column: str,
+        sequence_name: str,
+    ):
+        self._db = db
+        self._instance_name = instance_name
+        self._sequence_name = sequence_name
+
+        # We lock as some functions may be called from DB threads.
+        self._lock = threading.Lock()
+
+        self._current_positions = self._load_current_ids(
+            db_conn, table, instance_column, id_column
+        )
+
+        # Set of local IDs that we're still processing. The current position
+        # should be less than the minimum of this set (if not empty).
+        self._unfinished_ids = set()  # type: Set[int]
+
+    def _load_current_ids(
+        self, db_conn, table: str, instance_column: str, id_column: str
+    ) -> Dict[str, int]:
+        sql = """
+            SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+            GROUP BY %(instance)s
+        """ % {
+            "instance": instance_column,
+            "id": id_column,
+            "table": table,
+        }
+
+        cur = db_conn.cursor()
+        cur.execute(sql)
+
+        # `cur` is an iterable over returned rows, which are 2-tuples.
+        current_positions = dict(cur)
+
+        cur.close()
+
+        return current_positions
+
+    def _load_next_id_txn(self, txn):
+        txn.execute("SELECT nextval(?)", (self._sequence_name,))
+        (next_id,) = txn.fetchone()
+        return next_id
+
+    async def get_next(self):
+        """
+        Usage:
+            with await stream_id_gen.get_next() as stream_id:
+                # ... persist event ...
+        """
+        next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+        # Assert the fetched ID is actually greater than what we currently
+        # believe the ID to be. If not, then the sequence and table have got
+        # out of sync somehow.
+        assert self.get_current_token() < next_id
+
+        with self._lock:
+            self._unfinished_ids.add(next_id)
+
+        @contextlib.contextmanager
+        def manager():
+            try:
+                yield next_id
+            finally:
+                self._mark_id_as_finished(next_id)
+
+        return manager()
+
+    def get_next_txn(self, txn: LoggingTransaction):
+        """
+        Usage:
+
+            stream_id = stream_id_gen.get_next(txn)
+            # ... persist event ...
+        """
+
+        next_id = self._load_next_id_txn(txn)
+
+        with self._lock:
+            self._unfinished_ids.add(next_id)
+
+        txn.call_after(self._mark_id_as_finished, next_id)
+        txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+        return next_id
+
+    def _mark_id_as_finished(self, next_id: int):
+        """The ID has finished being processed so we should advance the
+        current poistion if possible.
+        """
+
+        with self._lock:
+            self._unfinished_ids.discard(next_id)
+
+            # Figure out if its safe to advance the position by checking there
+            # aren't any lower allocated IDs that are yet to finish.
+            if all(c > next_id for c in self._unfinished_ids):
+                curr = self._current_positions.get(self._instance_name, 0)
+                self._current_positions[self._instance_name] = max(curr, next_id)
+
+    def get_current_token(self, instance_name: str = None) -> int:
+        """Gets the current position of a named writer (defaults to current
+        instance).
+
+        Returns 0 if we don't have a position for the named writer (likely due
+        to it being a new writer).
+        """
+
+        if instance_name is None:
+            instance_name = self._instance_name
+
+        with self._lock:
+            return self._current_positions.get(instance_name, 0)
+
+    def get_positions(self) -> Dict[str, int]:
+        """Get a copy of the current positon map.
+        """
+
+        with self._lock:
+            return dict(self._current_positions)
+
+    def advance(self, instance_name: str, new_id: int):
+        """Advance the postion of the named writer to the given ID, if greater
+        than existing entry.
+        """
+
+        with self._lock:
+            self._current_positions[instance_name] = max(
+                new_id, self._current_positions.get(instance_name, 0)
+            )