2 files changed, 14 insertions, 4 deletions
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 508aeba078..a298cc0fd3 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,6 +17,7 @@
"""
Utilities for running the unit tests
"""
+from asyncio import Future
from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -38,6 +39,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed")
-async def make_awaitable(result: Any):
- """Create an awaitable that just returns a result."""
- return result
+def make_awaitable(result: Any) -> Awaitable[Any]:
+ """
+ Makes an awaitable, suitable for mocking an `async` function.
+ This uses Futures as they can be awaited multiple times so can be returned
+ to multiple callers.
+ """
+ future = Future() # type: ignore
+ future.set_result(result)
+ return future
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index fb1ca90336..e93aa84405 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -71,7 +71,10 @@ async def inject_event(
"""
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- await hs.get_storage().persistence.persist_event(event, context)
+ persistence = hs.get_storage().persistence
+ assert persistence is not None
+
+ await persistence.persist_event(event, context)
return event
|