summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17309.misc1
-rw-r--r--synapse/storage/util/id_generators.py23
-rw-r--r--tests/storage/test_id_generators.py301
3 files changed, 147 insertions, 178 deletions
diff --git a/changelog.d/17309.misc b/changelog.d/17309.misc
new file mode 100644
index 0000000000..cb6b9504b3
--- /dev/null
+++ b/changelog.d/17309.misc
@@ -0,0 +1 @@
+When rolling back to a previous Synapse version and then forwards again to this release, don't require server operators to manually run SQL.
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 59c8e05c39..48f88a6f8a 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -276,9 +276,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         # no active writes in progress.
         self._max_position_of_local_instance = self._max_seen_allocated_stream_id
 
-        # This goes and fills out the above state from the database.
-        self._load_current_ids(db_conn, tables)
-
         self._sequence_gen = build_sequence_generator(
             db_conn=db_conn,
             database_engine=db.engine,
@@ -303,6 +300,13 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                 positive=positive,
             )
 
+        # This goes and fills out the above state from the database.
+        # This may read on the PostgreSQL sequence, and
+        # SequenceGenerator.check_consistency might have fixed up the sequence, which
+        # means the SequenceGenerator needs to be setup before we read the value from
+        # the sequence.
+        self._load_current_ids(db_conn, tables, sequence_name)
+
         self._max_seen_allocated_stream_id = max(
             self._current_positions.values(), default=1
         )
@@ -327,6 +331,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         self,
         db_conn: LoggingDatabaseConnection,
         tables: List[Tuple[str, str, str]],
+        sequence_name: str,
     ) -> None:
         cur = db_conn.cursor(txn_name="_load_current_ids")
 
@@ -360,6 +365,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
                 if instance in self._writers
             }
 
+        # If we're a writer, we can assume we're at the end of the stream
+        # Usually, we would get that from the stream_positions, but in some cases,
+        # like if we rolled back Synapse, the stream_positions table might not be up to
+        # date. If we're using Postgres for the sequences, we can just use the current
+        # sequence value as our own position.
+        if self._instance_name in self._writers:
+            if isinstance(self._db.engine, PostgresEngine):
+                cur.execute(f"SELECT last_value FROM {sequence_name}")
+                row = cur.fetchone()
+                assert row is not None
+                self._current_positions[self._instance_name] = row[0]
+
         # We set the `_persisted_upto_position` to be the minimum of all current
         # positions. If empty we use the max stream ID from the DB table.
         min_stream_id = min(self._current_positions.values(), default=None)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 9be2923e6f..12b89cecb6 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -18,7 +18,7 @@
 # [This file includes modifications made by New Vector Limited]
 #
 #
-from typing import List, Optional
+from typing import Dict, List, Optional
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -42,9 +42,13 @@ from tests.utils import USE_POSTGRES_FOR_TESTS
 
 
 class MultiWriterIdGeneratorBase(HomeserverTestCase):
+    positive: bool = True
+    tables: List[str] = ["foobar"]
+
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
+        self.instances: Dict[str, MultiWriterIdGenerator] = {}
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
@@ -57,18 +61,22 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
         if USE_POSTGRES_FOR_TESTS:
             txn.execute("CREATE SEQUENCE foobar_seq")
 
