summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/__init__.py29
-rw-r--r--synapse/storage/databases/main/account_data.py3
-rw-r--r--synapse/storage/databases/main/cache.py26
-rw-r--r--synapse/storage/databases/main/deviceinbox.py2
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py6
-rw-r--r--synapse/storage/databases/main/event_federation.py9
-rw-r--r--synapse/storage/databases/main/event_push_actions.py471
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/events_worker.py193
-rw-r--r--synapse/storage/databases/main/push_rule.py127
-rw-r--r--synapse/storage/databases/main/receipts.py2
-rw-r--r--synapse/storage/databases/main/registration.py2
-rw-r--r--synapse/storage/databases/main/relations.py6
-rw-r--r--synapse/storage/databases/main/room.py8
-rw-r--r--synapse/storage/databases/main/roommember.py185
-rw-r--r--synapse/storage/databases/main/state.py13
-rw-r--r--synapse/storage/databases/main/stream.py2
18 files changed, 755 insertions, 337 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index a3d31d3737..4dccbb732a 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -24,9 +24,9 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -149,31 +149,6 @@ class DataStore(
             ],
         )
 
-        self._cache_id_gen: Optional[MultiWriterIdGenerator]
-        if isinstance(self.database_engine, PostgresEngine):
-            # We set the `writers` to an empty list here as we don't care about
-            # missing updates over restarts, as we'll not have anything in our
-            # caches to invalidate. (This reduces the amount of writes to the DB
-            # that happen).
-            self._cache_id_gen = MultiWriterIdGenerator(
-                db_conn,
-                database,
-                stream_name="caches",
-                instance_name=hs.get_instance_name(),
-                tables=[
-                    (
-                        "cache_invalidation_stream_by_instance",
-                        "instance_name",
-                        "stream_id",
-                    )
-                ],
-                sequence_name="cache_invalidation_stream_seq",
-                writers=[],
-            )
-
-        else:
-            self._cache_id_gen = None
-
         super().__init__(database, db_conn, hs)
 
         events_max = self._stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 9af9f4f18e..c38b8a9e5a 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -650,9 +650,6 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             txn, self.get_account_data_for_room, (user_id,)
         )
         self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
-        self._invalidate_cache_and_stream(
-            txn, self.get_push_rules_enabled_for_user, (user_id,)
-        )
         # This user might be contained in the ignored_by cache for other users,
         # so we have to invalidate it all.
         self._invalidate_all_cache_and_stream(txn, self.ignored_by)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2367ddeea3..12e9a42382 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -32,6 +32,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.util.caches.descriptors import _CachedFunction
 from synapse.util.iterutils import batch_iter
 
@@ -65,6 +66,31 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             psql_only=True,  # The table is only on postgres DBs.
         )
 
+        self._cache_id_gen: Optional[MultiWriterIdGenerator]
+        if isinstance(self.database_engine, PostgresEngine):
+            # We set the `writers` to an empty list here as we don't care about
+            # missing updates over restarts, as we'll not have anything in our
+            # caches to invalidate. (This reduces the amount of writes to the DB
+            # that happen).
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                stream_name="caches",
+                instance_name=hs.get_instance_name(),
+                tables=[
+                    (
+                        "cache_invalidation_stream_by_instance",
+                        "instance_name",
+                        "stream_id",
+                    )
+                ],
+                sequence_name="cache_invalidation_stream_seq",
+                writers=[],
+            )
+
+        else:
+            self._cache_id_gen = None
+
     async def get_all_updated_caches(
         self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 422e0e65ca..73c95ffb6f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             (user_id, device_id), None
         )
 
-        set_tag("last_deleted_stream_id", last_deleted_stream_id)
+        set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
 
         if last_deleted_stream_id:
             has_changed = self._device_inbox_stream_cache.has_entity_changed(
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 7a6ed332aa..ca0fe8c4be 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -706,8 +706,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
             else:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
 
-        set_tag("in_cache", results)
-        set_tag("not_in_cache", user_ids_not_in_cache)
+        set_tag("in_cache", str(results))
+        set_tag("not_in_cache", str(user_ids_not_in_cache))
 
         return user_ids_not_in_cache, results
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 60f622ad71..46c0d06157 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -146,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             key data.  The key data will be a dict in the same format as the
             DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
         """
-        set_tag("query_list", query_list)
+        set_tag("query_list", str(query_list))
         if not query_list:
             return {}
 
@@ -418,7 +418,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
-            set_tag("new_keys", new_keys)
+            set_tag("new_keys", str(new_keys))
             # We are protected from race between lookup and insertion due to
             # a unique constraint. If there is a race of two calls to
             # `add_e2e_one_time_keys` then they'll conflict and we will only
@@ -1161,7 +1161,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("time_now", time_now)
-            set_tag("device_keys", device_keys)
+            set_tag("device_keys", str(device_keys))
 
             old_key_json = self.db_pool.simple_select_one_onecol_txn(
                 txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..c836078da6 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.opentracing import tag_args, trace
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
@@ -126,6 +127,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
         return await self.get_events_as_list(event_ids)
 
+    @trace
+    @tag_args
     async def get_auth_chain_ids(
         self,
         room_id: str,
@@ -709,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
+    @trace
+    @tag_args
     async def get_oldest_event_ids_with_depth_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -767,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
+    @trace
     async def get_insertion_event_backward_extremities_in_room(
         self, room_id: str
     ) -> List[Tuple[str, int]]:
@@ -1339,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         event_results.reverse()
         return event_results
 
+    @trace
+    @tag_args
     async def get_successor_events(self, event_id: str) -> List[str]:
         """Fetch all events that have the given event as a prev event
 
@@ -1375,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             _delete_old_forward_extrem_cache_txn,
         )
 
