diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 14ce21c786..f0a8e32f1e 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -264,3 +264,108 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
+
+
+class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+ """Tests MultiWriterIdGenerator that produce *negative* stream IDs.
+ """
+
+ if not USE_POSTGRES_FOR_TESTS:
+ skip = "Requires Postgres"
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.db_pool = self.store.db_pool # type: DatabasePool
+
+ self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
+
+ def _setup_db(self, txn):
+ txn.execute("CREATE SEQUENCE foobar_seq")
+ txn.execute(
+ """
+ CREATE TABLE foobar (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ )
+
+ def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
+ def _create(conn):
+ return MultiWriterIdGenerator(
+ conn,
+ self.db_pool,
+ instance_name=instance_name,
+ table="foobar",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="foobar_seq",
+ positive=False,
+ )
+
+ return self.get_success(self.db_pool.runWithConnection(_create))
+
+ def _insert_row(self, instance_name: str, stream_id: int):
+ """Insert one row as the given instance with given stream_id.
+ """
+
+ def _insert(txn):
+ txn.execute(
+ "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+ def test_single_instance(self):
+ """Test that reads and writes from a single process are handled
+ correctly.
+ """
+ id_gen = self._create_id_generator()
+
+ with self.get_success(id_gen.get_next()) as stream_id:
+ self._insert_row("master", stream_id)
+
+ self.assertEqual(id_gen.get_positions(), {"master": -1})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -1)
+
+ with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
+ for stream_id in stream_ids:
+ self._insert_row("master", stream_id)
+
+ self.assertEqual(id_gen.get_positions(), {"master": -4})
+ self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(id_gen.get_persisted_upto_position(), -4)
+
+ # Test loading from DB by creating a second ID gen
+ second_id_gen = self._create_id_generator()
+
+ self.assertEqual(second_id_gen.get_positions(), {"master": -4})
+ self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
+ self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
+
+ def test_multiple_instance(self):
+ """Tests that having multiple instances that get advanced over
+ federation works corretly.
+ """
+ id_gen_1 = self._create_id_generator("first")
+ id_gen_2 = self._create_id_generator("second")
+
+ with self.get_success(id_gen_1.get_next()) as stream_id:
+ self._insert_row("first", stream_id)
+ id_gen_2.advance("first", stream_id)
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
+
+ with self.get_success(id_gen_2.get_next()) as stream_id:
+ self._insert_row("second", stream_id)
+ id_gen_1.advance("second", stream_id)
+
+ self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
+ self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
+ self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
|