-        txn.execute(
-            """
-            CREATE TABLE foobar (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
+        for table in self.tables:
+            txn.execute(
+                """
+                CREATE TABLE %s (
+                    stream_id BIGINT NOT NULL,
+                    instance_name TEXT NOT NULL,
+                    data TEXT
+                );
+                """
+                % (table,)
+            )
 
     def _create_id_generator(
-        self, instance_name: str = "master", writers: Optional[List[str]] = None
+        self,
+        instance_name: str = "master",
+        writers: Optional[List[str]] = None,
     ) -> MultiWriterIdGenerator:
         def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
             return MultiWriterIdGenerator(
@@ -77,36 +85,93 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
                 notifier=self.hs.get_replication_notifier(),
                 stream_name="test_stream",
                 instance_name=instance_name,
-                tables=[("foobar", "instance_name", "stream_id")],
+                tables=[(table, "instance_name", "stream_id") for table in self.tables],
                 sequence_name="foobar_seq",
                 writers=writers or ["master"],
+                positive=self.positive,
             )
 
-        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+        self.instances[instance_name] = self.get_success_or_raise(
+            self.db_pool.runWithConnection(_create)
+        )
+        return self.instances[instance_name]
+
+    def _replicate(self, instance_name: str) -> None:
+        """Similate a replication event for the given instance."""
+
+        writer = self.instances[instance_name]
+        token = writer.get_current_token_for_writer(instance_name)
+        for generator in self.instances.values():
+            if writer != generator:
+                generator.advance(instance_name, token)
+
+    def _replicate_all(self) -> None:
+        """Similate a replication event for all instances."""
 
-    def _insert_rows(self, instance_name: str, number: int) -> None:
+        for instance_name in self.instances:
+            self._replicate(instance_name)
+
+    def _insert_row(
+        self, instance_name: str, stream_id: int, table: Optional[str] = None
+    ) -> None:
+        """Insert one row as the given instance with given stream_id."""
+
+        if table is None:
+            table = self.tables[0]
+
+        factor = 1 if self.positive else -1
+
+        def _insert(txn: LoggingTransaction) -> None:
+            txn.execute(
+                "INSERT INTO %s VALUES (?, ?)" % (table,),
+                (
+                    stream_id,
+                    instance_name,
+                ),
+            )
+            txn.execute(
+                """
+                INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+                ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+                """,
+                (instance_name, stream_id * factor, stream_id * factor),
+            )
+
+        self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+    def _insert_rows(
+        self,
+        instance_name: str,
+        number: int,
+        table: Optional[str] = None,
+        update_stream_table: bool = True,
+    ) -> None:
         """Insert N rows as the given instance, inserting with stream IDs pulled
         from the postgres sequence.
         """
 
+        if table is None:
+            table = self.tables[0]
+
+        factor = 1 if self.positive else -1
+
         def _insert(txn: LoggingTransaction) -> None:
             for _ in range(number):
                 next_val = self.seq_gen.get_next_id_txn(txn)
                 txn.execute(
-                    "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
-                    (
-                        next_val,
-                        instance_name,
-                    ),
+                    "INSERT INTO %s (stream_id, instance_name) VALUES (?, ?)"
+                    % (table,),
+                    (next_val, instance_name),
                 )
 
-                txn.execute(
-                    """
-                    INSERT INTO stream_positions VALUES ('test_stream', ?,  ?)
-                    ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
-                    """,
-                    (instance_name, next_val, next_val),
-                )
+                if update_stream_table:
+                    txn.execute(
+                        """
+                        INSERT INTO stream_positions VALUES ('test_stream', ?,  ?)
+                        ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+                        """,
+                        (instance_name, next_val * factor, next_val * factor),
+                    )
 
         self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
 
@@ -353,7 +418,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
 
         id_gen = self._create_id_generator("first", writers=["first", "second"])
 
-        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+        # When the writer is created, it assumes its own position is the current head of
+        # the sequence
+        self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
 
         self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
@@ -375,11 +442,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         correctly.
         """
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
+        self._replicate_all()
+
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
@@ -398,6 +467,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
                 self.assertEqual(
                     first_id_gen.get_positions(), {"first": 3, "second": 7}
                 )
+                self.assertEqual(
+                    second_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
                 self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
 
         self.get_success(_get_next_async())
@@ -432,11 +504,11 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         """
         # Insert some rows for two out of three of the ID gens.
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator(
             "first", writers=["first", "second", "third"]
         )
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator(
             "second", writers=["first", "second", "third"]
         )
@@ -444,6 +516,8 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
             "third", writers=["first", "second", "third"]
         )
 
+        self._replicate_all()
+
         self.assertEqual(
             first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
         )
@@ -546,11 +620,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
 
     def test_minimal_local_token(self) -> None:
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
+        self._replicate_all()
+
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
 
@@ -562,15 +638,17 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         token when there are no writes.
         """
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator(
             "first", writers=["first", "second", "third"]
         )
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator(
             "second", writers=["first", "second", "third"]
         )
 
+        self._replicate_all()
+
         self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(second_id_gen.get_current_token(), 7)
@@ -609,68 +687,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         self.assertEqual(second_id_gen.get_current_token(), 7)
 
 
-class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
     """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
 
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastores().main
-        self.db_pool: DatabasePool = self.store.db_pool
-
-        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
-
-    def _setup_db(self, txn: LoggingTransaction) -> None:
-        txn.execute("CREATE SEQUENCE foobar_seq")
-        txn.execute(
-            """
-            CREATE TABLE foobar (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
-
-    def _create_id_generator(
-        self, instance_name: str = "master", writers: Optional[List[str]] = None
-    ) -> MultiWriterIdGenerator:
-        def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
-            return MultiWriterIdGenerator(
-                conn,
-                self.db_pool,
-                notifier=self.hs.get_replication_notifier(),
-                stream_name="test_stream",
-                instance_name=instance_name,
-                tables=[("foobar", "instance_name", "stream_id")],
-                sequence_name="foobar_seq",
-                writers=writers or ["master"],
-                positive=False,
-            )
-
-        return self.get_success(self.db_pool.runWithConnection(_create))
-
-    def _insert_row(self, instance_name: str, stream_id: int) -> None:
-        """Insert one row as the given instance with given stream_id."""
-
-        def _insert(txn: LoggingTransaction) -> None:
-            txn.execute(
-                "INSERT INTO foobar VALUES (?, ?)",
-                (
-                    stream_id,
-                    instance_name,
-                ),
-            )
-            txn.execute(
-                """
-                INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
-                ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
-                """,
-                (instance_name, -stream_id, -stream_id),
-            )
-
-        self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+    positive = False
 
     def test_single_instance(self) -> None:
         """Test that reads and writes from a single process are handled
@@ -716,7 +739,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         async def _get_next_async() -> None:
             async with id_gen_1.get_next() as stream_id:
                 self._insert_row("first", stream_id)
-                id_gen_2.advance("first", stream_id)
+            self._replicate("first")
 
         self.get_success(_get_next_async())
 
@@ -728,7 +751,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         async def _get_next_async2() -> None:
             async with id_gen_2.get_next() as stream_id:
                 self._insert_row("second", stream_id)
-                id_gen_1.advance("second", stream_id)
+            self._replicate("second")
 
         self.get_success(_get_next_async2())
 
@@ -738,98 +761,26 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
 
 
-class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+class MultiTableMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastores().main
-        self.db_pool: DatabasePool = self.store.db_pool
-
-        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
-
-    def _setup_db(self, txn: LoggingTransaction) -> None:
-        txn.execute("CREATE SEQUENCE foobar_seq")
-        txn.execute(
-            """
-            CREATE TABLE foobar1 (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
-
-        txn.execute(
-            """
-            CREATE TABLE foobar2 (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
-
-    def _create_id_generator(
-        self, instance_name: str = "master", writers: Optional[List[str]] = None
-    ) -> MultiWriterIdGenerator:
-        def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
-            return MultiWriterIdGenerator(
-                conn,
-                self.db_pool,
-                notifier=self.hs.get_replication_notifier(),
-                stream_name="test_stream",
-                instance_name=instance_name,
-                tables=[
-                    ("foobar1", "instance_name", "stream_id"),
-                    ("foobar2", "instance_name", "stream_id"),
-                ],
-                sequence_name="foobar_seq",
-                writers=writers or ["master"],
-            )
-
-        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
-
-    def _insert_rows(
-        self,
-        table: str,
-        instance_name: str,
-        number: int,
-        update_stream_table: bool = True,
-    ) -> None:
-        """Insert N rows as the given instance, inserting with stream IDs pulled
-        from the postgres sequence.
-        """
-
-        def _insert(txn: LoggingTransaction) -> None:
-            for _ in range(number):
-                txn.execute(
-                    "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
-                    (instance_name,),
-                )
-                if update_stream_table:
-                    txn.execute(
-                        """
-                        INSERT INTO stream_positions VALUES ('test_stream', ?,  lastval())
-                        ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
-                        """,
-                        (instance_name,),
-                    )
-
-        self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+    tables = ["foobar1", "foobar2"]
 
     def test_load_existing_stream(self) -> None:
         """Test creating ID gens with multiple tables that have rows from after
         the position in `stream_positions` table.
         """
-        self._insert_rows("foobar1", "first", 3)
-        self._insert_rows("foobar2", "second", 3)
-        self._insert_rows("foobar2", "second", 1, update_stream_table=False)
-
+        self._insert_rows("first", 3, table="foobar1")
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+        self._insert_rows("second", 3, table="foobar2")
+        self._insert_rows("second", 1, table="foobar2", update_stream_table=False)
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
+        self._replicate_all()
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)