summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13893.feature1
-rw-r--r--synapse/storage/databases/main/event_push_actions.py38
-rw-r--r--tests/storage/test_event_push_actions.py88
3 files changed, 97 insertions, 30 deletions
diff --git a/changelog.d/13893.feature b/changelog.d/13893.feature
new file mode 100644
index 0000000000..d0cb902dff
--- /dev/null
+++ b/changelog.d/13893.feature
@@ -0,0 +1 @@
+Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 6b8668d2dc..f4cdc2e399 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -559,7 +559,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
     def _get_receipts_by_room_txn(
         self, txn: LoggingTransaction, user_id: str
-    ) -> List[Tuple[str, int]]:
+    ) -> Dict[str, int]:
+        """
+        Generate a map of room ID to the latest stream ordering that has been
+        read by the given user.
+
+        Args:
+            txn:
+            user_id: The user to fetch receipts for.
+
+        Returns:
+            A map of room ID to stream ordering for all rooms the user has a receipt in.
+        """
         receipt_types_clause, args = make_in_list_sql_clause(
             self.database_engine,
             "receipt_type",
@@ -580,7 +591,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         args.extend((user_id,))
         txn.execute(sql, args)
-        return cast(List[Tuple[str, int]], txn.fetchall())
+        return {
+            room_id: latest_stream_ordering
+            for room_id, latest_stream_ordering in txn.fetchall()
+        }
 
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
@@ -605,12 +619,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        receipts_by_room = dict(
-            await self.db_pool.runInteraction(
-                "get_unread_push_actions_for_user_in_range_http_receipts",
-                self._get_receipts_by_room_txn,
-                user_id=user_id,
-            ),
+        receipts_by_room = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_http_receipts",
+            self._get_receipts_by_room_txn,
+            user_id=user_id,
         )
 
         def get_push_actions_txn(
@@ -679,12 +691,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             The list will have between 0~limit entries.
         """
 
-        receipts_by_room = dict(
-            await self.db_pool.runInteraction(
-                "get_unread_push_actions_for_user_in_range_email_receipts",
-                self._get_receipts_by_room_txn,
-                user_id=user_id,
-            ),
+        receipts_by_room = await self.db_pool.runInteraction(
+            "get_unread_push_actions_for_user_in_range_email_receipts",
+            self._get_receipts_by_room_txn,
+            user_id=user_id,
         )
 
         def get_push_actions_txn(
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 08c74b93e3..473c965e19 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Tuple
+
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.rest import admin
@@ -22,8 +24,6 @@ from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
-USER_ID = "@user:example.com"
-
 
 class EventPushActionsStoreTestCase(HomeserverTestCase):
     servlets = [
@@ -38,21 +38,13 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         assert persist_events_store is not None
         self.persist_events_store = persist_events_store
 
-    def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None:
-        self.get_success(
-            self.store.get_unread_push_actions_for_user_in_range_for_http(
-                USER_ID, 0, 1000, 20
-            )
-        )
+    def _create_users_and_room(self) -> Tuple[str, str, str, str, str]:
+        """
+        Creates two users and a shared room.
 
-    def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None:
-        self.get_success(
-            self.store.get_unread_push_actions_for_user_in_range_for_email(
-                USER_ID, 0, 1000, 20
-            )
-        )
-
-    def test_count_aggregation(self) -> None:
+        Returns:
+            Tuple of (user 1 ID, user 1 token, user 2 ID, user 2 token, room ID).
+        """
         # Create a user to receive notifications and send receipts.
         user_id = self.register_user("user1235", "pass")
         token = self.login("user1235", "pass")
@@ -65,6 +57,70 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         room_id = self.helper.create_room_as(user_id, tok=token)
         self.helper.join(room_id, other_id, tok=other_token)
 
+        return user_id, token, other_id, other_token, room_id
+
+    def test_get_unread_push_actions_for_user_in_range(self) -> None:
+        """Test getting unread push actions for HTTP and email pushers."""
+        user_id, token, _, other_token, room_id = self._create_users_and_room()
+
+        # Create two events, one of which is a highlight.
+        self.helper.send_event(
+            room_id,
+            type="m.room.message",
+            content={"msgtype": "m.text", "body": "msg"},
+            tok=other_token,
+        )
+        event_id = self.helper.send_event(
+            room_id,
+            type="m.room.message",
+            content={"msgtype": "m.text", "body": user_id},
+            tok=other_token,
+        )["event_id"]
+
+        # Fetch unread actions for HTTP pushers.
+        http_actions = self.get_success(
+            self.store.get_unread_push_actions_for_user_in_range_for_http(
+                user_id, 0, 1000, 20
+            )
+        )
+        self.assertEqual(2, len(http_actions))
+
+        # Fetch unread actions for email pushers.
+        email_actions = self.get_success(
+            self.store.get_unread_push_actions_for_user_in_range_for_email(
+                user_id, 0, 1000, 20
+            )
+        )
+        self.assertEqual(2, len(email_actions))
+
+        # Send a receipt, which should clear any actions.
+        self.get_success(
+            self.store.insert_receipt(
+                room_id,
+                "m.read",
+                user_id=user_id,
+                event_ids=[event_id],
+                thread_id=None,
+                data={},
+            )
+        )
+        http_actions = self.get_success(
+            self.store.get_unread_push_actions_for_user_in_range_for_http(
+                user_id, 0, 1000, 20
+            )
+        )
+        self.assertEqual([], http_actions)
+        email_actions = self.get_success(
+            self.store.get_unread_push_actions_for_user_in_range_for_email(
+                user_id, 0, 1000, 20
+            )
+        )
+        self.assertEqual([], email_actions)
+
+    def test_count_aggregation(self) -> None:
+        # Create a user to receive notifications and send receipts.
+        user_id, token, _, other_token, room_id = self._create_users_and_room()
+
         last_event_id: str
 
         def _assert_counts(noitf_count: int, highlight_count: int) -> None: