diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/storage/test_id_generators.py | 66 |
1 files changed, 41 insertions, 25 deletions
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 20636fc400..fb8f5bc255 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -111,7 +111,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # advanced after we leave the context manager. async def _get_next_async(): - with await id_gen.get_next() as stream_id: + async with id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) self.assertEqual(id_gen.get_positions(), {"master": 7}) @@ -139,10 +139,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): ctx3 = self.get_success(id_gen.get_next()) ctx4 = self.get_success(id_gen.get_next()) - s1 = ctx1.__enter__() - s2 = ctx2.__enter__() - s3 = ctx3.__enter__() - s4 = ctx4.__enter__() + s1 = self.get_success(ctx1.__aenter__()) + s2 = self.get_success(ctx2.__aenter__()) + s3 = self.get_success(ctx3.__aenter__()) + s4 = self.get_success(ctx4.__aenter__()) self.assertEqual(s1, 8) self.assertEqual(s2, 9) @@ -152,22 +152,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx2.__exit__(None, None, None) + self.get_success(ctx2.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) - ctx1.__exit__(None, None, None) + self.get_success(ctx1.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) - ctx4.__exit__(None, None, None) + self.get_success(ctx4.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) - ctx3.__exit__(None, None, None) + self.get_success(ctx3.__aexit__(None, None, None)) self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) @@ -190,7 +190,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # advanced after we leave the context manager. async def _get_next_async(): - with await first_id_gen.get_next() as stream_id: + async with first_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 8) self.assertEqual( @@ -208,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # stream ID async def _get_next_async(): - with await second_id_gen.get_next() as stream_id: + async with second_id_gen.get_next() as stream_id: self.assertEqual(stream_id, 9) self.assertEqual( @@ -305,9 +305,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) self.assertEqual(id_gen.get_persisted_upto_position(), 3) - with self.get_success(id_gen.get_next()) as stream_id: - self.assertEqual(stream_id, 6) - self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + async def _get_next_async(): + async with id_gen.get_next() as stream_id: + self.assertEqual(stream_id, 6) + self.assertEqual(id_gen.get_persisted_upto_position(), 3) + + self.get_success(_get_next_async()) self.assertEqual(id_gen.get_persisted_upto_position(), 6) @@ -373,16 +377,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): """ id_gen = self._create_id_generator() - with self.get_success(id_gen.get_next()) as stream_id: - self._insert_row("master", stream_id) + async def _get_next_async(): + async with id_gen.get_next() as stream_id: + self._insert_row("master", stream_id) + + self.get_success(_get_next_async()) self.assertEqual(id_gen.get_positions(), {"master": -1}) self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1) - with self.get_success(id_gen.get_next_mult(3)) as stream_ids: - for stream_id in stream_ids: - self._insert_row("master", stream_id) + async def _get_next_async2(): + async with id_gen.get_next_mult(3) as stream_ids: + for stream_id in stream_ids: + self._insert_row("master", stream_id) + + self.get_success(_get_next_async2()) self.assertEqual(id_gen.get_positions(), {"master": -4}) self.assertEqual(id_gen.get_current_token_for_writer("master"), -4) @@ -402,18 +412,24 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): id_gen_1 = self._create_id_generator("first") id_gen_2 = self._create_id_generator("second") - with self.get_success(id_gen_1.get_next()) as stream_id: - self._insert_row("first", stream_id) - id_gen_2.advance("first", stream_id) + async def _get_next_async(): + async with id_gen_1.get_next() as stream_id: + self._insert_row("first", stream_id) + id_gen_2.advance("first", stream_id) + + self.get_success(_get_next_async()) self.assertEqual(id_gen_1.get_positions(), {"first": -1}) self.assertEqual(id_gen_2.get_positions(), {"first": -1}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) - with self.get_success(id_gen_2.get_next()) as stream_id: - self._insert_row("second", stream_id) - id_gen_1.advance("second", stream_id) + async def _get_next_async2(): + async with id_gen_2.get_next() as stream_id: + self._insert_row("second", stream_id) + id_gen_1.advance("second", stream_id) + + self.get_success(_get_next_async2()) self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) |