diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index fb8f5bc255..d4ff55fbff 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -43,16 +43,20 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
)
return self.get_success(self.db_pool.runWithConnection(_create))
@@ -68,6 +72,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
(instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
+ """,
+ (instance_name,),
+ )
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
@@ -81,6 +92,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
txn.execute("SELECT setval('foobar_seq', ?)", (stream_id,))
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, stream_id, stream_id),
+ )
self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert))
@@ -179,8 +197,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_rows("first", 3)
self._insert_rows("second", 4)
- first_id_gen = self._create_id_generator("first")
- second_id_gen = self._create_id_generator("second")
+ first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+ second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
@@ -262,7 +280,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first")
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -300,7 +318,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self._insert_row_with_id("first", 3)
self._insert_row_with_id("second", 5)
- id_gen = self._create_id_generator("first")
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
@@ -319,6 +337,80 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
+ def test_restart_during_out_of_order_persistence(self):
+ """Test that restarting a process while another process is writing out
+ of order updates are handled correctly.
+ """
+
+ # Prefill table with 7 rows written by 'master'
+ self._insert_rows("master", 7)
+
+ id_gen = self._create_id_generator()
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # Persist two rows at once
+ ctx1 = self.get_success(id_gen.get_next())
+ ctx2 = self.get_success(id_gen.get_next())
+
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+
+ self.assertEqual(s1, 8)
+ self.assertEqual(s2, 9)
+
+ self.assertEqual(id_gen.get_positions(), {"master": 7})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
+
+ # We finish persisting the second row before restart
+ self.get_success(ctx2.__aexit__(None, None, None))
+
+ # We simulate a restart of another worker by just creating a new ID gen.
+ id_gen_worker = self._create_id_generator("worker")
+
+ # Restarted worker should not see the second persisted row
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 7})
+ self.assertEqual(id_gen_worker.get_current_token_for_writer("master"), 7)
+
+ # Now if we persist the first row then both instances should jump ahead
+ # correctly.
+ self.get_success(ctx1.__aexit__(None, None, None))
+
+ self.assertEqual(id_gen.get_positions(), {"master": 9})
+ id_gen_worker.advance("master", 9)
+ self.assertEqual(id_gen_worker.get_positions(), {"master": 9})
+
+ def test_writer_config_change(self):
+ """Test that changing the writer config correctly works.
+ """
+
+ self._insert_row_with_id("first", 3)
+ self._insert_row_with_id("second", 5)
+
+ # Initial config has two writers
+ id_gen = self._create_id_generator("first", writers=["first", "second"])
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ # New config removes one of the configs. Note that if the writer is
+ # removed from config we assume that it has been shut down and has
+ # finished persisting, hence why the persisted upto position is 5.
+ id_gen_2 = self._create_id_generator("second", writers=["second"])
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), 5)
+
+ # This config points to a single, previously unused writer.
+ id_gen_3 = self._create_id_generator("third", writers=["third"])
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 5)
+
+ # Check that we get a sane next stream ID with this new config.
+
+ async def _get_next_async():
+ async with id_gen_3.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+
+ self.get_success(_get_next_async())
+ self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
+
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
@@ -345,16 +437,20 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
)
- def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create_id_generator(
+ self, instance_name="master", writers=["master"]
+ ) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
+ stream_name="test_stream",
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
+ writers=writers,
positive=False,
)
@@ -368,6 +464,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
txn.execute(
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, -stream_id, -stream_id),
+ )
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
@@ -409,8 +512,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
- id_gen_1 = self._create_id_generator("first")
- id_gen_2 = self._create_id_generator("second")
+ id_gen_1 = self._create_id_generator("first", writers=["first", "second"])
+ id_gen_2 = self._create_id_generator("second", writers=["first", "second"])
async def _get_next_async():
async with id_gen_1.get_next() as stream_id:
|