summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py7
-rw-r--r--synapse/storage/background_updates.py19
-rw-r--r--synapse/storage/database.py651
-rw-r--r--synapse/storage/databases/__init__.py11
-rw-r--r--synapse/storage/databases/main/__init__.py31
-rw-r--r--synapse/storage/databases/main/account_data.py4
-rw-r--r--synapse/storage/databases/main/appservice.py7
-rw-r--r--synapse/storage/databases/main/cache.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py4
-rw-r--r--synapse/storage/databases/main/devices.py27
-rw-r--r--synapse/storage/databases/main/directory.py4
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py8
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py43
-rw-r--r--synapse/storage/databases/main/event_federation.py83
-rw-r--r--synapse/storage/databases/main/event_push_actions.py9
-rw-r--r--synapse/storage/databases/main/events.py37
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py46
-rw-r--r--synapse/storage/databases/main/events_worker.py343
-rw-r--r--synapse/storage/databases/main/group_server.py27
-rw-r--r--synapse/storage/databases/main/keys.py28
-rw-r--r--synapse/storage/databases/main/media_repository.py13
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py15
-rw-r--r--synapse/storage/databases/main/presence.py34
-rw-r--r--synapse/storage/databases/main/profile.py17
-rw-r--r--synapse/storage/databases/main/push_rule.py132
-rw-r--r--synapse/storage/databases/main/pusher.py108
-rw-r--r--synapse/storage/databases/main/receipts.py96
-rw-r--r--synapse/storage/databases/main/registration.py87
-rw-r--r--synapse/storage/databases/main/rejections.py5
-rw-r--r--synapse/storage/databases/main/room.py25
-rw-r--r--synapse/storage/databases/main/roommember.py21
-rw-r--r--synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql25
-rw-r--r--synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql18
-rw-r--r--synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql17
-rw-r--r--synapse/storage/databases/main/state.py9
-rw-r--r--synapse/storage/databases/main/stats.py10
-rw-r--r--synapse/storage/databases/main/stream.py401
-rw-r--r--synapse/storage/databases/main/tags.py11
-rw-r--r--synapse/storage/databases/main/ui_auth.py61
-rw-r--r--synapse/storage/databases/main/user_directory.py9
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py26
-rw-r--r--synapse/storage/presence.py69
-rw-r--r--synapse/storage/util/id_generators.py193
-rw-r--r--synapse/storage/util/sequence.py8
44 files changed, 1444 insertions, 1359 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 6814bf5fcf..ab49d227de 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,12 +19,11 @@ import random
 from abc import ABCMeta
 from typing import Any, Optional
 
-from canonicaljson import json
-
 from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
 from synapse.types import Collection, get_domain_from_id
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -99,13 +98,13 @@ def db_to_json(db_content):
     if isinstance(db_content, memoryview):
         db_content = db_content.tobytes()
 
-    # Decode it to a Unicode string before feeding it to json.loads, since
+    # Decode it to a Unicode string before feeding it to the JSON decoder, since
     # Python 3.5 does not support deserializing bytes.
     if isinstance(db_content, (bytes, bytearray)):
         db_content = db_content.decode("utf8")
 
     try:
-        return json.loads(db_content)
+        return json_decoder.decode(db_content)
     except Exception:
         logging.warning("Tried to decode '%r' as JSON and failed", db_content)
         raise
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index f43463df53..56818f4df8 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -16,11 +16,8 @@
 import logging
 from typing import Optional
 
-from canonicaljson import json
-
-from twisted.internet import defer
-
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import json_encoder
 
 from . import engines
 
@@ -308,9 +305,8 @@ class BackgroundUpdater(object):
             update_name (str): Name of update
         """
 
-        @defer.inlineCallbacks
-        def noop_update(progress, batch_size):
-            yield self._end_background_update(update_name)
+        async def noop_update(progress, batch_size):
+            await self._end_background_update(update_name)
             return 1
 
         self.register_background_update_handler(update_name, noop_update)
@@ -409,12 +405,11 @@ class BackgroundUpdater(object):
         else:
             runner = create_index_sqlite
 
-        @defer.inlineCallbacks
-        def updater(progress, batch_size):
+        async def updater(progress, batch_size):
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
-                yield self.db_pool.runWithConnection(runner)
-            yield self._end_background_update(update_name)
+                await self.db_pool.runWithConnection(runner)
+            await self._end_background_update(update_name)
             return 1
 
         self.register_background_update_handler(update_name, updater)
@@ -461,7 +456,7 @@ class BackgroundUpdater(object):
             progress(dict): The progress of the update.
         """
 
-        progress_json = json.dumps(progress)
+        progress_json = json_encoder.encode(progress)
 
         self.db_pool.simple_update_one_txn(
             txn,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 4ada6f5563..181c3ec249 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,9 +28,12 @@ from typing import (
     Optional,
     Tuple,
     TypeVar,
+    Union,
+    overload,
 )
 
 from prometheus_client import Histogram
+from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
 from twisted.internet import defer
@@ -125,7 +128,7 @@ class LoggingTransaction:
     method.
 
     Args:
-        txn: The database transcation object to wrap.
+        txn: The database transaction object to wrap.
         name: The name of this transactions for logging.
         database_engine
         after_callbacks: A list that callbacks will be appended to
@@ -160,7 +163,7 @@ class LoggingTransaction:
         self.after_callbacks = after_callbacks
         self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+    def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
@@ -171,7 +174,9 @@ class LoggingTransaction:
         assert self.after_callbacks is not None
         self.after_callbacks.append((callback, args, kwargs))
 
-    def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+    def call_on_exception(
+        self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+    ):
         # if self.exception_callbacks is None, that means that whatever constructed the
         # LoggingTransaction isn't expecting there to be any callbacks; assert that
         # is not the case.
@@ -195,7 +200,7 @@ class LoggingTransaction:
     def description(self) -> Any:
         return self.txn.description
 
-    def execute_batch(self, sql, args):
+    def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch  # type: ignore
 
@@ -204,17 +209,17 @@ class LoggingTransaction:
             for val in args:
                 self.execute(sql, val)
 
-    def execute(self, sql: str, *args: Any):
+    def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
 
-    def executemany(self, sql: str, *args: Any):
+    def executemany(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.executemany, sql, *args)
 
     def _make_sql_one_line(self, sql: str) -> str:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql, *args):
+    def _do_execute(self, func, sql: str, *args: Any) -> None:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -240,7 +245,7 @@ class LoggingTransaction:
             sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
             sql_query_timer.labels(sql.split()[0]).observe(secs)
 
-    def close(self):
+    def close(self) -> None:
         self.txn.close()
 
 
@@ -249,13 +254,13 @@ class PerformanceCounters(object):
         self.current_counters = {}
         self.previous_counters = {}
 
-    def update(self, key, duration_secs):
+    def update(self, key: str, duration_secs: float) -> None:
         count, cum_time = self.current_counters.get(key, (0, 0))
         count += 1
         cum_time += duration_secs
         self.current_counters[key] = (count, cum_time)
 
-    def interval(self, interval_duration_secs, limit=3):
+    def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
         counters = []
         for name, (count, cum_time) in self.current_counters.items():
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
@@ -279,6 +284,9 @@ class PerformanceCounters(object):
         return top_n_counters
 
 
+R = TypeVar("R")
+
+
 class DatabasePool(object):
     """Wraps a single physical database and connection pool.
 
@@ -327,13 +335,12 @@ class DatabasePool(object):
                 self._check_safe_to_upsert,
             )
 
-    def is_running(self):
+    def is_running(self) -> bool:
         """Is the database pool currently running
         """
         return self._db_pool.running
 
-    @defer.inlineCallbacks
-    def _check_safe_to_upsert(self):
+    async def _check_safe_to_upsert(self) -> None:
         """
         Is it safe to use native UPSERT?
 
@@ -342,7 +349,7 @@ class DatabasePool(object):
 
         If the background updates have not completed, wait 15 sec and check again.
         """
-        updates = yield self.simple_select_list(
+        updates = await self.simple_select_list(
             "background_updates",
             keyvalues=None,
             retcols=["update_name"],
@@ -364,7 +371,7 @@ class DatabasePool(object):
                 self._check_safe_to_upsert,
             )
 
-    def start_profiling(self):
+    def start_profiling(self) -> None:
         self._previous_loop_ts = monotonic_time()
 
         def loop():
@@ -388,8 +395,15 @@ class DatabasePool(object):
         self._clock.looping_call(loop, 10000)
 
     def new_transaction(
-        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
-    ):
+        self,
+        conn: Connection,
+        desc: str,
+        after_callbacks: List[_CallbackListEntry],
+        exception_callbacks: List[_CallbackListEntry],
+        func: "Callable[..., R]",
+        *args: Any,
+        **kwargs: Any
+    ) -> R:
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -517,14 +531,16 @@ class DatabasePool(object):
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
-            result = yield self.runWithConnection(
-                self.new_transaction,
-                desc,
-                after_callbacks,
-                exception_callbacks,
-                func,
-                *args,
-                **kwargs
+            result = yield defer.ensureDeferred(
+                self.runWithConnection(
+                    self.new_transaction,
+                    desc,
+                    after_callbacks,
+                    exception_callbacks,
+                    func,
+                    *args,
+                    **kwargs
+                )
             )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
@@ -536,8 +552,9 @@ class DatabasePool(object):
 
         return result
 
-    @defer.inlineCallbacks
-    def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+    async def runWithConnection(
+        self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+    ) -> R:
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
         Arguments:
@@ -548,7 +565,7 @@ class DatabasePool(object):
             kwargs: named args to pass to `func`
 
         Returns:
-            Deferred: The result of func
+            The result of func
         """
         parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
         if not parent_context:
@@ -571,18 +588,16 @@ class DatabasePool(object):
 
                 return func(conn, *args, **kwargs)
 
-        result = yield make_deferred_yieldable(
+        return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
         )
 
-        return result
-
     @staticmethod
-    def cursor_to_dict(cursor):
+    def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
         """Converts a SQL cursor into an list of dicts.
 
         Args:
-            cursor : The DBAPI cursor which has executed a query.
+            cursor: The DBAPI cursor which has executed a query.
         Returns:
             A list of dicts where the key is the column header.
         """
@@ -590,7 +605,7 @@ class DatabasePool(object):
         results = [dict(zip(col_headers, row)) for row in cursor]
         return results
 
-    def execute(self, desc, decoder, query, *args):
+    def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
         """Runs a single query for a result set.
 
         Args:
@@ -599,7 +614,7 @@ class DatabasePool(object):
             query - The query string to execute
             *args - Query args.
         Returns:
-            The result of decoder(results)
+            Deferred which results to the result of decoder(results)
         """
 
         def interaction(txn):
@@ -614,24 +629,28 @@ class DatabasePool(object):
     # "Simple" SQL API methods that operate on a single table with no JOINs,
     # no complex WHERE clauses, just a dict of values for columns.
 
-    @defer.inlineCallbacks
-    def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+    async def simple_insert(
+        self,
+        table: str,
+        values: Dict[str, Any],
+        or_ignore: bool = False,
+        desc: str = "simple_insert",
+    ) -> bool:
         """Executes an INSERT query on the named table.
 
         Args:
-            table : string giving the table name
-            values : dict of new column names and values for them
-            or_ignore : bool stating whether an exception should be raised
+            table: string giving the table name
+            values: dict of new column names and values for them
+            or_ignore: bool stating whether an exception should be raised
                 when a conflicting row already exists. If True, False will be
                 returned by the function instead
-            desc : string giving a description of the transaction
+            desc: string giving a description of the transaction
 
         Returns:
-            bool: Whether the row was inserted or not. Only useful when
-            `or_ignore` is True
+             Whether the row was inserted or not. Only useful when `or_ignore` is True
         """
         try:
-            yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+            await self.runInteraction(desc, self.simple_insert_txn, table, values)
         except self.engine.module.IntegrityError:
             # We have to do or_ignore flag at this layer, since we can't reuse
             # a cursor after we receive an error from the db.
@@ -641,7 +660,9 @@ class DatabasePool(object):
         return True
 
     @staticmethod
-    def simple_insert_txn(txn, table, values):
+    def simple_insert_txn(
+        txn: LoggingTransaction, table: str, values: Dict[str, Any]
+    ) -> None:
         keys, vals = zip(*values.items())
 
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -652,11 +673,15 @@ class DatabasePool(object):
 
         txn.execute(sql, vals)
 
-    def simple_insert_many(self, table, values, desc):
+    def simple_insert_many(
+        self, table: str, values: List[Dict[str, Any]], desc: str
+    ) -> defer.Deferred:
         return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
 
     @staticmethod
-    def simple_insert_many_txn(txn, table, values):
+    def simple_insert_many_txn(
+        txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+    ) -> None:
         if not values:
             return
 
@@ -684,16 +709,15 @@ class DatabasePool(object):
 
         txn.executemany(sql, vals)
 
-    @defer.inlineCallbacks
-    def simple_upsert(
+    async def simple_upsert(
         self,
-        table,
-        keyvalues,
-        values,
-        insertion_values={},
-        desc="simple_upsert",
-        lock=True,
-    ):
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        desc: str = "simple_upsert",
+        lock: bool = True,
+    ) -> Optional[bool]:
         """
 
         `lock` should generally be set to True (the default), but can be set
@@ -707,21 +731,19 @@ class DatabasePool(object):
           this table.
 
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key columns and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key columns and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            Deferred(None or bool): Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
+            Native upserts always return None. Emulated upserts return True if a
+            new entry was created, False if an existing one was updated.
         """
         attempts = 0
         while True:
             try:
-                result = yield self.runInteraction(
+                return await self.runInteraction(
                     desc,
                     self.simple_upsert_txn,
                     table,
@@ -730,7 +752,6 @@ class DatabasePool(object):
                     insertion_values,
                     lock=lock,
                 )
-                return result
             except self.engine.module.IntegrityError as e:
                 attempts += 1
                 if attempts >= 5:
@@ -744,29 +765,34 @@ class DatabasePool(object):
                 )
 
     def simple_upsert_txn(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        lock: bool = True,
+    ) -> Optional[bool]:
         """
         Pick the UPSERT method which works best on the platform. Either the
         native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
 
         Args:
             txn: The transaction to use.
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            None or bool: Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
+            Native upserts always return None. Emulated upserts return True if a
+            new entry was created, False if an existing one was updated.
         """
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
-            return self.simple_upsert_txn_native_upsert(
+            self.simple_upsert_txn_native_upsert(
                 txn, table, keyvalues, values, insertion_values=insertion_values
             )
+            return None
         else:
             return self.simple_upsert_txn_emulated(
                 txn,
@@ -778,18 +804,23 @@ class DatabasePool(object):
             )
 
     def simple_upsert_txn_emulated(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        lock: bool = True,
+    ) -> bool:
         """
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            bool: Return True if a new entry was created, False if an existing
+            Returns True if a new entry was created, False if an existing
             one was updated.
         """
         # We need to lock the table :(, unless we're *really* careful
@@ -847,19 +878,21 @@ class DatabasePool(object):
         return True
 
     def simple_upsert_txn_native_upsert(
-        self, txn, table, keyvalues, values, insertion_values={}
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+    ) -> None:
         """
         Use the native UPSERT functionality in recent PostgreSQL versions.
 
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-        Returns:
-            None
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
         """
         allvalues = {}  # type: Dict[str, Any]
         allvalues.update(keyvalues)
@@ -989,41 +1022,70 @@ class DatabasePool(object):
 
         return txn.execute_batch(sql, args)
 
-    def simple_select_one(
-        self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
-    ):
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[False] = False,
+        desc: str = "simple_select_one",
+    ) -> Dict[str, Any]:
+        ...
+
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[True] = True,
+        desc: str = "simple_select_one",
+    ) -> Optional[Dict[str, Any]]:
+        ...
+
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: bool = False,
+        desc: str = "simple_select_one",
+    ) -> Optional[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning multiple columns from it.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcols : list of strings giving the names of the columns to return
-
-            allow_none : If true, return None instead of failing if the SELECT
-              statement returns no rows
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            retcols: list of strings giving the names of the columns to return
+            allow_none: If true, return None instead of failing if the SELECT
+                statement returns no rows
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
         )
 
-    def simple_select_one_onecol(
+    async def simple_select_one_onecol(
         self,
-        table,
-        keyvalues,
-        retcol,
-        allow_none=False,
-        desc="simple_select_one_onecol",
-    ):
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: bool = False,
+        desc: str = "simple_select_one_onecol",
+    ) -> Optional[Any]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcol : string giving the name of the column to return
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            retcol: string giving the name of the column to return
+            allow_none: If true, return None instead of failing if the SELECT
+                statement returns no rows
+            desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc,
             self.simple_select_one_onecol_txn,
             table,
@@ -1034,8 +1096,13 @@ class DatabasePool(object):
 
     @classmethod
     def simple_select_one_onecol_txn(
-        cls, txn, table, keyvalues, retcol, allow_none=False
-    ):
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: bool = False,
+    ) -> Optional[Any]:
         ret = cls.simple_select_onecol_txn(
             txn, table=table, keyvalues=keyvalues, retcol=retcol
         )
@@ -1049,7 +1116,12 @@ class DatabasePool(object):
                 raise StoreError(404, "No row found")
 
     @staticmethod
-    def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+    def simple_select_onecol_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+    ) -> List[Any]:
         sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
 
         if keyvalues:
@@ -1061,15 +1133,19 @@ class DatabasePool(object):
         return [r[0] for r in txn]
 
     def simple_select_onecol(
-        self, table, keyvalues, retcol, desc="simple_select_onecol"
-    ):
+        self,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcol: str,
+        desc: str = "simple_select_onecol",
+    ) -> defer.Deferred:
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
 
         Args:
-            table (str): table name
-            keyvalues (dict|None): column names and values to select the rows with
-            retcol (str): column whos value we wish to retrieve.
+            table: table name
+            keyvalues: column names and values to select the rows with
+            retcol: column whos value we wish to retrieve.
 
         Returns:
             Deferred: Results in a list
@@ -1078,16 +1154,22 @@ class DatabasePool(object):
             desc, self.simple_select_onecol_txn, table, keyvalues, retcol
         )
 
-    def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+    def simple_select_list(
+        self,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcols: Iterable[str],
+        desc: str = "simple_select_list",
+    ) -> defer.Deferred:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            keyvalues (dict[str, Any] | None):
+            table: the table name
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
+            retcols: the names of the columns to return
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
@@ -1096,17 +1178,23 @@ class DatabasePool(object):
         )
 
     @classmethod
-    def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+    def simple_select_list_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcols: Iterable[str],
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
+            txn: Transaction object
+            table: the table name
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
+            retcols: the names of the columns to return
         """
         if keyvalues:
             sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1121,28 +1209,27 @@ class DatabasePool(object):
 
         return cls.cursor_to_dict(txn)
 
-    @defer.inlineCallbacks
-    def simple_select_many_batch(
+    async def simple_select_many_batch(
         self,
-        table,
-        column,
-        iterable,
-        retcols,
-        keyvalues={},
-        desc="simple_select_many_batch",
-        batch_size=100,
-    ):
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        retcols: Iterable[str],
+        keyvalues: Dict[str, Any] = {},
+        desc: str = "simple_select_many_batch",
+        batch_size: int = 100,
+    ) -> List[Any]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
+            retcols: list of strings giving the names of the columns to return
         """
         results = []  # type: List[Dict[str, Any]]
 
@@ -1156,7 +1243,7 @@ class DatabasePool(object):
             it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
         ]
         for chunk in chunks:
-            rows = yield self.runInteraction(
+            rows = await self.runInteraction(
                 desc,
                 self.simple_select_many_txn,
                 table,
@@ -1171,19 +1258,27 @@ class DatabasePool(object):
         return results
 
     @classmethod
-    def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+    def simple_select_many_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
+            txn: Transaction object
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
+            retcols: list of strings giving the names of the columns to return
         """
         if not iterable:
             return []
@@ -1204,13 +1299,24 @@ class DatabasePool(object):
         txn.execute(sql, values)
         return cls.cursor_to_dict(txn)
 
-    def simple_update(self, table, keyvalues, updatevalues, desc):
+    def simple_update(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+        desc: str,
+    ) -> defer.Deferred:
         return self.runInteraction(
             desc, self.simple_update_txn, table, keyvalues, updatevalues
         )
 
     @staticmethod
-    def simple_update_txn(txn, table, keyvalues, updatevalues):
+    def simple_update_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+    ) -> int:
         if keyvalues:
             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
         else:
@@ -1227,31 +1333,32 @@ class DatabasePool(object):
         return txn.rowcount
 
     def simple_update_one(
-        self, table, keyvalues, updatevalues, desc="simple_update_one"
-    ):
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+        desc: str = "simple_update_one",
+    ) -> defer.Deferred:
         """Executes an UPDATE query on the named table, setting new values for
         columns in a row matching the key values.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            updatevalues : dict giving column names and values to update
-            retcols : optional list of column names to return
-
-        If present, retcols gives a list of column names on which to perform
-        a SELECT statement *before* performing the UPDATE statement. The values
-        of these will be returned in a dict.
-
-        These are performed within the same transaction, allowing an atomic
-        get-and-set.  This can be used to implement compare-and-set by putting
-        the update column in the 'keyvalues' dict as well.
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            updatevalues: dict giving column names and values to update
         """
         return self.runInteraction(
             desc, self.simple_update_one_txn, table, keyvalues, updatevalues
         )
 
     @classmethod
-    def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+    def simple_update_one_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+    ) -> None:
         rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
 
         if rowcount == 0:
@@ -1259,8 +1366,18 @@ class DatabasePool(object):
         if rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
+    # Ideally we could use the overload decorator here to specify that the
+    # return type is only optional if allow_none is True, but this does not work
+    # when you call a static method from an instance.
+    # See https://github.com/python/mypy/issues/7781
     @staticmethod
-    def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+    def simple_select_one_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: bool = False,
+    ) -> Optional[Dict[str, Any]]:
         select_sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
@@ -1279,24 +1396,28 @@ class DatabasePool(object):
 
         return dict(zip(retcols, row))
 
-    def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+    def simple_delete_one(
+        self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+    ) -> defer.Deferred:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
         """
         return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
 
     @staticmethod
-    def simple_delete_one_txn(txn, table, keyvalues):
+    def simple_delete_one_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+    ) -> None:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
         """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
@@ -1309,11 +1430,13 @@ class DatabasePool(object):
         if txn.rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
-    def simple_delete(self, table, keyvalues, desc):
+    def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str):
         return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
 
     @staticmethod
-    def simple_delete_txn(txn, table, keyvalues):
+    def simple_delete_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+    ) -> int:
         sql = "DELETE FROM %s WHERE %s" % (
             table,
             " AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1322,26 +1445,39 @@ class DatabasePool(object):
         txn.execute(sql, list(keyvalues.values()))
         return txn.rowcount
 
-    def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+    def simple_delete_many(
+        self,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+        desc: str,
+    ) -> defer.Deferred:
         return self.runInteraction(
             desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
         )
 
     @staticmethod
-    def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+    def simple_delete_many_txn(
+        txn: LoggingTransaction,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+    ) -> int:
         """Executes a DELETE query on the named table.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
+            txn: Transaction object
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
 
         Returns:
-            int: Number rows deleted
+            Number rows deleted
         """
         if not iterable:
             return 0
@@ -1362,8 +1498,14 @@ class DatabasePool(object):
         return txn.rowcount
 
     def get_cache_dict(
-        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
-    ):
+        self,
+        db_conn: Connection,
+        table: str,
+        entity_column: str,
+        stream_column: str,
+        max_value: int,
+        limit: int = 100000,
+    ) -> Tuple[Dict[Any, int], int]:
         # Fetch a mapping of room_id -> max stream position for "recent" rooms.
         # It doesn't really matter how many we get, the StreamChangeCache will
         # do the right thing to ensure it respects the max size of cache.
@@ -1396,34 +1538,34 @@ class DatabasePool(object):
 
     def simple_select_list_paginate(
         self,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-        desc="simple_select_list_paginate",
-    ):
+        table: str,
+        orderby: str,
+        start: int,
+        limit: int,
+        retcols: Iterable[str],
+        filters: Optional[Dict[str, Any]] = None,
+        keyvalues: Optional[Dict[str, Any]] = None,
+        order_direction: str = "ASC",
+        desc: str = "simple_select_list_paginate",
+    ) -> defer.Deferred:
         """
         Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
         returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            filters (dict[str, T] | None):
+            table: the table name
+            orderby: Column to order the results by.
+            start: Index to begin the query at.
+            limit: Number of results to return.
+            retcols: the names of the columns to return
+            filters:
                 column names and values to filter the rows with, or None to not
                 apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+            order_direction: Whether the results should be ordered "ASC" or "DESC".
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
@@ -1443,16 +1585,16 @@ class DatabasePool(object):
     @classmethod
     def simple_select_list_paginate_txn(
         cls,
-        txn,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-    ):
+        txn: LoggingTransaction,
+        table: str,
+        orderby: str,
+        start: int,
+        limit: int,
+        retcols: Iterable[str],
+        filters: Optional[Dict[str, Any]] = None,
+        keyvalues: Optional[Dict[str, Any]] = None,
+        order_direction: str = "ASC",
+    ) -> List[Dict[str, Any]]:
         """
         Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
@@ -1463,21 +1605,22 @@ class DatabasePool(object):
         using 'AND'.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            filters (dict[str, T] | None):
+            txn: Transaction object
+            table: the table name
+            orderby: Column to order the results by.
+            start: Index to begin the query at.
+            limit: Number of results to return.
+            retcols: the names of the columns to return
+            filters:
                 column names and values to filter the rows with, or None to not
                 apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+            order_direction: Whether the results should be ordered "ASC" or "DESC".
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            The result as a list of dictionaries.
         """
         if order_direction not in ["ASC", "DESC"]:
             raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -1503,16 +1646,23 @@ class DatabasePool(object):
 
         return cls.cursor_to_dict(txn)
 
-    def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+    def simple_search_list(
+        self,
+        table: str,
+        term: Optional[str],
+        col: str,
+        retcols: Iterable[str],
+        desc="simple_search_list",
+    ):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
+            table: the table name
+            term: term for searching the table matched to a column.
+            col: column to query term should be matched to
+            retcols: the names of the columns to return
+
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]] or None
         """
@@ -1522,19 +1672,26 @@ class DatabasePool(object):
         )
 
     @classmethod
-    def simple_search_list_txn(cls, txn, table, term, col, retcols):
+    def simple_search_list_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        term: Optional[str],
+        col: str,
+        retcols: Iterable[str],
+    ) -> Union[List[Dict[str, Any]], int]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
+            txn: Transaction object
+            table: the table name
+            term: term for searching the table matched to a column.
+            col: column to query term should be matched to
+            retcols: the names of the columns to return
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
+            0 if no term is given, otherwise a list of dictionaries.
         """
         if term:
             sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
@@ -1547,7 +1704,7 @@ class DatabasePool(object):
 
 
 def make_in_list_sql_clause(
-    database_engine, column: str, iterable: Iterable
+    database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
 ) -> Tuple[str, list]:
     """Returns an SQL clause that checks the given column is in the iterable.
 
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 4406e58273..0ac854aee2 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -87,12 +87,21 @@ class Databases(object):
 
                 logger.info("Database %r prepared", db_name)
 
+            # Closing the context manager doesn't close the connection.
+            # psycopg will close the connection when the object gets GCed, but *only*
+            # if the PID is the same as when the connection was opened [1], and
+            # it may not be if we fork in the meantime.
+            #
+            # [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378
+
+            db_conn.close()
+
         # Sanity check that we have actually configured all the required stores.
         if not main:
             raise Exception("No 'main' data store configured")
 
         if not state:
-            raise Exception("No 'main' data store configured")
+            raise Exception("No 'state' data store configured")
 
         # We use local variables here to ensure that the databases do not have
         # optional types.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 17fa470919..0934ae276c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -498,7 +498,7 @@ class DataStore(
         )
 
     def get_users_paginate(
-        self, start, limit, name=None, guests=True, deactivated=False
+        self, start, limit, user_id=None, name=None, guests=True, deactivated=False
     ):
         """Function to retrieve a paginated list of users from
         users list. This will return a json list of users and the
