summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/account_data.py164
-rw-r--r--synapse/storage/databases/main/appservice.py24
-rw-r--r--synapse/storage/databases/main/cache.py16
-rw-r--r--synapse/storage/databases/main/deviceinbox.py279
-rw-r--r--synapse/storage/databases/main/devices.py22
-rw-r--r--synapse/storage/databases/main/event_federation.py325
-rw-r--r--synapse/storage/databases/main/events.py19
-rw-r--r--synapse/storage/databases/main/push_rule.py20
-rw-r--r--synapse/storage/databases/main/relations.py427
9 files changed, 916 insertions, 380 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 5bfa408f74..52146aacc8 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -106,6 +106,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             "AccountDataAndTagsChangeCache", account_max
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "delete_account_data_for_deactivated_users",
+            self._delete_account_data_for_deactivated_users,
+        )
+
     def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream ID for account data stream
 
@@ -549,72 +554,121 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
     async def purge_account_data_for_user(self, user_id: str) -> None:
         """
-        Removes the account data for a user.
+        Removes ALL the account data for a user.
+        Intended to be used upon user deactivation.
 
-        This is intended to be used upon user deactivation and also removes any
-        derived information from account data (e.g. push rules and ignored users).
+        Also purges the user from the ignored_users cache table
+        and the push_rules cache tables.
+        """
 
-        Args:
-            user_id: The user ID to remove data for.
+        await self.db_pool.runInteraction(
+            "purge_account_data_for_user_txn",
+            self._purge_account_data_for_user_txn,
+            user_id,
+        )
+
+    def _purge_account_data_for_user_txn(
+        self, txn: LoggingTransaction, user_id: str
+    ) -> None:
         """
+        See `purge_account_data_for_user`.
+        """
+        # Purge from the primary account_data tables.
+        self.db_pool.simple_delete_txn(
+            txn, table="account_data", keyvalues={"user_id": user_id}
+        )
 
-        def purge_account_data_for_user_txn(txn: LoggingTransaction) -> None:
-            # Purge from the primary account_data tables.
-            self.db_pool.simple_delete_txn(
-                txn, table="account_data", keyvalues={"user_id": user_id}
-            )
+        self.db_pool.simple_delete_txn(
+            txn, table="room_account_data", keyvalues={"user_id": user_id}
+        )
 
-            self.db_pool.simple_delete_txn(
-                txn, table="room_account_data", keyvalues={"user_id": user_id}
-            )
+        # Purge from ignored_users where this user is the ignorer.
+        # N.B. We don't purge where this user is the ignoree, because that
+        #      interferes with other users' account data.
+        #      It's also not this user's data to delete!
+        self.db_pool.simple_delete_txn(
+            txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
+        )
 
-            # Purge from ignored_users where this user is the ignorer.
-            # N.B. We don't purge where this user is the ignoree, because that
-            #      interferes with other users' account data.
-            #      It's also not this user's data to delete!
-            self.db_pool.simple_delete_txn(
-                txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
-            )
+        # Remove the push rules
+        self.db_pool.simple_delete_txn(
+            txn, table="push_rules", keyvalues={"user_name": user_id}
+        )
+        self.db_pool.simple_delete_txn(
+            txn, table="push_rules_enable", keyvalues={"user_name": user_id}
+        )
+        self.db_pool.simple_delete_txn(
+            txn, table="push_rules_stream", keyvalues={"user_id": user_id}
+        )
 
-            # Remove the push rules
-            self.db_pool.simple_delete_txn(
-                txn, table="push_rules", keyvalues={"user_name": user_id}
-            )
-            self.db_pool.simple_delete_txn(
-                txn, table="push_rules_enable", keyvalues={"user_name": user_id}
-            )
-            self.db_pool.simple_delete_txn(
-                txn, table="push_rules_stream", keyvalues={"user_id": user_id}
-            )
+        # Invalidate caches as appropriate
+        self._invalidate_cache_and_stream(
+            txn, self.get_account_data_for_room_and_type, (user_id,)
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_account_data_for_user, (user_id,)
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_global_account_data_by_type_for_user, (user_id,)
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_account_data_for_room, (user_id,)
+        )
+        self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
+        self._invalidate_cache_and_stream(
+            txn, self.get_push_rules_enabled_for_user, (user_id,)
+        )
+        # This user might be contained in the ignored_by cache for other users,
+        # so we have to invalidate it all.
+        self._invalidate_all_cache_and_stream(txn, self.ignored_by)
 
-            # Invalidate caches as appropriate
-            self._invalidate_cache_and_stream(
-                txn, self.get_account_data_for_room_and_type, (user_id,)
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_account_data_for_user, (user_id,)
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_global_account_data_by_type_for_user, (user_id,)
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_account_data_for_room, (user_id,)
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_push_rules_for_user, (user_id,)
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.get_push_rules_enabled_for_user, (user_id,)
-            )
-            # This user might be contained in the ignored_by cache for other users,
-            # so we have to invalidate it all.
-            self._invalidate_all_cache_and_stream(txn, self.ignored_by)
+    async def _delete_account_data_for_deactivated_users(
+        self, progress: dict, batch_size: int
+    ) -> int:
+        """
+        Retroactively purges account data for users that have already been deactivated.
+        Gets run as a background update caused by a schema delta.
+        """
 
-        await self.db_pool.runInteraction(
-            "purge_account_data_for_user_txn",
-            purge_account_data_for_user_txn,
+        last_user: str = progress.get("last_user", "")
+
+        def _delete_account_data_for_deactivated_users_txn(
+            txn: LoggingTransaction,
+        ) -> int:
+            sql = """
+                SELECT name FROM users
+                WHERE deactivated = ? and name > ?
+                ORDER BY name ASC
+                LIMIT ?
+            """
+
+            txn.execute(sql, (1, last_user, batch_size))
+            users = [row[0] for row in txn]
+
+            for user in users:
+                self._purge_account_data_for_user_txn(txn, user_id=user)
+
+            if users:
+                self.db_pool.updates._background_update_progress_txn(
+                    txn,
+                    "delete_account_data_for_deactivated_users",
+                    {"last_user": users[-1]},
+                )
+
+            return len(users)
+
+        number_deleted = await self.db_pool.runInteraction(
+            "_delete_account_data_for_deactivated_users",
+            _delete_account_data_for_deactivated_users_txn,
         )
 
+        if number_deleted < batch_size:
+            await self.db_pool.updates._end_background_update(
+                "delete_account_data_for_deactivated_users"
+            )
+
+        return number_deleted
+
 
 class AccountDataStore(AccountDataWorkerStore):
     pass
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 2bb5288431..304814af5d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -198,6 +198,7 @@ class ApplicationServiceTransactionWorkerStore(
         service: ApplicationService,
         events: List[EventBase],
         ephemeral: List[JsonDict],
+        to_device_messages: List[JsonDict],
     ) -> AppServiceTransaction:
         """Atomically creates a new transaction for this application service
         with the given list of events. Ephemeral events are NOT persisted to the
@@ -207,6 +208,7 @@ class ApplicationServiceTransactionWorkerStore(
             service: The service who the transaction is for.
             events: A list of persistent events to put in the transaction.
             ephemeral: A list of ephemeral events to put in the transaction.
+            to_device_messages: A list of to-device messages to put in the transaction.
 
         Returns:
             A new transaction.
@@ -237,7 +239,11 @@ class ApplicationServiceTransactionWorkerStore(
                 (service.id, new_txn_id, event_ids),
             )
             return AppServiceTransaction(
-                service=service, id=new_txn_id, events=events, ephemeral=ephemeral
+                service=service,
+                id=new_txn_id,
+                events=events,
+                ephemeral=ephemeral,
+                to_device_messages=to_device_messages,
             )
 
         return await self.db_pool.runInteraction(
@@ -330,7 +336,11 @@ class ApplicationServiceTransactionWorkerStore(
         events = await self.get_events_as_list(event_ids)
 
         return AppServiceTransaction(
-            service=service, id=entry["txn_id"], events=events, ephemeral=[]
+            service=service,
+            id=entry["txn_id"],
+            events=events,
+            ephemeral=[],
+            to_device_messages=[],
         )
 
     def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@@ -391,7 +401,7 @@ class ApplicationServiceTransactionWorkerStore(
     async def get_type_stream_id_for_appservice(
         self, service: ApplicationService, type: str
     ) -> int:
-        if type not in ("read_receipt", "presence"):
+        if type not in ("read_receipt", "presence", "to_device"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (type,)
@@ -415,16 +425,16 @@ class ApplicationServiceTransactionWorkerStore(
             "get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
         )
 
-    async def set_type_stream_id_for_appservice(
+    async def set_appservice_stream_type_pos(
         self, service: ApplicationService, stream_type: str, pos: Optional[int]
     ) -> None:
-        if stream_type not in ("read_receipt", "presence"):
+        if stream_type not in ("read_receipt", "presence", "to_device"):
             raise ValueError(
                 "Expected type to be a valid application stream id type, got %s"
                 % (stream_type,)
             )
 
-        def set_type_stream_id_for_appservice_txn(txn):
+        def set_appservice_stream_type_pos_txn(txn):
             stream_id_type = "%s_stream_id" % stream_type
             txn.execute(
                 "UPDATE application_services_state SET %s = ? WHERE as_id=?"
@@ -433,7 +443,7 @@ class ApplicationServiceTransactionWorkerStore(
             )
 
         await self.db_pool.runInteraction(
-            "set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
+            "set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
         )
 
 
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 0024348067..c428dd5596 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -15,7 +15,7 @@
 
 import itertools
 import logging
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple
 
 from synapse.api.constants import EventTypes
 from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -25,7 +25,11 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
 )
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.util.iterutils import batch_iter
 
@@ -236,7 +240,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         txn.call_after(cache_func.invalidate_all)
         self._send_invalidation_to_replication(txn, cache_func.__name__, None)
 
-    def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+    def _invalidate_state_caches_and_stream(
+        self, txn: LoggingTransaction, room_id: str, members_changed: Collection[str]
+    ) -> None:
         """Special case invalidation of caches based on current state.
 
         We special case this so that we can batch the cache invalidations into a
@@ -244,8 +250,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         Args:
             txn
-            room_id (str): Room where state changed
-            members_changed (iterable[str]): The user_ids of members that have changed
+            room_id: Room where state changed
+            members_changed: The user_ids of members that have changed
         """
         txn.call_after(self._invalidate_state_caches, room_id, members_changed)
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 4eca97189b..1392363de1 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -24,6 +24,7 @@ from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
+    make_in_list_sql_clause,
 )
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
@@ -136,63 +137,263 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     def get_to_device_stream_token(self):
         return self._device_inbox_id_gen.get_current_token()
 
-    async def get_new_messages_for_device(
+    async def get_messages_for_user_devices(
+        self,
+        user_ids: Collection[str],
+        from_stream_id: int,
+        to_stream_id: int,
+    ) -> Dict[Tuple[str, str], List[JsonDict]]:
+        """
+        Retrieve to-device messages for a given set of users.
+
+        Only to-device messages with stream ids between the given boundaries
+        (from < X <= to) are returned.
+
+        Args:
+            user_ids: The users to retrieve to-device messages for.
+            from_stream_id: The lower boundary of stream id to filter with (exclusive).
+            to_stream_id: The upper boundary of stream id to filter with (inclusive).
+
+        Returns:
+            A dictionary of (user id, device id) -> list of to-device messages.
+        """
+        # We expect the stream ID returned by _get_device_messages to always
+        # be to_stream_id. So, no need to return it from this function.
+        (
+            user_id_device_id_to_messages,
+            last_processed_stream_id,
+        ) = await self._get_device_messages(
+            user_ids=user_ids,
+            from_stream_id=from_stream_id,
+            to_stream_id=to_stream_id,
+        )
+
+        assert (
+            last_processed_stream_id == to_stream_id
+        ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
+
+        return user_id_device_id_to_messages
+
+    async def get_messages_for_device(
         self,
         user_id: str,
-        device_id: Optional[str],
-        last_stream_id: int,
-        current_stream_id: int,
+        device_id: str,
+        from_stream_id: int,
+        to_stream_id: int,
         limit: int = 100,
-    ) -> Tuple[List[dict], int]:
+    ) -> Tuple[List[JsonDict], int]:
         """
+        Retrieve to-device messages for a single user device.
+
+        Only to-device messages with stream ids between the given boundaries
+        (from < X <= to) are returned.
+
         Args:
-            user_id: The recipient user_id.
-            device_id: The recipient device_id.
-            last_stream_id: The last stream ID checked.
-            current_stream_id: The current position of the to device
-                message stream.
-            limit: The maximum number of messages to retrieve.
+            user_id: The ID of the user to retrieve messages for.
+            device_id: The ID of the device to retrieve to-device messages for.
+            from_stream_id: The lower boundary of stream id to filter with (exclusive).
+            to_stream_id: The upper boundary of stream id to filter with (inclusive).
+            limit: A limit on the number of to-device messages returned.
 
         Returns:
             A tuple containing:
-                * A list of messages for the device.
-                * The max stream token of these messages. There may be more to retrieve
-                  if the given limit was reached.
+                * A list of to-device messages within the given stream id range intended for
+                  the given user / device combo.
+                * The last-processed stream ID. Subsequent calls of this function with the
+                  same device should pass this value as 'from_stream_id'.
         """
-        has_changed = self._device_inbox_stream_cache.has_entity_changed(
-            user_id, last_stream_id
+        (
+            user_id_device_id_to_messages,
+            last_processed_stream_id,
+        ) = await self._get_device_messages(
+            user_ids=[user_id],
+            device_id=device_id,
+            from_stream_id=from_stream_id,
+            to_stream_id=to_stream_id,
+            limit=limit,
         )
-        if not has_changed:
-            return [], current_stream_id
 
-        def get_new_messages_for_device_txn(txn):
-            sql = (
-                "SELECT stream_id, message_json FROM device_inbox"
-                " WHERE user_id = ? AND device_id = ?"
-                " AND ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC"
-                " LIMIT ?"
+        if not user_id_device_id_to_messages:
+            # There were no messages!
+            return [], to_stream_id
+
+        # Extract the messages, no need to return the user and device ID again
+        to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
+
+        return to_device_messages, last_processed_stream_id
+
+    async def _get_device_messages(
+        self,
+        user_ids: Collection[str],
+        from_stream_id: int,
+        to_stream_id: int,
+        device_id: Optional[str] = None,
+        limit: Optional[int] = None,
+    ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
+        """
+        Retrieve pending to-device messages for a collection of user devices.
+
+        Only to-device messages with stream ids between the given boundaries
+        (from < X <= to) are returned.
+
+        Note that a stream ID can be shared by multiple copies of the same message with
+        different recipient devices. Stream IDs are only unique in the context of a single
+        user ID / device ID pair. Thus, applying a limit (of messages to return) when working
+        with a sliding window of stream IDs is only possible when querying messages of a
+        single user device.
+
+        Finally, note that device IDs are not unique across users.
+
+        Args:
+            user_ids: The user IDs to filter device messages by.
+            from_stream_id: The lower boundary of stream id to filter with (exclusive).
+            to_stream_id: The upper boundary of stream id to filter with (inclusive).
+            device_id: A device ID to query to-device messages for. If not provided, to-device
+                messages from all device IDs for the given user IDs will be queried. May not be
+                provided if `user_ids` contains more than one entry.
+            limit: The maximum number of to-device messages to return. Can only be used when
+                passing a single user ID / device ID tuple.
+
+        Returns:
+            A tuple containing:
+                * A dict of (user_id, device_id) -> list of to-device messages
+                * The last-processed stream ID. If this is less than `to_stream_id`, then
+                    there may be more messages to retrieve. If `limit` is not set, then this
+                    is always equal to 'to_stream_id'.
+        """
+        if not user_ids:
+            logger.warning("No users provided upon querying for device IDs")
+            return {}, to_stream_id
+
+        # Prevent a query for one user's device also retrieving another user's device with
+        # the same device ID (device IDs are not unique across users).
+        if len(user_ids) > 1 and device_id is not None:
+            raise AssertionError(
+                "Programming error: 'device_id' cannot be supplied to "
+                "_get_device_messages when >1 user_id has been provided"
             )
-            txn.execute(
-                sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
+
+        # A limit can only be applied when querying for a single user ID / device ID tuple.
+        # See the docstring of this function for more details.
+        if limit is not None and device_id is None:
+            raise AssertionError(
+                "Programming error: _get_device_messages was passed 'limit' "
+                "without a specific user_id/device_id"
             )
 
-            messages = []
-            stream_pos = current_stream_id
+        user_ids_to_query: Set[str] = set()
+        device_ids_to_query: Set[str] = set()
+
+        # Note that a device ID could be an empty str
+        if device_id is not None:
+            # If a device ID was passed, use it to filter results.
+            # Otherwise, device IDs will be derived from the given collection of user IDs.
+            device_ids_to_query.add(device_id)
+
+        # Determine which users have devices with pending messages
+        for user_id in user_ids:
+            if self._device_inbox_stream_cache.has_entity_changed(
+                user_id, from_stream_id
+            ):
+                # This user has new messages sent to them. Query messages for them
+                user_ids_to_query.add(user_id)
+
+        def get_device_messages_txn(txn: LoggingTransaction):
+            # Build a query to select messages from any of the given devices that
+            # are between the given stream id bounds.
+
+            # If a list of device IDs was not provided, retrieve all devices IDs
+            # for the given users. We explicitly do not query hidden devices, as
+            # hidden devices should not receive to-device messages.
+            # Note that this is more efficient than just dropping `device_id` from the query,
+            # since device_inbox has an index on `(user_id, device_id, stream_id)`
+            if not device_ids_to_query:
+                user_device_dicts = self.db_pool.simple_select_many_txn(
+                    txn,
+                    table="devices",
+                    column="user_id",
+                    iterable=user_ids_to_query,
+                    keyvalues={"user_id": user_id, "hidden": False},
+                    retcols=("device_id",),
+                )
 
-            for row in txn:
-                stream_pos = row[0]
-                messages.append(db_to_json(row[1]))
+                device_ids_to_query.update(
+                    {row["device_id"] for row in user_device_dicts}
+                )
 
-            # If the limit was not reached we know that there's no more data for this
-            # user/device pair up to current_stream_id.
-            if len(messages) < limit:
-                stream_pos = current_stream_id
+            if not device_ids_to_query:
+                # We've ended up with no devices to query.
+                return {}, to_stream_id
 
-            return messages, stream_pos
+            # We include both user IDs and device IDs in this query, as we have an index
+            # (device_inbox_user_stream_id) for them.
+            user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
+                self.database_engine, "user_id", user_ids_to_query
+            )
+            (
+                device_id_many_clause_sql,
+                device_id_many_clause_args,
+            ) = make_in_list_sql_clause(
+                self.database_engine, "device_id", device_ids_to_query
+            )
+
+            sql = f"""
+                SELECT stream_id, user_id, device_id, message_json FROM device_inbox
+                WHERE {user_id_many_clause_sql}
+                AND {device_id_many_clause_sql}
+                AND ? < stream_id AND stream_id <= ?
+                ORDER BY stream_id ASC
+            """
+            sql_args = (
+                *user_id_many_clause_args,
+                *device_id_many_clause_args,
+                from_stream_id,
+                to_stream_id,
+            )
+
+            # If a limit was provided, limit the data retrieved from the database
+            if limit is not None:
+                sql += "LIMIT ?"
+                sql_args += (limit,)
+
+            txn.execute(sql, sql_args)
+
+            # Create and fill a dictionary of (user ID, device ID) -> list of messages
+            # intended for each device.
+            last_processed_stream_pos = to_stream_id
+            recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
+            rowcount = 0
+            for row in txn:
+                rowcount += 1
+
+                last_processed_stream_pos = row[0]
+                recipient_user_id = row[1]
+                recipient_device_id = row[2]
+                message_dict = db_to_json(row[3])
+
+                # Store the device details
+                recipient_device_to_messages.setdefault(
+                    (recipient_user_id, recipient_device_id), []
+                ).append(message_dict)
+
+            if limit is not None and rowcount == limit:
+                # We ended up bumping up against the message limit. There may be more messages
+                # to retrieve. Return what we have, as well as the last stream position that
+                # was processed.
+                #
+                # The caller is expected to set this as the lower (exclusive) bound
+                # for the next query of this device.
+                return recipient_device_to_messages, last_processed_stream_pos
+
+            # The limit was not reached, thus we know that recipient_device_to_messages
+            # contains all to-device messages for the given device and stream id range.
+            #
+            # We return to_stream_id, which the caller should then provide as the lower
+            # (exclusive) bound on the next query of this device.
+            return recipient_device_to_messages, to_stream_id
 
         return await self.db_pool.runInteraction(
-            "get_new_messages_for_device", get_new_messages_for_device_txn
+            "get_device_messages", get_device_messages_txn
         )
 
     @trace
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index b2a5cd9a65..8d845fe951 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1496,13 +1496,23 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def add_device_change_to_streams(
-        self, user_id: str, device_ids: Collection[str], hosts: List[str]
-    ) -> int:
+        self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+    ) -> Optional[int]:
         """Persist that a user's devices have been updated, and which hosts
         (if any) should be poked.
+
+        Args:
+            user_id: The ID of the user whose device changed.
+            device_ids: The IDs of any changed devices. If empty, this function will
+                return None.
+            hosts: The remote destinations that should be notified of the change.
+
+        Returns:
+            The maximum stream ID of device list updates that were added to the database, or
+            None if no updates were added.
         """
         if not device_ids:
-            return
+            return None
 
         async with self._device_list_id_gen.get_next_mult(
             len(device_ids)
@@ -1573,11 +1583,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self,
         txn: LoggingTransaction,
         user_id: str,
-        device_ids: Collection[str],
-        hosts: List[str],
+        device_ids: Iterable[str],
+        hosts: Collection[str],
         stream_ids: List[str],
         context: Dict[str, str],
-    ):
+    ) -> None:
         for host in hosts:
             txn.call_after(
                 self._device_list_federation_stream_cache.entity_has_changed,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index ca71f073fc..277e6422eb 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -16,9 +16,10 @@ import logging
 from queue import Empty, PriorityQueue
 from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
 
+import attr
 from prometheus_client import Counter, Gauge
 
-from synapse.api.constants import MAX_DEPTH
+from synapse.api.constants import MAX_DEPTH, EventTypes
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import EventFormatVersions, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
@@ -60,6 +61,15 @@ pdus_pruned_from_federation_queue = Counter(
 logger = logging.getLogger(__name__)
 
 
+# All the info we need while iterating the DAG while backfilling
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class BackfillQueueNavigationItem:
+    depth: int
+    stream_ordering: int
+    event_id: str
+    type: str
+
+
 class _NoChainCoverIndex(Exception):
     def __init__(self, room_id: str):
         super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@@ -74,6 +84,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
     ):
         super().__init__(database, db_conn, hs)
 
+        self.hs = hs
+
         if hs.config.worker.run_background_tasks:
             hs.get_clock().looping_call(
                 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
@@ -109,7 +121,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         room_id: str,
         event_ids: Collection[str],
         include_given: bool = False,
-    ) -> List[str]:
+    ) -> Set[str]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
@@ -118,7 +130,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             include_given: include the given events in result
 
         Returns:
-            list of event_ids
+            set of event_ids
         """
 
         # Check if we have indexed the room so we can use the chain cover
@@ -147,7 +159,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
     def _get_auth_chain_ids_using_cover_index_txn(
         self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
-    ) -> List[str]:
+    ) -> Set[str]:
         """Calculates the auth chain IDs using the chain index."""
 
         # First we look up the chain ID/sequence numbers for the given events.
@@ -260,11 +272,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 txn.execute(sql, (chain_id, max_no))
                 results.update(r for r, in txn)
 
-        return list(results)
+        return results
 
     def _get_auth_chain_ids_txn(
         self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
-    ) -> List[str]:
+    ) -> Set[str]:
         """Calculates the auth chain IDs.
 
         This is used when we don't have a cover index for the room.
@@ -319,7 +331,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             front = new_front
             results.update(front)
 
-        return list(results)
+        return results
 
     async def get_auth_chain_difference(
         self, room_id: str, state_sets: List[Set[str]]
@@ -737,7 +749,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             room_id,
         )
 
-    async def get_insertion_event_backwards_extremities_in_room(
+    async def get_insertion_event_backward_extremities_in_room(
         self, room_id
     ) -> Dict[str, int]:
         """Get the insertion events we know about that we haven't backfilled yet.
@@ -754,7 +766,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             Map from event_id to depth
         """
 
-        def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id):
+        def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
             sql = """
                 SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
                 /* We only want insertion events that are also marked as backwards extremities */
@@ -770,8 +782,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             return dict(txn)
 
         return await self.db_pool.runInteraction(
-            "get_insertion_event_backwards_extremities_in_room",
-            get_insertion_event_backwards_extremities_in_room_txn,
+            "get_insertion_event_backward_extremities_in_room",
+            get_insertion_event_backward_extremities_in_room_txn,
             room_id,
         )
 
