diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 7b345b03bb..508aeba078 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,7 +17,7 @@
"""
Utilities for running the unit tests
"""
-from typing import Awaitable, TypeVar
+from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")
+
+
+async def make_awaitable(result: Any):
+ """Create an awaitable that just returns a result."""
+ return result
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 431e9f8e5e..fb1ca90336 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -13,25 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Collection
-
-from tests.test_utils import get_awaitable_result
-
"""
Utility functions for poking events into the storage of the server under test.
"""
-def inject_member_event(
+async def inject_member_event(
hs: synapse.server.HomeServer,
room_id: str,
sender: str,
@@ -48,7 +43,7 @@ def inject_member_event(
if extra_content:
content.update(extra_content)
- return inject_event(
+ return await inject_event(
hs,
room_id=room_id,
type=EventTypes.Member,
@@ -59,10 +54,10 @@ def inject_member_event(
)
-def inject_event(
+async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> EventBase:
"""Inject a generic event into a room
@@ -74,37 +69,27 @@ def inject_event(
prev_event_ids: prev_events for the event. If not specified, will be looked up
kwargs: fields for the event to be created
"""
- test_reactor = hs.get_reactor()
+ event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
- event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
-
- d = hs.get_storage().persistence.persist_event(event, context)
- test_reactor.advance(0)
- get_awaitable_result(d)
+ await hs.get_storage().persistence.persist_event(event, context)
return event
-def create_event(
+async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
- prev_event_ids: Optional[Collection[str]] = None,
+ prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
- test_reactor = hs.get_reactor()
-
if room_version is None:
- d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
- test_reactor.advance(0)
- room_version = get_awaitable_result(d)
+ room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
- d = hs.get_event_creation_handler().create_new_client_event(
+ event, context = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
- test_reactor.advance(0)
- event, context = get_awaitable_result(d)
return event, context
|