summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/receipts.py142
1 files changed, 110 insertions, 32 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 7d96f4feda..9e3d838eab 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -144,43 +144,77 @@ class ReceiptsWorkerStore(SQLBaseStore):
             desc="get_receipts_for_room",
         )
 
-    @cached()
     async def get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_type: str
+        self, user_id: str, room_id: str, receipt_types: Iterable[str]
     ) -> Optional[str]:
         """
-        Fetch the event ID for the latest receipt in a room with the given receipt type.
+        Fetch the event ID for the latest receipt in a room with one of the given receipt types.
 
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt type to fetch.
+            receipt_type: The receipt types to fetch. Earlier receipt types
+                are given priority if multiple receipts point to the same event.
 
         Returns:
-            The event ID of the latest receipt, if one exists; otherwise `None`.
+            The latest receipt, if one exists.
         """
-        return await self.db_pool.simple_select_one_onecol(
-            table="receipts_linearized",
-            keyvalues={
-                "room_id": room_id,
-                "receipt_type": receipt_type,
-                "user_id": user_id,
-            },
-            retcol="event_id",
-            desc="get_own_receipt_for_user",
-            allow_none=True,
-        )
+        latest_event_id: Optional[str] = None
+        latest_stream_ordering = 0
+        for receipt_type in receipt_types:
+            result = await self._get_last_receipt_event_id_for_user(
+                user_id, room_id, receipt_type
+            )
+            if result is None:
+                continue
+            event_id, stream_ordering = result
+
+            if latest_event_id is None or latest_stream_ordering < stream_ordering:
+                latest_event_id = event_id
+                latest_stream_ordering = stream_ordering
+
+        return latest_event_id
 
     @cached()
+    async def _get_last_receipt_event_id_for_user(
+        self, user_id: str, room_id: str, receipt_type: str
+    ) -> Optional[Tuple[str, int]]:
+        """
+        Fetch the event ID and stream ordering for the latest receipt.
+
+        Args:
+            user_id: The user to fetch receipts for.
+            room_id: The room ID to fetch the receipt for.
+            receipt_type: The receipt type to fetch.
+
+        Returns:
+            The event ID and stream ordering of the latest receipt, if one exists;
+            otherwise `None`.
+        """
+        sql = """
+            SELECT event_id, stream_ordering
+            FROM receipts_linearized
+            INNER JOIN events USING (room_id, event_id)
+            WHERE user_id = ?
+            AND room_id = ?
+            AND receipt_type = ?
+        """
+
+        def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]:
+            txn.execute(sql, (user_id, room_id, receipt_type))
+            return cast(Optional[Tuple[str, int]], txn.fetchone())
+
+        return await self.db_pool.runInteraction("get_own_receipt_for_user", f)
+
     async def get_receipts_for_user(
-        self, user_id: str, receipt_type: str
+        self, user_id: str, receipt_types: Iterable[str]
     ) -> Dict[str, str]:
         """
         Fetch the event IDs for the latest receipts sent by the given user.
 
         Args:
             user_id: The user to fetch receipts for.
-            receipt_type: The receipt type to fetch.
+            receipt_types: The receipt types to check.
 
         Returns:
             A map of room ID to the event ID of the latest receipt for that room.
@@ -188,16 +222,48 @@ class ReceiptsWorkerStore(SQLBaseStore):
             If the user has not sent a receipt to a room then it will not appear
             in the returned dictionary.
         """
-        rows = await self.db_pool.simple_select_list(
-            table="receipts_linearized",
-            keyvalues={"user_id": user_id, "receipt_type": receipt_type},
-            retcols=("room_id", "event_id"),
-            desc="get_receipts_for_user",
+        results = await self.get_receipts_for_user_with_orderings(
+            user_id, receipt_types
         )
 
-        return {row["room_id"]: row["event_id"] for row in rows}
+        # Reduce the result to room ID -> event ID.
+        return {
+            room_id: room_result["event_id"] for room_id, room_result in results.items()
+        }
 
     async def get_receipts_for_user_with_orderings(
+        self, user_id: str, receipt_types: Iterable[str]
+    ) -> JsonDict:
+        """
+        Fetch receipts for all rooms that the given user is joined to.
+
+        Args:
+            user_id: The user to fetch receipts for.
+            receipt_types: The receipt types to fetch. Earlier receipt types
+                are given priority if multiple receipts point to the same event.
+
+        Returns:
+            A map of room ID to the latest receipt (for the given types).
+        """
+        results: JsonDict = {}
+        for receipt_type in receipt_types:
+            partial_result = await self._get_receipts_for_user_with_orderings(
+                user_id, receipt_type
+            )
+            for room_id, room_result in partial_result.items():
+                # If the room has not yet been seen, or the receipt is newer,
+                # use it.
+                if (
+                    room_id not in results
+                    or results[room_id]["stream_ordering"]
+                    < room_result["stream_ordering"]
+                ):
+                    results[room_id] = room_result
+
+        return results
+
+    @cached()
+    async def _get_receipts_for_user_with_orderings(
         self, user_id: str, receipt_type: str
     ) -> JsonDict:
         """
@@ -220,8 +286,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 " WHERE rl.room_id = e.room_id"
                 " AND rl.event_id = e.event_id"
                 " AND user_id = ?"
+                " AND receipt_type = ?"
             )
-            txn.execute(sql, (user_id,))
+            txn.execute(sql, (user_id, receipt_type))
             return cast(List[Tuple[str, str, int, int]], txn.fetchall())
 
         rows = await self.db_pool.runInteraction(
@@ -552,9 +619,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
     def invalidate_caches_for_receipt(
         self, room_id: str, receipt_type: str, user_id: str
     ) -> None:
-        self.get_receipts_for_user.invalidate((user_id, receipt_type))
+        self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate((room_id,))
-        self.get_last_receipt_event_id_for_user.invalidate(
+        self._get_last_receipt_event_id_for_user.invalidate(
             (user_id, room_id, receipt_type)
         )
         self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
@@ -590,8 +657,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """Inserts a receipt into the database if it's newer than the current one.
 
         Returns:
-            None if the RR is older than the current RR
-            otherwise, the rx timestamp of the event that the RR corresponds to
+            None if the receipt is older than the current receipt
+            otherwise, the rx timestamp of the event that the receipt corresponds to
                 (or 0 if the event is unknown)
         """
         assert self._can_write_to_receipts
@@ -612,7 +679,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if stream_ordering is not None:
             sql = (
                 "SELECT stream_ordering, event_id FROM events"
-                " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+                " INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
                 " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
             )
             txn.execute(sql, (room_id, receipt_type, user_id))
@@ -653,7 +720,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
             lock=False,
         )
 
-        if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
+        if (
+            receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
+            and stream_ordering is not None
+        ):
             self._remove_old_push_actions_before_txn(  # type: ignore[attr-defined]
                 txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
             )
@@ -672,6 +742,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         Automatically does conversion between linearized and graph
         representations.
+
+        Returns:
+            The new receipts stream ID and token, if the receipt is newer than
+            what was previously persisted. None, otherwise.
         """
         assert self._can_write_to_receipts
 
@@ -719,6 +793,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 stream_id=stream_id,
             )
 
+        # If the receipt was older than the currently persisted one, nothing to do.
         if event_ts is None:
             return None
 
@@ -774,7 +849,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
             receipt_type,
             user_id,
         )
-        txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
+        txn.call_after(
+            self._get_receipts_for_user_with_orderings.invalidate,
+            (user_id, receipt_type),
+        )
         # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))