summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7884.misc1
-rw-r--r--synapse/handlers/message.py288
-rw-r--r--tests/events/test_snapshot.py36
-rw-r--r--tests/replication/tcp/streams/test_events.py76
-rw-r--r--tests/storage/test_roommember.py56
-rw-r--r--tests/storage/test_state.py4
-rw-r--r--tests/test_utils/event_injection.py28
-rw-r--r--tests/test_visibility.py14
-rw-r--r--tests/unittest.py4
-rw-r--r--tests/utils.py4
10 files changed, 273 insertions, 238 deletions
diff --git a/changelog.d/7884.misc b/changelog.d/7884.misc
new file mode 100644
index 0000000000..36c7d4de67
--- /dev/null
+++ b/changelog.d/7884.misc
@@ -0,0 +1 @@
+Convert the message handler to async/await.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c47764a4ce..172a7214b2 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,12 +15,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json, json
 
-from twisted.internet import defer
-from twisted.internet.defer import succeed
 from twisted.internet.interfaces import IDelayedCall
 
 from synapse import event_auth
@@ -41,13 +39,22 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.api.urls import ConsentURIBuilder
 from synapse.events import EventBase
+from synapse.events.builder import EventBuilder
+from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.logging.context import run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import Collection, RoomAlias, UserID, create_requester
+from synapse.types import (
+    Collection,
+    Requester,
+    RoomAlias,
+    StreamToken,
+    UserID,
+    create_requester,
+)
 from synapse.util.async_helpers import Linearizer
 from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.metrics import measure_func
@@ -84,14 +91,22 @@ class MessageHandler(object):
                 "_schedule_next_expiry", self._schedule_next_expiry
             )
 
-    @defer.inlineCallbacks
-    def get_room_data(
-        self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False
-    ):
+    async def get_room_data(
+        self,
+        user_id: str = None,
+        room_id: str = None,
+        event_type: Optional[str] = None,
+        state_key: str = "",
+        is_guest: bool = False,
+    ) -> dict:
         """ Get data from a room.
 
         Args:
-            event : The room path event
+            user_id
+            room_id
+            event_type
+            state_key
+            is_guest
         Returns:
             The path data content.
         Raises:
@@ -100,30 +115,29 @@ class MessageHandler(object):
         (
             membership,
             membership_event_id,
-        ) = yield self.auth.check_user_in_room_or_world_readable(
+        ) = await self.auth.check_user_in_room_or_world_readable(
             room_id, user_id, allow_departed_users=True
         )
 
         if membership == Membership.JOIN:
-            data = yield self.state.get_current_state(room_id, event_type, state_key)
+            data = await self.state.get_current_state(room_id, event_type, state_key)
         elif membership == Membership.LEAVE:
             key = (event_type, state_key)
-            room_state = yield self.state_store.get_state_for_events(
+            room_state = await self.state_store.get_state_for_events(
                 [membership_event_id], StateFilter.from_types([key])
             )
             data = room_state[membership_event_id].get(key)
 
         return data
 
-    @defer.inlineCallbacks
-    def get_state_events(
+    async def get_state_events(
         self,
-        user_id,
-        room_id,
-        state_filter=StateFilter.all(),
-        at_token=None,
-        is_guest=False,
-    ):
+        user_id: str,
+        room_id: str,
+        state_filter: StateFilter = StateFilter.all(),
+        at_token: Optional[StreamToken] = None,
+        is_guest: bool = False,
+    ) -> List[dict]:
         """Retrieve all state events for a given room. If the user is
         joined to the room then return the current state. If the user has
         left the room return the state events from when they left. If an explicit
@@ -131,15 +145,14 @@ class MessageHandler(object):
         visible.
 
         Args:
-            user_id(str): The user requesting state events.
-            room_id(str): The room ID to get all state events from.
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-            at_token(StreamToken|None): the stream token of the at which we are requesting
+            user_id: The user requesting state events.
+            room_id: The room ID to get all state events from.
+            state_filter: The state filter used to fetch state from the database.
+            at_token: the stream token of the at which we are requesting
                 the stats. If the user is not allowed to view the state as of that
                 stream token, we raise a 403 SynapseError. If None, returns the current
                 state based on the current_state_events table.
-            is_guest(bool): whether this user is a guest
+            is_guest: whether this user is a guest
         Returns:
             A list of dicts representing state events. [{}, {}, {}]
         Raises:
@@ -153,20 +166,20 @@ class MessageHandler(object):
             # get_recent_events_for_room operates by topo ordering. This therefore
             # does not reliably give you the state at the given stream position.
             # (https://github.com/matrix-org/synapse/issues/3305)
-            last_events, _ = yield self.store.get_recent_events_for_room(
+            last_events, _ = await self.store.get_recent_events_for_room(
                 room_id, end_token=at_token.room_key, limit=1
             )
 
             if not last_events:
                 raise NotFoundError("Can't find event for token %s" % (at_token,))
 
-            visible_events = yield filter_events_for_client(
+            visible_events = await filter_events_for_client(
                 self.storage, user_id, last_events, filter_send_to_client=False
             )
 
             event = last_events[0]
             if visible_events:
-                room_state = yield self.state_store.get_state_for_events(
+                room_state = await self.state_store.get_state_for_events(
                     [event.event_id], state_filter=state_filter
                 )
                 room_state = room_state[event.event_id]
@@ -180,23 +193,23 @@ class MessageHandler(object):
             (
                 membership,
                 membership_event_id,
-            ) = yield self.auth.check_user_in_room_or_world_readable(
+            ) = await self.auth.check_user_in_room_or_world_readable(
                 room_id, user_id, allow_departed_users=True
             )
 
             if membership == Membership.JOIN:
-                state_ids = yield self.store.get_filtered_current_state_ids(
+                state_ids = await self.store.get_filtered_current_state_ids(
                     room_id, state_filter=state_filter
                 )
-                room_state = yield self.store.get_events(state_ids.values())
+                room_state = await self.store.get_events(state_ids.values())
             elif membership == Membership.LEAVE:
-                room_state = yield self.state_store.get_state_for_events(
+                room_state = await self.state_store.get_state_for_events(
                     [membership_event_id], state_filter=state_filter
                 )
                 room_state = room_state[membership_event_id]
 
         now = self.clock.time_msec()
-        events = yield self._event_serializer.serialize_events(
+        events = await self._event_serializer.serialize_events(
             room_state.values(),
             now,
             # We don't bother bundling aggregations in when asked for state
@@ -205,15 +218,14 @@ class MessageHandler(object):
         )
         return events
 
-    @defer.inlineCallbacks
-    def get_joined_members(self, requester, room_id):
+    async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
         """Get all the joined members in the room and their profile information.
 
         If the user has left the room return the state events from when they left.
 
         Args:
-            requester(Requester): The user requesting state events.
-            room_id(str): The room ID to get all state events from.
+            requester: The user requesting state events.
+            room_id: The room ID to get all state events from.
         Returns:
             A dict of user_id to profile info
         """
@@ -221,7 +233,7 @@ class MessageHandler(object):
         if not requester.app_service:
             # We check AS auth after fetching the room membership, as it
             # requires us to pull out all joined members anyway.
-            membership, _ = yield self.auth.check_user_in_room_or_world_readable(
+            membership, _ = await self.auth.check_user_in_room_or_world_readable(
                 room_id, user_id, allow_departed_users=True
             )
             if membership != Membership.JOIN:
@@ -229,7 +241,7 @@ class MessageHandler(object):
                     "Getting joined members after leaving is not implemented"
                 )
 
-        users_with_profile = yield self.state.get_current_users_in_room(room_id)
+        users_with_profile = await self.state.get_current_users_in_room(room_id)
 
         # If this is an AS, double check that they are allowed to see the members.
         # This can either be because the AS user is in the room or because there
@@ -250,7 +262,7 @@ class MessageHandler(object):
             for user_id, profile in users_with_profile.items()
         }
 
-    def maybe_schedule_expiry(self, event):
+    def maybe_schedule_expiry(self, event: EventBase):
         """Schedule the expiry of an event if there's not already one scheduled,
         or if the one running is for an event that will expire after the provided
         timestamp.
@@ -259,7 +271,7 @@ class MessageHandler(object):
         the master process, and therefore needs to be run on there.
 
         Args:
-            event (EventBase): The event to schedule the expiry of.
+            event: The event to schedule the expiry of.
         """
 
         expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
@@ -270,8 +282,7 @@ class MessageHandler(object):
         # a task scheduled for a timestamp that's sooner than the provided one.
         self._schedule_expiry_for_event(event.event_id, expiry_ts)
 
-    @defer.inlineCallbacks
-    def _schedule_next_expiry(self):
+    async def _schedule_next_expiry(self):
         """Retrieve the ID and the expiry timestamp of the next event to be expired,
         and schedule an expiry task for it.
 
@@ -279,18 +290,18 @@ class MessageHandler(object):
         future call to save_expiry_ts can schedule a new expiry task.
         """
         # Try to get the expiry timestamp of the next event to expire.
-        res = yield self.store.get_next_event_to_expire()
+        res = await self.store.get_next_event_to_expire()
         if res:
             event_id, expiry_ts = res
             self._schedule_expiry_for_event(event_id, expiry_ts)
 
-    def _schedule_expiry_for_event(self, event_id, expiry_ts):
+    def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int):
         """Schedule an expiry task for the provided event if there's not already one
         scheduled at a timestamp that's sooner than the provided one.
 
         Args:
-            event_id (str): The ID of the event to expire.
-            expiry_ts (int): The timestamp at which to expire the event.
+            event_id: The ID of the event to expire.
+            expiry_ts: The timestamp at which to expire the event.
         """
         if self._scheduled_expiry:
             # If the provided timestamp refers to a time before the scheduled time of the
@@ -320,8 +331,7 @@ class MessageHandler(object):
             event_id,
         )
 
-    @defer.inlineCallbacks
-    def _expire_event(self, event_id):
+    async def _expire_event(self, event_id: str):
         """Retrieve and expire an event that needs to be expired from the database.
 
         If the event doesn't exist in the database, log it and delete the expiry date
@@ -336,12 +346,12 @@ class MessageHandler(object):
         try:
             # Expire the event if we know about it. This function also deletes the expiry
             # date from the database in the same database transaction.
-            yield self.store.expire_event(event_id)
+            await self.store.expire_event(event_id)
         except Exception as e:
             logger.error("Could not expire event %s: %r", event_id, e)
 
         # Schedule the expiry of the next event to expire.
-        yield self._schedule_next_expiry()
+        await self._schedule_next_expiry()
 
 
 # The duration (in ms) after which rooms should be removed
@@ -423,16 +433,15 @@ class EventCreationHandler(object):
 
         self._dummy_events_threshold = hs.config.dummy_events_threshold
 
-    @defer.inlineCallbacks
-    def create_event(
+    async def create_event(
         self,
-        requester,
-        event_dict,
-        token_id=None,
-        txn_id=None,
+        requester: Requester,
+        event_dict: dict,
+        token_id: Optional[str] = None,
+        txn_id: Optional[str] = None,
         prev_event_ids: Optional[Collection[str]] = None,
-        require_consent=True,
-    ):
+        require_consent: bool = True,
+    ) -> Tuple[EventBase, EventContext]:
         """
         Given a dict from a client, create a new event.
 
@@ -443,31 +452,29 @@ class EventCreationHandler(object):
 
         Args:
             requester
-            event_dict (dict): An entire event
-            token_id (str)
-            txn_id (str)
-
+            event_dict: An entire event
+            token_id
+            txn_id
             prev_event_ids:
                 the forward extremities to use as the prev_events for the
                 new event.
 
                 If None, they will be requested from the database.
-
-            require_consent (bool): Whether to check if the requester has
-                consented to privacy policy.
+            require_consent: Whether to check if the requester has
+                consented to the privacy policy.
         Raises:
             ResourceLimitError if server is blocked to some resource being
             exceeded
         Returns:
-            Tuple of created event (FrozenEvent), Context
+            Tuple of created event, Context
         """
-        yield self.auth.check_auth_blocking(requester.user.to_string())
+        await self.auth.check_auth_blocking(requester.user.to_string())
 
         if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
             room_version = event_dict["content"]["room_version"]
         else:
             try:
-                room_version = yield self.store.get_room_version_id(
+                room_version = await self.store.get_room_version_id(
                     event_dict["room_id"]
                 )
             except NotFoundError:
@@ -488,15 +495,11 @@ class EventCreationHandler(object):
 
                 try:
                     if "displayname" not in content:
-                        displayname = yield defer.ensureDeferred(
-                            profile.get_displayname(target)
-                        )
+                        displayname = await profile.get_displayname(target)
                         if displayname is not None:
                             content["displayname"] = displayname
                     if "avatar_url" not in content:
-                        avatar_url = yield defer.ensureDeferred(
-                            profile.get_avatar_url(target)
-                        )
+                        avatar_url = await profile.get_avatar_url(target)
                         if avatar_url is not None:
                             content["avatar_url"] = avatar_url
                 except Exception as e:
@@ -504,9 +507,9 @@ class EventCreationHandler(object):
                         "Failed to get profile information for %r: %s", target, e
                     )
 
-        is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
+        is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
         if require_consent and not is_exempt:
-            yield self.assert_accepted_privacy_policy(requester)
+            await self.assert_accepted_privacy_policy(requester)
 
         if token_id is not None:
             builder.internal_metadata.token_id = token_id
@@ -514,7 +517,7 @@ class EventCreationHandler(object):
         if txn_id is not None:
             builder.internal_metadata.txn_id = txn_id
 
-        event, context = yield self.create_new_client_event(
+        event, context = await self.create_new_client_event(
             builder=builder, requester=requester, prev_event_ids=prev_event_ids,
         )
 
@@ -530,10 +533,10 @@ class EventCreationHandler(object):
             # federation as well as those created locally. As of room v3, aliases events
             # can be created by users that are not in the room, therefore we have to
             # tolerate them in event_auth.check().
-            prev_state_ids = yield context.get_prev_state_ids()
+            prev_state_ids = await context.get_prev_state_ids()
             prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
             prev_event = (
-                yield self.store.get_event(prev_event_id, allow_none=True)
+                await self.store.get_event(prev_event_id, allow_none=True)
                 if prev_event_id
                 else None
             )
@@ -556,37 +559,36 @@ class EventCreationHandler(object):
 
         return (event, context)
 
-    def _is_exempt_from_privacy_policy(self, builder, requester):
+    async def _is_exempt_from_privacy_policy(
+        self, builder: EventBuilder, requester: Requester
+    ) -> bool:
         """"Determine if an event to be sent is exempt from having to consent
         to the privacy policy
 
         Args:
-            builder (synapse.events.builder.EventBuilder): event being created
-            requester (Requster): user requesting this event
+            builder: event being created
+            requester: user requesting this event
 
         Returns:
-            Deferred[bool]: true if the event can be sent without the user
-                consenting
+            true if the event can be sent without the user consenting
         """
         # the only thing the user can do is join the server notices room.
         if builder.type == EventTypes.Member:
             membership = builder.content.get("membership", None)
             if membership == Membership.JOIN:
-                return self._is_server_notices_room(builder.room_id)
+                return await self._is_server_notices_room(builder.room_id)
             elif membership == Membership.LEAVE:
                 # the user is always allowed to leave (but not kick people)
                 return builder.state_key == requester.user.to_string()
-        return succeed(False)
+        return False
 
-    @defer.inlineCallbacks
-    def _is_server_notices_room(self, room_id):
+    async def _is_server_notices_room(self, room_id: str) -> bool:
         if self.config.server_notices_mxid is None:
             return False
-        user_ids = yield self.store.get_users_in_room(room_id)
+        user_ids = await self.store.get_users_in_room(room_id)
         return self.config.server_notices_mxid in user_ids
 
-    @defer.inlineCallbacks
-    def assert_accepted_privacy_policy(self, requester):
+    async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
         """Check if a user has accepted the privacy policy
 
         Called when the given user is about to do something that requires
@@ -595,12 +597,10 @@ class EventCreationHandler(object):
         raised.
 
         Args:
-            requester (synapse.types.Requester):
-                The user making the request
+            requester: The user making the request
 
         Returns:
-            Deferred[None]: returns normally if the user has consented or is
-                exempt
+            Returns normally if the user has consented or is exempt
 
         Raises:
             ConsentNotGivenError: if the user has not given consent yet
@@ -621,7 +621,7 @@ class EventCreationHandler(object):
         ):
             return
 
-        u = yield self.store.get_user_by_id(user_id)
+        u = await self.store.get_user_by_id(user_id)
         assert u is not None
         if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
             # support and bot users are not required to consent
@@ -639,16 +639,20 @@ class EventCreationHandler(object):
         raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
 
     async def send_nonmember_event(
-        self, requester, event, context, ratelimit=True
+        self,
+        requester: Requester,
+        event: EventBase,
+        context: EventContext,
+        ratelimit: bool = True,
     ) -> int:
         """
         Persists and notifies local clients and federation of an event.
 
         Args:
-            event (FrozenEvent) the event to send.
-            context (Context) the context of the event.
-            ratelimit (bool): Whether to rate limit this send.
-            is_guest (bool): Whether the sender is a guest.
+            requester
+            event the event to send.
+            context: the context of the event.
+            ratelimit: Whether to rate limit this send.
 
         Return:
             The stream_id of the persisted event.
@@ -676,19 +680,20 @@ class EventCreationHandler(object):
             requester=requester, event=event, context=context, ratelimit=ratelimit
         )
 
-    @defer.inlineCallbacks
-    def deduplicate_state_event(self, event, context):
+    async def deduplicate_state_event(
+        self, event: EventBase, context: EventContext
+    ) -> None:
         """
         Checks whether event is in the latest resolved state in context.
 
         If so, returns the version of the event in context.
         Otherwise, returns None.
         """
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = await context.get_prev_state_ids()
         prev_event_id = prev_state_ids.get((event.type, event.state_key))
         if not prev_event_id:
             return
-        prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+        prev_event = await self.store.get_event(prev_event_id, allow_none=True)
         if not prev_event:
             return
 
@@ -700,7 +705,11 @@ class EventCreationHandler(object):
         return
 
     async def create_and_send_nonmember_event(
-        self, requester, event_dict, ratelimit=True, txn_id=None
+        self,
+        requester: Requester,
+        event_dict: EventBase,
+        ratelimit: bool = True,
+        txn_id: Optional[str] = None,
     ) -> Tuple[EventBase, int]:
         """
         Creates an event, then sends it.
@@ -730,17 +739,17 @@ class EventCreationHandler(object):
         return event, stream_id
 
     @measure_func("create_new_client_event")
-    @defer.inlineCallbacks
-    def create_new_client_event(
-        self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None
-    ):
+    async def create_new_client_event(
+        self,
+        builder: EventBuilder,
+        requester: Optional[Requester] = None,
+        prev_event_ids: Optional[Collection[str]] = None,
+    ) -> Tuple[EventBase, EventContext]:
         """Create a new event for a local client
 
         Args:
-            builder (EventBuilder):
-
-            requester (synapse.types.Requester|None):
-
+            builder:
+            requester:
             prev_event_ids:
                 the forward extremities to use as the prev_events for the
                 new event.
@@ -748,7 +757,7 @@ class EventCreationHandler(object):
                 If None, they will be requested from the database.
 
         Returns:
-            Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
+            Tuple of created event, context
         """
 
         if prev_event_ids is not None:
@@ -757,10 +766,10 @@ class EventCreationHandler(object):
                 % (len(prev_event_ids),)
             )
         else:
-            prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id)
+            prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
 
-        event = yield builder.build(prev_event_ids=prev_event_ids)
-        context = yield self.state.compute_event_context(event)
+        event = await builder.build(prev_event_ids=prev_event_ids)
+        context = await self.state.compute_event_context(event)
         if requester:
             context.app_service = requester.app_service
 
@@ -774,7 +783,7 @@ class EventCreationHandler(object):
             relates_to = relation["event_id"]
             aggregation_key = relation["key"]
 
-            already_exists = yield self.store.has_user_annotated_event(
+            already_exists = await self.store.has_user_annotated_event(
                 relates_to, event.type, aggregation_key, event.sender
             )
             if already_exists:
@@ -786,7 +795,12 @@ class EventCreationHandler(object):
 
     @measure_func("handle_new_client_event")
     async def handle_new_client_event(
-        self, requester, event, context, ratelimit=True, extra_users=[]
+        self,
+        requester: Requester,
+        event: EventBase,
+        context: EventContext,
+        ratelimit: bool = True,
+        extra_users: List[UserID] = [],
     ) -> int:
         """Processes a new event. This includes checking auth, persisting it,
         notifying users, sending to remote servers, etc.
@@ -795,11 +809,11 @@ class EventCreationHandler(object):
         processing.
 
         Args:
-            requester (Requester)
-            event (FrozenEvent)
-            context (EventContext)
-            ratelimit (bool)
-            extra_users (list(UserID)): Any extra users to notify about event
+            requester
+            event
+            context
+            ratelimit
+            extra_users: Any extra users to notify about event
 
         Return:
             The stream_id of the persisted event.
@@ -878,10 +892,9 @@ class EventCreationHandler(object):
                     self.store.remove_push_actions_from_staging, event.event_id
                 )
 
-    @defer.inlineCallbacks
-    def _validate_canonical_alias(
-        self, directory_handler, room_alias_str, expected_room_id
-    ):
+    async def _validate_canonical_alias(
+        self, directory_handler, room_alias_str: str, expected_room_id: str
+    ) -> None:
         """
         Ensure that the given room alias points to the expected room ID.
 
@@ -892,9 +905,7 @@ class EventCreationHandler(object):
         """
         room_alias = RoomAlias.from_string(room_alias_str)
         try:
-            mapping = yield defer.ensureDeferred(
-                directory_handler.get_association(room_alias)
-            )
+            mapping = await directory_handler.get_association(room_alias)
         except SynapseError as e:
             # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
             if e.errcode == Codes.NOT_FOUND:
@@ -913,7 +924,12 @@ class EventCreationHandler(object):
             )
 
     async def persist_and_notify_client_event(
-        self, requester, event, context, ratelimit=True, extra_users=[]
+        self,
+        requester: Requester,
+        event: EventBase,
+        context: EventContext,
+        ratelimit: bool = True,
+        extra_users: List[UserID] = [],
     ) -> int:
         """Called when we have fully built the event, have already
         calculated the push actions for the event, and checked auth.
@@ -1106,7 +1122,7 @@ class EventCreationHandler(object):
 
         return event_stream_id
 
-    async def _bump_active_time(self, user):
+    async def _bump_active_time(self, user: UserID) -> None:
         try:
             presence = self.hs.get_presence_handler()
             await presence.bump_presence_active_time(user)
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 640f5f3bce..3a80626224 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -41,8 +41,10 @@ class TestEventContext(unittest.HomeserverTestCase):
         serialize/deserialize.
         """
 
-        event, context = create_event(
-            self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+        event, context = self.get_success(
+            create_event(
+                self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
+            )
         )
 
         self._check_serialize_deserialize(event, context)
@@ -51,12 +53,14 @@ class TestEventContext(unittest.HomeserverTestCase):
         """Test that an EventContext for a state event (with not previous entry)
         is the same after serialize/deserialize.
         """
-        event, context = create_event(
-            self.hs,
-            room_id=self.room_id,
-            type="m.test",
-            sender=self.user_id,
-            state_key="",
+        event, context = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                type="m.test",
+                sender=self.user_id,
+                state_key="",
+            )
         )
 
         self._check_serialize_deserialize(event, context)
@@ -65,13 +69,15 @@ class TestEventContext(unittest.HomeserverTestCase):
         """Test that an EventContext for a state event (which replaces a
         previous entry) is the same after serialize/deserialize.
         """
-        event, context = create_event(
-            self.hs,
-            room_id=self.room_id,
-            type="m.room.member",
-            sender=self.user_id,
-            state_key=self.user_id,
-            content={"membership": "leave"},
+        event, context = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                type="m.room.member",
+                sender=self.user_id,
+                state_key=self.user_id,
+                content={"membership": "leave"},
+            )
         )
 
         self._check_serialize_deserialize(event, context)
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 097e1653b4..c9998e88e6 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -119,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
         OTHER_USER = "@other_user:localhost"
 
         # have the user join
-        inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+        self.get_success(
+            inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+        )
 
         # Update existing power levels with mod at PL50
         pls = self.helper.get_state(
@@ -157,14 +159,16 @@ class EventsStreamTestCase(BaseStreamTestCase):
         # roll back all the state by de-modding the user
         prev_events = fork_point
         pls["users"][OTHER_USER] = 0
-        pl_event = inject_event(
-            self.hs,
-            prev_event_ids=prev_events,
-            type=EventTypes.PowerLevels,
-            state_key="",
-            sender=self.user_id,
-            room_id=self.room_id,
-            content=pls,
+        pl_event = self.get_success(
+            inject_event(
+                self.hs,
+                prev_event_ids=prev_events,
+                type=EventTypes.PowerLevels,
+                state_key="",
+                sender=self.user_id,
+                room_id=self.room_id,
+                content=pls,
+            )
         )
 
         # one more bit of state that doesn't get rolled back
@@ -268,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         # have the users join
         for u in user_ids:
-            inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+            self.get_success(
+                inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+            )
 
         # Update existing power levels with mod at PL50
         pls = self.helper.get_state(
@@ -306,14 +312,16 @@ class EventsStreamTestCase(BaseStreamTestCase):
         pl_events = []
         for u in user_ids:
             pls["users"][u] = 0
-            e = inject_event(
-                self.hs,
-                prev_event_ids=prev_events,
-                type=EventTypes.PowerLevels,
-                state_key="",
-                sender=self.user_id,
-                room_id=self.room_id,
-                content=pls,
+            e = self.get_success(
+                inject_event(
+                    self.hs,
+                    prev_event_ids=prev_events,
+                    type=EventTypes.PowerLevels,
+                    state_key="",
+                    sender=self.user_id,
+                    room_id=self.room_id,
+                    content=pls,
+                )
             )
             prev_events = [e.event_id]
             pl_events.append(e)
@@ -434,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase):
             body = "event %i" % (self.event_count,)
             self.event_count += 1
 
-        return inject_event(
-            self.hs,
-            room_id=self.room_id,
-            sender=sender,
-            type="test_event",
-            content={"body": body},
-            **kwargs
+        return self.get_success(
+            inject_event(
+                self.hs,
+                room_id=self.room_id,
+                sender=sender,
+                type="test_event",
+                content={"body": body},
+                **kwargs
+            )
         )
 
     def _inject_state_event(
@@ -459,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase):
         if body is None:
             body = "state event %s" % (state_key,)
 
-        return inject_event(
-            self.hs,
-            room_id=self.room_id,
-            sender=sender,
-            type="test_state_event",
-            state_key=state_key,
-            content={"body": body},
+        return self.get_success(
+            inject_event(
+                self.hs,
+                room_id=self.room_id,
+                sender=sender,
+                type="test_state_event",
+                state_key=state_key,
+                content={"body": body},
+            )
         )
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 5dd46005e6..f282921538 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -118,18 +118,22 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
 
     def test_get_joined_users_from_context(self):
         room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
-        bob_event = event_injection.inject_member_event(
-            self.hs, room, self.u_bob, Membership.JOIN
+        bob_event = self.get_success(
+            event_injection.inject_member_event(
+                self.hs, room, self.u_bob, Membership.JOIN
+            )
         )
 
         # first, create a regular event
-        event, context = event_injection.create_event(
-            self.hs,
-            room_id=room,
-            sender=self.u_alice,
-            prev_event_ids=[bob_event.event_id],
-            type="m.test.1",
-            content={},
+        event, context = self.get_success(
+            event_injection.create_event(
+                self.hs,
+                room_id=room,
+                sender=self.u_alice,
+                prev_event_ids=[bob_event.event_id],
+                type="m.test.1",
+                content={},
+            )
         )
 
         users = self.get_success(
@@ -140,22 +144,26 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
         # Regression test for #7376: create a state event whose key matches bob's
         # user_id, but which is *not* a membership event, and persist that; then check
         # that `get_joined_users_from_context` returns the correct users for the next event.
-        non_member_event = event_injection.inject_event(
-            self.hs,
-            room_id=room,
-            sender=self.u_bob,
-            prev_event_ids=[bob_event.event_id],
-            type="m.test.2",
-            state_key=self.u_bob,
-            content={},
+        non_member_event = self.get_success(
+            event_injection.inject_event(
+                self.hs,
+                room_id=room,
+                sender=self.u_bob,
+                prev_event_ids=[bob_event.event_id],
+                type="m.test.2",
+                state_key=self.u_bob,
+                content={},
+            )
         )
-        event, context = event_injection.create_event(
-            self.hs,
-            room_id=room,
-            sender=self.u_alice,
-            prev_event_ids=[non_member_event.event_id],
-            type="m.test.3",
-            content={},
+        event, context = self.get_success(
+            event_injection.create_event(
+                self.hs,
+                room_id=room,
+                sender=self.u_alice,
+                prev_event_ids=[non_member_event.event_id],
+                type="m.test.3",
+                content={},
+            )
         )
         users = self.get_success(
             self.store.get_joined_users_from_context(event, context)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 0b88308ff4..a0e133cd4a 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -64,8 +64,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
-            builder
+        event, context = yield defer.ensureDeferred(
+            self.event_creation_handler.create_new_client_event(builder)
         )
 
         yield self.storage.persistence.persist_event(event, context)
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 43297b530c..8522c6fc09 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -22,14 +22,12 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import Collection
 
-from tests.test_utils import get_awaitable_result
-
 """
 Utility functions for poking events into the storage of the server under test.
 """
 
 
-def inject_member_event(
+async def inject_member_event(
     hs: synapse.server.HomeServer,
     room_id: str,
     sender: str,
@@ -46,7 +44,7 @@ def inject_member_event(
     if extra_content:
         content.update(extra_content)
 
-    return inject_event(
+    return await inject_event(
         hs,
         room_id=room_id,
         type=EventTypes.Member,
@@ -57,7 +55,7 @@ def inject_member_event(
     )
 
 
-def inject_event(
+async def inject_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
     prev_event_ids: Optional[Collection[str]] = None,
@@ -72,37 +70,27 @@ def inject_event(
         prev_event_ids: prev_events for the event. If not specified, will be looked up
         kwargs: fields for the event to be created
     """
-    test_reactor = hs.get_reactor()
-
-    event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
+    event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
 
-    d = hs.get_storage().persistence.persist_event(event, context)
-    test_reactor.advance(0)
-    get_awaitable_result(d)
+    await hs.get_storage().persistence.persist_event(event, context)
 
     return event
 
 
-def create_event(
+async def create_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
     prev_event_ids: Optional[Collection[str]] = None,
     **kwargs
 ) -> Tuple[EventBase, EventContext]:
-    test_reactor = hs.get_reactor()
-
     if room_version is None:
-        d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
-        test_reactor.advance(0)
-        room_version = get_awaitable_result(d)
+        room_version = await hs.get_datastore().get_room_version_id(kwargs["room_id"])
 
     builder = hs.get_event_builder_factory().for_room_version(
         KNOWN_ROOM_VERSIONS[room_version], kwargs
     )
-    d = hs.get_event_creation_handler().create_new_client_event(
+    event, context = await hs.get_event_creation_handler().create_new_client_event(
         builder, prev_event_ids=prev_event_ids
     )
-    test_reactor.advance(0)
-    event, context = get_awaitable_result(d)
 
     return event, context
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index f7381b2885..b371efc0df 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -53,7 +53,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         #
 
         # before we do that, we persist some other events to act as state.
-        self.inject_visibility("@admin:hs", "joined")
+        yield self.inject_visibility("@admin:hs", "joined")
         for i in range(0, 10):
             yield self.inject_room_member("@resident%i:hs" % i)
 
@@ -137,8 +137,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
-            builder
+        event, context = yield defer.ensureDeferred(
+            self.event_creation_handler.create_new_client_event(builder)
         )
         yield self.storage.persistence.persist_event(event, context)
         return event
@@ -158,8 +158,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
-            builder
+        event, context = yield defer.ensureDeferred(
+            self.event_creation_handler.create_new_client_event(builder)
         )
 
         yield self.storage.persistence.persist_event(event, context)
@@ -179,8 +179,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
-            builder
+        event, context = yield defer.ensureDeferred(
+            self.event_creation_handler.create_new_client_event(builder)
         )
 
         yield self.storage.persistence.persist_event(event, context)
diff --git a/tests/unittest.py b/tests/unittest.py
index 3175a3fa02..68d2586efd 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -603,7 +603,9 @@ class HomeserverTestCase(TestCase):
             user: MXID of the user to inject the membership for.
             membership: The membership type.
         """
-        event_injection.inject_member_event(self.hs, room, user, membership)
+        self.get_success(
+            event_injection.inject_member_event(self.hs, room, user, membership)
+        )
 
 
 class FederatingHomeserverTestCase(HomeserverTestCase):
diff --git a/tests/utils.py b/tests/utils.py
index 4d17355a5c..ac643679aa 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -671,6 +671,8 @@ def create_room(hs, room_id, creator_id):
         },
     )
 
-    event, context = yield event_creation_handler.create_new_client_event(builder)
+    event, context = yield defer.ensureDeferred(
+        event_creation_handler.create_new_client_event(builder)
+    )
 
     yield persistence_store.persist_event(event, context)