summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_id_generators.py119
1 files changed, 111 insertions, 8 deletions
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index fb8f5bc255..d4ff55fbff 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -43,16 +43,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
             """
         )
 
-    def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+    def _create_id_generator(
+        self, instance_name="master", writers=["master"]
+    ) -> MultiWriterIdGenerator:
         def _create(conn):
             return MultiWriterIdGenerator(
                 conn,
                 self.db_pool,
+                stream_name="test_stream",
                 instance_name=instance_name,
                 table="foobar",
                 instance_column="instance_name",
                 id_column="stream_id",
                 sequence_name="foobar_seq",
+                writers=writers,
             )
 
         return self.get_success(self.db_pool.runWithConnection(_create))
@@ -68,6 +72,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                     "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
                     (instance_name,),
                 )
+                txn.execute(
+                    """
+                    INSERT INTO stream_positions VALUES ('test_stream', ?,  lastval())
+                    ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+                    """,
+                    (instance_name,),
+                )
 
         self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
 
@@ -81,6 +92,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
             )
             txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+            txn.execute(
+                """
+                INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+                ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+                """,
+                (instance_name, stream_id, stream_id),
+            )
 
         self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
 
@@ -179,8 +197,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self._insert_rows("first", 3)
         self._insert_rows("second", 4)
 
-        first_id_gen = self._create_id_generator("first")
-        second_id_gen = self._create_id_generator("second")
+        first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+        second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
         self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
@@ -262,7 +280,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self._insert_row_with_id("first", 3)
         self._insert_row_with_id("second", 5)
 
-        id_gen = self._create_id_generator("first")
+        id_gen = self._create_id_generator("first", writers=["first", "second"])
 
         self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
 
@@ -300,7 +318,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self._insert_row_with_id("first", 3)
         self._insert_row_with_id("second", 5)
 
-        id_gen = self._create_id_generator("first")
+        id_gen = self._create_id_generator("first", writers=["first", "second"])
 
         self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
 
@@ -319,6 +337,80 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # `persisted_upto_position` in this case, then it will be correct in the
         # other cases that are tested above (since they'll hit the same code).
 
+    def test_restart_during_out_of_order_persistence(self):
+        """Test that restarting a process while another process is writing out
+        of order updates are handled correctly.
+        """
+
+        # Prefill table with 7 rows written by 'master'
+        self._insert_rows("master", 7)
+
+        id_gen = self._create_id_generator()
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+        # Persist two rows at once
+        ctx1 = self.get_success(id_gen.get_next())
+        ctx2 = self.get_success(id_gen.get_next())
+
+        s1 = self.get_success(ctx1.__aenter__())
+        s2 = self.get_success(ctx2.__aenter__())
+
+        self.assertEqual(s1, 8)
+        self.assertEqual(s2, 9)
+
+        self.assertEqual(id_gen.get_positions(), {"master": 7})
+        self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+        # We finish persisting the second row before restart
+        self.get_success(ctx2.__aexit__(None, None, None))
+
+        # We simulate a restart of another worker by just creating a new ID gen.
+        id_gen_worker = self._create_id_generator("worker")
+
+        # Restarted worker should not see the second persisted row
+        self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
+        self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
+
+        # Now if we persist the first row then both instances should jump ahead
+        # correctly.
+        self.get_success(ctx1.__aexit__(None, None, None))
+
+        self.assertEqual(id_gen.get_positions(), {"master": 9})
+        id_gen_worker.advance("master", 9)
+        self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
+
+    def test_writer_config_change(self):
+        """Test that changing the writer config correctly works.
+        """
+
+        self._insert_row_with_id("first", 3)
+        self._insert_row_with_id("second", 5)
+
+        # Initial config has two writers
+        id_gen = self._create_id_generator("first", writers=["first", "second"])
+        self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+        # New config removes one of the configs. Note that if the writer is
+        # removed from config we assume that it has been shut down and has
+        # finished persisting, hence why the persisted upto position is 5.
+        id_gen_2 = self._create_id_generator("second", writers=["second"])
+        self.assertEqual(id_gen_2.get_persisted_upto_position(), 5)
+
+        # This config points to a single, previously unused writer.
+        id_gen_3 = self._create_id_generator("third", writers=["third"])
+        self.assertEqual(id_gen_3.get_persisted_upto_position(), 5)
+
+        # Check that we get a sane next stream ID with this new config.
+
+        async def _get_next_async():
+            async with id_gen_3.get_next() as stream_id:
+                self.assertEqual(stream_id, 6)
+
+        self.get_success(_get_next_async())
+        self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
+
 
 class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
     """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
@@ -345,16 +437,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
             """
         )
 
-    def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+    def _create_id_generator(
+        self, instance_name="master", writers=["master"]
+    ) -> MultiWriterIdGenerator:
         def _create(conn):
             return MultiWriterIdGenerator(
                 conn,
                 self.db_pool,
+                stream_name="test_stream",
                 instance_name=instance_name,
                 table="foobar",
                 instance_column="instance_name",
                 id_column="stream_id",
                 sequence_name="foobar_seq",
+                writers=writers,
                 positive=False,
             )
 
@@ -368,6 +464,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
             txn.execute(
                 "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
             )
+            txn.execute(
+                """
+                INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+                ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+                """,
+                (instance_name, -stream_id, -stream_id),
+            )
 
         self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
 
@@ -409,8 +512,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         """Tests that having multiple instances that get advanced over
         federation works corretly.
         """
-        id_gen_1 = self._create_id_generator("first")
-        id_gen_2 = self._create_id_generator("second")
+        id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
+        id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
 
         async def _get_next_async():
             async with id_gen_1.get_next() as stream_id: