summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/schema/delta/56/unique_user_filter_index.py58
-rw-r--r--synapse/storage/state_deltas.py38
2 files changed, 61 insertions, 35 deletions
diff --git a/synapse/storage/schema/delta/56/unique_user_filter_index.py b/synapse/storage/schema/delta/56/unique_user_filter_index.py
index 60031f23ca..1de8b54961 100644
--- a/synapse/storage/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/schema/delta/56/unique_user_filter_index.py
@@ -5,42 +5,48 @@ from synapse.storage.engines import PostgresEngine
 logger = logging.getLogger(__name__)
 
 
+"""
+This migration updates the user_filters table as follows:
+
+ - drops any (user_id, filter_id) duplicates
+ - makes the columns NON-NULLable
+ - turns the index into a UNIQUE index
+"""
+
+
 def run_upgrade(cur, database_engine, *args, **kwargs):
+    pass
+
+
+def run_create(cur, database_engine, *args, **kwargs):
     if isinstance(database_engine, PostgresEngine):
         select_clause = """
-        CREATE TEMPORARY TABLE user_filters_migration AS
             SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json
-            FROM user_filters;
+            FROM user_filters
         """
     else:
         select_clause = """
-        CREATE TEMPORARY TABLE user_filters_migration AS
-            SELECT * FROM user_filters GROUP BY user_id, filter_id;
+            SELECT * FROM user_filters GROUP BY user_id, filter_id
         """
-    sql = (
-        """
-        BEGIN;
-            %s
-            DROP INDEX user_filters_by_user_id_filter_id;
-            DELETE FROM user_filters;
-            ALTER TABLE user_filters
-               ALTER COLUMN user_id SET NOT NULL,
-               ALTER COLUMN filter_id SET NOT NULL,
-               ALTER COLUMN filter_json SET NOT NULL;
-            INSERT INTO user_filters(user_id, filter_id, filter_json)
-                SELECT * FROM user_filters_migration;
-            DROP TABLE user_filters_migration;
-            CREATE UNIQUE INDEX user_filters_by_user_id_filter_id_unique
-                ON user_filters(user_id, filter_id);
-        END;
-    """
-        % select_clause
+    sql = """
+            DROP TABLE IF EXISTS user_filters_migration;
+            DROP INDEX IF EXISTS user_filters_unique;
+            CREATE TABLE user_filters_migration (
+                user_id TEXT NOT NULL,
+                filter_id BIGINT NOT NULL,
+                filter_json BYTEA NOT NULL
+            );
+            INSERT INTO user_filters_migration (user_id, filter_id, filter_json)
+                %s;
+            CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration
+                (user_id, filter_id);
+            DROP TABLE user_filters;
+            ALTER TABLE user_filters_migration RENAME TO user_filters;
+        """ % (
+        select_clause,
     )
+
     if isinstance(database_engine, PostgresEngine):
         cur.execute(sql)
     else:
         cur.executescript(sql)
-
-
-def run_create(cur, database_engine, *args, **kwargs):
-    pass
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
index 5fdb442104..28f33ec18f 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/state_deltas.py
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
 
 
 class StateDeltasStore(SQLBaseStore):
-    def get_current_state_deltas(self, prev_stream_id):
+    def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
         """Fetch a list of room state changes since the given stream id
 
         Each entry in the result contains the following fields:
@@ -36,15 +36,27 @@ class StateDeltasStore(SQLBaseStore):
 
         Args:
             prev_stream_id (int): point to get changes since (exclusive)
+            max_stream_id (int): the point that we know has been correctly persisted
+               - ie, an upper limit to return changes from.
 
         Returns:
-            Deferred[list[dict]]: results
+            Deferred[tuple[int, list[dict]]: A tuple consisting of:
+               - the stream id which these results go up to
+               - list of current_state_delta_stream rows. If it is empty, we are
+                 up to date.
         """
         prev_stream_id = int(prev_stream_id)
+
+        # check we're not going backwards
+        assert prev_stream_id <= max_stream_id
+
         if not self._curr_state_delta_stream_cache.has_any_entity_changed(
             prev_stream_id
         ):
-            return []
+            # if the CSDs haven't changed between prev_stream_id and now, we
+            # know for certain that they haven't changed between prev_stream_id and
+            # max_stream_id.
+            return max_stream_id, []
 
         def get_current_state_deltas_txn(txn):
             # First we calculate the max stream id that will give us less than
@@ -54,21 +66,29 @@ class StateDeltasStore(SQLBaseStore):
             sql = """
                 SELECT stream_id, count(*)
                 FROM current_state_delta_stream
-                WHERE stream_id > ?
+                WHERE stream_id > ? AND stream_id <= ?
                 GROUP BY stream_id
                 ORDER BY stream_id ASC
                 LIMIT 100
             """
-            txn.execute(sql, (prev_stream_id,))
+            txn.execute(sql, (prev_stream_id, max_stream_id))
 
             total = 0
-            max_stream_id = prev_stream_id
-            for max_stream_id, count in txn:
+
+            for stream_id, count in txn:
                 total += count
                 if total > 100:
                     # We arbitarily limit to 100 entries to ensure we don't
                     # select toooo many.
+                    logger.debug(
+                        "Clipping current_state_delta_stream rows to stream_id %i",
+                        stream_id,
+                    )
+                    clipped_stream_id = stream_id
                     break
+            else:
+                # if there's no problem, we may as well go right up to the max_stream_id
+                clipped_stream_id = max_stream_id
 
             # Now actually get the deltas
             sql = """
@@ -77,8 +97,8 @@ class StateDeltasStore(SQLBaseStore):
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
             """
-            txn.execute(sql, (prev_stream_id, max_stream_id))
-            return self.cursor_to_dict(txn)
+            txn.execute(sql, (prev_stream_id, clipped_stream_id))
+            return clipped_stream_id, self.cursor_to_dict(txn)
 
         return self.runInteraction(
             "get_current_state_deltas", get_current_state_deltas_txn