summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-07-28 16:09:53 -0400
committerGitHub <noreply@github.com>2020-07-28 16:09:53 -0400
commit3345c166a45cb4a8f87c583ee0476c2bca5c41bd (patch)
tree8bc5a87a123313c2da008e4a82d0621459b717c5
parentAdd an option to disable purge in delete room admin API (#7964) (diff)
downloadsynapse-3345c166a45cb4a8f87c583ee0476c2bca5c41bd.tar.xz
Convert storage layer to async/await. (#7963)
-rw-r--r--changelog.d/7963.misc1
-rw-r--r--synapse/storage/persist_events.py40
-rw-r--r--synapse/storage/purge_events.py38
-rw-r--r--synapse/storage/state.py207
-rw-r--r--tests/storage/test_purge.py8
-rw-r--r--tests/storage/test_room.py6
-rw-r--r--tests/storage/test_state.py64
-rw-r--r--tests/test_visibility.py14
-rw-r--r--tests/utils.py16
-rw-r--r--tox.ini1
10 files changed, 210 insertions, 185 deletions
diff --git a/changelog.d/7963.misc b/changelog.d/7963.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/7963.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 78fbdcdee8..4a164834d9 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.events import FrozenEvent
+from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -192,12 +192,11 @@ class EventsPersistenceStorage(object):
         self._event_persist_queue = _EventPeristenceQueue()
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
-    @defer.inlineCallbacks
-    def persist_events(
+    async def persist_events(
         self,
-        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ):
+    ) -> int:
         """
         Write events to the database
         Args:
@@ -207,7 +206,7 @@ class EventsPersistenceStorage(object):
                 which might update the current state etc.
 
         Returns:
-            Deferred[int]: the stream ordering of the latest persisted event
+            the stream ordering of the latest persisted event
         """
         partitioned = {}
         for event, ctx in events_and_contexts:
@@ -223,22 +222,19 @@ class EventsPersistenceStorage(object):
         for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
-        yield make_deferred_yieldable(
+        await make_deferred_yieldable(
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
-        max_persisted_id = yield self.main_store.get_current_events_token()
-
-        return max_persisted_id
+        return self.main_store.get_current_events_token()
 
-    @defer.inlineCallbacks
-    def persist_event(
-        self, event: FrozenEvent, context: EventContext, backfilled: bool = False
-    ):
+    async def persist_event(
+        self, event: EventBase, context: EventContext, backfilled: bool = False
+    ) -> Tuple[int, int]:
         """
         Returns:
-            Deferred[Tuple[int, int]]: the stream ordering of ``event``,
-            and the stream ordering of the latest persisted event
+            The stream ordering of `event`, and the stream ordering of the
+            latest persisted event
         """
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)], backfilled=backfilled
@@ -246,9 +242,9 @@ class EventsPersistenceStorage(object):
 
         self._maybe_start_persisting(event.room_id)
 
-        yield make_deferred_yieldable(deferred)
+        await make_deferred_yieldable(deferred)
 
-        max_persisted_id = yield self.main_store.get_current_events_token()
+        max_persisted_id = self.main_store.get_current_events_token()
         return (event.internal_metadata.stream_ordering, max_persisted_id)
 
     def _maybe_start_persisting(self, room_id: str):
@@ -262,7 +258,7 @@ class EventsPersistenceStorage(object):
 
     async def _persist_events(
         self,
-        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
     ):
         """Calculates the change to current state and forward extremities, and
@@ -439,7 +435,7 @@ class EventsPersistenceStorage(object):
     async def _calculate_new_extremities(
         self,
         room_id: str,
-        event_contexts: List[Tuple[FrozenEvent, EventContext]],
+        event_contexts: List[Tuple[EventBase, EventContext]],
         latest_event_ids: List[str],
     ):
         """Calculates the new forward extremities for a room given events to
@@ -497,7 +493,7 @@ class EventsPersistenceStorage(object):
     async def _get_new_state_after_events(
         self,
         room_id: str,
-        events_context: List[Tuple[FrozenEvent, EventContext]],
+        events_context: List[Tuple[EventBase, EventContext]],
         old_latest_event_ids: Iterable[str],
         new_latest_event_ids: Iterable[str],
     ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
     async def _is_server_still_joined(
         self,
         room_id: str,
-        ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
+        ev_ctx_rm: List[Tuple[EventBase, EventContext]],
         delta: DeltaState,
         current_state: Optional[StateMap[str]],
         potentially_left_users: Set[str],
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index fdc0abf5cf..79d9f06e2e 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -15,8 +15,7 @@
 
 import itertools
 import logging
-
-from twisted.internet import defer
+from typing import Set
 
 logger = logging.getLogger(__name__)
 
@@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
     def __init__(self, hs, stores):
         self.stores = stores
 
-    @defer.inlineCallbacks
-    def purge_room(self, room_id: str):
+    async def purge_room(self, room_id: str):
         """Deletes all record of a room
         """
 
-        state_groups_to_delete = yield self.stores.main.purge_room(room_id)
-        yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
+        state_groups_to_delete = await self.stores.main.purge_room(room_id)
+        await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
 
-    @defer.inlineCallbacks
-    def purge_history(self, room_id, token, delete_local_events):
+    async def purge_history(
+        self, room_id: str, token: str, delete_local_events: bool
+    ) -> None:
         """Deletes room history before a certain point
 
         Args:
-            room_id (str):
+            room_id: The room ID
 
-            token (str): A topological token to delete events before
+            token: A topological token to delete events before
 
-            delete_local_events (bool):
+            delete_local_events:
                 if True, we will delete local events as well as remote ones
                 (instead of just marking them as outliers and deleting their
                 state groups).
         """
-        state_groups = yield self.stores.main.purge_history(
+        state_groups = await self.stores.main.purge_history(
             room_id, token, delete_local_events
         )
 
         logger.info("[purge] finding state groups that can be deleted")
 
-        sg_to_delete = yield self._find_unreferenced_groups(state_groups)
+        sg_to_delete = await self._find_unreferenced_groups(state_groups)
 
-        yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
+        await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
 
-    @defer.inlineCallbacks
-    def _find_unreferenced_groups(self, state_groups):
+    async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
         """Used when purging history to figure out which state groups can be
         deleted.
 
         Args:
-            state_groups (set[int]): Set of state groups referenced by events
+            state_groups: Set of state groups referenced by events
                 that are going to be deleted.
 
         Returns:
-            Deferred[set[int]] The set of state groups that can be deleted.
+            The set of state groups that can be deleted.
         """
         # Graph of state group -> previous group
         graph = {}
@@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
                 current_search = set(itertools.islice(next_to_search, 100))
                 next_to_search -= current_search
 
-            referenced = yield self.stores.main.get_referenced_state_groups(
+            referenced = await self.stores.main.get_referenced_state_groups(
                 current_search
             )
             referenced_groups |= referenced
@@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
             # groups that are referenced.
             current_search -= referenced
 
-            edges = yield self.stores.state.get_previous_state_groups(current_search)
+            edges = await self.stores.state.get_previous_state_groups(current_search)
 
             prevs = set(edges.values())
             # We don't bother re-handling groups we've already seen
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dc568476f4..49ee9c9a74 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,13 +14,12 @@
 # limitations under the License.
 
 import logging
-from typing import Iterable, List, TypeVar
+from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar
 
 import attr
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes
+from synapse.events import EventBase
 from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
@@ -34,16 +33,16 @@ class StateFilter(object):
     """A filter used when querying for state.
 
     Attributes:
-        types (dict[str, set[str]|None]): Map from type to set of state keys (or
-            None). This specifies which state_keys for the given type to fetch
-            from the DB. If None then all events with that type are fetched. If
-            the set is empty then no events with that type are fetched.
-        include_others (bool): Whether to fetch events with types that do not
+        types: Map from type to set of state keys (or None). This specifies
+            which state_keys for the given type to fetch from the DB. If None
+            then all events with that type are fetched. If the set is empty
+            then no events with that type are fetched.
+        include_others: Whether to fetch events with types that do not
             appear in `types`.
     """
 
-    types = attr.ib()
-    include_others = attr.ib(default=False)
+    types = attr.ib(type=Dict[str, Optional[Set[str]]])
+    include_others = attr.ib(default=False, type=bool)
 
     def __attrs_post_init__(self):
         # If `include_others` is set we canonicalise the filter by removing
@@ -52,36 +51,35 @@ class StateFilter(object):
             self.types = {k: v for k, v in self.types.items() if v is not None}
 
     @staticmethod
-    def all():
+    def all() -> "StateFilter":
         """Creates a filter that fetches everything.
 
         Returns:
-            StateFilter
+            The new state filter.
         """
         return StateFilter(types={}, include_others=True)
 
     @staticmethod
-    def none():
+    def none() -> "StateFilter":
         """Creates a filter that fetches nothing.
 
         Returns:
-            StateFilter
+            The new state filter.
         """
         return StateFilter(types={}, include_others=False)
 
     @staticmethod
-    def from_types(types):
+    def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
         """Creates a filter that only fetches the given types
 
         Args:
-            types (Iterable[tuple[str, str|None]]): A list of type and state
-                keys to fetch. A state_key of None fetches everything for
-                that type
+            types: A list of type and state keys to fetch. A state_key of None
+                fetches everything for that type
 
         Returns:
-            StateFilter
+            The new state filter.
         """
-        type_dict = {}
+        type_dict = {}  # type: Dict[str, Optional[Set[str]]]
         for typ, s in types:
             if typ in type_dict:
                 if type_dict[typ] is None:
@@ -91,24 +89,24 @@ class StateFilter(object):
                 type_dict[typ] = None
                 continue
 
-            type_dict.setdefault(typ, set()).add(s)
+            type_dict.setdefault(typ, set()).add(s)  # type: ignore
 
         return StateFilter(types=type_dict)
 
     @staticmethod
-    def from_lazy_load_member_list(members):
+    def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
         """Creates a filter that returns all non-member events, plus the member
         events for the given users
 
         Args:
-            members (iterable[str]): Set of user IDs
+            members: Set of user IDs
 
         Returns:
-            StateFilter
+            The new state filter
         """
         return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
 
-    def return_expanded(self):
+    def return_expanded(self) -> "StateFilter":
         """Creates a new StateFilter where type wild cards have been removed
         (except for memberships). The returned filter is a superset of the
         current one, i.e. anything that passes the current filter will pass
@@ -130,7 +128,7 @@ class StateFilter(object):
                return all non-member events
 
         Returns:
-            StateFilter
+            The new state filter.
         """
 
         if self.is_full():
@@ -167,7 +165,7 @@ class StateFilter(object):
                 include_others=True,
             )
 
-    def make_sql_filter_clause(self):
+    def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
         """Converts the filter to an SQL clause.
 
         For example:
@@ -179,13 +177,12 @@ class StateFilter(object):
 
 
         Returns:
-            tuple[str, list]: The SQL string (may be empty) and arguments. An
-            empty SQL string is returned when the filter matches everything
-            (i.e. is "full").
+            The SQL string (may be empty) and arguments. An empty SQL string is
+            returned when the filter matches everything (i.e. is "full").
         """
 
         where_clause = ""
-        where_args = []
+        where_args = []  # type: List[str]
 
         if self.is_full():
             return where_clause, where_args
@@ -221,7 +218,7 @@ class StateFilter(object):
 
         return where_clause, where_args
 
-    def max_entries_returned(self):
+    def max_entries_returned(self) -> Optional[int]:
         """Returns the maximum number of entries this filter will return if
         known, otherwise returns None.
 
@@ -260,33 +257,33 @@ class StateFilter(object):
 
         return filtered_state
 
-    def is_full(self):
+    def is_full(self) -> bool:
         """Whether this filter fetches everything or not
 
         Returns:
-            bool
+            True if the filter fetches everything.
         """
         return self.include_others and not self.types
 
-    def has_wildcards(self):
+    def has_wildcards(self) -> bool:
         """Whether the filter includes wildcards or is attempting to fetch
         specific state.
 
         Returns:
-            bool
+            True if the filter includes wildcards.
         """
 
         return self.include_others or any(
             state_keys is None for state_keys in self.types.values()
         )
 
-    def concrete_types(self):
+    def concrete_types(self) -> List[Tuple[str, str]]:
         """Returns a list of concrete type/state_keys (i.e. not None) that
         will be fetched. This will be a complete list if `has_wildcards`
         returns False, but otherwise will be a subset (or even empty).
 
         Returns:
-            list[tuple[str,str]]
+            A list of type/state_keys tuples.
         """
         return [
             (t, s)
@@ -295,7 +292,7 @@ class StateFilter(object):
             for s in state_keys
         ]
 
-    def get_member_split(self):
+    def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
         """Return the filter split into two: one which assumes it's exclusively
         matching against member state, and one which assumes it's matching
         against non member state.
@@ -307,7 +304,7 @@ class StateFilter(object):
         state caches).
 
         Returns:
-            tuple[StateFilter, StateFilter]: The member and non member filters
+            The member and non member filters
         """
 
         if EventTypes.Member in self.types:
@@ -340,6 +337,9 @@ class StateGroupStorage(object):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
+        Args:
+            state_group: The state group used to retrieve state deltas.
+
         Returns:
             Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
                 (prev_group, delta_ids)
@@ -347,55 +347,59 @@ class StateGroupStorage(object):
 
         return self.stores.state.get_state_group_delta(state_group)
 
-    @defer.inlineCallbacks
-    def get_state_groups_ids(self, _room_id, event_ids):
+    async def get_state_groups_ids(
+        self, _room_id: str, event_ids: Iterable[str]
+    ) -> Dict[int, StateMap[str]]:
         """Get the event IDs of all the state for the state groups for the given events
 
         Args:
-            _room_id (str): id of the room for these events
-            event_ids (iterable[str]): ids of the events
+            _room_id: id of the room for these events
+            event_ids: ids of the events
 
         Returns:
-            Deferred[dict[int, StateMap[str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if not event_ids:
             return {}
 
-        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
-        group_to_state = yield self.stores.state._get_state_for_groups(groups)
+        group_to_state = await self.stores.state._get_state_for_groups(groups)
 
         return group_to_state
 
-    @defer.inlineCallbacks
-    def get_state_ids_for_group(self, state_group):
+    async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
         """Get the event IDs of all the state in the given state group
 
         Args:
-            state_group (int)
+            state_group: A state group for which we want to get the state IDs.
 
         Returns:
-            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+            Resolves to a map of (type, state_key) -> event_id
         """
-        group_to_state = yield self._get_state_for_groups((state_group,))
+        group_to_state = await self._get_state_for_groups((state_group,))
 
         return group_to_state[state_group]
 
-    @defer.inlineCallbacks
-    def get_state_groups(self, room_id, event_ids):
+    async def get_state_groups(
+        self, room_id: str, event_ids: Iterable[str]
+    ) -> Dict[int, List[EventBase]]:
         """ Get the state groups for the given list of event_ids
+
+        Args:
+            room_id: ID of the room for these events.
+            event_ids: The event IDs to retrieve state for.
+
         Returns:
-            Deferred[dict[int, list[EventBase]]]:
-                dict of state_group_id -> list of state events.
+            dict of state_group_id -> list of state events.
         """
         if not event_ids:
             return {}
 
-        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+        group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
 
-        state_event_map = yield self.stores.main.get_events(
+        state_event_map = await self.stores.main.get_events(
             [
                 ev_id
                 for group_ids in group_to_ids.values()
@@ -423,31 +427,34 @@ class StateGroupStorage(object):
             groups: list of state group IDs to query
             state_filter: The state filter used to fetch state
                 from the database.
+
         Returns:
             Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
 
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
-    @defer.inlineCallbacks
-    def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
+    async def get_state_for_events(
+        self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
+    ):
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
+
         Args:
-            event_ids (list[string])
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            event_ids: The events to fetch the state of.
+            state_filter: The state filter used to fetch state.
+
         Returns:
-            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
+            A dict of (event_id) -> (type, state_key) -> [state_events]
         """
-        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
-        group_to_state = yield self.stores.state._get_state_for_groups(
+        group_to_state = await self.stores.state._get_state_for_groups(
             groups, state_filter
         )
 
-        state_event_map = yield self.stores.main.get_events(
+        state_event_map = await self.stores.main.get_events(
             [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
             get_prev_content=False,
         )
@@ -463,24 +470,24 @@ class StateGroupStorage(object):
 
         return {event: event_to_state[event] for event in event_ids}
 
-    @defer.inlineCallbacks
-    def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
+    async def get_state_ids_for_events(
+        self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
+    ):
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
         of the state events (as opposed to the events themselves)
 
         Args:
-            event_ids(list(str)): events whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            event_ids: events whose state should be returned
+            state_filter: The state filter used to fetch state from the database.
 
         Returns:
-            A deferred dict from event_id -> (type, state_key) -> event_id
+            A dict from event_id -> (type, state_key) -> event_id
         """
-        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
-        group_to_state = yield self.stores.state._get_state_for_groups(
+        group_to_state = await self.stores.state._get_state_for_groups(
             groups, state_filter
         )
 
@@ -491,36 +498,36 @@ class StateGroupStorage(object):
 
         return {event: event_to_state[event] for event in event_ids}
 
-    @defer.inlineCallbacks
-    def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
+    async def get_state_for_event(
+        self, event_id: str, state_filter: StateFilter = StateFilter.all()
+    ):
         """
         Get the state dict corresponding to a particular event
 
         Args:
-            event_id(str): event whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            event_id: event whose state should be returned
+            state_filter: The state filter used to fetch state from the database.
 
         Returns:
-            A deferred dict from (type, state_key) -> state_event
+            A dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_for_events([event_id], state_filter)
+        state_map = await self.get_state_for_events([event_id], state_filter)
         return state_map[event_id]
 
-    @defer.inlineCallbacks
-    def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
+    async def get_state_ids_for_event(
+        self, event_id: str, state_filter: StateFilter = StateFilter.all()
+    ):
         """
         Get the state dict corresponding to a particular event
 
         Args:
-            event_id(str): event whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            event_id: event whose state should be returned
+            state_filter: The state filter used to fetch state from the database.
 
         Returns:
             A deferred dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_ids_for_events([event_id], state_filter)
+        state_map = await self.get_state_ids_for_events([event_id], state_filter)
         return state_map[event_id]
 
     def _get_state_for_groups(
@@ -530,9 +537,8 @@ class StateGroupStorage(object):
         filtering by type/state_key
 
         Args:
-            groups (iterable[int]): list of state groups for which we want
-                to get the state.
-            state_filter (StateFilter): The state filter used to fetch state
+            groups: list of state groups for which we want to get the state.
+            state_filter: The state filter used to fetch state.
                 from the database.
         Returns:
             Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
@@ -540,18 +546,23 @@ class StateGroupStorage(object):
         return self.stores.state._get_state_for_groups(groups, state_filter)
 
     def store_state_group(
-        self, event_id, room_id, prev_group, delta_ids, current_state_ids
+        self,
+        event_id: str,
+        room_id: str,
+        prev_group: Optional[int],
+        delta_ids: Optional[dict],
+        current_state_ids: dict,
     ):
         """Store a new set of state, returning a newly assigned state group.
 
         Args:
-            event_id (str): The event ID for which the state was calculated
-            room_id (str)
-            prev_group (int|None): A previous state group for the room, optional.
-            delta_ids (dict|None): The delta between state at `prev_group` and
+            event_id: The event ID for which the state was calculated.
+            room_id: ID of the room for which the state was calculated.
+            prev_group: A previous state group for the room, optional.
+            delta_ids: The delta between state at `prev_group` and
                 `current_state_ids`, if `prev_group` was given. Same format as
                 `current_state_ids`.
-            current_state_ids (dict): The state to store. Map of (type, state_key)
+            current_state_ids: The state to store. Map of (type, state_key)
                 to event_id.
 
         Returns:
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index b9fafaa1a6..a6012c973d 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,6 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 from synapse.rest.client.v1 import room
 
 from tests.unittest import HomeserverTestCase
@@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
         event = self.successResultOf(event)
 
         # Purge everything before this topological token
-        purge = storage.purge_events.purge_history(self.room_id, event, True)
+        purge = defer.ensureDeferred(
+            storage.purge_events.purge_history(self.room_id, event, True)
+        )
         self.pump()
         self.assertEqual(self.successResultOf(purge), None)
 
@@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase):
         )
 
         # Purge everything before this topological token
-        purge = storage.purge_history(self.room_id, event, True)
+        purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
         self.pump()
         f = self.failureResultOf(purge)
         self.assertIn("greater than forward", f.value.args[0])
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index 1d77b4a2d6..a5f250d477 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -97,8 +97,10 @@ class RoomEventsStoreTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def inject_room_event(self, **kwargs):
-        yield self.storage.persistence.persist_event(
-            self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(
+                self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
+            )
         )
 
     @defer.inlineCallbacks
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index a0e133cd4a..6a48b9d3b3 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
 
         return event
 
@@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.storage.state.get_state_groups_ids(
-            self.room, [e2.event_id]
+        state_group_map = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_map = list(state_group_map.values())[0]
@@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
             self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
         )
 
-        state_group_map = yield self.storage.state.get_state_groups(
-            self.room, [e2.event_id]
+        state_group_map = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups(self.room, [e2.event_id])
         )
         self.assertEqual(len(state_group_map), 1)
         state_list = list(state_group_map.values())[0]
@@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we get the full state as of the final event
-        state = yield self.storage.state.get_state_for_event(e5.event_id)
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(e5.event_id)
+        )
 
         self.assertIsNotNone(e4)
 
@@ -164,22 +168,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check we can filter to the m.room.name event (with a '' state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
+            )
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can filter to the m.room.name event (with a wildcard None state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
+            )
         )
 
         self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
 
         # check we can grab the m.room.member events (with a wildcard None state key)
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
+            )
         )
 
         self.assertStateMapEqual(
@@ -188,12 +198,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
 
         # check we can grab a specific room member without filtering out the
         # other event types
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id,
-            state_filter=StateFilter(
-                types={EventTypes.Member: {self.u_alice.to_string()}},
-                include_others=True,
-            ),
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id,
+                state_filter=StateFilter(
+                    types={EventTypes.Member: {self.u_alice.to_string()}},
+                    include_others=True,
+                ),
+            )
         )
 
         self.assertStateMapEqual(
@@ -206,11 +218,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
         )
 
         # check that we can grab everything except members
-        state = yield self.storage.state.get_state_for_event(
-            e5.event_id,
-            state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
-            ),
+        state = yield defer.ensureDeferred(
+            self.storage.state.get_state_for_event(
+                e5.event_id,
+                state_filter=StateFilter(
+                    types={EventTypes.Member: set()}, include_others=True
+                ),
+            )
         )
 
         self.assertStateMapEqual(
@@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
         #######################################################
 
         room_id = self.room.to_string()
-        group_ids = yield self.storage.state.get_state_groups_ids(
-            room_id, [e5.event_id]
+        group_ids = yield defer.ensureDeferred(
+            self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
         )
         group = list(group_ids.keys())[0]
 
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index a7a36174ea..531a9b9118 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         self.store = self.hs.get_datastore()
         self.storage = self.hs.get_storage()
 
-        yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
+        yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
 
     @defer.inlineCallbacks
     def test_filtering(self):
@@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
         event, context = yield defer.ensureDeferred(
             self.event_creation_handler.create_new_client_event(builder)
         )
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
         return event
 
     @defer.inlineCallbacks
@@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
         return event
 
     @defer.inlineCallbacks
@@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
             self.event_creation_handler.create_new_client_event(builder)
         )
 
-        yield self.storage.persistence.persist_event(event, context)
+        yield defer.ensureDeferred(
+            self.storage.persistence.persist_event(event, context)
+        )
         return event
 
     @defer.inlineCallbacks
diff --git a/tests/utils.py b/tests/utils.py
index ac643679aa..b33b6860d4 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -638,14 +638,8 @@ class DeferredMockCallable(object):
             )
 
 
-@defer.inlineCallbacks
-def create_room(hs, room_id, creator_id):
+async def create_room(hs, room_id: str, creator_id: str):
     """Creates and persist a creation event for the given room
-
-    Args:
-        hs
-        room_id (str)
-        creator_id (str)
     """
 
     persistence_store = hs.get_storage().persistence
@@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id):
     event_builder_factory = hs.get_event_builder_factory()
     event_creation_handler = hs.get_event_creation_handler()
 
-    yield store.store_room(
+    await store.store_room(
         room_id=room_id,
         room_creator_user_id=creator_id,
         is_public=False,
@@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id):
         },
     )
 
-    event, context = yield defer.ensureDeferred(
-        event_creation_handler.create_new_client_event(builder)
-    )
+    event, context = await event_creation_handler.create_new_client_event(builder)
 
-    yield persistence_store.persist_event(event, context)
+    await persistence_store.persist_event(event, context)
diff --git a/tox.ini b/tox.ini
index 595ab3ba66..a394f6eadc 100644
--- a/tox.ini
+++ b/tox.ini
@@ -206,6 +206,7 @@ commands = mypy \
             synapse/storage/data_stores/main/ui_auth.py \
             synapse/storage/database.py \
             synapse/storage/engines \
+            synapse/storage/state.py \
             synapse/storage/util \
             synapse/streams \
             synapse/util/caches/stream_change_cache.py \