summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2022-03-28 13:54:02 +0100
committerBrendan Abolivier <babolivier@matrix.org>2022-03-28 13:54:02 +0100
commit25507bffc67c40e83cbcd4a79fdfee3667855a7c (patch)
tree5620b2a06a5a9894ac875ddcf3b232db45cae48d /synapse/storage
parentMerge branch 'develop' of github.com:matrix-org/synapse into babolivier/sign_... (diff)
parentAdd restrictions by default to open registration in Synapse (#12091) (diff)
downloadsynapse-github/babolivier/sign_json_module.tar.xz
Merge branch 'develop' into babolivier/sign_json_module github/babolivier/sign_json_module babolivier/sign_json_module
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/background_updates.py47
-rw-r--r--synapse/storage/database.py75
-rw-r--r--synapse/storage/databases/main/account_data.py41
-rw-r--r--synapse/storage/databases/main/cache.py67
-rw-r--r--synapse/storage/databases/main/deviceinbox.py3
-rw-r--r--synapse/storage/databases/main/events.py61
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/group_server.py156
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py38
-rw-r--r--synapse/storage/databases/main/registration.py2
-rw-r--r--synapse/storage/databases/main/relations.py218
-rw-r--r--synapse/storage/databases/main/roommember.py51
-rw-r--r--synapse/storage/databases/main/search.py13
-rw-r--r--synapse/storage/databases/main/stream.py18
-rw-r--r--synapse/storage/databases/main/user_directory.py29
-rw-r--r--synapse/storage/engines/__init__.py2
-rw-r--r--synapse/storage/engines/postgres.py45
-rw-r--r--synapse/storage/persist_events.py15
-rw-r--r--synapse/storage/relations.py31
-rw-r--r--synapse/storage/schema/main/delta/30/as_users.py1
-rw-r--r--synapse/storage/state.py8
21 files changed, 506 insertions, 419 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index d64910aded..08c6eabc6d 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -60,18 +60,19 @@ class _BackgroundUpdateHandler:
 
 
 class _BackgroundUpdateContextManager:
-    BACKGROUND_UPDATE_INTERVAL_MS = 1000
-    BACKGROUND_UPDATE_DURATION_MS = 100
-
-    def __init__(self, sleep: bool, clock: Clock):
+    def __init__(
+        self, sleep: bool, clock: Clock, sleep_duration_ms: int, update_duration: int
+    ):
         self._sleep = sleep
         self._clock = clock
+        self._sleep_duration_ms = sleep_duration_ms
+        self._update_duration_ms = update_duration
 
     async def __aenter__(self) -> int:
         if self._sleep:
-            await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
+            await self._clock.sleep(self._sleep_duration_ms / 1000)
 
-        return self.BACKGROUND_UPDATE_DURATION_MS
+        return self._update_duration_ms
 
     async def __aexit__(self, *exc) -> None:
         pass
@@ -102,10 +103,12 @@ class BackgroundUpdatePerformance:
         Returns:
             A duration in ms as a float
         """
-        if self.avg_duration_ms == 0:
-            return 0
-        elif self.total_item_count == 0:
+        # We want to return None if this is the first background update item
+        if self.total_item_count == 0:
             return None
+        # Avoid dividing by zero
+        elif self.avg_duration_ms == 0:
+            return 0
         else:
             # Use the exponential moving average so that we can adapt to
             # changes in how long the update process takes.
@@ -131,9 +134,6 @@ class BackgroundUpdater:
     process and autotuning the batch size.
     """
 
-    MINIMUM_BACKGROUND_BATCH_SIZE = 1
-    DEFAULT_BACKGROUND_BATCH_SIZE = 100
-
     def __init__(self, hs: "HomeServer", database: "DatabasePool"):
         self._clock = hs.get_clock()
         self.db_pool = database
@@ -158,6 +158,14 @@ class BackgroundUpdater:
         # enable/disable background updates via the admin API.
         self.enabled = True
 
+        self.minimum_background_batch_size = hs.config.background_updates.min_batch_size
+        self.default_background_batch_size = (
+            hs.config.background_updates.default_batch_size
+        )
+        self.update_duration_ms = hs.config.background_updates.update_duration_ms
+        self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms
+        self.sleep_enabled = hs.config.background_updates.sleep_enabled
+
     def register_update_controller_callbacks(
         self,
         on_update: ON_UPDATE_CALLBACK,
@@ -214,7 +222,9 @@ class BackgroundUpdater:
         if self._on_update_callback is not None:
             return self._on_update_callback(update_name, database_name, oneshot)
 
-        return _BackgroundUpdateContextManager(sleep, self._clock)
+        return _BackgroundUpdateContextManager(
+            sleep, self._clock, self.sleep_duration_ms, self.update_duration_ms
+        )
 
     async def _default_batch_size(self, update_name: str, database_name: str) -> int:
         """The batch size to use for the first iteration of a new background
@@ -223,7 +233,7 @@ class BackgroundUpdater:
         if self._default_batch_size_callback is not None:
             return await self._default_batch_size_callback(update_name, database_name)
 
-        return self.DEFAULT_BACKGROUND_BATCH_SIZE
+        return self.default_background_batch_size
 
     async def _min_batch_size(self, update_name: str, database_name: str) -> int:
         """A lower bound on the batch size of a new background update.
@@ -233,7 +243,7 @@ class BackgroundUpdater:
         if self._min_batch_size_callback is not None:
             return await self._min_batch_size_callback(update_name, database_name)
 
-        return self.MINIMUM_BACKGROUND_BATCH_SIZE
+        return self.minimum_background_batch_size
 
     def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
         """Returns the current background update, if any."""
@@ -252,9 +262,12 @@ class BackgroundUpdater:
         if self.enabled:
             # if we start a new background update, not all updates are done.
             self._all_done = False
-            run_as_background_process("background_updates", self.run_background_updates)
+            sleep = self.sleep_enabled
+            run_as_background_process(
+                "background_updates", self.run_background_updates, sleep
+            )
 
-    async def run_background_updates(self, sleep: bool = True) -> None:
+    async def run_background_updates(self, sleep: bool) -> None:
         if self._running or not self.enabled:
             return
 
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 99802228c9..367709a1a7 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -41,6 +41,7 @@ from prometheus_client import Histogram
 from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
+from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.util.async_helpers import delay_cancellation
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -286,7 +288,7 @@ class LoggingTransaction:
         """
 
         if isinstance(self.database_engine, PostgresEngine):
-            from psycopg2.extras import execute_batch  # type: ignore
+            from psycopg2.extras import execute_batch
 
             self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
         else:
@@ -300,10 +302,18 @@ class LoggingTransaction:
         rows (e.g. INSERTs).
         """
         assert isinstance(self.database_engine, PostgresEngine)
-        from psycopg2.extras import execute_values  # type: ignore
+        from psycopg2.extras import execute_values
 
         return self._do_execute(
-            lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
+            # Type ignore: mypy is unhappy because if `x` is a 5-tuple, then there will
+            # be two values for `fetch`: one given positionally, and another given
+            # as a keyword argument. We might be able to fix this by
+            # - propagating the signature of psycopg2.extras.execute_values to this
+            #   function, or
+            # - changing `*args: Any` to `values: T` for some appropriate T.
+            lambda *x: execute_values(self.txn, *x, fetch=fetch),  # type: ignore[misc]
+            sql,
+            *args,
         )
 
     def execute(self, sql: str, *args: Any) -> None:
@@ -732,34 +742,45 @@ class DatabasePool:
         Returns:
             The result of func
         """
-        after_callbacks: List[_CallbackListEntry] = []
-        exception_callbacks: List[_CallbackListEntry] = []
 
-        if not current_context():
-            logger.warning("Starting db txn '%s' from sentinel context", desc)
+        async def _runInteraction() -> R:
+            after_callbacks: List[_CallbackListEntry] = []
+            exception_callbacks: List[_CallbackListEntry] = []
 
-        try:
-            with opentracing.start_active_span(f"db.{desc}"):
-                result = await self.runWithConnection(
-                    self.new_transaction,
-                    desc,
-                    after_callbacks,
-                    exception_callbacks,
-                    func,
-                    *args,
-                    db_autocommit=db_autocommit,
-                    isolation_level=isolation_level,
-                    **kwargs,
-                )
+            if not current_context():
+                logger.warning("Starting db txn '%s' from sentinel context", desc)
 
-            for after_callback, after_args, after_kwargs in after_callbacks:
-                after_callback(*after_args, **after_kwargs)
-        except Exception:
-            for after_callback, after_args, after_kwargs in exception_callbacks:
-                after_callback(*after_args, **after_kwargs)
-            raise
+            try:
+                with opentracing.start_active_span(f"db.{desc}"):
+                    result = await self.runWithConnection(
+                        self.new_transaction,
+                        desc,
+                        after_callbacks,
+                        exception_callbacks,
+                        func,
+                        *args,
+                        db_autocommit=db_autocommit,
+                        isolation_level=isolation_level,
+                        **kwargs,
+                    )
 
-        return cast(R, result)
+                for after_callback, after_args, after_kwargs in after_callbacks:
+                    after_callback(*after_args, **after_kwargs)
+
+                return cast(R, result)
+            except Exception:
+                for after_callback, after_args, after_kwargs in exception_callbacks:
+                    after_callback(*after_args, **after_kwargs)
+                raise
+
+        # To handle cancellation, we ensure that `after_callback`s and
+        # `exception_callback`s are always run, since the transaction will complete
+        # on another thread regardless of cancellation.
+        #
+        # We also wait until everything above is done before releasing the
+        # `CancelledError`, so that logging contexts won't get used after they have been
+        # finished.
+        return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
 
     async def runWithConnection(
         self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 52146aacc8..9af9f4f18e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,17 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    FrozenSet,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    cast,
+)
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         )
 
     @cached(max_entries=5000, iterable=True)
-    async def ignored_by(self, user_id: str) -> Set[str]:
+    async def ignored_by(self, user_id: str) -> FrozenSet[str]:
         """
         Get users which ignore the given user.
 
@@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         Return:
             The user IDs which ignore the given user.
         """
-        return set(
+        return frozenset(
             await self.db_pool.simple_select_onecol(
                 table="ignored_users",
                 keyvalues={"ignored_user_id": user_id},
@@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
         )
 
+    @cached(max_entries=5000, iterable=True)
+    async def ignored_users(self, user_id: str) -> FrozenSet[str]:
+        """
+        Get users which the given user ignores.
+
+        Params:
+            user_id: The user ID which is making the request.
+
+        Return:
+            The user IDs which are ignored by the given user.
+        """
+        return frozenset(
+            await self.db_pool.simple_select_onecol(
+                table="ignored_users",
+                keyvalues={"ignorer_user_id": user_id},
+                retcol="ignored_user_id",
+                desc="ignored_users",
+            )
+        )
+
     def process_replication_rows(
         self,
         stream_name: str,
@@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         else:
             currently_ignored_users = set()
 
+        # If the data has not changed, nothing to do.
+        if previously_ignored_users == currently_ignored_users:
+            return
+
         # Delete entries which are no longer ignored.
         self.db_pool.simple_delete_many_txn(
             txn,
@@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         # Invalidate the cache for any ignored users which were added or removed.
         for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
             self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+        self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
 
     async def purge_account_data_for_user(self, user_id: str) -> None:
         """
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index c428dd5596..dd4e83a2ad 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStream,
     EventsStreamCurrentStateRow,
     EventsStreamEventRow,
+    EventsStreamRow,
 )
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
@@ -31,6 +32,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import _CachedFunction
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_caches_txn(txn):
+        def get_all_updated_caches_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             # We purposefully don't bound by the current token, as we want to
             # send across cache invalidations as quickly as possible. Cache
             # invalidations are idempotent, so duplicates are fine.
@@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             "get_all_updated_caches", get_all_updated_caches_txn
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+    ) -> None:
         if stream_name == EventsStream.NAME:
             for row in rows:
                 self._process_event_stream_row(token, row)
