summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/database.py89
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/__init__.py192
-rw-r--r--synapse/storage/databases/main/account_data.py9
-rw-r--r--synapse/storage/databases/main/censor_events.py21
-rw-r--r--synapse/storage/databases/main/client_ips.py109
-rw-r--r--synapse/storage/databases/main/devices.py280
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py111
-rw-r--r--synapse/storage/databases/main/event_federation.py60
-rw-r--r--synapse/storage/databases/main/event_push_actions.py267
-rw-r--r--synapse/storage/databases/main/events.py34
-rw-r--r--synapse/storage/databases/main/events_worker.py71
-rw-r--r--synapse/storage/databases/main/metrics.py207
-rw-r--r--synapse/storage/databases/main/monthly_active_users.py110
-rw-r--r--synapse/storage/databases/main/registration.py215
-rw-r--r--synapse/storage/databases/main/room.py24
-rw-r--r--synapse/storage/databases/main/roommember.py18
-rw-r--r--synapse/storage/databases/main/schema/delta/20/pushers.py19
-rw-r--r--synapse/storage/databases/main/schema/delta/25/fts.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/27/ts.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/30/as_users.py6
-rw-r--r--synapse/storage/databases/main/schema/delta/31/pushers.py19
-rw-r--r--synapse/storage/databases/main/schema/delta/31/search_update.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/33/event_fields.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/33/remote_media_ts.py5
-rw-r--r--synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py7
-rw-r--r--synapse/storage/databases/main/schema/delta/57/local_current_membership.py1
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11dehydration.sql20
-rw-r--r--synapse/storage/databases/main/schema/delta/58/11fallback.sql24
-rw-r--r--synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres25
-rw-r--r--synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql17
-rw-r--r--synapse/storage/databases/main/stream.py295
-rw-r--r--synapse/storage/databases/main/transactions.py42
-rw-r--r--synapse/storage/databases/main/ui_auth.py6
-rw-r--r--synapse/storage/persist_events.py2
-rw-r--r--synapse/storage/prepare_database.py33
-rw-r--r--synapse/storage/types.py6
-rw-r--r--synapse/storage/util/id_generators.py20
-rw-r--r--synapse/storage/util/sequence.py17
39 files changed, 1480 insertions, 911 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6116191b16..0ba3a025cf 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -32,6 +32,7 @@ from typing import (
     overload,
 )
 
+import attr
 from prometheus_client import Histogram
 from typing_extensions import Literal
 
@@ -90,13 +91,17 @@ def make_pool(
     return adbapi.ConnectionPool(
         db_config.config["name"],
         cp_reactor=reactor,
-        cp_openfun=engine.on_new_connection,
+        cp_openfun=lambda conn: engine.on_new_connection(
+            LoggingDatabaseConnection(conn, engine, "on_new_connection")
+        ),
         **db_config.config.get("args", {})
     )
 
 
 def make_conn(
-    db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    db_config: DatabaseConnectionConfig,
+    engine: BaseDatabaseEngine,
+    default_txn_name: str,
 ) -> Connection:
     """Make a new connection to the database and return it.
 
@@ -109,11 +114,60 @@ def make_conn(
         for k, v in db_config.config.get("args", {}).items()
         if not k.startswith("cp_")
     }
-    db_conn = engine.module.connect(**db_params)
+    native_db_conn = engine.module.connect(**db_params)
+    db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
+
     engine.on_new_connection(db_conn)
     return db_conn
 
 
+@attr.s(slots=True)
+class LoggingDatabaseConnection:
+    """A wrapper around a database connection that returns `LoggingTransaction`
+    as its cursor class.
+
+    This is mainly used on startup to ensure that queries get logged correctly
+    """
+
+    conn = attr.ib(type=Connection)
+    engine = attr.ib(type=BaseDatabaseEngine)
+    default_txn_name = attr.ib(type=str)
+
+    def cursor(
+        self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+    ) -> "LoggingTransaction":
+        if not txn_name:
+            txn_name = self.default_txn_name
+
+        return LoggingTransaction(
+            self.conn.cursor(),
+            name=txn_name,
+            database_engine=self.engine,
+            after_callbacks=after_callbacks,
+            exception_callbacks=exception_callbacks,
+        )
+
+    def close(self) -> None:
+        self.conn.close()
+
+    def commit(self) -> None:
+        self.conn.commit()
+
+    def rollback(self, *args, **kwargs) -> None:
+        self.conn.rollback(*args, **kwargs)
+
+    def __enter__(self) -> "Connection":
+        self.conn.__enter__()
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback) -> bool:
+        return self.conn.__exit__(exc_type, exc_value, traceback)
+
+    # Proxy through any unknown lookups to the DB conn class.
+    def __getattr__(self, name):
+        return getattr(self.conn, name)
+
+
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
 #
 # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
@@ -247,6 +301,12 @@ class LoggingTransaction:
     def close(self) -> None:
         self.txn.close()
 
+    def __enter__(self) -> "LoggingTransaction":
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self.close()
+
 
 class PerformanceCounters:
     def __init__(self):
@@ -395,7 +455,7 @@ class DatabasePool:
 
     def new_transaction(
         self,
-        conn: Connection,
+        conn: LoggingDatabaseConnection,
         desc: str,
         after_callbacks: List[_CallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
@@ -436,12 +496,10 @@ class DatabasePool:
             i = 0
             N = 5
             while True:
-                cursor = LoggingTransaction(
-                    conn.cursor(),
-                    name,
-                    self.engine,
-                    after_callbacks,
-                    exception_callbacks,
+                cursor = conn.cursor(
+                    txn_name=name,
+                    after_callbacks=after_callbacks,
+                    exception_callbacks=exception_callbacks,
                 )
                 try:
                     r = func(cursor, *args, **kwargs)
@@ -638,7 +696,10 @@ class DatabasePool:
                     if db_autocommit:
                         self.engine.attempt_to_set_autocommit(conn, True)
 
-                    return func(conn, *args, **kwargs)
+                    db_conn = LoggingDatabaseConnection(
+                        conn, self.engine, "runWithConnection"
+                    )
+                    return func(db_conn, *args, **kwargs)
                 finally:
                     if db_autocommit:
                         self.engine.attempt_to_set_autocommit(conn, False)
@@ -1678,7 +1739,7 @@ class DatabasePool:
 
     def get_cache_dict(
         self,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         table: str,
         entity_column: str,
         stream_column: str,
@@ -1699,9 +1760,7 @@ class DatabasePool:
             "limit": limit,
         }
 
-        sql = self.engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
+        txn = db_conn.cursor(txn_name="get_cache_dict")
         txn.execute(sql, (int(max_value),))
 
         cache = {row[0]: int(row[1]) for row in txn}
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index aa5d490624..0c24325011 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -46,7 +46,7 @@ class Databases:
             db_name = database_config.name
             engine = create_engine(database_config.config)
 
-            with make_conn(database_config, engine) as db_conn:
+            with make_conn(database_config, engine, "startup") as db_conn:
                 logger.info("[database config %r]: Checking database server", db_name)
                 engine.check_database(db_conn)
 
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 0cb12f4c61..9b16f45f3e 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,9 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import calendar
 import logging
-import time
 from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import PresenceState
@@ -268,9 +266,6 @@ class DataStore(
         self._stream_order_on_start = self.get_room_max_stream_ordering()
         self._min_stream_order_on_start = self.get_room_min_stream_ordering()
 
-        # Used in _generate_user_daily_visits to keep track of progress
-        self._last_user_visit_update = self._get_start_of_day()
-
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()
 
@@ -289,7 +284,6 @@ class DataStore(
             " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
             " WHERE state != ?"
         )
-        sql = self.database_engine.convert_param_style(sql)
 
         txn = db_conn.cursor()
         txn.execute(sql, (PresenceState.OFFLINE,))
@@ -301,192 +295,6 @@ class DataStore(
 
         return [UserPresenceState(**row) for row in rows]
 
-    async def count_daily_users(self) -> int:
-        """
-        Counts the number of users who used this homeserver in the last 24 hours.
-        """
-        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-        return await self.db_pool.runInteraction(
-            "count_daily_users", self._count_users, yesterday
-        )
-
-    async def count_monthly_users(self) -> int:
-        """
-        Counts the number of users who used this homeserver in the last 30 days.
-        Note this method is intended for phonehome metrics only and is different
-        from the mau figure in synapse.storage.monthly_active_users which,
-        amongst other things, includes a 3 day grace period before a user counts.
-        """
-        thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-        return await self.db_pool.runInteraction(
-            "count_monthly_users", self._count_users, thirty_days_ago
-        )
-
-    def _count_users(self, txn, time_from):
-        """
-        Returns number of users seen in the past time_from period
-        """
-        sql = """
-            SELECT COALESCE(count(*), 0) FROM (
-                SELECT user_id FROM user_ips
-                WHERE last_seen > ?
-                GROUP BY user_id
-            ) u
-        """
-        txn.execute(sql, (time_from,))
-        (count,) = txn.fetchone()
-        return count
-
-    async def count_r30_users(self) -> Dict[str, int]:
-        """
-        Counts the number of 30 day retained users, defined as:-
-         * Users who have created their accounts more than 30 days ago
-         * Where last seen at most 30 days ago
-         * Where account creation and last_seen are > 30 days apart
-
-        Returns:
-             A mapping of counts globally as well as broken out by platform.
-        """
-
-        def _count_r30_users(txn):
-            thirty_days_in_secs = 86400 * 30
-            now = int(self._clock.time())
-            thirty_days_ago_in_secs = now - thirty_days_in_secs
-
-            sql = """
-                SELECT platform, COALESCE(count(*), 0) FROM (
-                     SELECT
-                        users.name, platform, users.creation_ts * 1000,
-                        MAX(uip.last_seen)
-                     FROM users
-                     INNER JOIN (
-                         SELECT
-                         user_id,
-                         last_seen,
-                         CASE
-                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
-                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
-                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
-                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
-                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
-                             ELSE 'unknown'
-                         END
-                         AS platform
-                         FROM user_ips
-                     ) uip
-                     ON users.name = uip.user_id
-                     AND users.appservice_id is NULL
-                     AND users.creation_ts < ?
-                     AND uip.last_seen/1000 > ?
-                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
-                     GROUP BY users.name, platform, users.creation_ts
-                ) u GROUP BY platform
-            """
-
-            results = {}
-            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
-            for row in txn:
-                if row[0] == "unknown":
-                    pass
-                results[row[0]] = row[1]
-
-            sql = """
-                SELECT COALESCE(count(*), 0) FROM (
-                    SELECT users.name, users.creation_ts * 1000,
-                                                        MAX(uip.last_seen)
-                    FROM users
-                    INNER JOIN (
-                        SELECT
-                        user_id,
-                        last_seen
-                        FROM user_ips
-                    ) uip
-                    ON users.name = uip.user_id
-                    AND appservice_id is NULL
-                    AND users.creation_ts < ?
-                    AND uip.last_seen/1000 > ?
-                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
-                    GROUP BY users.name, users.creation_ts
-                ) u
-            """
-
-            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
-            (count,) = txn.fetchone()
-            results["all"] = count
-
-            return results
-
-        return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
-
-    def _get_start_of_day(self):
-        """
-        Returns millisecond unixtime for start of UTC day.
-        """
-        now = time.gmtime()
-        today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
-        return today_start * 1000
-
-    async def generate_user_daily_visits(self) -> None:
-        """
-        Generates daily visit data for use in cohort/ retention analysis
-        """
-
-        def _generate_user_daily_visits(txn):
-            logger.info("Calling _generate_user_daily_visits")
-            today_start = self._get_start_of_day()
-            a_day_in_milliseconds = 24 * 60 * 60 * 1000
-            now = self.clock.time_msec()
-
-            sql = """
-                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
-                    SELECT u.user_id, u.device_id, ?
-                    FROM user_ips AS u
-                    LEFT JOIN (
-                      SELECT user_id, device_id, timestamp FROM user_daily_visits
-                      WHERE timestamp = ?
-                    ) udv
-                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
-                    INNER JOIN users ON users.name=u.user_id
-                    WHERE last_seen > ? AND last_seen <= ?
-                    AND udv.timestamp IS NULL AND users.is_guest=0
-                    AND users.appservice_id IS NULL
-                    GROUP BY u.user_id, u.device_id
-            """
-
-            # This means that the day has rolled over but there could still
-            # be entries from the previous day. There is an edge case
-            # where if the user logs in at 23:59 and overwrites their
-            # last_seen at 00:01 then they will not be counted in the
-            # previous day's stats - it is important that the query is run
-            # often to minimise this case.
-            if today_start > self._last_user_visit_update:
-                yesterday_start = today_start - a_day_in_milliseconds
-                txn.execute(
-                    sql,
-                    (
-                        yesterday_start,
-                        yesterday_start,
-                        self._last_user_visit_update,
-                        today_start,
-                    ),
-                )
-                self._last_user_visit_update = today_start
-
-            txn.execute(
-                sql, (today_start, today_start, self._last_user_visit_update, now)
-            )
-            # Update _last_user_visit_update to now. The reason to do this
-            # rather just clamping to the beginning of the day is to limit
-            # the size of the join - meaning that the query can be run more
-            # frequently
-            self._last_user_visit_update = now
-
-        await self.db_pool.runInteraction(
-            "generate_user_daily_visits", _generate_user_daily_visits
-        )
-
     async def get_users(self) -> List[Dict[str, Any]]:
         """Function to retrieve a list of users in users table.
 
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index ef81d73573..49ee23470d 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -18,6 +18,7 @@ import abc
 import logging
 from typing import Dict, List, Optional, Tuple
 
+from synapse.api.constants import AccountDataTypes
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.util.id_generators import StreamIdGenerator
@@ -291,14 +292,18 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
         self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
     ) -> bool:
         ignored_account_data = await self.get_global_account_data_by_type_for_user(
-            "m.ignored_user_list",
+            AccountDataTypes.IGNORED_USER_LIST,
             ignorer_user_id,
             on_invalidate=cache_context.invalidate,
         )
         if not ignored_account_data:
             return False
 
