summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/cache.py1
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py96
-rw-r--r--synapse/storage/data_stores/main/events.py48
-rw-r--r--synapse/storage/data_stores/main/events_worker.py86
-rw-r--r--synapse/storage/data_stores/main/purge_events.py2
-rw-r--r--synapse/storage/data_stores/main/push_rule.py2
-rw-r--r--synapse/storage/data_stores/main/room.py98
-rw-r--r--synapse/storage/data_stores/main/roommember.py2
-rw-r--r--synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql18
-rw-r--r--synapse/storage/data_stores/main/state.py57
-rw-r--r--synapse/storage/data_stores/main/stats.py53
-rw-r--r--synapse/storage/data_stores/main/stream.py2
-rw-r--r--synapse/storage/data_stores/main/user_directory.py4
-rw-r--r--synapse/storage/data_stores/main/user_erasure_store.py26
-rw-r--r--synapse/storage/data_stores/state/store.py37
-rw-r--r--synapse/storage/database.py18
-rw-r--r--synapse/storage/persist_events.py45
-rw-r--r--synapse/storage/purge_events.py38
-rw-r--r--synapse/storage/state.py216
19 files changed, 503 insertions, 346 deletions
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index f39f556c20..edc3624fed 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
 
+        self.get_unread_message_count_for_user.invalidate_many((room_id,))
         self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
 
         if not backfilled:
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 504babaa7e..ad82838901 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -15,11 +15,10 @@
 # limitations under the License.
 
 import logging
+from typing import List
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
 from synapse.storage.database import Database
@@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         return {"notify_count": notify_count, "highlight_count": highlight_count}
 
-    @defer.inlineCallbacks
-    def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
+    async def get_push_action_users_in_range(
+        self, min_stream_ordering, max_stream_ordering
+    ):
         def f(txn):
             sql = (
                 "SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
@@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
             return [r[0] for r in txn]
 
-        ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
+        ret = await self.db.runInteraction("get_push_action_users_in_range", f)
         return ret
 
-    @defer.inlineCallbacks
-    def get_unread_push_actions_for_user_in_range_for_http(
-        self, user_id, min_stream_ordering, max_stream_ordering, limit=20
-    ):
+    async def get_unread_push_actions_for_user_in_range_for_http(
+        self,
+        user_id: str,
+        min_stream_ordering: int,
+        max_stream_ordering: int,
+        limit: int = 20,
+    ) -> List[dict]:
         """Get a list of the most recent unread push actions for a given user,
         within the given stream ordering range. Called by the httppusher.
 
         Args:
-            user_id (str): The user to fetch push actions for.
-            min_stream_ordering(int): The exclusive lower bound on the
+            user_id: The user to fetch push actions for.
+            min_stream_ordering: The exclusive lower bound on the
                 stream ordering of event push actions to fetch.
-            max_stream_ordering(int): The inclusive upper bound on the
+            max_stream_ordering: The inclusive upper bound on the
                 stream ordering of event push actions to fetch.
-            limit (int): The maximum number of rows to return.
+            limit: The maximum number of rows to return.
         Returns:
-            A promise which resolves to a list of dicts with the keys "event_id",
-            "room_id", "stream_ordering", "actions".
+            A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions".
             The list will be ordered by ascending stream_ordering.
             The list will have between 0~limit entries.
         """
@@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        after_read_receipt = yield self.db.runInteraction(
+        after_read_receipt = await self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
         )
 
@@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        no_read_receipt = yield self.db.runInteraction(
+        no_read_receipt = await self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
         )
 
@@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         # one of the subqueries may have hit the limit.
         return notifs[:limit]
 
-    @defer.inlineCallbacks
-    def get_unread_push_actions_for_user_in_range_for_email(
-        self, user_id, min_stream_ordering, max_stream_ordering, limit=20
-    ):
+    async def get_unread_push_actions_for_user_in_range_for_email(
+        self,
+        user_id: str,
+        min_stream_ordering: int,
+        max_stream_ordering: int,
+        limit: int = 20,
+    ) -> List[dict]:
         """Get a list of the most recent unread push actions for a given user,
         within the given stream ordering range. Called by the emailpusher
 
         Args:
-            user_id (str): The user to fetch push actions for.
-            min_stream_ordering(int): The exclusive lower bound on the
+            user_id: The user to fetch push actions for.
+            min_stream_ordering: The exclusive lower bound on the
                 stream ordering of event push actions to fetch.
-            max_stream_ordering(int): The inclusive upper bound on the
+            max_stream_ordering: The inclusive upper bound on the
                 stream ordering of event push actions to fetch.
-            limit (int): The maximum number of rows to return.
+            limit: The maximum number of rows to return.
         Returns:
-            A promise which resolves to a list of dicts with the keys "event_id",
-            "room_id", "stream_ordering", "actions", "received_ts".
+            A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts".
             The list will be ordered by descending received_ts.
             The list will have between 0~limit entries.
         """
@@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        after_read_receipt = yield self.db.runInteraction(
+        after_read_receipt = await self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
         )
 
@@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        no_read_receipt = yield self.db.runInteraction(
+        no_read_receipt = await self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
         )
 
@@ -411,7 +415,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             _get_if_maybe_push_in_range_for_user_txn,
         )
 
-    def add_push_actions_to_staging(self, event_id, user_id_actions):
+    async def add_push_actions_to_staging(self, event_id, user_id_actions):
         """Add the push actions for the event to the push action staging area.
 
         Args:
@@ -457,21 +461,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 ),
             )
 
-        return self.db.runInteraction(
+        return await self.db.runInteraction(
             "add_push_actions_to_staging", _add_push_actions_to_staging_txn
         )
 
-    @defer.inlineCallbacks
-    def remove_push_actions_from_staging(self, event_id):
+    async def remove_push_actions_from_staging(self, event_id: str) -> None:
         """Called if we failed to persist the event to ensure that stale push
         actions don't build up in the DB
-
-        Args:
-            event_id (str)
         """
 
         try:
-            res = yield self.db.simple_delete(
+            res = await self.db.simple_delete(
                 table="event_push_actions_staging",
                 keyvalues={"event_id": event_id},
                 desc="remove_push_actions_from_staging",
@@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 
         return range_end
 
-    @defer.inlineCallbacks
-    def get_time_of_last_push_action_before(self, stream_ordering):
+    async def get_time_of_last_push_action_before(self, stream_ordering):
         def f(txn):
             sql = (
                 "SELECT e.received_ts"
@@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (stream_ordering,))
             return txn.fetchone()
 
-        result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
+        result = await self.db.runInteraction("get_time_of_last_push_action_before", f)
         return result[0] if result else None
 
 
@@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             self._start_rotate_notifs, 30 * 60 * 1000
         )
 
-    @defer.inlineCallbacks
-    def get_push_actions_for_user(
+    async def get_push_actions_for_user(
         self, user_id, before=None, limit=50, only_highlight=False
     ):
         def f(txn):
@@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             txn.execute(sql, args)
             return self.db.cursor_to_dict(txn)
 
-        push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
+        push_actions = await self.db.runInteraction("get_push_actions_for_user", f)
         for pa in push_actions:
             pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         return push_actions
 
-    @defer.inlineCallbacks
-    def get_latest_push_action_stream_ordering(self):
+    async def get_latest_push_action_stream_ordering(self):
         def f(txn):
             txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
             return txn.fetchone()
 
-        result = yield self.db.runInteraction(
+        result = await self.db.runInteraction(
             "get_latest_push_action_stream_ordering", f
         )
         return result[0] or 0
@@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
     def _start_rotate_notifs(self):
         return run_as_background_process("rotate_notifs", self._rotate_notifs)
 
-    @defer.inlineCallbacks
-    def _rotate_notifs(self):
+    async def _rotate_notifs(self):
         if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
             return
         self._doing_notif_rotation = True
@@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             while True:
                 logger.info("Rotating notifications")
 
-                caught_up = yield self.db.runInteraction(
+                caught_up = await self.db.runInteraction(
                     "_rotate_notifs", self._rotate_notifs_txn
                 )
                 if caught_up:
                     break
-                yield self.hs.get_clock().sleep(self._rotate_delay)
+                await self.hs.get_clock().sleep(self._rotate_delay)
         finally:
             self._doing_notif_rotation = False
 
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 6f2e0d15cc..0c9c02afa1 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -53,6 +53,47 @@ event_counter = Counter(
     ["type", "origin_type", "origin_entity"],
 )
 
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+    EventTypes.Topic,
+    EventTypes.Name,
+    EventTypes.RoomAvatar,
+    EventTypes.Tombstone,
+}
+
+
+def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+    # Exclude rejected and soft-failed events.
+    if context.rejected or event.internal_metadata.is_soft_failed():
+        return False
+
+    # Exclude notices.
+    if (
+        not event.is_state()
+        and event.type == EventTypes.Message
+        and event.content.get("msgtype") == "m.notice"
+    ):
+        return False
+
+    # Exclude edits.
+    relates_to = event.content.get("m.relates_to", {})
+    if relates_to.get("rel_type") == RelationTypes.REPLACE:
+        return False
+
+    # Mark events that have a non-empty string body as unread.
+    body = event.content.get("body")
+    if isinstance(body, str) and body:
+        return True
+
+    # Mark some state events as unread.
+    if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+        return True
+
+    # Mark encrypted events as unread.
+    if not event.is_state() and event.type == EventTypes.Encrypted:
+        return True
+
+    return False
+
 
 def encode_json(json_object):
     """
@@ -196,6 +237,10 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
+                self.store.get_unread_message_count_for_user.invalidate_many(
+                    (event.room_id,),
+                )
+
             for room_id, new_state in current_state_for_room.items():
                 self.store.get_current_state_ids.prefill((room_id,), new_state)
 
@@ -817,8 +862,9 @@ class PersistEventsStore:
                     "contains_url": (
                         "url" in event.content and isinstance(event.content["url"], str)
                     ),
+                    "count_as_unread": should_count_as_unread(event, context),
                 }
-                for event, _ in events_and_contexts
+                for event, context in events_and_contexts
             ],
         )
 
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index e812c67078..b03b259636 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import Database
+from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import (
+    Cache,
+    _CacheContext,
+    cached,
+    cachedInlineCallbacks,
+)
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
 
+    @cached(tree=True, cache_context=True)
+    async def get_unread_message_count_for_user(
+        self, room_id: str, user_id: str, cache_context: _CacheContext,
+    ) -> int:
+        """Retrieve the count of unread messages for the given room and user.
+
+        Args:
+            room_id: The ID of the room to count unread messages in.
+            user_id: The ID of the user to count unread messages for.
+
+        Returns:
+            The number of unread messages for the given user in the given room.
+        """
+        with Measure(self._clock, "get_unread_message_count_for_user"):
+            last_read_event_id = await self.get_last_receipt_event_id_for_user(
+                user_id=user_id,
+                room_id=room_id,
+                receipt_type="m.read",
+                on_invalidate=cache_context.invalidate,
+            )
+
+            return await self.db.runInteraction(
+                "get_unread_message_count_for_user",
+                self._get_unread_message_count_for_user_txn,
+                user_id,
+                room_id,
+                last_read_event_id,
+            )
+
+    def _get_unread_message_count_for_user_txn(
+        self,
+        txn: Cursor,
+        user_id: str,
+        room_id: str,
+        last_read_event_id: Optional[str],
+    ) -> int:
+        if last_read_event_id:
+            # Get the stream ordering for the last read event.
+            stream_ordering = self.db.simple_select_one_onecol_txn(
+                txn=txn,
+                table="events",
+                keyvalues={"room_id": room_id, "event_id": last_read_event_id},
+                retcol="stream_ordering",
+            )
+        else:
+            # If there's no read receipt for that room, it probably means the user hasn't
+            # opened it yet, in which case use the stream ID of their join event.
+            # We can't just set it to 0 otherwise messages from other local users from
+            # before this user joined will be counted as well.
+            txn.execute(
+                """
+                SELECT stream_ordering FROM local_current_membership
+                LEFT JOIN events USING (event_id, room_id)
+                WHERE membership = 'join'
+                    AND user_id = ?
+                    AND room_id = ?
+                """,
+                (user_id, room_id),
+            )
+            row = txn.fetchone()
+
+            if row is None:
+                return 0
+
+            stream_ordering = row[0]
+
+        # Count the messages that qualify as unread after the stream ordering we've just
+        # retrieved.
+        sql = """
+            SELECT COUNT(*) FROM events
+            WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
+        """
+
+        txn.execute(sql, (user_id, room_id, stream_ordering))
+        row = txn.fetchone()
+
+        return row[0] if row else 0
+
 
 AllNewEventsResult = namedtuple(
     "AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
index 6546569139..b53fe35c33 100644
--- a/synapse/storage/data_stores/main/purge_events.py
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -62,6 +62,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
         #     event_json
         #     event_push_actions
         #     event_reference_hashes
+        #     event_relations
         #     event_search
         #     event_to_state_groups
         #     events
@@ -209,6 +210,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             "event_edges",
             "event_forward_extremities",
             "event_reference_hashes",
+            "event_relations",
             "event_search",
             "rejections",
         ):
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index c10da245d2..861050814d 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -284,7 +284,7 @@ class PushRulesWorkerStore(
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
         result = yield self._bulk_get_push_rules_for_room(
             event.room_id, state_group, current_state_ids, event=event
         )
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index d2e1e36e7f..ab48052cdc 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
@@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.data_stores.main.search import SearchStore
 from synapse.storage.database import Database, LoggingTransaction
 from synapse.types import ThirdPartyInstanceID
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
 
@@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore):
 
         return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
 
-    @defer.inlineCallbacks
-    def get_largest_public_rooms(
+    async def get_largest_public_rooms(
         self,
         network_tuple: Optional[ThirdPartyInstanceID],
         search_filter: Optional[dict],
@@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore):
 
             return results
 
-        ret_val = yield self.db.runInteraction(
+        ret_val = await self.db.runInteraction(
             "get_largest_public_rooms", _get_largest_public_rooms_txn
         )
-        defer.returnValue(ret_val)
+        return ret_val
 
     @cached(max_entries=10000)
     def is_room_blocked(self, room_id):
@@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore):
             "get_rooms_paginate", _get_rooms_paginate_txn,
         )
 
-    @cachedInlineCallbacks(max_entries=10000)
-    def get_ratelimit_for_user(self, user_id):
+    @cached(max_entries=10000)
+    async def get_ratelimit_for_user(self, user_id):
         """Check if there are any overrides for ratelimiting for the given
         user
 
@@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore):
             of RatelimitOverride are None or 0 then ratelimitng has been
             disabled for that user entirely.
         """
-        row = yield self.db.simple_select_one(
+        row = await self.db.simple_select_one(
             table="ratelimit_override",
             keyvalues={"user_id": user_id},
             retcols=("messages_per_second", "burst_count"),
@@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore):
         else:
             return None
 
-    @cachedInlineCallbacks()
-    def get_retention_policy_for_room(self, room_id):
+    @cached()
+    async def get_retention_policy_for_room(self, room_id):
         """Get the retention policy for a given room.
 
         If no retention policy has been found for this room, returns a policy defined
@@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore):
 
             return self.db.cursor_to_dict(txn)
 
-        ret = yield self.db.runInteraction(
+        ret = await self.db.runInteraction(
             "get_retention_policy_for_room", get_retention_policy_for_room_txn,
         )
 
         # If we don't know this room ID, ret will be None, in this case return the default
         # policy.
         if not ret:
-            defer.returnValue(
-                {
-                    "min_lifetime": self.config.retention_default_min_lifetime,
-                    "max_lifetime": self.config.retention_default_max_lifetime,
-                }
-            )
+            return {
+                "min_lifetime": self.config.retention_default_min_lifetime,
+                "max_lifetime": self.config.retention_default_max_lifetime,
+            }
 
         row = ret[0]
 
@@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore):
         if row["max_lifetime"] is None:
             row["max_lifetime"] = self.config.retention_default_max_lifetime
 
-        defer.returnValue(row)
+        return row
 
     def get_media_mxcs_in_room(self, room_id):
         """Retrieves all the local and remote media MXC URIs in a given room
@@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             self._background_add_rooms_room_version_column,
         )
 
-    @defer.inlineCallbacks
-    def _background_insert_retention(self, progress, batch_size):
+    async def _background_insert_retention(self, progress, batch_size):
         """Retrieves a list of all rooms within a range and inserts an entry for each of
         them into the room_retention table.
         NULLs the property's columns if missing from the retention event in the room's
@@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             else:
                 return False
 
-        end = yield self.db.runInteraction(
+        end = await self.db.runInteraction(
             "insert_room_retention", _background_insert_retention_txn,
         )
 
         if end:
-            yield self.db.updates._end_background_update("insert_room_retention")
+            await self.db.updates._end_background_update("insert_room_retention")
 
-        defer.returnValue(batch_size)
+        return batch_size
 
     async def _background_add_rooms_room_version_column(
         self, progress: dict, batch_size: int
@@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             lock=False,
         )
 
-    @defer.inlineCallbacks
-    def store_room(
+    async def store_room(
         self,
         room_id: str,
         room_creator_user_id: str,
@@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     )
 
             with self._public_room_id_gen.get_next() as next_id:
-                yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
+                await self.db.runInteraction("store_room_txn", store_room_txn, next_id)
         except Exception as e:
             logger.error("store_room with room_id=%s failed: %s", room_id, e)
             raise StoreError(500, "Problem creating room.")
@@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             lock=False,
         )
 
-    @defer.inlineCallbacks
-    def set_room_is_public(self, room_id, is_public):
+    async def set_room_is_public(self, room_id, is_public):
         def set_room_is_public_txn(txn, next_id):
             self.db.simple_update_one_txn(
                 txn,
@@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 )
 
         with self._public_room_id_gen.get_next() as next_id:
-            yield self.db.runInteraction(
+            await self.db.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    @defer.inlineCallbacks
-    def set_room_is_public_appservice(
+    async def set_room_is_public_appservice(
         self, room_id, appservice_id, network_id, is_public
     ):
         """Edit the appservice/network specific public room list.
@@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 )
 
         with self._public_room_id_gen.get_next() as next_id:
