diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 20636fc400..fb8f5bc255 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -111,7 +111,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager.
async def _get_next_async():
- with await id_gen.get_next() as stream_id:
+ async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
@@ -139,10 +139,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
ctx3 = self.get_success(id_gen.get_next())
ctx4 = self.get_success(id_gen.get_next())
- s1 = ctx1.__enter__()
- s2 = ctx2.__enter__()
- s3 = ctx3.__enter__()
- s4 = ctx4.__enter__()
+ s1 = self.get_success(ctx1.__aenter__())
+ s2 = self.get_success(ctx2.__aenter__())
+ s3 = self.get_success(ctx3.__aenter__())
+ s4 = self.get_success(ctx4.__aenter__())
self.assertEqual(s1, 8)
self.assertEqual(s2, 9)
@@ -152,22 +152,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- ctx2.__exit__(None, None, None)
+ self.get_success(ctx2.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
- ctx1.__exit__(None, None, None)
+ self.get_success(ctx1.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
- ctx4.__exit__(None, None, None)
+ self.get_success(ctx4.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
- ctx3.__exit__(None, None, None)
+ self.get_success(ctx3.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
@@ -190,7 +190,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager.
async def _get_next_async():
- with await first_id_gen.get_next() as stream_id:
+ async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(
@@ -208,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# stream ID
async def _get_next_async():
- with await second_id_gen.get_next() as stream_id:
+ async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9)
self.assertEqual(
@@ -305,9 +305,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
- with self.get_success(id_gen.get_next()) as stream_id:
- self.assertEqual(stream_id, 6)
- self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ async def _get_next_async():
+ async with id_gen.get_next() as stream_id:
+ self.assertEqual(stream_id, 6)
+ self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
@@ -373,16 +377,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""
id_gen = self._create_id_generator()
- with self.get_success(id_gen.get_next()) as stream_id:
- self._insert_row("master", stream_id)
+ async def _get_next_async():
+ async with id_gen.get_next() as stream_id:
+ self._insert_row("master", stream_id)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": -1})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
- with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
- for stream_id in stream_ids:
- self._insert_row("master", stream_id)
+ async def _get_next_async2():
+ async with id_gen.get_next_mult(3) as stream_ids:
+ for stream_id in stream_ids:
+ self._insert_row("master", stream_id)
+
+ self.get_success(_get_next_async2())
self.assertEqual(id_gen.get_positions(), {"master": -4})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
@@ -402,18 +412,24 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen_1 = self._create_id_generator("first")
id_gen_2 = self._create_id_generator("second")
- with self.get_success(id_gen_1.get_next()) as stream_id:
- self._insert_row("first", stream_id)
- id_gen_2.advance("first", stream_id)
+ async def _get_next_async():
+ async with id_gen_1.get_next() as stream_id:
+ self._insert_row("first", stream_id)
+ id_gen_2.advance("first", stream_id)
+
+ self.get_success(_get_next_async())
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
- with self.get_success(id_gen_2.get_next()) as stream_id:
- self._insert_row("second", stream_id)
- id_gen_1.advance("second", stream_id)
+ async def _get_next_async2():
+ async with id_gen_2.get_next() as stream_id:
+ self._insert_row("second", stream_id)
+ id_gen_1.advance("second", stream_id)
+
+ self.get_success(_get_next_async2())
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
|