summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/controllers/purge_events.py22
-rw-r--r--synapse/storage/controllers/state.py6
-rw-r--r--synapse/storage/database.py32
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/account_data.py201
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/cache.py3
-rw-r--r--synapse/storage/databases/main/deviceinbox.py5
-rw-r--r--synapse/storage/databases/main/devices.py60
-rw-r--r--synapse/storage/databases/main/directory.py4
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py59
-rw-r--r--synapse/storage/databases/main/event_federation.py12
-rw-r--r--synapse/storage/databases/main/event_push_actions.py7
-rw-r--r--synapse/storage/databases/main/events.py35
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py13
-rw-r--r--synapse/storage/databases/main/events_worker.py15
-rw-r--r--synapse/storage/databases/main/filtering.py25
-rw-r--r--synapse/storage/databases/main/media_repository.py1
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py4
-rw-r--r--synapse/storage/databases/main/purge_events.py17
-rw-r--r--synapse/storage/databases/main/push_rule.py3
-rw-r--r--synapse/storage/databases/main/pusher.py6
-rw-r--r--synapse/storage/databases/main/receipts.py17
-rw-r--r--synapse/storage/databases/main/registration.py17
-rw-r--r--synapse/storage/databases/main/relations.py138
-rw-r--r--synapse/storage/databases/main/room.py391
-rw-r--r--synapse/storage/databases/main/roommember.py19
-rw-r--r--synapse/storage/databases/main/search.py2
-rw-r--r--synapse/storage/databases/main/signatures.py6
-rw-r--r--synapse/storage/databases/main/state.py1
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--synapse/storage/databases/main/stream.py1
-rw-r--r--synapse/storage/databases/main/tags.py8
-rw-r--r--synapse/storage/databases/main/transactions.py1
-rw-r--r--synapse/storage/databases/main/user_directory.py81
-rw-r--r--synapse/storage/databases/state/bg_updates.py1
-rw-r--r--synapse/storage/databases/state/store.py128
-rw-r--r--synapse/storage/engines/__init__.py4
-rw-r--r--synapse/storage/prepare_database.py17
-rw-r--r--synapse/storage/schema/__init__.py9
-rw-r--r--synapse/storage/types.py74
-rw-r--r--synapse/storage/util/id_generators.py62
-rw-r--r--synapse/storage/util/sequence.py2
45 files changed, 896 insertions, 625 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 41d9111019..481fec72fe 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -37,6 +37,8 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
+    db_pool: DatabasePool
+
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py
index 9ca50d6a09..c599397b86 100644
--- a/synapse/storage/controllers/purge_events.py
+++ b/synapse/storage/controllers/purge_events.py
@@ -16,6 +16,7 @@ import itertools
 import logging
 from typing import TYPE_CHECKING, Set
 
+from synapse.logging.context import nested_logging_context
 from synapse.storage.databases import Databases
 
 if TYPE_CHECKING:
@@ -33,8 +34,9 @@ class PurgeEventsStorageController:
     async def purge_room(self, room_id: str) -> None:
         """Deletes all record of a room"""
 
-        state_groups_to_delete = await self.stores.main.purge_room(room_id)
-        await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
+        with nested_logging_context(room_id):
+            state_groups_to_delete = await self.stores.main.purge_room(room_id)
+            await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
 
     async def purge_history(
         self, room_id: str, token: str, delete_local_events: bool
@@ -51,15 +53,17 @@ class PurgeEventsStorageController:
                 (instead of just marking them as outliers and deleting their
                 state groups).
         """
-        state_groups = await self.stores.main.purge_history(
-            room_id, token, delete_local_events
-        )
-
-        logger.info("[purge] finding state groups that can be deleted")
+        with nested_logging_context(room_id):
+            state_groups = await self.stores.main.purge_history(
+                room_id, token, delete_local_events
+            )
 
-        sg_to_delete = await self._find_unreferenced_groups(state_groups)
+            logger.info("[purge] finding state groups that can be deleted")
+            sg_to_delete = await self._find_unreferenced_groups(state_groups)
 
-        await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
+            await self.stores.state.purge_unreferenced_state_groups(
+                room_id, sg_to_delete
+            )
 
     async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
         """Used when purging history to figure out which state groups can be
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 52efd4a171..9d7a8a792f 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -14,6 +14,7 @@
 import logging
 from typing import (
     TYPE_CHECKING,
+    AbstractSet,
     Any,
     Awaitable,
     Callable,
@@ -23,7 +24,6 @@ from typing import (
     List,
     Mapping,
     Optional,
-    Set,
     Tuple,
 )
 
@@ -527,7 +527,7 @@ class StateStorageController:
         )
         return state_map.get(key)
 
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
         """Get current hosts in room based on current state.
 
         Blocks until we have full state for the given room. This only happens for rooms
@@ -584,7 +584,7 @@ class StateStorageController:
 
     async def get_users_in_room_with_profiles(
         self, room_id: str
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Mapping[str, ProfileInfo]:
         """
         Get the current users in the room with their profiles.
         If the room is currently partial-stated, this will block until the room has
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e20c5c5302..fec4ae5b97 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -34,6 +34,7 @@ from typing import (
     Tuple,
     Type,
     TypeVar,
+    Union,
     cast,
     overload,
 )
@@ -100,6 +101,15 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 }
 
 
+class _PoolConnection(Connection):
+    """
+    A Connection from twisted.enterprise.adbapi.Connection.
+    """
+
+    def reconnect(self) -> None:
+        ...
+
+
 def make_pool(
     reactor: IReactorCore,
     db_config: DatabaseConnectionConfig,
@@ -499,6 +509,7 @@ class DatabasePool:
     """
 
     _TXN_ID = 0
+    engine: BaseDatabaseEngine
 
     def __init__(
         self,
@@ -671,7 +682,15 @@ class DatabasePool:
             f = cast(types.FunctionType, func)  # type: ignore[redundant-cast]
             if f.__closure__:
                 for i, cell in enumerate(f.__closure__):
-                    if inspect.isgenerator(cell.cell_contents):
+                    try:
+                        contents = cell.cell_contents
+                    except ValueError:
+                        # cell.cell_contents can raise if the "cell" is empty,
+                        # which indicates that the variable is currently
+                        # unbound.
+                        continue
+
+                    if inspect.isgenerator(contents):
                         logger.error(
                             "Programming error: function %s references generator %s "
                             "via its closure",
@@ -847,7 +866,8 @@ class DatabasePool:
             try:
                 with opentracing.start_active_span(f"db.{desc}"):
                     result = await self.runWithConnection(
-                        self.new_transaction,
+                        # mypy seems to have an issue with this, maybe a bug?
+                        self.new_transaction,  # type: ignore[arg-type]
                         desc,
                         after_callbacks,
                         async_after_callbacks,
@@ -883,7 +903,7 @@ class DatabasePool:
 
     async def runWithConnection(
         self,
-        func: Callable[..., R],
+        func: Callable[Concatenate[LoggingDatabaseConnection, P], R],
         *args: Any,
         db_autocommit: bool = False,
         isolation_level: Optional[int] = None,
@@ -917,7 +937,7 @@ class DatabasePool:
 
         start_time = monotonic_time()
 
-        def inner_func(conn, *args, **kwargs):
+        def inner_func(conn: _PoolConnection, *args: P.args, **kwargs: P.kwargs) -> R:
             # We shouldn't be in a transaction. If we are then something
             # somewhere hasn't committed after doing work. (This is likely only
             # possible during startup, as `run*` will ensure changes are
@@ -1010,7 +1030,7 @@ class DatabasePool:
         decoder: Optional[Callable[[Cursor], R]],
         query: str,
         *args: Any,
-    ) -> R:
+    ) -> Union[List[Tuple[Any, ...]], R]:
         """Runs a single query for a result set.
 
         Args:
@@ -1023,7 +1043,7 @@ class DatabasePool:
             The result of decoder(results)
         """
 
-        def interaction(txn):
+        def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
             txn.execute(query, args)
             if decoder:
                 return decoder(txn)
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 837dc7646e..dc3948c170 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -43,7 +43,7 @@ from .event_federation import EventFederationStore
 from .event_push_actions import EventPushActionsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
 from .events_forward_extremities import EventForwardExtremitiesStore
-from .filtering import FilteringStore
+from .filtering import FilteringWorkerStore
 from .keys import KeyStore
 from .lock import LockStore
 from .media_repository import MediaRepositoryStore
@@ -99,7 +99,7 @@ class DataStore(
     EventFederationStore,
     MediaRepositoryStore,
     RejectionsStore,
-    FilteringStore,
+    FilteringWorkerStore,
     PusherStore,
     PushRuleStore,
     ApplicationServiceTransactionStore,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 8a359d7eb8..a9843f6e17 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -21,6 +21,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Tuple,
     cast,
@@ -39,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
@@ -63,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     ):
         super().__init__(database, db_conn, hs)
 
-        # `_can_write_to_account_data` indicates whether the current worker is allowed
-        # to write account data. A value of `True` implies that `_account_data_id_gen`
-        # is an `AbstractStreamIdGenerator` and not just a tracker.
-        self._account_data_id_gen: AbstractStreamIdTracker
         self._can_write_to_account_data = (
             self._instance_name in hs.config.worker.writers.account_data
         )
 
+        self._account_data_id_gen: AbstractStreamIdGenerator
+
         if isinstance(database.engine, PostgresEngine):
             self._account_data_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
@@ -122,25 +120,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         return self._account_data_id_gen.get_current_token()
 
     @cached()
-    async def get_account_data_for_user(
+    async def get_global_account_data_for_user(
         self, user_id: str
-    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+    ) -> Mapping[str, JsonDict]:
         """
-        Get all the client account_data for a user.
+        Get all the global client account_data for a user.
 
         If experimental MSC3391 support is enabled, any entries with an empty
         content body are excluded; as this means they have been deleted.
 
         Args:
             user_id: The user to get the account_data for.
+
         Returns:
-            A 2-tuple of a dict of global account_data and a dict mapping from
-            room_id string to per room account_data dicts.
+            The global account_data.
         """
 
-        def get_account_data_for_user_txn(
+        def get_global_account_data_for_user(
             txn: LoggingTransaction,
-        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+        ) -> Dict[str, JsonDict]:
             # The 'content != '{}' condition below prevents us from using
             # `simple_select_list_txn` here, as it doesn't support conditions
             # other than 'equals'.
@@ -158,10 +156,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn.execute(sql, (user_id,))
             rows = self.db_pool.cursor_to_dict(txn)
 
-            global_account_data = {
+            return {
                 row["account_data_type"]: db_to_json(row["content"]) for row in rows
             }
 
+        return await self.db_pool.runInteraction(
+            "get_global_account_data_for_user", get_global_account_data_for_user
+        )
+
+    @cached()
+    async def get_room_account_data_for_user(
+        self, user_id: str
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
+        """
+        Get all of the per-room client account_data for a user.
+
+        If experimental MSC3391 support is enabled, any entries with an empty
+        content body are excluded; as this means they have been deleted.
+
+        Args:
+            user_id: The user to get the account_data for.
+
+        Returns:
+            A dict mapping from room_id string to per-room account_data dicts.
+        """
+
+        def get_room_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, JsonDict]]:
             # The 'content != '{}' condition below prevents us from using
             # `simple_select_list_txn` here, as it doesn't support conditions
             # other than 'equals'.
@@ -185,10 +207,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
                 room_data[row["account_data_type"]] = db_to_json(row["content"])
 
-            return global_account_data, by_room
+            return by_room
 
         return await self.db_pool.runInteraction(
-            "get_account_data_for_user", get_account_data_for_user_txn
+            "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn
         )
 
     @cached(num_args=2, max_entries=5000, tree=True)
@@ -212,10 +234,41 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         else:
             return None
 
+    async def get_latest_stream_id_for_global_account_data_by_type_for_user(
+        self, user_id: str, data_type: str
+    ) -> Optional[int]:
+        """
+        Returns:
+            The stream ID of the account data,
+            or None if there is no such account data.
+        """
+
+        def get_latest_stream_id_for_global_account_data_by_type_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[int]:
+            sql = """
+                SELECT stream_id FROM account_data
+                WHERE user_id = ? AND account_data_type = ?
+                ORDER BY stream_id DESC
+                LIMIT 1
+            """
+            txn.execute(sql, (user_id, data_type))
+
+            row = txn.fetchone()
+            if row:
+                return row[0]
+            else:
+                return None
+
+        return await self.db_pool.runInteraction(
+            "get_latest_stream_id_for_global_account_data_by_type_for_user",
+            get_latest_stream_id_for_global_account_data_by_type_for_user_txn,
+        )
+
     @cached(num_args=2, tree=True)
     async def get_account_data_for_room(
         self, user_id: str, room_id: str
-    ) -> Dict[str, JsonDict]:
+    ) -> Mapping[str, JsonDict]:
         """Get all the client account_data for a user for a room.
 
         Args:
@@ -342,36 +395,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
-    async def get_updated_account_data_for_user(
+    async def get_updated_global_account_data_for_user(
         self, user_id: str, stream_id: int
-    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
-        """Get all the client account_data for a that's changed for a user
+    ) -> Dict[str, JsonDict]:
+        """Get all the global account_data that's changed for a user.
 
         Args:
             user_id: The user to get the account_data for.
             stream_id: The point in the stream since which to get updates
+
         Returns:
-            A deferred pair of a dict of global account_data and a dict
-            mapping from room_id string to per room account_data dicts.
+            A dict of global account_data.
         """
 
-        def get_updated_account_data_for_user_txn(
+        def get_updated_global_account_data_for_user(
             txn: LoggingTransaction,
-        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
-            sql = (
-                "SELECT account_data_type, content FROM account_data"
-                " WHERE user_id = ? AND stream_id > ?"
-            )
-
+        ) -> Dict[str, JsonDict]:
+            sql = """
+                SELECT account_data_type, content FROM account_data
+                WHERE user_id = ? AND stream_id > ?
+            """
             txn.execute(sql, (user_id, stream_id))
 
-            global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
+            return {row[0]: db_to_json(row[1]) for row in txn}
 
-            sql = (
-                "SELECT room_id, account_data_type, content FROM room_account_data"
-                " WHERE user_id = ? AND stream_id > ?"
-            )
+        changed = self._account_data_stream_cache.has_entity_changed(
+            user_id, int(stream_id)
+        )
+        if not changed:
+            return {}
+
+        return await self.db_pool.runInteraction(
+            "get_updated_global_account_data_for_user",
+            get_updated_global_account_data_for_user,
+        )
+
+    async def get_updated_room_account_data_for_user(
+        self, user_id: str, stream_id: int
+    ) -> Dict[str, Dict[str, JsonDict]]:
+        """Get all the room account_data that's changed for a user.
 
+        Args:
+            user_id: The user to get the account_data for.
+            stream_id: The point in the stream since which to get updates
+
+        Returns:
+            A dict mapping from room_id string to per room account_data dicts.
+        """
+
+        def get_updated_room_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, JsonDict]]:
+            sql = """
+                SELECT room_id, account_data_type, content FROM room_account_data
+                WHERE user_id = ? AND stream_id > ?
+            """
             txn.execute(sql, (user_id, stream_id))
 
             account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
@@ -379,16 +457,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = db_to_json(row[2])
 
-            return global_account_data, account_data_by_room
+            return account_data_by_room
 
         changed = self._account_data_stream_cache.has_entity_changed(
             user_id, int(stream_id)
         )
         if not changed:
-            return {}, {}
+            return {}
 
         return await self.db_pool.runInteraction(
-            "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
+            "get_updated_room_account_data_for_user",
+            get_updated_room_account_data_for_user_txn,
         )
 
     @cached(max_entries=5000, iterable=True)
@@ -444,7 +523,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                     self.get_global_account_data_by_type_for_user.invalidate(
                         (row.user_id, row.data_type)
                     )
-                self.get_account_data_for_user.invalidate((row.user_id,))
+                self.get_global_account_data_for_user.invalidate((row.user_id,))
+                self.get_room_account_data_for_user.invalidate((row.user_id,))
                 self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
                 self.get_account_data_for_room_and_type.invalidate(
                     (row.user_id, row.room_id, row.data_type)
@@ -475,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         content_json = json_encoder.encode(content)
 
@@ -492,7 +571,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_room_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
             self.get_account_data_for_room_and_type.prefill(
                 (user_id, room_id, account_data_type), content
@@ -502,7 +581,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
     async def remove_account_data_for_room(
         self, user_id: str, room_id: str, account_data_type: str
-    ) -> Optional[int]:
+    ) -> int:
         """Delete the room account data for the user of a given type.
 
         Args:
@@ -515,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             data to delete.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def _remove_account_data_for_room_txn(
             txn: LoggingTransaction, next_id: int
@@ -554,15 +632,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 next_id,
             )
 
-            if not row_updated:
-                return None
-
-            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
-            self.get_account_data_for_room.invalidate((user_id, room_id))
-            self.get_account_data_for_room_and_type.prefill(
-                (user_id, room_id, account_data_type), {}
-            )
+            if row_updated:
+                self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+                self.get_room_account_data_for_user.invalidate((user_id,))
+                self.get_account_data_for_room.invalidate((user_id, room_id))
+                self.get_account_data_for_room_and_type.prefill(
+                    (user_id, room_id, account_data_type), {}
+                )
 
         return self._account_data_id_gen.get_current_token()
 
@@ -580,7 +656,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
@@ -593,7 +668,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
                 (user_id, account_data_type)
             )
@@ -670,7 +745,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         self,
         user_id: str,
         account_data_type: str,
-    ) -> Optional[int]:
+    ) -> int:
         """
         Delete a single piece of user account data by type.
 
@@ -687,7 +762,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             to delete.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def _remove_account_data_for_user_txn(
             txn: LoggingTransaction, next_id: int
@@ -757,14 +831,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 next_id,
             )
 
-            if not row_updated:
-                return None
-
-            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
-            self.get_global_account_data_by_type_for_user.prefill(
-                (user_id, account_data_type), {}
-            )
+            if row_updated:
+                self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+                self.get_global_account_data_for_user.invalidate((user_id,))
+                self.get_global_account_data_by_type_for_user.prefill(
+                    (user_id, account_data_type), {}
+                )
 
         return self._account_data_id_gen.get_current_token()
 
@@ -822,7 +894,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn, self.get_account_data_for_room_and_type, (user_id,)
         )
         self._invalidate_cache_and_stream(
-            txn, self.get_account_data_for_user, (user_id,)
+            txn, self.get_global_account_data_for_user, (user_id,)
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_room_account_data_for_user, (user_id,)
         )
         self._invalidate_cache_and_stream(
             txn, self.get_global_account_data_by_type_for_user, (user_id,)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5fb152c4ff..484db175d0 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
         room_id: str,
         app_service: "ApplicationService",
         cache_context: _CacheContext,
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """
         Get all users in a room that the appservice controls.
 
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 5b66431691..096dec7f87 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -266,9 +266,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         if relates_to:
             self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
             self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
-            self._attempt_to_invalidate_cache(
-                "get_aggregation_groups_for_event", (relates_to,)
-            )
             self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
             self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
             self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8e61aba454..0d75d9739a 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                         ],
                     )
 
-                for (user_id, messages_by_device) in edu["messages"].items():
-                    for (device_id, msg) in messages_by_device.items():
+                for user_id, messages_by_device in edu["messages"].items():
+                    for device_id, msg in messages_by_device.items():
                         with start_active_span("store_outgoing_to_device_message"):
                             set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
                             set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
@@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
         def _remove_dead_devices_from_device_inbox_txn(
             txn: LoggingTransaction,
         ) -> Tuple[int, bool]:
-
             if "max_stream_id" in progress:
                 max_stream_id = progress["max_stream_id"]
             else:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e8b6cc6b80..5503621ad6 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -21,6 +21,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -51,7 +52,6 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     StreamIdGenerator,
 )
 from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
@@ -90,7 +90,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+        self._device_list_id_gen = StreamIdGenerator(
             db_conn,
             hs.get_replication_notifier(),
             "device_lists_stream",
@@ -100,6 +100,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 ("device_lists_outbound_pokes", "stream_id"),
                 ("device_lists_changes_in_room", "stream_id"),
                 ("device_lists_remote_pending", "stream_id"),
+                ("device_lists_changes_converted_stream_position", "stream_id"),
             ],
             is_writer=hs.config.worker.worker_app is None,
         )
@@ -201,7 +202,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
-    async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+    async def count_devices_by_users(
+        self, user_ids: Optional[Collection[str]] = None
+    ) -> int:
         """Retrieve number of all devices of given users.
         Only returns number of devices that are not marked as hidden.
 
@@ -212,7 +215,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         """
 
         def count_devices_by_users_txn(
-            txn: LoggingTransaction, user_ids: List[str]
+            txn: LoggingTransaction, user_ids: Collection[str]
         ) -> int:
             sql = """
                 SELECT count(*)
@@ -508,7 +511,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             results.append(("org.matrix.signing_key_update", result))
 
         if issue_8631_logger.isEnabledFor(logging.DEBUG):
-            for (user_id, edu) in results:
+            for user_id, edu in results:
                 issue_8631_logger.debug(
                     "device update to %s for %s from %s to %s: %s",
                     destination,
@@ -708,9 +711,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             The new stream ID.
         """
 
