summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorEric Eastwood <erice@element.io>2022-08-09 14:50:10 -0500
committerEric Eastwood <erice@element.io>2022-08-09 14:50:10 -0500
commit2a467fd26b369d937a04da950d9ecaf8f23aa382 (patch)
treef2c9f654a0f67d7bbcdf459b9b4a918923d041ca /synapse
parentOnly set attribute if going forward (diff)
parentMerge branch 'develop' into madlittlemods/11850-migrate-to-opentelemetry (diff)
downloadsynapse-2a467fd26b369d937a04da950d9ecaf8f23aa382.tar.xz
Merge branch 'madlittlemods/11850-migrate-to-opentelemetry' into madlittlemods/13356-messages-investigation-scratch-v1
Conflicts:
	pyproject.toml
	synapse/logging/tracing.py
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/constants.py3
-rw-r--r--synapse/config/experimental.py2
-rw-r--r--synapse/events/snapshot.py7
-rw-r--r--synapse/federation/federation_server.py17
-rw-r--r--synapse/handlers/federation.py17
-rw-r--r--synapse/handlers/initial_sync.py11
-rw-r--r--synapse/handlers/receipts.py36
-rw-r--r--synapse/handlers/sync.py25
-rw-r--r--synapse/logging/tracing.py36
-rw-r--r--synapse/module_api/__init__.py59
-rw-r--r--synapse/replication/tcp/client.py5
-rw-r--r--synapse/rest/client/notifications.py7
-rw-r--r--synapse/rest/client/read_marker.py8
-rw-r--r--synapse/rest/client/receipts.py10
-rw-r--r--synapse/rest/client/versions.py1
-rw-r--r--synapse/state/v2.py12
-rw-r--r--synapse/storage/databases/main/event_push_actions.py402
-rw-r--r--synapse/storage/databases/main/events.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py75
-rw-r--r--synapse/storage/databases/main/roommember.py2
-rw-r--r--synapse/util/caches/lrucache.py17
21 files changed, 554 insertions, 200 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index fc04e4d4bd..7cacd107d4 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -260,7 +260,8 @@ class GuestAccess:
 
 class ReceiptTypes:
     READ: Final = "m.read"
-    READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
+    READ_PRIVATE: Final = "m.read.private"
+    UNSTABLE_READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
     FULLY_READ: Final = "m.fully_read"
 
 
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index c2ecd977cd..7d17c958bb 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -32,7 +32,7 @@ class ExperimentalConfig(Config):
         # MSC2716 (importing historical messages)
         self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)
 
-        # MSC2285 (private read receipts)
+        # MSC2285 (unstable private read receipts)
         self.msc2285_enabled: bool = experimental.get("msc2285_enabled", False)
 
         # MSC3244 (room version capabilities)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index b700cbbfa1..d3c8083e4a 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,11 +11,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 import attr
 from frozendict import frozendict
-from typing_extensions import Literal
 
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
@@ -33,7 +32,7 @@ class EventContext:
     Holds information relevant to persisting an event
 
     Attributes:
-        rejected: A rejection reason if the event was rejected, else False
+        rejected: A rejection reason if the event was rejected, else None
 
         _state_group: The ID of the state group for this event. Note that state events
             are persisted with a state group which includes the new event, so this is
@@ -85,7 +84,7 @@ class EventContext:
     """
 
     _storage: "StorageControllers"
-    rejected: Union[Literal[False], str] = False
+    rejected: Optional[str] = None
     _state_group: Optional[int] = None
     state_group_before_event: Optional[int] = None
     _state_delta_due_to_event: Optional[StateMap[str]] = None
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index c9cd02ebeb..c5b3b3cedf 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -843,8 +843,25 @@ class FederationServer(FederationBase):
                 Codes.BAD_JSON,
             )
 
+        # Note that get_room_version throws if the room does not exist here.
         room_version = await self.store.get_room_version(room_id)
 
+        if await self.store.is_partial_state_room(room_id):
+            # If our server is still only partially joined, we can't give a complete
+            # response to /send_join, /send_knock or /send_leave.
+            # This is because we will not be able to provide the server list (for partial
+            # joins) or the full state (for full joins).
+            # Return a 404 as we would if we weren't in the room at all.
+            logger.info(
+                f"Rejecting /send_{membership_type} to %s because it's a partial state room",
+                room_id,
+            )
+            raise SynapseError(
+                404,
+                f"Unable to handle /send_{membership_type} right now; this server is not fully joined.",
+                errcode=Codes.NOT_FOUND,
+            )
+
         if membership_type == Membership.KNOCK and not room_version.msc2403_knocking:
             raise SynapseError(
                 403,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index a27f9f6246..39be782937 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -756,6 +756,23 @@ class FederationHandler:
         # (and return a 404 otherwise)
         room_version = await self.store.get_room_version(room_id)
 
+        if await self.store.is_partial_state_room(room_id):
+            # If our server is still only partially joined, we can't give a complete
+            # response to /make_join, so return a 404 as we would if we weren't in the
+            # room at all.
+            # The main reason we can't respond properly is that we need to know about
+            # the auth events for the join event that we would return.
+            # We also should not bother entertaining the /make_join since we cannot
+            # handle the /send_join.
+            logger.info(
+                "Rejecting /make_join to %s because it's a partial state room", room_id
+            )
+            raise SynapseError(
+                404,
+                "Unable to handle /make_join right now; this server is not fully joined.",
+                errcode=Codes.NOT_FOUND,
+            )
+
         # now check that we are *still* in the room
         is_in_room = await self._event_auth_handler.check_host_in_room(
             room_id, self.server_name
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 85b472f250..6484e47e5f 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -143,8 +143,8 @@ class InitialSyncHandler:
             joined_rooms,
             to_key=int(now_token.receipt_key),
         )
-        if self.hs.config.experimental.msc2285_enabled:
-            receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
+
+        receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
 
@@ -456,11 +456,8 @@ class InitialSyncHandler:
             )
             if not receipts:
                 return []
-            if self.hs.config.experimental.msc2285_enabled:
-                receipts = ReceiptEventSource.filter_out_private_receipts(
-                    receipts, user_id
-                )
-            return receipts
+
+            return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
             gather_results(
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 43d2882b0a..d4a866b346 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -163,7 +163,10 @@ class ReceiptsHandler:
         if not is_new:
             return
 
-        if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE:
+        if self.federation_sender and receipt_type not in (
+            ReceiptTypes.READ_PRIVATE,
+            ReceiptTypes.UNSTABLE_READ_PRIVATE,
+        ):
             await self.federation_sender.send_read_receipt(receipt)
 
 
@@ -203,24 +206,38 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
             for event_id, orig_event_content in room.get("content", {}).items():
                 event_content = orig_event_content
                 # If there are private read receipts, additional logic is necessary.
-                if ReceiptTypes.READ_PRIVATE in event_content:
+                if (
+                    ReceiptTypes.READ_PRIVATE in event_content
+                    or ReceiptTypes.UNSTABLE_READ_PRIVATE in event_content
+                ):
                     # Make a copy without private read receipts to avoid leaking
                     # other user's private read receipts..
                     event_content = {
                         receipt_type: receipt_value
                         for receipt_type, receipt_value in event_content.items()
-                        if receipt_type != ReceiptTypes.READ_PRIVATE
+                        if receipt_type
+                        not in (
+                            ReceiptTypes.READ_PRIVATE,
+                            ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                        )
                     }
 
                     # Copy the current user's private read receipt from the
                     # original content, if it exists.
-                    user_private_read_receipt = orig_event_content[
-                        ReceiptTypes.READ_PRIVATE
-                    ].get(user_id, None)
+                    user_private_read_receipt = orig_event_content.get(
+                        ReceiptTypes.READ_PRIVATE, {}
+                    ).get(user_id, None)
                     if user_private_read_receipt:
                         event_content[ReceiptTypes.READ_PRIVATE] = {
                             user_id: user_private_read_receipt
                         }
+                    user_unstable_private_read_receipt = orig_event_content.get(
+                        ReceiptTypes.UNSTABLE_READ_PRIVATE, {}
+                    ).get(user_id, None)
+                    if user_unstable_private_read_receipt:
+                        event_content[ReceiptTypes.UNSTABLE_READ_PRIVATE] = {
+                            user_id: user_unstable_private_read_receipt
+                        }
 
                 # Include the event if there is at least one non-private read
                 # receipt or the current user has a private read receipt.
@@ -256,10 +273,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
             room_ids, from_key=from_key, to_key=to_key
         )
 
-        if self.config.experimental.msc2285_enabled:
-            events = ReceiptEventSource.filter_out_private_receipts(
-                events, user.to_string()
-            )
+        events = ReceiptEventSource.filter_out_private_receipts(
+            events, user.to_string()
+        )
 
         return events, to_key
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 8dc05d648d..cddfb4cec7 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1541,15 +1541,13 @@ class SyncHandler:
         ignored_users = await self.store.ignored_users(user_id)
         if since_token:
             room_changes = await self._get_rooms_changed(
-                sync_result_builder, ignored_users, self.rooms_to_exclude
+                sync_result_builder, ignored_users
             )
             tags_by_room = await self.store.get_updated_tags(
                 user_id, since_token.account_data_key
             )
         else:
-            room_changes = await self._get_all_rooms(
-                sync_result_builder, ignored_users, self.rooms_to_exclude
-            )
+            room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
             tags_by_room = await self.store.get_tags_for_user(user_id)
 
         log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1628,13 +1626,14 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         ignored_users: FrozenSet[str],
-        excluded_rooms: List[str],
     ) -> _RoomChanges:
         """Determine the changes in rooms to report to the user.
 
         This function is a first pass at generating the rooms part of the sync response.
         It determines which rooms have changed during the sync period, and categorises
-        them into four buckets: "knock", "invite", "join" and "leave".
+        them into four buckets: "knock", "invite", "join" and "leave". It also excludes
+        from that list any room that appears in the list of rooms to exclude from sync
+        results in the server configuration.
 
         1. Finds all membership changes for the user in the sync period (from
            `since_token` up to `now_token`).
@@ -1660,7 +1659,7 @@ class SyncHandler:
         #       _have_rooms_changed. We could keep the results in memory to avoid a
         #       second query, at the cost of more complicated source code.
         membership_change_events = await self.store.get_membership_changes_for_user(
-            user_id, since_token.room_key, now_token.room_key, excluded_rooms
+            user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
         )
 
         mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
@@ -1867,7 +1866,6 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         ignored_users: FrozenSet[str],
-        ignored_rooms: List[str],
     ) -> _RoomChanges:
         """Returns entries for all rooms for the user.
 
@@ -1889,7 +1887,7 @@ class SyncHandler:
         room_list = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id=user_id,
             membership_list=Membership.LIST,
-            excluded_rooms=ignored_rooms,
+            excluded_rooms=self.rooms_to_exclude,
         )
 
         room_entries = []
@@ -2155,7 +2153,9 @@ class SyncHandler:
                 raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
     async def get_rooms_for_user_at(
-        self, user_id: str, room_key: RoomStreamToken
+        self,
+        user_id: str,
+        room_key: RoomStreamToken,
     ) -> FrozenSet[str]:
         """Get set of joined rooms for a user at the given stream ordering.
 