-        return ignored_user_id in ignored_account_data.get("ignored_users", {})
+        try:
+            return ignored_user_id in ignored_account_data.get("ignored_users", {})
+        except TypeError:
+            # The type of the ignored_users field is invalid.
+            return False
 
 
 class AccountDataStore(AccountDataWorkerStore):
diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py
index f211ddbaf8..849bd5ba7a 100644
--- a/synapse/storage/databases/main/censor_events.py
+++ b/synapse/storage/databases/main/censor_events.py
@@ -17,12 +17,12 @@ import logging
 from typing import TYPE_CHECKING
 
 from synapse.events.utils import prune_event_dict
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
-from synapse.storage.databases.main.events import encode_json
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.util.frozenutils import frozendict_json_encoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -35,14 +35,13 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
-        def _censor_redactions():
-            return run_as_background_process(
-                "_censor_redactions", self._censor_redactions
-            )
-
-        if self.hs.config.redaction_retention_period is not None:
-            hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+        if (
+            hs.config.run_background_tasks
+            and self.hs.config.redaction_retention_period is not None
+        ):
+            hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000)
 
+    @wrap_as_background_process("_censor_redactions")
     async def _censor_redactions(self):
         """Censors all redactions older than the configured period that haven't
         been censored yet.
@@ -105,7 +104,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 and original_event.internal_metadata.is_redacted()
             ):
                 # Redaction was allowed
-                pruned_json = encode_json(
+                pruned_json = frozendict_json_encoder.encode(
                     prune_event_dict(
                         original_event.room_version, original_event.get_dict()
                     )
@@ -171,7 +170,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
                 return
 
             # Prune the event's dict then convert it to JSON.
-            pruned_json = encode_json(
+            pruned_json = frozendict_json_encoder.encode(
                 prune_event_dict(event.room_version, event.get_dict())
             )
 
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 239c7a949c..a25a888443 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -351,7 +351,63 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         return updated
 
 
-class ClientIpStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self.user_ips_max_age = hs.config.user_ips_max_age
+
+        if hs.config.run_background_tasks and self.user_ips_max_age:
+            self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+
+    @wrap_as_background_process("prune_old_user_ips")
+    async def _prune_old_user_ips(self):
+        """Removes entries in user IPs older than the configured period.
+        """
+
+        if self.user_ips_max_age is None:
+            # Nothing to do
+            return
+
+        if not await self.db_pool.updates.has_completed_background_update(
+            "devices_last_seen"
+        ):
+            # Only start pruning if we have finished populating the devices
+            # last seen info.
+            return
+
+        # We do a slightly funky SQL delete to ensure we don't try and delete
+        # too much at once (as the table may be very large from before we
+        # started pruning).
+        #
+        # This works by finding the max last_seen that is less than the given
+        # time, but has no more than N rows before it, deleting all rows with
+        # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+        # returns exactly one row).
+        sql = """
+            DELETE FROM user_ips
+            WHERE last_seen <= (
+                SELECT COALESCE(MAX(last_seen), -1)
+                FROM (
+                    SELECT last_seen FROM user_ips
+                    WHERE last_seen <= ?
+                    ORDER BY last_seen ASC
+                    LIMIT 5000
+                ) AS u
+            )
+        """
+
+        timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+        def _prune_old_user_ips_txn(txn):
+            txn.execute(sql, (timestamp,))
+
+        await self.db_pool.runInteraction(
+            "_prune_old_user_ips", _prune_old_user_ips_txn
+        )
+
+
+class ClientIpStore(ClientIpWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
 
         self.client_ip_last_seen = Cache(
@@ -360,8 +416,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
 
         super().__init__(database, db_conn, hs)
 
-        self.user_ips_max_age = hs.config.user_ips_max_age
-
         # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
         self._batch_row_update = {}
 
@@ -372,9 +426,6 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             "before", "shutdown", self._update_client_ips_batch
         )
 
-        if self.user_ips_max_age:
-            self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
-
     async def insert_client_ip(
         self, user_id, access_token, ip, user_agent, device_id, now=None
     ):
@@ -525,49 +576,3 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             }
             for (access_token, ip), (user_agent, last_seen) in results.items()
         ]
-
-    @wrap_as_background_process("prune_old_user_ips")
-    async def _prune_old_user_ips(self):
-        """Removes entries in user IPs older than the configured period.
-        """
-
-        if self.user_ips_max_age is None:
-            # Nothing to do
-            return
-
-        if not await self.db_pool.updates.has_completed_background_update(
-            "devices_last_seen"
-        ):
-            # Only start pruning if we have finished populating the devices
-            # last seen info.
-            return
-
-        # We do a slightly funky SQL delete to ensure we don't try and delete
-        # too much at once (as the table may be very large from before we
-        # started pruning).
-        #
-        # This works by finding the max last_seen that is less than the given
-        # time, but has no more than N rows before it, deleting all rows with
-        # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
-        # returns exactly one row).
-        sql = """
-            DELETE FROM user_ips
-            WHERE last_seen <= (
-                SELECT COALESCE(MAX(last_seen), -1)
-                FROM (
-                    SELECT last_seen FROM user_ips
-                    WHERE last_seen <= ?
-                    ORDER BY last_seen ASC
-                    LIMIT 5000
-                ) AS u
-            )
-        """
-
-        timestamp = self.clock.time_msec() - self.user_ips_max_age
-
-        def _prune_old_user_ips_txn(txn):
-            txn.execute(sql, (timestamp,))
-
-        await self.db_pool.runInteraction(
-            "_prune_old_user_ips", _prune_old_user_ips_txn
-        )
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index fdf394c612..88fd97e1df 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
 # Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@ from synapse.logging.opentracing import (
     trace,
     whitelisted_homeserver,
 )
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
@@ -33,7 +33,7 @@ from synapse.storage.database import (
     make_tuple_comparison_clause,
 )
 from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
-from synapse.util import json_encoder
+from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import Cache, cached, cachedList
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
@@ -48,6 +48,14 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
 class DeviceWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(
+                self._prune_old_outbound_device_pokes, 60 * 60 * 1000
+            )
+
     async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
         """Retrieve a device. Only returns devices that are not marked as
         hidden.
@@ -698,6 +706,172 @@ class DeviceWorkerStore(SQLBaseStore):
             _mark_remote_user_device_list_as_unsubscribed_txn,
         )
 
+    async def get_dehydrated_device(
+        self, user_id: str
+    ) -> Optional[Tuple[str, JsonDict]]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        # FIXME: make sure device ID still exists in devices table
+        row = await self.db_pool.simple_select_one(
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            retcols=["device_id", "device_data"],
+            allow_none=True,
+        )
+        return (
+            (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
+        )
+
+    def _store_dehydrated_device_txn(
+        self, txn, user_id: str, device_id: str, device_data: str
+    ) -> Optional[str]:
+        old_device_id = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            retcol="device_id",
+            allow_none=True,
+        )
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="dehydrated_devices",
+            keyvalues={"user_id": user_id},
+            values={"device_id": device_id, "device_data": device_data},
+        )
+        return old_device_id
+
+    async def store_dehydrated_device(
+        self, user_id: str, device_id: str, device_data: JsonDict
+    ) -> Optional[str]:
+        """Store a dehydrated device for a user.
+
+        Args:
+            user_id: the user that we are storing the device for
+            device_id: the ID of the dehydrated device
+            device_data: the dehydrated device information
+        Returns:
+            device id of the user's previous dehydrated device, if any
+        """
+        return await self.db_pool.runInteraction(
+            "store_dehydrated_device_txn",
+            self._store_dehydrated_device_txn,
+            user_id,
+            device_id,
+            json_encoder.encode(device_data),
+        )
+
+    async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
+        """Remove a dehydrated device.
+
+        Args:
+            user_id: the user that the dehydrated device belongs to
+            device_id: the ID of the dehydrated device
+        """
+        count = await self.db_pool.simple_delete(
+            "dehydrated_devices",
+            {"user_id": user_id, "device_id": device_id},
+            desc="remove_dehydrated_device",
+        )
+        return count >= 1
+
+    @wrap_as_background_process("prune_old_outbound_device_pokes")
+    async def _prune_old_outbound_device_pokes(
+        self, prune_age: int = 24 * 60 * 60 * 1000
+    ) -> None:
+        """Delete old entries out of the device_lists_outbound_pokes to ensure
+        that we don't fill up due to dead servers.
+
+        Normally, we try to send device updates as a delta since a previous known point:
+        this is done by setting the prev_id in the m.device_list_update EDU. However,
+        for that to work, we have to have a complete record of each change to
+        each device, which can add up to quite a lot of data.
+
+        An alternative mechanism is that, if the remote server sees that it has missed
+        an entry in the stream_id sequence for a given user, it will request a full
+        list of that user's devices. Hence, we can reduce the amount of data we have to
+        store (and transmit in some future transaction), by clearing almost everything
+        for a given destination out of the database, and having the remote server
+        resync.
+
+        All we need to do is make sure we keep at least one row for each
+        (user, destination) pair, to remind us to send a m.device_list_update EDU for
+        that user when the destination comes back. It doesn't matter which device
+        we keep.
+        """
+        yesterday = self._clock.time_msec() - prune_age
+
+        def _prune_txn(txn):
+            # look for (user, destination) pairs which have an update older than
+            # the cutoff.
+            #
+            # For each pair, we also need to know the most recent stream_id, and
+            # an arbitrary device_id at that stream_id.
+            select_sql = """
+            SELECT
+                dlop1.destination,
+                dlop1.user_id,
+                MAX(dlop1.stream_id) AS stream_id,
+                (SELECT MIN(dlop2.device_id) AS device_id FROM
+                    device_lists_outbound_pokes dlop2
+                    WHERE dlop2.destination = dlop1.destination AND
+                      dlop2.user_id=dlop1.user_id AND
+                      dlop2.stream_id=MAX(dlop1.stream_id)
+                )
+            FROM device_lists_outbound_pokes dlop1
+                GROUP BY destination, user_id
+                HAVING min(ts) < ? AND count(*) > 1
+            """
+
+            txn.execute(select_sql, (yesterday,))
+            rows = txn.fetchall()
+
+            if not rows:
+                return
+
+            logger.info(
+                "Pruning old outbound device list updates for %i users/destinations: %s",
+                len(rows),
+                shortstr((row[0], row[1]) for row in rows),
+            )
+
+            # we want to keep the update with the highest stream_id for each user.
+            #
+            # there might be more than one update (with different device_ids) with the
+            # same stream_id, so we also delete all but one rows with the max stream id.
+            delete_sql = """
+                DELETE FROM device_lists_outbound_pokes
+                WHERE destination = ? AND user_id = ? AND (
+                    stream_id < ? OR
+                    (stream_id = ? AND device_id != ?)
+                )
+            """
+            count = 0
+            for (destination, user_id, stream_id, device_id) in rows:
+                txn.execute(
+                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+                )
+                count += txn.rowcount
+
+            # Since we've deleted unsent deltas, we need to remove the entry
+            # of last successful sent so that the prev_ids are correctly set.
+            sql = """
+                DELETE FROM device_lists_outbound_last_success
+                WHERE destination = ? AND user_id = ?
+            """
+            txn.executemany(sql, ((row[0], row[1]) for row in rows))
+
+            logger.info("Pruned %d device list outbound pokes", count)
+
+        await self.db_pool.runInteraction(
+            "_prune_old_outbound_device_pokes", _prune_txn,
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -834,10 +1008,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             name="device_id_exists", keylen=2, max_entries=10000
         )
 
-        self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
-
     async def store_device(
-        self, user_id: str, device_id: str, initial_device_display_name: str
+        self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
     ) -> bool:
         """Ensure the given device is known; add it to the store if not
 
@@ -955,7 +1127,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     async def update_remote_device_list_cache_entry(
-        self, user_id: str, device_id: str, content: JsonDict, stream_id: int
+        self, user_id: str, device_id: str, content: JsonDict, stream_id: str
     ) -> None:
         """Updates a single device in the cache of a remote user's devicelist.
 
@@ -983,7 +1155,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_id: str,
         content: JsonDict,
-        stream_id: int,
+        stream_id: str,
     ) -> None:
         if content.get("deleted"):
             self.db_pool.simple_delete_txn(
@@ -1193,95 +1365,3 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 for device_id in device_ids
             ],
         )
-
-    def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
-        """Delete old entries out of the device_lists_outbound_pokes to ensure
-        that we don't fill up due to dead servers.
-
-        Normally, we try to send device updates as a delta since a previous known point:
-        this is done by setting the prev_id in the m.device_list_update EDU. However,
-        for that to work, we have to have a complete record of each change to
-        each device, which can add up to quite a lot of data.
-
-        An alternative mechanism is that, if the remote server sees that it has missed
-        an entry in the stream_id sequence for a given user, it will request a full
-        list of that user's devices. Hence, we can reduce the amount of data we have to
-        store (and transmit in some future transaction), by clearing almost everything
-        for a given destination out of the database, and having the remote server
-        resync.
-
-        All we need to do is make sure we keep at least one row for each
-        (user, destination) pair, to remind us to send a m.device_list_update EDU for
-        that user when the destination comes back. It doesn't matter which device
-        we keep.
-        """
-        yesterday = self._clock.time_msec() - prune_age
-
-        def _prune_txn(txn):
-            # look for (user, destination) pairs which have an update older than
-            # the cutoff.
-            #
-            # For each pair, we also need to know the most recent stream_id, and
-            # an arbitrary device_id at that stream_id.
-            select_sql = """
-            SELECT
-                dlop1.destination,
-                dlop1.user_id,
-                MAX(dlop1.stream_id) AS stream_id,
-                (SELECT MIN(dlop2.device_id) AS device_id FROM
-                    device_lists_outbound_pokes dlop2
-                    WHERE dlop2.destination = dlop1.destination AND
-                      dlop2.user_id=dlop1.user_id AND
-                      dlop2.stream_id=MAX(dlop1.stream_id)
-                )
-            FROM device_lists_outbound_pokes dlop1
-                GROUP BY destination, user_id
-                HAVING min(ts) < ? AND count(*) > 1
-            """
-
-            txn.execute(select_sql, (yesterday,))
-            rows = txn.fetchall()
-
-            if not rows:
-                return
-
-            logger.info(
-                "Pruning old outbound device list updates for %i users/destinations: %s",
-                len(rows),
-                shortstr((row[0], row[1]) for row in rows),
-            )
-
-            # we want to keep the update with the highest stream_id for each user.
-            #
-            # there might be more than one update (with different device_ids) with the
-            # same stream_id, so we also delete all but one rows with the max stream id.
-            delete_sql = """
-                DELETE FROM device_lists_outbound_pokes
-                WHERE destination = ? AND user_id = ? AND (
-                    stream_id < ? OR
-                    (stream_id = ? AND device_id != ?)
-                )
-            """
-            count = 0
-            for (destination, user_id, stream_id, device_id) in rows:
-                txn.execute(
-                    delete_sql, (destination, user_id, stream_id, stream_id, device_id)
-                )
-                count += txn.rowcount
-
-            # Since we've deleted unsent deltas, we need to remove the entry
-            # of last successful sent so that the prev_ids are correctly set.
-            sql = """
-                DELETE FROM device_lists_outbound_last_success
-                WHERE destination = ? AND user_id = ?
-            """
-            txn.executemany(sql, ((row[0], row[1]) for row in rows))
-
-            logger.info("Pruned %d device list outbound pokes", count)
-
-        return run_as_background_process(
-            "prune_old_outbound_device_pokes",
-            self.db_pool.runInteraction,
-            "_prune_old_outbound_device_pokes",
-            _prune_txn,
-        )
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 22e1ed15d0..4415909414 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
 # Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -367,6 +367,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             "count_e2e_one_time_keys", _count_e2e_one_time_keys
         )
 
+    async def set_e2e_fallback_keys(
+        self, user_id: str, device_id: str, fallback_keys: JsonDict
+    ) -> None:
+        """Set the user's e2e fallback keys.
+
+        Args:
+            user_id: the user whose keys are being set
+            device_id: the device whose keys are being set
+            fallback_keys: the keys to set.  This is a map from key ID (which is
+                of the form "algorithm:id") to key data.
+        """
+        # fallback_keys will usually only have one item in it, so using a for
+        # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
+        # FIXME: make sure that only one key per algorithm is uploaded
+        for key_id, fallback_key in fallback_keys.items():
+            algorithm, key_id = key_id.split(":", 1)
+            await self.db_pool.simple_upsert(
+                "e2e_fallback_keys_json",
+                keyvalues={
+                    "user_id": user_id,
+                    "device_id": device_id,
+                    "algorithm": algorithm,
+                },
+                values={
+                    "key_id": key_id,
+                    "key_json": json_encoder.encode(fallback_key),
+                    "used": False,
+                },
+                desc="set_e2e_fallback_key",
+            )
+
+        await self.invalidate_cache_and_stream(
+            "get_e2e_unused_fallback_key_types", (user_id, device_id)
+        )
+
+    @cached(max_entries=10000)
+    async def get_e2e_unused_fallback_key_types(
+        self, user_id: str, device_id: str
+    ) -> List[str]:
+        """Returns the fallback key types that have an unused key.
+
+        Args:
+            user_id: the user whose keys are being queried
+            device_id: the device whose keys are being queried
+
+        Returns:
+            a list of key types
+        """
+        return await self.db_pool.simple_select_onecol(
+            "e2e_fallback_keys_json",
+            keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
+            retcol="algorithm",
+            desc="get_e2e_unused_fallback_key_types",
+        )
+
     async def get_e2e_cross_signing_key(
         self, user_id: str, key_type: str, from_user_id: Optional[str] = None
     ) -> Optional[dict]:
@@ -701,15 +756,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
                 " LIMIT 1"
             )
+            fallback_sql = (
+                "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
+                " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+                " LIMIT 1"
+            )
             result = {}
             delete = []
+            used_fallbacks = []
             for user_id, device_id, algorithm in query_list:
                 user_result = result.setdefault(user_id, {})
                 device_result = user_result.setdefault(device_id, {})
                 txn.execute(sql, (user_id, device_id, algorithm))
-                for key_id, key_json in txn:
+                otk_row = txn.fetchone()
+                if otk_row is not None:
+                    key_id, key_json = otk_row
                     device_result[algorithm + ":" + key_id] = key_json
                     delete.append((user_id, device_id, algorithm, key_id))
+                else:
+                    # no one-time key available, so see if there's a fallback
+                    # key
+                    txn.execute(fallback_sql, (user_id, device_id, algorithm))
+                    fallback_row = txn.fetchone()
+                    if fallback_row is not None:
+                        key_id, key_json, used = fallback_row
+                        device_result[algorithm + ":" + key_id] = key_json
+                        if not used:
+                            used_fallbacks.append(
+                                (user_id, device_id, algorithm, key_id)
+                            )
+
+            # drop any one-time keys that were claimed
             sql = (
                 "DELETE FROM e2e_one_time_keys_json"
                 " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@@ -726,6 +803,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 self._invalidate_cache_and_stream(
                     txn, self.count_e2e_one_time_keys, (user_id, device_id)
                 )
+            # mark fallback keys as used
+            for user_id, device_id, algorithm, key_id in used_fallbacks:
+                self.db_pool.simple_update_txn(
+                    txn,
+                    "e2e_fallback_keys_json",
+                    {
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        "algorithm": algorithm,
+                        "key_id": key_id,
+                    },
+                    {"used": True},
+                )
+                self._invalidate_cache_and_stream(
+                    txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+                )
+
             return result
 
         return await self.db_pool.runInteraction(
@@ -754,6 +848,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             self._invalidate_cache_and_stream(
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="dehydrated_devices",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="e2e_fallback_keys_json",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+            )
 
         await self.db_pool.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6d3689c09e..a6279a6c13 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -19,7 +19,7 @@ from typing import Dict, Iterable, List, Set, Tuple
 
 from synapse.api.errors import StoreError
 from synapse.events import EventBase
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
@@ -32,6 +32,14 @@ logger = logging.getLogger(__name__)
 
 
 class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        if hs.config.run_background_tasks:
+            hs.get_clock().looping_call(
+                self._delete_old_forward_extrem_cache, 60 * 60 * 1000
+            )
+
     async def get_auth_chain(
         self, event_ids: Collection[str], include_given: bool = False
     ) -> List[EventBase]:
@@ -586,6 +594,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return [row["event_id"] for row in rows]
 
+    @wrap_as_background_process("delete_old_forward_extrem_cache")
+    async def _delete_old_forward_extrem_cache(self) -> None:
+        def _delete_old_forward_extrem_cache_txn(txn):
+            # Delete entries older than a month, while making sure we don't delete
+            # the only entries for a room.
+            sql = """
+                DELETE FROM stream_ordering_to_exterm
+                WHERE
+                room_id IN (
+                    SELECT room_id
+                    FROM stream_ordering_to_exterm
+                    WHERE stream_ordering > ?
+                ) AND stream_ordering < ?
+            """
+            txn.execute(
+                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
+            )
+
+        await self.db_pool.runInteraction(
+            "_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
+        )
+
 
 class EventFederationStore(EventFederationWorkerStore):
     """ Responsible for storing and serving up the various graphs associated
@@ -606,34 +636,6 @@ class EventFederationStore(EventFederationWorkerStore):
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
         )
 
-        hs.get_clock().looping_call(
-            self._delete_old_forward_extrem_cache, 60 * 60 * 1000
-        )
-
-    def _delete_old_forward_extrem_cache(self):
-        def _delete_old_forward_extrem_cache_txn(txn):
-            # Delete entries older than a month, while making sure we don't delete
-            # the only entries for a room.
-            sql = """
-                DELETE FROM stream_ordering_to_exterm
-                WHERE
-                room_id IN (
-                    SELECT room_id
-                    FROM stream_ordering_to_exterm
-                    WHERE stream_ordering > ?
-                ) AND stream_ordering < ?
-            """
-            txn.execute(
-                sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
-            )
-
-        return run_as_background_process(
-            "delete_old_forward_extrem_cache",
-            self.db_pool.runInteraction,
-            "_delete_old_forward_extrem_cache",
-            _delete_old_forward_extrem_cache_txn,
-        )
-
     async def clean_room_for_join(self, room_id):
         return await self.db_pool.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 62f1738732..2e56dfaf31 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -13,15 +13,14 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import Dict, List, Optional, Tuple, Union
 
 import attr
 
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 
@@ -74,19 +73,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         self.stream_ordering_month_ago = None
         self.stream_ordering_day_ago = None
 
-        cur = LoggingTransaction(
-            db_conn.cursor(),
-            name="_find_stream_orderings_for_times_txn",
-            database_engine=self.database_engine,
-        )
+        cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
         self._find_stream_orderings_for_times_txn(cur)
         cur.close()
 
         self.find_stream_orderings_looping_call = self._clock.looping_call(
             self._find_stream_orderings_for_times, 10 * 60 * 1000
         )
+
         self._rotate_delay = 3
         self._rotate_count = 10000
+        self._doing_notif_rotation = False
+        if hs.config.run_background_tasks:
+            self._rotate_notif_loop = self._clock.looping_call(
+                self._rotate_notifs, 30 * 60 * 1000
+            )
 
     @cached(num_args=3, tree=True, max_entries=5000)
     async def get_unread_event_push_actions_by_room_for_user(
@@ -518,15 +519,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 "Error removing push actions after event persistence failure"
             )
 
-    def _find_stream_orderings_for_times(self):
-        return run_as_background_process(
-            "event_push_action_stream_orderings",
-            self.db_pool.runInteraction,
+    @wrap_as_background_process("event_push_action_stream_orderings")
+    async def _find_stream_orderings_for_times(self) -> None:
+        await self.db_pool.runInteraction(
             "_find_stream_orderings_for_times",
             self._find_stream_orderings_for_times_txn,
         )
 
-    def _find_stream_orderings_for_times_txn(self, txn):
+    def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None:
         logger.info("Searching for stream ordering 1 month ago")
         self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn(
             txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
@@ -656,129 +656,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         )
         return result[0] if result else None
 
-
-class EventPushActionsStore(EventPushActionsWorkerStore):
-    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
-
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        self.db_pool.updates.register_background_index_update(
-            self.EPA_HIGHLIGHT_INDEX,
-            index_name="event_push_actions_u_highlight",
-            table="event_push_actions",
-            columns=["user_id", "stream_ordering"],
-        )
-
-        self.db_pool.updates.register_background_index_update(
-            "event_push_actions_highlights_index",
-            index_name="event_push_actions_highlights_index",
-            table="event_push_actions",
-            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
-            where_clause="highlight=1",
-        )
-
-        self._doing_notif_rotation = False
-        self._rotate_notif_loop = self._clock.looping_call(
-            self._start_rotate_notifs, 30 * 60 * 1000
-        )
-
-    async def get_push_actions_for_user(
-        self, user_id, before=None, limit=50, only_highlight=False
-    ):
-        def f(txn):
-            before_clause = ""
-            if before:
-                before_clause = "AND epa.stream_ordering < ?"
-                args = [user_id, before, limit]
-            else:
-                args = [user_id, limit]
-
-            if only_highlight:
-                if len(before_clause) > 0:
-                    before_clause += " "
-                before_clause += "AND epa.highlight = 1"
-
-            # NB. This assumes event_ids are globally unique since
-            # it makes the query easier to index
-            sql = (
-                "SELECT epa.event_id, epa.room_id,"
-                " epa.stream_ordering, epa.topological_ordering,"
-                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
-                " FROM event_push_actions epa, events e"
-                " WHERE epa.event_id = e.event_id"
-                " AND epa.user_id = ? %s"
-                " AND epa.notif = 1"
-                " ORDER BY epa.stream_ordering DESC"
-                " LIMIT ?" % (before_clause,)
-            )
-            txn.execute(sql, args)
-            return self.db_pool.cursor_to_dict(txn)
-
-        push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
-        for pa in push_actions:
-            pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
-        return push_actions
-
-    async def get_latest_push_action_stream_ordering(self):
-        def f(txn):
-            txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
-            return txn.fetchone()
-
-        result = await self.db_pool.runInteraction(
-            "get_latest_push_action_stream_ordering", f
-        )
-        return result[0] or 0
-
-    def _remove_old_push_actions_before_txn(
-        self, txn, room_id, user_id, stream_ordering
-    ):
-        """
-        Purges old push actions for a user and room before a given
-        stream_ordering.
-
-        We however keep a months worth of highlighted notifications, so that
-        users can still get a list of recent highlights.
-
-        Args:
-            txn: The transcation
-            room_id: Room ID to delete from
-            user_id: user ID to delete for
-            stream_ordering: The lowest stream ordering which will
-                                  not be deleted.
-        """
-        txn.call_after(
-            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
-            (room_id, user_id),
-        )
-
-        # We need to join on the events table to get the received_ts for
-        # event_push_actions and sqlite won't let us use a join in a delete so
-        # we can't just delete where received_ts < x. Furthermore we can
-        # only identify event_push_actions by a tuple of room_id, event_id
-        # we we can't use a subquery.
-        # Instead, we look up the stream ordering for the last event in that
-        # room received before the threshold time and delete event_push_actions
-        # in the room with a stream_odering before that.
-        txn.execute(
-            "DELETE FROM event_push_actions "
-            " WHERE user_id = ? AND room_id = ? AND "
-            " stream_ordering <= ?"
-            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
-            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
-        )
-
-        txn.execute(
-            """
-            DELETE FROM event_push_summary
-            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
-        """,
-            (room_id, user_id, stream_ordering),
-        )
-
-    def _start_rotate_notifs(self):
-        return run_as_background_process("rotate_notifs", self._rotate_notifs)
-
+    @wrap_as_background_process("rotate_notifs")
     async def _rotate_notifs(self):
         if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
             return
@@ -958,6 +836,121 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         )
 
 
+class EventPushActionsStore(EventPushActionsWorkerStore):
+    EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
+
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self.db_pool.updates.register_background_index_update(
+            self.EPA_HIGHLIGHT_INDEX,
+            index_name="event_push_actions_u_highlight",
+            table="event_push_actions",
+            columns=["user_id", "stream_ordering"],
+        )
+
+        self.db_pool.updates.register_background_index_update(
+            "event_push_actions_highlights_index",
+            index_name="event_push_actions_highlights_index",
+            table="event_push_actions",
+            columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
+            where_clause="highlight=1",
+        )
+
+    async def get_push_actions_for_user(
+        self, user_id, before=None, limit=50, only_highlight=False
+    ):
+        def f(txn):
+            before_clause = ""
+            if before:
+                before_clause = "AND epa.stream_ordering < ?"
+                args = [user_id, before, limit]
+            else:
+                args = [user_id, limit]
+
+            if only_highlight:
+                if len(before_clause) > 0:
+                    before_clause += " "
+                before_clause += "AND epa.highlight = 1"
+
+            # NB. This assumes event_ids are globally unique since
+            # it makes the query easier to index
+            sql = (
+                "SELECT epa.event_id, epa.room_id,"
+                " epa.stream_ordering, epa.topological_ordering,"
+                " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
+                " FROM event_push_actions epa, events e"
+                " WHERE epa.event_id = e.event_id"
+                " AND epa.user_id = ? %s"
+                " AND epa.notif = 1"
+                " ORDER BY epa.stream_ordering DESC"
+                " LIMIT ?" % (before_clause,)
+            )
+            txn.execute(sql, args)
+            return self.db_pool.cursor_to_dict(txn)
+
+        push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
+        for pa in push_actions:
+            pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
+        return push_actions
+
+    async def get_latest_push_action_stream_ordering(self):
+        def f(txn):
+            txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
+            return txn.fetchone()
+
+        result = await self.db_pool.runInteraction(
+            "get_latest_push_action_stream_ordering", f
+        )
+        return result[0] or 0
+
+    def _remove_old_push_actions_before_txn(
+        self, txn, room_id, user_id, stream_ordering
+    ):
+        """
+        Purges old push actions for a user and room before a given
+        stream_ordering.
+
+        We however keep a months worth of highlighted notifications, so that
+        users can still get a list of recent highlights.
+
+        Args:
+            txn: The transcation
+            room_id: Room ID to delete from
+            user_id: user ID to delete for
+            stream_ordering: The lowest stream ordering which will
+                                  not be deleted.
+        """
+        txn.call_after(
+            self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+            (room_id, user_id),
+        )
+
+        # We need to join on the events table to get the received_ts for
+        # event_push_actions and sqlite won't let us use a join in a delete so
+        # we can't just delete where received_ts < x. Furthermore we can
+        # only identify event_push_actions by a tuple of room_id, event_id
+        # we we can't use a subquery.
+        # Instead, we look up the stream ordering for the last event in that
+        # room received before the threshold time and delete event_push_actions
+        # in the room with a stream_odering before that.
+        txn.execute(
+            "DELETE FROM event_push_actions "
+            " WHERE user_id = ? AND room_id = ? AND "
+            " stream_ordering <= ?"
+            " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
+            (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
+        )
+
+        txn.execute(
+            """
+            DELETE FROM event_push_summary
+            WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
+        """,
+            (room_id, user_id, stream_ordering),
+        )
+
+
 def _action_has_highlight(actions):
     for action in actions:
         try:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 18def01f50..b19c424ba9 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -52,16 +52,6 @@ event_counter = Counter(
 )
 
 
-def encode_json(json_object):
-    """
-    Encode a Python object as JSON and return it in a Unicode string.
-    """
-    out = frozendict_json_encoder.encode(json_object)
-    if isinstance(out, bytes):
-        out = out.decode("utf8")
-    return out
-
-
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
@@ -341,6 +331,10 @@ class PersistEventsStore:
         min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
         max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
 
+        # stream orderings should have been assigned by now
+        assert min_stream_order
+        assert max_stream_order
+
         self._update_forward_extremities_txn(
             txn,
             new_forward_extremities=new_forward_extremeties,
@@ -432,12 +426,12 @@ class PersistEventsStore:
                 # so that async background tasks get told what happened.
                 sql = """
                     INSERT INTO current_state_delta_stream
-                        (stream_id, room_id, type, state_key, event_id, prev_event_id)
-                    SELECT ?, room_id, type, state_key, null, event_id
+                        (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
+                    SELECT ?, ?, room_id, type, state_key, null, event_id
                         FROM current_state_events
                         WHERE room_id = ?
                 """
-                txn.execute(sql, (stream_id, room_id))
+                txn.execute(sql, (stream_id, self._instance_name, room_id))
 
                 self.db_pool.simple_delete_txn(
                     txn, table="current_state_events", keyvalues={"room_id": room_id},
@@ -458,8 +452,8 @@ class PersistEventsStore:
                 #
                 sql = """
                     INSERT INTO current_state_delta_stream
-                    (stream_id, room_id, type, state_key, event_id, prev_event_id)
-                    SELECT ?, ?, ?, ?, ?, (
+                    (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
+                    SELECT ?, ?, ?, ?, ?, ?, (
                         SELECT event_id FROM current_state_events
                         WHERE room_id = ? AND type = ? AND state_key = ?
                     )
@@ -469,6 +463,7 @@ class PersistEventsStore:
                     (
                         (
                             stream_id,
+                            self._instance_name,
                             room_id,
                             etype,
                             state_key,
@@ -743,7 +738,9 @@ class PersistEventsStore:
                     logger.exception("")
                     raise
 
-                metadata_json = encode_json(event.internal_metadata.get_dict())
+                metadata_json = frozendict_json_encoder.encode(
+                    event.internal_metadata.get_dict()
+                )
 
                 sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
                 txn.execute(sql, (metadata_json, event.event_id))
@@ -759,6 +756,7 @@ class PersistEventsStore:
                         "event_stream_ordering": stream_order,
                         "event_id": event.event_id,
                         "state_group": state_group_id,
+                        "instance_name": self._instance_name,
                     },
                 )
 
@@ -797,10 +795,10 @@ class PersistEventsStore:
                 {
                     "event_id": event.event_id,
                     "room_id": event.room_id,
-                    "internal_metadata": encode_json(
+                    "internal_metadata": frozendict_json_encoder.encode(
                         event.internal_metadata.get_dict()
                     ),
-                    "json": encode_json(event_dict(event)),
+                    "json": frozendict_json_encoder.encode(event_dict(event)),
                     "format_version": event.format_version,
                 }
                 for event, _ in events_and_contexts
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index f95679ebc4..4e74fafe43 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -74,6 +74,13 @@ class EventRedactBehaviour(Names):
 
 
 class EventsWorkerStore(SQLBaseStore):
+    # Whether to use dedicated DB threads for event fetching. This is only used
+    # if there are multiple DB threads available. When used will lock the DB
+    # thread for periods of time (so unit tests want to disable this when they
+    # run DB transactions on the main thread). See EVENT_QUEUE_* for more
+    # options controlling this.
+    USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
+
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
@@ -522,7 +529,11 @@ class EventsWorkerStore(SQLBaseStore):
 
                 if not event_list:
                     single_threaded = self.database_engine.single_threaded
-                    if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+                    if (
+                        not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+                        or single_threaded
+                        or i > EVENT_QUEUE_ITERATIONS
+                    ):
                         self._event_fetch_ongoing -= 1
                         return
                     else:
@@ -712,6 +723,7 @@ class EventsWorkerStore(SQLBaseStore):
                 internal_metadata_dict=internal_metadata,
                 rejected_reason=rejected_reason,
             )
+            original_ev.internal_metadata.stream_ordering = row["stream_ordering"]
 
             event_map[event_id] = original_ev
 
@@ -779,6 +791,8 @@ class EventsWorkerStore(SQLBaseStore):
 
          * event_id (str)
 
+         * stream_ordering (int): stream ordering for this event
+
          * json (str): json-encoded event structure
 
          * internal_metadata (str): json-encoded internal metadata dict
@@ -811,13 +825,15 @@ class EventsWorkerStore(SQLBaseStore):
             sql = """\
                 SELECT
                   e.event_id,
-                  e.internal_metadata,
-                  e.json,
-                  e.format_version,
+                  e.stream_ordering,
+                  ej.internal_metadata,
+                  ej.json,
+                  ej.format_version,
                   r.room_version,
                   rej.reason
-                FROM event_json as e
-                  LEFT JOIN rooms r USING (room_id)
+                FROM events AS e
+                  JOIN event_json AS ej USING (event_id)
+                  LEFT JOIN rooms r ON r.room_id = e.room_id
                   LEFT JOIN rejections as rej USING (event_id)
                 WHERE """
 
@@ -831,11 +847,12 @@ class EventsWorkerStore(SQLBaseStore):
                 event_id = row[0]
                 event_dict[event_id] = {
                     "event_id": event_id,
-                    "internal_metadata": row[1],
-                    "json": row[2],
-                    "format_version": row[3],
-                    "room_version_id": row[4],
-                    "rejected_reason": row[5],
+                    "stream_ordering": row[1],
+                    "internal_metadata": row[2],
+                    "json": row[3],
+                    "format_version": row[4],
+                    "room_version_id": row[5],
+                    "rejected_reason": row[6],
                     "redactions": [],
                 }
 
@@ -1017,16 +1034,12 @@ class EventsWorkerStore(SQLBaseStore):
 
         return {"v1": complexity_v1}
 
-    def get_current_backfill_token(self):
-        """The current minimum token that backfilled events have reached"""
-        return -self._backfill_id_gen.get_current_token()
-
     def get_current_events_token(self):
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
     async def get_all_new_forward_event_rows(
-        self, last_id: int, current_id: int, limit: int
+        self, instance_name: str, last_id: int, current_id: int, limit: int
     ) -> List[Tuple]:
         """Returns new events, for the Events replication stream
 
@@ -1050,10 +1063,11 @@ class EventsWorkerStore(SQLBaseStore):
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " AND instance_name = ?"
                 " ORDER BY stream_ordering ASC"
                 " LIMIT ?"
             )
-            txn.execute(sql, (last_id, current_id, limit))
+            txn.execute(sql, (last_id, current_id, instance_name, limit))
             return txn.fetchall()
 
         return await self.db_pool.runInteraction(
@@ -1061,7 +1075,7 @@ class EventsWorkerStore(SQLBaseStore):
         )
 
     async def get_ex_outlier_stream_rows(
-        self, last_id: int, current_id: int
+        self, instance_name: str, last_id: int, current_id: int
     ) -> List[Tuple]:
         """Returns de-outliered events, for the Events replication stream
 
@@ -1080,16 +1094,17 @@ class EventsWorkerStore(SQLBaseStore):
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
                 " state_key, redacts, relates_to_id"
                 " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? < event_stream_ordering"
                 " AND event_stream_ordering <= ?"
+                " AND out.instance_name = ?"
                 " ORDER BY event_stream_ordering ASC"
             )
 
-            txn.execute(sql, (last_id, current_id))
+            txn.execute(sql, (last_id, current_id, instance_name))
             return txn.fetchall()
 
         return await self.db_pool.runInteraction(
@@ -1102,6 +1117,9 @@ class EventsWorkerStore(SQLBaseStore):
         """Get updates for backfill replication stream, including all new
         backfilled events and events that have gone from being outliers to not.
 
+        NOTE: The IDs given here are from replication, and so should be
+        *positive*.
+
         Args:
             instance_name: The writer we want to fetch updates from. Unused
                 here since there is only ever one writer.
@@ -1132,10 +1150,11 @@ class EventsWorkerStore(SQLBaseStore):
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                "  AND instance_name = ?"
                 " ORDER BY stream_ordering ASC"
                 " LIMIT ?"
             )
-            txn.execute(sql, (-last_id, -current_id, limit))
+            txn.execute(sql, (-last_id, -current_id, instance_name, limit))
             new_event_updates = [(row[0], row[1:]) for row in txn]
 
             limited = False
@@ -1149,15 +1168,16 @@ class EventsWorkerStore(SQLBaseStore):
                 "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
                 " state_key, redacts, relates_to_id"
                 " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
                 " WHERE ? > event_stream_ordering"
                 " AND event_stream_ordering >= ?"
+                " AND out.instance_name = ?"
                 " ORDER BY event_stream_ordering DESC"
             )
-            txn.execute(sql, (-last_id, -upper_bound))
+            txn.execute(sql, (-last_id, -upper_bound, instance_name))
             new_event_updates.extend((row[0], row[1:]) for row in txn)
 
             if len(new_event_updates) >= limit:
@@ -1171,7 +1191,7 @@ class EventsWorkerStore(SQLBaseStore):
         )
 
     async def get_all_updated_current_state_deltas(
-        self, from_token: int, to_token: int, target_row_count: int
+        self, instance_name: str, from_token: int, to_token: int, target_row_count: int
     ) -> Tuple[List[Tuple], int, bool]:
         """Fetch updates from current_state_delta_stream
 
@@ -1197,9 +1217,10 @@ class EventsWorkerStore(SQLBaseStore):
                 SELECT stream_id, room_id, type, state_key, event_id
                 FROM current_state_delta_stream
                 WHERE ? < stream_id AND stream_id <= ?
+                    AND instance_name = ?
                 ORDER BY stream_id ASC LIMIT ?
             """
-            txn.execute(sql, (from_token, to_token, target_row_count))
+            txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
             return txn.fetchall()
 
         def get_deltas_for_stream_id_txn(txn, stream_id):
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 92099f95ce..0acf0617ca 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -12,15 +12,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import calendar
+import logging
+import time
+from typing import Dict
 
 from synapse.metrics import GaugeBucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
 
+logger = logging.getLogger(__name__)
+
 # Collect metrics on the number of forward extremities that exist.
 _extremities_collecter = GaugeBucketCollector(
     "synapse_forward_extremities",
@@ -51,15 +57,13 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         super().__init__(database, db_conn, hs)
 
         # Read the extrems every 60 minutes
-        def read_forward_extremities():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "read_forward_extremities", self._read_forward_extremities
-            )
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000)
 
-        hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+        # Used in _generate_user_daily_visits to keep track of progress
+        self._last_user_visit_update = self._get_start_of_day()
 
+    @wrap_as_background_process("read_forward_extremities")
     async def _read_forward_extremities(self):
         def fetch(txn):
             txn.execute(
@@ -137,3 +141,190 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             return count
 
         return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
+
+    async def count_daily_users(self) -> int:
+        """
+        Counts the number of users who used this homeserver in the last 24 hours.
+        """
+        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+        return await self.db_pool.runInteraction(
+            "count_daily_users", self._count_users, yesterday
+        )
+
+    async def count_monthly_users(self) -> int:
+        """
+        Counts the number of users who used this homeserver in the last 30 days.
+        Note this method is intended for phonehome metrics only and is different
+        from the mau figure in synapse.storage.monthly_active_users which,
+        amongst other things, includes a 3 day grace period before a user counts.
+        """
+        thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+        return await self.db_pool.runInteraction(
+            "count_monthly_users", self._count_users, thirty_days_ago
+        )
+
+    def _count_users(self, txn, time_from):
+        """
+        Returns number of users seen in the past time_from period
+        """
+        sql = """
+            SELECT COALESCE(count(*), 0) FROM (
+                SELECT user_id FROM user_ips
+                WHERE last_seen > ?
+                GROUP BY user_id
+            ) u
+        """
+        txn.execute(sql, (time_from,))
+        (count,) = txn.fetchone()
+        return count
+
+    async def count_r30_users(self) -> Dict[str, int]:
+        """
+        Counts the number of 30 day retained users, defined as:-
+         * Users who have created their accounts more than 30 days ago
+         * Where last seen at most 30 days ago
+         * Where account creation and last_seen are > 30 days apart
+
+        Returns:
+             A mapping of counts globally as well as broken out by platform.
+        """
+
+        def _count_r30_users(txn):
+            thirty_days_in_secs = 86400 * 30
+            now = int(self._clock.time())
+            thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+            sql = """
+                SELECT platform, COALESCE(count(*), 0) FROM (
+                     SELECT
+                        users.name, platform, users.creation_ts * 1000,
+                        MAX(uip.last_seen)
+                     FROM users
+                     INNER JOIN (
+                         SELECT
+                         user_id,
+                         last_seen,
+                         CASE
+                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
+                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+                             ELSE 'unknown'
+                         END
+                         AS platform
+                         FROM user_ips
+                     ) uip
+                     ON users.name = uip.user_id
+                     AND users.appservice_id is NULL
+                     AND users.creation_ts < ?
+                     AND uip.last_seen/1000 > ?
+                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                     GROUP BY users.name, platform, users.creation_ts
+                ) u GROUP BY platform
+            """
+
+            results = {}
+            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+            for row in txn:
+                if row[0] == "unknown":
+                    pass
+                results[row[0]] = row[1]
+
+            sql = """
+                SELECT COALESCE(count(*), 0) FROM (
+                    SELECT users.name, users.creation_ts * 1000,
+                                                        MAX(uip.last_seen)
+                    FROM users
+                    INNER JOIN (
+                        SELECT
+                        user_id,
+                        last_seen
+                        FROM user_ips
+                    ) uip
+                    ON users.name = uip.user_id
+                    AND appservice_id is NULL
+                    AND users.creation_ts < ?
+                    AND uip.last_seen/1000 > ?
+                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                    GROUP BY users.name, users.creation_ts
+                ) u
+            """
+
+            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+            (count,) = txn.fetchone()
+            results["all"] = count
+
+            return results
+
+        return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
+
+    def _get_start_of_day(self):
+        """
+        Returns millisecond unixtime for start of UTC day.
+        """
+        now = time.gmtime()
+        today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+        return today_start * 1000
+
+    @wrap_as_background_process("generate_user_daily_visits")
+    async def generate_user_daily_visits(self) -> None:
+        """
+        Generates daily visit data for use in cohort/ retention analysis
+        """
+
+        def _generate_user_daily_visits(txn):
+            logger.info("Calling _generate_user_daily_visits")
+            today_start = self._get_start_of_day()
+            a_day_in_milliseconds = 24 * 60 * 60 * 1000
+            now = self._clock.time_msec()
+
+            sql = """
+                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+                    SELECT u.user_id, u.device_id, ?
+                    FROM user_ips AS u
+                    LEFT JOIN (
+                      SELECT user_id, device_id, timestamp FROM user_daily_visits
+                      WHERE timestamp = ?
+                    ) udv
+                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+                    INNER JOIN users ON users.name=u.user_id
+                    WHERE last_seen > ? AND last_seen <= ?
+                    AND udv.timestamp IS NULL AND users.is_guest=0
+                    AND users.appservice_id IS NULL
+                    GROUP BY u.user_id, u.device_id
+            """
+
+            # This means that the day has rolled over but there could still
+            # be entries from the previous day. There is an edge case
+            # where if the user logs in at 23:59 and overwrites their
+            # last_seen at 00:01 then they will not be counted in the
+            # previous day's stats - it is important that the query is run
+            # often to minimise this case.
+            if today_start > self._last_user_visit_update:
+                yesterday_start = today_start - a_day_in_milliseconds
+                txn.execute(
+                    sql,
+                    (
+                        yesterday_start,
+                        yesterday_start,
+                        self._last_user_visit_update,
+                        today_start,
+                    ),
+                )
+                self._last_user_visit_update = today_start
+
+            txn.execute(
+                sql, (today_start, today_start, self._last_user_visit_update, now)
+            )
+            # Update _last_user_visit_update to now. The reason to do this
+            # rather just clamping to the beginning of the day is to limit
+            # the size of the join - meaning that the query can be run more
+            # frequently
+            self._last_user_visit_update = now
+
+        await self.db_pool.runInteraction(
+            "generate_user_daily_visits", _generate_user_daily_visits
+        )
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e93aad33cd..d788dc0fc6 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,6 +15,7 @@
 import logging
 from typing import Dict, List
 
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool, make_in_list_sql_clause
 from synapse.util.caches.descriptors import cached
@@ -32,6 +33,9 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
         self._clock = hs.get_clock()
         self.hs = hs
 
+        self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+        self._max_mau_value = hs.config.max_mau_value
+
     @cached(num_args=0)
     async def get_monthly_active_count(self) -> int:
         """Generates current count of monthly active users
@@ -124,60 +128,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
             desc="user_last_seen_monthly_active",
         )
 
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
-        super().__init__(database, db_conn, hs)
-
-        self._limit_usage_by_mau = hs.config.limit_usage_by_mau
-        self._mau_stats_only = hs.config.mau_stats_only
-        self._max_mau_value = hs.config.max_mau_value
-
-        # Do not add more reserved users than the total allowable number
-        # cur = LoggingTransaction(
-        self.db_pool.new_transaction(
-            db_conn,
-            "initialise_mau_threepids",
-            [],
-            [],
-            self._initialise_reserved_users,
-            hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
-        )
-
-    def _initialise_reserved_users(self, txn, threepids):
-        """Ensures that reserved threepids are accounted for in the MAU table, should
-        be called on start up.
-
-        Args:
-            txn (cursor):
-            threepids (list[dict]): List of threepid dicts to reserve
-        """
-
-        # XXX what is this function trying to achieve?  It upserts into
-        # monthly_active_users for each *registered* reserved mau user, but why?
-        #
-        #  - shouldn't there already be an entry for each reserved user (at least
-        #    if they have been active recently)?
-        #
-        #  - if it's important that the timestamp is kept up to date, why do we only
-        #    run this at startup?
-
-        for tp in threepids:
-            user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
-
-            if user_id:
-                is_support = self.is_support_user_txn(txn, user_id)
-                if not is_support:
-                    # We do this manually here to avoid hitting #6791
-                    self.db_pool.simple_upsert_txn(
-                        txn,
-                        table="monthly_active_users",
-                        keyvalues={"user_id": user_id},
-                        values={"timestamp": int(self._clock.time_msec())},
-                    )
-            else:
-                logger.warning("mau limit reserved threepid %s not found in db" % tp)
-
+    @wrap_as_background_process("reap_monthly_active_users")
     async def reap_monthly_active_users(self):
         """Cleans out monthly active user table to ensure that no stale
         entries exist.
@@ -257,6 +208,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
             "reap_monthly_active_users", _reap_users, reserved_users
         )
 
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        self._mau_stats_only = hs.config.mau_stats_only
+
+        # Do not add more reserved users than the total allowable number
+        self.db_pool.new_transaction(
+            db_conn,
+            "initialise_mau_threepids",
+            [],
+            [],
+            self._initialise_reserved_users,
+            hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
+        )
+
+    def _initialise_reserved_users(self, txn, threepids):
+        """Ensures that reserved threepids are accounted for in the MAU table, should
+        be called on start up.
+
+        Args:
+            txn (cursor):
+            threepids (list[dict]): List of threepid dicts to reserve
+        """
+
+        # XXX what is this function trying to achieve?  It upserts into
+        # monthly_active_users for each *registered* reserved mau user, but why?
+        #
+        #  - shouldn't there already be an entry for each reserved user (at least
+        #    if they have been active recently)?
+        #
+        #  - if it's important that the timestamp is kept up to date, why do we only
+        #    run this at startup?
+
+        for tp in threepids:
+            user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
+
+            if user_id:
+                is_support = self.is_support_user_txn(txn, user_id)
+                if not is_support:
+                    # We do this manually here to avoid hitting #6791
+                    self.db_pool.simple_upsert_txn(
+                        txn,
+                        table="monthly_active_users",
+                        keyvalues={"user_id": user_id},
+                        values={"timestamp": int(self._clock.time_msec())},
+                    )
+            else:
+                logger.warning("mau limit reserved threepid %s not found in db" % tp)
+
     async def upsert_monthly_active_user(self, user_id: str) -> None:
         """Updates or inserts the user into the monthly active user table, which
         is used to track the current MAU usage of the server
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index a83df7759d..236d3cdbe3 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2017-2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,14 +14,13 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import re
 from typing import Any, Dict, List, Optional, Tuple
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
 from synapse.storage.types import Cursor
@@ -48,6 +47,18 @@ class RegistrationWorkerStore(SQLBaseStore):
             database.engine, find_max_generated_user_id_localpart, "user_id_seq",
         )
 
+        self._account_validity = hs.config.account_validity
+        if hs.config.run_background_tasks and self._account_validity.enabled:
+            self._clock.call_later(
+                0.0, self._set_expiration_date_when_missing,
+            )
+
+        # Create a background job for culling expired 3PID validity tokens
+        if hs.config.run_background_tasks:
+            self.clock.looping_call(
+                self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
+            )
+
     @cached()
     async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
         return await self.db_pool.simple_select_one(
@@ -778,6 +789,79 @@ class RegistrationWorkerStore(SQLBaseStore):
             "delete_threepid_session", delete_threepid_session_txn
         )
 
+    @wrap_as_background_process("cull_expired_threepid_validation_tokens")
+    async def cull_expired_threepid_validation_tokens(self) -> None:
+        """Remove threepid validation tokens with expiry dates that have passed"""
+
+        def cull_expired_threepid_validation_tokens_txn(txn, ts):
+            sql = """
+            DELETE FROM threepid_validation_token WHERE
+            expires < ?
+            """
+            txn.execute(sql, (ts,))
+
+        await self.db_pool.runInteraction(
+            "cull_expired_threepid_validation_tokens",
+            cull_expired_threepid_validation_tokens_txn,
+            self.clock.time_msec(),
+        )
+
+    @wrap_as_background_process("account_validity_set_expiration_dates")
+    async def _set_expiration_date_when_missing(self):
+        """
+        Retrieves the list of registered users that don't have an expiration date, and
+        adds an expiration date for each of them.
+        """
+
+        def select_users_with_no_expiration_date_txn(txn):
+            """Retrieves the list of registered users with no expiration date from the
+            database, filtering out deactivated users.
+            """
+            sql = (
+                "SELECT users.name FROM users"
+                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+            )
+            txn.execute(sql, [])
+
+            res = self.db_pool.cursor_to_dict(txn)
+            if res:
+                for user in res:
+                    self.set_expiration_date_for_user_txn(
+                        txn, user["name"], use_delta=True
+                    )
+
+        await self.db_pool.runInteraction(
+            "get_users_with_no_expiration_date",
+            select_users_with_no_expiration_date_txn,
+        )
+
+    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+        """Sets an expiration date to the account with the given user ID.
+
+        Args:
+             user_id (str): User ID to set an expiration date for.
+             use_delta (bool): If set to False, the expiration date for the user will be
+                now + validity period. If set to True, this expiration date will be a
+                random value in the [now + period - d ; now + period] range, d being a
+                delta equal to 10% of the validity period.
+        """
+        now_ms = self._clock.time_msec()
+        expiration_ts = now_ms + self._account_validity.period
+
+        if use_delta:
+            expiration_ts = self.rand.randrange(
+                expiration_ts - self._account_validity.startup_job_max_delta,
+                expiration_ts,
+            )
+
+        self.db_pool.simple_upsert_txn(
+            txn,
+            "account_validity",
+            keyvalues={"user_id": user_id},
+            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+        )
+
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
@@ -911,28 +995,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
-        self._account_validity = hs.config.account_validity
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
-        if self._account_validity.enabled:
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "account_validity_set_expiration_dates",
-                self._set_expiration_date_when_missing,
-            )
-
-        # Create a background job for culling expired 3PID validity tokens
-        def start_cull():
-            # run as a background process to make sure that the database transactions
-            # have a logcontext to report to
-            return run_as_background_process(
-                "cull_expired_threepid_validation_tokens",
-                self.cull_expired_threepid_validation_tokens,
-            )
-
-        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
-
     async def add_access_token_to_user(
         self,
         user_id: str,
@@ -964,6 +1028,36 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="add_access_token_to_user",
         )
 
+    def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+        old_device_id = self.db_pool.simple_select_one_onecol_txn(
+            txn, "access_tokens", {"token": token}, "device_id"
+        )
+
+        self.db_pool.simple_update_txn(
+            txn, "access_tokens", {"token": token}, {"device_id": device_id}
+        )
+
+        self._invalidate_cache_and_stream(txn, self.get_user_by_access_token, (token,))
+
+        return old_device_id
+
+    async def set_device_for_access_token(self, token: str, device_id: str) -> str:
+        """Sets the device ID associated with an access token.
+
+        Args:
+            token: The access token to modify.
+            device_id: The new device ID.
+        Returns:
+            The old device ID associated with the access token.
+        """
+
+        return await self.db_pool.runInteraction(
+            "set_device_for_access_token",
+            self._set_device_for_access_token_txn,
+            token,
+            device_id,
+        )
+
     async def register_user(
         self,
         user_id: str,
@@ -1121,7 +1215,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             desc="record_user_external_id",
         )
 
-    async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
+    async def user_set_password_hash(
+        self, user_id: str, password_hash: Optional[str]
+    ) -> None:
         """
         NB. This does *not* evict any cache because the one use for this
             removes most of the entries subsequently anyway so it would be
@@ -1447,22 +1543,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             start_or_continue_validation_session_txn,
         )
 
-    async def cull_expired_threepid_validation_tokens(self) -> None:
-        """Remove threepid validation tokens with expiry dates that have passed"""
-
-        def cull_expired_threepid_validation_tokens_txn(txn, ts):
-            sql = """
-            DELETE FROM threepid_validation_token WHERE
-            expires < ?
-            """
-            txn.execute(sql, (ts,))
-
-        await self.db_pool.runInteraction(
-            "cull_expired_threepid_validation_tokens",
-            cull_expired_threepid_validation_tokens_txn,
-            self.clock.time_msec(),
-        )
-
     async def set_user_deactivated_status(
         self, user_id: str, deactivated: bool
     ) -> None:
@@ -1492,61 +1572,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         )
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
-    async def _set_expiration_date_when_missing(self):
-        """
-        Retrieves the list of registered users that don't have an expiration date, and
-        adds an expiration date for each of them.
-        """
-
-        def select_users_with_no_expiration_date_txn(txn):
-            """Retrieves the list of registered users with no expiration date from the
-            database, filtering out deactivated users.
-            """
-            sql = (
-                "SELECT users.name FROM users"
-                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
-                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
-            )
-            txn.execute(sql, [])
-
-            res = self.db_pool.cursor_to_dict(txn)
-            if res:
-                for user in res:
-                    self.set_expiration_date_for_user_txn(
-                        txn, user["name"], use_delta=True
-                    )
-
-        await self.db_pool.runInteraction(
-            "get_users_with_no_expiration_date",
-            select_users_with_no_expiration_date_txn,
-        )
-
-    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
-        """Sets an expiration date to the account with the given user ID.
-
-        Args:
-             user_id (str): User ID to set an expiration date for.
-             use_delta (bool): If set to False, the expiration date for the user will be
-                now + validity period. If set to True, this expiration date will be a
-                random value in the [now + period - d ; now + period] range, d being a
-                delta equal to 10% of the validity period.
-        """
-        now_ms = self._clock.time_msec()
-        expiration_ts = now_ms + self._account_validity.period
-
-        if use_delta:
-            expiration_ts = self.rand.randrange(
-                expiration_ts - self._account_validity.startup_job_max_delta,
-                expiration_ts,
-            )
-
-        self.db_pool.simple_upsert_txn(
-            txn,
-            "account_validity",
-            keyvalues={"user_id": user_id},
-            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
-        )
-
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 3c7630857f..c0f2af0785 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -192,6 +192,18 @@ class RoomWorkerStore(SQLBaseStore):
             "count_public_rooms", _count_public_rooms_txn
         )
 
+    async def get_room_count(self) -> int:
+        """Retrieve the total number of rooms.
+        """
+
+        def f(txn):
+            sql = "SELECT count(*)  FROM rooms"
+            txn.execute(sql)
+            row = txn.fetchone()
+            return row[0] or 0
+
+        return await self.db_pool.runInteraction("get_rooms", f)
+
     async def get_largest_public_rooms(
         self,
         network_tuple: Optional[ThirdPartyInstanceID],
@@ -1292,18 +1304,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             )
         self.hs.get_notifier().on_new_replication_data()
 
-    async def get_room_count(self) -> int:
-        """Retrieve the total number of rooms.
-        """
-
-        def f(txn):
-            sql = "SELECT count(*)  FROM rooms"
-            txn.execute(sql)
-            row = txn.fetchone()
-            return row[0] or 0
-
-        return await self.db_pool.runInteraction("get_rooms", f)
-
     async def add_event_report(
         self,
         room_id: str,
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 86ffe2479e..20fcdaa529 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,12 +21,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
-    LoggingTransaction,
-    SQLBaseStore,
-    db_to_json,
-    make_in_list_sql_clause,
-)
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
@@ -60,15 +55,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # background update still running?
         self._current_state_events_membership_up_to_date = False
 
-        txn = LoggingTransaction(
-            db_conn.cursor(),
-            name="_check_safe_current_state_events_membership_updated",
-            database_engine=self.database_engine,
+        txn = db_conn.cursor(
+            txn_name="_check_safe_current_state_events_membership_updated"
         )
         self._check_safe_current_state_events_membership_updated_txn(txn)
         txn.close()
 
-        if self.hs.config.metrics_flags.known_servers:
+        if (
+            self.hs.config.run_background_tasks
+            and self.hs.config.metrics_flags.known_servers
+        ):
             self._known_servers_count = 1
             self.hs.get_clock().looping_call(
                 run_as_background_process,
diff --git a/synapse/storage/databases/main/schema/delta/20/pushers.py b/synapse/storage/databases/main/schema/delta/20/pushers.py
index 3edfcfd783..45b846e6a7 100644
--- a/synapse/storage/databases/main/schema/delta/20/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/20/pushers.py
@@ -66,16 +66,15 @@ def run_create(cur, database_engine, *args, **kwargs):
         row[8] = bytes(row[8]).decode("utf-8")
         row[11] = bytes(row[11]).decode("utf-8")
         cur.execute(
-            database_engine.convert_param_style(
-                """
-            INSERT into pushers2 (
-            id, user_name, access_token, profile_tag, kind,
-            app_id, app_display_name, device_display_name,
-            pushkey, ts, lang, data, last_token, last_success,
-            failing_since
-            ) values (%s)"""
-                % (",".join(["?" for _ in range(len(row))]))
-            ),
+            """
+                INSERT into pushers2 (
+                id, user_name, access_token, profile_tag, kind,
+                app_id, app_display_name, device_display_name,
+                pushkey, ts, lang, data, last_token, last_success,
+                failing_since
+                ) values (%s)
+            """
+            % (",".join(["?" for _ in range(len(row))])),
             row,
         )
         count += 1
diff --git a/synapse/storage/databases/main/schema/delta/25/fts.py b/synapse/storage/databases/main/schema/delta/25/fts.py
index ee675e71ff..21f57825d4 100644
--- a/synapse/storage/databases/main/schema/delta/25/fts.py
+++ b/synapse/storage/databases/main/schema/delta/25/fts.py
@@ -71,8 +71,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_search", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/27/ts.py b/synapse/storage/databases/main/schema/delta/27/ts.py
index b7972cfa8e..1c6058063f 100644
--- a/synapse/storage/databases/main/schema/delta/27/ts.py
+++ b/synapse/storage/databases/main/schema/delta/27/ts.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_origin_server_ts", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/30/as_users.py b/synapse/storage/databases/main/schema/delta/30/as_users.py
index b42c02710a..7f08fabe9f 100644
--- a/synapse/storage/databases/main/schema/delta/30/as_users.py
+++ b/synapse/storage/databases/main/schema/delta/30/as_users.py
@@ -59,9 +59,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
         user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
         for chunk in user_chunks:
             cur.execute(
-                database_engine.convert_param_style(
-                    "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
-                    % (",".join("?" for _ in chunk),)
-                ),
+                "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+                % (",".join("?" for _ in chunk),),
                 [as_id] + chunk,
             )
diff --git a/synapse/storage/databases/main/schema/delta/31/pushers.py b/synapse/storage/databases/main/schema/delta/31/pushers.py
index 9bb504aad5..5be81c806a 100644
--- a/synapse/storage/databases/main/schema/delta/31/pushers.py
+++ b/synapse/storage/databases/main/schema/delta/31/pushers.py
@@ -65,16 +65,15 @@ def run_create(cur, database_engine, *args, **kwargs):
         row = list(row)
         row[12] = token_to_stream_ordering(row[12])
         cur.execute(
-            database_engine.convert_param_style(
-                """
-            INSERT into pushers2 (
-            id, user_name, access_token, profile_tag, kind,
-            app_id, app_display_name, device_display_name,
-            pushkey, ts, lang, data, last_stream_ordering, last_success,
-            failing_since
-            ) values (%s)"""
-                % (",".join(["?" for _ in range(len(row))]))
-            ),
+            """
+                INSERT into pushers2 (
+                id, user_name, access_token, profile_tag, kind,
+                app_id, app_display_name, device_display_name,
+                pushkey, ts, lang, data, last_stream_ordering, last_success,
+                failing_since
+                ) values (%s)
+            """
+            % (",".join(["?" for _ in range(len(row))])),
             row,
         )
         count += 1
diff --git a/synapse/storage/databases/main/schema/delta/31/search_update.py b/synapse/storage/databases/main/schema/delta/31/search_update.py
index 63b757ade6..b84c844e3a 100644
--- a/synapse/storage/databases/main/schema/delta/31/search_update.py
+++ b/synapse/storage/databases/main/schema/delta/31/search_update.py
@@ -55,8 +55,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_search_order", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/33/event_fields.py b/synapse/storage/databases/main/schema/delta/33/event_fields.py
index a3e81eeac7..e928c66a8f 100644
--- a/synapse/storage/databases/main/schema/delta/33/event_fields.py
+++ b/synapse/storage/databases/main/schema/delta/33/event_fields.py
@@ -50,8 +50,6 @@ def run_create(cur, database_engine, *args, **kwargs):
             " VALUES (?, ?)"
         )
 
-        sql = database_engine.convert_param_style(sql)
-
         cur.execute(sql, ("event_fields_sender_url", progress_json))
 
 
diff --git a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
index a26057dfb6..ad875c733a 100644
--- a/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/databases/main/schema/delta/33/remote_media_ts.py
@@ -23,8 +23,5 @@ def run_create(cur, database_engine, *args, **kwargs):
 
 def run_upgrade(cur, database_engine, *args, **kwargs):
     cur.execute(
-        database_engine.convert_param_style(
-            "UPDATE remote_media_cache SET last_access_ts = ?"
-        ),
-        (int(time.time() * 1000),),
+        "UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
     )
diff --git a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
index 1de8b54961..bb7296852a 100644
--- a/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/databases/main/schema/delta/56/unique_user_filter_index.py
@@ -1,6 +1,8 @@
 import logging
+from io import StringIO
 
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
 
 logger = logging.getLogger(__name__)
 
@@ -46,7 +48,4 @@ def run_create(cur, database_engine, *args, **kwargs):
         select_clause,
     )
 
-    if isinstance(database_engine, PostgresEngine):
-        cur.execute(sql)
-    else:
-        cur.executescript(sql)
+    execute_statements_from_stream(cur, StringIO(sql))
diff --git a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
index 63b5acdcf7..44917f0a2e 100644
--- a/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
+++ b/synapse/storage/databases/main/schema/delta/57/local_current_membership.py
@@ -68,7 +68,6 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
                 INNER JOIN room_memberships AS r USING (event_id)
                 WHERE type = 'm.room.member' AND state_key LIKE ?
         """
-    sql = database_engine.convert_param_style(sql)
     cur.execute(sql, ("%:" + config.server_name,))
 
     cur.execute(
diff --git a/synapse/storage/databases/main/schema/delta/58/11dehydration.sql b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
new file mode 100644
index 0000000000..7851a0a825
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11dehydration.sql
@@ -0,0 +1,20 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS dehydrated_devices(
+    user_id TEXT NOT NULL PRIMARY KEY,
+    device_id TEXT NOT NULL,
+    device_data TEXT NOT NULL -- JSON-encoded client-defined data
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/11fallback.sql b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
new file mode 100644
index 0000000000..4ed981dbf8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/11fallback.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
+    user_id TEXT NOT NULL, -- The user this fallback key is for.
+    device_id TEXT NOT NULL, -- The device this fallback key is for.
+    algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
+    key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
+    key_json TEXT NOT NULL, -- The key as a JSON blob.
+    used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
+    CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
+);
diff --git a/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
new file mode 100644
index 0000000000..841186b826
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A unique and immutable mapping between instance name and an integer ID. This
+-- lets us refer to instances via a small ID in e.g. stream tokens, without
+-- having to encode the full name.
+CREATE TABLE IF NOT EXISTS instance_map (
+    instance_id SERIAL PRIMARY KEY,
+    instance_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);
diff --git a/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
new file mode 100644
index 0000000000..ad1f481428
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/20instance_name_event_tables.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE current_state_delta_stream ADD COLUMN instance_name TEXT;
+ALTER TABLE ex_outlier_stream ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 37249f1e3f..e3b9ff5ca6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -53,7 +53,9 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 if TYPE_CHECKING:
@@ -208,6 +210,55 @@ def _make_generic_sql_bound(
     )
 
 
+def _filter_results(
+    lower_token: Optional[RoomStreamToken],
+    upper_token: Optional[RoomStreamToken],
+    instance_name: str,
+    topological_ordering: int,
+    stream_ordering: int,
+) -> bool:
+    """Returns True if the event persisted by the given instance at the given
+    topological/stream_ordering falls between the two tokens (taking a None
+    token to mean unbounded).
+
+    Used to filter results from fetching events in the DB against the given
+    tokens. This is necessary to handle the case where the tokens include
+    position maps, which we handle by fetching more than necessary from the DB
+    and then filtering (rather than attempting to construct a complicated SQL
+    query).
+    """
+
+    event_historical_tuple = (
+        topological_ordering,
+        stream_ordering,
+    )
+
+    if lower_token:
+        if lower_token.topological is not None:
+            # If these are historical tokens we compare the `(topological, stream)`
+            # tuples.
+            if event_historical_tuple <= lower_token.as_historical_tuple():
+                return False
+
+        else:
+            # If these are live tokens we compare the stream ordering against the
+            # writers stream position.
+            if stream_ordering <= lower_token.get_stream_pos_for_instance(
+                instance_name
+            ):
+                return False
+
+    if upper_token:
+        if upper_token.topological is not None:
+            if upper_token.as_historical_tuple() < event_historical_tuple:
+                return False
+        else:
+            if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+                return False
+
+    return True
+
+
 def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
@@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     def get_room_max_token(self) -> RoomStreamToken:
-        return RoomStreamToken(None, self.get_room_max_stream_ordering())
+        """Get a `RoomStreamToken` that marks the current maximum persisted
+        position of the events stream. Useful to get a token that represents
+        "now".
+
+        The token returned is a "live" token that may have an instance_map
+        component.
+        """
+
+        min_pos = self._stream_id_gen.get_current_token()
+
+        positions = {}
+        if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+            # The `min_pos` is the minimum position that we know all instances
+            # have finished persisting to, so we only care about instances whose
+            # positions are ahead of that. (Instance positions can be behind the
+            # min position as there are times we can work out that the minimum
+            # position is ahead of the naive minimum across all current
+            # positions. See MultiWriterIdGenerator for details)
+            positions = {
+                i: p
+                for i, p in self._stream_id_gen.get_positions().items()
+                if p > min_pos
+            }
+
+        return RoomStreamToken(None, min_pos, positions)
 
     async def get_room_events_stream_for_rooms(
         self,
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         if from_key == to_key:
             return [], from_key
 
-        from_id = from_key.stream
-        to_id = to_key.stream
-
-        has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+        has_changed = self._events_stream_cache.has_entity_changed(
+            room_id, from_key.stream
+        )
 
         if not has_changed:
             return [], from_key
 
         def f(txn):
-            sql = (
-                "SELECT event_id, stream_ordering FROM events WHERE"
-                " room_id = ?"
-                " AND not outlier"
-                " AND stream_ordering > ? AND stream_ordering <= ?"
-                " ORDER BY stream_ordering %s LIMIT ?"
-            ) % (order,)
-            txn.execute(sql, (room_id, from_id, to_id, limit))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            sql = """
+                SELECT event_id, instance_name, topological_ordering, stream_ordering
+                FROM events
+                WHERE
+                    room_id = ?
+                    AND not outlier
+                    AND stream_ordering > ? AND stream_ordering <= ?
+                ORDER BY stream_ordering %s LIMIT ?
+            """ % (
+                order,
+            )
+            txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, topological_ordering, stream_ordering in txn
+                if _filter_results(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    topological_ordering,
+                    stream_ordering,
+                )
+            ][:limit]
             return rows
 
         rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             [r.event_id for r in rows], get_prev_content=True
         )
 
-        self._set_before_and_after(ret, rows, topo_order=from_id is None)
+        self._set_before_and_after(ret, rows, topo_order=False)
 
         if order.lower() == "desc":
             ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     async def get_membership_changes_for_user(
         self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
-        from_id = from_key.stream
-        to_id = to_key.stream
-
         if from_key == to_key:
             return []
 
-        if from_id:
+        if from_key:
             has_changed = self._membership_stream_cache.has_entity_changed(
-                user_id, int(from_id)
+                user_id, int(from_key.stream)
             )
             if not has_changed:
                 return []
 
         def f(txn):
-            sql = (
-                "SELECT m.event_id, stream_ordering FROM events AS e,"
-                " room_memberships AS m"
-                " WHERE e.event_id = m.event_id"
-                " AND m.user_id = ?"
-                " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
-            )
-            txn.execute(sql, (user_id, from_id, to_id))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            sql = """
+                SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+                FROM events AS e, room_memberships AS m
+                WHERE e.event_id = m.event_id
+                    AND m.user_id = ?
+                    AND e.stream_ordering > ? AND e.stream_ordering <= ?
+                ORDER BY e.stream_ordering ASC
+            """
+            txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, topological_ordering, stream_ordering in txn
+                if _filter_results(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    topological_ordering,
+                    stream_ordering,
+                )
+            ]
 
             return rows
 
@@ -546,7 +651,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
     async def get_room_event_before_stream_ordering(
         self, room_id: str, stream_ordering: int
-    ) -> Tuple[int, int, str]:
+    ) -> Optional[Tuple[int, int, str]]:
         """Gets details of the first event in a room at or before a stream ordering
 
         Args:
@@ -589,19 +694,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             )
             return "t%d-%d" % (topo, token)
 
-    async def get_stream_id_for_event(self, event_id: str) -> int:
-        """The stream ID for an event
-        Args:
-            event_id: The id of the event to look up a stream token for.
-        Raises:
-            StoreError if the event wasn't in the database.
-        Returns:
-            A stream ID.
-        """
-        return await self.db_pool.runInteraction(
-            "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
-        )
-
     def get_stream_id_for_event_txn(
         self, txn: LoggingTransaction, event_id: str, allow_none=False,
     ) -> int:
@@ -979,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         else:
             order = "ASC"
 
+        # The bounds for the stream tokens are complicated by the fact
+        # that we need to handle the instance_map part of the tokens. We do this
+        # by fetching all events between the min stream token and the maximum
+        # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+        # then filtering the results.
+        if from_token.topological is not None:
+            from_bound = (
+                from_token.as_historical_tuple()
+            )  # type: Tuple[Optional[int], int]
+        elif direction == "b":
+            from_bound = (
+                None,
+                from_token.get_max_stream_pos(),
+            )
+        else:
+            from_bound = (
+                None,
+                from_token.stream,
+            )
+
+        to_bound = None  # type: Optional[Tuple[Optional[int], int]]
+        if to_token:
+            if to_token.topological is not None:
+                to_bound = to_token.as_historical_tuple()
+            elif direction == "b":
+                to_bound = (
+                    None,
+                    to_token.stream,
+                )
+            else:
+                to_bound = (
+                    None,
+                    to_token.get_max_stream_pos(),
+                )
+
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token.as_tuple(),
-            to_token=to_token.as_tuple() if to_token else None,
+            from_token=from_bound,
+            to_token=to_bound,
             engine=self.database_engine,
         )
 
@@ -993,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             bounds += " AND " + filter_clause
             args.extend(filter_args)
 
-        args.append(int(limit))
+        # We fetch more events as we'll filter the result set
+        args.append(int(limit) * 2)
 
         select_keywords = "SELECT"
         join_clause = ""
@@ -1015,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                 select_keywords += "DISTINCT"
 
         sql = """
