summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13937.feature1
-rw-r--r--synapse/storage/databases/main/event_push_actions.py12
-rw-r--r--synapse/storage/databases/main/receipts.py36
-rw-r--r--tests/storage/test_receipts.py74
4 files changed, 47 insertions, 76 deletions
diff --git a/changelog.d/13937.feature b/changelog.d/13937.feature
new file mode 100644
index 0000000000..d0cb902dff
--- /dev/null
+++ b/changelog.d/13937.feature
@@ -0,0 +1 @@
+Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index f4cdc2e399..7e0ffef7d3 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -366,14 +366,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         user_id: str,
     ) -> NotifCounts:
         # Get the stream ordering of the user's latest receipt in the room.
-        result = self.get_last_receipt_for_user_txn(
+        result = self.get_last_unthreaded_receipt_for_user_txn(
             txn,
             user_id,
             room_id,
-            receipt_types=(
-                ReceiptTypes.READ,
-                ReceiptTypes.READ_PRIVATE,
-            ),
+            receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
         )
 
         if result:
@@ -574,10 +571,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         receipt_types_clause, args = make_in_list_sql_clause(
             self.database_engine,
             "receipt_type",
-            (
-                ReceiptTypes.READ,
-                ReceiptTypes.READ_PRIVATE,
-            ),
+            (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
         )
 
         sql = f"""
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 52fe0db924..246f78ac1f 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -135,34 +135,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """Get the current max stream ID for receipts stream"""
         return self._receipts_id_gen.get_current_token()
 
-    async def get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_types: Collection[str]
-    ) -> Optional[str]:
-        """
-        Fetch the event ID for the latest receipt in a room with one of the given receipt types.
-
-        Args:
-            user_id: The user to fetch receipts for.
-            room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt types to fetch.
-
-        Returns:
-            The latest receipt, if one exists.
-        """
-        result = await self.db_pool.runInteraction(
-            "get_last_receipt_event_id_for_user",
-            self.get_last_receipt_for_user_txn,
-            user_id,
-            room_id,
-            receipt_types,
-        )
-        if not result:
-            return None
-
-        event_id, _ = result
-        return event_id
-
-    def get_last_receipt_for_user_txn(
+    def get_last_unthreaded_receipt_for_user_txn(
         self,
         txn: LoggingTransaction,
         user_id: str,
@@ -170,13 +143,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
         receipt_types: Collection[str],
     ) -> Optional[Tuple[str, int]]:
         """
-        Fetch the event ID and stream_ordering for the latest receipt in a room
-        with one of the given receipt types.
+        Fetch the event ID and stream_ordering for the latest unthreaded receipt
+        in a room with one of the given receipt types.
 
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt types to fetch.
+            receipt_types: The receipt types to fetch.
 
         Returns:
             The event ID and stream ordering of the latest receipt, if one exists.
@@ -193,6 +166,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             WHERE {clause}
             AND user_id = ?
             AND room_id = ?
+            AND thread_id IS NULL
             ORDER BY stream_ordering DESC
             LIMIT 1
         """
diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py
index 9459ee1705..81253d0361 100644
--- a/tests/storage/test_receipts.py
+++ b/tests/storage/test_receipts.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Collection, Optional
 
 from synapse.api.constants import ReceiptTypes
 from synapse.types import UserID, create_requester
@@ -84,6 +85,33 @@ class ReceiptTestCase(HomeserverTestCase):
             )
         )
 
+    def get_last_unthreaded_receipt(
+        self, receipt_types: Collection[str], room_id: Optional[str] = None
+    ) -> Optional[str]:
+        """
+        Fetch the event ID for the latest unthreaded receipt in the test room for the test user.
+
+        Args:
+            receipt_types: The receipt types to fetch.
+
+        Returns:
+            The latest receipt, if one exists.
+        """
+        result = self.get_success(
+            self.store.db_pool.runInteraction(
+                "get_last_receipt_event_id_for_user",
+                self.store.get_last_unthreaded_receipt_for_user_txn,
+                OUR_USER_ID,
+                room_id or self.room_id1,
+                receipt_types,
+            )
+        )
+        if not result:
+            return None
+
+        event_id, _ = result
+        return event_id
+
     def test_return_empty_with_no_data(self) -> None:
         res = self.get_success(
             self.store.get_receipts_for_user(
@@ -107,16 +135,10 @@ class ReceiptTestCase(HomeserverTestCase):
         )
         self.assertEqual(res, {})
 
-        res = self.get_success(
-            self.store.get_last_receipt_event_id_for_user(
-                OUR_USER_ID,
-                self.room_id1,
-                [
-                    ReceiptTypes.READ,
-                    ReceiptTypes.READ_PRIVATE,
-                ],
-            )
+        res = self.get_last_unthreaded_receipt(
+            [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
         )
+
         self.assertEqual(res, None)
 
     def test_get_receipts_for_user(self) -> None:
@@ -228,29 +250,17 @@ class ReceiptTestCase(HomeserverTestCase):
         )
 
         # Test we get the latest event when we want both private and public receipts
-        res = self.get_success(
-            self.store.get_last_receipt_event_id_for_user(
-                OUR_USER_ID,
-                self.room_id1,
-                [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
-            )
+        res = self.get_last_unthreaded_receipt(
+            [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
         )
         self.assertEqual(res, event1_2_id)
 
         # Test we get the older event when we want only public receipt
-        res = self.get_success(
-            self.store.get_last_receipt_event_id_for_user(
-                OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
-            )
-        )
+        res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
         self.assertEqual(res, event1_1_id)
 
         # Test we get the latest event when we want only the private receipt
-        res = self.get_success(
-            self.store.get_last_receipt_event_id_for_user(
-                OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
-            )
-        )
+        res = self.get_last_unthreaded_receipt([ReceiptTypes.READ_PRIVATE])
         self.assertEqual(res, event1_2_id)
 
         # Test receipt updating
@@ -259,11 +269,7 @@ class ReceiptTestCase(HomeserverTestCase):
                 self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
             )
         )
-        res = self.get_success(
-            self.store.get_last_receipt_event_id_for_user(
-                OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
-            )
-        )
+        res = self.get_last_unthreaded_receipt([ReceiptTypes.READ])
         self.assertEqual(res, event1_2_id)
 
         # Send some events into the second room
@@ -282,11 +288,7 @@ class ReceiptTestCase(HomeserverTestCase):
                 {},
             )
         )
-        res = self.get_success(
-            self.store.get_last_receipt_event_id_for_user(
-                OUR_USER_ID,
-                self.room_id2,
-                [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
-            )
+        res = self.get_last_unthreaded_receipt(
+            [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], room_id=self.room_id2
         )
         self.assertEqual(res, event2_1_id)