summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorQuentin Gliech <quenting@element.io>2024-06-17 13:50:00 +0200
committerGitHub <noreply@github.com>2024-06-17 11:50:00 +0000
commitf983a77ab070eac03f0eafe8dc6b990c43c3e89b (patch)
tree2728bbff222ebc487676de6cbf0c25b870588143 /tests
parentUse the release branch for sytest in release-branch PRs (#17306) (diff)
downloadsynapse-f983a77ab070eac03f0eafe8dc6b990c43c3e89b.tar.xz
Set our own stream position from the current sequence value on startup (#17309)
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test_id_generators.py301
1 files changed, 126 insertions, 175 deletions
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 9be2923e6f..12b89cecb6 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -18,7 +18,7 @@
 # [This file includes modifications made by New Vector Limited]
 #
 #
-from typing import List, Optional
+from typing import Dict, List, Optional
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -42,9 +42,13 @@ from tests.utils import USE_POSTGRES_FOR_TESTS
 
 
 class MultiWriterIdGeneratorBase(HomeserverTestCase):
+    positive: bool = True
+    tables: List[str] = ["foobar"]
+
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
+        self.instances: Dict[str, MultiWriterIdGenerator] = {}
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
@@ -57,18 +61,22 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
         if USE_POSTGRES_FOR_TESTS:
             txn.execute("CREATE SEQUENCE foobar_seq")
 
-        txn.execute(
-            """
-            CREATE TABLE foobar (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
+        for table in self.tables:
+            txn.execute(
+                """
+                CREATE TABLE %s (
+                    stream_id BIGINT NOT NULL,
+                    instance_name TEXT NOT NULL,
+                    data TEXT
+                );
+                """
+                % (table,)
+            )
 
     def _create_id_generator(
-        self, instance_name: str = "master", writers: Optional[List[str]] = None
+        self,
+        instance_name: str = "master",
+        writers: Optional[List[str]] = None,
     ) -> MultiWriterIdGenerator:
         def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
             return MultiWriterIdGenerator(
@@ -77,36 +85,93 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
                 notifier=self.hs.get_replication_notifier(),
                 stream_name="test_stream",
                 instance_name=instance_name,
-                tables=[("foobar", "instance_name", "stream_id")],
+                tables=[(table, "instance_name", "stream_id") for table in self.tables],
                 sequence_name="foobar_seq",
                 writers=writers or ["master"],
+                positive=self.positive,
             )
 
-        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+        self.instances[instance_name] = self.get_success_or_raise(
+            self.db_pool.runWithConnection(_create)
+        )
+        return self.instances[instance_name]
+
+    def _replicate(self, instance_name: str) -> None:
+        """Similate a replication event for the given instance."""
+
+        writer = self.instances[instance_name]
+        token = writer.get_current_token_for_writer(instance_name)
+        for generator in self.instances.values():
+            if writer != generator:
+                generator.advance(instance_name, token)
+
+    def _replicate_all(self) -> None:
+        """Similate a replication event for all instances."""
 
-    def _insert_rows(self, instance_name: str, number: int) -> None:
+        for instance_name in self.instances:
+            self._replicate(instance_name)
+
+    def _insert_row(
+        self, instance_name: str, stream_id: int, table: Optional[str] = None
+    ) -> None:
+        """Insert one row as the given instance with given stream_id."""
+
+        if table is None:
+            table = self.tables[0]
+
+        factor = 1 if self.positive else -1
+
+        def _insert(txn: LoggingTransaction) -> None:
+            txn.execute(
+                "INSERT INTO %s VALUES (?, ?)" % (table,),
+                (
+                    stream_id,
+                    instance_name,
+                ),
+            )
+            txn.execute(
+                """
+                INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+                ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+                """,
+                (instance_name, stream_id * factor, stream_id * factor),
+            )
+
+        self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+    def _insert_rows(
+        self,
+        instance_name: str,
+        number: int,
+        table: Optional[str] = None,
+        update_stream_table: bool = True,
+    ) -> None:
         """Insert N rows as the given instance, inserting with stream IDs pulled
         from the postgres sequence.
         """
 
+        if table is None:
+            table = self.tables[0]
+
+        factor = 1 if self.positive else -1
+
         def _insert(txn: LoggingTransaction) -> None:
             for _ in range(number):
                 next_val = self.seq_gen.get_next_id_txn(txn)
                 txn.execute(
-                    "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
-                    (
-                        next_val,
-                        instance_name,
-                    ),
+                    "INSERT INTO %s (stream_id, instance_name) VALUES (?, ?)"
+                    % (table,),
+                    (next_val, instance_name),
                 )
 
-                txn.execute(
-                    """
-                    INSERT INTO stream_positions VALUES ('test_stream', ?,  ?)
-                    ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
-                    """,
-                    (instance_name, next_val, next_val),
-                )
+                if update_stream_table:
+                    txn.execute(
+                        """
+                        INSERT INTO stream_positions VALUES ('test_stream', ?,  ?)
+                        ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+                        """,
+                        (instance_name, next_val * factor, next_val * factor),
+                    )
 
         self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
 
@@ -353,7 +418,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
 
         id_gen = self._create_id_generator("first", writers=["first", "second"])
 
-        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+        # When the writer is created, it assumes its own position is the current head of
+        # the sequence
+        self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
 
         self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
@@ -375,11 +442,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         correctly.
         """
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
+        self._replicate_all()
+
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
@@ -398,6 +467,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
                 self.assertEqual(
                     first_id_gen.get_positions(), {"first": 3, "second": 7}
                 )
+                self.assertEqual(
+                    second_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
                 self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
 
         self.get_success(_get_next_async())
@@ -432,11 +504,11 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         """
         # Insert some rows for two out of three of the ID gens.
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator(
             "first", writers=["first", "second", "third"]
         )
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator(
             "second", writers=["first", "second", "third"]
         )
@@ -444,6 +516,8 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
             "third", writers=["first", "second", "third"]
         )
 
+        self._replicate_all()
+
         self.assertEqual(
             first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
         )
@@ -546,11 +620,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
 
     def test_minimal_local_token(self) -> None:
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
+        self._replicate_all()
+
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
 
@@ -562,15 +638,17 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         token when there are no writes.
         """
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator(
             "first", writers=["first", "second", "third"]
         )
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator(
             "second", writers=["first", "second", "third"]
         )
 
+        self._replicate_all()
+
         self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(second_id_gen.get_current_token(), 7)
@@ -609,68 +687,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         self.assertEqual(second_id_gen.get_current_token(), 7)
 
 
-class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
     """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
 
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastores().main
-        self.db_pool: DatabasePool = self.store.db_pool
-
-        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
-
-    def _setup_db(self, txn: LoggingTransaction) -> None:
-        txn.execute("CREATE SEQUENCE foobar_seq")
-        txn.execute(
-            """
-            CREATE TABLE foobar (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
-
-    def _create_id_generator(
-        self, instance_name: str = "master", writers: Optional[List[str]] = None
-    ) -> MultiWriterIdGenerator:
-        def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
-            return MultiWriterIdGenerator(
-                conn,
-                self.db_pool,
-                notifier=self.hs.get_replication_notifier(),
-                stream_name="test_stream",
-                instance_name=instance_name,
-                tables=[("foobar", "instance_name", "stream_id")],
-                sequence_name="foobar_seq",
-                writers=writers or ["master"],
-                positive=False,
-            )
-
-        return self.get_success(self.db_pool.runWithConnection(_create))
-
-    def _insert_row(self, instance_name: str, stream_id: int) -> None:
-        """Insert one row as the given instance with given stream_id."""
-
-        def _insert(txn: LoggingTransaction) -> None:
-            txn.execute(
-                "INSERT INTO foobar VALUES (?, ?)",
-                (
-                    stream_id,
-                    instance_name,
-                ),
-            )
-            txn.execute(
-                """
-                INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
-                ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
-                """,
-                (instance_name, -stream_id, -stream_id),
-            )
-
-        self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+    positive = False
 
     def test_single_instance(self) -> None:
         """Test that reads and writes from a single process are handled
@@ -716,7 +739,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         async def _get_next_async() -> None:
             async with id_gen_1.get_next() as stream_id:
                 self._insert_row("first", stream_id)
-                id_gen_2.advance("first", stream_id)
+            self._replicate("first")
 
         self.get_success(_get_next_async())
 
@@ -728,7 +751,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         async def _get_next_async2() -> None:
             async with id_gen_2.get_next() as stream_id:
                 self._insert_row("second", stream_id)
-                id_gen_1.advance("second", stream_id)
+            self._replicate("second")
 
         self.get_success(_get_next_async2())
 
@@ -738,98 +761,26 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
 
 
-class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+class MultiTableMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastores().main
-        self.db_pool: DatabasePool = self.store.db_pool
-
-        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
-
-    def _setup_db(self, txn: LoggingTransaction) -> None:
-        txn.execute("CREATE SEQUENCE foobar_seq")
-        txn.execute(
-            """
-            CREATE TABLE foobar1 (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
-
-        txn.execute(
-            """
-            CREATE TABLE foobar2 (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
-
-    def _create_id_generator(
-        self, instance_name: str = "master", writers: Optional[List[str]] = None
-    ) -> MultiWriterIdGenerator:
-        def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
-            return MultiWriterIdGenerator(
-                conn,
-                self.db_pool,
-                notifier=self.hs.get_replication_notifier(),
-                stream_name="test_stream",
-                instance_name=instance_name,
-                tables=[
-                    ("foobar1", "instance_name", "stream_id"),
-                    ("foobar2", "instance_name", "stream_id"),
-                ],
-                sequence_name="foobar_seq",
-                writers=writers or ["master"],
-            )
-
-        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
-
-    def _insert_rows(
-        self,
-        table: str,
-        instance_name: str,
-        number: int,
-        update_stream_table: bool = True,
-    ) -> None:
-        """Insert N rows as the given instance, inserting with stream IDs pulled
-        from the postgres sequence.
-        """
-
-        def _insert(txn: LoggingTransaction) -> None:
-            for _ in range(number):
-                txn.execute(
-                    "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
-                    (instance_name,),
-                )
-                if update_stream_table:
-                    txn.execute(
-                        """
-                        INSERT INTO stream_positions VALUES ('test_stream', ?,  lastval())
-                        ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
-                        """,
-                        (instance_name,),
-                    )
-
-        self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+    tables = ["foobar1", "foobar2"]
 
     def test_load_existing_stream(self) -> None:
         """Test creating ID gens with multiple tables that have rows from after
         the position in `stream_positions` table.
         """
-        self._insert_rows("foobar1", "first", 3)
-        self._insert_rows("foobar2", "second", 3)
-        self._insert_rows("foobar2", "second", 1, update_stream_table=False)
-
+        self._insert_rows("first", 3, table="foobar1")
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+        self._insert_rows("second", 3, table="foobar2")
+        self._insert_rows("second", 1, table="foobar2", update_stream_table=False)
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
+        self._replicate_all()
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)