summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/database.py36
-rw-r--r--synapse/storage/databases/main/__init__.py31
-rw-r--r--synapse/storage/databases/main/account_data.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/devices.py22
-rw-r--r--synapse/storage/databases/main/directory.py4
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py8
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py43
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/events_worker.py14
-rw-r--r--synapse/storage/databases/main/group_server.py20
-rw-r--r--synapse/storage/databases/main/media_repository.py13
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py15
-rw-r--r--synapse/storage/databases/main/presence.py2
-rw-r--r--synapse/storage/databases/main/profile.py17
-rw-r--r--synapse/storage/databases/main/push_rule.py8
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/receipts.py9
-rw-r--r--synapse/storage/databases/main/registration.py35
-rw-r--r--synapse/storage/databases/main/rejections.py5
-rw-r--r--synapse/storage/databases/main/room.py16
-rw-r--r--synapse/storage/databases/main/state.py4
-rw-r--r--synapse/storage/databases/main/stats.py10
-rw-r--r--synapse/storage/databases/main/tags.py4
-rw-r--r--synapse/storage/databases/main/user_directory.py9
-rw-r--r--synapse/storage/util/id_generators.py113
-rw-r--r--synapse/storage/util/sequence.py8
27 files changed, 317 insertions, 145 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index bc327e344e..181c3ec249 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -29,9 +29,11 @@ from typing import (
     Tuple,
     TypeVar,
     Union,
+    overload,
 )
 
 from prometheus_client import Histogram
+from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
 from twisted.internet import defer
@@ -1020,14 +1022,36 @@ class DatabasePool(object):
 
         return txn.execute_batch(sql, args)
 
-    def simple_select_one(
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[False] = False,
+        desc: str = "simple_select_one",
+    ) -> Dict[str, Any]:
+        ...
+
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[True] = True,
+        desc: str = "simple_select_one",
+    ) -> Optional[Dict[str, Any]]:
+        ...
+
+    async def simple_select_one(
         self,
         table: str,
         keyvalues: Dict[str, Any],
         retcols: Iterable[str],
         allow_none: bool = False,
         desc: str = "simple_select_one",
-    ) -> defer.Deferred:
+    ) -> Optional[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning multiple columns from it.
 
@@ -1038,18 +1062,18 @@ class DatabasePool(object):
             allow_none: If true, return None instead of failing if the SELECT
                 statement returns no rows
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
         )
 
-    def simple_select_one_onecol(
+    async def simple_select_one_onecol(
         self,
         table: str,
         keyvalues: Dict[str, Any],
         retcol: Iterable[str],
         allow_none: bool = False,
         desc: str = "simple_select_one_onecol",
-    ) -> defer.Deferred:
+    ) -> Optional[Any]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it.
 
@@ -1061,7 +1085,7 @@ class DatabasePool(object):
                 statement returns no rows
             desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc,
             self.simple_select_one_onecol_txn,
             table,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 17fa470919..0934ae276c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -498,7 +498,7 @@ class DataStore(
         )
 
     def get_users_paginate(
-        self, start, limit, name=None, guests=True, deactivated=False
+        self, start, limit, user_id=None, name=None, guests=True, deactivated=False
     ):
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users and the
@@ -507,7 +507,8 @@ class DataStore(
         Args:
             start (int): start number to begin the query from
             limit (int): number of rows to retrieve
-            name (string): filter for user names
+            user_id (string): search for user_id. ignored if name is not None
+            name (string): search for local part of user_id or display name
             guests (bool): whether to in include guest users
             deactivated (bool): whether to include deactivated users
         Returns:
@@ -516,11 +517,14 @@ class DataStore(
 
         def get_users_paginate_txn(txn):
             filters = []
-            args = []
+            args = [self.hs.config.server_name]
 
             if name:
+                filters.append("(name LIKE ? OR displayname LIKE ?)")
+                args.extend(["@%" + name + "%:%", "%" + name + "%"])
+            elif user_id:
                 filters.append("name LIKE ?")
-                args.append("%" + name + "%")
+                args.extend(["%" + user_id + "%"])
 
             if not guests:
                 filters.append("is_guest = 0")
@@ -530,20 +534,23 @@ class DataStore(
 
             where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
 
-            sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
-            txn.execute(sql, args)
-            count = txn.fetchone()[0]
-
-            args = [self.hs.config.server_name] + args + [limit, start]
-            sql = """
-                SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+            sql_base = """
                 FROM users as u
                 LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
                 {}
-                ORDER BY u.name LIMIT ? OFFSET ?
                 """.format(
                 where_clause
             )
+            sql = "SELECT COUNT(*) as total_users " + sql_base
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = (
+                "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+                + sql_base
+                + " ORDER BY u.name LIMIT ? OFFSET ?"
+            )
+            args += [limit, start]
             txn.execute(sql, args)
             users = self.db_pool.cursor_to_dict(txn)
             return users, count
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..04042a2c98 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as room_account_data has a unique constraint
             # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as account_data has a unique constraint on
             # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..bb85637a95 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
-        with self._device_inbox_id_gen.get_next() as stream_id:
+        with await self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 txn, stream_id, local_messages_by_user_then_device
             )
 
-        with self._device_inbox_id_gen.get_next() as stream_id:
+        with await self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9a786e2929..a811a39eb5 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
@@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
-    def get_device(self, user_id: str, device_id: str):
+    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
 
@@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
             user_id: The ID of the user which owns the device
             device_id: The ID of the device to retrieve
         Returns:
-            defer.Deferred for a dict containing the device information
+            A dict containing the device information
         Raises:
             StoreError: if the device is not found
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
             THe new stream ID.
         """
 
-        with self._device_list_id_gen.get_next() as stream_id:
+        with await self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    def get_device_list_last_stream_id_for_remote(self, user_id: str):
+    async def get_device_list_last_stream_id_for_remote(
+        self, user_id: str
+    ) -> Optional[Any]:
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
             retcol="stream_id",
@@ -1146,7 +1148,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return
 
-        with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+        with await self._device_list_id_gen.get_next_mult(
+            len(device_ids)
+        ) as stream_ids:
             await self.db_pool.runInteraction(
                 "add_device_change_to_stream",
                 self._add_device_change_to_stream_txn,
@@ -1159,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return stream_ids[-1]
 
         context = get_active_span_text_map()
-        with self._device_list_id_gen.get_next_mult(
+        with await self._device_list_id_gen.get_next_mult(
             len(hosts) * len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 037e02603c..301d5d845a 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
 
         return RoomAliasMapping(room_id, room_alias.to_string(), servers)
 
-    def get_room_alias_creator(self, room_alias):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_room_alias_creator(self, room_alias: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="room_aliases",
             keyvalues={"room_alias": room_alias},
             retcol="creator",
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 2eeb9f97dc..46c3e33cc6 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         return ret
 
-    def count_e2e_room_keys(self, user_id, version):
+    async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
         """Get the number of keys in a backup version.
 
         Args:
-            user_id (str): the user whose backup we're querying
-            version (str): the version ID of the backup we're querying about
+            user_id: the user whose backup we're querying
+            version: the version ID of the backup we're querying about
         """
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="e2e_room_keys",
             keyvalues={"user_id": user_id, "version": version},
             retcol="COUNT(*)",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..385868bdab 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
-    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
         """Set a user's cross-signing key.
 
         Args:
@@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 for a master key, 'self_signing' for a self-signing key, or
                 'user_signing' for a user-signing key
             key (dict): the key data
+            stream_id (int)
         """
         # the 'key' dict will look something like:
         # {
@@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             )
 
         # and finally, store the key itself
-        with self._cross_signing_id_gen.get_next() as stream_id:
-            self.db_pool.simple_insert_txn(
-                txn,
-                "e2e_cross_signing_keys",
-                values={
-                    "user_id": user_id,
-                    "keytype": key_type,
-                    "keydata": json_encoder.encode(key),
-                    "stream_id": stream_id,
-                },
-            )
+        self.db_pool.simple_insert_txn(
+            txn,
+            "e2e_cross_signing_keys",
+            values={
+                "user_id": user_id,
+                "keytype": key_type,
+                "keydata": json_encoder.encode(key),
+                "stream_id": stream_id,
+            },
+        )
 
         self._invalidate_cache_and_stream(
             txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
         )
 
-    def set_e2e_cross_signing_key(self, user_id, key_type, key):
+    async def set_e2e_cross_signing_key(self, user_id, key_type, key):
         """Set a user's cross-signing key.
 
         Args:
@@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key_type (str): the type of cross-signing key to set
             key (dict): the key data
         """
-        return self.db_pool.runInteraction(
-            "add_e2e_cross_signing_key",
-            self._set_e2e_cross_signing_key_txn,
-            user_id,
-            key_type,
-            key,
-        )
+
+        with await self._cross_signing_id_gen.get_next() as stream_id:
+            return await self.db_pool.runInteraction(
+                "add_e2e_cross_signing_key",
+                self._set_e2e_cross_signing_key_txn,
+                user_id,
+                key_type,
+                key,
+                stream_id,
+            )
 
     def store_e2e_cross_signing_signatures(self, user_id, signatures):
         """Stores cross-signing signatures.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b90e6de2d5..6313b41eef 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -153,11 +153,11 @@ class PersistEventsStore:
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
         if backfilled:
-            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+            stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
         else:
-            stream_ordering_manager = self._stream_id_gen.get_next_mult(
+            stream_ordering_manager = await self._stream_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e1241a724b..e6247d682d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -113,25 +113,25 @@ class EventsWorkerStore(SQLBaseStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == EventsStream.NAME:
-            self._stream_id_gen.advance(token)
+            self._stream_id_gen.advance(instance_name, token)
         elif stream_name == BackfillStream.NAME:
-            self._backfill_id_gen.advance(-token)
+            self._backfill_id_gen.advance(instance_name, -token)
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def get_received_ts(self, event_id):
+    async def get_received_ts(self, event_id: str) -> Optional[int]:
         """Get received_ts (when it was persisted) for the event.
 
         Raises an exception for unknown events.
 
         Args:
-            event_id (str)
+            event_id: The event ID to query.
 
         Returns:
-            Deferred[int|None]: Timestamp in milliseconds, or None for events
-            that were persisted before received_ts was implemented.
+            Timestamp in milliseconds, or None for events that were persisted
+            before received_ts was implemented.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="events",
             keyvalues={"event_id": event_id},
             retcol="received_ts",
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 0e3b8739c6..c39864f59f 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
 
 
 class GroupServerWorkerStore(SQLBaseStore):
-    def get_group(self, group_id):
-        return self.db_pool.simple_select_one(
+    async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="groups",
             keyvalues={"group_id": group_id},
             retcols=(
@@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore):
         )
         return bool(result)
 
-    def is_user_admin_in_group(self, group_id, user_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_user_admin_in_group(
+        self, group_id: str, user_id: str
+    ) -> Optional[bool]:
+        return await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="is_admin",
@@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="is_user_admin_in_group",
         )
 
-    def is_user_invited_to_local_group(self, group_id, user_id):
+    async def is_user_invited_to_local_group(
+        self, group_id: str, user_id: str
+    ) -> Optional[bool]:
         """Has the group server invited a user?
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
@@ -1182,7 +1186,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        with self._group_updates_id_gen.get_next() as next_id:
+        with await self._group_updates_id_gen.get_next() as next_id:
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 80fc1cd009..4ae255ebd8 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,6 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, Optional
+
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
 
-    def get_local_media(self, media_id):
+    async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
         """Get the metadata for a local piece of media
+
         Returns:
             None if the media_id doesn't exist.
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_thumbnail",
         )
 
-    def get_cached_remote_media(self, origin, media_id):
-        return self.db_pool.simple_select_one(
+    async def get_cached_remote_media(
+        self, origin, media_id: str
+    ) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e71cdd2cb4..fe30552c08 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         return users
 
     @cached(num_args=1)
-    def user_last_seen_monthly_active(self, user_id):
+    async def user_last_seen_monthly_active(self, user_id: str) -> int:
         """
-            Checks if a given user is part of the monthly active user group
-            Arguments:
-                user_id (str): user to add/update
-            Return:
-                Deferred[int] : timestamp since last seen, None if never seen
+        Checks if a given user is part of the monthly active user group
 
+        Arguments:
+            user_id: user to add/update
+
+        Return:
+            Timestamp since last seen, None if never seen
         """
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="monthly_active_users",
             keyvalues={"user_id": user_id},
             retcol="timestamp",
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4e3ec02d14..c9f655dfb7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
 
 class PresenceStore(SQLBaseStore):
     async def update_presence(self, presence_states):
-        stream_ordering_manager = self._presence_id_gen.get_next_mult(
+        stream_ordering_manager = await self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
 
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8261357d4..b8233c4848 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, Optional
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
 
 
 class ProfileWorkerStore(SQLBaseStore):
-    async def get_profileinfo(self, user_localpart):
+    async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
         try:
             profile = await self.db_pool.simple_select_one(
                 table="profiles",
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
             avatar_url=profile["avatar_url"], display_name=profile["displayname"]
         )
 
-    def get_profile_displayname(self, user_localpart):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_profile_displayname(self, user_localpart: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="displayname",
             desc="get_profile_displayname",
         )
 
-    def get_profile_avatar_url(self, user_localpart):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_profile_avatar_url(self, user_localpart: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="avatar_url",
             desc="get_profile_avatar_url",
         )
 
-    def get_from_remote_profile_cache(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_from_remote_profile_cache(
+        self, user_id: str
+    ) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             retcols=("displayname", "avatar_url"),
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index a585e54812..2fb5b02d7d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with self._push_rules_stream_id_gen.get_next() as stream_id:
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             if before or after:
@@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with self._push_rules_stream_id_gen.get_next() as stream_id:
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
@@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
             )
 
     async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
-        with self._push_rules_stream_id_gen.get_next() as stream_id:
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
@@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with self._push_rules_stream_id_gen.get_next() as stream_id:
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 1126fd0751..c388468273 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
         last_stream_ordering,
         profile_tag="",
     ) -> None:
-        with self._pushers_id_gen.get_next() as stream_id:
+        with await self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so simple_upsert will retry
             await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
                 },
             )
 
-        with self._pushers_id_gen.get_next() as stream_id:
+        with await self._pushers_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "delete_pusher", delete_pusher_txn, stream_id
             )
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 19ad1c056f..cea5ac9a68 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=3)
-    def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_last_receipt_event_id_for_user(
+        self, user_id: str, room_id: str, receipt_type: str
+    ) -> Optional[str]:
+        return await self.db_pool.simple_select_one_onecol(
             table="receipts_linearized",
             keyvalues={
                 "room_id": room_id,
@@ -520,8 +522,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "insert_receipt_conv", graph_to_linear
             )
 
-        stream_id_manager = self._receipts_id_gen.get_next()
-        with stream_id_manager as stream_id:
+        with await self._receipts_id_gen.get_next() as stream_id:
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 068ad22b30..eced53d470 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,7 +17,7 @@
 
 import logging
 import re
-from typing import Awaitable, Dict, List, Optional
+from typing import Any, Awaitable, Dict, List, Optional
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    def get_user_by_id(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="users",
             keyvalues={"name": user_id},
             retcols=[
@@ -889,6 +889,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         super(RegistrationStore, self).__init__(database, db_conn, hs)
 
         self._account_validity = hs.config.account_validity
+        self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
         if self._account_validity.enabled:
             self._clock.call_later(
@@ -1258,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="del_user_pending_deactivation",
         )
 
-    def get_user_pending_deactivation(self):
+    async def get_user_pending_deactivation(self) -> Optional[str]:
         """
         Gets one user from the table of users waiting to be parted from all the rooms
         they're in.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             "users_pending_deactivation",
             keyvalues={},
             retcol="user_id",
@@ -1302,15 +1303,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
 
             if not row:
-                raise ThreepidValidationError(400, "Unknown session_id")
+                if self._ignore_unknown_session_error:
+                    # If we need to inhibit the error caused by an incorrect session ID,
+                    # use None as placeholder values for the client secret and the
+                    # validation timestamp.
+                    # It shouldn't be an issue because they're both only checked after
+                    # the token check, which should fail. And if it doesn't for some
+                    # reason, the next check is on the client secret, which is NOT NULL,
+                    # so we don't have to worry about the client secret matching by
+                    # accident.
+                    row = {"client_secret": None, "validated_at": None}
+                else:
+                    raise ThreepidValidationError(400, "Unknown session_id")
+
             retrieved_client_secret = row["client_secret"]
             validated_at = row["validated_at"]
 
-            if retrieved_client_secret != client_secret:
-                raise ThreepidValidationError(
-                    400, "This client_secret does not match the provided session_id"
-                )
-
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 table="threepid_validation_token",
@@ -1326,6 +1334,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             expires = row["expires"]
             next_link = row["next_link"]
 
+            if retrieved_client_secret != client_secret:
+                raise ThreepidValidationError(
+                    400, "This client_secret does not match the provided session_id"
+                )
+
             # If the session is already validated, no need to revalidate
             if validated_at:
                 return next_link
diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py
index cf9ba51205..1e361aaa9a 100644
--- a/synapse/storage/databases/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
 
 from synapse.storage._base import SQLBaseStore
 
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
 
 
 class RejectionsStore(SQLBaseStore):
-    def get_rejection_reason(self, event_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_rejection_reason(self, event_id: str) -> Optional[str]:
+        return await self.db_pool.simple_select_one_onecol(
             table="rejections",
             retcol="reason",
             keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7d3ac47261..97ecdb16e4 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
 
         self.config = hs.config
 
-    def get_room(self, room_id):
+    async def get_room(self, room_id: str) -> dict:
         """Retrieve a room.
 
         Args:
-            room_id (str): The ID of the room to retrieve.
+            room_id: The ID of the room to retrieve.
         Returns:
             A dict containing the room information, or None if the room is unknown.
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             table="rooms",
             keyvalues={"room_id": room_id},
             retcols=("room_id", "is_public", "creator"),
@@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
         return ret_val
 
     @cached(max_entries=10000)
-    def is_room_blocked(self, room_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_room_blocked(self, room_id: str) -> Optional[bool]:
+        return await self.db_pool.simple_select_one_onecol(
             table="blocked_rooms",
             keyvalues={"room_id": room_id},
             retcol="1",
@@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         },
                     )
 
-            with self._public_room_id_gen.get_next() as next_id:
+            with await self._public_room_id_gen.get_next() as next_id:
                 await self.db_pool.runInteraction(
                     "store_room_txn", store_room_txn, next_id
                 )
@@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with self._public_room_id_gen.get_next() as next_id:
+        with await self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
@@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with self._public_room_id_gen.get_next() as next_id:
+        with await self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 991233a9bc..458f169617 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         return event.content.get("canonical_alias")
 
     @cached(max_entries=50000)
-    def _get_state_group_for_event(self, event_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
+        return await self.db_pool.simple_select_one_onecol(
             table="event_to_state_groups",
             keyvalues={"event_id": event_id},
             retcol="state_group",
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 802c9019b9..9fe97af56a 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
 
         return len(rooms_to_work_on)
 
-    def get_stats_positions(self):
+    async def get_stats_positions(self) -> int:
         """
         Returns the stats processor positions.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="stats_incremental_position",
             keyvalues={},
             retcol="stream_id",
@@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
         return slice_list
 
     @cached()
-    def get_earliest_token_for_stats(self, stats_type, id):
+    async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
         """
         Fetch the "earliest token". This is used by the room stats delta
         processor to ignore deltas that have been processed between the
@@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
         being calculated.
 
         Returns:
-            Deferred[int]
+            The earliest token.
         """
         table, id_col = TYPE_TO_TABLE[stats_type]
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             "%s_current" % (table,),
             keyvalues={id_col: id},
             retcol="completed_delta_stream_id",
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index ade7abc927..0c34bbf21a 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
             )
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
             txn.execute(sql, (user_id, room_id, tag))
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index af21fe457a..20cbcd851c 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,6 +15,7 @@
 
 import logging
 import re
+from typing import Any, Dict, Optional
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.database import DatabasePool
@@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
     @cached()
-    def get_user_in_directory(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
             retcols=("display_name", "avatar_url"),
@@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         users.update(rows)
         return list(users)
 
-    def get_user_directory_stream_pos(self):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_user_directory_stream_pos(self) -> int:
+        return await self.db_pool.simple_select_one_onecol(
             table="user_directory_stream_pos",
             keyvalues={},
             retcol="stream_id",
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 0bf772d4d1..5b07847773 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,9 +14,10 @@
 # limitations under the License.
 
 import contextlib
+import heapq
 import threading
 from collections import deque
-from typing import Dict, Set
+from typing import Dict, List, Set
 
 from typing_extensions import Deque
 
@@ -80,7 +81,7 @@ class StreamIdGenerator(object):
             upwards, -1 to grow downwards.
 
     Usage:
-        with stream_id_gen.get_next() as stream_id:
+        with await stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
 
@@ -95,10 +96,10 @@ class StreamIdGenerator(object):
             )
         self._unfinished_ids = deque()  # type: Deque[int]
 
-    def get_next(self):
+    async def get_next(self):
         """
         Usage:
-            with stream_id_gen.get_next() as stream_id:
+            with await stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
         with self._lock:
@@ -117,10 +118,10 @@ class StreamIdGenerator(object):
 
         return manager()
 
-    def get_next_mult(self, n):
+    async def get_next_mult(self, n):
         """
         Usage:
-            with stream_id_gen.get_next(n) as stream_ids:
+            with await stream_id_gen.get_next(n) as stream_ids:
                 # ... persist events ...
         """
         with self._lock:
@@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
 
+        # We track the max position where we know everything before has been
+        # persisted. This is done by a) looking at the min across all instances
+        # and b) noting that if we have seen a run of persisted positions
+        # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+        #
+        # Note: There is no guarentee that the IDs generated by the sequence
+        # will be gapless; gaps can form when e.g. a transaction was rolled
+        # back. This means that sometimes we won't be able to skip forward the
+        # position even though everything has been persisted. However, since
+        # gaps should be relatively rare it's still worth doing the book keeping
+        # that allows us to skip forwards when there are gapless runs of
+        # positions.
+        self._persisted_upto_position = (
+            min(self._current_positions.values()) if self._current_positions else 0
+        )
+        self._known_persisted_positions = []  # type: List[int]
+
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
     def _load_current_ids(
@@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
 
         return current_positions
 
-    def _load_next_id_txn(self, txn):
+    def _load_next_id_txn(self, txn) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
 
+    def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+        return self._sequence_gen.get_next_mult_txn(txn, n)
+
     async def get_next(self):
         """
         Usage:
@@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
 
         return manager()
 
+    async def get_next_mult(self, n: int):
+        """
+        Usage:
+            with await stream_id_gen.get_next_mult(5) as stream_ids:
+                # ... persist events ...
+        """
+        next_ids = await self._db.runInteraction(
+            "_load_next_mult_id", self._load_next_mult_id_txn, n
+        )
+
+        # Assert the fetched ID is actually greater than any ID we've already
+        # seen. If not, then the sequence and table have got out of sync
+        # somehow.
+        assert max(self.get_positions().values(), default=0) < min(next_ids)
+
+        with self._lock:
+            self._unfinished_ids.update(next_ids)
+
+        @contextlib.contextmanager
+        def manager():
+            try:
+                yield next_ids
+            finally:
+                for i in next_ids:
+                    self._mark_id_as_finished(i)
+
+        return manager()
+
     def get_next_txn(self, txn: LoggingTransaction):
         """
         Usage:
@@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
             self._current_positions[instance_name] = max(
                 new_id, self._current_positions.get(instance_name, 0)
             )
+
+            self._add_persisted_position(new_id)
+
+    def get_persisted_upto_position(self) -> int:
+        """Get the max position where all previous positions have been
+        persisted.
+
+        Note: In the worst case scenario this will be equal to the minimum
+        position across writers. This means that the returned position here can
+        lag if one writer doesn't write very often.
+        """
+
+        with self._lock:
+            return self._persisted_upto_position
+
+    def _add_persisted_position(self, new_id: int):
+        """Record that we have persisted a position.
+
+        This is used to keep the `_current_positions` up to date.
+        """
+
+        # We require that the lock is locked by caller
+        assert self._lock.locked()
+
+        heapq.heappush(self._known_persisted_positions, new_id)
+
+        # We move the current min position up if the minimum current positions
+        # of all instances is higher (since by definition all positions less
+        # that that have been persisted).
+        min_curr = min(self._current_positions.values())
+        self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+        # We now iterate through the seen positions, discarding those that are
+        # less than the current min positions, and incrementing the min position
+        # if its exactly one greater.
+        #
+        # This is also where we discard items from `_known_persisted_positions`
+        # (to ensure the list doesn't infinitely grow).
+        while self._known_persisted_positions:
+            if self._known_persisted_positions[0] <= self._persisted_upto_position:
+                heapq.heappop(self._known_persisted_positions)
+            elif (
+                self._known_persisted_positions[0] == self._persisted_upto_position + 1
+            ):
+                heapq.heappop(self._known_persisted_positions)
+                self._persisted_upto_position += 1
+            else:
+                # There was a gap in seen positions, so there is nothing more to
+                # do.
+                break
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 63dfea4220..ffc1894748 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import abc
 import threading
-from typing import Callable, Optional
+from typing import Callable, List, Optional
 
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.storage.types import Cursor
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
         txn.execute("SELECT nextval(?)", (self._sequence_name,))
         return txn.fetchone()[0]
 
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        txn.execute(
+            "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+        )
+        return [i for (i,) in txn]
+
 
 GetFirstCallbackType = Callable[[Cursor], int]