diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 4558bee7be..392b08832b 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -390,17 +390,28 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# Initial config has two writers
id_gen = self._create_id_generator("first", writers=["first", "second"])
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+ self.assertEqual(id_gen.get_current_token_for_writer("first"), 3)
+ self.assertEqual(id_gen.get_current_token_for_writer("second"), 5)
# New config removes one of the configs. Note that if the writer is
# removed from config we assume that it has been shut down and has
# finished persisting, hence why the persisted upto position is 5.
id_gen_2 = self._create_id_generator("second", writers=["second"])
self.assertEqual(id_gen_2.get_persisted_upto_position(), 5)
+ self.assertEqual(id_gen_2.get_current_token_for_writer("second"), 5)
# This config points to a single, previously unused writer.
id_gen_3 = self._create_id_generator("third", writers=["third"])
self.assertEqual(id_gen_3.get_persisted_upto_position(), 5)
+ # For new writers we assume their initial position to be the current
+ # persisted up to position. This stops Synapse from doing a full table
+ # scan when a new writer comes along.
+ self.assertEqual(id_gen_3.get_current_token_for_writer("third"), 5)
+
+ id_gen_4 = self._create_id_generator("fourth", writers=["third"])
+ self.assertEqual(id_gen_4.get_current_token_for_writer("third"), 5)
+
# Check that we get a sane next stream ID with this new config.
async def _get_next_async():
@@ -410,6 +421,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async())
self.assertEqual(id_gen_3.get_persisted_upto_position(), 6)
+ # If we add back the old "first" then we shouldn't see the persisted up
+ # to position revert back to 3.
+ id_gen_5 = self._create_id_generator("five", writers=["first", "third"])
+ self.assertEqual(id_gen_5.get_persisted_upto_position(), 6)
+ self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6)
+ self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6)
+
def test_sequence_consistency(self):
"""Test that we error out if the table and sequence diverges.
"""
|