@@ -507,7 +507,8 @@ class DataStore(
         Args:
             start (int): start number to begin the query from
             limit (int): number of rows to retrieve
-            name (string): filter for user names
+            user_id (string): search for user_id. ignored if name is not None
+            name (string): search for local part of user_id or display name
             guests (bool): whether to in include guest users
             deactivated (bool): whether to include deactivated users
         Returns:
@@ -516,11 +517,14 @@ class DataStore(
 
         def get_users_paginate_txn(txn):
             filters = []
-            args = []
+            args = [self.hs.config.server_name]
 
             if name:
+                filters.append("(name LIKE ? OR displayname LIKE ?)")
+                args.extend(["@%" + name + "%:%", "%" + name + "%"])
+            elif user_id:
                 filters.append("name LIKE ?")
-                args.append("%" + name + "%")
+                args.extend(["%" + user_id + "%"])
 
             if not guests:
                 filters.append("is_guest = 0")
@@ -530,20 +534,23 @@ class DataStore(
 
             where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
 
-            sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
-            txn.execute(sql, args)
-            count = txn.fetchone()[0]
-
-            args = [self.hs.config.server_name] + args + [limit, start]
-            sql = """
-                SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+            sql_base = """
                 FROM users as u
                 LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
                 {}
-                ORDER BY u.name LIMIT ? OFFSET ?
                 """.format(
                 where_clause
             )
+            sql = "SELECT COUNT(*) as total_users " + sql_base
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = (
+                "SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+                + sql_base
+                + " ORDER BY u.name LIMIT ? OFFSET ?"
+            )
+            args += [limit, start]
             txn.execute(sql, args)
             users = self.db_pool.cursor_to_dict(txn)
             return users, count
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 82aac2bbf3..04042a2c98 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as room_account_data has a unique constraint
             # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as account_data has a unique constraint on
             # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5cf1a88399..77723f7d4d 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -16,13 +16,12 @@
 import logging
 import re
 
-from canonicaljson import json
-
 from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util import json_encoder
 
 logger = logging.getLogger(__name__)
 
@@ -169,7 +168,7 @@ class ApplicationServiceTransactionWorkerStore(
             service(ApplicationService): The service whose state to set.
             state(ApplicationServiceState): The connectivity state to apply.
         Returns:
-            A Deferred which resolves when the state was set successfully.
+            An Awaitable which resolves when the state was set successfully.
         """
         return self.db_pool.simple_upsert(
             "application_services_state", {"as_id": service.id}, {"state": state}
@@ -204,7 +203,7 @@ class ApplicationServiceTransactionWorkerStore(
             new_txn_id = max(highest_txn_id, last_txn_id) + 1
 
             # Insert new txn into txn table
-            event_ids = json.dumps([e.event_id for e in events])
+            event_ids = json_encoder.encode([e.event_id for e in events])
             txn.execute(
                 "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
                 "VALUES(?,?,?)",
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 10de446065..1e7637a6f5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                 },
             )
 
-    def get_cache_stream_token(self, instance_name):
+    def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
         if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token(instance_name)
+            return self._cache_id_gen.get_current_token_for_writer(instance_name)
         else:
             return 0
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 1f6e995c4f..bb85637a95 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
-        with self._device_inbox_id_gen.get_next() as stream_id:
+        with await self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 txn, stream_id, local_messages_by_user_then_device
             )
 
-        with self._device_inbox_id_gen.get_next() as stream_id:
+        with await self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 2b33060480..a811a39eb5 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
@@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
-    def get_device(self, user_id: str, device_id: str):
+    async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
 
@@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
             user_id: The ID of the user which owns the device
             device_id: The ID of the device to retrieve
         Returns:
-            defer.Deferred for a dict containing the device information
+            A dict containing the device information
         Raises:
             StoreError: if the device is not found
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
             THe new stream ID.
         """
 
-        with self._device_list_id_gen.get_next() as stream_id:
+        with await self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     @cached(max_entries=10000)
-    def get_device_list_last_stream_id_for_remote(self, user_id: str):
+    async def get_device_list_last_stream_id_for_remote(
+        self, user_id: str
+    ) -> Optional[Any]:
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
             retcol="stream_id",
@@ -671,10 +673,9 @@ class DeviceWorkerStore(SQLBaseStore):
     @cachedList(
         cached_method_name="get_device_list_last_stream_id_for_remote",
         list_name="user_ids",
-        inlineCallbacks=True,
     )
-    def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
-        rows = yield self.db_pool.simple_select_many_batch(
+    async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
+        rows = await self.db_pool.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
             iterable=user_ids,
@@ -1147,7 +1148,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return
 
-        with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+        with await self._device_list_id_gen.get_next_mult(
+            len(device_ids)
+        ) as stream_ids:
             await self.db_pool.runInteraction(
                 "add_device_change_to_stream",
                 self._add_device_change_to_stream_txn,
@@ -1160,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return stream_ids[-1]
 
         context = get_active_span_text_map()
-        with self._device_list_id_gen.get_next_mult(
+        with await self._device_list_id_gen.get_next_mult(
             len(hosts) * len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 037e02603c..301d5d845a 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
 
         return RoomAliasMapping(room_id, room_alias.to_string(), servers)
 
-    def get_room_alias_creator(self, room_alias):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_room_alias_creator(self, room_alias: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="room_aliases",
             keyvalues={"room_alias": room_alias},
             retcol="creator",
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index 2eeb9f97dc..46c3e33cc6 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         return ret
 
-    def count_e2e_room_keys(self, user_id, version):
+    async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
         """Get the number of keys in a backup version.
 
         Args:
-            user_id (str): the user whose backup we're querying
-            version (str): the version ID of the backup we're querying about
+            user_id: the user whose backup we're querying
+            version: the version ID of the backup we're querying about
         """
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="e2e_room_keys",
             keyvalues={"user_id": user_id, "version": version},
             retcol="COUNT(*)",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index f93e0d320d..385868bdab 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
-    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+    def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
         """Set a user's cross-signing key.
 
         Args:
@@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 for a master key, 'self_signing' for a self-signing key, or
                 'user_signing' for a user-signing key
             key (dict): the key data
+            stream_id (int)
         """
         # the 'key' dict will look something like:
         # {
@@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             )
 
         # and finally, store the key itself
-        with self._cross_signing_id_gen.get_next() as stream_id:
-            self.db_pool.simple_insert_txn(
-                txn,
-                "e2e_cross_signing_keys",
-                values={
-                    "user_id": user_id,
-                    "keytype": key_type,
-                    "keydata": json_encoder.encode(key),
-                    "stream_id": stream_id,
-                },
-            )
+        self.db_pool.simple_insert_txn(
+            txn,
+            "e2e_cross_signing_keys",
+            values={
+                "user_id": user_id,
+                "keytype": key_type,
+                "keydata": json_encoder.encode(key),
+                "stream_id": stream_id,
+            },
+        )
 
         self._invalidate_cache_and_stream(
             txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
         )
 
-    def set_e2e_cross_signing_key(self, user_id, key_type, key):
+    async def set_e2e_cross_signing_key(self, user_id, key_type, key):
         """Set a user's cross-signing key.
 
         Args:
@@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key_type (str): the type of cross-signing key to set
             key (dict): the key data
         """
-        return self.db_pool.runInteraction(
-            "add_e2e_cross_signing_key",
-            self._set_e2e_cross_signing_key_txn,
-            user_id,
-            key_type,
-            key,
-        )
+
+        with await self._cross_signing_id_gen.get_next() as stream_id:
+            return await self.db_pool.runInteraction(
+                "add_e2e_cross_signing_key",
+                self._set_e2e_cross_signing_key_txn,
+                user_id,
+                key_type,
+                key,
+                stream_id,
+            )
 
     def store_e2e_cross_signing_signatures(self, user_id, signatures):
         """Stores cross-signing signatures.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 484875f989..e6a97b018c 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -15,14 +15,16 @@
 import itertools
 import logging
 from queue import Empty, PriorityQueue
-from typing import Dict, Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Set, Tuple
 
 from synapse.api.errors import StoreError
+from synapse.events import EventBase
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.signatures import SignatureWorkerStore
+from synapse.types import Collection
 from synapse.util.caches.descriptors import cached
 from synapse.util.iterutils import batch_iter
 
@@ -30,57 +32,51 @@ logger = logging.getLogger(__name__)
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
-    def get_auth_chain(self, event_ids, include_given=False):
+    async def get_auth_chain(
+        self, event_ids: Collection[str], include_given: bool = False
+    ) -> List[EventBase]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
-            event_ids (list): state events
-            include_given (bool): include the given events in result
+            event_ids: state events
+            include_given: include the given events in result
 
         Returns:
             list of events
         """
-        return self.get_auth_chain_ids(
+        event_ids = await self.get_auth_chain_ids(
             event_ids, include_given=include_given
-        ).addCallback(self.get_events_as_list)
-
-    def get_auth_chain_ids(
-        self,
-        event_ids: List[str],
-        include_given: bool = False,
-        ignore_events: Optional[Set[str]] = None,
-    ):
+        )
+        return await self.get_events_as_list(event_ids)
+
+    async def get_auth_chain_ids(
+        self, event_ids: Collection[str], include_given: bool = False,
+    ) -> List[str]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
             event_ids: state events
             include_given: include the given events in result
-            ignore_events: Set of events to exclude from the returned auth
-                chain. This is useful if the caller will just discard the
-                given events anyway, and saves us from figuring out their auth
-                chains if not required.
 
         Returns:
             list of event_ids
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_auth_chain_ids",
             self._get_auth_chain_ids_txn,
             event_ids,
             include_given,
-            ignore_events,
         )
 
-    def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
-        if ignore_events is None:
-            ignore_events = set()
-
+    def _get_auth_chain_ids_txn(
+        self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
+    ) -> List[str]:
         if include_given:
             results = set(event_ids)
         else:
             results = set()
 
-        base_sql = "SELECT auth_id FROM event_auth WHERE "
+        base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
 
         front = set(event_ids)
         while front:
@@ -92,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 txn.execute(base_sql + clause, args)
                 new_front.update(r[0] for r in txn)
 
-            new_front -= ignore_events
             new_front -= results
 
             front = new_front
@@ -257,11 +252,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
-    def get_oldest_events_in_room(self, room_id):
-        return self.db_pool.runInteraction(
-            "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
-        )
-
     def get_oldest_events_with_depth_in_room(self, room_id):
         return self.db_pool.runInteraction(
             "get_oldest_events_with_depth_in_room",
@@ -303,14 +293,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         else:
             return max(row["depth"] for row in rows)
 
-    def _get_oldest_events_in_room_txn(self, txn, room_id):
-        return self.db_pool.simple_select_onecol_txn(
-            txn,
-            table="event_backward_extremities",
-            keyvalues={"room_id": room_id},
-            retcol="event_id",
-        )
-
     def get_prev_events_for_room(self, room_id: str):
         """
         Gets a subset of the current forward extremities in the given room.
@@ -472,7 +454,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
-    def get_backfill_events(self, room_id, event_list, limit):
+    async def get_backfill_events(self, room_id, event_list, limit):
         """Get a list of Events for a given topic that occurred before (and
         including) the events in event_list. Return a list of max size `limit`
 
@@ -482,17 +464,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             event_list (list)
             limit (int)
         """
-        return (
-            self.db_pool.runInteraction(
-                "get_backfill_events",
-                self._get_backfill_events,
-                room_id,
-                event_list,
-                limit,
-            )
-            .addCallback(self.get_events_as_list)
-            .addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
+        event_ids = await self.db_pool.runInteraction(
+            "get_backfill_events",
+            self._get_backfill_events,
+            room_id,
+            event_list,
+            limit,
         )
+        events = await self.get_events_as_list(event_ids)
+        return sorted(events, key=lambda e: -e.depth)
 
     def _get_backfill_events(self, txn, room_id, event_list, limit):
         logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -553,8 +533,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             latest_events,
             limit,
         )
-        events = await self.get_events_as_list(ids)
-        return events
+        return await self.get_events_as_list(ids)
 
     def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
 
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7c246d3e4c..e8834b2162 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
 
@@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         self._rotate_delay = 3
         self._rotate_count = 10000
 
-    @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
-    def get_unread_event_push_actions_by_room_for_user(
+    @cached(num_args=3, tree=True, max_entries=5000)
+    async def get_unread_event_push_actions_by_room_for_user(
         self, room_id, user_id, last_read_event_id
     ):
-        ret = yield self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
             room_id,
             user_id,
             last_read_event_id,
         )
-        return ret
 
     def _get_unread_counts_by_receipt_txn(
         self, txn, room_id, user_id, last_read_event_id
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1a68bf32cb..6313b41eef 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,13 +17,11 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
 
 import attr
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 import synapse.metrics
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.room_versions import RoomVersions
@@ -113,15 +111,14 @@ class PersistEventsStore:
             hs.config.worker.writers.events == hs.get_instance_name()
         ), "Can only instantiate EventsStore on master"
 
-    @defer.inlineCallbacks
-    def _persist_events_and_state_updates(
+    async def _persist_events_and_state_updates(
         self,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         current_state_for_room: Dict[str, StateMap[str]],
         state_delta_for_room: Dict[str, DeltaState],
         new_forward_extremeties: Dict[str, List[str]],
         backfilled: bool = False,
-    ):
+    ) -> None:
         """Persist a set of events alongside updates to the current state and
         forward extremities tables.
 
@@ -136,7 +133,7 @@ class PersistEventsStore:
             backfilled
 
         Returns:
-            Deferred: resolves when the events have been persisted
+            Resolves when the events have been persisted
         """
 
         # We want to calculate the stream orderings as late as possible, as
@@ -156,11 +153,11 @@ class PersistEventsStore:
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
         if backfilled:
-            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+            stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
         else:
-            stream_ordering_manager = self._stream_id_gen.get_next_mult(
+            stream_ordering_manager = await self._stream_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
 
@@ -168,7 +165,7 @@ class PersistEventsStore:
             for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "persist_events",
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
@@ -206,16 +203,15 @@ class PersistEventsStore:
                     (room_id,), list(latest_event_ids)
                 )
 
-    @defer.inlineCallbacks
-    def _get_events_which_are_prevs(self, event_ids):
+    async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
         """Filter the supplied list of event_ids to get those which are prev_events of
         existing (non-outlier/rejected) events.
 
         Args:
-            event_ids (Iterable[str]): event ids to filter
+            event_ids: event ids to filter
 
         Returns:
-            Deferred[List[str]]: filtered event ids
+            Filtered event ids
         """
         results = []
 
@@ -240,14 +236,13 @@ class PersistEventsStore:
             results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
             )
 
         return results
 
-    @defer.inlineCallbacks
-    def _get_prevs_before_rejected(self, event_ids):
+    async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
         """Get soft-failed ancestors to remove from the extremities.
 
         Given a set of events, find all those that have been soft-failed or
@@ -259,11 +254,11 @@ class PersistEventsStore:
         are separated by soft failed events.
 
         Args:
-            event_ids (Iterable[str]): Events to find prev events for. Note
-                that these must have already been persisted.
+            event_ids: Events to find prev events for. Note that these must have
+                already been persisted.
 
         Returns:
-            Deferred[set[str]]
+            The previous events.
         """
 
         # The set of event_ids to return. This includes all soft-failed events
@@ -304,7 +299,7 @@ class PersistEventsStore:
                         existing_prevs.add(prev_event_id)
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
             )
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 35a0e09e3c..e53c6373a8 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventContentFields
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
@@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             where_clause="NOT have_censored",
         )
 
-    @defer.inlineCallbacks
-    def _background_reindex_fields_sender(self, progress, batch_size):
+    async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(rows)
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
         )
 
         if not result:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
             )
 
         return result
 
-    @defer.inlineCallbacks
-    def _background_reindex_origin_server_ts(self, progress, batch_size):
+    async def _background_reindex_origin_server_ts(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
         rows_inserted = progress.get("rows_inserted", 0)
@@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(rows_to_update)
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
         )
 
         if not result:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.EVENT_ORIGIN_SERVER_TS_NAME
             )
 
         return result
 
-    @defer.inlineCallbacks
-    def _cleanup_extremities_bg_update(self, progress, batch_size):
+    async def _cleanup_extremities_bg_update(self, progress, batch_size):
         """Background update to clean out extremities that should have been
         deleted previously.
 
@@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(original_set)
 
-        num_handled = yield self.db_pool.runInteraction(
+        num_handled = await self.db_pool.runInteraction(
             "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
         )
 
         if not num_handled:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 self.DELETE_SOFT_FAILED_EXTREMITIES
             )
 
             def _drop_table_txn(txn):
                 txn.execute("DROP TABLE _extremities_to_check")
 
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
             )
 
         return num_handled
 
-    @defer.inlineCallbacks
-    def _redactions_received_ts(self, progress, batch_size):
+    async def _redactions_received_ts(self, progress, batch_size):
         """Handles filling out the `received_ts` column in redactions.
         """
         last_event_id = progress.get("last_event_id", "")
@@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return len(rows)
 
-        count = yield self.db_pool.runInteraction(
+        count = await self.db_pool.runInteraction(
             "_redactions_received_ts", _redactions_received_ts_txn
         )
 
         if not count:
-            yield self.db_pool.updates._end_background_update("redactions_received_ts")
+            await self.db_pool.updates._end_background_update("redactions_received_ts")
 
         return count
 
-    @defer.inlineCallbacks
-    def _event_fix_redactions_bytes(self, progress, batch_size):
+    async def _event_fix_redactions_bytes(self, progress, batch_size):
         """Undoes hex encoded censored redacted event JSON.
         """
 
@@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             txn.execute("DROP INDEX redactions_censored_redacts")
 
-        yield self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
         )
 
-        yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
+        await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
 
         return 1
 
-    @defer.inlineCallbacks
-    def _event_store_labels(self, progress, batch_size):
+    async def _event_store_labels(self, progress, batch_size):
         """Background update handler which will store labels for existing events."""
         last_event_id = progress.get("last_event_id", "")
 
@@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             return nbrows
 
-        num_rows = yield self.db_pool.runInteraction(
+        num_rows = await self.db_pool.runInteraction(
             desc="event_store_labels", func=_event_store_labels_txn
         )
 
         if not num_rows:
-            yield self.db_pool.updates._end_background_update("event_store_labels")
+            await self.db_pool.updates._end_background_update("event_store_labels")
 
         return num_rows
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 755b7a2a85..e6247d682d 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -19,9 +19,10 @@ import itertools
 import logging
 import threading
 from collections import namedtuple
-from typing import List, Optional, Tuple
+from typing import Dict, Iterable, List, Optional, Tuple, overload
 
 from constantly import NamedConstant, Names
+from typing_extensions import Literal
 
 from twisted.internet import defer
 
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
     EventFormatVersions,
     RoomVersions,
 )
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.caches.descriptors import Cache, cached
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -112,69 +113,58 @@ class EventsWorkerStore(SQLBaseStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == EventsStream.NAME:
-            self._stream_id_gen.advance(token)
+            self._stream_id_gen.advance(instance_name, token)
         elif stream_name == BackfillStream.NAME:
-            self._backfill_id_gen.advance(-token)
+            self._backfill_id_gen.advance(instance_name, -token)
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def get_received_ts(self, event_id):
+    async def get_received_ts(self, event_id: str) -> Optional[int]:
         """Get received_ts (when it was persisted) for the event.
 
         Raises an exception for unknown events.
 
         Args:
-            event_id (str)
+            event_id: The event ID to query.
 
         Returns:
-            Deferred[int|None]: Timestamp in milliseconds, or None for events
-            that were persisted before received_ts was implemented.
+            Timestamp in milliseconds, or None for events that were persisted
+            before received_ts was implemented.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="events",
             keyvalues={"event_id": event_id},
             retcol="received_ts",
             desc="get_received_ts",
         )
 
-    def get_received_ts_by_stream_pos(self, stream_ordering):
-        """Given a stream ordering get an approximate timestamp of when it
-        happened.
-
-        This is done by simply taking the received ts of the first event that
-        has a stream ordering greater than or equal to the given stream pos.
-        If none exists returns the current time, on the assumption that it must
-        have happened recently.
-
-        Args:
-            stream_ordering (int)
-
-        Returns:
-            Deferred[int]
-        """
-
-        def _get_approximate_received_ts_txn(txn):
-            sql = """
-                SELECT received_ts FROM events
-                WHERE stream_ordering >= ?
-                LIMIT 1
-            """
-
-            txn.execute(sql, (stream_ordering,))
-            row = txn.fetchone()
-            if row and row[0]:
-                ts = row[0]
-            else:
-                ts = self.clock.time_msec()
-
-            return ts
+    # Inform mypy that if allow_none is False (the default) then get_event
+    # always returns an EventBase.
+    @overload
+    async def get_event(
+        self,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: Literal[False] = False,
+        check_room_id: Optional[str] = None,
+    ) -> EventBase:
+        ...
 
-        return self.db_pool.runInteraction(
-            "get_approximate_received_ts", _get_approximate_received_ts_txn
-        )
+    @overload
+    async def get_event(
+        self,
+        event_id: str,
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: Literal[True] = False,
+        check_room_id: Optional[str] = None,
+    ) -> Optional[EventBase]:
+        ...
 
-    @defer.inlineCallbacks
-    def get_event(
+    async def get_event(
         self,
         event_id: str,
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -182,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
         allow_rejected: bool = False,
         allow_none: bool = False,
         check_room_id: Optional[str] = None,
-    ):
+    ) -> Optional[EventBase]:
         """Get an event from the database by event_id.
 
         Args:
@@ -207,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
-            Deferred[EventBase|None]
+            The event, or None if the event was not found.
         """
         if not isinstance(event_id, str):
             raise TypeError("Invalid event event_id %r" % (event_id,))
 
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [event_id],
             redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
@@ -230,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event
 
-    @defer.inlineCallbacks
-    def get_events(
+    async def get_events(
         self,
-        event_ids: List[str],
+        event_ids: Iterable[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
-    ):
+    ) -> Dict[str, EventBase]:
         """Get events from the database
 
         Args:
@@ -256,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
                 omits rejeted events from the response.
 
         Returns:
-            Deferred : Dict from event_id to event.
+            A mapping from event_id to event.
         """
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             event_ids,
             redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
@@ -267,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {e.event_id: e for e in events}
 
-    @defer.inlineCallbacks
-    def get_events_as_list(
+    async def get_events_as_list(
         self,
-        event_ids: List[str],
+        event_ids: Collection[str],
         redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
         get_prev_content: bool = False,
         allow_rejected: bool = False,
-    ):
+    ) -> List[EventBase]:
         """Get events from the database and return in a list in the same order
         as given by `event_ids` arg.
 
@@ -295,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
                 omits rejected events from the response.
 
         Returns:
-            Deferred[list[EventBase]]: List of events fetched from the database. The
-            events are in the same order as `event_ids` arg.
+            List of events fetched from the database. The events are in the same
+            order as `event_ids` arg.
 
             Note that the returned list may be smaller than the list of event
             IDs if not all events could be fetched.
@@ -306,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
             return []
 
         # there may be duplicates so we cast the list to a set
-        event_entry_map = yield self._get_events_from_cache_or_db(
+        event_entry_map = await self._get_events_from_cache_or_db(
             set(event_ids), allow_rejected=allow_rejected
         )
 
@@ -341,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
                     continue
 
                 redacted_event_id = entry.event.redacts
-                event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+                event_map = await self._get_events_from_cache_or_db([redacted_event_id])
                 original_event_entry = event_map.get(redacted_event_id)
                 if not original_event_entry:
                     # we don't have the redacted event (or it was rejected).
@@ -407,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             if get_prev_content:
                 if "replaces_state" in event.unsigned:
-                    prev = yield self.get_event(
+                    prev = await self.get_event(
                         event.unsigned["replaces_state"],
                         get_prev_content=False,
                         allow_none=True,
@@ -419,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return events
 
-    @defer.inlineCallbacks
-    def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
         """Fetch a bunch of events from the cache or the database.
 
         If events are pulled from the database, they will be cached for future lookups.
@@ -435,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected events are omitted from the response.
 
         Returns:
-            Deferred[Dict[str, _EventCacheEntry]]:
+            Dict[str, _EventCacheEntry]:
                 map from event id to result
         """
         event_entry_map = self._get_events_from_cache(
@@ -453,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
             # the events have been redacted, and if so pulling the redaction event out
             # of the database to check it.
             #
-            missing_events = yield self._get_events_from_db(
+            missing_events = await self._get_events_from_db(
                 missing_events_ids, allow_rejected=allow_rejected
             )
 
@@ -561,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
                 with PreserveLoggingContext():
                     self.hs.get_reactor().callFromThread(fire, event_list, e)
 
-    @defer.inlineCallbacks
-    def _get_events_from_db(self, event_ids, allow_rejected=False):
+    async def _get_events_from_db(self, event_ids, allow_rejected=False):
         """Fetch a bunch of events from the database.
 
         Returned events will be added to the cache for future lookups.
@@ -576,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
                 rejected events are omitted from the response.
 
         Returns:
-            Deferred[Dict[str, _EventCacheEntry]]:
+            Dict[str, _EventCacheEntry]:
                 map from event id to result. May return extra events which
                 weren't asked for.
         """
@@ -584,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
         events_to_fetch = event_ids
 
         while events_to_fetch:
-            row_map = yield self._enqueue_events(events_to_fetch)
+            row_map = await self._enqueue_events(events_to_fetch)
 
             # we need to recursively fetch any redactions of those events
             redaction_ids = set()
@@ -610,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
             if not allow_rejected and rejected_reason:
                 continue
 
-            d = db_to_json(row["json"])
-            internal_metadata = db_to_json(row["internal_metadata"])
+            # If the event or metadata cannot be parsed, log the error and act
+            # as if the event is unknown.
+            try:
+                d = db_to_json(row["json"])
+            except ValueError:
+                logger.error("Unable to parse json from event: %s", event_id)
+                continue
+            try:
+                internal_metadata = db_to_json(row["internal_metadata"])
+            except ValueError:
+                logger.error(
+                    "Unable to parse internal_metadata from event: %s", event_id
+                )
+                continue
 
             format_version = row["format_version"]
             if format_version is None:
@@ -622,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore):
             room_version_id = row["room_version_id"]
 
             if not room_version_id:
-                # this should only happen for out-of-band membership events
-                if not internal_metadata.get("out_of_band_membership"):
-                    logger.warning(
-                        "Room %s for event %s is unknown", d["room_id"], event_id
+                # this should only happen for out-of-band membership events which
+                # arrived before #6983 landed. For all other events, we should have
+                # an entry in the 'rooms' table.
+                #
+                # However, the 'out_of_band_membership' flag is unreliable for older
+                # invites, so just accept it for all membership events.
+                #
+                if d["type"] != EventTypes.Member:
+                    raise Exception(
+                        "Room %s for event %s is unknown" % (d["room_id"], event_id)
                     )
-                    continue
 
-                # take a wild stab at the room version based on the event format
+                # so, assuming this is an out-of-band-invite that arrived before #6983
+                # landed, we know that the room version must be v5 or earlier (because
+                # v6 hadn't been invented at that point, so invites from such rooms
+                # would have been rejected.)
+                #
+                # The main reason we need to know the room version here (other than
+                # choosing the right python Event class) is in case the event later has
+                # to be redacted - and all the room versions up to v5 used the same
+                # redaction algorithm.
+                #
+                # So, the following approximations should be adequate.
+
                 if format_version == EventFormatVersions.V1:
+                    # if it's event format v1 then it must be room v1 or v2
                     room_version = RoomVersions.V1
                 elif format_version == EventFormatVersions.V2:
+                    # if it's event format v2 then it must be room v3
                     room_version = RoomVersions.V3
                 else:
+                    # if it's event format v3 then it must be room v4 or v5
                     room_version = RoomVersions.V5
             else:
                 room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
@@ -686,8 +703,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return result_map
 
-    @defer.inlineCallbacks
-    def _enqueue_events(self, events):
+    async def _enqueue_events(self, events):
         """Fetches events from the database using the _event_fetch_list. This
         allows batch and bulk fetching of events - it allows us to fetch events
         without having to create a new transaction for each request for events.
@@ -696,7 +712,7 @@ class EventsWorkerStore(SQLBaseStore):
             events (Iterable[str]): events to be fetched.
 
         Returns:
-            Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+            Dict[str, Dict]: map from event id to row data from the database.
                 May contain events that weren't requested.
         """
 
@@ -719,7 +735,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():
-            row_map = yield events_d
+            row_map = await events_d
         logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
 
         return row_map
@@ -878,12 +894,11 @@ class EventsWorkerStore(SQLBaseStore):
         # no valid redaction found for this event
         return None
 
-    @defer.inlineCallbacks
-    def have_events_in_timeline(self, event_ids):
+    async def have_events_in_timeline(self, event_ids):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="events",
             retcols=("event_id",),
             column="event_id",
@@ -894,15 +909,14 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {r["event_id"] for r in rows}
 
-    @defer.inlineCallbacks
-    def have_seen_events(self, event_ids):
+    async def have_seen_events(self, event_ids):
         """Given a list of event ids, check if we have already processed them.
 
         Args:
             event_ids (iterable[str]):
 
         Returns:
-            Deferred[set[str]]: The events we have already seen.
+            set[str]: The events we have already seen.
         """
         results = set()
 
@@ -918,41 +932,11 @@ class EventsWorkerStore(SQLBaseStore):
         # break the input up into chunks of 100
         input_iterator = iter(event_ids)
         for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
-            yield self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "have_seen_events", have_seen_events_txn, chunk
             )
         return results
 
-    def _get_total_state_event_counts_txn(self, txn, room_id):
-        """
-        See get_total_state_event_counts.
-        """
-        # We join against the events table as that has an index on room_id
-        sql = """
-            SELECT COUNT(*) FROM state_events
-            INNER JOIN events USING (room_id, event_id)
-            WHERE room_id=?
-        """
-        txn.execute(sql, (room_id,))
-        row = txn.fetchone()
-        return row[0] if row else 0
-
-    def get_total_state_event_counts(self, room_id):
-        """
-        Gets the total number of state events in a room.
-
-        Args:
-            room_id (str)
-
-        Returns:
-            Deferred[int]
-        """
-        return self.db_pool.runInteraction(
-            "get_total_state_event_counts",
-            self._get_total_state_event_counts_txn,
-            room_id,
-        )
-
     def _get_current_state_event_counts_txn(self, txn, room_id):
         """
         See get_current_state_event_counts.
@@ -978,8 +962,7 @@ class EventsWorkerStore(SQLBaseStore):
             room_id,
         )
 
-    @defer.inlineCallbacks
-    def get_room_complexity(self, room_id):
+    async def get_room_complexity(self, room_id):
         """
         Get a rough approximation of the complexity of the room. This is used by
         remote servers to decide whether they wish to join the room or not.
@@ -990,9 +973,9 @@ class EventsWorkerStore(SQLBaseStore):
             room_id (str)
 
         Returns:
-            Deferred[dict[str:int]] of complexity version to complexity.
+            dict[str:int] of complexity version to complexity.
         """
-        state_events = yield self.get_current_state_event_counts(room_id)
+        state_events = await self.get_current_state_event_counts(room_id)
 
         # Call this one "v1", so we can introduce new ones as we want to develop
         # it.
@@ -1222,97 +1205,6 @@ class EventsWorkerStore(SQLBaseStore):
 
         return rows, to_token, True
 
-    @cached(num_args=5, max_entries=10)
-    def get_all_new_events(
-        self,
-        last_backfill_id,
-        last_forward_id,
-        current_backfill_id,
-        current_forward_id,
-        limit,
-    ):
-        """Get all the new events that have arrived at the server either as
-        new events or as backfilled events"""
-        have_backfill_events = last_backfill_id != current_backfill_id
-        have_forward_events = last_forward_id != current_forward_id
-
-        if not have_backfill_events and not have_forward_events:
-            return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
-        def get_all_new_events_txn(txn):
-            sql = (
-                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " WHERE ? < stream_ordering AND stream_ordering <= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            if have_forward_events:
-                txn.execute(sql, (last_forward_id, current_forward_id, limit))
-                new_forward_events = txn.fetchall()
-
-                if len(new_forward_events) == limit:
-                    upper_bound = new_forward_events[-1][0]
-                else:
-                    upper_bound = current_forward_id
-
-                sql = (
-                    "SELECT event_stream_ordering, event_id, state_group"
-                    " FROM ex_outlier_stream"
-                    " WHERE ? > event_stream_ordering"
-                    " AND event_stream_ordering >= ?"
-                    " ORDER BY event_stream_ordering DESC"
-                )
-                txn.execute(sql, (last_forward_id, upper_bound))
-                forward_ex_outliers = txn.fetchall()
-            else:
-                new_forward_events = []
-                forward_ex_outliers = []
-
-            sql = (
-                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " WHERE ? > stream_ordering AND stream_ordering >= ?"
-                " ORDER BY stream_ordering DESC"
-                " LIMIT ?"
-            )
-            if have_backfill_events:
-                txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
-                new_backfill_events = txn.fetchall()
-
-                if len(new_backfill_events) == limit:
-                    upper_bound = new_backfill_events[-1][0]
-                else:
-                    upper_bound = current_backfill_id
-
-                sql = (
-                    "SELECT -event_stream_ordering, event_id, state_group"
-                    " FROM ex_outlier_stream"
-                    " WHERE ? > event_stream_ordering"
-                    " AND event_stream_ordering >= ?"
-                    " ORDER BY event_stream_ordering DESC"
-                )
-                txn.execute(sql, (-last_backfill_id, -upper_bound))
-                backward_ex_outliers = txn.fetchall()
-            else:
-                new_backfill_events = []
-                backward_ex_outliers = []
-
-            return AllNewEventsResult(
-                new_forward_events,
-                new_backfill_events,
-                forward_ex_outliers,
-                backward_ex_outliers,
-            )
-
-        return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
-
     async def is_event_after(self, event_id1, event_id2):
         """Returns True if event_id1 is after event_id2 in the stream
         """
@@ -1320,9 +1212,9 @@ class EventsWorkerStore(SQLBaseStore):
         to_2, so_2 = await self.get_event_ordering(event_id2)
         return (to_1, so_1) > (to_2, so_2)
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_event_ordering(self, event_id):
-        res = yield self.db_pool.simple_select_one(
+    @cached(max_entries=5000)
+    async def get_event_ordering(self, event_id):
+        res = await self.db_pool.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
             keyvalues={"event_id": event_id},
@@ -1357,14 +1249,3 @@ class EventsWorkerStore(SQLBaseStore):
         return self.db_pool.runInteraction(
             desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
         )
-
-
-AllNewEventsResult = namedtuple(
-    "AllNewEventsResult",
-    [
-        "new_forward_events",
-        "new_backfill_events",
-        "forward_ex_outliers",
-        "backward_ex_outliers",
-    ],
-)
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 380db3a3f3..c39864f59f 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
 
 
 class GroupServerWorkerStore(SQLBaseStore):
-    def get_group(self, group_id):
-        return self.db_pool.simple_select_one(
+    async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="groups",
             keyvalues={"group_id": group_id},
             retcols=(
@@ -341,17 +341,20 @@ class GroupServerWorkerStore(SQLBaseStore):
             "get_users_for_summary_by_role", _get_users_for_summary_txn
         )
 
-    def is_user_in_group(self, user_id, group_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
+        result = await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
             allow_none=True,
             desc="is_user_in_group",
-        ).addCallback(lambda r: bool(r))
+        )
+        return bool(result)
 
-    def is_user_admin_in_group(self, group_id, user_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_user_admin_in_group(
+        self, group_id: str, user_id: str
+    ) -> Optional[bool]:
+        return await self.db_pool.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="is_admin",
@@ -359,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="is_user_admin_in_group",
         )
 
-    def is_user_invited_to_local_group(self, group_id, user_id):
+    async def is_user_invited_to_local_group(
+        self, group_id: str, user_id: str
+    ) -> Optional[bool]:
         """Has the group server invited a user?
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
@@ -1181,7 +1186,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        with self._group_updates_id_gen.get_next() as next_id:
+        with await self._group_updates_id_gen.get_next() as next_id:
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 384e9c5eb0..fadcad51e7 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,6 +16,7 @@
 
 import itertools
 import logging
+from typing import Iterable, Tuple
 
 from signedjson.key import decode_verify_key_bytes
 
@@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore):
 
         return self.db_pool.runInteraction("get_server_verify_keys", _txn)
 
-    def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+    async def store_server_verify_keys(
+        self,
+        from_server: str,
+        ts_added_ms: int,
+        verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+    ) -> None:
         """Stores NACL verification keys for remote servers.
         Args:
-            from_server (str): Where the verification keys were looked up
-            ts_added_ms (int): The time to record that the key was added
-            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+            from_server: Where the verification keys were looked up
+            ts_added_ms: The time to record that the key was added
+            verify_keys:
                 keys to be stored. Each entry is a triplet of
                 (server_name, key_id, key).
         """
@@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore):
             # param, which is itself the 2-tuple (server_name, key_id).
             invalidations.append((server_name, key_id))
 
-        def _invalidate(res):
-            f = self._get_server_verify_key.invalidate
-            for i in invalidations:
-                f((i,))
-            return res
-
-        return self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "store_server_verify_keys",
             self.db_pool.simple_upsert_many_txn,
             table="server_signature_keys",
@@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore):
                 "verify_key",
             ),
             value_values=value_values,
-        ).addCallback(_invalidate)
+        )
+
+        invalidate = self._get_server_verify_key.invalidate
+        for i in invalidations:
+            invalidate((i,))
 
     def store_server_keys_json(
         self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 80fc1cd009..4ae255ebd8 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -12,6 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, Optional
+
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
 
-    def get_local_media(self, media_id):
+    async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
         """Get the metadata for a local piece of media
+
         Returns:
             None if the media_id doesn't exist.
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="store_local_thumbnail",
         )
 
-    def get_cached_remote_media(self, origin, media_id):
-        return self.db_pool.simple_select_one(
+    async def get_cached_remote_media(
+        self, origin, media_id: str
+    ) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e71cdd2cb4..fe30552c08 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         return users
 
     @cached(num_args=1)
-    def user_last_seen_monthly_active(self, user_id):
+    async def user_last_seen_monthly_active(self, user_id: str) -> int:
         """
-            Checks if a given user is part of the monthly active user group
-            Arguments:
-                user_id (str): user to add/update
-            Return:
-                Deferred[int] : timestamp since last seen, None if never seen
+        Checks if a given user is part of the monthly active user group
 
+        Arguments:
+            user_id: user to add/update
+
+        Return:
+            Timestamp since last seen, None if never seen
         """
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="monthly_active_users",
             keyvalues={"user_id": user_id},
             retcol="timestamp",
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 59ba12820a..c9f655dfb7 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -15,15 +15,15 @@
 
 from typing import List, Tuple
 
+from synapse.api.presence import UserPresenceState
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.presence import UserPresenceState
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
 
 
 class PresenceStore(SQLBaseStore):
     async def update_presence(self, presence_states):
-        stream_ordering_manager = self._presence_id_gen.get_next_mult(
+        stream_ordering_manager = await self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
 
@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_presence_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
+        cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
     )
-    def get_presence_for_users(self, user_ids):
-        rows = yield self.db_pool.simple_select_many_batch(
+    async def get_presence_for_users(self, user_ids):
+        rows = await self.db_pool.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
             iterable=user_ids,
@@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore):
 
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
-
-    def allow_presence_visible(self, observed_localpart, observer_userid):
-        return self.db_pool.simple_insert(
-            table="presence_allow_inbound",
-            values={
-                "observed_user_id": observed_localpart,
-                "observer_user_id": observer_userid,
-            },
-            desc="allow_presence_visible",
-            or_ignore=True,
-        )
-
-    def disallow_presence_visible(self, observed_localpart, observer_userid):
-        return self.db_pool.simple_delete_one(
-            table="presence_allow_inbound",
-            keyvalues={
-                "observed_user_id": observed_localpart,
-                "observer_user_id": observer_userid,
-            },
-            desc="disallow_presence_visible",
-        )
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index b8261357d4..b8233c4848 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any, Dict, Optional
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
 
 
 class ProfileWorkerStore(SQLBaseStore):
-    async def get_profileinfo(self, user_localpart):
+    async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
         try:
             profile = await self.db_pool.simple_select_one(
                 table="profiles",
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
             avatar_url=profile["avatar_url"], display_name=profile["displayname"]
         )
 
-    def get_profile_displayname(self, user_localpart):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_profile_displayname(self, user_localpart: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="displayname",
             desc="get_profile_displayname",
         )
 
-    def get_profile_avatar_url(self, user_localpart):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_profile_avatar_url(self, user_localpart: str) -> str:
+        return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="avatar_url",
             desc="get_profile_avatar_url",
         )
 
-    def get_from_remote_profile_cache(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_from_remote_profile_cache(
+        self, user_id: str
+    ) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             retcols=("displayname", "avatar_url"),
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 6562db5c2b..2fb5b02d7d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -30,9 +30,9 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
 from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import ChainedIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util import json_encoder
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
@@ -82,9 +82,9 @@ class PushRulesWorkerStore(
         super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
-            self._push_rules_stream_id_gen = ChainedIdGenerator(
-                self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
-            )  # type: Union[ChainedIdGenerator, SlavedIdTracker]
+            self._push_rules_stream_id_gen = StreamIdGenerator(
+                db_conn, "push_rules_stream", "stream_id"
+            )  # type: Union[StreamIdGenerator, SlavedIdTracker]
         else:
             self._push_rules_stream_id_gen = SlavedIdTracker(
                 db_conn, "push_rules_stream", "stream_id"
@@ -115,9 +115,9 @@ class PushRulesWorkerStore(
         """
         raise NotImplementedError()
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_push_rules_for_user(self, user_id):
-        rows = yield self.db_pool.simple_select_list(
+    @cached(max_entries=5000)
+    async def get_push_rules_for_user(self, user_id):
+        rows = await self.db_pool.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
             retcols=(
@@ -133,17 +133,15 @@ class PushRulesWorkerStore(
 
         rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
 
-        enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+        enabled_map = await self.get_push_rules_enabled_for_user(user_id)
 
         use_new_defaults = user_id in self._users_new_default_push_rules
 
-        rules = _load_rules(rows, enabled_map, use_new_defaults)
-
-        return rules
+        return _load_rules(rows, enabled_map, use_new_defaults)
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_push_rules_enabled_for_user(self, user_id):
-        results = yield self.db_pool.simple_select_list(
+    @cached(max_entries=5000)
+    async def get_push_rules_enabled_for_user(self, user_id):
+        results = await self.db_pool.simple_select_list(
             table="push_rules_enable",
             keyvalues={"user_name": user_id},
             retcols=("user_name", "rule_id", "enabled"),
@@ -170,18 +168,15 @@ class PushRulesWorkerStore(
             )
 
     @cachedList(
-        cached_method_name="get_push_rules_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
+        cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
     )
-    def bulk_get_push_rules(self, user_ids):
+    async def bulk_get_push_rules(self, user_ids):
         if not user_ids:
             return {}
 
         results = {user_id: [] for user_id in user_ids}
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="push_rules",
             column="user_name",
             iterable=user_ids,
@@ -194,7 +189,7 @@ class PushRulesWorkerStore(
         for row in rows:
             results.setdefault(row["user_name"], []).append(row)
 
-        enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+        enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
 
         for user_id, rules in results.items():
             use_new_defaults = user_id in self._users_new_default_push_rules
@@ -205,14 +200,15 @@ class PushRulesWorkerStore(
 
         return results
 
-    @defer.inlineCallbacks
-    def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+    async def copy_push_rule_from_room_to_room(
+        self, new_room_id: str, user_id: str, rule: dict
+    ) -> None:
         """Copy a single push rule from one room to another for a specific user.
 
         Args:
-            new_room_id (str): ID of the new room.
-            user_id (str): ID of user the push rule belongs to.
-            rule (Dict): A push rule.
+            new_room_id: ID of the new room.
+            user_id : ID of user the push rule belongs to.
+            rule: A push rule.
         """
         # Create new rule id
         rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
@@ -224,7 +220,7 @@ class PushRulesWorkerStore(
                 condition["pattern"] = new_room_id
 
         # Add the rule for the new room
-        yield self.add_push_rule(
+        await self.add_push_rule(
             user_id=user_id,
             rule_id=new_rule_id,
             priority_class=rule["priority_class"],
@@ -232,20 +228,19 @@ class PushRulesWorkerStore(
             actions=rule["actions"],
         )
 
-    @defer.inlineCallbacks
-    def copy_push_rules_from_room_to_room_for_user(
-        self, old_room_id, new_room_id, user_id
-    ):
+    async def copy_push_rules_from_room_to_room_for_user(
+        self, old_room_id: str, new_room_id: str, user_id: str
+    ) -> None:
         """Copy all of the push rules from one room to another for a specific
         user.
 
         Args:
-            old_room_id (str): ID of the old room.
-            new_room_id (str): ID of the new room.
-            user_id (str): ID of user to copy push rules for.
+            old_room_id: ID of the old room.
+            new_room_id: ID of the new room.
+            user_id: ID of user to copy push rules for.
         """
         # Retrieve push rules for this user
-        user_push_rules = yield self.get_push_rules_for_user(user_id)
+        user_push_rules = await self.get_push_rules_for_user(user_id)
 
         # Get rules relating to the old room and copy them to the new room
         for rule in user_push_rules:
@@ -254,21 +249,20 @@ class PushRulesWorkerStore(
                 (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
                 for c in conditions
             ):
-                yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
+                await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
 
     @cachedList(
         cached_method_name="get_push_rules_enabled_for_user",
         list_name="user_ids",
         num_args=1,
-        inlineCallbacks=True,
     )
-    def bulk_get_push_rules_enabled(self, user_ids):
+    async def bulk_get_push_rules_enabled(self, user_ids):
         if not user_ids:
             return {}
 
         results = {user_id: {} for user_id in user_ids}
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="push_rules_enable",
             column="user_name",
             iterable=user_ids,
@@ -332,8 +326,7 @@ class PushRulesWorkerStore(
 
 
 class PushRuleStore(PushRulesWorkerStore):
-    @defer.inlineCallbacks
-    def add_push_rule(
+    async def add_push_rule(
         self,
         user_id,
         rule_id,
@@ -342,13 +335,14 @@ class PushRuleStore(PushRulesWorkerStore):
         actions,
         before=None,
         after=None,
-    ):
+    ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
             if before or after:
-                yield self.db_pool.runInteraction(
+                await self.db_pool.runInteraction(
                     "_add_push_rule_relative_txn",
                     self._add_push_rule_relative_txn,
                     stream_id,
@@ -362,7 +356,7 @@ class PushRuleStore(PushRulesWorkerStore):
                     after,
                 )
             else:
-                yield self.db_pool.runInteraction(
+                await self.db_pool.runInteraction(
                     "_add_push_rule_highest_priority_txn",
                     self._add_push_rule_highest_priority_txn,
                     stream_id,
@@ -546,16 +540,15 @@ class PushRuleStore(PushRulesWorkerStore):
                 },
             )
 
-    @defer.inlineCallbacks
-    def delete_push_rule(self, user_id, rule_id):
+    async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
         """
         Delete a push rule. Args specify the row to be deleted and can be
         any of the columns in the push_rule table, but below are the
         standard ones
 
         Args:
-            user_id (str): The matrix ID of the push rule owner
-            rule_id (str): The rule_id of the rule to be deleted
+            user_id: The matrix ID of the push rule owner
+            rule_id: The rule_id of the rule to be deleted
         """
 
         def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
@@ -567,20 +560,21 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.db_pool.runInteraction(
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
+            await self.db_pool.runInteraction(
                 "delete_push_rule",
                 delete_push_rule_txn,
                 stream_id,
                 event_stream_ordering,
             )
 
-    @defer.inlineCallbacks
-    def set_push_rule_enabled(self, user_id, rule_id, enabled):
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.db_pool.runInteraction(
+    async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
+            await self.db_pool.runInteraction(
                 "_set_push_rule_enabled_txn",
                 self._set_push_rule_enabled_txn,
                 stream_id,
@@ -611,8 +605,9 @@ class PushRuleStore(PushRulesWorkerStore):
             op="ENABLE" if enabled else "DISABLE",
         )
 
-    @defer.inlineCallbacks
-    def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+    async def set_push_rule_actions(
+        self, user_id, rule_id, actions, is_default_rule
+    ) -> None:
         actions_json = json_encoder.encode(actions)
 
         def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -651,9 +646,10 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.db_pool.runInteraction(
+        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+            event_stream_ordering = self._stream_id_gen.get_current_token()
+
+            await self.db_pool.runInteraction(
                 "set_push_rule_actions",
                 set_push_rule_actions_txn,
                 stream_id,
@@ -681,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_push_rules_stream_token(self):
-        """Get the position of the push rules stream.
-        Returns a pair of a stream id for the push_rules stream and the
-        room stream ordering it corresponds to."""
-        return self._push_rules_stream_id_gen.get_current_token()
-
     def get_max_push_rules_stream_id(self):
-        return self.get_push_rules_stream_token()[0]
+        return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index b5200fbe79..c388468273 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -19,10 +19,8 @@ from typing import Iterable, Iterator, List, Tuple
 
 from canonicaljson import encode_canonical_json
 
-from twisted.internet import defer
-
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
 
 logger = logging.getLogger(__name__)
 
@@ -34,23 +32,22 @@ class PusherWorkerStore(SQLBaseStore):
         Drops any rows whose data cannot be decoded
         """
         for r in rows:
-            dataJson = r["data"]
+            data_json = r["data"]
             try:
-                r["data"] = db_to_json(dataJson)
+                r["data"] = db_to_json(data_json)
             except Exception as e:
                 logger.warning(
                     "Invalid JSON in data for pusher %d: %s, %s",
                     r["id"],
-                    dataJson,
+                    data_json,
                     e.args[0],
                 )
                 continue
 
             yield r
 
-    @defer.inlineCallbacks
-    def user_has_pusher(self, user_id):
-        ret = yield self.db_pool.simple_select_one_onecol(
+    async def user_has_pusher(self, user_id):
+        ret = await self.db_pool.simple_select_one_onecol(
             "pushers", {"user_name": user_id}, "id", allow_none=True
         )
         return ret is not None
@@ -61,9 +58,8 @@ class PusherWorkerStore(SQLBaseStore):
     def get_pushers_by_user_id(self, user_id):
         return self.get_pushers_by({"user_name": user_id})
 
-    @defer.inlineCallbacks
-    def get_pushers_by(self, keyvalues):
-        ret = yield self.db_pool.simple_select_list(
+    async def get_pushers_by(self, keyvalues):
+        ret = await self.db_pool.simple_select_list(
             "pushers",
             keyvalues,
             [
@@ -87,16 +83,14 @@ class PusherWorkerStore(SQLBaseStore):
         )
         return self._decode_pushers_rows(ret)
 
-    @defer.inlineCallbacks
-    def get_all_pushers(self):
+    async def get_all_pushers(self):
         def get_pushers(txn):
             txn.execute("SELECT * FROM pushers")
             rows = self.db_pool.cursor_to_dict(txn)
 
             return self._decode_pushers_rows(rows)
 
-        rows = yield self.db_pool.runInteraction("get_all_pushers", get_pushers)
-        return rows
+        return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
 
     async def get_all_updated_pushers_rows(
         self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -164,19 +158,16 @@ class PusherWorkerStore(SQLBaseStore):
             "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
         )
 
-    @cachedInlineCallbacks(num_args=1, max_entries=15000)
-    def get_if_user_has_pusher(self, user_id):
+    @cached(num_args=1, max_entries=15000)
+    async def get_if_user_has_pusher(self, user_id):
         # This only exists for the cachedList decorator
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="get_if_user_has_pusher",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
+        cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
     )
-    def get_if_users_have_pushers(self, user_ids):
-        rows = yield self.db_pool.simple_select_many_batch(
+    async def get_if_users_have_pushers(self, user_ids):
+        rows = await self.db_pool.simple_select_many_batch(
             table="pushers",
             column="user_name",
             iterable=user_ids,
@@ -189,34 +180,38 @@ class PusherWorkerStore(SQLBaseStore):
 
         return result
 
-    @defer.inlineCallbacks
-    def update_pusher_last_stream_ordering(
+    async def update_pusher_last_stream_ordering(
         self, app_id, pushkey, user_id, last_stream_ordering
-    ):
-        yield self.db_pool.simple_update_one(
+    ) -> None:
+        await self.db_pool.simple_update_one(
             "pushers",
             {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             {"last_stream_ordering": last_stream_ordering},
             desc="update_pusher_last_stream_ordering",
         )
 
-    @defer.inlineCallbacks
-    def update_pusher_last_stream_ordering_and_success(
-        self, app_id, pushkey, user_id, last_stream_ordering, last_success
-    ):
+    async def update_pusher_last_stream_ordering_and_success(
+        self,
+        app_id: str,
+        pushkey: str,
+        user_id: str,
+        last_stream_ordering: int,
+        last_success: int,
+    ) -> bool:
         """Update the last stream ordering position we've processed up to for
         the given pusher.
 
         Args:
-            app_id (str)
-            pushkey (str)
-            last_stream_ordering (int)
-            last_success (int)
+            app_id
+            pushkey
+            user_id
+            last_stream_ordering
+            last_success
 
         Returns:
-            Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+            True if the pusher still exists; False if it has been deleted.
         """
-        updated = yield self.db_pool.simple_update(
+        updated = await self.db_pool.simple_update(
             table="pushers",
             keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             updatevalues={
@@ -228,18 +223,18 @@ class PusherWorkerStore(SQLBaseStore):
 
         return bool(updated)
 
-    @defer.inlineCallbacks
-    def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
-        yield self.db_pool.simple_update(
+    async def update_pusher_failing_since(
+        self, app_id, pushkey, user_id, failing_since
+    ) -> None:
+        await self.db_pool.simple_update(
             table="pushers",
             keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             updatevalues={"failing_since": failing_since},
             desc="update_pusher_failing_since",
         )
 
-    @defer.inlineCallbacks
-    def get_throttle_params_by_room(self, pusher_id):
-        res = yield self.db_pool.simple_select_list(
+    async def get_throttle_params_by_room(self, pusher_id):
+        res = await self.db_pool.simple_select_list(
             "pusher_throttle",
             {"pusher": pusher_id},
             ["room_id", "last_sent_ts", "throttle_ms"],
@@ -255,11 +250,10 @@ class PusherWorkerStore(SQLBaseStore):
 
         return params_by_room
 
-    @defer.inlineCallbacks
-    def set_throttle_params(self, pusher_id, room_id, params):
+    async def set_throttle_params(self, pusher_id, room_id, params) -> None:
         # no need to lock because `pusher_throttle` has a primary key on
         # (pusher, room_id) so simple_upsert will retry
-        yield self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             "pusher_throttle",
             {"pusher": pusher_id, "room_id": room_id},
             params,
@@ -272,8 +266,7 @@ class PusherStore(PusherWorkerStore):
     def get_pushers_stream_token(self):
         return self._pushers_id_gen.get_current_token()
 
-    @defer.inlineCallbacks
-    def add_pusher(
+    async def add_pusher(
         self,
         user_id,
         access_token,
@@ -287,11 +280,11 @@ class PusherStore(PusherWorkerStore):
         data,
         last_stream_ordering,
         profile_tag="",
-    ):
-        with self._pushers_id_gen.get_next() as stream_id:
+    ) -> None:
+        with await self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so simple_upsert will retry
-            yield self.db_pool.simple_upsert(
+            await self.db_pool.simple_upsert(
                 table="pushers",
                 keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
                 values={
@@ -316,15 +309,16 @@ class PusherStore(PusherWorkerStore):
 
             if user_has_pusher is not True:
                 # invalidate, since we the user might not have had a pusher before
-                yield self.db_pool.runInteraction(
+                await self.db_pool.runInteraction(
                     "add_pusher",
                     self._invalidate_cache_and_stream,
                     self.get_if_user_has_pusher,
                     (user_id,),
                 )
 
-    @defer.inlineCallbacks
-    def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
+    async def delete_pusher_by_app_id_pushkey_user_id(
+        self, app_id, pushkey, user_id
+    ) -> None:
         def delete_pusher_txn(txn, stream_id):
             self._invalidate_cache_and_stream(
                 txn, self.get_if_user_has_pusher, (user_id,)
@@ -350,7 +344,7 @@ class PusherStore(PusherWorkerStore):
                 },
             )
 
-        with self._pushers_id_gen.get_next() as stream_id:
-            yield self.db_pool.runInteraction(
+        with await self._pushers_id_gen.get_next() as stream_id:
+            await self.db_pool.runInteraction(
                 "delete_pusher", delete_pusher_txn, stream_id
             )
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 1920a8a152..cea5ac9a68 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -16,7 +16,7 @@
 
 import abc
 import logging
-from typing import List, Tuple
+from typing import List, Optional, Tuple
 
 from twisted.internet import defer
 
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util import json_encoder
 from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 logger = logging.getLogger(__name__)
@@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """
         raise NotImplementedError()
 
-    @cachedInlineCallbacks()
-    def get_users_with_read_receipts_in_room(self, room_id):
-        receipts = yield self.get_receipts_for_room(room_id, "m.read")
+    @cached()
+    async def get_users_with_read_receipts_in_room(self, room_id):
+        receipts = await self.get_receipts_for_room(room_id, "m.read")
         return {r["user_id"] for r in receipts}
 
     @cached(num_args=2)
@@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=3)
-    def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_last_receipt_event_id_for_user(
+        self, user_id: str, room_id: str, receipt_type: str
+    ) -> Optional[str]:
+        return await self.db_pool.simple_select_one_onecol(
             table="receipts_linearized",
             keyvalues={
                 "room_id": room_id,
@@ -84,9 +86,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
             allow_none=True,
         )
 
-    @cachedInlineCallbacks(num_args=2)
-    def get_receipts_for_user(self, user_id, receipt_type):
-        rows = yield self.db_pool.simple_select_list(
+    @cached(num_args=2)
+    async def get_receipts_for_user(self, user_id, receipt_type):
+        rows = await self.db_pool.simple_select_list(
             table="receipts_linearized",
             keyvalues={"user_id": user_id, "receipt_type": receipt_type},
             retcols=("room_id", "event_id"),
@@ -95,8 +97,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return {row["room_id"]: row["event_id"] for row in rows}
 
-    @defer.inlineCallbacks
-    def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
+    async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
         def f(txn):
             sql = (
                 "SELECT rl.room_id, rl.event_id,"
@@ -110,7 +111,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id,))
             return txn.fetchall()
 
-        rows = yield self.db_pool.runInteraction(
+        rows = await self.db_pool.runInteraction(
             "get_receipts_for_user_with_orderings", f
         )
         return {
@@ -122,56 +123,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
             for row in rows
         }
 
-    @defer.inlineCallbacks
-    def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+    async def get_linearized_receipts_for_rooms(
+        self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+    ) -> List[dict]:
         """Get receipts for multiple rooms for sending to clients.
 
         Args:
-            room_ids (list): List of room_ids.
-            to_key (int): Max stream id to fetch receipts upto.
-            from_key (int): Min stream id to fetch receipts from. None fetches
+            room_id: List of room_ids.
+            to_key: Max stream id to fetch receipts upto.
+            from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
         Returns:
-            list: A list of receipts.
+            A list of receipts.
         """
         room_ids = set(room_ids)
 
         if from_key is not None:
             # Only ask the database about rooms where there have been new
             # receipts added since `from_key`
-            room_ids = yield self._receipts_stream_cache.get_entities_changed(
+            room_ids = self._receipts_stream_cache.get_entities_changed(
                 room_ids, from_key
             )
 
-        results = yield self._get_linearized_receipts_for_rooms(
+        results = await self._get_linearized_receipts_for_rooms(
             room_ids, to_key, from_key=from_key
         )
 
         return [ev for res in results.values() for ev in res]
 
-    def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+    async def get_linearized_receipts_for_room(
+        self, room_id: str, to_key: int, from_key: Optional[int] = None
+    ) -> List[dict]:
         """Get receipts for a single room for sending to clients.
 
         Args:
-            room_ids (str): The room id.
-            to_key (int): Max stream id to fetch receipts upto.
-            from_key (int): Min stream id to fetch receipts from. None fetches
+            room_ids: The room id.
+            to_key: Max stream id to fetch receipts upto.
+            from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
 
         Returns:
-            Deferred[list]: A list of receipts.
+            A list of receipts.
         """
         if from_key is not None:
             # Check the cache first to see if any new receipts have been added
             # since`from_key`. If not we can no-op.
             if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
-                defer.succeed([])
+                return []
 
-        return self._get_linearized_receipts_for_room(room_id, to_key, from_key)
+        return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
 
-    @cachedInlineCallbacks(num_args=3, tree=True)
-    def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
+    @cached(num_args=3, tree=True)
+    async def _get_linearized_receipts_for_room(
+        self, room_id: str, to_key: int, from_key: Optional[int] = None
+    ) -> List[dict]:
         """See get_linearized_receipts_for_room
         """
 
@@ -195,7 +201,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return rows
 
-        rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
+        rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
 
         if not rows:
             return []
@@ -212,9 +218,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
         cached_method_name="_get_linearized_receipts_for_room",
         list_name="room_ids",
         num_args=3,
-        inlineCallbacks=True,
     )
-    def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
+    async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
         if not room_ids:
             return {}
 
@@ -243,7 +248,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
             return self.db_pool.cursor_to_dict(txn)
 
-        txn_results = yield self.db_pool.runInteraction(
+        txn_results = await self.db_pool.runInteraction(
             "_get_linearized_receipts_for_rooms", f
         )
 
@@ -346,7 +351,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         )
 
     def _invalidate_get_users_with_receipts_in_room(
-        self, room_id, receipt_type, user_id
+        self, room_id: str, receipt_type: str, user_id: str
     ):
         if receipt_type != "m.read":
             return
@@ -472,15 +477,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
 
         return rx_ts
 
-    @defer.inlineCallbacks
-    def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
+    async def insert_receipt(
+        self,
+        room_id: str,
+        receipt_type: str,
+        user_id: str,
+        event_ids: List[str],
+        data: dict,
+    ) -> Optional[Tuple[int, int]]:
         """Insert a receipt, either from local client or remote server.
 
         Automatically does conversion between linearized and graph
         representations.
         """
         if not event_ids:
-            return
+            return None
 
         if len(event_ids) == 1:
             linearized_event_id = event_ids[0]
@@ -507,13 +518,12 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 else:
                     raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
 
-            linearized_event_id = yield self.db_pool.runInteraction(
+            linearized_event_id = await self.db_pool.runInteraction(
                 "insert_receipt_conv", graph_to_linear
             )
 
-        stream_id_manager = self._receipts_id_gen.get_next()
-        with stream_id_manager as stream_id:
-            event_ts = yield self.db_pool.runInteraction(
+        with await self._receipts_id_gen.get_next() as stream_id:
+            event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
                 room_id,
@@ -535,7 +545,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             now - event_ts,
         )
 
-        yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
+        await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
 
         max_persisted_id = self._receipts_id_gen.get_current_token()
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 402ae25571..eced53d470 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -17,9 +17,7 @@
 
 import logging
 import re
-from typing import Dict, List, Optional
-
-from twisted.internet.defer import Deferred
+from typing import Any, Awaitable, Dict, List, Optional
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -48,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @cached()
-    def get_user_by_id(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="users",
             keyvalues={"name": user_id},
             retcols=[
@@ -304,7 +302,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     def _query_for_auth(self, txn, token):
         sql = (
-            "SELECT users.name, users.is_guest, access_tokens.id as token_id,"
+            "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
             " access_tokens.device_id, access_tokens.valid_until_ms"
             " FROM users"
             " INNER JOIN access_tokens on users.name = access_tokens.user_id"
@@ -563,7 +561,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             id_server (str)
 
         Returns:
-            Deferred
+            Awaitable
         """
         # We need to use an upsert, in case they user had already bound the
         # threepid
@@ -891,6 +889,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         super(RegistrationStore, self).__init__(database, db_conn, hs)
 
         self._account_validity = hs.config.account_validity
+        self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
         if self._account_validity.enabled:
             self._clock.call_later(
@@ -952,6 +951,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         create_profile_with_displayname=None,
         admin=False,
         user_type=None,
+        shadow_banned=False,
     ):
         """Attempts to register an account.
 
@@ -968,6 +968,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             admin (boolean): is an admin user?
             user_type (str|None): type of user. One of the values from
                 api.constants.UserTypes, or None for a normal user.
+            shadow_banned (bool): Whether the user is shadow-banned,
+                i.e. they may be told their requests succeeded but we ignore them.
 
         Raises:
             StoreError if the user_id could not be registered.
@@ -986,6 +988,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             create_profile_with_displayname,
             admin,
             user_type,
+            shadow_banned,
         )
 
     def _register_user(
@@ -999,6 +1002,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         create_profile_with_displayname,
         admin,
         user_type,
+        shadow_banned,
     ):
         user_id_obj = UserID.from_string(user_id)
 
@@ -1028,6 +1032,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                         "appservice_id": appservice_id,
                         "admin": 1 if admin else 0,
                         "user_type": user_type,
+                        "shadow_banned": shadow_banned,
                     },
                 )
             else:
@@ -1042,6 +1047,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                         "appservice_id": appservice_id,
                         "admin": 1 if admin else 0,
                         "user_type": user_type,
+                        "shadow_banned": shadow_banned,
                     },
                 )
 
@@ -1077,7 +1083,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
     def record_user_external_id(
         self, auth_provider: str, external_id: str, user_id: str
-    ) -> Deferred:
+    ) -> Awaitable:
         """Record a mapping from an external user id to a mxid
 
         Args:
@@ -1253,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="del_user_pending_deactivation",
         )
 
-    def get_user_pending_deactivation(self):
+    async def get_user_pending_deactivation(self) -> Optional[str]:
         """
         Gets one user from the table of users waiting to be parted from all the rooms
         they're in.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             "users_pending_deactivation",
             keyvalues={},
             retcol="user_id",
@@ -1297,15 +1303,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
 
             if not row:
-                raise ThreepidValidationError(400, "Unknown session_id")
+                if self._ignore_unknown_session_error:
+                    # If we need to inhibit the error caused by an incorrect session ID,
+                    # use None as placeholder values for the client secret and the
+                    # validation timestamp.
+                    # It shouldn't be an issue because they're both only checked after
+                    # the token check, which should fail. And if it doesn't for some
+                    # reason, the next check is on the client secret, which is NOT NULL,
+                    # so we don't have to worry about the client secret matching by
+                    # accident.
+                    row = {"client_secret": None, "validated_at": None}
+                else:
+                    raise ThreepidValidationError(400, "Unknown session_id")
+
             retrieved_client_secret = row["client_secret"]
             validated_at = row["validated_at"]
 
-            if retrieved_client_secret != client_secret:
-                raise ThreepidValidationError(
-                    400, "This client_secret does not match the provided session_id"
-                )
-
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 table="threepid_validation_token",
@@ -1321,6 +1334,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             expires = row["expires"]
             next_link = row["next_link"]
 
+            if retrieved_client_secret != client_secret:
+                raise ThreepidValidationError(
+                    400, "This client_secret does not match the provided session_id"
+                )
+
             # If the session is already validated, no need to revalidate
             if validated_at:
                 return next_link
@@ -1345,43 +1363,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             "validate_threepid_session_txn", validate_threepid_session_txn
         )
 
-    def upsert_threepid_validation_session(
-        self,
-        medium,
-        address,
-        client_secret,
-        send_attempt,
-        session_id,
-        validated_at=None,
-    ):
-        """Upsert a threepid validation session
-        Args:
-            medium (str): The medium of the 3PID
-            address (str): The address of the 3PID
-            client_secret (str): A unique string provided by the client to
-                help identify this validation attempt
-            send_attempt (int): The latest send_attempt on this session
-            session_id (str): The id of this validation session
-            validated_at (int|None): The unix timestamp in milliseconds of
-                when the session was marked as valid
-        """
-        insertion_values = {
-            "medium": medium,
-            "address": address,
-            "client_secret": client_secret,
-        }
-
-        if validated_at:
-            insertion_values["validated_at"] = validated_at
-
-        return self.db_pool.simple_upsert(
-            table="threepid_validation_session",
-            keyvalues={"session_id": session_id},
-            values={"last_send_attempt": send_attempt},
-            insertion_values=insertion_values,
-            desc="upsert_threepid_validation_session",
-        )
-
     def start_or_continue_validation_session(
         self,
         medium,
diff --git a/synapse/storage/databases/main/rejections.py b/synapse/storage/databases/main/rejections.py
index cf9ba51205..1e361aaa9a 100644
--- a/synapse/storage/databases/main/rejections.py
+++ b/synapse/storage/databases/main/rejections.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
 
 from synapse.storage._base import SQLBaseStore
 
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
 
 
 class RejectionsStore(SQLBaseStore):
-    def get_rejection_reason(self, event_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_rejection_reason(self, event_id: str) -> Optional[str]:
+        return await self.db_pool.simple_select_one_onecol(
             table="rejections",
             retcol="reason",
             keyvalues={"event_id": event_id},
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index f4008e6221..97ecdb16e4 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -21,8 +21,6 @@ from abc import abstractmethod
 from enum import Enum
 from typing import Any, Dict, List, Optional, Tuple
 
-from canonicaljson import json
-
 from synapse.api.constants import EventTypes
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
@@ -30,15 +28,12 @@ from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchStore
 from synapse.types import ThirdPartyInstanceID
+from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
 
 
-OpsLevel = collections.namedtuple(
-    "OpsLevel", ("ban_level", "kick_level", "redact_level")
-)
-
 RatelimitOverride = collections.namedtuple(
     "RatelimitOverride", ("messages_per_second", "burst_count")
 )
@@ -78,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
 
         self.config = hs.config
 
-    def get_room(self, room_id):
+    async def get_room(self, room_id: str) -> dict:
         """Retrieve a room.
 
         Args:
-            room_id (str): The ID of the room to retrieve.
+            room_id: The ID of the room to retrieve.
         Returns:
             A dict containing the room information, or None if the room is unknown.
         """
-        return self.db_pool.simple_select_one(
+        return await self.db_pool.simple_select_one(
             table="rooms",
             keyvalues={"room_id": room_id},
             retcols=("room_id", "is_public", "creator"),
@@ -335,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
         return ret_val
 
     @cached(max_entries=10000)
-    def is_room_blocked(self, room_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def is_room_blocked(self, room_id: str) -> Optional[bool]:
+        return await self.db_pool.simple_select_one_onecol(
             table="blocked_rooms",
             keyvalues={"room_id": room_id},
             retcol="1",
@@ -1134,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         },
                     )
 
-            with self._public_room_id_gen.get_next() as next_id:
+            with await self._public_room_id_gen.get_next() as next_id:
                 await self.db_pool.runInteraction(
                     "store_room_txn", store_room_txn, next_id
                 )
@@ -1201,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with self._public_room_id_gen.get_next() as next_id:
+        with await self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
@@ -1281,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with self._public_room_id_gen.get_next() as next_id:
+        with await self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
@@ -1314,7 +1309,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                 "event_id": event_id,
                 "user_id": user_id,
                 "reason": reason,
-                "content": json.dumps(content),
+                "content": json_encoder.encode(content),
             },
             desc="add_event_report",
         )
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index b2fcfc9bfe..161edbeccb 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -17,8 +17,6 @@
 import logging
 from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 lambda: self._known_servers_count,
             )
 
-    @defer.inlineCallbacks
-    def _count_known_servers(self):
+    async def _count_known_servers(self):
         """
         Count the servers that this server knows about.
 
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(query)
             return list(txn)[0][0]
 
-        count = yield self.db_pool.runInteraction("get_known_servers", _transact)
+        count = await self.db_pool.runInteraction("get_known_servers", _transact)
 
         # We always know about ourselves, even if we have nothing in
         # room_memberships (for example, the server is new).
@@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id",
-        list_name="event_ids",
-        inlineCallbacks=True,
+        cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
     )
-    def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+    async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
         """For given set of member event_ids check if they point to a join
         event and if so return the associated user and profile info.
 
@@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             event_ids: The member event IDs to lookup
 
         Returns:
-            Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+            dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
             to `user_id` and ProfileInfo (or None if not join event).
         """
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
@@ -772,13 +767,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return set(room_ids)
 
-    def get_membership_from_event_ids(
+    async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
     ) -> List[dict]:
         """Get user_id and membership of a set of event IDs.
         """
 
-        return self.db_pool.simple_select_many_batch(
+        return await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=member_event_ids,
diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
new file mode 100644
index 0000000000..4cc96a5341
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A table of the IP address and user-agent used to complete each step of a
+-- user-interactive authentication session.
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
+    session_id TEXT NOT NULL,
+    ip TEXT NOT NULL,
+    user_agent TEXT NOT NULL,
+    UNIQUE (session_id, ip, user_agent),
+    FOREIGN KEY (session_id)
+        REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
new file mode 100644
index 0000000000..260b009b48
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/09shadow_ban.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- A shadow-banned user may be told that their requests succeeded when they were
+-- actually ignored.
+ALTER TABLE users ADD COLUMN shadow_banned BOOLEAN;
diff --git a/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
new file mode 100644
index 0000000000..15421b99ac
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/13remove_presence_allow_inbound.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This table is no longer used.
+DROP TABLE IF EXISTS presence_allow_inbound;
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 96e0378e50..458f169617 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         return event.content.get("canonical_alias")
 
     @cached(max_entries=50000)
-    def _get_state_group_for_event(self, event_id):
-        return self.db_pool.simple_select_one_onecol(
+    async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
+        return await self.db_pool.simple_select_one_onecol(
             table="event_to_state_groups",
             keyvalues={"event_id": event_id},
             retcol="state_group",
@@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         cached_method_name="_get_state_group_for_event",
         list_name="event_ids",
         num_args=1,
-        inlineCallbacks=True,
     )
-    def _get_state_group_for_events(self, event_ids):
+    async def _get_state_group_for_events(self, event_ids):
         """Returns mapping event_id -> state_group
         """
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="event_to_state_groups",
             column="event_id",
             iterable=event_ids,
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 802c9019b9..9fe97af56a 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
 
         return len(rooms_to_work_on)
 
-    def get_stats_positions(self):
+    async def get_stats_positions(self) -> int:
         """
         Returns the stats processor positions.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="stats_incremental_position",
             keyvalues={},
             retcol="stream_id",
@@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
         return slice_list
 
     @cached()
-    def get_earliest_token_for_stats(self, stats_type, id):
+    async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
         """
         Fetch the "earliest token". This is used by the room stats delta
         processor to ignore deltas that have been processed between the
@@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
         being calculated.
 
         Returns:
-            Deferred[int]
+            The earliest token.
         """
         table, id_col = TYPE_TO_TABLE[stats_type]
 
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             "%s_current" % (table,),
             keyvalues={id_col: id},
             retcol="completed_delta_stream_id",
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index aaf225894e..497f607703 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -39,15 +39,17 @@ what sort order was used:
 import abc
 import logging
 from collections import namedtuple
-from typing import Optional
+from typing import Dict, Iterable, List, Optional, Tuple
 
 from twisted.internet import defer
 
+from synapse.api.filtering import Filter
+from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -68,8 +70,12 @@ _EventDictReturn = namedtuple(
 
 
 def generate_pagination_where_clause(
-    direction, column_names, from_token, to_token, engine
-):
+    direction: str,
+    column_names: Tuple[str, str],
+    from_token: Optional[Tuple[int, int]],
+    to_token: Optional[Tuple[int, int]],
+    engine: BaseDatabaseEngine,
+) -> str:
     """Creates an SQL expression to bound the columns by the pagination
     tokens.
 
@@ -90,21 +96,19 @@ def generate_pagination_where_clause(
           token, but include those that match the to token.
 
     Args:
-        direction (str): Whether we're paginating backwards("b") or
-            forwards ("f").
-        column_names (tuple[str, str]): The column names to bound. Must *not*
-            be user defined as these get inserted directly into the SQL
-            statement without escapes.
-        from_token (tuple[int, int]|None): The start point for the pagination.
-            This is an exclusive minimum bound if direction is "f", and an
-            inclusive maximum bound if direction is "b".
-        to_token (tuple[int, int]|None): The endpoint point for the pagination.
-            This is an inclusive maximum bound if direction is "f", and an
-            exclusive minimum bound if direction is "b".
+        direction: Whether we're paginating backwards("b") or forwards ("f").
+        column_names: The column names to bound. Must *not* be user defined as
+            these get inserted directly into the SQL statement without escapes.
+        from_token: The start point for the pagination. This is an exclusive
+            minimum bound if direction is "f", and an inclusive maximum bound if
+            direction is "b".
+        to_token: The endpoint point for the pagination. This is an inclusive
+            maximum bound if direction is "f", and an exclusive minimum bound if
+            direction is "b".
         engine: The database engine to generate the clauses for
 
     Returns:
-        str: The sql expression
+        The sql expression
     """
     assert direction in ("b", "f")
 
@@ -132,7 +136,12 @@ def generate_pagination_where_clause(
     return " AND ".join(where_clause)
 
 
-def _make_generic_sql_bound(bound, column_names, values, engine):
+def _make_generic_sql_bound(
+    bound: str,
+    column_names: Tuple[str, str],
+    values: Tuple[Optional[int], int],
+    engine: BaseDatabaseEngine,
+) -> str:
     """Create an SQL expression that bounds the given column names by the
     values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
 
@@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
     out manually.
 
     Args:
-        bound (str): The comparison operator to use. One of ">", "<", ">=",
+        bound: The comparison operator to use. One of ">", "<", ">=",
             "<=", where the values are on the left and columns on the right.
-        names (tuple[str, str]): The column names. Must *not* be user defined
+        names: The column names. Must *not* be user defined
             as these get inserted directly into the SQL statement without
             escapes.
-        values (tuple[int|None, int]): The values to bound the columns by. If
+        values: The values to bound the columns by. If
             the first value is None then only creates a bound on the second
             column.
         engine: The database engine to generate the SQL for
 
     Returns:
-        str
+        The SQL statement
     """
 
     assert bound in (">", "<", ">=", "<=")
@@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
     )
 
 
-def filter_to_clause(event_filter):
+def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
     # "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     def get_room_min_stream_ordering(self):
         raise NotImplementedError()
 
-    @defer.inlineCallbacks
-    def get_room_events_stream_for_rooms(
-        self, room_ids, from_key, to_key, limit=0, order="DESC"
-    ):
+    async def get_room_events_stream_for_rooms(
+        self,
+        room_ids: Iterable[str],
+        from_key: str,
+        to_key: str,
+        limit: int = 0,
+        order: str = "DESC",
+    ) -> Dict[str, Tuple[List[EventBase], str]]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
-            room_id (str)
-            from_key (str): Token from which no events are returned before
-            to_key (str): Token from which no events are returned after. (This
+            room_ids
+            from_key: Token from which no events are returned before
+            to_key: Token from which no events are returned after. (This
                 is typically the current stream token)
-            limit (int): Maximum number of events to return
-            order (str): Either "DESC" or "ASC". Determines which events are
+            limit: Maximum number of events to return
+            order: Either "DESC" or "ASC". Determines which events are
                 returned when the result is limited. If "DESC" then the most
                 recent `limit` events are returned, otherwise returns the
                 oldest `limit` events.
 
         Returns:
-            Deferred[dict[str,tuple[list[FrozenEvent], str]]]
-                A map from room id to a tuple containing:
-                    - list of recent events in the room
-                    - stream ordering key for the start of the chunk of events returned.
+            A map from room id to a tuple containing:
+                - list of recent events in the room
+                - stream ordering key for the start of the chunk of events returned.
         """
         from_id = RoomStreamToken.parse_stream_token(from_key).stream
 
-        room_ids = yield self._events_stream_cache.get_entities_changed(
-            room_ids, from_id
-        )
+        room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
 
         if not room_ids:
             return {}
@@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         results = {}
         room_ids = list(room_ids)
         for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
-            res = yield make_deferred_yieldable(
+            res = await make_deferred_yieldable(
                 defer.gatherResults(
                     [
                         run_in_background(
@@ -361,28 +371,30 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             if self._events_stream_cache.has_entity_changed(room_id, from_key)
         }
 
-    @defer.inlineCallbacks
-    def get_room_events_stream_for_room(
-        self, room_id, from_key, to_key, limit=0, order="DESC"
-    ):
-
+    async def get_room_events_stream_for_room(
+        self,
+        room_id: str,
+        from_key: str,
+        to_key: str,
+        limit: int = 0,
+        order: str = "DESC",
+    ) -> Tuple[List[EventBase], str]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
-            room_id (str)
-            from_key (str): Token from which no events are returned before
-            to_key (str): Token from which no events are returned after. (This
+            room_id
+            from_key: Token from which no events are returned before
+            to_key: Token from which no events are returned after. (This
                 is typically the current stream token)
-            limit (int): Maximum number of events to return
-            order (str): Either "DESC" or "ASC". Determines which events are
+            limit: Maximum number of events to return
+            order: Either "DESC" or "ASC". Determines which events are
                 returned when the result is limited. If "DESC" then the most
                 recent `limit` events are returned, otherwise returns the
                 oldest `limit` events.
 
         Returns:
-            Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
-            events (in ascending order) and the token from the start of
-            the chunk of events returned.
+            The list of events (in ascending order) and the token from the start
+            of the chunk of events returned.
         """
         if from_key == to_key:
             return [], from_key
@@ -390,9 +402,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         from_id = RoomStreamToken.parse_stream_token(from_key).stream
         to_id = RoomStreamToken.parse_stream_token(to_key).stream
 
-        has_changed = yield self._events_stream_cache.has_entity_changed(
-            room_id, from_id
-        )
+        has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
 
         if not has_changed:
             return [], from_key
@@ -410,9 +420,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
             return rows
 
-        rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f)
+        rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
 
-        ret = yield self.get_events_as_list(
+        ret = await self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
 
@@ -430,8 +440,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return ret, key
 
-    @defer.inlineCallbacks
-    def get_membership_changes_for_user(self, user_id, from_key, to_key):
+    async def get_membership_changes_for_user(self, user_id, from_key, to_key):
         from_id = RoomStreamToken.parse_stream_token(from_key).stream
         to_id = RoomStreamToken.parse_stream_token(to_key).stream
 
@@ -460,9 +469,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return rows
 
-        rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f)
+        rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
 
-        ret = yield self.get_events_as_list(
+        ret = await self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
 
@@ -470,27 +479,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return ret
 
-    @defer.inlineCallbacks
-    def get_recent_events_for_room(self, room_id, limit, end_token):
+    async def get_recent_events_for_room(
+        self, room_id: str, limit: int, end_token: str
+    ) -> Tuple[List[EventBase], str]:
         """Get the most recent events in the room in topological ordering.
 
         Args:
-            room_id (str)
-            limit (int)
-            end_token (str): The stream token representing now.
+            room_id
+            limit
+            end_token: The stream token representing now.
 
         Returns:
-            Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
-            events and a token pointing to the start of the returned
-            events.
-            The events returned are in ascending order.
+            A list of events and a token pointing to the start of the returned
+            events. The events returned are in ascending order.
         """
 
-        rows, token = yield self.get_recent_event_ids_for_room(
+        rows, token = await self.get_recent_event_ids_for_room(
             room_id, limit, end_token
         )
 
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
 
@@ -498,20 +506,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return (events, token)
 
-    @defer.inlineCallbacks
-    def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+    async def get_recent_event_ids_for_room(
+        self, room_id: str, limit: int, end_token: str
+    ) -> Tuple[List[_EventDictReturn], str]:
         """Get the most recent events in the room in topological ordering.
 
         Args:
-            room_id (str)
-            limit (int)
-            end_token (str): The stream token representing now.
+            room_id
+            limit
+            end_token: The stream token representing now.
 
         Returns:
-            Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
-            _EventDictReturn and a token pointing to the start of the returned
-            events.
-            The events returned are in ascending order.
+            A list of _EventDictReturn and a token pointing to the start of the
+            returned events. The events returned are in ascending order.
         """
         # Allow a zero limit here, and no-op.
         if limit == 0:
@@ -519,7 +526,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         end_token = RoomStreamToken.parse(end_token)
 
-        rows, token = yield self.db_pool.runInteraction(
+        rows, token = await self.db_pool.runInteraction(
             "get_recent_event_ids_for_room",
             self._paginate_room_events_txn,
             room_id,
@@ -532,12 +539,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return rows, token
 
-    def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+    def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
         """Gets details of the first event in a room at or before a stream ordering
 
         Args:
-            room_id (str):
-            stream_ordering (int):
+            room_id:
+            stream_ordering:
 
         Returns:
             Deferred[(int, int, str)]:
@@ -574,55 +581,67 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             )
             return "t%d-%d" % (topo, token)
 
-    def get_stream_token_for_event(self, event_id):
-        """The stream token for an event
+    async def get_stream_id_for_event(self, event_id: str) -> int:
+        """The stream ID for an event
         Args:
-            event_id(str): The id of the event to look up a stream token for.
+            event_id: The id of the event to look up a stream token for.
         Raises:
             StoreError if the event wasn't in the database.
         Returns:
-            A deferred "s%d" stream token.
+            A stream ID.
         """
-        return self.db_pool.simple_select_one_onecol(
+        return await self.db_pool.simple_select_one_onecol(
             table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
-        ).addCallback(lambda row: "s%d" % (row,))
+        )
 
-    def get_topological_token_for_event(self, event_id):
+    async def get_stream_token_for_event(self, event_id: str) -> str:
         """The stream token for an event
         Args:
-            event_id(str): The id of the event to look up a stream token for.
+            event_id: The id of the event to look up a stream token for.
         Raises:
             StoreError if the event wasn't in the database.
         Returns:
-            A deferred "t%d-%d" topological token.
+            A "s%d" stream token.
         """
-        return self.db_pool.simple_select_one(
+        stream_id = await self.get_stream_id_for_event(event_id)
+        return "s%d" % (stream_id,)
+
+    async def get_topological_token_for_event(self, event_id: str) -> str:
+        """The stream token for an event
+        Args:
+            event_id: The id of the event to look up a stream token for.
+        Raises:
+            StoreError if the event wasn't in the database.
+        Returns:
+            A "t%d-%d" topological token.
+        """
+        row = await self.db_pool.simple_select_one(
             table="events",
             keyvalues={"event_id": event_id},
             retcols=("stream_ordering", "topological_ordering"),
             desc="get_topological_token_for_event",
-        ).addCallback(
-            lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
         )
+        return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
 
-    def get_max_topological_token(self, room_id, stream_key):
+    async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
         """Get the max topological token in a room before the given stream
         ordering.
 
         Args:
-            room_id (str)
-            stream_key (int)
+            room_id
+            stream_key
 
         Returns:
-            Deferred[int]
+            The maximum topological token.
         """
         sql = (
             "SELECT coalesce(max(topological_ordering), 0) FROM events"
             " WHERE room_id = ? AND stream_ordering < ?"
         )
-        return self.db_pool.execute(
+        row = await self.db_pool.execute(
             "get_max_topological_token", None, sql, room_id, stream_key
-        ).addCallback(lambda r: r[0][0] if r else 0)
+        )
+        return row[0][0] if row else 0
 
     def _get_max_topological_txn(self, txn, room_id):
         txn.execute(
@@ -634,16 +653,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return rows[0][0] if rows else 0
 
     @staticmethod
-    def _set_before_and_after(events, rows, topo_order=True):
+    def _set_before_and_after(
+        events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
+    ):
         """Inserts ordering information to events' internal metadata from
         the DB rows.
 
         Args:
-            events (list[FrozenEvent])
-            rows (list[_EventDictReturn])
-            topo_order (bool): Whether the events were ordered topologically
-                or by stream ordering. If true then all rows should have a non
-                null topological_ordering.
+            events
+            rows
+            topo_order: Whether the events were ordered topologically or by stream
+                ordering. If true then all rows should have a non null
+                topological_ordering.
         """
         for event, row in zip(events, rows):
             stream = row.stream_ordering
@@ -656,25 +677,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             internal.after = str(RoomStreamToken(topo, stream))
             internal.order = (int(topo) if topo else 0, int(stream))
 
-    @defer.inlineCallbacks
-    def get_events_around(
-        self, room_id, event_id, before_limit, after_limit, event_filter=None
-    ):
+    async def get_events_around(
+        self,
+        room_id: str,
+        event_id: str,
+        before_limit: int,
+        after_limit: int,
+        event_filter: Optional[Filter] = None,
+    ) -> dict:
         """Retrieve events and pagination tokens around a given event in a
         room.
-
-        Args:
-            room_id (str)
-            event_id (str)
-            before_limit (int)
-            after_limit (int)
-            event_filter (Filter|None)
-
-        Returns:
-            dict
         """
 
-        results = yield self.db_pool.runInteraction(
+        results = await self.db_pool.runInteraction(
             "get_events_around",
             self._get_events_around_txn,
             room_id,
@@ -684,11 +699,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             event_filter,
         )
 
-        events_before = yield self.get_events_as_list(
+        events_before = await self.get_events_as_list(
             list(results["before"]["event_ids"]), get_prev_content=True
         )
 
-        events_after = yield self.get_events_as_list(
+        events_after = await self.get_events_as_list(
             list(results["after"]["event_ids"]), get_prev_content=True
         )
 
@@ -700,17 +715,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         }
 
     def _get_events_around_txn(
-        self, txn, room_id, event_id, before_limit, after_limit, event_filter
-    ):
+        self,
+        txn,
+        room_id: str,
+        event_id: str,
+        before_limit: int,
+        after_limit: int,
+        event_filter: Optional[Filter],
+    ) -> dict:
         """Retrieves event_ids and pagination tokens around a given event in a
         room.
 
         Args:
-            room_id (str)
-            event_id (str)
-            before_limit (int)
-            after_limit (int)
-            event_filter (Filter|None)
+            room_id
+            event_id
+            before_limit
+            after_limit
+            event_filter
 
         Returns:
             dict
@@ -758,22 +779,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             "after": {"event_ids": events_after, "token": end_token},
         }
 
-    @defer.inlineCallbacks
-    def get_all_new_events_stream(self, from_id, current_id, limit):
+    async def get_all_new_events_stream(
+        self, from_id: int, current_id: int, limit: int
+    ) -> Tuple[int, List[EventBase]]:
         """Get all new events
 
          Returns all events with from_id < stream_ordering <= current_id.
 
          Args:
-             from_id (int):  the stream_ordering of the last event we processed
-             current_id (int):  the stream_ordering of the most recently processed event
-             limit (int): the maximum number of events to return
+             from_id:  the stream_ordering of the last event we processed
+             current_id:  the stream_ordering of the most recently processed event
+             limit: the maximum number of events to return
 
          Returns:
-             Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where
-             `next_id` is the next value to pass as `from_id` (it will either be the
-             stream_ordering of the last returned event, or, if fewer than `limit` events
-             were found, `current_id`.
+             A tuple of (next_id, events), where `next_id` is the next value to
+             pass as `from_id` (it will either be the stream_ordering of the
+             last returned event, or, if fewer than `limit` events were found,
+             the `current_id`).
          """
 
         def get_all_new_events_stream_txn(txn):
@@ -795,11 +817,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return upper_bound, [row[1] for row in rows]
 
-        upper_bound, event_ids = yield self.db_pool.runInteraction(
+        upper_bound, event_ids = await self.db_pool.runInteraction(
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
-        events = yield self.get_events_as_list(event_ids)
+        events = await self.get_events_as_list(event_ids)
 
         return upper_bound, events
 
@@ -817,21 +839,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             desc="get_federation_out_pos",
         )
 
-    async def update_federation_out_pos(self, typ, stream_id):
+    async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
         if self._need_to_reset_federation_stream_positions:
             await self.db_pool.runInteraction(
                 "_reset_federation_positions_txn", self._reset_federation_positions_txn
             )
             self._need_to_reset_federation_stream_positions = False
 
-        return await self.db_pool.simple_update_one(
+        await self.db_pool.simple_update_one(
             table="federation_stream_position",
             keyvalues={"type": typ, "instance_name": self._instance_name},
             updatevalues={"stream_id": stream_id},
             desc="update_federation_out_pos",
         )
 
-    def _reset_federation_positions_txn(self, txn):
+    def _reset_federation_positions_txn(self, txn) -> None:
         """Fiddles with the `federation_stream_position` table to make it match
         the configured federation sender instances during start up.
         """
@@ -892,39 +914,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 values={"stream_id": stream_id},
             )
 
-    def has_room_changed_since(self, room_id, stream_id):
+    def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
         return self._events_stream_cache.has_entity_changed(room_id, stream_id)
 
     def _paginate_room_events_txn(
         self,
         txn,
-        room_id,
-        from_token,
-        to_token=None,
-        direction="b",
-        limit=-1,
-        event_filter=None,
-    ):
+        room_id: str,
+        from_token: RoomStreamToken,
+        to_token: Optional[RoomStreamToken] = None,
+        direction: str = "b",
+        limit: int = -1,
+        event_filter: Optional[Filter] = None,
+    ) -> Tuple[List[_EventDictReturn], str]:
         """Returns list of events before or after a given token.
 
         Args:
             txn
-            room_id (str)
-            from_token (RoomStreamToken): The token used to stream from
-            to_token (RoomStreamToken|None): A token which if given limits the
-                results to only those before
-            direction(char): Either 'b' or 'f' to indicate whether we are
-                paginating forwards or backwards from `from_key`.
-            limit (int): The maximum number of events to return.
-            event_filter (Filter|None): If provided filters the events to
+            room_id
+            from_token: The token used to stream from
+            to_token: A token which if given limits the results to only those before
+            direction: Either 'b' or 'f' to indicate whether we are paginating
+                forwards or backwards from `from_key`.
+            limit: The maximum number of events to return.
+            event_filter: If provided filters the events to
                 those that match the filter.
 
         Returns:
-            Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
-            as a list of _EventDictReturn and a token that points to the end
-            of the result set. If no events are returned then the end of the
-            stream has been reached (i.e. there are no events between
-            `from_token` and `to_token`), or `limit` is zero.
+            A list of _EventDictReturn and a token that points to the end of the
+            result set. If no events are returned then the end of the stream has
+            been reached (i.e. there are no events between `from_token` and
+            `to_token`), or `limit` is zero.
         """
 
         assert int(limit) >= 0
@@ -1008,35 +1028,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return rows, str(next_token)
 
-    @defer.inlineCallbacks
-    def paginate_room_events(
-        self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
-    ):
+    async def paginate_room_events(
+        self,
+        room_id: str,
+        from_key: str,
+        to_key: Optional[str] = None,
+        direction: str = "b",
+        limit: int = -1,
+        event_filter: Optional[Filter] = None,
+    ) -> Tuple[List[EventBase], str]:
         """Returns list of events before or after a given token.
 
         Args:
-            room_id (str)
-            from_key (str): The token used to stream from
-            to_key (str|None): A token which if given limits the results to
-                only those before
-            direction(char): Either 'b' or 'f' to indicate whether we are
-                paginating forwards or backwards from `from_key`.
-            limit (int): The maximum number of events to return.
-            event_filter (Filter|None): If provided filters the events to
-                those that match the filter.
+            room_id
+            from_key: The token used to stream from
+            to_key: A token which if given limits the results to only those before
+            direction: Either 'b' or 'f' to indicate whether we are paginating
+                forwards or backwards from `from_key`.
+            limit: The maximum number of events to return.
+            event_filter: If provided filters the events to those that match the filter.
 
         Returns:
-            tuple[list[FrozenEvent], str]: Returns the results as a list of
-            events and a token that points to the end of the result set. If no
-            events are returned then the end of the stream has been reached
-            (i.e. there are no events between `from_key` and `to_key`).
+            The results as a list of events and a token that points to the end
+            of the result set. If no events are returned then the end of the
+            stream has been reached (i.e. there are no events between `from_key`
+            and `to_key`).
         """
 
         from_key = RoomStreamToken.parse(from_key)
         if to_key:
             to_key = RoomStreamToken.parse(to_key)
 
-        rows, token = yield self.db_pool.runInteraction(
+        rows, token = await self.db_pool.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,
             room_id,
@@ -1047,7 +1070,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             event_filter,
         )
 
-        events = yield self.get_events_as_list(
+        events = await self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
         )
 
@@ -1057,8 +1080,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
 
 class StreamStore(StreamWorkerStore):
-    def get_room_max_stream_ordering(self):
+    def get_room_max_stream_ordering(self) -> int:
         return self._stream_id_gen.get_current_token()
 
-    def get_room_min_stream_ordering(self):
+    def get_room_min_stream_ordering(self) -> int:
         return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index e4e0a0c433..0c34bbf21a 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -17,11 +17,10 @@
 import logging
 from typing import Dict, List, Tuple
 
-from canonicaljson import json
-
 from synapse.storage._base import db_to_json
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
 from synapse.types import JsonDict
+from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
                 txn.execute(sql, (user_id, room_id))
                 tags = []
                 for tag, content in txn:
-                    tags.append(json.dumps(tag) + ":" + content)
+                    tags.append(json_encoder.encode(tag) + ":" + content)
                 tag_json = "{" + ",".join(tags) + "}"
                 results.append((stream_id, (user_id, room_id, tag_json)))
 
@@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore):
         Returns:
             The next account data ID.
         """
-        content_json = json.dumps(content)
+        content_json = json_encoder.encode(content)
 
         def add_tag_txn(txn, next_id):
             self.db_pool.simple_upsert_txn(
@@ -211,7 +210,7 @@ class TagsStore(TagsWorkerStore):
             )
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
@@ -233,7 +232,7 @@ class TagsStore(TagsWorkerStore):
             txn.execute(sql, (user_id, room_id, tag))
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with self._account_data_id_gen.get_next() as next_id:
+        with await self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 37276f73f8..9eef8e57c5 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -12,15 +12,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import attr
-from canonicaljson import json
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict
-from synapse.util import stringutils as stringutils
+from synapse.util import json_encoder, stringutils
 
 
 @attr.s
@@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore):
             StoreError if a unique session ID cannot be generated.
         """
         # The clientdict gets stored as JSON.
-        clientdict_json = json.dumps(clientdict)
+        clientdict_json = json_encoder.encode(clientdict)
 
         # autogen a session ID and try to create it. We may clash, so just
         # try a few times till one goes through, giving up eventually.
@@ -143,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore):
             await self.db_pool.simple_upsert(
                 table="ui_auth_sessions_credentials",
                 keyvalues={"session_id": session_id, "stage_type": stage_type},
-                values={"result": json.dumps(result)},
+                values={"result": json_encoder.encode(result)},
                 desc="mark_ui_auth_stage_complete",
             )
         except self.db_pool.engine.module.IntegrityError:
@@ -184,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore):
                 The dictionary from the client root level, not the 'auth' key.
         """
         # The clientdict gets stored as JSON.
-        clientdict_json = json.dumps(clientdict)
+        clientdict_json = json_encoder.encode(clientdict)
 
         await self.db_pool.simple_update_one(
             table="ui_auth_sessions",
@@ -214,14 +214,16 @@ class UIAuthWorkerStore(SQLBaseStore):
             value,
         )
 
-    def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+    def _set_ui_auth_session_data_txn(
+        self, txn: LoggingTransaction, session_id: str, key: str, value: Any
+    ):
         # Get the current value.
         result = self.db_pool.simple_select_one_txn(
             txn,
             table="ui_auth_sessions",
             keyvalues={"session_id": session_id},
             retcols=("serverdict",),
-        )
+        )  # type: Dict[str, Any]  # type: ignore
 
         # Update it and add it back to the database.
         serverdict = db_to_json(result["serverdict"])
@@ -231,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore):
             txn,
             table="ui_auth_sessions",
             keyvalues={"session_id": session_id},
-            updatevalues={"serverdict": json.dumps(serverdict)},
+            updatevalues={"serverdict": json_encoder.encode(serverdict)},
         )
 
     async def get_ui_auth_session_data(
@@ -258,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore):
 
         return serverdict.get(key, default)
 
+    async def add_user_agent_ip_to_ui_auth_session(
+        self, session_id: str, user_agent: str, ip: str,
+    ):
+        """Add the given user agent / IP to the tracking table
+        """
+        await self.db_pool.simple_upsert(
+            table="ui_auth_sessions_ips",
+            keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
+            values={},
+            desc="add_user_agent_ip_to_ui_auth_session",
+        )
+
+    async def get_user_agents_ips_to_ui_auth_session(
+        self, session_id: str,
+    ) -> List[Tuple[str, str]]:
+        """Get the given user agents / IPs used during the ui auth process
+
+        Returns:
+            List of user_agent/ip pairs
+        """
+        rows = await self.db_pool.simple_select_list(
+            table="ui_auth_sessions_ips",
+            keyvalues={"session_id": session_id},
+            retcols=("user_agent", "ip"),
+            desc="get_user_agents_ips_to_ui_auth_session",
+        )
+        return [(row["user_agent"], row["ip"]) for row in rows]
+
 
 class UIAuthStore(UIAuthWorkerStore):
     def delete_old_ui_auth_sessions(self, expiration_time: int):
@@ -275,12 +305,23 @@ class UIAuthStore(UIAuthWorkerStore):
             expiration_time,
         )
 
-    def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+    def _delete_old_ui_auth_sessions_txn(
+        self, txn: LoggingTransaction, expiration_time: int
+    ):
         # Get the expired sessions.
         sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
         txn.execute(sql, [expiration_time])
         session_ids = [r[0] for r in txn.fetchall()]
 
+        # Delete the corresponding IP/user agents.
+        self.db_pool.simple_delete_many_txn(
+            txn,
+            table="ui_auth_sessions_ips",
+            column="session_id",
+            iterable=session_ids,
+            keyvalues={},
+        )
+
         # Delete the corresponding completed credentials.
         self.db_pool.simple_delete_many_txn(
             txn,
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index af21fe457a..20cbcd851c 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -15,6 +15,7 @@
 
 import logging
 import re
+from typing import Any, Dict, Optional
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.database import DatabasePool
@@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         )
 
     @cached()
-    def get_user_in_directory(self, user_id):
-        return self.db_pool.simple_select_one(
+    async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
+        return await self.db_pool.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
             retcols=("display_name", "avatar_url"),
@@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         users.update(rows)
         return list(users)
 
-    def get_user_directory_stream_pos(self):
-        return self.db_pool.simple_select_one_onecol(
+    async def get_user_directory_stream_pos(self) -> int:
+        return await self.db_pool.simple_select_one_onecol(
             table="user_directory_stream_pos",
             keyvalues={},
             retcol="stream_id",
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index ab6cb2c1f6..e3547e53b3 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -13,35 +13,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import operator
-
 from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
 
 class UserErasureWorkerStore(SQLBaseStore):
     @cached()
-    def is_user_erased(self, user_id):
+    async def is_user_erased(self, user_id: str) -> bool:
         """
         Check if the given user id has requested erasure
 
         Args:
-            user_id (str): full user id to check
+            user_id: full user id to check
 
         Returns:
-            Deferred[bool]: True if the user has requested erasure
+            True if the user has requested erasure
         """
-        return self.db_pool.simple_select_onecol(
+        result = await self.db_pool.simple_select_onecol(
             table="erased_users",
             keyvalues={"user_id": user_id},
             retcol="1",
             desc="is_user_erased",
-        ).addCallback(operator.truth)
+        )
+        return bool(result)
 
-    @cachedList(
-        cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
-    )
-    def are_users_erased(self, user_ids):
+    @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
+    async def are_users_erased(self, user_ids):
         """
         Checks which users in a list have requested erasure
 
@@ -49,14 +46,14 @@ class UserErasureWorkerStore(SQLBaseStore):
             user_ids (iterable[str]): full user id to check
 
         Returns:
-            Deferred[dict[str, bool]]:
+            dict[str, bool]:
                 for each user, whether the user has requested erasure.
         """
         # this serves the dual purpose of (a) making sure we can do len and
         # iterate it multiple times, and (b) avoiding duplicates.
         user_ids = tuple(set(user_ids))
 
-        rows = yield self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="erased_users",
             column="user_id",
             iterable=user_ids,
@@ -65,8 +62,7 @@ class UserErasureWorkerStore(SQLBaseStore):
         )
         erased_users = {row["user_id"] for row in rows}
 
-        res = {u: u in erased_users for u in user_ids}
-        return res
+        return {u: u in erased_users for u in user_ids}
 
 
 class UserErasureStore(UserErasureWorkerStore):
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
deleted file mode 100644
index 18a462f0ee..0000000000
--- a/synapse/storage/presence.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from collections import namedtuple
-
-from synapse.api.constants import PresenceState
-
-
-class UserPresenceState(
-    namedtuple(
-        "UserPresenceState",
-        (
-            "user_id",
-            "state",
-            "last_active_ts",
-            "last_federation_update_ts",
-            "last_user_sync_ts",
-            "status_msg",
-            "currently_active",
-        ),
-    )
-):
-    """Represents the current presence state of the user.
-
-    user_id (str)
-    last_active (int): Time in msec that the user last interacted with server.
-    last_federation_update (int): Time in msec since either a) we sent a presence
-        update to other servers or b) we received a presence update, depending
-        on if is a local user or not.
-    last_user_sync (int): Time in msec that the user last *completed* a sync
-        (or event stream).
-    status_msg (str): User set status message.
-    """
-
-    def as_dict(self):
-        return dict(self._asdict())
-
-    @staticmethod
-    def from_dict(d):
-        return UserPresenceState(**d)
-
-    def copy_and_replace(self, **kwargs):
-        return self._replace(**kwargs)
-
-    @classmethod
-    def default(cls, user_id):
-        """Returns a default presence state.
-        """
-        return cls(
-            user_id=user_id,
-            state=PresenceState.OFFLINE,
-            last_active_ts=0,
-            last_federation_update_ts=0,
-            last_user_sync_ts=0,
-            status_msg=None,
-            currently_active=False,
-        )
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e2ddd01290..5b07847773 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,9 +14,10 @@
 # limitations under the License.
 
 import contextlib
+import heapq
 import threading
 from collections import deque
-from typing import Dict, Set, Tuple
+from typing import Dict, List, Set
 
 from typing_extensions import Deque
 
@@ -80,7 +81,7 @@ class StreamIdGenerator(object):
             upwards, -1 to grow downwards.
 
     Usage:
-        with stream_id_gen.get_next() as stream_id:
+        with await stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
 
@@ -95,10 +96,10 @@ class StreamIdGenerator(object):
             )
         self._unfinished_ids = deque()  # type: Deque[int]
 
-    def get_next(self):
+    async def get_next(self):
         """
         Usage:
-            with stream_id_gen.get_next() as stream_id:
+            with await stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
         with self._lock:
@@ -117,10 +118,10 @@ class StreamIdGenerator(object):
 
         return manager()
 
-    def get_next_mult(self, n):
+    async def get_next_mult(self, n):
         """
         Usage:
-            with stream_id_gen.get_next(n) as stream_ids:
+            with await stream_id_gen.get_next(n) as stream_ids:
                 # ... persist events ...
         """
         with self._lock:
@@ -158,63 +159,13 @@ class StreamIdGenerator(object):
 
             return self._current
 
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
 
-class ChainedIdGenerator(object):
-    """Used to generate new stream ids where the stream must be kept in sync
-    with another stream. It generates pairs of IDs, the first element is an
-    integer ID for this stream, the second element is the ID for the stream
-    that this stream needs to be kept in sync with."""
-
-    def __init__(self, chained_generator, db_conn, table, column):
-        self.chained_generator = chained_generator
-        self._table = table
-        self._lock = threading.Lock()
-        self._current_max = _load_current_id(db_conn, table, column)
-        self._unfinished_ids = deque()  # type: Deque[Tuple[int, int]]
-
-    def get_next(self):
-        """
-        Usage:
-            with stream_id_gen.get_next() as (stream_id, chained_id):
-                # ... persist event ...
-        """
-        with self._lock:
-            self._current_max += 1
-            next_id = self._current_max
-            chained_id = self.chained_generator.get_current_token()
-
-            self._unfinished_ids.append((next_id, chained_id))
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                yield (next_id, chained_id)
-            finally:
-                with self._lock:
-                    self._unfinished_ids.remove((next_id, chained_id))
-
-        return manager()
-
-    def get_current_token(self):
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-        """
-        with self._lock:
-            if self._unfinished_ids:
-                stream_id, chained_id = self._unfinished_ids[0]
-                return stream_id - 1, chained_id
-
-            return self._current_max, self.chained_generator.get_current_token()
-
-    def advance(self, token: int):
-        """Stub implementation for advancing the token when receiving updates
-        over replication; raises an exception as this instance should be the
-        only source of updates.
+        For streams with single writers this is equivalent to
+        `get_current_token`.
         """
-
-        raise Exception(
-            "Attempted to advance token on source for table %r", self._table
-        )
+        return self.get_current_token()
 
 
 class MultiWriterIdGenerator:
@@ -260,6 +211,23 @@ class MultiWriterIdGenerator:
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
 
+        # We track the max position where we know everything before has been
+        # persisted. This is done by a) looking at the min across all instances
+        # and b) noting that if we have seen a run of persisted positions
+        # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+        #
+        # Note: There is no guarentee that the IDs generated by the sequence
+        # will be gapless; gaps can form when e.g. a transaction was rolled
+        # back. This means that sometimes we won't be able to skip forward the
+        # position even though everything has been persisted. However, since
+        # gaps should be relatively rare it's still worth doing the book keeping
+        # that allows us to skip forwards when there are gapless runs of
+        # positions.
+        self._persisted_upto_position = (
+            min(self._current_positions.values()) if self._current_positions else 0
+        )
+        self._known_persisted_positions = []  # type: List[int]
+
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
     def _load_current_ids(
@@ -284,9 +252,12 @@ class MultiWriterIdGenerator:
 
         return current_positions
 
-    def _load_next_id_txn(self, txn):
+    def _load_next_id_txn(self, txn) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
 
+    def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+        return self._sequence_gen.get_next_mult_txn(txn, n)
+
     async def get_next(self):
         """
         Usage:
@@ -298,7 +269,7 @@ class MultiWriterIdGenerator:
         # Assert the fetched ID is actually greater than what we currently
         # believe the ID to be. If not, then the sequence and table have got
         # out of sync somehow.
-        assert self.get_current_token() < next_id
+        assert self.get_current_token_for_writer(self._instance_name) < next_id
 
         with self._lock:
             self._unfinished_ids.add(next_id)
@@ -312,6 +283,34 @@ class MultiWriterIdGenerator:
 
         return manager()
 
+    async def get_next_mult(self, n: int):
+        """
+        Usage:
+            with await stream_id_gen.get_next_mult(5) as stream_ids:
+                # ... persist events ...
+        """
+        next_ids = await self._db.runInteraction(
+            "_load_next_mult_id", self._load_next_mult_id_txn, n
+        )
+
+        # Assert the fetched ID is actually greater than any ID we've already
+        # seen. If not, then the sequence and table have got out of sync
+        # somehow.
+        assert max(self.get_positions().values(), default=0) < min(next_ids)
+
+        with self._lock:
+            self._unfinished_ids.update(next_ids)
+
+        @contextlib.contextmanager
+        def manager():
+            try:
+                yield next_ids
+            finally:
+                for i in next_ids:
+                    self._mark_id_as_finished(i)
+
+        return manager()
+
     def get_next_txn(self, txn: LoggingTransaction):
         """
         Usage:
@@ -344,16 +343,18 @@ class MultiWriterIdGenerator:
                 curr = self._current_positions.get(self._instance_name, 0)
                 self._current_positions[self._instance_name] = max(curr, next_id)
 
-    def get_current_token(self, instance_name: str = None) -> int:
-        """Gets the current position of a named writer (defaults to current
-        instance).
-
-        Returns 0 if we don't have a position for the named writer (likely due
-        to it being a new writer).
+    def get_current_token(self) -> int:
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
         """
 
-        if instance_name is None:
-            instance_name = self._instance_name
+        # Currently we don't support this operation, as it's not obvious how to
+        # condense the stream positions of multiple writers into a single int.
+        raise NotImplementedError()
+
+    def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+        """
 
         with self._lock:
             return self._current_positions.get(instance_name, 0)
@@ -374,3 +375,53 @@ class MultiWriterIdGenerator:
             self._current_positions[instance_name] = max(
                 new_id, self._current_positions.get(instance_name, 0)
             )
+
+            self._add_persisted_position(new_id)
+
+    def get_persisted_upto_position(self) -> int:
+        """Get the max position where all previous positions have been
+        persisted.
+
+        Note: In the worst case scenario this will be equal to the minimum
+        position across writers. This means that the returned position here can
+        lag if one writer doesn't write very often.
+        """
+
+        with self._lock:
+            return self._persisted_upto_position
+
+    def _add_persisted_position(self, new_id: int):
+        """Record that we have persisted a position.
+
+        This is used to keep the `_current_positions` up to date.
+        """
+
+        # We require that the lock is locked by caller
+        assert self._lock.locked()
+
+        heapq.heappush(self._known_persisted_positions, new_id)
+
+        # We move the current min position up if the minimum current positions
+        # of all instances is higher (since by definition all positions less
+        # that that have been persisted).
+        min_curr = min(self._current_positions.values())
+        self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+        # We now iterate through the seen positions, discarding those that are
+        # less than the current min positions, and incrementing the min position
+        # if its exactly one greater.
+        #
+        # This is also where we discard items from `_known_persisted_positions`
+        # (to ensure the list doesn't infinitely grow).
+        while self._known_persisted_positions:
+            if self._known_persisted_positions[0] <= self._persisted_upto_position:
+                heapq.heappop(self._known_persisted_positions)
+            elif (
+                self._known_persisted_positions[0] == self._persisted_upto_position + 1
+            ):
+                heapq.heappop(self._known_persisted_positions)
+                self._persisted_upto_position += 1
+            else:
+                # There was a gap in seen positions, so there is nothing more to
+                # do.
+                break
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 63dfea4220..ffc1894748 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import abc
 import threading
-from typing import Callable, Optional
+from typing import Callable, List, Optional
 
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.storage.types import Cursor
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
         txn.execute("SELECT nextval(?)", (self._sequence_name,))
         return txn.fetchone()[0]
 
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        txn.execute(
+            "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+        )
+        return [i for (i,) in txn]
+
 
 GetFirstCallbackType = Callable[[Cursor], int]