+    @trace
     async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
         await self.db_pool.simple_upsert(
             table="insertion_event_extremities",
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index dd2627037c..eabf9c9739 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -12,14 +12,85 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+"""Responsible for storing and fetching push actions / notifications.
+
+There are two main uses for push actions:
+  1. Sending out push to a user's device; and
+  2. Tracking per-room per-user notification counts (used in sync requests).
+
+For the former we simply use the `event_push_actions` table, which contains all
+the calculated actions for a given user (which were calculated by the
+`BulkPushRuleEvaluator`).
+
+For the latter we could simply count the number of rows in `event_push_actions`
+table for a given room/user, but in practice this is *very* heavyweight when
+there were a large number of notifications (due to e.g. the user never reading a
+room). Plus, keeping all push actions indefinitely uses a lot of disk space.
+
+To fix these issues, we add a new table `event_push_summary` that tracks
+per-user per-room counts of all notifications that happened before a stream
+ordering S. Thus, to get the notification count for a user / room we can simply
+query a single row in `event_push_summary` and count the number of rows in
+`event_push_actions` with a stream ordering larger than S (and as long as S is
+"recent", the number of rows needing to be scanned will be small).
+
+The `event_push_summary` table is updated via a background job that periodically
+chooses a new stream ordering S' (usually the latest stream ordering), counts
+all notifications in `event_push_actions` between the existing S and S', and
+adds them to the existing counts in `event_push_summary`.
+
+This allows us to delete old rows from `event_push_actions` once those rows have
+been counted and added to `event_push_summary` (we call this process
+"rotation").
+
+
+We need to handle when a user sends a read receipt to the room. Again this is
+done as a background process. For each receipt we clear the row in
+`event_push_summary` and count the number of notifications in
+`event_push_actions` that happened after the receipt but before S, and insert
+that count into `event_push_summary` (If the receipt happened *after* S then we
+simply clear the `event_push_summary`.)
+
+Note that its possible that if the read receipt is for an old event the relevant
+`event_push_actions` rows will have been rotated and we get the wrong count
+(it'll be too low). We accept this as a rare edge case that is unlikely to
+impact the user much (since the vast majority of read receipts will be for the
+latest event).
+
+The last complication is to handle the race where we request the notifications
+counts after a user sends a read receipt into the room, but *before* the
+background update handles the receipt (without any special handling the counts
+would be outdated). We fix this by including in `event_push_summary` the read
+receipt we used when updating `event_push_summary`, and every time we query the
+table we check if that matches the most recent read receipt in the room. If yes,
+continue as above, if not we simply query the `event_push_actions` table
+directly.
+
+Since read receipts are almost always for recent events, scanning the
+`event_push_actions` table in this case is unlikely to be a problem. Even if it
+is a problem, it is temporary until the background job handles the new read
+receipt.
+"""
+
 import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
 
 import attr
 
 from synapse.api.constants import ReceiptTypes
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -93,7 +164,9 @@ class NotifCounts:
     highlight_count: int = 0
 
 
-def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
+def _serialize_action(
+    actions: Collection[Union[Mapping, str]], is_highlight: bool
+) -> str:
     """Custom serializer for actions. This allows us to "compress" common actions.
 
     We use the fact that most users have the same actions for notifs (and for
@@ -166,7 +239,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         user_id: str,
     ) -> NotifCounts:
         """Get the notification count, the highlight count and the unread message count
-        for a given user in a given room after the given read receipt.
+        for a given user in a given room after their latest read receipt.
 
         Note that this function assumes the user to be a current member of the room,
         since it's either called by the sync handler to handle joined room entries, or by
@@ -177,9 +250,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             user_id: The user to retrieve the counts for.
 
         Returns
-            A dict containing the counts mentioned earlier in this docstring,
-            respectively under the keys "notify_count", "highlight_count" and
-            "unread_count".
+            A NotifCounts object containing the notification count, the highlight count
+            and the unread message count.
         """
         return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
@@ -194,20 +266,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         room_id: str,
         user_id: str,
     ) -> NotifCounts:
+        # Get the stream ordering of the user's latest receipt in the room.
         result = self.get_last_receipt_for_user_txn(
             txn,
             user_id,
             room_id,
-            receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+            receipt_types=(
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+                ReceiptTypes.UNSTABLE_READ_PRIVATE,
+            ),
         )
 
-        stream_ordering = None
         if result:
             _, stream_ordering = result
 
-        if stream_ordering is None:
-            # Either last_read_event_id is None, or it's an event we don't have (e.g.
-            # because it's been purged), in which case retrieve the stream ordering for
+        else:
+            # If the user has no receipts in the room, retrieve the stream ordering for
             # the latest membership event from this user in this room (which we assume is
             # a join).
             event_id = self.db_pool.simple_select_one_onecol_txn(
@@ -224,10 +299,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
 
     def _get_unread_counts_by_pos_txn(
-        self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        user_id: str,
+        receipt_stream_ordering: int,
     ) -> NotifCounts:
         """Get the number of unread messages for a user/room that have happened
         since the given stream ordering.
+
+        Args:
+            txn: The database transaction.
+            room_id: The room ID to get unread counts for.
+            user_id: The user ID to get unread counts for.
+            receipt_stream_ordering: The stream ordering of the user's latest
+                receipt in the room. If there are no receipts, the stream ordering
+                of the user's join event.
+
+        Returns
+            A NotifCounts object containing the notification count, the highlight count
+            and the unread message count.
         """
 
         counts = NotifCounts()
@@ -255,7 +346,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     OR last_receipt_stream_ordering = ?
                 )
             """,
-            (room_id, user_id, stream_ordering, stream_ordering),
+            (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
         )
         row = txn.fetchone()
 
@@ -265,7 +356,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             counts.notify_count += row[1]
             counts.unread_count += row[2]
 
-        # Next we need to count highlights, which aren't summarized
+        # Next we need to count highlights, which aren't summarised
         sql = """
             SELECT COUNT(*) FROM event_push_actions
             WHERE user_id = ?
@@ -273,17 +364,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 AND stream_ordering > ?
                 AND highlight = 1
         """
-        txn.execute(sql, (user_id, room_id, stream_ordering))
+        txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
         row = txn.fetchone()
         if row:
             counts.highlight_count += row[0]
 
         # Finally we need to count push actions that aren't included in the
-        # summary returned above, e.g. recent events that haven't been
-        # summarized yet, or the summary is empty due to a recent read receipt.
-        stream_ordering = max(stream_ordering, summary_stream_ordering)
+        # summary returned above. This might be due to recent events that haven't
+        # been summarised yet or the summary is out of date due to a recent read
+        # receipt.
+        start_unread_stream_ordering = max(
+            receipt_stream_ordering, summary_stream_ordering
+        )
         notify_count, unread_count = self._get_notif_unread_count_for_user_room(
-            txn, room_id, user_id, stream_ordering
+            txn, room_id, user_id, start_unread_stream_ordering
         )
 
         counts.notify_count += notify_count
@@ -304,6 +398,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         Does not consult `event_push_summary` table, which may include push
         actions that have been deleted from `event_push_actions` table.
+
+        Args:
+            txn: The database transaction.
+            room_id: The room ID to get unread counts for.
+            user_id: The user ID to get unread counts for.
+            stream_ordering: The (exclusive) minimum stream ordering to consider.
+            max_stream_ordering: The (inclusive) maximum stream ordering to consider.
+                If this is not given, then no maximum is applied.
+
+        Return:
+            A tuple of the notif count and unread count in the given range.
         """
 
         # If there have been no events in the room since the stream ordering,
@@ -376,6 +481,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will be ordered by ascending stream_ordering.
             The list will have between 0~limit entries.
         """
+
         # find rooms that have a read receipt in them and return the next
         # push actions
         def get_after_receipt(
@@ -383,28 +489,41 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         ) -> List[Tuple[str, str, int, str, bool]]:
             # find rooms that have a read receipt in them and return the next
             # push actions
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight "
-                " FROM ("
-                "   SELECT room_id,"
-                "       MAX(stream_ordering) as stream_ordering"
-                "   FROM events"
-                "   INNER JOIN receipts_linearized USING (room_id, event_id)"
-                "   WHERE receipt_type = 'm.read' AND user_id = ?"
-                "   GROUP BY room_id"
-                ") AS rl,"
-                " event_push_actions AS ep"
-                " WHERE"
-                "   ep.room_id = rl.room_id"
-                "   AND ep.stream_ordering > rl.stream_ordering"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering ASC LIMIT ?"
+
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+                    ep.highlight
+                FROM (
+                    SELECT room_id,
+                        MAX(stream_ordering) as stream_ordering
+                    FROM events
+                    INNER JOIN receipts_linearized USING (room_id, event_id)
+                    WHERE {receipt_types_clause} AND user_id = ?
+                    GROUP BY room_id
+                ) AS rl,
+                event_push_actions AS ep
+                WHERE
+                    ep.room_id = rl.room_id
+                    AND ep.stream_ordering > rl.stream_ordering
+                    AND ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering ASC LIMIT ?
+            """
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
             )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
@@ -418,24 +537,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         def get_no_receipt(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight "
-                " FROM event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id NOT IN ("
-                "     SELECT room_id FROM receipts_linearized"
-                "       WHERE receipt_type = 'm.read' AND user_id = ?"
-                "       GROUP BY room_id"
-                "   )"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering ASC LIMIT ?"
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+                    ep.highlight
+                FROM event_push_actions AS ep
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    ep.room_id NOT IN (
+                        SELECT room_id FROM receipts_linearized
+                        WHERE {receipt_types_clause} AND user_id = ?
+                        GROUP BY room_id
+                    )
+                    AND ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering ASC LIMIT ?
+            """
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
             )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
@@ -485,34 +616,47 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will be ordered by descending received_ts.
             The list will have between 0~limit entries.
         """
+
         # find rooms that have a read receipt in them and return the most recent
         # push actions
         def get_after_receipt(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "  ep.highlight, e.received_ts"
-                " FROM ("
-                "   SELECT room_id,"
-                "       MAX(stream_ordering) as stream_ordering"
-                "   FROM events"
-                "   INNER JOIN receipts_linearized USING (room_id, event_id)"
-                "   WHERE receipt_type = 'm.read' AND user_id = ?"
-                "   GROUP BY room_id"
-                ") AS rl,"
-                " event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id = rl.room_id"
-                "   AND ep.stream_ordering > rl.stream_ordering"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering DESC LIMIT ?"
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+                    ep.highlight, e.received_ts
+                FROM (
+                    SELECT room_id,
+                        MAX(stream_ordering) as stream_ordering
+                    FROM events
+                    INNER JOIN receipts_linearized USING (room_id, event_id)
+                    WHERE {receipt_types_clause} AND user_id = ?
+                    GROUP BY room_id
+                ) AS rl,
+                event_push_actions AS ep
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    ep.room_id = rl.room_id
+                    AND ep.stream_ordering > rl.stream_ordering
+                    AND ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering DESC LIMIT ?
+            """
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
             )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
@@ -526,24 +670,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         def get_no_receipt(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight, e.received_ts"
-                " FROM event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id NOT IN ("
-                "     SELECT room_id FROM receipts_linearized"
-                "       WHERE receipt_type = 'm.read' AND user_id = ?"
-                "       GROUP BY room_id"
-                "   )"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering DESC LIMIT ?"
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+                    ep.highlight, e.received_ts
+                FROM event_push_actions AS ep
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    ep.room_id NOT IN (
+                        SELECT room_id FROM receipts_linearized
+                        WHERE {receipt_types_clause} AND user_id = ?
+                        GROUP BY room_id
+                    )
+                    AND ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering DESC LIMIT ?
+            """
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
             )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
@@ -606,7 +762,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
     async def add_push_actions_to_staging(
         self,
         event_id: str,
-        user_id_actions: Dict[str, List[Union[dict, str]]],
+        user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
         count_as_unread: bool,
     ) -> None:
         """Add the push actions for the event to the push action staging area.
@@ -623,7 +779,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # This is a helper function for generating the necessary tuple that
         # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(
-            user_id: str, actions: List[Union[dict, str]]
+            user_id: str, actions: Collection[Union[Mapping, str]]
         ) -> Tuple[str, str, str, int, int, int]:
             is_highlight = 1 if _action_has_highlight(actions) else 0
             notif = 1 if "notify" in actions else 0
@@ -769,12 +925,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # [10, <none>, 20], we should treat this as being equivalent to
         # [10, 10, 20].
         #
-        sql = (
-            "SELECT received_ts FROM events"
-            " WHERE stream_ordering <= ?"
-            " ORDER BY stream_ordering DESC"
-            " LIMIT 1"
-        )
+        sql = """
+            SELECT received_ts FROM events
+            WHERE stream_ordering <= ?
+            ORDER BY stream_ordering DESC
+            LIMIT 1
+        """
 
         while range_end - range_start > 0:
             middle = (range_end + range_start) // 2
@@ -802,14 +958,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         self, stream_ordering: int
     ) -> Optional[int]:
         def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
-            sql = (
-                "SELECT e.received_ts"
-                " FROM event_push_actions AS ep"
-                " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
-                " WHERE ep.stream_ordering > ? AND notif = 1"
-                " ORDER BY ep.stream_ordering ASC"
-                " LIMIT 1"
-            )
+            sql = """
+                SELECT e.received_ts
+                FROM event_push_actions AS ep
+                JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id
+                WHERE ep.stream_ordering > ? AND notif = 1
+                ORDER BY ep.stream_ordering ASC
+                LIMIT 1
+            """
             txn.execute(sql, (stream_ordering,))
             return cast(Optional[Tuple[int]], txn.fetchone())
 
@@ -858,10 +1014,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         Any push actions which predate the user's most recent read receipt are
         now redundant, so we can remove them from `event_push_actions` and
         update `event_push_summary`.
+
+        Returns true if all new receipts have been processed.
         """
 
         limit = 100
 
+        # The (inclusive) receipt stream ID that was previously processed..
         min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_last_receipt_stream_id",
@@ -871,6 +1030,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         max_receipts_stream_id = self._receipts_id_gen.get_current_token()
 
+        # The (inclusive) event stream ordering that was previously summarised.
+        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
+        )
+
         sql = """
             SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
             FROM receipts_linearized AS r
@@ -895,13 +1062,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
         rows = txn.fetchall()
 
-        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
-            txn,
-            table="event_push_summary_stream_ordering",
-            keyvalues={},
-            retcol="stream_ordering",
-        )
-
         # For each new read receipt we delete push actions from before it and
         # recalculate the summary.
         for _, room_id, user_id, stream_ordering in rows:
@@ -920,10 +1080,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 (room_id, user_id, stream_ordering),
             )
 
+            # Fetch the notification counts between the stream ordering of the
+            # latest receipt and what was previously summarised.
             notif_count, unread_count = self._get_notif_unread_count_for_user_room(
                 txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
             )
 
+            # Replace the previous summary with the new counts.
             self.db_pool.simple_upsert_txn(
                 txn,
                 table="event_push_summary",
@@ -956,10 +1119,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         return len(rows) < limit
 
     def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
-        """Archives older notifications into event_push_summary. Returns whether
-        the archiving process has caught up or not.
+        """Archives older notifications (from event_push_actions) into event_push_summary.
+
+        Returns whether the archiving process has caught up or not.
         """
 
+        # The (inclusive) event stream ordering that was previously summarised.
         old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
@@ -974,7 +1139,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             SELECT stream_ordering FROM event_push_actions
             WHERE stream_ordering > ?
             ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
-        """,
+            """,
             (old_rotate_stream_ordering, self._rotate_count),
         )
         stream_row = txn.fetchone()
@@ -993,19 +1158,29 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
 
-        self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+        self._rotate_notifs_before_txn(
+            txn, old_rotate_stream_ordering, rotate_to_stream_ordering
+        )
 
         return caught_up
 
     def _rotate_notifs_before_txn(
-        self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+        self,
+        txn: LoggingTransaction,
+        old_rotate_stream_ordering: int,
+        rotate_to_stream_ordering: int,
     ) -> None:
-        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
-            txn,
-            table="event_push_summary_stream_ordering",
-            keyvalues={},
-            retcol="stream_ordering",
-        )
+        """Archives older notifications (from event_push_actions) into event_push_summary.
+
+        Any event_push_actions between old_rotate_stream_ordering (exclusive) and
+        rotate_to_stream_ordering (inclusive) will be added to the event_push_summary
+        table.
+
+        Args:
+            txn: The database transaction.
+            old_rotate_stream_ordering: The previous maximum event stream ordering.
+            rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
+        """
 
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
@@ -1090,12 +1265,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             (rotate_to_stream_ordering,),
         )
 
-    async def _remove_old_push_actions_that_have_rotated(
-        self,
-    ) -> None:
-        """Clear out old push actions that have been summarized."""
+    async def _remove_old_push_actions_that_have_rotated(self) -> None:
+        """Clear out old push actions that have been summarised."""
 
-        # We want to clear out anything that older than a day that *has* already
+        # We want to clear out anything that is older than a day that *has* already
         # been rotated.
         rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
             table="event_push_summary_stream_ordering",
@@ -1119,7 +1292,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 SELECT stream_ordering FROM event_push_actions
                 WHERE stream_ordering <= ? AND highlight = 0
                 ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
-            """,
+                """,
                 (
                     max_stream_ordering_to_delete,
                     batch_size,
@@ -1215,16 +1388,18 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
 
             # NB. This assumes event_ids are globally unique since
             # it makes the query easier to index
-            sql = (
-                "SELECT epa.event_id, epa.room_id,"
-                " epa.stream_ordering, epa.topological_ordering,"
-                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
-                " FROM event_push_actions epa, events e"
-                " WHERE epa.event_id = e.event_id"
-                " AND epa.user_id = ? %s"
-                " AND epa.notif = 1"
-                " ORDER BY epa.stream_ordering DESC"
-                " LIMIT ?" % (before_clause,)
+            sql = """
+                SELECT epa.event_id, epa.room_id,
+                    epa.stream_ordering, epa.topological_ordering,
+                    epa.actions, epa.highlight, epa.profile_tag, e.received_ts
+                FROM event_push_actions epa, events e
+                WHERE epa.event_id = e.event_id
+                    AND epa.user_id = ? %s
+                    AND epa.notif = 1
+                ORDER BY epa.stream_ordering DESC
+                LIMIT ?
+            """ % (
+                before_clause,
             )
             txn.execute(sql, args)
             return cast(
@@ -1247,7 +1422,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         ]
 
 
-def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
+def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool:
     for action in actions:
         if not isinstance(action, dict):
             continue
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f600f1190..a4010ee28d 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, relation_from_event
 from synapse.events.snapshot import EventContext
+from synapse.logging.opentracing import trace
 from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -145,6 +146,7 @@ class PersistEventsStore:
         self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
         self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
 
+    @trace
     async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
@@ -1490,7 +1492,7 @@ class PersistEventsStore:
                     event.sender,
                     "url" in event.content and isinstance(event.content["url"], str),
                     event.get_state_key(),
-                    context.rejected or None,
+                    context.rejected,
                 )
                 for event, context in events_and_contexts
             ),
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5914a35420..8a7cdb024d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
     current_context,
     make_deferred_yieldable,
 )
+from synapse.logging.opentracing import start_active_span, tag_args, trace
 from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
@@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
+    @trace
+    @tag_args
     async def get_events_as_list(
         self,
         event_ids: Collection[str],
@@ -600,7 +603,11 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             map from event id to result
         """
-        event_entry_map = await self._get_events_from_cache(
+        # Shortcut: check if we have any events in the *in memory* cache - this function
+        # may be called repeatedly for the same event so at this point we cannot reach
+        # out to any external cache for performance reasons. The external cache is
+        # checked later on in the `get_missing_events_from_cache_or_db` function below.
+        event_entry_map = self._get_events_from_local_cache(
             event_ids,
         )
 
@@ -632,7 +639,9 @@ class EventsWorkerStore(SQLBaseStore):
 
         if missing_events_ids:
 
-            async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+            async def get_missing_events_from_cache_or_db() -> Dict[
+                str, EventCacheEntry
+            ]:
                 """Fetches the events in `missing_event_ids` from the database.
 
                 Also creates entries in `self._current_event_fetches` to allow
@@ -657,10 +666,18 @@ class EventsWorkerStore(SQLBaseStore):
                 # the events have been redacted, and if so pulling the redaction event
                 # out of the database to check it.
                 #
+                missing_events = {}
                 try:
-                    missing_events = await self._get_events_from_db(
+                    # Try to fetch from any external cache. We already checked the
+                    # in-memory cache above.
+                    missing_events = await self._get_events_from_external_cache(
                         missing_events_ids,
                     )
+                    # Now actually fetch any remaining events from the DB
+                    db_missing_events = await self._get_events_from_db(
+                        missing_events_ids - missing_events.keys(),
+                    )
+                    missing_events.update(db_missing_events)
                 except Exception as e:
                     with PreserveLoggingContext():
                         fetching_deferred.errback(e)
@@ -679,7 +696,7 @@ class EventsWorkerStore(SQLBaseStore):
             # cancellations, since multiple `_get_events_from_cache_or_db` calls can
             # reuse the same fetch.
             missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
-                get_missing_events_from_db()
+                get_missing_events_from_cache_or_db()
             )
             event_entry_map.update(missing_events)
 
@@ -754,7 +771,54 @@ class EventsWorkerStore(SQLBaseStore):
     async def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
     ) -> Dict[str, EventCacheEntry]:
-        """Fetch events from the caches.
+        """Fetch events from the caches, both in memory and any external.
+
+        May return rejected events.
+
+        Args:
+            events: list of event_ids to fetch
+            update_metrics: Whether to update the cache hit ratio metrics
+        """
+        event_map = self._get_events_from_local_cache(
+            events, update_metrics=update_metrics
+        )
+
+        missing_event_ids = (e for e in events if e not in event_map)
+        event_map.update(
+            await self._get_events_from_external_cache(
+                events=missing_event_ids,
+                update_metrics=update_metrics,
+            )
+        )
+
+        return event_map
+
+    async def _get_events_from_external_cache(
+        self, events: Iterable[str], update_metrics: bool = True
+    ) -> Dict[str, EventCacheEntry]:
+        """Fetch events from any configured external cache.
+
+        May return rejected events.
+
+        Args:
+            events: list of event_ids to fetch
+            update_metrics: Whether to update the cache hit ratio metrics
+        """
+        event_map = {}
+
+        for event_id in events:
+            ret = await self._get_event_cache.get_external(
+                (event_id,), None, update_metrics=update_metrics
+            )
+            if ret:
+                event_map[event_id] = ret
+
+        return event_map
+
+    def _get_events_from_local_cache(
+        self, events: Iterable[str], update_metrics: bool = True
+    ) -> Dict[str, EventCacheEntry]:
+        """Fetch events from the local, in memory, caches.
 
         May return rejected events.
 
@@ -766,7 +830,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         for event_id in events:
             # First check if it's in the event cache
-            ret = await self._get_event_cache.get(
+            ret = self._get_event_cache.get_local(
                 (event_id,), None, update_metrics=update_metrics
             )
             if ret:
@@ -788,7 +852,7 @@ class EventsWorkerStore(SQLBaseStore):
 
                 # We add the entry back into the cache as we want to keep
                 # recently queried events in the cache.
-                await self._get_event_cache.set((event_id,), cache_entry)
+                self._get_event_cache.set_local((event_id,), cache_entry)
 
         return event_map
 
@@ -1029,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore):
         """
         fetched_event_ids: Set[str] = set()
         fetched_events: Dict[str, _EventRow] = {}
-        events_to_fetch = event_ids
 
-        while events_to_fetch:
-            row_map = await self._enqueue_events(events_to_fetch)
+        async def _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids_to_fetch: Collection[str],
+        ) -> Collection[str]:
+            """
+            Fetch all of the given event_ids and return any associated redaction event_ids
+            that we still need to fetch in the next iteration.
+            """
+            row_map = await self._enqueue_events(event_ids_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids: Set[str] = set()
-            for event_id in events_to_fetch:
+            for event_id in event_ids_to_fetch:
                 row = row_map.get(event_id)
                 fetched_event_ids.add(event_id)
                 if row:
                     fetched_events[event_id] = row
                     redaction_ids.update(row.redactions)
 
-            events_to_fetch = redaction_ids.difference(fetched_event_ids)
-            if events_to_fetch:
-                logger.debug("Also fetching redaction events %s", events_to_fetch)
+            event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+            return event_ids_to_fetch
+
+        # Grab the initial list of events requested
+        event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+            event_ids
+        )
+        # Then go and recursively find all of the associated redactions
+        with start_active_span("recursively fetching redactions"):
+            while event_ids_to_fetch:
+                logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+                event_ids_to_fetch = (
+                    await _fetch_event_ids_and_get_outstanding_redactions(
+                        event_ids_to_fetch
+                    )
+                )
 
         # build a map from event_id to EventBase
         event_map: Dict[str, EventBase] = {}
@@ -1363,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
+    @trace
+    @tag_args
     async def have_seen_events(
         self, room_id: str, event_ids: Iterable[str]
     ) -> Set[str]:
@@ -2110,14 +2195,92 @@ class EventsWorkerStore(SQLBaseStore):
     def _get_partial_state_events_batch_txn(
         txn: LoggingTransaction, room_id: str
     ) -> List[str]:
+        # we want to work through the events from oldest to newest, so
+        # we only want events whose prev_events do *not* have partial state - hence
+        # the 'NOT EXISTS' clause in the below.
+        #
+        # This is necessary because ordering by stream ordering isn't quite enough
+        # to ensure that we work from oldest to newest event (in particular,
+        # if an event is initially persisted as an outlier and later de-outliered,
+        # it can end up with a lower stream_ordering than its prev_events).
+        #
+        # Typically this means we'll only return one event per batch, but that's
+        # hard to do much about.
+        #
+        # See also: https://github.com/matrix-org/synapse/issues/13001
         txn.execute(
             """
             SELECT event_id FROM partial_state_events AS pse
                 JOIN events USING (event_id)
-            WHERE pse.room_id = ?
+            WHERE pse.room_id = ? AND
+               NOT EXISTS(
+                  SELECT 1 FROM event_edges AS ee
+                     JOIN partial_state_events AS prev_pse ON (prev_pse.event_id=ee.prev_event_id)
+                     WHERE ee.event_id=pse.event_id
+               )
             ORDER BY events.stream_ordering
             LIMIT 100
             """,
             (room_id,),
         )
         return [row[0] for row in txn]
+
+    def mark_event_rejected_txn(
+        self,
+        txn: LoggingTransaction,
+        event_id: str,
+        rejection_reason: Optional[str],
+    ) -> None:
+        """Mark an event that was previously accepted as rejected, or vice versa
+
+        This can happen, for example, when resyncing state during a faster join.
+
+        Args:
+            txn:
+            event_id: ID of event to update
+            rejection_reason: reason it has been rejected, or None if it is now accepted
+        """
+        if rejection_reason is None:
+            logger.info(
+                "Marking previously-processed event %s as accepted",
+                event_id,
+            )
+            self.db_pool.simple_delete_txn(
+                txn,
+                "rejections",
+                keyvalues={"event_id": event_id},
+            )
+        else:
+            logger.info(
+                "Marking previously-processed event %s as rejected(%s)",
+                event_id,
+                rejection_reason,
+            )
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="rejections",
+                keyvalues={"event_id": event_id},
+                values={
+                    "reason": rejection_reason,
+                    "last_check": self._clock.time_msec(),
+                },
+            )
+        self.db_pool.simple_update_txn(
+            txn,
+            table="events",
+            keyvalues={"event_id": event_id},
+            updatevalues={"rejection_reason": rejection_reason},
+        )
+
+        self.invalidate_get_event_cache_after_txn(txn, event_id)
+
+        # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
+        #   call '_send_invalidation_to_replication', but we actually need the other
+        #   end to call _invalidate_local_get_event_cache() rather than (just)
+        #   _get_event_cache.invalidate().
+        #
+        #   One solution might be to (somehow) get the workers to call
+        #   _invalidate_caches_for_event() (though that will invalidate more than
+        #   strictly necessary).
+        #
+        #   https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 768f95d16c..5079edd1e0 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,11 +14,23 @@
 # limitations under the License.
 import abc
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+    cast,
+)
 
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import list_with_base_rules
+from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
@@ -50,60 +62,30 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-def _is_experimental_rule_enabled(
-    rule_id: str, experimental_config: ExperimentalConfig
-) -> bool:
-    """Used by `_load_rules` to filter out experimental rules when they
-    have not been enabled.
-    """
-    if (
-        rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
-        and not experimental_config.msc3786_enabled
-    ):
-        return False
-    if (
-        rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
-        and not experimental_config.msc3772_enabled
-    ):
-        return False
-    return True
-
-
 def _load_rules(
     rawrules: List[JsonDict],
     enabled_map: Dict[str, bool],
     experimental_config: ExperimentalConfig,
-) -> List[JsonDict]:
-    ruleslist = []
-    for rawrule in rawrules:
-        rule = dict(rawrule)
-        rule["conditions"] = db_to_json(rawrule["conditions"])
-        rule["actions"] = db_to_json(rawrule["actions"])
-        rule["default"] = False
-        ruleslist.append(rule)
-
-    # We're going to be mutating this a lot, so copy it. We also filter out
-    # any experimental default push rules that aren't enabled.
-    rules = [
-        rule
-        for rule in list_with_base_rules(ruleslist)
-        if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
-    ]
+) -> FilteredPushRules:
+    """Take the DB rows returned from the DB and convert them into a full
+    `FilteredPushRules` object.
+    """
 
