summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/__init__.py4
-rw-r--r--synapse/storage/databases/main/account_data.py76
-rw-r--r--synapse/storage/databases/main/cache.py3
-rw-r--r--synapse/storage/databases/main/deviceinbox.py5
-rw-r--r--synapse/storage/databases/main/devices.py11
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py8
-rw-r--r--synapse/storage/databases/main/event_federation.py1
-rw-r--r--synapse/storage/databases/main/events.py5
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py7
-rw-r--r--synapse/storage/databases/main/filtering.py25
-rw-r--r--synapse/storage/databases/main/media_repository.py1
-rw-r--r--synapse/storage/databases/main/push_rule.py3
-rw-r--r--synapse/storage/databases/main/pusher.py6
-rw-r--r--synapse/storage/databases/main/receipts.py7
-rw-r--r--synapse/storage/databases/main/registration.py13
-rw-r--r--synapse/storage/databases/main/relations.py137
-rw-r--r--synapse/storage/databases/main/room.py391
-rw-r--r--synapse/storage/databases/main/search.py2
-rw-r--r--synapse/storage/databases/main/state.py1
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--synapse/storage/databases/main/stream.py1
-rw-r--r--synapse/storage/databases/main/transactions.py1
-rw-r--r--synapse/storage/databases/main/user_directory.py77
-rw-r--r--synapse/storage/databases/state/bg_updates.py1
-rw-r--r--synapse/storage/databases/state/store.py126
-rw-r--r--synapse/storage/engines/__init__.py4
-rw-r--r--synapse/storage/prepare_database.py4
-rw-r--r--synapse/storage/types.py74
-rw-r--r--synapse/storage/util/id_generators.py62
-rw-r--r--synapse/storage/util/sequence.py2
32 files changed, 613 insertions, 456 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 837dc7646e..dc3948c170 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -43,7 +43,7 @@ from .event_federation import EventFederationStore
 from .event_push_actions import EventPushActionsStore
 from .events_bg_updates import EventsBackgroundUpdatesStore
 from .events_forward_extremities import EventForwardExtremitiesStore
-from .filtering import FilteringStore
+from .filtering import FilteringWorkerStore
 from .keys import KeyStore
 from .lock import LockStore
 from .media_repository import MediaRepositoryStore
@@ -99,7 +99,7 @@ class DataStore(
     EventFederationStore,
     MediaRepositoryStore,
     RejectionsStore,
-    FilteringStore,
+    FilteringWorkerStore,
     PusherStore,
     PushRuleStore,
     ApplicationServiceTransactionStore,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 95567826f2..a9843f6e17 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -40,7 +40,6 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
@@ -64,14 +63,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
     ):
         super().__init__(database, db_conn, hs)
 
-        # `_can_write_to_account_data` indicates whether the current worker is allowed
-        # to write account data. A value of `True` implies that `_account_data_id_gen`
-        # is an `AbstractStreamIdGenerator` and not just a tracker.
-        self._account_data_id_gen: AbstractStreamIdTracker
         self._can_write_to_account_data = (
             self._instance_name in hs.config.worker.writers.account_data
         )
 
+        self._account_data_id_gen: AbstractStreamIdGenerator
+
         if isinstance(database.engine, PostgresEngine):
             self._account_data_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
@@ -237,6 +234,37 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         else:
             return None
 
+    async def get_latest_stream_id_for_global_account_data_by_type_for_user(
+        self, user_id: str, data_type: str
+    ) -> Optional[int]:
+        """
+        Returns:
+            The stream ID of the account data,
+            or None if there is no such account data.
+        """
+
+        def get_latest_stream_id_for_global_account_data_by_type_for_user_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[int]:
+            sql = """
+                SELECT stream_id FROM account_data
+                WHERE user_id = ? AND account_data_type = ?
+                ORDER BY stream_id DESC
+                LIMIT 1
+            """
+            txn.execute(sql, (user_id, data_type))
+
+            row = txn.fetchone()
+            if row:
+                return row[0]
+            else:
+                return None
+
+        return await self.db_pool.runInteraction(
+            "get_latest_stream_id_for_global_account_data_by_type_for_user",
+            get_latest_stream_id_for_global_account_data_by_type_for_user_txn,
+        )
+
     @cached(num_args=2, tree=True)
     async def get_account_data_for_room(
         self, user_id: str, room_id: str
@@ -527,7 +555,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         content_json = json_encoder.encode(content)
 
@@ -554,7 +581,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
     async def remove_account_data_for_room(
         self, user_id: str, room_id: str, account_data_type: str
-    ) -> Optional[int]:
+    ) -> int:
         """Delete the room account data for the user of a given type.
 
         Args:
@@ -567,7 +594,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             data to delete.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def _remove_account_data_for_room_txn(
             txn: LoggingTransaction, next_id: int
@@ -606,15 +632,13 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 next_id,
             )
 
-            if not row_updated:
-                return None
-
-            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_room_account_data_for_user.invalidate((user_id,))
-            self.get_account_data_for_room.invalidate((user_id, room_id))
-            self.get_account_data_for_room_and_type.prefill(
-                (user_id, room_id, account_data_type), {}
-            )
+            if row_updated:
+                self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+                self.get_room_account_data_for_user.invalidate((user_id,))
+                self.get_account_data_for_room.invalidate((user_id, room_id))
+                self.get_account_data_for_room_and_type.prefill(
+                    (user_id, room_id, account_data_type), {}
+                )
 
         return self._account_data_id_gen.get_current_token()
 
@@ -632,7 +656,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             The maximum stream ID.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
@@ -722,7 +745,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         self,
         user_id: str,
         account_data_type: str,
-    ) -> Optional[int]:
+    ) -> int:
         """
         Delete a single piece of user account data by type.
 
@@ -739,7 +762,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             to delete.
         """
         assert self._can_write_to_account_data
-        assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
 
         def _remove_account_data_for_user_txn(
             txn: LoggingTransaction, next_id: int
@@ -809,14 +831,12 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
                 next_id,
             )
 
-            if not row_updated:
-                return None
-
-            self._account_data_stream_cache.entity_has_changed(user_id, next_id)
-            self.get_global_account_data_for_user.invalidate((user_id,))
-            self.get_global_account_data_by_type_for_user.prefill(
-                (user_id, account_data_type), {}
-            )
+            if row_updated:
+                self._account_data_stream_cache.entity_has_changed(user_id, next_id)
+                self.get_global_account_data_for_user.invalidate((user_id,))
+                self.get_global_account_data_by_type_for_user.prefill(
+                    (user_id, account_data_type), {}
+                )
 
         return self._account_data_id_gen.get_current_token()
 
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 5b66431691..096dec7f87 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -266,9 +266,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         if relates_to:
             self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
             self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
-            self._attempt_to_invalidate_cache(
-                "get_aggregation_groups_for_event", (relates_to,)
-            )
             self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
             self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
             self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8e61aba454..0d75d9739a 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -721,8 +721,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                         ],
                     )
 
-                for (user_id, messages_by_device) in edu["messages"].items():
-                    for (device_id, msg) in messages_by_device.items():
+                for user_id, messages_by_device in edu["messages"].items():
+                    for device_id, msg in messages_by_device.items():
                         with start_active_span("store_outgoing_to_device_message"):
                             set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
                             set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
@@ -959,7 +959,6 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
         def _remove_dead_devices_from_device_inbox_txn(
             txn: LoggingTransaction,
         ) -> Tuple[int, bool]:
-
             if "max_stream_id" in progress:
                 max_stream_id = progress["max_stream_id"]
             else:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 1ca66d57d4..5503621ad6 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -52,7 +52,6 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     StreamIdGenerator,
 )
 from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
@@ -91,7 +90,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+        self._device_list_id_gen = StreamIdGenerator(
             db_conn,
             hs.get_replication_notifier(),
             "device_lists_stream",
@@ -512,7 +511,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             results.append(("org.matrix.signing_key_update", result))
 
         if issue_8631_logger.isEnabledFor(logging.DEBUG):
-            for (user_id, edu) in results:
+            for user_id, edu in results:
                 issue_8631_logger.debug(
                     "device update to %s for %s from %s to %s: %s",
                     destination,
@@ -712,9 +711,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             The new stream ID.
         """
 
-        # TODO: this looks like it's _writing_. Should this be on DeviceStore rather
-        #  than DeviceWorkerStore?
-        async with self._device_list_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
+        async with self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -1316,7 +1313,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 )
             """
             count = 0
-            for (destination, user_id, stream_id, device_id) in rows:
+            for destination, user_id, stream_id, device_id in rows:
                 txn.execute(
                     delete_sql, (destination, user_id, stream_id, stream_id, device_id)
                 )
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 6240f9a75e..9f8d2e4bea 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -108,7 +108,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             raise StoreError(404, "No backup with that version exists")
 
         values = []
-        for (room_id, session_id, room_key) in room_keys:
+        for room_id, session_id, room_key in room_keys:
             values.append(
                 (
                     user_id,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 2c2d145666..b9c39b1718 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -268,7 +268,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             )
 
             # add each cross-signing signature to the correct device in the result dict.
-            for (user_id, key_id, device_id, signature) in cross_sigs_result:
+            for user_id, key_id, device_id, signature in cross_sigs_result:
                 target_device_result = result[user_id][device_id]
                 # We've only looked up cross-signatures for non-deleted devices with key
                 # data.
@@ -311,7 +311,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         # devices.
         user_list = []
         user_device_list = []
-        for (user_id, device_id) in query_list:
+        for user_id, device_id in query_list:
             if device_id is None:
                 user_list.append(user_id)
             else:
@@ -353,7 +353,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
             txn.execute(sql, query_params)
 
-            for (user_id, device_id, display_name, key_json) in txn:
+            for user_id, device_id, display_name, key_json in txn:
                 assert device_id is not None
                 if include_deleted_devices:
                     deleted_devices.remove((user_id, device_id))
@@ -382,7 +382,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         signature_query_clauses = []
         signature_query_params = []
 
-        for (user_id, device_id) in device_query:
+        for user_id, device_id in device_query:
             signature_query_clauses.append(
                 "target_user_id = ? AND target_device_id = ? AND user_id = ?"
             )
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ca780cca36..ff3edeb716 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1612,7 +1612,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         latest_events: List[str],
         limit: int,
     ) -> List[str]:
-
         seen_events = set(earliest_events)
         front = set(latest_events) - seen_events
         event_results: List[str] = []
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 7996cbb557..a8a4ed4436 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -469,7 +469,6 @@ class PersistEventsStore:
         txn: LoggingTransaction,
         events: List[EventBase],
     ) -> None:
-
         # We only care about state events, so this if there are no state events.
         if not any(e.is_state() for e in events):
             return
@@ -2025,10 +2024,6 @@ class PersistEventsStore:
         self.store._invalidate_cache_and_stream(
             txn, self.store.get_relations_for_event, (redacted_relates_to,)
         )
-        if rel_type == RelationTypes.ANNOTATION:
-            self.store._invalidate_cache_and_stream(
-                txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
-            )
         if rel_type == RelationTypes.REFERENCE:
             self.store._invalidate_cache_and_stream(
                 txn, self.store.get_references_for_event, (redacted_relates_to,)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 584536111d..daef3685b0 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -709,7 +709,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             nbrows = 0
             last_row_event_id = ""
-            for (event_id, event_json_raw) in results:
+            for event_id, event_json_raw in results:
                 try:
                     event_json = db_to_json(event_json_raw)
 
@@ -1167,7 +1167,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             results = list(txn)
             # (event_id, parent_id, rel_type) for each relation
             relations_to_insert: List[Tuple[str, str, str]] = []
-            for (event_id, event_json_raw) in results:
+            for event_id, event_json_raw in results:
                 try:
                     event_json = db_to_json(event_json_raw)
                 except Exception as e:
@@ -1220,9 +1220,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
                         txn, self.get_relations_for_event, cache_tuple  # type: ignore[attr-defined]
                     )
                     self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
-                        txn, self.get_aggregation_groups_for_event, cache_tuple  # type: ignore[attr-defined]
-                    )
-                    self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                         txn, self.get_thread_summary, cache_tuple  # type: ignore[attr-defined]
                     )
 
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6d0ef10258..20b7a68362 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -72,7 +72,6 @@ from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
@@ -187,8 +186,8 @@ class EventsWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self._stream_id_gen: AbstractStreamIdTracker
-        self._backfill_id_gen: AbstractStreamIdTracker
+        self._stream_id_gen: AbstractStreamIdGenerator
+        self._backfill_id_gen: AbstractStreamIdGenerator
         if isinstance(database.engine, PostgresEngine):
             # If we're using Postgres than we can use `MultiWriterIdGenerator`
             # regardless of whether this process writes to the streams or not.
@@ -1493,7 +1492,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             txn.execute(redactions_sql + clause, args)
 
-            for (redacter, redacted) in txn:
+            for redacter, redacted in txn:
                 d = event_dict.get(redacted)
                 if d:
                     d.redactions.append(redacter)
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 12f3b601f1..8e57c8e5a0 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -17,7 +17,7 @@ from typing import Optional, Tuple, Union, cast
 
 from canonicaljson import encode_canonical_json
 
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict
@@ -46,8 +46,6 @@ class FilteringWorkerStore(SQLBaseStore):
 
         return db_to_json(def_json)
 
-
-class FilteringStore(FilteringWorkerStore):
     async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
         def_json = encode_canonical_json(user_filter)
 
@@ -79,4 +77,23 @@ class FilteringStore(FilteringWorkerStore):
 
             return filter_id
 
-        return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+        attempts = 0
+        while True:
+            # Try a few times.
+            # This is technically needed if a user tries to create two filters at once,
+            # leading to two concurrent transactions.
+            # The failure case would be:
+            # - SELECT filter_id ... filter_json = ? → both transactions return no rows
+            # - SELECT MAX(filter_id) ... → both transactions return e.g. 5
+            # - INSERT INTO ... → both transactions insert filter_id = 6
+            # One of the transactions will commit. The other will get a unique key
+            # constraint violation error (IntegrityError). This is not the same as a
+            # serialisability violation, which would be automatically retried by
+            # `runInteraction`.
+            try:
+                return await self.db_pool.runInteraction("add_user_filter", _do_txn)
+            except self.db_pool.engine.module.IntegrityError:
+                attempts += 1
+
+                if attempts >= 5:
+                    raise StoreError(500, "Couldn't generate a filter ID.")
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index b202c5eb87..fa8be214ce 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -196,7 +196,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         def get_local_media_by_user_paginate_txn(
             txn: LoggingTransaction,
         ) -> Tuple[List[Dict[str, Any]], int]:
-
             # Set ordering
             order_by_column = MediaSortOrder(order_by).value
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 9b2bbe060d..9f862f00c1 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -46,7 +46,6 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     IdGenerator,
     StreamIdGenerator,
 )
@@ -118,7 +117,7 @@ class PushRulesWorkerStore(
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+        self._push_rules_stream_id_gen = StreamIdGenerator(
             db_conn,
             hs.get_replication_notifier(),
             "push_rules_stream",
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index df53e726e6..9a24f7a655 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -36,7 +36,6 @@ from synapse.storage.database import (
 )
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
-    AbstractStreamIdTracker,
     StreamIdGenerator,
 )
 from synapse.types import JsonDict
@@ -60,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+        self._pushers_id_gen = StreamIdGenerator(
             db_conn,
             hs.get_replication_notifier(),
             "pushers",
@@ -344,7 +343,6 @@ class PusherWorkerStore(SQLBaseStore):
         last_user = progress.get("last_user", "")
 
         def _delete_pushers(txn: LoggingTransaction) -> int:
-
             sql = """
                 SELECT name FROM users
                 WHERE deactivated = ? and name > ?
@@ -392,7 +390,6 @@ class PusherWorkerStore(SQLBaseStore):
         last_pusher = progress.get("last_pusher", 0)
 
         def _delete_pushers(txn: LoggingTransaction) -> int:
-
             sql = """
                 SELECT p.id, access_token FROM pushers AS p
                 LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
@@ -449,7 +446,6 @@ class PusherWorkerStore(SQLBaseStore):
         last_pusher = progress.get("last_pusher", 0)
 
         def _delete_pushers(txn: LoggingTransaction) -> int:
-
             sql = """
                 SELECT p.id, p.user_name, p.app_id, p.pushkey
                 FROM pushers AS p
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index dddf49c2d5..074942b167 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -39,7 +39,7 @@ from synapse.storage.database import (
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.engines._base import IsolationLevel
 from synapse.storage.util.id_generators import (
-    AbstractStreamIdTracker,
+    AbstractStreamIdGenerator,
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
@@ -65,7 +65,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._receipts_id_gen: AbstractStreamIdTracker
+        self._receipts_id_gen: AbstractStreamIdGenerator
 
         if isinstance(database.engine, PostgresEngine):
             self._can_write_to_receipts = (
@@ -768,7 +768,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
             )
 
-        async with self._receipts_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
+        async with self._receipts_id_gen.get_next() as stream_id:
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self._insert_linearized_receipt_txn,
@@ -887,7 +887,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
         def _populate_receipt_event_stream_ordering_txn(
             txn: LoggingTransaction,
         ) -> bool:
-
             if "max_stream_id" in progress:
                 max_stream_id = progress["max_stream_id"]
             else:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 9a55e17624..717237e024 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1002,19 +1002,6 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="user_delete_threepid",
         )
 
-    async def user_delete_threepids(self, user_id: str) -> None:
-        """Delete all threepid this user has bound
-
-        Args:
-             user_id: The user id to delete all threepids of
-
-        """
-        await self.db_pool.simple_delete(
-            "user_threepids",
-            keyvalues={"user_id": user_id},
-            desc="user_delete_threepids",
-        )
-
     async def add_user_bound_threepid(
         self, user_id: str, medium: str, address: str, id_server: str
     ) -> None:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index fa3266c081..bc3a83919c 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -398,143 +398,6 @@ class RelationsWorkerStore(SQLBaseStore):
         return result is not None
 
     @cached()
-    async def get_aggregation_groups_for_event(
-        self, event_id: str
-    ) -> Sequence[JsonDict]:
-        raise NotImplementedError()
-
-    @cachedList(
-        cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
-    )
-    async def get_aggregation_groups_for_events(
-        self, event_ids: Collection[str]
-    ) -> Mapping[str, Optional[List[JsonDict]]]:
-        """Get a list of annotations on the given events, grouped by event type and
-        aggregation key, sorted by count.
-
-        This is used e.g. to get the what and how many reactions have happend
-        on an event.
-
-        Args:
-            event_ids: Fetch events that relate to these event IDs.
-
-        Returns:
-            A map of event IDs to a list of groups of annotations that match.
-            Each entry is a dict with `type`, `key` and `count` fields.
-        """
-        # The number of entries to return per event ID.
-        limit = 5
-
-        clause, args = make_in_list_sql_clause(
-            self.database_engine, "relates_to_id", event_ids
-        )
-        args.append(RelationTypes.ANNOTATION)
-
-        sql = f"""
-            SELECT
-                relates_to_id,
-                annotation.type,
-                aggregation_key,
-                COUNT(DISTINCT annotation.sender)
-            FROM events AS annotation
-            INNER JOIN event_relations USING (event_id)
-            INNER JOIN events AS parent ON
-                parent.event_id = relates_to_id
-                AND parent.room_id = annotation.room_id
-            WHERE
-                {clause}
-                AND relation_type = ?
-            GROUP BY relates_to_id, annotation.type, aggregation_key
-            ORDER BY relates_to_id, COUNT(*) DESC
-        """
-
-        def _get_aggregation_groups_for_events_txn(
-            txn: LoggingTransaction,
-        ) -> Mapping[str, List[JsonDict]]:
-            txn.execute(sql, args)
-
-            result: Dict[str, List[JsonDict]] = {}
-            for event_id, type, key, count in cast(
-                List[Tuple[str, str, str, int]], txn
-            ):
-                event_results = result.setdefault(event_id, [])
-
-                # Limit the number of results per event ID.
-                if len(event_results) == limit:
-                    continue
-
-                event_results.append({"type": type, "key": key, "count": count})
-
-            return result
-
-        return await self.db_pool.runInteraction(
-            "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn
-        )
-
-    async def get_aggregation_groups_for_users(
-        self, event_ids: Collection[str], users: FrozenSet[str]
-    ) -> Dict[str, Dict[Tuple[str, str], int]]:
-        """Fetch the partial aggregations for an event for specific users.
-
-        This is used, in conjunction with get_aggregation_groups_for_event, to
-        remove information from the results for ignored users.
-
-        Args:
-            event_ids: Fetch events that relate to these event IDs.
-            users: The users to fetch information for.
-
-        Returns:
-            A map of event ID to a map of (event type, aggregation key) to a
-            count of users.
-        """
-
-        if not users:
-            return {}
-
-        events_sql, args = make_in_list_sql_clause(
-            self.database_engine, "relates_to_id", event_ids
-        )
-
-        users_sql, users_args = make_in_list_sql_clause(
-            self.database_engine, "annotation.sender", users
-        )
-        args.extend(users_args)
-        args.append(RelationTypes.ANNOTATION)
-
-        sql = f"""
-            SELECT
-                relates_to_id,
-                annotation.type,
-                aggregation_key,
-                COUNT(DISTINCT annotation.sender)
-            FROM events AS annotation
-            INNER JOIN event_relations USING (event_id)
-            INNER JOIN events AS parent ON
-                parent.event_id = relates_to_id
-                AND parent.room_id = annotation.room_id
-            WHERE {events_sql} AND {users_sql} AND relation_type = ?
-            GROUP BY relates_to_id, annotation.type, aggregation_key
-            ORDER BY relates_to_id, COUNT(*) DESC
-        """
-
-        def _get_aggregation_groups_for_users_txn(
-            txn: LoggingTransaction,
-        ) -> Dict[str, Dict[Tuple[str, str], int]]:
-            txn.execute(sql, args)
-
-            result: Dict[str, Dict[Tuple[str, str], int]] = {}
-            for event_id, type, key, count in cast(
-                List[Tuple[str, str, str, int]], txn
-            ):
-                result.setdefault(event_id, {})[(type, key)] = count
-
-            return result
-
-        return await self.db_pool.runInteraction(
-            "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
-        )
-
-    @cached()
     async def get_references_for_event(self, event_id: str) -> List[JsonDict]:
         raise NotImplementedError()
 
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 644bbb8878..3825bd6079 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1417,6 +1417,204 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             get_un_partial_stated_rooms_from_stream_txn,
         )
 
+    async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
+        """Retrieve an event report
+
+        Args:
+            report_id: ID of reported event in database
+        Returns:
+            JSON dict of information from an event report or None if the
+            report does not exist.
+        """
+
+        def _get_event_report_txn(
+            txn: LoggingTransaction, report_id: int
+        ) -> Optional[Dict[str, Any]]:
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.content,
+                    events.sender,
+                    room_stats_state.canonical_alias,
+                    room_stats_state.name,
+                    event_json.json AS event_json
+                FROM event_reports AS er
+                LEFT JOIN events
+                    ON events.event_id = er.event_id
+                JOIN event_json
+                    ON event_json.event_id = er.event_id
+                JOIN room_stats_state
+                    ON room_stats_state.room_id = er.room_id
+                WHERE er.id = ?
+            """
+
+            txn.execute(sql, [report_id])
+            row = txn.fetchone()
+
+            if not row:
+                return None
+
+            event_report = {
+                "id": row[0],
+                "received_ts": row[1],
+                "room_id": row[2],
+                "event_id": row[3],
+                "user_id": row[4],
+                "score": db_to_json(row[5]).get("score"),
+                "reason": db_to_json(row[5]).get("reason"),
+                "sender": row[6],
+                "canonical_alias": row[7],
+                "name": row[8],
+                "event_json": db_to_json(row[9]),
+            }
+
+            return event_report
+
+        return await self.db_pool.runInteraction(
+            "get_event_report", _get_event_report_txn, report_id
+        )
+
+    async def get_event_reports_paginate(
+        self,
+        start: int,
+        limit: int,
+        direction: Direction = Direction.BACKWARDS,
+        user_id: Optional[str] = None,
+        room_id: Optional[str] = None,
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Retrieve a paginated list of event reports
+
+        Args:
+            start: event offset to begin the query from
+            limit: number of rows to retrieve
+            direction: Whether to fetch the most recent first (backwards) or the
+                oldest first (forwards)
+            user_id: search for user_id. Ignored if user_id is None
+            room_id: search for room_id. Ignored if room_id is None
+        Returns:
+            Tuple of:
+                json list of event reports
+                total number of event reports matching the filter criteria
+        """
+
+        def _get_event_reports_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], int]:
+            filters = []
+            args: List[object] = []
+
+            if user_id:
+                filters.append("er.user_id LIKE ?")
+                args.extend(["%" + user_id + "%"])
+            if room_id:
+                filters.append("er.room_id LIKE ?")
+                args.extend(["%" + room_id + "%"])
+
+            if direction == Direction.BACKWARDS:
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+            # We join on room_stats_state despite not using any columns from it
+            # because the join can influence the number of rows returned;
+            # e.g. a room that doesn't have state, maybe because it was deleted.
+            # The query returning the total count should be consistent with
+            # the query returning the results.
+            sql = """
+                SELECT COUNT(*) as total_event_reports
+                FROM event_reports AS er
+                JOIN room_stats_state ON room_stats_state.room_id = er.room_id
+                {}
+                """.format(
+                where_clause
+            )
+            txn.execute(sql, args)
+            count = cast(Tuple[int], txn.fetchone())[0]
+
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.content,
+                    events.sender,
+                    room_stats_state.canonical_alias,
+                    room_stats_state.name
+                FROM event_reports AS er
+                LEFT JOIN events
+                    ON events.event_id = er.event_id
+                JOIN room_stats_state
+                    ON room_stats_state.room_id = er.room_id
+                {where_clause}
+                ORDER BY er.received_ts {order}
+                LIMIT ?
+                OFFSET ?
+            """.format(
+                where_clause=where_clause,
+                order=order,
+            )
+
+            args += [limit, start]
+            txn.execute(sql, args)
+
+            event_reports = []
+            for row in txn:
+                try:
+                    s = db_to_json(row[5]).get("score")
+                    r = db_to_json(row[5]).get("reason")
+                except Exception:
+                    logger.error("Unable to parse json from event_reports: %s", row[0])
+                    continue
+                event_reports.append(
+                    {
+                        "id": row[0],
+                        "received_ts": row[1],
+                        "room_id": row[2],
+                        "event_id": row[3],
+                        "user_id": row[4],
+                        "score": s,
+                        "reason": r,
+                        "sender": row[6],
+                        "canonical_alias": row[7],
+                        "name": row[8],
+                    }
+                )
+
+            return event_reports, count
+
+        return await self.db_pool.runInteraction(
+            "get_event_reports_paginate", _get_event_reports_paginate_txn
+        )
+
+    async def delete_event_report(self, report_id: int) -> bool:
+        """Remove an event report from database.
+
+        Args:
+            report_id: Report to delete
+
+        Returns:
+            Whether the report was successfully deleted or not.
+        """
+        try:
+            await self.db_pool.simple_delete_one(
+                table="event_reports",
+                keyvalues={"id": report_id},
+                desc="delete_event_report",
+            )
+        except StoreError:
+            # Deletion failed because report does not exist
+            return False
+
+        return True
+
 
 class _BackgroundUpdates:
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -2139,7 +2337,19 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         reason: Optional[str],
         content: JsonDict,
         received_ts: int,
-    ) -> None:
+    ) -> int:
+        """Add an event report
+
+        Args:
+            room_id: Room that contains the reported event.
+            event_id: The reported event.
+            user_id: User who reports the event.
+            reason: Description that the user specifies.
+            content: Report request body (score and reason).
+            received_ts: Time when the user submitted the report (milliseconds).
+        Returns:
+            Id of the event report.
+        """
         next_id = self._event_reports_id_gen.get_next()
         await self.db_pool.simple_insert(
             table="event_reports",
@@ -2154,184 +2364,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
             },
             desc="add_event_report",
         )
-
-    async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
-        """Retrieve an event report
-
-        Args:
-            report_id: ID of reported event in database
-        Returns:
-            JSON dict of information from an event report or None if the
-            report does not exist.
-        """
-
-        def _get_event_report_txn(
-            txn: LoggingTransaction, report_id: int
-        ) -> Optional[Dict[str, Any]]:
-
-            sql = """
-                SELECT
-                    er.id,
-                    er.received_ts,
-                    er.room_id,
-                    er.event_id,
-                    er.user_id,
-                    er.content,
-                    events.sender,
-                    room_stats_state.canonical_alias,
-                    room_stats_state.name,
-                    event_json.json AS event_json
-                FROM event_reports AS er
-                LEFT JOIN events
-                    ON events.event_id = er.event_id
-                JOIN event_json
-                    ON event_json.event_id = er.event_id
-                JOIN room_stats_state
-                    ON room_stats_state.room_id = er.room_id
-                WHERE er.id = ?
-            """
-
-            txn.execute(sql, [report_id])
-            row = txn.fetchone()
-
-            if not row:
-                return None
-
-            event_report = {
-                "id": row[0],
-                "received_ts": row[1],
-                "room_id": row[2],
-                "event_id": row[3],
-                "user_id": row[4],
-                "score": db_to_json(row[5]).get("score"),
-                "reason": db_to_json(row[5]).get("reason"),
-                "sender": row[6],
-                "canonical_alias": row[7],
-                "name": row[8],
-                "event_json": db_to_json(row[9]),
-            }
-
-            return event_report
-
-        return await self.db_pool.runInteraction(
-            "get_event_report", _get_event_report_txn, report_id
-        )
-
-    async def get_event_reports_paginate(
-        self,
-        start: int,
-        limit: int,
-        direction: Direction = Direction.BACKWARDS,
-        user_id: Optional[str] = None,
-        room_id: Optional[str] = None,
-    ) -> Tuple[List[Dict[str, Any]], int]:
-        """Retrieve a paginated list of event reports
-
-        Args:
-            start: event offset to begin the query from
-            limit: number of rows to retrieve
-            direction: Whether to fetch the most recent first (backwards) or the
-                oldest first (forwards)
-            user_id: search for user_id. Ignored if user_id is None
-            room_id: search for room_id. Ignored if room_id is None
-        Returns:
-            Tuple of:
-                json list of event reports
-                total number of event reports matching the filter criteria
-        """
-
-        def _get_event_reports_paginate_txn(
-            txn: LoggingTransaction,
-        ) -> Tuple[List[Dict[str, Any]], int]:
-            filters = []
-            args: List[object] = []
-
-            if user_id:
-                filters.append("er.user_id LIKE ?")
-                args.extend(["%" + user_id + "%"])
-            if room_id:
-                filters.append("er.room_id LIKE ?")
-                args.extend(["%" + room_id + "%"])
-
-            if direction == Direction.BACKWARDS:
-                order = "DESC"
-            else:
-                order = "ASC"
-
-            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
-
-            # We join on room_stats_state despite not using any columns from it
-            # because the join can influence the number of rows returned;
-            # e.g. a room that doesn't have state, maybe because it was deleted.
-            # The query returning the total count should be consistent with
-            # the query returning the results.
-            sql = """
-                SELECT COUNT(*) as total_event_reports
-                FROM event_reports AS er
-                JOIN room_stats_state ON room_stats_state.room_id = er.room_id
-                {}
-                """.format(
-                where_clause
-            )
-            txn.execute(sql, args)
-            count = cast(Tuple[int], txn.fetchone())[0]
-
-            sql = """
-                SELECT
-                    er.id,
-                    er.received_ts,
-                    er.room_id,
-                    er.event_id,
-                    er.user_id,
-                    er.content,
-                    events.sender,
-                    room_stats_state.canonical_alias,
-                    room_stats_state.name
-                FROM event_reports AS er
-                LEFT JOIN events
-                    ON events.event_id = er.event_id
-                JOIN room_stats_state
-                    ON room_stats_state.room_id = er.room_id
-                {where_clause}
-                ORDER BY er.received_ts {order}
-                LIMIT ?
-                OFFSET ?
-            """.format(
-                where_clause=where_clause,
-                order=order,
-            )
-
-            args += [limit, start]
-            txn.execute(sql, args)
-
-            event_reports = []
-            for row in txn:
-                try:
-                    s = db_to_json(row[5]).get("score")
-                    r = db_to_json(row[5]).get("reason")
-                except Exception:
-                    logger.error("Unable to parse json from event_reports: %s", row[0])
-                    continue
-                event_reports.append(
-                    {
-                        "id": row[0],
-                        "received_ts": row[1],
-                        "room_id": row[2],
-                        "event_id": row[3],
-                        "user_id": row[4],
-                        "score": s,
-                        "reason": r,
-                        "sender": row[6],
-                        "canonical_alias": row[7],
-                        "name": row[8],
-                    }
-                )
-
-            return event_reports, count
-
-        return await self.db_pool.runInteraction(
-            "get_event_reports_paginate", _get_event_reports_paginate_txn
-        )
+        return next_id
 
     async def block_room(self, room_id: str, user_id: str) -> None:
         """Marks the room as blocked.
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 3fe433f66c..a7aae661d8 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -122,7 +122,6 @@ class SearchWorkerStore(SQLBaseStore):
 
 
 class SearchBackgroundUpdateStore(SearchWorkerStore):