@@ -2181,7 +2181,12 @@ class SyncHandler:
         # If the membership's stream ordering is after the given stream
         # ordering, we need to go and work out if the user was in the room
         # before.
+        # We also need to check whether the room should be excluded from sync
+        # responses as per the homeserver config.
         for joined_room in joined_rooms:
+            if joined_room.room_id in self.rooms_to_exclude:
+                continue
+
             if not joined_room.event_pos.persisted_after(room_key):
                 joined_room_ids.add(joined_room.room_id)
                 continue
diff --git a/synapse/logging/tracing.py b/synapse/logging/tracing.py
index cb557a147d..38521d18df 100644
--- a/synapse/logging/tracing.py
+++ b/synapse/logging/tracing.py
@@ -885,16 +885,17 @@ def trace_with_opname(
 ) -> Callable[[Callable[P, R]], Callable[P, R]]:
     """
     Decorator to trace a function with a custom opname.
-
     See the module's doc string for usage examples.
     """
-
-    @contextlib.contextmanager
-    def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs):
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
         with start_active_span(opname, tracer=tracer):
             yield
 
-    def _decorator(func: Callable[P, R]):
+    def _decorator(func: Callable[P, R]) -> Callable[P, R]:
         if not opentelemetry:
             return func
 
@@ -906,9 +907,7 @@ def trace_with_opname(
 def trace(func: Callable[P, R]) -> Callable[P, R]:
     """
     Decorator to trace a function.
-
     Sets the operation name to that of the function's name.
-
     See the module's doc string for usage examples.
     """
 
@@ -917,19 +916,28 @@ def trace(func: Callable[P, R]) -> Callable[P, R]:
 
 def tag_args(func: Callable[P, R]) -> Callable[P, R]:
     """
-    Decorator to tag all of the args to the active span.
+    Tags all of the args to the active span.
     """
 
     if not opentelemetry:
         return func
 
-    @contextlib.contextmanager
-    def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs):
+    # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
+    @contextlib.contextmanager  # type: ignore[arg-type]
+    def _wrapping_logic(
+        func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
+    ) -> Generator[None, None, None]:
         argspec = inspect.getfullargspec(func)
-        for i, arg in enumerate(args[1:]):
-            set_attribute(SynapseTags.FUNC_ARG_PREFIX + argspec.args[i + 1], str(arg))  # type: ignore[index]
-        set_attribute(SynapseTags.FUNC_ARGS, str(args[len(argspec.args) :]))  # type: ignore[index]
-        set_attribute(SynapseTags.FUNC_KWARGS, str(kwargs))
+        # We use `[1:]` to skip the `self` object reference and `start=1` to
+        # make the index line up with `argspec.args`.
+        #
+        # FIXME: We could update this to handle any type of function by ignoring the
+        #   first argument only if it's named `self` or `cls`. This isn't fool-proof
+        #   but handles the idiomatic cases.
+        for i, arg in enumerate(args[1:], start=1):  # type: ignore[index]
+            set_attribute("ARG_" + argspec.args[i], str(arg))
+        set_attribute("args", str(args[len(argspec.args) :]))  # type: ignore[index]
+        set_attribute("kwargs", str(kwargs))
         yield
 
     return _custom_sync_async_decorator(func, _wrapping_logic)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 18d6d1058a..87ba154cb7 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -929,10 +929,12 @@ class ModuleApi:
         room_id: str,
         new_membership: str,
         content: Optional[JsonDict] = None,
+        remote_room_hosts: Optional[List[str]] = None,
     ) -> EventBase:
         """Updates the membership of a user to the given value.
 
         Added in Synapse v1.46.0.
+        Changed in Synapse v1.65.0: Added the 'remote_room_hosts' parameter.
 
         Args:
             sender: The user performing the membership change. Must be a user local to
@@ -946,6 +948,7 @@ class ModuleApi:
                 https://spec.matrix.org/unstable/client-server-api/#mroommember for the
                 list of allowed values.
             content: Additional values to include in the resulting event's content.
+            remote_room_hosts: Remote servers to use for remote joins/knocks/etc.
 
         Returns:
             The newly created membership event.
@@ -1005,15 +1008,12 @@ class ModuleApi:
             room_id=room_id,
             action=new_membership,
             content=content,
+            remote_room_hosts=remote_room_hosts,
         )
 
         # Try to retrieve the resulting event.
         event = await self._hs.get_datastores().main.get_event(event_id)
 
-        # update_membership is supposed to always return after the event has been
-        # successfully persisted.
-        assert event is not None
-
         return event
 
     async def create_and_send_event_into_room(self, event_dict: JsonDict) -> EventBase:
@@ -1476,6 +1476,57 @@ class ModuleApi:
 
         return room_id.to_string(), hosts
 
+    async def create_room(
+        self,
+        user_id: str,
+        config: JsonDict,
+        ratelimit: bool = True,
+        creator_join_profile: Optional[JsonDict] = None,
+    ) -> Tuple[str, Optional[str]]:
+        """Creates a new room.
+
+        Added in Synapse v1.65.0.
+
+        Args:
+            user_id:
+                The user who requested the room creation.
+            config : A dict of configuration options. See "Request body" of:
+                https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
+            ratelimit: set to False to disable the rate limiter for this specific operation.
+
+            creator_join_profile:
+                Set to override the displayname and avatar for the creating
+                user in this room. If unset, displayname and avatar will be
+                derived from the user's profile. If set, should contain the
+                values to go in the body of the 'join' event (typically
+                `avatar_url` and/or `displayname`.
+
+        Returns:
+                A tuple containing: 1) the room ID (str), 2) if an alias was requested,
+                the room alias (str), otherwise None if no alias was requested.
+
+        Raises:
+            ResourceLimitError if server is blocked to some resource being
+            exceeded.
+            RuntimeError if the user_id does not refer to a local user.
+            SynapseError if the user_id is invalid, room ID couldn't be stored, or
+            something went horribly wrong.
+        """
+        if not self.is_mine(user_id):
+            raise RuntimeError(
+                "Tried to create a room as a user that isn't local to this homeserver",
+            )
+
+        requester = create_requester(user_id)
+        room_id_and_alias, _ = await self._hs.get_room_creation_handler().create_room(
+            requester=requester,
+            config=config,
+            ratelimit=ratelimit,
+            creator_join_profile=creator_join_profile,
+        )
+
+        return room_id_and_alias["room_id"], room_id_and_alias.get("room_alias", None)
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index e4f2201c92..1ed7230e32 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -416,7 +416,10 @@ class FederationSenderHandler:
             if not self._is_mine_id(receipt.user_id):
                 continue
             # Private read receipts never get sent over federation.
-            if receipt.receipt_type == ReceiptTypes.READ_PRIVATE:
+            if receipt.receipt_type in (
+                ReceiptTypes.READ_PRIVATE,
+                ReceiptTypes.UNSTABLE_READ_PRIVATE,
+            ):
                 continue
             receipt_info = ReadReceipt(
                 receipt.room_id,
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 24bc7c9095..a73322a6a4 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -58,7 +58,12 @@ class NotificationsServlet(RestServlet):
         )
 
         receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
-            user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+            user_id,
+            [
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+                ReceiptTypes.UNSTABLE_READ_PRIVATE,
+            ],
         )
 
         notif_event_ids = [pa.event_id for pa in push_actions]
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 8896f2df50..aaad8b233f 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -40,9 +40,13 @@ class ReadMarkerRestServlet(RestServlet):
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
-        self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ}
+        self._known_receipt_types = {
+            ReceiptTypes.READ,
+            ReceiptTypes.FULLY_READ,
+            ReceiptTypes.READ_PRIVATE,
+        }
         if hs.config.experimental.msc2285_enabled:
-            self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE)
+            self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE)
 
     async def on_POST(
         self, request: SynapseRequest, room_id: str
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 409bfd43c1..c6108fc5eb 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -44,11 +44,13 @@ class ReceiptRestServlet(RestServlet):
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
-        self._known_receipt_types = {ReceiptTypes.READ}
+        self._known_receipt_types = {
+            ReceiptTypes.READ,
+            ReceiptTypes.READ_PRIVATE,
+            ReceiptTypes.FULLY_READ,
+        }
         if hs.config.experimental.msc2285_enabled:
-            self._known_receipt_types.update(
-                (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
-            )
+            self._known_receipt_types.add(ReceiptTypes.UNSTABLE_READ_PRIVATE)
 
     async def on_POST(
         self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0366986755..c9a830cbac 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -94,6 +94,7 @@ class VersionsRestServlet(RestServlet):
                     # Supports the busy presence state described in MSC3026.
                     "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
                     # Supports receiving private read receipts as per MSC2285
+                    "org.matrix.msc2285.stable": True,  # TODO: Remove when MSC2285 becomes a part of the spec
                     "org.matrix.msc2285": self.config.experimental.msc2285_enabled,
                     # Supports filtering of /publicRooms by room type as per MSC3827
                     "org.matrix.msc3827.stable": True,
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 7db032203b..cf3045f82e 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -434,7 +434,7 @@ async def _add_event_and_auth_chain_to_graph(
     event_id: str,
     event_map: Dict[str, EventBase],
     state_res_store: StateResolutionStore,
-    auth_diff: Set[str],
+    full_conflicted_set: Set[str],
 ) -> None:
     """Helper function for _reverse_topological_power_sort that add the event
     and its auth chain (that is in the auth diff) to the graph
@@ -445,7 +445,7 @@ async def _add_event_and_auth_chain_to_graph(
         event_id: Event to add to the graph
         event_map
         state_res_store
-        auth_diff: Set of event IDs that are in the auth difference.
+        full_conflicted_set: Set of event IDs that are in the full conflicted set.
     """
 
     state = [event_id]
@@ -455,7 +455,7 @@ async def _add_event_and_auth_chain_to_graph(
 
         event = await _get_event(room_id, eid, event_map, state_res_store)
         for aid in event.auth_event_ids():
-            if aid in auth_diff:
+            if aid in full_conflicted_set:
                 if aid not in graph:
                     state.append(aid)
 
@@ -468,7 +468,7 @@ async def _reverse_topological_power_sort(
     event_ids: Iterable[str],
     event_map: Dict[str, EventBase],
     state_res_store: StateResolutionStore,
-    auth_diff: Set[str],
+    full_conflicted_set: Set[str],
 ) -> List[str]:
     """Returns a list of the event_ids sorted by reverse topological ordering,
     and then by power level and origin_server_ts
@@ -479,7 +479,7 @@ async def _reverse_topological_power_sort(
         event_ids: The events to sort
         event_map
         state_res_store
-        auth_diff: Set of event IDs that are in the auth difference.
+        full_conflicted_set: Set of event IDs that are in the full conflicted set.
 
     Returns:
         The sorted list
@@ -488,7 +488,7 @@ async def _reverse_topological_power_sort(
     graph: Dict[str, Set[str]] = {}
     for idx, event_id in enumerate(event_ids, start=1):
         await _add_event_and_auth_chain_to_graph(
-            graph, room_id, event_id, event_map, state_res_store, auth_diff
+            graph, room_id, event_id, event_map, state_res_store, full_conflicted_set
         )
 
         # We await occasionally when we're working with large data sets to
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index dd2627037c..161aad0f89 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -12,6 +12,67 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+"""Responsible for storing and fetching push actions / notifications.
+
+There are two main uses for push actions:
+  1. Sending out push to a user's device; and
+  2. Tracking per-room per-user notification counts (used in sync requests).
+
+For the former we simply use the `event_push_actions` table, which contains all
+the calculated actions for a given user (which were calculated by the
+`BulkPushRuleEvaluator`).
+
+For the latter we could simply count the number of rows in `event_push_actions`
+table for a given room/user, but in practice this is *very* heavyweight when
+there were a large number of notifications (due to e.g. the user never reading a
+room). Plus, keeping all push actions indefinitely uses a lot of disk space.
+
+To fix these issues, we add a new table `event_push_summary` that tracks
+per-user per-room counts of all notifications that happened before a stream
+ordering S. Thus, to get the notification count for a user / room we can simply
+query a single row in `event_push_summary` and count the number of rows in
+`event_push_actions` with a stream ordering larger than S (and as long as S is
+"recent", the number of rows needing to be scanned will be small).
+
+The `event_push_summary` table is updated via a background job that periodically
+chooses a new stream ordering S' (usually the latest stream ordering), counts
+all notifications in `event_push_actions` between the existing S and S', and
+adds them to the existing counts in `event_push_summary`.
+
+This allows us to delete old rows from `event_push_actions` once those rows have
+been counted and added to `event_push_summary` (we call this process
+"rotation").
+
+
+We need to handle when a user sends a read receipt to the room. Again this is
+done as a background process. For each receipt we clear the row in
+`event_push_summary` and count the number of notifications in
+`event_push_actions` that happened after the receipt but before S, and insert
+that count into `event_push_summary` (If the receipt happened *after* S then we
+simply clear the `event_push_summary`.)
+
+Note that its possible that if the read receipt is for an old event the relevant
+`event_push_actions` rows will have been rotated and we get the wrong count
+(it'll be too low). We accept this as a rare edge case that is unlikely to
+impact the user much (since the vast majority of read receipts will be for the
+latest event).
+
+The last complication is to handle the race where we request the notifications
+counts after a user sends a read receipt into the room, but *before* the
+background update handles the receipt (without any special handling the counts
+would be outdated). We fix this by including in `event_push_summary` the read
+receipt we used when updating `event_push_summary`, and every time we query the
+table we check if that matches the most recent read receipt in the room. If yes,
+continue as above, if not we simply query the `event_push_actions` table
+directly.
+
+Since read receipts are almost always for recent events, scanning the
+`event_push_actions` table in this case is unlikely to be a problem. Even if it
+is a problem, it is temporary until the background job handles the new read
+receipt.
+"""
+
 import logging
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
 
@@ -19,7 +80,7 @@ import attr
 
 from synapse.api.constants import ReceiptTypes
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -198,7 +259,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             txn,
             user_id,
             room_id,
-            receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+            receipt_types=(
+                ReceiptTypes.READ,
+                ReceiptTypes.READ_PRIVATE,
+                ReceiptTypes.UNSTABLE_READ_PRIVATE,
+            ),
         )
 
         stream_ordering = None
@@ -265,7 +330,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             counts.notify_count += row[1]
             counts.unread_count += row[2]
 
-        # Next we need to count highlights, which aren't summarized
+        # Next we need to count highlights, which aren't summarised
         sql = """
             SELECT COUNT(*) FROM event_push_actions
             WHERE user_id = ?
@@ -280,7 +345,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         # Finally we need to count push actions that aren't included in the
         # summary returned above, e.g. recent events that haven't been
-        # summarized yet, or the summary is empty due to a recent read receipt.
+        # summarised yet, or the summary is empty due to a recent read receipt.
         stream_ordering = max(stream_ordering, summary_stream_ordering)
         notify_count, unread_count = self._get_notif_unread_count_for_user_room(
             txn, room_id, user_id, stream_ordering
@@ -304,6 +369,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         Does not consult `event_push_summary` table, which may include push
         actions that have been deleted from `event_push_actions` table.
+
+        Args:
+            txn: The database transaction.
+            room_id: The room ID to get unread counts for.
+            user_id: The user ID to get unread counts for.
+            stream_ordering: The (exclusive) minimum stream ordering to consider.
+            max_stream_ordering: The (inclusive) maximum stream ordering to consider.
+                If this is not given, then no maximum is applied.
+
+        Return:
+            A tuple of the notif count and unread count in the given range.
         """
 
         # If there have been no events in the room since the stream ordering,
@@ -376,6 +452,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will be ordered by ascending stream_ordering.
             The list will have between 0~limit entries.
         """
+
         # find rooms that have a read receipt in them and return the next
         # push actions
         def get_after_receipt(
@@ -383,28 +460,41 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         ) -> List[Tuple[str, str, int, str, bool]]:
             # find rooms that have a read receipt in them and return the next
             # push actions
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight "
-                " FROM ("
-                "   SELECT room_id,"
-                "       MAX(stream_ordering) as stream_ordering"
-                "   FROM events"
-                "   INNER JOIN receipts_linearized USING (room_id, event_id)"
-                "   WHERE receipt_type = 'm.read' AND user_id = ?"
-                "   GROUP BY room_id"
-                ") AS rl,"
-                " event_push_actions AS ep"
-                " WHERE"
-                "   ep.room_id = rl.room_id"
-                "   AND ep.stream_ordering > rl.stream_ordering"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering ASC LIMIT ?"
+
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+                    ep.highlight
+                FROM (
+                    SELECT room_id,
+                        MAX(stream_ordering) as stream_ordering
+                    FROM events
+                    INNER JOIN receipts_linearized USING (room_id, event_id)
+                    WHERE {receipt_types_clause} AND user_id = ?
+                    GROUP BY room_id
+                ) AS rl,
+                event_push_actions AS ep
+                WHERE
+                    ep.room_id = rl.room_id
+                    AND ep.stream_ordering > rl.stream_ordering
+                    AND ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering ASC LIMIT ?
+            """
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
             )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
@@ -418,24 +508,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         def get_no_receipt(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight "
-                " FROM event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id NOT IN ("
-                "     SELECT room_id FROM receipts_linearized"
-                "       WHERE receipt_type = 'm.read' AND user_id = ?"
-                "       GROUP BY room_id"
-                "   )"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering ASC LIMIT ?"
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+                    ep.highlight
+                FROM event_push_actions AS ep
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    ep.room_id NOT IN (
+                        SELECT room_id FROM receipts_linearized
+                        WHERE {receipt_types_clause} AND user_id = ?
+                        GROUP BY room_id
+                    )
+                    AND ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering ASC LIMIT ?
+            """
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
             )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
 
@@ -485,34 +587,47 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will be ordered by descending received_ts.
             The list will have between 0~limit entries.
         """
+
         # find rooms that have a read receipt in them and return the most recent
         # push actions
         def get_after_receipt(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "  ep.highlight, e.received_ts"
-                " FROM ("
-                "   SELECT room_id,"
-                "       MAX(stream_ordering) as stream_ordering"
-                "   FROM events"
-                "   INNER JOIN receipts_linearized USING (room_id, event_id)"
-                "   WHERE receipt_type = 'm.read' AND user_id = ?"
-                "   GROUP BY room_id"
-                ") AS rl,"
-                " event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id = rl.room_id"
-                "   AND ep.stream_ordering > rl.stream_ordering"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering DESC LIMIT ?"
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+                    ep.highlight, e.received_ts
+                FROM (
+                    SELECT room_id,
+                        MAX(stream_ordering) as stream_ordering
+                    FROM events
+                    INNER JOIN receipts_linearized USING (room_id, event_id)
+                    WHERE {receipt_types_clause} AND user_id = ?
+                    GROUP BY room_id
+                ) AS rl,
+                event_push_actions AS ep
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    ep.room_id = rl.room_id
+                    AND ep.stream_ordering > rl.stream_ordering
+                    AND ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering DESC LIMIT ?
+            """
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
             )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
@@ -526,24 +641,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         def get_no_receipt(
             txn: LoggingTransaction,
         ) -> List[Tuple[str, str, int, str, bool, int]]:
-            sql = (
-                "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
-                "   ep.highlight, e.received_ts"
-                " FROM event_push_actions AS ep"
-                " INNER JOIN events AS e USING (room_id, event_id)"
-                " WHERE"
-                "   ep.room_id NOT IN ("
-                "     SELECT room_id FROM receipts_linearized"
-                "       WHERE receipt_type = 'm.read' AND user_id = ?"
-                "       GROUP BY room_id"
-                "   )"
-                "   AND ep.user_id = ?"
-                "   AND ep.stream_ordering > ?"
-                "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
-                " ORDER BY ep.stream_ordering DESC LIMIT ?"
+            receipt_types_clause, args = make_in_list_sql_clause(
+                self.database_engine,
+                "receipt_type",
+                (
+                    ReceiptTypes.READ,
+                    ReceiptTypes.READ_PRIVATE,
+                    ReceiptTypes.UNSTABLE_READ_PRIVATE,
+                ),
+            )
+
+            sql = f"""
+                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
+                    ep.highlight, e.received_ts
+                FROM event_push_actions AS ep
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    ep.room_id NOT IN (
+                        SELECT room_id FROM receipts_linearized
+                        WHERE {receipt_types_clause} AND user_id = ?
+                        GROUP BY room_id
+                    )
+                    AND ep.user_id = ?
+                    AND ep.stream_ordering > ?
+                    AND ep.stream_ordering <= ?
+                    AND ep.notif = 1
+                ORDER BY ep.stream_ordering DESC LIMIT ?
+            """
+            args.extend(
+                (user_id, user_id, min_stream_ordering, max_stream_ordering, limit)
             )
-            args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
             txn.execute(sql, args)
             return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
 
@@ -769,12 +896,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # [10, <none>, 20], we should treat this as being equivalent to
         # [10, 10, 20].
         #
-        sql = (
-            "SELECT received_ts FROM events"
-            " WHERE stream_ordering <= ?"
-            " ORDER BY stream_ordering DESC"
-            " LIMIT 1"
-        )
+        sql = """
+            SELECT received_ts FROM events
+            WHERE stream_ordering <= ?
+            ORDER BY stream_ordering DESC
+            LIMIT 1
+        """
 
         while range_end - range_start > 0:
             middle = (range_end + range_start) // 2
@@ -802,14 +929,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         self, stream_ordering: int
     ) -> Optional[int]:
         def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
-            sql = (
-                "SELECT e.received_ts"
-                " FROM event_push_actions AS ep"
-                " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
-                " WHERE ep.stream_ordering > ? AND notif = 1"
-                " ORDER BY ep.stream_ordering ASC"
-                " LIMIT 1"
-            )
+            sql = """
+                SELECT e.received_ts
+                FROM event_push_actions AS ep
+                JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id
+                WHERE ep.stream_ordering > ? AND notif = 1
+                ORDER BY ep.stream_ordering ASC
+                LIMIT 1
+            """
             txn.execute(sql, (stream_ordering,))
             return cast(Optional[Tuple[int]], txn.fetchone())
 
@@ -858,10 +985,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         Any push actions which predate the user's most recent read receipt are
         now redundant, so we can remove them from `event_push_actions` and
         update `event_push_summary`.
+
+        Returns true if all new receipts have been processed.
         """
 
         limit = 100
 
+        # The (inclusive) receipt stream ID that was previously processed..
         min_receipts_stream_id = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_last_receipt_stream_id",
@@ -871,6 +1001,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         max_receipts_stream_id = self._receipts_id_gen.get_current_token()
 
+        # The (inclusive) event stream ordering that was previously summarised.
+        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_push_summary_stream_ordering",
+            keyvalues={},
+            retcol="stream_ordering",
+        )
+
         sql = """
             SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
             FROM receipts_linearized AS r
@@ -895,13 +1033,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
         rows = txn.fetchall()
 
-        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
-            txn,
-            table="event_push_summary_stream_ordering",
-            keyvalues={},
-            retcol="stream_ordering",
-        )
-
         # For each new read receipt we delete push actions from before it and
         # recalculate the summary.
         for _, room_id, user_id, stream_ordering in rows:
@@ -920,10 +1051,13 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 (room_id, user_id, stream_ordering),
             )
 
+            # Fetch the notification counts between the stream ordering of the
+            # latest receipt and what was previously summarised.
             notif_count, unread_count = self._get_notif_unread_count_for_user_room(
                 txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
             )
 
+            # Replace the previous summary with the new counts.
             self.db_pool.simple_upsert_txn(
                 txn,
                 table="event_push_summary",
@@ -956,10 +1090,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         return len(rows) < limit
 
     def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
-        """Archives older notifications into event_push_summary. Returns whether
-        the archiving process has caught up or not.
+        """Archives older notifications (from event_push_actions) into event_push_summary.
+
+        Returns whether the archiving process has caught up or not.
         """
 
+        # The (inclusive) event stream ordering that was previously summarised.
         old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
@@ -974,7 +1110,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             SELECT stream_ordering FROM event_push_actions
             WHERE stream_ordering > ?
             ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
-        """,
+            """,
             (old_rotate_stream_ordering, self._rotate_count),
         )
         stream_row = txn.fetchone()
@@ -993,19 +1129,31 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
 
-        self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
+        self._rotate_notifs_before_txn(
+            txn, old_rotate_stream_ordering, rotate_to_stream_ordering
+        )
 
         return caught_up
 
     def _rotate_notifs_before_txn(
-        self, txn: LoggingTransaction, rotate_to_stream_ordering: int
+        self,
+        txn: LoggingTransaction,
+        old_rotate_stream_ordering: int,
+        rotate_to_stream_ordering: int,
     ) -> None:
-        old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
-            txn,
-            table="event_push_summary_stream_ordering",
-            keyvalues={},
-            retcol="stream_ordering",
-        )
+        """Archives older notifications (from event_push_actions) into event_push_summary.
+
+        Any event_push_actions between old_rotate_stream_ordering (exclusive) and
+        rotate_to_stream_ordering (inclusive) will be added to the event_push_summary
+        table.
+
+        Args:
+            txn: The database transaction.
+            old_rotate_stream_ordering: The previous maximum event stream ordering.
+            rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
+
+        Returns whether the archiving process has caught up or not.
+        """
 
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
@@ -1093,9 +1241,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
     async def _remove_old_push_actions_that_have_rotated(
         self,
     ) -> None:
-        """Clear out old push actions that have been summarized."""
+        """Clear out old push actions that have been summarised."""
 
-        # We want to clear out anything that older than a day that *has* already
+        # We want to clear out anything that is older than a day that *has* already
         # been rotated.
         rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
             table="event_push_summary_stream_ordering",
@@ -1119,7 +1267,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 SELECT stream_ordering FROM event_push_actions
                 WHERE stream_ordering <= ? AND highlight = 0
                 ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
-            """,
+                """,
                 (
                     max_stream_ordering_to_delete,
                     batch_size,
@@ -1215,16 +1363,18 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
 
             # NB. This assumes event_ids are globally unique since
             # it makes the query easier to index
-            sql = (
-                "SELECT epa.event_id, epa.room_id,"
-                " epa.stream_ordering, epa.topological_ordering,"
-                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
-                " FROM event_push_actions epa, events e"
-                " WHERE epa.event_id = e.event_id"
-                " AND epa.user_id = ? %s"
-                " AND epa.notif = 1"
-                " ORDER BY epa.stream_ordering DESC"
-                " LIMIT ?" % (before_clause,)
+            sql = """
+                SELECT epa.event_id, epa.room_id,
+                    epa.stream_ordering, epa.topological_ordering,
+                    epa.actions, epa.highlight, epa.profile_tag, e.received_ts
+                FROM event_push_actions epa, events e
+                WHERE epa.event_id = e.event_id
+                    AND epa.user_id = ? %s
+                    AND epa.notif = 1
+                ORDER BY epa.stream_ordering DESC
+                LIMIT ?
+            """ % (
+                before_clause,
             )
             txn.execute(sql, args)
             return cast(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 21ba7a540e..1c3b804da0 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1492,7 +1492,7 @@ class PersistEventsStore:
                     event.sender,
                     "url" in event.content and isinstance(event.content["url"], str),
                     event.get_state_key(),
-                    context.rejected or None,
+                    context.rejected,
                 )
                 for event, context in events_and_contexts
             ),
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c547ba7afd..8cf7ae4acd 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -603,7 +603,11 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             map from event id to result
         """
-        event_entry_map = await self._get_events_from_cache(
+        # Shortcut: check if we have any events in the *in memory* cache - this function
+        # may be called repeatedly for the same event so at this point we cannot reach
+        # out to any external cache for performance reasons. The external cache is
+        # checked later on in the `get_missing_events_from_cache_or_db` function below.
+        event_entry_map = self._get_events_from_local_cache(
             event_ids,
         )
 
@@ -635,7 +639,9 @@ class EventsWorkerStore(SQLBaseStore):
 
         if missing_events_ids:
 
-            async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]:
+            async def get_missing_events_from_cache_or_db() -> Dict[
+                str, EventCacheEntry
+            ]:
                 """Fetches the events in `missing_event_ids` from the database.
 
                 Also creates entries in `self._current_event_fetches` to allow
@@ -660,10 +666,18 @@ class EventsWorkerStore(SQLBaseStore):
                 # the events have been redacted, and if so pulling the redaction event
                 # out of the database to check it.
                 #
+                missing_events = {}
                 try:
-                    missing_events = await self._get_events_from_db(
+                    # Try to fetch from any external cache. We already checked the
+                    # in-memory cache above.
+                    missing_events = await self._get_events_from_external_cache(
                         missing_events_ids,
                     )
+                    # Now actually fetch any remaining events from the DB
+                    db_missing_events = await self._get_events_from_db(
+                        missing_events_ids - missing_events.keys(),
+                    )
+                    missing_events.update(db_missing_events)
                 except Exception as e:
                     with PreserveLoggingContext():
                         fetching_deferred.errback(e)
@@ -682,7 +696,7 @@ class EventsWorkerStore(SQLBaseStore):
             # cancellations, since multiple `_get_events_from_cache_or_db` calls can
             # reuse the same fetch.
             missing_events: Dict[str, EventCacheEntry] = await delay_cancellation(
-                get_missing_events_from_db()
+                get_missing_events_from_cache_or_db()
             )
             event_entry_map.update(missing_events)
 
@@ -757,7 +771,54 @@ class EventsWorkerStore(SQLBaseStore):
     async def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
     ) -> Dict[str, EventCacheEntry]:
-        """Fetch events from the caches.
+        """Fetch events from the caches, both in memory and any external.
+
+        May return rejected events.
+
+        Args:
+            events: list of event_ids to fetch
+            update_metrics: Whether to update the cache hit ratio metrics
+        """
+        event_map = self._get_events_from_local_cache(
+            events, update_metrics=update_metrics
+        )
+
+        missing_event_ids = (e for e in events if e not in event_map)
+        event_map.update(
+            await self._get_events_from_external_cache(
+                events=missing_event_ids,
+                update_metrics=update_metrics,
+            )
+        )
+
+        return event_map
+
+    async def _get_events_from_external_cache(
+        self, events: Iterable[str], update_metrics: bool = True
+    ) -> Dict[str, EventCacheEntry]:
+        """Fetch events from any configured external cache.
+
+        May return rejected events.
+
+        Args:
+            events: list of event_ids to fetch
+            update_metrics: Whether to update the cache hit ratio metrics
+        """
+        event_map = {}
+
+        for event_id in events:
+            ret = await self._get_event_cache.get_external(
+                (event_id,), None, update_metrics=update_metrics
+            )
+            if ret:
+                event_map[event_id] = ret
+
+        return event_map
+
+    def _get_events_from_local_cache(
+        self, events: Iterable[str], update_metrics: bool = True
+    ) -> Dict[str, EventCacheEntry]:
+        """Fetch events from the local, in memory, caches.
 
         May return rejected events.
 
@@ -769,7 +830,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         for event_id in events:
             # First check if it's in the event cache
-            ret = await self._get_event_cache.get(
+            ret = self._get_event_cache.get_local(
                 (event_id,), None, update_metrics=update_metrics
             )
             if ret:
@@ -791,7 +852,7 @@ class EventsWorkerStore(SQLBaseStore):
 
                 # We add the entry back into the cache as we want to keep
                 # recently queried events in the cache.
-                await self._get_event_cache.set((event_id,), cache_entry)
+                self._get_event_cache.set_local((event_id,), cache_entry)
 
         return event_map
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index e2cccc688c..93ff4816c8 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -896,7 +896,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # We don't update the event cache hit ratio as it completely throws off
         # the hit ratio counts. After all, we don't populate the cache if we
         # miss it here
-        event_map = await self._get_events_from_cache(
+        event_map = self._get_events_from_local_cache(
             member_event_ids, update_metrics=False
         )
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index b3bdedb04c..aa93109d13 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -834,9 +834,26 @@ class AsyncLruCache(Generic[KT, VT]):
     ) -> Optional[VT]:
         return self._lru_cache.get(key, update_metrics=update_metrics)
 
+    async def get_external(
+        self,
+        key: KT,
+        default: Optional[T] = None,
+        update_metrics: bool = True,
+    ) -> Optional[VT]:
+        # This method should fetch from any configured external cache, in this case noop.
+        return None
+
+    def get_local(
+        self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+    ) -> Optional[VT]:
+        return self._lru_cache.get(key, update_metrics=update_metrics)
+
     async def set(self, key: KT, value: VT) -> None:
         self._lru_cache.set(key, value)
 
+    def set_local(self, key: KT, value: VT) -> None:
+        self._lru_cache.set(key, value)
+
     async def invalidate(self, key: KT) -> None:
         # This method should invalidate any external cache and then invalidate the LruCache.
         return self._lru_cache.invalidate(key)