summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test_id_generators.py315
1 files changed, 135 insertions, 180 deletions
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index f0307252f3..12b89cecb6 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -18,7 +18,7 @@
 # [This file includes modifications made by New Vector Limited]
 #
 #
-from typing import List, Optional
+from typing import Dict, List, Optional
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -28,7 +28,6 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.storage.engines import IncorrectDatabaseSetup
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.storage.util.sequence import (
@@ -43,9 +42,13 @@ from tests.utils import USE_POSTGRES_FOR_TESTS
 
 
 class MultiWriterIdGeneratorBase(HomeserverTestCase):
+    positive: bool = True
+    tables: List[str] = ["foobar"]
+
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.db_pool: DatabasePool = self.store.db_pool
+        self.instances: Dict[str, MultiWriterIdGenerator] = {}
 
         self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
 
@@ -58,18 +61,22 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
         if USE_POSTGRES_FOR_TESTS:
             txn.execute("CREATE SEQUENCE foobar_seq")
 
-        txn.execute(
-            """
-            CREATE TABLE foobar (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
+        for table in self.tables:
+            txn.execute(
+                """
+                CREATE TABLE %s (
+                    stream_id BIGINT NOT NULL,
+                    instance_name TEXT NOT NULL,
+                    data TEXT
+                );
+                """
+                % (table,)
+            )
 
     def _create_id_generator(
-        self, instance_name: str = "master", writers: Optional[List[str]] = None
+        self,
+        instance_name: str = "master",
+        writers: Optional[List[str]] = None,
     ) -> MultiWriterIdGenerator:
         def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
             return MultiWriterIdGenerator(
@@ -78,36 +85,93 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
                 notifier=self.hs.get_replication_notifier(),
                 stream_name="test_stream",
                 instance_name=instance_name,
-                tables=[("foobar", "instance_name", "stream_id")],
+                tables=[(table, "instance_name", "stream_id") for table in self.tables],
                 sequence_name="foobar_seq",
                 writers=writers or ["master"],
+                positive=self.positive,
+            )
+
+        self.instances[instance_name] = self.get_success_or_raise(
+            self.db_pool.runWithConnection(_create)
+        )
+        return self.instances[instance_name]
+
+    def _replicate(self, instance_name: str) -> None:
+        """Similate a replication event for the given instance."""
+
+        writer = self.instances[instance_name]
+        token = writer.get_current_token_for_writer(instance_name)
+        for generator in self.instances.values():
+            if writer != generator:
+                generator.advance(instance_name, token)
+
+    def _replicate_all(self) -> None:
+        """Similate a replication event for all instances."""
+
+        for instance_name in self.instances:
+            self._replicate(instance_name)
+
+    def _insert_row(
+        self, instance_name: str, stream_id: int, table: Optional[str] = None
+    ) -> None:
+        """Insert one row as the given instance with given stream_id."""
+
+        if table is None:
+            table = self.tables[0]
+
+        factor = 1 if self.positive else -1
+
+        def _insert(txn: LoggingTransaction) -> None:
+            txn.execute(
+                "INSERT INTO %s VALUES (?, ?)" % (table,),
+                (
+                    stream_id,
+                    instance_name,
+                ),
+            )
+            txn.execute(
+                """
+                INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+                ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+                """,
+                (instance_name, stream_id * factor, stream_id * factor),
             )
 
-        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+        self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
 
-    def _insert_rows(self, instance_name: str, number: int) -> None:
+    def _insert_rows(
+        self,
+        instance_name: str,
+        number: int,
+        table: Optional[str] = None,
+        update_stream_table: bool = True,
+    ) -> None:
         """Insert N rows as the given instance, inserting with stream IDs pulled
         from the postgres sequence.
         """
 
+        if table is None:
+            table = self.tables[0]
+
+        factor = 1 if self.positive else -1
+
         def _insert(txn: LoggingTransaction) -> None:
             for _ in range(number):
                 next_val = self.seq_gen.get_next_id_txn(txn)
                 txn.execute(
-                    "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
-                    (
-                        next_val,
-                        instance_name,
-                    ),
+                    "INSERT INTO %s (stream_id, instance_name) VALUES (?, ?)"
+                    % (table,),
+                    (next_val, instance_name),
                 )
 
-                txn.execute(
-                    """
-                    INSERT INTO stream_positions VALUES ('test_stream', ?,  ?)
-                    ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
-                    """,
-                    (instance_name, next_val, next_val),
-                )
+                if update_stream_table:
+                    txn.execute(
+                        """
+                        INSERT INTO stream_positions VALUES ('test_stream', ?,  ?)
+                        ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+                        """,
+                        (instance_name, next_val * factor, next_val * factor),
+                    )
 
         self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
 
@@ -354,7 +418,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
 
         id_gen = self._create_id_generator("first", writers=["first", "second"])
 
-        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+        # When the writer is created, it assumes its own position is the current head of
+        # the sequence
+        self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
 
         self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
@@ -376,11 +442,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         correctly.
         """
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
+        self._replicate_all()
+
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
@@ -399,6 +467,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
                 self.assertEqual(
                     first_id_gen.get_positions(), {"first": 3, "second": 7}
                 )
+                self.assertEqual(
+                    second_id_gen.get_positions(), {"first": 3, "second": 7}
+                )
                 self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
 
         self.get_success(_get_next_async())
@@ -433,11 +504,11 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         """
         # Insert some rows for two out of three of the ID gens.
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator(
             "first", writers=["first", "second", "third"]
         )
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator(
             "second", writers=["first", "second", "third"]
         )
@@ -445,6 +516,8 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
             "third", writers=["first", "second", "third"]
         )
 
+        self._replicate_all()
+
         self.assertEqual(
             first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
         )
@@ -525,7 +598,7 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
 
     def test_sequence_consistency(self) -> None:
-        """Test that we error out if the table and sequence diverges."""
+        """Test that we correct the sequence if the table and sequence diverges."""
 
         # Prefill with some rows
         self._insert_row_with_id("master", 3)
@@ -536,17 +609,24 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
 
         self.get_success(self.db_pool.runInteraction("_insert", _insert))
 
-        # Creating the ID gen should error
-        with self.assertRaises(IncorrectDatabaseSetup):
-            self._create_id_generator("first")
+        # Creating the ID gen should now fix the inconsistency
+        id_gen = self._create_id_generator()
+
+        async def _get_next_async() -> None:
+            async with id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 27)
+
+        self.get_success(_get_next_async())
 
     def test_minimal_local_token(self) -> None:
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
+        self._replicate_all()
+
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
 
@@ -558,15 +638,17 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         token when there are no writes.
         """
         self._insert_rows("first", 3)
-        self._insert_rows("second", 4)
-
         first_id_gen = self._create_id_generator(
             "first", writers=["first", "second", "third"]
         )
+
+        self._insert_rows("second", 4)
         second_id_gen = self._create_id_generator(
             "second", writers=["first", "second", "third"]
         )
 
+        self._replicate_all()
+
         self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(second_id_gen.get_current_token(), 7)
@@ -605,68 +687,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
         self.assertEqual(second_id_gen.get_current_token(), 7)
 
 
-class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
     """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
 
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastores().main
-        self.db_pool: DatabasePool = self.store.db_pool
-
-        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
-
-    def _setup_db(self, txn: LoggingTransaction) -> None:
-        txn.execute("CREATE SEQUENCE foobar_seq")
-        txn.execute(
-            """
-            CREATE TABLE foobar (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
-
-    def _create_id_generator(
-        self, instance_name: str = "master", writers: Optional[List[str]] = None
-    ) -> MultiWriterIdGenerator:
-        def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
-            return MultiWriterIdGenerator(
-                conn,
-                self.db_pool,
-                notifier=self.hs.get_replication_notifier(),
-                stream_name="test_stream",
-                instance_name=instance_name,
-                tables=[("foobar", "instance_name", "stream_id")],
-                sequence_name="foobar_seq",
-                writers=writers or ["master"],
-                positive=False,
-            )
-
-        return self.get_success(self.db_pool.runWithConnection(_create))
-
-    def _insert_row(self, instance_name: str, stream_id: int) -> None:
-        """Insert one row as the given instance with given stream_id."""
-
-        def _insert(txn: LoggingTransaction) -> None:
-            txn.execute(
-                "INSERT INTO foobar VALUES (?, ?)",
-                (
-                    stream_id,
-                    instance_name,
-                ),
-            )
-            txn.execute(
-                """
-                INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
-                ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
-                """,
-                (instance_name, -stream_id, -stream_id),
-            )
-
-        self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+    positive = False
 
     def test_single_instance(self) -> None:
         """Test that reads and writes from a single process are handled
@@ -712,7 +739,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         async def _get_next_async() -> None:
             async with id_gen_1.get_next() as stream_id:
                 self._insert_row("first", stream_id)
-                id_gen_2.advance("first", stream_id)
+            self._replicate("first")
 
         self.get_success(_get_next_async())
 
@@ -724,7 +751,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         async def _get_next_async2() -> None:
             async with id_gen_2.get_next() as stream_id:
                 self._insert_row("second", stream_id)
-                id_gen_1.advance("second", stream_id)
+            self._replicate("second")
 
         self.get_success(_get_next_async2())
 
@@ -734,98 +761,26 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
 
 
-class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+class MultiTableMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
     if not USE_POSTGRES_FOR_TESTS:
         skip = "Requires Postgres"
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
-        self.store = hs.get_datastores().main
-        self.db_pool: DatabasePool = self.store.db_pool
-
-        self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
-
-    def _setup_db(self, txn: LoggingTransaction) -> None:
-        txn.execute("CREATE SEQUENCE foobar_seq")
-        txn.execute(
-            """
-            CREATE TABLE foobar1 (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
-
-        txn.execute(
-            """
-            CREATE TABLE foobar2 (
-                stream_id BIGINT NOT NULL,
-                instance_name TEXT NOT NULL,
-                data TEXT
-            );
-            """
-        )
-
-    def _create_id_generator(
-        self, instance_name: str = "master", writers: Optional[List[str]] = None
-    ) -> MultiWriterIdGenerator:
-        def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
-            return MultiWriterIdGenerator(
-                conn,
-                self.db_pool,
-                notifier=self.hs.get_replication_notifier(),
-                stream_name="test_stream",
-                instance_name=instance_name,
-                tables=[
-                    ("foobar1", "instance_name", "stream_id"),
-                    ("foobar2", "instance_name", "stream_id"),
-                ],
-                sequence_name="foobar_seq",
-                writers=writers or ["master"],
-            )
-
-        return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
-
-    def _insert_rows(
-        self,
-        table: str,
-        instance_name: str,
-        number: int,
-        update_stream_table: bool = True,
-    ) -> None:
-        """Insert N rows as the given instance, inserting with stream IDs pulled
-        from the postgres sequence.
-        """
-
-        def _insert(txn: LoggingTransaction) -> None:
-            for _ in range(number):
-                txn.execute(
-                    "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
-                    (instance_name,),
-                )
-                if update_stream_table:
-                    txn.execute(
-                        """
-                        INSERT INTO stream_positions VALUES ('test_stream', ?,  lastval())
-                        ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
-                        """,
-                        (instance_name,),
-                    )
-
-        self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+    tables = ["foobar1", "foobar2"]
 
     def test_load_existing_stream(self) -> None:
         """Test creating ID gens with multiple tables that have rows from after
         the position in `stream_positions` table.
         """
-        self._insert_rows("foobar1", "first", 3)
-        self._insert_rows("foobar2", "second", 3)
-        self._insert_rows("foobar2", "second", 1, update_stream_table=False)
-
+        self._insert_rows("first", 3, table="foobar1")
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+        self._insert_rows("second", 3, table="foobar2")
+        self._insert_rows("second", 1, table="foobar2", update_stream_table=False)
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
+        self._replicate_all()
+
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
         self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)