summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/database.py150
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/__init__.py202
-rw-r--r--synapse/storage/databases/main/account_data.py25
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/censor_events.py6
-rw-r--r--synapse/storage/databases/main/client_ips.py113
-rw-r--r--synapse/storage/databases/main/deviceinbox.py8
-rw-r--r--synapse/storage/databases/main/devices.py94
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py109
-rw-r--r--synapse/storage/databases/main/event_federation.py2
-rw-r--r--synapse/storage/databases/main/event_push_actions.py12
-rw-r--r--synapse/storage/databases/main/events.py38
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py47
-rw-r--r--synapse/storage/databases/main/group_server.py2
-rw-r--r--synapse/storage/databases/main/media_repository.py6
-rw-r--r--synapse/storage/databases/main/metrics.py260
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py119
-rw-r--r--synapse/storage/databases/main/presence.py4
-rw-r--r--synapse/storage/databases/main/purge_events.py8
-rw-r--r--synapse/storage/databases/main/push_rule.py17
-rw-r--r--synapse/storage/databases/main/pusher.py4
-rw-r--r--synapse/storage/databases/main/receipts.py14
-rw-r--r--synapse/storage/databases/main/registration.py243
-rw-r--r--synapse/storage/databases/main/room.py131
-rw-r--r--synapse/storage/databases/main/roommember.py38
-rw-r--r--synapse/storage/databases/main/schema/delta/20/pushers.py19
-rw-r--r--synapse/storage/databases/main/schema/delta/25/fts.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/27/ts.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/30/as_users.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/31/pushers.py19
-rw-r--r--synapse/storage/databases/main/schema/delta/31/search_update.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/33/event_fields.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/33/remote_media_ts.py5
-rw-r--r--synapse/storage/databases/main/schema/delta/56/event_labels.sql2
-rw-r--r--synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py7
-rw-r--r--synapse/storage/databases/main/schema/delta/57/local_current_membership.py1
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11dehydration.sql20
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11fallback.sql24
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres4
-rw-r--r--synapse/storage/databases/main/schema/delta/58/18stream_positions.sql22
-rw-r--r--synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres25
-rw-r--r--synapse/storage/databases/main/search.py4
-rw-r--r--synapse/storage/databases/main/state.py6
-rw-r--r--synapse/storage/databases/main/stats.py4
-rw-r--r--synapse/storage/databases/main/stream.py350
-rw-r--r--synapse/storage/databases/main/tags.py4
-rw-r--r--synapse/storage/databases/main/transactions.py110
-rw-r--r--synapse/storage/databases/main/ui_auth.py6
-rw-r--r--synapse/storage/databases/main/user_directory.py4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py2
-rw-r--r--synapse/storage/databases/state/bg_updates.py2
-rw-r--r--synapse/storage/databases/state/store.py9
-rw-r--r--synapse/storage/engines/_base.py17
-rw-r--r--synapse/storage/engines/postgres.py10
-rw-r--r--synapse/storage/engines/sqlite.py10
-rw-r--r--synapse/storage/persist_events.py19
-rw-r--r--synapse/storage/prepare_database.py33
-rw-r--r--synapse/storage/roommember.py2
-rw-r--r--synapse/storage/state.py6
-rw-r--r--synapse/storage/types.py6
-rw-r--r--synapse/storage/util/id_generators.py323
-rw-r--r--synapse/storage/util/sequence.py99
64 files changed, 1980 insertions, 866 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 79ec8f119d..0ba3a025cf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import (
     overload,
 )
 
+import attr
 from prometheus_client import Histogram
 from typing_extensions import Literal
 
@@ -90,13 +91,17 @@ def make_pool(
     return adbapi.ConnectionPool(
         db_config.config["name"],
         cp_reactor=reactor,
-        cp_openfun=engine.on_new_connection,
+        cp_openfun=lambda conn: engine.on_new_connection(
+            LoggingDatabaseConnection(conn, engine, "on_new_connection")
+        ),
         **db_config.config.get("args", {})
     )
 
 
 def make_conn(
-    db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
+    default_txn_name: str,
 ) -> Connection:
     """Make a new connection to the database and return it.
 
@@ -109,11 +114,60 @@ def make_conn(
         for k, v in db_config.config.get("args", {}).items()
         if not k.startswith("cp_")
     }
-    db_conn = engine.module.connect(**db_params)
+    native_db_conn = engine.module.connect(**db_params)
+    db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
+
     engine.on_new_connection(db_conn)
     return db_conn
 
 
+@attr.s(slots=True)
+class LoggingDatabaseConnection:
+    """A wrapper around a database connection that returns `LoggingTransaction`
+    as its cursor class.
+
+    This is mainly used on startup to ensure that queries get logged correctly
+    """
+
+    conn = attr.ib(type=Connection)
+    engine = attr.ib(type=BaseDatabaseEngine)
+    default_txn_name = attr.ib(type=str)
+
+    def cursor(
+        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+    ) -> "LoggingTransaction":
+        if not txn_name:
+            txn_name = self.default_txn_name
+
+        return LoggingTransaction(
+            self.conn.cursor(),
+            name=txn_name,
+            database_engine=self.engine,
+            after_callbacks=after_callbacks,
+            exception_callbacks=exception_callbacks,
+        )
+
+    def close(self) -> None:
+        self.conn.close()
+
+    def commit(self) -> None:
+        self.conn.commit()
+
+    def rollback(self, *args, **kwargs) -> None:
+        self.conn.rollback(*args, **kwargs)
+
+    def __enter__(self) -> "Connection":
+        self.conn.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback) -> bool:
+        return self.conn.__exit__(exc_type, exc_value, traceback)
+
+    # Proxy through any unknown lookups to the DB conn class.
+    def __getattr__(self, name):
+        return getattr(self.conn, name)
+
+
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 #
 # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@@ -247,6 +301,12 @@ class LoggingTransaction:
     def close(self) -> None:
         self.txn.close()
 
+    def __enter__(self) -> "LoggingTransaction":
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
+
 
 class PerformanceCounters:
     def __init__(self):
@@ -395,7 +455,7 @@ class DatabasePool:
 
     def new_transaction(
         self,
-        conn: Connection,
+        conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
@@ -403,6 +463,24 @@ class DatabasePool:
         *args: Any,
         **kwargs: Any
     ) -> R:
+        """Start a new database transaction with the given connection.
+
+        Note: The given func may be called multiple times under certain
+        failure modes. This is normally fine when in a standard transaction,
+        but care must be taken if the connection is in `autocommit` mode that
+        the function will correctly handle being aborted and retried half way
+        through its execution.
+
+        Args:
+            conn
+            desc
+            after_callbacks
+            exception_callbacks
+            func
+            *args
+            **kwargs
+        """
+
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -418,12 +496,10 @@ class DatabasePool:
             i = 0
             N = 5
             while True:
-                cursor = LoggingTransaction(
-                    conn.cursor(),
-                    name,
-                    self.engine,
-                    after_callbacks,
-                    exception_callbacks,
+                cursor = conn.cursor(
+                    txn_name=name,
+                    after_callbacks=after_callbacks,
+                    exception_callbacks=exception_callbacks,
                 )
                 try:
                     r = func(cursor, *args, **kwargs)
@@ -508,7 +584,12 @@ class DatabasePool:
             sql_txn_timer.labels(desc).observe(duration)
 
     async def runInteraction(
-        self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+        self,
+        desc: str,
+        func: "Callable[..., R]",
+        *args: Any,
+        db_autocommit: bool = False,
+        **kwargs: Any
     ) -> R:
         """Starts a transaction on the database and runs a given function
 
@@ -518,6 +599,18 @@ class DatabasePool:
                 database transaction (twisted.enterprise.adbapi.Transaction) as
                 its first argument, followed by `args` and `kwargs`.
 
+            db_autocommit: Whether to run the function in "autocommit" mode,
+                i.e. outside of a transaction. This is useful for transactions
+                that are only a single query.
+
+                Currently, this is only implemented for Postgres. SQLite will still
+                run the function inside a transaction.
+
+                WARNING: This means that if func fails half way through then
+                the changes will *not* be rolled back. `func` may also get
+                called multiple times if the transaction is retried, so must
+                correctly handle that case.
+
             args: positional args to pass to `func`
             kwargs: named args to pass to `func`
 
@@ -538,6 +631,7 @@ class DatabasePool:
                 exception_callbacks,
                 func,
                 *args,
+                db_autocommit=db_autocommit,
                 **kwargs
             )
 
@@ -551,7 +645,11 @@ class DatabasePool:
         return cast(R, result)
 
     async def runWithConnection(
-        self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+        self,
+        func: "Callable[..., R]",
+        *args: Any,
+        db_autocommit: bool = False,
+        **kwargs: Any
     ) -> R:
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
@@ -560,6 +658,9 @@ class DatabasePool:
                 database connection (twisted.enterprise.adbapi.Connection) as
                 its first argument, followed by `args` and `kwargs`.
             args: positional args to pass to `func`
+            db_autocommit: Whether to run the function in "autocommit" mode,
+                i.e. outside of a transaction. This is useful for transaction
+                that are only a single query. Currently only affects postgres.
             kwargs: named args to pass to `func`
 
         Returns:
@@ -575,6 +676,13 @@ class DatabasePool:
         start_time = monotonic_time()
 
         def inner_func(conn, *args, **kwargs):
+            # We shouldn't be in a transaction. If we are then something
+            # somewhere hasn't committed after doing work. (This is likely only
+            # possible during startup, as `run*` will ensure changes are
+            # committed/rolled back before putting the connection back in the
+            # pool).
+            assert not self.engine.in_transaction(conn)
+
             with LoggingContext("runWithConnection", parent_context) as context:
                 sched_duration_sec = monotonic_time() - start_time
                 sql_scheduling_timer.observe(sched_duration_sec)
@@ -584,7 +692,17 @@ class DatabasePool:
                     logger.debug("Reconnecting closed database connection")
                     conn.reconnect()
 
-                return func(conn, *args, **kwargs)
+                try:
+                    if db_autocommit:
+                        self.engine.attempt_to_set_autocommit(conn, True)
+
+                    db_conn = LoggingDatabaseConnection(
+                        conn, self.engine, "runWithConnection"
+                    )
+                    return func(db_conn, *args, **kwargs)
+                finally:
+                    if db_autocommit:
+                        self.engine.attempt_to_set_autocommit(conn, False)
 
         return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
@@ -1621,7 +1739,7 @@ class DatabasePool:
 
     def get_cache_dict(
         self,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         table: str,
         entity_column: str,
         stream_column: str,
@@ -1642,9 +1760,7 @@ class DatabasePool:
             "limit": limit,
         }
 
-        sql = self.engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
+        txn = db_conn.cursor(txn_name="get_cache_dict")
         txn.execute(sql, (int(max_value),))
 
         cache = {row[0]: int(row[1]) for row in txn}
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index aa5d490624..0c24325011 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -46,7 +46,7 @@ class Databases:
             db_name = database_config.name
             engine = create_engine(database_config.config)
 
-            with make_conn(database_config, engine) as db_conn:
+            with make_conn(database_config, engine, "startup") as db_conn:
                 logger.info("[database config %r]: Checking database server", db_name)
                 engine.check_database(db_conn)
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 2ae2fbd5d7..9b16f45f3e 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,9 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import calendar
 import logging
-import time
 from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import PresenceState
@@ -160,19 +158,25 @@ class DataStore(
         )
 
         if isinstance(self.database_engine, PostgresEngine):
+            # We set the `writers` to an empty list here as we don't care about
+            # missing updates over restarts, as we'll not have anything in our
+            # caches to invalidate. (This reduces the amount of writes to the DB
+            # that happen).
             self._cache_id_gen = MultiWriterIdGenerator(
                 db_conn,
                 database,
-                instance_name="master",
+                stream_name="caches",
+                instance_name=hs.get_instance_name(),
                 table="cache_invalidation_stream_by_instance",
                 instance_column="instance_name",
                 id_column="stream_id",
                 sequence_name="cache_invalidation_stream_seq",
+                writers=[],
             )
         else:
             self._cache_id_gen = None
 