@@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def _process_event_stream_row(self, token, row):
+    def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
         data = row.data
 
         if row.type == EventsStreamEventRow.TypeId:
+            assert isinstance(data, EventsStreamEventRow)
             self._invalidate_caches_for_event(
                 token,
                 data.event_id,
@@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 backfilled=False,
             )
         elif row.type == EventsStreamCurrentStateRow.TypeId:
-            self._curr_state_delta_stream_cache.entity_has_changed(
-                row.data.room_id, token
-            )
+            assert isinstance(data, EventsStreamCurrentStateRow)
+            self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
 
             if data.type == EventTypes.Member:
                 self.get_rooms_for_user_with_stream_ordering.invalidate(
@@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
     def _invalidate_caches_for_event(
         self,
-        stream_ordering,
-        event_id,
-        room_id,
-        etype,
-        state_key,
-        redacts,
-        relates_to,
-        backfilled,
-    ):
+        stream_ordering: int,
+        event_id: str,
+        room_id: str,
+        etype: str,
+        state_key: Optional[str],
+        redacts: Optional[str],
+        relates_to: Optional[str],
+        backfilled: bool,
+    ) -> None:
         self._invalidate_get_event_cache(event_id)
         self.have_seen_event.invalidate((room_id, event_id))
 
@@ -186,11 +192,19 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
 
+        # The `_get_membership_from_event_id` is immutable, except for the
+        # case where we look up an event *before* persisting it.
+        self._get_membership_from_event_id.invalidate((event_id,))
+
         if not backfilled:
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
 
         if redacts:
             self._invalidate_get_event_cache(redacts)
+            # Caches which might leak edits must be invalidated for the event being
+            # redacted.
+            self.get_relations_for_event.invalidate((redacts,))
+            self.get_applicable_edit.invalidate((redacts,))
 
         if etype == EventTypes.Member:
             self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
@@ -200,8 +214,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self.get_relations_for_event.invalidate((relates_to,))
             self.get_aggregation_groups_for_event.invalidate((relates_to,))
             self.get_applicable_edit.invalidate((relates_to,))
+            self.get_thread_summary.invalidate((relates_to,))
+            self.get_thread_participated.invalidate((relates_to,))
 
-    async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+    async def invalidate_cache_and_stream(
+        self, cache_name: str, keys: Tuple[Any, ...]
+    ) -> None:
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
 
@@ -221,7 +239,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             keys,
         )
 
-    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+    def _invalidate_cache_and_stream(
+        self,
+        txn: LoggingTransaction,
+        cache_func: _CachedFunction,
+        keys: Tuple[Any, ...],
+    ) -> None:
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
 
@@ -232,7 +255,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         txn.call_after(cache_func.invalidate, keys)
         self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
 
-    def _invalidate_all_cache_and_stream(self, txn, cache_func):
+    def _invalidate_all_cache_and_stream(
+        self, txn: LoggingTransaction, cache_func: _CachedFunction
+    ) -> None:
         """Invalidates the entire cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
         """
@@ -273,8 +298,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             )
 
     def _send_invalidation_to_replication(
-        self, txn, cache_name: str, keys: Optional[Iterable[Any]]
-    ):
+        self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
+    ) -> None:
         """Notifies replication that given cache has been invalidated.
 
         Note that this does *not* invalidate the cache locally.
@@ -309,7 +334,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                     "instance_name": self._instance_name,
                     "cache_func": cache_name,
                     "keys": keys,
-                    "invalidation_ts": self.clock.time_msec(),
+                    "invalidation_ts": self._clock.time_msec(),
                 },
             )
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1392363de1..b4a1b041b1 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -298,6 +298,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 # This user has new messages sent to them. Query messages for them
                 user_ids_to_query.add(user_id)
 
+        if not user_ids_to_query:
+            return {}, to_stream_id
+
         def get_device_messages_txn(txn: LoggingTransaction):
             # Build a query to select messages from any of the given devices that
             # are between the given stream id bounds.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ca2a9ba9d1..d253243125 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1518,7 +1518,7 @@ class PersistEventsStore:
                 )
 
                 # Remove from relations table.
-                self._handle_redaction(txn, event.redacts)
+                self._handle_redact_relations(txn, event.redacts)
 
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
@@ -1619,9 +1619,12 @@ class PersistEventsStore:
 
         txn.call_after(prefill)
 
-    def _store_redaction(self, txn, event):
-        # invalidate the cache for the redacted event
+    def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
+        # Invalidate the caches for the redacted event, note that these caches
+        # are also cleared as part of event replication in _invalidate_caches_for_event.
         txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+        txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
+        txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
 
         self.db_pool.simple_upsert_txn(
             txn,
@@ -1742,6 +1745,13 @@ class PersistEventsStore:
                 (event.state_key,),
             )
 
+            # The `_get_membership_from_event_id` is immutable, except for the
+            # case where we look up an event *before* persisting it.
+            txn.call_after(
+                self.store._get_membership_from_event_id.invalidate,
+                (event.event_id,),
+            )
+
             # We update the local_current_membership table only if the event is
             # "current", i.e., its something that has just happened.
             #
@@ -1811,10 +1821,11 @@ class PersistEventsStore:
         if rel_type == RelationTypes.REPLACE:
             txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
 
-        if rel_type == RelationTypes.THREAD:
-            txn.call_after(
-                self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
-            )
+        if (
+            rel_type == RelationTypes.THREAD
+            or rel_type == RelationTypes.UNSTABLE_THREAD
+        ):
+            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
             # It should be safe to only invalidate the cache if the user has not
             # previously participated in the thread, but that's difficult (and
             # potentially error-prone) so it is always invalidated.
@@ -1943,15 +1954,43 @@ class PersistEventsStore:
 
         txn.execute(sql, (batch_id,))
 
-    def _handle_redaction(self, txn, redacted_event_id):
-        """Handles receiving a redaction and checking whether we need to remove
-        any redacted relations from the database.
+    def _handle_redact_relations(
+        self, txn: LoggingTransaction, redacted_event_id: str
+    ) -> None:
+        """Handles receiving a redaction and checking whether the redacted event
+        has any relations which must be removed from the database.
 
         Args:
             txn
-            redacted_event_id (str): The event that was redacted.
+            redacted_event_id: The event that was redacted.
         """
 
+        # Fetch the current relation of the event being redacted.
+        redacted_relates_to = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_relations",
+            keyvalues={"event_id": redacted_event_id},
+            retcol="relates_to_id",
+            allow_none=True,
+        )
+        # Any relation information for the related event must be cleared.
+        if redacted_relates_to is not None:
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_relations_for_event, (redacted_relates_to,)
+            )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
+            )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_applicable_edit, (redacted_relates_to,)
+            )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_thread_summary, (redacted_relates_to,)
+            )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_thread_participated, (redacted_relates_to,)
+            )
+
         self.db_pool.simple_delete_txn(
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 26784f755e..59454a47df 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -1286,7 +1286,7 @@ class EventsWorkerStore(SQLBaseStore):
         )
         return {eid for ((_rid, eid), have_event) in res.items() if have_event}
 
-    @cachedList("have_seen_event", "keys")
+    @cachedList(cached_method_name="have_seen_event", list_name="keys")
     async def _have_seen_events_dict(
         self, keys: Iterable[Tuple[str, str]]
     ) -> Dict[Tuple[str, str], bool]:
@@ -1954,7 +1954,7 @@ class EventsWorkerStore(SQLBaseStore):
             get_event_id_for_timestamp_txn,
         )
 
-    @cachedList("is_partial_state_event", list_name="event_ids")
+    @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
     async def get_partial_state_events(
         self, event_ids: Collection[str]
     ) -> Dict[str, bool]:
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 3f6086050b..0aef121d83 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,13 +13,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
 
 from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
@@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
     ) -> List[Dict[str, Any]]:
         # TODO: Pagination
 
-        keyvalues = {"group_id": group_id}
+        keyvalues: JsonDict = {"group_id": group_id}
         if not include_private:
             keyvalues["is_public"] = True
 
@@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         # TODO: Pagination
 
-        def _get_rooms_in_group_txn(txn):
+        def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
             sql = """
             SELECT room_id, is_public FROM group_rooms
                 WHERE group_id = ?
@@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
                         * "order": int, the sort order of rooms in this category
         """
 
-        def _get_rooms_for_summary_txn(txn):
-            keyvalues = {"group_id": group_id}
+        def _get_rooms_for_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+            keyvalues: JsonDict = {"group_id": group_id}
             if not include_private:
                 keyvalues["is_public"] = True
 
@@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_rooms_for_summary", _get_rooms_for_summary_txn
         )
 
-    async def get_group_categories(self, group_id):
+    async def get_group_categories(self, group_id: str) -> JsonDict:
         rows = await self.db_pool.simple_select_list(
             table="group_room_categories",
             keyvalues={"group_id": group_id},
@@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    async def get_group_category(self, group_id, category_id):
+    async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
         category = await self.db_pool.simple_select_one(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
@@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         return category
 
-    async def get_group_roles(self, group_id):
+    async def get_group_roles(self, group_id: str) -> JsonDict:
         rows = await self.db_pool.simple_select_list(
             table="group_roles",
             keyvalues={"group_id": group_id},
@@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    async def get_group_role(self, group_id, role_id):
+    async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
         role = await self.db_pool.simple_select_one(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
@@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_local_groups_for_room",
         )
 
-    async def get_users_for_summary_by_role(self, group_id, include_private=False):
+    async def get_users_for_summary_by_role(
+        self, group_id: str, include_private: bool = False
+    ) -> Tuple[List[JsonDict], JsonDict]:
         """Get the users and roles that should be included in a summary request
 
         Returns:
             ([users], [roles])
         """
 
-        def _get_users_for_summary_txn(txn):
-            keyvalues = {"group_id": group_id}
+        def _get_users_for_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], JsonDict]:
+            keyvalues: JsonDict = {"group_id": group_id}
             if not include_private:
                 keyvalues["is_public"] = True
 
@@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    async def get_users_membership_info_in_group(self, group_id, user_id):
+    async def get_users_membership_info_in_group(
+        self, group_id: str, user_id: str
+    ) -> JsonDict:
         """Get a dict describing the membership of a user in a group.
 
         Example if joined:
@@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
              An empty dict if the user is not join/invite/etc
         """
 
-        def _get_users_membership_in_group_txn(txn):
+        def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 table="group_users",
@@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_publicised_groups_for_user",
         )
 
-    async def get_attestations_need_renewals(self, valid_until_ms):
+    async def get_attestations_need_renewals(
+        self, valid_until_ms: int
+    ) -> List[Dict[str, Any]]:
         """Get all attestations that need to be renewed until givent time"""
 
-        def _get_attestations_need_renewals_txn(txn):
+        def _get_attestations_need_renewals_txn(
+            txn: LoggingTransaction,
+        ) -> List[Dict[str, Any]]:
             sql = """
                 SELECT group_id, user_id FROM group_attestations_renewals
                 WHERE valid_until_ms <= ?
@@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_attestations_need_renewals", _get_attestations_need_renewals_txn
         )
 
-    async def get_remote_attestation(self, group_id, user_id):
+    async def get_remote_attestation(
+        self, group_id: str, user_id: str
+    ) -> Optional[JsonDict]:
         """Get the attestation that proves the remote agrees that the user is
         in the group.
         """
@@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_joined_groups",
         )
 
-    async def get_all_groups_for_user(self, user_id, now_token):
-        def _get_all_groups_for_user_txn(txn):
+    async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
+        def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
             sql = """
                 SELECT group_id, type, membership, u.content
                 FROM local_group_updates AS u
@@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_all_groups_for_user", _get_all_groups_for_user_txn
         )
 
-    async def get_groups_changes_for_user(self, user_id, from_token, to_token):
-        from_token = int(from_token)
-        has_changed = self._group_updates_stream_cache.has_entity_changed(
+    async def get_groups_changes_for_user(
+        self, user_id: str, from_token: int, to_token: int
+    ) -> List[JsonDict]:
+        has_changed = self._group_updates_stream_cache.has_entity_changed(  # type: ignore[attr-defined]
             user_id, from_token
         )
         if not has_changed:
             return []
 
-        def _get_groups_changes_for_user_txn(txn):
+        def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
             sql = """
                 SELECT group_id, membership, type, u.content
                 FROM local_group_updates AS u
@@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
         """
 
         last_id = int(last_id)
-        has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+        has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)  # type: ignore[attr-defined]
 
         if not has_changed:
             return [], current_id, False
 
-        def _get_all_groups_changes_txn(txn):
+        def _get_all_groups_changes_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             sql = """
                 SELECT stream_id, group_id, user_id, type, content
                 FROM local_group_updates
@@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [
-                (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
-                for stream_id, group_id, user_id, gtype, content_json in txn
-            ]
+            updates = cast(
+                List[Tuple[int, tuple]],
+                [
+                    (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
+                    for stream_id, group_id, user_id, gtype, content_json in txn
+                ],
+            )
 
             limited = False
             upto_token = current_id
@@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
         self,
         group_id: str,
         room_id: str,
-        category_id: str,
-        order: int,
+        category_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
     ) -> None:
         """Add (or update) room's entry in summary.
@@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
 
     def _add_room_to_summary_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         group_id: str,
         room_id: str,
-        category_id: str,
-        order: int,
+        category_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
     ) -> None:
         """Add (or update) room's entry in summary.
@@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
                 WHERE group_id = ? AND category_id = ?
             """
             txn.execute(sql, (group_id, category_id))
-            (order,) = txn.fetchone()
+            (order,) = cast(Tuple[int], txn.fetchone())
 
         if existing:
             to_update = {}
@@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
                     "category_id": category_id,
                     "room_id": room_id,
                 },
-                values=to_update,
+                updatevalues=to_update,
             )
         else:
             if is_public is None:
@@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
             )
 
     async def remove_room_from_summary(
-        self, group_id: str, room_id: str, category_id: str
+        self, group_id: str, room_id: str, category_id: Optional[str]
     ) -> int:
         if category_id is None:
             category_id = _DEFAULT_CATEGORY_ID
@@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
         is_public: Optional[bool],
     ) -> None:
         """Add/update room category for group"""
-        insertion_values = {}
-        update_values = {"category_id": category_id}  # This cannot be empty
+        insertion_values: JsonDict = {}
+        update_values: JsonDict = {"category_id": category_id}  # This cannot be empty
 
         if profile is None:
             insertion_values["profile"] = "{}"
@@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
         is_public: Optional[bool],
     ) -> None:
         """Add/remove user role"""
-        insertion_values = {}
-        update_values = {"role_id": role_id}  # This cannot be empty
+        insertion_values: JsonDict = {}
+        update_values: JsonDict = {"role_id": role_id}  # This cannot be empty
 
         if profile is None:
             insertion_values["profile"] = "{}"
@@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
         self,
         group_id: str,
         user_id: str,
-        role_id: str,
-        order: int,
+        role_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
     ) -> None:
         """Add (or update) user's entry in summary.
@@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
 
     def _add_user_to_summary_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         group_id: str,
         user_id: str,
-        role_id: str,
-        order: int,
+        role_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
-    ):
+    ) -> None:
         """Add (or update) user's entry in summary.
 
         Args:
@@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
                 WHERE group_id = ? AND role_id = ?
             """
             txn.execute(sql, (group_id, role_id))
-            (order,) = txn.fetchone()
+            (order,) = cast(Tuple[int], txn.fetchone())
 
         if existing:
             to_update = {}
@@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
                     "role_id": role_id,
                     "user_id": user_id,
                 },
-                values=to_update,
+                updatevalues=to_update,
             )
         else:
             if is_public is None:
@@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
             )
 
     async def remove_user_from_summary(
-        self, group_id: str, user_id: str, role_id: str
+        self, group_id: str, user_id: str, role_id: Optional[str]
     ) -> int:
         if role_id is None:
             role_id = _DEFAULT_ROLE_ID
@@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
                 Optional if the user and group are on the same server
         """
 
-        def _add_user_to_group_txn(txn):
+        def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_insert_txn(
                 txn,
                 table="group_users",
@@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
         await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
 
     async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
-        def _remove_user_from_group_txn(txn):
+        def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="group_users",
@@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
         )
 
     async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
-        def _remove_room_from_group_txn(txn):
+        def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="group_rooms",
@@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
 
         content = content or {}
 
-        def _register_user_group_membership_txn(txn, next_id):
+        def _register_user_group_membership_txn(
+            txn: LoggingTransaction, next_id: int
+        ) -> int:
             # TODO: Upsert?
             self.db_pool.simple_delete_txn(
                 txn,
@@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
                     ),
                 },
             )
-            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)  # type: ignore[attr-defined]
 
             # TODO: Insert profile to ensure it comes down stream if its a join.
 
@@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        async with self._group_updates_id_gen.get_next() as next_id:
+        async with self._group_updates_id_gen.get_next() as next_id:  # type: ignore[attr-defined]
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
@@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
         return res
 
     async def create_group(
-        self, group_id, user_id, name, avatar_url, short_description, long_description
+        self,
+        group_id: str,
+        user_id: str,
+        name: str,
+        avatar_url: str,
+        short_description: str,
+        long_description: str,
     ) -> None:
         await self.db_pool.simple_insert(
             table="groups",
@@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="create_group",
         )
 
-    async def update_group_profile(self, group_id, profile):
+    async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
         await self.db_pool.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
@@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="remove_attestation_renewal",
         )
 
-    def get_group_stream_token(self):
-        return self._group_updates_id_gen.get_current_token()
+    def get_group_stream_token(self) -> int:
+        return self._group_updates_id_gen.get_current_token()  # type: ignore[attr-defined]
 
     async def delete_group(self, group_id: str) -> None:
         """Deletes a group fully from the database.
@@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
             group_id: The group ID to delete.
         """
 
-        def _delete_group_txn(txn):
+        def _delete_group_txn(txn: LoggingTransaction) -> None:
             tables = [
                 "groups",
                 "group_users",
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e9a0cdc6be..216622964a 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,15 +12,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
+    LoggingTransaction,
     make_in_list_sql_clause,
 )
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
 from synapse.util.caches.descriptors import cached
 from synapse.util.threepids import canonicalise_email
 
@@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             Number of current monthly active users
         """
 
-        def _count_users(txn):
+        def _count_users(txn: LoggingTransaction) -> int:
             # Exclude app service users
             sql = """
                 SELECT COUNT(*)
@@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
                 WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
             """
             txn.execute(sql)
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_users", _count_users)
@@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
         """
 
-        def _count_users_by_service(txn):
+        def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
             sql = """
                 SELECT COALESCE(appservice_id, 'native'), COUNT(*)
                 FROM monthly_active_users
@@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             """
 
             txn.execute(sql)
-            result = txn.fetchall()
+            result = cast(List[Tuple[str, int]], txn.fetchall())
             return dict(result)
 
         return await self.db_pool.runInteraction(
@@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         )
 
     @wrap_as_background_process("reap_monthly_active_users")
-    async def reap_monthly_active_users(self):
+    async def reap_monthly_active_users(self) -> None:
         """Cleans out monthly active user table to ensure that no stale
         entries exist.
         """
 
-        def _reap_users(txn, reserved_users):
+        def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
             """
             Args:
                 reserved_users (tuple): reserved users to preserve
@@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             # is racy.
             # Have resolved to invalidate the whole cache for now and do
             # something about it if and when the perf becomes significant
-            self._invalidate_all_cache_and_stream(
+            self._invalidate_all_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.user_last_seen_monthly_active
             )
-            self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+            self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())  # type: ignore[attr-defined]
 
         reserved_users = await self.get_registered_reserved_users()
         await self.db_pool.runInteraction(
@@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         )
 
 
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
         )
 
-    def _initialise_reserved_users(self, txn, threepids):
+    def _initialise_reserved_users(
+        self, txn: LoggingTransaction, threepids: List[dict]
+    ) -> None:
         """Ensures that reserved threepids are accounted for in the MAU table, should
         be called on start up.
 
         Args:
-            txn (cursor):
-            threepids (list[dict]): List of threepid dicts to reserve
+            txn:
+            threepids: List of threepid dicts to reserve
         """
 
         # XXX what is this function trying to achieve?  It upserts into
@@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
         )
 
-    def upsert_monthly_active_user_txn(self, txn, user_id):
+    def upsert_monthly_active_user_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> None:
         """Updates or inserts monthly active user member
 
         We consciously do not call is_support_txn from this method because it
@@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             txn, self.user_last_seen_monthly_active, (user_id,)
         )
 
-    async def populate_monthly_active_users(self, user_id):
+    async def populate_monthly_active_users(self, user_id: str) -> None:
         """Checks on the state of monthly active user limits and optionally
         add the user to the monthly active tables
 
@@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
         """
         if self._limit_usage_by_mau or self._mau_stats_only:
             # Trial users and guests should not be included as part of MAU group
-            is_guest = await self.is_guest(user_id)
+            is_guest = await self.is_guest(user_id)  # type: ignore[attr-defined]
             if is_guest:
                 return
             is_trial = await self.is_trial_user(user_id)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index dc6665237a..a698d10cc5 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -48,8 +48,6 @@ class ExternalIDReuseException(Exception):
     """Exception if writing an external id for a user fails,
     because this external id is given to an other user."""
 