-        # TODO: this looks like it's _writing_. Should this be on DeviceStore rather
-        #  than DeviceWorkerStore?
-        async with self._device_list_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
+        async with self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -745,42 +746,47 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     @trace
     @cancellable
     async def get_user_devices_from_cache(
-        self, query_list: List[Tuple[str, Optional[str]]]
-    ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
+        self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
+    ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
         Args:
-            query_list: List of (user_id, device_ids), if device_ids is
-                falsey then return all device ids for that user.
+            user_ids: users which should have all device IDs returned
+            user_and_device_ids: List of (user_id, device_ids)
 
         Returns:
             A tuple of (user_ids_not_in_cache, results_map), where
             user_ids_not_in_cache is a set of user_ids and results_map is a
             mapping of user_id -> device_id -> device_info.
         """
-        user_ids = {user_id for user_id, _ in query_list}
-        user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+        unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
+        user_map = await self.get_device_list_last_stream_id_for_remotes(
+            list(unique_user_ids)
+        )
 
         # We go and check if any of the users need to have their device lists
         # resynced. If they do then we remove them from the cached list.
         users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
-            user_ids
+            unique_user_ids
         )
         user_ids_in_cache = {
             user_id for user_id, stream_id in user_map.items() if stream_id
         } - users_needing_resync
-        user_ids_not_in_cache = user_ids - user_ids_in_cache
-
-        results: Dict[str, Dict[str, JsonDict]] = {}
-        for user_id, device_id in query_list:
-            if user_id not in user_ids_in_cache:
-                continue
+        user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
 
-            if device_id:
-                device = await self._get_cached_user_device(user_id, device_id)
-                results.setdefault(user_id, {})[device_id] = device
-            else:
+        # First fetch all the users which all devices are to be returned.
+        results: Dict[str, Mapping[str, JsonDict]] = {}
+        for user_id in user_ids:
+            if user_id in user_ids_in_cache:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
+        # Then fetch all device-specific requests, but skip users we've already
+        # fetched all devices for.
+        device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
+        for user_id, device_id in user_and_device_ids:
+            if user_id in user_ids_in_cache and user_id not in user_ids:
+                device = await self._get_cached_user_device(user_id, device_id)
+                device_specific_results.setdefault(user_id, {})[device_id] = device
+        results.update(device_specific_results)
 
         set_tag("in_cache", str(results))
         set_tag("not_in_cache", str(user_ids_not_in_cache))
@@ -798,7 +804,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         return db_to_json(content)
 
     @cached()
-    async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+    async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
         devices = await self.db_pool.simple_select_list(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id},
@@ -1307,7 +1313,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 )
             """
             count = 0
-            for (destination, user_id, stream_id, device_id) in rows:
+            for destination, user_id, stream_id, device_id in rows:
                 txn.execute(
                     delete_sql, (destination, user_id, stream_id, stream_id, device_id)
                 )
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 5903fdaf00..44aa181174 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Sequence, Tuple
 
 import attr
 
@@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
         )
 
     @cached(max_entries=5000)
-    async def get_aliases_for_room(self, room_id: str) -> List[str]:
+    async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
         return await self.db_pool.simple_select_onecol(
             "room_aliases",
             {"room_id": room_id},
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 6240f9a75e..9f8d2e4bea 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             raise StoreError(404, "No backup with that version exists")
 
         values = []
-        for (room_id, session_id, room_key) in room_keys:
+        for room_id, session_id, room_key in room_keys:
             values.append(
                 (
                     user_id,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c4ac6c33ba..a3b6c8ae8e 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -20,7 +20,9 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
+    Sequence,
     Tuple,
     Union,
     cast,
@@ -242,9 +244,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         set_tag("include_all_devices", include_all_devices)
         set_tag("include_deleted_devices", include_deleted_devices)
 
-        result = await self.db_pool.runInteraction(
-            "get_e2e_device_keys",
-            self._get_e2e_device_keys_txn,
+        result = await self._get_e2e_device_keys(
             query_list,
             include_all_devices,
             include_deleted_devices,
@@ -260,13 +260,13 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         for batch in batch_iter(signature_query, 50):
             cross_sigs_result = await self.db_pool.runInteraction(
-                "get_e2e_cross_signing_signatures",
+                "get_e2e_cross_signing_signatures_for_devices",
                 self._get_e2e_cross_signing_signatures_for_devices_txn,
                 batch,
             )
 
             # add each cross-signing signature to the correct device in the result dict.
-            for (user_id, key_id, device_id, signature) in cross_sigs_result:
+            for user_id, key_id, device_id, signature in cross_sigs_result:
                 target_device_result = result[user_id][device_id]
                 # We've only looked up cross-signatures for non-deleted devices with key
                 # data.
@@ -283,9 +283,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         log_kv(result)
         return result
 
-    def _get_e2e_device_keys_txn(
+    async def _get_e2e_device_keys(
         self,
-        txn: LoggingTransaction,
         query_list: Collection[Tuple[str, Optional[str]]],
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
@@ -309,7 +308,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         # devices.
         user_list = []
         user_device_list = []
-        for (user_id, device_id) in query_list:
+        for user_id, device_id in query_list:
             if device_id is None:
                 user_list.append(user_id)
             else:
@@ -317,7 +316,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         if user_list:
             user_id_in_list_clause, user_args = make_in_list_sql_clause(
-                txn.database_engine, "user_id", user_list
+                self.database_engine, "user_id", user_list
             )
             query_clauses.append(user_id_in_list_clause)
             query_params_list.append(user_args)
@@ -330,13 +329,16 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                     user_device_id_in_list_clause,
                     user_device_args,
                 ) = make_tuple_in_list_sql_clause(
-                    txn.database_engine, ("user_id", "device_id"), user_device_batch
+                    self.database_engine, ("user_id", "device_id"), user_device_batch
                 )
                 query_clauses.append(user_device_id_in_list_clause)
                 query_params_list.append(user_device_args)
 
         result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
-        for query_clause, query_params in zip(query_clauses, query_params_list):
+
+        def get_e2e_device_keys_txn(
+            txn: LoggingTransaction, query_clause: str, query_params: list
+        ) -> None:
             sql = (
                 "SELECT user_id, device_id, "
                 "    d.display_name, "
@@ -351,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
             txn.execute(sql, query_params)
 
-            for (user_id, device_id, display_name, key_json) in txn:
+            for user_id, device_id, display_name, key_json in txn:
                 assert device_id is not None
                 if include_deleted_devices:
                     deleted_devices.remove((user_id, device_id))
@@ -359,6 +361,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                     display_name, db_to_json(key_json) if key_json else None
                 )
 
+        for query_clause, query_params in zip(query_clauses, query_params_list):
+            await self.db_pool.runInteraction(
+                "_get_e2e_device_keys",
+                get_e2e_device_keys_txn,
+                query_clause,
+                query_params,
+            )
+
         if include_deleted_devices:
             for user_id, device_id in deleted_devices:
                 if device_id is None:
@@ -380,7 +390,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         signature_query_clauses = []
         signature_query_params = []
 
-        for (user_id, device_id) in device_query:
+        for user_id, device_id in device_query:
             signature_query_clauses.append(
                 "target_user_id = ? AND target_device_id = ? AND user_id = ?"
             )
@@ -691,7 +701,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cached(max_entries=10000)
     async def get_e2e_unused_fallback_key_types(
         self, user_id: str, device_id: str
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """Returns the fallback key types that have an unused key.
 
         Args:
@@ -731,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         return user_keys.get(key_type)
 
     @cached(num_args=1)
-    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
+    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
         """Dummy function.  Only used to make a cache for
         _get_bare_e2e_cross_signing_keys_bulk.
         """
@@ -744,7 +754,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: Iterable[str]
-    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+    ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
@@ -765,7 +775,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         )
 
         # The `Optional` comes from the `@cachedList` decorator.
-        return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+        return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
 
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
@@ -924,7 +934,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cancellable
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+    ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -940,11 +950,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
 
         if from_user_id:
-            result = await self.db_pool.runInteraction(
-                "get_e2e_cross_signing_signatures",
-                self._get_e2e_cross_signing_signatures_txn,
-                result,
-                from_user_id,
+            result = cast(
+                Dict[str, Optional[Mapping[str, JsonDict]]],
+                await self.db_pool.runInteraction(
+                    "get_e2e_cross_signing_signatures",
+                    self._get_e2e_cross_signing_signatures_txn,
+                    result,
+                    from_user_id,
+                ),
             )
 
         return result
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index bbee02ab18..ff3edeb716 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
     Set,
     Tuple,
     cast,
@@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
-    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
+    async def get_max_depth_of(
+        self, event_ids: Collection[str]
+    ) -> Tuple[Optional[str], int]:
         """Returns the event ID and depth for the event that has the max depth from a set of event IDs
 
         Args:
@@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     @cached(max_entries=5000, iterable=True)
-    async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+    async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
         return await self.db_pool.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
@@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     @cancellable
     async def get_forward_extremities_for_room_at_stream_ordering(
         self, room_id: str, stream_ordering: int
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     @cached(max_entries=5000, num_args=2)
     async def _get_forward_extremeties_for_room(
         self, room_id: str, stream_ordering: int
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -1609,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         latest_events: List[str],
         limit: int,
     ) -> List[str]:
-
         seen_events = set(earliest_events)
         front = set(latest_events) - seen_events
         event_results: List[str] = []
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3a0c370fde..eeccf5db24 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -203,11 +203,18 @@ class RoomNotifCounts:
     # Map of thread ID to the notification counts.
     threads: Dict[str, NotifCounts]
 
+    @staticmethod
+    def empty() -> "RoomNotifCounts":
+        return _EMPTY_ROOM_NOTIF_COUNTS
+
     def __len__(self) -> int:
         # To properly account for the amount of space in any caches.
         return len(self.threads) + 1
 
 
+_EMPTY_ROOM_NOTIF_COUNTS = RoomNotifCounts(NotifCounts(), {})
+
+
 def _serialize_action(
     actions: Collection[Union[Mapping, str]], is_highlight: bool
 ) -> str:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1536937b67..a8a4ed4436 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -16,7 +16,6 @@
 import itertools
 import logging
 from collections import OrderedDict
-from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -26,7 +25,6 @@ from typing import (
     Iterable,
     List,
     Optional,
-    Sequence,
     Set,
     Tuple,
 )
@@ -36,7 +34,7 @@ from prometheus_client import Counter
 
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import PartialStateConflictError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
@@ -52,7 +50,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 from synapse.util.stringutils import non_null_str_or_none
@@ -72,24 +70,6 @@ event_counter = Counter(
 )
 
 
-class PartialStateConflictError(SynapseError):
-    """An internal error raised when attempting to persist an event with partial state
-    after the room containing the event has been un-partial stated.
-
-    This error should be handled by recomputing the event context and trying again.
-
-    This error has an HTTP status code so that it can be transported over replication.
-    It should not be exposed to clients.
-    """
-
-    def __init__(self) -> None:
-        super().__init__(
-            HTTPStatus.CONFLICT,
-            msg="Cannot persist partial state event in un-partial stated room",
-            errcode=Codes.UNKNOWN,
-        )
-
-
 @attr.s(slots=True, auto_attribs=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
@@ -306,7 +286,7 @@ class PersistEventsStore:
 
         # The set of event_ids to return. This includes all soft-failed events
         # and their prev events.
-        existing_prevs = set()
+        existing_prevs: Set[str] = set()
 
         def _get_prevs_before_rejected_txn(
             txn: LoggingTransaction, batch: Collection[str]
@@ -489,7 +469,6 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         events: List[EventBase],
     ) -> None:
-
         # We only care about state events, so this if there are no state events.
         if not any(e.is_state() for e in events):
             return
@@ -571,7 +550,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, Sequence[str]],
+        event_to_auth_chain: Dict[str, StrCollection],
     ) -> None:
         """Calculate the chain cover index for the given events.
 
@@ -865,7 +844,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, Sequence[str]],
+        event_to_auth_chain: Dict[str, StrCollection],
         events_to_calc_chain_id_for: Set[str],
         chain_map: Dict[str, Tuple[int, int]],
     ) -> Dict[str, Tuple[int, int]]:
@@ -2045,10 +2024,6 @@ class PersistEventsStore:
         self.store._invalidate_cache_and_stream(
             txn, self.store.get_relations_for_event, (redacted_relates_to,)
         )
-        if rel_type == RelationTypes.ANNOTATION:
-            self.store._invalidate_cache_and_stream(
-                txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
-            )
         if rel_type == RelationTypes.REFERENCE:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_references_for_event, (redacted_relates_to,)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index b9d3c36d60..daef3685b0 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
 
 import attr
 
@@ -29,7 +29,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.types import Cursor
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             nbrows = 0
             last_row_event_id = ""
-            for (event_id, event_json_raw) in results:
+            for event_id, event_json_raw in results:
                 try:
                     event_json = db_to_json(event_json_raw)
 
@@ -1061,7 +1061,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             self.event_chain_id_gen,  # type: ignore[attr-defined]
             event_to_room_id,
             event_to_types,
-            cast(Dict[str, Sequence[str]], event_to_auth_chain),
+            cast(Dict[str, StrCollection], event_to_auth_chain),
         )
 
         return _CalculateChainCover(
@@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             results = list(txn)
             # (event_id, parent_id, rel_type) for each relation
             relations_to_insert: List[Tuple[str, str, str]] = []
-            for (event_id, event_json_raw) in results:
+            for event_id, event_json_raw in results:
                 try:
                     event_json = db_to_json(event_json_raw)
                 except Exception as e:
@@ -1220,9 +1220,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                         txn, self.get_relations_for_event, cache_tuple  # type: ignore[attr-defined]
                     )
                     self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
-                        txn, self.get_aggregation_groups_for_event, cache_tuple  # type: ignore[attr-defined]
-                    )
-                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                         txn, self.get_thread_summary, cache_tuple  # type: ignore[attr-defined]
                     )
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a9259fe446..0d90407ebd 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -72,7 +72,6 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
@@ -187,8 +186,8 @@ class EventsWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self._stream_id_gen: AbstractStreamIdTracker
-        self._backfill_id_gen: AbstractStreamIdTracker
+        self._stream_id_gen: AbstractStreamIdGenerator
+        self._backfill_id_gen: AbstractStreamIdGenerator
         if isinstance(database.engine, PostgresEngine):
             # If we're using Postgres than we can use `MultiWriterIdGenerator`
             # regardless of whether this process writes to the streams or not.
@@ -1493,7 +1492,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             txn.execute(redactions_sql + clause, args)
 
-            for (redacter, redacted) in txn:
+            for redacter, redacted in txn:
                 d = event_dict.get(redacted)
                 if d:
                     d.redactions.append(redacter)
@@ -1779,7 +1778,7 @@ class EventsWorkerStore(SQLBaseStore):
             txn: LoggingTransaction,
         ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
             sql = (
-                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+                "SELECT out.event_stream_ordering, e.event_id, e.room_id, e.type,"
                 " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
                 " e.outlier"
                 " FROM events AS e"
@@ -1791,10 +1790,10 @@ class EventsWorkerStore(SQLBaseStore):
                 " LEFT JOIN event_relations USING (event_id)"
                 " LEFT JOIN room_memberships USING (event_id)"
                 " LEFT JOIN rejections USING (event_id)"
-                " WHERE ? < event_stream_ordering"
-                " AND event_stream_ordering <= ?"
+                " WHERE ? < out.event_stream_ordering"
+                " AND out.event_stream_ordering <= ?"
                 " AND out.instance_name = ?"
-                " ORDER BY event_stream_ordering ASC"
+                " ORDER BY out.event_stream_ordering ASC"
             )
 
             txn.execute(sql, (last_id, current_id, instance_name))
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 12f3b601f1..8e57c8e5a0 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,7 +17,7 @@ from typing import Optional, Tuple, Union, cast
 
 from canonicaljson import encode_canonical_json
 
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict
@@ -46,8 +46,6 @@ class FilteringWorkerStore(SQLBaseStore):
 
         return db_to_json(def_json)
 
-
-class FilteringStore(FilteringWorkerStore):
     async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
         def_json = encode_canonical_json(user_filter)
 
@@ -79,4 +77,23 @@ class FilteringStore(FilteringWorkerStore):
 
             return filter_id
 
-        return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+        attempts = 0
+        while True:
+            # Try a few times.
+            # This is technically needed if a user tries to create two filters at once,
+            # leading to two concurrent transactions.
+            # The failure case would be:
+            # - SELECT filter_id ... filter_json = ? → both transactions return no rows
+            # - SELECT MAX(filter_id) ... → both transactions return e.g. 5
+            # - INSERT INTO ... → both transactions insert filter_id = 6
+            # One of the transactions will commit. The other will get a unique key
+            # constraint violation error (IntegrityError). This is not the same as a
+            # serialisability violation, which would be automatically retried by
+            # `runInteraction`.
+            try:
+                return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+            except self.db_pool.engine.module.IntegrityError:
+                attempts += 1
+
+                if attempts >= 5:
+                    raise StoreError(500, "Couldn't generate a filter ID.")
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index b202c5eb87..fa8be214ce 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def get_local_media_by_user_paginate_txn(
             txn: LoggingTransaction,
         ) -> Tuple[List[Dict[str, Any]], int]:
-
             # Set ordering
             order_by_column = MediaSortOrder(order_by).value
 
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index db9a24db5e..4b1061e6d7 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage.database import (
@@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
         return await self.db_pool.runInteraction("count_users", _count_users)
 
     @cached(num_args=0)
-    async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
+    async def get_monthly_active_count_by_service(self) -> Mapping[str, int]:
         """Generates current count of monthly active users broken down by service.
         A service is typically an appservice but also includes native matrix users.
         Since the `monthly_active_users` table is populated from the `user_ips` table
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 9213ce0b5a..7a7c0d9c75 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -325,6 +325,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         # We then run the same purge a second time without this isolation level to
         # purge any of those rows which were added during the first.
 
+        logger.info("[purge] Starting initial main purge of [1/2]")
         state_groups_to_delete = await self.db_pool.runInteraction(
             "purge_room",
             self._purge_room_txn,
@@ -332,6 +333,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             isolation_level=IsolationLevel.READ_COMMITTED,
         )
 
+        logger.info("[purge] Starting secondary main purge of [2/2]")
         state_groups_to_delete.extend(
             await self.db_pool.runInteraction(
                 "purge_room",
@@ -339,6 +341,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
                 room_id=room_id,
             ),
         )
+        logger.info("[purge] Done with main purge")
 
         return state_groups_to_delete
 
@@ -376,7 +379,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         )
         referenced_chain_id_tuples = list(txn)
 
-        logger.info("[purge] removing events from event_auth_chain_links")
+        logger.info("[purge] removing from event_auth_chain_links")
         txn.executemany(
             """
             DELETE FROM event_auth_chain_links WHERE
@@ -399,7 +402,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "rejections",
             "state_events",
         ):
-            logger.info("[purge] removing %s from %s", room_id, table)
+            logger.info("[purge] removing from %s", table)
 
             txn.execute(
                 """
@@ -420,12 +423,14 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_push_actions",
             "event_search",
             "event_failed_pull_attempts",
+            # Note: the partial state tables have foreign keys between each other, and to
+            # `events` and `rooms`. We need to delete from them in the right order.
             "partial_state_events",
+            "partial_state_rooms_servers",
+            "partial_state_rooms",
             "events",
             "federation_inbound_events_staging",
             "local_current_membership",
-            "partial_state_rooms_servers",
-            "partial_state_rooms",
             "receipts_graph",
             "receipts_linearized",
             "room_aliases",
@@ -452,7 +457,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             # happy
             "rooms",
         ):
-            logger.info("[purge] removing %s from %s", room_id, table)
+            logger.info("[purge] removing from %s", table)
             txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
 
         # Other tables we do NOT need to clear out:
@@ -484,6 +489,4 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         #   that already exist.
         self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
 
-        logger.info("[purge] done")
-
         return state_groups
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 9b2bbe060d..9f862f00c1 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -46,7 +46,6 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     IdGenerator,
     StreamIdGenerator,
 )
@@ -118,7 +117,7 @@ class PushRulesWorkerStore(
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+        self._push_rules_stream_id_gen = StreamIdGenerator(
             db_conn,
             hs.get_replication_notifier(),
             "push_rules_stream",
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df53e726e6..9a24f7a655 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -36,7 +36,6 @@ from synapse.storage.database import (
 )
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     StreamIdGenerator,
 )
 from synapse.types import JsonDict
@@ -60,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+        self._pushers_id_gen = StreamIdGenerator(
             db_conn,
             hs.get_replication_notifier(),
             "pushers",
@@ -344,7 +343,6 @@ class PusherWorkerStore(SQLBaseStore):
         last_user = progress.get("last_user", "")
 
         def _delete_pushers(txn: LoggingTransaction) -> int:
-
             sql = """
                 SELECT name FROM users
                 WHERE deactivated = ? and name > ?
@@ -392,7 +390,6 @@ class PusherWorkerStore(SQLBaseStore):
         last_pusher = progress.get("last_pusher", 0)
 
         def _delete_pushers(txn: LoggingTransaction) -> int:
-
             sql = """
                 SELECT p.id, access_token FROM pushers AS p
                 LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
@@ -449,7 +446,6 @@ class PusherWorkerStore(SQLBaseStore):
         last_pusher = progress.get("last_pusher", 0)
 
         def _delete_pushers(txn: LoggingTransaction) -> int:
-
             sql = """
                 SELECT p.id, p.user_name, p.app_id, p.pushkey
                 FROM pushers AS p
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 29972d5204..074942b167 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -21,7 +21,9 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
+    Sequence,
     Tuple,
     cast,
 )
@@ -37,7 +39,7 @@ from synapse.storage.database import (
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.engines._base import IsolationLevel
 from synapse.storage.util.id_generators import (
-    AbstractStreamIdTracker,
+    AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
@@ -63,7 +65,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._receipts_id_gen: AbstractStreamIdTracker
+        self._receipts_id_gen: AbstractStreamIdGenerator
 
         if isinstance(database.engine, PostgresEngine):
             self._can_write_to_receipts = (
@@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     async def get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[dict]:
+    ) -> Sequence[JsonDict]:
         """Get receipts for a single room for sending to clients.
 
         Args:
@@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cached(tree=True)
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[JsonDict]:
+    ) -> Sequence[JsonDict]:
         """See get_linearized_receipts_for_room"""
 
         def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     async def _get_linearized_receipts_for_rooms(
         self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
-    ) -> Dict[str, List[JsonDict]]:
+    ) -> Dict[str, Sequence[JsonDict]]:
         if not room_ids:
             return {}
 
@@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     async def get_linearized_receipts_for_all_rooms(
         self, to_key: int, from_key: Optional[int] = None
-    ) -> Dict[str, JsonDict]:
+    ) -> Mapping[str, JsonDict]:
         """Get receipts for all rooms between two stream_ids, up
         to a limit of the latest 100 read receipts.
 
@@ -766,7 +768,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
             )
 
-        async with self._receipts_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
+        async with self._receipts_id_gen.get_next() as stream_id:
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self._insert_linearized_receipt_txn,
@@ -885,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
         def _populate_receipt_event_stream_ordering_txn(
             txn: LoggingTransaction,
         ) -> bool:
-
             if "max_stream_id" in progress:
                 max_stream_id = progress["max_stream_id"]
             else:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 31f0f2bd3d..717237e024 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
 
 import attr
 
@@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             )
 
     @cached()
-    async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+    async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
         """Deprecated: use get_userinfo_by_id instead"""
 
         def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
@@ -1002,19 +1002,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="user_delete_threepid",
         )
 
-    async def user_delete_threepids(self, user_id: str) -> None:
-        """Delete all threepid this user has bound
-
-        Args:
-             user_id: The user id to delete all threepids of
-
-        """
-        await self.db_pool.simple_delete(
-            "user_threepids",
-            keyvalues={"user_id": user_id},
-            desc="user_delete_threepids",
-        )
-
     async def add_user_bound_threepid(
         self, user_id: str, medium: str, address: str, id_server: str
     ) -> None:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0018d6f7ab..bc3a83919c 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -22,6 +22,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Union,
@@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
         direction: Direction = Direction.BACKWARDS,
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
-    ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
+    ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
         """Get a list of relations for an event, ordered by topological ordering.
 
         Args:
@@ -397,141 +398,6 @@ class RelationsWorkerStore(SQLBaseStore):
         return result is not None
 
     @cached()
-    async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
-        raise NotImplementedError()
-
-    @cachedList(
-        cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
-    )
-    async def get_aggregation_groups_for_events(
-        self, event_ids: Collection[str]
-    ) -> Mapping[str, Optional[List[JsonDict]]]:
-        """Get a list of annotations on the given events, grouped by event type and
-        aggregation key, sorted by count.
-
-        This is used e.g. to get the what and how many reactions have happend
-        on an event.
-
-        Args:
-            event_ids: Fetch events that relate to these event IDs.
-
-        Returns:
-            A map of event IDs to a list of groups of annotations that match.
-            Each entry is a dict with `type`, `key` and `count` fields.
-        """
-        # The number of entries to return per event ID.
-        limit = 5
-
-        clause, args = make_in_list_sql_clause(
-            self.database_engine, "relates_to_id", event_ids
-        )
-        args.append(RelationTypes.ANNOTATION)
-
-        sql = f"""
-            SELECT
-                relates_to_id,
-                annotation.type,
-                aggregation_key,
-                COUNT(DISTINCT annotation.sender)
-            FROM events AS annotation
-            INNER JOIN event_relations USING (event_id)
-            INNER JOIN events AS parent ON
-                parent.event_id = relates_to_id
-                AND parent.room_id = annotation.room_id
-            WHERE
-                {clause}
-                AND relation_type = ?
-            GROUP BY relates_to_id, annotation.type, aggregation_key
-            ORDER BY relates_to_id, COUNT(*) DESC
-        """
-
-        def _get_aggregation_groups_for_events_txn(
-            txn: LoggingTransaction,
-        ) -> Mapping[str, List[JsonDict]]:
-            txn.execute(sql, args)
-
-            result: Dict[str, List[JsonDict]] = {}
-            for event_id, type, key, count in cast(
-                List[Tuple[str, str, str, int]], txn
-            ):
-                event_results = result.setdefault(event_id, [])
-
-                # Limit the number of results per event ID.
-                if len(event_results) == limit:
-                    continue
-
-                event_results.append({"type": type, "key": key, "count": count})
-
-            return result
-
-        return await self.db_pool.runInteraction(
-            "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn
-        )
-
-    async def get_aggregation_groups_for_users(
-        self, event_ids: Collection[str], users: FrozenSet[str]
-    ) -> Dict[str, Dict[Tuple[str, str], int]]:
-        """Fetch the partial aggregations for an event for specific users.
-
-        This is used, in conjunction with get_aggregation_groups_for_event, to
-        remove information from the results for ignored users.
-
-        Args:
-            event_ids: Fetch events that relate to these event IDs.
-            users: The users to fetch information for.
-
-        Returns:
-            A map of event ID to a map of (event type, aggregation key) to a
-            count of users.
-        """
-
-        if not users:
-            return {}
-
-        events_sql, args = make_in_list_sql_clause(
-            self.database_engine, "relates_to_id", event_ids
-        )
-
-        users_sql, users_args = make_in_list_sql_clause(
-            self.database_engine, "annotation.sender", users
-        )
-        args.extend(users_args)
-        args.append(RelationTypes.ANNOTATION)
-
-        sql = f"""
-            SELECT
-                relates_to_id,
-                annotation.type,
-                aggregation_key,
-                COUNT(DISTINCT annotation.sender)
-            FROM events AS annotation
-            INNER JOIN event_relations USING (event_id)
-            INNER JOIN events AS parent ON
-                parent.event_id = relates_to_id
-                AND parent.room_id = annotation.room_id
-            WHERE {events_sql} AND {users_sql} AND relation_type = ?
-            GROUP BY relates_to_id, annotation.type, aggregation_key
-            ORDER BY relates_to_id, COUNT(*) DESC
-        """
-
-        def _get_aggregation_groups_for_users_txn(
-            txn: LoggingTransaction,
-        ) -> Dict[str, Dict[Tuple[str, str], int]]:
-            txn.execute(sql, args)
-
-            result: Dict[str, Dict[Tuple[str, str], int]] = {}
-            for event_id, type, key, count in cast(
-                List[Tuple[str, str, str, int]], txn
-            ):
-                result.setdefault(event_id, {})[(type, key)] = count
-
-            return result
-
-        return await self.db_pool.runInteraction(
-            "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
-        )
-
-    @cached()
     async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
         raise NotImplementedError()
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 644bbb8878..3825bd6079 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1417,6 +1417,204 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             get_un_partial_stated_rooms_from_stream_txn,
         )
 
+    async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
+        """Retrieve an event report
+
+        Args:
+            report_id: ID of reported event in database
+        Returns:
+            JSON dict of information from an event report or None if the
+            report does not exist.
+        """
+
+        def _get_event_report_txn(
+            txn: LoggingTransaction, report_id: int
+        ) -> Optional[Dict[str, Any]]:
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.content,
+                    events.sender,
+                    room_stats_state.canonical_alias,
+                    room_stats_state.name,
+                    event_json.json AS event_json
+                FROM event_reports AS er
+                LEFT JOIN events
+                    ON events.event_id = er.event_id
+                JOIN event_json
+                    ON event_json.event_id = er.event_id
+                JOIN room_stats_state
+                    ON room_stats_state.room_id = er.room_id
+                WHERE er.id = ?
+            """
+
+            txn.execute(sql, [report_id])
+            row = txn.fetchone()
+
+            if not row:
+                return None
+
+            event_report = {
+                "id": row[0],
+                "received_ts": row[1],
+                "room_id": row[2],
+                "event_id": row[3],
+                "user_id": row[4],
+                "score": db_to_json(row[5]).get("score"),
+                "reason": db_to_json(row[5]).get("reason"),
+                "sender": row[6],
+                "canonical_alias": row[7],
+                "name": row[8],
+                "event_json": db_to_json(row[9]),
+            }
+
+            return event_report
+
+        return await self.db_pool.runInteraction(
+            "get_event_report", _get_event_report_txn, report_id
+        )
+
+    async def get_event_reports_paginate(
+        self,
+        start: int,
+        limit: int,
+        direction: Direction = Direction.BACKWARDS,
+        user_id: Optional[str] = None,
+        room_id: Optional[str] = None,
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Retrieve a paginated list of event reports
+
+        Args:
+            start: event offset to begin the query from
+            limit: number of rows to retrieve
+            direction: Whether to fetch the most recent first (backwards) or the
+                oldest first (forwards)
+            user_id: search for user_id. Ignored if user_id is None
+            room_id: search for room_id. Ignored if room_id is None
+        Returns:
+            Tuple of:
+                json list of event reports
+                total number of event reports matching the filter criteria
+        """
+
+        def _get_event_reports_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], int]:
+            filters = []
+            args: List[object] = []
+
+            if user_id:
+                filters.append("er.user_id LIKE ?")
+                args.extend(["%" + user_id + "%"])
+            if room_id:
+                filters.append("er.room_id LIKE ?")
+                args.extend(["%" + room_id + "%"])
+
+            if direction == Direction.BACKWARDS:
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+            # We join on room_stats_state despite not using any columns from it
+            # because the join can influence the number of rows returned;
+            # e.g. a room that doesn't have state, maybe because it was deleted.
+            # The query returning the total count should be consistent with
+            # the query returning the results.
+            sql = """
+                SELECT COUNT(*) as total_event_reports
+                FROM event_reports AS er
+                JOIN room_stats_state ON room_stats_state.room_id = er.room_id
+                {}
+                """.format(
+                where_clause
+            )
+            txn.execute(sql, args)
+            count = cast(Tuple[int], txn.fetchone())[0]
+
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.content,
+                    events.sender,
+                    room_stats_state.canonical_alias,
+                    room_stats_state.name
+                FROM event_reports AS er
+                LEFT JOIN events
+                    ON events.event_id = er.event_id
+                JOIN room_stats_state
+                    ON room_stats_state.room_id = er.room_id
+                {where_clause}
+                ORDER BY er.received_ts {order}
+                LIMIT ?
+                OFFSET ?
+            """.format(
+                where_clause=where_clause,
+                order=order,
+            )
+
+            args += [limit, start]
+            txn.execute(sql, args)
+
+            event_reports = []
+            for row in txn:
+                try:
+                    s = db_to_json(row[5]).get("score")
+                    r = db_to_json(row[5]).get("reason")
+                except Exception:
+                    logger.error("Unable to parse json from event_reports: %s", row[0])
+                    continue
+                event_reports.append(
+                    {
+                        "id": row[0],
+                        "received_ts": row[1],
+                        "room_id": row[2],
+                        "event_id": row[3],
+                        "user_id": row[4],
+                        "score": s,
+                        "reason": r,
+                        "sender": row[6],
+                        "canonical_alias": row[7],
+                        "name": row[8],
+                    }
+                )
+
+            return event_reports, count
+
+        return await self.db_pool.runInteraction(
+            "get_event_reports_paginate", _get_event_reports_paginate_txn
+        )
+
+    async def delete_event_report(self, report_id: int) -> bool:
+        """Remove an event report from database.
+
+        Args:
+            report_id: Report to delete
+
+        Returns:
+            Whether the report was successfully deleted or not.
+        """
+        try:
+            await self.db_pool.simple_delete_one(
+                table="event_reports",
+                keyvalues={"id": report_id},
+                desc="delete_event_report",
+            )
+        except StoreError:
+            # Deletion failed because report does not exist
+            return False
+
+        return True
+
 
 class _BackgroundUpdates:
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -2139,7 +2337,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         reason: Optional[str],
         content: JsonDict,
         received_ts: int,
-    ) -> None:
+    ) -> int:
+        """Add an event report
+
+        Args:
+            room_id: Room that contains the reported event.
+            event_id: The reported event.
+            user_id: User who reports the event.
+            reason: Description that the user specifies.
+            content: Report request body (score and reason).
+            received_ts: Time when the user submitted the report (milliseconds).
+        Returns:
+            Id of the event report.
+        """
         next_id = self._event_reports_id_gen.get_next()
         await self.db_pool.simple_insert(
             table="event_reports",
@@ -2154,184 +2364,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             },
             desc="add_event_report",
         )
-
-    async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
-        """Retrieve an event report
-
-        Args:
-            report_id: ID of reported event in database
-        Returns:
-            JSON dict of information from an event report or None if the
-            report does not exist.
-        """
-
-        def _get_event_report_txn(
-            txn: LoggingTransaction, report_id: int
-        ) -> Optional[Dict[str, Any]]:
-
-            sql = """
-                SELECT
-                    er.id,
-                    er.received_ts,
-                    er.room_id,
-                    er.event_id,
-                    er.user_id,
-                    er.content,
-                    events.sender,
-                    room_stats_state.canonical_alias,
-                    room_stats_state.name,
-                    event_json.json AS event_json
-                FROM event_reports AS er
-                LEFT JOIN events
-                    ON events.event_id = er.event_id
-                JOIN event_json
-                    ON event_json.event_id = er.event_id
-                JOIN room_stats_state
-                    ON room_stats_state.room_id = er.room_id
-                WHERE er.id = ?
-            """
-
-            txn.execute(sql, [report_id])
-            row = txn.fetchone()
-
-            if not row:
-                return None
-
-            event_report = {
-                "id": row[0],
-                "received_ts": row[1],
-                "room_id": row[2],
-                "event_id": row[3],
-                "user_id": row[4],
-                "score": db_to_json(row[5]).get("score"),
-                "reason": db_to_json(row[5]).get("reason"),
-                "sender": row[6],
-                "canonical_alias": row[7],
-                "name": row[8],
-                "event_json": db_to_json(row[9]),
-            }
-
-            return event_report
-
-        return await self.db_pool.runInteraction(
-            "get_event_report", _get_event_report_txn, report_id
-        )
-
-    async def get_event_reports_paginate(
-        self,
-        start: int,
-        limit: int,
-        direction: Direction = Direction.BACKWARDS,
-        user_id: Optional[str] = None,
-        room_id: Optional[str] = None,
-    ) -> Tuple[List[Dict[str, Any]], int]:
-        """Retrieve a paginated list of event reports
-
-        Args:
-            start: event offset to begin the query from
-            limit: number of rows to retrieve
-            direction: Whether to fetch the most recent first (backwards) or the
-                oldest first (forwards)
-            user_id: search for user_id. Ignored if user_id is None
-            room_id: search for room_id. Ignored if room_id is None
-        Returns:
-            Tuple of:
-                json list of event reports
-                total number of event reports matching the filter criteria
-        """
-
-        def _get_event_reports_paginate_txn(
-            txn: LoggingTransaction,
-        ) -> Tuple[List[Dict[str, Any]], int]:
-            filters = []
-            args: List[object] = []
-
-            if user_id:
-                filters.append("er.user_id LIKE ?")
-                args.extend(["%" + user_id + "%"])
-            if room_id:
-                filters.append("er.room_id LIKE ?")
-                args.extend(["%" + room_id + "%"])
-
-            if direction == Direction.BACKWARDS:
-                order = "DESC"
-            else:
-                order = "ASC"
-
-            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
-
-            # We join on room_stats_state despite not using any columns from it
-            # because the join can influence the number of rows returned;
-            # e.g. a room that doesn't have state, maybe because it was deleted.
-            # The query returning the total count should be consistent with
-            # the query returning the results.
-            sql = """
-                SELECT COUNT(*) as total_event_reports
-                FROM event_reports AS er
-                JOIN room_stats_state ON room_stats_state.room_id = er.room_id
-                {}
-                """.format(
-                where_clause
-            )
-            txn.execute(sql, args)
-            count = cast(Tuple[int], txn.fetchone())[0]
-
-            sql = """
-                SELECT
-                    er.id,
-                    er.received_ts,
-                    er.room_id,
-                    er.event_id,
-                    er.user_id,
-                    er.content,
-                    events.sender,
-                    room_stats_state.canonical_alias,
-                    room_stats_state.name
-                FROM event_reports AS er
-                LEFT JOIN events
-                    ON events.event_id = er.event_id
-                JOIN room_stats_state
-                    ON room_stats_state.room_id = er.room_id
-                {where_clause}
-                ORDER BY er.received_ts {order}
-                LIMIT ?
-                OFFSET ?
-            """.format(
-                where_clause=where_clause,
-                order=order,
-            )
-
-            args += [limit, start]
-            txn.execute(sql, args)
-
-            event_reports = []
-            for row in txn:
-                try:
-                    s = db_to_json(row[5]).get("score")
-                    r = db_to_json(row[5]).get("reason")
-                except Exception:
-                    logger.error("Unable to parse json from event_reports: %s", row[0])
-                    continue
-                event_reports.append(
-                    {
-                        "id": row[0],
-                        "received_ts": row[1],
-                        "room_id": row[2],
-                        "event_id": row[3],
-                        "user_id": row[4],
-                        "score": s,
-                        "reason": r,
-                        "sender": row[6],
-                        "canonical_alias": row[7],
-                        "name": row[8],
-                    }
-                )
-
-            return event_reports, count
-
-        return await self.db_pool.runInteraction(
-            "get_event_reports_paginate", _get_event_reports_paginate_txn
-        )
+        return next_id
 
     async def block_room(self, room_id: str, user_id: str) -> None:
         """Marks the room as blocked.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ea6a5e2f34..694a5b802c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -24,6 +24,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Union,
@@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return self._known_servers_count
 
     @cached(max_entries=100000, iterable=True)
-    async def get_users_in_room(self, room_id: str) -> List[str]:
+    async def get_users_in_room(self, room_id: str) -> Sequence[str]:
         """Returns a list of users in the room.
 
         Will return inaccurate results for rooms with partial state, since the state for
@@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached()
-    def get_user_in_room_with_profile(
-        self, room_id: str, user_id: str
-    ) -> Dict[str, ProfileInfo]:
+    def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
         raise NotImplementedError()
 
     @cachedList(
@@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached(max_entries=100000, iterable=True)
     async def get_users_in_room_with_profiles(
         self, room_id: str
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Mapping[str, ProfileInfo]:
         """Get a mapping from user ID to profile information for all users in a given room.
 
         The profile information comes directly from this room's `m.room.member`
@@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached(max_entries=100000)
-    async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
+    async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
         """Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
         Args:
@@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached()
     async def get_invited_rooms_for_local_user(
         self, user_id: str
-    ) -> List[RoomsForUser]:
+    ) -> Sequence[RoomsForUser]:
         """Get all the rooms the *local* user is invited to.
 
         Args:
@@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return results
 
     @cached(iterable=True)
-    async def get_local_users_in_room(self, room_id: str) -> List[str]:
+    async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
         """
         Retrieves a list of the current roommembers who are local to the server.
         """
@@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Returns the set of users who share a room with `user_id`"""
         room_ids = await self.get_rooms_for_user(user_id)
 
-        user_who_share_room = set()
+        user_who_share_room: Set[str] = set()
         for room_id in room_ids:
             user_ids = await self.get_users_in_room(room_id)
             user_who_share_room.update(user_ids)
@@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return True
 
     @cached(iterable=True, max_entries=10000)
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
         """Get current hosts in room based on current state."""
 
         # First we check if we already have `get_users_in_room` in the cache, as
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3fe433f66c..a7aae661d8 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore):
 
 
 class SearchBackgroundUpdateStore(SearchWorkerStore):
-
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
@@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore):
             """
             count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
-
             # We use CROSS JOIN here to ensure we use the right indexes.
             # https://sqlite.org/optoverview.html#crossjoin
             #
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 05da15074a..5dcb1fc0b5 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Collection, Dict, List, Tuple
+from typing import Collection, Dict, List, Mapping, Tuple
 
 from unpaddedbase64 import encode_base64
 
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
 
 class SignatureWorkerStore(EventsWorkerStore):
     @cached()
-    def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
+    def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]:
         # This is a dummy function to allow get_event_reference_hashes
         # to use its cache
         raise NotImplementedError()