-        super(DataStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._presence_on_startup = self._get_active_presence(db_conn)
 
@@ -262,9 +266,6 @@ class DataStore(
         self._stream_order_on_start = self.get_room_max_stream_ordering()
         self._min_stream_order_on_start = self.get_room_min_stream_ordering()
 
-        # Used in _generate_user_daily_visits to keep track of progress
-        self._last_user_visit_update = self._get_start_of_day()
-
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
@@ -283,7 +284,6 @@ class DataStore(
             " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
             " WHERE state != ?"
         )
-        sql = self.database_engine.convert_param_style(sql)
 
         txn = db_conn.cursor()
         txn.execute(sql, (PresenceState.OFFLINE,))
@@ -295,192 +295,6 @@ class DataStore(
 
         return [UserPresenceState(**row) for row in rows]
 
-    async def count_daily_users(self) -> int:
-        """
-        Counts the number of users who used this homeserver in the last 24 hours.
-        """
-        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-        return await self.db_pool.runInteraction(
-            "count_daily_users", self._count_users, yesterday
-        )
-
-    async def count_monthly_users(self) -> int:
-        """
-        Counts the number of users who used this homeserver in the last 30 days.
-        Note this method is intended for phonehome metrics only and is different
-        from the mau figure in synapse.storage.monthly_active_users which,
-        amongst other things, includes a 3 day grace period before a user counts.
-        """
-        thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-        return await self.db_pool.runInteraction(
-            "count_monthly_users", self._count_users, thirty_days_ago
-        )
-
-    def _count_users(self, txn, time_from):
-        """
-        Returns number of users seen in the past time_from period
-        """
-        sql = """
-            SELECT COALESCE(count(*), 0) FROM (
-                SELECT user_id FROM user_ips
-                WHERE last_seen > ?
-                GROUP BY user_id
-            ) u
-        """
-        txn.execute(sql, (time_from,))
-        (count,) = txn.fetchone()
-        return count
-
-    async def count_r30_users(self) -> Dict[str, int]:
-        """
-        Counts the number of 30 day retained users, defined as:-
-         * Users who have created their accounts more than 30 days ago
-         * Where last seen at most 30 days ago
-         * Where account creation and last_seen are > 30 days apart
-
-        Returns:
-             A mapping of counts globally as well as broken out by platform.
-        """
-
-        def _count_r30_users(txn):
-            thirty_days_in_secs = 86400 * 30
-            now = int(self._clock.time())
-            thirty_days_ago_in_secs = now - thirty_days_in_secs
-
-            sql = """
-                SELECT platform, COALESCE(count(*), 0) FROM (
-                     SELECT
-                        users.name, platform, users.creation_ts * 1000,
-                        MAX(uip.last_seen)
-                     FROM users
-                     INNER JOIN (
-                         SELECT
-                         user_id,
-                         last_seen,
-                         CASE
-                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
-                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
-                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
-                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
-                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
-                             ELSE 'unknown'
-                         END
-                         AS platform
-                         FROM user_ips
-                     ) uip
-                     ON users.name = uip.user_id
-                     AND users.appservice_id is NULL
-                     AND users.creation_ts < ?
-                     AND uip.last_seen/1000 > ?
-                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
-                     GROUP BY users.name, platform, users.creation_ts
-                ) u GROUP BY platform
-            """
-
-            results = {}
-            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
-            for row in txn:
-                if row[0] == "unknown":
-                    pass
-                results[row[0]] = row[1]
-
-            sql = """
-                SELECT COALESCE(count(*), 0) FROM (
-                    SELECT users.name, users.creation_ts * 1000,
-                                                        MAX(uip.last_seen)
-                    FROM users
-                    INNER JOIN (
-                        SELECT
-                        user_id,
-                        last_seen
-                        FROM user_ips
-                    ) uip
-                    ON users.name = uip.user_id
-                    AND appservice_id is NULL
-                    AND users.creation_ts < ?
-                    AND uip.last_seen/1000 > ?
-                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
-                    GROUP BY users.name, users.creation_ts
-                ) u
-            """
-
-            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
-            (count,) = txn.fetchone()
-            results["all"] = count
-
-            return results
-
-        return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
-
-    def _get_start_of_day(self):
-        """
-        Returns millisecond unixtime for start of UTC day.
-        """
-        now = time.gmtime()
-        today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
-        return today_start * 1000
-
-    async def generate_user_daily_visits(self) -> None:
-        """
-        Generates daily visit data for use in cohort/ retention analysis
-        """
-
-        def _generate_user_daily_visits(txn):
-            logger.info("Calling _generate_user_daily_visits")
-            today_start = self._get_start_of_day()
-            a_day_in_milliseconds = 24 * 60 * 60 * 1000
-            now = self.clock.time_msec()
-
-            sql = """
-                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
-                    SELECT u.user_id, u.device_id, ?
-                    FROM user_ips AS u
-                    LEFT JOIN (
-                      SELECT user_id, device_id, timestamp FROM user_daily_visits
-                      WHERE timestamp = ?
-                    ) udv
-                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
-                    INNER JOIN users ON users.name=u.user_id
-                    WHERE last_seen > ? AND last_seen <= ?
-                    AND udv.timestamp IS NULL AND users.is_guest=0
-                    AND users.appservice_id IS NULL
-                    GROUP BY u.user_id, u.device_id
-            """
-
-            # This means that the day has rolled over but there could still
-            # be entries from the previous day. There is an edge case
-            # where if the user logs in at 23:59 and overwrites their
-            # last_seen at 00:01 then they will not be counted in the
-            # previous day's stats - it is important that the query is run
-            # often to minimise this case.
-            if today_start > self._last_user_visit_update:
-                yesterday_start = today_start - a_day_in_milliseconds
-                txn.execute(
-                    sql,
-                    (
-                        yesterday_start,
-                        yesterday_start,
-                        self._last_user_visit_update,
-                        today_start,
-                    ),
-                )
-                self._last_user_visit_update = today_start
-
-            txn.execute(
-                sql, (today_start, today_start, self._last_user_visit_update, now)
-            )
-            # Update _last_user_visit_update to now. The reason to do this
-            # rather just clamping to the beginning of the day is to limit
-            # the size of the join - meaning that the query can be run more
-            # frequently
-            self._last_user_visit_update = now
-
-        await self.db_pool.runInteraction(
-            "generate_user_daily_visits", _generate_user_daily_visits
-        )
-
     async def get_users(self) -> List[Dict[str, Any]]:
         """Function to retrieve a list of users in users table.
 
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 4436b1a83d..49ee23470d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -18,6 +18,7 @@ import abc
 import logging
 from typing import Dict, List, Optional, Tuple
 
+from synapse.api.constants import AccountDataTypes
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
@@ -29,22 +30,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-class AccountDataWorkerStore(SQLBaseStore):
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
+class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
     """This is an abstract base class where subclasses must implement
     `get_max_account_data_stream_id` which can be called in the initializer.
     """
 
-    # This ABCMeta metaclass ensures that we cannot be instantiated without
-    # the abstract methods being implemented.
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, database: DatabasePool, db_conn, hs):
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
         )
 
-        super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     @abc.abstractmethod
     def get_max_account_data_stream_id(self):
@@ -293,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore):
         self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
     ) -> bool:
         ignored_account_data = await self.get_global_account_data_by_type_for_user(
-            "m.ignored_user_list",
+            AccountDataTypes.IGNORED_USER_LIST,
             ignorer_user_id,
             on_invalidate=cache_context.invalidate,
         )
         if not ignored_account_data:
             return False
 
-        return ignored_user_id in ignored_account_data.get("ignored_users", {})
+        try:
+            return ignored_user_id in ignored_account_data.get("ignored_users", {})
+        except TypeError:
+            # The type of the ignored_users field is invalid.
+            return False
 
 
 class AccountDataStore(AccountDataWorkerStore):
@@ -315,7 +318,7 @@ class AccountDataStore(AccountDataWorkerStore):
             ],
         )
 
-        super(AccountDataStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self) -> int:
         """Get the current max stream id for the private user data stream
@@ -341,7 +344,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as room_account_data has a unique constraint
             # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
@@ -389,7 +392,7 @@ class AccountDataStore(AccountDataWorkerStore):
         """
         content_json = json_encoder.encode(content)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as account_data has a unique constraint on
             # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 454c0bc50c..85f6b1e3fd 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -52,7 +52,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         )
         self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
 
-        super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_app_services(self):
         return self.services_cache
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index f211ddbaf8..4bb2b9c28c 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -21,8 +21,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.databases.main.events import encode_json
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util.frozenutils import frozendict_json_encoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -105,7 +105,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 and original_event.internal_metadata.is_redacted()
             ):
                 # Redaction was allowed
-                pruned_json = encode_json(
+                pruned_json = frozendict_json_encoder.encode(
                     prune_event_dict(
                         original_event.room_version, original_event.get_dict()
                     )
@@ -171,7 +171,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 return
 
             # Prune the event's dict then convert it to JSON.
-            pruned_json = encode_json(
+            pruned_json = frozendict_json_encoder.encode(
                 prune_event_dict(event.room_version, event.get_dict())
             )
 
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index c2fc847fbc..a25a888443 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -31,7 +31,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000
 
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
             "user_ips_device_index",
@@ -351,16 +351,70 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         return updated
 
 
-class ClientIpStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self.user_ips_max_age = hs.config.user_ips_max_age
+
+        if hs.config.run_background_tasks and self.user_ips_max_age:
+            self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
+    @wrap_as_background_process("prune_old_user_ips")
+    async def _prune_old_user_ips(self):
+        """Removes entries in user IPs older than the configured period.
+        """
+
+        if self.user_ips_max_age is None:
+            # Nothing to do
+            return
+
+        if not await self.db_pool.updates.has_completed_background_update(
+            "devices_last_seen"
+        ):
+            # Only start pruning if we have finished populating the devices
+            # last seen info.
+            return
+
+        # We do a slightly funky SQL delete to ensure we don't try and delete
+        # too much at once (as the table may be very large from before we
+        # started pruning).
+        #
+        # This works by finding the max last_seen that is less than the given
+        # time, but has no more than N rows before it, deleting all rows with
+        # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+        # returns exactly one row).
+        sql = """
+            DELETE FROM user_ips
+            WHERE last_seen <= (
+                SELECT COALESCE(MAX(last_seen), -1)
+                FROM (
+                    SELECT last_seen FROM user_ips
+                    WHERE last_seen <= ?
+                    ORDER BY last_seen ASC
+                    LIMIT 5000
+                ) AS u
+            )
+        """
+
+        timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+        def _prune_old_user_ips_txn(txn):
+            txn.execute(sql, (timestamp,))
+
+        await self.db_pool.runInteraction(
+            "_prune_old_user_ips", _prune_old_user_ips_txn
+        )
+
+
+class ClientIpStore(ClientIpWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000
         )
 
-        super(ClientIpStore, self).__init__(database, db_conn, hs)
-
-        self.user_ips_max_age = hs.config.user_ips_max_age
+        super().__init__(database, db_conn, hs)
 
         # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
         self._batch_row_update = {}
@@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             "before", "shutdown", self._update_client_ips_batch
         )
 
-        if self.user_ips_max_age:
-            self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
-
     async def insert_client_ip(
         self, user_id, access_token, ip, user_agent, device_id, now=None
     ):
@@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             }
             for (access_token, ip), (user_agent, last_seen) in results.items()
         ]
-
-    @wrap_as_background_process("prune_old_user_ips")
-    async def _prune_old_user_ips(self):
-        """Removes entries in user IPs older than the configured period.
-        """
-
-        if self.user_ips_max_age is None:
-            # Nothing to do
-            return
-
-        if not await self.db_pool.updates.has_completed_background_update(
-            "devices_last_seen"
-        ):
-            # Only start pruning if we have finished populating the devices
-            # last seen info.
-            return
-
-        # We do a slightly funky SQL delete to ensure we don't try and delete
-        # too much at once (as the table may be very large from before we
-        # started pruning).
-        #
-        # This works by finding the max last_seen that is less than the given
-        # time, but has no more than N rows before it, deleting all rows with
-        # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
-        # returns exactly one row).
-        sql = """
-            DELETE FROM user_ips
-            WHERE last_seen <= (
-                SELECT COALESCE(MAX(last_seen), -1)
-                FROM (
-                    SELECT last_seen FROM user_ips
-                    WHERE last_seen <= ?
-                    ORDER BY last_seen ASC
-                    LIMIT 5000
-                ) AS u
-            )
-        """
-
-        timestamp = self.clock.time_msec() - self.user_ips_max_age
-
-        def _prune_old_user_ips_txn(txn):
-            txn.execute(sql, (timestamp,))
-
-        await self.db_pool.runInteraction(
-            "_prune_old_user_ips", _prune_old_user_ips_txn
-        )
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 0044433110..d42faa3f1f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -283,7 +283,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
             "device_inbox_stream_index",
@@ -313,7 +313,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(DeviceInboxStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) to the last stream_id that has been
         # deleted up to. This is so that we can no op deletions.
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((destination, stream_id, now_ms, edu_json))
             txn.executemany(sql, rows)
 
-        with await self._device_inbox_id_gen.get_next() as stream_id:
+        async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 txn, stream_id, local_messages_by_user_then_device
             )
 
-        with await self._device_inbox_id_gen.get_next() as stream_id:
+        async with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 306fc6947c..2d0a6408b5 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
 # Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@ from synapse.storage.database import (
     make_tuple_comparison_clause,
 )
 from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
-from synapse.util import json_encoder
+from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import Cache, cached, cachedList
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
             THe new stream ID.
         """
 
-        with await self._device_list_id_gen.get_next() as stream_id:
+        async with self._device_list_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -698,10 +698,84 @@ class DeviceWorkerStore(SQLBaseStore):
             _mark_remote_user_device_list_as_unsubscribed_txn,
         )
 
+    async def get_dehydrated_device(
+        self, user_id: str
+    ) -> Optional[Tuple[str, JsonDict]]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        # FIXME: make sure device ID still exists in devices table
+        row = await self.db_pool.simple_select_one(
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            retcols=["device_id", "device_data"],
+            allow_none=True,
+        )
+        return (
+            (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
+        )
+
+    def _store_dehydrated_device_txn(
+        self, txn, user_id: str, device_id: str, device_data: str
+    ) -> Optional[str]:
+        old_device_id = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            retcol="device_id",
+            allow_none=True,
+        )
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            values={"device_id": device_id, "device_data": device_data},
+        )
+        return old_device_id
+
+    async def store_dehydrated_device(
+        self, user_id: str, device_id: str, device_data: JsonDict
+    ) -> Optional[str]:
+        """Store a dehydrated device for a user.
+
+        Args:
+            user_id: the user that we are storing the device for
+            device_id: the ID of the dehydrated device
+            device_data: the dehydrated device information
+        Returns:
+            device id of the user's previous dehydrated device, if any
+        """
+        return await self.db_pool.runInteraction(
+            "store_dehydrated_device_txn",
+            self._store_dehydrated_device_txn,
+            user_id,
+            device_id,
+            json_encoder.encode(device_data),
+        )
+
+    async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
+        """Remove a dehydrated device.
+
+        Args:
+            user_id: the user that the dehydrated device belongs to
+            device_id: the ID of the dehydrated device
+        """
+        count = await self.db_pool.simple_delete(
+            "dehydrated_devices",
+            {"user_id": user_id, "device_id": device_id},
+            desc="remove_dehydrated_device",
+        )
+        return count >= 1
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
             "device_lists_stream_idx",
@@ -826,7 +900,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(DeviceStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
@@ -837,7 +911,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
 
     async def store_device(
-        self, user_id: str, device_id: str, initial_device_display_name: str
+        self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
     ) -> bool:
         """Ensure the given device is known; add it to the store if not
 
@@ -955,7 +1029,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def update_remote_device_list_cache_entry(
-        self, user_id: str, device_id: str, content: JsonDict, stream_id: int
+        self, user_id: str, device_id: str, content: JsonDict, stream_id: str
     ) -> None:
         """Updates a single device in the cache of a remote user's devicelist.
 
@@ -983,7 +1057,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_id: str,
         content: JsonDict,
-        stream_id: int,
+        stream_id: str,
     ) -> None:
         if content.get("deleted"):
             self.db_pool.simple_delete_txn(
@@ -1093,7 +1167,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         if not device_ids:
             return
 
-        with await self._device_list_id_gen.get_next_mult(
+        async with self._device_list_id_gen.get_next_mult(
             len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
@@ -1108,7 +1182,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return stream_ids[-1]
 
         context = get_active_span_text_map()
-        with await self._device_list_id_gen.get_next_mult(
+        async with self._device_list_id_gen.get_next_mult(
             len(hosts) * len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c8df0bcb3f..359dc6e968 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
 # Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
+    async def set_e2e_fallback_keys(
+        self, user_id: str, device_id: str, fallback_keys: JsonDict
+    ) -> None:
+        """Set the user's e2e fallback keys.
+
+        Args:
+            user_id: the user whose keys are being set
+            device_id: the device whose keys are being set
+            fallback_keys: the keys to set.  This is a map from key ID (which is
+                of the form "algorithm:id") to key data.
+        """
+        # fallback_keys will usually only have one item in it, so using a for
+        # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
+        # FIXME: make sure that only one key per algorithm is uploaded
+        for key_id, fallback_key in fallback_keys.items():
+            algorithm, key_id = key_id.split(":", 1)
+            await self.db_pool.simple_upsert(
+                "e2e_fallback_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "algorithm": algorithm,
+                },
+                values={
+                    "key_id": key_id,
+                    "key_json": json_encoder.encode(fallback_key),
+                    "used": False,
+                },
+                desc="set_e2e_fallback_key",
+            )
+
+    @cached(max_entries=10000)
+    async def get_e2e_unused_fallback_key_types(
+        self, user_id: str, device_id: str
+    ) -> List[str]:
+        """Returns the fallback key types that have an unused key.
+
+        Args:
+            user_id: the user whose keys are being queried
+            device_id: the device whose keys are being queried
+
+        Returns:
+            a list of key types
+        """
+        return await self.db_pool.simple_select_onecol(
+            "e2e_fallback_keys_json",
+            keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
+            retcol="algorithm",
+            desc="get_e2e_unused_fallback_key_types",
+        )
+
     async def get_e2e_cross_signing_key(
         self, user_id: str, key_type: str, from_user_id: Optional[str] = None
     ) -> Optional[dict]:
@@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
                 " LIMIT 1"
             )
+            fallback_sql = (
+                "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
+                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+                " LIMIT 1"
+            )
             result = {}
             delete = []
+            used_fallbacks = []
             for user_id, device_id, algorithm in query_list:
                 user_result = result.setdefault(user_id, {})
                 device_result = user_result.setdefault(device_id, {})
                 txn.execute(sql, (user_id, device_id, algorithm))
-                for key_id, key_json in txn:
+                otk_row = txn.fetchone()
+                if otk_row is not None:
+                    key_id, key_json = otk_row
                     device_result[algorithm + ":" + key_id] = key_json
                     delete.append((user_id, device_id, algorithm, key_id))
+                else:
+                    # no one-time key available, so see if there's a fallback
+                    # key
+                    txn.execute(fallback_sql, (user_id, device_id, algorithm))
+                    fallback_row = txn.fetchone()
+                    if fallback_row is not None:
+                        key_id, key_json, used = fallback_row
+                        device_result[algorithm + ":" + key_id] = key_json
+                        if not used:
+                            used_fallbacks.append(
+                                (user_id, device_id, algorithm, key_id)
+                            )
+
+            # drop any one-time keys that were claimed
             sql = (
                 "DELETE FROM e2e_one_time_keys_json"
                 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 self._invalidate_cache_and_stream(
                     txn, self.count_e2e_one_time_keys, (user_id, device_id)
                 )
+            # mark fallback keys as used
+            for user_id, device_id, algorithm, key_id in used_fallbacks:
+                self.db_pool.simple_update_txn(
+                    txn,
+                    "e2e_fallback_keys_json",
+                    {
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "algorithm": algorithm,
+                        "key_id": key_id,
+                    },
+                    {"used": True},
+                )
+                self._invalidate_cache_and_stream(
+                    txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+                )
+
             return result
 
         return await self.db_pool.runInteraction(
@@ -754,6 +844,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             self._invalidate_cache_and_stream(
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="dehydrated_devices",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="e2e_fallback_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+            )
 
         await self.db_pool.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
@@ -831,7 +934,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key (dict): the key data
         """
 
-        with await self._cross_signing_id_gen.get_next() as stream_id:
+        async with self._cross_signing_id_gen.get_next() as stream_id:
             return await self.db_pool.runInteraction(
                 "add_e2e_cross_signing_key",
                 self._set_e2e_cross_signing_key_txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4c3c162acf..6d3689c09e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -600,7 +600,7 @@ class EventFederationStore(EventFederationWorkerStore):
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(EventFederationStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 7805fb814e..80f3b4d740 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -20,7 +20,7 @@ from typing import Dict, List, Optional, Tuple, Union
 import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -68,17 +68,13 @@ def _deserialize_action(actions, is_highlight):
 
 class EventPushActionsWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
         self.stream_ordering_month_ago = None
         self.stream_ordering_day_ago = None
 
-        cur = LoggingTransaction(
-            db_conn.cursor(),
-            name="_find_stream_orderings_for_times_txn",
-            database_engine=self.database_engine,
-        )
+        cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
         self._find_stream_orderings_for_times_txn(cur)
         cur.close()
 
@@ -661,7 +657,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(EventPushActionsStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
             self.EPA_HIGHLIGHT_INDEX,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 9a80f419e3..b4abd961b9 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,7 @@
 import itertools
 import logging
 from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
 
 import attr
 from prometheus_client import Counter
@@ -52,16 +52,6 @@ event_counter = Counter(
 )
 
 
-def encode_json(json_object):
-    """
-    Encode a Python object as JSON and return it in a Unicode string.
-    """
-    out = frozendict_json_encoder.encode(json_object)
-    if isinstance(out, bytes):
-        out = out.decode("utf8")
-    return out
-
-
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
@@ -156,15 +146,15 @@ class PersistEventsStore:
         # Note: Multiple instances of this function cannot be in flight at
         # the same time for the same room.
         if backfilled:
-            stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
+            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
         else:
-            stream_ordering_manager = await self._stream_id_gen.get_next_mult(
+            stream_ordering_manager = self._stream_id_gen.get_next_mult(
                 len(events_and_contexts)
             )
 
-        with stream_ordering_manager as stream_orderings:
+        async with stream_ordering_manager as stream_orderings:
             for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
@@ -341,6 +331,10 @@ class PersistEventsStore:
         min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
+        # stream orderings should have been assigned by now
+        assert min_stream_order
+        assert max_stream_order
+
         self._update_forward_extremities_txn(
             txn,
             new_forward_extremities=new_forward_extremeties,
@@ -743,7 +737,9 @@ class PersistEventsStore:
                     logger.exception("")
                     raise
 
-                metadata_json = encode_json(event.internal_metadata.get_dict())
+                metadata_json = frozendict_json_encoder.encode(
+                    event.internal_metadata.get_dict()
+                )
 
                 sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
                 txn.execute(sql, (metadata_json, event.event_id))
@@ -797,10 +793,10 @@ class PersistEventsStore:
                 {
                     "event_id": event.event_id,
                     "room_id": event.room_id,
-                    "internal_metadata": encode_json(
+                    "internal_metadata": frozendict_json_encoder.encode(
                         event.internal_metadata.get_dict()
                     ),
-                    "json": encode_json(event_dict(event)),
+                    "json": frozendict_json_encoder.encode(event_dict(event)),
                     "format_version": event.format_version,
                 }
                 for event, _ in events_and_contexts
@@ -1108,6 +1104,10 @@ class PersistEventsStore:
     def _store_room_members_txn(self, txn, events, backfilled):
         """Store a room member in the database.
         """
+
+        def str_or_none(val: Any) -> Optional[str]:
+            return val if isinstance(val, str) else None
+
         self.db_pool.simple_insert_many_txn(
             txn,
             table="room_memberships",
@@ -1118,8 +1118,8 @@ class PersistEventsStore:
                     "sender": event.user_id,
                     "room_id": event.room_id,
                     "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
+                    "display_name": str_or_none(event.content.get("displayname")),
+                    "avatar_url": str_or_none(event.content.get("avatar_url")),
                 }
                 for event in events
             ],
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e53c6373a8..5e4af2eb51 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -29,7 +29,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
     DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 17f5997b89..b7ed8ca6ab 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -13,8 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from __future__ import division
-
 import itertools
 import logging
 import threading
@@ -76,8 +74,15 @@ class EventRedactBehaviour(Names):
 
 
 class EventsWorkerStore(SQLBaseStore):
+    # Whether to use dedicated DB threads for event fetching. This is only used
+    # if there are multiple DB threads available. When used will lock the DB
+    # thread for periods of time (so unit tests want to disable this when they
+    # run DB transactions on the main thread). See EVENT_QUEUE_* for more
+    # options controlling this.
+    USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
+
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(EventsWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         if isinstance(database.engine, PostgresEngine):
             # If we're using Postgres than we can use `MultiWriterIdGenerator`
@@ -85,21 +90,25 @@ class EventsWorkerStore(SQLBaseStore):
             self._stream_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                stream_name="events",
                 instance_name=hs.get_instance_name(),
                 table="events",
                 instance_column="instance_name",
                 id_column="stream_ordering",
                 sequence_name="events_stream_seq",
+                writers=hs.config.worker.writers.events,
             )
             self._backfill_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                stream_name="backfill",
                 instance_name=hs.get_instance_name(),
                 table="events",
                 instance_column="instance_name",
                 id_column="stream_ordering",
                 sequence_name="events_backfill_stream_seq",
                 positive=False,
+                writers=hs.config.worker.writers.events,
             )
         else:
             # We shouldn't be running in worker mode with SQLite, but its useful
@@ -520,7 +529,11 @@ class EventsWorkerStore(SQLBaseStore):
 
                 if not event_list:
                     single_threaded = self.database_engine.single_threaded
-                    if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+                    if (
+                        not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+                        or single_threaded
+                        or i > EVENT_QUEUE_ITERATIONS
+                    ):
                         self._event_fetch_ongoing -= 1
                         return
                     else:
@@ -710,6 +723,7 @@ class EventsWorkerStore(SQLBaseStore):
                 internal_metadata_dict=internal_metadata,
                 rejected_reason=rejected_reason,
             )
+            original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
 
             event_map[event_id] = original_ev
 
@@ -777,6 +791,8 @@ class EventsWorkerStore(SQLBaseStore):
 
          * event_id (str)
 
+         * stream_ordering (int): stream ordering for this event
+
          * json (str): json-encoded event structure
 
          * internal_metadata (str): json-encoded internal metadata dict
@@ -809,13 +825,15 @@ class EventsWorkerStore(SQLBaseStore):
             sql = """\
                 SELECT
                   e.event_id,
-                  e.internal_metadata,
-                  e.json,
-                  e.format_version,
+                  e.stream_ordering,
+                  ej.internal_metadata,
+                  ej.json,
+                  ej.format_version,
                   r.room_version,
                   rej.reason
-                FROM event_json as e
-                  LEFT JOIN rooms r USING (room_id)
+                FROM events AS e
+                  JOIN event_json AS ej USING (event_id)
+                  LEFT JOIN rooms r ON r.room_id = e.room_id
                   LEFT JOIN rejections as rej USING (event_id)
                 WHERE """
 
@@ -829,11 +847,12 @@ class EventsWorkerStore(SQLBaseStore):
                 event_id = row[0]
                 event_dict[event_id] = {
                     "event_id": event_id,
-                    "internal_metadata": row[1],
-                    "json": row[2],
-                    "format_version": row[3],
-                    "room_version_id": row[4],
-                    "rejected_reason": row[5],
+                    "stream_ordering": row[1],
+                    "internal_metadata": row[2],
+                    "json": row[3],
+                    "format_version": row[4],
+                    "room_version_id": row[5],
+                    "rejected_reason": row[6],
                     "redactions": [],
                 }
 
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ccfbb2135e..7218191965 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
 
             return next_id
 
-        with await self._group_updates_id_gen.get_next() as next_id:
+        async with self._group_updates_id_gen.get_next() as next_id:
             res = await self.db_pool.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 1d76c761a6..cc538c5c10 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -24,9 +24,7 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
 
 class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(MediaRepositoryBackgroundUpdateStore, self).__init__(
-            database, db_conn, hs
-        )
+        super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
             update_name="local_media_repository_url_idx",
@@ -94,7 +92,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
         """Get the metadata for a local piece of media
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 686052bd83..0acf0617ca 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -12,17 +12,41 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import typing
-from collections import Counter
+import calendar
+import logging
+import time
+from typing import Dict
 
-from synapse.metrics import BucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics import GaugeBucketCollector
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
 
+logger = logging.getLogger(__name__)
+
+# Collect metrics on the number of forward extremities that exist.
+_extremities_collecter = GaugeBucketCollector(
+    "synapse_forward_extremities",
+    "Number of rooms on the server with the given number of forward extremities"
+    " or fewer",
+    buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500],
+)
+
+# we also expose metrics on the "number of excess extremity events", which is
+# (E-1)*N, where E is the number of extremities and N is the number of state
+# events in the room. This is an approximation to the number of state events
+# we could remove from state resolution by reducing the graph to a single
+# forward extremity.
+_excess_state_events_collecter = GaugeBucketCollector(
+    "synapse_excess_extremity_events",
+    "Number of rooms on the server with the given number of excess extremity "
+    "events, or fewer",
+    buckets=[0] + [1 << n for n in range(12)],
+)
+
 
 class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     """Functions to pull various metrics from the DB, for e.g. phone home
@@ -32,40 +56,37 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
-        # Collect metrics on the number of forward extremities that exist.
-        # Counter of number of extremities to count
-        self._current_forward_extremities_amount = (
-            Counter()
-        )  # type: typing.Counter[int]
-
-        BucketCollector(
-            "synapse_forward_extremities",
-            lambda: self._current_forward_extremities_amount,
-            buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
-        )
-
         # Read the extrems every 60 minutes
-        def read_forward_extremities():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "read_forward_extremities", self._read_forward_extremities
-            )
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000)
 
-        hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+        # Used in _generate_user_daily_visits to keep track of progress
+        self._last_user_visit_update = self._get_start_of_day()
 
+    @wrap_as_background_process("read_forward_extremities")
     async def _read_forward_extremities(self):
         def fetch(txn):
             txn.execute(
                 """
-                select count(*) c from event_forward_extremities
-                group by room_id
+                SELECT t1.c, t2.c
+                FROM (
+                    SELECT room_id, COUNT(*) c FROM event_forward_extremities
+                    GROUP BY room_id
+                ) t1 LEFT JOIN (
+                    SELECT room_id, COUNT(*) c FROM current_state_events
+                    GROUP BY room_id
+                ) t2 ON t1.room_id = t2.room_id
                 """
             )
             return txn.fetchall()
 
         res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
-        self._current_forward_extremities_amount = Counter([x[0] for x in res])
+
+        _extremities_collecter.update_data(x[0] for x in res)
+
+        _excess_state_events_collecter.update_data(
+            (x[0] - 1) * x[1] for x in res if x[1]
+        )
 
     async def count_daily_messages(self):
         """
@@ -120,3 +141,190 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             return count
 
         return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
+
+    async def count_daily_users(self) -> int:
+        """
+        Counts the number of users who used this homeserver in the last 24 hours.
+        """
+        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+        return await self.db_pool.runInteraction(
+            "count_daily_users", self._count_users, yesterday
+        )
+
+    async def count_monthly_users(self) -> int:
+        """
+        Counts the number of users who used this homeserver in the last 30 days.
+        Note this method is intended for phonehome metrics only and is different
+        from the mau figure in synapse.storage.monthly_active_users which,
+        amongst other things, includes a 3 day grace period before a user counts.
+        """
+        thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+        return await self.db_pool.runInteraction(
+            "count_monthly_users", self._count_users, thirty_days_ago
+        )
+
+    def _count_users(self, txn, time_from):
+        """
+        Returns number of users seen in the past time_from period
+        """
+        sql = """
+            SELECT COALESCE(count(*), 0) FROM (
+                SELECT user_id FROM user_ips
+                WHERE last_seen > ?
+                GROUP BY user_id
+            ) u
+        """
+        txn.execute(sql, (time_from,))
+        (count,) = txn.fetchone()
+        return count
+
+    async def count_r30_users(self) -> Dict[str, int]:
+        """
+        Counts the number of 30 day retained users, defined as:-
+         * Users who have created their accounts more than 30 days ago
+         * Where last seen at most 30 days ago
+         * Where account creation and last_seen are > 30 days apart
+
+        Returns:
+             A mapping of counts globally as well as broken out by platform.
+        """
+
+        def _count_r30_users(txn):
+            thirty_days_in_secs = 86400 * 30
+            now = int(self._clock.time())
+            thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+            sql = """
+                SELECT platform, COALESCE(count(*), 0) FROM (
+                     SELECT
+                        users.name, platform, users.creation_ts * 1000,
+                        MAX(uip.last_seen)
+                     FROM users
+                     INNER JOIN (
+                         SELECT
+                         user_id,
+                         last_seen,
+                         CASE
+                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
+                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+                             ELSE 'unknown'
+                         END
+                         AS platform
+                         FROM user_ips
+                     ) uip
+                     ON users.name = uip.user_id
+                     AND users.appservice_id is NULL
+                     AND users.creation_ts < ?
+                     AND uip.last_seen/1000 > ?
+                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                     GROUP BY users.name, platform, users.creation_ts
+                ) u GROUP BY platform
+            """
+
+            results = {}
+            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+            for row in txn:
+                if row[0] == "unknown":
+                    pass
+                results[row[0]] = row[1]
+
+            sql = """
+                SELECT COALESCE(count(*), 0) FROM (
+                    SELECT users.name, users.creation_ts * 1000,
+                                                        MAX(uip.last_seen)
+                    FROM users
+                    INNER JOIN (
+                        SELECT
+                        user_id,
+                        last_seen
+                        FROM user_ips
+                    ) uip
+                    ON users.name = uip.user_id
+                    AND appservice_id is NULL
+                    AND users.creation_ts < ?
+                    AND uip.last_seen/1000 > ?
+                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                    GROUP BY users.name, users.creation_ts
+                ) u
+            """
+
+            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+            (count,) = txn.fetchone()
+            results["all"] = count
+
+            return results
+
+        return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+
+    def _get_start_of_day(self):
+        """
+        Returns millisecond unixtime for start of UTC day.
+        """
+        now = time.gmtime()
+        today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+        return today_start * 1000
+
+    @wrap_as_background_process("generate_user_daily_visits")
+    async def generate_user_daily_visits(self) -> None:
+        """
+        Generates daily visit data for use in cohort/ retention analysis
+        """
+
+        def _generate_user_daily_visits(txn):
+            logger.info("Calling _generate_user_daily_visits")
+            today_start = self._get_start_of_day()
+            a_day_in_milliseconds = 24 * 60 * 60 * 1000
+            now = self._clock.time_msec()
+
+            sql = """
+                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+                    SELECT u.user_id, u.device_id, ?
+                    FROM user_ips AS u
+                    LEFT JOIN (
+                      SELECT user_id, device_id, timestamp FROM user_daily_visits
+                      WHERE timestamp = ?
+                    ) udv
+                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+                    INNER JOIN users ON users.name=u.user_id
+                    WHERE last_seen > ? AND last_seen <= ?
+                    AND udv.timestamp IS NULL AND users.is_guest=0
+                    AND users.appservice_id IS NULL
+                    GROUP BY u.user_id, u.device_id
+            """
+
+            # This means that the day has rolled over but there could still
+            # be entries from the previous day. There is an edge case
+            # where if the user logs in at 23:59 and overwrites their
+            # last_seen at 00:01 then they will not be counted in the
+            # previous day's stats - it is important that the query is run
+            # often to minimise this case.
+            if today_start > self._last_user_visit_update:
+                yesterday_start = today_start - a_day_in_milliseconds
+                txn.execute(
+                    sql,
+                    (
+                        yesterday_start,
+                        yesterday_start,
+                        self._last_user_visit_update,
+                        today_start,
+                    ),
+                )
+                self._last_user_visit_update = today_start
+
+            txn.execute(
+                sql, (today_start, today_start, self._last_user_visit_update, now)
+            )
+            # Update _last_user_visit_update to now. The reason to do this
+            # rather just clamping to the beginning of the day is to limit
+            # the size of the join - meaning that the query can be run more
+            # frequently
+            self._last_user_visit_update = now
+
+        await self.db_pool.runInteraction(
+            "generate_user_daily_visits", _generate_user_daily_visits
+        )
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 1d793d3deb..c66f558567 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -28,10 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 class MonthlyActiveUsersWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
 
+        self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+        self._max_mau_value = hs.config.max_mau_value
+
     @cached(num_args=0)
     async def get_monthly_active_count(self) -> int:
         """Generates current count of monthly active users
@@ -41,7 +44,14 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         """
 
         def _count_users(txn):
-            sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
+            # Exclude app service users
+            sql = """
+                SELECT COALESCE(count(*), 0)
+                FROM monthly_active_users
+                    LEFT JOIN users
+                    ON monthly_active_users.user_id=users.name
+                WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
+            """
             txn.execute(sql)
             (count,) = txn.fetchone()
             return count
@@ -117,60 +127,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             desc="user_last_seen_monthly_active",
         )
 
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
-
-        self._limit_usage_by_mau = hs.config.limit_usage_by_mau
-        self._mau_stats_only = hs.config.mau_stats_only
-        self._max_mau_value = hs.config.max_mau_value
-
-        # Do not add more reserved users than the total allowable number
-        # cur = LoggingTransaction(
-        self.db_pool.new_transaction(
-            db_conn,
-            "initialise_mau_threepids",
-            [],
-            [],
-            self._initialise_reserved_users,
-            hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
-        )
-
-    def _initialise_reserved_users(self, txn, threepids):
-        """Ensures that reserved threepids are accounted for in the MAU table, should
-        be called on start up.
-
-        Args:
-            txn (cursor):
-            threepids (list[dict]): List of threepid dicts to reserve
-        """
-
-        # XXX what is this function trying to achieve?  It upserts into
-        # monthly_active_users for each *registered* reserved mau user, but why?
-        #
-        #  - shouldn't there already be an entry for each reserved user (at least
-        #    if they have been active recently)?
-        #
-        #  - if it's important that the timestamp is kept up to date, why do we only
-        #    run this at startup?
-
-        for tp in threepids:
-            user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
-
-            if user_id:
-                is_support = self.is_support_user_txn(txn, user_id)
-                if not is_support:
-                    # We do this manually here to avoid hitting #6791
-                    self.db_pool.simple_upsert_txn(
-                        txn,
-                        table="monthly_active_users",
-                        keyvalues={"user_id": user_id},
-                        values={"timestamp": int(self._clock.time_msec())},
-                    )
-            else:
-                logger.warning("mau limit reserved threepid %s not found in db" % tp)
-
     async def reap_monthly_active_users(self):
         """Cleans out monthly active user table to ensure that no stale
         entries exist.
@@ -250,6 +206,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             "reap_monthly_active_users", _reap_users, reserved_users
         )
 
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._mau_stats_only = hs.config.mau_stats_only
+
+        # Do not add more reserved users than the total allowable number
+        self.db_pool.new_transaction(
+            db_conn,
+            "initialise_mau_threepids",
+            [],
+            [],
+            self._initialise_reserved_users,
+            hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
+        )
+
+    def _initialise_reserved_users(self, txn, threepids):
+        """Ensures that reserved threepids are accounted for in the MAU table, should
+        be called on start up.
+
+        Args:
+            txn (cursor):
+            threepids (list[dict]): List of threepid dicts to reserve
+        """
+
+        # XXX what is this function trying to achieve?  It upserts into
+        # monthly_active_users for each *registered* reserved mau user, but why?
+        #
+        #  - shouldn't there already be an entry for each reserved user (at least
+        #    if they have been active recently)?
+        #
+        #  - if it's important that the timestamp is kept up to date, why do we only
+        #    run this at startup?
+
+        for tp in threepids:
+            user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
+
+            if user_id:
+                is_support = self.is_support_user_txn(txn, user_id)
+                if not is_support:
+                    # We do this manually here to avoid hitting #6791
+                    self.db_pool.simple_upsert_txn(
+                        txn,
+                        table="monthly_active_users",
+                        keyvalues={"user_id": user_id},
+                        values={"timestamp": int(self._clock.time_msec())},
+                    )
+            else:
+                logger.warning("mau limit reserved threepid %s not found in db" % tp)
+
     async def upsert_monthly_active_user(self, user_id: str) -> None:
         """Updates or inserts the user into the monthly active user table, which
         is used to track the current MAU usage of the server
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c9f655dfb7..dbbb99cb95 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
 
 class PresenceStore(SQLBaseStore):
     async def update_presence(self, presence_states):
-        stream_ordering_manager = await self._presence_id_gen.get_next_mult(
+        stream_ordering_manager = self._presence_id_gen.get_next_mult(
             len(presence_states)
         )
 
-        with stream_ordering_manager as stream_orderings:
+        async with stream_ordering_manager as stream_orderings:
             await self.db_pool.runInteraction(
                 "update_presence",
                 self._update_presence_txn,
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index d7a03cbf7d..ecfc6717b3 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
             The set of state groups that are referenced by deleted events.
         """
 
+        parsed_token = await RoomStreamToken.parse(self, token)
+
         return await self.db_pool.runInteraction(
             "purge_history",
             self._purge_history_txn,
             room_id,
-            token,
+            parsed_token,
             delete_local_events,
         )
 
-    def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
-        token = RoomStreamToken.parse(token_str)
-
+    def _purge_history_txn(self, txn, room_id, token, delete_local_events):
         # Tables that should be pruned:
         #     event_auth
         #     event_backward_extremities
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 9790a31998..711d5aa23d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -61,6 +61,8 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False):
     return rules
 
 
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
 class PushRulesWorkerStore(
     ApplicationServiceWorkerStore,
     ReceiptsWorkerStore,
@@ -68,17 +70,14 @@ class PushRulesWorkerStore(
     RoomMemberWorkerStore,
     EventsWorkerStore,
     SQLBaseStore,
+    metaclass=abc.ABCMeta,
 ):
     """This is an abstract base class where subclasses must implement
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
-    # This ABCMeta metaclass ensures that we cannot be instantiated without
-    # the abstract methods being implemented.
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         if hs.config.worker.worker_app is None:
             self._push_rules_stream_id_gen = StreamIdGenerator(
@@ -339,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             if before or after:
@@ -586,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
             )
 
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
@@ -617,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
         Raises:
             NotFoundError if the rule does not exist.
         """
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
             await self.db_pool.runInteraction(
                 "_set_push_rule_enabled_txn",
@@ -755,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
                 data={"actions": actions_json},
             )
 
-        with await self._push_rules_stream_id_gen.get_next() as stream_id:
+        async with self._push_rules_stream_id_gen.get_next() as stream_id:
             event_stream_ordering = self._stream_id_gen.get_current_token()
 
             await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c388468273..df8609b97b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
         last_stream_ordering,
         profile_tag="",
     ) -> None:
-        with await self._pushers_id_gen.get_next() as stream_id:
+        async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so simple_upsert will retry
             await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
                 },
             )
 
-        with await self._pushers_id_gen.get_next() as stream_id:
+        async with self._pushers_id_gen.get_next() as stream_id:
             await self.db_pool.runInteraction(
                 "delete_pusher", delete_pusher_txn, stream_id
             )
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 4a0d5a320e..c79ddff680 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -31,17 +31,15 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 logger = logging.getLogger(__name__)
 
 
-class ReceiptsWorkerStore(SQLBaseStore):
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
+class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
     """This is an abstract base class where subclasses must implement
     `get_max_receipt_stream_id` which can be called in the initializer.
     """
 
-    # This ABCMeta metaclass ensures that we cannot be instantiated without
-    # the abstract methods being implemented.
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -388,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(ReceiptsStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
@@ -526,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "insert_receipt_conv", graph_to_linear
             )
 
-        with await self._receipts_id_gen.get_next() as stream_id:
+        async with self._receipts_id_gen.get_next() as stream_id:
             event_ts = await self.db_pool.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 01f20c03c2..a85867936f 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,14 +14,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import re
 from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Cursor
@@ -36,15 +38,33 @@ logger = logging.getLogger(__name__)
 
 class RegistrationWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.config = hs.config
         self.clock = hs.get_clock()
 
+        # Note: we don't check this sequence for consistency as we'd have to
+        # call `find_max_generated_user_id_localpart` each time, which is
+        # expensive if there are many entries.
         self._user_id_seq = build_sequence_generator(
             database.engine, find_max_generated_user_id_localpart, "user_id_seq",
         )
 
+        self._account_validity = hs.config.account_validity
+        if hs.config.run_background_tasks and self._account_validity.enabled:
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "account_validity_set_expiration_dates",
+                self._set_expiration_date_when_missing,
+            )
+
+        # Create a background job for culling expired 3PID validity tokens
+        if hs.config.run_background_tasks:
+            self.clock.looping_call(
+                self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+            )
+
     @cached()
     async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
         return await self.db_pool.simple_select_one(
@@ -116,6 +136,20 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_expiration_ts_for_user",
         )
 
+    async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+        """
+        Returns whether an user account is expired.
+
+        Args:
+            user_id: The user's ID
+            current_ts: The current timestamp
+
+        Returns:
+            Whether the user account has expired
+        """
+        expiration_ts = await self.get_expiration_ts_for_user(user_id)
+        return expiration_ts is not None and current_ts >= expiration_ts
+
     async def set_account_validity_for_user(
         self,
         user_id: str,
@@ -379,7 +413,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
-    ) -> str:
+    ) -> Optional[str]:
         """Look up a user by their external auth id
 
         Args:
@@ -387,7 +421,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             external_id: id on that system
 
         Returns:
-            str|None: the mxid of the user, or None if they are not known
+            the mxid of the user, or None if they are not known
         """
         return await self.db_pool.simple_select_one_onecol(
             table="user_external_ids",
@@ -761,10 +795,82 @@ class RegistrationWorkerStore(SQLBaseStore):
             "delete_threepid_session", delete_threepid_session_txn
         )
 
+    @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+    async def cull_expired_threepid_validation_tokens(self) -> None:
+        """Remove threepid validation tokens with expiry dates that have passed"""
+
+        def cull_expired_threepid_validation_tokens_txn(txn, ts):
+            sql = """
+            DELETE FROM threepid_validation_token WHERE
+            expires < ?
+            """
+            txn.execute(sql, (ts,))
+
+        await self.db_pool.runInteraction(
+            "cull_expired_threepid_validation_tokens",
+            cull_expired_threepid_validation_tokens_txn,
+            self.clock.time_msec(),
+        )
+
+    async def _set_expiration_date_when_missing(self):
+        """
+        Retrieves the list of registered users that don't have an expiration date, and
+        adds an expiration date for each of them.
+        """
+
+        def select_users_with_no_expiration_date_txn(txn):
+            """Retrieves the list of registered users with no expiration date from the
+            database, filtering out deactivated users.
+            """
+            sql = (
+                "SELECT users.name FROM users"
+                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+            )
+            txn.execute(sql, [])
+
+            res = self.db_pool.cursor_to_dict(txn)
+            if res:
+                for user in res:
+                    self.set_expiration_date_for_user_txn(
+                        txn, user["name"], use_delta=True
+                    )
+
+        await self.db_pool.runInteraction(
+            "get_users_with_no_expiration_date",
+            select_users_with_no_expiration_date_txn,
+        )
+
+    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+        """Sets an expiration date to the account with the given user ID.
+
+        Args:
+             user_id (str): User ID to set an expiration date for.
+             use_delta (bool): If set to False, the expiration date for the user will be
+                now + validity period. If set to True, this expiration date will be a
+                random value in the [now + period - d ; now + period] range, d being a
+                delta equal to 10% of the validity period.
+        """
+        now_ms = self._clock.time_msec()
+        expiration_ts = now_ms + self._account_validity.period
+
+        if use_delta:
+            expiration_ts = self.rand.randrange(
+                expiration_ts - self._account_validity.startup_job_max_delta,
+                expiration_ts,
+            )
+
+        self.db_pool.simple_upsert_txn(
+            txn,
+            "account_validity",
+            keyvalues={"user_id": user_id},
+            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.clock = hs.get_clock()
         self.config = hs.config
@@ -892,30 +998,10 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
 class RegistrationStore(RegistrationBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RegistrationStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
-        self._account_validity = hs.config.account_validity
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
-        if self._account_validity.enabled:
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "account_validity_set_expiration_dates",
-                self._set_expiration_date_when_missing,
-            )
-
-        # Create a background job for culling expired 3PID validity tokens
-        def start_cull():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "cull_expired_threepid_validation_tokens",
-                self.cull_expired_threepid_validation_tokens,
-            )
-
-        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
     async def add_access_token_to_user(
         self,
         user_id: str,
@@ -947,6 +1033,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_access_token_to_user",
         )
 
+    def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+        old_device_id = self.db_pool.simple_select_one_onecol_txn(
+            txn, "access_tokens", {"token": token}, "device_id"
+        )
+
+        self.db_pool.simple_update_txn(
+            txn, "access_tokens", {"token": token}, {"device_id": device_id}
+        )
+
+        self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
+
+        return old_device_id
+
+    async def set_device_for_access_token(self, token: str, device_id: str) -> str:
+        """Sets the device ID associated with an access token.
+
+        Args:
+            token: The access token to modify.
+            device_id: The new device ID.
+        Returns:
+            The old device ID associated with the access token.
+        """
+
+        return await self.db_pool.runInteraction(
+            "set_device_for_access_token",
+            self._set_device_for_access_token_txn,
+            token,
+            device_id,
+        )
+
     async def register_user(
         self,
         user_id: str,
@@ -1430,22 +1546,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             start_or_continue_validation_session_txn,
         )
 
-    async def cull_expired_threepid_validation_tokens(self) -> None:
-        """Remove threepid validation tokens with expiry dates that have passed"""
-
-        def cull_expired_threepid_validation_tokens_txn(txn, ts):
-            sql = """
-            DELETE FROM threepid_validation_token WHERE
-            expires < ?
-            """
-            txn.execute(sql, (ts,))
-
-        await self.db_pool.runInteraction(
-            "cull_expired_threepid_validation_tokens",
-            cull_expired_threepid_validation_tokens_txn,
-            self.clock.time_msec(),
-        )
-
     async def set_user_deactivated_status(
         self, user_id: str, deactivated: bool
     ) -> None:
@@ -1475,61 +1575,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         )
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    async def _set_expiration_date_when_missing(self):
-        """
-        Retrieves the list of registered users that don't have an expiration date, and
-        adds an expiration date for each of them.
-        """
-
-        def select_users_with_no_expiration_date_txn(txn):
-            """Retrieves the list of registered users with no expiration date from the
-            database, filtering out deactivated users.
-            """
-            sql = (
-                "SELECT users.name FROM users"
-                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
-                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
-            )
-            txn.execute(sql, [])
-
-            res = self.db_pool.cursor_to_dict(txn)
-            if res:
-                for user in res:
-                    self.set_expiration_date_for_user_txn(
-                        txn, user["name"], use_delta=True
-                    )
-
-        await self.db_pool.runInteraction(
-            "get_users_with_no_expiration_date",
-            select_users_with_no_expiration_date_txn,
-        )
-
-    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
-        """Sets an expiration date to the account with the given user ID.
-
-        Args:
-             user_id (str): User ID to set an expiration date for.
-             use_delta (bool): If set to False, the expiration date for the user will be
-                now + validity period. If set to True, this expiration date will be a
-                random value in the [now + period - d ; now + period] range, d being a
-                delta equal to 10% of the validity period.
-        """
-        now_ms = self._clock.time_msec()
-        expiration_ts = now_ms + self._account_validity.period
-
-        if use_delta:
-            expiration_ts = self.rand.randrange(
-                expiration_ts - self._account_validity.startup_job_max_delta,
-                expiration_ts,
-            )
-
-        self.db_pool.simple_upsert_txn(
-            txn,
-            "account_validity",
-            keyvalues={"user_id": user_id},
-            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
-        )
-
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 127588ce4c..c0f2af0785 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -69,7 +69,7 @@ class RoomSortOrder(Enum):
 
 class RoomWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.config = hs.config
 
@@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore):
             "count_public_rooms", _count_public_rooms_txn
         )
 
+    async def get_room_count(self) -> int:
+        """Retrieve the total number of rooms.
+        """
+
+        def f(txn):
+            sql = "SELECT count(*)  FROM rooms"
+            txn.execute(sql)
+            row = txn.fetchone()
+            return row[0] or 0
+
+        return await self.db_pool.runInteraction("get_rooms", f)
+
     async def get_largest_public_rooms(
         self,
         network_tuple: Optional[ThirdPartyInstanceID],
@@ -863,7 +875,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
     ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.config = hs.config
 
@@ -1074,7 +1086,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.config = hs.config
 
@@ -1137,7 +1149,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                         },
                     )
 
-            with await self._public_room_id_gen.get_next() as next_id:
+            async with self._public_room_id_gen.get_next() as next_id:
                 await self.db_pool.runInteraction(
                     "store_room_txn", store_room_txn, next_id
                 )
@@ -1204,7 +1216,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with await self._public_room_id_gen.get_next() as next_id:
+        async with self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
@@ -1284,7 +1296,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     },
                 )
 
-        with await self._public_room_id_gen.get_next() as next_id:
+        async with self._public_room_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
@@ -1292,18 +1304,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    async def get_room_count(self) -> int:
-        """Retrieve the total number of rooms.
-        """
-
-        def f(txn):
-            sql = "SELECT count(*)  FROM rooms"
-            txn.execute(sql)
-            row = txn.fetchone()
-            return row[0] or 0
-
-        return await self.db_pool.runInteraction("get_rooms", f)
-
     async def add_event_report(
         self,
         room_id: str,
@@ -1328,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             desc="add_event_report",
         )
 
+    async def get_event_reports_paginate(
+        self,
+        start: int,
+        limit: int,
+        direction: str = "b",
+        user_id: Optional[str] = None,
+        room_id: Optional[str] = None,
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Retrieve a paginated list of event reports
+
+        Args:
+            start: event offset to begin the query from
+            limit: number of rows to retrieve
+            direction: Whether to fetch the most recent first (`"b"`) or the
+                oldest first (`"f"`)
+            user_id: search for user_id. Ignored if user_id is None
+            room_id: search for room_id. Ignored if room_id is None
+        Returns:
+            event_reports: json list of event reports
+            count: total number of event reports matching the filter criteria
+        """
+
+        def _get_event_reports_paginate_txn(txn):
+            filters = []
+            args = []
+
+            if user_id:
+                filters.append("er.user_id LIKE ?")
+                args.extend(["%" + user_id + "%"])
+            if room_id:
+                filters.append("er.room_id LIKE ?")
+                args.extend(["%" + room_id + "%"])
+
+            if direction == "b":
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+            sql = """
+                SELECT COUNT(*) as total_event_reports
+                FROM event_reports AS er
+                {}
+                """.format(
+                where_clause
+            )
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.reason,
+                    er.content,
+                    events.sender,
+                    room_aliases.room_alias,
+                    event_json.json AS event_json
+                FROM event_reports AS er
+                LEFT JOIN room_aliases
+                    ON room_aliases.room_id = er.room_id
+                JOIN events
+                    ON events.event_id = er.event_id
+                JOIN event_json
+                    ON event_json.event_id = er.event_id
+                {where_clause}
+                ORDER BY er.received_ts {order}
+                LIMIT ?
+                OFFSET ?
+            """.format(
+                where_clause=where_clause, order=order,
+            )
+
+            args += [limit, start]
+            txn.execute(sql, args)
+            event_reports = self.db_pool.cursor_to_dict(txn)
+
+            if count > 0:
+                for row in event_reports:
+                    try:
+                        row["content"] = db_to_json(row["content"])
+                        row["event_json"] = db_to_json(row["event_json"])
+                    except Exception:
+                        continue
+
+            return event_reports, count
+
+        return await self.db_pool.runInteraction(
+            "get_event_reports_paginate", _get_event_reports_paginate_txn
+        )
+
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 91a8b43da3..20fcdaa529 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
 
@@ -22,12 +21,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
-    LoggingTransaction,
-    SQLBaseStore,
-    db_to_json,
-    make_in_list_sql_clause,
-)
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
@@ -37,7 +31,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, PersistedEventPosition, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -55,21 +49,22 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 class RoomMemberWorkerStore(EventsWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         # Is the current_state_events.membership up to date? Or is the
         # background update still running?
         self._current_state_events_membership_up_to_date = False
 
-        txn = LoggingTransaction(
-            db_conn.cursor(),
-            name="_check_safe_current_state_events_membership_updated",
-            database_engine=self.database_engine,
+        txn = db_conn.cursor(
+            txn_name="_check_safe_current_state_events_membership_updated"
         )
         self._check_safe_current_state_events_membership_updated_txn(txn)
         txn.close()
 
-        if self.hs.config.metrics_flags.known_servers:
+        if (
+            self.hs.config.run_background_tasks
+            and self.hs.config.metrics_flags.known_servers
+        ):
             self._known_servers_count = 1
             self.hs.get_clock().looping_call(
                 run_as_background_process,
@@ -387,7 +382,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # for rooms the server is participating in.
         if self._current_state_events_membership_up_to_date:
             sql = """
-                SELECT room_id, e.stream_ordering
+                SELECT room_id, e.instance_name, e.stream_ordering
                 FROM current_state_events AS c
                 INNER JOIN events AS e USING (room_id, event_id)
                 WHERE
@@ -397,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
         else:
             sql = """
-                SELECT room_id, e.stream_ordering
+                SELECT room_id, e.instance_name, e.stream_ordering
                 FROM current_state_events AS c
                 INNER JOIN room_memberships AS m USING (room_id, event_id)
                 INNER JOIN events AS e USING (room_id, event_id)
@@ -408,7 +403,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             """
 
         txn.execute(sql, (user_id, Membership.JOIN))
-        return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+        return frozenset(
+            GetRoomsForUserWithStreamOrdering(
+                room_id, PersistedEventPosition(instance, stream_id)
+            )
+            for room_id, instance, stream_id in txn
+        )
 
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
@@ -819,7 +819,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
 class RoomMemberBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
         )
@@ -973,7 +973,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(RoomMemberStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     async def forget(self, user_id: str, room_id: str) -> None:
         """Indicate that user_id wishes to discard history for room_id."""
diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..45b846e6a7 100644
--- a/synapse/storage/databases/main/schema/delta/20/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
@@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs):
         row[8] = bytes(row[8]).decode("utf-8")
         row[11] = bytes(row[11]).decode("utf-8")
         cur.execute(
-            database_engine.convert_param_style(
-                """
-            INSERT into pushers2 (
-            id, user_name, access_token, profile_tag, kind,
-            app_id, app_display_name, device_display_name,
-            pushkey, ts, lang, data, last_token, last_success,
-            failing_since
-            ) values (%s)"""
-                % (",".join(["?" for _ in range(len(row))]))
-            ),
+            """
+                INSERT into pushers2 (
+                id, user_name, access_token, profile_tag, kind,
+                app_id, app_display_name, device_display_name,
+                pushkey, ts, lang, data, last_token, last_success,
+                failing_since
+                ) values (%s)
+            """
+            % (",".join(["?" for _ in range(len(row))])),
             row,
         )
         count += 1
diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index ee675e71ff..21f57825d4 100644
--- a/synapse/storage/databases/main/schema/delta/25/fts.py
+++ b/synapse/storage/databases/main/schema/delta/25/fts.py
@@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_search", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index b7972cfa8e..1c6058063f 100644
--- a/synapse/storage/databases/main/schema/delta/27/ts.py
+++ b/synapse/storage/databases/main/schema/delta/27/ts.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_origin_server_ts", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index b42c02710a..7f08fabe9f 100644
--- a/synapse/storage/databases/main/schema/delta/30/as_users.py
+++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
@@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
         user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
         for chunk in user_chunks:
             cur.execute(
-                database_engine.convert_param_style(
-                    "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
-                    % (",".join("?" for _ in chunk),)
-                ),
+                "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+                % (",".join("?" for _ in chunk),),
                 [as_id] + chunk,
             )
diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..5be81c806a 100644
--- a/synapse/storage/databases/main/schema/delta/31/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
@@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs):
         row = list(row)
         row[12] = token_to_stream_ordering(row[12])
         cur.execute(
-            database_engine.convert_param_style(
-                """
-            INSERT into pushers2 (
-            id, user_name, access_token, profile_tag, kind,
-            app_id, app_display_name, device_display_name,
-            pushkey, ts, lang, data, last_stream_ordering, last_success,
-            failing_since
-            ) values (%s)"""
-                % (",".join(["?" for _ in range(len(row))]))
-            ),
+            """
+                INSERT into pushers2 (
+                id, user_name, access_token, profile_tag, kind,
+                app_id, app_display_name, device_display_name,
+                pushkey, ts, lang, data, last_stream_ordering, last_success,
+                failing_since
+                ) values (%s)
+            """
+            % (",".join(["?" for _ in range(len(row))])),
             row,
         )
         count += 1
diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 63b757ade6..b84c844e3a 100644
--- a/synapse/storage/databases/main/schema/delta/31/search_update.py
+++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
@@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_search_order", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index a3e81eeac7..e928c66a8f 100644
--- a/synapse/storage/databases/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_fields_sender_url", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..ad875c733a 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs):
 
 def run_upgrade(cur, database_engine, *args, **kwargs):
     cur.execute(
-        database_engine.convert_param_style(
-            "UPDATE remote_media_cache SET last_access_ts = ?"
-        ),
-        (int(time.time() * 1000),),
+        "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
     )
diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
index 5e29c1da19..ccf287971c 100644
--- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
@@ -13,7 +13,7 @@
  * limitations under the License.
  */
 
--- room_id and topoligical_ordering are denormalised from the events table in order to
+-- room_id and topological_ordering are denormalised from the events table in order to
 -- make the index work.
 CREATE TABLE IF NOT EXISTS event_labels (
     event_id TEXT,
diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..bb7296852a 100644
--- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
@@ -1,6 +1,8 @@
 import logging
+from io import StringIO
 
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
 
 logger = logging.getLogger(__name__)
 
@@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs):
         select_clause,
     )
 
-    if isinstance(database_engine, PostgresEngine):
-        cur.execute(sql)
-    else:
-        cur.executescript(sql)
+    execute_statements_from_stream(cur, StringIO(sql))
diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..44917f0a2e 100644
--- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
+++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
@@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
                 INNER JOIN room_memberships AS r USING (event_id)
                 WHERE type = 'm.room.member' AND state_key LIKE ?
         """
-    sql = database_engine.convert_param_style(sql)
     cur.execute(sql, ("%:" + config.server_name,))
 
     cur.execute(
diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
new file mode 100644
index 0000000000..7851a0a825
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
@@ -0,0 +1,20 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS dehydrated_devices(
+    user_id TEXT NOT NULL PRIMARY KEY,
+    device_id TEXT NOT NULL,
+    device_data TEXT NOT NULL -- JSON-encoded client-defined data
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
new file mode 100644
index 0000000000..4ed981dbf8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
+    user_id TEXT NOT NULL, -- The user this fallback key is for.
+    device_id TEXT NOT NULL, -- The device this fallback key is for.
+    algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
+    key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+    key_json TEXT NOT NULL, -- The key as a JSON blob.
+    used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
+    CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
index 97c1e6a0c5..c31f9af82a 100644
--- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -21,6 +21,8 @@ SELECT setval('events_stream_seq', (
 
 CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
 
+-- If the server has never backfilled a room then doing `-MIN(...)` will give
+-- a negative result, hence why we do `GREATEST(...)`
 SELECT setval('events_backfill_stream_seq', (
-    SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+    SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events
 ));
diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
new file mode 100644
index 0000000000..985fd949a2
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stream_positions (
+    stream_name TEXT NOT NULL,
+    instance_name TEXT NOT NULL,
+    stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name);
diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
new file mode 100644
index 0000000000..841186b826
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A unique and immutable mapping between instance name and an integer ID. This
+-- lets us refer to instances via a small ID in e.g. stream tokens, without
+-- having to encode the full name.
+CREATE TABLE IF NOT EXISTS instance_map (
+    instance_id SERIAL PRIMARY KEY,
+    instance_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index f01cf2fd02..e34fce6281 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -89,7 +89,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         if not hs.config.enable_search:
             return
@@ -342,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
 
 class SearchStore(SearchBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(SearchStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     async def search_msgs(self, room_ids, search_term, keys):
         """Performs a full text search over events with given keys.
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5c6168e301..3c1e33819b 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -56,7 +56,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     async def get_room_version(self, room_id: str) -> RoomVersion:
         """Get the room_version of a given room
@@ -320,7 +320,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
     DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
 
@@ -506,4 +506,4 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(StateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 9c1bf3c289..bc8e78e1f1 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -62,7 +62,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
 
 class StatsStore(StateDeltasStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(StatsStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
         self.clock = self.hs.get_clock()
@@ -211,6 +211,7 @@ class StatsStore(StateDeltasStore):
         * topic
         * avatar
         * canonical_alias
+        * guest_access
 
         A is_federatable key can also be included with a boolean value.
 
@@ -235,6 +236,7 @@ class StatsStore(StateDeltasStore):
             "topic",
             "avatar",
             "canonical_alias",
+            "guest_access",
         ):
             field = fields.get(col, sentinel)
             if field is not sentinel and (not isinstance(field, str) or "\0" in field):
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 2e95518752..e3b9ff5ca6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -35,7 +35,6 @@ what sort order was used:
     - topological tokems: "t%d-%d", where the integers map to the topological
       and stream ordering columns respectively.
 """
-
 import abc
 import logging
 from collections import namedtuple
@@ -54,7 +53,9 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.types import Collection, RoomStreamToken
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 if TYPE_CHECKING:
@@ -209,6 +210,55 @@ def _make_generic_sql_bound(
     )
 
 
+def _filter_results(
+    lower_token: Optional[RoomStreamToken],
+    upper_token: Optional[RoomStreamToken],
+    instance_name: str,
+    topological_ordering: int,
+    stream_ordering: int,
+) -> bool:
+    """Returns True if the event persisted by the given instance at the given
+    topological/stream_ordering falls between the two tokens (taking a None
+    token to mean unbounded).
+
+    Used to filter results from fetching events in the DB against the given
+    tokens. This is necessary to handle the case where the tokens include
+    position maps, which we handle by fetching more than necessary from the DB
+    and then filtering (rather than attempting to construct a complicated SQL
+    query).
+    """
+
+    event_historical_tuple = (
+        topological_ordering,
+        stream_ordering,
+    )
+
+    if lower_token:
+        if lower_token.topological is not None:
+            # If these are historical tokens we compare the `(topological, stream)`
+            # tuples.
+            if event_historical_tuple <= lower_token.as_historical_tuple():
+                return False
+
+        else:
+            # If these are live tokens we compare the stream ordering against the
+            # writers stream position.
+            if stream_ordering <= lower_token.get_stream_pos_for_instance(
+                instance_name
+            ):
+                return False
+
+    if upper_token:
+        if upper_token.topological is not None:
+            if upper_token.as_historical_tuple() < event_historical_tuple:
+                return False
+        else:
+            if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+                return False
+
+    return True
+
+
 def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
@@ -259,16 +309,14 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     return " AND ".join(clauses), args
 
 
-class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     """This is an abstract base class where subclasses must implement
     `get_room_max_stream_ordering` and `get_room_min_stream_ordering`
     which can be called in the initializer.
     """
 
-    __metaclass__ = abc.ABCMeta
-
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
-        super(StreamWorkerStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
         self._send_federation = hs.should_send_federation()
@@ -307,6 +355,33 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     def get_room_min_stream_ordering(self) -> int:
         raise NotImplementedError()
 
+    def get_room_max_token(self) -> RoomStreamToken:
+        """Get a `RoomStreamToken` that marks the current maximum persisted
+        position of the events stream. Useful to get a token that represents
+        "now".
+
+        The token returned is a "live" token that may have an instance_map
+        component.
+        """
+
+        min_pos = self._stream_id_gen.get_current_token()
+
+        positions = {}
+        if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+            # The `min_pos` is the minimum position that we know all instances
+            # have finished persisting to, so we only care about instances whose
+            # positions are ahead of that. (Instance positions can be behind the
+            # min position as there are times we can work out that the minimum
+            # position is ahead of the naive minimum across all current
+            # positions. See MultiWriterIdGenerator for details)
+            positions = {
+                i: p
+                for i, p in self._stream_id_gen.get_positions().items()
+                if p > min_pos
+            }
+
+        return RoomStreamToken(None, min_pos, positions)
+
     async def get_room_events_stream_for_rooms(
         self,
         room_ids: Collection[str],
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if from_key == to_key:
             return [], from_key
 
-        from_id = from_key.stream
-        to_id = to_key.stream
-
-        has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+        has_changed = self._events_stream_cache.has_entity_changed(
+            room_id, from_key.stream
+        )
 
         if not has_changed:
             return [], from_key
 
         def f(txn):
-            sql = (
-                "SELECT event_id, stream_ordering FROM events WHERE"
-                " room_id = ?"
-                " AND not outlier"
-                " AND stream_ordering > ? AND stream_ordering <= ?"
-                " ORDER BY stream_ordering %s LIMIT ?"
-            ) % (order,)
-            txn.execute(sql, (room_id, from_id, to_id, limit))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            sql = """
+                SELECT event_id, instance_name, topological_ordering, stream_ordering
+                FROM events
+                WHERE
+                    room_id = ?
+                    AND not outlier
+                    AND stream_ordering > ? AND stream_ordering <= ?
+                ORDER BY stream_ordering %s LIMIT ?
+            """ % (
+                order,
+            )
+            txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, topological_ordering, stream_ordering in txn
+                if _filter_results(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    topological_ordering,
+                    stream_ordering,
+                )
+            ][:limit]
             return rows
 
         rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             [r.event_id for r in rows], get_prev_content=True
         )
 
-        self._set_before_and_after(ret, rows, topo_order=from_id is None)
+        self._set_before_and_after(ret, rows, topo_order=False)
 
         if order.lower() == "desc":
             ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     async def get_membership_changes_for_user(
         self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
-        from_id = from_key.stream
-        to_id = to_key.stream
-
         if from_key == to_key:
             return []
 
-        if from_id:
+        if from_key:
             has_changed = self._membership_stream_cache.has_entity_changed(
-                user_id, int(from_id)
+                user_id, int(from_key.stream)
             )
             if not has_changed:
                 return []
 
         def f(txn):
-            sql = (
-                "SELECT m.event_id, stream_ordering FROM events AS e,"
-                " room_memberships AS m"
-                " WHERE e.event_id = m.event_id"
-                " AND m.user_id = ?"
-                " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
-            )
-            txn.execute(sql, (user_id, from_id, to_id))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            sql = """
+                SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+                FROM events AS e, room_memberships AS m
+                WHERE e.event_id = m.event_id
+                    AND m.user_id = ?
+                    AND e.stream_ordering > ? AND e.stream_ordering <= ?
+                ORDER BY e.stream_ordering ASC
+            """
+            txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, topological_ordering, stream_ordering in txn
+                if _filter_results(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    topological_ordering,
+                    stream_ordering,
+                )
+            ]
 
             return rows
 
@@ -546,7 +651,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     async def get_room_event_before_stream_ordering(
         self, room_id: str, stream_ordering: int
-    ) -> Tuple[int, int, str]:
+    ) -> Optional[Tuple[int, int, str]]:
         """Gets details of the first event in a room at or before a stream ordering
 
         Args:
@@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             )
             return "t%d-%d" % (topo, token)
 
-    async def get_stream_id_for_event(self, event_id: str) -> int:
-        """The stream ID for an event
-        Args:
-            event_id: The id of the event to look up a stream token for.
-        Raises:
-            StoreError if the event wasn't in the database.
-        Returns:
-            A stream ID.
-        """
-        return await self.db_pool.runInteraction(
-            "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
-        )
-
     def get_stream_id_for_event_txn(
         self, txn: LoggingTransaction, event_id: str, allow_none=False,
     ) -> int:
@@ -613,26 +705,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             allow_none=allow_none,
         )
 
-    async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
-        """The stream token for an event
-        Args:
-            event_id: The id of the event to look up a stream token for.
-        Raises:
-            StoreError if the event wasn't in the database.
-        Returns:
-            A stream token.
+    async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
+        """Get the persisted position for an event
         """
-        stream_id = await self.get_stream_id_for_event(event_id)
-        return RoomStreamToken(None, stream_id)
+        row = await self.db_pool.simple_select_one(
+            table="events",
+            keyvalues={"event_id": event_id},
+            retcols=("stream_ordering", "instance_name"),
+            desc="get_position_for_event",
+        )
+
+        return PersistedEventPosition(
+            row["instance_name"] or "master", row["stream_ordering"]
+        )
 
-    async def get_topological_token_for_event(self, event_id: str) -> str:
+    async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
         """The stream token for an event
         Args:
             event_id: The id of the event to look up a stream token for.
         Raises:
             StoreError if the event wasn't in the database.
         Returns:
-            A "t%d-%d" topological token.
+            A `RoomStreamToken` topological token.
         """
         row = await self.db_pool.simple_select_one(
             table="events",
@@ -640,25 +734,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             retcols=("stream_ordering", "topological_ordering"),
             desc="get_topological_token_for_event",
         )
-        return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
+        return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
 
-    async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
-        """Get the max topological token in a room before the given stream
+    async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
+        """Gets the topological token in a room after or at the given stream
         ordering.
 
         Args:
             room_id
             stream_key
-
-        Returns:
-            The maximum topological token.
         """
         sql = (
-            "SELECT coalesce(max(topological_ordering), 0) FROM events"
-            " WHERE room_id = ? AND stream_ordering < ?"
+            "SELECT coalesce(MIN(topological_ordering), 0) FROM events"
+            " WHERE room_id = ? AND stream_ordering >= ?"
         )
         row = await self.db_pool.execute(
-            "get_max_topological_token", None, sql, room_id, stream_key
+            "get_current_topological_token", None, sql, room_id, stream_key
         )
         return row[0][0] if row else 0
 
@@ -692,8 +783,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             else:
                 topo = None
             internal = event.internal_metadata
-            internal.before = str(RoomStreamToken(topo, stream - 1))
-            internal.after = str(RoomStreamToken(topo, stream))
+            internal.before = RoomStreamToken(topo, stream - 1)
+            internal.after = RoomStreamToken(topo, stream)
             internal.order = (int(topo) if topo else 0, int(stream))
 
     async def get_events_around(
@@ -980,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         else:
             order = "ASC"
 
+        # The bounds for the stream tokens are complicated by the fact
+        # that we need to handle the instance_map part of the tokens. We do this
+        # by fetching all events between the min stream token and the maximum
+        # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+        # then filtering the results.
+        if from_token.topological is not None:
+            from_bound = (
+                from_token.as_historical_tuple()
+            )  # type: Tuple[Optional[int], int]
+        elif direction == "b":
+            from_bound = (
+                None,
+                from_token.get_max_stream_pos(),
+            )
+        else:
+            from_bound = (
+                None,
+                from_token.stream,
+            )
+
+        to_bound = None  # type: Optional[Tuple[Optional[int], int]]
+        if to_token:
+            if to_token.topological is not None:
+                to_bound = to_token.as_historical_tuple()
+            elif direction == "b":
+                to_bound = (
+                    None,
+                    to_token.stream,
+                )
+            else:
+                to_bound = (
+                    None,
+                    to_token.get_max_stream_pos(),
+                )
+
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token.as_tuple(),
-            to_token=to_token.as_tuple() if to_token else None,
+            from_token=from_bound,
+            to_token=to_bound,
             engine=self.database_engine,
         )
 
@@ -994,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             bounds += " AND " + filter_clause
             args.extend(filter_args)
 
-        args.append(int(limit))
+        # We fetch more events as we'll filter the result set
+        args.append(int(limit) * 2)
 
         select_keywords = "SELECT"
         join_clause = ""
@@ -1016,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 select_keywords += "DISTINCT"
 
         sql = """
-            %(select_keywords)s event_id, topological_ordering, stream_ordering
+            %(select_keywords)s
+                event_id, instance_name,
+                topological_ordering, stream_ordering
             FROM events
             %(join_clause)s
             WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1031,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         txn.execute(sql, args)
 
-        rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+        # Filter the result set.
+        rows = [
+            _EventDictReturn(event_id, topological_ordering, stream_ordering)
+            for event_id, instance_name, topological_ordering, stream_ordering in txn
+            if _filter_results(
+                lower_token=to_token if direction == "b" else from_token,
+                upper_token=from_token if direction == "b" else to_token,
+                instance_name=instance_name,
+                topological_ordering=topological_ordering,
+                stream_ordering=stream_ordering,
+            )
+        ][:limit]
 
         if rows:
             topo = rows[-1].topological_ordering
@@ -1096,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return (events, token)
 
+    @cached()
+    async def get_id_for_instance(self, instance_name: str) -> int:
+        """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+        """
+
+        def _get_id_for_instance_txn(txn):
+            instance_id = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                retcol="instance_id",
+                allow_none=True,
+            )
+            if instance_id is not None:
+                return instance_id
+
+            # If we don't have an entry upsert one.
+            #
+            # We could do this before the first check, and rely on the cache for
+            # efficiency, but each UPSERT causes the next ID to increment which
+            # can quickly bloat the size of the generated IDs for new instances.
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                values={},
+            )
+
+            return self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                retcol="instance_id",
+            )
+
+        return await self.db_pool.runInteraction(
+            "get_id_for_instance", _get_id_for_instance_txn
+        )
+
+    @cached()
+    async def get_name_from_instance_id(self, instance_id: int) -> str:
+        """Get the instance name from an ID previously returned by
+        `get_id_for_instance`.
+        """
+
+        return await self.db_pool.simple_select_one_onecol(
+            table="instance_map",
+            keyvalues={"instance_id": instance_id},
+            retcol="instance_name",
+            desc="get_name_from_instance_id",
+        )
+
 
 class StreamStore(StreamWorkerStore):
     def get_room_max_stream_ordering(self) -> int:
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 96ffe26cc9..9f120d3cb6 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
             )
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
             txn.execute(sql, (user_id, room_id, tag))
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
-        with await self._account_data_id_gen.get_next() as next_id:
+        async with self._account_data_id_gen.get_next() as next_id:
             await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 091367006e..7d46090267 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -43,14 +43,32 @@ _UpdateTransactionRow = namedtuple(
 SENTINEL = object()
 
 
-class TransactionStore(SQLBaseStore):
+class TransactionWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
+
+    @wrap_as_background_process("cleanup_transactions")
+    async def _cleanup_transactions(self) -> None:
+        now = self._clock.time_msec()
+        month_ago = now - 30 * 24 * 60 * 60 * 1000
+
+        def _cleanup_transactions_txn(txn):
+            txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
+
+        await self.db_pool.runInteraction(
+            "_cleanup_transactions", _cleanup_transactions_txn
+        )
+
+
+class TransactionStore(TransactionWorkerStore):
     """A collection of queries for handling PDUs.
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(TransactionStore, self).__init__(database, db_conn, hs)
-
-        self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
+        super().__init__(database, db_conn, hs)
 
         self._destination_retry_cache = ExpiringCache(
             cache_name="get_destination_retry_timings",
@@ -218,6 +236,7 @@ class TransactionStore(SQLBaseStore):
                         retry_interval = EXCLUDED.retry_interval
                     WHERE
                         EXCLUDED.retry_interval = 0
+                        OR destinations.retry_interval IS NULL
                         OR destinations.retry_interval < EXCLUDED.retry_interval
             """
 
@@ -249,7 +268,11 @@ class TransactionStore(SQLBaseStore):
                     "retry_interval": retry_interval,
                 },
             )
