summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2020-09-02 17:19:37 +0100
committerGitHub <noreply@github.com>2020-09-02 17:19:37 +0100
commit5a1dd297c3ce105a7f516d9d9fe87b94b9d356c8 (patch)
treee7be5283e17b93e1e9de477b5cf08a8d87bfcbb3 /synapse/storage/databases
parentRefactor `_get_e2e_device_keys_for_federation_query_txn` (#8225) (diff)
downloadsynapse-5a1dd297c3ce105a7f516d9d9fe87b94b9d356c8.tar.xz
Re-implement unread counts (again) (#8059)
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/event_push_actions.py224
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/schema/delta/58/15unread_count.sql26
-rw-r--r--synapse/storage/databases/main/stream.py21
4 files changed, 205 insertions, 70 deletions
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 384c39f4d7..001d06378d 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -15,7 +15,9 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Union
+from typing import Dict, List, Optional, Tuple, Union
+
+import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     @cached(num_args=3, tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
-        self, room_id, user_id, last_read_event_id
-    ):
+        self, room_id: str, user_id: str, last_read_event_id: Optional[str],
+    ) -> Dict[str, int]:
+        """Get the notification count, the highlight count and the unread message count
+        for a given user in a given room after the given read receipt.
+
+        Note that this function assumes the user to be a current member of the room,
+        since it's either called by the sync handler to handle joined room entries, or by
+        the HTTP pusher to calculate the badge of unread joined rooms.
+
+        Args:
+            room_id: The room to retrieve the counts in.
+            user_id: The user to retrieve the counts for.
+            last_read_event_id: The event associated with the latest read receipt for
+                this user in this room. None if no receipt for this user in this room.
+
+        Returns
+            A dict containing the counts mentioned earlier in this docstring,
+            respectively under the keys "notify_count", "highlight_count" and
+            "unread_count".
+        """
         return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
@@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     def _get_unread_counts_by_receipt_txn(
-        self, txn, room_id, user_id, last_read_event_id
+        self, txn, room_id, user_id, last_read_event_id,
     ):
-        sql = (
-            "SELECT stream_ordering"
-            " FROM events"
-            " WHERE room_id = ? AND event_id = ?"
-        )
-        txn.execute(sql, (room_id, last_read_event_id))
-        results = txn.fetchall()
-        if len(results) == 0:
-            return {"notify_count": 0, "highlight_count": 0}
+        stream_ordering = None
+
+        if last_read_event_id is not None:
+            stream_ordering = self.get_stream_id_for_event_txn(
+                txn, last_read_event_id, allow_none=True,
+            )
+
+        if stream_ordering is None:
+            # Either last_read_event_id is None, or it's an event we don't have (e.g.
+            # because it's been purged), in which case retrieve the stream ordering for
+            # the latest membership event from this user in this room (which we assume is
+            # a join).
+            event_id = self.db_pool.simple_select_one_onecol_txn(
+                txn=txn,
+                table="local_current_membership",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+                retcol="event_id",
+            )
 
-        stream_ordering = results[0][0]
+            stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
 
         return self._get_unread_counts_by_pos_txn(
             txn, room_id, user_id, stream_ordering
         )
 
     def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
-
-        # First get number of notifications.
-        # We don't need to put a notif=1 clause as all rows always have
-        # notif=1
         sql = (
-            "SELECT count(*)"
+            "SELECT"
+            "   COUNT(CASE WHEN notif = 1 THEN 1 END),"
+            "   COUNT(CASE WHEN highlight = 1 THEN 1 END),"
+            "   COUNT(CASE WHEN unread = 1 THEN 1 END)"
             " FROM event_push_actions ea"
-            " WHERE"
-            " user_id = ?"
-            " AND room_id = ?"
-            " AND stream_ordering > ?"
+            " WHERE user_id = ?"
+            "   AND room_id = ?"
+            "   AND stream_ordering > ?"
         )
 
         txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
-        notify_count = row[0] if row else 0
+
+        (notif_count, highlight_count, unread_count) = (0, 0, 0)
+
+        if row:
+            (notif_count, highlight_count, unread_count) = row
 
         txn.execute(
             """
-            SELECT notif_count FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
-        """,
+                SELECT notif_count, unread_count FROM event_push_summary
+                WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
+            """,
             (room_id, user_id, stream_ordering),
         )
-        rows = txn.fetchall()
-        if rows:
-            notify_count += rows[0][0]
-
-        # Now get the number of highlights
-        sql = (
-            "SELECT count(*)"
-            " FROM event_push_actions ea"
-            " WHERE"
-            " highlight = 1"
-            " AND user_id = ?"
-            " AND room_id = ?"
-            " AND stream_ordering > ?"
-        )
-
-        txn.execute(sql, (user_id, room_id, stream_ordering))
         row = txn.fetchone()
-        highlight_count = row[0] if row else 0
 
-        return {"notify_count": notify_count, "highlight_count": highlight_count}
+        if row:
+            notif_count += row[0]
+            unread_count += row[1]
+
+        return {
+            "notify_count": notif_count,
+            "unread_count": unread_count,
+            "highlight_count": highlight_count,
+        }
 
     async def get_push_action_users_in_range(
         self, min_stream_ordering, max_stream_ordering
@@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
+                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -402,7 +428,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         def _get_if_maybe_push_in_range_for_user_txn(txn):
             sql = """
                 SELECT 1 FROM event_push_actions
-                WHERE user_id = ? AND stream_ordering > ?
+                WHERE user_id = ? AND stream_ordering > ? AND notif = 1
                 LIMIT 1
             """
 
@@ -415,7 +441,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
 
     async def add_push_actions_to_staging(
-        self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]]
+        self,
+        event_id: str,
+        user_id_actions: Dict[str, List[Union[dict, str]]],
+        count_as_unread: bool,
     ) -> None:
         """Add the push actions for the event to the push action staging area.
 
@@ -423,21 +452,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             event_id
             user_id_actions: A mapping of user_id to list of push actions, where
                 an action can either be a string or dict.
+            count_as_unread: Whether this event should increment unread counts.
         """
-
         if not user_id_actions:
             return
 
         # This is a helper function for generating the necessary tuple that
-        # can be used to inert into the `event_push_actions_staging` table.
+        # can be used to insert into the `event_push_actions_staging` table.
         def _gen_entry(user_id, actions):
             is_highlight = 1 if _action_has_highlight(actions) else 0
+            notif = 1 if "notify" in actions else 0
             return (
                 event_id,  # event_id column
                 user_id,  # user_id column
                 _serialize_action(actions, is_highlight),  # actions column
-                1,  # notif column
+                notif,  # notif column
                 is_highlight,  # highlight column
+                int(count_as_unread),  # unread column
             )
 
         def _add_push_actions_to_staging_txn(txn):
@@ -446,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
             sql = """
                 INSERT INTO event_push_actions_staging
-                    (event_id, user_id, actions, notif, highlight)
-                VALUES (?, ?, ?, ?, ?)
+                    (event_id, user_id, actions, notif, highlight, unread)
+                VALUES (?, ?, ?, ?, ?, ?)
             """
 
             txn.executemany(
@@ -811,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
             SELECT user_id, room_id,
-                coalesce(old.notif_count, 0) + upd.notif_count,
+                coalesce(old.%s, 0) + upd.cnt,
                 upd.stream_ordering,
                 old.user_id
             FROM (
-                SELECT user_id, room_id, count(*) as notif_count,
+                SELECT user_id, room_id, count(*) as cnt,
                     max(stream_ordering) as stream_ordering
                 FROM event_push_actions
                 WHERE ? <= stream_ordering AND stream_ordering < ?
                     AND highlight = 0
+                    AND %s = 1
                 GROUP BY user_id, room_id
             ) AS upd
             LEFT JOIN event_push_summary AS old USING (user_id, room_id)
         """
 
-        txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
-        rows = txn.fetchall()
+        # First get the count of unread messages.
+        txn.execute(
+            sql % ("unread_count", "unread"),
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+        )
+
+        # We need to merge results from the two requests (the one that retrieves the
+        # unread count and the one that retrieves the notifications count) into a single
+        # object because we might not have the same amount of rows in each of them. To do
+        # this, we use a dict indexed on the user ID and room ID to make it easier to
+        # populate.
+        summaries = {}  # type: Dict[Tuple[str, str], _EventPushSummary]
+        for row in txn:
+            summaries[(row[0], row[1])] = _EventPushSummary(
+                unread_count=row[2],
+                stream_ordering=row[3],
+                old_user_id=row[4],
+                notif_count=0,
+            )
+
+        # Then get the count of notifications.
+        txn.execute(
+            sql % ("notif_count", "notif"),
+            (old_rotate_stream_ordering, rotate_to_stream_ordering),
+        )
+
+        for row in txn:
+            if (row[0], row[1]) in summaries:
+                summaries[(row[0], row[1])].notif_count = row[2]
+            else:
+                # Because the rules on notifying are different than the rules on marking
+                # a message unread, we might end up with messages that notify but aren't
+                # marked unread, so we might not have a summary for this (user, room)
+                # tuple to complete.
+                summaries[(row[0], row[1])] = _EventPushSummary(
+                    unread_count=0,
+                    stream_ordering=row[3],
+                    old_user_id=row[4],
+                    notif_count=row[2],
+                )
 
-        logger.info("Rotating notifications, handling %d rows", len(rows))
+        logger.info("Rotating notifications, handling %d rows", len(summaries))
 
         # If the `old.user_id` above is NULL then we know there isn't already an
         # entry in the table, so we simply insert it. Otherwise we update the
@@ -838,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             table="event_push_summary",
             values=[
                 {
-                    "user_id": row[0],
-                    "room_id": row[1],
-                    "notif_count": row[2],
-                    "stream_ordering": row[3],
+                    "user_id": user_id,
+                    "room_id": room_id,
+                    "notif_count": summary.notif_count,
+                    "unread_count": summary.unread_count,
+                    "stream_ordering": summary.stream_ordering,
                 }
-                for row in rows
-                if row[4] is None
+                for ((user_id, room_id), summary) in summaries.items()
+                if summary.old_user_id is None
             ],
         )
 
         txn.executemany(
             """
-                UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
+                UPDATE event_push_summary
+                SET notif_count = ?, unread_count = ?, stream_ordering = ?
                 WHERE user_id = ? AND room_id = ?
             """,
-            ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
+            (
+                (
+                    summary.notif_count,
+                    summary.unread_count,
+                    summary.stream_ordering,
+                    user_id,
+                    room_id,
+                )
+                for ((user_id, room_id), summary) in summaries.items()
+                if summary.old_user_id is not None
+            ),
         )
 
         txn.execute(
@@ -879,3 +961,15 @@ def _action_has_highlight(actions):
             pass
 
     return False
+
+
+@attr.s
+class _EventPushSummary:
+    """Summary of pending event push actions for a given user in a given room.
+    Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+    """
+
+    unread_count = attr.ib(type=int)
+    stream_ordering = attr.ib(type=int)
+    old_user_id = attr.ib(type=str)
+    notif_count = attr.ib(type=int)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 46b11e705b..b94fe7ac17 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1298,9 +1298,9 @@ class PersistEventsStore:
         sql = """
             INSERT INTO event_push_actions (
                 room_id, event_id, user_id, actions, stream_ordering,
-                topological_ordering, notif, highlight
+                topological_ordering, notif, highlight, unread
             )
-            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+            SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
             FROM event_push_actions_staging
             WHERE event_id = ?
         """
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
new file mode 100644
index 0000000000..b451e8663a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We're hijacking the push actions to store unread messages and unread counts (specified
+-- in MSC2654) because doing otherwise would result in either performance issues or
+-- reimplementing a consequent bit of the push actions.
+
+-- Add columns to event_push_actions and event_push_actions_staging to track unread
+-- messages and calculate unread counts.
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+
+-- Add column to event_push_summary
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 24f44a7e36..83c1ddf95a 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -47,7 +47,11 @@ from synapse.api.filtering import Filter
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.types import RoomStreamToken
@@ -593,8 +597,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A stream ID.
         """
-        return await self.db_pool.simple_select_one_onecol(
-            table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
+        return await self.db_pool.runInteraction(
+            "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
+        )
+
+    def get_stream_id_for_event_txn(
+        self, txn: LoggingTransaction, event_id: str, allow_none=False,
+    ) -> int:
+        return self.db_pool.simple_select_one_onecol_txn(
+            txn=txn,
+            table="events",
+            keyvalues={"event_id": event_id},
+            retcol="stream_ordering",
+            allow_none=allow_none,
         )
 
     async def get_stream_token_for_event(self, event_id: str) -> str: