summary refs log tree commit diff
path: root/tests/test_utils
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_utils')
-rw-r--r--tests/test_utils/__init__.py7
-rw-r--r--tests/test_utils/event_injection.py28
2 files changed, 14 insertions, 21 deletions
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 7b345b03bb..508aeba078 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,7 +17,7 @@
 """
 Utilities for running the unit tests
 """
-from typing import Awaitable, TypeVar
+from typing import Any, Awaitable, TypeVar
 
 TV = TypeVar("TV")
 
@@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
 
     # if next didn't raise, the awaitable hasn't completed.
     raise Exception("awaitable has not yet completed")
+
+
+async def make_awaitable(result: Any):
+    """Create an awaitable that just returns a result."""
+    return result
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 43297b530c..8522c6fc09 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -22,14 +22,12 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import Collection
 
-from tests.test_utils import get_awaitable_result
-
 """
 Utility functions for poking events into the storage of the server under test.
 """
 
 
-def inject_member_event(
+async def inject_member_event(
     hs: synapse.server.HomeServer,
     room_id: str,
     sender: str,
@@ -46,7 +44,7 @@ def inject_member_event(
     if extra_content:
         content.update(extra_content)
 
-    return inject_event(
+    return await inject_event(
         hs,
         room_id=room_id,
         type=EventTypes.Member,
@@ -57,7 +55,7 @@ def inject_member_event(
     )
 
 
-def inject_event(
+async def inject_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
     prev_event_ids: Optional[Collection[str]] = None,
@@ -72,37 +70,27 @@ def inject_event(
         prev_event_ids: prev_events for the event. If not specified, will be looked up
         kwargs: fields for the event to be created
     """
-    test_reactor = hs.get_reactor()
-
-    event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
+    event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
 
-    d = hs.get_storage().persistence.persist_event(event, context)
-    test_reactor.advance(0)
-    get_awaitable_result(d)
+    await hs.get_storage().persistence.persist_event(event, context)
 
     return event
 
 
-def create_event(
+async def create_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
     prev_event_ids: Optional[Collection[str]] = None,
     **kwargs
 ) -> Tuple[EventBase, EventContext]:
-    test_reactor = hs.get_reactor()
-
     if room_version is None:
-        d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
-        test_reactor.advance(0)
-        room_version = get_awaitable_result(d)
+        room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
 
     builder = hs.get_event_builder_factory().for_room_version(
         KNOWN_ROOM_VERSIONS[room_version], kwargs
     )
-    d = hs.get_event_creation_handler().create_new_client_event(
+    event, context = await hs.get_event_creation_handler().create_new_client_event(
         builder, prev_event_ids=prev_event_ids
     )
-    test_reactor.advance(0)
-    event, context = get_awaitable_result(d)
 
     return event, context