summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorMathieu Velten <mathieuv@matrix.org>2023-03-13 13:59:16 +0100
committerMathieu Velten <mathieuv@matrix.org>2023-03-13 13:59:16 +0100
commit5980756e0900d23f83926a7223644586183fd0b5 (patch)
treea3513d14a88d1bed4d649c4945535888c2d91ef2 /synapse/storage/databases
parentUpdate changelog.d/14519.misc (diff)
parentBump hiredis from 2.2.1 to 2.2.2 (#15252) (diff)
downloadsynapse-github/mv/mypy-unused-awaitable.tar.xz
Merge remote-tracking branch 'origin/develop' into mv/mypy-unused-awaitable github/mv/mypy-unused-awaitable mv/mypy-unused-awaitable
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/account_data.py201
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/cache.py3
-rw-r--r--synapse/storage/databases/main/deviceinbox.py5
-rw-r--r--synapse/storage/databases/main/devices.py60
-rw-r--r--synapse/storage/databases/main/directory.py4
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py59
-rw-r--r--synapse/storage/databases/main/event_federation.py12
-rw-r--r--synapse/storage/databases/main/event_push_actions.py7
-rw-r--r--synapse/storage/databases/main/events.py35
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py13
-rw-r--r--synapse/storage/databases/main/events_worker.py15
-rw-r--r--synapse/storage/databases/main/filtering.py25
-rw-r--r--synapse/storage/databases/main/media_repository.py1
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py4
-rw-r--r--synapse/storage/databases/main/purge_events.py17
-rw-r--r--synapse/storage/databases/main/push_rule.py3
-rw-r--r--synapse/storage/databases/main/pusher.py6
-rw-r--r--synapse/storage/databases/main/receipts.py17
-rw-r--r--synapse/storage/databases/main/registration.py17
-rw-r--r--synapse/storage/databases/main/relations.py138
-rw-r--r--synapse/storage/databases/main/room.py391
-rw-r--r--synapse/storage/databases/main/roommember.py19
-rw-r--r--synapse/storage/databases/main/search.py2
-rw-r--r--synapse/storage/databases/main/signatures.py6
-rw-r--r--synapse/storage/databases/main/state.py1
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--synapse/storage/databases/main/stream.py1
-rw-r--r--synapse/storage/databases/main/tags.py8
-rw-r--r--synapse/storage/databases/main/transactions.py1
-rw-r--r--synapse/storage/databases/main/user_directory.py81
-rw-r--r--synapse/storage/databases/state/bg_updates.py1
-rw-r--r--synapse/storage/databases/state/store.py128
35 files changed, 720 insertions, 571 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 837dc7646e..dc3948c170 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -43,7 +43,7 @@ from .event_federation import EventFederationStore
 from .event_push_actions import EventPushActionsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
 from .events_forward_extremities import EventForwardExtremitiesStore
-from .filtering import FilteringStore
+from .filtering import FilteringWorkerStore
 from .keys import KeyStore
 from .lock import LockStore
 from .media_repository import MediaRepositoryStore
@@ -99,7 +99,7 @@ class DataStore(
     EventFederationStore,
     MediaRepositoryStore,
     RejectionsStore,
-    FilteringStore,
+    FilteringWorkerStore,
     PusherStore,
     PushRuleStore,
     ApplicationServiceTransactionStore,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 8a359d7eb8..a9843f6e17 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -21,6 +21,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Tuple,
     cast,
@@ -39,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
@@ -63,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     ):
         super().__init__(database, db_conn, hs)
 
-        # `_can_write_to_account_data` indicates whether the current worker is allowed
-        # to write account data. A value of `True` implies that `_account_data_id_gen`
-        # is an `AbstractStreamIdGenerator` and not just a tracker.
-        self._account_data_id_gen: AbstractStreamIdTracker
         self._can_write_to_account_data = (
             self._instance_name in hs.config.worker.writers.account_data
         )
 
+        self._account_data_id_gen: AbstractStreamIdGenerator
+
         if isinstance(database.engine, PostgresEngine):
             self._account_data_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
@@ -122,25 +120,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         return self._account_data_id_gen.get_current_token()
 
     @cached()
-    async def get_account_data_for_user(
+    async def get_global_account_data_for_user(
         self, user_id: str
-    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+    ) -> Mapping[str, JsonDict]:
         """
-        Get all the client account_data for a user.
+        Get all the global client account_data for a user.
 
         If experimental MSC3391 support is enabled, any entries with an empty
         content body are excluded; as this means they have been deleted.
 
         Args:
             user_id: The user to get the account_data for.
+
         Returns:
-            A 2-tuple of a dict of global account_data and a dict mapping from
-            room_id string to per room account_data dicts.
+            The global account_data.
         """
 
-        def get_account_data_for_user_txn(
+        def get_global_account_data_for_user(
             txn: LoggingTransaction,
-        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+        ) -> Dict[str, JsonDict]:
             # The 'content != '{}' condition below prevents us from using
             # `simple_select_list_txn` here, as it doesn't support conditions
             # other than 'equals'.
@@ -158,10 +156,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn.execute(sql, (user_id,))
             rows = self.db_pool.cursor_to_dict(txn)
 
-            global_account_data = {
+            return {
                 row["account_data_type"]: db_to_json(row["content"]) for row in rows
             }
 
+        return await self.db_pool.runInteraction(
+            "get_global_account_data_for_user", get_global_account_data_for_user
+        )
+
+    @cached()
+    async def get_room_account_data_for_user(
+        self, user_id: str
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
+        """
+        Get all of the per-room client account_data for a user.
+
+        If experimental MSC3391 support is enabled, any entries with an empty
+        content body are excluded; as this means they have been deleted.
+
+        Args:
+            user_id: The user to get the account_data for.
+
+        Returns:
+            A dict mapping from room_id string to per-room account_data dicts.
+        """
+
+        def get_room_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, JsonDict]]:
             # The 'content != '{}' condition below prevents us from using
             # `simple_select_list_txn` here, as it doesn't support conditions
             # other than 'equals'.
@@ -185,10 +207,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
                 room_data[row["account_data_type"]] = db_to_json(row["content"])
 
-            return global_account_data, by_room
+            return by_room
 
         return await self.db_pool.runInteraction(
-            "get_account_data_for_user", get_account_data_for_user_txn
+            "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn
         )
 
     @cached(num_args=2, max_entries=5000, tree=True)
@@ -212,10 +234,41 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         else:
             return None
 
+    async def get_latest_stream_id_for_global_account_data_by_type_for_user(
+        self, user_id: str, data_type: str
+    ) -> Optional[int]:
+        """
+        Returns:
+            The stream ID of the account data,
+            or None if there is no such account data.
+        """
+
+        def get_latest_stream_id_for_global_account_data_by_type_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[int]:
+            sql = """
+                SELECT stream_id FROM account_data
+                WHERE user_id = ? AND account_data_type = ?
+                ORDER BY stream_id DESC
+                LIMIT 1
+            """
+            txn.execute(sql, (user_id, data_type))
+
+            row = txn.fetchone()
+            if row:
+                return row[0]
+            else:
+                return None
+
+        return await self.db_pool.runInteraction(
+            "get_latest_stream_id_for_global_account_data_by_type_for_user",
+            get_latest_stream_id_for_global_account_data_by_type_for_user_txn,
+        )
+
     @cached(num_args=2, tree=True)
     async def get_account_data_for_room(
         self, user_id: str, room_id: str
-    ) -> Dict[str, JsonDict]:
+    ) -> Mapping[str, JsonDict]:
         """Get all the client account_data for a user for a room.
 
         Args:
@@ -342,36 +395,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             "get_updated_room_account_data", get_updated_room_account_data_txn
         )
 
-    async def get_updated_account_data_for_user(
+    async def get_updated_global_account_data_for_user(
         self, user_id: str, stream_id: int
-    ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
-        """Get all the client account_data for a that's changed for a user
+    ) -> Dict[str, JsonDict]:
+        """Get all the global account_data that's changed for a user.
 
         Args:
             user_id: The user to get the account_data for.
             stream_id: The point in the stream since which to get updates
+
         Returns:
-            A deferred pair of a dict of global account_data and a dict
-            mapping from room_id string to per room account_data dicts.
+            A dict of global account_data.
         """
 
-        def get_updated_account_data_for_user_txn(
+        def get_updated_global_account_data_for_user(
             txn: LoggingTransaction,
-        ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
-            sql = (
-                "SELECT account_data_type, content FROM account_data"
-                " WHERE user_id = ? AND stream_id > ?"
-            )
-
+        ) -> Dict[str, JsonDict]:
+            sql = """
+                SELECT account_data_type, content FROM account_data
+                WHERE user_id = ? AND stream_id > ?
+            """
             txn.execute(sql, (user_id, stream_id))
 
-            global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
+            return {row[0]: db_to_json(row[1]) for row in txn}
 
-            sql = (
-                "SELECT room_id, account_data_type, content FROM room_account_data"
-                " WHERE user_id = ? AND stream_id > ?"
-            )
+        changed = self._account_data_stream_cache.has_entity_changed(
+            user_id, int(stream_id)
+        )
+        if not changed:
+            return {}
+
+        return await self.db_pool.runInteraction(
+            "get_updated_global_account_data_for_user",
+            get_updated_global_account_data_for_user,
+        )
+
+    async def get_updated_room_account_data_for_user(
+        self, user_id: str, stream_id: int
+    ) -> Dict[str, Dict[str, JsonDict]]:
+        """Get all the room account_data that's changed for a user.
 
+        Args:
+            user_id: The user to get the account_data for.
+            stream_id: The point in the stream since which to get updates
+
+        Returns:
+            A dict mapping from room_id string to per room account_data dicts.
+        """
+
+        def get_updated_room_account_data_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, Dict[str, JsonDict]]:
+            sql = """
+                SELECT room_id, account_data_type, content FROM room_account_data
+                WHERE user_id = ? AND stream_id > ?
+            """
             txn.execute(sql, (user_id, stream_id))
 
             account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
@@ -379,16 +457,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 room_account_data = account_data_by_room.setdefault(row[0], {})
                 room_account_data[row[1]] = db_to_json(row[2])
 
-            return global_account_data, account_data_by_room
+            return account_data_by_room
 
         changed = self._account_data_stream_cache.has_entity_changed(
             user_id, int(stream_id)
         )
         if not changed:
-            return {}, {}
+            return {}
 
         return await self.db_pool.runInteraction(
-            "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
+            "get_updated_room_account_data_for_user",
+            get_updated_room_account_data_for_user_txn,
         )
 
     @cached(max_entries=5000, iterable=True)
@@ -444,7 +523,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                     self.get_global_account_data_by_type_for_user.invalidate(
                         (row.user_id, row.data_type)
                     )
-                self.get_account_data_for_user.invalidate((row.user_id,))
+                self.get_global_account_data_for_user.invalidate((row.user_id,))
+                self.get_room_account_data_for_user.invalidate((row.user_id,))
                 self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
                 self.get_account_data_for_room_and_type.invalidate(
                     (row.user_id, row.room_id, row.data_type)
@@ -475,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         content_json = json_encoder.encode(content)
 
@@ -492,7 +571,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_room_account_data_for_user.invalidate((user_id,))
             self.get_account_data_for_room.invalidate((user_id, room_id))
             self.get_account_data_for_room_and_type.prefill(
                 (user_id, room_id, account_data_type), content
@@ -502,7 +581,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
     async def remove_account_data_for_room(
         self, user_id: str, room_id: str, account_data_type: str
-    ) -> Optional[int]:
+    ) -> int:
         """Delete the room account data for the user of a given type.
 
         Args:
@@ -515,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             data to delete.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def _remove_account_data_for_room_txn(
             txn: LoggingTransaction, next_id: int
@@ -554,15 +632,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 next_id,
             )
 
-            if not row_updated:
-                return None
-
-            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
-            self.get_account_data_for_room.invalidate((user_id, room_id))
-            self.get_account_data_for_room_and_type.prefill(
-                (user_id, room_id, account_data_type), {}
-            )
+            if row_updated:
+                self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+                self.get_room_account_data_for_user.invalidate((user_id,))
+                self.get_account_data_for_room.invalidate((user_id, room_id))
+                self.get_account_data_for_room_and_type.prefill(
+                    (user_id, room_id, account_data_type), {}
+                )
 
         return self._account_data_id_gen.get_current_token()
 
@@ -580,7 +656,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
@@ -593,7 +668,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
 
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
+            self.get_global_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
                 (user_id, account_data_type)
             )
@@ -670,7 +745,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         self,
         user_id: str,
         account_data_type: str,
-    ) -> Optional[int]:
+    ) -> int:
         """
         Delete a single piece of user account data by type.
 
@@ -687,7 +762,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             to delete.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def _remove_account_data_for_user_txn(
             txn: LoggingTransaction, next_id: int
@@ -757,14 +831,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 next_id,
             )
 
-            if not row_updated:
-                return None
-
-            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_account_data_for_user.invalidate((user_id,))
-            self.get_global_account_data_by_type_for_user.prefill(
-                (user_id, account_data_type), {}
-            )
+            if row_updated:
+                self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+                self.get_global_account_data_for_user.invalidate((user_id,))
+                self.get_global_account_data_by_type_for_user.prefill(
+                    (user_id, account_data_type), {}
+                )
 
         return self._account_data_id_gen.get_current_token()
 
@@ -822,7 +894,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn, self.get_account_data_for_room_and_type, (user_id,)
         )
         self._invalidate_cache_and_stream(
-            txn, self.get_account_data_for_user, (user_id,)
+            txn, self.get_global_account_data_for_user, (user_id,)
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_room_account_data_for_user, (user_id,)
         )
         self._invalidate_cache_and_stream(
             txn, self.get_global_account_data_by_type_for_user, (user_id,)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5fb152c4ff..484db175d0 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
         room_id: str,
         app_service: "ApplicationService",
         cache_context: _CacheContext,
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """
         Get all users in a room that the appservice controls.
 
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 5b66431691..096dec7f87 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -266,9 +266,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         if relates_to:
             self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
             self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
-            self._attempt_to_invalidate_cache(
-                "get_aggregation_groups_for_event", (relates_to,)
-            )
             self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
             self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
             self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8e61aba454..0d75d9739a 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                         ],
                     )
 
-                for (user_id, messages_by_device) in edu["messages"].items():
-                    for (device_id, msg) in messages_by_device.items():
+                for user_id, messages_by_device in edu["messages"].items():
+                    for device_id, msg in messages_by_device.items():
                         with start_active_span("store_outgoing_to_device_message"):
                             set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
                             set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
@@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
         def _remove_dead_devices_from_device_inbox_txn(
             txn: LoggingTransaction,
         ) -> Tuple[int, bool]:
-
             if "max_stream_id" in progress:
                 max_stream_id = progress["max_stream_id"]
             else:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e8b6cc6b80..5503621ad6 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -21,6 +21,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -51,7 +52,6 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     StreamIdGenerator,
 )
 from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
@@ -90,7 +90,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+        self._device_list_id_gen = StreamIdGenerator(
             db_conn,
             hs.get_replication_notifier(),
             "device_lists_stream",
@@ -100,6 +100,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 ("device_lists_outbound_pokes", "stream_id"),
                 ("device_lists_changes_in_room", "stream_id"),
                 ("device_lists_remote_pending", "stream_id"),
+                ("device_lists_changes_converted_stream_position", "stream_id"),
             ],
             is_writer=hs.config.worker.worker_app is None,
         )
@@ -201,7 +202,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
-    async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+    async def count_devices_by_users(
+        self, user_ids: Optional[Collection[str]] = None
+    ) -> int:
         """Retrieve number of all devices of given users.
         Only returns number of devices that are not marked as hidden.
 
@@ -212,7 +215,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         """
 
         def count_devices_by_users_txn(
-            txn: LoggingTransaction, user_ids: List[str]
+            txn: LoggingTransaction, user_ids: Collection[str]
         ) -> int:
             sql = """
                 SELECT count(*)
@@ -508,7 +511,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             results.append(("org.matrix.signing_key_update", result))
 
         if issue_8631_logger.isEnabledFor(logging.DEBUG):
-            for (user_id, edu) in results:
+            for user_id, edu in results:
                 issue_8631_logger.debug(
                     "device update to %s for %s from %s to %s: %s",
                     destination,
@@ -708,9 +711,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             The new stream ID.
         """
 
-        # TODO: this looks like it's _writing_. Should this be on DeviceStore rather
-        #  than DeviceWorkerStore?
-        async with self._device_list_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
+        async with self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -745,42 +746,47 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     @trace
     @cancellable
     async def get_user_devices_from_cache(
-        self, query_list: List[Tuple[str, Optional[str]]]
-    ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
+        self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
+    ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
         """Get the devices (and keys if any) for remote users from the cache.
 
         Args:
-            query_list: List of (user_id, device_ids), if device_ids is
-                falsey then return all device ids for that user.
+            user_ids: users which should have all device IDs returned
+            user_and_device_ids: List of (user_id, device_ids)
 
         Returns:
             A tuple of (user_ids_not_in_cache, results_map), where
             user_ids_not_in_cache is a set of user_ids and results_map is a
             mapping of user_id -> device_id -> device_info.
         """
-        user_ids = {user_id for user_id, _ in query_list}
-        user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+        unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
+        user_map = await self.get_device_list_last_stream_id_for_remotes(
+            list(unique_user_ids)
+        )
 
         # We go and check if any of the users need to have their device lists
         # resynced. If they do then we remove them from the cached list.
         users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
-            user_ids
+            unique_user_ids
         )
         user_ids_in_cache = {
             user_id for user_id, stream_id in user_map.items() if stream_id
         } - users_needing_resync
-        user_ids_not_in_cache = user_ids - user_ids_in_cache
-
-        results: Dict[str, Dict[str, JsonDict]] = {}
-        for user_id, device_id in query_list:
-            if user_id not in user_ids_in_cache:
-                continue
+        user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
 
-            if device_id:
-                device = await self._get_cached_user_device(user_id, device_id)
-                results.setdefault(user_id, {})[device_id] = device
-            else:
+        # First fetch all the users which all devices are to be returned.
+        results: Dict[str, Mapping[str, JsonDict]] = {}
+        for user_id in user_ids:
+            if user_id in user_ids_in_cache:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
+        # Then fetch all device-specific requests, but skip users we've already
+        # fetched all devices for.
+        device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
+        for user_id, device_id in user_and_device_ids:
+            if user_id in user_ids_in_cache and user_id not in user_ids:
+                device = await self._get_cached_user_device(user_id, device_id)
+                device_specific_results.setdefault(user_id, {})[device_id] = device
+        results.update(device_specific_results)
 
         set_tag("in_cache", str(results))
         set_tag("not_in_cache", str(user_ids_not_in_cache))
@@ -798,7 +804,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         return db_to_json(content)
 
     @cached()
-    async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+    async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
         devices = await self.db_pool.simple_select_list(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id},
@@ -1307,7 +1313,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 )
             """
             count = 0
-            for (destination, user_id, stream_id, device_id) in rows:
+            for destination, user_id, stream_id, device_id in rows:
                 txn.execute(
                     delete_sql, (destination, user_id, stream_id, stream_id, device_id)
                 )
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 5903fdaf00..44aa181174 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Sequence, Tuple
 
 import attr
 
@@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
         )
 
     @cached(max_entries=5000)
-    async def get_aliases_for_room(self, room_id: str) -> List[str]:
+    async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
         return await self.db_pool.simple_select_onecol(
             "room_aliases",
             {"room_id": room_id},
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 6240f9a75e..9f8d2e4bea 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             raise StoreError(404, "No backup with that version exists")
 
         values = []
-        for (room_id, session_id, room_key) in room_keys:
+        for room_id, session_id, room_key in room_keys:
             values.append(
                 (
                     user_id,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c4ac6c33ba..a3b6c8ae8e 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -20,7 +20,9 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
+    Sequence,
     Tuple,
     Union,
     cast,
@@ -242,9 +244,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         set_tag("include_all_devices", include_all_devices)
         set_tag("include_deleted_devices", include_deleted_devices)
 
-        result = await self.db_pool.runInteraction(
-            "get_e2e_device_keys",
-            self._get_e2e_device_keys_txn,
+        result = await self._get_e2e_device_keys(
             query_list,
             include_all_devices,
             include_deleted_devices,
@@ -260,13 +260,13 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         for batch in batch_iter(signature_query, 50):
             cross_sigs_result = await self.db_pool.runInteraction(
-                "get_e2e_cross_signing_signatures",
+                "get_e2e_cross_signing_signatures_for_devices",
                 self._get_e2e_cross_signing_signatures_for_devices_txn,
                 batch,
             )
 
             # add each cross-signing signature to the correct device in the result dict.
-            for (user_id, key_id, device_id, signature) in cross_sigs_result:
+            for user_id, key_id, device_id, signature in cross_sigs_result:
                 target_device_result = result[user_id][device_id]
                 # We've only looked up cross-signatures for non-deleted devices with key
                 # data.
@@ -283,9 +283,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         log_kv(result)
         return result
 
-    def _get_e2e_device_keys_txn(
+    async def _get_e2e_device_keys(
         self,
-        txn: LoggingTransaction,
         query_list: Collection[Tuple[str, Optional[str]]],
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
@@ -309,7 +308,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         # devices.
         user_list = []
         user_device_list = []
-        for (user_id, device_id) in query_list:
+        for user_id, device_id in query_list:
             if device_id is None:
                 user_list.append(user_id)
             else:
@@ -317,7 +316,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         if user_list:
             user_id_in_list_clause, user_args = make_in_list_sql_clause(
-                txn.database_engine, "user_id", user_list
+                self.database_engine, "user_id", user_list
             )
             query_clauses.append(user_id_in_list_clause)
             query_params_list.append(user_args)
@@ -330,13 +329,16 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                     user_device_id_in_list_clause,
                     user_device_args,
                 ) = make_tuple_in_list_sql_clause(
-                    txn.database_engine, ("user_id", "device_id"), user_device_batch
+                    self.database_engine, ("user_id", "device_id"), user_device_batch
                 )
                 query_clauses.append(user_device_id_in_list_clause)
                 query_params_list.append(user_device_args)
 
         result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
-        for query_clause, query_params in zip(query_clauses, query_params_list):
+
+        def get_e2e_device_keys_txn(
+            txn: LoggingTransaction, query_clause: str, query_params: list
+        ) -> None:
             sql = (
                 "SELECT user_id, device_id, "
                 "    d.display_name, "
@@ -351,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
             txn.execute(sql, query_params)
 
-            for (user_id, device_id, display_name, key_json) in txn:
+            for user_id, device_id, display_name, key_json in txn:
                 assert device_id is not None
                 if include_deleted_devices:
                     deleted_devices.remove((user_id, device_id))
@@ -359,6 +361,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
                     display_name, db_to_json(key_json) if key_json else None
                 )
 
+        for query_clause, query_params in zip(query_clauses, query_params_list):
+            await self.db_pool.runInteraction(
+                "_get_e2e_device_keys",
+                get_e2e_device_keys_txn,
+                query_clause,
+                query_params,
+            )
+
         if include_deleted_devices:
             for user_id, device_id in deleted_devices:
                 if device_id is None:
@@ -380,7 +390,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         signature_query_clauses = []
         signature_query_params = []
 
-        for (user_id, device_id) in device_query:
+        for user_id, device_id in device_query:
             signature_query_clauses.append(
                 "target_user_id = ? AND target_device_id = ? AND user_id = ?"
             )
@@ -691,7 +701,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cached(max_entries=10000)
     async def get_e2e_unused_fallback_key_types(
         self, user_id: str, device_id: str
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """Returns the fallback key types that have an unused key.
 
         Args:
@@ -731,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         return user_keys.get(key_type)
 
     @cached(num_args=1)
-    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
+    def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
         """Dummy function.  Only used to make a cache for
         _get_bare_e2e_cross_signing_keys_bulk.
         """
@@ -744,7 +754,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     )
     async def _get_bare_e2e_cross_signing_keys_bulk(
         self, user_ids: Iterable[str]
-    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+    ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.  The output of this
         function should be passed to _get_e2e_cross_signing_signatures_txn if
         the signatures for the calling user need to be fetched.
@@ -765,7 +775,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         )
 
         # The `Optional` comes from the `@cachedList` decorator.
-        return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+        return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
 
     def _get_bare_e2e_cross_signing_keys_bulk_txn(
         self,
@@ -924,7 +934,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
     @cancellable
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
-    ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+    ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
         """Returns the cross-signing keys for a set of users.
 
         Args:
@@ -940,11 +950,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
 
         if from_user_id:
-            result = await self.db_pool.runInteraction(
-                "get_e2e_cross_signing_signatures",
-                self._get_e2e_cross_signing_signatures_txn,
-                result,
-                from_user_id,
+            result = cast(
+                Dict[str, Optional[Mapping[str, JsonDict]]],
+                await self.db_pool.runInteraction(
+                    "get_e2e_cross_signing_signatures",
+                    self._get_e2e_cross_signing_signatures_txn,
+                    result,
+                    from_user_id,
+                ),
             )
 
         return result
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index bbee02ab18..ff3edeb716 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -22,6 +22,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Sequence,
     Set,
     Tuple,
     cast,
@@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
-    async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
+    async def get_max_depth_of(
+        self, event_ids: Collection[str]
+    ) -> Tuple[Optional[str], int]:
         """Returns the event ID and depth for the event that has the max depth from a set of event IDs
 
         Args:
@@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     @cached(max_entries=5000, iterable=True)
-    async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+    async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
         return await self.db_pool.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
@@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     @cancellable
     async def get_forward_extremities_for_room_at_stream_ordering(
         self, room_id: str, stream_ordering: int
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     @cached(max_entries=5000, num_args=2)
     async def _get_forward_extremeties_for_room(
         self, room_id: str, stream_ordering: int
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -1609,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         latest_events: List[str],
         limit: int,
     ) -> List[str]:
-
         seen_events = set(earliest_events)
         front = set(latest_events) - seen_events
         event_results: List[str] = []
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3a0c370fde..eeccf5db24 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -203,11 +203,18 @@ class RoomNotifCounts:
     # Map of thread ID to the notification counts.
     threads: Dict[str, NotifCounts]
 
+    @staticmethod
+    def empty() -> "RoomNotifCounts":
+        return _EMPTY_ROOM_NOTIF_COUNTS
+
     def __len__(self) -> int:
         # To properly account for the amount of space in any caches.
         return len(self.threads) + 1
 
 
+_EMPTY_ROOM_NOTIF_COUNTS = RoomNotifCounts(NotifCounts(), {})
+
+
 def _serialize_action(
     actions: Collection[Union[Mapping, str]], is_highlight: bool
 ) -> str:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1536937b67..a8a4ed4436 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -16,7 +16,6 @@
 import itertools
 import logging
 from collections import OrderedDict
-from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -26,7 +25,6 @@ from typing import (
     Iterable,
     List,
     Optional,
-    Sequence,
     Set,
     Tuple,
 )
@@ -36,7 +34,7 @@ from prometheus_client import Counter
 
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import PartialStateConflictError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
@@ -52,7 +50,7 @@ from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import AbstractStreamIdGenerator
 from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
 from synapse.util.stringutils import non_null_str_or_none
@@ -72,24 +70,6 @@ event_counter = Counter(
 )
 
 
-class PartialStateConflictError(SynapseError):
-    """An internal error raised when attempting to persist an event with partial state
-    after the room containing the event has been un-partial stated.
-
-    This error should be handled by recomputing the event context and trying again.
-
-    This error has an HTTP status code so that it can be transported over replication.
-    It should not be exposed to clients.
-    """
-
-    def __init__(self) -> None:
-        super().__init__(
-            HTTPStatus.CONFLICT,
-            msg="Cannot persist partial state event in un-partial stated room",
-            errcode=Codes.UNKNOWN,
-        )
-
-
 @attr.s(slots=True, auto_attribs=True)
 class DeltaState:
     """Deltas to use to update the `current_state_events` table.
@@ -306,7 +286,7 @@ class PersistEventsStore:
 
         # The set of event_ids to return. This includes all soft-failed events
         # and their prev events.
-        existing_prevs = set()
+        existing_prevs: Set[str] = set()
 
         def _get_prevs_before_rejected_txn(
             txn: LoggingTransaction, batch: Collection[str]
@@ -489,7 +469,6 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         events: List[EventBase],
     ) -> None:
-
         # We only care about state events, so this if there are no state events.
         if not any(e.is_state() for e in events):
             return
@@ -571,7 +550,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, Sequence[str]],
+        event_to_auth_chain: Dict[str, StrCollection],
     ) -> None:
         """Calculate the chain cover index for the given events.
 
@@ -865,7 +844,7 @@ class PersistEventsStore:
         event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
-        event_to_auth_chain: Dict[str, Sequence[str]],
+        event_to_auth_chain: Dict[str, StrCollection],
         events_to_calc_chain_id_for: Set[str],
         chain_map: Dict[str, Tuple[int, int]],
     ) -> Dict[str, Tuple[int, int]]:
@@ -2045,10 +2024,6 @@ class PersistEventsStore:
         self.store._invalidate_cache_and_stream(
             txn, self.store.get_relations_for_event, (redacted_relates_to,)
         )
-        if rel_type == RelationTypes.ANNOTATION:
-            self.store._invalidate_cache_and_stream(
-                txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
-            )
         if rel_type == RelationTypes.REFERENCE:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_references_for_event, (redacted_relates_to,)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index b9d3c36d60..daef3685b0 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
 
 import attr
 
@@ -29,7 +29,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.types import Cursor
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             nbrows = 0
             last_row_event_id = ""
-            for (event_id, event_json_raw) in results:
+            for event_id, event_json_raw in results:
                 try:
                     event_json = db_to_json(event_json_raw)
 
@@ -1061,7 +1061,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             self.event_chain_id_gen,  # type: ignore[attr-defined]
             event_to_room_id,
             event_to_types,
-            cast(Dict[str, Sequence[str]], event_to_auth_chain),
+            cast(Dict[str, StrCollection], event_to_auth_chain),
         )
 
         return _CalculateChainCover(
@@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             results = list(txn)
             # (event_id, parent_id, rel_type) for each relation
             relations_to_insert: List[Tuple[str, str, str]] = []
-            for (event_id, event_json_raw) in results:
+            for event_id, event_json_raw in results:
                 try:
                     event_json = db_to_json(event_json_raw)
                 except Exception as e:
@@ -1220,9 +1220,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                         txn, self.get_relations_for_event, cache_tuple  # type: ignore[attr-defined]
                     )
                     self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
-                        txn, self.get_aggregation_groups_for_event, cache_tuple  # type: ignore[attr-defined]
-                    )
-                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                         txn, self.get_thread_summary, cache_tuple  # type: ignore[attr-defined]
                     )
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a9259fe446..0d90407ebd 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -72,7 +72,6 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
@@ -187,8 +186,8 @@ class EventsWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self._stream_id_gen: AbstractStreamIdTracker
-        self._backfill_id_gen: AbstractStreamIdTracker
+        self._stream_id_gen: AbstractStreamIdGenerator
+        self._backfill_id_gen: AbstractStreamIdGenerator
         if isinstance(database.engine, PostgresEngine):
             # If we're using Postgres than we can use `MultiWriterIdGenerator`
             # regardless of whether this process writes to the streams or not.
@@ -1493,7 +1492,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             txn.execute(redactions_sql + clause, args)
 
-            for (redacter, redacted) in txn:
+            for redacter, redacted in txn:
                 d = event_dict.get(redacted)
                 if d:
                     d.redactions.append(redacter)
@@ -1779,7 +1778,7 @@ class EventsWorkerStore(SQLBaseStore):
             txn: LoggingTransaction,
         ) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
             sql = (
-                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+                "SELECT out.event_stream_ordering, e.event_id, e.room_id, e.type,"
                 " se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
                 " e.outlier"
                 " FROM events AS e"
@@ -1791,10 +1790,10 @@ class EventsWorkerStore(SQLBaseStore):
                 " LEFT JOIN event_relations USING (event_id)"
                 " LEFT JOIN room_memberships USING (event_id)"
                 " LEFT JOIN rejections USING (event_id)"
-                " WHERE ? < event_stream_ordering"
-                " AND event_stream_ordering <= ?"
+                " WHERE ? < out.event_stream_ordering"
+                " AND out.event_stream_ordering <= ?"
                 " AND out.instance_name = ?"
-                " ORDER BY event_stream_ordering ASC"
+                " ORDER BY out.event_stream_ordering ASC"
             )
 
             txn.execute(sql, (last_id, current_id, instance_name))
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 12f3b601f1..8e57c8e5a0 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,7 +17,7 @@ from typing import Optional, Tuple, Union, cast
 
 from canonicaljson import encode_canonical_json
 
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict
@@ -46,8 +46,6 @@ class FilteringWorkerStore(SQLBaseStore):
 
         return db_to_json(def_json)
 
-
-class FilteringStore(FilteringWorkerStore):
     async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
         def_json = encode_canonical_json(user_filter)
 
@@ -79,4 +77,23 @@ class FilteringStore(FilteringWorkerStore):
 
             return filter_id
 
-        return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+        attempts = 0
+        while True:
+            # Try a few times.
+            # This is technically needed if a user tries to create two filters at once,
+            # leading to two concurrent transactions.
+            # The failure case would be:
+            # - SELECT filter_id ... filter_json = ? → both transactions return no rows
+            # - SELECT MAX(filter_id) ... → both transactions return e.g. 5
+            # - INSERT INTO ... → both transactions insert filter_id = 6
+            # One of the transactions will commit. The other will get a unique key
+            # constraint violation error (IntegrityError). This is not the same as a
+            # serialisability violation, which would be automatically retried by
+            # `runInteraction`.
+            try:
+                return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+            except self.db_pool.engine.module.IntegrityError:
+                attempts += 1
+
+                if attempts >= 5:
+                    raise StoreError(500, "Couldn't generate a filter ID.")
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index b202c5eb87..fa8be214ce 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def get_local_media_by_user_paginate_txn(
             txn: LoggingTransaction,
         ) -> Tuple[List[Dict[str, Any]], int]:
-
             # Set ordering
             order_by_column = MediaSortOrder(order_by).value
 
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index db9a24db5e..4b1061e6d7 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage.database import (
@@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
         return await self.db_pool.runInteraction("count_users", _count_users)
 
     @cached(num_args=0)
-    async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
+    async def get_monthly_active_count_by_service(self) -> Mapping[str, int]:
         """Generates current count of monthly active users broken down by service.
         A service is typically an appservice but also includes native matrix users.
         Since the `monthly_active_users` table is populated from the `user_ips` table
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 9213ce0b5a..7a7c0d9c75 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -325,6 +325,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         # We then run the same purge a second time without this isolation level to
         # purge any of those rows which were added during the first.
 
+        logger.info("[purge] Starting initial main purge of [1/2]")
         state_groups_to_delete = await self.db_pool.runInteraction(
             "purge_room",
             self._purge_room_txn,
@@ -332,6 +333,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             isolation_level=IsolationLevel.READ_COMMITTED,
         )
 
+        logger.info("[purge] Starting secondary main purge of [2/2]")
         state_groups_to_delete.extend(
             await self.db_pool.runInteraction(
                 "purge_room",
@@ -339,6 +341,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
                 room_id=room_id,
             ),
         )
+        logger.info("[purge] Done with main purge")
 
         return state_groups_to_delete
 
@@ -376,7 +379,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         )
         referenced_chain_id_tuples = list(txn)
 
-        logger.info("[purge] removing events from event_auth_chain_links")
+        logger.info("[purge] removing from event_auth_chain_links")
         txn.executemany(
             """
             DELETE FROM event_auth_chain_links WHERE
@@ -399,7 +402,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "rejections",
             "state_events",
         ):
-            logger.info("[purge] removing %s from %s", room_id, table)
+            logger.info("[purge] removing from %s", table)
 
             txn.execute(
                 """
@@ -420,12 +423,14 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_push_actions",
             "event_search",
             "event_failed_pull_attempts",
+            # Note: the partial state tables have foreign keys between each other, and to
+            # `events` and `rooms`. We need to delete from them in the right order.
             "partial_state_events",
+            "partial_state_rooms_servers",
+            "partial_state_rooms",
             "events",
             "federation_inbound_events_staging",
             "local_current_membership",
-            "partial_state_rooms_servers",
-            "partial_state_rooms",
             "receipts_graph",
             "receipts_linearized",
             "room_aliases",
@@ -452,7 +457,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             # happy
             "rooms",
         ):
-            logger.info("[purge] removing %s from %s", room_id, table)
+            logger.info("[purge] removing from %s", table)
             txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
 
         # Other tables we do NOT need to clear out:
@@ -484,6 +489,4 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         #   that already exist.
         self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
 
-        logger.info("[purge] done")
-
         return state_groups
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 9b2bbe060d..9f862f00c1 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -46,7 +46,6 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     IdGenerator,
     StreamIdGenerator,
 )
@@ -118,7 +117,7 @@ class PushRulesWorkerStore(
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+        self._push_rules_stream_id_gen = StreamIdGenerator(
             db_conn,
             hs.get_replication_notifier(),
             "push_rules_stream",
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df53e726e6..9a24f7a655 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -36,7 +36,6 @@ from synapse.storage.database import (
 )
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     StreamIdGenerator,
 )
 from synapse.types import JsonDict
@@ -60,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+        self._pushers_id_gen = StreamIdGenerator(
             db_conn,
             hs.get_replication_notifier(),
             "pushers",
@@ -344,7 +343,6 @@ class PusherWorkerStore(SQLBaseStore):
         last_user = progress.get("last_user", "")
 
         def _delete_pushers(txn: LoggingTransaction) -> int:
-
             sql = """
                 SELECT name FROM users
                 WHERE deactivated = ? and name > ?
@@ -392,7 +390,6 @@ class PusherWorkerStore(SQLBaseStore):
         last_pusher = progress.get("last_pusher", 0)
 
         def _delete_pushers(txn: LoggingTransaction) -> int:
-
             sql = """
                 SELECT p.id, access_token FROM pushers AS p
                 LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
@@ -449,7 +446,6 @@ class PusherWorkerStore(SQLBaseStore):
         last_pusher = progress.get("last_pusher", 0)
 
         def _delete_pushers(txn: LoggingTransaction) -> int:
-
             sql = """
                 SELECT p.id, p.user_name, p.app_id, p.pushkey
                 FROM pushers AS p
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 29972d5204..074942b167 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -21,7 +21,9 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
+    Sequence,
     Tuple,
     cast,
 )
@@ -37,7 +39,7 @@ from synapse.storage.database import (
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.engines._base import IsolationLevel
 from synapse.storage.util.id_generators import (
-    AbstractStreamIdTracker,
+    AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
@@ -63,7 +65,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._receipts_id_gen: AbstractStreamIdTracker
+        self._receipts_id_gen: AbstractStreamIdGenerator
 
         if isinstance(database.engine, PostgresEngine):
             self._can_write_to_receipts = (
@@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     async def get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[dict]:
+    ) -> Sequence[JsonDict]:
         """Get receipts for a single room for sending to clients.
 
         Args:
@@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     @cached(tree=True)
     async def _get_linearized_receipts_for_room(
         self, room_id: str, to_key: int, from_key: Optional[int] = None
-    ) -> List[JsonDict]:
+    ) -> Sequence[JsonDict]:
         """See get_linearized_receipts_for_room"""
 
         def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     async def _get_linearized_receipts_for_rooms(
         self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
-    ) -> Dict[str, List[JsonDict]]:
+    ) -> Dict[str, Sequence[JsonDict]]:
         if not room_ids:
             return {}
 
@@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
     )
     async def get_linearized_receipts_for_all_rooms(
         self, to_key: int, from_key: Optional[int] = None
-    ) -> Dict[str, JsonDict]:
+    ) -> Mapping[str, JsonDict]:
         """Get receipts for all rooms between two stream_ids, up
         to a limit of the latest 100 read receipts.
 
@@ -766,7 +768,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
             )
 
-        async with self._receipts_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
+        async with self._receipts_id_gen.get_next() as stream_id:
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self._insert_linearized_receipt_txn,
@@ -885,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
         def _populate_receipt_event_stream_ordering_txn(
             txn: LoggingTransaction,
         ) -> bool:
-
             if "max_stream_id" in progress:
                 max_stream_id = progress["max_stream_id"]
             else:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 31f0f2bd3d..717237e024 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
 
 import attr
 
@@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             )
 
     @cached()
-    async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+    async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
         """Deprecated: use get_userinfo_by_id instead"""
 
         def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
@@ -1002,19 +1002,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="user_delete_threepid",
         )
 
-    async def user_delete_threepids(self, user_id: str) -> None:
-        """Delete all threepid this user has bound
-
-        Args:
-             user_id: The user id to delete all threepids of
-
-        """
-        await self.db_pool.simple_delete(
-            "user_threepids",
-            keyvalues={"user_id": user_id},
-            desc="user_delete_threepids",
-        )
-
     async def add_user_bound_threepid(
         self, user_id: str, medium: str, address: str, id_server: str
     ) -> None:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0018d6f7ab..bc3a83919c 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -22,6 +22,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Union,
@@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
         direction: Direction = Direction.BACKWARDS,
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
-    ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
+    ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
         """Get a list of relations for an event, ordered by topological ordering.
 
         Args:
@@ -397,141 +398,6 @@ class RelationsWorkerStore(SQLBaseStore):
         return result is not None
 
     @cached()
-    async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
-        raise NotImplementedError()
-
-    @cachedList(
-        cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
-    )
-    async def get_aggregation_groups_for_events(
-        self, event_ids: Collection[str]
-    ) -> Mapping[str, Optional[List[JsonDict]]]:
-        """Get a list of annotations on the given events, grouped by event type and
-        aggregation key, sorted by count.
-
-        This is used e.g. to get the what and how many reactions have happend
-        on an event.
-
-        Args:
-            event_ids: Fetch events that relate to these event IDs.
-
-        Returns:
-            A map of event IDs to a list of groups of annotations that match.
-            Each entry is a dict with `type`, `key` and `count` fields.
-        """
-        # The number of entries to return per event ID.
-        limit = 5
-
-        clause, args = make_in_list_sql_clause(
-            self.database_engine, "relates_to_id", event_ids
-        )
-        args.append(RelationTypes.ANNOTATION)
-
-        sql = f"""
-            SELECT
-                relates_to_id,
-                annotation.type,
-                aggregation_key,
-                COUNT(DISTINCT annotation.sender)
-            FROM events AS annotation
-            INNER JOIN event_relations USING (event_id)
-            INNER JOIN events AS parent ON
-                parent.event_id = relates_to_id
-                AND parent.room_id = annotation.room_id
-            WHERE
-                {clause}
-                AND relation_type = ?
-            GROUP BY relates_to_id, annotation.type, aggregation_key
-            ORDER BY relates_to_id, COUNT(*) DESC
-        """
-
-        def _get_aggregation_groups_for_events_txn(
-            txn: LoggingTransaction,
-        ) -> Mapping[str, List[JsonDict]]:
-            txn.execute(sql, args)
-
-            result: Dict[str, List[JsonDict]] = {}
-            for event_id, type, key, count in cast(
-                List[Tuple[str, str, str, int]], txn
-            ):
-                event_results = result.setdefault(event_id, [])
-
-                # Limit the number of results per event ID.
-                if len(event_results) == limit:
-                    continue
-
-                event_results.append({"type": type, "key": key, "count": count})
-
-            return result
-
-        return await self.db_pool.runInteraction(
-            "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn
-        )
-
-    async def get_aggregation_groups_for_users(
-        self, event_ids: Collection[str], users: FrozenSet[str]
-    ) -> Dict[str, Dict[Tuple[str, str], int]]:
-        """Fetch the partial aggregations for an event for specific users.
-
-        This is used, in conjunction with get_aggregation_groups_for_event, to
-        remove information from the results for ignored users.
-
-        Args:
-            event_ids: Fetch events that relate to these event IDs.
-            users: The users to fetch information for.
-
-        Returns:
-            A map of event ID to a map of (event type, aggregation key) to a
-            count of users.
-        """
-
-        if not users:
-            return {}
-
-        events_sql, args = make_in_list_sql_clause(
-            self.database_engine, "relates_to_id", event_ids
-        )
-
-        users_sql, users_args = make_in_list_sql_clause(
-            self.database_engine, "annotation.sender", users
-        )
-        args.extend(users_args)
-        args.append(RelationTypes.ANNOTATION)
-
-        sql = f"""
-            SELECT
-                relates_to_id,
-                annotation.type,
-                aggregation_key,
-                COUNT(DISTINCT annotation.sender)
-            FROM events AS annotation
-            INNER JOIN event_relations USING (event_id)
-            INNER JOIN events AS parent ON
-                parent.event_id = relates_to_id
-                AND parent.room_id = annotation.room_id
-            WHERE {events_sql} AND {users_sql} AND relation_type = ?
-            GROUP BY relates_to_id, annotation.type, aggregation_key
-            ORDER BY relates_to_id, COUNT(*) DESC
-        """
-
-        def _get_aggregation_groups_for_users_txn(
-            txn: LoggingTransaction,
-        ) -> Dict[str, Dict[Tuple[str, str], int]]:
-            txn.execute(sql, args)
-
-            result: Dict[str, Dict[Tuple[str, str], int]] = {}
-            for event_id, type, key, count in cast(
-                List[Tuple[str, str, str, int]], txn
-            ):
-                result.setdefault(event_id, {})[(type, key)] = count
-
-            return result
-
-        return await self.db_pool.runInteraction(
-            "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
-        )
-
-    @cached()
     async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
         raise NotImplementedError()
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 644bbb8878..3825bd6079 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1417,6 +1417,204 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             get_un_partial_stated_rooms_from_stream_txn,
         )
 
+    async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
+        """Retrieve an event report
+
+        Args:
+            report_id: ID of reported event in database
+        Returns:
+            JSON dict of information from an event report or None if the
+            report does not exist.
+        """
+
+        def _get_event_report_txn(
+            txn: LoggingTransaction, report_id: int
+        ) -> Optional[Dict[str, Any]]:
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.content,
+                    events.sender,
+                    room_stats_state.canonical_alias,
+                    room_stats_state.name,
+                    event_json.json AS event_json
+                FROM event_reports AS er
+                LEFT JOIN events
+                    ON events.event_id = er.event_id
+                JOIN event_json
+                    ON event_json.event_id = er.event_id
+                JOIN room_stats_state
+                    ON room_stats_state.room_id = er.room_id
+                WHERE er.id = ?
+            """
+
+            txn.execute(sql, [report_id])
+            row = txn.fetchone()
+
+            if not row:
+                return None
+
+            event_report = {
+                "id": row[0],
+                "received_ts": row[1],
+                "room_id": row[2],
+                "event_id": row[3],
+                "user_id": row[4],
+                "score": db_to_json(row[5]).get("score"),
+                "reason": db_to_json(row[5]).get("reason"),
+                "sender": row[6],
+                "canonical_alias": row[7],
+                "name": row[8],
+                "event_json": db_to_json(row[9]),
+            }
+
+            return event_report
+
+        return await self.db_pool.runInteraction(
+            "get_event_report", _get_event_report_txn, report_id
+        )
+
+    async def get_event_reports_paginate(
+        self,
+        start: int,
+        limit: int,
+        direction: Direction = Direction.BACKWARDS,
+        user_id: Optional[str] = None,
+        room_id: Optional[str] = None,
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Retrieve a paginated list of event reports
+
+        Args:
+            start: event offset to begin the query from
+            limit: number of rows to retrieve
+            direction: Whether to fetch the most recent first (backwards) or the
+                oldest first (forwards)
+            user_id: search for user_id. Ignored if user_id is None
+            room_id: search for room_id. Ignored if room_id is None
+        Returns:
+            Tuple of:
+                json list of event reports
+                total number of event reports matching the filter criteria
+        """
+
+        def _get_event_reports_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], int]:
+            filters = []
+            args: List[object] = []
+
+            if user_id:
+                filters.append("er.user_id LIKE ?")
+                args.extend(["%" + user_id + "%"])
+            if room_id:
+                filters.append("er.room_id LIKE ?")
+                args.extend(["%" + room_id + "%"])
+
+            if direction == Direction.BACKWARDS:
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+            # We join on room_stats_state despite not using any columns from it
+            # because the join can influence the number of rows returned;
+            # e.g. a room that doesn't have state, maybe because it was deleted.
+            # The query returning the total count should be consistent with
+            # the query returning the results.
+            sql = """
+                SELECT COUNT(*) as total_event_reports
+                FROM event_reports AS er
+                JOIN room_stats_state ON room_stats_state.room_id = er.room_id
+                {}
+                """.format(
+                where_clause
+            )
+            txn.execute(sql, args)
+            count = cast(Tuple[int], txn.fetchone())[0]
+
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.content,
+                    events.sender,
+                    room_stats_state.canonical_alias,
+                    room_stats_state.name
+                FROM event_reports AS er
+                LEFT JOIN events
+                    ON events.event_id = er.event_id
+                JOIN room_stats_state
+                    ON room_stats_state.room_id = er.room_id
+                {where_clause}
+                ORDER BY er.received_ts {order}
+                LIMIT ?
+                OFFSET ?
+            """.format(
+                where_clause=where_clause,
+                order=order,
+            )
+
+            args += [limit, start]
+            txn.execute(sql, args)
+
+            event_reports = []
+            for row in txn:
+                try:
+                    s = db_to_json(row[5]).get("score")
+                    r = db_to_json(row[5]).get("reason")
+                except Exception:
+                    logger.error("Unable to parse json from event_reports: %s", row[0])
+                    continue
+                event_reports.append(
+                    {
+                        "id": row[0],
+                        "received_ts": row[1],
+                        "room_id": row[2],
+                        "event_id": row[3],
+                        "user_id": row[4],
+                        "score": s,
+                        "reason": r,
+                        "sender": row[6],
+                        "canonical_alias": row[7],
+                        "name": row[8],
+                    }
+                )
+
+            return event_reports, count
+
+        return await self.db_pool.runInteraction(
+            "get_event_reports_paginate", _get_event_reports_paginate_txn
+        )
+
+    async def delete_event_report(self, report_id: int) -> bool:
+        """Remove an event report from database.
+
+        Args:
+            report_id: Report to delete
+
+        Returns:
+            Whether the report was successfully deleted or not.
+        """
+        try:
+            await self.db_pool.simple_delete_one(
+                table="event_reports",
+                keyvalues={"id": report_id},
+                desc="delete_event_report",
+            )
+        except StoreError:
+            # Deletion failed because report does not exist
+            return False
+
+        return True
+
 
 class _BackgroundUpdates:
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -2139,7 +2337,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         reason: Optional[str],
         content: JsonDict,
         received_ts: int,
-    ) -> None:
+    ) -> int:
+        """Add an event report
+
+        Args:
+            room_id: Room that contains the reported event.
+            event_id: The reported event.
+            user_id: User who reports the event.
+            reason: Description that the user specifies.
+            content: Report request body (score and reason).
+            received_ts: Time when the user submitted the report (milliseconds).
+        Returns:
+            Id of the event report.
+        """
         next_id = self._event_reports_id_gen.get_next()
         await self.db_pool.simple_insert(
             table="event_reports",
@@ -2154,184 +2364,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             },
             desc="add_event_report",
         )
-
-    async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
-        """Retrieve an event report
-
-        Args:
-            report_id: ID of reported event in database
-        Returns:
-            JSON dict of information from an event report or None if the
-            report does not exist.
-        """
-
-        def _get_event_report_txn(
-            txn: LoggingTransaction, report_id: int
-        ) -> Optional[Dict[str, Any]]:
-
-            sql = """
-                SELECT
-                    er.id,
-                    er.received_ts,
-                    er.room_id,
-                    er.event_id,
-                    er.user_id,
-                    er.content,
-                    events.sender,
-                    room_stats_state.canonical_alias,
-                    room_stats_state.name,
-                    event_json.json AS event_json
-                FROM event_reports AS er
-                LEFT JOIN events
-                    ON events.event_id = er.event_id
-                JOIN event_json
-                    ON event_json.event_id = er.event_id
-                JOIN room_stats_state
-                    ON room_stats_state.room_id = er.room_id
-                WHERE er.id = ?
-            """
-
-            txn.execute(sql, [report_id])
-            row = txn.fetchone()
-
-            if not row:
-                return None
-
-            event_report = {
-                "id": row[0],
-                "received_ts": row[1],
-                "room_id": row[2],
-                "event_id": row[3],
-                "user_id": row[4],
-                "score": db_to_json(row[5]).get("score"),
-                "reason": db_to_json(row[5]).get("reason"),
-                "sender": row[6],
-                "canonical_alias": row[7],
-                "name": row[8],
-                "event_json": db_to_json(row[9]),
-            }
-
-            return event_report
-
-        return await self.db_pool.runInteraction(
-            "get_event_report", _get_event_report_txn, report_id
-        )
-
-    async def get_event_reports_paginate(
-        self,
-        start: int,
-        limit: int,
-        direction: Direction = Direction.BACKWARDS,
-        user_id: Optional[str] = None,
-        room_id: Optional[str] = None,
-    ) -> Tuple[List[Dict[str, Any]], int]:
-        """Retrieve a paginated list of event reports
-
-        Args:
-            start: event offset to begin the query from
-            limit: number of rows to retrieve
-            direction: Whether to fetch the most recent first (backwards) or the
-                oldest first (forwards)
-            user_id: search for user_id. Ignored if user_id is None
-            room_id: search for room_id. Ignored if room_id is None
-        Returns:
-            Tuple of:
-                json list of event reports
-                total number of event reports matching the filter criteria
-        """
-
-        def _get_event_reports_paginate_txn(
-            txn: LoggingTransaction,
-        ) -> Tuple[List[Dict[str, Any]], int]:
-            filters = []
-            args: List[object] = []
-
-            if user_id:
-                filters.append("er.user_id LIKE ?")
-                args.extend(["%" + user_id + "%"])
-            if room_id:
-                filters.append("er.room_id LIKE ?")
-                args.extend(["%" + room_id + "%"])
-
-            if direction == Direction.BACKWARDS:
-                order = "DESC"
-            else:
-                order = "ASC"
-
-            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
-
-            # We join on room_stats_state despite not using any columns from it
-            # because the join can influence the number of rows returned;
-            # e.g. a room that doesn't have state, maybe because it was deleted.
-            # The query returning the total count should be consistent with
-            # the query returning the results.
-            sql = """
-                SELECT COUNT(*) as total_event_reports
-                FROM event_reports AS er
-                JOIN room_stats_state ON room_stats_state.room_id = er.room_id
-                {}
-                """.format(
-                where_clause
-            )
-            txn.execute(sql, args)
-            count = cast(Tuple[int], txn.fetchone())[0]
-
-            sql = """
-                SELECT
-                    er.id,
-                    er.received_ts,
-                    er.room_id,
-                    er.event_id,
-                    er.user_id,
-                    er.content,
-                    events.sender,
-                    room_stats_state.canonical_alias,
-                    room_stats_state.name
-                FROM event_reports AS er
-                LEFT JOIN events
-                    ON events.event_id = er.event_id
-                JOIN room_stats_state
-                    ON room_stats_state.room_id = er.room_id
-                {where_clause}
-                ORDER BY er.received_ts {order}
-                LIMIT ?
-                OFFSET ?
-            """.format(
-                where_clause=where_clause,
-                order=order,
-            )
-
-            args += [limit, start]
-            txn.execute(sql, args)
-
-            event_reports = []
-            for row in txn:
-                try:
-                    s = db_to_json(row[5]).get("score")
-                    r = db_to_json(row[5]).get("reason")
-                except Exception:
-                    logger.error("Unable to parse json from event_reports: %s", row[0])
-                    continue
-                event_reports.append(
-                    {
-                        "id": row[0],
-                        "received_ts": row[1],
-                        "room_id": row[2],
-                        "event_id": row[3],
-                        "user_id": row[4],
-                        "score": s,
-                        "reason": r,
-                        "sender": row[6],
-                        "canonical_alias": row[7],
-                        "name": row[8],
-                    }
-                )
-
-            return event_reports, count
-
-        return await self.db_pool.runInteraction(
-            "get_event_reports_paginate", _get_event_reports_paginate_txn
-        )
+        return next_id
 
     async def block_room(self, room_id: str, user_id: str) -> None:
         """Marks the room as blocked.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ea6a5e2f34..694a5b802c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -24,6 +24,7 @@ from typing import (
     List,
     Mapping,
     Optional,
+    Sequence,
     Set,
     Tuple,
     Union,
@@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return self._known_servers_count
 
     @cached(max_entries=100000, iterable=True)
-    async def get_users_in_room(self, room_id: str) -> List[str]:
+    async def get_users_in_room(self, room_id: str) -> Sequence[str]:
         """Returns a list of users in the room.
 
         Will return inaccurate results for rooms with partial state, since the state for
@@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached()
-    def get_user_in_room_with_profile(
-        self, room_id: str, user_id: str
-    ) -> Dict[str, ProfileInfo]:
+    def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
         raise NotImplementedError()
 
     @cachedList(
@@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached(max_entries=100000, iterable=True)
     async def get_users_in_room_with_profiles(
         self, room_id: str
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Mapping[str, ProfileInfo]:
         """Get a mapping from user ID to profile information for all users in a given room.
 
         The profile information comes directly from this room's `m.room.member`
@@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     @cached(max_entries=100000)
-    async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
+    async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
         """Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
         Args:
@@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached()
     async def get_invited_rooms_for_local_user(
         self, user_id: str
-    ) -> List[RoomsForUser]:
+    ) -> Sequence[RoomsForUser]:
         """Get all the rooms the *local* user is invited to.
 
         Args:
@@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return results
 
     @cached(iterable=True)
-    async def get_local_users_in_room(self, room_id: str) -> List[str]:
+    async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
         """
         Retrieves a list of the current roommembers who are local to the server.
         """
@@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         """Returns the set of users who share a room with `user_id`"""
         room_ids = await self.get_rooms_for_user(user_id)
 
-        user_who_share_room = set()
+        user_who_share_room: Set[str] = set()
         for room_id in room_ids:
             user_ids = await self.get_users_in_room(room_id)
             user_who_share_room.update(user_ids)
@@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return True
 
     @cached(iterable=True, max_entries=10000)
-    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+    async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
         """Get current hosts in room based on current state."""
 
         # First we check if we already have `get_users_in_room` in the cache, as
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3fe433f66c..a7aae661d8 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore):
 
 
 class SearchBackgroundUpdateStore(SearchWorkerStore):
-
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
@@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore):
             """
             count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
-
             # We use CROSS JOIN here to ensure we use the right indexes.
             # https://sqlite.org/optoverview.html#crossjoin
             #
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 05da15074a..5dcb1fc0b5 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Collection, Dict, List, Tuple
+from typing import Collection, Dict, List, Mapping, Tuple
 
 from unpaddedbase64 import encode_base64
 
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
 
 class SignatureWorkerStore(EventsWorkerStore):
     @cached()
-    def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
+    def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]:
         # This is a dummy function to allow get_event_reference_hashes
         # to use its cache
         raise NotImplementedError()
@@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore):
     )
     async def get_event_reference_hashes(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Dict[str, bytes]]:
+    ) -> Mapping[str, Mapping[str, bytes]]:
         """Get all hashes for given events.
 
         Args:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ba325d390b..ebb2ae964f 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
 
 class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
