summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/controllers/persist_events.py30
-rw-r--r--synapse/storage/controllers/state.py42
-rw-r--r--synapse/storage/databases/main/account_data.py3
-rw-r--r--synapse/storage/databases/main/event_federation.py9
-rw-r--r--synapse/storage/databases/main/event_push_actions.py303
-rw-r--r--synapse/storage/databases/main/events.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py38
-rw-r--r--synapse/storage/databases/main/push_rule.py127
-rw-r--r--synapse/storage/databases/main/receipts.py2
-rw-r--r--synapse/storage/databases/main/registration.py2
-rw-r--r--synapse/storage/databases/main/room.py6
-rw-r--r--synapse/storage/databases/main/roommember.py126
-rw-r--r--synapse/storage/util/id_generators.py13
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py3
14 files changed, 387 insertions, 319 deletions
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index cf98b0ab48..dad3731b9b 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -45,8 +45,14 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.logging import opentracing
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import (
+    SynapseTags,
+    active_span,
+    set_tag,
+    start_active_span_follows_from,
+    trace,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.controllers.state import StateStorageController
 from synapse.storage.databases import Databases
@@ -223,7 +229,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
             queue.append(end_item)
 
         # also add our active opentracing span to the item so that we get a link back
-        span = opentracing.active_span()
+        span = active_span()
         if span:
             end_item.parent_opentracing_span_contexts.append(span.context)
 
@@ -234,7 +240,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
         res = await make_deferred_yieldable(end_item.deferred.observe())
 
         # add another opentracing span which links to the persist trace.
-        with opentracing.start_active_span_follows_from(
+        with start_active_span_follows_from(
             f"{task.name}_complete", (end_item.opentracing_span_context,)
         ):
             pass
@@ -266,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
                     try:
-                        with opentracing.start_active_span_follows_from(
+                        with start_active_span_follows_from(
                             item.task.name,
                             item.parent_opentracing_span_contexts,
                             inherit_force_tracing=True,
@@ -355,7 +361,7 @@ class EventsPersistenceStorageController:
                 f"Found an unexpected task type in event persistence queue: {task}"
             )
 
-    @opentracing.trace
+    @trace
     async def persist_events(
         self,
         events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
@@ -380,9 +386,21 @@ class EventsPersistenceStorageController:
             PartialStateConflictError: if attempting to persist a partial state event in
                 a room that has been un-partial stated.
         """
+        event_ids: List[str] = []
         partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
         for event, ctx in events_and_contexts:
             partitioned.setdefault(event.room_id, []).append((event, ctx))
+            event_ids.append(event.event_id)
+
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str(event_ids),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+        set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
 
         async def enqueue(
             item: Tuple[str, List[Tuple[EventBase, EventContext]]]
@@ -418,7 +436,7 @@ class EventsPersistenceStorageController:
             self.main_store.get_room_max_token(),
         )
 
-    @opentracing.trace
+    @trace
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 0d480f1014..f9ffd0e29e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -29,7 +29,8 @@ from typing import (
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.logging.opentracing import trace
+from synapse.logging.opentracing import tag_args, trace
+from synapse.storage.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
 from synapse.storage.util.partial_state_events_tracker import (
     PartialCurrentStateTracker,
@@ -228,10 +229,12 @@ class StateStorageController:
         return {event: event_to_state[event] for event in event_ids}
 
     @trace
+    @tag_args
     async def get_state_ids_for_events(
         self,
         event_ids: Collection[str],
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
@@ -240,6 +243,9 @@ class StateStorageController:
         Args:
             event_ids: events whose state should be returned
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at these events and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
 
         Returns:
             A dict from event_id -> (type, state_key) -> event_id
@@ -248,8 +254,12 @@ class StateStorageController:
             RuntimeError if we don't have a state group for one or more of the events
                 (ie they are outliers or unknown)
         """
-        await_full_state = True
-        if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+        if (
+            await_full_state
+            and state_filter
+            and not state_filter.must_await_full_state(self._is_mine_id)
+        ):
+            # Full state is not required if the state filter is restrictive enough.
             await_full_state = False
 
         event_to_groups = await self.get_state_group_for_events(
@@ -292,7 +302,10 @@ class StateStorageController:
 
     @trace
     async def get_state_ids_for_event(
-        self, event_id: str, state_filter: Optional[StateFilter] = None
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event
@@ -300,6 +313,9 @@ class StateStorageController:
         Args:
             event_id: event whose state should be returned
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the event and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
 
         Returns:
             A dict from (type, state_key) -> state_event_id
@@ -309,7 +325,9 @@ class StateStorageController:
                 outlier or is unknown)
         """
         state_map = await self.get_state_ids_for_events(
-            [event_id], state_filter or StateFilter.all()
+            [event_id],
+            state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
         )
         return state_map[event_id]
 
@@ -332,6 +350,7 @@ class StateStorageController:
         )
 
     @trace
+    @tag_args
     async def get_state_group_for_events(
         self,
         event_ids: Collection[str],
@@ -473,6 +492,7 @@ class StateStorageController:
             prev_stream_id, max_stream_id
         )
 
+    @trace
     async def get_current_state(
         self, room_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[EventBase]:
@@ -506,3 +526,15 @@ class StateStorageController:
         await self._partial_state_room_tracker.await_full_state(room_id)
 
         return await self.stores.main.get_current_hosts_in_room(room_id)
+
+    async def get_users_in_room_with_profiles(
+        self, room_id: str
+    ) -> Dict[str, ProfileInfo]:
+        """
+        Get the current users in the room with their profiles.
+        If the room is currently partial-stated, this will block until the room has
+        full state.
+        """
+        await self._partial_state_room_tracker.await_full_state(room_id)
+
+        return await self.stores.main.get_users_in_room_with_profiles(room_id)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9af9f4f18e..c38b8a9e5a 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn, self.get_account_data_for_room, (user_id,)
         )
         self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
-        self._invalidate_cache_and_stream(
-            txn, self.get_push_rules_enabled_for_user, (user_id,)
-        )
         # This user might be contained in the ignored_by cache for other users,
         # so we have to invalidate it all.
         self._invalidate_all_cache_and_stream(txn, self.ignored_by)
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..c836078da6 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.opentracing import tag_args, trace
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
@@ -126,6 +127,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
         return await self.get_events_as_list(event_ids)
 
+    @trace
+    @tag_args
     async def get_auth_chain_ids(
         self,
         room_id: str,
@@ -709,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
+    @trace
+    @tag_args
     async def get_oldest_event_ids_with_depth_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -767,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
+    @trace
     async def get_insertion_event_backward_extremities_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -1339,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         event_results.reverse()
         return event_results
 
+    @trace
+    @tag_args
     async def get_successor_events(self, event_id: str) -> List[str]:
         """Fetch all events that have the given event as a prev event
 
@@ -1375,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             _delete_old_forward_extrem_cache_txn,
         )
 
+    @trace
     async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
         await self.db_pool.simple_upsert(
             table="insertion_event_extremities",
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 161aad0f89..8dfa545c27 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -74,7 +74,17 @@ receipt.
 """
 
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
 
 import attr
 
@@ -154,7 +164,9 @@ class NotifCounts:
     highlight_count: int = 0
 
 
-def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
+def _serialize_action(
+    actions: Collection[Union[Mapping, str]], is_highlight: bool
+) -> str:
     """Custom serializer for actions. This allows us to "compress" common actions.
 
     We use the fact that most users have the same actions for notifs (and for
@@ -227,7 +239,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         user_id: str,
     ) -> NotifCounts:
         """Get the notification count, the highlight count and the unread message count
-        for a given user in a given room after the given read receipt.
+        for a given user in a given room after their latest read receipt.
 
         Note that this function assumes the user to be a current member of the room,
         since it's either called by the sync handler to handle joined room entries, or by
@@ -238,9 +250,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             user_id: The user to retrieve the counts for.
 
         Returns
-            A dict containing the counts mentioned earlier in this docstring,
-            respectively under the keys "notify_count", "highlight_count" and
-            "unread_count".
+            A NotifCounts object containing the notification count, the highlight count
+            and the unread message count.
         """
         return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
@@ -255,6 +266,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         room_id: str,
         user_id: str,
     ) -> NotifCounts:
+        # Get the stream ordering of the user's latest receipt in the room.
         result = self.get_last_receipt_for_user_txn(
             txn,
             user_id,
@@ -266,13 +278,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             ),
         )
 
-        stream_ordering = None
         if result:
             _, stream_ordering = result
 
-        if stream_ordering is None:
-            # Either last_read_event_id is None, or it's an event we don't have (e.g.
-            # because it's been purged), in which case retrieve the stream ordering for
+        else:
+            # If the user has no receipts in the room, retrieve the stream ordering for
             # the latest membership event from this user in this room (which we assume is
             # a join).
             event_id = self.db_pool.simple_select_one_onecol_txn(
@@ -289,10 +299,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
 
     def _get_unread_counts_by_pos_txn(
-        self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        user_id: str,
+        receipt_stream_ordering: int,
     ) -> NotifCounts:
         """Get the number of unread messages for a user/room that have happened
         since the given stream ordering.
+
+        Args:
+            txn: The database transaction.
+            room_id: The room ID to get unread counts for.
+            user_id: The user ID to get unread counts for.
+            receipt_stream_ordering: The stream ordering of the user's latest
+                receipt in the room. If there are no receipts, the stream ordering
+                of the user's join event.
+
+        Returns
+            A NotifCounts object containing the notification count, the highlight count
+            and the unread message count.
         """
 
         counts = NotifCounts()
@@ -320,7 +346,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     OR last_receipt_stream_ordering = ?
                 )
             """,
-            (room_id, user_id, stream_ordering, stream_ordering),
+            (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
         )
         row = txn.fetchone()
 
@@ -338,17 +364,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 AND stream_ordering > ?
                 AND highlight = 1
         """
-        txn.execute(sql, (user_id, room_id, stream_ordering))
+        txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
         row = txn.fetchone()
         if row:
             counts.highlight_count += row[0]
 
         # Finally we need to count push actions that aren't included in the
-        # summary returned above, e.g. recent events that haven't been
-        # summarised yet, or the summary is empty due to a recent read receipt.
-        stream_ordering = max(stream_ordering, summary_stream_ordering)
+        # summary returned above. This might be due to recent events that haven't
+        # been summarised yet or the summary is out of date due to a recent read
+        # receipt.
+        start_unread_stream_ordering = max(
+            receipt_stream_ordering, summary_stream_ordering
+        )
         notify_count, unread_count = self._get_notif_unread_count_for_user_room(
-            txn, room_id, user_id, stream_ordering
+            txn, room_id, user_id, start_unread_stream_ordering
         )
 
         counts.notify_count += notify_count
@@ -430,6 +459,32 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
 
+    def _get_receipts_by_room_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> List[Tuple[str, int]]:
+        receipt_types_clause, args = make_in_list_sql_clause(
+            self.database_engine,
+            "receipt_type",
+            (
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+                ReceiptTypes.UNSTABLE_READ_PRIVATE,
+            ),
+        )
+
+        sql = f"""
+            SELECT room_id, MAX(stream_ordering)
+            FROM receipts_linearized
+            INNER JOIN events USING (room_id, event_id)
+            WHERE {receipt_types_clause}
+            AND user_id = ?
+            GROUP BY room_id
+        """
+
+        args.extend((user_id,))
+        txn.execute(sql, args)
+        return cast(List[Tuple[str, int]], txn.fetchall())
+
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
         user_id: str,
@@ -453,106 +508,45 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        # find rooms that have a read receipt in them and return the next
-        # push actions
-        def get_after_receipt(
-            txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool]]:
-            # find rooms that have a read receipt in them and return the next
-            # push actions
-
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight
-                FROM (
-                    SELECT room_id,
-                        MAX(stream_ordering) as stream_ordering
-                    FROM events
-                    INNER JOIN receipts_linearized USING (room_id, event_id)
-                    WHERE {receipt_types_clause} AND user_id = ?
-                    GROUP BY room_id
-                ) AS rl,
-                event_push_actions AS ep
-                WHERE
-                    ep.room_id = rl.room_id
-                    AND ep.stream_ordering > rl.stream_ordering
-                    AND ep.user_id = ?
-                    AND ep.stream_ordering > ?
-                    AND ep.stream_ordering <= ?
-                    AND ep.notif = 1
-                ORDER BY ep.stream_ordering ASC LIMIT ?
-            """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
-            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
-
-        after_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
+        receipts_by_room = dict(
+            await self.db_pool.runInteraction(
+                "get_unread_push_actions_for_user_in_range_http_receipts",
+                self._get_receipts_by_room_txn,
+                user_id=user_id,
+            ),
         )
 
-        # There are rooms with push actions in them but you don't have a read receipt in
-        # them e.g. rooms you've been invited to, so get push actions for rooms which do
-        # not have read receipts in them too.
-        def get_no_receipt(
+        def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool]]:
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight
+            sql = """
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
                 FROM event_push_actions AS ep
-                INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
-                    ep.room_id NOT IN (
-                        SELECT room_id FROM receipts_linearized
-                        WHERE {receipt_types_clause} AND user_id = ?
-                        GROUP BY room_id
-                    )
-                    AND ep.user_id = ?
+                    ep.user_id = ?
                     AND ep.stream_ordering > ?
                     AND ep.stream_ordering <= ?
                     AND ep.notif = 1
                 ORDER BY ep.stream_ordering ASC LIMIT ?
             """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
+            txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
-        no_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
+        push_actions = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
         )
 
         notifs = [
             HttpPushAction(
-                event_id=row[0],
-                room_id=row[1],
-                stream_ordering=row[2],
-                actions=_deserialize_action(row[3], row[4]),
+                event_id=event_id,
+                room_id=room_id,
+                stream_ordering=stream_ordering,
+                actions=_deserialize_action(actions, highlight),
             )
-            for row in after_read_receipt + no_read_receipt
+            for event_id, room_id, stream_ordering, actions, highlight in push_actions
+            # Only include push actions with a stream ordering after any receipt, or without any
+            # receipt present (invited to but never read rooms).
+            if stream_ordering > receipts_by_room.get(room_id, 0)
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -588,106 +582,49 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        # find rooms that have a read receipt in them and return the most recent
-        # push actions
-        def get_after_receipt(
-            txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool, int]]:
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight, e.received_ts
-                FROM (
-                    SELECT room_id,
-                        MAX(stream_ordering) as stream_ordering
-                    FROM events
-                    INNER JOIN receipts_linearized USING (room_id, event_id)
-                    WHERE {receipt_types_clause} AND user_id = ?
-                    GROUP BY room_id
-                ) AS rl,
-                event_push_actions AS ep
-                INNER JOIN events AS e USING (room_id, event_id)
-                WHERE
-                    ep.room_id = rl.room_id
-                    AND ep.stream_ordering > rl.stream_ordering
-                    AND ep.user_id = ?
-                    AND ep.stream_ordering > ?
-                    AND ep.stream_ordering <= ?
-                    AND ep.notif = 1
-                ORDER BY ep.stream_ordering DESC LIMIT ?
-            """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
-            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
-
-        after_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
+        receipts_by_room = dict(
+            await self.db_pool.runInteraction(
+                "get_unread_push_actions_for_user_in_range_email_receipts",
+                self._get_receipts_by_room_txn,
+                user_id=user_id,
+            ),
         )
 
-        # There are rooms with push actions in them but you don't have a read receipt in
-        # them e.g. rooms you've been invited to, so get push actions for rooms which do
-        # not have read receipts in them too.
-        def get_no_receipt(
+        def get_push_actions_txn(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            receipt_types_clause, args = make_in_list_sql_clause(
-                self.database_engine,
-                "receipt_type",
-                (
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
-                ),
-            )
-
-            sql = f"""
+            sql = """
                 SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
                     ep.highlight, e.received_ts
                 FROM event_push_actions AS ep
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
-                    ep.room_id NOT IN (
-                        SELECT room_id FROM receipts_linearized
-                        WHERE {receipt_types_clause} AND user_id = ?
-                        GROUP BY room_id
-                    )
-                    AND ep.user_id = ?
+                    ep.user_id = ?
                     AND ep.stream_ordering > ?
                     AND ep.stream_ordering <= ?
                     AND ep.notif = 1
                 ORDER BY ep.stream_ordering DESC LIMIT ?
             """
-            args.extend(
-                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
-            )
-            txn.execute(sql, args)
+            txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
-        no_read_receipt = await self.db_pool.runInteraction(
-            "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
+        push_actions = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
         )
 
         # Make a list of dicts from the two sets of results.
         notifs = [
             EmailPushAction(
-                event_id=row[0],
-                room_id=row[1],
-                stream_ordering=row[2],
-                actions=_deserialize_action(row[3], row[4]),
-                received_ts=row[5],
+                event_id=event_id,
+                room_id=room_id,
+                stream_ordering=stream_ordering,
+                actions=_deserialize_action(actions, highlight),
+                received_ts=received_ts,
             )
-            for row in after_read_receipt + no_read_receipt
+            for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
+            # Only include push actions with a stream ordering after any receipt, or without any
+            # receipt present (invited to but never read rooms).
+            if stream_ordering > receipts_by_room.get(room_id, 0)
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -733,7 +670,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
     async def add_push_actions_to_staging(
         self,
         event_id: str,
-        user_id_actions: Dict[str, List[Union[dict, str]]],
+        user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
         count_as_unread: bool,
     ) -> None:
         """Add the push actions for the event to the push action staging area.
@@ -750,7 +687,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # This is a helper function for generating the necessary tuple that
         # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(
-            user_id: str, actions: List[Union[dict, str]]
+            user_id: str, actions: Collection[Union[Mapping, str]]
         ) -> Tuple[str, str, str, int, int, int]:
             is_highlight = 1 if _action_has_highlight(actions) else 0
             notif = 1 if "notify" in actions else 0
@@ -1151,8 +1088,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             txn: The database transaction.
             old_rotate_stream_ordering: The previous maximum event stream ordering.
             rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
-
-        Returns whether the archiving process has caught up or not.
         """
 
         # Calculate the new counts that should be upserted into event_push_summary
@@ -1238,9 +1173,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             (rotate_to_stream_ordering,),
         )
 
-    async def _remove_old_push_actions_that_have_rotated(
-        self,
-    ) -> None:
+    async def _remove_old_push_actions_that_have_rotated(self) -> None:
         """Clear out old push actions that have been summarised."""
 
         # We want to clear out anything that is older than a day that *has* already
@@ -1397,7 +1330,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         ]
 
 
-def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
+def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool:
     for action in actions:
         if not isinstance(action, dict):
             continue
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5560b38a48..a4010ee28d 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -145,6 +146,7 @@ class PersistEventsStore:
         self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
         self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
 
+    @trace
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index b07d812ae2..8a7cdb024d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
     current_context,
     make_deferred_yieldable,
 )
+from synapse.logging.opentracing import start_active_span, tag_args, trace
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
@@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
+    @trace
+    @tag_args
     async def get_events_as_list(
         self,
         event_ids: Collection[str],
@@ -1090,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore):
         """
         fetched_event_ids: Set[str] = set()
         fetched_events: Dict[str, _EventRow] = {}
-        events_to_fetch = event_ids
 
-        while events_to_fetch:
-            row_map = await self._enqueue_events(events_to_fetch)
+        async def _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids_to_fetch: Collection[str],
+        ) -> Collection[str]:
+            """
+            Fetch all of the given event_ids and return any associated redaction event_ids
+            that we still need to fetch in the next iteration.
+            """
+            row_map = await self._enqueue_events(event_ids_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids: Set[str] = set()
-            for event_id in events_to_fetch:
+            for event_id in event_ids_to_fetch:
                 row = row_map.get(event_id)
                 fetched_event_ids.add(event_id)
                 if row:
                     fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_event_ids)
-            if events_to_fetch:
-                logger.debug("Also fetching redaction events %s", events_to_fetch)
+            event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+            return event_ids_to_fetch
+
+        # Grab the initial list of events requested
+        event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids
+        )
+        # Then go and recursively find all of the associated redactions
+        with start_active_span("recursively fetching redactions"):
+            while event_ids_to_fetch:
+                logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+                event_ids_to_fetch = (
+                    await _fetch_event_ids_and_get_outstanding_redactions(
+                        event_ids_to_fetch
+                    )
+                )
 
         # build a map from event_id to EventBase
         event_map: Dict[str, EventBase] = {}
@@ -1424,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
+    @trace
+    @tag_args
     async def have_seen_events(
         self, room_id: str, event_ids: Iterable[str]
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 768f95d16c..5079edd1e0 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,11 +14,23 @@
 # limitations under the License.
 import abc
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+    cast,
+)
 
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import list_with_base_rules
+from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
@@ -50,60 +62,30 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-def _is_experimental_rule_enabled(
-    rule_id: str, experimental_config: ExperimentalConfig
-) -> bool:
-    """Used by `_load_rules` to filter out experimental rules when they
-    have not been enabled.
-    """
-    if (
-        rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
-        and not experimental_config.msc3786_enabled
-    ):
-        return False
-    if (
-        rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
-        and not experimental_config.msc3772_enabled
-    ):
-        return False
-    return True
-
-
 def _load_rules(
     rawrules: List[JsonDict],
     enabled_map: Dict[str, bool],
     experimental_config: ExperimentalConfig,
-) -> List[JsonDict]:
-    ruleslist = []
-    for rawrule in rawrules:
-        rule = dict(rawrule)
-        rule["conditions"] = db_to_json(rawrule["conditions"])
-        rule["actions"] = db_to_json(rawrule["actions"])
-        rule["default"] = False
-        ruleslist.append(rule)
-
-    # We're going to be mutating this a lot, so copy it. We also filter out
-    # any experimental default push rules that aren't enabled.
-    rules = [
-        rule
-        for rule in list_with_base_rules(ruleslist)
-        if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
-    ]
+) -> FilteredPushRules:
+    """Take the DB rows returned from the DB and convert them into a full
+    `FilteredPushRules` object.
+    """
 
-    for i, rule in enumerate(rules):
-        rule_id = rule["rule_id"]
+    ruleslist = [
+        PushRule(
+            rule_id=rawrule["rule_id"],
+            priority_class=rawrule["priority_class"],
+            conditions=db_to_json(rawrule["conditions"]),
+            actions=db_to_json(rawrule["actions"]),
+        )
+        for rawrule in rawrules
+    ]
 
-        if rule_id not in enabled_map:
-            continue
-        if rule.get("enabled", True) == bool(enabled_map[rule_id]):
-            continue
+    push_rules = compile_push_rules(ruleslist)
 
-        # Rules are cached across users.
-        rule = dict(rule)
-        rule["enabled"] = bool(enabled_map[rule_id])
-        rules[i] = rule
+    filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
 
-    return rules
+    return filtered_rules
 
 
 # The ABCMeta metaclass ensures that it cannot be instantiated without
@@ -162,7 +144,7 @@ class PushRulesWorkerStore(
         raise NotImplementedError()
 
     @cached(max_entries=5000)
-    async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
+    async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
         rows = await self.db_pool.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
@@ -183,7 +165,6 @@ class PushRulesWorkerStore(
 
         return _load_rules(rows, enabled_map, self.hs.config.experimental)
 
-    @cached(max_entries=5000)
     async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
         results = await self.db_pool.simple_select_list(
             table="push_rules_enable",
@@ -216,11 +197,11 @@ class PushRulesWorkerStore(
     @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
     async def bulk_get_push_rules(
         self, user_ids: Collection[str]
-    ) -> Dict[str, List[JsonDict]]:
+    ) -> Dict[str, FilteredPushRules]:
         if not user_ids:
             return {}
 
-        results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+        raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
 
         rows = await self.db_pool.simple_select_many_batch(
             table="push_rules",
@@ -234,20 +215,19 @@ class PushRulesWorkerStore(
         rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
 
         for row in rows:
-            results.setdefault(row["user_name"], []).append(row)
+            raw_rules.setdefault(row["user_name"], []).append(row)
 
         enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
 
-        for user_id, rules in results.items():
+        results: Dict[str, FilteredPushRules] = {}
+
+        for user_id, rules in raw_rules.items():
             results[user_id] = _load_rules(
                 rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
             )
 
         return results
 
-    @cachedList(
-        cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
-    )
     async def bulk_get_push_rules_enabled(
         self, user_ids: Collection[str]
     ) -> Dict[str, Dict[str, bool]]:
@@ -262,6 +242,7 @@ class PushRulesWorkerStore(
             iterable=user_ids,
             retcols=("user_name", "rule_id", "enabled"),
             desc="bulk_get_push_rules_enabled",
+            batch_size=1000,
         )
         for row in rows:
             enabled = bool(row["enabled"])
@@ -345,8 +326,8 @@ class PushRuleStore(PushRulesWorkerStore):
         user_id: str,
         rule_id: str,
         priority_class: int,
-        conditions: List[Dict[str, str]],
-        actions: List[Union[JsonDict, str]],
+        conditions: Sequence[Mapping[str, str]],
+        actions: Sequence[Union[Mapping[str, Any], str]],
         before: Optional[str] = None,
         after: Optional[str] = None,
     ) -> None:
@@ -808,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore):
         self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
 
         txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
-        txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
         txn.call_after(
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
@@ -817,7 +797,7 @@ class PushRuleStore(PushRulesWorkerStore):
         return self._push_rules_stream_id_gen.get_current_token()
 
     async def copy_push_rule_from_room_to_room(
-        self, new_room_id: str, user_id: str, rule: dict
+        self, new_room_id: str, user_id: str, rule: PushRule
     ) -> None:
         """Copy a single push rule from one room to another for a specific user.
 
@@ -827,21 +807,27 @@ class PushRuleStore(PushRulesWorkerStore):
             rule: A push rule.
         """
         # Create new rule id
-        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+        rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
         new_rule_id = rule_id_scope + "/" + new_room_id
 
+        new_conditions = []
+
         # Change room id in each condition
-        for condition in rule.get("conditions", []):
+        for condition in rule.conditions:
+            new_condition = condition
             if condition.get("key") == "room_id":
-                condition["pattern"] = new_room_id
+                new_condition = dict(condition)
+                new_condition["pattern"] = new_room_id
+
+            new_conditions.append(new_condition)
 
         # Add the rule for the new room
         await self.add_push_rule(
             user_id=user_id,
             rule_id=new_rule_id,
-            priority_class=rule["priority_class"],
-            conditions=rule["conditions"],
-            actions=rule["actions"],
+            priority_class=rule.priority_class,
+            conditions=new_conditions,
+            actions=rule.actions,
         )
 
     async def copy_push_rules_from_room_to_room_for_user(
@@ -859,8 +845,11 @@ class PushRuleStore(PushRulesWorkerStore):
         user_push_rules = await self.get_push_rules_for_user(user_id)
 
         # Get rules relating to the old room and copy them to the new room
-        for rule in user_push_rules:
-            conditions = rule.get("conditions", [])
+        for rule, enabled in user_push_rules:
+            if not enabled:
+                continue
+
+            conditions = rule.conditions
             if any(
                 (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
                 for c in conditions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0090c9f225..124c70ad37 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -161,7 +161,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             receipt_type: The receipt types to fetch.
 
         Returns:
-            The latest receipt, if one exists.
+            The event ID and stream ordering of the latest receipt, if one exists.
         """
 
         clause, args = make_in_list_sql_clause(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cb63cd9b7d..7fb9c801da 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -69,9 +69,9 @@ class TokenLookupResult:
     """
 
     user_id: str
+    token_id: int
     is_guest: bool = False
     shadow_banned: bool = False
-    token_id: Optional[int] = None
     device_id: Optional[str] = None
     valid_until_ms: Optional[int] = None
     token_owner: str = attr.ib()
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 0f1f0d11ea..b7d4baa6bb 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -2001,9 +2001,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
 
             where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
 
+            # We join on room_stats_state despite not using any columns from it
+            # because the join can influence the number of rows returned;
+            # e.g. a room that doesn't have state, maybe because it was deleted.
+            # The query returning the total count should be consistent with
+            # the query returning the results.
             sql = """
                 SELECT COUNT(*) as total_event_reports
                 FROM event_reports AS er
+                JOIN room_stats_state ON room_stats_state.room_id = er.room_id
                 {}
                 """.format(
                 where_clause
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 93ff4816c8..9e5034b401 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -283,6 +283,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         Returns:
             A mapping from user ID to ProfileInfo.
+
+        Preconditions:
+          - There is full state available for the room (it is not partial-stated).
         """
 
         def _get_users_in_room_with_profiles(
@@ -531,6 +534,32 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             desc="get_local_users_in_room",
         )
 
+    async def check_local_user_in_room(self, user_id: str, room_id: str) -> bool:
+        """
+        Check whether a given local user is currently joined to the given room.
+
+        Returns:
+            A boolean indicating whether the user is currently joined to the room
+
+        Raises:
+            Exeption when called with a non-local user to this homeserver
+        """
+        if not self.hs.is_mine_id(user_id):
+            raise Exception(
+                "Cannot call 'check_local_user_in_room' on "
+                "non-local user %s" % (user_id,),
+            )
+
+        (
+            membership,
+            member_event_id,
+        ) = await self.get_local_current_membership_for_user_in_room(
+            user_id=user_id,
+            room_id=room_id,
+        )
+
+        return membership == Membership.JOIN
+
     async def get_local_current_membership_for_user_in_room(
         self, user_id: str, room_id: str
     ) -> Tuple[Optional[str], Optional[str]]:
@@ -832,9 +861,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return shared_room_ids or frozenset()
 
-    async def get_joined_users_from_state(
+    async def get_joined_user_ids_from_state(
         self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -845,25 +874,25 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
-            return await self._get_joined_users_from_context(
+            return await self._get_joined_user_ids_from_context(
                 room_id, state_group, state, context=state_entry
             )
 
     @cached(num_args=2, iterable=True, max_entries=100000)
-    async def _get_joined_users_from_context(
+    async def _get_joined_user_ids_from_context(
         self,
         room_id: str,
         state_group: Union[object, int],
         current_state_ids: StateMap[str],
         event: Optional[EventBase] = None,
         context: Optional["_StateCacheEntry"] = None,
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
         # with a state_group of None are likely to be different.
         assert state_group is not None
 
-        users_in_room = {}
+        users_in_room = set()
         member_event_ids = [
             e_id
             for key, e_id in current_state_ids.items()
@@ -876,11 +905,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # If we do then we can reuse that result and simply update it with
             # any membership changes in `delta_ids`
             if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get_immediate(
+                prev_res = self._get_joined_user_ids_from_context.cache.get_immediate(
                     (room_id, context.prev_group), None
                 )
-                if prev_res and isinstance(prev_res, dict):
-                    users_in_room = dict(prev_res)
+                if prev_res and isinstance(prev_res, set):
+                    users_in_room = prev_res
                     member_event_ids = [
                         e_id
                         for key, e_id in context.delta_ids.items()
@@ -888,7 +917,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                     ]
                     for etype, state_key in context.delta_ids:
                         if etype == EventTypes.Member:
-                            users_in_room.pop(state_key, None)
+                            users_in_room.discard(state_key)
 
         # We check if we have any of the member event ids in the event cache
         # before we ask the DB
@@ -905,71 +934,64 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             ev_entry = event_map.get(event_id)
             if ev_entry and not ev_entry.event.rejected_reason:
                 if ev_entry.event.membership == Membership.JOIN:
-                    users_in_room[ev_entry.event.state_key] = ProfileInfo(
-                        display_name=ev_entry.event.content.get("displayname", None),
-                        avatar_url=ev_entry.event.content.get("avatar_url", None),
-                    )
+                    users_in_room.add(ev_entry.event.state_key)
             else:
                 missing_member_event_ids.append(event_id)
 
         if missing_member_event_ids:
-            event_to_memberships = await self._get_joined_profiles_from_event_ids(
+            event_to_memberships = await self._get_user_ids_from_membership_event_ids(
                 missing_member_event_ids
             )
-            users_in_room.update(row for row in event_to_memberships.values() if row)
+            users_in_room.update(
+                user_id for user_id in event_to_memberships.values() if user_id
+            )
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 if event.event_id in member_event_ids:
-                    users_in_room[event.state_key] = ProfileInfo(
-                        display_name=event.content.get("displayname", None),
-                        avatar_url=event.content.get("avatar_url", None),
-                    )
+                    users_in_room.add(event.state_key)
 
         return users_in_room
 
-    @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(
+    @cached(
+        max_entries=10000,
+        # This name matches the old function that has been replaced - the cache name
+        # is kept here to maintain backwards compatibility.
+        name="_get_joined_profile_from_event_id",
+    )
+    def _get_user_id_from_membership_event_id(
         self, event_id: str
     ) -> Optional[Tuple[str, ProfileInfo]]:
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id",
+        cached_method_name="_get_user_id_from_membership_event_id",
         list_name="event_ids",
     )
-    async def _get_joined_profiles_from_event_ids(
+    async def _get_user_ids_from_membership_event_ids(
         self, event_ids: Iterable[str]
-    ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+    ) -> Dict[str, Optional[str]]:
         """For given set of member event_ids check if they point to a join
-        event and if so return the associated user and profile info.
+        event.
 
         Args:
             event_ids: The member event IDs to lookup
 
         Returns:
-            Map from event ID to `user_id` and ProfileInfo (or None if not join event).
+            Map from event ID to `user_id`, or None if event is not a join.
         """
 
         rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
-            retcols=("user_id", "display_name", "avatar_url", "event_id"),
+            retcols=("user_id", "event_id"),
             keyvalues={"membership": Membership.JOIN},
             batch_size=1000,
-            desc="_get_joined_profiles_from_event_ids",
+            desc="_get_user_ids_from_membership_event_ids",
         )
 
-        return {
-            row["event_id"]: (
-                row["user_id"],
-                ProfileInfo(
-                    avatar_url=row["avatar_url"], display_name=row["display_name"]
-                ),
-            )
-            for row in rows
-        }
+        return {row["event_id"]: row["user_id"] for row in rows}
 
     @cached(max_entries=10000)
     async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1128,12 +1150,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             else:
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
-                joined_users = await self.get_joined_users_from_state(
+                joined_user_ids = await self.get_joined_user_ids_from_state(
                     room_id, state, state_entry
                 )
 
                 cache.hosts_to_joined_users = {}
-                for user_id in joined_users:
+                for user_id in joined_user_ids:
                     host = intern_string(get_domain_from_id(user_id))
                     cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
 
@@ -1212,6 +1234,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
         )
 
+    async def is_locally_forgotten_room(self, room_id: str) -> bool:
+        """Returns whether all local users have forgotten this room_id.
+
+        Args:
+            room_id: The room ID to query.
+
+        Returns:
+            Whether the room is forgotten.
+        """
+
+        sql = """
+            SELECT count(*) > 0 FROM local_current_membership
+            INNER JOIN room_memberships USING (room_id, event_id)
+            WHERE
+                room_id = ?
+                AND forgotten = 0;
+        """
+
+        rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+
+        # `count(*)` returns always an integer
+        # If any rows still exist it means someone has not forgotten this room yet
+        return not rows[0][0]
+
     async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
         """Get all rooms that the user has ever been in.
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 3c13859faa..2dfe4c0b66 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -460,8 +460,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                 # Cast safety: this corresponds to the types returned by the query above.
                 rows.extend(cast(Iterable[Tuple[str, int]], cur))
 
-            # Sort so that we handle rows in order for each instance.
-            rows.sort()
+            # Sort by stream_id (ascending, lowest -> highest) so that we handle
+            # rows in order for each instance because we don't want to overwrite
+            # the current_position of an instance to a lower stream ID than
+            # we're actually at.
+            def sort_by_stream_id_key_func(row: Tuple[str, int]) -> int:
+                (instance, stream_id) = row
+                # If `stream_id` is ever `None`, we will see a `TypeError: '<'
+                # not supported between instances of 'NoneType' and 'X'` error.
+                return stream_id
+
+            rows.sort(key=sort_by_stream_id_key_func)
 
             with self._lock:
                 for (
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 466e5137f2..b4bf49dace 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
 from twisted.internet.defer import Deferred
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import trace_with_opname
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.util import unwrapFirstError
@@ -58,6 +59,7 @@ class PartialStateEventsTracker:
             for o in observers:
                 o.callback(None)
 
+    @trace_with_opname("PartialStateEventsTracker.await_full_state")
     async def await_full_state(self, event_ids: Collection[str]) -> None:
         """Wait for all the given events to have full state.
 
@@ -151,6 +153,7 @@ class PartialCurrentStateTracker:
             for o in observers:
                 o.callback(None)
 
+    @trace_with_opname("PartialCurrentStateTracker.await_full_state")
     async def await_full_state(self, room_id: str) -> None:
         # We add the deferred immediately so that the DB call to check for
         # partial state doesn't race when we unpartial the room.