summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py80
1 files changed, 57 insertions, 23 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7469cd336c..332e13d1c9 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -119,6 +119,32 @@ DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [
 ]
 
 
+@attr.s(slots=True, auto_attribs=True)
+class _RoomReceipt:
+    """
+    HttpPushAction instances include the information used to generate HTTP
+    requests to a push gateway.
+    """
+
+    unthreaded_stream_ordering: int = 0
+    # threaded_stream_ordering includes the main pseudo-thread.
+    threaded_stream_ordering: Dict[str, int] = attr.Factory(dict)
+
+    def is_unread(self, thread_id: str, stream_ordering: int) -> bool:
+        """Returns True if the stream ordering is unread according to the receipt information."""
+
+        # Only include push actions with a stream ordering after both the unthreaded
+        # and threaded receipt. Properly handles a user without any receipts present.
+        return (
+            self.unthreaded_stream_ordering < stream_ordering
+            and self.threaded_stream_ordering.get(thread_id, 0) < stream_ordering
+        )
+
+
+# A _RoomReceipt with no receipts in it.
+MISSING_ROOM_RECEIPT = _RoomReceipt()
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class HttpPushAction:
     """
@@ -716,7 +742,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
     def _get_receipts_by_room_txn(
         self, txn: LoggingTransaction, user_id: str
-    ) -> Dict[str, int]:
+    ) -> Dict[str, _RoomReceipt]:
         """
         Generate a map of room ID to the latest stream ordering that has been
         read by the given user.
@@ -726,7 +752,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             user_id: The user to fetch receipts for.
 
         Returns:
-            A map of room ID to stream ordering for all rooms the user has a receipt in.
+            A map including all rooms the user is in with a receipt. It maps
+            room IDs to _RoomReceipt instances
         """
         receipt_types_clause, args = make_in_list_sql_clause(
             self.database_engine,
@@ -735,20 +762,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
 
         sql = f"""
-            SELECT room_id, MAX(stream_ordering)
+            SELECT room_id, thread_id, MAX(stream_ordering)
             FROM receipts_linearized
             INNER JOIN events USING (room_id, event_id)
             WHERE {receipt_types_clause}
             AND user_id = ?
-            GROUP BY room_id
+            GROUP BY room_id, thread_id
         """
 
         args.extend((user_id,))
         txn.execute(sql, args)
-        return {
-            room_id: latest_stream_ordering
-            for room_id, latest_stream_ordering in txn.fetchall()
-        }
+
+        result: Dict[str, _RoomReceipt] = {}
+        for room_id, thread_id, stream_ordering in txn:
+            room_receipt = result.setdefault(room_id, _RoomReceipt())
+            if thread_id is None:
+                room_receipt.unthreaded_stream_ordering = stream_ordering
+            else:
+                room_receipt.threaded_stream_ordering[thread_id] = stream_ordering
+
+        return result
 
     async def get_unread_push_actions_for_user_in_range_for_http(
         self,
@@ -781,9 +814,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         def get_push_actions_txn(
             txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool]]:
+        ) -> List[Tuple[str, str, str, int, str, bool]]:
             sql = """
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
+                SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
+                    ep.actions, ep.highlight
                 FROM event_push_actions AS ep
                 WHERE
                     ep.user_id = ?
@@ -793,7 +827,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 ORDER BY ep.stream_ordering ASC LIMIT ?
             """
             txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
-            return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
+            return cast(List[Tuple[str, str, str, int, str, bool]], txn.fetchall())
 
         push_actions = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
@@ -806,10 +840,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 stream_ordering=stream_ordering,
                 actions=_deserialize_action(actions, highlight),
             )
-            for event_id, room_id, stream_ordering, actions, highlight in push_actions
-            # Only include push actions with a stream ordering after any receipt, or without any
-            # receipt present (invited to but never read rooms).
-            if stream_ordering > receipts_by_room.get(room_id, 0)
+            for event_id, room_id, thread_id, stream_ordering, actions, highlight in push_actions
+            if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
+                thread_id, stream_ordering
+            )
         ]
 
         # Now sort it so it's ordered correctly, since currently it will
@@ -853,10 +887,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         def get_push_actions_txn(
             txn: LoggingTransaction,
-        ) -> List[Tuple[str, str, int, str, bool, int]]:
+        ) -> List[Tuple[str, str, str, int, str, bool, int]]:
             sql = """
-                SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
-                    ep.highlight, e.received_ts
+                SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
+                    ep.actions, ep.highlight, e.received_ts
                 FROM event_push_actions AS ep
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
@@ -867,7 +901,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 ORDER BY ep.stream_ordering DESC LIMIT ?
             """
             txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
-            return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
+            return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall())
 
         push_actions = await self.db_pool.runInteraction(
             "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
@@ -882,10 +916,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 actions=_deserialize_action(actions, highlight),
                 received_ts=received_ts,
             )
-            for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
-            # Only include push actions with a stream ordering after any receipt, or without any
-            # receipt present (invited to but never read rooms).
-            if stream_ordering > receipts_by_room.get(room_id, 0)
+            for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions
+            if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
+                thread_id, stream_ordering
+            )
         ]
 
         # Now sort it so it's ordered correctly, since currently it will