summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2022-07-06 14:53:31 -0400
committerPatrick Cloke <patrickc@matrix.org>2022-08-05 08:18:31 -0400
commitfd972df8f9ca9a6805d532a5502fc89478bddc67 (patch)
treedb655551fa3fefa9fa7cab3b022cd09eef4ea390
parentSend the thread ID over replication. (diff)
downloadsynapse-fd972df8f9ca9a6805d532a5502fc89478bddc67.tar.xz
Mark thread notifications as read.
-rw-r--r--synapse/storage/databases/main/event_push_actions.py156
-rw-r--r--synapse/storage/databases/main/receipts.py1
-rw-r--r--tests/storage/test_event_push_actions.py17
3 files changed, 104 insertions, 70 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 78191ee4bd..3a007f7539 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -78,7 +78,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
 
 import attr
 
-from synapse.api.constants import ReceiptTypes
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
@@ -264,30 +263,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         room_id: str,
         user_id: str,
     ) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
-        result = self.get_last_receipt_for_user_txn(
-            txn,
-            user_id,
-            room_id,
-            receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+        # Either last_read_event_id is None, or it's an event we don't have (e.g.
+        # because it's been purged), in which case retrieve the stream ordering for
+        # the latest membership event from this user in this room (which we assume is
+        # a join).
+        event_id = self.db_pool.simple_select_one_onecol_txn(
+            txn=txn,
+            table="local_current_membership",
+            keyvalues={"room_id": room_id, "user_id": user_id},
+            retcol="event_id",
         )
 
-        stream_ordering = None
-        if result:
-            _, stream_ordering = result
-
-        if stream_ordering is None:
-            # Either last_read_event_id is None, or it's an event we don't have (e.g.
-            # because it's been purged), in which case retrieve the stream ordering for
-            # the latest membership event from this user in this room (which we assume is
-            # a join).
-            event_id = self.db_pool.simple_select_one_onecol_txn(
-                txn=txn,
-                table="local_current_membership",
-                keyvalues={"room_id": room_id, "user_id": user_id},
-                retcol="event_id",
-            )
-
-            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
+        stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
 
         return self._get_unread_counts_by_pos_txn(
             txn, room_id, user_id, stream_ordering
@@ -325,18 +312,30 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # receipt).
         txn.execute(
             """
-                SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id
+                SELECT notif_count, COALESCE(unread_count, 0), thread_id, MAX(events.stream_ordering)
                 FROM event_push_summary
-                WHERE room_id = ? AND user_id = ?
+                LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
+                LEFT JOIN events ON (
+                    events.room_id = receipts_linearized.room_id AND
+                    events.event_id = receipts_linearized.event_id
+                )
+                WHERE event_push_summary.room_id = ? AND user_id = ?
                 AND (
-                    (last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
-                    OR last_receipt_stream_ordering = ?
+                    (
+                        last_receipt_stream_ordering IS NULL
+                        AND event_push_summary.stream_ordering > COALESCE(events.stream_ordering, ?)
+                    )
+                    OR last_receipt_stream_ordering = COALESCE(events.stream_ordering, ?)
                 )
+                AND (receipt_type = 'm.read' OR receipt_type = 'org.matrix.msc2285.read.private')
             """,
             (room_id, user_id, stream_ordering, stream_ordering),
         )
-        max_summary_stream_ordering = 0
-        for summary_stream_ordering, notif_count, unread_count, thread_id in txn:
+        for notif_count, unread_count, thread_id, _ in txn:
+            # XXX Why are these returned? Related to MAX(...) aggregation.
+            if notif_count is None:
+                continue
+
             if not thread_id:
                 counts = NotifCounts(
                     notify_count=notif_count, unread_count=unread_count
@@ -347,22 +346,22 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     notify_count=notif_count, unread_count=unread_count
                 )
 
-            # XXX All threads should have the same stream ordering?
-            max_summary_stream_ordering = max(
-                summary_stream_ordering, max_summary_stream_ordering
-            )
-
         # Next we need to count highlights, which aren't summarised
         sql = """
-            SELECT COUNT(*), thread_id FROM event_push_actions
+            SELECT COUNT(*), thread_id, MAX(events.stream_ordering) FROM event_push_actions
+            LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
+            LEFT JOIN events ON (
+                events.room_id = receipts_linearized.room_id AND
+                events.event_id = receipts_linearized.event_id
+            )
             WHERE user_id = ?
-                AND room_id = ?
-                AND stream_ordering > ?
+                AND event_push_actions.room_id = ?
+                AND event_push_actions.stream_ordering > COALESCE(events.stream_ordering, ?)
                 AND highlight = 1
             GROUP BY thread_id
         """
         txn.execute(sql, (user_id, room_id, stream_ordering))
-        for highlight_count, thread_id in txn:
+        for highlight_count, thread_id, _ in txn:
             if not thread_id:
                 counts.highlight_count += highlight_count
             elif highlight_count:
@@ -376,7 +375,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         # Finally we need to count push actions that aren't included in the
         # summary returned above, e.g. recent events that haven't been
         # summarised yet, or the summary is empty due to a recent read receipt.
-        stream_ordering = max(stream_ordering, max_summary_stream_ordering)
         unread_counts = self._get_notif_unread_count_for_user_room(
             txn, room_id, user_id, stream_ordering
         )
@@ -425,8 +423,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
 
         # If there have been no events in the room since the stream ordering,
         # there can't be any push actions either.
-        if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
-            return []
+        #
+        # XXX
+        # if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
+        #     return []
 
         clause = ""
         args = [user_id, room_id, stream_ordering]
@@ -434,26 +434,29 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
             clause = "AND ea.stream_ordering <= ?"
             args.append(max_stream_ordering)
 
-            # If the max stream ordering is less than the min stream ordering,
-            # then obviously there are zero push actions in that range.
-            if max_stream_ordering <= stream_ordering:
-                return []
-
         sql = f"""
             SELECT
                COUNT(CASE WHEN notif = 1 THEN 1 END),
                COUNT(CASE WHEN unread = 1 THEN 1 END),
-               thread_id
+               thread_id,
+               MAX(events.stream_ordering)
             FROM event_push_actions ea
+            LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
+            LEFT JOIN events ON (
+                events.room_id = receipts_linearized.room_id AND
+                events.event_id = receipts_linearized.event_id
+            )
             WHERE user_id = ?
-               AND room_id = ?
-               AND ea.stream_ordering > ?
+               AND ea.room_id = ?
+               AND ea.stream_ordering > COALESCE(events.stream_ordering, ?)
                {clause}
             GROUP BY thread_id
         """
 
         txn.execute(sql, args)
-        return cast(List[Tuple[int, int, str]], txn.fetchall())
+        # The max stream ordering is simply there to select the latest receipt,
+        # it doesn't need to be returned.
+        return [cast(Tuple[int, int, str], row[:3]) for row in txn.fetchall()]
 
     async def get_push_action_users_in_range(
         self, min_stream_ordering: int, max_stream_ordering: int
@@ -1000,7 +1003,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
 
         sql = """
-            SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
+            SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering
             FROM receipts_linearized AS r
             INNER JOIN events AS e USING (event_id)
             WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ?
@@ -1023,28 +1026,42 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
         )
         rows = txn.fetchall()
 
-        # For each new read receipt we delete push actions from before it and
-        # recalculate the summary.
-        for _, room_id, user_id, stream_ordering in rows:
+        # Group the rows by room ID / user ID.
+        rows_by_room_user: Dict[Tuple[str, str], List[Tuple[str, str, int]]] = {}
+        for stream_id, room_id, user_id, thread_id, stream_ordering in rows:
             # Only handle our own read receipts.
             if not self.hs.is_mine_id(user_id):
                 continue
 
-            txn.execute(
-                """
-                DELETE FROM event_push_actions
-                WHERE room_id = ?
-                    AND user_id = ?
-                    AND stream_ordering <= ?
-                    AND highlight = 0
-                """,
-                (room_id, user_id, stream_ordering),
+            rows_by_room_user.setdefault((room_id, user_id), []).append(
+                (stream_id, thread_id, stream_ordering)
             )
 
+        # For each new read receipt we delete push actions from before it and
+        # recalculate the summary.
+        for (room_id, user_id), room_rows in rows_by_room_user.items():
+            for _, thread_id, stream_ordering in room_rows:
+                txn.execute(
+                    """
+                    DELETE FROM event_push_actions
+                    WHERE room_id = ?
+                        AND user_id = ?
+                        AND thread_id = ?
+                        AND stream_ordering <= ?
+                        AND highlight = 0
+                    """,
+                    (room_id, user_id, thread_id, stream_ordering),
+                )
+
             # Fetch the notification counts between the stream ordering of the
             # latest receipt and what was previously summarised.
+            earliest_stream_ordering = min(r[2] for r in room_rows)
             unread_counts = self._get_notif_unread_count_for_user_room(
-                txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
+                txn,
+                room_id,
+                user_id,
+                earliest_stream_ordering,
+                old_rotate_stream_ordering,
             )
 
             # Updated threads get their notification count and unread count updated.
@@ -1060,11 +1077,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                     "last_receipt_stream_ordering",
                 ),
                 value_values=[
-                    (row[0], row[1], old_rotate_stream_ordering, stream_ordering)
+                    # XXX Stream ordering.
+                    (
+                        row[0],
+                        row[1],
+                        old_rotate_stream_ordering,
+                        earliest_stream_ordering,
+                    )
                     for row in unread_counts
                 ],
             )
 
+            # XXX WTF?
             # Other threads should be marked as reset at the old stream ordering.
             txn.execute(
                 """
@@ -1074,7 +1098,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
                 """,
                 (
                     old_rotate_stream_ordering,
-                    stream_ordering,
+                    min_receipts_stream_id,
                     user_id,
                     room_id,
                     old_rotate_stream_ordering,
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 51347c1cf9..238b6a8b91 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -117,6 +117,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """Get the current max stream ID for receipts stream"""
         return self._receipts_id_gen.get_current_token()
 
+    # XXX MOVE TO TESTS
     async def get_last_receipt_event_id_for_user(
         self, user_id: str, room_id: str, receipt_types: Collection[str]
     ) -> Optional[str]:
diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py
index 53a68b6d17..175437c765 100644
--- a/tests/storage/test_event_push_actions.py
+++ b/tests/storage/test_event_push_actions.py
@@ -252,14 +252,14 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         def _rotate() -> None:
             self.get_success(self.store._rotate_notifs())
 
-        def _mark_read(event_id: str) -> None:
+        def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None:
             self.get_success(
                 self.store.insert_receipt(
                     room_id,
                     "m.read",
                     user_id=user_id,
                     event_ids=[event_id],
-                    thread_id=None,
+                    thread_id=thread_id,
                     data={},
                 )
             )
@@ -288,9 +288,12 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         _create_event()
         _create_event(thread_id=thread_id)
         _mark_read(event_id)
+        _assert_counts(1, 0, 0, 3, 0, 0)
+        _mark_read(event_id, thread_id)
         _assert_counts(1, 0, 0, 1, 0, 0)
 
         _mark_read(last_event_id)
+        _mark_read(last_event_id, thread_id)
         _assert_counts(0, 0, 0, 0, 0, 0)
 
         _create_event()
@@ -304,6 +307,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         _assert_counts(1, 0, 0, 1, 0, 0)
 
         _mark_read(last_event_id)
+        _mark_read(last_event_id, thread_id)
         _assert_counts(0, 0, 0, 0, 0, 0)
 
         _create_event(True)
@@ -320,23 +324,28 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
         # works.
         _create_event()
         _rotate()
-        _assert_counts(2, 0, 1, 1, 1, 1)
+        _assert_counts(2, 1, 1, 1, 1, 1)
 
         _create_event(thread_id=thread_id)
         _rotate()
-        _assert_counts(2, 0, 1, 2, 0, 1)
+        _assert_counts(2, 1, 1, 2, 1, 1)
 
         # Check that sending read receipts at different points results in the
         # right counts.
         _mark_read(event_id)
+        _assert_counts(1, 0, 0, 2, 1, 1)
+        _mark_read(event_id, thread_id)
         _assert_counts(1, 0, 0, 1, 0, 0)
         _mark_read(last_event_id)
+        _assert_counts(0, 0, 0, 1, 0, 0)
+        _mark_read(last_event_id, thread_id)
         _assert_counts(0, 0, 0, 0, 0, 0)
 
         _create_event(True)
         _create_event(True, thread_id)
         _assert_counts(1, 1, 1, 1, 1, 1)
         _mark_read(last_event_id)
+        _mark_read(last_event_id, thread_id)
         _assert_counts(0, 0, 0, 0, 0, 0)
         _rotate()
         _assert_counts(0, 0, 0, 0, 0, 0)