summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-01-22 16:53:28 +0000
committerErik Johnston <erik@matrix.org>2020-01-22 16:53:28 +0000
commit57a60365da0c47f286ea4608d766abbca5762233 (patch)
treeaaef0948f26f3352092b787d32e1dda0743d697e /synapse/storage
parentPull out more info about room key requests (diff)
parentRemove unnecessary abstractions in admin handler (#6751) (diff)
downloadsynapse-erikj/debug_direct_message_checks.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/debug_direct_message_checks github/erikj/debug_direct_message_checks erikj/debug_direct_message_checks
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py18
-rw-r--r--synapse/storage/data_stores/__init__.py2
-rw-r--r--synapse/storage/data_stores/main/cache.py29
-rw-r--r--synapse/storage/data_stores/main/devices.py2
-rw-r--r--synapse/storage/data_stores/main/events.py119
-rw-r--r--synapse/storage/data_stores/main/events_worker.py2
-rw-r--r--synapse/storage/data_stores/main/keys.py2
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py165
-rw-r--r--synapse/storage/data_stores/main/presence.py2
-rw-r--r--synapse/storage/data_stores/main/registration.py2
-rw-r--r--synapse/storage/data_stores/main/room.py239
-rw-r--r--synapse/storage/data_stores/main/roommember.py189
-rw-r--r--synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py98
-rw-r--r--synapse/storage/data_stores/main/state.py11
-rw-r--r--synapse/storage/data_stores/main/stream.py10
-rw-r--r--synapse/storage/data_stores/state/store.py52
-rw-r--r--synapse/storage/engines/postgres.py34
-rw-r--r--synapse/storage/engines/sqlite.py7
-rw-r--r--synapse/storage/persist_events.py123
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/purge_events.py2
-rw-r--r--synapse/storage/state.py35
22 files changed, 798 insertions, 347 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 3bb9381663..da3b99f93d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,6 +17,7 @@
 import logging
 import random
 from abc import ABCMeta
+from typing import Any, Optional
 
 from six import PY2
 from six.moves import builtins
@@ -26,7 +27,7 @@ from canonicaljson import json
 from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import Database
-from synapse.types import get_domain_from_id
+from synapse.types import Collection, get_domain_from_id
 
 logger = logging.getLogger(__name__)
 
@@ -63,17 +64,24 @@ class SQLBaseStore(metaclass=ABCMeta):
         self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
         self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
 
-    def _attempt_to_invalidate_cache(self, cache_name, key):
+    def _attempt_to_invalidate_cache(
+        self, cache_name: str, key: Optional[Collection[Any]]
+    ):
         """Attempts to invalidate the cache of the given name, ignoring if the
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
 
         Args:
-            cache_name (str)
-            key (tuple)
+            cache_name
+            key: Entry to invalidate. If None then invalidates the entire
+                cache.
         """
+
         try:
-            getattr(self, cache_name).invalidate(key)
+            if key is None:
+                getattr(self, cache_name).invalidate_all()
+            else:
+                getattr(self, cache_name).invalidate(tuple(key))
         except AttributeError:
             # We probably haven't pulled in the cache in this worker,
             # which is fine.
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index 092e803799..e1d03429ca 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -47,7 +47,7 @@ class DataStores(object):
             with make_conn(database_config, engine) as db_conn:
                 logger.info("Preparing database %r...", db_name)
 
-                engine.check_database(db_conn.cursor())
+                engine.check_database(db_conn)
                 prepare_database(
                     db_conn, engine, hs.config, data_stores=database_config.data_stores,
                 )
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index 54ed8574c4..afa2b41c98 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -16,12 +16,13 @@
 
 import itertools
 import logging
+from typing import Any, Iterable, Optional
 
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine
-from synapse.util import batch_iter
+from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
@@ -43,6 +44,14 @@ class CacheInvalidationStore(SQLBaseStore):
         txn.call_after(cache_func.invalidate, keys)
         self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
 
+    def _invalidate_all_cache_and_stream(self, txn, cache_func):
+        """Invalidates the entire cache and adds it to the cache stream so slaves
+        will know to invalidate their caches.
+        """
+
+        txn.call_after(cache_func.invalidate_all)
+        self._send_invalidation_to_replication(txn, cache_func.__name__, None)
+
     def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
         """Special case invalidation of caches based on current state.
 
@@ -73,17 +82,24 @@ class CacheInvalidationStore(SQLBaseStore):
                 txn, CURRENT_STATE_CACHE_NAME, [room_id]
             )
 
-    def _send_invalidation_to_replication(self, txn, cache_name, keys):
+    def _send_invalidation_to_replication(
+        self, txn, cache_name: str, keys: Optional[Iterable[Any]]
+    ):
         """Notifies replication that given cache has been invalidated.
 
         Note that this does *not* invalidate the cache locally.
 
         Args:
             txn
-            cache_name (str)
-            keys (iterable[str])
+            cache_name
+            keys: Entry to invalidate. If None will invalidate all.
         """
 
+        if cache_name == CURRENT_STATE_CACHE_NAME and keys is None:
+            raise Exception(
+                "Can't stream invalidate all with magic current state cache"
+            )
+
         if isinstance(self.database_engine, PostgresEngine):
             # get_next() returns a context manager which is designed to wrap
             # the transaction. However, we want to only get an ID when we want
@@ -95,13 +111,16 @@ class CacheInvalidationStore(SQLBaseStore):
             txn.call_after(ctx.__exit__, None, None, None)
             txn.call_after(self.hs.get_notifier().on_new_replication_data)
 
+            if keys is not None:
+                keys = list(keys)
+
             self.db.simple_insert_txn(
                 txn,
                 table="cache_invalidation_stream",
                 values={
                     "stream_id": stream_id,
                     "cache_func": cache_name,
-                    "keys": list(keys),
+                    "keys": keys,
                     "invalidation_ts": self.clock.time_msec(),
                 },
             )
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 9a828231c4..f0a7962dd0 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -33,13 +33,13 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import Database
 from synapse.types import get_verify_key_from_cross_signing_key
-from synapse.util import batch_iter
 from synapse.util.caches.descriptors import (
     Cache,
     cached,
     cachedInlineCallbacks,
     cachedList,
 )
+from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 58f35d7f56..596daf8909 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -19,6 +19,7 @@ import itertools
 import logging
 from collections import Counter as c_counter, OrderedDict, namedtuple
 from functools import wraps
+from typing import Dict, List, Tuple
 
 from six import iteritems, text_type
 from six.moves import range
@@ -41,11 +42,12 @@ from synapse.storage._base import make_in_list_sql_clause
 from synapse.storage.data_stores.main.event_federation import EventFederationStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
-from synapse.storage.database import Database
-from synapse.types import RoomStreamToken, get_domain_from_id
-from synapse.util import batch_iter
+from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.persist_events import DeltaState
+from synapse.types import RoomStreamToken, StateMap, get_domain_from_id
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.frozenutils import frozendict_json_encoder
+from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
@@ -128,6 +130,7 @@ class EventsStore(
             hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+        self.is_mine_id = hs.is_mine_id
 
     @defer.inlineCallbacks
     def _read_forward_extremities(self):
@@ -147,30 +150,26 @@ class EventsStore(
     @defer.inlineCallbacks
     def _persist_events_and_state_updates(
         self,
-        events_and_contexts,
-        current_state_for_room,
-        state_delta_for_room,
-        new_forward_extremeties,
-        backfilled=False,
-        delete_existing=False,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        current_state_for_room: Dict[str, StateMap[str]],
+        state_delta_for_room: Dict[str, DeltaState],
+        new_forward_extremeties: Dict[str, List[str]],
+        backfilled: bool = False,
+        delete_existing: bool = False,
     ):
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
 
         Args:
-            events_and_contexts (list[(EventBase, EventContext)]):
-            current_state_for_room (dict[str, dict]): Map from room_id to the
-                current state of the room based on forward extremities
-            state_delta_for_room (dict[str, tuple]): Map from room_id to tuple
-                of `(to_delete, to_insert)` where to_delete is a list
-                of type/state keys to remove from current state, and to_insert
-                is a map (type,key)->event_id giving the state delta in each
-                room.
-            new_forward_extremities (dict[str, list[str]]): Map from room_id
-                to list of event IDs that are the new forward extremities of
-                the room.
-            backfilled (bool)
-            delete_existing (bool):
+            events_and_contexts:
+            current_state_for_room: Map from room_id to the current state of
+                the room based on forward extremities
+            state_delta_for_room: Map from room_id to the delta to apply to
+                room state
+            new_forward_extremities: Map from room_id to list of event IDs
+                that are the new forward extremities of the room.
+            backfilled
+            delete_existing
 
         Returns:
             Deferred: resolves when the events have been persisted
@@ -351,12 +350,12 @@ class EventsStore(
     @log_function
     def _persist_events_txn(
         self,
-        txn,
-        events_and_contexts,
-        backfilled,
-        delete_existing=False,
-        state_delta_for_room={},
-        new_forward_extremeties={},
+        txn: LoggingTransaction,
+        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        backfilled: bool,
+        delete_existing: bool = False,
+        state_delta_for_room: Dict[str, DeltaState] = {},
+        new_forward_extremeties: Dict[str, List[str]] = {},
     ):
         """Insert some number of room events into the necessary database tables.
 
@@ -365,21 +364,16 @@ class EventsStore(
         whether the event was rejected.
 
         Args:
-            txn (twisted.enterprise.adbapi.Connection): db connection
-            events_and_contexts (list[(EventBase, EventContext)]):
-                events to persist
-            backfilled (bool): True if the events were backfilled
-            delete_existing (bool): True to purge existing table rows for the
-                events from the database. This is useful when retrying due to
+            txn
+            events_and_contexts: events to persist
+            backfilled: True if the events were backfilled
+            delete_existing True to purge existing table rows for the events
+                from the database. This is useful when retrying due to
                 IntegrityError.
-            state_delta_for_room (dict[str, (list, dict)]):
-                The current-state delta for each room. For each room, a tuple
-                (to_delete, to_insert), being a list of type/state keys to be
-                removed from the current state, and a state set to be added to
-                the current state.
-            new_forward_extremeties (dict[str, list[str]]):
-                The new forward extremities for each room. For each room, a
-                list of the event ids which are the forward extremities.
+            state_delta_for_room: The current-state delta for each room.
+            new_forward_extremetie: The new forward extremities for each room.
+                For each room, a list of the event ids which are the forward
+                extremities.
 
         """
         all_events_and_contexts = events_and_contexts
@@ -464,9 +458,15 @@ class EventsStore(
         # room_memberships, where applicable.
         self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
 
-    def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
-        for room_id, current_state_tuple in iteritems(state_delta_by_room):
-            to_delete, to_insert = current_state_tuple
+    def _update_current_state_txn(
+        self,
+        txn: LoggingTransaction,
+        state_delta_by_room: Dict[str, DeltaState],
+        stream_id: int,
+    ):
+        for room_id, delta_state in iteritems(state_delta_by_room):
+            to_delete = delta_state.to_delete
+            to_insert = delta_state.to_insert
 
             # First we add entries to the current_state_delta_stream. We
             # do this before updating the current_state_events table so
@@ -547,6 +547,34 @@ class EventsStore(
                 ],
             )
 
+            # Note: Do we really want to delete rows here (that we do not
+            # subsequently reinsert below)? While technically correct it means
+            # we have no record of the fact the user *was* a member of the
+            # room but got, say, state reset out of it.
+            if to_delete or to_insert:
+                txn.executemany(
+                    "DELETE FROM local_current_membership"
+                    " WHERE room_id = ? AND user_id = ?",
+                    (
+                        (room_id, state_key)
+                        for etype, state_key in itertools.chain(to_delete, to_insert)
+                        if etype == EventTypes.Member and self.is_mine_id(state_key)
+                    ),
+                )
+
+            if to_insert:
+                txn.executemany(
+                    """INSERT INTO local_current_membership
+                        (room_id, user_id, event_id, membership)
+                    VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+                    """,
+                    [
+                        (room_id, key[1], ev_id, ev_id)
+                        for key, ev_id in to_insert.items()
+                        if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+                    ],
+                )
+
             txn.call_after(
                 self._curr_state_delta_stream_cache.entity_has_changed,
                 room_id,
@@ -1724,6 +1752,7 @@ class EventsStore(
             "local_invites",
             "room_account_data",
             "room_tags",
+            "local_current_membership",
         ):
             logger.info("[purge] removing %s from %s", room_id, table)
             txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 0cce5232f5..3b93e0597a 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -37,8 +37,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import Database
 from synapse.types import get_domain_from_id
-from synapse.util import batch_iter
 from synapse.util.caches.descriptors import Cache
+from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
index 6b12f5a75f..ba89c68c9f 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/data_stores/main/keys.py
@@ -23,8 +23,8 @@ from signedjson.key import decode_verify_key_bytes
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.keys import FetchKeyResult
-from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 27158534cb..89a41542a3 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -27,12 +27,76 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
-class MonthlyActiveUsersStore(SQLBaseStore):
+class MonthlyActiveUsersWorkerStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
-        super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
+        super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
+
+    @cached(num_args=0)
+    def get_monthly_active_count(self):
+        """Generates current count of monthly active users
+
+        Returns:
+            Defered[int]: Number of current monthly active users
+        """
+
+        def _count_users(txn):
+            sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
+
+            txn.execute(sql)
+            (count,) = txn.fetchone()
+            return count
+
+        return self.db.runInteraction("count_users", _count_users)
+
+    @defer.inlineCallbacks
+    def get_registered_reserved_users(self):
+        """Of the reserved threepids defined in config, which are associated
+        with registered users?
+
+        Returns:
+            Defered[list]: Real reserved users
+        """
+        users = []
+
+        for tp in self.hs.config.mau_limits_reserved_threepids[
+            : self.hs.config.max_mau_value
+        ]:
+            user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+                tp["medium"], tp["address"]
+            )
+            if user_id:
+                users.append(user_id)
+
+        return users
+
+    @cached(num_args=1)
+    def user_last_seen_monthly_active(self, user_id):
+        """
+            Checks if a given user is part of the monthly active user group
+            Arguments:
+                user_id (str): user to add/update
+            Return:
+                Deferred[int] : timestamp since last seen, None if never seen
+
+        """
+
+        return self.db.simple_select_one_onecol(
+            table="monthly_active_users",
+            keyvalues={"user_id": user_id},
+            retcol="timestamp",
+            allow_none=True,
+            desc="user_last_seen_monthly_active",
+        )
+
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
+
         # Do not add more reserved users than the total allowable number
+        # cur = LoggingTransaction(
         self.db.new_transaction(
             db_conn,
             "initialise_mau_threepids",
@@ -146,57 +210,22 @@ class MonthlyActiveUsersStore(SQLBaseStore):
 
                     txn.execute(sql, query_args)
 
+            # It seems poor to invalidate the whole cache, Postgres supports
+            # 'Returning' which would allow me to invalidate only the
+            # specific users, but sqlite has no way to do this and instead
+            # I would need to SELECT and the DELETE which without locking
+            # is racy.
+            # Have resolved to invalidate the whole cache for now and do
+            # something about it if and when the perf becomes significant
+            self._invalidate_all_cache_and_stream(
+                txn, self.user_last_seen_monthly_active
+            )
+            self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+
         reserved_users = yield self.get_registered_reserved_users()
         yield self.db.runInteraction(
             "reap_monthly_active_users", _reap_users, reserved_users
         )
-        # It seems poor to invalidate the whole cache, Postgres supports
-        # 'Returning' which would allow me to invalidate only the
-        # specific users, but sqlite has no way to do this and instead
-        # I would need to SELECT and the DELETE which without locking
-        # is racy.
-        # Have resolved to invalidate the whole cache for now and do
-        # something about it if and when the perf becomes significant
-        self.user_last_seen_monthly_active.invalidate_all()
-        self.get_monthly_active_count.invalidate_all()
-
-    @cached(num_args=0)
-    def get_monthly_active_count(self):
-        """Generates current count of monthly active users
-
-        Returns:
-            Defered[int]: Number of current monthly active users
-        """
-
-        def _count_users(txn):
-            sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
-
-            txn.execute(sql)
-            (count,) = txn.fetchone()
-            return count
-
-        return self.db.runInteraction("count_users", _count_users)
-
-    @defer.inlineCallbacks
-    def get_registered_reserved_users(self):
-        """Of the reserved threepids defined in config, which are associated
-        with registered users?
-
-        Returns:
-            Defered[list]: Real reserved users
-        """
-        users = []
-
-        for tp in self.hs.config.mau_limits_reserved_threepids[
-            : self.hs.config.max_mau_value
-        ]:
-            user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
-                tp["medium"], tp["address"]
-            )
-            if user_id:
-                users.append(user_id)
-
-        return users
 
     @defer.inlineCallbacks
     def upsert_monthly_active_user(self, user_id):
@@ -222,23 +251,9 @@ class MonthlyActiveUsersStore(SQLBaseStore):
             "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
         )
 
-        user_in_mau = self.user_last_seen_monthly_active.cache.get(
-            (user_id,), None, update_metrics=False
-        )
-        if user_in_mau is None:
-            self.get_monthly_active_count.invalidate(())
-
-        self.user_last_seen_monthly_active.invalidate((user_id,))
-
     def upsert_monthly_active_user_txn(self, txn, user_id):
         """Updates or inserts monthly active user member
 
-        Note that, after calling this method, it will generally be necessary
-        to invalidate the caches on user_last_seen_monthly_active and
-        get_monthly_active_count. We can't do that here, because we are running
-        in a database thread rather than the main thread, and we can't call
-        txn.call_after because txn may not be a LoggingTransaction.
-
         We consciously do not call is_support_txn from this method because it
         is not possible to cache the response. is_support_txn will be false in
         almost all cases, so it seems reasonable to call it only for
@@ -269,27 +284,13 @@ class MonthlyActiveUsersStore(SQLBaseStore):
             values={"timestamp": int(self._clock.time_msec())},
         )
 
-        return is_insert
-
-    @cached(num_args=1)
-    def user_last_seen_monthly_active(self, user_id):
-        """
-            Checks if a given user is part of the monthly active user group
-            Arguments:
-                user_id (str): user to add/update
-            Return:
-                Deferred[int] : timestamp since last seen, None if never seen
-
-        """
-
-        return self.db.simple_select_one_onecol(
-            table="monthly_active_users",
-            keyvalues={"user_id": user_id},
-            retcol="timestamp",
-            allow_none=True,
-            desc="user_last_seen_monthly_active",
+        self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+        self._invalidate_cache_and_stream(
+            txn, self.user_last_seen_monthly_active, (user_id,)
         )
 
+        return is_insert
+
     @defer.inlineCallbacks
     def populate_monthly_active_users(self, user_id):
         """Checks on the state of monthly active user limits and optionally
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index a2c83e0867..604c8b7ddd 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -17,8 +17,8 @@ from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.presence import UserPresenceState
-from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
 
 
 class PresenceStore(SQLBaseStore):
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index cb4b2b39a0..49306642ed 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -291,7 +291,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="is_server_admin",
         )
 
-        return res if res else False
+        return bool(res) if res else False
 
     def set_server_admin(self, user, admin):
         """Sets whether a user is an admin of this homeserver.
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 8636d75030..d968803ad2 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -18,7 +18,8 @@ import collections
 import logging
 import re
 from abc import abstractmethod
-from typing import Optional, Tuple
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple
 
 from six import integer_types
 
@@ -46,6 +47,18 @@ RatelimitOverride = collections.namedtuple(
 )
 
 
+class RoomSortOrder(Enum):
+    """
+    Enum to define the sorting method used when returning rooms with get_rooms_paginate
+
+    ALPHABETICAL = sort rooms alphabetically by name
+    SIZE = sort rooms by membership size, highest to lowest
+    """
+
+    ALPHABETICAL = "alphabetical"
+    SIZE = "size"
+
+
 class RoomWorkerStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
         super(RoomWorkerStore, self).__init__(database, db_conn, hs)
@@ -281,6 +294,116 @@ class RoomWorkerStore(SQLBaseStore):
             desc="is_room_blocked",
         )
 
+    async def get_rooms_paginate(
+        self,
+        start: int,
+        limit: int,
+        order_by: RoomSortOrder,
+        reverse_order: bool,
+        search_term: Optional[str],
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Function to retrieve a paginated list of rooms as json.
+
+        Args:
+            start: offset in the list
+            limit: maximum amount of rooms to retrieve
+            order_by: the sort order of the returned list
+            reverse_order: whether to reverse the room list
+            search_term: a string to filter room names by
+        Returns:
+            A list of room dicts and an integer representing the total number of
+            rooms that exist given this query
+        """
+        # Filter room names by a string
+        where_statement = ""
+        if search_term:
+            where_statement = "WHERE state.name LIKE ?"
+
+            # Our postgres db driver converts ? -> %s in SQL strings as that's the
+            # placeholder for postgres.
+            # HOWEVER, if you put a % into your SQL then everything goes wibbly.
+            # To get around this, we're going to surround search_term with %'s
+            # before giving it to the database in python instead
+            search_term = "%" + search_term + "%"
+
+        # Set ordering
+        if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
+            order_by_column = "curr.joined_members"
+            order_by_asc = False
+        elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL:
+            # Sort alphabetically
+            order_by_column = "state.name"
+            order_by_asc = True
+        else:
+            raise StoreError(
+                500, "Incorrect value for order_by provided: %s" % order_by
+            )
+
+        # Whether to return the list in reverse order
+        if reverse_order:
+            # Flip the boolean
+            order_by_asc = not order_by_asc
+
+        # Create one query for getting the limited number of events that the user asked
+        # for, and another query for getting the total number of events that could be
+        # returned. Thus allowing us to see if there are more events to paginate through
+        info_sql = """
+            SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members
+            FROM room_stats_state state
+            INNER JOIN room_stats_current curr USING (room_id)
+            %s
+            ORDER BY %s %s
+            LIMIT ?
+            OFFSET ?
+        """ % (
+            where_statement,
+            order_by_column,
+            "ASC" if order_by_asc else "DESC",
+        )
+
+        # Use a nested SELECT statement as SQL can't count(*) with an OFFSET
+        count_sql = """
+            SELECT count(*) FROM (
+              SELECT room_id FROM room_stats_state state
+              %s
+            ) AS get_room_ids
+        """ % (
+            where_statement,
+        )
+
+        def _get_rooms_paginate_txn(txn):
+            # Execute the data query
+            sql_values = (limit, start)
+            if search_term:
+                # Add the search term into the WHERE clause
+                sql_values = (search_term,) + sql_values
+            txn.execute(info_sql, sql_values)
+
+            # Refactor room query data into a structured dictionary
+            rooms = []
+            for room in txn:
+                rooms.append(
+                    {
+                        "room_id": room[0],
+                        "name": room[1],
+                        "canonical_alias": room[2],
+                        "joined_members": room[3],
+                    }
+                )
+
+            # Execute the count query
+
+            # Add the search term into the WHERE clause if present
+            sql_values = (search_term,) if search_term else ()
+            txn.execute(count_sql, sql_values)
+
+            room_count = txn.fetchone()
+            return rooms, room_count[0]
+
+        return await self.db.runInteraction(
+            "get_rooms_paginate", _get_rooms_paginate_txn,
+        )
+
     @cachedInlineCallbacks(max_entries=10000)
     def get_ratelimit_for_user(self, user_id):
         """Check if there are any overrides for ratelimiting for the given
@@ -399,6 +522,8 @@ class RoomWorkerStore(SQLBaseStore):
         the associated media
         """
 
+        logger.info("Quarantining media in room: %s", room_id)
+
         def _quarantine_media_in_room_txn(txn):
             local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
             total_media_quarantined = 0
@@ -494,6 +619,118 @@ class RoomWorkerStore(SQLBaseStore):
 
         return local_media_mxcs, remote_media_mxcs
 
+    def quarantine_media_by_id(
+        self, server_name: str, media_id: str, quarantined_by: str,
+    ):
+        """quarantines a single local or remote media id
+
+        Args:
+            server_name: The name of the server that holds this media
+            media_id: The ID of the media to be quarantined
+            quarantined_by: The user ID that initiated the quarantine request
+        """
+        logger.info("Quarantining media: %s/%s", server_name, media_id)
+        is_local = server_name == self.config.server_name
+
+        def _quarantine_media_by_id_txn(txn):
+            local_mxcs = [media_id] if is_local else []
+            remote_mxcs = [(server_name, media_id)] if not is_local else []
+
+            return self._quarantine_media_txn(
+                txn, local_mxcs, remote_mxcs, quarantined_by
+            )
+
+        return self.db.runInteraction(
+            "quarantine_media_by_user", _quarantine_media_by_id_txn
+        )
+
+    def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+        """quarantines all local media associated with a single user
+
+        Args:
+            user_id: The ID of the user to quarantine media of
+            quarantined_by: The ID of the user who made the quarantine request
+        """
+
+        def _quarantine_media_by_user_txn(txn):
+            local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
+            return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
+
+        return self.db.runInteraction(
+            "quarantine_media_by_user", _quarantine_media_by_user_txn
+        )
+
+    def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+        """Retrieves local media IDs by a given user
+
+        Args:
+            txn (cursor)
+            user_id: The ID of the user to retrieve media IDs of
+
+        Returns:
+            The local and remote media as a lists of tuples where the key is
+            the hostname and the value is the media ID.
+        """
+        # Local media
+        sql = """
+            SELECT media_id
+            FROM local_media_repository
+            WHERE user_id = ?
+            """
+        if filter_quarantined:
+            sql += "AND quarantined_by IS NULL"
+        txn.execute(sql, (user_id,))
+
+        local_media_ids = [row[0] for row in txn]
+
+        # TODO: Figure out all remote media a user has referenced in a message
+
+        return local_media_ids
+
+    def _quarantine_media_txn(
+        self,
+        txn,
+        local_mxcs: List[str],
+        remote_mxcs: List[Tuple[str, str]],
+        quarantined_by: str,
+    ) -> int:
+        """Quarantine local and remote media items
+
+        Args:
+            txn (cursor)
+            local_mxcs: A list of local mxc URLs
+            remote_mxcs: A list of (remote server, media id) tuples representing
+                remote mxc URLs
+            quarantined_by: The ID of the user who initiated the quarantine request
+        Returns:
+            The total number of media items quarantined
+        """
+        total_media_quarantined = 0
+
+        # Update all the tables to set the quarantined_by flag
+        txn.executemany(
+            """
+            UPDATE local_media_repository
+            SET quarantined_by = ?
+            WHERE media_id = ?
+        """,
+            ((quarantined_by, media_id) for media_id in local_mxcs),
+        )
+
+        txn.executemany(
+            """
+                UPDATE remote_media_cache
+                SET quarantined_by = ?
+                WHERE media_origin = ? AND media_id = ?
+            """,
+            ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
+        )
+
+        total_media_quarantined += len(local_mxcs)
+        total_media_quarantined += len(remote_mxcs)
+
+        return total_media_quarantined
+
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 70ff5751b6..9acef7c950 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -297,19 +297,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return {row[0]: row[1] for row in txn}
 
     @cached()
-    def get_invited_rooms_for_user(self, user_id):
-        """ Get all the rooms the user is invited to
+    def get_invited_rooms_for_local_user(self, user_id):
+        """ Get all the rooms the *local* user is invited to
+
         Args:
             user_id (str): The user ID.
         Returns:
             A deferred list of RoomsForUser.
         """
 
-        return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
+        return self.get_rooms_for_local_user_where_membership_is(
+            user_id, [Membership.INVITE]
+        )
 
     @defer.inlineCallbacks
-    def get_invite_for_user_in_room(self, user_id, room_id):
-        """Gets the invite for the given user and room
+    def get_invite_for_local_user_in_room(self, user_id, room_id):
+        """Gets the invite for the given *local* user and room
 
         Args:
             user_id (str)
@@ -319,15 +322,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             Deferred: Resolves to either a RoomsForUser or None if no invite was
                 found.
         """
-        invites = yield self.get_invited_rooms_for_user(user_id)
+        invites = yield self.get_invited_rooms_for_local_user(user_id)
         for invite in invites:
             if invite.room_id == room_id:
                 return invite
         return None
 
     @defer.inlineCallbacks
-    def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
-        """ Get all the rooms for this user where the membership for this user
+    def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
+        """ Get all the rooms for this *local* user where the membership for this user
         matches one in the membership list.
 
         Filters out forgotten rooms.
@@ -344,8 +347,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             return defer.succeed(None)
 
         rooms = yield self.db.runInteraction(
-            "get_rooms_for_user_where_membership_is",
-            self._get_rooms_for_user_where_membership_is_txn,
+            "get_rooms_for_local_user_where_membership_is",
+            self._get_rooms_for_local_user_where_membership_is_txn,
             user_id,
             membership_list,
         )
@@ -354,76 +357,42 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
         return [room for room in rooms if room.room_id not in forgotten_rooms]
 
-    def _get_rooms_for_user_where_membership_is_txn(
+    def _get_rooms_for_local_user_where_membership_is_txn(
         self, txn, user_id, membership_list
     ):
+        # Paranoia check.
+        if not self.hs.is_mine_id(user_id):
+            raise Exception(
+                "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
+                % (user_id,),
+            )
 
-        do_invite = Membership.INVITE in membership_list
-        membership_list = [m for m in membership_list if m != Membership.INVITE]
-
-        results = []
-        if membership_list:
-            if self._current_state_events_membership_up_to_date:
-                clause, args = make_in_list_sql_clause(
-                    self.database_engine, "c.membership", membership_list
-                )
-                sql = """
-                    SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
-                    FROM current_state_events AS c
-                    INNER JOIN events AS e USING (room_id, event_id)
-                    WHERE
-                        c.type = 'm.room.member'
-                        AND state_key = ?
-                        AND %s
-                """ % (
-                    clause,
-                )
-            else:
-                clause, args = make_in_list_sql_clause(
-                    self.database_engine, "m.membership", membership_list
-                )
-                sql = """
-                    SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
-                    FROM current_state_events AS c
-                    INNER JOIN room_memberships AS m USING (room_id, event_id)
-                    INNER JOIN events AS e USING (room_id, event_id)
-                    WHERE
-                        c.type = 'm.room.member'
-                        AND state_key = ?
-                        AND %s
-                """ % (
-                    clause,
-                )
-
-            txn.execute(sql, (user_id, *args))
-            results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "c.membership", membership_list
+        )
 
-        if do_invite:
-            sql = (
-                "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
-                " FROM local_invites as i"
-                " INNER JOIN events as e USING (event_id)"
-                " WHERE invitee = ? AND locally_rejected is NULL"
-                " AND replaced_by is NULL"
-            )
+        sql = """
+            SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+            FROM local_current_membership AS c
+            INNER JOIN events AS e USING (room_id, event_id)
+            WHERE
+                user_id = ?
+                AND %s
+        """ % (
+            clause,
+        )
 
-            txn.execute(sql, (user_id,))
-            results.extend(
-                RoomsForUser(
-                    room_id=r["room_id"],
-                    sender=r["inviter"],
-                    event_id=r["event_id"],
-                    stream_ordering=r["stream_ordering"],
-                    membership=Membership.INVITE,
-                )
-                for r in self.db.cursor_to_dict(txn)
-            )
+        txn.execute(sql, (user_id, *args))
+        results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
 
         return results
 
-    @cachedInlineCallbacks(max_entries=500000, iterable=True)
+    @cached(max_entries=500000, iterable=True)
     def get_rooms_for_user_with_stream_ordering(self, user_id):
-        """Returns a set of room_ids the user is currently joined to
+        """Returns a set of room_ids the user is currently joined to.
+
+        If a remote user only returns rooms this server is currently
+        participating in.
 
         Args:
             user_id (str)
@@ -433,17 +402,49 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             the rooms the user is in currently, along with the stream ordering
             of the most recent join for that user and room.
         """
-        rooms = yield self.get_rooms_for_user_where_membership_is(
-            user_id, membership_list=[Membership.JOIN]
-        )
-        return frozenset(
-            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
-            for r in rooms
+        return self.db.runInteraction(
+            "get_rooms_for_user_with_stream_ordering",
+            self._get_rooms_for_user_with_stream_ordering_txn,
+            user_id,
         )
 
+    def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+        # We use `current_state_events` here and not `local_current_membership`
+        # as a) this gets called with remote users and b) this only gets called
+        # for rooms the server is participating in.
+        if self._current_state_events_membership_up_to_date:
+            sql = """
+                SELECT room_id, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND state_key = ?
+                    AND c.membership = ?
+            """
+        else:
+            sql = """
+                SELECT room_id, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN room_memberships AS m USING (room_id, event_id)
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND state_key = ?
+                    AND m.membership = ?
+            """
+
+        txn.execute(sql, (user_id, Membership.JOIN))
+        results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+
+        return results
+
     @defer.inlineCallbacks
     def get_rooms_for_user(self, user_id, on_invalidate=None):
-        """Returns a set of room_ids the user is currently joined to
+        """Returns a set of room_ids the user is currently joined to.
+
+        If a remote user only returns rooms this server is currently
+        participating in.
         """
         rooms = yield self.get_rooms_for_user_with_stream_ordering(
             user_id, on_invalidate=on_invalidate
@@ -1022,7 +1023,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
                 event.internal_metadata.stream_ordering,
             )
             txn.call_after(
-                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+                self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
             )
 
             # We update the local_invites table only if the event is "current",
@@ -1064,6 +1065,27 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
                         ),
                     )
 
+                # We also update the `local_current_membership` table with
+                # latest invite info. This will usually get updated by the
+                # `current_state_events` handling, unless its an outlier.
+                if event.internal_metadata.is_outlier():
+                    # This should only happen for out of band memberships, so
+                    # we add a paranoia check.
+                    assert event.internal_metadata.is_out_of_band_membership()
+
+                    self.db.simple_upsert_txn(
+                        txn,
+                        table="local_current_membership",
+                        keyvalues={
+                            "room_id": event.room_id,
+                            "user_id": event.state_key,
+                        },
+                        values={
+                            "event_id": event.event_id,
+                            "membership": event.membership,
+                        },
+                    )
+
     @defer.inlineCallbacks
     def locally_reject_invite(self, user_id, room_id):
         sql = (
@@ -1075,6 +1097,15 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
         def f(txn, stream_ordering):
             txn.execute(sql, (stream_ordering, True, room_id, user_id))
 
+            # We also clear this entry from `local_current_membership`.
+            # Ideally we'd point to a leave event, but we don't have one, so
+            # nevermind.
+            self.db.simple_delete_txn(
+                txn,
+                table="local_current_membership",
+                keyvalues={"room_id": room_id, "user_id": user_id},
+            )
+
         with self._stream_id_gen.get_next() as stream_ordering:
             yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
 
diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
new file mode 100644
index 0000000000..63b5acdcf7
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# We create a new table called `local_current_membership` that stores the latest
+# membership state of local users in rooms, which helps track leaves/bans/etc
+# even if the server has left the room (and so has deleted the room from
+# `current_state_events`). This will also include outstanding invites for local
+# users for rooms the server isn't in.
+#
+# If the server isn't and hasn't been in the room then it will only include
+# outsstanding invites, and not e.g. pre-emptive bans of local users.
+#
+# If the server later rejoins a room `local_current_membership` can simply be
+# replaced with the new current state of the room (which results in the
+# equivalent behaviour as if the server had remained in the room).
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+    # We need to do the insert in `run_upgrade` section as we don't have access
+    # to `config` in `run_create`.
+
+    # This upgrade may take a bit of time for large servers (e.g. one minute for
+    # matrix.org) but means we avoid a lots of book keeping required to do it as
+    # a background update.
+
+    # We check if the `current_state_events.membership` is up to date by
+    # checking if the relevant background update has finished. If it has
+    # finished we can avoid doing a join against `room_memberships`, which
+    # speesd things up.
+    cur.execute(
+        """SELECT 1 FROM background_updates
+            WHERE update_name = 'current_state_events_membership'
+        """
+    )
+    current_state_membership_up_to_date = not bool(cur.fetchone())
+
+    # Cheekily drop and recreate indices, as that is faster.
+    cur.execute("DROP INDEX local_current_membership_idx")
+    cur.execute("DROP INDEX local_current_membership_room_idx")
+
+    if current_state_membership_up_to_date:
+        sql = """
+            INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+                SELECT c.room_id, state_key AS user_id, event_id, c.membership
+                FROM current_state_events AS c
+                WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key LIKE ?
+        """
+    else:
+        # We can't rely on the membership column, so we need to join against
+        # `room_memberships`.
+        sql = """
+            INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+                SELECT c.room_id, state_key AS user_id, event_id, r.membership
+                FROM current_state_events AS c
+                INNER JOIN room_memberships AS r USING (event_id)
+                WHERE type = 'm.room.member' AND state_key LIKE ?
+        """
+    sql = database_engine.convert_param_style(sql)
+    cur.execute(sql, ("%:" + config.server_name,))
+
+    cur.execute(
+        "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+    )
+    cur.execute(
+        "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+    )
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+    cur.execute(
+        """
+        CREATE TABLE local_current_membership (
+            room_id TEXT NOT NULL,
+            user_id TEXT NOT NULL,
+            event_id TEXT NOT NULL,
+            membership TEXT NOT NULL
+        )"""
+    )
+
+    cur.execute(
+        "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+    )
+    cur.execute(
+        "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+    )
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index d07440e3ed..33bebd1c48 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
     # FIXME: how should this be cached?
-    def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
+    def get_filtered_current_state_ids(
+        self, room_id: str, state_filter: StateFilter = StateFilter.all()
+    ):
         """Get the current state event of a given type for a room based on the
         current_state_events table.  This may not be as up-to-date as the result
         of doing a fresh state resolution as per state_handler.get_current_state
 
         Args:
-            room_id (str)
-            state_filter (StateFilter): The state filter used to fetch state
+            room_id
+            state_filter: The state filter used to fetch state
                 from the database.
 
         Returns:
-            Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
-            event ID.
+            defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
         """
 
         where_clause, where_args = state_filter.make_sql_filter_clause()
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 140da8dad6..056b25b13a 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -525,8 +525,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return rows, token
 
-    def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
-        """Gets details of the first event in a room at or after a stream ordering
+    def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+        """Gets details of the first event in a room at or before a stream ordering
 
         Args:
             room_id (str):
@@ -541,15 +541,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             sql = (
                 "SELECT stream_ordering, topological_ordering, event_id"
                 " FROM events"
-                " WHERE room_id = ? AND stream_ordering >= ?"
+                " WHERE room_id = ? AND stream_ordering <= ?"
                 " AND NOT outlier"
-                " ORDER BY stream_ordering"
+                " ORDER BY stream_ordering DESC"
                 " LIMIT 1"
             )
             txn.execute(sql, (room_id, stream_ordering))
             return txn.fetchone()
 
-        return self.db.runInteraction("get_room_event_after_stream_ordering", _f)
+        return self.db.runInteraction("get_room_event_before_stream_ordering", _f)
 
     @defer.inlineCallbacks
     def get_room_events_max_id(self, room_id=None):
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index d53695f238..c4ee9b7ccb 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -15,6 +15,7 @@
 
 import logging
 from collections import namedtuple
+from typing import Dict, Iterable, List, Set, Tuple
 
 from six import iteritems
 from six.moves import range
@@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
 from synapse.storage.database import Database
 from synapse.storage.state import StateFilter
+from synapse.types import StateMap
 from synapse.util.caches import get_cache_factor_for
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, state_filter):
-        """Returns the state groups for a given set of groups, filtering on
-        types of state events.
+    def _get_state_groups_from_groups(
+        self, groups: List[int], state_filter: StateFilter
+    ):
+        """Returns the state groups for a given set of groups from the
+        database, filtering on types of state events.
 
         Args:
-            groups(list[int]): list of state group IDs to query
-            state_filter (StateFilter): The state filter used to fetch state
+            groups: list of state group IDs to query
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
         results = {}
 
@@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         return state_filter.filter_state(state_dict_ids), not missing_types
 
     @defer.inlineCallbacks
-    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+    def _get_state_for_groups(
+        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+    ):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
         Args:
-            groups (iterable[int]): list of state groups for which we want
+            groups: list of state groups for which we want
                 to get the state.
-            state_filter (StateFilter): The state filter used to fetch state
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
 
         member_filter, non_member_filter = state_filter.get_member_split()
@@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return state
 
-    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
+    def _get_state_for_groups_using_cache(
+        self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
+    ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key, querying from a specific cache.
 
         Args:
-            groups (iterable[int]): list of state groups for which we want
-                to get the state.
-            cache (DictionaryCache): the cache of group ids to state dicts which
-                we will pass through - either the normal state cache or the specific
-                members state cache.
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
+            groups: list of state groups for which we want to get the state.
+            cache: the cache of group ids to state dicts which
+                we will pass through - either the normal state cache or the
+                specific members state cache.
+            state_filter: The state filter used to fetch state from the
+                database.
 
         Returns:
-            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
-            dict of state_group_id -> (dict of (type, state_key) -> event id)
-            of entries in the cache, and the state group ids either missing
-            from the cache or incomplete.
+            Tuple of dict of state_group_id to state map of entries in the
+            cache, and the state group ids either missing from the cache or
+            incomplete.
         """
         results = {}
         incomplete_groups = set()
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index b7c4eda338..c84cb452b0 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -32,20 +32,7 @@ class PostgresEngine(object):
         self.synchronous_commit = database_config.get("synchronous_commit", True)
         self._version = None  # unknown as yet
 
-    def check_database(self, txn):
-        txn.execute("SHOW SERVER_ENCODING")
-        rows = txn.fetchall()
-        if rows and rows[0][0] != "UTF8":
-            raise IncorrectDatabaseSetup(
-                "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
-                "See docs/postgres.rst for more information." % (rows[0][0],)
-            )
-
-    def convert_param_style(self, sql):
-        return sql.replace("?", "%s")
-
-    def on_new_connection(self, db_conn):
-
+    def check_database(self, db_conn, allow_outdated_version: bool = False):
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
@@ -53,9 +40,22 @@ class PostgresEngine(object):
         self._version = db_conn.server_version
 
         # Are we on a supported PostgreSQL version?
-        if self._version < 90500:
+        if not allow_outdated_version and self._version < 90500:
             raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
 
+        with db_conn.cursor() as txn:
+            txn.execute("SHOW SERVER_ENCODING")
+            rows = txn.fetchall()
+            if rows and rows[0][0] != "UTF8":
+                raise IncorrectDatabaseSetup(
+                    "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
+                    "See docs/postgres.rst for more information." % (rows[0][0],)
+                )
+
+    def convert_param_style(self, sql):
+        return sql.replace("?", "%s")
+
+    def on_new_connection(self, db_conn):
         db_conn.set_isolation_level(
             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
@@ -119,8 +119,8 @@ class PostgresEngine(object):
         Returns:
             string
         """
-        # note that this is a bit of a hack because it relies on on_new_connection
-        # having been called at least once. Still, that should be a safe bet here.
+        # note that this is a bit of a hack because it relies on check_database
+        # having been called. Still, that should be a safe bet here.
         numver = self._version
         assert numver is not None
 
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index df039a072d..cbf52f5191 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -53,8 +53,11 @@ class Sqlite3Engine(object):
         """
         return False
 
-    def check_database(self, txn):
-        pass
+    def check_database(self, db_conn, allow_outdated_version: bool = False):
+        if not allow_outdated_version:
+            version = self.module.sqlite_version_info
+            if version < (3, 11, 0):
+                raise RuntimeError("Synapse requires sqlite 3.11 or above.")
 
     def convert_param_style(self, sql):
         return sql
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 1ed44925fc..368c457321 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -17,19 +17,24 @@
 
 import logging
 from collections import deque, namedtuple
+from typing import Iterable, List, Optional, Tuple
 
 from six import iteritems
 from six.moves import range
 
+import attr
 from prometheus_client import Counter, Histogram
 
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.state import StateResolutionStore
 from synapse.storage.data_stores import DataStores
+from synapse.types import StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
 
@@ -67,6 +72,19 @@ stale_forward_extremities_counter = Histogram(
 )
 
 
+@attr.s(slots=True, frozen=True)
+class DeltaState:
+    """Deltas to use to update the `current_state_events` table.
+
+    Attributes:
+        to_delete: List of type/state_keys to delete from current state
+        to_insert: Map of state to upsert into current state
+    """
+
+    to_delete = attr.ib(type=List[Tuple[str, str]])
+    to_insert = attr.ib(type=StateMap[str])
+
+
 class _EventPeristenceQueue(object):
     """Queues up events so that they can be persisted in bulk with only one
     concurrent transaction per room.
@@ -138,13 +156,12 @@ class _EventPeristenceQueue(object):
 
         self._currently_persisting_rooms.add(room_id)
 
-        @defer.inlineCallbacks
-        def handle_queue_loop():
+        async def handle_queue_loop():
             try:
                 queue = self._get_drainining_queue(room_id)
                 for item in queue:
                     try:
-                        ret = yield per_item_callback(item)
+                        ret = await per_item_callback(item)
                     except Exception:
                         with PreserveLoggingContext():
                             item.deferred.errback()
@@ -191,12 +208,16 @@ class EventsPersistenceStorage(object):
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
     @defer.inlineCallbacks
-    def persist_events(self, events_and_contexts, backfilled=False):
+    def persist_events(
+        self,
+        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        backfilled: bool = False,
+    ):
         """
         Write events to the database
         Args:
             events_and_contexts: list of tuples of (event, context)
-            backfilled (bool): Whether the results are retrieved from federation
+            backfilled: Whether the results are retrieved from federation
                 via backfill or not. Used to determine if they're "new" events
                 which might update the current state etc.
 
@@ -226,16 +247,12 @@ class EventsPersistenceStorage(object):
         return max_persisted_id
 
     @defer.inlineCallbacks
-    def persist_event(self, event, context, backfilled=False):
+    def persist_event(
+        self, event: FrozenEvent, context: EventContext, backfilled: bool = False
+    ):
         """
-
-        Args:
-            event (EventBase):
-            context (EventContext):
-            backfilled (bool):
-
         Returns:
-            Deferred: resolves to (int, int): the stream ordering of ``event``,
+            Deferred[Tuple[int, int]]: the stream ordering of ``event``,
             and the stream ordering of the latest persisted event
         """
         deferred = self._event_persist_queue.add_to_queue(
@@ -249,28 +266,22 @@ class EventsPersistenceStorage(object):
         max_persisted_id = yield self.main_store.get_current_events_token()
         return (event.internal_metadata.stream_ordering, max_persisted_id)
 
-    def _maybe_start_persisting(self, room_id):
-        @defer.inlineCallbacks
-        def persisting_queue(item):
+    def _maybe_start_persisting(self, room_id: str):
+        async def persisting_queue(item):
             with Measure(self._clock, "persist_events"):
-                yield self._persist_events(
+                await self._persist_events(
                     item.events_and_contexts, backfilled=item.backfilled
                 )
 
         self._event_persist_queue.handle_queue(room_id, persisting_queue)
 
-    @defer.inlineCallbacks
-    def _persist_events(self, events_and_contexts, backfilled=False):
+    async def _persist_events(
+        self,
+        events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+        backfilled: bool = False,
+    ):
         """Calculates the change to current state and forward extremities, and
         persists the given events and with those updates.
-
-        Args:
-            events_and_contexts (list[(EventBase, EventContext)]):
-            backfilled (bool):
-            delete_existing (bool):
-
-        Returns:
-            Deferred: resolves when the events have been persisted
         """
         if not events_and_contexts:
             return
@@ -315,10 +326,10 @@ class EventsPersistenceStorage(object):
                         )
 
                     for room_id, ev_ctx_rm in iteritems(events_by_room):
-                        latest_event_ids = yield self.main_store.get_latest_event_ids_in_room(
+                        latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
                             room_id
                         )
-                        new_latest_event_ids = yield self._calculate_new_extremities(
+                        new_latest_event_ids = await self._calculate_new_extremities(
                             room_id, ev_ctx_rm, latest_event_ids
                         )
 
@@ -374,7 +385,7 @@ class EventsPersistenceStorage(object):
                         with Measure(
                             self._clock, "persist_events.get_new_state_after_events"
                         ):
-                            res = yield self._get_new_state_after_events(
+                            res = await self._get_new_state_after_events(
                                 room_id,
                                 ev_ctx_rm,
                                 latest_event_ids,
@@ -389,12 +400,12 @@ class EventsPersistenceStorage(object):
                             # If there is a delta we know that we've
                             # only added or replaced state, never
                             # removed keys entirely.
-                            state_delta_for_room[room_id] = ([], delta_ids)
+                            state_delta_for_room[room_id] = DeltaState([], delta_ids)
                         elif current_state is not None:
                             with Measure(
                                 self._clock, "persist_events.calculate_state_delta"
                             ):
-                                delta = yield self._calculate_state_delta(
+                                delta = await self._calculate_state_delta(
                                     room_id, current_state
                                 )
                             state_delta_for_room[room_id] = delta
@@ -404,7 +415,7 @@ class EventsPersistenceStorage(object):
                         if current_state is not None:
                             current_state_for_room[room_id] = current_state
 
-            yield self.main_store._persist_events_and_state_updates(
+            await self.main_store._persist_events_and_state_updates(
                 chunk,
                 current_state_for_room=current_state_for_room,
                 state_delta_for_room=state_delta_for_room,
@@ -412,8 +423,12 @@ class EventsPersistenceStorage(object):
                 backfilled=backfilled,
             )
 
-    @defer.inlineCallbacks
-    def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
+    async def _calculate_new_extremities(
+        self,
+        room_id: str,
+        event_contexts: List[Tuple[FrozenEvent, EventContext]],
+        latest_event_ids: List[str],
+    ):
         """Calculates the new forward extremities for a room given events to
         persist.
 
@@ -444,13 +459,13 @@ class EventsPersistenceStorage(object):
         )
 
         # Remove any events which are prev_events of any existing events.
-        existing_prevs = yield self.main_store._get_events_which_are_prevs(result)
+        existing_prevs = await self.main_store._get_events_which_are_prevs(result)
         result.difference_update(existing_prevs)
 
         # Finally handle the case where the new events have soft-failed prev
         # events. If they do we need to remove them and their prev events,
         # otherwise we end up with dangling extremities.
-        existing_prevs = yield self.main_store._get_prevs_before_rejected(
+        existing_prevs = await self.main_store._get_prevs_before_rejected(
             e_id for event in new_events for e_id in event.prev_event_ids()
         )
         result.difference_update(existing_prevs)
@@ -464,10 +479,13 @@ class EventsPersistenceStorage(object):
 
         return result
 
-    @defer.inlineCallbacks
-    def _get_new_state_after_events(
-        self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
-    ):
+    async def _get_new_state_after_events(
+        self,
+        room_id: str,
+        events_context: List[Tuple[FrozenEvent, EventContext]],
+        old_latest_event_ids: Iterable[str],
+        new_latest_event_ids: Iterable[str],
+    ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
         """Calculate the current state dict after adding some new events to
         a room
 
@@ -485,7 +503,6 @@ class EventsPersistenceStorage(object):
                 the new forward extremities for the room.
 
         Returns:
-            Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
             Returns a tuple of two state maps, the first being the full new current
             state and the second being the delta to the existing current state.
             If both are None then there has been no change.
@@ -547,7 +564,7 @@ class EventsPersistenceStorage(object):
 
         if missing_event_ids:
             # Now pull out the state groups for any missing events from DB
-            event_to_groups = yield self.main_store._get_state_group_for_events(
+            event_to_groups = await self.main_store._get_state_group_for_events(
                 missing_event_ids
             )
             event_id_to_state_group.update(event_to_groups)
@@ -588,7 +605,7 @@ class EventsPersistenceStorage(object):
         # their state IDs so we can resolve to a single state set.
         missing_state = new_state_groups - set(state_groups_map)
         if missing_state:
-            group_to_state = yield self.state_store._get_state_for_groups(missing_state)
+            group_to_state = await self.state_store._get_state_for_groups(missing_state)
             state_groups_map.update(group_to_state)
 
         if len(new_state_groups) == 1:
@@ -612,10 +629,10 @@ class EventsPersistenceStorage(object):
                 break
 
         if not room_version:
-            room_version = yield self.main_store.get_room_version(room_id)
+            room_version = await self.main_store.get_room_version(room_id)
 
         logger.debug("calling resolve_state_groups from preserve_events")
-        res = yield self._state_resolution_handler.resolve_state_groups(
+        res = await self._state_resolution_handler.resolve_state_groups(
             room_id,
             room_version,
             state_groups,
@@ -625,18 +642,14 @@ class EventsPersistenceStorage(object):
 
         return res.state, None
 
-    @defer.inlineCallbacks
-    def _calculate_state_delta(self, room_id, current_state):
+    async def _calculate_state_delta(
+        self, room_id: str, current_state: StateMap[str]
+    ) -> DeltaState:
         """Calculate the new state deltas for a room.
 
         Assumes that we are only persisting events for one room at a time.
-
-        Returns:
-            tuple[list, dict] (to_delete, to_insert): where to_delete are the
-            type/state_keys to remove from current_state_events and `to_insert`
-            are the updates to current_state_events.
         """
-        existing_state = yield self.main_store.get_current_state_ids(room_id)
+        existing_state = await self.main_store.get_current_state_ids(room_id)
 
         to_delete = [key for key in existing_state if key not in current_state]
 
@@ -646,4 +659,4 @@ class EventsPersistenceStorage(object):
             if ev_id != existing_state.get(key)
         }
 
-        return to_delete, to_insert
+        return DeltaState(to_delete=to_delete, to_insert=to_insert)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e70026b80a..e86984cd50 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 56
+SCHEMA_VERSION = 57
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index d6a7bd7834..fdc0abf5cf 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -34,7 +34,7 @@ class PurgeEventsStorage(object):
         """
 
         state_groups_to_delete = yield self.stores.main.purge_room(room_id)
-        yield self.stores.main.purge_room_state(room_id, state_groups_to_delete)
+        yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
 
     @defer.inlineCallbacks
     def purge_history(self, room_id, token, delete_local_events):
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index cbeb586014..c522c80922 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Iterable, List, TypeVar
 
 from six import iteritems, itervalues
 
@@ -22,9 +23,13 @@ import attr
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes
+from synapse.types import StateMap
 
 logger = logging.getLogger(__name__)
 
+# Used for generic functions below
+T = TypeVar("T")
+
 
 @attr.s(slots=True)
 class StateFilter(object):
@@ -233,14 +238,14 @@ class StateFilter(object):
 
         return len(self.concrete_types())
 
-    def filter_state(self, state_dict):
+    def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
         """Returns the state filtered with by this StateFilter
 
         Args:
-            state (dict[tuple[str, str], Any]): The state map to filter
+            state: The state map to filter
 
         Returns:
-            dict[tuple[str, str], Any]: The filtered state map
+            The filtered state map
         """
         if self.is_full():
             return dict(state_dict)
@@ -333,12 +338,12 @@ class StateGroupStorage(object):
     def __init__(self, hs, stores):
         self.stores = stores
 
-    def get_state_group_delta(self, state_group):
+    def get_state_group_delta(self, state_group: int):
         """Given a state group try to return a previous group and a delta between
         the old and the new.
 
         Returns:
-            Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
+            Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
                 (prev_group, delta_ids)
         """
 
@@ -353,7 +358,7 @@ class StateGroupStorage(object):
             event_ids (iterable[str]): ids of the events
 
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
+            Deferred[dict[int, StateMap[str]]]:
                 dict of state_group_id -> (dict of (type, state_key) -> event id)
         """
         if not event_ids:
@@ -410,17 +415,18 @@ class StateGroupStorage(object):
             for group, event_id_map in iteritems(group_to_ids)
         }
 
-    def _get_state_groups_from_groups(self, groups, state_filter):
+    def _get_state_groups_from_groups(
+        self, groups: List[int], state_filter: StateFilter
+    ):
         """Returns the state groups for a given set of groups, filtering on
         types of state events.
 
         Args:
-            groups(list[int]): list of state group IDs to query
-            state_filter (StateFilter): The state filter used to fetch state
+            groups: list of state group IDs to query
+            state_filter: The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
         """
 
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@@ -519,7 +525,9 @@ class StateGroupStorage(object):
         state_map = yield self.get_state_ids_for_events([event_id], state_filter)
         return state_map[event_id]
 
-    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+    def _get_state_for_groups(
+        self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+    ):
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
@@ -529,8 +537,7 @@ class StateGroupStorage(object):
             state_filter (StateFilter): The state filter used to fetch state
                 from the database.
         Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
+            Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
         """
         return self.stores.state._get_state_for_groups(groups, state_filter)