-    for i, rule in enumerate(rules):
-        rule_id = rule["rule_id"]
+    ruleslist = [
+        PushRule(
+            rule_id=rawrule["rule_id"],
+            priority_class=rawrule["priority_class"],
+            conditions=db_to_json(rawrule["conditions"]),
+            actions=db_to_json(rawrule["actions"]),
+        )
+        for rawrule in rawrules
+    ]
 
-        if rule_id not in enabled_map:
-            continue
-        if rule.get("enabled", True) == bool(enabled_map[rule_id]):
-            continue
+    push_rules = compile_push_rules(ruleslist)
 
-        # Rules are cached across users.
-        rule = dict(rule)
-        rule["enabled"] = bool(enabled_map[rule_id])
-        rules[i] = rule
+    filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
 
-    return rules
+    return filtered_rules
 
 
 # The ABCMeta metaclass ensures that it cannot be instantiated without
@@ -162,7 +144,7 @@ class PushRulesWorkerStore(
         raise NotImplementedError()
 
     @cached(max_entries=5000)
-    async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
+    async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
         rows = await self.db_pool.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
@@ -183,7 +165,6 @@ class PushRulesWorkerStore(
 
         return _load_rules(rows, enabled_map, self.hs.config.experimental)
 
-    @cached(max_entries=5000)
     async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
         results = await self.db_pool.simple_select_list(
             table="push_rules_enable",
@@ -216,11 +197,11 @@ class PushRulesWorkerStore(
     @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
     async def bulk_get_push_rules(
         self, user_ids: Collection[str]
-    ) -> Dict[str, List[JsonDict]]:
+    ) -> Dict[str, FilteredPushRules]:
         if not user_ids:
             return {}
 
-        results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+        raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
 
         rows = await self.db_pool.simple_select_many_batch(
             table="push_rules",
@@ -234,20 +215,19 @@ class PushRulesWorkerStore(
         rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
 
         for row in rows:
-            results.setdefault(row["user_name"], []).append(row)
+            raw_rules.setdefault(row["user_name"], []).append(row)
 
         enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
 
-        for user_id, rules in results.items():
+        results: Dict[str, FilteredPushRules] = {}
+
+        for user_id, rules in raw_rules.items():
             results[user_id] = _load_rules(
                 rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
             )
 
         return results
 
-    @cachedList(
-        cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids"
-    )
     async def bulk_get_push_rules_enabled(
         self, user_ids: Collection[str]
     ) -> Dict[str, Dict[str, bool]]:
@@ -262,6 +242,7 @@ class PushRulesWorkerStore(
             iterable=user_ids,
             retcols=("user_name", "rule_id", "enabled"),
             desc="bulk_get_push_rules_enabled",
+            batch_size=1000,
         )
         for row in rows:
             enabled = bool(row["enabled"])
@@ -345,8 +326,8 @@ class PushRuleStore(PushRulesWorkerStore):
         user_id: str,
         rule_id: str,
         priority_class: int,
-        conditions: List[Dict[str, str]],
-        actions: List[Union[JsonDict, str]],
+        conditions: Sequence[Mapping[str, str]],
+        actions: Sequence[Union[Mapping[str, Any], str]],
         before: Optional[str] = None,
         after: Optional[str] = None,
     ) -> None:
@@ -808,7 +789,6 @@ class PushRuleStore(PushRulesWorkerStore):
         self.db_pool.simple_insert_txn(txn, "push_rules_stream", values=values)
 
         txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
-        txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
         txn.call_after(
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
@@ -817,7 +797,7 @@ class PushRuleStore(PushRulesWorkerStore):
         return self._push_rules_stream_id_gen.get_current_token()
 
     async def copy_push_rule_from_room_to_room(
-        self, new_room_id: str, user_id: str, rule: dict
+        self, new_room_id: str, user_id: str, rule: PushRule
     ) -> None:
         """Copy a single push rule from one room to another for a specific user.
 
@@ -827,21 +807,27 @@ class PushRuleStore(PushRulesWorkerStore):
             rule: A push rule.
         """
         # Create new rule id
-        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+        rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
         new_rule_id = rule_id_scope + "/" + new_room_id
 
+        new_conditions = []
+
         # Change room id in each condition
-        for condition in rule.get("conditions", []):
+        for condition in rule.conditions:
+            new_condition = condition
             if condition.get("key") == "room_id":
-                condition["pattern"] = new_room_id
+                new_condition = dict(condition)
+                new_condition["pattern"] = new_room_id
+
+            new_conditions.append(new_condition)
 
         # Add the rule for the new room
         await self.add_push_rule(
             user_id=user_id,
             rule_id=new_rule_id,
-            priority_class=rule["priority_class"],
-            conditions=rule["conditions"],
-            actions=rule["actions"],
+            priority_class=rule.priority_class,
+            conditions=new_conditions,
+            actions=rule.actions,
         )
 
     async def copy_push_rules_from_room_to_room_for_user(
@@ -859,8 +845,11 @@ class PushRuleStore(PushRulesWorkerStore):
         user_push_rules = await self.get_push_rules_for_user(user_id)
 
         # Get rules relating to the old room and copy them to the new room
-        for rule in user_push_rules:
-            conditions = rule.get("conditions", [])
+        for rule, enabled in user_push_rules:
+            if not enabled:
+                continue
+
+            conditions = rule.conditions
             if any(
                 (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
                 for c in conditions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0090c9f225..124c70ad37 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -161,7 +161,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             receipt_type: The receipt types to fetch.
 
         Returns:
-            The latest receipt, if one exists.
+            The event ID and stream ordering of the latest receipt, if one exists.
         """
 
         clause, args = make_in_list_sql_clause(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index cb63cd9b7d..7fb9c801da 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -69,9 +69,9 @@ class TokenLookupResult:
     """
 
     user_id: str
+    token_id: int
     is_guest: bool = False
     shadow_banned: bool = False
-    token_id: Optional[int] = None
     device_id: Optional[str] = None
     valid_until_ms: Optional[int] = None
     token_owner: str = attr.ib()
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b457bc189e..7bd27790eb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -62,7 +62,6 @@ class RelationsWorkerStore(SQLBaseStore):
         room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
-        aggregation_key: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
         from_token: Optional[StreamToken] = None,
@@ -76,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
             room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
-            aggregation_key: Only fetch events with this aggregation key, if given.
             limit: Only fetch the most recent `limit` events.
             direction: Whether to fetch the most recent first (`"b"`) or the
                 oldest first (`"f"`).
@@ -105,10 +103,6 @@ class RelationsWorkerStore(SQLBaseStore):
             where_clause.append("type = ?")
             where_args.append(event_type)
 
-        if aggregation_key:
-            where_clause.append("aggregation_key = ?")
-            where_args.append(aggregation_key)
-
         pagination_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d6d485507b..b7d4baa6bb 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -207,7 +207,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
     def _construct_room_type_where_clause(
         self, room_types: Union[List[Union[str, None]], None]
     ) -> Tuple[Union[str, None], List[str]]:
-        if not room_types or not self.config.experimental.msc3827_enabled:
+        if not room_types:
             return None, []
         else:
             # We use None when we want get rooms without a type
@@ -2001,9 +2001,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
 
             where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
 
+            # We join on room_stats_state despite not using any columns from it
+            # because the join can influence the number of rows returned;
+            # e.g. a room that doesn't have state, maybe because it was deleted.
+            # The query returning the total count should be consistent with
+            # the query returning the results.
             sql = """
                 SELECT COUNT(*) as total_event_reports
                 FROM event_reports AS er
+                JOIN room_stats_state ON room_stats_state.room_id = er.room_id
                 {}
                 """.format(
                 where_clause
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index df6b82660e..046ad3a11c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,6 +21,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -55,6 +56,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -183,7 +185,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 self._check_safe_current_state_events_membership_updated_txn,
             )
 
-    @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
+    @cached(max_entries=100000, iterable=True)
     async def get_users_in_room(self, room_id: str) -> List[str]:
         return await self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
@@ -281,6 +283,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         Returns:
             A mapping from user ID to ProfileInfo.
+
+        Preconditions:
+          - There is full state available for the room (it is not partial-stated).
         """
 
         def _get_users_in_room_with_profiles(
@@ -561,7 +566,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return results_dict.get("membership"), results_dict.get("event_id")
 
-    @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
+    @cached(max_entries=500000, iterable=True)
     async def get_rooms_for_user_with_stream_ordering(
         self, user_id: str
     ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@@ -732,25 +737,76 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
         return frozenset(r.room_id for r in rooms)
 
-    @cached(
-        max_entries=500000,
-        cache_context=True,
-        iterable=True,
-        prune_unread_entries=False,
+    @cached(max_entries=10000)
+    async def does_pair_of_users_share_a_room(
+        self, user_id: str, other_user_id: str
+    ) -> bool:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids"
     )
-    async def get_users_who_share_room_with_user(
-        self, user_id: str, cache_context: _CacheContext
+    async def _do_users_share_a_room(
+        self, user_id: str, other_user_ids: Collection[str]
+    ) -> Mapping[str, Optional[bool]]:
+        """Return mapping from user ID to whether they share a room with the
+        given user.
+
+        Note: `None` and `False` are equivalent and mean they don't share a
+        room.
+        """
+
+        def do_users_share_a_room_txn(
+            txn: LoggingTransaction, user_ids: Collection[str]
+        ) -> Dict[str, bool]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "state_key", user_ids
+            )
+
+            # This query works by fetching both the list of rooms for the target
+            # user and the set of other users, and then checking if there is any
+            # overlap.
+            sql = f"""
+                SELECT b.state_key
+                FROM (
+                    SELECT room_id FROM current_state_events
+                    WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
+                ) AS a
+                INNER JOIN (
+                    SELECT room_id, state_key FROM current_state_events
+                    WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
+                ) AS b using (room_id)
+                LIMIT 1
+            """
+
+            txn.execute(sql, (user_id, *args))
+            return {u: True for u, in txn}
+
+        to_return = {}
+        for batch_user_ids in batch_iter(other_user_ids, 1000):
+            res = await self.db_pool.runInteraction(
+                "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids
+            )
+            to_return.update(res)
+
+        return to_return
+
+    async def do_users_share_a_room(
+        self, user_id: str, other_user_ids: Collection[str]
     ) -> Set[str]:
+        """Return the set of users who share a room with the first users"""
+
+        user_dict = await self._do_users_share_a_room(user_id, other_user_ids)
+
+        return {u for u, share_room in user_dict.items() if share_room}
+
+    async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
         """Returns the set of users who share a room with `user_id`"""
-        room_ids = await self.get_rooms_for_user(
-            user_id, on_invalidate=cache_context.invalidate
-        )
+        room_ids = await self.get_rooms_for_user(user_id)
 
         user_who_share_room = set()
         for room_id in room_ids:
-            user_ids = await self.get_users_in_room(
-                room_id, on_invalidate=cache_context.invalidate
-            )
+            user_ids = await self.get_users_in_room(room_id)
             user_who_share_room.update(user_ids)
 
         return user_who_share_room
@@ -779,9 +835,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return shared_room_ids or frozenset()
 
-    async def get_joined_users_from_state(
+    async def get_joined_user_ids_from_state(
         self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -792,25 +848,25 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
-            return await self._get_joined_users_from_context(
+            return await self._get_joined_user_ids_from_context(
                 room_id, state_group, state, context=state_entry
             )
 
     @cached(num_args=2, iterable=True, max_entries=100000)
-    async def _get_joined_users_from_context(
+    async def _get_joined_user_ids_from_context(
         self,
         room_id: str,
         state_group: Union[object, int],
         current_state_ids: StateMap[str],
         event: Optional[EventBase] = None,
         context: Optional["_StateCacheEntry"] = None,
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
         # with a state_group of None are likely to be different.
         assert state_group is not None
 
-        users_in_room = {}
+        users_in_room = set()
         member_event_ids = [
             e_id
             for key, e_id in current_state_ids.items()
@@ -823,11 +879,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # If we do then we can reuse that result and simply update it with
             # any membership changes in `delta_ids`
             if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get_immediate(
+                prev_res = self._get_joined_user_ids_from_context.cache.get_immediate(
                     (room_id, context.prev_group), None
                 )
-                if prev_res and isinstance(prev_res, dict):
-                    users_in_room = dict(prev_res)
+                if prev_res and isinstance(prev_res, set):
+                    users_in_room = prev_res
                     member_event_ids = [
                         e_id
                         for key, e_id in context.delta_ids.items()
@@ -835,7 +891,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                     ]
                     for etype, state_key in context.delta_ids:
                         if etype == EventTypes.Member:
-                            users_in_room.pop(state_key, None)
+                            users_in_room.discard(state_key)
 
         # We check if we have any of the member event ids in the event cache
         # before we ask the DB
@@ -843,7 +899,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # We don't update the event cache hit ratio as it completely throws off
         # the hit ratio counts. After all, we don't populate the cache if we
         # miss it here
-        event_map = await self._get_events_from_cache(
+        event_map = self._get_events_from_local_cache(
             member_event_ids, update_metrics=False
         )
 
@@ -852,71 +908,64 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             ev_entry = event_map.get(event_id)
             if ev_entry and not ev_entry.event.rejected_reason:
                 if ev_entry.event.membership == Membership.JOIN:
-                    users_in_room[ev_entry.event.state_key] = ProfileInfo(
-                        display_name=ev_entry.event.content.get("displayname", None),
-                        avatar_url=ev_entry.event.content.get("avatar_url", None),
-                    )
+                    users_in_room.add(ev_entry.event.state_key)
             else:
                 missing_member_event_ids.append(event_id)
 
         if missing_member_event_ids:
-            event_to_memberships = await self._get_joined_profiles_from_event_ids(
+            event_to_memberships = await self._get_user_ids_from_membership_event_ids(
                 missing_member_event_ids
             )
-            users_in_room.update(row for row in event_to_memberships.values() if row)
+            users_in_room.update(
+                user_id for user_id in event_to_memberships.values() if user_id
+            )
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 if event.event_id in member_event_ids:
-                    users_in_room[event.state_key] = ProfileInfo(
-                        display_name=event.content.get("displayname", None),
-                        avatar_url=event.content.get("avatar_url", None),
-                    )
+                    users_in_room.add(event.state_key)
 
         return users_in_room
 
-    @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(
+    @cached(
+        max_entries=10000,
+        # This name matches the old function that has been replaced - the cache name
+        # is kept here to maintain backwards compatibility.
+        name="_get_joined_profile_from_event_id",
+    )
+    def _get_user_id_from_membership_event_id(
         self, event_id: str
     ) -> Optional[Tuple[str, ProfileInfo]]:
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id",
+        cached_method_name="_get_user_id_from_membership_event_id",
         list_name="event_ids",
     )
-    async def _get_joined_profiles_from_event_ids(
+    async def _get_user_ids_from_membership_event_ids(
         self, event_ids: Iterable[str]
-    ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+    ) -> Dict[str, Optional[str]]:
         """For given set of member event_ids check if they point to a join
-        event and if so return the associated user and profile info.
+        event.
 
         Args:
             event_ids: The member event IDs to lookup
 
         Returns:
-            Map from event ID to `user_id` and ProfileInfo (or None if not join event).
+            Map from event ID to `user_id`, or None if event is not a join.
         """
 
         rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
-            retcols=("user_id", "display_name", "avatar_url", "event_id"),
+            retcols=("user_id", "event_id"),
             keyvalues={"membership": Membership.JOIN},
             batch_size=1000,
-            desc="_get_joined_profiles_from_event_ids",
+            desc="_get_user_ids_from_membership_event_ids",
         )
 
-        return {
-            row["event_id"]: (
-                row["user_id"],
-                ProfileInfo(
-                    avatar_url=row["avatar_url"], display_name=row["display_name"]
-                ),
-            )
-            for row in rows
-        }
+        return {row["event_id"]: row["user_id"] for row in rows}
 
     @cached(max_entries=10000)
     async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1075,12 +1124,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             else:
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
-                joined_users = await self.get_joined_users_from_state(
+                joined_user_ids = await self.get_joined_user_ids_from_state(
                     room_id, state, state_entry
                 )
 
                 cache.hosts_to_joined_users = {}
-                for user_id in joined_users:
+                for user_id in joined_user_ids:
                     host = intern_string(get_domain_from_id(user_id))
                     cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
 
@@ -1159,6 +1208,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
         )
 
+    async def is_locally_forgotten_room(self, room_id: str) -> bool:
+        """Returns whether all local users have forgotten this room_id.
+
+        Args:
+            room_id: The room ID to query.
+
+        Returns:
+            Whether the room is forgotten.
+        """
+
+        sql = """
+            SELECT count(*) > 0 FROM local_current_membership
+            INNER JOIN room_memberships USING (room_id, event_id)
+            WHERE
+                room_id = ?
+                AND forgotten = 0;
+        """
+
+        rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+
+        # `count(*)` returns always an integer
+        # If any rows still exist it means someone has not forgotten this room yet
+        return not rows[0][0]
+
     async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
         """Get all rooms that the user has ever been in.
 
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 9674c4a757..0b10af0e58 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -419,15 +419,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         # anything that was rejected should have the same state as its
         # predecessor.
         if context.rejected:
-            assert context.state_group == context.state_group_before_event
+            state_group = context.state_group_before_event
+        else:
+            state_group = context.state_group
 
         self.db_pool.simple_update_txn(
             txn,
             table="event_to_state_groups",
             keyvalues={"event_id": event.event_id},
-            updatevalues={"state_group": context.state_group},
+            updatevalues={"state_group": state_group},
         )
 
+        # the event may now be rejected where it was not before, or vice versa,
+        # in which case we need to update the rejected flags.
+        if bool(context.rejected) != (event.rejected_reason is not None):
+            self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
+
         self.db_pool.simple_delete_one_txn(
             txn,
             table="partial_state_events",
@@ -440,7 +447,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         txn.call_after(
             self._get_state_group_for_event.prefill,
             (event.event_id,),
-            context.state_group,
+            state_group,
         )
 
 
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 2590b52f73..a347430aa7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -58,6 +58,7 @@ from twisted.internet import defer
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import trace
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -1346,6 +1347,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return rows, next_token
 
+    @trace
     async def paginate_room_events(
         self,
         room_id: str,