summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-10-04 12:07:02 -0400
committerGitHub <noreply@github.com>2022-10-04 12:07:02 -0400
commitdcced5a8d76b94e372aefa7d1f05ec0dbc22ea0d (patch)
tree8673b6e27bf90c5ef57da90db53be9a7242ea13b
parentBump types-pyyaml from 6.0.4 to 6.0.12 (#14041) (diff)
downloadsynapse-dcced5a8d76b94e372aefa7d1f05ec0dbc22ea0d.tar.xz
Use threaded receipts when fetching events for push. (#13878)
Update the HTTP and email pushers to consider threaded read receipts
when fetching unread events.
-rw-r--r--changelog.d/13878.feature1
-rw-r--r--synapse/storage/databases/main/event_push_actions.py80
-rw-r--r--tests/storage/test_event_push_actions.py57
3 files changed, 97 insertions, 41 deletions
diff --git a/changelog.d/13878.feature b/changelog.d/13878.feature
new file mode 100644
index 0000000000..d0cb902dff
--- /dev/null
+++ b/changelog.d/13878.feature
@@ -0,0 +1 @@
+Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7469cd336c..332e13d1c9 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -119,6 +119,32 @@ DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [
 ]
 
 
+@attr.s(slots=True, auto_attribs=True)
+class _RoomReceipt:
+    """
+    HttpPushAction instances include the information used to generate HTTP
+    requests to a push gateway.
+    """
+
+    unthreaded_stream_ordering: int = 0
+    # threaded_stream_ordering includes the main pseudo-thread.
+    threaded_stream_ordering: Dict[str, int] = attr.Factory(dict)
+
+    def is_unread(self, thread_id: str, stream_ordering: int) -> bool:
+        """Returns True if the stream ordering is unread according to the receipt information."""
+
+        # Only include push actions with a stream ordering after both the unthreaded
+        # and threaded receipt. Properly handles a user without any receipts present.
+        return (
+            self.unthreaded_stream_ordering < stream_ordering
+            and self.threaded_stream_ordering.get(thread_id, 0) < stream_ordering
+        )
+
+
+# A _RoomReceipt with no receipts in it.
+MISSING_ROOM_RECEIPT = _RoomReceipt()
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class HttpPushAction:
     """
@@ -716,7 +742,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
     def _get_receipts_by_room_txn(
         self, txn: LoggingTransaction, user_id: str
-    ) -> Dict[str, int]:
+    ) -> Dict[str, _RoomReceipt]:
         """
         Generate a map of room ID to the latest stream ordering that has been
         read by the given user.
@@ -726,7 +752,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             user_id: The user to fetch receipts for.
 
         Returns:
-            A map of room ID to stream ordering for all rooms the user has a receipt in.
+            A map including all rooms the user is in with a receipt. It maps
+            room IDs to _RoomReceipt instances
         """
         receipt_types_clause, args = make_in_list_sql_clause(
             self.database_engine,
@@ -735,20 +762,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
 
         sql = f"""
-            SELECT room_id, MAX(stream_ordering)
+            SELECT room_id, thread_id, MAX(stream_ordering)
             FROM receipts_linearized
             INNER JOIN events USING (room_id, event_id)
             WHERE {receipt_types_clause}
             AND user_id = ?
-            GROUP BY room_id
+            GROUP BY room_id, thread_id
         """
 
         args.extend((user_id,))
         txn.execute(sql, args)
-        return {
-            room_id: latest_stream_ordering
-            for room_id, latest_stream_ordering in txn.fetchall()
-        }
+
+        result: Dict[str, _RoomReceipt] = {}
+        for room_id, thread_id, stream_ordering in txn:
+            room_receipt = result.setdefault(room_id, _RoomReceipt())
+            if thread_id is None:
+                room_receipt.unthreaded_stream_ordering = stream_ordering
+            else:
+                room_receipt.threaded_stream_ordering[thread_id] = stream_ordering
+
+        return result
 
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
@@ -781,9 +814,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         def get_push_actions_txn(
             txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool]]:
+        ) -> List[Tuple[str, str, str, int, str, bool]]:
             sql = """
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
+                SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
+                    ep.actions, ep.highlight
                 FROM event_push_actions AS ep
                 WHERE
                     ep.user_id = ?
@@ -793,7 +827,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 ORDER BY ep.stream_ordering ASC LIMIT ?
             """
             txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
-            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
+            return cast(List[Tuple[str, str, str, int, str, bool]], txn.fetchall())
 
         push_actions = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
@@ -806,10 +840,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 stream_ordering=stream_ordering,
                 actions=_deserialize_action(actions, highlight),
             )
-            for event_id, room_id, stream_ordering, actions, highlight in push_actions
-            # Only include push actions with a stream ordering after any receipt, or without any
-            # receipt present (invited to but never read rooms).
-            if stream_ordering > receipts_by_room.get(room_id, 0)
+            for event_id, room_id, thread_id, stream_ordering, actions, highlight in push_actions
+            if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
+                thread_id, stream_ordering
+            )
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -853,10 +887,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         def get_push_actions_txn(
             txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool, int]]:
+        ) -> List[Tuple[str, str, str, int, str, bool, int]]:
             sql = """
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight, e.received_ts
+                SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
+                    ep.actions, ep.highlight, e.received_ts
                 FROM event_push_actions AS ep
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
@@ -867,7 +901,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 ORDER BY ep.stream_ordering DESC LIMIT ?
             """
             txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
-            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
+            return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall())
 
         push_actions = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
@@ -882,10 +916,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 actions=_deserialize_action(actions, highlight),
                 received_ts=received_ts,
             )
-            for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
-            # Only include push actions with a stream ordering after any receipt, or without any
-            # receipt present (invited to but never read rooms).
-            if stream_ordering > receipts_by_room.get(room_id, 0)
+            for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions
+            if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
+                thread_id, stream_ordering
+            )
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 886585e9f2..ee48920f84 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -16,7 +16,7 @@ from typing import Optional, Tuple
 
 from twisted.test.proto_helpers import MemoryReactor
 
-from synapse.api.constants import MAIN_TIMELINE
+from synapse.api.constants import MAIN_TIMELINE, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
@@ -66,16 +66,23 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         user_id, token, _, other_token, room_id = self._create_users_and_room()
 
         # Create two events, one of which is a highlight.
-        self.helper.send_event(
+        first_event_id = self.helper.send_event(
             room_id,
             type="m.room.message",
             content={"msgtype": "m.text", "body": "msg"},
             tok=other_token,
-        )
-        event_id = self.helper.send_event(
+        )["event_id"]
+        second_event_id = self.helper.send_event(
             room_id,
             type="m.room.message",
-            content={"msgtype": "m.text", "body": user_id},
+            content={
+                "msgtype": "m.text",
+                "body": user_id,
+                "m.relates_to": {
+                    "rel_type": RelationTypes.THREAD,
+                    "event_id": first_event_id,
+                },
+            },
             tok=other_token,
         )["event_id"]
 
@@ -95,13 +102,13 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         )
         self.assertEqual(2, len(email_actions))
 
-        # Send a receipt, which should clear any actions.
+        # Send a receipt, which should clear the first action.
         self.get_success(
             self.store.insert_receipt(
                 room_id,
                 "m.read",
                 user_id=user_id,
-                event_ids=[event_id],
+                event_ids=[first_event_id],
                 thread_id=None,
                 data={},
             )
@@ -111,6 +118,30 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
                 user_id, 0, 1000, 20
             )
         )
+        self.assertEqual(1, len(http_actions))
+        email_actions = self.get_success(
+            self.store.get_unread_push_actions_for_user_in_range_for_email(
+                user_id, 0, 1000, 20
+            )
+        )
+        self.assertEqual(1, len(email_actions))
+
+        # Send a thread receipt to clear the thread action.
+        self.get_success(
+            self.store.insert_receipt(
+                room_id,
+                "m.read",
+                user_id=user_id,
+                event_ids=[second_event_id],
+                thread_id=first_event_id,
+                data={},
+            )
+        )
+        http_actions = self.get_success(
+            self.store.get_unread_push_actions_for_user_in_range_for_http(
+                user_id, 0, 1000, 20
+            )
+        )
         self.assertEqual([], http_actions)
         email_actions = self.get_success(
             self.store.get_unread_push_actions_for_user_in_range_for_email(
@@ -417,17 +448,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         sends both unthreaded and threaded receipts.
         """
 
-        # Create a user to receive notifications and send receipts.
-        user_id = self.register_user("user1235", "pass")
-        token = self.login("user1235", "pass")
-
-        # And another users to send events.
-        other_id = self.register_user("other", "pass")
-        other_token = self.login("other", "pass")
-
-        # Create a room and put both users in it.
-        room_id = self.helper.create_room_as(user_id, tok=token)
-        self.helper.join(room_id, other_id, tok=other_token)
+        user_id, token, _, other_token, room_id = self._create_users_and_room()
         thread_id: str
 
         last_event_id: str