summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py133
-rw-r--r--synapse/storage/data_stores/main/media_repository.py9
-rw-r--r--synapse/storage/data_stores/main/room.py42
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres (renamed from synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql)11
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite18
5 files changed, 65 insertions, 148 deletions
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 815d52ab4c..bc9f4f08ea 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
-# Copyright 2015-2020 The Matrix.org Foundation C.I.C.
+# Copyright 2015 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,9 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Tuple
 
-import attr
 from canonicaljson import json
 
 from twisted.internet import defer
@@ -37,16 +36,6 @@ DEFAULT_HIGHLIGHT_ACTION = [
 ]
 
 
-@attr.s
-class EventPushSummary:
-    """Summary of pending event push actions for a given user in a given room."""
-
-    unread_count = attr.ib(type=int)
-    stream_ordering = attr.ib(type=int)
-    old_user_id = attr.ib(type=str)
-    notif_count = attr.ib(type=int)
-
-
 def _serialize_action(actions, is_highlight):
     """Custom serializer for actions. This allows us to "compress" common actions.
 
@@ -123,7 +112,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         txn.execute(sql, (room_id, last_read_event_id))
         results = txn.fetchall()
         if len(results) == 0:
-            return {"notify_count": 0, "highlight_count": 0, "unread_count": 0}
+            return {"notify_count": 0, "highlight_count": 0}
 
         stream_ordering = results[0][0]
 
@@ -133,42 +122,25 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
     def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
 
-        # First get number of actions, grouped on whether the action notifies.
+        # First get number of notifications.
+        # We don't need to put a notif=1 clause as all rows always have
+        # notif=1
         sql = (
-            "SELECT count(*), notif"
+            "SELECT count(*)"
             " FROM event_push_actions ea"
             " WHERE"
             " user_id = ?"
             " AND room_id = ?"
             " AND stream_ordering > ?"
-            " GROUP BY notif"
         )
-        txn.execute(sql, (user_id, room_id, stream_ordering))
-        rows = txn.fetchall()
 
-        # We should get a maximum number of two rows: one for notif = 0, which is the
-        # number of actions that contribute to the unread_count but not to the
-        # notify_count, and one for notif = 1, which is the number of actions that
-        # contribute to both counters. If one or both rows don't appear, then the
-        # value for the matching counter should be 0.
-        unread_count = 0
-        notify_count = 0
-        for row in rows:
-            # We always increment unread_count because actions that notify also
-            # contribute to it.
-            unread_count += row[0]
-            if row[1] == 1:
-                notify_count = row[0]
-            elif row[1] != 0:
-                logger.warning(
-                    "Unexpected value %d for column 'notif' in table"
-                    " 'event_push_actions'",
-                    row[1],
-                )
+        txn.execute(sql, (user_id, room_id, stream_ordering))
+        row = txn.fetchone()
+        notify_count = row[0] if row else 0
 
         txn.execute(
             """
-            SELECT notif_count, unread_count FROM event_push_summary
+            SELECT notif_count FROM event_push_summary
             WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
         """,
             (room_id, user_id, stream_ordering),
@@ -176,7 +148,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         rows = txn.fetchall()
         if rows:
             notify_count += rows[0][0]
-            unread_count += rows[0][1]
 
         # Now get the number of highlights
         sql = (
@@ -193,11 +164,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         row = txn.fetchone()
         highlight_count = row[0] if row else 0
 
-        return {
-            "unread_count": unread_count,
-            "notify_count": notify_count,
-            "highlight_count": highlight_count,
-        }
+        return {"notify_count": notify_count, "highlight_count": highlight_count}
 
     @defer.inlineCallbacks
     def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
@@ -255,7 +222,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -284,7 +250,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering ASC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -357,7 +322,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -386,7 +350,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "   AND ep.user_id = ?"
                 "   AND ep.stream_ordering > ?"
                 "   AND ep.stream_ordering <= ?"
-                "   AND ep.notif = 1"
                 " ORDER BY ep.stream_ordering DESC LIMIT ?"
             )
             args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -436,7 +399,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         def _get_if_maybe_push_in_range_for_user_txn(txn):
             sql = """
                 SELECT 1 FROM event_push_actions
-                WHERE user_id = ? AND stream_ordering > ? AND notif = 1
+                WHERE user_id = ? AND stream_ordering > ?
                 LIMIT 1
             """
 
@@ -465,15 +428,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             return
 
         # This is a helper function for generating the necessary tuple that
-        # can be used to insert into the `event_push_actions_staging` table.
+        # can be used to inert into the `event_push_actions_staging` table.
         def _gen_entry(user_id, actions):
             is_highlight = 1 if _action_has_highlight(actions) else 0
-            notif = 0 if "org.matrix.msc2625.mark_unread" in actions else 1
             return (
                 event_id,  # event_id column
                 user_id,  # user_id column
                 _serialize_action(actions, is_highlight),  # actions column
-                notif,  # notif column
+                1,  # notif column
                 is_highlight,  # highlight column
             )
 
@@ -855,51 +817,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         # Calculate the new counts that should be upserted into event_push_summary
         sql = """
             SELECT user_id, room_id,
-                coalesce(old.%s, 0) + upd.cnt,
+                coalesce(old.notif_count, 0) + upd.notif_count,
                 upd.stream_ordering,
                 old.user_id
             FROM (
-                SELECT user_id, room_id, count(*) as cnt,
+                SELECT user_id, room_id, count(*) as notif_count,
                     max(stream_ordering) as stream_ordering
                 FROM event_push_actions
                 WHERE ? <= stream_ordering AND stream_ordering < ?
                     AND highlight = 0
-                    %s
                 GROUP BY user_id, room_id
             ) AS upd
             LEFT JOIN event_push_summary AS old USING (user_id, room_id)
         """
 
-        # First get the count of unread messages.
-        txn.execute(
-            sql % ("unread_count", ""),
-            (old_rotate_stream_ordering, rotate_to_stream_ordering),
-        )
-
-        # We need to merge both lists into a single object because we might not have the
-        # same amount of rows in each of them. In this case we use a dict indexed on the
-        # user ID and room ID to make it easier to populate.
-        summaries = {}  # type: Dict[Tuple[str, str], EventPushSummary]
-        for row in txn:
-            summaries[(row[0], row[1])] = EventPushSummary(
-                unread_count=row[2],
-                stream_ordering=row[3],
-                old_user_id=row[4],
-                notif_count=0,
-            )
-
-        # Then get the count of notifications.
-        txn.execute(
-            sql % ("notif_count", "AND notif = 1"),
-            (old_rotate_stream_ordering, rotate_to_stream_ordering),
-        )
-
-        # notif_rows is populated based on a subset of the query used to populate
-        # unread_rows, so we can be sure that there will be no KeyError here.
-        for row in txn:
-            summaries[(row[0], row[1])].notif_count = row[2]
+        txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
+        rows = txn.fetchall()
 
-        logger.info("Rotating notifications, handling %d rows", len(summaries))
+        logger.info("Rotating notifications, handling %d rows", len(rows))
 
         # If the `old.user_id` above is NULL then we know there isn't already an
         # entry in the table, so we simply insert it. Otherwise we update the
@@ -909,34 +844,22 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             table="event_push_summary",
             values=[
                 {
-                    "user_id": user_id,
-                    "room_id": room_id,
-                    "notif_count": summary.notif_count,
-                    "unread_count": summary.unread_count,
-                    "stream_ordering": summary.stream_ordering,
+                    "user_id": row[0],
+                    "room_id": row[1],
+                    "notif_count": row[2],
+                    "stream_ordering": row[3],
                 }
-                for ((user_id, room_id), summary) in summaries.items()
-                if summary.old_user_id is None
+                for row in rows
+                if row[4] is None
             ],
         )
 
         txn.executemany(
             """
-                UPDATE event_push_summary
-                SET notif_count = ?, unread_count = ?, stream_ordering = ?
+                UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
                 WHERE user_id = ? AND room_id = ?
             """,
-            (
-                (
-                    summary.notif_count,
-                    summary.unread_count,
-                    summary.stream_ordering,
-                    user_id,
-                    room_id,
-                )
-                for ((user_id, room_id), summary) in summaries.items()
-                if summary.old_user_id is not None
-            ),
+            ((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
         )
 
         txn.execute(
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 8aecd414c2..15bc13cbd0 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -81,6 +81,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_media",
         )
 
+    def mark_local_media_as_safe(self, media_id: str):
+        """Mark a local media as safe from quarantining."""
+        return self.db.simple_update_one(
+            table="local_media_repository",
+            keyvalues={"media_id": media_id},
+            updatevalues={"safe_from_quarantine": True},
+            desc="mark_local_media_as_safe",
+        )
+
     def get_url_cache(self, url, ts):
         """Get the media_id and ts for a cached URL as of the given timestamp
         Returns:
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 55c1defc73..d72d6affb8 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -649,36 +649,10 @@ class RoomWorkerStore(SQLBaseStore):
 
         def _quarantine_media_in_room_txn(txn):
             local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
-            total_media_quarantined = 0
-
-            # Now update all the tables to set the quarantined_by flag
-
-            txn.executemany(
-                """
-                UPDATE local_media_repository
-                SET quarantined_by = ?
-                WHERE media_id = ?
-            """,
-                ((quarantined_by, media_id) for media_id in local_mxcs),
-            )
-
-            txn.executemany(
-                """
-                    UPDATE remote_media_cache
-                    SET quarantined_by = ?
-                    WHERE media_origin = ? AND media_id = ?
-                """,
-                (
-                    (quarantined_by, origin, media_id)
-                    for origin, media_id in remote_mxcs
-                ),
+            return self._quarantine_media_txn(
+                txn, local_mxcs, remote_mxcs, quarantined_by
             )
 
-            total_media_quarantined += len(local_mxcs)
-            total_media_quarantined += len(remote_mxcs)
-
-            return total_media_quarantined
-
         return self.db.runInteraction(
             "quarantine_media_in_room", _quarantine_media_in_room_txn
         )
@@ -828,17 +802,17 @@ class RoomWorkerStore(SQLBaseStore):
         Returns:
             The total number of media items quarantined
         """
-        total_media_quarantined = 0
-
         # Update all the tables to set the quarantined_by flag
         txn.executemany(
             """
             UPDATE local_media_repository
             SET quarantined_by = ?
-            WHERE media_id = ?
+            WHERE media_id = ? AND safe_from_quarantine = ?
         """,
-            ((quarantined_by, media_id) for media_id in local_mxcs),
+            ((quarantined_by, media_id, False) for media_id in local_mxcs),
         )
+        # Note that a rowcount of -1 can be used to indicate no rows were affected.
+        total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
 
         txn.executemany(
             """
@@ -848,9 +822,7 @@ class RoomWorkerStore(SQLBaseStore):
             """,
             ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
         )
-
-        total_media_quarantined += len(local_mxcs)
-        total_media_quarantined += len(remote_mxcs)
+        total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
 
         return total_media_quarantined
 
diff --git a/synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
index f1459ef7f0..597f2ffd3d 100644
--- a/synapse/storage/data_stores/main/schema/delta/58/07push_summary_unread_count.sql
+++ b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.postgres
@@ -13,11 +13,6 @@
  * limitations under the License.
  */
 
--- Store the number of unread messages, i.e. messages that triggered either a notify
--- action or a mark_unread one.
-ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;
-
--- Pre-populate the new column with the count of pending notifications.
--- We expect event_push_summary to be relatively small, so we can do this update
--- synchronously without impacting Synapse's startup time too much.
-UPDATE event_push_summary SET unread_count = notif_count;
\ No newline at end of file
+-- The local_media_repository should have files which do not get quarantined,
+-- e.g. files from sticker packs.
+ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE;
diff --git a/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
new file mode 100644
index 0000000000..69db89ac0e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/08_media_safe_from_quarantine.sql.sqlite
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The local_media_repository should have files which do not get quarantined,
+-- e.g. files from sticker packs.
+ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0;