summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7219.misc2
-rw-r--r--changelog.d/7281.misc1
-rw-r--r--changelog.d/7374.misc1
-rw-r--r--changelog.d/7382.misc1
-rw-r--r--changelog.d/7396.misc1
-rw-r--r--changelog.d/7429.misc1
-rw-r--r--synapse/federation/send_queue.py84
-rw-r--r--synapse/federation/sender/__init__.py12
-rw-r--r--synapse/federation/sender/per_destination_queue.py6
-rw-r--r--synapse/federation/sender/transaction_manager.py6
-rw-r--r--synapse/handlers/room.py42
-rw-r--r--synapse/replication/tcp/resource.py2
-rw-r--r--synapse/replication/tcp/streams/_base.py3
-rw-r--r--synapse/replication/tcp/streams/federation.py30
-rw-r--r--synapse/storage/data_stores/main/devices.py60
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql28
-rw-r--r--synapse/storage/database.py74
-rw-r--r--synapse/storage/util/id_generators.py169
-rw-r--r--tests/replication/tcp/streams/_base.py20
-rw-r--r--tests/replication/tcp/streams/test_federation.py81
-rw-r--r--tests/replication/tcp/streams/test_typing.py5
-rw-r--r--tests/storage/test_id_generators.py184
-rw-r--r--tox.ini7
23 files changed, 678 insertions, 142 deletions
diff --git a/changelog.d/7219.misc b/changelog.d/7219.misc
index 4af5da8646..dbf7a530be 100644
--- a/changelog.d/7219.misc
+++ b/changelog.d/7219.misc
@@ -1 +1 @@
-Add typing information to federation server code.
+Add typing annotations in `synapse.federation`.
diff --git a/changelog.d/7281.misc b/changelog.d/7281.misc
new file mode 100644
index 0000000000..86ad511e19
--- /dev/null
+++ b/changelog.d/7281.misc
@@ -0,0 +1 @@
+Add MultiWriterIdGenerator to support multiple concurrent writers of streams.
diff --git a/changelog.d/7374.misc b/changelog.d/7374.misc
new file mode 100644
index 0000000000..676f285377
--- /dev/null
+++ b/changelog.d/7374.misc
@@ -0,0 +1 @@
+Move catchup of replication streams logic to worker.
diff --git a/changelog.d/7382.misc b/changelog.d/7382.misc
new file mode 100644
index 0000000000..dbf7a530be
--- /dev/null
+++ b/changelog.d/7382.misc
@@ -0,0 +1 @@
+Add typing annotations in `synapse.federation`.
diff --git a/changelog.d/7396.misc b/changelog.d/7396.misc
new file mode 100644
index 0000000000..290d2befc7
--- /dev/null
+++ b/changelog.d/7396.misc
@@ -0,0 +1 @@
+Convert the room handler to async/await.
diff --git a/changelog.d/7429.misc b/changelog.d/7429.misc
new file mode 100644
index 0000000000..3c25cd9917
--- /dev/null
+++ b/changelog.d/7429.misc
@@ -0,0 +1 @@
+Improve performance of `mark_as_sent_devices_by_remote`.
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index e1700ca8aa..52f4f54215 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
 
 import logging
 from collections import namedtuple
+from typing import Dict, List, Tuple, Type
 
 from six import iteritems
 
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
         self.notifier = hs.get_notifier()
         self.is_mine_id = hs.is_mine_id
 
-        self.presence_map = {}  # Pending presence map user_id -> UserPresenceState
-        self.presence_changed = SortedDict()  # Stream position -> list[user_id]
+        # Pending presence map user_id -> UserPresenceState
+        self.presence_map = {}  # type: Dict[str, UserPresenceState]
+
+        # Stream position -> list[user_id]
+        self.presence_changed = SortedDict()  # type: SortedDict[int, List[str]]
 
         # Stores the destinations we need to explicitly send presence to about a
         # given user.
         # Stream position -> (user_id, destinations)
-        self.presence_destinations = SortedDict()
+        self.presence_destinations = (
+            SortedDict()
+        )  # type: SortedDict[int, Tuple[str, List[str]]]
+
+        # (destination, key) -> EDU
+        self.keyed_edu = {}  # type: Dict[Tuple[str, tuple], Edu]
 
-        self.keyed_edu = {}  # (destination, key) -> EDU
-        self.keyed_edu_changed = SortedDict()  # stream position -> (destination, key)
+        # stream position -> (destination, key)
+        self.keyed_edu_changed = (
+            SortedDict()
+        )  # type: SortedDict[int, Tuple[str, tuple]]
 