-            %(select_keywords)s event_id, topological_ordering, stream_ordering
+            %(select_keywords)s
+                event_id, instance_name,
+                topological_ordering, stream_ordering
             FROM events
             %(join_clause)s
             WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1030,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         txn.execute(sql, args)
 
-        rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+        # Filter the result set.
+        rows = [
+            _EventDictReturn(event_id, topological_ordering, stream_ordering)
+            for event_id, instance_name, topological_ordering, stream_ordering in txn
+            if _filter_results(
+                lower_token=to_token if direction == "b" else from_token,
+                upper_token=from_token if direction == "b" else to_token,
+                instance_name=instance_name,
+                topological_ordering=topological_ordering,
+                stream_ordering=stream_ordering,
+            )
+        ][:limit]
 
         if rows:
             topo = rows[-1].topological_ordering
@@ -1095,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         return (events, token)
 
+    @cached()
+    async def get_id_for_instance(self, instance_name: str) -> int:
+        """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+        """
+
+        def _get_id_for_instance_txn(txn):
+            instance_id = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                retcol="instance_id",
+                allow_none=True,
+            )
+            if instance_id is not None:
+                return instance_id
+
+            # If we don't have an entry upsert one.
+            #
+            # We could do this before the first check, and rely on the cache for
+            # efficiency, but each UPSERT causes the next ID to increment which
+            # can quickly bloat the size of the generated IDs for new instances.
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                values={},
+            )
+
+            return self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                retcol="instance_id",
+            )
+
+        return await self.db_pool.runInteraction(
+            "get_id_for_instance", _get_id_for_instance_txn
+        )
+
+    @cached()
+    async def get_name_from_instance_id(self, instance_id: int) -> str:
+        """Get the instance name from an ID previously returned by
+        `get_id_for_instance`.
+        """
+
+        return await self.db_pool.simple_select_one_onecol(
+            table="instance_map",
+            keyvalues={"instance_id": instance_id},
+            retcol="instance_name",
+            desc="get_name_from_instance_id",
+        )
+
 
 class StreamStore(StreamWorkerStore):
     def get_room_max_stream_ordering(self) -> int:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 97aed1500e..7d46090267 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -19,7 +19,7 @@ from typing import Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -43,15 +43,33 @@ _UpdateTransactionRow = namedtuple(
 SENTINEL = object()
 
 
-class TransactionStore(SQLBaseStore):
+class TransactionWorkerStore(SQLBaseStore):
+    def __init__(self, database: DatabasePool, db_conn, hs):
+        super().__init__(database, db_conn, hs)
+
+        if hs.config.run_background_tasks:
+            self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000)
+
+    @wrap_as_background_process("cleanup_transactions")
+    async def _cleanup_transactions(self) -> None:
+        now = self._clock.time_msec()
+        month_ago = now - 30 * 24 * 60 * 60 * 1000
+
+        def _cleanup_transactions_txn(txn):
+            txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
+
+        await self.db_pool.runInteraction(
+            "_cleanup_transactions", _cleanup_transactions_txn
+        )
+
+
+class TransactionStore(TransactionWorkerStore):
     """A collection of queries for handling PDUs.
     """
 
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
-        self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
-
         self._destination_retry_cache = ExpiringCache(
             cache_name="get_destination_retry_timings",
             clock=self._clock,
@@ -266,22 +284,6 @@ class TransactionStore(SQLBaseStore):
                 },
             )
 