-        elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
+        elif (
+            retry_interval == 0
+            or prev_row["retry_interval"] is None
+            or prev_row["retry_interval"] < retry_interval
+        ):
             self.db_pool.simple_update_one_txn(
                 txn,
                 "destinations",
@@ -261,22 +284,6 @@ class TransactionStore(SQLBaseStore):
                 },
             )
 
-    def _start_cleanup_transactions(self):
-        return run_as_background_process(
-            "cleanup_transactions", self._cleanup_transactions
-        )
-
-    async def _cleanup_transactions(self) -> None:
-        now = self._clock.time_msec()
-        month_ago = now - 30 * 24 * 60 * 60 * 1000
-
-        def _cleanup_transactions_txn(txn):
-            txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
-
-        await self.db_pool.runInteraction(
-            "_cleanup_transactions", _cleanup_transactions_txn
-        )
-
     async def store_destination_rooms_entries(
         self, destinations: Iterable[str], room_id: str, stream_ordering: int,
     ) -> None:
@@ -397,7 +404,7 @@ class TransactionStore(SQLBaseStore):
 
     @staticmethod
     def _get_catch_up_room_event_ids_txn(
-        txn, destination: str, last_successful_stream_ordering: int,
+        txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
     ) -> List[str]:
         q = """
                 SELECT event_id FROM destination_rooms
@@ -412,3 +419,60 @@ class TransactionStore(SQLBaseStore):
         )
         event_ids = [row[0] for row in txn]
         return event_ids
+
+    async def get_catch_up_outstanding_destinations(
+        self, after_destination: Optional[str]
+    ) -> List[str]:
+        """
+        Gets at most 25 destinations which have outstanding PDUs to be caught up,
+        and are not being backed off from
+        Args:
+            after_destination:
+                If provided, all destinations must be lexicographically greater
+                than this one.
+
+        Returns:
+            list of up to 25 destinations with outstanding catch-up.
+                These are the lexicographically first destinations which are
+                lexicographically greater than after_destination (if provided).
+        """
+        time = self.hs.get_clock().time_msec()
+
+        return await self.db_pool.runInteraction(
+            "get_catch_up_outstanding_destinations",
+            self._get_catch_up_outstanding_destinations_txn,
+            time,
+            after_destination,
+        )
+
+    @staticmethod
+    def _get_catch_up_outstanding_destinations_txn(
+        txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
+    ) -> List[str]:
+        q = """
+            SELECT destination FROM destinations
+                WHERE destination IN (
+                    SELECT destination FROM destination_rooms
+                        WHERE destination_rooms.stream_ordering >
+                            destinations.last_successful_stream_ordering
+                )
+                AND destination > ?
+                AND (
+                    retry_last_ts IS NULL OR
+                    retry_last_ts + retry_interval < ?
+                )
+                ORDER BY destination
+                LIMIT 25
+        """
+        txn.execute(
+            q,
+            (
+                # everything is lexicographically greater than "" so this gives
+                # us the first batch of up to 25.
+                after_destination or "",
+                now_time_ms,
+            ),
+        )
+
+        destinations = [row[0] for row in txn]
+        return destinations
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 3b9211a6d2..79b7ece330 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore):
         )
         return [(row["user_agent"], row["ip"]) for row in rows]
 
-
-class UIAuthStore(UIAuthWorkerStore):
     async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
         """
         Remove sessions which were last used earlier than the expiration time.