@@ -997,143 +1009,242 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
-    async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
-        """Get a list of Events for a given topic that occurred before (and
-        including) the events in event_list. Return a list of max size `limit`
+    def _get_connected_batch_event_backfill_results_txn(
+        self, txn: LoggingTransaction, insertion_event_id: str, limit: int
+    ) -> List[BackfillQueueNavigationItem]:
+        """
+        Find any batch connections of a given insertion event.
+        A batch event points at a insertion event via:
+        batch_event.content[MSC2716_BATCH_ID] -> insertion_event.content[MSC2716_NEXT_BATCH_ID]
 
         Args:
-            room_id
-            event_list
-            limit
+            txn: The database transaction to use
+            insertion_event_id: The event ID to navigate from. We will find
+                batch events that point back at this insertion event.
+            limit: Max number of event ID's to query for and return
+
+        Returns:
+            List of batch events that the backfill queue can process
+        """
+        batch_connection_query = """
+            SELECT e.depth, e.stream_ordering, c.event_id, e.type FROM insertion_events AS i
+            /* Find the batch that connects to the given insertion event */
+            INNER JOIN batch_events AS c
+            ON i.next_batch_id = c.batch_id
+            /* Get the depth of the batch start event from the events table */
+            INNER JOIN events AS e USING (event_id)
+            /* Find an insertion event which matches the given event_id */
+            WHERE i.event_id = ?
+            LIMIT ?
         """
-        event_ids = await self.db_pool.runInteraction(
-            "get_backfill_events",
-            self._get_backfill_events,
-            room_id,
-            event_list,
-            limit,
-        )
-        events = await self.get_events_as_list(event_ids)
-        return sorted(events, key=lambda e: -e.depth)
 
-    def _get_backfill_events(self, txn, room_id, event_list, limit):
-        logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
+        # Find any batch connections for the given insertion event
+        txn.execute(
+            batch_connection_query,
+            (insertion_event_id, limit),
+        )
+        return [
+            BackfillQueueNavigationItem(
+                depth=row[0],
+                stream_ordering=row[1],
+                event_id=row[2],
+                type=row[3],
+            )
+            for row in txn
+        ]
 
-        event_results = set()
+    def _get_connected_prev_event_backfill_results_txn(
+        self, txn: LoggingTransaction, event_id: str, limit: int
+    ) -> List[BackfillQueueNavigationItem]:
+        """
+        Find any events connected by prev_event the specified event_id.
 
-        # We want to make sure that we do a breadth-first, "depth" ordered
-        # search.
+        Args:
+            txn: The database transaction to use
+            event_id: The event ID to navigate from
+            limit: Max number of event ID's to query for and return
 
+        Returns:
+            List of prev events that the backfill queue can process
+        """
         # Look for the prev_event_id connected to the given event_id
-        query = """
-            SELECT depth, prev_event_id FROM event_edges
-            /* Get the depth of the prev_event_id from the events table */
+        connected_prev_event_query = """
+            SELECT depth, stream_ordering, prev_event_id, events.type FROM event_edges
+            /* Get the depth and stream_ordering of the prev_event_id from the events table */
             INNER JOIN events
             ON prev_event_id = events.event_id
-            /* Find an event which matches the given event_id */
+            /* Look for an edge which matches the given event_id */
             WHERE event_edges.event_id = ?
             AND event_edges.is_state = ?
+            /* Because we can have many events at the same depth,
+            * we want to also tie-break and sort on stream_ordering */
+            ORDER BY depth DESC, stream_ordering DESC
             LIMIT ?
         """
 
-        # Look for the "insertion" events connected to the given event_id
-        connected_insertion_event_query = """
-            SELECT e.depth, i.event_id FROM insertion_event_edges AS i
-            /* Get the depth of the insertion event from the events table */
-            INNER JOIN events AS e USING (event_id)
-            /* Find an insertion event which points via prev_events to the given event_id */
-            WHERE i.insertion_prev_event_id = ?
-            LIMIT ?
+        txn.execute(
+            connected_prev_event_query,
+            (event_id, False, limit),
+        )
+        return [
+            BackfillQueueNavigationItem(
+                depth=row[0],
+                stream_ordering=row[1],
+                event_id=row[2],
+                type=row[3],
+            )
+            for row in txn
+        ]
+
+    async def get_backfill_events(
+        self, room_id: str, seed_event_id_list: list, limit: int
+    ):
+        """Get a list of Events for a given topic that occurred before (and
+        including) the events in seed_event_id_list. Return a list of max size `limit`
+
+        Args:
+            room_id
+            seed_event_id_list
+            limit
         """
+        event_ids = await self.db_pool.runInteraction(
+            "get_backfill_events",
+            self._get_backfill_events,
+            room_id,
+            seed_event_id_list,
+            limit,
+        )
+        events = await self.get_events_as_list(event_ids)
+        return sorted(
+            events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
+        )
 
-        # Find any batch connections of a given insertion event
-        batch_connection_query = """
-            SELECT e.depth, c.event_id FROM insertion_events AS i
-            /* Find the batch that connects to the given insertion event */
-            INNER JOIN batch_events AS c
-            ON i.next_batch_id = c.batch_id
-            /* Get the depth of the batch start event from the events table */
-            INNER JOIN events AS e USING (event_id)
-            /* Find an insertion event which matches the given event_id */
-            WHERE i.event_id = ?
-            LIMIT ?
+    def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
+        """
+        We want to make sure that we do a breadth-first, "depth" ordered search.
+        We also handle navigating historical branches of history connected by
+        insertion and batch events.
         """
+        logger.debug(
+            "_get_backfill_events(room_id=%s): seeding backfill with seed_event_id_list=%s limit=%s",
+            room_id,
+            seed_event_id_list,
+            limit,
+        )
+
+        event_id_results = set()
 
         # In a PriorityQueue, the lowest valued entries are retrieved first.
-        # We're using depth as the priority in the queue.
-        # Depth is lowest at the oldest-in-time message and highest and
-        # newest-in-time message. We add events to the queue with a negative depth so that
-        # we process the newest-in-time messages first going backwards in time.
+        # We're using depth as the priority in the queue and tie-break based on
+        # stream_ordering. Depth is lowest at the oldest-in-time message and
+        # highest and newest-in-time message. We add events to the queue with a
+        # negative depth so that we process the newest-in-time messages first
+        # going backwards in time. stream_ordering follows the same pattern.
         queue = PriorityQueue()
 
-        for event_id in event_list:
-            depth = self.db_pool.simple_select_one_onecol_txn(
+        for seed_event_id in seed_event_id_list:
+            event_lookup_result = self.db_pool.simple_select_one_txn(
                 txn,
                 table="events",
-                keyvalues={"event_id": event_id, "room_id": room_id},
-                retcol="depth",
+                keyvalues={"event_id": seed_event_id, "room_id": room_id},
+                retcols=(
+                    "type",
+                    "depth",
+                    "stream_ordering",
+                ),
                 allow_none=True,
             )
 
-            if depth:
-                queue.put((-depth, event_id))
+            if event_lookup_result is not None:
+                logger.debug(
+                    "_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
+                    room_id,
+                    seed_event_id,
+                    event_lookup_result["depth"],
+                    event_lookup_result["stream_ordering"],
+                    event_lookup_result["type"],
+                )
+
+                if event_lookup_result["depth"]:
+                    queue.put(
+                        (
+                            -event_lookup_result["depth"],
+                            -event_lookup_result["stream_ordering"],
+                            seed_event_id,
+                            event_lookup_result["type"],
+                        )
+                    )
 
-        while not queue.empty() and len(event_results) < limit:
+        while not queue.empty() and len(event_id_results) < limit:
             try:
-                _, event_id = queue.get_nowait()
+                _, _, event_id, event_type = queue.get_nowait()
             except Empty:
                 break
 
-            if event_id in event_results:
+            if event_id in event_id_results:
                 continue
 
-            event_results.add(event_id)
+            event_id_results.add(event_id)
 
             # Try and find any potential historical batches of message history.
-            #
-            # First we look for an insertion event connected to the current
-            # event (by prev_event). If we find any, we need to go and try to
-            # find any batch events connected to the insertion event (by
-            # batch_id). If we find any, we'll add them to the queue and
-            # navigate up the DAG like normal in the next iteration of the loop.
-            txn.execute(
-                connected_insertion_event_query, (event_id, limit - len(event_results))
-            )
-            connected_insertion_event_id_results = txn.fetchall()
-            logger.debug(
-                "_get_backfill_events: connected_insertion_event_query %s",
-                connected_insertion_event_id_results,
-            )
-            for row in connected_insertion_event_id_results:
-                connected_insertion_event_depth = row[0]
-                connected_insertion_event = row[1]
-                queue.put((-connected_insertion_event_depth, connected_insertion_event))
+            if self.hs.config.experimental.msc2716_enabled:
+                # We need to go and try to find any batch events connected
+                # to a given insertion event (by batch_id). If we find any, we'll
+                # add them to the queue and navigate up the DAG like normal in the
+                # next iteration of the loop.
+                if event_type == EventTypes.MSC2716_INSERTION:
+                    # Find any batch connections for the given insertion event
+                    connected_batch_event_backfill_results = (
+                        self._get_connected_batch_event_backfill_results_txn(
+                            txn, event_id, limit - len(event_id_results)
+                        )
+                    )
+                    logger.debug(
+                        "_get_backfill_events(room_id=%s): connected_batch_event_backfill_results=%s",
+                        room_id,
+                        connected_batch_event_backfill_results,
+                    )
+                    for (
+                        connected_batch_event_backfill_item
+                    ) in connected_batch_event_backfill_results:
+                        if (
+                            connected_batch_event_backfill_item.event_id
+                            not in event_id_results
+                        ):
+                            queue.put(
+                                (
+                                    -connected_batch_event_backfill_item.depth,
+                                    -connected_batch_event_backfill_item.stream_ordering,
+                                    connected_batch_event_backfill_item.event_id,
+                                    connected_batch_event_backfill_item.type,
+                                )
+                            )
 
-                # Find any batch connections for the given insertion event
-                txn.execute(
-                    batch_connection_query,
-                    (connected_insertion_event, limit - len(event_results)),
-                )
-                batch_start_event_id_results = txn.fetchall()
-                logger.debug(
-                    "_get_backfill_events: batch_start_event_id_results %s",
-                    batch_start_event_id_results,
+            # Now we just look up the DAG by prev_events as normal
+            connected_prev_event_backfill_results = (
+                self._get_connected_prev_event_backfill_results_txn(
+                    txn, event_id, limit - len(event_id_results)
                 )
-                for row in batch_start_event_id_results:
-                    if row[1] not in event_results:
-                        queue.put((-row[0], row[1]))
-
-            txn.execute(query, (event_id, False, limit - len(event_results)))
-            prev_event_id_results = txn.fetchall()
+            )
             logger.debug(
-                "_get_backfill_events: prev_event_ids %s", prev_event_id_results
+                "_get_backfill_events(room_id=%s): connected_prev_event_backfill_results=%s",
+                room_id,
+                connected_prev_event_backfill_results,
             )
+            for (
+                connected_prev_event_backfill_item
+            ) in connected_prev_event_backfill_results:
+                if connected_prev_event_backfill_item.event_id not in event_id_results:
+                    queue.put(
+                        (
+                            -connected_prev_event_backfill_item.depth,
+                            -connected_prev_event_backfill_item.stream_ordering,
+                            connected_prev_event_backfill_item.event_id,
+                            connected_prev_event_backfill_item.type,
+                        )
+                    )
 
-            for row in prev_event_id_results:
-                if row[1] not in event_results:
-                    queue.put((-row[0], row[1]))
-
-        return event_results
+        return event_id_results
 
     async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
         ids = await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b7554154ac..5246fccad5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1801,9 +1801,7 @@ class PersistEventsStore:
         )
 
         if rel_type == RelationTypes.REPLACE:
-            txn.call_after(
-                self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
-            )
+            txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
 
         if rel_type == RelationTypes.THREAD:
             txn.call_after(
@@ -1814,7 +1812,7 @@ class PersistEventsStore:
             # potentially error-prone) so it is always invalidated.
             txn.call_after(
                 self.store.get_thread_participated.invalidate,
-                (parent_id, event.room_id, event.sender),
+                (parent_id, event.sender),
             )
 
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
@@ -2215,9 +2213,14 @@ class PersistEventsStore:
             "   SELECT 1 FROM event_backward_extremities"
             "   WHERE event_id = ? AND room_id = ?"
             " )"
+            # 1. Don't add an event as a extremity again if we already persisted it
+            # as a non-outlier.
+            # 2. Don't add an outlier as an extremity if it has no prev_events
             " AND NOT EXISTS ("
-            "   SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
-            "   AND outlier = ?"
+            "   SELECT 1 FROM events"
+            "   LEFT JOIN event_edges edge"
+            "   ON edge.event_id = events.event_id"
+            "   WHERE events.event_id = ? AND events.room_id = ? AND (events.outlier = ? OR edge.event_id IS NULL)"
             " )"
         )
 
@@ -2243,6 +2246,10 @@ class PersistEventsStore:
                 (ev.event_id, ev.room_id)
                 for ev in events
                 if not ev.internal_metadata.is_outlier()
+                # If we encountered an event with no prev_events, then we might
+                # as well remove it now because it won't ever have anything else
+                # to backfill from.
+                or len(ev.prev_event_ids()) == 0
             ],
         )
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index e01c94930a..92539f5d41 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-def _load_rules(rawrules, enabled_map, use_new_defaults=False):
+def _load_rules(rawrules, enabled_map):
     ruleslist = []
     for rawrule in rawrules:
         rule = dict(rawrule)
@@ -52,7 +52,7 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False):
         ruleslist.append(rule)
 
     # We're going to be mutating this a lot, so do a deep copy
-    rules = list(list_with_base_rules(ruleslist, use_new_defaults))
+    rules = list(list_with_base_rules(ruleslist))
 
     for i, rule in enumerate(rules):
         rule_id = rule["rule_id"]
@@ -112,10 +112,6 @@ class PushRulesWorkerStore(
             prefilled_cache=push_rules_prefill,
         )
 
-        self._users_new_default_push_rules = (
-            hs.config.server.users_new_default_push_rules
-        )
-
     @abc.abstractmethod
     def get_max_push_rules_stream_id(self):
         """Get the position of the push rules stream.
@@ -145,9 +141,7 @@ class PushRulesWorkerStore(
 
         enabled_map = await self.get_push_rules_enabled_for_user(user_id)
 
-        use_new_defaults = user_id in self._users_new_default_push_rules
-
-        return _load_rules(rows, enabled_map, use_new_defaults)
+        return _load_rules(rows, enabled_map)
 
     @cached(max_entries=5000)
     async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
@@ -206,13 +200,7 @@ class PushRulesWorkerStore(
         enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
 
         for user_id, rules in results.items():
-            use_new_defaults = user_id in self._users_new_default_push_rules
-
-            results[user_id] = _load_rules(
-                rules,
-                enabled_map_by_user.get(user_id, {}),
-                use_new_defaults,
-            )
+            results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
 
         return results
 
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 37468a5183..e2c27e594b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,12 +13,23 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    Union,
+    cast,
+)
 
 import attr
 from frozendict import frozendict
 
-from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
@@ -28,16 +39,14 @@ from synapse.storage.database import (
     make_in_list_sql_clause,
 )
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
-from synapse.storage.relations import (
-    AggregationPaginationToken,
-    PaginationChunk,
-    RelationPaginationToken,
-)
-from synapse.types import JsonDict
-from synapse.util.caches.descriptors import cached
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
@@ -87,8 +96,8 @@ class RelationsWorkerStore(SQLBaseStore):
         aggregation_key: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
-        from_token: Optional[RelationPaginationToken] = None,
-        to_token: Optional[RelationPaginationToken] = None,
+        from_token: Optional[StreamToken] = None,
+        to_token: Optional[StreamToken] = None,
     ) -> PaginationChunk:
         """Get a list of relations for an event, ordered by topological ordering.
 
@@ -127,8 +136,10 @@ class RelationsWorkerStore(SQLBaseStore):
         pagination_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
-            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
+            from_token=from_token.room_key.as_historical_tuple()
+            if from_token
+            else None,
+            to_token=to_token.room_key.as_historical_tuple() if to_token else None,
             engine=self.database_engine,
         )
 
@@ -166,12 +177,27 @@ class RelationsWorkerStore(SQLBaseStore):
                 last_topo_id = row[1]
                 last_stream_id = row[2]
 
-            next_batch = None
+            # If there are more events, generate the next pagination key.
+            next_token = None
             if len(events) > limit and last_topo_id and last_stream_id:
-                next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+                next_key = RoomStreamToken(last_topo_id, last_stream_id)
+                if from_token:
+                    next_token = from_token.copy_and_replace("room_key", next_key)
+                else:
+                    next_token = StreamToken(
+                        room_key=next_key,
+                        presence_key=0,
+                        typing_key=0,
+                        receipt_key=0,
+                        account_data_key=0,
+                        push_rules_key=0,
+                        to_device_key=0,
+                        device_list_key=0,
+                        groups_key=0,
+                    )
 
             return PaginationChunk(
-                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+                chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
             )
 
         return await self.db_pool.runInteraction(
@@ -340,20 +366,24 @@ class RelationsWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    async def get_applicable_edit(
-        self, event_id: str, room_id: str
-    ) -> Optional[EventBase]:
+    def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
+    async def _get_applicable_edits(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, Optional[EventBase]]:
         """Get the most recent edit (if any) that has happened for the given
-        event.
+        events.
 
         Correctly handles checking whether edits were allowed to happen.
 
         Args:
-            event_id: The original event ID
-            room_id: The original event's room ID
+            event_ids: The original event IDs
 
         Returns:
-            The most recent edit, if any.
+            A map of the most recent edit for each event. If there are no edits,
+            the event will map to None.
         """
 
         # We only allow edits for `m.room.message` events that have the same sender
@@ -362,139 +392,238 @@ class RelationsWorkerStore(SQLBaseStore):
 
         # Fetches latest edit that has the same type and sender as the
         # original, and is an `m.room.message`.
-        sql = """
-            SELECT edit.event_id FROM events AS edit
-            INNER JOIN event_relations USING (event_id)
-            INNER JOIN events AS original ON
-                original.event_id = relates_to_id
-                AND edit.type = original.type
-                AND edit.sender = original.sender
-            WHERE
-                relates_to_id = ?
-                AND relation_type = ?
-                AND edit.room_id = ?
-                AND edit.type = 'm.room.message'
-            ORDER by edit.origin_server_ts DESC, edit.event_id DESC
-            LIMIT 1
-        """
+        if isinstance(self.database_engine, PostgresEngine):
+            # The `DISTINCT ON` clause will pick the *first* row it encounters,
+            # so ordering by origin server ts + event ID desc will ensure we get
+            # the latest edit.
+            sql = """
+                SELECT DISTINCT ON (original.event_id) original.event_id, edit.event_id FROM events AS edit
+                INNER JOIN event_relations USING (event_id)
+                INNER JOIN events AS original ON
+                    original.event_id = relates_to_id
+                    AND edit.type = original.type
+                    AND edit.sender = original.sender
+                    AND edit.room_id = original.room_id
+                WHERE
+                    %s
+                    AND relation_type = ?
+                    AND edit.type = 'm.room.message'
+                ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC
+            """
+        else:
+            # SQLite uses a simplified query which returns all edits for an
+            # original event. The results are then de-duplicated when turned into
+            # a dict. Due to the chosen ordering, the latest edit stomps on
+            # earlier edits.
+            sql = """
+                SELECT original.event_id, edit.event_id FROM events AS edit
+                INNER JOIN event_relations USING (event_id)
+                INNER JOIN events AS original ON
+                    original.event_id = relates_to_id
+                    AND edit.type = original.type
+                    AND edit.sender = original.sender
+                    AND edit.room_id = original.room_id
+                WHERE
+                    %s
+                    AND relation_type = ?
+                    AND edit.type = 'm.room.message'
+                ORDER by edit.origin_server_ts, edit.event_id
+            """
 
-        def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
-            txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
-            row = txn.fetchone()
-            if row:
-                return row[0]
-            return None
+        def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]:
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", event_ids
+            )
+            args.append(RelationTypes.REPLACE)
 
-        edit_id = await self.db_pool.runInteraction(
-            "get_applicable_edit", _get_applicable_edit_txn
+            txn.execute(sql % (clause,), args)
+            return dict(cast(Iterable[Tuple[str, str]], txn.fetchall()))
+
+        edit_ids = await self.db_pool.runInteraction(
+            "get_applicable_edits", _get_applicable_edits_txn
         )
 
-        if not edit_id:
-            return None
+        edits = await self.get_events(edit_ids.values())  # type: ignore[attr-defined]
 
-        return await self.get_event(edit_id, allow_none=True)  # type: ignore[attr-defined]
+        # Map to the original event IDs to the edit events.
+        #
+        # There might not be an edit event due to there being no edits or
+        # due to the event not being known, either case is treated the same.
+        return {
+            original_event_id: edits.get(edit_ids.get(original_event_id))
+            for original_event_id in event_ids
+        }
 
     @cached()
-    async def get_thread_summary(
-        self, event_id: str, room_id: str
-    ) -> Tuple[int, Optional[EventBase]]:
+    def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
+    async def _get_thread_summaries(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
         """Get the number of threaded replies and the latest reply (if any) for the given event.
 
         Args:
-            event_id: Summarize the thread related to this event ID.
-            room_id: The room the event belongs to.
+            event_ids: Summarize the thread related to this event ID.
 
         Returns:
-            The number of items in the thread and the most recent response, if any.
+            A map of the thread summary each event. A missing event implies there
+            are no threaded replies.
+
+            Each summary includes the number of items in the thread and the most
+            recent response.
         """
 
-        def _get_thread_summary_txn(
+        def _get_thread_summaries_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[int, Optional[str]]:
-            # Fetch the latest event ID in the thread.
+        ) -> Tuple[Dict[str, int], Dict[str, str]]:
+            # Fetch the count of threaded events and the latest event ID.
             # TODO Should this only allow m.room.message events.
-            sql = """
-                SELECT event_id
-                FROM event_relations
-                INNER JOIN events USING (event_id)
-                WHERE
-                    relates_to_id = ?
-                    AND room_id = ?
-                    AND relation_type = ?
-                ORDER BY topological_ordering DESC, stream_ordering DESC
-                LIMIT 1
-            """
+            if isinstance(self.database_engine, PostgresEngine):
+                # The `DISTINCT ON` clause will pick the *first* row it encounters,
+                # so ordering by topologica ordering + stream ordering desc will
+                # ensure we get the latest event in the thread.
+                sql = """
+                    SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
+                    INNER JOIN event_relations USING (event_id)
+                    INNER JOIN events AS parent ON
+                        parent.event_id = relates_to_id
+                        AND parent.room_id = child.room_id
+                    WHERE
+                        %s
+                        AND relation_type = ?
+                    ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
+                """
+            else:
+                # SQLite uses a simplified query which returns all entries for a
+                # thread. The first result for each thread is chosen to and subsequent
+                # results for a thread are ignored.
+                sql = """
+                    SELECT parent.event_id, child.event_id FROM events AS child
+                    INNER JOIN event_relations USING (event_id)
+                    INNER JOIN events AS parent ON
+                        parent.event_id = relates_to_id
+                        AND parent.room_id = child.room_id
+                    WHERE
+                        %s
+                        AND relation_type = ?
+                    ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
+                """
+
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", event_ids
+            )
+            args.append(RelationTypes.THREAD)
 
-            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
-            row = txn.fetchone()
-            if row is None:
-                return 0, None
+            txn.execute(sql % (clause,), args)
+            latest_event_ids = {}
+            for parent_event_id, child_event_id in txn:
+                # Only consider the latest threaded reply (by topological ordering).
+                if parent_event_id not in latest_event_ids:
+                    latest_event_ids[parent_event_id] = child_event_id
 
-            latest_event_id = row[0]
+            # If no threads were found, bail.
+            if not latest_event_ids:
+                return {}, latest_event_ids
 
             # Fetch the number of threaded replies.
             sql = """
-                SELECT COUNT(event_id)
-                FROM event_relations
-                INNER JOIN events USING (event_id)
+                SELECT parent.event_id, COUNT(child.event_id) FROM events AS child
+                INNER JOIN event_relations USING (event_id)
+                INNER JOIN events AS parent ON
+                    parent.event_id = relates_to_id
+                    AND parent.room_id = child.room_id
                 WHERE
-                    relates_to_id = ?
-                    AND room_id = ?
+                    %s
                     AND relation_type = ?
+                GROUP BY parent.event_id
             """
-            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
-            count = cast(Tuple[int], txn.fetchone())[0]
 
-            return count, latest_event_id
+            # Regenerate the arguments since only threads found above could
+            # possibly have any replies.
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", latest_event_ids.keys()
+            )
+            args.append(RelationTypes.THREAD)
 
-        count, latest_event_id = await self.db_pool.runInteraction(
-            "get_thread_summary", _get_thread_summary_txn
+            txn.execute(sql % (clause,), args)
+            counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
+
+            return counts, latest_event_ids
+
+        counts, latest_event_ids = await self.db_pool.runInteraction(
+            "get_thread_summaries", _get_thread_summaries_txn
         )
 
-        latest_event = None
-        if latest_event_id:
-            latest_event = await self.get_event(latest_event_id, allow_none=True)  # type: ignore[attr-defined]
+        latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
-        return count, latest_event
+        # Map to the event IDs to the thread summary.
+        #
+        # There might not be a summary due to there not being a thread or
+        # due to the latest event not being known, either case is treated the same.
+        summaries = {}
+        for parent_event_id, latest_event_id in latest_event_ids.items():
+            latest_event = latest_events.get(latest_event_id)
+
+            summary = None
+            if latest_event:
+                summary = (counts[parent_event_id], latest_event)
+            summaries[parent_event_id] = summary
+
+        return summaries
 
     @cached()
-    async def get_thread_participated(
-        self, event_id: str, room_id: str, user_id: str
-    ) -> bool:
-        """Get whether the requesting user participated in a thread.
+    def get_thread_participated(self, event_id: str, user_id: str) -> bool:
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
+    async def _get_threads_participated(
+        self, event_ids: Collection[str], user_id: str
+    ) -> Dict[str, bool]:
+        """Get whether the requesting user participated in the given threads.
 
-        This is separate from get_thread_summary since that can be cached across
-        all users while this value is specific to the requeser.
+        This is separate from get_thread_summaries since that can be cached across
+        all users while this value is specific to the requester.
 
         Args:
-            event_id: The thread related to this event ID.
-            room_id: The room the event belongs to.
+            event_ids: The thread related to these event IDs.
             user_id: The user requesting the summary.
 
         Returns:
-            True if the requesting user participated in the thread, otherwise false.
+            A map of event ID to a boolean which represents if the requesting
+            user participated in that event's thread, otherwise false.
         """
 
-        def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+        def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
             # Fetch whether the requester has participated or not.
             sql = """
-                SELECT 1
-                FROM event_relations
-                INNER JOIN events USING (event_id)
+                SELECT DISTINCT relates_to_id
+                FROM events AS child
+                INNER JOIN event_relations USING (event_id)
+                INNER JOIN events AS parent ON
+                    parent.event_id = relates_to_id
+                    AND parent.room_id = child.room_id
                 WHERE
-                    relates_to_id = ?
-                    AND room_id = ?
+                    %s
                     AND relation_type = ?
-                    AND sender = ?
+                    AND child.sender = ?
             """
 
-            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
-            return bool(txn.fetchone())
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", event_ids
+            )
+            args.extend((RelationTypes.THREAD, user_id))
 
-        return await self.db_pool.runInteraction(
+            txn.execute(sql % (clause,), args)
+            return {row[0] for row in txn.fetchall()}
+
+        participated_threads = await self.db_pool.runInteraction(
             "get_thread_summary", _get_thread_summary_txn
         )
 
+        return {event_id: event_id in participated_threads for event_id in event_ids}
+
     async def events_have_relations(
         self,
         parent_ids: List[str],
@@ -612,9 +741,6 @@ class RelationsWorkerStore(SQLBaseStore):
             The bundled aggregations for an event, if bundled aggregations are
             enabled and the event can have bundled aggregations.
         """
-        # State events and redacted events do not get bundled aggregations.
-        if event.is_state() or event.internal_metadata.is_redacted():
-            return None
 
         # Do not bundle aggregations for an event which represents an edit or an
         # annotation. It does not make sense for them to have related events.
@@ -634,43 +760,21 @@ class RelationsWorkerStore(SQLBaseStore):
 
         annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
         if annotations.chunk:
-            aggregations.annotations = annotations.to_dict()
+            aggregations.annotations = await annotations.to_dict(
+                cast("DataStore", self)
+            )
 
         references = await self.get_relations_for_event(
             event_id, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
-            aggregations.references = references.to_dict()
-
-        edit = None
-        if event.type == EventTypes.Message:
-            edit = await self.get_applicable_edit(event_id, room_id)
-
-        if edit:
-            aggregations.replace = edit
-
-        # If this event is the start of a thread, include a summary of the replies.
-        if self._msc3440_enabled:
-            thread_count, latest_thread_event = await self.get_thread_summary(
-                event_id, room_id
-            )
-            participated = await self.get_thread_participated(
-                event_id, room_id, user_id
-            )
-            if latest_thread_event:
-                aggregations.thread = _ThreadAggregation(
-                    latest_event=latest_thread_event,
-                    count=thread_count,
-                    current_user_participated=participated,
-                )
+            aggregations.references = await references.to_dict(cast("DataStore", self))
 
         # Store the bundled aggregations in the event metadata for later use.
         return aggregations
 
     async def get_bundled_aggregations(
-        self,
-        events: Iterable[EventBase],
-        user_id: str,
+        self, events: Iterable[EventBase], user_id: str
     ) -> Dict[str, BundledAggregations]:
         """Generate bundled aggregations for events.
 
@@ -682,14 +786,59 @@ class RelationsWorkerStore(SQLBaseStore):
             A map of event ID to the bundled aggregation for the event. Not all
             events may have bundled aggregations in the results.
         """
+        # The already processed event IDs. Tracked separately from the result
+        # since the result omits events which do not have bundled aggregations.
+        seen_event_ids = set()
 
-        # TODO Parallelize.
-        results = {}
+        # State events and redacted events do not get bundled aggregations.
+        events = [
+            event
+            for event in events
+            if not event.is_state() and not event.internal_metadata.is_redacted()
+        ]
+
+        # event ID -> bundled aggregation in non-serialized form.
+        results: Dict[str, BundledAggregations] = {}
+
+        # Fetch other relations per event.
         for event in events:
+            # De-duplicate events by ID to handle the same event requested multiple
+            # times. The caches that _get_bundled_aggregation_for_event use should
+            # capture this, but best to reduce work.
+            if event.event_id in seen_event_ids:
+                continue
+            seen_event_ids.add(event.event_id)
+
             event_result = await self._get_bundled_aggregation_for_event(event, user_id)
             if event_result:
                 results[event.event_id] = event_result
 
+        # Fetch any edits.
+        edits = await self._get_applicable_edits(seen_event_ids)
+        for event_id, edit in edits.items():
+            results.setdefault(event_id, BundledAggregations()).replace = edit
+
+        # Fetch thread summaries.
+        if self._msc3440_enabled:
+            summaries = await self._get_thread_summaries(seen_event_ids)
+            # Only fetch participated for a limited selection based on what had
+            # summaries.
+            participated = await self._get_threads_participated(
+                summaries.keys(), user_id
+            )
+            for event_id, summary in summaries.items():
+                if summary:
+                    thread_count, latest_thread_event = summary
+                    results.setdefault(
+                        event_id, BundledAggregations()
+                    ).thread = _ThreadAggregation(
+                        latest_event=latest_thread_event,
+                        count=thread_count,
+                        # If there's a thread summary it must also exist in the
+                        # participated dictionary.
+                        current_user_participated=participated[event_id],
+                    )
+
         return results