-        self.edus = SortedDict()  # stream position -> Edu
+        self.edus = SortedDict()  # type: SortedDict[int, Edu]
 
+        # stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
         self.pos = 1
-        self.pos_time = SortedDict()
+
+        # map from stream ID to the time that stream entry was generated, so that we
+        # can clear out entries after a while
+        self.pos_time = SortedDict()  # type: SortedDict[int, int]
 
         # EVERYTHING IS SAD. In particular, python only makes new scopes when
         # we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
             for edu_key in self.keyed_edu_changed.values():
                 live_keys.add(edu_key)
 
-            to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
-            for edu_key in to_del:
+            keys_to_del = [
+                edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
+            ]
+            for edu_key in keys_to_del:
                 del self.keyed_edu[edu_key]
 
             # Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
         self._clear_queue_before_pos(token)
 
     async def get_replication_rows(
-        self, from_token, to_token, limit, federation_ack=None
-    ):
+        self, instance_name: str, from_token: int, to_token: int, target_row_count: int
+    ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
         """Get rows to be sent over federation between the two tokens
 
         Args:
-            from_token (int)
-            to_token(int)
-            limit (int)
-            federation_ack (int): Optional. The position where the worker is
-                explicitly acknowledged it has handled. Allows us to drop
-                data from before that point
+            instance_name: the name of the current process
+            from_token: the previous stream token: the starting point for fetching the
+                updates
+            to_token: the new stream token: the point to get updates up to
+            target_row_count: a target for the number of rows to be returned.
+
+        Returns: a triplet `(updates, new_last_token, limited)`, where:
+           * `updates` is a list of `(token, row)` entries.
+           * `new_last_token` is the new position in stream.
+           * `limited` is whether there are more updates to fetch.
         """
-        # TODO: Handle limit.
+        # TODO: Handle target_row_count.
 
         # To handle restarts where we wrap around
         if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
 
         # list of tuple(int, BaseFederationRow), where the first is the position
         # of the federation stream.
-        rows = []
-
-        # There should be only one reader, so lets delete everything its
-        # acknowledged its seen.
-        if federation_ack:
-            self._clear_queue_before_pos(federation_ack)
+        rows = []  # type: List[Tuple[int, BaseFederationRow]]
 
         # Fetch changed presence
         i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
         # Sort rows based on pos
         rows.sort()
 
-        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+        return (
+            [(pos, (row.TypeId, row.to_data())) for pos, row in rows],
+            to_token,
+            False,
+        )
 
 
 class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
     Specifies how to identify, serialize and deserialize the different types.
     """
 
-    TypeId = None  # Unique string that ids the type. Must be overriden in sub classes.
+    TypeId = ""  # Unique string that ids the type. Must be overriden in sub classes.
 
     @staticmethod
     def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))):  # Edu
         buff.edus.setdefault(self.edu.destination, []).append(self.edu)
 
 
-TypeToRow = {
-    Row.TypeId: Row
-    for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
-}
+_rowtypes = (
+    PresenceRow,
+    PresenceDestinationsRow,
+    KeyedEduRow,
+    EduRow,
+)  # type: Tuple[Type[BaseFederationRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
 
 
 ParsedFederationStreamData = namedtuple(
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a477578e44..d473576902 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Hashable, Iterable, List, Optional, Set
+from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
 
 from six import itervalues
 
@@ -498,14 +498,16 @@ class FederationSender(object):
 
         self._get_per_destination_queue(destination).attempt_new_transaction()
 
-    def get_current_token(self) -> int:
+    @staticmethod
+    def get_current_token() -> int:
         # Dummy implementation for case where federation sender isn't offloaded
         # to a worker.
         return 0
 
+    @staticmethod
     async def get_replication_rows(
-        self, from_token, to_token, limit, federation_ack=None
-    ):
+        instance_name: str, from_token: int, to_token: int, target_row_count: int
+    ) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
         # Dummy implementation for case where federation sender isn't offloaded
         # to a worker.
-        return []
+        return [], 0, False
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index e13cd20ffa..276a2b596f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,10 @@
 # limitations under the License.
 import datetime
 import logging
-from typing import Dict, Hashable, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
 
 from prometheus_client import Counter
 
-import synapse.server
 from synapse.api.errors import (
     FederationDeniedError,
     HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
 from synapse.types import ReadReceipt
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
+if TYPE_CHECKING:
+    import synapse.server
+
 # This is defined in the Matrix spec and enforced by the receiver.
 MAX_EDUS_PER_TRANSACTION = 100
 
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 3c2a02a3b3..a2752a54a5 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,11 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List
+from typing import TYPE_CHECKING, List
 
 from canonicaljson import json
 
-import synapse.server
 from synapse.api.errors import HttpResponseException
 from synapse.events import EventBase
 from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
 )
 from synapse.util.metrics import measure_func
 
+if TYPE_CHECKING:
+    import synapse.server
+
 logger = logging.getLogger(__name__)
 
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index da12df7f53..73f9eeb399 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,8 +25,6 @@ from collections import OrderedDict
 
 from six import iteritems, string_types
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
 from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
-    @defer.inlineCallbacks
-    def upgrade_room(
+    async def upgrade_room(
         self, requester: Requester, old_room_id: str, new_version: RoomVersion
     ):
         """Replace a room with a new room with a different version
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
         Returns:
             Deferred[unicode]: the new room id
         """
-        yield self.ratelimit(requester)
+        await self.ratelimit(requester)
 
         user_id = requester.user.to_string()
 
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
         # If this user has sent multiple upgrade requests for the same room
         # and one of them is not complete yet, cache the response and
         # return it to all subsequent requests
-        ret = yield self._upgrade_response_cache.wrap(
+        ret = await self._upgrade_response_cache.wrap(
             (old_room_id, user_id),
             self._upgrade_room,
             requester,
@@ -856,8 +853,7 @@ class RoomCreationHandler(BaseHandler):
         for (etype, state_key), content in initial_state.items():
             await send(etype=etype, state_key=state_key, content=content)
 
-    @defer.inlineCallbacks
-    def _generate_room_id(
+    async def _generate_room_id(
         self, creator_id: str, is_public: str, room_version: RoomVersion,
     ):
         # autogen room IDs and try to create it. We may clash, so just
@@ -869,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
                 gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
                 if isinstance(gen_room_id, bytes):
                     gen_room_id = gen_room_id.decode("utf-8")
-                yield self.store.store_room(
+                await self.store.store_room(
                     room_id=gen_room_id,
                     room_creator_user_id=creator_id,
                     is_public=is_public,
@@ -888,8 +884,7 @@ class RoomContextHandler(object):
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    @defer.inlineCallbacks
-    def get_event_context(self, user, room_id, event_id, limit, event_filter):
+    async def get_event_context(self, user, room_id, event_id, limit, event_filter):
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
@@ -908,7 +903,7 @@ class RoomContextHandler(object):
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
-        users = yield self.store.get_users_in_room(room_id)
+        users = await self.store.get_users_in_room(room_id)
         is_peeking = user.to_string() not in users
 
         def filter_evts(events):
@@ -916,17 +911,17 @@ class RoomContextHandler(object):
                 self.storage, user.to_string(), events, is_peeking=is_peeking
             )
 
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, get_prev_content=True, allow_none=True
         )
         if not event:
             return None
 
-        filtered = yield (filter_evts([event]))
+        filtered = await filter_evts([event])
         if not filtered:
             raise AuthError(403, "You don't have permission to access that event.")
 
-        results = yield self.store.get_events_around(
+        results = await self.store.get_events_around(
             room_id, event_id, before_limit, after_limit, event_filter
         )
 
@@ -934,8 +929,8 @@ class RoomContextHandler(object):
             results["events_before"] = event_filter.filter(results["events_before"])
             results["events_after"] = event_filter.filter(results["events_after"])
 
-        results["events_before"] = yield filter_evts(results["events_before"])
-        results["events_after"] = yield filter_evts(results["events_after"])
+        results["events_before"] = await filter_evts(results["events_before"])
+        results["events_after"] = await filter_evts(results["events_after"])
         # filter_evts can return a pruned event in case the user is allowed to see that
         # there's something there but not see the content, so use the event that's in
         # `filtered` rather than the event we retrieved from the datastore.
@@ -962,7 +957,7 @@ class RoomContextHandler(object):
         # first? Shouldn't we be consistent with /sync?
         # https://github.com/matrix-org/matrix-doc/issues/687
 
-        state = yield self.state_store.get_state_for_events(
+        state = await self.state_store.get_state_for_events(
             [last_event_id], state_filter=state_filter
         )
 
@@ -970,7 +965,7 @@ class RoomContextHandler(object):
         if event_filter:
             state_events = event_filter.filter(state_events)
 
-        results["state"] = yield filter_evts(state_events)
+        results["state"] = await filter_evts(state_events)
 
         # We use a dummy token here as we only care about the room portion of
         # the token, which we replace.
@@ -989,13 +984,12 @@ class RoomEventSource(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
-    def get_new_events(
+    async def get_new_events(
         self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
     ):
         # We just ignore the key for now.
 
-        to_key = yield self.get_current_key()
+        to_key = await self.get_current_key()
 
         from_token = RoomStreamToken.parse(from_key)
         if from_token.topological:
@@ -1008,11 +1002,11 @@ class RoomEventSource(object):
             # See https://github.com/matrix-org/matrix-doc/issues/1144
             raise NotImplementedError()
         else:
-            room_events = yield self.store.get_membership_changes_for_user(
+            room_events = await self.store.get_membership_changes_for_user(
                 user.to_string(), from_key, to_key
             )
 
-            room_to_events = yield self.store.get_room_events_stream_for_rooms(
+            room_to_events = await self.store.get_room_events_stream_for_rooms(
                 room_ids=room_ids,
                 from_key=from_key,
                 to_key=to_key,
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 33d2f589ac..b690abedad 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -80,7 +80,7 @@ class ReplicationStreamer(object):
             for stream in STREAMS_MAP.values():
                 if stream == FederationStream and hs.config.send_federation:
                     # We only support federation stream if federation sending
-                    # hase been disabled on the master.
+                    # has been disabled on the master.
                     continue
 
                 self.streams.append(stream(hs))
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b0f87c365b..084604e8b0 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -104,7 +104,8 @@ class Stream(object):
         implemented by subclasses.
 
         current_token_function is called to get the current token of the underlying
-        stream.
+        stream. It is only meaningful on the process that is the source of the
+        replication stream (ie, usually the master).
 
         update_function is called to get updates for this stream between a pair of
         stream tokens. See the UpdateFunction type definition for more info.
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index e8bd52e389..b0505b8a2c 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 from collections import namedtuple
 
-from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
+from synapse.replication.tcp.streams._base import Stream, make_http_update_function
 
 
 class FederationStream(Stream):
@@ -35,21 +35,33 @@ class FederationStream(Stream):
     ROW_TYPE = FederationStreamRow
 
     def __init__(self, hs):
-        # Not all synapse instances will have a federation sender instance,
-        # whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
-        # so we stub the stream out when that is the case.
-        if hs.config.worker_app is None or hs.should_send_federation():
+        if hs.config.worker_app is None:
+            # master process: get updates from the FederationRemoteSendQueue.
+            # (if the master is configured to send federation itself, federation_sender
+            # will be a real FederationSender, which has stubs for current_token and
+            # get_replication_rows.)
             federation_sender = hs.get_federation_sender()
             current_token = federation_sender.get_current_token
-            update_function = db_query_to_update_function(
-                federation_sender.get_replication_rows
-            )
+            update_function = federation_sender.get_replication_rows
+
+        elif hs.should_send_federation():
+            # federation sender: Query master process
+            update_function = make_http_update_function(hs, self.NAME)
+            current_token = self._stub_current_token
+
         else:
-            current_token = lambda: 0
+            # other worker: stub out the update function (we're not interested in
+            # any updates so when we get a POSITION we do nothing)
             update_function = self._stub_update_function
+            current_token = self._stub_current_token
 
         super().__init__(hs.get_instance_name(), current_token, update_function)
 
     @staticmethod
+    def _stub_current_token():
+        # dummy current-token method for use on workers
+        return 0
+
+    @staticmethod
     async def _stub_update_function(instance_name, from_token, upto_token, limit):
         return [], upto_token, False
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index ee3a2ab031..536cef3abd 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -55,6 +55,10 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 
 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
+BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = (
+    "drop_device_lists_outbound_last_success_non_unique_idx"
+)
+
 
 class DeviceWorkerStore(SQLBaseStore):
     def get_device(self, user_id, device_id):
@@ -342,32 +346,23 @@ class DeviceWorkerStore(SQLBaseStore):
 
     def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
         # We update the device_lists_outbound_last_success with the successfully
-        # poked users. We do the join to see which users need to be inserted and
-        # which updated.
+        # poked users.
         sql = """
-            SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+            SELECT user_id, coalesce(max(o.stream_id), 0)
             FROM device_lists_outbound_pokes as o
-            LEFT JOIN device_lists_outbound_last_success as s
-                USING (destination, user_id)
             WHERE destination = ? AND o.stream_id <= ?
             GROUP BY user_id
         """
         txn.execute(sql, (destination, stream_id))
         rows = txn.fetchall()
 
-        sql = """
-            UPDATE device_lists_outbound_last_success
-            SET stream_id = ?
-            WHERE destination = ? AND user_id = ?
-        """
-        txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
-
-        sql = """
-            INSERT INTO device_lists_outbound_last_success
-            (destination, user_id, stream_id) VALUES (?, ?, ?)
-        """
-        txn.executemany(
-            sql, ((destination, row[0], row[1]) for row in rows if not row[2])
+        self.db.simple_upsert_many_txn(
+            txn=txn,
+            table="device_lists_outbound_last_success",
+            key_names=("destination", "user_id"),
+            key_values=((destination, user_id) for user_id, _ in rows),
+            value_names=("stream_id",),
+            value_values=((stream_id,) for _, stream_id in rows),
         )
 
         # Delete all sent outbound pokes
@@ -725,6 +720,21 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
         )
 
+        # create a unique index on device_lists_outbound_last_success
+        self.db.updates.register_background_index_update(
+            "device_lists_outbound_last_success_unique_idx",
+            index_name="device_lists_outbound_last_success_unique_idx",
+            table="device_lists_outbound_last_success",
+            columns=["destination", "user_id"],
+            unique=True,
+        )
+
+        # once that completes, we can remove the old non-unique index.
+        self.db.updates.register_background_update_handler(
+            BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX,
+            self._drop_device_lists_outbound_last_success_non_unique_idx,
+        )
+
     @defer.inlineCallbacks
     def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
         def f(conn):
@@ -799,6 +809,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
         return rows
 
+    async def _drop_device_lists_outbound_last_success_non_unique_idx(
+        self, progress, batch_size
+    ):
+        def f(txn):
+            txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx")
+
+        await self.db.runInteraction(
+            "drop_device_lists_outbound_last_success_non_unique_idx", f,
+        )
+        await self.db.updates._end_background_update(
+            BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX
+        )
+        return 1
+
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
     def __init__(self, database: Database, db_conn, hs):
diff --git a/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
new file mode 100644
index 0000000000..d5e6deb878
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/04device_lists_outbound_last_success_unique_idx.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- register a background update which will create a unique index on
+-- device_lists_outbound_last_success
+INSERT into background_updates (ordering, update_name, progress_json)
+    VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}');
+
+-- once that completes, we can drop the old index.
+INSERT into background_updates (ordering, update_name, progress_json, depends_on)
+    VALUES (
+        5804,
+        'drop_device_lists_outbound_last_success_non_unique_idx',
+        '{}',
+        'device_lists_outbound_last_success_unique_idx'
+    );
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a7cd97b0b0..2b635d6ca0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
 from synapse.util.stringutils import exception_to_unicode
 
 logger = logging.getLogger(__name__)
@@ -78,6 +79,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
     "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
     "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
     "event_search": "event_search_event_id_idx",
+    "device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx",
 }
 
 
@@ -889,20 +891,24 @@ class Database(object):
         txn.execute(sql, list(allvalues.values()))
 
     def simple_upsert_many_txn(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[str]],
+    ) -> None:
         """
         Upsert, many times.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
             return self.simple_upsert_many_txn_native_upsert(
@@ -914,20 +920,24 @@ class Database(object):
             )
 
     def simple_upsert_many_txn_emulated(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Iterable[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[str]],
+    ) -> None:
         """
         Upsert, many times, but without native UPSERT support or batching.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         # No value columns, therefore make a blank list so that the following
         # zip() works correctly.
@@ -941,20 +951,24 @@ class Database(object):
             self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
 
     def simple_upsert_many_txn_native_upsert(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        key_names: Collection[str],
+        key_values: Collection[Iterable[Any]],
+        value_names: Collection[str],
+        value_values: Iterable[Iterable[Any]],
+    ) -> None:
         """
         Upsert, many times, using batching where possible.
 
         Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
+            table: The table to upsert into
+            key_names: The key column names.
+            key_values: A list of each row's key column values.
+            value_names: The value column names
+            value_values: A list of each row's value column values.
+                Ignored if value_names is empty.
         """
         allnames = []  # type: List[str]
         allnames.extend(key_names)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d851beaa5..86d04ea9ac 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
 import contextlib
 import threading
 from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
 
 
 class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
             )
-        self._unfinished_ids = deque()
+        self._unfinished_ids = deque()  # type: Deque[int]
 
     def get_next(self):
         """
@@ -163,7 +168,7 @@ class ChainedIdGenerator(object):
         self.chained_generator = chained_generator
         self._lock = threading.Lock()
         self._current_max = _load_current_id(db_conn, table, column)
-        self._unfinished_ids = deque()
+        self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
 
     def get_next(self):
         """
@@ -198,3 +203,163 @@ class ChainedIdGenerator(object):
                 return stream_id - 1, chained_id
 
             return self._current_max, self.chained_generator.get_current_token()
+
+
+class MultiWriterIdGenerator:
+    """An ID generator that tracks a stream that can have multiple writers.
+
+    Uses a Postgres sequence to coordinate ID assignment, but positions of other
+    writers will only get updated when `advance` is called (by replication).
+
+    Note: Only works with Postgres.
+
+    Args:
+        db_conn
+        db
+        instance_name: The name of this instance.
+        table: Database table associated with stream.
+        instance_column: Column that stores the row's writer's instance name
+        id_column: Column that stores the stream ID.
+        sequence_name: The name of the postgres sequence used to generate new
+            IDs.
+    """
+
+    def __init__(
+        self,
+        db_conn,
+        db: Database,
+        instance_name: str,
+        table: str,
+        instance_column: str,
+        id_column: str,
+        sequence_name: str,
+    ):
+        self._db = db
+        self._instance_name = instance_name
+        self._sequence_name = sequence_name
+
+        # We lock as some functions may be called from DB threads.
+        self._lock = threading.Lock()
+
+        self._current_positions = self._load_current_ids(
+            db_conn, table, instance_column, id_column
+        )
+
+        # Set of local IDs that we're still processing. The current position
+        # should be less than the minimum of this set (if not empty).
+        self._unfinished_ids = set()  # type: Set[int]
+
+    def _load_current_ids(
+        self, db_conn, table: str, instance_column: str, id_column: str
+    ) -> Dict[str, int]:
+        sql = """
+            SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+            GROUP BY %(instance)s
+        """ % {
+            "instance": instance_column,
+            "id": id_column,
+            "table": table,
+        }
+
+        cur = db_conn.cursor()
+        cur.execute(sql)
+
+        # `cur` is an iterable over returned rows, which are 2-tuples.
+        current_positions = dict(cur)
+
+        cur.close()
+
+        return current_positions
+
+    def _load_next_id_txn(self, txn):
+        txn.execute("SELECT nextval(?)", (self._sequence_name,))
+        (next_id,) = txn.fetchone()
+        return next_id
+
+    async def get_next(self):
+        """
+        Usage:
+            with await stream_id_gen.get_next() as stream_id:
+                # ... persist event ...
+        """
+        next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+        # Assert the fetched ID is actually greater than what we currently
+        # believe the ID to be. If not, then the sequence and table have got
+        # out of sync somehow.
+        assert self.get_current_token() < next_id
+
+        with self._lock:
+            self._unfinished_ids.add(next_id)
+
+        @contextlib.contextmanager
+        def manager():
+            try:
+                yield next_id
+            finally:
+                self._mark_id_as_finished(next_id)
+
+        return manager()
+
+    def get_next_txn(self, txn: LoggingTransaction):
+        """
+        Usage:
+
+            stream_id = stream_id_gen.get_next(txn)
+            # ... persist event ...
+        """
+
+        next_id = self._load_next_id_txn(txn)
+
+        with self._lock:
+            self._unfinished_ids.add(next_id)
+
+        txn.call_after(self._mark_id_as_finished, next_id)
+        txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+        return next_id
+
+    def _mark_id_as_finished(self, next_id: int):
+        """The ID has finished being processed so we should advance the
+        current poistion if possible.
+        """
+
+        with self._lock:
+            self._unfinished_ids.discard(next_id)
+
+            # Figure out if its safe to advance the position by checking there
+            # aren't any lower allocated IDs that are yet to finish.
+            if all(c > next_id for c in self._unfinished_ids):
+                curr = self._current_positions.get(self._instance_name, 0)
+                self._current_positions[self._instance_name] = max(curr, next_id)
+
+    def get_current_token(self, instance_name: str = None) -> int:
+        """Gets the current position of a named writer (defaults to current
+        instance).
+
+        Returns 0 if we don't have a position for the named writer (likely due
+        to it being a new writer).
+        """
+
+        if instance_name is None:
+            instance_name = self._instance_name
+
+        with self._lock:
+            return self._current_positions.get(instance_name, 0)
+
+    def get_positions(self) -> Dict[str, int]:
+        """Get a copy of the current positon map.
+        """
+
+        with self._lock:
+            return dict(self._current_positions)
+
+    def advance(self, instance_name: str, new_id: int):
+        """Advance the postion of the named writer to the given ID, if greater
+        than existing entry.
+        """
+
+        with self._lock:
+            self._current_positions[instance_name] = max(
+                new_id, self._current_positions.get(instance_name, 0)
+            )
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 7b56d2028d..9d4f0bbe44 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -27,6 +27,7 @@ from synapse.app.generic_worker import (
     GenericWorkerServer,
 )
 from synapse.http.site import SynapseRequest
+from synapse.replication.http import streams
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -42,6 +43,10 @@ logger = logging.getLogger(__name__)
 class BaseStreamTestCase(unittest.HomeserverTestCase):
     """Base class for tests of the replication streams"""
 
+    servlets = [
+        streams.register_servlets,
+    ]
+
     def prepare(self, reactor, clock, hs):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
@@ -49,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.server = server_factory.buildProtocol(None)
 
         # Make a new HomeServer object for the worker
-        config = self.default_config()
-        config["worker_app"] = "synapse.app.generic_worker"
-        config["worker_replication_host"] = "testserv"
-        config["worker_replication_http_port"] = "8765"
-
         self.reactor.lookups["testserv"] = "1.2.3.4"
-
         self.worker_hs = self.setup_test_homeserver(
             http_client=None,
             homeserverToUse=GenericWorkerServer,
-            config=config,
+            config=self._get_worker_hs_config(),
             reactor=self.reactor,
         )
 
@@ -78,6 +77,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._client_transport = None
         self._server_transport = None
 
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_app"] = "synapse.app.generic_worker"
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+        return config
+
     def _build_replication_data_handler(self):
         return TestReplicationDataHandler(self.worker_hs)
 
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
new file mode 100644
index 0000000000..eea4565da3
--- /dev/null
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -0,0 +1,81 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.federation.send_queue import EduRow
+from synapse.replication.tcp.streams.federation import FederationStream
+
+from tests.replication.tcp.streams._base import BaseStreamTestCase
+
+
+class FederationStreamTestCase(BaseStreamTestCase):
+    def _get_worker_hs_config(self) -> dict:
+        # enable federation sending on the worker
+        config = super()._get_worker_hs_config()
+        # TODO: make it so we don't need both of these
+        config["send_federation"] = True
+        config["worker_app"] = "synapse.app.federation_sender"
+        return config
+
+    def test_catchup(self):
+        """Basic test of catchup on reconnect
+
+        Makes sure that updates sent while we are offline are received later.
+        """
+        fed_sender = self.hs.get_federation_sender()
+        received_rows = self.test_handler.received_rdata_rows
+
+        fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
+
+        self.reconnect()
+        self.reactor.advance(0)
+
+        # check we're testing what we think we are: no rows should yet have been
+        # received
+        self.assertEqual(received_rows, [])
+
+        # We should now see an attempt to connect to the master
+        request = self.handle_http_replication_attempt()
+        self.assert_request_is_get_repl_stream_updates(request, "federation")
+
+        # we should have received an update row
+        stream_name, token, row = received_rows.pop()
+        self.assertEqual(stream_name, "federation")
+        self.assertIsInstance(row, FederationStream.FederationStreamRow)
+        self.assertEqual(row.type, EduRow.TypeId)
+        edurow = EduRow.from_data(row.data)
+        self.assertEqual(edurow.edu.edu_type, "m.test_edu")
+        self.assertEqual(edurow.edu.origin, self.hs.hostname)
+        self.assertEqual(edurow.edu.destination, "testdest")
+        self.assertEqual(edurow.edu.content, {"a": "b"})
+
+        self.assertEqual(received_rows, [])
+
+        # additional updates should be transferred without an HTTP hit
+        fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
+        self.reactor.advance(0)
+        # there should be no http hit
+        self.assertEqual(len(self.reactor.tcpClients), 0)
+        # ... but we should have a row
+        self.assertEqual(len(received_rows), 1)
+
+        stream_name, token, row = received_rows.pop()
+        self.assertEqual(stream_name, "federation")
+        self.assertIsInstance(row, FederationStream.FederationStreamRow)
+        self.assertEqual(row.type, EduRow.TypeId)
+        edurow = EduRow.from_data(row.data)
+        self.assertEqual(edurow.edu.edu_type, "m.test1")
+        self.assertEqual(edurow.edu.origin, self.hs.hostname)
+        self.assertEqual(edurow.edu.destination, "testdest")
+        self.assertEqual(edurow.edu.content, {"c": "d"})
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index d25a7b194e..125c63dab5 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -15,7 +15,6 @@
 from mock import Mock
 
 from synapse.handlers.typing import RoomMember
-from synapse.replication.http import streams
 from synapse.replication.tcp.streams import TypingStream
 
 from tests.replication.tcp.streams._base import BaseStreamTestCase
@@ -24,10 +23,6 @@ USER_ID = "@feeling:blue"
 
 
 class TypingStreamTestCase(BaseStreamTestCase):
-    servlets = [
-        streams.register_servlets,
-    ]
-
     def _build_replication_data_handler(self):
         return Mock(wraps=super()._build_replication_data_handler())
 
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
new file mode 100644
index 0000000000..55e9ecf264
--- /dev/null
+++ b/tests/storage/test_id_generators.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.db = self.store.db  # type: Database
+
+        self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
+
+    def _setup_db(self, txn):
+        txn.execute("CREATE SEQUENCE foobar_seq")
+        txn.execute(
+            """
+            CREATE TABLE foobar (
+                stream_id BIGINT NOT NULL,
+                instance_name TEXT NOT NULL,
+                data TEXT
+            );
+            """
+        )
+
+    def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+        def _create(conn):
+            return MultiWriterIdGenerator(
+                conn,
+                self.db,
+                instance_name=instance_name,
+                table="foobar",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="foobar_seq",
+            )
+
+        return self.get_success(self.db.runWithConnection(_create))
+
+    def _insert_rows(self, instance_name: str, number: int):
+        def _insert(txn):
+            for _ in range(number):
+                txn.execute(
+                    "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
+                    (instance_name,),
+                )
+
+        self.get_success(self.db.runInteraction("test_single_instance", _insert))
+
+    def test_empty(self):
+        """Test an ID generator against an empty database gives sensible
+        current positions.
+        """
+
+        id_gen = self._create_id_generator()
+
+        # The table is empty so we expect an empty map for positions
+        self.assertEqual(id_gen.get_positions(), {})
+
+    def test_single_instance(self):
+        """Test that reads and writes from a single process are handled
+        correctly.
+        """
+
+        # Prefill table with 7 rows written by 'master'
+        self._insert_rows("master", 7)
+
+        id_gen = self._create_id_generator()
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        async def _get_next_async():
+            with await id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
+
+                self.assertEqual(id_gen.get_positions(), {"master": 7})
+                self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(id_gen.get_positions(), {"master": 8})
+        self.assertEqual(id_gen.get_current_token("master"), 8)
+
+    def test_multi_instance(self):
+        """Test that reads and writes from multiple processes are handled
+        correctly.
+        """
+        self._insert_rows("first", 3)
+        self._insert_rows("second", 4)
+
+        first_id_gen = self._create_id_generator("first")
+        second_id_gen = self._create_id_generator("second")
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(first_id_gen.get_current_token("first"), 3)
+        self.assertEqual(first_id_gen.get_current_token("second"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        async def _get_next_async():
+            with await first_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 8)
+
+                self.assertEqual(
+                    first_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
+
+        # However the ID gen on the second instance won't have seen the update
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
+
+        # ... but calling `get_next` on the second instance should give a unique
+        # stream ID
+
+        async def _get_next_async():
+            with await second_id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 9)
+
+                self.assertEqual(
+                    second_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
+
+        self.get_success(_get_next_async())
+
+        self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
+
+        # If the second ID gen gets told about the first, it correctly updates
+        second_id_gen.advance("first", 8)
+        self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
+
+    def test_get_next_txn(self):
+        """Test that the `get_next_txn` function works correctly.
+        """
+
+        # Prefill table with 7 rows written by 'master'
+        self._insert_rows("master", 7)
+
+        id_gen = self._create_id_generator()
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        # Try allocating a new ID gen and check that we only see position
+        # advanced after we leave the context manager.
+
+        def _get_next_txn(txn):
+            stream_id = id_gen.get_next_txn(txn)
+            self.assertEqual(stream_id, 8)
+
+            self.assertEqual(id_gen.get_positions(), {"master": 7})
+            self.assertEqual(id_gen.get_current_token("master"), 7)
+
+        self.get_success(self.db.runInteraction("test", _get_next_txn))
+
+        self.assertEqual(id_gen.get_positions(), {"master": 8})
+        self.assertEqual(id_gen.get_current_token("master"), 8)
diff --git a/tox.ini b/tox.ini
index eccc44e436..ad4ed8299e 100644
--- a/tox.ini
+++ b/tox.ini
@@ -181,11 +181,7 @@ commands = mypy \
             synapse/appservice \
             synapse/config \
             synapse/events/spamcheck.py \
-            synapse/federation/federation_base.py \
-            synapse/federation/federation_client.py \
-            synapse/federation/federation_server.py \
-            synapse/federation/sender \
-            synapse/federation/transport \
+            synapse/federation \
             synapse/handlers/auth.py \
             synapse/handlers/cas_handler.py \
             synapse/handlers/directory.py \
@@ -203,6 +199,7 @@ commands = mypy \
             synapse/storage/data_stores/main/ui_auth.py \
             synapse/storage/database.py \
             synapse/storage/engines \
+            synapse/storage/util \
             synapse/streams \
             synapse/util/caches/stream_change_cache.py \
             tests/replication/tcp/streams \