From d16910ca021320f0fa09c6cf82a802ee97e22a0c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 May 2024 12:07:32 +0100 Subject: Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator` (#17229) Replaces all usages of `StreamIdGenerator` with `MultiWriterIdGenerator`, which is safer. --- tests/storage/test_id_generators.py | 140 +----------------------------------- 1 file changed, 1 insertion(+), 139 deletions(-) (limited to 'tests') diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index fad9511cea..f0307252f3 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -30,7 +30,7 @@ from synapse.storage.database import ( ) from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.types import Cursor -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import MultiWriterIdGenerator from synapse.storage.util.sequence import ( LocalSequenceGenerator, PostgresSequenceGenerator, @@ -42,144 +42,6 @@ from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS -class StreamIdGeneratorTestCase(HomeserverTestCase): - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.store = hs.get_datastores().main - self.db_pool: DatabasePool = self.store.db_pool - - self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - - def _setup_db(self, txn: LoggingTransaction) -> None: - txn.execute( - """ - CREATE TABLE foobar ( - stream_id BIGINT NOT NULL, - data TEXT - ); - """ - ) - txn.execute("INSERT INTO foobar VALUES (123, 'hello world');") - - def _create_id_generator(self) -> StreamIdGenerator: - def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator: - return StreamIdGenerator( - db_conn=conn, - notifier=self.hs.get_replication_notifier(), - table="foobar", - column="stream_id", - ) - - return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) - - def test_initial_value(self) -> None: - """Check that we read the current token from the DB.""" - id_gen = self._create_id_generator() - self.assertEqual(id_gen.get_current_token(), 123) - - def test_single_gen_next(self) -> None: - """Check that we correctly increment the current token from the DB.""" - id_gen = self._create_id_generator() - - async def test_gen_next() -> None: - async with id_gen.get_next() as next_id: - # We haven't persisted `next_id` yet; current token is still 123 - self.assertEqual(id_gen.get_current_token(), 123) - # But we did learn what the next value is - self.assertEqual(next_id, 124) - - # Once the context manager closes we assume that the `next_id` has been - # written to the DB. - self.assertEqual(id_gen.get_current_token(), 124) - - self.get_success(test_gen_next()) - - def test_multiple_gen_nexts(self) -> None: - """Check that we handle overlapping calls to gen_next sensibly.""" - id_gen = self._create_id_generator() - - async def test_gen_next() -> None: - ctx1 = id_gen.get_next() - ctx2 = id_gen.get_next() - ctx3 = id_gen.get_next() - - # Request three new stream IDs. - self.assertEqual(await ctx1.__aenter__(), 124) - self.assertEqual(await ctx2.__aenter__(), 125) - self.assertEqual(await ctx3.__aenter__(), 126) - - # None are persisted: current token unchanged. - self.assertEqual(id_gen.get_current_token(), 123) - - # Persist each in turn. - await ctx1.__aexit__(None, None, None) - self.assertEqual(id_gen.get_current_token(), 124) - await ctx2.__aexit__(None, None, None) - self.assertEqual(id_gen.get_current_token(), 125) - await ctx3.__aexit__(None, None, None) - self.assertEqual(id_gen.get_current_token(), 126) - - self.get_success(test_gen_next()) - - def test_multiple_gen_nexts_closed_in_different_order(self) -> None: - """Check that we handle overlapping calls to gen_next, even when their IDs - created and persisted in different orders.""" - id_gen = self._create_id_generator() - - async def test_gen_next() -> None: - ctx1 = id_gen.get_next() - ctx2 = id_gen.get_next() - ctx3 = id_gen.get_next() - - # Request three new stream IDs. - self.assertEqual(await ctx1.__aenter__(), 124) - self.assertEqual(await ctx2.__aenter__(), 125) - self.assertEqual(await ctx3.__aenter__(), 126) - - # None are persisted: current token unchanged. - self.assertEqual(id_gen.get_current_token(), 123) - - # Persist them in a different order, starting with 126 from ctx3. - await ctx3.__aexit__(None, None, None) - # We haven't persisted 124 from ctx1 yet---current token is still 123. - self.assertEqual(id_gen.get_current_token(), 123) - - # Now persist 124 from ctx1. - await ctx1.__aexit__(None, None, None) - # Current token is then 124, waiting for 125 to be persisted. - self.assertEqual(id_gen.get_current_token(), 124) - - # Finally persist 125 from ctx2. - await ctx2.__aexit__(None, None, None) - # Current token is then 126 (skipping over 125). - self.assertEqual(id_gen.get_current_token(), 126) - - self.get_success(test_gen_next()) - - def test_gen_next_while_still_waiting_for_persistence(self) -> None: - """Check that we handle overlapping calls to gen_next.""" - id_gen = self._create_id_generator() - - async def test_gen_next() -> None: - ctx1 = id_gen.get_next() - ctx2 = id_gen.get_next() - ctx3 = id_gen.get_next() - - # Request two new stream IDs. - self.assertEqual(await ctx1.__aenter__(), 124) - self.assertEqual(await ctx2.__aenter__(), 125) - - # Persist ctx2 first. - await ctx2.__aexit__(None, None, None) - # Still waiting on ctx1's ID to be persisted. - self.assertEqual(id_gen.get_current_token(), 123) - - # Now request a third stream ID. It should be 126 (the smallest ID that - # we've not yet handed out.) - self.assertEqual(await ctx3.__aenter__(), 126) - - self.get_success(test_gen_next()) - - class MultiWriterIdGeneratorBase(HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main -- cgit 1.4.1