diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 7d96f4feda..9e3d838eab 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -144,43 +144,77 @@ class ReceiptsWorkerStore(SQLBaseStore):
desc="get_receipts_for_room",
)
- @cached()
async def get_last_receipt_event_id_for_user(
- self, user_id: str, room_id: str, receipt_type: str
+ self, user_id: str, room_id: str, receipt_types: Iterable[str]
) -> Optional[str]:
"""
- Fetch the event ID for the latest receipt in a room with the given receipt type.
+ Fetch the event ID for the latest receipt in a room with one of the given receipt types.
Args:
user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for.
- receipt_type: The receipt type to fetch.
+ receipt_type: The receipt types to fetch. Earlier receipt types
+ are given priority if multiple receipts point to the same event.
Returns:
- The event ID of the latest receipt, if one exists; otherwise `None`.
+ The latest receipt, if one exists.
"""
- return await self.db_pool.simple_select_one_onecol(
- table="receipts_linearized",
- keyvalues={
- "room_id": room_id,
- "receipt_type": receipt_type,
- "user_id": user_id,
- },
- retcol="event_id",
- desc="get_own_receipt_for_user",
- allow_none=True,
- )
+ latest_event_id: Optional[str] = None
+ latest_stream_ordering = 0
+ for receipt_type in receipt_types:
+ result = await self._get_last_receipt_event_id_for_user(
+ user_id, room_id, receipt_type
+ )
+ if result is None:
+ continue
+ event_id, stream_ordering = result
+
+ if latest_event_id is None or latest_stream_ordering < stream_ordering:
+ latest_event_id = event_id
+ latest_stream_ordering = stream_ordering
+
+ return latest_event_id
@cached()
+ async def _get_last_receipt_event_id_for_user(
+ self, user_id: str, room_id: str, receipt_type: str
+ ) -> Optional[Tuple[str, int]]:
+ """
+ Fetch the event ID and stream ordering for the latest receipt.
+
+ Args:
+ user_id: The user to fetch receipts for.
+ room_id: The room ID to fetch the receipt for.
+ receipt_type: The receipt type to fetch.
+
+ Returns:
+ The event ID and stream ordering of the latest receipt, if one exists;
+ otherwise `None`.
+ """
+ sql = """
+ SELECT event_id, stream_ordering
+ FROM receipts_linearized
+ INNER JOIN events USING (room_id, event_id)
+ WHERE user_id = ?
+ AND room_id = ?
+ AND receipt_type = ?
+ """
+
+ def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]:
+ txn.execute(sql, (user_id, room_id, receipt_type))
+ return cast(Optional[Tuple[str, int]], txn.fetchone())
+
+ return await self.db_pool.runInteraction("get_own_receipt_for_user", f)
+
async def get_receipts_for_user(
- self, user_id: str, receipt_type: str
+ self, user_id: str, receipt_types: Iterable[str]
) -> Dict[str, str]:
"""
Fetch the event IDs for the latest receipts sent by the given user.
Args:
user_id: The user to fetch receipts for.
- receipt_type: The receipt type to fetch.
+ receipt_types: The receipt types to check.
Returns:
A map of room ID to the event ID of the latest receipt for that room.
@@ -188,16 +222,48 @@ class ReceiptsWorkerStore(SQLBaseStore):
If the user has not sent a receipt to a room then it will not appear
in the returned dictionary.
"""
- rows = await self.db_pool.simple_select_list(
- table="receipts_linearized",
- keyvalues={"user_id": user_id, "receipt_type": receipt_type},
- retcols=("room_id", "event_id"),
- desc="get_receipts_for_user",
+ results = await self.get_receipts_for_user_with_orderings(
+ user_id, receipt_types
)
- return {row["room_id"]: row["event_id"] for row in rows}
+ # Reduce the result to room ID -> event ID.
+ return {
+ room_id: room_result["event_id"] for room_id, room_result in results.items()
+ }
async def get_receipts_for_user_with_orderings(
+ self, user_id: str, receipt_types: Iterable[str]
+ ) -> JsonDict:
+ """
+ Fetch receipts for all rooms that the given user is joined to.
+
+ Args:
+ user_id: The user to fetch receipts for.
+ receipt_types: The receipt types to fetch. Earlier receipt types
+ are given priority if multiple receipts point to the same event.
+
+ Returns:
+ A map of room ID to the latest receipt (for the given types).
+ """
+ results: JsonDict = {}
+ for receipt_type in receipt_types:
+ partial_result = await self._get_receipts_for_user_with_orderings(
+ user_id, receipt_type
+ )
+ for room_id, room_result in partial_result.items():
+ # If the room has not yet been seen, or the receipt is newer,
+ # use it.
+ if (
+ room_id not in results
+ or results[room_id]["stream_ordering"]
+ < room_result["stream_ordering"]
+ ):
+ results[room_id] = room_result
+
+ return results
+
+ @cached()
+ async def _get_receipts_for_user_with_orderings(
self, user_id: str, receipt_type: str
) -> JsonDict:
"""
@@ -220,8 +286,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
" WHERE rl.room_id = e.room_id"
" AND rl.event_id = e.event_id"
" AND user_id = ?"
+ " AND receipt_type = ?"
)
- txn.execute(sql, (user_id,))
+ txn.execute(sql, (user_id, receipt_type))
return cast(List[Tuple[str, str, int, int]], txn.fetchall())
rows = await self.db_pool.runInteraction(
@@ -552,9 +619,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
def invalidate_caches_for_receipt(
self, room_id: str, receipt_type: str, user_id: str
) -> None:
- self.get_receipts_for_user.invalidate((user_id, receipt_type))
+ self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate((room_id,))
- self.get_last_receipt_event_id_for_user.invalidate(
+ self._get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
@@ -590,8 +657,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""Inserts a receipt into the database if it's newer than the current one.
Returns:
- None if the RR is older than the current RR
- otherwise, the rx timestamp of the event that the RR corresponds to
+ None if the receipt is older than the current receipt
+ otherwise, the rx timestamp of the event that the receipt corresponds to
(or 0 if the event is unknown)
"""
assert self._can_write_to_receipts
@@ -612,7 +679,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
if stream_ordering is not None:
sql = (
"SELECT stream_ordering, event_id FROM events"
- " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+ " INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
)
txn.execute(sql, (room_id, receipt_type, user_id))
@@ -653,7 +720,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
lock=False,
)
- if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
+ if (
+ receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
+ and stream_ordering is not None
+ ):
self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
@@ -672,6 +742,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
Automatically does conversion between linearized and graph
representations.
+
+ Returns:
+ The new receipts stream ID and token, if the receipt is newer than
+ what was previously persisted. None, otherwise.
"""
assert self._can_write_to_receipts
@@ -719,6 +793,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
stream_id=stream_id,
)
+ # If the receipt was older than the currently persisted one, nothing to do.
if event_ts is None:
return None
@@ -774,7 +849,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type,
user_id,
)
- txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
+ txn.call_after(
+ self._get_receipts_for_user_with_orderings.invalidate,
+ (user_id, receipt_type),
+ )
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
|