summary refs log tree commit diff
path: root/synapse/storage/databases/main/receipts.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/receipts.py')
-rw-r--r--synapse/storage/databases/main/receipts.py74
1 files changed, 41 insertions, 33 deletions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index b6106affa6..bec6d60577 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -118,7 +118,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         return self._receipts_id_gen.get_current_token()
 
     async def get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_types: Iterable[str]
+        self, user_id: str, room_id: str, receipt_types: Collection[str]
     ) -> Optional[str]:
         """
         Fetch the event ID for the latest receipt in a room with one of the given receipt types.
@@ -126,58 +126,63 @@ class ReceiptsWorkerStore(SQLBaseStore):
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt types to fetch. Earlier receipt types
-                are given priority if multiple receipts point to the same event.
+            receipt_type: The receipt types to fetch.
 
         Returns:
             The latest receipt, if one exists.
         """
-        latest_event_id: Optional[str] = None
-        latest_stream_ordering = 0
-        for receipt_type in receipt_types:
-            result = await self._get_last_receipt_event_id_for_user(
-                user_id, room_id, receipt_type
-            )
-            if result is None:
-                continue
-            event_id, stream_ordering = result
-
-            if latest_event_id is None or latest_stream_ordering < stream_ordering:
-                latest_event_id = event_id
-                latest_stream_ordering = stream_ordering
+        result = await self.db_pool.runInteraction(
+            "get_last_receipt_event_id_for_user",
+            self.get_last_receipt_for_user_txn,
+            user_id,
+            room_id,
+            receipt_types,
+        )
+        if not result:
+            return None
 
-        return latest_event_id
+        event_id, _ = result
+        return event_id
 
-    @cached()
-    async def _get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_type: str
+    def get_last_receipt_for_user_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        room_id: str,
+        receipt_types: Collection[str],
     ) -> Optional[Tuple[str, int]]:
         """
-        Fetch the event ID and stream ordering for the latest receipt.
+        Fetch the event ID and stream_ordering for the latest receipt in a room
+        with one of the given receipt types.
 
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt type to fetch.
+            receipt_type: The receipt types to fetch.
 
         Returns:
-            The event ID and stream ordering of the latest receipt, if one exists;
-            otherwise `None`.
+            The latest receipt, if one exists.
         """
-        sql = """
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "receipt_type", receipt_types
+        )
+
+        sql = f"""
             SELECT event_id, stream_ordering
             FROM receipts_linearized
             INNER JOIN events USING (room_id, event_id)
-            WHERE user_id = ?
+            WHERE {clause}
+            AND user_id = ?
             AND room_id = ?
-            AND receipt_type = ?
+            ORDER BY stream_ordering DESC
+            LIMIT 1
         """
 
-        def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]:
-            txn.execute(sql, (user_id, room_id, receipt_type))
-            return cast(Optional[Tuple[str, int]], txn.fetchone())
+        args.extend((user_id, room_id))
+        txn.execute(sql, args)
 
-        return await self.db_pool.runInteraction("get_own_receipt_for_user", f)
+        return cast(Optional[Tuple[str, int]], txn.fetchone())
 
     async def get_receipts_for_user(
         self, user_id: str, receipt_types: Iterable[str]
@@ -577,8 +582,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
     ) -> None:
         self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate((room_id,))
-        self._get_last_receipt_event_id_for_user.invalidate(
-            (user_id, room_id, receipt_type)
+
+        # We use this method to invalidate so that we don't end up with circular
+        # dependencies between the receipts and push action stores.
+        self._attempt_to_invalidate_cache(
+            "get_unread_event_push_actions_by_room_for_user", (room_id,)
         )
 
     def process_replication_rows(