summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-03-29 10:15:25 +0100
committerErik Johnston <erik@matrix.org>2022-03-29 10:15:25 +0100
commitfd1b6334f071c7bbd88928513e99e5c8971ac414 (patch)
tree7333cc7290baba787047b22d0e5bdb875cd8e59b /synapse/storage/databases
parentMerge branch 'release-v1.55' of github.com:matrix-org/synapse into matrix-org... (diff)
parentRemove unused `auth_event_ids` argument plumbing (#12304) (diff)
downloadsynapse-fd1b6334f071c7bbd88928513e99e5c8971ac414.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/account_data.py41
-rw-r--r--synapse/storage/databases/main/cache.py61
-rw-r--r--synapse/storage/databases/main/event_federation.py12
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/group_server.py156
-rw-r--r--synapse/storage/databases/main/media_repository.py2
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py38
-rw-r--r--synapse/storage/databases/main/receipts.py37
-rw-r--r--synapse/storage/databases/main/registration.py3
-rw-r--r--synapse/storage/databases/main/relations.py149
-rw-r--r--synapse/storage/databases/main/room.py3
-rw-r--r--synapse/storage/databases/main/roommember.py37
-rw-r--r--synapse/storage/databases/main/search.py26
-rw-r--r--synapse/storage/databases/main/state.py24
-rw-r--r--synapse/storage/databases/main/stats.py2
-rw-r--r--synapse/storage/databases/main/user_directory.py31
16 files changed, 330 insertions, 299 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 52146aacc8..9af9f4f18e 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -14,7 +14,17 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    FrozenSet,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    cast,
+)
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -365,7 +375,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         )
 
     @cached(max_entries=5000, iterable=True)
-    async def ignored_by(self, user_id: str) -> Set[str]:
+    async def ignored_by(self, user_id: str) -> FrozenSet[str]:
         """
         Get users which ignore the given user.
 
@@ -375,7 +385,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         Return:
             The user IDs which ignore the given user.
         """
-        return set(
+        return frozenset(
             await self.db_pool.simple_select_onecol(
                 table="ignored_users",
                 keyvalues={"ignored_user_id": user_id},
@@ -384,6 +394,26 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             )
         )
 
+    @cached(max_entries=5000, iterable=True)
+    async def ignored_users(self, user_id: str) -> FrozenSet[str]:
+        """
+        Get users which the given user ignores.
+
+        Params:
+            user_id: The user ID which is making the request.
+
+        Return:
+            The user IDs which are ignored by the given user.
+        """
+        return frozenset(
+            await self.db_pool.simple_select_onecol(
+                table="ignored_users",
+                keyvalues={"ignorer_user_id": user_id},
+                retcol="ignored_user_id",
+                desc="ignored_users",
+            )
+        )
+
     def process_replication_rows(
         self,
         stream_name: str,
@@ -529,6 +559,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         else:
             currently_ignored_users = set()
 
+        # If the data has not changed, nothing to do.
+        if previously_ignored_users == currently_ignored_users:
+            return
+
         # Delete entries which are no longer ignored.
         self.db_pool.simple_delete_many_txn(
             txn,
@@ -551,6 +585,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         # Invalidate the cache for any ignored users which were added or removed.
         for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
             self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
+        self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
 
     async def purge_account_data_for_user(self, user_id: str) -> None:
         """
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index d6a2df1afe..dd4e83a2ad 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStream,
     EventsStreamCurrentStateRow,
     EventsStreamEventRow,
+    EventsStreamRow,
 )
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
@@ -31,6 +32,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import _CachedFunction
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_caches_txn(txn):
+        def get_all_updated_caches_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             # We purposefully don't bound by the current token, as we want to
             # send across cache invalidations as quickly as possible. Cache
             # invalidations are idempotent, so duplicates are fine.
@@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             "get_all_updated_caches", get_all_updated_caches_txn
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+    ) -> None:
         if stream_name == EventsStream.NAME:
             for row in rows:
                 self._process_event_stream_row(token, row)
@@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def _process_event_stream_row(self, token, row):
+    def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
         data = row.data
 
         if row.type == EventsStreamEventRow.TypeId:
+            assert isinstance(data, EventsStreamEventRow)
             self._invalidate_caches_for_event(
                 token,
                 data.event_id,
@@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 backfilled=False,
             )
         elif row.type == EventsStreamCurrentStateRow.TypeId:
-            self._curr_state_delta_stream_cache.entity_has_changed(
-                row.data.room_id, token
-            )
+            assert isinstance(data, EventsStreamCurrentStateRow)
+            self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token)
 
             if data.type == EventTypes.Member:
                 self.get_rooms_for_user_with_stream_ordering.invalidate(
@@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
     def _invalidate_caches_for_event(
         self,
-        stream_ordering,
-        event_id,
-        room_id,
-        etype,
-        state_key,
-        redacts,
-        relates_to,
-        backfilled,
-    ):
+        stream_ordering: int,
+        event_id: str,
+        room_id: str,
+        etype: str,
+        state_key: Optional[str],
+        redacts: Optional[str],
+        relates_to: Optional[str],
+        backfilled: bool,
+    ) -> None:
         self._invalidate_get_event_cache(event_id)
         self.have_seen_event.invalidate((room_id, event_id))
 
@@ -186,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
 
+        # The `_get_membership_from_event_id` is immutable, except for the
+        # case where we look up an event *before* persisting it.
+        self._get_membership_from_event_id.invalidate((event_id,))
+
         if not backfilled:
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
 
@@ -207,7 +217,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self.get_thread_summary.invalidate((relates_to,))
             self.get_thread_participated.invalidate((relates_to,))
 
-    async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+    async def invalidate_cache_and_stream(
+        self, cache_name: str, keys: Tuple[Any, ...]
+    ) -> None:
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
 
@@ -227,7 +239,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             keys,
         )
 
-    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+    def _invalidate_cache_and_stream(
+        self,
+        txn: LoggingTransaction,
+        cache_func: _CachedFunction,
+        keys: Tuple[Any, ...],
+    ) -> None:
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
 
@@ -238,7 +255,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         txn.call_after(cache_func.invalidate, keys)
         self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
 
-    def _invalidate_all_cache_and_stream(self, txn, cache_func):
+    def _invalidate_all_cache_and_stream(
+        self, txn: LoggingTransaction, cache_func: _CachedFunction
+    ) -> None:
         """Invalidates the entire cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
         """
@@ -279,8 +298,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             )
 
     def _send_invalidation_to_replication(
-        self, txn, cache_name: str, keys: Optional[Iterable[Any]]
-    ):
+        self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
+    ) -> None:
         """Notifies replication that given cache has been invalidated.
 
         Note that this does *not* invalidate the cache locally.
@@ -315,7 +334,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                     "instance_name": self._instance_name,
                     "cache_func": cache_name,
                     "keys": keys,
-                    "invalidation_ts": self.clock.time_msec(),
+                    "invalidation_ts": self._clock.time_msec(),
                 },
             )
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 277e6422eb..634e19e035 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1073,9 +1073,15 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             /* Get the depth and stream_ordering of the prev_event_id from the events table */
             INNER JOIN events
             ON prev_event_id = events.event_id
+
+            /* exclude outliers from the results (we don't have the state, so cannot
+             * verify if the requesting server can see them).
+             */
+            WHERE NOT events.outlier
+
             /* Look for an edge which matches the given event_id */
-            WHERE event_edges.event_id = ?
-            AND event_edges.is_state = ?
+            AND event_edges.event_id = ? AND NOT event_edges.is_state
+
             /* Because we can have many events at the same depth,
             * we want to also tie-break and sort on stream_ordering */
             ORDER BY depth DESC, stream_ordering DESC
@@ -1084,7 +1090,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         txn.execute(
             connected_prev_event_query,
-            (event_id, False, limit),
+            (event_id, limit),
         )
         return [
             BackfillQueueNavigationItem(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f60aef180..d253243125 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1745,6 +1745,13 @@ class PersistEventsStore:
                 (event.state_key,),
             )
 
+            # The `_get_membership_from_event_id` is immutable, except for the
+            # case where we look up an event *before* persisting it.
+            txn.call_after(
+                self.store._get_membership_from_event_id.invalidate,
+                (event.event_id,),
+            )
+
             # We update the local_current_membership table only if the event is
             # "current", i.e., its something that has just happened.
             #
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 3f6086050b..0aef121d83 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -13,13 +13,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
 
 from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
@@ -75,7 +79,7 @@ class GroupServerWorkerStore(SQLBaseStore):
     ) -> List[Dict[str, Any]]:
         # TODO: Pagination
 
-        keyvalues = {"group_id": group_id}
+        keyvalues: JsonDict = {"group_id": group_id}
         if not include_private:
             keyvalues["is_public"] = True
 
@@ -117,7 +121,7 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         # TODO: Pagination
 
-        def _get_rooms_in_group_txn(txn):
+        def _get_rooms_in_group_txn(txn: LoggingTransaction) -> List[_RoomInGroup]:
             sql = """
             SELECT room_id, is_public FROM group_rooms
                 WHERE group_id = ?
@@ -176,8 +180,10 @@ class GroupServerWorkerStore(SQLBaseStore):
                         * "order": int, the sort order of rooms in this category
         """
 
-        def _get_rooms_for_summary_txn(txn):
-            keyvalues = {"group_id": group_id}
+        def _get_rooms_for_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+            keyvalues: JsonDict = {"group_id": group_id}
             if not include_private:
                 keyvalues["is_public"] = True
 
@@ -241,7 +247,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_rooms_for_summary", _get_rooms_for_summary_txn
         )
 
-    async def get_group_categories(self, group_id):
+    async def get_group_categories(self, group_id: str) -> JsonDict:
         rows = await self.db_pool.simple_select_list(
             table="group_room_categories",
             keyvalues={"group_id": group_id},
@@ -257,7 +263,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    async def get_group_category(self, group_id, category_id):
+    async def get_group_category(self, group_id: str, category_id: str) -> JsonDict:
         category = await self.db_pool.simple_select_one(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
@@ -269,7 +275,7 @@ class GroupServerWorkerStore(SQLBaseStore):
 
         return category
 
-    async def get_group_roles(self, group_id):
+    async def get_group_roles(self, group_id: str) -> JsonDict:
         rows = await self.db_pool.simple_select_list(
             table="group_roles",
             keyvalues={"group_id": group_id},
@@ -285,7 +291,7 @@ class GroupServerWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    async def get_group_role(self, group_id, role_id):
+    async def get_group_role(self, group_id: str, role_id: str) -> JsonDict:
         role = await self.db_pool.simple_select_one(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
@@ -311,15 +317,19 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_local_groups_for_room",
         )
 
-    async def get_users_for_summary_by_role(self, group_id, include_private=False):
+    async def get_users_for_summary_by_role(
+        self, group_id: str, include_private: bool = False
+    ) -> Tuple[List[JsonDict], JsonDict]:
         """Get the users and roles that should be included in a summary request
 
         Returns:
             ([users], [roles])
         """
 
-        def _get_users_for_summary_txn(txn):
-            keyvalues = {"group_id": group_id}
+        def _get_users_for_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], JsonDict]:
+            keyvalues: JsonDict = {"group_id": group_id}
             if not include_private:
                 keyvalues["is_public"] = True
 
@@ -406,7 +416,9 @@ class GroupServerWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    async def get_users_membership_info_in_group(self, group_id, user_id):
+    async def get_users_membership_info_in_group(
+        self, group_id: str, user_id: str
+    ) -> JsonDict:
         """Get a dict describing the membership of a user in a group.
 
         Example if joined:
@@ -421,7 +433,7 @@ class GroupServerWorkerStore(SQLBaseStore):
              An empty dict if the user is not join/invite/etc
         """
 
-        def _get_users_membership_in_group_txn(txn):
+        def _get_users_membership_in_group_txn(txn: LoggingTransaction) -> JsonDict:
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 table="group_users",
@@ -463,10 +475,14 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_publicised_groups_for_user",
         )
 
-    async def get_attestations_need_renewals(self, valid_until_ms):
+    async def get_attestations_need_renewals(
+        self, valid_until_ms: int
+    ) -> List[Dict[str, Any]]:
         """Get all attestations that need to be renewed until givent time"""
 
-        def _get_attestations_need_renewals_txn(txn):
+        def _get_attestations_need_renewals_txn(
+            txn: LoggingTransaction,
+        ) -> List[Dict[str, Any]]:
             sql = """
                 SELECT group_id, user_id FROM group_attestations_renewals
                 WHERE valid_until_ms <= ?
@@ -478,7 +494,9 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_attestations_need_renewals", _get_attestations_need_renewals_txn
         )
 
-    async def get_remote_attestation(self, group_id, user_id):
+    async def get_remote_attestation(
+        self, group_id: str, user_id: str
+    ) -> Optional[JsonDict]:
         """Get the attestation that proves the remote agrees that the user is
         in the group.
         """
@@ -504,8 +522,8 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_joined_groups",
         )
 
-    async def get_all_groups_for_user(self, user_id, now_token):
-        def _get_all_groups_for_user_txn(txn):
+    async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
+        def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
             sql = """
                 SELECT group_id, type, membership, u.content
                 FROM local_group_updates AS u
@@ -528,15 +546,16 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_all_groups_for_user", _get_all_groups_for_user_txn
         )
 
-    async def get_groups_changes_for_user(self, user_id, from_token, to_token):
-        from_token = int(from_token)
-        has_changed = self._group_updates_stream_cache.has_entity_changed(
+    async def get_groups_changes_for_user(
+        self, user_id: str, from_token: int, to_token: int
+    ) -> List[JsonDict]:
+        has_changed = self._group_updates_stream_cache.has_entity_changed(  # type: ignore[attr-defined]
             user_id, from_token
         )
         if not has_changed:
             return []
 
-        def _get_groups_changes_for_user_txn(txn):
+        def _get_groups_changes_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
             sql = """
                 SELECT group_id, membership, type, u.content
                 FROM local_group_updates AS u
@@ -583,12 +602,14 @@ class GroupServerWorkerStore(SQLBaseStore):
         """
 
         last_id = int(last_id)
-        has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
+        has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)  # type: ignore[attr-defined]
 
         if not has_changed:
             return [], current_id, False
 
-        def _get_all_groups_changes_txn(txn):
+        def _get_all_groups_changes_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             sql = """
                 SELECT stream_id, group_id, user_id, type, content
                 FROM local_group_updates
@@ -596,10 +617,13 @@ class GroupServerWorkerStore(SQLBaseStore):
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [
-                (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
-                for stream_id, group_id, user_id, gtype, content_json in txn
-            ]
+            updates = cast(
+                List[Tuple[int, tuple]],
+                [
+                    (stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
+                    for stream_id, group_id, user_id, gtype, content_json in txn
+                ],
+            )
 
             limited = False
             upto_token = current_id
@@ -633,8 +657,8 @@ class GroupServerStore(GroupServerWorkerStore):
         self,
         group_id: str,
         room_id: str,
-        category_id: str,
-        order: int,
+        category_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
     ) -> None:
         """Add (or update) room's entry in summary.
@@ -661,11 +685,11 @@ class GroupServerStore(GroupServerWorkerStore):
 
     def _add_room_to_summary_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         group_id: str,
         room_id: str,
-        category_id: str,
-        order: int,
+        category_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
     ) -> None:
         """Add (or update) room's entry in summary.
@@ -750,7 +774,7 @@ class GroupServerStore(GroupServerWorkerStore):
                 WHERE group_id = ? AND category_id = ?
             """
             txn.execute(sql, (group_id, category_id))
-            (order,) = txn.fetchone()
+            (order,) = cast(Tuple[int], txn.fetchone())
 
         if existing:
             to_update = {}
@@ -766,7 +790,7 @@ class GroupServerStore(GroupServerWorkerStore):
                     "category_id": category_id,
                     "room_id": room_id,
                 },
-                values=to_update,
+                updatevalues=to_update,
             )
         else:
             if is_public is None:
@@ -785,7 +809,7 @@ class GroupServerStore(GroupServerWorkerStore):
             )
 
     async def remove_room_from_summary(
-        self, group_id: str, room_id: str, category_id: str
+        self, group_id: str, room_id: str, category_id: Optional[str]
     ) -> int:
         if category_id is None:
             category_id = _DEFAULT_CATEGORY_ID
@@ -808,8 +832,8 @@ class GroupServerStore(GroupServerWorkerStore):
         is_public: Optional[bool],
     ) -> None:
         """Add/update room category for group"""
-        insertion_values = {}
-        update_values = {"category_id": category_id}  # This cannot be empty
+        insertion_values: JsonDict = {}
+        update_values: JsonDict = {"category_id": category_id}  # This cannot be empty
 
         if profile is None:
             insertion_values["profile"] = "{}"
@@ -844,8 +868,8 @@ class GroupServerStore(GroupServerWorkerStore):
         is_public: Optional[bool],
     ) -> None:
         """Add/remove user role"""
-        insertion_values = {}
-        update_values = {"role_id": role_id}  # This cannot be empty
+        insertion_values: JsonDict = {}
+        update_values: JsonDict = {"role_id": role_id}  # This cannot be empty
 
         if profile is None:
             insertion_values["profile"] = "{}"
@@ -876,8 +900,8 @@ class GroupServerStore(GroupServerWorkerStore):
         self,
         group_id: str,
         user_id: str,
-        role_id: str,
-        order: int,
+        role_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
     ) -> None:
         """Add (or update) user's entry in summary.
@@ -904,13 +928,13 @@ class GroupServerStore(GroupServerWorkerStore):
 
     def _add_user_to_summary_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         group_id: str,
         user_id: str,
-        role_id: str,
-        order: int,
+        role_id: Optional[str],
+        order: Optional[int],
         is_public: Optional[bool],
-    ):
+    ) -> None:
         """Add (or update) user's entry in summary.
 
         Args:
@@ -989,7 +1013,7 @@ class GroupServerStore(GroupServerWorkerStore):
                 WHERE group_id = ? AND role_id = ?
             """
             txn.execute(sql, (group_id, role_id))
-            (order,) = txn.fetchone()
+            (order,) = cast(Tuple[int], txn.fetchone())
 
         if existing:
             to_update = {}
@@ -1005,7 +1029,7 @@ class GroupServerStore(GroupServerWorkerStore):
                     "role_id": role_id,
                     "user_id": user_id,
                 },
-                values=to_update,
+                updatevalues=to_update,
             )
         else:
             if is_public is None:
@@ -1024,7 +1048,7 @@ class GroupServerStore(GroupServerWorkerStore):
             )
 
     async def remove_user_from_summary(
-        self, group_id: str, user_id: str, role_id: str
+        self, group_id: str, user_id: str, role_id: Optional[str]
     ) -> int:
         if role_id is None:
             role_id = _DEFAULT_ROLE_ID
@@ -1065,7 +1089,7 @@ class GroupServerStore(GroupServerWorkerStore):
                 Optional if the user and group are on the same server
         """
 
-        def _add_user_to_group_txn(txn):
+        def _add_user_to_group_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_insert_txn(
                 txn,
                 table="group_users",
@@ -1108,7 +1132,7 @@ class GroupServerStore(GroupServerWorkerStore):
         await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
 
     async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
-        def _remove_user_from_group_txn(txn):
+        def _remove_user_from_group_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="group_users",
@@ -1159,7 +1183,7 @@ class GroupServerStore(GroupServerWorkerStore):
         )
 
     async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
-        def _remove_room_from_group_txn(txn):
+        def _remove_room_from_group_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="group_rooms",
@@ -1216,7 +1240,9 @@ class GroupServerStore(GroupServerWorkerStore):
 
         content = content or {}
 
-        def _register_user_group_membership_txn(txn, next_id):
+        def _register_user_group_membership_txn(
+            txn: LoggingTransaction, next_id: int
+        ) -> int:
             # TODO: Upsert?
             self.db_pool.simple_delete_txn(
                 txn,
@@ -1249,7 +1275,7 @@ class GroupServerStore(GroupServerWorkerStore):
                     ),
                 },
             )
-            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
+            self._group_updates_stream_cache.entity_has_changed(user_id, next_id)  # type: ignore[attr-defined]
 
             # TODO: Insert profile to ensure it comes down stream if its a join.
 
@@ -1289,7 +1315,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        async with self._group_updates_id_gen.get_next() as next_id:
+        async with self._group_updates_id_gen.get_next() as next_id:  # type: ignore[attr-defined]
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
@@ -1298,7 +1324,13 @@ class GroupServerStore(GroupServerWorkerStore):
         return res
 
     async def create_group(
-        self, group_id, user_id, name, avatar_url, short_description, long_description
+        self,
+        group_id: str,
+        user_id: str,
+        name: str,
+        avatar_url: str,
+        short_description: str,
+        long_description: str,
     ) -> None:
         await self.db_pool.simple_insert(
             table="groups",
@@ -1313,7 +1345,7 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="create_group",
         )
 
-    async def update_group_profile(self, group_id, profile):
+    async def update_group_profile(self, group_id: str, profile: JsonDict) -> None:
         await self.db_pool.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
@@ -1361,8 +1393,8 @@ class GroupServerStore(GroupServerWorkerStore):
             desc="remove_attestation_renewal",
         )
 
-    def get_group_stream_token(self):
-        return self._group_updates_id_gen.get_current_token()
+    def get_group_stream_token(self) -> int:
+        return self._group_updates_id_gen.get_current_token()  # type: ignore[attr-defined]
 
     async def delete_group(self, group_id: str) -> None:
         """Deletes a group fully from the database.
@@ -1371,7 +1403,7 @@ class GroupServerStore(GroupServerWorkerStore):
             group_id: The group ID to delete.
         """
 
-        def _delete_group_txn(txn):
+        def _delete_group_txn(txn: LoggingTransaction) -> None:
             tables = [
                 "groups",
                 "group_users",
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index cbba356b4a..322ed05390 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -156,7 +156,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
-        self.server_name = hs.hostname
+        self.server_name: str = hs.hostname
 
     async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
         """Get the metadata for a local piece of media
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e9a0cdc6be..216622964a 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,15 +12,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
+    LoggingTransaction,
     make_in_list_sql_clause,
 )
+from synapse.storage.databases.main.registration import RegistrationWorkerStore
 from synapse.util.caches.descriptors import cached
 from synapse.util.threepids import canonicalise_email
 
@@ -56,7 +58,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             Number of current monthly active users
         """
 
-        def _count_users(txn):
+        def _count_users(txn: LoggingTransaction) -> int:
             # Exclude app service users
             sql = """
                 SELECT COUNT(*)
@@ -66,7 +68,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
                 WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
             """
             txn.execute(sql)
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_users", _count_users)
@@ -84,7 +86,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
 
         """
 
-        def _count_users_by_service(txn):
+        def _count_users_by_service(txn: LoggingTransaction) -> Dict[str, int]:
             sql = """
                 SELECT COALESCE(appservice_id, 'native'), COUNT(*)
                 FROM monthly_active_users
@@ -93,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             """
 
             txn.execute(sql)
-            result = txn.fetchall()
+            result = cast(List[Tuple[str, int]], txn.fetchall())
             return dict(result)
 
         return await self.db_pool.runInteraction(
@@ -141,12 +143,12 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         )
 
     @wrap_as_background_process("reap_monthly_active_users")
-    async def reap_monthly_active_users(self):
+    async def reap_monthly_active_users(self) -> None:
         """Cleans out monthly active user table to ensure that no stale
         entries exist.
         """
 
-        def _reap_users(txn, reserved_users):
+        def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None:
             """
             Args:
                 reserved_users (tuple): reserved users to preserve
@@ -210,10 +212,10 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             # is racy.
             # Have resolved to invalidate the whole cache for now and do
             # something about it if and when the perf becomes significant
-            self._invalidate_all_cache_and_stream(
+            self._invalidate_all_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.user_last_seen_monthly_active
             )
-            self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+            self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())  # type: ignore[attr-defined]
 
         reserved_users = await self.get_registered_reserved_users()
         await self.db_pool.runInteraction(
@@ -221,7 +223,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         )
 
 
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -242,13 +244,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
         )
 
-    def _initialise_reserved_users(self, txn, threepids):
+    def _initialise_reserved_users(
+        self, txn: LoggingTransaction, threepids: List[dict]
+    ) -> None:
         """Ensures that reserved threepids are accounted for in the MAU table, should
         be called on start up.
 
         Args:
-            txn (cursor):
-            threepids (list[dict]): List of threepid dicts to reserve
+            txn:
+            threepids: List of threepid dicts to reserve
         """
 
         # XXX what is this function trying to achieve?  It upserts into
@@ -299,7 +303,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
         )
 
-    def upsert_monthly_active_user_txn(self, txn, user_id):
+    def upsert_monthly_active_user_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> None:
         """Updates or inserts monthly active user member
 
         We consciously do not call is_support_txn from this method because it
@@ -336,7 +342,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             txn, self.user_last_seen_monthly_active, (user_id,)
         )
 
-    async def populate_monthly_active_users(self, user_id):
+    async def populate_monthly_active_users(self, user_id: str) -> None:
         """Checks on the state of monthly active user limits and optionally
         add the user to the monthly active tables
 
@@ -345,7 +351,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
         """
         if self._limit_usage_by_mau or self._mau_stats_only:
             # Trial users and guests should not be included as part of MAU group
-            is_guest = await self.is_guest(user_id)
+            is_guest = await self.is_guest(user_id)  # type: ignore[attr-defined]
             if is_guest:
                 return
             is_trial = await self.is_trial_user(user_id)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index bf0b903af2..e6f97aeece 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -24,10 +24,9 @@ from typing import (
     Optional,
     Set,
     Tuple,
+    cast,
 )
 
-from twisted.internet import defer
-
 from synapse.api.constants import ReceiptTypes
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams import ReceiptsStream
@@ -38,7 +37,11 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdTracker,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -58,6 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         hs: "HomeServer",
     ):
         self._instance_name = hs.get_instance_name()
+        self._receipts_id_gen: AbstractStreamIdTracker
 
         if isinstance(database.engine, PostgresEngine):
             self._can_write_to_receipts = (
@@ -161,7 +165,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 " AND user_id = ?"
             )
             txn.execute(sql, (user_id,))
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, int, int]], txn.fetchall())
 
         rows = await self.db_pool.runInteraction(
             "get_receipts_for_user_with_orderings", f
@@ -257,7 +261,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if not rows:
             return []
 
-        content = {}
+        content: JsonDict = {}
         for row in rows:
             content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
                 row["user_id"]
@@ -305,7 +309,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "_get_linearized_receipts_for_rooms", f
         )
 
-        results = {}
+        results: JsonDict = {}
         for row in txn_results:
             # We want a single event per room, since we want to batch the
             # receipts by room, event and type.
@@ -370,7 +374,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             "get_linearized_receipts_for_all_rooms", f
         )
 
-        results = {}
+        results: JsonDict = {}
         for row in txn_results:
             # We want a single event per room, since we want to batch the
             # receipts by room, event and type.
@@ -399,7 +403,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """
 
         if last_id == current_id:
-            return defer.succeed([])
+            return []
 
         def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
             sql = """
@@ -453,7 +457,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
             """
             txn.execute(sql, (last_id, current_id, limit))
 
-            updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
+            updates = cast(
+                List[Tuple[int, list]],
+                [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
+            )
 
             limited = False
             upper_bound = current_id
@@ -496,7 +503,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
         self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == ReceiptsStream.NAME:
             self._receipts_id_gen.advance(instance_name, token)
             for row in rows:
@@ -584,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
         if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
-            self._remove_old_push_actions_before_txn(
+            self._remove_old_push_actions_before_txn(  # type: ignore[attr-defined]
                 txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
             )
 
@@ -637,7 +650,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 "insert_receipt_conv", graph_to_linear
             )
 
-        async with self._receipts_id_gen.get_next() as stream_id:
+        async with self._receipts_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a698d10cc5..7f3d190e94 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -22,6 +22,7 @@ import attr
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
+from synapse.config.homeserver import HomeServerConfig
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage.database import (
     DatabasePool,
@@ -123,7 +124,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.config = hs.config
+        self.config: HomeServerConfig = hs.config
 
         # Note: we don't check this sequence for consistency as we'd have to
         # call `find_max_generated_user_id_localpart` each time, which is
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index c4869d64e6..b2295fd51f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -27,7 +27,6 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
 
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
@@ -41,45 +40,15 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
-    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _ThreadAggregation:
-    # The latest event in the thread.
-    latest_event: EventBase
-    # The latest edit to the latest event in the thread.
-    latest_edit: Optional[EventBase]
-    # The total number of events in the thread.
-    count: int
-    # True if the current user has sent an event to the thread.
-    current_user_participated: bool
-
-
-@attr.s(slots=True, auto_attribs=True)
-class BundledAggregations:
-    """
-    The bundled aggregations for an event.
-
-    Some values require additional processing during serialization.
-    """
-
-    annotations: Optional[JsonDict] = None
-    references: Optional[JsonDict] = None
-    replace: Optional[EventBase] = None
-    thread: Optional[_ThreadAggregation] = None
-
-    def __bool__(self) -> bool:
-        return bool(self.annotations or self.references or self.replace or self.thread)
-
-
 class RelationsWorkerStore(SQLBaseStore):
     def __init__(
         self,
@@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
-    async def _get_applicable_edits(
+    async def get_applicable_edits(
         self, event_ids: Collection[str]
     ) -> Dict[str, Optional[EventBase]]:
         """Get the most recent edit (if any) that has happened for the given
@@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
-    async def _get_thread_summaries(
+    async def get_thread_summaries(
         self, event_ids: Collection[str]
     ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
         """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
         latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
         # Check to see if any of those events are edited.
-        latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+        latest_edits = await self.get_applicable_edits(latest_event_ids.values())
 
         # Map to the event IDs to the thread summary.
         #
@@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
-    async def _get_threads_participated(
+    async def get_threads_participated(
         self, event_ids: Collection[str], user_id: str
     ) -> Dict[str, bool]:
         """Get whether the requesting user participated in the given threads.
@@ -766,114 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
-    async def _get_bundled_aggregation_for_event(
-        self, event: EventBase, user_id: str
-    ) -> Optional[BundledAggregations]:
-        """Generate bundled aggregations for an event.
-
-        Note that this does not use a cache, but depends on cached methods.
-
-        Args:
-            event: The event to calculate bundled aggregations for.
-            user_id: The user requesting the bundled aggregations.
-
-        Returns:
-            The bundled aggregations for an event, if bundled aggregations are
-            enabled and the event can have bundled aggregations.
-        """
-
-        # Do not bundle aggregations for an event which represents an edit or an
-        # annotation. It does not make sense for them to have related events.
-        relates_to = event.content.get("m.relates_to")
-        if isinstance(relates_to, (dict, frozendict)):
-            relation_type = relates_to.get("rel_type")
-            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
-                return None
-
-        event_id = event.event_id
-        room_id = event.room_id
-
-        # The bundled aggregations to include, a mapping of relation type to a
-        # type-specific value. Some types include the direct return type here
-        # while others need more processing during serialization.
-        aggregations = BundledAggregations()
-
-        annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
-        if annotations.chunk:
-            aggregations.annotations = await annotations.to_dict(
-                cast("DataStore", self)
-            )
-
-        references = await self.get_relations_for_event(
-            event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
-        )
-        if references.chunk:
-            aggregations.references = await references.to_dict(cast("DataStore", self))
-
-        # Store the bundled aggregations in the event metadata for later use.
-        return aggregations
-
-    async def get_bundled_aggregations(
-        self, events: Iterable[EventBase], user_id: str
-    ) -> Dict[str, BundledAggregations]:
-        """Generate bundled aggregations for events.
-
-        Args:
-            events: The iterable of events to calculate bundled aggregations for.
-            user_id: The user requesting the bundled aggregations.
-
-        Returns:
-            A map of event ID to the bundled aggregation for the event. Not all
-            events may have bundled aggregations in the results.
-        """
-        # De-duplicate events by ID to handle the same event requested multiple times.
-        #
-        # State events do not get bundled aggregations.
-        events_by_id = {
-            event.event_id: event for event in events if not event.is_state()
-        }
-
-        # event ID -> bundled aggregation in non-serialized form.
-        results: Dict[str, BundledAggregations] = {}
-
-        # Fetch other relations per event.
-        for event in events_by_id.values():
-            event_result = await self._get_bundled_aggregation_for_event(event, user_id)
-            if event_result:
-                results[event.event_id] = event_result
-
-        # Fetch any edits (but not for redacted events).
-        edits = await self._get_applicable_edits(
-            [
-                event_id
-                for event_id, event in events_by_id.items()
-                if not event.internal_metadata.is_redacted()
-            ]
-        )
-        for event_id, edit in edits.items():
-            results.setdefault(event_id, BundledAggregations()).replace = edit
-
-        # Fetch thread summaries.
-        summaries = await self._get_thread_summaries(events_by_id.keys())
-        # Only fetch participated for a limited selection based on what had
-        # summaries.
-        participated = await self._get_threads_participated(summaries.keys(), user_id)
-        for event_id, summary in summaries.items():
-            if summary:
-                thread_count, latest_thread_event, edit = summary
-                results.setdefault(
-                    event_id, BundledAggregations()
-                ).thread = _ThreadAggregation(
-                    latest_event=latest_thread_event,
-                    latest_edit=edit,
-                    count=thread_count,
-                    # If there's a thread summary it must also exist in the
-                    # participated dictionary.
-                    current_user_participated=participated[event_id],
-                )
-
-        return results
-
 
 class RelationsStore(RelationsWorkerStore):
     pass
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 94068940b9..18b1acd9e1 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -34,6 +34,7 @@ import attr
 from synapse.api.constants import EventContentFields, EventTypes, JoinRules
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.config.homeserver import HomeServerConfig
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
@@ -98,7 +99,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.config = hs.config
+        self.config: HomeServerConfig = hs.config
 
     async def store_room(
         self,
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bef675b845..3248da5356 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class EventIdMembership:
+    """Returned by `get_membership_from_event_ids`"""
+
+    user_id: str
+    membership: str
+
+
 class RoomMemberWorkerStore(EventsWorkerStore):
     def __init__(
         self,
@@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             retcols=("user_id", "display_name", "avatar_url", "event_id"),
             keyvalues={"membership": Membership.JOIN},
             batch_size=500,
-            desc="_get_membership_from_event_ids",
+            desc="_get_joined_profiles_from_event_ids",
         )
 
         return {
@@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return set(room_ids)
 
+    @cached(max_entries=5000)
+    async def _get_membership_from_event_id(
+        self, member_event_id: str
+    ) -> Optional[EventIdMembership]:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
+    )
     async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
-    ) -> List[dict]:
-        """Get user_id and membership of a set of event IDs."""
+    ) -> Dict[str, Optional[EventIdMembership]]:
+        """Get user_id and membership of a set of event IDs.
+
+        Returns:
+            Mapping from event ID to `EventIdMembership` if the event is a
+            membership event, otherwise the value is None.
+        """
 
-        return await self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=member_event_ids,
@@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             desc="get_membership_from_event_ids",
         )
 
+        return {
+            row["event_id"]: EventIdMembership(
+                membership=row["membership"], user_id=row["user_id"]
+            )
+            for row in rows
+        }
+
     async def is_local_host_in_room_ignoring_users(
         self, room_id: str, ignore_users: Collection[str]
     ) -> bool:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index bb41beb827..79abe758e6 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
 
 import attr
 
@@ -74,7 +74,7 @@ class SearchWorkerStore(SQLBaseStore):
                 " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
             )
 
-            args = (
+            args1 = (
                 (
                     entry.event_id,
                     entry.room_id,
@@ -86,14 +86,14 @@ class SearchWorkerStore(SQLBaseStore):
                 for entry in entries
             )
 
-            txn.execute_batch(sql, args)
+            txn.execute_batch(sql, args1)
 
         elif isinstance(self.database_engine, Sqlite3Engine):
             sql = (
                 "INSERT INTO event_search (event_id, room_id, key, value)"
                 " VALUES (?,?,?,?)"
             )
-            args = (
+            args2 = (
                 (
                     entry.event_id,
                     entry.room_id,
@@ -102,7 +102,7 @@ class SearchWorkerStore(SQLBaseStore):
                 )
                 for entry in entries
             )
-            txn.execute_batch(sql, args)
+            txn.execute_batch(sql, args2)
 
         else:
             # This should be unreachable.
@@ -427,7 +427,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         search_query = _parse_query(self.database_engine, search_term)
 
-        args = []
+        args: List[Any] = []
 
         # Make sure we don't explode because the person is in too many rooms.
         # We filter the results below regardless.
@@ -496,7 +496,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
         # search results (which is a data leak)
-        events = await self.get_events_as_list(
+        events = await self.get_events_as_list(  # type: ignore[attr-defined]
             [r["event_id"] for r in results],
             redact_behaviour=EventRedactBehaviour.BLOCK,
         )
@@ -530,7 +530,7 @@ class SearchStore(SearchBackgroundUpdateStore):
         room_ids: Collection[str],
         search_term: str,
         keys: Iterable[str],
-        limit,
+        limit: int,
         pagination_token: Optional[str] = None,
     ) -> JsonDict:
         """Performs a full text search over events with given keys.
@@ -549,7 +549,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         search_query = _parse_query(self.database_engine, search_term)
 
-        args = []
+        args: List[Any] = []
 
         # Make sure we don't explode because the person is in too many rooms.
         # We filter the results below regardless.
@@ -573,9 +573,9 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         if pagination_token:
             try:
-                origin_server_ts, stream = pagination_token.split(",")
-                origin_server_ts = int(origin_server_ts)
-                stream = int(stream)
+                origin_server_ts_str, stream_str = pagination_token.split(",")
+                origin_server_ts = int(origin_server_ts_str)
+                stream = int(stream_str)
             except Exception:
                 raise SynapseError(400, "Invalid pagination token")
 
@@ -654,7 +654,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
         # search results (which is a data leak)
-        events = await self.get_events_as_list(
+        events = await self.get_events_as_list(  # type: ignore[attr-defined]
             [r["event_id"] for r in results],
             redact_behaviour=EventRedactBehaviour.BLOCK,
         )
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 417aef1dbc..28460fd364 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import collections.abc
 import logging
-from typing import TYPE_CHECKING, Iterable, Optional, Set
+from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -29,7 +29,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
-from synapse.types import StateMap
+from synapse.types import JsonDict, StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             # We delegate to the cached version
             return await self.get_current_state_ids(room_id)
 
-        def _get_filtered_current_state_ids_txn(txn):
+        def _get_filtered_current_state_ids_txn(
+            txn: LoggingTransaction,
+        ) -> StateMap[str]:
             results = {}
             sql = """
                 SELECT type, state_key, event_id FROM current_state_events
@@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         event_id = state.get((EventTypes.CanonicalAlias, ""))
         if not event_id:
-            return
+            return None
 
         event = await self.get_event(event_id, allow_none=True)
         if not event:
-            return
+            return None
 
         return event.content.get("canonical_alias")
 
@@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         list_name="event_ids",
         num_args=1,
     )
-    async def _get_state_group_for_events(self, event_ids):
+    async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
         """Returns mapping event_id -> state_group"""
         rows = await self.db_pool.simple_select_many_batch(
             table="event_to_state_groups",
@@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.server_name = hs.hostname
+        self.server_name: str = hs.hostname
 
         self.db_pool.updates.register_background_index_update(
             self.CURRENT_STATE_INDEX_UPDATE_NAME,
@@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
             self._background_remove_left_rooms,
         )
 
-    async def _background_remove_left_rooms(self, progress, batch_size):
+    async def _background_remove_left_rooms(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to delete rows from `current_state_events` and
         `event_forward_extremities` tables of rooms that the server is no
         longer joined to.
@@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
 
         last_room_id = progress.get("last_room_id", "")
 
-        def _background_remove_left_rooms_txn(txn):
+        def _background_remove_left_rooms_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[bool, Set[str]]:
             # get a batch of room ids to consider
             sql = """
                 SELECT DISTINCT room_id FROM current_state_events
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 427ae1f649..b95dbef678 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -108,7 +108,7 @@ class StatsStore(StateDeltasStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.server_name = hs.hostname
+        self.server_name: str = hs.hostname
         self.clock = self.hs.get_clock()
         self.stats_enabled = hs.config.stats.stats_enabled
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index e7fddd2426..df772d4721 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -26,6 +26,8 @@ from typing import (
     cast,
 )
 
+from typing_extensions import TypedDict
+
 from synapse.api.errors import StoreError
 
 if TYPE_CHECKING:
@@ -40,7 +42,12 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.state import StateFilter
 from synapse.storage.databases.main.state_deltas import StateDeltasStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
+from synapse.types import (
+    JsonDict,
+    UserProfile,
+    get_domain_from_id,
+    get_localpart_from_id,
+)
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -61,7 +68,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     ) -> None:
         super().__init__(database, db_conn, hs)
 
-        self.server_name = hs.hostname
+        self.server_name: str = hs.hostname
 
         self.db_pool.updates.register_background_update_handler(
             "populate_user_directory_createtables",
@@ -591,6 +598,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
 
+class SearchResult(TypedDict):
+    limited: bool
+    results: List[UserProfile]
+
+
 class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     # How many records do we calculate before sending it to
     # add_users_who_share_private_rooms?
@@ -718,7 +730,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         users.update(rows)
         return list(users)
 
-    async def get_shared_rooms_for_users(
+    async def get_mutual_rooms_for_users(
         self, user_id: str, other_user_id: str
     ) -> Set[str]:
         """
@@ -732,7 +744,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             A set of room ID's that the users share.
         """
 
-        def _get_shared_rooms_for_users_txn(
+        def _get_mutual_rooms_for_users_txn(
             txn: LoggingTransaction,
         ) -> List[Dict[str, str]]:
             txn.execute(
@@ -756,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             return rows
 
         rows = await self.db_pool.runInteraction(
-            "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
+            "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
         )
 
         return {row["room_id"] for row in rows}
@@ -777,7 +789,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
 
     async def search_user_dir(
         self, user_id: str, search_term: str, limit: int
-    ) -> JsonDict:
+    ) -> SearchResult:
         """Searches for users in directory
 
         Returns:
@@ -910,8 +922,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
-        results = await self.db_pool.execute(
-            "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+        results = cast(
+            List[UserProfile],
+            await self.db_pool.execute(
+                "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
+            ),
         )
 
         limited = len(results) > limit