summary refs log tree commit diff
path: root/tests/storage/test_id_generators.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_id_generators.py')
-rw-r--r--tests/storage/test_id_generators.py80
1 files changed, 42 insertions, 38 deletions
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py

index 6ac4b93f98..395396340b 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py
@@ -13,9 +13,13 @@ # limitations under the License. from typing import List, Optional -from synapse.storage.database import DatabasePool +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.util import Clock from tests.unittest import HomeserverTestCase from tests.utils import USE_POSTGRES_FOR_TESTS @@ -25,13 +29,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -59,12 +63,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) - def _insert_rows(self, instance_name: str, number: int): + def _insert_rows(self, instance_name: str, number: int) -> None: """Insert N rows as the given instance, inserting with stream IDs pulled from the postgres sequence. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: for _ in range(number): txn.execute( "INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)", @@ -80,12 +84,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - def _insert_row_with_id(self, instance_name: str, stream_id: int): + def _insert_row_with_id(self, instance_name: str, stream_id: int) -> None: """Insert one row as the given instance with given stream_id, updating the postgres sequence position to match. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: txn.execute( "INSERT INTO foobar VALUES (?, ?)", ( @@ -104,7 +108,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_row_with_id", _insert)) - def test_empty(self): + def test_empty(self) -> None: """Test an ID generator against an empty database gives sensible current positions. """ @@ -114,7 +118,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # The table is empty so we expect an empty map for positions self.assertEqual(id_gen.get_positions(), {}) - def test_single_instance(self): + def test_single_instance(self) -> None: """Test that reads and writes from a single process are handled correctly. """ @@ -130,7 +134,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) @@ -142,7 +146,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) - def test_out_of_order_finish(self): + def test_out_of_order_finish(self) -> None: """Test that IDs persisted out of order are correctly handled""" # Prefill table with 7 rows written by 'master' @@ -191,7 +195,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) - def test_multi_instance(self): + def test_multi_instance(self) -> None: """Test that reads and writes from multiple processes are handled correctly. """ @@ -215,7 +219,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - async def _get_next_async(): + async def _get_next_async() -> None: async with first_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) @@ -233,7 +237,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # ... but calling `get_next` on the second instance should give a unique # stream ID - async def _get_next_async(): + async def _get_next_async2() -> None: async with second_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 9) @@ -241,7 +245,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen.get_positions(), {"first": 3, "second": 7} ) - self.get_success(_get_next_async()) + self.get_success(_get_next_async2()) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9}) @@ -249,7 +253,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): second_id_gen.advance("first", 8) self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) - def test_get_next_txn(self): + def test_get_next_txn(self) -> None: """Test that the `get_next_txn` function works correctly.""" # Prefill table with 7 rows written by 'master' @@ -263,7 +267,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Try allocating a new ID gen and check that we only see position # advanced after we leave the context manager. - def _get_next_txn(txn): + def _get_next_txn(txn: LoggingTransaction) -> None: stream_id = id_gen.get_next_txn(txn) self.assertEqual(stream_id, 8) @@ -275,7 +279,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 8}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 8) - def test_get_persisted_upto_position(self): + def test_get_persisted_upto_position(self) -> None: """Test that `get_persisted_upto_position` correctly tracks updates to positions. """ @@ -317,7 +321,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen.advance("second", 15) self.assertEqual(id_gen.get_persisted_upto_position(), 11) - def test_get_persisted_upto_position_get_next(self): + def test_get_persisted_upto_position_get_next(self) -> None: """Test that `get_persisted_upto_position` correctly tracks updates to positions when `get_next` is called. """ @@ -331,7 +335,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_persisted_upto_position(), 5) - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 6) self.assertEqual(id_gen.get_persisted_upto_position(), 5) @@ -344,7 +348,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). - def test_restart_during_out_of_order_persistence(self): + def test_restart_during_out_of_order_persistence(self) -> None: """Test that restarting a process while another process is writing out of order updates are handled correctly. """ @@ -388,7 +392,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen_worker.advance("master", 9) self.assertEqual(id_gen_worker.get_positions(), {"master": 9}) - def test_writer_config_change(self): + def test_writer_config_change(self) -> None: """Test that changing the writer config correctly works.""" self._insert_row_with_id("first", 3) @@ -421,7 +425,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # Check that we get a sane next stream ID with this new config. - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen_3.get_next() as stream_id: self.assertEqual(stream_id, 6) @@ -435,7 +439,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen_5.get_current_token_for_writer("first"), 6) self.assertEqual(id_gen_5.get_current_token_for_writer("third"), 6) - def test_sequence_consistency(self): + def test_sequence_consistency(self) -> None: """Test that we error out if the table and sequence diverges.""" # Prefill with some rows @@ -458,13 +462,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -493,10 +497,10 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): return self.get_success(self.db_pool.runWithConnection(_create)) - def _insert_row(self, instance_name: str, stream_id: int): + def _insert_row(self, instance_name: str, stream_id: int) -> None: """Insert one row as the given instance with given stream_id.""" - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: txn.execute( "INSERT INTO foobar VALUES (?, ?)", ( @@ -514,13 +518,13 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) - def test_single_instance(self): + def test_single_instance(self) -> None: """Test that reads and writes from a single process are handled correctly. """ id_gen = self._create_id_generator() - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen.get_next() as stream_id: self._insert_row("master", stream_id) @@ -530,7 +534,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1) - async def _get_next_async2(): + async def _get_next_async2() -> None: async with id_gen.get_next_mult(3) as stream_ids: for stream_id in stream_ids: self._insert_row("master", stream_id) @@ -548,14 +552,14 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4) self.assertEqual(second_id_gen.get_persisted_upto_position(), -4) - def test_multiple_instance(self): + def test_multiple_instance(self) -> None: """Tests that having multiple instances that get advanced over federation works corretly. """ id_gen_1 = self._create_id_generator("first", writers=["first", "second"]) id_gen_2 = self._create_id_generator("second", writers=["first", "second"]) - async def _get_next_async(): + async def _get_next_async() -> None: async with id_gen_1.get_next() as stream_id: self._insert_row("first", stream_id) id_gen_2.advance("first", stream_id) @@ -567,7 +571,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) - async def _get_next_async2(): + async def _get_next_async2() -> None: async with id_gen_2.get_next() as stream_id: self._insert_row("second", stream_id) id_gen_1.advance("second", stream_id) @@ -584,13 +588,13 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): if not USE_POSTGRES_FOR_TESTS: skip = "Requires Postgres" - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.db_pool: DatabasePool = self.store.db_pool self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) - def _setup_db(self, txn): + def _setup_db(self, txn: LoggingTransaction) -> None: txn.execute("CREATE SEQUENCE foobar_seq") txn.execute( """ @@ -642,7 +646,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): from the postgres sequence. """ - def _insert(txn): + def _insert(txn: LoggingTransaction) -> None: for _ in range(number): txn.execute( "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), @@ -659,7 +663,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) - def test_load_existing_stream(self): + def test_load_existing_stream(self) -> None: """Test creating ID gens with multiple tables that have rows from after the position in `stream_positions` table. """