@@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore):
     )
     async def get_event_reference_hashes(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Dict[str, bytes]]:
+    ) -> Mapping[str, Mapping[str, bytes]]:
         """Get all hashes for given events.
 
         Args:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ba325d390b..ebb2ae964f 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
 
 class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
-
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
     DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7b7d0c3c9..d3393d8e49 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore):
         insert_cols = []
         qargs = []
 
-        for (key, val) in chain(
+        for key, val in chain(
             keyvalues.items(), absolutes.items(), additive_relatives.items()
         ):
             insert_cols.append(key)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 818c46182e..ac5fbf6b86 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000
 _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
+
 # Used as return values for pagination APIs
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _EventDictReturn:
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index d5500cdd47..c149a9eacb 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, Iterable, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.tcp.streams import AccountDataStream
@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
 
 class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
-    async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
+    async def get_tags_for_user(
+        self, user_id: str
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
         """Get all the tags for a user.
 
 
@@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
     async def get_updated_tags(
         self, user_id: str, stream_id: int
-    ) -> Dict[str, Dict[str, JsonDict]]:
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
         """Get all the tags for the rooms where the tags have changed since the
         given version
 
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6b33d809b6..6d72bd9f67 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         def get_destination_rooms_paginate_txn(
             txn: LoggingTransaction,
         ) -> Tuple[List[JsonDict], int]:
-
             if direction == Direction.BACKWARDS:
                 order = "DESC"
             else:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 14ef5b040d..f16a509ac4 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,11 +14,12 @@
 
 import logging
 import re
+import unicodedata
 from typing import (
     TYPE_CHECKING,
-    Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Sequence,
     Set,
@@ -98,7 +99,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     async def _populate_user_directory_createtables(
         self, progress: JsonDict, batch_size: int
     ) -> int:
-
         # Get all the rooms that we want to process.
         def _make_staging_area(txn: LoggingTransaction) -> None:
             sql = (
@@ -491,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 values={"display_name": display_name, "avatar_url": avatar_url},
             )
 
+            # The display name that goes into the database index.
+            index_display_name = display_name
+            if index_display_name is not None:
+                index_display_name = _filter_text_for_index(index_display_name)
+
             if isinstance(self.database_engine, PostgresEngine):
                 # We weight the localpart most highly, then display name and finally
                 # server name
@@ -508,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                         user_id,
                         get_localpart_from_id(user_id),
                         get_domain_from_id(user_id),
-                        display_name,
+                        index_display_name,
                     ),
                 )
             elif isinstance(self.database_engine, Sqlite3Engine):
-                value = "%s %s" % (user_id, display_name) if display_name else user_id
+                value = (
+                    "%s %s" % (user_id, index_display_name)
+                    if index_display_name
+                    else user_id
+                )
                 self.db_pool.simple_upsert_txn(
                     txn,
                     table="user_directory_search",
@@ -586,7 +595,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
     @cached()
-    async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
+    async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
         return await self.db_pool.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
@@ -897,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         return {"limited": limited, "results": results[0:limit]}
 
 
+def _filter_text_for_index(text: str) -> str:
+    """Transforms text before it is inserted into the user directory index, or searched
+    for in the user directory index.
+
+    Note that the user directory search table needs to be rebuilt whenever this function
+    changes.
+    """
+    # Lowercase the text, to make searches case-insensitive.
+    # This is necessary for both PostgreSQL and SQLite. PostgreSQL's
+    # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using
+    # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at
+    # all.
+    text = text.lower()
+
+    # Normalize the text. NFKC normalization has two effects:
+    #  1. It canonicalizes the text, ie. maps all visually identical strings to the same
+    #     string. For example, ["e", "◌́"] is mapped to ["é"].
+    #  2. It maps strings that are roughly equivalent to the same string.
+    #     For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to
+    #     ["i", "9"].
+    text = unicodedata.normalize("NFKC", text)
+
+    # Note that nothing is done to make searches accent-insensitive.
+    # That could be achieved by converting to NFKD form instead (with combining accents
+    # split out) and filtering out combining accents using `unicodedata.combining(c)`.
+    # The downside of this may be noisier search results, since search terms with
+    # explicit accents will match characters with no accents, or completely different
+    # accents.
+    #
+    # text = unicodedata.normalize("NFKD", text)
+    # text = "".join([c for c in text if not unicodedata.combining(c)])
+
+    return text
+
+
 def _parse_query_sqlite(search_term: str) -> str:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
@@ -906,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str:
     We specifically add both a prefix and non prefix matching term so that
     exact matches get ranked higher.
     """
+    search_term = _filter_text_for_index(search_term)
 
     # Pull out the individual words, discarding any non-word characters.
     results = _parse_words(search_term)
@@ -918,11 +963,21 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
     We use this so that we can add prefix matching, which isn't something
     that is supported by default.
     """
-    results = _parse_words(search_term)
+    search_term = _filter_text_for_index(search_term)
+
+    escaped_words = []
+    for word in _parse_words(search_term):
+        # Postgres tsvector and tsquery quoting rules:
+        # words potentially containing punctuation should be quoted
+        # and then existing quotes and backslashes should be doubled
+        # See: https://www.postgresql.org/docs/current/datatype-textsearch.html#DATATYPE-TSQUERY
 
-    both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
-    exact = " & ".join("%s" % (result,) for result in results)
-    prefix = " & ".join("%s:*" % (result,) for result in results)
+        quoted_word = word.replace("'", "''").replace("\\", "\\\\")
+        escaped_words.append(f"'{quoted_word}'")
+
+    both = " & ".join("(%s:* | %s)" % (word, word) for word in escaped_words)
+    exact = " & ".join("%s" % (word,) for word in escaped_words)
+    prefix = " & ".join("%s:*" % (word,) for word in escaped_words)
 
     return both, exact, prefix
 
@@ -944,6 +999,14 @@ def _parse_words(search_term: str) -> List[str]:
     if USE_ICU:
         return _parse_words_with_icu(search_term)
 
+    return _parse_words_with_regex(search_term)
+
+
+def _parse_words_with_regex(search_term: str) -> List[str]:
+    """
+    Break down search term into words, when we don't have ICU available.
+    See: `_parse_words`
+    """
     return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
 
 
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index d743282f13..097dea5182 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
 
 
 class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
-
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 1a7232b276..29ff64e876 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
 import attr
 
 from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -257,14 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         member_filter, non_member_filter = state_filter.get_member_split()
 
         # Now we look them up in the member and non-member caches
-        (
-            non_member_state,
-            incomplete_groups_nm,
-        ) = self._get_state_for_groups_using_cache(
+        non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache(
             groups, self._state_group_cache, state_filter=non_member_filter
         )
 
-        (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
+        member_state, incomplete_groups_m = self._get_state_for_groups_using_cache(
             groups, self._state_group_members_cache, state_filter=member_filter
         )
 
@@ -404,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 fetched_keys=non_member_types,
             )
 
+    async def store_state_deltas_for_batched(
+        self,
+        events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
+        room_id: str,
+        prev_group: int,
+    ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+        """Generate and store state deltas for a group of events and contexts created to be
+        batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
+
+        Args:
+            events_and_context: the events to generate and store a state groups for
+            and their associated contexts
+            room_id: the id of the room the events were created for
+            prev_group: the state group of the last event persisted before the batched events
+            were created
+        """
+
+        def insert_deltas_group_txn(
+            txn: LoggingTransaction,
+            events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
+            prev_group: int,
+        ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+            """Generate and store state groups for the provided events and contexts.
+
+            Requires that we have the state as a delta from the last persisted state group.
+
+            Returns:
+                A list of state groups
+            """
+            is_in_db = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="state_groups",
+                keyvalues={"id": prev_group},
+                retcol="id",
+                allow_none=True,
+            )
+            if not is_in_db:
+                raise Exception(
+                    "Trying to persist state with unpersisted prev_group: %r"
+                    % (prev_group,)
+                )
+
+            num_state_groups = sum(
+                1 for event, _ in events_and_context if event.is_state()
+            )
+
+            state_groups = self._state_group_seq_gen.get_next_mult_txn(
+                txn, num_state_groups
+            )
+
+            sg_before = prev_group
+            state_group_iter = iter(state_groups)
+            for event, context in events_and_context:
+                if not event.is_state():
+                    context.state_group_after_event = sg_before
+                    context.state_group_before_event = sg_before
+                    continue
+
+                sg_after = next(state_group_iter)
+                context.state_group_after_event = sg_after
+                context.state_group_before_event = sg_before
+                context.state_delta_due_to_event = {
+                    (event.type, event.state_key): event.event_id
+                }
+                sg_before = sg_after
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups",
+                keys=("id", "room_id", "event_id"),
+                values=[
+                    (context.state_group_after_event, room_id, event.event_id)
+                    for event, context in events_and_context
+                    if event.is_state()
+                ],
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_group_edges",
+                keys=("state_group", "prev_state_group"),
+                values=[
+                    (
+                        context.state_group_after_event,
+                        context.state_group_before_event,
+                    )
+                    for event, context in events_and_context
+                    if event.is_state()
+                ],
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (
+                        context.state_group_after_event,
+                        room_id,
+                        key[0],
+                        key[1],
+                        state_id,
+                    )
+                    for event, context in events_and_context
+                    if context.state_delta_due_to_event is not None
+                    for key, state_id in context.state_delta_due_to_event.items()
+                ],
+            )
+            return events_and_context
+
+        return await self.db_pool.runInteraction(
+            "store_state_deltas_for_batched.insert_deltas_group",
+            insert_deltas_group_txn,
+            events_and_context,
+            prev_group,
+        )
+
     async def store_state_group(
         self,
         event_id: str,
@@ -689,12 +805,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state_groups_to_delete: State groups to delete
         """
 
+        logger.info("[purge] Starting state purge")
         await self.db_pool.runInteraction(
             "purge_room_state",
             self._purge_room_state_txn,
             room_id,
             state_groups_to_delete,
         )
+        logger.info("[purge] Done with state purge")
 
     def _purge_room_state_txn(
         self,
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index a182e8a098..d1ccb7390a 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -25,7 +25,7 @@ try:
 except ImportError:
 
     class PostgresEngine(BaseDatabaseEngine):  # type: ignore[no-redef]
-        def __new__(cls, *args: object, **kwargs: object) -> NoReturn:  # type: ignore[misc]
+        def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
             raise RuntimeError(
                 f"Cannot create {cls.__name__} -- psycopg2 module is not installed"
             )
@@ -36,7 +36,7 @@ try:
 except ImportError:
 
     class Sqlite3Engine(BaseDatabaseEngine):  # type: ignore[no-redef]
-        def __new__(cls, *args: object, **kwargs: object) -> NoReturn:  # type: ignore[misc]
+        def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
             raise RuntimeError(
                 f"Cannot create {cls.__name__} -- sqlite3 module is not installed"
             )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 3acdb39da7..2a1c6fa31b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -23,7 +23,7 @@ from typing_extensions import Counter as CounterType
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION
 from synapse.storage.types import Cursor
 
@@ -108,9 +108,14 @@ def prepare_database(
         # so we start one before running anything. This ensures that any upgrades
         # are either applied completely, or not at all.
         #
-        # (psycopg2 automatically starts a transaction as soon as we run any statements
-        # at all, so this is redundant but harmless there.)
-        cur.execute("BEGIN TRANSACTION")
+        # psycopg2 does not automatically start transactions when in autocommit mode.
+        # While it is technically harmless to nest transactions in postgres, doing so
+        # results in a warning in Postgres' logs per query. And we'd rather like to
+        # avoid doing that.
+        if isinstance(database_engine, Sqlite3Engine) or (
+            isinstance(database_engine, PostgresEngine) and db_conn.autocommit
+        ):
+            cur.execute("BEGIN TRANSACTION")
 
         logger.info("%r: Checking existing schema version", databases)
         version_info = _get_or_create_schema_state(cur, database_engine)
@@ -558,7 +563,7 @@ def _apply_module_schemas(
     """
     # This is the old way for password_auth_provider modules to make changes
     # to the database. This should instead be done using the module API
-    for (mod, _config) in config.authproviders.password_providers:
+    for mod, _config in config.authproviders.password_providers:
         if not hasattr(mod, "get_db_schema_files"):
             continue
         modname = ".".join((mod.__module__, mod.__name__))
@@ -586,7 +591,7 @@ def _apply_module_schema_files(
         (modname,),
     )
     applied_deltas = {d for d, in cur}
-    for (name, stream) in names_and_streams:
+    for name, stream in names_and_streams:
         if name in applied_deltas:
             continue
 
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 19dbf2da7f..d3103a6c7a 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-SCHEMA_VERSION = 73  # remember to update the list below when updating
+SCHEMA_VERSION = 74  # remember to update the list below when updating
 """Represents the expectations made by the codebase about the database schema
 
 This should be incremented whenever the codebase changes its requirements on the
@@ -78,7 +78,7 @@ Changes in SCHEMA_VERSION = 72:
     - Unused column application_services_state.last_txn is dropped
     - Cache invalidation stream id sequence now begins at 2 to match code expectation.
 
-Changes in SCHEMA_VERSION = 73;
+Changes in SCHEMA_VERSION = 73:
     - thread_id column is added to event_push_actions, event_push_actions_staging
       event_push_summary, receipts_linearized, and receipts_graph.
     - Add table `event_failed_pull_attempts` to keep track when we fail to pull
@@ -86,6 +86,11 @@ Changes in SCHEMA_VERSION = 73;
     - Add indexes to various tables (`event_failed_pull_attempts`, `insertion_events`,
       `batch_events`) to make it easy to delete all associated rows when purging a room.
     - `inserted_ts` column is added to `event_push_actions_staging` table.
+
+Changes in SCHEMA_VERSION = 74:
+    - A query on `event_stream_ordering` column has now been disambiguated (i.e. the
+      codebase can handle the `current_state_events`, `local_current_memberships` and
+      `room_memberships` tables having an `event_stream_ordering` column).
 """
 
 
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 0031df1e06..56a0048539 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from types import TracebackType
-from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Iterator,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+)
 
 from typing_extensions import Protocol
 
@@ -112,15 +123,35 @@ class DBAPI2Module(Protocol):
     #   extends from this hierarchy. See
     #     https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions
     #     https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
-    Warning: Type[Exception]
-    Error: Type[Exception]
+    #
+    # Note: rather than
+    #     x: T
+    # we write
+    #     @property
+    #     def x(self) -> T: ...
+    # which expresses that the protocol attribute `x` is read-only. The mypy docs
+    #     https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected
+    # explain why this is necessary for safety. TL;DR: we shouldn't be able to write
+    # to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 .
+    @property
+    def Warning(self) -> Type[Exception]:
+        ...
+
+    @property
+    def Error(self) -> Type[Exception]:
+        ...
 
     # Errors are divided into `InterfaceError`s (something went wrong in the database
     # driver) and `DatabaseError`s (something went wrong in the database). These are
     # both subclasses of `Error`, but we can't currently express this in type
     # annotations due to https://github.com/python/mypy/issues/8397
-    InterfaceError: Type[Exception]
-    DatabaseError: Type[Exception]
+    @property
+    def InterfaceError(self) -> Type[Exception]:
+        ...
+
+    @property
+    def DatabaseError(self) -> Type[Exception]:
+        ...
 
     # Everything below is a subclass of `DatabaseError`.
 
@@ -128,7 +159,9 @@ class DBAPI2Module(Protocol):
     # - An integer was too big for its data type.
     # - An invalid date time was provided.
     # - A string contained a null code point.
-    DataError: Type[Exception]
+    @property
+    def DataError(self) -> Type[Exception]:
+        ...
 
     # Roughly: something went wrong in the database, but it's not within the application
     # programmer's control. Examples:
@@ -138,28 +171,45 @@ class DBAPI2Module(Protocol):
     # - A serialisation failure occurred.
     # - The database ran out of resources, such as storage, memory, connections, etc.
     # - The database encountered an error from the operating system.
-    OperationalError: Type[Exception]
+    @property
+    def OperationalError(self) -> Type[Exception]:
+        ...
 
     # Roughly: we've given the database data which breaks a rule we asked it to enforce.
     # Examples:
     # - Stop, criminal scum! You violated the foreign key constraint
     # - Also check constraints, non-null constraints, etc.
-    IntegrityError: Type[Exception]
+    @property
+    def IntegrityError(self) -> Type[Exception]:
+        ...
 
     # Roughly: something went wrong within the database server itself.
-    InternalError: Type[Exception]
+    @property
+    def InternalError(self) -> Type[Exception]:
+        ...
 
     # Roughly: the application did something silly that needs to be fixed. Examples:
     # - We don't have permissions to do something.
     # - We tried to create a table with duplicate column names.
     # - We tried to use a reserved name.
     # - We referred to a column that doesn't exist.
-    ProgrammingError: Type[Exception]
+    @property
+    def ProgrammingError(self) -> Type[Exception]:
+        ...
 
     # Roughly: we've tried to do something that this database doesn't support.
-    NotSupportedError: Type[Exception]
+    @property
+    def NotSupportedError(self) -> Type[Exception]:
+        ...
 
-    def connect(self, **parameters: object) -> Connection:
+    # We originally wrote
+    # def connect(self, *args, **kwargs) -> Connection: ...
+    # But mypy doesn't seem to like that because sqlite3.connect takes a mandatory
+    # positional argument. We can't make that part of the signature though, because
+    # psycopg2.connect doesn't have a mandatory positional argument. Instead, we use
+    # the following slightly unusual workaround.
+    @property
+    def connect(self) -> Callable[..., Connection]:
         ...
 
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9adff3f4f5..d2c874b9a8 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -93,8 +93,11 @@ def _load_current_id(
     return res
 
 
-class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
-    """Tracks the "current" stream ID of a stream that may have multiple writers.
+class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
+    """Generates or tracks stream IDs for a stream that may have multiple writers.
+
+    Each stream ID represents a write transaction, whose completion is tracked
+    so that the "current" stream ID of the stream can be determined.
 
     Stream IDs are monotonically increasing or decreasing integers representing write
     transactions. The "current" stream ID is the stream ID such that all transactions
@@ -130,16 +133,6 @@ class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
         """
         raise NotImplementedError()
 
-
-class AbstractStreamIdGenerator(AbstractStreamIdTracker):
-    """Generates stream IDs for a stream that may have multiple writers.
-
-    Each stream ID represents a write transaction, whose completion is tracked
-    so that the "current" stream ID of the stream can be determined.
-
-    See `AbstractStreamIdTracker` for more details.
-    """
-
     @abc.abstractmethod
     def get_next(self) -> AsyncContextManager[int]:
         """
@@ -158,6 +151,15 @@ class AbstractStreamIdGenerator(AbstractStreamIdTracker):
         """
         raise NotImplementedError()
 
+    @abc.abstractmethod
+    def get_next_txn(self, txn: LoggingTransaction) -> int:
+        """
+        Usage:
+            stream_id_gen.get_next_txn(txn)
+            # ... persist events ...
+        """
+        raise NotImplementedError()
+
 
 class StreamIdGenerator(AbstractStreamIdGenerator):
     """Generates and tracks stream IDs for a stream with a single writer.
@@ -263,6 +265,40 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
 
         return _AsyncCtxManagerWrapper(manager())
 
+    def get_next_txn(self, txn: LoggingTransaction) -> int:
+        """
+        Retrieve the next stream ID from within a database transaction.
+
+        Clean-up functions will be called when the transaction finishes.
+
+        Args:
+            txn: The database transaction object.
+
+        Returns:
+            The next stream ID.
+        """
+        if not self._is_writer:
+            raise Exception("Tried to allocate stream ID on non-writer")
+
+        # Get the next stream ID.
+        with self._lock:
+            self._current += self._step
+            next_id = self._current
+
+            self._unfinished_ids[next_id] = next_id
+
+        def clear_unfinished_id(id_to_clear: int) -> None:
+            """A function to mark processing this ID as finished"""
+            with self._lock:
+                self._unfinished_ids.pop(id_to_clear)
+
+        # Mark this ID as finished once the database transaction itself finishes.
+        txn.call_after(clear_unfinished_id, next_id)
+        txn.call_on_exception(clear_unfinished_id, next_id)
+
+        # Return the new ID.
+        return next_id
+
     def get_current_token(self) -> int:
         if not self._is_writer:
             return self._current
@@ -568,7 +604,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         """
         Usage:
 
-            stream_id = stream_id_gen.get_next(txn)
+            stream_id = stream_id_gen.get_next_txn(txn)
             # ... persist event ...
         """
 
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 75268cbe15..80915216de 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -205,7 +205,7 @@ class LocalSequenceGenerator(SequenceGenerator):
         """
         Args:
             get_first_callback: a callback which is called on the first call to
-                 get_next_id_txn; should return the curreent maximum id
+                 get_next_id_txn; should return the current maximum id
         """
         # the callback. this is cleared after it is called, so that it can be GCed.
         self._callback: Optional[GetFirstCallbackType] = get_first_callback