-
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
     DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7b7d0c3c9..d3393d8e49 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore):
         insert_cols = []
         qargs = []
 
-        for (key, val) in chain(
+        for key, val in chain(
             keyvalues.items(), absolutes.items(), additive_relatives.items()
         ):
             insert_cols.append(key)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 818c46182e..ac5fbf6b86 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000
 _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
+
 # Used as return values for pagination APIs
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _EventDictReturn:
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index d5500cdd47..c149a9eacb 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, Iterable, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.tcp.streams import AccountDataStream
@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
 
 class TagsWorkerStore(AccountDataWorkerStore):
     @cached()
-    async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
+    async def get_tags_for_user(
+        self, user_id: str
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
         """Get all the tags for a user.
 
 
@@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
 
     async def get_updated_tags(
         self, user_id: str, stream_id: int
-    ) -> Dict[str, Dict[str, JsonDict]]:
+    ) -> Mapping[str, Mapping[str, JsonDict]]:
         """Get all the tags for the rooms where the tags have changed since the
         given version
 
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6b33d809b6..6d72bd9f67 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         def get_destination_rooms_paginate_txn(
             txn: LoggingTransaction,
         ) -> Tuple[List[JsonDict], int]:
-
             if direction == Direction.BACKWARDS:
                 order = "DESC"
             else:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 14ef5b040d..f16a509ac4 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,11 +14,12 @@
 
 import logging
 import re
+import unicodedata
 from typing import (
     TYPE_CHECKING,
-    Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Sequence,
     Set,
@@ -98,7 +99,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     async def _populate_user_directory_createtables(
         self, progress: JsonDict, batch_size: int
     ) -> int:
-
         # Get all the rooms that we want to process.
         def _make_staging_area(txn: LoggingTransaction) -> None:
             sql = (
@@ -491,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 values={"display_name": display_name, "avatar_url": avatar_url},
             )
 
+            # The display name that goes into the database index.
+            index_display_name = display_name
+            if index_display_name is not None:
+                index_display_name = _filter_text_for_index(index_display_name)
+
             if isinstance(self.database_engine, PostgresEngine):
                 # We weight the localpart most highly, then display name and finally
                 # server name
@@ -508,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                         user_id,
                         get_localpart_from_id(user_id),
                         get_domain_from_id(user_id),
-                        display_name,
+                        index_display_name,
                     ),
                 )
             elif isinstance(self.database_engine, Sqlite3Engine):
-                value = "%s %s" % (user_id, display_name) if display_name else user_id
+                value = (
+                    "%s %s" % (user_id, index_display_name)
+                    if index_display_name
+                    else user_id
+                )
                 self.db_pool.simple_upsert_txn(
                     txn,
                     table="user_directory_search",
@@ -586,7 +595,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
     @cached()
-    async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
+    async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
         return await self.db_pool.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
@@ -897,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         return {"limited": limited, "results": results[0:limit]}
 
 
+def _filter_text_for_index(text: str) -> str:
+    """Transforms text before it is inserted into the user directory index, or searched
+    for in the user directory index.
+
+    Note that the user directory search table needs to be rebuilt whenever this function
+    changes.
+    """
+    # Lowercase the text, to make searches case-insensitive.
+    # This is necessary for both PostgreSQL and SQLite. PostgreSQL's
+    # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using
+    # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at
+    # all.
+    text = text.lower()
+
+    # Normalize the text. NFKC normalization has two effects:
+    #  1. It canonicalizes the text, ie. maps all visually identical strings to the same
+    #     string. For example, ["e", "◌́"] is mapped to ["é"].
+    #  2. It maps strings that are roughly equivalent to the same string.
+    #     For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to
+    #     ["i", "9"].
+    text = unicodedata.normalize("NFKC", text)
+
+    # Note that nothing is done to make searches accent-insensitive.
+    # That could be achieved by converting to NFKD form instead (with combining accents
+    # split out) and filtering out combining accents using `unicodedata.combining(c)`.
+    # The downside of this may be noisier search results, since search terms with
+    # explicit accents will match characters with no accents, or completely different
+    # accents.
+    #
+    # text = unicodedata.normalize("NFKD", text)
+    # text = "".join([c for c in text if not unicodedata.combining(c)])
+
+    return text
+
+
 def _parse_query_sqlite(search_term: str) -> str:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
@@ -906,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str:
     We specifically add both a prefix and non prefix matching term so that
     exact matches get ranked higher.
     """
+    search_term = _filter_text_for_index(search_term)
 
     # Pull out the individual words, discarding any non-word characters.
     results = _parse_words(search_term)
@@ -918,11 +963,21 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
     We use this so that we can add prefix matching, which isn't something
     that is supported by default.
     """
-    results = _parse_words(search_term)
+    search_term = _filter_text_for_index(search_term)
+
+    escaped_words = []
+    for word in _parse_words(search_term):
+        # Postgres tsvector and tsquery quoting rules:
+        # words potentially containing punctuation should be quoted
+        # and then existing quotes and backslashes should be doubled
+        # See: https://www.postgresql.org/docs/current/datatype-textsearch.html#DATATYPE-TSQUERY
 
-    both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
-    exact = " & ".join("%s" % (result,) for result in results)
-    prefix = " & ".join("%s:*" % (result,) for result in results)
+        quoted_word = word.replace("'", "''").replace("\\", "\\\\")
+        escaped_words.append(f"'{quoted_word}'")
+
+    both = " & ".join("(%s:* | %s)" % (word, word) for word in escaped_words)
+    exact = " & ".join("%s" % (word,) for word in escaped_words)
+    prefix = " & ".join("%s:*" % (word,) for word in escaped_words)
 
     return both, exact, prefix
 
@@ -944,6 +999,14 @@ def _parse_words(search_term: str) -> List[str]:
     if USE_ICU:
         return _parse_words_with_icu(search_term)
 
+    return _parse_words_with_regex(search_term)
+
+
+def _parse_words_with_regex(search_term: str) -> List[str]:
+    """
+    Break down search term into words, when we don't have ICU available.
+    See: `_parse_words`
+    """
     return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
 
 
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index d743282f13..097dea5182 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
 
 
 class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
-
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 1a7232b276..29ff64e876 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
 import attr
 
 from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -257,14 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         member_filter, non_member_filter = state_filter.get_member_split()
 
         # Now we look them up in the member and non-member caches
-        (
-            non_member_state,
-            incomplete_groups_nm,
-        ) = self._get_state_for_groups_using_cache(
+        non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache(
             groups, self._state_group_cache, state_filter=non_member_filter
         )
 
-        (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
+        member_state, incomplete_groups_m = self._get_state_for_groups_using_cache(
             groups, self._state_group_members_cache, state_filter=member_filter
         )
 
@@ -404,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 fetched_keys=non_member_types,
             )
 
+    async def store_state_deltas_for_batched(
+        self,
+        events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
+        room_id: str,
+        prev_group: int,
+    ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+        """Generate and store state deltas for a group of events and contexts created to be
+        batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
+
+        Args:
+            events_and_context: the events to generate and store a state groups for
+            and their associated contexts
+            room_id: the id of the room the events were created for
+            prev_group: the state group of the last event persisted before the batched events
+            were created
+        """
+
+        def insert_deltas_group_txn(
+            txn: LoggingTransaction,
+            events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
+            prev_group: int,
+        ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+            """Generate and store state groups for the provided events and contexts.
+
+            Requires that we have the state as a delta from the last persisted state group.
+
+            Returns:
+                A list of state groups
+            """
+            is_in_db = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="state_groups",
+                keyvalues={"id": prev_group},
+                retcol="id",
+                allow_none=True,
+            )
+            if not is_in_db:
+                raise Exception(
+                    "Trying to persist state with unpersisted prev_group: %r"
+                    % (prev_group,)
+                )
+
+            num_state_groups = sum(
+                1 for event, _ in events_and_context if event.is_state()
+            )
+
+            state_groups = self._state_group_seq_gen.get_next_mult_txn(
+                txn, num_state_groups
+            )
+
+            sg_before = prev_group
+            state_group_iter = iter(state_groups)
+            for event, context in events_and_context:
+                if not event.is_state():
+                    context.state_group_after_event = sg_before
+                    context.state_group_before_event = sg_before
+                    continue
+
+                sg_after = next(state_group_iter)
+                context.state_group_after_event = sg_after
+                context.state_group_before_event = sg_before
+                context.state_delta_due_to_event = {
+                    (event.type, event.state_key): event.event_id
+                }
+                sg_before = sg_after
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups",
+                keys=("id", "room_id", "event_id"),
+                values=[
+                    (context.state_group_after_event, room_id, event.event_id)
+                    for event, context in events_and_context
+                    if event.is_state()
+                ],
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_group_edges",
+                keys=("state_group", "prev_state_group"),
+                values=[
+                    (
+                        context.state_group_after_event,
+                        context.state_group_before_event,
+                    )
+                    for event, context in events_and_context
+                    if event.is_state()
+                ],
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (
+                        context.state_group_after_event,
+                        room_id,
+                        key[0],
+                        key[1],
+                        state_id,
+                    )
+                    for event, context in events_and_context
+                    if context.state_delta_due_to_event is not None
+                    for key, state_id in context.state_delta_due_to_event.items()
+                ],
+            )
+            return events_and_context
+
+        return await self.db_pool.runInteraction(
+            "store_state_deltas_for_batched.insert_deltas_group",
+            insert_deltas_group_txn,
+            events_and_context,
+            prev_group,
+        )
+
     async def store_state_group(
         self,
         event_id: str,
@@ -689,12 +805,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state_groups_to_delete: State groups to delete
         """
 
+        logger.info("[purge] Starting state purge")
         await self.db_pool.runInteraction(
             "purge_room_state",
             self._purge_room_state_txn,
             room_id,
             state_groups_to_delete,
         )
+        logger.info("[purge] Done with state purge")
 
     def _purge_room_state_txn(
         self,