-
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
@@ -615,7 +614,6 @@ class SearchStore(SearchBackgroundUpdateStore):
             """
             count_args = [search_query] + count_args
         elif isinstance(self.database_engine, Sqlite3Engine):
-
             # We use CROSS JOIN here to ensure we use the right indexes.
             # https://sqlite.org/optoverview.html#crossjoin
             #
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ba325d390b..ebb2ae964f 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -490,7 +490,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
 
 class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
-
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
     DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index d7b7d0c3c9..d3393d8e49 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -461,7 +461,7 @@ class StatsStore(StateDeltasStore):
         insert_cols = []
         qargs = []
 
-        for (key, val) in chain(
+        for key, val in chain(
             keyvalues.items(), absolutes.items(), additive_relatives.items()
         ):
             insert_cols.append(key)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 818c46182e..ac5fbf6b86 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -87,6 +87,7 @@ MAX_STREAM_SIZE = 1000
 _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
+
 # Used as return values for pagination APIs
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _EventDictReturn:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 6b33d809b6..6d72bd9f67 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -573,7 +573,6 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         def get_destination_rooms_paginate_txn(
             txn: LoggingTransaction,
         ) -> Tuple[List[JsonDict], int]:
-
             if direction == Direction.BACKWARDS:
                 order = "DESC"
             else:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f6a6fd4079..f16a509ac4 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -14,6 +14,7 @@
 
 import logging
 import re
+import unicodedata
 from typing import (
     TYPE_CHECKING,
     Iterable,
@@ -98,7 +99,6 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     async def _populate_user_directory_createtables(
         self, progress: JsonDict, batch_size: int
     ) -> int:
-
         # Get all the rooms that we want to process.
         def _make_staging_area(txn: LoggingTransaction) -> None:
             sql = (
@@ -491,6 +491,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                 values={"display_name": display_name, "avatar_url": avatar_url},
             )
 
+            # The display name that goes into the database index.
+            index_display_name = display_name
+            if index_display_name is not None:
+                index_display_name = _filter_text_for_index(index_display_name)
+
             if isinstance(self.database_engine, PostgresEngine):
                 # We weight the localpart most highly, then display name and finally
                 # server name
@@ -508,11 +513,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                         user_id,
                         get_localpart_from_id(user_id),
                         get_domain_from_id(user_id),
-                        display_name,
+                        index_display_name,
                     ),
                 )
             elif isinstance(self.database_engine, Sqlite3Engine):
-                value = "%s %s" % (user_id, display_name) if display_name else user_id
+                value = (
+                    "%s %s" % (user_id, index_display_name)
+                    if index_display_name
+                    else user_id
+                )
                 self.db_pool.simple_upsert_txn(
                     txn,
                     table="user_directory_search",
@@ -897,6 +906,41 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         return {"limited": limited, "results": results[0:limit]}
 
 
+def _filter_text_for_index(text: str) -> str:
+    """Transforms text before it is inserted into the user directory index, or searched
+    for in the user directory index.
+
+    Note that the user directory search table needs to be rebuilt whenever this function
+    changes.
+    """
+    # Lowercase the text, to make searches case-insensitive.
+    # This is necessary for both PostgreSQL and SQLite. PostgreSQL's
+    # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using
+    # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at
+    # all.
+    text = text.lower()
+
+    # Normalize the text. NFKC normalization has two effects:
+    #  1. It canonicalizes the text, ie. maps all visually identical strings to the same
+    #     string. For example, ["e", "◌́"] is mapped to ["é"].
+    #  2. It maps strings that are roughly equivalent to the same string.
+    #     For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to
+    #     ["i", "9"].
+    text = unicodedata.normalize("NFKC", text)
+
+    # Note that nothing is done to make searches accent-insensitive.
+    # That could be achieved by converting to NFKD form instead (with combining accents
+    # split out) and filtering out combining accents using `unicodedata.combining(c)`.
+    # The downside of this may be noisier search results, since search terms with
+    # explicit accents will match characters with no accents, or completely different
+    # accents.
+    #
+    # text = unicodedata.normalize("NFKD", text)
+    # text = "".join([c for c in text if not unicodedata.combining(c)])
+
+    return text
+
+
 def _parse_query_sqlite(search_term: str) -> str:
     """Takes a plain unicode string from the user and converts it into a form
     that can be passed to database.
@@ -906,6 +950,7 @@ def _parse_query_sqlite(search_term: str) -> str:
     We specifically add both a prefix and non prefix matching term so that
     exact matches get ranked higher.
     """
+    search_term = _filter_text_for_index(search_term)
 
     # Pull out the individual words, discarding any non-word characters.
     results = _parse_words(search_term)
@@ -918,11 +963,21 @@ def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
     We use this so that we can add prefix matching, which isn't something
     that is supported by default.
     """
-    results = _parse_words(search_term)
+    search_term = _filter_text_for_index(search_term)
+
+    escaped_words = []
+    for word in _parse_words(search_term):
+        # Postgres tsvector and tsquery quoting rules:
+        # words potentially containing punctuation should be quoted
+        # and then existing quotes and backslashes should be doubled
+        # See: https://www.postgresql.org/docs/current/datatype-textsearch.html#DATATYPE-TSQUERY
 
-    both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
-    exact = " & ".join("%s" % (result,) for result in results)
-    prefix = " & ".join("%s:*" % (result,) for result in results)
+        quoted_word = word.replace("'", "''").replace("\\", "\\\\")
+        escaped_words.append(f"'{quoted_word}'")
+
+    both = " & ".join("(%s:* | %s)" % (word, word) for word in escaped_words)
+    exact = " & ".join("%s" % (word,) for word in escaped_words)
+    prefix = " & ".join("%s:*" % (word,) for word in escaped_words)
 
     return both, exact, prefix
 
@@ -944,6 +999,14 @@ def _parse_words(search_term: str) -> List[str]:
     if USE_ICU:
         return _parse_words_with_icu(search_term)
 
+    return _parse_words_with_regex(search_term)
+
+
+def _parse_words_with_regex(search_term: str) -> List[str]:
+    """
+    Break down search term into words, when we don't have ICU available.
+    See: `_parse_words`
+    """
     return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
 
 
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index d743282f13..097dea5182 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -251,7 +251,6 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
 
 
 class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
-
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 1a7232b276..bf4cdfdf29 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Se
 import attr
 
 from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -257,14 +259,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         member_filter, non_member_filter = state_filter.get_member_split()
 
         # Now we look them up in the member and non-member caches
-        (
-            non_member_state,
-            incomplete_groups_nm,
-        ) = self._get_state_for_groups_using_cache(
+        non_member_state, incomplete_groups_nm = self._get_state_for_groups_using_cache(
             groups, self._state_group_cache, state_filter=non_member_filter
         )
 
-        (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
+        member_state, incomplete_groups_m = self._get_state_for_groups_using_cache(
             groups, self._state_group_members_cache, state_filter=member_filter
         )
 
@@ -404,6 +403,123 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
                 fetched_keys=non_member_types,
             )
 
+    async def store_state_deltas_for_batched(
+        self,
+        events_and_context: List[Tuple[EventBase, UnpersistedEventContextBase]],
+        room_id: str,
+        prev_group: int,
+    ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+        """Generate and store state deltas for a group of events and contexts created to be
+        batch persisted. Note that all the events must be in a linear chain (ie a <- b <- c).
+
+        Args:
+            events_and_context: the events to generate and store a state groups for
+            and their associated contexts
+            room_id: the id of the room the events were created for
+            prev_group: the state group of the last event persisted before the batched events
+            were created
+        """
+
+        def insert_deltas_group_txn(
+            txn: LoggingTransaction,
+            events_and_context: List[Tuple[EventBase, UnpersistedEventContext]],
+            prev_group: int,
+        ) -> List[Tuple[EventBase, UnpersistedEventContext]]:
+            """Generate and store state groups for the provided events and contexts.
+
+            Requires that we have the state as a delta from the last persisted state group.
+
+            Returns:
+                A list of state groups
+            """
+            is_in_db = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="state_groups",
+                keyvalues={"id": prev_group},
+                retcol="id",
+                allow_none=True,
+            )
+            if not is_in_db:
+                raise Exception(
+                    "Trying to persist state with unpersisted prev_group: %r"
+                    % (prev_group,)
+                )
+
+            num_state_groups = sum(
+                1 for event, _ in events_and_context if event.is_state()
+            )
+
+            state_groups = self._state_group_seq_gen.get_next_mult_txn(
+                txn, num_state_groups
+            )
+
+            sg_before = prev_group
+            state_group_iter = iter(state_groups)
+            for event, context in events_and_context:
+                if not event.is_state():
+                    context.state_group_after_event = sg_before
+                    context.state_group_before_event = sg_before
+                    continue
+
+                sg_after = next(state_group_iter)
+                context.state_group_after_event = sg_after
+                context.state_group_before_event = sg_before
+                context.state_delta_due_to_event = {
+                    (event.type, event.state_key): event.event_id
+                }
+                sg_before = sg_after
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups",
+                keys=("id", "room_id", "event_id"),
+                values=[
+                    (context.state_group_after_event, room_id, event.event_id)
+                    for event, context in events_and_context
+                    if event.is_state()
+                ],
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_group_edges",
+                keys=("state_group", "prev_state_group"),
+                values=[
+                    (
+                        context.state_group_after_event,
+                        context.state_group_before_event,
+                    )
+                    for event, context in events_and_context
+                    if event.is_state()
+                ],
+            )
+
+            self.db_pool.simple_insert_many_txn(
+                txn,
+                table="state_groups_state",
+                keys=("state_group", "room_id", "type", "state_key", "event_id"),
+                values=[
+                    (
+                        context.state_group_after_event,
+                        room_id,
+                        key[0],
+                        key[1],
+                        state_id,
+                    )
+                    for event, context in events_and_context
+                    if context.state_delta_due_to_event is not None
+                    for key, state_id in context.state_delta_due_to_event.items()
+                ],
+            )
+            return events_and_context
+
+        return await self.db_pool.runInteraction(
+            "store_state_deltas_for_batched.insert_deltas_group",
+            insert_deltas_group_txn,
+            events_and_context,
+            prev_group,
+        )
+
     async def store_state_group(
         self,
         event_id: str,
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index a182e8a098..d1ccb7390a 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -25,7 +25,7 @@ try:
 except ImportError:
 
     class PostgresEngine(BaseDatabaseEngine):  # type: ignore[no-redef]
-        def __new__(cls, *args: object, **kwargs: object) -> NoReturn:  # type: ignore[misc]
+        def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
             raise RuntimeError(
                 f"Cannot create {cls.__name__} -- psycopg2 module is not installed"
             )
@@ -36,7 +36,7 @@ try:
 except ImportError:
 
     class Sqlite3Engine(BaseDatabaseEngine):  # type: ignore[no-redef]
-        def __new__(cls, *args: object, **kwargs: object) -> NoReturn:  # type: ignore[misc]
+        def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
             raise RuntimeError(
                 f"Cannot create {cls.__name__} -- sqlite3 module is not installed"
             )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6c335a9315..2a1c6fa31b 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -563,7 +563,7 @@ def _apply_module_schemas(
     """
     # This is the old way for password_auth_provider modules to make changes
     # to the database. This should instead be done using the module API
-    for (mod, _config) in config.authproviders.password_providers:
+    for mod, _config in config.authproviders.password_providers:
         if not hasattr(mod, "get_db_schema_files"):
             continue
         modname = ".".join((mod.__module__, mod.__name__))
@@ -591,7 +591,7 @@ def _apply_module_schema_files(
         (modname,),
     )
     applied_deltas = {d for d, in cur}
-    for (name, stream) in names_and_streams:
+    for name, stream in names_and_streams:
         if name in applied_deltas:
             continue
 
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 0031df1e06..56a0048539 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -12,7 +12,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from types import TracebackType
-from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
+from typing import (
+    Any,
+    Callable,
+    Iterator,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+)
 
 from typing_extensions import Protocol
 
@@ -112,15 +123,35 @@ class DBAPI2Module(Protocol):
     #   extends from this hierarchy. See
     #     https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions
     #     https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
-    Warning: Type[Exception]
-    Error: Type[Exception]
+    #
+    # Note: rather than
+    #     x: T
+    # we write
+    #     @property
+    #     def x(self) -> T: ...
+    # which expresses that the protocol attribute `x` is read-only. The mypy docs
+    #     https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected
+    # explain why this is necessary for safety. TL;DR: we shouldn't be able to write
+    # to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 .
+    @property
+    def Warning(self) -> Type[Exception]:
+        ...
+
+    @property
+    def Error(self) -> Type[Exception]:
+        ...
 
     # Errors are divided into `InterfaceError`s (something went wrong in the database
     # driver) and `DatabaseError`s (something went wrong in the database). These are
     # both subclasses of `Error`, but we can't currently express this in type
     # annotations due to https://github.com/python/mypy/issues/8397
-    InterfaceError: Type[Exception]
-    DatabaseError: Type[Exception]
+    @property
+    def InterfaceError(self) -> Type[Exception]:
+        ...
+
+    @property
+    def DatabaseError(self) -> Type[Exception]:
+        ...
 
     # Everything below is a subclass of `DatabaseError`.
 
@@ -128,7 +159,9 @@ class DBAPI2Module(Protocol):
     # - An integer was too big for its data type.
     # - An invalid date time was provided.
     # - A string contained a null code point.
-    DataError: Type[Exception]
+    @property
+    def DataError(self) -> Type[Exception]:
+        ...
 
     # Roughly: something went wrong in the database, but it's not within the application
     # programmer's control. Examples:
@@ -138,28 +171,45 @@ class DBAPI2Module(Protocol):
     # - A serialisation failure occurred.
     # - The database ran out of resources, such as storage, memory, connections, etc.
     # - The database encountered an error from the operating system.
-    OperationalError: Type[Exception]
+    @property
+    def OperationalError(self) -> Type[Exception]:
+        ...
 
     # Roughly: we've given the database data which breaks a rule we asked it to enforce.
     # Examples:
     # - Stop, criminal scum! You violated the foreign key constraint
     # - Also check constraints, non-null constraints, etc.
-    IntegrityError: Type[Exception]
+    @property
+    def IntegrityError(self) -> Type[Exception]:
+        ...
 
     # Roughly: something went wrong within the database server itself.
-    InternalError: Type[Exception]
+    @property
+    def InternalError(self) -> Type[Exception]:
+        ...
 
     # Roughly: the application did something silly that needs to be fixed. Examples:
     # - We don't have permissions to do something.
     # - We tried to create a table with duplicate column names.
     # - We tried to use a reserved name.
     # - We referred to a column that doesn't exist.
-    ProgrammingError: Type[Exception]
+    @property
+    def ProgrammingError(self) -> Type[Exception]:
+        ...
 
     # Roughly: we've tried to do something that this database doesn't support.
-    NotSupportedError: Type[Exception]
+    @property
+    def NotSupportedError(self) -> Type[Exception]:
+        ...
 
-    def connect(self, **parameters: object) -> Connection:
+    # We originally wrote
+    # def connect(self, *args, **kwargs) -> Connection: ...
+    # But mypy doesn't seem to like that because sqlite3.connect takes a mandatory
+    # positional argument. We can't make that part of the signature though, because
+    # psycopg2.connect doesn't have a mandatory positional argument. Instead, we use
+    # the following slightly unusual workaround.
+    @property
+    def connect(self) -> Callable[..., Connection]:
         ...
 
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9adff3f4f5..d2c874b9a8 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -93,8 +93,11 @@ def _load_current_id(
     return res
 
 
-class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
-    """Tracks the "current" stream ID of a stream that may have multiple writers.
+class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
+    """Generates or tracks stream IDs for a stream that may have multiple writers.
+
+    Each stream ID represents a write transaction, whose completion is tracked
+    so that the "current" stream ID of the stream can be determined.
 
     Stream IDs are monotonically increasing or decreasing integers representing write
     transactions. The "current" stream ID is the stream ID such that all transactions
@@ -130,16 +133,6 @@ class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
         """
         raise NotImplementedError()
 
-
-class AbstractStreamIdGenerator(AbstractStreamIdTracker):
-    """Generates stream IDs for a stream that may have multiple writers.
-
-    Each stream ID represents a write transaction, whose completion is tracked
-    so that the "current" stream ID of the stream can be determined.
-
-    See `AbstractStreamIdTracker` for more details.
-    """
-
     @abc.abstractmethod
     def get_next(self) -> AsyncContextManager[int]:
         """
@@ -158,6 +151,15 @@ class AbstractStreamIdGenerator(AbstractStreamIdTracker):
         """
         raise NotImplementedError()
 
+    @abc.abstractmethod
+    def get_next_txn(self, txn: LoggingTransaction) -> int:
+        """
+        Usage:
+            stream_id_gen.get_next_txn(txn)
+            # ... persist events ...
+        """
+        raise NotImplementedError()
+
 
 class StreamIdGenerator(AbstractStreamIdGenerator):
     """Generates and tracks stream IDs for a stream with a single writer.
@@ -263,6 +265,40 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
 
         return _AsyncCtxManagerWrapper(manager())
 
+    def get_next_txn(self, txn: LoggingTransaction) -> int:
+        """
+        Retrieve the next stream ID from within a database transaction.
+
+        Clean-up functions will be called when the transaction finishes.
+
+        Args:
+            txn: The database transaction object.
+
+        Returns:
+            The next stream ID.
+        """
+        if not self._is_writer:
+            raise Exception("Tried to allocate stream ID on non-writer")
+
+        # Get the next stream ID.
+        with self._lock:
+            self._current += self._step
+            next_id = self._current
+
+            self._unfinished_ids[next_id] = next_id
+
+        def clear_unfinished_id(id_to_clear: int) -> None:
+            """A function to mark processing this ID as finished"""
+            with self._lock:
+                self._unfinished_ids.pop(id_to_clear)
+
+        # Mark this ID as finished once the database transaction itself finishes.
+        txn.call_after(clear_unfinished_id, next_id)
+        txn.call_on_exception(clear_unfinished_id, next_id)
+
+        # Return the new ID.
+        return next_id
+
     def get_current_token(self) -> int:
         if not self._is_writer:
             return self._current
@@ -568,7 +604,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         """
         Usage:
 
-            stream_id = stream_id_gen.get_next(txn)
+            stream_id = stream_id_gen.get_next_txn(txn)
             # ... persist event ...
         """
 
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 75268cbe15..80915216de 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -205,7 +205,7 @@ class LocalSequenceGenerator(SequenceGenerator):
         """
         Args:
             get_first_callback: a callback which is called on the first call to
-                 get_next_id_txn; should return the curreent maximum id
+                 get_next_id_txn; should return the current maximum id
         """
         # the callback. this is cleared after it is called, so that it can be GCed.
         self._callback: Optional[GetFirstCallbackType] = get_first_callback