@@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore):
             iterable=session_ids,
             keyvalues={},
         )
+
+
+class UIAuthStore(UIAuthWorkerStore):
+    pass
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f2f9a5799a..5a390ff2f6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -38,7 +38,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     SHARE_PRIVATE_WORKING_SET = 500
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
 
@@ -564,7 +564,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     SHARE_PRIVATE_WORKING_SET = 500
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(UserDirectoryStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
     async def remove_from_user_dir(self, user_id: str) -> None:
         def _remove_from_user_dir_txn(txn):
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 2f7c95fc74..f9575b1f1f 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore):
                 return
 
             # They are there, delete them.
-            self.simple_delete_one_txn(
+            self.db_pool.simple_delete_one_txn(
                 txn, "erased_users", keyvalues={"user_id": user_id}
             )
 
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 139085b672..acb24e33af 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -181,7 +181,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
     STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
         self.db_pool.updates.register_background_update_handler(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
             self._background_deduplicate_state,
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index e924f1ca3b..0e31cc811a 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -24,7 +24,7 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor
 from synapse.storage.state import StateFilter
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
 
@@ -52,7 +52,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
-        super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+        super().__init__(database, db_conn, hs)
 
         # Originally the state store used a single DictionaryCache to cache the
         # event IDs for the state types in a given state group to avoid hammering
@@ -99,6 +99,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         self._state_group_seq_gen = build_sequence_generator(
             self.database_engine, get_max_state_group_txn, "state_group_id_seq"
         )
+        self._state_group_seq_gen.check_consistency(
+            db_conn, table="state_groups", id_column="id"
+        )
 
     @cached(max_entries=10000, iterable=True)
     async def get_state_group_delta(self, state_group):
@@ -205,7 +208,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
     async def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
-    ) -> Dict[int, StateMap[str]]:
+    ) -> Dict[int, MutableStateMap[str]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index 908cbc79e3..d6d632dc10 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -97,3 +97,20 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
         """Gets a string giving the server version. For example: '3.22.0'
         """
         ...
+
+    @abc.abstractmethod
+    def in_transaction(self, conn: Connection) -> bool:
+        """Whether the connection is currently in a transaction.
+        """
+        ...
+
+    @abc.abstractmethod
+    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+        """Attempt to set the connections autocommit mode.
+
+        When True queries are run outside of transactions.
+
+        Note: This has no effect on SQLite3, so callers still need to
+        commit/rollback the connections.
+        """
+        ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index ff39281f85..7719ac32f7 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -15,7 +15,8 @@
 
 import logging
 
-from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
+from synapse.storage.engines._base import BaseDatabaseEngine, IncorrectDatabaseSetup
+from synapse.storage.types import Connection
 
 logger = logging.getLogger(__name__)
 
@@ -119,6 +120,7 @@ class PostgresEngine(BaseDatabaseEngine):
             cursor.execute("SET synchronous_commit TO OFF")
 
         cursor.close()
+        db_conn.commit()
 
     @property
     def can_native_upsert(self):
@@ -171,3 +173,9 @@ class PostgresEngine(BaseDatabaseEngine):
             return "%i.%i" % (numver / 10000, numver % 10000)
         else:
             return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)
+
+    def in_transaction(self, conn: Connection) -> bool:
+        return conn.status != self.module.extensions.STATUS_READY  # type: ignore
+
+    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+        return conn.set_session(autocommit=autocommit)  # type: ignore
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 8a0f8c89d1..5db0f0b520 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -17,6 +17,7 @@ import threading
 import typing
 
 from synapse.storage.engines import BaseDatabaseEngine
+from synapse.storage.types import Connection
 
 if typing.TYPE_CHECKING:
     import sqlite3  # noqa: F401
@@ -86,6 +87,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
 
         db_conn.create_function("rank", 1, _rank)
         db_conn.execute("PRAGMA foreign_keys = ON;")
+        db_conn.commit()
 
     def is_deadlock(self, error):
         return False
@@ -105,6 +107,14 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
         """
         return "%i.%i.%i" % self.module.sqlite_version_info
 
+    def in_transaction(self, conn: Connection) -> bool:
+        return conn.in_transaction  # type: ignore
+
+    def attempt_to_set_autocommit(self, conn: Connection, autocommit: bool):
+        # Twisted doesn't let us set attributes on the connections, so we can't
+        # set the connection to autocommit mode.
+        pass
+
 
 # Following functions taken from: https://github.com/coleifer/peewee
 
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index d89f6ed128..4d2d88d1f0 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.databases import Databases
 from synapse.storage.databases.main.events import DeltaState
-from synapse.types import Collection, StateMap
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.metrics import Measure
 
@@ -190,15 +190,16 @@ class EventsPersistenceStorage:
         self.persist_events_store = stores.persist_events
 
         self._clock = hs.get_clock()
+        self._instance_name = hs.get_instance_name()
         self.is_mine_id = hs.is_mine_id
         self._event_persist_queue = _EventPeristenceQueue()
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
     async def persist_events(
         self,
-        events_and_contexts: List[Tuple[EventBase, EventContext]],
+        events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ) -> int:
+    ) -> RoomStreamToken:
         """
         Write events to the database
         Args:
@@ -228,11 +229,11 @@ class EventsPersistenceStorage:
             defer.gatherResults(deferreds, consumeErrors=True)
         )
 
-        return self.main_store.get_current_events_token()
+        return self.main_store.get_room_max_token()
 
     async def persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
-    ) -> Tuple[int, int]:
+    ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
         """
         Returns:
             The stream ordering of `event`, and the stream ordering of the
@@ -246,8 +247,12 @@ class EventsPersistenceStorage:
 
         await make_deferred_yieldable(deferred)
 
-        max_persisted_id = self.main_store.get_current_events_token()
-        return (event.internal_metadata.stream_ordering, max_persisted_id)
+        event_stream_id = event.internal_metadata.stream_ordering
+        # stream ordering should have been assigned by now
+        assert event_stream_id
+
+        pos = PersistedEventPosition(self._instance_name, event_stream_id)
+        return pos, self.main_store.get_room_max_token()
 
     def _maybe_start_persisting(self, room_id: str):
         async def persisting_queue(item):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 77de025069..9e3dfe4805 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import imp
 import logging
 import os
@@ -24,9 +23,10 @@ from typing import Optional, TextIO
 import attr
 
 from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Cursor
 from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
@@ -64,7 +64,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
 
 
 def prepare_database(
-    db_conn: Connection,
+    db_conn: LoggingDatabaseConnection,
     database_engine: BaseDatabaseEngine,
     config: Optional[HomeServerConfig],
     databases: Collection[str] = ["main", "state"],
@@ -86,7 +86,7 @@ def prepare_database(
     """
 
     try:
-        cur = db_conn.cursor()
+        cur = db_conn.cursor(txn_name="prepare_database")
 
         # sqlite does not automatically start transactions for DDL / SELECT statements,
         # so we start one before running anything. This ensures that any upgrades
@@ -255,9 +255,7 @@ def _setup_new_database(cur, database_engine, databases):
             executescript(cur, entry.absolute_path)
 
     cur.execute(
-        database_engine.convert_param_style(
-            "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
-        ),
+        "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
         (max_current_ver, False),
     )
 
@@ -483,17 +481,13 @@ def _upgrade_existing_database(
 
             # Mark as done.
             cur.execute(
-                database_engine.convert_param_style(
-                    "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
-                ),
+                "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)",
                 (v, relative_path),
             )
 
             cur.execute("DELETE FROM schema_version")
             cur.execute(
-                database_engine.convert_param_style(
-                    "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
-                ),
+                "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
                 (v, True),
             )
 
@@ -529,10 +523,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
             schemas to be applied
     """
     cur.execute(
-        database_engine.convert_param_style(
-            "SELECT file FROM applied_module_schemas WHERE module_name = ?"
-        ),
-        (modname,),
+        "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
     )
     applied_deltas = {d for d, in cur}
     for (name, stream) in names_and_streams:
@@ -550,9 +541,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
 
         # Mark as done.
         cur.execute(
-            database_engine.convert_param_style(
-                "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
-            ),
+            "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)",
             (modname, name),
         )
 
@@ -624,9 +613,7 @@ def _get_or_create_schema_state(txn, database_engine):
 
     if current_version:
         txn.execute(
-            database_engine.convert_param_style(
-                "SELECT file FROM applied_schema_deltas WHERE version >= ?"
-            ),
+            "SELECT file FROM applied_schema_deltas WHERE version >= ?",
             (current_version,),
         )
         applied_deltas = [d for d, in txn]
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8c4a83a840..f152f63321 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
 )
 
 GetRoomsForUserWithStreamOrdering = namedtuple(
-    "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
+    "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
 )
 
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 8f68d968f0..08a69f2f96 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -20,7 +20,7 @@ import attr
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
 
 logger = logging.getLogger(__name__)
 
@@ -349,7 +349,7 @@ class StateGroupStorage:
 
     async def get_state_groups_ids(
         self, _room_id: str, event_ids: Iterable[str]
-    ) -> Dict[int, StateMap[str]]:
+    ) -> Dict[int, MutableStateMap[str]]:
         """Get the event IDs of all the state for the state groups for the given events
 
         Args:
@@ -532,7 +532,7 @@ class StateGroupStorage:
 
     def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
-    ) -> Awaitable[Dict[int, StateMap[str]]]:
+    ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
         """Gets the state at each of a list of state groups, optionally
         filtering by type/state_key
 
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 2d2b560e74..970bb1b9da 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -61,3 +61,9 @@ class Connection(Protocol):
 
     def rollback(self, *args, **kwargs) -> None:
         ...
+
+    def __enter__(self) -> "Connection":
+        ...
+
+    def __exit__(self, exc_type, exc_value, traceback) -> bool:
+        ...
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 1de2b91587..d7e40aaa8b 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,17 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-import contextlib
 import heapq
 import logging
 import threading
 from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
 
+import attr
 from typing_extensions import Deque
 
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
 logger = logging.getLogger(__name__)
@@ -53,7 +55,7 @@ def _load_current_id(db_conn, table, column, step=1):
     """
     # debug logging for https://github.com/matrix-org/synapse/issues/7968
     logger.info("initialising stream generator for %s(%s)", table, column)
-    cur = db_conn.cursor()
+    cur = db_conn.cursor(txn_name="_load_current_id")
     if step == 1:
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
     else:
@@ -86,7 +88,7 @@ class StreamIdGenerator:
             upwards, -1 to grow downwards.
 
     Usage:
-        with await stream_id_gen.get_next() as stream_id:
+        async with stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
 
@@ -101,10 +103,10 @@ class StreamIdGenerator:
             )
         self._unfinished_ids = deque()  # type: Deque[int]
 
-    async def get_next(self):
+    def get_next(self):
         """
         Usage:
-            with await stream_id_gen.get_next() as stream_id:
+            async with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
         with self._lock:
@@ -113,7 +115,7 @@ class StreamIdGenerator:
 
             self._unfinished_ids.append(next_id)
 
-        @contextlib.contextmanager
+        @contextmanager
         def manager():
             try:
                 yield next_id
@@ -121,12 +123,12 @@ class StreamIdGenerator:
                 with self._lock:
                     self._unfinished_ids.remove(next_id)
 
-        return manager()
+        return _AsyncCtxManagerWrapper(manager())
 
-    async def get_next_mult(self, n):
+    def get_next_mult(self, n):
         """
         Usage:
-            with await stream_id_gen.get_next(n) as stream_ids:
+            async with stream_id_gen.get_next(n) as stream_ids:
                 # ... persist events ...
         """
         with self._lock:
@@ -140,7 +142,7 @@ class StreamIdGenerator:
             for next_id in next_ids:
                 self._unfinished_ids.append(next_id)
 
-        @contextlib.contextmanager
+        @contextmanager
         def manager():
             try:
                 yield next_ids
@@ -149,7 +151,7 @@ class StreamIdGenerator:
                     for next_id in next_ids:
                         self._unfinished_ids.remove(next_id)
 
-        return manager()
+        return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
@@ -184,12 +186,16 @@ class MultiWriterIdGenerator:
     Args:
         db_conn
         db
+        stream_name: A name for the stream.
         instance_name: The name of this instance.
         table: Database table associated with stream.
         instance_column: Column that stores the row's writer's instance name
         id_column: Column that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
+        writers: A list of known writers to use to populate current positions
+            on startup. Can be empty if nothing uses `get_current_token` or
+            `get_positions` (e.g. caches stream).
         positive: Whether the IDs are positive (true) or negative (false).
             When using negative IDs we go backwards from -1 to -2, -3, etc.
     """
@@ -198,16 +204,20 @@ class MultiWriterIdGenerator:
         self,
         db_conn,
         db: DatabasePool,
+        stream_name: str,
         instance_name: str,
         table: str,
         instance_column: str,
         id_column: str,
         sequence_name: str,
+        writers: List[str],
         positive: bool = True,
     ):
         self._db = db
+        self._stream_name = stream_name
         self._instance_name = instance_name
         self._positive = positive
+        self._writers = writers
         self._return_factor = 1 if positive else -1
 
         # We lock as some functions may be called from DB threads.
@@ -216,9 +226,7 @@ class MultiWriterIdGenerator:
         # Note: If we are a negative stream then we still store all the IDs as
         # positive to make life easier for us, and simply negate the IDs when we
         # return them.
-        self._current_positions = self._load_current_ids(
-            db_conn, table, instance_column, id_column
-        )
+        self._current_positions = {}  # type: Dict[str, int]
 
         # Set of local IDs that we're still processing. The current position
         # should be less than the minimum of this set (if not empty).
@@ -251,30 +259,98 @@ class MultiWriterIdGenerator:
 
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
+        # We check that the table and sequence haven't diverged.
+        self._sequence_gen.check_consistency(
+            db_conn, table=table, id_column=id_column, positive=positive
+        )
+
+        # This goes and fills out the above state from the database.
+        self._load_current_ids(db_conn, table, instance_column, id_column)
+
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
-    ) -> Dict[str, int]:
-        # If positive stream aggregate via MAX. For negative stream use MIN
-        # *and* negate the result to get a positive number.
-        sql = """
-            SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
-            GROUP BY %(instance)s
-        """ % {
-            "instance": instance_column,
-            "id": id_column,
-            "table": table,
-            "agg": "MAX" if self._positive else "-MIN",
-        }
+    ):
+        cur = db_conn.cursor(txn_name="_load_current_ids")
+
+        # Load the current positions of all writers for the stream.
+        if self._writers:
+            # We delete any stale entries in the positions table. This is
+            # important if we add back a writer after a long time; we want to
+            # consider that a "new" writer, rather than using the old stale
+            # entry here.
+            sql = """
+                DELETE FROM stream_positions
+                WHERE
+                    stream_name = ?
+                    AND instance_name != ALL(?)
+            """
+            cur.execute(sql, (self._stream_name, self._writers))
+
+            sql = """
+                SELECT instance_name, stream_id FROM stream_positions
+                WHERE stream_name = ?
+            """
+            cur.execute(sql, (self._stream_name,))
+
+            self._current_positions = {
+                instance: stream_id * self._return_factor
+                for instance, stream_id in cur
+                if instance in self._writers
+            }
 
-        cur = db_conn.cursor()
-        cur.execute(sql)
+        # We set the `_persisted_upto_position` to be the minimum of all current
+        # positions. If empty we use the max stream ID from the DB table.
+        min_stream_id = min(self._current_positions.values(), default=None)
+
+        if min_stream_id is None:
+            # We add a GREATEST here to ensure that the result is always
+            # positive. (This can be a problem for e.g. backfill streams where
+            # the server has never backfilled).
+            sql = """
+                SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+                FROM %(table)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "agg": "MAX" if self._positive else "-MIN",
+            }
+            cur.execute(sql)
+            (stream_id,) = cur.fetchone()
+            self._persisted_upto_position = stream_id
+        else:
+            # If we have a min_stream_id then we pull out everything greater
+            # than it from the DB so that we can prefill
+            # `_known_persisted_positions` and get a more accurate
+            # `_persisted_upto_position`.
+            #
+            # We also check if any of the later rows are from this instance, in
+            # which case we use that for this instance's current position. This
+            # is to handle the case where we didn't finish persisting to the
+            # stream positions table before restart (or the stream position
+            # table otherwise got out of date).
+
+            sql = """
+                SELECT %(instance)s, %(id)s FROM %(table)s
+                WHERE ? %(cmp)s %(id)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "instance": instance_column,
+                "cmp": "<=" if self._positive else ">=",
+            }
+            cur.execute(sql, (min_stream_id * self._return_factor,))
 
-        # `cur` is an iterable over returned rows, which are 2-tuples.
-        current_positions = dict(cur)
+            self._persisted_upto_position = min_stream_id
 
-        cur.close()
+            with self._lock:
+                for (instance, stream_id,) in cur:
+                    stream_id = self._return_factor * stream_id
+                    self._add_persisted_position(stream_id)
 
-        return current_positions
+                    if instance == self._instance_name:
+                        self._current_positions[instance] = stream_id
+
+        cur.close()
 
     def _load_next_id_txn(self, txn) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
@@ -282,59 +358,23 @@ class MultiWriterIdGenerator:
     def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
         return self._sequence_gen.get_next_mult_txn(txn, n)
 
-    async def get_next(self):
+    def get_next(self):
         """
         Usage:
-            with await stream_id_gen.get_next() as stream_id:
+            async with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
-        next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
-        # Assert the fetched ID is actually greater than what we currently
-        # believe the ID to be. If not, then the sequence and table have got
-        # out of sync somehow.
-        with self._lock:
-            assert self._current_positions.get(self._instance_name, 0) < next_id
 
-            self._unfinished_ids.add(next_id)
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                # Multiply by the return factor so that the ID has correct sign.
-                yield self._return_factor * next_id
-            finally:
-                self._mark_id_as_finished(next_id)
+        return _MultiWriterCtxManager(self)
 
-        return manager()
-
-    async def get_next_mult(self, n: int):
+    def get_next_mult(self, n: int):
         """
         Usage:
-            with await stream_id_gen.get_next_mult(5) as stream_ids:
+            async with stream_id_gen.get_next_mult(5) as stream_ids:
                 # ... persist events ...
         """
-        next_ids = await self._db.runInteraction(
-            "_load_next_mult_id", self._load_next_mult_id_txn, n
-        )
 
-        # Assert the fetched ID is actually greater than any ID we've already
-        # seen. If not, then the sequence and table have got out of sync
-        # somehow.
-        with self._lock:
-            assert max(self._current_positions.values(), default=0) < min(next_ids)
-
-            self._unfinished_ids.update(next_ids)
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                yield [self._return_factor * i for i in next_ids]
-            finally:
-                for i in next_ids:
-                    self._mark_id_as_finished(i)
-
-        return manager()
+        return _MultiWriterCtxManager(self, n)
 
     def get_next_txn(self, txn: LoggingTransaction):
         """
@@ -352,6 +392,21 @@ class MultiWriterIdGenerator:
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
 
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persited row with the correct instance name.
+        if self._writers:
+            txn.call_after(
+                run_as_background_process,
+                "MultiWriterIdGenerator._update_table",
+                self._db.runInteraction,
+                "MultiWriterIdGenerator._update_table",
+                self._update_stream_positions_table_txn,
+            )
+
         return self._return_factor * next_id
 
     def _mark_id_as_finished(self, next_id: int):
@@ -363,7 +418,7 @@ class MultiWriterIdGenerator:
             self._unfinished_ids.discard(next_id)
             self._finished_ids.add(next_id)
 
-            new_cur = None
+            new_cur = None  # type: Optional[int]
 
             if self._unfinished_ids:
                 # If there are unfinished IDs then the new position will be the
@@ -408,11 +463,22 @@ class MultiWriterIdGenerator:
         """Returns the position of the given writer.
         """
 
+        # If we don't have an entry for the given instance name, we assume it's a
+        # new writer.
+        #
+        # For new writers we assume their initial position to be the current
+        # persisted up to position. This stops Synapse from doing a full table
+        # scan when a new writer announces itself over replication.
         with self._lock:
-            return self._return_factor * self._current_positions.get(instance_name, 0)
+            return self._return_factor * self._current_positions.get(
+                instance_name, self._persisted_upto_position
+            )
 
     def get_positions(self) -> Dict[str, int]:
         """Get a copy of the current positon map.
+
+        Note that this won't necessarily include all configured writers if some
+        writers haven't written anything yet.
         """
 
         with self._lock:
@@ -482,3 +548,104 @@ class MultiWriterIdGenerator:
                 # There was a gap in seen positions, so there is nothing more to
                 # do.
                 break
+
+    def _update_stream_positions_table_txn(self, txn: Cursor):
+        """Update the `stream_positions` table with newly persisted position.
+        """
+
+        if not self._writers:
+            return
+
+        # We upsert the value, ensuring on conflict that we always increase the
+        # value (or decrease if stream goes backwards).
+        sql = """
+            INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+            VALUES (?, ?, ?)
+            ON CONFLICT (stream_name, instance_name)
+            DO UPDATE SET
+                stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+        """ % {
+            "agg": "GREATEST" if self._positive else "LEAST",
+        }
+
+        pos = (self.get_current_token_for_writer(self._instance_name),)
+        txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+    """Helper class to convert a plain context manager to an async one.
+
+    This is mainly useful if you have a plain context manager but the interface
+    requires an async one.
+    """
+
+    inner = attr.ib()
+
+    async def __aenter__(self):
+        return self.inner.__enter__()
+
+    async def __aexit__(self, exc_type, exc, tb):
+        return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+    """Async context manager returned by MultiWriterIdGenerator
+    """
+
+    id_gen = attr.ib(type=MultiWriterIdGenerator)
+    multiple_ids = attr.ib(type=Optional[int], default=None)
+    stream_ids = attr.ib(type=List[int], factory=list)
+
+    async def __aenter__(self) -> Union[int, List[int]]:
+        # It's safe to run this in autocommit mode as fetching values from a
+        # sequence ignores transaction semantics anyway.
+        self.stream_ids = await self.id_gen._db.runInteraction(
+            "_load_next_mult_id",
+            self.id_gen._load_next_mult_id_txn,
+            self.multiple_ids or 1,
+            db_autocommit=True,
+        )
+
+        # Assert the fetched ID is actually greater than any ID we've already
+        # seen. If not, then the sequence and table have got out of sync
+        # somehow.
+        with self.id_gen._lock:
+            assert max(self.id_gen._current_positions.values(), default=0) < min(
+                self.stream_ids
+            )
+
+            self.id_gen._unfinished_ids.update(self.stream_ids)
+
+        if self.multiple_ids is None:
+            return self.stream_ids[0] * self.id_gen._return_factor
+        else:
+            return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+    async def __aexit__(self, exc_type, exc, tb):
+        for i in self.stream_ids:
+            self.id_gen._mark_id_as_finished(i)
+
+        if exc_type is not None:
+            return False
+
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persisted row with the correct instance name.
+        #
+        # We do this in autocommit mode as a) the upsert works correctly outside
+        # transactions and b) reduces the amount of time the rows are locked
+        # for. If we don't do this then we'll often hit serialization errors due
+        # to the fact we default to REPEATABLE READ isolation levels.
+        if self.id_gen._writers:
+            await self.id_gen._db.runInteraction(
+                "MultiWriterIdGenerator._update_table",
+                self.id_gen._update_stream_positions_table_txn,
+                db_autocommit=True,
+            )
+
+        return False
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index ffc1894748..ff2d038ad2 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -13,11 +13,35 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import abc
+import logging
 import threading
 from typing import Callable, List, Optional
 
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.database import LoggingDatabaseConnection
+from synapse.storage.engines import (
+    BaseDatabaseEngine,
+    IncorrectDatabaseSetup,
+    PostgresEngine,
+)
+from synapse.storage.types import Connection, Cursor
+
+logger = logging.getLogger(__name__)
+
+
+_INCONSISTENT_SEQUENCE_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated
+table '%(table)s'. This can happen if Synapse has been downgraded and
+then upgraded again, or due to a bad migration.
+
+To fix this error, shut down Synapse (including any and all workers)
+and run the following SQL:
+
+    SELECT setval('%(seq)s', (
+        %(max_id_sql)s
+    ));
+
+See docs/postgres.md for more information.
+"""
 
 
 class SequenceGenerator(metaclass=abc.ABCMeta):
@@ -28,6 +52,23 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         """Gets the next ID in the sequence"""
         ...
 
+    @abc.abstractmethod
+    def check_consistency(
+        self,
+        db_conn: LoggingDatabaseConnection,
+        table: str,
+        id_column: str,
+        positive: bool = True,
+    ):
+        """Should be called during start up to test that the current value of
+        the sequence is greater than or equal to the maximum ID in the table.
+
+        This is to handle various cases where the sequence value can get out
+        of sync with the table, e.g. if Synapse gets rolled back to a previous
+        version and the rolled forwards again.
+        """
+        ...
+
 
 class PostgresSequenceGenerator(SequenceGenerator):
     """An implementation of SequenceGenerator which uses a postgres sequence"""
@@ -45,6 +86,54 @@ class PostgresSequenceGenerator(SequenceGenerator):
         )
         return [i for (i,) in txn]
 
+    def check_consistency(
+        self,
+        db_conn: LoggingDatabaseConnection,
+        table: str,
+        id_column: str,
+        positive: bool = True,
+    ):
+        txn = db_conn.cursor(txn_name="sequence.check_consistency")
+
+        # First we get the current max ID from the table.
+        table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
+            "id": id_column,
+            "table": table,
+            "agg": "MAX" if positive else "-MIN",
+        }
+
+        txn.execute(table_sql)
+        row = txn.fetchone()
+        if not row:
+            # Table is empty, so nothing to do.
+            txn.close()
+            return
+
+        # Now we fetch the current value from the sequence and compare with the
+        # above.
+        max_stream_id = row[0]
+        txn.execute(
+            "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
+        )
+        last_value, is_called = txn.fetchone()
+        txn.close()
+
+        # If `is_called` is False then `last_value` is actually the value that
+        # will be generated next, so we decrement to get the true "last value".
+        if not is_called:
+            last_value -= 1
+
+        if max_stream_id > last_value:
+            logger.warning(
+                "Postgres sequence %s is behind table %s: %d < %d",
+                last_value,
+                max_stream_id,
+            )
+            raise IncorrectDatabaseSetup(
+                _INCONSISTENT_SEQUENCE_ERROR
+                % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
+            )
+
 
 GetFirstCallbackType = Callable[[Cursor], int]
 
@@ -81,6 +170,12 @@ class LocalSequenceGenerator(SequenceGenerator):
             self._current_max_id += 1
             return self._current_max_id
 
+    def check_consistency(
+        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+    ):
+        # There is nothing to do for in memory sequences
+        pass
+
 
 def build_sequence_generator(
     database_engine: BaseDatabaseEngine,