-            yield self.db.runInteraction(
+            await self.db.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
                 next_id,
@@ -1327,52 +1318,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    @defer.inlineCallbacks
-    def block_room(self, room_id, user_id):
+    async def block_room(self, room_id: str, user_id: str) -> None:
         """Marks the room as blocked. Can be called multiple times.
 
         Args:
-            room_id (str): Room to block
-            user_id (str): Who blocked it
-
-        Returns:
-            Deferred
+            room_id: Room to block
+            user_id: Who blocked it
         """
-        yield self.db.simple_upsert(
+        await self.db.simple_upsert(
             table="blocked_rooms",
             keyvalues={"room_id": room_id},
             values={},
             insertion_values={"user_id": user_id},
             desc="block_room",
         )
-        yield self.db.runInteraction(
+        await self.db.runInteraction(
             "block_room_invalidation",
             self._invalidate_cache_and_stream,
             self.is_room_blocked,
             (room_id,),
         )
 
-    @defer.inlineCallbacks
-    def get_rooms_for_retention_period_in_range(
-        self, min_ms, max_ms, include_null=False
-    ):
+    async def get_rooms_for_retention_period_in_range(
+        self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
+    ) -> Dict[str, dict]:
         """Retrieves all of the rooms within the given retention range.
 
         Optionally includes the rooms which don't have a retention policy.
 
         Args:
-            min_ms (int|None): Duration in milliseconds that define the lower limit of
+            min_ms: Duration in milliseconds that define the lower limit of
                 the range to handle (exclusive). If None, doesn't set a lower limit.
-            max_ms (int|None): Duration in milliseconds that define the upper limit of
+            max_ms: Duration in milliseconds that define the upper limit of
                 the range to handle (inclusive). If None, doesn't set an upper limit.
-            include_null (bool): Whether to include rooms which retention policy is NULL
+            include_null: Whether to include rooms which retention policy is NULL
                 in the returned set.
 
         Returns:
-            dict[str, dict]: The rooms within this range, along with their retention
-                policy. The key is "room_id", and maps to a dict describing the retention
-                policy associated with this room ID. The keys for this nested dict are
-                "min_lifetime" (int|None), and "max_lifetime" (int|None).
+            The rooms within this range, along with their retention
+            policy. The key is "room_id", and maps to a dict describing the retention
+            policy associated with this room ID. The keys for this nested dict are
+            "min_lifetime" (int|None), and "max_lifetime" (int|None).
         """
 
         def get_rooms_for_retention_period_in_range_txn(txn):
@@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
             return rooms_dict
 
-        rooms = yield self.db.runInteraction(
+        rooms = await self.db.runInteraction(
             "get_rooms_for_retention_period_in_range",
             get_rooms_for_retention_period_in_range_txn,
         )
 
-        defer.returnValue(rooms)
+        return rooms
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 29765890ee..a92e401e88 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
         result = yield self._get_joined_users_from_context(
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
new file mode 100644
index 0000000000..531b532c73
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Store a boolean value in the events table for whether the event should be counted in
+-- the unread_count property of sync responses.
+ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index bb38a04ede..a360699408 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -16,12 +16,12 @@
 import collections.abc
 import logging
 from collections import namedtuple
-
-from twisted.internet import defer
+from typing import Iterable, Optional, Set
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
@@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         create_event = await self.get_create_event_for_room(room_id)
         return create_event.content.get("room_version", "1")
 
-    @defer.inlineCallbacks
-    def get_room_predecessor(self, room_id):
+    async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
         """Get the predecessor of an upgraded room if it exists.
         Otherwise return None.
 
         Args:
-            room_id (str)
+            room_id: The room ID.
 
         Returns:
-            Deferred[dict|None]: A dictionary containing the structure of the predecessor
-                field from the room's create event. The structure is subject to other servers,
-                but it is expected to be:
-                    * room_id (str): The room ID of the predecessor room
-                    * event_id (str): The ID of the tombstone event in the predecessor room
+            A dictionary containing the structure of the predecessor
+            field from the room's create event. The structure is subject to other servers,
+            but it is expected to be:
+                * room_id (str): The room ID of the predecessor room
+                * event_id (str): The ID of the tombstone event in the predecessor room
 
-                None if a predecessor key is not found, or is not a dictionary.
+            None if a predecessor key is not found, or is not a dictionary.
 
         Raises:
             NotFoundError if the given room is unknown
         """
         # Retrieve the room's create event
-        create_event = yield self.get_create_event_for_room(room_id)
+        create_event = await self.get_create_event_for_room(room_id)
 
         # Retrieve the predecessor key of the create event
         predecessor = create_event.content.get("predecessor", None)
@@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return predecessor
 
-    @defer.inlineCallbacks
-    def get_create_event_for_room(self, room_id):
+    async def get_create_event_for_room(self, room_id: str) -> EventBase:
         """Get the create state event for a room.
 
         Args:
-            room_id (str)
+            room_id: The room ID.
 
         Returns:
-            Deferred[EventBase]: The room creation event.
+            The room creation event.
 
         Raises:
             NotFoundError if the room is unknown
         """
-        state_ids = yield self.get_current_state_ids(room_id)
+        state_ids = await self.get_current_state_ids(room_id)
         create_id = state_ids.get((EventTypes.Create, ""))
 
         # If we can't find the create event, assume we've hit a dead end
@@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             raise NotFoundError("Unknown room %s" % (room_id,))
 
         # Retrieve the room's create event and return
-        create_event = yield self.get_event(create_id)
+        create_event = await self.get_event(create_id)
         return create_event
 
     @cached(max_entries=100000, iterable=True)
@@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
         )
 
-    @defer.inlineCallbacks
-    def get_canonical_alias_for_room(self, room_id):
+    async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
         """Get canonical alias for room, if any
 
         Args:
-            room_id (str)
+            room_id: The room ID
 
         Returns:
-            Deferred[str|None]: The canonical alias, if any
+            The canonical alias, if any
         """
 
-        state = yield self.get_filtered_current_state_ids(
+        state = await self.get_filtered_current_state_ids(
             room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
         )
 
@@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         if not event_id:
             return
 
-        event = yield self.get_event(event_id, allow_none=True)
+        event = await self.get_event(event_id, allow_none=True)
         if not event:
             return
 
@@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return {row["event_id"]: row["state_group"] for row in rows}
 
-    @defer.inlineCallbacks
-    def get_referenced_state_groups(self, state_groups):
+    async def get_referenced_state_groups(
+        self, state_groups: Iterable[int]
+    ) -> Set[int]:
         """Check if the state groups are referenced by events.
 
         Args:
-            state_groups (Iterable[int])
+            state_groups
 
         Returns:
-            Deferred[set[int]]: The subset of state groups that are
-            referenced.
+            The subset of state groups that are referenced.
         """
 
-        rows = yield self.db.simple_select_many_batch(
+        rows = await self.db.simple_select_many_batch(
             table="event_to_state_groups",
             column="state_group",
             iterable=state_groups,
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 380c1ec7da..922400a7c3 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -16,8 +16,8 @@
 
 import logging
 from itertools import chain
+from typing import Tuple
 
-from twisted.internet import defer
 from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
@@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore):
         """
         return (ts // self.stats_bucket_size) * self.stats_bucket_size
 
-    @defer.inlineCallbacks
-    def _populate_stats_process_users(self, progress, batch_size):
+    async def _populate_stats_process_users(self, progress, batch_size):
         """
         This is a background update which regenerates statistics for users.
         """
         if not self.stats_enabled:
-            yield self.db.updates._end_background_update("populate_stats_process_users")
+            await self.db.updates._end_background_update("populate_stats_process_users")
             return 1
 
         last_user_id = progress.get("last_user_id", "")
@@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, (last_user_id, batch_size))
             return [r for r, in txn]
 
-        users_to_work_on = yield self.db.runInteraction(
+        users_to_work_on = await self.db.runInteraction(
             "_populate_stats_process_users", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not users_to_work_on:
-            yield self.db.updates._end_background_update("populate_stats_process_users")
+            await self.db.updates._end_background_update("populate_stats_process_users")
             return 1
 
         for user_id in users_to_work_on:
-            yield self._calculate_and_set_initial_state_for_user(user_id)
+            await self._calculate_and_set_initial_state_for_user(user_id)
             progress["last_user_id"] = user_id
 
-        yield self.db.runInteraction(
+        await self.db.runInteraction(
             "populate_stats_process_users",
             self.db.updates._background_update_progress_txn,
             "populate_stats_process_users",
@@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore):
 
         return len(users_to_work_on)
 
-    @defer.inlineCallbacks
-    def _populate_stats_process_rooms(self, progress, batch_size):
+    async def _populate_stats_process_rooms(self, progress, batch_size):
         """
         This is a background update which regenerates statistics for rooms.
         """
         if not self.stats_enabled:
-            yield self.db.updates._end_background_update("populate_stats_process_rooms")
+            await self.db.updates._end_background_update("populate_stats_process_rooms")
             return 1
 
         last_room_id = progress.get("last_room_id", "")
@@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, (last_room_id, batch_size))
             return [r for r, in txn]
 
-        rooms_to_work_on = yield self.db.runInteraction(
+        rooms_to_work_on = await self.db.runInteraction(
             "populate_stats_rooms_get_batch", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
-            yield self.db.updates._end_background_update("populate_stats_process_rooms")
+            await self.db.updates._end_background_update("populate_stats_process_rooms")
             return 1
 
         for room_id in rooms_to_work_on:
-            yield self._calculate_and_set_initial_state_for_room(room_id)
+            await self._calculate_and_set_initial_state_for_room(room_id)
             progress["last_room_id"] = room_id
 
-        yield self.db.runInteraction(
+        await self.db.runInteraction(
             "_populate_stats_process_rooms",
             self.db.updates._background_update_progress_txn,
             "populate_stats_process_rooms",
@@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore):
 
         return room_deltas, user_deltas
 
-    @defer.inlineCallbacks
-    def _calculate_and_set_initial_state_for_room(self, room_id):
+    async def _calculate_and_set_initial_state_for_room(
+        self, room_id: str
+    ) -> Tuple[dict, dict, int]:
         """Calculate and insert an entry into room_stats_current.
 
         Args:
-            room_id (str)
+            room_id: The room ID under calculation.
 
         Returns:
-            Deferred[tuple[dict, dict, int]]: A tuple of room state, membership
-            counts and stream position.
+            A tuple of room state, membership counts and stream position.
         """
 
         def _fetch_current_state_stats(txn):
@@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore):
             current_state_events_count,
             users_in_room,
             pos,
-        ) = yield self.db.runInteraction(
+        ) = await self.db.runInteraction(
             "get_initial_state_for_room", _fetch_current_state_stats
         )
 
-        state_event_map = yield self.get_events(event_ids, get_prev_content=False)
+        state_event_map = await self.get_events(event_ids, get_prev_content=False)
 
         room_state = {
             "join_rules": None,
@@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore):
                     event.content.get("m.federate", True) is True
                 )
 
-        yield self.update_room_state(room_id, room_state)
+        await self.update_room_state(room_id, room_state)
 
         local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
 
-        yield self.update_stats_delta(
+        await self.update_stats_delta(
             ts=self.clock.time_msec(),
             stats_type="room",
             stats_id=room_id,
@@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore):
             },
         )
 
-    @defer.inlineCallbacks
-    def _calculate_and_set_initial_state_for_user(self, user_id):
+    async def _calculate_and_set_initial_state_for_user(self, user_id):
         def _calculate_and_set_initial_state_for_user_txn(txn):
             pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
 
@@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore):
             (count,) = txn.fetchone()
             return count, pos
 
-        joined_rooms, pos = yield self.db.runInteraction(
+        joined_rooms, pos = await self.db.runInteraction(
             "calculate_and_set_initial_state_for_user",
             _calculate_and_set_initial_state_for_user_txn,
         )
 
-        yield self.update_stats_delta(
+        await self.update_stats_delta(
             ts=self.clock.time_msec(),
             stats_type="user",
             stats_id=user_id,
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 5e32c7aa1e..10d39b3699 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -255,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         self._instance_name = hs.get_instance_name()
         self._send_federation = hs.should_send_federation()
-        self._federation_shard_config = hs.config.federation.federation_shard_config
+        self._federation_shard_config = hs.config.worker.federation_shard_config
 
         # If we're a process that sends federation we may need to reset the
         # `federation_stream_position` table to match the current sharding
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 6b8130bf0f..942e51fd3a 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
                     room_id
                 )
 
-                users_with_profile = yield state.get_current_users_in_room(room_id)
+                users_with_profile = yield defer.ensureDeferred(
+                    state.get_current_users_in_room(room_id)
+                )
                 user_ids = set(users_with_profile)
 
                 # Update each user in the user directory.
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index ec6b8a4ffd..d3038ff06d 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -70,11 +70,11 @@ class UserErasureWorkerStore(SQLBaseStore):
 
 
 class UserErasureStore(UserErasureWorkerStore):
-    def mark_user_erased(self, user_id):
+    def mark_user_erased(self, user_id: str) -> None:
         """Indicate that user_id wishes their message history to be erased.
 
         Args:
-            user_id (str): full user_id to be erased
+            user_id: full user_id to be erased
         """
 
         def f(txn):
@@ -89,3 +89,25 @@ class UserErasureStore(UserErasureWorkerStore):
             self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
 
         return self.db.runInteraction("mark_user_erased", f)
+
+    def mark_user_not_erased(self, user_id: str) -> None:
+        """Indicate that user_id is no longer erased.
+
+        Args:
+            user_id: full user_id to be un-erased
+        """
+
+        def f(txn):
+            # first check if they are already in the list
+            txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
+            if not txn.fetchone():
+                return
+
+            # They are there, delete them.
+            self.simple_delete_one_txn(
+                txn, "erased_users", keyvalues={"user_id": user_id}
+            )
+
+            self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
+
+        return self.db.runInteraction("mark_user_not_erased", f)
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 128c09a2cf..7dada7f75f 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             "get_state_group_delta", _get_state_group_delta_txn
         )
 
-    @defer.inlineCallbacks
-    def _get_state_groups_from_groups(
+    async def _get_state_groups_from_groups(
         self, groups: List[int], state_filter: StateFilter
-    ):
+    ) -> Dict[int, StateMap[str]]:
         """Returns the state groups for a given set of groups from the
         database, filtering on types of state events.
 
@@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+            Dict of state group to state map.
         """
         results = {}
 
         chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
         for chunk in chunks:
-            res = yield self.db.runInteraction(
+            res = await self.db.runInteraction(
                 "_get_state_groups_from_groups",
                 self._get_state_groups_from_groups_txn,
                 chunk,
@@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return state_filter.filter_state(state_dict_ids), not missing_types
 
-    @defer.inlineCallbacks
-    def _get_state_for_groups(
+    async def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> Dict[int, StateMap[str]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
@@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+            Dict of state group to state map.
         """
 
         member_filter, non_member_filter = state_filter.get_member_split()
@@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         (
             non_member_state,
             incomplete_groups_nm,
-        ) = yield self._get_state_for_groups_using_cache(
+        ) = self._get_state_for_groups_using_cache(
             groups, self._state_group_cache, state_filter=non_member_filter
         )
 
-        (
-            member_state,
-            incomplete_groups_m,
-        ) = yield self._get_state_for_groups_using_cache(
+        (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
             groups, self._state_group_members_cache, state_filter=member_filter
         )
 
@@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         # Help the cache hit ratio by expanding the filter a bit
         db_state_filter = state_filter.return_expanded()
 
-        group_to_state_dict = yield self._get_state_groups_from_groups(
+        group_to_state_dict = await self._get_state_groups_from_groups(
             list(incomplete_groups), state_filter=db_state_filter
         )
 
@@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             ((sg,) for sg in state_groups_to_delete),
         )
 
-    @defer.inlineCallbacks
-    def get_previous_state_groups(self, state_groups):
+    async def get_previous_state_groups(
+        self, state_groups: Iterable[int]
+    ) -> Dict[int, int]:
         """Fetch the previous groups of the given state groups.
 
         Args:
-            state_groups (Iterable[int])
+            state_groups
 
         Returns:
-            Deferred[dict[int, int]]: mapping from state group to previous
-            state group.
+            A mapping from state group to previous state group.
         """
 
-        rows = yield self.db.simple_select_many_batch(
+        rows = await self.db.simple_select_many_batch(
             table="state_group_edges",
             column="prev_state_group",
             iterable=state_groups,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 3be20c866a..ce8757a400 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,11 +49,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E
 from synapse.storage.types import Connection, Cursor
 from synapse.types import Collection
 
-logger = logging.getLogger(__name__)
-
 # python 3 does not have a maximum int value
 MAX_TXN_ID = 2 ** 63 - 1
 
+logger = logging.getLogger(__name__)
+
 sql_logger = logging.getLogger("synapse.storage.SQL")
 transaction_logger = logging.getLogger("synapse.storage.txn")
 perf_logger = logging.getLogger("synapse.storage.TIME")
@@ -233,7 +233,7 @@ class LoggingTransaction:
         try:
             return func(sql, *args)
         except Exception as e:
-            logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+            sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
             raise
         finally:
             secs = time.time() - start
@@ -419,7 +419,7 @@ class Database(object):
                 except self.engine.module.OperationalError as e:
                     # This can happen if the database disappears mid
                     # transaction.
-                    logger.warning(
+                    transaction_logger.warning(
                         "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
                     )
                     if i < N:
@@ -427,18 +427,20 @@ class Database(object):
                         try:
                             conn.rollback()
                         except self.engine.module.Error as e1:
-                            logger.warning("[TXN EROLL] {%s} %s", name, e1)
+                            transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
                         continue
                     raise
                 except self.engine.module.DatabaseError as e:
                     if self.engine.is_deadlock(e):
-                        logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                        transaction_logger.warning(
+                            "[TXN DEADLOCK] {%s} %d/%d", name, i, N
+                        )
                         if i < N:
                             i += 1
                             try:
                                 conn.rollback()
                             except self.engine.module.Error as e1:
-                                logger.warning(
+                                transaction_logger.warning(
                                     "[TXN EROLL] {%s} %s", name, e1,
                                 )
                             continue
@@ -478,7 +480,7 @@ class Database(object):
                     # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
                     cursor.close()
         except Exception as e:
-            logger.debug("[TXN FAIL] {%s} %s", name, e)
+            transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
             raise
         finally:
             end = monotonic_time()
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index fa46041676..4a164834d9 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -25,11 +25,10 @@ from prometheus_client import Counter, Histogram
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.events import FrozenEvent
+from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
 from synapse.storage.data_stores import DataStores
 from synapse.storage.data_stores.main.events import DeltaState
 from synapse.types import StateMap
@@ -193,12 +192,11 @@ class EventsPersistenceStorage(object):
         self._event_persist_queue = _EventPeristenceQueue()
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
-    @defer.inlineCallbacks
-    def persist_events(
+    async def persist_events(
         self,
-        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ):
+    ) -> int:
         """
         Write events to the database
         Args:
@@ -208,7 +206,7 @@ class EventsPersistenceStorage(object):
                 which might update the current state etc.
 
         Returns:
-            Deferred[int]: the stream ordering of the latest persisted event
+            the stream ordering of the latest persisted event
         """
         partitioned = {}
         for event, ctx in events_and_contexts:
@@ -224,22 +222,19 @@ class EventsPersistenceStorage(object):
         for room_id in partitioned:
             self._maybe_start_persisting(room_id)
 
-        yield make_deferred_yieldable(
+        await make_deferred_yieldable(
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
-        max_persisted_id = yield self.main_store.get_current_events_token()
-
-        return max_persisted_id
+        return self.main_store.get_current_events_token()
 
-    @defer.inlineCallbacks
-    def persist_event(
-        self, event: FrozenEvent, context: EventContext, backfilled: bool = False
-    ):
+    async def persist_event(
+        self, event: EventBase, context: EventContext, backfilled: bool = False
+    ) -> Tuple[int, int]:
         """
         Returns:
-            Deferred[Tuple[int, int]]: the stream ordering of ``event``,
-            and the stream ordering of the latest persisted event
+            The stream ordering of `event`, and the stream ordering of the
+            latest persisted event
         """
         deferred = self._event_persist_queue.add_to_queue(
             event.room_id, [(event, context)], backfilled=backfilled
@@ -247,9 +242,9 @@ class EventsPersistenceStorage(object):
 
         self._maybe_start_persisting(event.room_id)
 
-        yield make_deferred_yieldable(deferred)
+        await make_deferred_yieldable(deferred)
 
-        max_persisted_id = yield self.main_store.get_current_events_token()
+        max_persisted_id = self.main_store.get_current_events_token()
         return (event.internal_metadata.stream_ordering, max_persisted_id)
 
     def _maybe_start_persisting(self, room_id: str):
@@ -263,7 +258,7 @@ class EventsPersistenceStorage(object):
 
     async def _persist_events(
         self,
-        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
     ):
         """Calculates the change to current state and forward extremities, and
@@ -440,7 +435,7 @@ class EventsPersistenceStorage(object):
     async def _calculate_new_extremities(
         self,
         room_id: str,
-        event_contexts: List[Tuple[FrozenEvent, EventContext]],
+        event_contexts: List[Tuple[EventBase, EventContext]],
         latest_event_ids: List[str],
     ):
         """Calculates the new forward extremities for a room given events to
@@ -498,7 +493,7 @@ class EventsPersistenceStorage(object):
     async def _get_new_state_after_events(
         self,
         room_id: str,
-        events_context: List[Tuple[FrozenEvent, EventContext]],
+        events_context: List[Tuple[EventBase, EventContext]],
         old_latest_event_ids: Iterable[str],
         new_latest_event_ids: Iterable[str],
     ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@@ -648,6 +643,10 @@ class EventsPersistenceStorage(object):
             room_version = await self.main_store.get_room_version_id(room_id)
 
         logger.debug("calling resolve_state_groups from preserve_events")
+
+        # Avoid a circular import.
+        from synapse.state import StateResolutionStore
+
         res = await self._state_resolution_handler.resolve_state_groups(
             room_id,
             room_version,
@@ -680,7 +679,7 @@ class EventsPersistenceStorage(object):
     async def _is_server_still_joined(
         self,
         room_id: str,
-        ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
+        ev_ctx_rm: List[Tuple[EventBase, EventContext]],
         delta: DeltaState,
         current_state: Optional[StateMap[str]],
         potentially_left_users: Set[str],
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index fdc0abf5cf..79d9f06e2e 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -15,8 +15,7 @@
 
 import itertools
 import logging
-
-from twisted.internet import defer
+from typing import Set
 
 logger = logging.getLogger(__name__)
 
@@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
     def __init__(self, hs, stores):
         self.stores = stores
 
-    @defer.inlineCallbacks
-    def purge_room(self, room_id: str):
+    async def purge_room(self, room_id: str):
         """Deletes all record of a room
         """
 
-        state_groups_to_delete = yield self.stores.main.purge_room(room_id)
-        yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
+        state_groups_to_delete = await self.stores.main.purge_room(room_id)
+        await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
 
-    @defer.inlineCallbacks
-    def purge_history(self, room_id, token, delete_local_events):
+    async def purge_history(
+        self, room_id: str, token: str, delete_local_events: bool
+    ) -> None:
         """Deletes room history before a certain point
 
         Args:
-            room_id (str):
+            room_id: The room ID
 
-            token (str): A topological token to delete events before
+            token: A topological token to delete events before
 
-            delete_local_events (bool):
+            delete_local_events:
                 if True, we will delete local events as well as remote ones
                 (instead of just marking them as outliers and deleting their
                 state groups).
         """
-        state_groups = yield self.stores.main.purge_history(
+        state_groups = await self.stores.main.purge_history(
             room_id, token, delete_local_events
         )
 
         logger.info("[purge] finding state groups that can be deleted")
 
-        sg_to_delete = yield self._find_unreferenced_groups(state_groups)
+        sg_to_delete = await self._find_unreferenced_groups(state_groups)
 
-        yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
+        await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
 
-    @defer.inlineCallbacks
-    def _find_unreferenced_groups(self, state_groups):
+    async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
         """Used when purging history to figure out which state groups can be
         deleted.
 
         Args:
-            state_groups (set[int]): Set of state groups referenced by events
+            state_groups: Set of state groups referenced by events
                 that are going to be deleted.
 
         Returns:
-            Deferred[set[int]] The set of state groups that can be deleted.
+            The set of state groups that can be deleted.
         """
         # Graph of state group -> previous group
         graph = {}
@@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
                 current_search = set(itertools.islice(next_to_search, 100))
                 next_to_search -= current_search
 
-            referenced = yield self.stores.main.get_referenced_state_groups(
+            referenced = await self.stores.main.get_referenced_state_groups(
                 current_search
             )
             referenced_groups |= referenced
@@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
             # groups that are referenced.
             current_search -= referenced
 
-            edges = yield self.stores.state.get_previous_state_groups(current_search)
+            edges = await self.stores.state.get_previous_state_groups(current_search)
 
             prevs = set(edges.values())
             # We don't bother re-handling groups we've already seen
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dc568476f4..534883361f 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,13 +14,12 @@
 # limitations under the License.
 
 import logging
-from typing import Iterable, List, TypeVar
+from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
 
 import attr
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes
+from synapse.events import EventBase
 from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
@@ -34,16 +33,16 @@ class StateFilter(object):
     """A filter used when querying for state.
 
     Attributes:
-        types (dict[str, set[str]|None]): Map from type to set of state keys (or
-            None). This specifies which state_keys for the given type to fetch
-            from the DB. If None then all events with that type are fetched. If
-            the set is empty then no events with that type are fetched.
-        include_others (bool): Whether to fetch events with types that do not
+        types: Map from type to set of state keys (or None). This specifies
+            which state_keys for the given type to fetch from the DB. If None
+            then all events with that type are fetched. If the set is empty
+            then no events with that type are fetched.
+        include_others: Whether to fetch events with types that do not
             appear in `types`.
     """
 
-    types = attr.ib()
-    include_others = attr.ib(default=False)
+    types = attr.ib(type=Dict[str, Optional[Set[str]]])
+    include_others = attr.ib(default=False, type=bool)
 
     def __attrs_post_init__(self):
         # If `include_others` is set we canonicalise the filter by removing
@@ -52,36 +51,35 @@ class StateFilter(object):
             self.types = {k: v for k, v in self.types.items() if v is not None}
 
     @staticmethod
-    def all():
+    def all() -> "StateFilter":
         """Creates a filter that fetches everything.
 
         Returns:
-            StateFilter
+            The new state filter.
         """
         return StateFilter(types={}, include_others=True)
 
     @staticmethod
-    def none():
+    def none() -> "StateFilter":
         """Creates a filter that fetches nothing.
 
         Returns:
-            StateFilter
+            The new state filter.
         """
         return StateFilter(types={}, include_others=False)
 
     @staticmethod
-    def from_types(types):
+    def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
         """Creates a filter that only fetches the given types
 
         Args:
-            types (Iterable[tuple[str, str|None]]): A list of type and state
-                keys to fetch. A state_key of None fetches everything for
-                that type
+            types: A list of type and state keys to fetch. A state_key of None
+                fetches everything for that type
 
         Returns:
-            StateFilter
+            The new state filter.
         """
-        type_dict = {}
+        type_dict = {}  # type: Dict[str, Optional[Set[str]]]
         for typ, s in types:
             if typ in type_dict:
                 if type_dict[typ] is None:
@@ -91,24 +89,24 @@ class StateFilter(object):
                 type_dict[typ] = None
                 continue
 
-            type_dict.setdefault(typ, set()).add(s)
+            type_dict.setdefault(typ, set()).add(s)  # type: ignore
 
         return StateFilter(types=type_dict)
 
     @staticmethod
-    def from_lazy_load_member_list(members):
+    def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
         """Creates a filter that returns all non-member events, plus the member
         events for the given users
 
         Args:
-            members (iterable[str]): Set of user IDs
+            members: Set of user IDs
 
         Returns:
-            StateFilter
+            The new state filter
         """
         return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
 
-    def return_expanded(self):
+    def return_expanded(self) -> "StateFilter":
         """Creates a new StateFilter where type wild cards have been removed
         (except for memberships). The returned filter is a superset of the
         current one, i.e. anything that passes the current filter will pass
@@ -130,7 +128,7 @@ class StateFilter(object):
                return all non-member events
 
         Returns:
-            StateFilter
+            The new state filter.
         """
 
         if self.is_full():
@@ -167,7 +165,7 @@ class StateFilter(object):
                 include_others=True,
             )
 
-    def make_sql_filter_clause(self):
+    def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
         """Converts the filter to an SQL clause.
 
         For example:
@@ -179,13 +177,12 @@ class StateFilter(object):
 
 
         Returns:
-            tuple[str, list]: The SQL string (may be empty) and arguments. An
-            empty SQL string is returned when the filter matches everything
-            (i.e. is "full").
+            The SQL string (may be empty) and arguments. An empty SQL string is
+            returned when the filter matches everything (i.e. is "full").
         """
 
         where_clause = ""
-        where_args = []
+        where_args = []  # type: List[str]
 
         if self.is_full():
             return where_clause, where_args
@@ -221,7 +218,7 @@ class StateFilter(object):
 
         return where_clause, where_args
 
-    def max_entries_returned(self):
+    def max_entries_returned(self) -> Optional[int]:
         """Returns the maximum number of entries this filter will return if
         known, otherwise returns None.
 
@@ -260,33 +257,33 @@ class StateFilter(object):
 
         return filtered_state
 
-    def is_full(self):
+    def is_full(self) -> bool:
         """Whether this filter fetches everything or not
 
         Returns:
-            bool
+            True if the filter fetches everything.
         """
         return self.include_others and not self.types
 
-    def has_wildcards(self):
+    def has_wildcards(self) -> bool:
         """Whether the filter includes wildcards or is attempting to fetch
         specific state.
 
         Returns:
-            bool
+            True if the filter includes wildcards.
         """
 
         return self.include_others or any(
             state_keys is None for state_keys in self.types.values()
         )
 
-    def concrete_types(self):
+    def concrete_types(self) -> List[Tuple[str, str]]:
         """Returns a list of concrete type/state_keys (i.e. not None) that
         will be fetched. This will be a complete list if `has_wildcards`
         returns False, but otherwise will be a subset (or even empty).
 
         Returns:
-            list[tuple[str,str]]
+            A list of type/state_keys tuples.
         """
         return [
             (t, s)
@@ -295,7 +292,7 @@ class StateFilter(object):
             for s in state_keys
         ]
 
-    def get_member_split(self):
+    def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
         """Return the filter split into two: one which assumes it's exclusively
         matching against member state, and one which assumes it's matching
         against non member state.
@@ -307,7 +304,7 @@ class StateFilter(object):
         state caches).
 
         Returns:
-            tuple[StateFilter, StateFilter]: The member and non member filters
+            The member and non member filters
         """
 
         if EventTypes.Member in self.types:
@@ -340,6 +337,9 @@ class StateGroupStorage(object):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
+        Args:
+            state_group: The state group used to retrieve state deltas.
+
         Returns:
             Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
                 (prev_group, delta_ids)
@@ -347,55 +347,59 @@ class StateGroupStorage(object):
 
         return self.stores.state.get_state_group_delta(state_group)
 
-    @defer.inlineCallbacks
-    def get_state_groups_ids(self, _room_id, event_ids):
+    async def get_state_groups_ids(
+        self, _room_id: str, event_ids: Iterable[str]
+    ) -> Dict[int, StateMap[str]]:
         """Get the event IDs of all the state for the state groups for the given events
 
         Args:
-            _room_id (str): id of the room for these events
-            event_ids (iterable[str]): ids of the events
+            _room_id: id of the room for these events
+            event_ids: ids of the events
 
         Returns:
-            Deferred[dict[int, StateMap[str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if not event_ids:
             return {}
 
-        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
-        group_to_state = yield self.stores.state._get_state_for_groups(groups)
+        group_to_state = await self.stores.state._get_state_for_groups(groups)
 
         return group_to_state
 
-    @defer.inlineCallbacks
-    def get_state_ids_for_group(self, state_group):
+    async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
         """Get the event IDs of all the state in the given state group
 
         Args:
-            state_group (int)
+            state_group: A state group for which we want to get the state IDs.
 
         Returns:
-            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+            Resolves to a map of (type, state_key) -> event_id
         """
-        group_to_state = yield self._get_state_for_groups((state_group,))
+        group_to_state = await self._get_state_for_groups((state_group,))
 
         return group_to_state[state_group]
 
-    @defer.inlineCallbacks
-    def get_state_groups(self, room_id, event_ids):
+    async def get_state_groups(
+        self, room_id: str, event_ids: Iterable[str]
+    ) -> Dict[int, List[EventBase]]:
         """ Get the state groups for the given list of event_ids
+
+        Args:
+            room_id: ID of the room for these events.
+            event_ids: The event IDs to retrieve state for.
+
         Returns:
-            Deferred[dict[int, list[EventBase]]]:
-                dict of state_group_id -> list of state events.
+            dict of state_group_id -> list of state events.
         """
         if not event_ids:
             return {}
 
-        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+        group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
 
-        state_event_map = yield self.stores.main.get_events(
+        state_event_map = await self.stores.main.get_events(
             [
                 ev_id
                 for group_ids in group_to_ids.values()
@@ -415,7 +419,7 @@ class StateGroupStorage(object):
 
     def _get_state_groups_from_groups(
         self, groups: List[int], state_filter: StateFilter
-    ):
+    ) -> Awaitable[Dict[int, StateMap[str]]]:
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
@@ -423,31 +427,34 @@ class StateGroupStorage(object):
             groups: list of state group IDs to query
             state_filter: The state filter used to fetch state
                 from the database.
+
         Returns:
-            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+            Dict of state group to state map.
         """
 
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
-    @defer.inlineCallbacks
-    def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
+    async def get_state_for_events(
+        self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
+    ):
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
+
         Args:
-            event_ids (list[string])
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            event_ids: The events to fetch the state of.
+            state_filter: The state filter used to fetch state.
+
         Returns:
-            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
+            A dict of (event_id) -> (type, state_key) -> [state_events]
         """
-        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
-        group_to_state = yield self.stores.state._get_state_for_groups(
+        group_to_state = await self.stores.state._get_state_for_groups(
             groups, state_filter
         )
 
-        state_event_map = yield self.stores.main.get_events(
+        state_event_map = await self.stores.main.get_events(
             [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
             get_prev_content=False,
         )
@@ -463,24 +470,24 @@ class StateGroupStorage(object):
 
         return {event: event_to_state[event] for event in event_ids}
 
-    @defer.inlineCallbacks
-    def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
+    async def get_state_ids_for_events(
+        self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
+    ):
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
         of the state events (as opposed to the events themselves)
 
         Args:
-            event_ids(list(str)): events whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            event_ids: events whose state should be returned
+            state_filter: The state filter used to fetch state from the database.
 
         Returns:
-            A deferred dict from event_id -> (type, state_key) -> event_id
+            A dict from event_id -> (type, state_key) -> event_id
         """
-        event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+        event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
-        group_to_state = yield self.stores.state._get_state_for_groups(
+        group_to_state = await self.stores.state._get_state_for_groups(
             groups, state_filter
         )
 
@@ -491,67 +498,72 @@ class StateGroupStorage(object):
 
         return {event: event_to_state[event] for event in event_ids}
 
-    @defer.inlineCallbacks
-    def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
+    async def get_state_for_event(
+        self, event_id: str, state_filter: StateFilter = StateFilter.all()
+    ):
         """
         Get the state dict corresponding to a particular event
 
         Args:
-            event_id(str): event whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            event_id: event whose state should be returned
+            state_filter: The state filter used to fetch state from the database.
 
         Returns:
-            A deferred dict from (type, state_key) -> state_event
+            A dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_for_events([event_id], state_filter)
+        state_map = await self.get_state_for_events([event_id], state_filter)
         return state_map[event_id]
 
-    @defer.inlineCallbacks
-    def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
+    async def get_state_ids_for_event(
+        self, event_id: str, state_filter: StateFilter = StateFilter.all()
+    ):
         """
         Get the state dict corresponding to a particular event
 
         Args:
-            event_id(str): event whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            event_id: event whose state should be returned
+            state_filter: The state filter used to fetch state from the database.
 
         Returns:
             A deferred dict from (type, state_key) -> state_event
         """
-        state_map = yield self.get_state_ids_for_events([event_id], state_filter)
+        state_map = await self.get_state_ids_for_events([event_id], state_filter)
         return state_map[event_id]
 
     def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
-    ):
+    ) -> Awaitable[Dict[int, StateMap[str]]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
         Args:
-            groups (iterable[int]): list of state groups for which we want
-                to get the state.
-            state_filter (StateFilter): The state filter used to fetch state
+            groups: list of state groups for which we want to get the state.
+            state_filter: The state filter used to fetch state.
                 from the database.
+
         Returns:
-            Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
+            Dict of state group to state map.
         """
         return self.stores.state._get_state_for_groups(groups, state_filter)
 
     def store_state_group(
-        self, event_id, room_id, prev_group, delta_ids, current_state_ids
+        self,
+        event_id: str,
+        room_id: str,
+        prev_group: Optional[int],
+        delta_ids: Optional[dict],
+        current_state_ids: dict,
     ):
         """Store a new set of state, returning a newly assigned state group.
 
         Args:
-            event_id (str): The event ID for which the state was calculated
-            room_id (str)
-            prev_group (int|None): A previous state group for the room, optional.
-            delta_ids (dict|None): The delta between state at `prev_group` and
+            event_id: The event ID for which the state was calculated.
+            room_id: ID of the room for which the state was calculated.
+            prev_group: A previous state group for the room, optional.
+            delta_ids: The delta between state at `prev_group` and
                 `current_state_ids`, if `prev_group` was given. Same format as
                 `current_state_ids`.
-            current_state_ids (dict): The state to store. Map of (type, state_key)
+            current_state_ids: The state to store. Map of (type, state_key)
                 to event_id.
 
         Returns: