summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/storage/test_id_generators.py66
1 files changed, 41 insertions, 25 deletions
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 20636fc400..fb8f5bc255 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -111,7 +111,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # advanced after we leave the context manager.
 
         async def _get_next_async():
-            with await id_gen.get_next() as stream_id:
+            async with id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 8)
 
                 self.assertEqual(id_gen.get_positions(), {"master": 7})
@@ -139,10 +139,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         ctx3 = self.get_success(id_gen.get_next())
         ctx4 = self.get_success(id_gen.get_next())
 
-        s1 = ctx1.__enter__()
-        s2 = ctx2.__enter__()
-        s3 = ctx3.__enter__()
-        s4 = ctx4.__enter__()
+        s1 = self.get_success(ctx1.__aenter__())
+        s2 = self.get_success(ctx2.__aenter__())
+        s3 = self.get_success(ctx3.__aenter__())
+        s4 = self.get_success(ctx4.__aenter__())
 
         self.assertEqual(s1, 8)
         self.assertEqual(s2, 9)
@@ -152,22 +152,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"master": 7})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
-        ctx2.__exit__(None, None, None)
+        self.get_success(ctx2.__aexit__(None, None, None))
 
         self.assertEqual(id_gen.get_positions(), {"master": 7})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
 
-        ctx1.__exit__(None, None, None)
+        self.get_success(ctx1.__aexit__(None, None, None))
 
         self.assertEqual(id_gen.get_positions(), {"master": 9})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
 
-        ctx4.__exit__(None, None, None)
+        self.get_success(ctx4.__aexit__(None, None, None))
 
         self.assertEqual(id_gen.get_positions(), {"master": 9})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
 
-        ctx3.__exit__(None, None, None)
+        self.get_success(ctx3.__aexit__(None, None, None))
 
         self.assertEqual(id_gen.get_positions(), {"master": 11})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
@@ -190,7 +190,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # advanced after we leave the context manager.
 
         async def _get_next_async():
-            with await first_id_gen.get_next() as stream_id:
+            async with first_id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 8)
 
                 self.assertEqual(
@@ -208,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         # stream ID
 
         async def _get_next_async():
-            with await second_id_gen.get_next() as stream_id:
+            async with second_id_gen.get_next() as stream_id:
                 self.assertEqual(stream_id, 9)
 
                 self.assertEqual(
@@ -305,9 +305,13 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
         self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
 
         self.assertEqual(id_gen.get_persisted_upto_position(), 3)
-        with self.get_success(id_gen.get_next()) as stream_id:
-            self.assertEqual(stream_id, 6)
-            self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+        async def _get_next_async():
+            async with id_gen.get_next() as stream_id:
+                self.assertEqual(stream_id, 6)
+                self.assertEqual(id_gen.get_persisted_upto_position(), 3)
+
+        self.get_success(_get_next_async())
 
         self.assertEqual(id_gen.get_persisted_upto_position(), 6)
 
@@ -373,16 +377,22 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         """
         id_gen = self._create_id_generator()
 
-        with self.get_success(id_gen.get_next()) as stream_id:
-            self._insert_row("master", stream_id)
+        async def _get_next_async():
+            async with id_gen.get_next() as stream_id:
+                self._insert_row("master", stream_id)
+
+        self.get_success(_get_next_async())
 
         self.assertEqual(id_gen.get_positions(), {"master": -1})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
         self.assertEqual(id_gen.get_persisted_upto_position(), -1)
 
-        with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
-            for stream_id in stream_ids:
-                self._insert_row("master", stream_id)
+        async def _get_next_async2():
+            async with id_gen.get_next_mult(3) as stream_ids:
+                for stream_id in stream_ids:
+                    self._insert_row("master", stream_id)
+
+        self.get_success(_get_next_async2())
 
         self.assertEqual(id_gen.get_positions(), {"master": -4})
         self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
@@ -402,18 +412,24 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         id_gen_1 = self._create_id_generator("first")
         id_gen_2 = self._create_id_generator("second")
 
-        with self.get_success(id_gen_1.get_next()) as stream_id:
-            self._insert_row("first", stream_id)
-            id_gen_2.advance("first", stream_id)
+        async def _get_next_async():
+            async with id_gen_1.get_next() as stream_id:
+                self._insert_row("first", stream_id)
+                id_gen_2.advance("first", stream_id)
+
+        self.get_success(_get_next_async())
 
         self.assertEqual(id_gen_1.get_positions(), {"first": -1})
         self.assertEqual(id_gen_2.get_positions(), {"first": -1})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
 
-        with self.get_success(id_gen_2.get_next()) as stream_id:
-            self._insert_row("second", stream_id)
-            id_gen_1.advance("second", stream_id)
+        async def _get_next_async2():
+            async with id_gen_2.get_next() as stream_id:
+                self._insert_row("second", stream_id)
+                id_gen_1.advance("second", stream_id)
+
+        self.get_success(_get_next_async2())
 
         self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
         self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})