summary refs log tree commit diff
path: root/tests/test_utils
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_utils')
-rw-r--r--tests/test_utils/__init__.py20
-rw-r--r--tests/test_utils/event_injection.py96
2 files changed, 116 insertions, 0 deletions
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index a7310cf12a..7b345b03bb 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2019 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -16,3 +17,22 @@
 """
 Utilities for running the unit tests
 """
+from typing import Awaitable, TypeVar
+
+TV = TypeVar("TV")
+
+
+def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
+    """Get the result from an Awaitable which should have completed
+
+    Asserts that the given awaitable has a result ready, and returns its value
+    """
+    i = awaitable.__await__()
+    try:
+        next(i)
+    except StopIteration as e:
+        # awaitable returned a result
+        return e.value
+
+    # if next didn't raise, the awaitable hasn't completed.
+    raise Exception("awaitable has not yet completed")
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
new file mode 100644
index 0000000000..8f6872761a
--- /dev/null
+++ b/tests/test_utils/event_injection.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import synapse.server
+from synapse.api.constants import EventTypes
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import EventBase
+from synapse.types import Collection
+
+from tests.test_utils import get_awaitable_result
+
+
+"""
+Utility functions for poking events into the storage of the server under test.
+"""
+
+
+def inject_member_event(
+    hs: synapse.server.HomeServer,
+    room_id: str,
+    sender: str,
+    membership: str,
+    target: Optional[str] = None,
+    extra_content: Optional[dict] = None,
+    **kwargs
+) -> EventBase:
+    """Inject a membership event into a room."""
+    if target is None:
+        target = sender
+
+    content = {"membership": membership}
+    if extra_content:
+        content.update(extra_content)
+
+    return inject_event(
+        hs,
+        room_id=room_id,
+        type=EventTypes.Member,
+        sender=sender,
+        state_key=target,
+        content=content,
+        **kwargs
+    )
+
+
+def inject_event(
+    hs: synapse.server.HomeServer,
+    room_version: Optional[str] = None,
+    prev_event_ids: Optional[Collection[str]] = None,
+    **kwargs
+) -> EventBase:
+    """Inject a generic event into a room
+
+    Args:
+        hs: the homeserver under test
+        room_version: the version of the room we're inserting into.
+            if not specified, will be looked up
+        prev_event_ids: prev_events for the event. If not specified, will be looked up
+        kwargs: fields for the event to be created
+    """
+    test_reactor = hs.get_reactor()
+
+    if room_version is None:
+        d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
+        test_reactor.advance(0)
+        room_version = get_awaitable_result(d)
+
+    builder = hs.get_event_builder_factory().for_room_version(
+        KNOWN_ROOM_VERSIONS[room_version], kwargs
+    )
+    d = hs.get_event_creation_handler().create_new_client_event(
+        builder, prev_event_ids=prev_event_ids
+    )
+    test_reactor.advance(0)
+    event, context = get_awaitable_result(d)
+
+    d = hs.get_storage().persistence.persist_event(event, context)
+    test_reactor.advance(0)
+    get_awaitable_result(d)
+
+    return event