-    def _start_cleanup_transactions(self):
-        return run_as_background_process(
-            "cleanup_transactions", self._cleanup_transactions
-        )
-
-    async def _cleanup_transactions(self) -> None:
-        now = self._clock.time_msec()
-        month_ago = now - 30 * 24 * 60 * 60 * 1000
-
-        def _cleanup_transactions_txn(txn):
-            txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
-
-        await self.db_pool.runInteraction(
-            "_cleanup_transactions", _cleanup_transactions_txn
-        )
-
     async def store_destination_rooms_entries(
         self, destinations: Iterable[str], room_id: str, stream_ordering: int,
     ) -> None:
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 3b9211a6d2..79b7ece330 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -288,8 +288,6 @@ class UIAuthWorkerStore(SQLBaseStore):
         )
         return [(row["user_agent"], row["ip"]) for row in rows]
 
-
-class UIAuthStore(UIAuthWorkerStore):
     async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
         """
         Remove sessions which were last used earlier than the expiration time.
@@ -339,3 +337,7 @@ class UIAuthStore(UIAuthWorkerStore):
             iterable=session_ids,
             keyvalues={},
         )
+
+
+class UIAuthStore(UIAuthWorkerStore):
+    pass
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 72939f3984..4d2d88d1f0 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -248,6 +248,8 @@ class EventsPersistenceStorage:
         await make_deferred_yieldable(deferred)
 
         event_stream_id = event.internal_metadata.stream_ordering
+        # stream ordering should have been assigned by now
+        assert event_stream_id
 
         pos = PersistedEventPosition(self._instance_name, event_stream_id)
         return pos, self.main_store.get_room_max_token()
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 4957e77f4c..459754feab 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import imp
 import logging
 import os
@@ -24,9 +23,10 @@ from typing import Optional, TextIO
 import attr
 
 from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import BaseDatabaseEngine
 from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Cursor
 from synapse.types import Collection
 
 logger = logging.getLogger(__name__)
@@ -67,7 +67,7 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
 
 
 def prepare_database(
-    db_conn: Connection,
+    db_conn: LoggingDatabaseConnection,
     database_engine: BaseDatabaseEngine,
     config: Optional[HomeServerConfig],
     databases: Collection[str] = ["main", "state"],
@@ -89,7 +89,7 @@ def prepare_database(
     """
 
     try:
-        cur = db_conn.cursor()
+        cur = db_conn.cursor(txn_name="prepare_database")
 
         # sqlite does not automatically start transactions for DDL / SELECT statements,
         # so we start one before running anything. This ensures that any upgrades
@@ -258,9 +258,7 @@ def _setup_new_database(cur, database_engine, databases):
             executescript(cur, entry.absolute_path)
 
     cur.execute(
-        database_engine.convert_param_style(
-            "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
-        ),
+        "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
         (max_current_ver, False),
     )
 
@@ -486,17 +484,13 @@ def _upgrade_existing_database(
 
             # Mark as done.
             cur.execute(
-                database_engine.convert_param_style(
-                    "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
-                ),
+                "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)",
                 (v, relative_path),
             )
 
             cur.execute("DELETE FROM schema_version")
             cur.execute(
-                database_engine.convert_param_style(
-                    "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
-                ),
+                "INSERT INTO schema_version (version, upgraded) VALUES (?,?)",
                 (v, True),
             )
 
@@ -532,10 +526,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
             schemas to be applied
     """
     cur.execute(
-        database_engine.convert_param_style(
-            "SELECT file FROM applied_module_schemas WHERE module_name = ?"
-        ),
-        (modname,),
+        "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
     )
     applied_deltas = {d for d, in cur}
     for (name, stream) in names_and_streams:
@@ -553,9 +544,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
 
         # Mark as done.
         cur.execute(
-            database_engine.convert_param_style(
-                "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
-            ),
+            "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)",
             (modname, name),
         )
 
@@ -627,9 +616,7 @@ def _get_or_create_schema_state(txn, database_engine):
 
     if current_version:
         txn.execute(
-            database_engine.convert_param_style(
-                "SELECT file FROM applied_schema_deltas WHERE version >= ?"
-            ),
+            "SELECT file FROM applied_schema_deltas WHERE version >= ?",
             (current_version,),
         )
         applied_deltas = [d for d, in txn]
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 2d2b560e74..970bb1b9da 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -61,3 +61,9 @@ class Connection(Protocol):
 
     def rollback(self, *args, **kwargs) -> None:
         ...
+
+    def __enter__(self) -> "Connection":
+        ...
+
+    def __exit__(self, exc_type, exc_value, traceback) -> bool:
+        ...
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ad017207aa..3d8da48f2d 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -55,7 +55,7 @@ def _load_current_id(db_conn, table, column, step=1):
     """
     # debug logging for https://github.com/matrix-org/synapse/issues/7968
     logger.info("initialising stream generator for %s(%s)", table, column)
-    cur = db_conn.cursor()
+    cur = db_conn.cursor(txn_name="_load_current_id")
     if step == 1:
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
     else:
@@ -270,7 +270,7 @@ class MultiWriterIdGenerator:
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
     ):
-        cur = db_conn.cursor()
+        cur = db_conn.cursor(txn_name="_load_current_ids")
 
         # Load the current positions of all writers for the stream.
         if self._writers:
@@ -284,15 +284,12 @@ class MultiWriterIdGenerator:
                     stream_name = ?
                     AND instance_name != ALL(?)
             """
-            sql = self._db.engine.convert_param_style(sql)
             cur.execute(sql, (self._stream_name, self._writers))
 
             sql = """
                 SELECT instance_name, stream_id FROM stream_positions
                 WHERE stream_name = ?
             """
-            sql = self._db.engine.convert_param_style(sql)
-
             cur.execute(sql, (self._stream_name,))
 
             self._current_positions = {
@@ -341,7 +338,6 @@ class MultiWriterIdGenerator:
                 "instance": instance_column,
                 "cmp": "<=" if self._positive else ">=",
             }
-            sql = self._db.engine.convert_param_style(sql)
             cur.execute(sql, (min_stream_id * self._return_factor,))
 
             self._persisted_upto_position = min_stream_id
@@ -422,7 +418,7 @@ class MultiWriterIdGenerator:
             self._unfinished_ids.discard(next_id)
             self._finished_ids.add(next_id)
 
-            new_cur = None
+            new_cur = None  # type: Optional[int]
 
             if self._unfinished_ids:
                 # If there are unfinished IDs then the new position will be the
@@ -528,6 +524,16 @@ class MultiWriterIdGenerator:
 
         heapq.heappush(self._known_persisted_positions, new_id)
 
+        # If we're a writer and we don't have any active writes we update our
+        # current position to the latest position seen. This allows the instance
+        # to report a recent position when asked, rather than a potentially old
+        # one (if this instance hasn't written anything for a while).
+        our_current_position = self._current_positions.get(self._instance_name)
+        if our_current_position and not self._unfinished_ids:
+            self._current_positions[self._instance_name] = max(
+                our_current_position, new_id
+            )
+
         # We move the current min position up if the minimum current positions
         # of all instances is higher (since by definition all positions less
         # that that have been persisted).
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 2dd95e2709..4386b6101e 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -17,6 +17,7 @@ import logging
 import threading
 from typing import Callable, List, Optional
 
+from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import (
     BaseDatabaseEngine,
     IncorrectDatabaseSetup,
@@ -53,7 +54,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     def check_consistency(
-        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+        self,
+        db_conn: LoggingDatabaseConnection,
+        table: str,
+        id_column: str,
+        positive: bool = True,
     ):
         """Should be called during start up to test that the current value of
         the sequence is greater than or equal to the maximum ID in the table.
@@ -82,9 +87,13 @@ class PostgresSequenceGenerator(SequenceGenerator):
         return [i for (i,) in txn]
 
     def check_consistency(
-        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+        self,
+        db_conn: LoggingDatabaseConnection,
+        table: str,
+        id_column: str,
+        positive: bool = True,
     ):
-        txn = db_conn.cursor()
+        txn = db_conn.cursor(txn_name="sequence.check_consistency")
 
         # First we get the current max ID from the table.
         table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
@@ -117,6 +126,8 @@ class PostgresSequenceGenerator(SequenceGenerator):
         if max_stream_id > last_value:
             logger.warning(
                 "Postgres sequence %s is behind table %s: %d < %d",
+                self._sequence_name,
+                table,
                 last_value,
                 max_stream_id,
             )