diff options
author | Šimon Brandner <simon.bra.ag@gmail.com> | 2022-05-04 17:59:22 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-04 11:59:22 -0400 |
commit | 116a4c8340b729ffde43be33df24d417384cb28b (patch) | |
tree | b74756a823802110beb1e0b90451973c886d270c /synapse/storage | |
parent | Disable device name lookup over federation by default (#12616) (diff) | |
download | synapse-116a4c8340b729ffde43be33df24d417384cb28b.tar.xz |
Implement changes to MSC2285 (hidden read receipts) (#12168)
* Changes hidden read receipts to be a separate receipt type (instead of a field on `m.read`). * Updates the `/receipts` endpoint to accept `m.fully_read`.
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/databases/main/receipts.py | 142 |
1 files changed, 110 insertions, 32 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 7d96f4feda..9e3d838eab 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -144,43 +144,77 @@ class ReceiptsWorkerStore(SQLBaseStore): desc="get_receipts_for_room", ) - @cached() async def get_last_receipt_event_id_for_user( - self, user_id: str, room_id: str, receipt_type: str + self, user_id: str, room_id: str, receipt_types: Iterable[str] ) -> Optional[str]: """ - Fetch the event ID for the latest receipt in a room with the given receipt type. + Fetch the event ID for the latest receipt in a room with one of the given receipt types. Args: user_id: The user to fetch receipts for. room_id: The room ID to fetch the receipt for. - receipt_type: The receipt type to fetch. + receipt_type: The receipt types to fetch. Earlier receipt types + are given priority if multiple receipts point to the same event. Returns: - The event ID of the latest receipt, if one exists; otherwise `None`. + The latest receipt, if one exists. """ - return await self.db_pool.simple_select_one_onecol( - table="receipts_linearized", - keyvalues={ - "room_id": room_id, - "receipt_type": receipt_type, - "user_id": user_id, - }, - retcol="event_id", - desc="get_own_receipt_for_user", - allow_none=True, - ) + latest_event_id: Optional[str] = None + latest_stream_ordering = 0 + for receipt_type in receipt_types: + result = await self._get_last_receipt_event_id_for_user( + user_id, room_id, receipt_type + ) + if result is None: + continue + event_id, stream_ordering = result + + if latest_event_id is None or latest_stream_ordering < stream_ordering: + latest_event_id = event_id + latest_stream_ordering = stream_ordering + + return latest_event_id @cached() + async def _get_last_receipt_event_id_for_user( + self, user_id: str, room_id: str, receipt_type: str + ) -> Optional[Tuple[str, int]]: + """ + Fetch the event ID and stream ordering for the latest receipt. + + Args: + user_id: The user to fetch receipts for. + room_id: The room ID to fetch the receipt for. + receipt_type: The receipt type to fetch. + + Returns: + The event ID and stream ordering of the latest receipt, if one exists; + otherwise `None`. + """ + sql = """ + SELECT event_id, stream_ordering + FROM receipts_linearized + INNER JOIN events USING (room_id, event_id) + WHERE user_id = ? + AND room_id = ? + AND receipt_type = ? + """ + + def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]: + txn.execute(sql, (user_id, room_id, receipt_type)) + return cast(Optional[Tuple[str, int]], txn.fetchone()) + + return await self.db_pool.runInteraction("get_own_receipt_for_user", f) + async def get_receipts_for_user( - self, user_id: str, receipt_type: str + self, user_id: str, receipt_types: Iterable[str] ) -> Dict[str, str]: """ Fetch the event IDs for the latest receipts sent by the given user. Args: user_id: The user to fetch receipts for. - receipt_type: The receipt type to fetch. + receipt_types: The receipt types to check. Returns: A map of room ID to the event ID of the latest receipt for that room. @@ -188,16 +222,48 @@ class ReceiptsWorkerStore(SQLBaseStore): If the user has not sent a receipt to a room then it will not appear in the returned dictionary. """ - rows = await self.db_pool.simple_select_list( - table="receipts_linearized", - keyvalues={"user_id": user_id, "receipt_type": receipt_type}, - retcols=("room_id", "event_id"), - desc="get_receipts_for_user", + results = await self.get_receipts_for_user_with_orderings( + user_id, receipt_types ) - return {row["room_id"]: row["event_id"] for row in rows} + # Reduce the result to room ID -> event ID. + return { + room_id: room_result["event_id"] for room_id, room_result in results.items() + } async def get_receipts_for_user_with_orderings( + self, user_id: str, receipt_types: Iterable[str] + ) -> JsonDict: + """ + Fetch receipts for all rooms that the given user is joined to. + + Args: + user_id: The user to fetch receipts for. + receipt_types: The receipt types to fetch. Earlier receipt types + are given priority if multiple receipts point to the same event. + + Returns: + A map of room ID to the latest receipt (for the given types). + """ + results: JsonDict = {} + for receipt_type in receipt_types: + partial_result = await self._get_receipts_for_user_with_orderings( + user_id, receipt_type + ) + for room_id, room_result in partial_result.items(): + # If the room has not yet been seen, or the receipt is newer, + # use it. + if ( + room_id not in results + or results[room_id]["stream_ordering"] + < room_result["stream_ordering"] + ): + results[room_id] = room_result + + return results + + @cached() + async def _get_receipts_for_user_with_orderings( self, user_id: str, receipt_type: str ) -> JsonDict: """ @@ -220,8 +286,9 @@ class ReceiptsWorkerStore(SQLBaseStore): " WHERE rl.room_id = e.room_id" " AND rl.event_id = e.event_id" " AND user_id = ?" + " AND receipt_type = ?" ) - txn.execute(sql, (user_id,)) + txn.execute(sql, (user_id, receipt_type)) return cast(List[Tuple[str, str, int, int]], txn.fetchall()) rows = await self.db_pool.runInteraction( @@ -552,9 +619,9 @@ class ReceiptsWorkerStore(SQLBaseStore): def invalidate_caches_for_receipt( self, room_id: str, receipt_type: str, user_id: str ) -> None: - self.get_receipts_for_user.invalidate((user_id, receipt_type)) + self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) - self.get_last_receipt_event_id_for_user.invalidate( + self._get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) @@ -590,8 +657,8 @@ class ReceiptsWorkerStore(SQLBaseStore): """Inserts a receipt into the database if it's newer than the current one. Returns: - None if the RR is older than the current RR - otherwise, the rx timestamp of the event that the RR corresponds to + None if the receipt is older than the current receipt + otherwise, the rx timestamp of the event that the receipt corresponds to (or 0 if the event is unknown) """ assert self._can_write_to_receipts @@ -612,7 +679,7 @@ class ReceiptsWorkerStore(SQLBaseStore): if stream_ordering is not None: sql = ( "SELECT stream_ordering, event_id FROM events" - " INNER JOIN receipts_linearized as r USING (event_id, room_id)" + " INNER JOIN receipts_linearized AS r USING (event_id, room_id)" " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" ) txn.execute(sql, (room_id, receipt_type, user_id)) @@ -653,7 +720,10 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) - if receipt_type == ReceiptTypes.READ and stream_ordering is not None: + if ( + receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) + and stream_ordering is not None + ): self._remove_old_push_actions_before_txn( # type: ignore[attr-defined] txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) @@ -672,6 +742,10 @@ class ReceiptsWorkerStore(SQLBaseStore): Automatically does conversion between linearized and graph representations. + + Returns: + The new receipts stream ID and token, if the receipt is newer than + what was previously persisted. None, otherwise. """ assert self._can_write_to_receipts @@ -719,6 +793,7 @@ class ReceiptsWorkerStore(SQLBaseStore): stream_id=stream_id, ) + # If the receipt was older than the currently persisted one, nothing to do. if event_ts is None: return None @@ -774,7 +849,10 @@ class ReceiptsWorkerStore(SQLBaseStore): receipt_type, user_id, ) - txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) + txn.call_after( + self._get_receipts_for_user_with_orderings.invalidate, + (user_id, receipt_type), + ) # FIXME: This shouldn't invalidate the whole cache txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) |