diff --git a/changelog.d/16805.misc b/changelog.d/16805.misc
new file mode 100644
index 0000000000..0b54ab0f74
--- /dev/null
+++ b/changelog.d/16805.misc
@@ -0,0 +1 @@
+Optimize query for fetching to-device messages in `/sync`.
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 40477b9da0..fa47b471e8 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -245,33 +245,74 @@ class DeviceInboxWorkerStore(SQLBaseStore):
* The last-processed stream ID. Subsequent calls of this function with the
same device should pass this value as 'from_stream_id'.
"""
- (
- user_id_device_id_to_messages,
- last_processed_stream_id,
- ) = await self._get_device_messages(
- user_ids=[user_id],
- device_id=device_id,
- from_stream_id=from_stream_id,
- to_stream_id=to_stream_id,
- limit=limit,
- )
-
- if not user_id_device_id_to_messages:
+ if not self._device_inbox_stream_cache.has_entity_changed(
+ user_id, from_stream_id
+ ):
# There were no messages!
return [], to_stream_id
- # Extract the messages, no need to return the user and device ID again
- to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
+ def get_device_messages_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[JsonDict], int]:
+ sql = """
+ SELECT stream_id, message_json FROM device_inbox
+ WHERE user_id = ? AND device_id = ?
+ AND ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (user_id, device_id, from_stream_id, to_stream_id, limit))
+
+ # Create and fill a dictionary of (user ID, device ID) -> list of messages
+ # intended for each device.
+ last_processed_stream_pos = to_stream_id
+ to_device_messages: List[JsonDict] = []
+ rowcount = 0
+ for row in txn:
+ rowcount += 1
+
+ last_processed_stream_pos = row[0]
+ message_dict = db_to_json(row[1])
+
+ # Store the device details
+ to_device_messages.append(message_dict)
- return to_device_messages, last_processed_stream_id
+ # start a new span for each message, so that we can tag each separately
+ with start_active_span("get_to_device_message"):
+ set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"])
+ set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"])
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
+ set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
+ set_tag(
+ SynapseTags.TO_DEVICE_MSGID,
+ message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
+ )
+
+ if rowcount == limit:
+ # We ended up bumping up against the message limit. There may be more messages
+ # to retrieve. Return what we have, as well as the last stream position that
+ # was processed.
+ #
+ # The caller is expected to set this as the lower (exclusive) bound
+ # for the next query of this device.
+ return to_device_messages, last_processed_stream_pos
+
+ # The limit was not reached, thus we know that recipient_device_to_messages
+ # contains all to-device messages for the given device and stream id range.
+ #
+ # We return to_stream_id, which the caller should then provide as the lower
+ # (exclusive) bound on the next query of this device.
+ return to_device_messages, to_stream_id
+
+ return await self.db_pool.runInteraction(
+ "get_messages_for_device", get_device_messages_txn
+ )
async def _get_device_messages(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
- device_id: Optional[str] = None,
- limit: Optional[int] = None,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
"""
Retrieve pending to-device messages for a collection of user devices.
@@ -291,11 +332,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_ids: The user IDs to filter device messages by.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
- device_id: A device ID to query to-device messages for. If not provided, to-device
- messages from all device IDs for the given user IDs will be queried. May not be
- provided if `user_ids` contains more than one entry.
- limit: The maximum number of to-device messages to return. Can only be used when
- passing a single user ID / device ID tuple.
+
Returns:
A tuple containing:
@@ -308,30 +345,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
logger.warning("No users provided upon querying for device IDs")
return {}, to_stream_id
- # Prevent a query for one user's device also retrieving another user's device with
- # the same device ID (device IDs are not unique across users).
- if len(user_ids) > 1 and device_id is not None:
- raise AssertionError(
- "Programming error: 'device_id' cannot be supplied to "
- "_get_device_messages when >1 user_id has been provided"
- )
-
- # A limit can only be applied when querying for a single user ID / device ID tuple.
- # See the docstring of this function for more details.
- if limit is not None and device_id is None:
- raise AssertionError(
- "Programming error: _get_device_messages was passed 'limit' "
- "without a specific user_id/device_id"
- )
-
user_ids_to_query: Set[str] = set()
- device_ids_to_query: Set[str] = set()
-
- # Note that a device ID could be an empty str
- if device_id is not None:
- # If a device ID was passed, use it to filter results.
- # Otherwise, device IDs will be derived from the given collection of user IDs.
- device_ids_to_query.add(device_id)
# Determine which users have devices with pending messages
for user_id in user_ids:
@@ -355,20 +369,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# hidden devices should not receive to-device messages.
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
- if not device_ids_to_query:
- user_device_dicts = cast(
- List[Tuple[str]],
- self.db_pool.simple_select_many_txn(
- txn,
- table="devices",
- column="user_id",
- iterable=user_ids_to_query,
- keyvalues={"hidden": False},
- retcols=("device_id",),
- ),
- )
- device_ids_to_query.update({row[0] for row in user_device_dicts})
+ user_device_dicts = cast(
+ List[Tuple[str]],
+ self.db_pool.simple_select_many_txn(
+ txn,
+ table="devices",
+ column="user_id",
+ iterable=user_ids_to_query,
+ keyvalues={"hidden": False},
+ retcols=("device_id",),
+ ),
+ )
+
+ device_ids_to_query = {row[0] for row in user_device_dicts}
if not device_ids_to_query:
# We've ended up with no devices to query.
@@ -400,22 +414,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
to_stream_id,
)
- # If a limit was provided, limit the data retrieved from the database
- if limit is not None:
- sql += "LIMIT ?"
- sql_args += (limit,)
-
txn.execute(sql, sql_args)
# Create and fill a dictionary of (user ID, device ID) -> list of messages
# intended for each device.
- last_processed_stream_pos = to_stream_id
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
rowcount = 0
for row in txn:
rowcount += 1
- last_processed_stream_pos = row[0]
recipient_user_id = row[1]
recipient_device_id = row[2]
message_dict = db_to_json(row[3])
@@ -436,18 +443,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
)
- if limit is not None and rowcount == limit:
- # We ended up bumping up against the message limit. There may be more messages
- # to retrieve. Return what we have, as well as the last stream position that
- # was processed.
- #
- # The caller is expected to set this as the lower (exclusive) bound
- # for the next query of this device.
- return recipient_device_to_messages, last_processed_stream_pos
-
- # The limit was not reached, thus we know that recipient_device_to_messages
- # contains all to-device messages for the given device and stream id range.
- #
# We return to_stream_id, which the caller should then provide as the lower
# (exclusive) bound on the next query of this device.
return recipient_device_to_messages, to_stream_id
|