-    pass
-
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
 class TokenLookupResult:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 36aa1092f6..b2295fd51f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -27,7 +27,6 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
 
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
@@ -41,45 +40,15 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
-    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _ThreadAggregation:
-    # The latest event in the thread.
-    latest_event: EventBase
-    # The latest edit to the latest event in the thread.
-    latest_edit: Optional[EventBase]
-    # The total number of events in the thread.
-    count: int
-    # True if the current user has sent an event to the thread.
-    current_user_participated: bool
-
-
-@attr.s(slots=True, auto_attribs=True)
-class BundledAggregations:
-    """
-    The bundled aggregations for an event.
-
-    Some values require additional processing during serialization.
-    """
-
-    annotations: Optional[JsonDict] = None
-    references: Optional[JsonDict] = None
-    replace: Optional[EventBase] = None
-    thread: Optional[_ThreadAggregation] = None
-
-    def __bool__(self) -> bool:
-        return bool(self.annotations or self.references or self.replace or self.thread)
-
-
 class RelationsWorkerStore(SQLBaseStore):
     def __init__(
         self,
@@ -91,10 +60,11 @@ class RelationsWorkerStore(SQLBaseStore):
 
         self._msc3440_enabled = hs.config.experimental.msc3440_enabled
 
-    @cached(tree=True)
+    @cached(uncached_args=("event",), tree=True)
     async def get_relations_for_event(
         self,
         event_id: str,
+        event: EventBase,
         room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
@@ -108,6 +78,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            event: The matching EventBase to event_id.
             room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
@@ -122,9 +93,13 @@ class RelationsWorkerStore(SQLBaseStore):
             List of event IDs that match relations requested. The rows are of
             the form `{"event_id": "..."}`.
         """
+        # We don't use `event_id`, it's there so that we can cache based on
+        # it. The `event_id` must match the `event.event_id`.
+        assert event.event_id == event_id
 
         where_clause = ["relates_to_id = ?", "room_id = ?"]
-        where_args: List[Union[str, int]] = [event_id, room_id]
+        where_args: List[Union[str, int]] = [event.event_id, room_id]
+        is_redacted = event.internal_metadata.is_redacted()
 
         if relation_type is not None:
             where_clause.append("relation_type = ?")
@@ -157,7 +132,7 @@ class RelationsWorkerStore(SQLBaseStore):
             order = "ASC"
 
         sql = """
-            SELECT event_id, topological_ordering, stream_ordering
+            SELECT event_id, relation_type, topological_ordering, stream_ordering
             FROM event_relations
             INNER JOIN events USING (event_id)
             WHERE %s
@@ -178,9 +153,12 @@ class RelationsWorkerStore(SQLBaseStore):
             last_stream_id = None
             events = []
             for row in txn:
-                events.append({"event_id": row[0]})
-                last_topo_id = row[1]
-                last_stream_id = row[2]
+                # Do not include edits for redacted events as they leak event
+                # content.
+                if not is_redacted or row[1] != RelationTypes.REPLACE:
+                    events.append({"event_id": row[0]})
+                last_topo_id = row[2]
+                last_stream_id = row[3]
 
             # If there are more events, generate the next pagination key.
             next_token = None
@@ -375,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
-    async def _get_applicable_edits(
+    async def get_applicable_edits(
         self, event_ids: Collection[str]
     ) -> Dict[str, Optional[EventBase]]:
         """Get the most recent edit (if any) that has happened for the given
@@ -464,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
-    async def _get_thread_summaries(
+    async def get_thread_summaries(
         self, event_ids: Collection[str]
     ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
         """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@@ -499,7 +477,7 @@ class RelationsWorkerStore(SQLBaseStore):
                         AND parent.room_id = child.room_id
                     WHERE
                         %s
-                        AND relation_type = ?
+                        AND %s
                     ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
                 """
             else:
@@ -514,16 +492,22 @@ class RelationsWorkerStore(SQLBaseStore):
                         AND parent.room_id = child.room_id
                     WHERE
                         %s
-                        AND relation_type = ?
+                        AND %s
                     ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
                 """
 
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", event_ids
             )
-            args.append(RelationTypes.THREAD)
 
-            txn.execute(sql % (clause,), args)
+            if self._msc3440_enabled:
+                relations_clause = "(relation_type = ? OR relation_type = ?)"
+                args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+            else:
+                relations_clause = "relation_type = ?"
+                args.append(RelationTypes.THREAD)
+
+            txn.execute(sql % (clause, relations_clause), args)
             latest_event_ids = {}
             for parent_event_id, child_event_id in txn:
                 # Only consider the latest threaded reply (by topological ordering).
@@ -543,7 +527,7 @@ class RelationsWorkerStore(SQLBaseStore):
                     AND parent.room_id = child.room_id
                 WHERE
                     %s
-                    AND relation_type = ?
+                    AND %s
                 GROUP BY parent.event_id
             """
 
@@ -552,9 +536,15 @@ class RelationsWorkerStore(SQLBaseStore):
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", latest_event_ids.keys()
             )
-            args.append(RelationTypes.THREAD)
 
-            txn.execute(sql % (clause,), args)
+            if self._msc3440_enabled:
+                relations_clause = "(relation_type = ? OR relation_type = ?)"
+                args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+            else:
+                relations_clause = "relation_type = ?"
+                args.append(RelationTypes.THREAD)
+
+            txn.execute(sql % (clause, relations_clause), args)
             counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
 
             return counts, latest_event_ids
@@ -566,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
         latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
         # Check to see if any of those events are edited.
-        latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+        latest_edits = await self.get_applicable_edits(latest_event_ids.values())
 
         # Map to the event IDs to the thread summary.
         #
@@ -589,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
-    async def _get_threads_participated(
+    async def get_threads_participated(
         self, event_ids: Collection[str], user_id: str
     ) -> Dict[str, bool]:
         """Get whether the requesting user participated in the given threads.
@@ -617,16 +607,24 @@ class RelationsWorkerStore(SQLBaseStore):
                     AND parent.room_id = child.room_id
                 WHERE
                     %s
-                    AND relation_type = ?
+                    AND %s
                     AND child.sender = ?
             """
 
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", event_ids
             )
-            args.extend((RelationTypes.THREAD, user_id))
 
-            txn.execute(sql % (clause,), args)
+            if self._msc3440_enabled:
+                relations_clause = "(relation_type = ? OR relation_type = ?)"
+                args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+            else:
+                relations_clause = "relation_type = ?"
+                args.append(RelationTypes.THREAD)
+
+            args.append(user_id)
+
+            txn.execute(sql % (clause, relations_clause), args)
             return {row[0] for row in txn.fetchall()}
 
         participated_threads = await self.db_pool.runInteraction(
@@ -737,122 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
-    async def _get_bundled_aggregation_for_event(
-        self, event: EventBase, user_id: str
-    ) -> Optional[BundledAggregations]:
-        """Generate bundled aggregations for an event.
-
-        Note that this does not use a cache, but depends on cached methods.
-
-        Args:
-            event: The event to calculate bundled aggregations for.
-            user_id: The user requesting the bundled aggregations.
-
-        Returns:
-            The bundled aggregations for an event, if bundled aggregations are
-            enabled and the event can have bundled aggregations.
-        """
-
-        # Do not bundle aggregations for an event which represents an edit or an
-        # annotation. It does not make sense for them to have related events.
-        relates_to = event.content.get("m.relates_to")
-        if isinstance(relates_to, (dict, frozendict)):
-            relation_type = relates_to.get("rel_type")
-            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
-                return None
-
-        event_id = event.event_id
-        room_id = event.room_id
-
-        # The bundled aggregations to include, a mapping of relation type to a
-        # type-specific value. Some types include the direct return type here
-        # while others need more processing during serialization.
-        aggregations = BundledAggregations()
-
-        annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
-        if annotations.chunk:
-            aggregations.annotations = await annotations.to_dict(
-                cast("DataStore", self)
-            )
-
-        references = await self.get_relations_for_event(
-            event_id, room_id, RelationTypes.REFERENCE, direction="f"
-        )
-        if references.chunk:
-            aggregations.references = await references.to_dict(cast("DataStore", self))
-
-        # Store the bundled aggregations in the event metadata for later use.
-        return aggregations
-
-    async def get_bundled_aggregations(
-        self, events: Iterable[EventBase], user_id: str
-    ) -> Dict[str, BundledAggregations]:
-        """Generate bundled aggregations for events.
-
-        Args:
-            events: The iterable of events to calculate bundled aggregations for.
-            user_id: The user requesting the bundled aggregations.
-
-        Returns:
-            A map of event ID to the bundled aggregation for the event. Not all
-            events may have bundled aggregations in the results.
-        """
-        # The already processed event IDs. Tracked separately from the result
-        # since the result omits events which do not have bundled aggregations.
-        seen_event_ids = set()
-
-        # State events and redacted events do not get bundled aggregations.
-        events = [
-            event
-            for event in events
-            if not event.is_state() and not event.internal_metadata.is_redacted()
-        ]
-
-        # event ID -> bundled aggregation in non-serialized form.
-        results: Dict[str, BundledAggregations] = {}
-
-        # Fetch other relations per event.
-        for event in events:
-            # De-duplicate events by ID to handle the same event requested multiple
-            # times. The caches that _get_bundled_aggregation_for_event use should
-            # capture this, but best to reduce work.
-            if event.event_id in seen_event_ids:
-                continue
-            seen_event_ids.add(event.event_id)
-
-            event_result = await self._get_bundled_aggregation_for_event(event, user_id)
-            if event_result:
-                results[event.event_id] = event_result
-
-        # Fetch any edits.
-        edits = await self._get_applicable_edits(seen_event_ids)
-        for event_id, edit in edits.items():
-            results.setdefault(event_id, BundledAggregations()).replace = edit
-
-        # Fetch thread summaries.
-        if self._msc3440_enabled:
-            summaries = await self._get_thread_summaries(seen_event_ids)
-            # Only fetch participated for a limited selection based on what had
-            # summaries.
-            participated = await self._get_threads_participated(
-                summaries.keys(), user_id
-            )
-            for event_id, summary in summaries.items():
-                if summary:
-                    thread_count, latest_thread_event, edit = summary
-                    results.setdefault(
-                        event_id, BundledAggregations()
-                    ).thread = _ThreadAggregation(
-                        latest_event=latest_thread_event,
-                        latest_edit=edit,
-                        count=thread_count,
-                        # If there's a thread summary it must also exist in the
-                        # participated dictionary.
-                        current_user_participated=participated[event_id],
-                    )
-
-        return results
-
 
 class RelationsStore(RelationsWorkerStore):
     pass
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e48ec5f495..3248da5356 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -46,7 +46,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id
+from synapse.types import PersistedEventPosition, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class EventIdMembership:
+    """Returned by `get_membership_from_event_ids`"""
+
+    user_id: str
+    membership: str
+
+
 class RoomMemberWorkerStore(EventsWorkerStore):
     def __init__(
         self,
@@ -273,7 +281,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(sql, (room_id,))
             res = {}
             for count, membership in txn:
-                summary = res.setdefault(membership, MemberSummary([], count))
+                res.setdefault(membership, MemberSummary([], count))
 
             # we order by membership and then fairly arbitrarily by event_id so
             # heroes are consistent
@@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             retcols=("user_id", "display_name", "avatar_url", "event_id"),
             keyvalues={"membership": Membership.JOIN},
             batch_size=500,
-            desc="_get_membership_from_event_ids",
+            desc="_get_joined_profiles_from_event_ids",
         )
 
         return {
@@ -839,18 +847,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
-                room_id, state_group, state_entry.state, state_entry=state_entry
+                room_id, state_group, state_entry=state_entry
             )
 
     @cached(num_args=2, max_entries=10000, iterable=True)
     async def _get_joined_hosts(
-        self,
-        room_id: str,
-        state_group: int,
-        current_state_ids: StateMap[str],
-        state_entry: "_StateCacheEntry",
+        self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
     ) -> FrozenSet[str]:
-        # We don't use `state_group`, its there so that we can cache based on
+        # We don't use `state_group`, it's there so that we can cache based on
         # it. However, its important that its never None, since two
         # current_state's with a state_group of None are likely to be different.
         #
@@ -1004,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return set(room_ids)
 
+    @cached(max_entries=5000)
+    async def _get_membership_from_event_id(
+        self, member_event_id: str
+    ) -> Optional[EventIdMembership]:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
+    )
     async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
-    ) -> List[dict]:
-        """Get user_id and membership of a set of event IDs."""
+    ) -> Dict[str, Optional[EventIdMembership]]:
+        """Get user_id and membership of a set of event IDs.
 
-        return await self.db_pool.simple_select_many_batch(
+        Returns:
+            Mapping from event ID to `EventIdMembership` if the event is a
+            membership event, otherwise the value is None.
+        """
+
+        rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=member_event_ids,
@@ -1019,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             desc="get_membership_from_event_ids",
         )
 
+        return {
+            row["event_id"]: EventIdMembership(
+                membership=row["membership"], user_id=row["user_id"]
+            )
+            for row in rows
+        }
+
     async def is_local_host_in_room_ignoring_users(
         self, room_id: str, ignore_users: Collection[str]
     ) -> bool:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index e23b119072..c5e9010c83 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -125,9 +125,6 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        if not hs.config.server.enable_search:
-            return
-
         self.db_pool.updates.register_background_update_handler(
             self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
         )
@@ -243,9 +240,13 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
             return len(event_search_rows)
 
-        result = await self.db_pool.runInteraction(
-            self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
-        )
+        if self.hs.config.server.enable_search:
+            result = await self.db_pool.runInteraction(
+                self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
+            )
+        else:
+            # Don't index anything if search is not enabled.
+            result = 0
 
         if not result:
             await self.db_pool.updates._end_background_update(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index a898f847e7..39e1efe373 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -325,21 +325,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
         args.extend(event_filter.labels)
 
     # Filter on relation_senders / relation types from the joined tables.
-    if event_filter.relation_senders:
+    if event_filter.related_by_senders:
         clauses.append(
             "(%s)"
             % " OR ".join(
-                "related_event.sender = ?" for _ in event_filter.relation_senders
+                "related_event.sender = ?" for _ in event_filter.related_by_senders
             )
         )
-        args.extend(event_filter.relation_senders)
+        args.extend(event_filter.related_by_senders)
 
-    if event_filter.relation_types:
+    if event_filter.related_by_rel_types:
         clauses.append(
             "(%s)"
-            % " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
+            % " OR ".join(
+                "relation_type = ?" for _ in event_filter.related_by_rel_types
+            )
         )
-        args.extend(event_filter.relation_types)
+        args.extend(event_filter.related_by_rel_types)
 
     return " AND ".join(clauses), args
 
@@ -1203,7 +1205,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         # If there is a filter on relation_senders and relation_types join to the
         # relations table.
         if event_filter and (
-            event_filter.relation_senders or event_filter.relation_types
+            event_filter.related_by_senders or event_filter.related_by_rel_types
         ):
             # Filtering by relations could cause the same event to appear multiple
             # times (since there's no limit on the number of relations to an event).
@@ -1211,7 +1213,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             join_clause += """
                 LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
             """
-            if event_filter.relation_senders:
+            if event_filter.related_by_senders:
                 join_clause += """
                     LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
                 """
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e7fddd2426..0595df01d3 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,8 @@ from typing import (
     cast,
 )
 
+from typing_extensions import TypedDict
+
 from synapse.api.errors import StoreError
 
 if TYPE_CHECKING:
@@ -40,7 +42,12 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.state import StateFilter
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
+from synapse.types import (
+    JsonDict,
+    UserProfile,
+    get_domain_from_id,
+    get_localpart_from_id,
+)
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
 
+class SearchResult(TypedDict):
+    limited: bool
+    results: List[UserProfile]
+
+
 class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     # How many records do we calculate before sending it to
     # add_users_who_share_private_rooms?
@@ -718,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         users.update(rows)
         return list(users)
 
-    async def get_shared_rooms_for_users(
+    async def get_mutual_rooms_for_users(
         self, user_id: str, other_user_id: str
     ) -> Set[str]:
         """
@@ -732,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             A set of room ID's that the users share.
         """
 
-        def _get_shared_rooms_for_users_txn(
+        def _get_mutual_rooms_for_users_txn(
             txn: LoggingTransaction,
         ) -> List[Dict[str, str]]:
             txn.execute(
@@ -756,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             return rows
 
         rows = await self.db_pool.runInteraction(
-            "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+            "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
         )
 
         return {row["room_id"] for row in rows}
@@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
 
     async def search_user_dir(
         self, user_id: str, search_term: str, limit: int
-    ) -> JsonDict:
+    ) -> SearchResult:
         """Searches for users in directory
 
         Returns:
@@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
-        results = await self.db_pool.execute(
-            "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+        results = cast(
+            List[UserProfile],
+            await self.db_pool.execute(
+                "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+            ),
         )
 
         limited = len(results) > limit
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9abc02046e..afb7d5054d 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -27,7 +27,7 @@ def create_engine(database_config) -> BaseDatabaseEngine:
 
     if name == "psycopg2":
         # Note that psycopg2cffi-compat provides the psycopg2 module on pypy.
-        import psycopg2  # type: ignore
+        import psycopg2
 
         return PostgresEngine(psycopg2, database_config)
 
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 808342fafb..e8d29e2870 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine):
         self.default_isolation_level = (
             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
+        self.config = database_config
 
     @property
     def single_threaded(self) -> bool:
         return False
 
+    def get_db_locale(self, txn):
+        txn.execute(
+            "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+        )
+        collation, ctype = txn.fetchone()
+        return collation, ctype
+
     def check_database(self, db_conn, allow_outdated_version: bool = False):
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
         # together. For example, version 8.1.5 will be returned as 80105
         self._version = db_conn.server_version
+        allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
 
         # Are we on a supported PostgreSQL version?
         if not allow_outdated_version and self._version < 100000:
@@ -72,33 +81,39 @@ class PostgresEngine(BaseDatabaseEngine):
                     "See docs/postgres.md for more information." % (rows[0][0],)
                 )
 
-            txn.execute(
-                "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
-            )
-            collation, ctype = txn.fetchone()
+            collation, ctype = self.get_db_locale(txn)
             if collation != "C":
                 logger.warning(
-                    "Database has incorrect collation of %r. Should be 'C'\n"
-                    "See docs/postgres.md for more information.",
+                    "Database has incorrect collation of %r. Should be 'C'",
                     collation,
                 )
+                if not allow_unsafe_locale:
+                    raise IncorrectDatabaseSetup(
+                        "Database has incorrect collation of %r. Should be 'C'\n"
+                        "See docs/postgres.md for more information. You can override this check by"
+                        "setting 'allow_unsafe_locale' to true in the database config.",
+                        collation,
+                    )
 
             if ctype != "C":
-                logger.warning(
-                    "Database has incorrect ctype of %r. Should be 'C'\n"
-                    "See docs/postgres.md for more information.",
-                    ctype,
-                )
+                if not allow_unsafe_locale:
+                    logger.warning(
+                        "Database has incorrect ctype of %r. Should be 'C'",
+                        ctype,
+                    )
+                    raise IncorrectDatabaseSetup(
+                        "Database has incorrect ctype of %r. Should be 'C'\n"
+                        "See docs/postgres.md for more information. You can override this check by"
+                        "setting 'allow_unsafe_locale' to true in the database config.",
+                        ctype,
+                    )
 
     def check_new_database(self, txn):
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
 
-        txn.execute(
-            "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
-        )
-        collation, ctype = txn.fetchone()
+        collation, ctype = self.get_db_locale(txn)
 
         errors = []
 
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 7d543fdbe0..b402922817 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -1023,8 +1023,13 @@ class EventsPersistenceStorage:
 
         # Check if any of the changes that we don't have events for are joins.
         if events_to_check:
-            rows = await self.main_store.get_membership_from_event_ids(events_to_check)
-            is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
+            members = await self.main_store.get_membership_from_event_ids(
+                events_to_check
+            )
+            is_still_joined = any(
+                member and member.membership == Membership.JOIN
+                for member in members.values()
+            )
             if is_still_joined:
                 return True
 
@@ -1060,9 +1065,11 @@ class EventsPersistenceStorage:
             ), event_id in current_state.items()
             if typ == EventTypes.Member and not self.is_mine_id(state_key)
         ]
-        rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+        members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
         potentially_left_users.update(
-            row["user_id"] for row in rows if row["membership"] == Membership.JOIN
+            member.user_id
+            for member in members.values()
+            if member and member.membership == Membership.JOIN
         )
 
         return False
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 36ca2b8273..fba270150b 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -55,37 +55,6 @@ class PaginationChunk:
 
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
-class RelationPaginationToken:
-    """Pagination token for relation pagination API.
-
-    As the results are in topological order, we can use the
-    `topological_ordering` and `stream_ordering` fields of the events at the
-    boundaries of the chunk as pagination tokens.
-
-    Attributes:
-        topological: The topological ordering of the boundary event
-        stream: The stream ordering of the boundary event.
-    """
-
-    topological: int
-    stream: int
-
-    @staticmethod
-    def from_string(string: str) -> "RelationPaginationToken":
-        try:
-            t, s = string.split("-")
-            return RelationPaginationToken(int(t), int(s))
-        except ValueError:
-            raise SynapseError(400, "Invalid relation pagination token")
-
-    async def to_string(self, store: "DataStore") -> str:
-        return "%d-%d" % (self.topological, self.stream)
-
-    def as_tuple(self) -> Tuple[Any, ...]:
-        return attr.astuple(self)
-
-
-@attr.s(frozen=True, slots=True, auto_attribs=True)
 class AggregationPaginationToken:
     """Pagination token for relation aggregation pagination API.
 
diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py
index 22a7901e15..4b4b166e37 100644
--- a/synapse/storage/schema/main/delta/30/as_users.py
+++ b/synapse/storage/schema/main/delta/30/as_users.py
@@ -36,7 +36,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
         config_files = config.appservice.app_service_config_files
     except AttributeError:
         logger.warning("Could not get app_service_config_files from config")
-        pass
 
     appservices = load_appservices(config.server.server_name, config_files)
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e79ecf64a0..86f1a5373b 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -561,7 +561,7 @@ class StateGroupStorage:
         return state_group_delta.prev_group, state_group_delta.delta_ids
 
     async def get_state_groups_ids(
-        self, _room_id: str, event_ids: Iterable[str]
+        self, _room_id: str, event_ids: Collection[str]
     ) -> Dict[int, MutableStateMap[str]]:
         """Get the event IDs of all the state for the state groups for the given events
 
@@ -596,7 +596,7 @@ class StateGroupStorage:
         return group_to_state[state_group]
 
     async def get_state_groups(
-        self, room_id: str, event_ids: Iterable[str]
+        self, room_id: str, event_ids: Collection[str]
     ) -> Dict[int, List[EventBase]]:
         """Get the state groups for the given list of event_ids
 
@@ -648,7 +648,7 @@ class StateGroupStorage:
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
     async def get_state_for_events(
-        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
+        self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
     ) -> Dict[str, StateMap[EventBase]]:
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
@@ -684,7 +684,7 @@ class StateGroupStorage:
         return {event: event_to_state[event] for event in event_ids}
 
     async def get_state_ids_for_events(
-        self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
+        self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids