summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py20
-rw-r--r--synapse/storage/_base.py1633
-rw-r--r--synapse/storage/background_updates.py31
-rw-r--r--synapse/storage/data_stores/__init__.py16
-rw-r--r--synapse/storage/data_stores/main/__init__.py120
-rw-r--r--synapse/storage/data_stores/main/account_data.py47
-rw-r--r--synapse/storage/data_stores/main/appservice.py29
-rw-r--r--synapse/storage/data_stores/main/cache.py133
-rw-r--r--synapse/storage/data_stores/main/client_ips.py71
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py40
-rw-r--r--synapse/storage/data_stores/main/devices.py108
-rw-r--r--synapse/storage/data_stores/main/directory.py20
-rw-r--r--synapse/storage/data_stores/main/e2e_room_keys.py244
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py73
-rw-r--r--synapse/storage/data_stores/main/event_federation.py51
-rw-r--r--synapse/storage/data_stores/main/event_push_actions.py57
-rw-r--r--synapse/storage/data_stores/main/events.py245
-rw-r--r--synapse/storage/data_stores/main/events_bg_updates.py86
-rw-r--r--synapse/storage/data_stores/main/events_worker.py66
-rw-r--r--synapse/storage/data_stores/main/filtering.py6
-rw-r--r--synapse/storage/data_stores/main/group_server.py164
-rw-r--r--synapse/storage/data_stores/main/keys.py12
-rw-r--r--synapse/storage/data_stores/main/media_repository.py67
-rw-r--r--synapse/storage/data_stores/main/monthly_active_users.py19
-rw-r--r--synapse/storage/data_stores/main/openid.py6
-rw-r--r--synapse/storage/data_stores/main/presence.py12
-rw-r--r--synapse/storage/data_stores/main/profile.py28
-rw-r--r--synapse/storage/data_stores/main/push_rule.py43
-rw-r--r--synapse/storage/data_stores/main/pusher.py38
-rw-r--r--synapse/storage/data_stores/main/receipts.py47
-rw-r--r--synapse/storage/data_stores/main/registration.py278
-rw-r--r--synapse/storage/data_stores/main/rejections.py4
-rw-r--r--synapse/storage/data_stores/main/relations.py12
-rw-r--r--synapse/storage/data_stores/main/room.py322
-rw-r--r--synapse/storage/data_stores/main/roommember.py99
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql21
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql1
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql4
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql16
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql17
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/room_retention.sql33
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql3
-rw-r--r--synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql22
-rw-r--r--synapse/storage/data_stores/main/search.py70
-rw-r--r--synapse/storage/data_stores/main/signatures.py4
-rw-r--r--synapse/storage/data_stores/main/state.py86
-rw-r--r--synapse/storage/data_stores/main/state_deltas.py8
-rw-r--r--synapse/storage/data_stores/main/stats.py75
-rw-r--r--synapse/storage/data_stores/main/stream.py40
-rw-r--r--synapse/storage/data_stores/main/tags.py22
-rw-r--r--synapse/storage/data_stores/main/transactions.py27
-rw-r--r--synapse/storage/data_stores/main/user_directory.py122
-rw-r--r--synapse/storage/data_stores/main/user_erasure_store.py6
-rw-r--r--synapse/storage/database.py1490
-rw-r--r--synapse/storage/prepare_database.py2
55 files changed, 3559 insertions, 2757 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 0460fe8cc9..ec89f645d4 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -17,10 +17,10 @@
 """
 The storage layer is split up into multiple parts to allow Synapse to run
 against different configurations of databases (e.g. single or multiple
-databases). The `data_stores` are classes that talk directly to a single
-database and have associated schemas, background updates, etc. On top of those
-there are (or will be) classes that provide high level interfaces that combine
-calls to multiple `data_stores`.
+databases). The `Database` class represents a single physical database. The
+`data_stores` are classes that talk directly to a `Database` instance and have
+associated schemas, background updates, etc. On top of those there are classes
+that provide high level interfaces that combine calls to multiple `data_stores`.
 
 There are also schemas that get applied to every database, regardless of the
 data stores associated with them (e.g. the schema version tables), which are
@@ -49,15 +49,3 @@ class Storage(object):
         self.persistence = EventsPersistenceStorage(hs, stores)
         self.purge_events = PurgeEventsStorage(hs, stores)
         self.state = StateGroupStorage(hs, stores)
-
-
-def are_all_users_on_domain(txn, database_engine, domain):
-    sql = database_engine.convert_param_style(
-        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
-    )
-    pat = "%:" + domain
-    txn.execute(sql, (pat,))
-    num_not_matching = txn.fetchall()[0][0]
-    if num_not_matching == 0:
-        return True
-    return False
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ab596fa68d..b7637b5dc0 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,1400 +14,36 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import itertools
 import logging
 import random
-import sys
-import threading
-import time
-from typing import Iterable, Tuple
 
-from six import PY2, iteritems, iterkeys, itervalues
-from six.moves import builtins, intern, range
+from six import PY2
+from six.moves import builtins
 
 from canonicaljson import json
-from prometheus_client import Histogram
 
-from twisted.internet import defer
-
-from synapse.api.errors import StoreError
-from synapse.logging.context import LoggingContext, make_deferred_yieldable
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.database import LoggingTransaction  # noqa: F401
+from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
+from synapse.storage.database import Database
 from synapse.types import get_domain_from_id
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import Cache
-from synapse.util.stringutils import exception_to_unicode
-
-# import a function which will return a monotonic time, in seconds
-try:
-    # on python 3, use time.monotonic, since time.clock can go backwards
-    from time import monotonic as monotonic_time
-except ImportError:
-    # ... but python 2 doesn't have it
-    from time import clock as monotonic_time
 
 logger = logging.getLogger(__name__)
 
-try:
-    MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
-    # python 3 does not have a maximum int value
-    MAX_TXN_ID = 2 ** 63 - 1
-
-sql_logger = logging.getLogger("synapse.storage.SQL")
-transaction_logger = logging.getLogger("synapse.storage.txn")
-perf_logger = logging.getLogger("synapse.storage.TIME")
-
-sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
-
-sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
-sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
-
-
-# Unique indexes which have been added in background updates. Maps from table name
-# to the name of the background update which added the unique index to that table.
-#
-# This is used by the upsert logic to figure out which tables are safe to do a proper
-# UPSERT on: until the relevant background update has completed, we
-# have to emulate an upsert by locking the table.
-#
-UNIQUE_INDEX_BACKGROUND_UPDATES = {
-    "user_ips": "user_ips_device_unique_index",
-    "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
-    "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
-    "event_search": "event_search_event_id_idx",
-}
-
-# This is a special cache name we use to batch multiple invalidations of caches
-# based on the current state when notifying workers over replication.
-_CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-
 
-class LoggingTransaction(object):
-    """An object that almost-transparently proxies for the 'txn' object
-    passed to the constructor. Adds logging and metrics to the .execute()
-    method.
+class SQLBaseStore(object):
+    """Base class for data stores that holds helper functions.
 
-    Args:
-        txn: The database transcation object to wrap.
-        name (str): The name of this transactions for logging.
-        database_engine (Sqlite3Engine|PostgresEngine)
-        after_callbacks(list|None): A list that callbacks will be appended to
-            that have been added by `call_after` which should be run on
-            successful completion of the transaction. None indicates that no
-            callbacks should be allowed to be scheduled to run.
-        exception_callbacks(list|None): A list that callbacks will be appended
-            to that have been added by `call_on_exception` which should be run
-            if transaction ends with an error. None indicates that no callbacks
-            should be allowed to be scheduled to run.
+    Note that multiple instances of this class will exist as there will be one
+    per data store (and not one per physical database).
     """
 
-    __slots__ = [
-        "txn",
-        "name",
-        "database_engine",
-        "after_callbacks",
-        "exception_callbacks",
-    ]
-
-    def __init__(
-        self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
-    ):
-        object.__setattr__(self, "txn", txn)
-        object.__setattr__(self, "name", name)
-        object.__setattr__(self, "database_engine", database_engine)
-        object.__setattr__(self, "after_callbacks", after_callbacks)
-        object.__setattr__(self, "exception_callbacks", exception_callbacks)
-
-    def call_after(self, callback, *args, **kwargs):
-        """Call the given callback on the main twisted thread after the
-        transaction has finished. Used to invalidate the caches on the
-        correct thread.
-        """
-        self.after_callbacks.append((callback, args, kwargs))
-
-    def call_on_exception(self, callback, *args, **kwargs):
-        self.exception_callbacks.append((callback, args, kwargs))
-
-    def __getattr__(self, name):
-        return getattr(self.txn, name)
-
-    def __setattr__(self, name, value):
-        setattr(self.txn, name, value)
-
-    def __iter__(self):
-        return self.txn.__iter__()
-
-    def execute_batch(self, sql, args):
-        if isinstance(self.database_engine, PostgresEngine):
-            from psycopg2.extras import execute_batch
-
-            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
-        else:
-            for val in args:
-                self.execute(sql, val)
-
-    def execute(self, sql, *args):
-        self._do_execute(self.txn.execute, sql, *args)
-
-    def executemany(self, sql, *args):
-        self._do_execute(self.txn.executemany, sql, *args)
-
-    def _make_sql_one_line(self, sql):
-        "Strip newlines out of SQL so that the loggers in the DB are on one line"
-        return " ".join(l.strip() for l in sql.splitlines() if l.strip())
-
-    def _do_execute(self, func, sql, *args):
-        sql = self._make_sql_one_line(sql)
-
-        # TODO(paul): Maybe use 'info' and 'debug' for values?
-        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
-
-        sql = self.database_engine.convert_param_style(sql)
-        if args:
-            try:
-                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
-            except Exception:
-                # Don't let logging failures stop SQL from working
-                pass
-
-        start = time.time()
-
-        try:
-            return func(sql, *args)
-        except Exception as e:
-            logger.debug("[SQL FAIL] {%s} %s", self.name, e)
-            raise
-        finally:
-            secs = time.time() - start
-            sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
-            sql_query_timer.labels(sql.split()[0]).observe(secs)
-
-
-class PerformanceCounters(object):
-    def __init__(self):
-        self.current_counters = {}
-        self.previous_counters = {}
-
-    def update(self, key, duration_secs):
-        count, cum_time = self.current_counters.get(key, (0, 0))
-        count += 1
-        cum_time += duration_secs
-        self.current_counters[key] = (count, cum_time)
-
-    def interval(self, interval_duration_secs, limit=3):
-        counters = []
-        for name, (count, cum_time) in iteritems(self.current_counters):
-            prev_count, prev_time = self.previous_counters.get(name, (0, 0))
-            counters.append(
-                (
-                    (cum_time - prev_time) / interval_duration_secs,
-                    count - prev_count,
-                    name,
-                )
-            )
-
-        self.previous_counters = dict(self.current_counters)
-
-        counters.sort(reverse=True)
-
-        top_n_counters = ", ".join(
-            "%s(%d): %.3f%%" % (name, count, 100 * ratio)
-            for ratio, count, name in counters[:limit]
-        )
-
-        return top_n_counters
-
-
-class SQLBaseStore(object):
-    _TXN_ID = 0
-
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
-        self._db_pool = hs.get_db_pool()
-
-        self._previous_txn_total_time = 0
-        self._current_txn_total_time = 0
-        self._previous_loop_ts = 0
-
-        # TODO(paul): These can eventually be removed once the metrics code
-        #   is running in mainline, and we have some nice monitoring frontends
-        #   to watch it
-        self._txn_perf_counters = PerformanceCounters()
-
-        self._get_event_cache = Cache(
-            "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
-        )
-
-        self._event_fetch_lock = threading.Condition()
-        self._event_fetch_list = []
-        self._event_fetch_ongoing = 0
-
-        self._pending_ds = []
-
         self.database_engine = hs.database_engine
-
-        # A set of tables that are not safe to use native upserts in.
-        self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
-
-        self._account_validity = self.hs.config.account_validity
-
-        # We add the user_directory_search table to the blacklist on SQLite
-        # because the existing search table does not have an index, making it
-        # unsafe to use native upserts.
-        if isinstance(self.database_engine, Sqlite3Engine):
-            self._unsafe_to_upsert_tables.add("user_directory_search")
-
-        if self.database_engine.can_native_upsert:
-            # Check ASAP (and then later, every 1s) to see if we have finished
-            # background updates of tables that aren't safe to update.
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "upsert_safety_check",
-                self._check_safe_to_upsert,
-            )
-
+        self.db = database
         self.rand = random.SystemRandom()
 
-        if self._account_validity.enabled:
-            self._clock.call_later(
-                0.0,
-                run_as_background_process,
-                "account_validity_set_expiration_dates",
-                self._set_expiration_date_when_missing,
-            )
-
-    @defer.inlineCallbacks
-    def _check_safe_to_upsert(self):
-        """
-        Is it safe to use native UPSERT?
-
-        If there are background updates, we will need to wait, as they may be
-        the addition of indexes that set the UNIQUE constraint that we require.
-
-        If the background updates have not completed, wait 15 sec and check again.
-        """
-        updates = yield self._simple_select_list(
-            "background_updates",
-            keyvalues=None,
-            retcols=["update_name"],
-            desc="check_background_updates",
-        )
-        updates = [x["update_name"] for x in updates]
-
-        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
-            if update_name not in updates:
-                logger.debug("Now safe to upsert in %s", table)
-                self._unsafe_to_upsert_tables.discard(table)
-
-        # If there's any updates still running, reschedule to run.
-        if updates:
-            self._clock.call_later(
-                15.0,
-                run_as_background_process,
-                "upsert_safety_check",
-                self._check_safe_to_upsert,
-            )
-
-    @defer.inlineCallbacks
-    def _set_expiration_date_when_missing(self):
-        """
-        Retrieves the list of registered users that don't have an expiration date, and
-        adds an expiration date for each of them.
-        """
-
-        def select_users_with_no_expiration_date_txn(txn):
-            """Retrieves the list of registered users with no expiration date from the
-            database, filtering out deactivated users.
-            """
-            sql = (
-                "SELECT users.name FROM users"
-                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
-                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
-            )
-            txn.execute(sql, [])
-
-            res = self.cursor_to_dict(txn)
-            if res:
-                for user in res:
-                    self.set_expiration_date_for_user_txn(
-                        txn, user["name"], use_delta=True
-                    )
-
-        yield self.runInteraction(
-            "get_users_with_no_expiration_date",
-            select_users_with_no_expiration_date_txn,
-        )
-
-    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
-        """Sets an expiration date to the account with the given user ID.
-
-        Args:
-             user_id (str): User ID to set an expiration date for.
-             use_delta (bool): If set to False, the expiration date for the user will be
-                now + validity period. If set to True, this expiration date will be a
-                random value in the [now + period - d ; now + period] range, d being a
-                delta equal to 10% of the validity period.
-        """
-        now_ms = self._clock.time_msec()
-        expiration_ts = now_ms + self._account_validity.period
-
-        if use_delta:
-            expiration_ts = self.rand.randrange(
-                expiration_ts - self._account_validity.startup_job_max_delta,
-                expiration_ts,
-            )
-
-        self._simple_upsert_txn(
-            txn,
-            "account_validity",
-            keyvalues={"user_id": user_id},
-            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
-        )
-
-    def start_profiling(self):
-        self._previous_loop_ts = monotonic_time()
-
-        def loop():
-            curr = self._current_txn_total_time
-            prev = self._previous_txn_total_time
-            self._previous_txn_total_time = curr
-
-            time_now = monotonic_time()
-            time_then = self._previous_loop_ts
-            self._previous_loop_ts = time_now
-
-            duration = time_now - time_then
-            ratio = (curr - prev) / duration
-
-            top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
-
-            perf_logger.info(
-                "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
-            )
-
-        self._clock.looping_call(loop, 10000)
-
-    def _new_transaction(
-        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
-    ):
-        start = monotonic_time()
-        txn_id = self._TXN_ID
-
-        # We don't really need these to be unique, so lets stop it from
-        # growing really large.
-        self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
-
-        name = "%s-%x" % (desc, txn_id)
-
-        transaction_logger.debug("[TXN START] {%s}", name)
-
-        try:
-            i = 0
-            N = 5
-            while True:
-                try:
-                    txn = conn.cursor()
-                    txn = LoggingTransaction(
-                        txn,
-                        name,
-                        self.database_engine,
-                        after_callbacks,
-                        exception_callbacks,
-                    )
-                    r = func(txn, *args, **kwargs)
-                    conn.commit()
-                    return r
-                except self.database_engine.module.OperationalError as e:
-                    # This can happen if the database disappears mid
-                    # transaction.
-                    logger.warning(
-                        "[TXN OPERROR] {%s} %s %d/%d",
-                        name,
-                        exception_to_unicode(e),
-                        i,
-                        N,
-                    )
-                    if i < N:
-                        i += 1
-                        try:
-                            conn.rollback()
-                        except self.database_engine.module.Error as e1:
-                            logger.warning(
-                                "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
-                            )
-                        continue
-                    raise
-                except self.database_engine.module.DatabaseError as e:
-                    if self.database_engine.is_deadlock(e):
-                        logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
-                        if i < N:
-                            i += 1
-                            try:
-                                conn.rollback()
-                            except self.database_engine.module.Error as e1:
-                                logger.warning(
-                                    "[TXN EROLL] {%s} %s",
-                                    name,
-                                    exception_to_unicode(e1),
-                                )
-                            continue
-                    raise
-        except Exception as e:
-            logger.debug("[TXN FAIL] {%s} %s", name, e)
-            raise
-        finally:
-            end = monotonic_time()
-            duration = end - start
-
-            LoggingContext.current_context().add_database_transaction(duration)
-
-            transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
-
-            self._current_txn_total_time += duration
-            self._txn_perf_counters.update(desc, duration)
-            sql_txn_timer.labels(desc).observe(duration)
-
-    @defer.inlineCallbacks
-    def runInteraction(self, desc, func, *args, **kwargs):
-        """Starts a transaction on the database and runs a given function
-
-        Arguments:
-            desc (str): description of the transaction, for logging and metrics
-            func (func): callback function, which will be called with a
-                database transaction (twisted.enterprise.adbapi.Transaction) as
-                its first argument, followed by `args` and `kwargs`.
-
-            args (list): positional args to pass to `func`
-            kwargs (dict): named args to pass to `func`
-
-        Returns:
-            Deferred: The result of func
-        """
-        after_callbacks = []
-        exception_callbacks = []
-
-        if LoggingContext.current_context() == LoggingContext.sentinel:
-            logger.warning("Starting db txn '%s' from sentinel context", desc)
-
-        try:
-            result = yield self.runWithConnection(
-                self._new_transaction,
-                desc,
-                after_callbacks,
-                exception_callbacks,
-                func,
-                *args,
-                **kwargs
-            )
-
-            for after_callback, after_args, after_kwargs in after_callbacks:
-                after_callback(*after_args, **after_kwargs)
-        except:  # noqa: E722, as we reraise the exception this is fine.
-            for after_callback, after_args, after_kwargs in exception_callbacks:
-                after_callback(*after_args, **after_kwargs)
-            raise
-
-        return result
-
-    @defer.inlineCallbacks
-    def runWithConnection(self, func, *args, **kwargs):
-        """Wraps the .runWithConnection() method on the underlying db_pool.
-
-        Arguments:
-            func (func): callback function, which will be called with a
-                database connection (twisted.enterprise.adbapi.Connection) as
-                its first argument, followed by `args` and `kwargs`.
-            args (list): positional args to pass to `func`
-            kwargs (dict): named args to pass to `func`
-
-        Returns:
-            Deferred: The result of func
-        """
-        parent_context = LoggingContext.current_context()
-        if parent_context == LoggingContext.sentinel:
-            logger.warning(
-                "Starting db connection from sentinel context: metrics will be lost"
-            )
-            parent_context = None
-
-        start_time = monotonic_time()
-
-        def inner_func(conn, *args, **kwargs):
-            with LoggingContext("runWithConnection", parent_context) as context:
-                sched_duration_sec = monotonic_time() - start_time
-                sql_scheduling_timer.observe(sched_duration_sec)
-                context.add_database_scheduled(sched_duration_sec)
-
-                if self.database_engine.is_connection_closed(conn):
-                    logger.debug("Reconnecting closed database connection")
-                    conn.reconnect()
-
-                return func(conn, *args, **kwargs)
-
-        result = yield make_deferred_yieldable(
-            self._db_pool.runWithConnection(inner_func, *args, **kwargs)
-        )
-
-        return result
-
-    @staticmethod
-    def cursor_to_dict(cursor):
-        """Converts a SQL cursor into an list of dicts.
-
-        Args:
-            cursor : The DBAPI cursor which has executed a query.
-        Returns:
-            A list of dicts where the key is the column header.
-        """
-        col_headers = list(intern(str(column[0])) for column in cursor.description)
-        results = list(dict(zip(col_headers, row)) for row in cursor)
-        return results
-
-    def _execute(self, desc, decoder, query, *args):
-        """Runs a single query for a result set.
-
-        Args:
-            decoder - The function which can resolve the cursor results to
-                something meaningful.
-            query - The query string to execute
-            *args - Query args.
-        Returns:
-            The result of decoder(results)
-        """
-
-        def interaction(txn):
-            txn.execute(query, args)
-            if decoder:
-                return decoder(txn)
-            else:
-                return txn.fetchall()
-
-        return self.runInteraction(desc, interaction)
-
-    # "Simple" SQL API methods that operate on a single table with no JOINs,
-    # no complex WHERE clauses, just a dict of values for columns.
-
-    @defer.inlineCallbacks
-    def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
-        """Executes an INSERT query on the named table.
-
-        Args:
-            table : string giving the table name
-            values : dict of new column names and values for them
-            or_ignore : bool stating whether an exception should be raised
-                when a conflicting row already exists. If True, False will be
-                returned by the function instead
-            desc : string giving a description of the transaction
-
-        Returns:
-            bool: Whether the row was inserted or not. Only useful when
-            `or_ignore` is True
-        """
-        try:
-            yield self.runInteraction(desc, self._simple_insert_txn, table, values)
-        except self.database_engine.module.IntegrityError:
-            # We have to do or_ignore flag at this layer, since we can't reuse
-            # a cursor after we receive an error from the db.
-            if not or_ignore:
-                raise
-            return False
-        return True
-
-    @staticmethod
-    def _simple_insert_txn(txn, table, values):
-        keys, vals = zip(*values.items())
-
-        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
-            table,
-            ", ".join(k for k in keys),
-            ", ".join("?" for _ in keys),
-        )
-
-        txn.execute(sql, vals)
-
-    def _simple_insert_many(self, table, values, desc):
-        return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
-
-    @staticmethod
-    def _simple_insert_many_txn(txn, table, values):
-        if not values:
-            return
-
-        # This is a *slight* abomination to get a list of tuples of key names
-        # and a list of tuples of value names.
-        #
-        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
-        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
-        #
-        # The sort is to ensure that we don't rely on dictionary iteration
-        # order.
-        keys, vals = zip(
-            *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
-        )
-
-        for k in keys:
-            if k != keys[0]:
-                raise RuntimeError("All items must have the same keys")
-
-        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
-            table,
-            ", ".join(k for k in keys[0]),
-            ", ".join("?" for _ in keys[0]),
-        )
-
-        txn.executemany(sql, vals)
-
-    @defer.inlineCallbacks
-    def _simple_upsert(
-        self,
-        table,
-        keyvalues,
-        values,
-        insertion_values={},
-        desc="_simple_upsert",
-        lock=True,
-    ):
-        """
-
-        `lock` should generally be set to True (the default), but can be set
-        to False if either of the following are true:
-
-        * there is a UNIQUE INDEX on the key columns. In this case a conflict
-          will cause an IntegrityError in which case this function will retry
-          the update.
-
-        * we somehow know that we are the only thread which will be updating
-          this table.
-
-        Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key columns and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
-        Returns:
-            Deferred(None or bool): Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
-        """
-        attempts = 0
-        while True:
-            try:
-                result = yield self.runInteraction(
-                    desc,
-                    self._simple_upsert_txn,
-                    table,
-                    keyvalues,
-                    values,
-                    insertion_values,
-                    lock=lock,
-                )
-                return result
-            except self.database_engine.module.IntegrityError as e:
-                attempts += 1
-                if attempts >= 5:
-                    # don't retry forever, because things other than races
-                    # can cause IntegrityErrors
-                    raise
-
-                # presumably we raced with another transaction: let's retry.
-                logger.warning(
-                    "IntegrityError when upserting into %s; retrying: %s", table, e
-                )
-
-    def _simple_upsert_txn(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
-        """
-        Pick the UPSERT method which works best on the platform. Either the
-        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
-
-        Args:
-            txn: The transaction to use.
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
-        Returns:
-            None or bool: Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
-        """
-        if (
-            self.database_engine.can_native_upsert
-            and table not in self._unsafe_to_upsert_tables
-        ):
-            return self._simple_upsert_txn_native_upsert(
-                txn, table, keyvalues, values, insertion_values=insertion_values
-            )
-        else:
-            return self._simple_upsert_txn_emulated(
-                txn,
-                table,
-                keyvalues,
-                values,
-                insertion_values=insertion_values,
-                lock=lock,
-            )
-
-    def _simple_upsert_txn_emulated(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
-        """
-        Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
-        Returns:
-            bool: Return True if a new entry was created, False if an existing
-            one was updated.
-        """
-        # We need to lock the table :(, unless we're *really* careful
-        if lock:
-            self.database_engine.lock_table(txn, table)
-
-        def _getwhere(key):
-            # If the value we're passing in is None (aka NULL), we need to use
-            # IS, not =, as NULL = NULL equals NULL (False).
-            if keyvalues[key] is None:
-                return "%s IS ?" % (key,)
-            else:
-                return "%s = ?" % (key,)
-
-        if not values:
-            # If `values` is empty, then all of the values we care about are in
-            # the unique key, so there is nothing to UPDATE. We can just do a
-            # SELECT instead to see if it exists.
-            sql = "SELECT 1 FROM %s WHERE %s" % (
-                table,
-                " AND ".join(_getwhere(k) for k in keyvalues),
-            )
-            sqlargs = list(keyvalues.values())
-            txn.execute(sql, sqlargs)
-            if txn.fetchall():
-                # We have an existing record.
-                return False
-        else:
-            # First try to update.
-            sql = "UPDATE %s SET %s WHERE %s" % (
-                table,
-                ", ".join("%s = ?" % (k,) for k in values),
-                " AND ".join(_getwhere(k) for k in keyvalues),
-            )
-            sqlargs = list(values.values()) + list(keyvalues.values())
-
-            txn.execute(sql, sqlargs)
-            if txn.rowcount > 0:
-                # successfully updated at least one row.
-                return False
-
-        # We didn't find any existing rows, so insert a new one
-        allvalues = {}
-        allvalues.update(keyvalues)
-        allvalues.update(values)
-        allvalues.update(insertion_values)
-
-        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
-            table,
-            ", ".join(k for k in allvalues),
-            ", ".join("?" for _ in allvalues),
-        )
-        txn.execute(sql, list(allvalues.values()))
-        # successfully inserted
-        return True
-
-    def _simple_upsert_txn_native_upsert(
-        self, txn, table, keyvalues, values, insertion_values={}
-    ):
-        """
-        Use the native UPSERT functionality in recent PostgreSQL versions.
-
-        Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-        Returns:
-            None
-        """
-        allvalues = {}
-        allvalues.update(keyvalues)
-        allvalues.update(insertion_values)
-
-        if not values:
-            latter = "NOTHING"
-        else:
-            allvalues.update(values)
-            latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
-
-        sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
-            table,
-            ", ".join(k for k in allvalues),
-            ", ".join("?" for _ in allvalues),
-            ", ".join(k for k in keyvalues),
-            latter,
-        )
-        txn.execute(sql, list(allvalues.values()))
-
-    def _simple_upsert_many_txn(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
-        """
-        Upsert, many times.
-
-        Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
-        """
-        if (
-            self.database_engine.can_native_upsert
-            and table not in self._unsafe_to_upsert_tables
-        ):
-            return self._simple_upsert_many_txn_native_upsert(
-                txn, table, key_names, key_values, value_names, value_values
-            )
-        else:
-            return self._simple_upsert_many_txn_emulated(
-                txn, table, key_names, key_values, value_names, value_values
-            )
-
-    def _simple_upsert_many_txn_emulated(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
-        """
-        Upsert, many times, but without native UPSERT support or batching.
-
-        Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
-        """
-        # No value columns, therefore make a blank list so that the following
-        # zip() works correctly.
-        if not value_names:
-            value_values = [() for x in range(len(key_values))]
-
-        for keyv, valv in zip(key_values, value_values):
-            _keys = {x: y for x, y in zip(key_names, keyv)}
-            _vals = {x: y for x, y in zip(value_names, valv)}
-
-            self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
-
-    def _simple_upsert_many_txn_native_upsert(
-        self, txn, table, key_names, key_values, value_names, value_values
-    ):
-        """
-        Upsert, many times, using batching where possible.
-
-        Args:
-            table (str): The table to upsert into
-            key_names (list[str]): The key column names.
-            key_values (list[list]): A list of each row's key column values.
-            value_names (list[str]): The value column names. If empty, no
-                values will be used, even if value_values is provided.
-            value_values (list[list]): A list of each row's value column values.
-        Returns:
-            None
-        """
-        allnames = []
-        allnames.extend(key_names)
-        allnames.extend(value_names)
-
-        if not value_names:
-            # No value columns, therefore make a blank list so that the
-            # following zip() works correctly.
-            latter = "NOTHING"
-            value_values = [() for x in range(len(key_values))]
-        else:
-            latter = "UPDATE SET " + ", ".join(
-                k + "=EXCLUDED." + k for k in value_names
-            )
-
-        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
-            table,
-            ", ".join(k for k in allnames),
-            ", ".join("?" for _ in allnames),
-            ", ".join(key_names),
-            latter,
-        )
-
-        args = []
-
-        for x, y in zip(key_values, value_values):
-            args.append(tuple(x) + tuple(y))
-
-        return txn.execute_batch(sql, args)
-
-    def _simple_select_one(
-        self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
-    ):
-        """Executes a SELECT query on the named table, which is expected to
-        return a single row, returning multiple columns from it.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcols : list of strings giving the names of the columns to return
-
-            allow_none : If true, return None instead of failing if the SELECT
-              statement returns no rows
-        """
-        return self.runInteraction(
-            desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
-        )
-
-    def _simple_select_one_onecol(
-        self,
-        table,
-        keyvalues,
-        retcol,
-        allow_none=False,
-        desc="_simple_select_one_onecol",
-    ):
-        """Executes a SELECT query on the named table, which is expected to
-        return a single row, returning a single column from it.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcol : string giving the name of the column to return
-        """
-        return self.runInteraction(
-            desc,
-            self._simple_select_one_onecol_txn,
-            table,
-            keyvalues,
-            retcol,
-            allow_none=allow_none,
-        )
-
-    @classmethod
-    def _simple_select_one_onecol_txn(
-        cls, txn, table, keyvalues, retcol, allow_none=False
-    ):
-        ret = cls._simple_select_onecol_txn(
-            txn, table=table, keyvalues=keyvalues, retcol=retcol
-        )
-
-        if ret:
-            return ret[0]
-        else:
-            if allow_none:
-                return None
-            else:
-                raise StoreError(404, "No row found")
-
-    @staticmethod
-    def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
-        sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
-
-        if keyvalues:
-            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
-            txn.execute(sql, list(keyvalues.values()))
-        else:
-            txn.execute(sql)
-
-        return [r[0] for r in txn]
-
-    def _simple_select_onecol(
-        self, table, keyvalues, retcol, desc="_simple_select_onecol"
-    ):
-        """Executes a SELECT query on the named table, which returns a list
-        comprising of the values of the named column from the selected rows.
-
-        Args:
-            table (str): table name
-            keyvalues (dict|None): column names and values to select the rows with
-            retcol (str): column whos value we wish to retrieve.
-
-        Returns:
-            Deferred: Results in a list
-        """
-        return self.runInteraction(
-            desc, self._simple_select_onecol_txn, table, keyvalues, retcol
-        )
-
-    def _simple_select_list(
-        self, table, keyvalues, retcols, desc="_simple_select_list"
-    ):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            keyvalues (dict[str, Any] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self.runInteraction(
-            desc, self._simple_select_list_txn, table, keyvalues, retcols
-        )
-
-    @classmethod
-    def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            txn : Transaction object
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
-        """
-        if keyvalues:
-            sql = "SELECT %s FROM %s WHERE %s" % (
-                ", ".join(retcols),
-                table,
-                " AND ".join("%s = ?" % (k,) for k in keyvalues),
-            )
-            txn.execute(sql, list(keyvalues.values()))
-        else:
-            sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
-            txn.execute(sql)
-
-        return cls.cursor_to_dict(txn)
-
-    @defer.inlineCallbacks
-    def _simple_select_many_batch(
-        self,
-        table,
-        column,
-        iterable,
-        retcols,
-        keyvalues={},
-        desc="_simple_select_many_batch",
-        batch_size=100,
-    ):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Filters rows by if value of `column` is in `iterable`.
-
-        Args:
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
-        """
-        results = []
-
-        if not iterable:
-            return results
-
-        # iterables can not be sliced, so convert it to a list first
-        it_list = list(iterable)
-
-        chunks = [
-            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
-        ]
-        for chunk in chunks:
-            rows = yield self.runInteraction(
-                desc,
-                self._simple_select_many_txn,
-                table,
-                column,
-                chunk,
-                keyvalues,
-                retcols,
-            )
-
-            results.extend(rows)
-
-        return results
-
-    @classmethod
-    def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Filters rows by if value of `column` is in `iterable`.
-
-        Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
-        """
-        if not iterable:
-            return []
-
-        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
-        clauses = [clause]
-
-        for key, value in iteritems(keyvalues):
-            clauses.append("%s = ?" % (key,))
-            values.append(value)
-
-        sql = "SELECT %s FROM %s WHERE %s" % (
-            ", ".join(retcols),
-            table,
-            " AND ".join(clauses),
-        )
-
-        txn.execute(sql, values)
-        return cls.cursor_to_dict(txn)
-
-    def _simple_update(self, table, keyvalues, updatevalues, desc):
-        return self.runInteraction(
-            desc, self._simple_update_txn, table, keyvalues, updatevalues
-        )
-
-    @staticmethod
-    def _simple_update_txn(txn, table, keyvalues, updatevalues):
-        if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
-        else:
-            where = ""
-
-        update_sql = "UPDATE %s SET %s %s" % (
-            table,
-            ", ".join("%s = ?" % (k,) for k in updatevalues),
-            where,
-        )
-
-        txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
-
-        return txn.rowcount
-
-    def _simple_update_one(
-        self, table, keyvalues, updatevalues, desc="_simple_update_one"
-    ):
-        """Executes an UPDATE query on the named table, setting new values for
-        columns in a row matching the key values.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            updatevalues : dict giving column names and values to update
-            retcols : optional list of column names to return
-
-        If present, retcols gives a list of column names on which to perform
-        a SELECT statement *before* performing the UPDATE statement. The values
-        of these will be returned in a dict.
-
-        These are performed within the same transaction, allowing an atomic
-        get-and-set.  This can be used to implement compare-and-set by putting
-        the update column in the 'keyvalues' dict as well.
-        """
-        return self.runInteraction(
-            desc, self._simple_update_one_txn, table, keyvalues, updatevalues
-        )
-
-    @classmethod
-    def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
-        rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
-
-        if rowcount == 0:
-            raise StoreError(404, "No row found (%s)" % (table,))
-        if rowcount > 1:
-            raise StoreError(500, "More than one row matched (%s)" % (table,))
-
-    @staticmethod
-    def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
-        select_sql = "SELECT %s FROM %s WHERE %s" % (
-            ", ".join(retcols),
-            table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues),
-        )
-
-        txn.execute(select_sql, list(keyvalues.values()))
-        row = txn.fetchone()
-
-        if not row:
-            if allow_none:
-                return None
-            raise StoreError(404, "No row found (%s)" % (table,))
-        if txn.rowcount > 1:
-            raise StoreError(500, "More than one row matched (%s)" % (table,))
-
-        return dict(zip(retcols, row))
-
-    def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
-        """Executes a DELETE query on the named table, expecting to delete a
-        single row.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-        """
-        return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
-
-    @staticmethod
-    def _simple_delete_one_txn(txn, table, keyvalues):
-        """Executes a DELETE query on the named table, expecting to delete a
-        single row.
-
-        Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-        """
-        sql = "DELETE FROM %s WHERE %s" % (
-            table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues),
-        )
-
-        txn.execute(sql, list(keyvalues.values()))
-        if txn.rowcount == 0:
-            raise StoreError(404, "No row found (%s)" % (table,))
-        if txn.rowcount > 1:
-            raise StoreError(500, "More than one row matched (%s)" % (table,))
-
-    def _simple_delete(self, table, keyvalues, desc):
-        return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
-
-    @staticmethod
-    def _simple_delete_txn(txn, table, keyvalues):
-        sql = "DELETE FROM %s WHERE %s" % (
-            table,
-            " AND ".join("%s = ?" % (k,) for k in keyvalues),
-        )
-
-        txn.execute(sql, list(keyvalues.values()))
-        return txn.rowcount
-
-    def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
-        return self.runInteraction(
-            desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
-        )
-
-    @staticmethod
-    def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
-        """Executes a DELETE query on the named table.
-
-        Filters rows by if value of `column` is in `iterable`.
-
-        Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-
-        Returns:
-            int: Number rows deleted
-        """
-        if not iterable:
-            return 0
-
-        sql = "DELETE FROM %s" % table
-
-        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
-        clauses = [clause]
-
-        for key, value in iteritems(keyvalues):
-            clauses.append("%s = ?" % (key,))
-            values.append(value)
-
-        if clauses:
-            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
-        txn.execute(sql, values)
-
-        return txn.rowcount
-
-    def _get_cache_dict(
-        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
-    ):
-        # Fetch a mapping of room_id -> max stream position for "recent" rooms.
-        # It doesn't really matter how many we get, the StreamChangeCache will
-        # do the right thing to ensure it respects the max size of cache.
-        sql = (
-            "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
-            " WHERE %(stream)s > ? - %(limit)s"
-            " GROUP BY %(entity)s"
-        ) % {
-            "table": table,
-            "entity": entity_column,
-            "stream": stream_column,
-            "limit": limit,
-        }
-
-        sql = self.database_engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
-        txn.execute(sql, (int(max_value),))
-
-        cache = {row[0]: int(row[1]) for row in txn}
-
-        txn.close()
-
-        if cache:
-            min_val = min(itervalues(cache))
-        else:
-            min_val = max_value
-
-        return cache, min_val
-
-    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
-        """Invalidates the cache and adds it to the cache stream so slaves
-        will know to invalidate their caches.
-
-        This should only be used to invalidate caches where slaves won't
-        otherwise know from other replication streams that the cache should
-        be invalidated.
-        """
-        txn.call_after(cache_func.invalidate, keys)
-        self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
-
-    def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
-        """Special case invalidation of caches based on current state.
-
-        We special case this so that we can batch the cache invalidations into a
-        single replication poke.
-
-        Args:
-            txn
-            room_id (str): Room where state changed
-            members_changed (iterable[str]): The user_ids of members that have changed
-        """
-        txn.call_after(self._invalidate_state_caches, room_id, members_changed)
-
-        if members_changed:
-            # We need to be careful that the size of the `members_changed` list
-            # isn't so large that it causes problems sending over replication, so we
-            # send them in chunks.
-            # Max line length is 16K, and max user ID length is 255, so 50 should
-            # be safe.
-            for chunk in batch_iter(members_changed, 50):
-                keys = itertools.chain([room_id], chunk)
-                self._send_invalidation_to_replication(
-                    txn, _CURRENT_STATE_CACHE_NAME, keys
-                )
-        else:
-            # if no members changed, we still need to invalidate the other caches.
-            self._send_invalidation_to_replication(
-                txn, _CURRENT_STATE_CACHE_NAME, [room_id]
-            )
-
     def _invalidate_state_caches(self, room_id, members_changed):
         """Invalidates caches that are based on the current state, but does
         not stream invalidations down replication.
@@ -1441,226 +77,6 @@ class SQLBaseStore(object):
             # which is fine.
             pass
 
-    def _send_invalidation_to_replication(self, txn, cache_name, keys):
-        """Notifies replication that given cache has been invalidated.
-
-        Note that this does *not* invalidate the cache locally.
-
-        Args:
-            txn
-            cache_name (str)
-            keys (iterable[str])
-        """
-
-        if isinstance(self.database_engine, PostgresEngine):
-            # get_next() returns a context manager which is designed to wrap
-            # the transaction. However, we want to only get an ID when we want
-            # to use it, here, so we need to call __enter__ manually, and have
-            # __exit__ called after the transaction finishes.
-            ctx = self._cache_id_gen.get_next()
-            stream_id = ctx.__enter__()
-            txn.call_on_exception(ctx.__exit__, None, None, None)
-            txn.call_after(ctx.__exit__, None, None, None)
-            txn.call_after(self.hs.get_notifier().on_new_replication_data)
-
-            self._simple_insert_txn(
-                txn,
-                table="cache_invalidation_stream",
-                values={
-                    "stream_id": stream_id,
-                    "cache_func": cache_name,
-                    "keys": list(keys),
-                    "invalidation_ts": self.clock.time_msec(),
-                },
-            )
-
-    def get_all_updated_caches(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_updated_caches_txn(txn):
-            # We purposefully don't bound by the current token, as we want to
-            # send across cache invalidations as quickly as possible. Cache
-            # invalidations are idempotent, so duplicates are fine.
-            sql = (
-                "SELECT stream_id, cache_func, keys, invalidation_ts"
-                " FROM cache_invalidation_stream"
-                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, limit))
-            return txn.fetchall()
-
-        return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
-
-    def get_cache_stream_token(self):
-        if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token()
-        else:
-            return 0
-
-    def _simple_select_list_paginate(
-        self,
-        table,
-        keyvalues,
-        orderby,
-        start,
-        limit,
-        retcols,
-        order_direction="ASC",
-        desc="_simple_select_list_paginate",
-    ):
-        """
-        Executes a SELECT query on the named table with start and limit,
-        of row numbers, which may return zero or number of rows from start to limit,
-        returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self.runInteraction(
-            desc,
-            self._simple_select_list_paginate_txn,
-            table,
-            keyvalues,
-            orderby,
-            start,
-            limit,
-            retcols,
-            order_direction=order_direction,
-        )
-
-    @classmethod
-    def _simple_select_list_paginate_txn(
-        cls,
-        txn,
-        table,
-        keyvalues,
-        orderby,
-        start,
-        limit,
-        retcols,
-        order_direction="ASC",
-    ):
-        """
-        Executes a SELECT query on the named table with start and limit,
-        of row numbers, which may return zero or number of rows from start to limit,
-        returning the result as a list of dicts.
-
-        Args:
-            txn : Transaction object
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        if order_direction not in ["ASC", "DESC"]:
-            raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
-
-        if keyvalues:
-            where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
-        else:
-            where_clause = ""
-
-        sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
-            ", ".join(retcols),
-            table,
-            where_clause,
-            orderby,
-            order_direction,
-        )
-        txn.execute(sql, list(keyvalues.values()) + [limit, start])
-
-        return cls.cursor_to_dict(txn)
-
-    def get_user_count_txn(self, txn):
-        """Get a total number of registered users in the users list.
-
-        Args:
-            txn : Transaction object
-        Returns:
-            int : number of users
-        """
-        sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
-        txn.execute(sql_count)
-        return txn.fetchone()[0]
-
-    def _simple_search_list(
-        self, table, term, col, retcols, desc="_simple_search_list"
-    ):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
-        """
-
-        return self.runInteraction(
-            desc, self._simple_search_list_txn, table, term, col, retcols
-        )
-
-    @classmethod
-    def _simple_search_list_txn(cls, txn, table, term, col, retcols):
-        """Executes a SELECT query on the named table, which may return zero or
-        more rows, returning the result as a list of dicts.
-
-        Args:
-            txn : Transaction object
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
-        """
-        if term:
-            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
-            termvalues = ["%%" + term + "%%"]
-            txn.execute(sql, termvalues)
-        else:
-            return 0
-
-        return cls.cursor_to_dict(txn)
-
-    @property
-    def database_engine_name(self):
-        return self.database_engine.module.__name__
-
-    def get_server_version(self):
-        """Returns a string describing the server version number"""
-        return self.database_engine.server_version
-
-
-class _RollbackButIsFineException(Exception):
-    """ This exception is used to rollback a transaction without implying
-    something went wrong.
-    """
-
-    pass
-
 
 def db_to_json(db_content):
     """
@@ -1689,30 +105,3 @@ def db_to_json(db_content):
     except Exception:
         logging.warning("Tried to decode '%r' as JSON and failed", db_content)
         raise
-
-
-def make_in_list_sql_clause(
-    database_engine, column: str, iterable: Iterable
-) -> Tuple[str, Iterable]:
-    """Returns an SQL clause that checks the given column is in the iterable.
-
-    On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
-    it expands to `column = ANY(?)`. While both DBs support the `IN` form,
-    using the `ANY` form on postgres means that it views queries with
-    different length iterables as the same, helping the query stats.
-
-    Args:
-        database_engine
-        column: Name of the column
-        iterable: The values to check the column against.
-
-    Returns:
-        A tuple of SQL query and the args
-    """
-
-    if database_engine.supports_using_any_list:
-        # This should hopefully be faster, but also makes postgres query
-        # stats easier to understand.
-        return "%s = ANY(?)" % (column,), [list(iterable)]
-    else:
-        return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 37d469ffd7..4f97fd5ab6 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -22,7 +22,6 @@ from twisted.internet import defer
 from synapse.metrics.background_process_metrics import run_as_background_process
 
 from . import engines
-from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
@@ -74,7 +73,7 @@ class BackgroundUpdatePerformance(object):
             return float(self.total_item_count) / float(self.total_duration_ms)
 
 
-class BackgroundUpdateStore(SQLBaseStore):
+class BackgroundUpdater(object):
     """ Background updates are updates to the database that run in the
     background. Each update processes a batch of data at once. We attempt to
     limit the impact of each update by monitoring how long each batch takes to
@@ -86,8 +85,10 @@ class BackgroundUpdateStore(SQLBaseStore):
     BACKGROUND_UPDATE_INTERVAL_MS = 1000
     BACKGROUND_UPDATE_DURATION_MS = 100
 
-    def __init__(self, db_conn, hs):
-        super(BackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, hs, database):
+        self._clock = hs.get_clock()
+        self.db = database
+
         self._background_update_performance = {}
         self._background_update_queue = []
         self._background_update_handlers = {}
@@ -101,9 +102,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         logger.info("Starting background schema updates")
         while True:
             if sleep:
-                yield self.hs.get_clock().sleep(
-                    self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0
-                )
+                yield self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
 
             try:
                 result = yield self.do_next_background_update(
@@ -139,7 +138,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         # otherwise, check if there are updates to be run. This is important,
         # as we may be running on a worker which doesn't perform the bg updates
         # itself, but still wants to wait for them to happen.
-        updates = yield self._simple_select_onecol(
+        updates = yield self.db.simple_select_onecol(
             "background_updates",
             keyvalues=None,
             retcol="1",
@@ -161,7 +160,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         if update_name in self._background_update_queue:
             return False
 
-        update_exists = await self._simple_select_one_onecol(
+        update_exists = await self.db.simple_select_one_onecol(
             "background_updates",
             keyvalues={"update_name": update_name},
             retcol="1",
@@ -184,7 +183,7 @@ class BackgroundUpdateStore(SQLBaseStore):
             no more work to do.
         """
         if not self._background_update_queue:
-            updates = yield self._simple_select_list(
+            updates = yield self.db.simple_select_list(
                 "background_updates",
                 keyvalues=None,
                 retcols=("update_name", "depends_on"),
@@ -226,7 +225,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         else:
             batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
 
-        progress_json = yield self._simple_select_one_onecol(
+        progress_json = yield self.db.simple_select_one_onecol(
             "background_updates",
             keyvalues={"update_name": update_name},
             retcol="progress_json",
@@ -380,7 +379,7 @@ class BackgroundUpdateStore(SQLBaseStore):
             logger.debug("[SQL] %s", sql)
             c.execute(sql)
 
-        if isinstance(self.database_engine, engines.PostgresEngine):
+        if isinstance(self.db.engine, engines.PostgresEngine):
             runner = create_index_psql
         elif psql_only:
             runner = None
@@ -391,7 +390,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         def updater(progress, batch_size):
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
-                yield self.runWithConnection(runner)
+                yield self.db.runWithConnection(runner)
             yield self._end_background_update(update_name)
             return 1
 
@@ -413,7 +412,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         self._background_update_queue = []
         progress_json = json.dumps(progress)
 
-        return self._simple_insert(
+        return self.db.simple_insert(
             "background_updates",
             {"update_name": update_name, "progress_json": progress_json},
         )
@@ -429,7 +428,7 @@ class BackgroundUpdateStore(SQLBaseStore):
         self._background_update_queue = [
             name for name in self._background_update_queue if name != update_name
         ]
-        return self._simple_delete_one(
+        return self.db.simple_delete_one(
             "background_updates", keyvalues={"update_name": update_name}
         )
 
@@ -444,7 +443,7 @@ class BackgroundUpdateStore(SQLBaseStore):
 
         progress_json = json.dumps(progress)
 
-        self._simple_update_one_txn(
+        self.db.simple_update_one_txn(
             txn,
             "background_updates",
             keyvalues={"update_name": update_name},
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index cb184a98cc..cafedd5c0d 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -13,6 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from synapse.storage.database import Database
+from synapse.storage.prepare_database import prepare_database
+
 
 class DataStores(object):
     """The various data stores.
@@ -20,7 +23,14 @@ class DataStores(object):
     These are low level interfaces to physical databases.
     """
 
-    def __init__(self, main_store, db_conn, hs):
-        # Note we pass in the main store here as workers use a different main
+    def __init__(self, main_store_class, db_conn, hs):
+        # Note we pass in the main store class here as workers use a different main
         # store.
-        self.main = main_store
+        database = Database(hs)
+
+        # Check that db is correctly configured.
+        database.engine.check_database(db_conn.cursor())
+
+        prepare_database(db_conn, database.engine, config=hs.config)
+
+        self.main = main_store_class(database, db_conn, hs)
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 10c940df1e..c577c0df5f 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -19,9 +19,8 @@ import calendar
 import logging
 import time
 
-from twisted.internet import defer
-
 from synapse.api.constants import PresenceState
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     ChainedIdGenerator,
@@ -32,6 +31,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from .account_data import AccountDataStore
 from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
+from .cache import CacheInvalidationStore
 from .client_ips import ClientIpStore
 from .deviceinbox import DeviceInboxStore
 from .devices import DeviceStore
@@ -110,11 +110,22 @@ class DataStore(
     MonthlyActiveUsersStore,
     StatsStore,
     RelationsStore,
+    CacheInvalidationStore,
 ):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.hs = hs
         self._clock = hs.get_clock()
-        self.database_engine = hs.database_engine
+        self.database_engine = database.engine
+
+        all_users_native = are_all_users_on_domain(
+            db_conn.cursor(), database.engine, hs.hostname
+        )
+        if not all_users_native:
+            raise Exception(
+                "Found users in database not native to %s!\n"
+                "You cannot changed a synapse server_name after it's been configured"
+                % (hs.hostname,)
+            )
 
         self._stream_id_gen = StreamIdGenerator(
             db_conn,
@@ -169,9 +180,11 @@ class DataStore(
         else:
             self._cache_id_gen = None
 
+        super(DataStore, self).__init__(database, db_conn, hs)
+
         self._presence_on_startup = self._get_active_presence(db_conn)
 
-        presence_cache_prefill, min_presence_val = self._get_cache_dict(
+        presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
             db_conn,
             "presence_stream",
             entity_column="user_id",
@@ -185,7 +198,7 @@ class DataStore(
         )
 
         max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
-        device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
+        device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
             db_conn,
             "device_inbox",
             entity_column="user_id",
@@ -200,7 +213,7 @@ class DataStore(
         )
         # The federation outbox and the local device inbox uses the same
         # stream_id generator.
-        device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
+        device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
             db_conn,
             "device_federation_outbox",
             entity_column="destination",
@@ -226,7 +239,7 @@ class DataStore(
         )
 
         events_max = self._stream_id_gen.get_current_token()
-        curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+        curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
             db_conn,
             "current_state_delta_stream",
             entity_column="room_id",
@@ -240,7 +253,7 @@ class DataStore(
             prefilled_cache=curr_state_delta_prefill,
         )
 
-        _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+        _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
             db_conn,
             "local_group_updates",
             entity_column="user_id",
@@ -260,8 +273,6 @@ class DataStore(
         # Used in _generate_user_daily_visits to keep track of progress
         self._last_user_visit_update = self._get_start_of_day()
 
-        super(DataStore, self).__init__(db_conn, hs)
-
     def take_presence_startup_info(self):
         active_on_startup = self._presence_on_startup
         self._presence_on_startup = None
@@ -281,7 +292,7 @@ class DataStore(
 
         txn = db_conn.cursor()
         txn.execute(sql, (PresenceState.OFFLINE,))
-        rows = self.cursor_to_dict(txn)
+        rows = self.db.cursor_to_dict(txn)
         txn.close()
 
         for row in rows:
@@ -294,7 +305,7 @@ class DataStore(
         Counts the number of users who used this homeserver in the last 24 hours.
         """
         yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-        return self.runInteraction("count_daily_users", self._count_users, yesterday)
+        return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
 
     def count_monthly_users(self):
         """
@@ -304,7 +315,7 @@ class DataStore(
         amongst other things, includes a 3 day grace period before a user counts.
         """
         thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-        return self.runInteraction(
+        return self.db.runInteraction(
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
@@ -404,7 +415,7 @@ class DataStore(
 
             return results
 
-        return self.runInteraction("count_r30_users", _count_r30_users)
+        return self.db.runInteraction("count_r30_users", _count_r30_users)
 
     def _get_start_of_day(self):
         """
@@ -469,50 +480,73 @@ class DataStore(
             # frequently
             self._last_user_visit_update = now
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "generate_user_daily_visits", _generate_user_daily_visits
         )
 
     def get_users(self):
-        """Function to reterive a list of users in users table.
+        """Function to retrieve a list of users in users table.
 
         Args:
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             table="users",
             keyvalues={},
-            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
+            retcols=[
+                "name",
+                "password_hash",
+                "is_guest",
+                "admin",
+                "user_type",
+                "deactivated",
+            ],
             desc="get_users",
         )
 
-    @defer.inlineCallbacks
-    def get_users_paginate(self, order, start, limit):
-        """Function to reterive a paginated list of users from
-        users list. This will return a json object, which contains
-        list of users and the total number of users in users table.
+    def get_users_paginate(
+        self, start, limit, name=None, guests=True, deactivated=False
+    ):
+        """Function to retrieve a paginated list of users from
+        users list. This will return a json list of users.
 
         Args:
-            order (str): column name to order the select by this column
             start (int): start number to begin the query from
-            limit (int): number of rows to reterive
+            limit (int): number of rows to retrieve
+            name (string): filter for user names
+            guests (bool): whether to in include guest users
+            deactivated (bool): whether to include deactivated users
         Returns:
-            defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+            defer.Deferred: resolves to list[dict[str, Any]]
         """
-        users = yield self.runInteraction(
-            "get_users_paginate",
-            self._simple_select_list_paginate_txn,
+        name_filter = {}
+        if name:
+            name_filter["name"] = "%" + name + "%"
+
+        attr_filter = {}
+        if not guests:
+            attr_filter["is_guest"] = False
+        if not deactivated:
+            attr_filter["deactivated"] = False
+
+        return self.db.simple_select_list_paginate(
+            desc="get_users_paginate",
             table="users",
-            keyvalues={"is_guest": False},
-            orderby=order,
+            orderby="name",
             start=start,
             limit=limit,
-            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
+            filters=name_filter,
+            keyvalues=attr_filter,
+            retcols=[
+                "name",
+                "password_hash",
+                "is_guest",
+                "admin",
+                "user_type",
+                "deactivated",
+            ],
         )
-        count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
-        retval = {"users": users, "total": count}
-        return retval
 
     def search_users(self, term):
         """Function to search users list for one or more users with
@@ -524,10 +558,22 @@ class DataStore(
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
-        return self._simple_search_list(
+        return self.db.simple_search_list(
             table="users",
             term=term,
             col="name",
             retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
             desc="search_users",
         )
+
+
+def are_all_users_on_domain(txn, database_engine, domain):
+    sql = database_engine.convert_param_style(
+        "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+    )
+    pat = "%:" + domain
+    txn.execute(sql, (pat,))
+    num_not_matching = txn.fetchall()[0][0]
+    if num_not_matching == 0:
+        return True
+    return False
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 6afbfc0d74..46b494b334 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,6 +22,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore):
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         account_max = self.get_max_account_data_stream_id()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max
         )
 
-        super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+        super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
 
     @abc.abstractmethod
     def get_max_account_data_stream_id(self):
@@ -67,7 +68,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         """
 
         def get_account_data_for_user_txn(txn):
-            rows = self._simple_select_list_txn(
+            rows = self.db.simple_select_list_txn(
                 txn,
                 "account_data",
                 {"user_id": user_id},
@@ -78,7 +79,7 @@ class AccountDataWorkerStore(SQLBaseStore):
                 row["account_data_type"]: json.loads(row["content"]) for row in rows
             }
 
-            rows = self._simple_select_list_txn(
+            rows = self.db.simple_select_list_txn(
                 txn,
                 "room_account_data",
                 {"user_id": user_id},
@@ -92,7 +93,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return global_account_data, by_room
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
@@ -102,7 +103,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         Returns:
             Deferred: A dict
         """
-        result = yield self._simple_select_one_onecol(
+        result = yield self.db.simple_select_one_onecol(
             table="account_data",
             keyvalues={"user_id": user_id, "account_data_type": data_type},
             retcol="content",
@@ -127,7 +128,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         """
 
         def get_account_data_for_room_txn(txn):
-            rows = self._simple_select_list_txn(
+            rows = self.db.simple_select_list_txn(
                 txn,
                 "room_account_data",
                 {"user_id": user_id, "room_id": room_id},
@@ -138,7 +139,7 @@ class AccountDataWorkerStore(SQLBaseStore):
                 row["account_data_type"]: json.loads(row["content"]) for row in rows
             }
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
@@ -156,7 +157,7 @@ class AccountDataWorkerStore(SQLBaseStore):
         """
 
         def get_account_data_for_room_and_type_txn(txn):
-            content_json = self._simple_select_one_onecol_txn(
+            content_json = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="room_account_data",
                 keyvalues={
@@ -170,7 +171,7 @@ class AccountDataWorkerStore(SQLBaseStore):
 
             return json.loads(content_json) if content_json else None
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
         )
 
@@ -184,14 +185,14 @@ class AccountDataWorkerStore(SQLBaseStore):
             current_id(int): The position to fetch up to.
         Returns:
             A deferred pair of lists of tuples of stream_id int, user_id string,
-            room_id string, type string, and content string.
+            room_id string, and type string.
         """
         if last_room_id == current_id and last_global_id == current_id:
             return defer.succeed(([], []))
 
         def get_updated_account_data_txn(txn):
             sql = (
-                "SELECT stream_id, user_id, account_data_type, content"
+                "SELECT stream_id, user_id, account_data_type"
                 " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
@@ -199,7 +200,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             global_results = txn.fetchall()
 
             sql = (
-                "SELECT stream_id, user_id, room_id, account_data_type, content"
+                "SELECT stream_id, user_id, room_id, account_data_type"
                 " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
                 " ORDER BY stream_id ASC LIMIT ?"
             )
@@ -207,7 +208,7 @@ class AccountDataWorkerStore(SQLBaseStore):
             room_results = txn.fetchall()
             return global_results, room_results
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_account_data_txn", get_updated_account_data_txn
         )
 
@@ -250,9 +251,9 @@ class AccountDataWorkerStore(SQLBaseStore):
             user_id, int(stream_id)
         )
         if not changed:
-            return {}, {}
+            return defer.succeed(({}, {}))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
         )
 
@@ -270,12 +271,12 @@ class AccountDataWorkerStore(SQLBaseStore):
 
 
 class AccountDataStore(AccountDataWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self._account_data_id_gen = StreamIdGenerator(
             db_conn, "account_data_max_stream_id", "stream_id"
         )
 
-        super(AccountDataStore, self).__init__(db_conn, hs)
+        super(AccountDataStore, self).__init__(database, db_conn, hs)
 
     def get_max_account_data_stream_id(self):
         """Get the current max stream id for the private user data stream
@@ -300,9 +301,9 @@ class AccountDataStore(AccountDataWorkerStore):
 
         with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as room_account_data has a unique constraint
-            # on (user_id, room_id, account_data_type) so _simple_upsert will
+            # on (user_id, room_id, account_data_type) so simple_upsert will
             # retry if there is a conflict.
-            yield self._simple_upsert(
+            yield self.db.simple_upsert(
                 desc="add_room_account_data",
                 table="room_account_data",
                 keyvalues={
@@ -346,9 +347,9 @@ class AccountDataStore(AccountDataWorkerStore):
 
         with self._account_data_id_gen.get_next() as next_id:
             # no need to lock here as account_data has a unique constraint on
-            # (user_id, account_data_type) so _simple_upsert will retry if
+            # (user_id, account_data_type) so simple_upsert will retry if
             # there is a conflict.
-            yield self._simple_upsert(
+            yield self.db.simple_upsert(
                 desc="add_user_account_data",
                 table="account_data",
                 keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@@ -388,4 +389,4 @@ class AccountDataStore(AccountDataWorkerStore):
             )
             txn.execute(update_max_id_sql, (next_id, next_id))
 
-        return self.runInteraction("update_account_data_max_stream_id", _update)
+        return self.db.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 81babf2029..b2f39649fd 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -24,6 +24,7 @@ from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 
 logger = logging.getLogger(__name__)
 
@@ -48,13 +49,13 @@ def _make_exclusive_regex(services_cache):
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.app_service_config_files
         )
         self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
 
-        super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+        super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
 
     def get_app_services(self):
         return self.services_cache
@@ -133,7 +134,7 @@ class ApplicationServiceTransactionWorkerStore(
             A Deferred which resolves to a list of ApplicationServices, which
             may be empty.
         """
-        results = yield self._simple_select_list(
+        results = yield self.db.simple_select_list(
             "application_services_state", dict(state=state), ["as_id"]
         )
         # NB: This assumes this class is linked with ApplicationServiceStore
@@ -155,7 +156,7 @@ class ApplicationServiceTransactionWorkerStore(
         Returns:
             A Deferred which resolves to ApplicationServiceState.
         """
-        result = yield self._simple_select_one(
+        result = yield self.db.simple_select_one(
             "application_services_state",
             dict(as_id=service.id),
             ["state"],
@@ -175,7 +176,7 @@ class ApplicationServiceTransactionWorkerStore(
         Returns:
             A Deferred which resolves when the state was set successfully.
         """
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             "application_services_state", dict(as_id=service.id), dict(state=state)
         )
 
@@ -216,7 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
             )
             return AppServiceTransaction(service=service, id=new_txn_id, events=events)
 
-        return self.runInteraction("create_appservice_txn", _create_appservice_txn)
+        return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
 
     def complete_appservice_txn(self, txn_id, service):
         """Completes an application service transaction.
@@ -249,7 +250,7 @@ class ApplicationServiceTransactionWorkerStore(
                 )
 
             # Set current txn_id for AS to 'txn_id'
-            self._simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 "application_services_state",
                 dict(as_id=service.id),
@@ -257,11 +258,13 @@ class ApplicationServiceTransactionWorkerStore(
             )
 
             # Delete txn
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
             )
 
-        return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
+        return self.db.runInteraction(
+            "complete_appservice_txn", _complete_appservice_txn
+        )
 
     @defer.inlineCallbacks
     def get_oldest_unsent_txn(self, service):
@@ -283,7 +286,7 @@ class ApplicationServiceTransactionWorkerStore(
                 " ORDER BY txn_id ASC LIMIT 1",
                 (service.id,),
             )
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return None
 
@@ -291,7 +294,7 @@ class ApplicationServiceTransactionWorkerStore(
 
             return entry
 
-        entry = yield self.runInteraction(
+        entry = yield self.db.runInteraction(
             "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
         )
 
@@ -321,7 +324,7 @@ class ApplicationServiceTransactionWorkerStore(
                 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "set_appservice_last_pos", set_appservice_last_pos_txn
         )
 
@@ -350,7 +353,7 @@ class ApplicationServiceTransactionWorkerStore(
 
             return upper_bound, [row[1] for row in rows]
 
-        upper_bound, event_ids = yield self.runInteraction(
+        upper_bound, event_ids = yield self.db.runInteraction(
             "get_new_events_for_appservice", get_new_events_for_appservice_txn
         )
 
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
new file mode 100644
index 0000000000..54ed8574c4
--- /dev/null
+++ b/synapse/storage/data_stores/main/cache.py
@@ -0,0 +1,133 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import itertools
+import logging
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.engines import PostgresEngine
+from synapse.util import batch_iter
+
+logger = logging.getLogger(__name__)
+
+
+# This is a special cache name we use to batch multiple invalidations of caches
+# based on the current state when notifying workers over replication.
+CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
+
+
+class CacheInvalidationStore(SQLBaseStore):
+    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+        """Invalidates the cache and adds it to the cache stream so slaves
+        will know to invalidate their caches.
+
+        This should only be used to invalidate caches where slaves won't
+        otherwise know from other replication streams that the cache should
+        be invalidated.
+        """
+        txn.call_after(cache_func.invalidate, keys)
+        self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
+
+    def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+        """Special case invalidation of caches based on current state.
+
+        We special case this so that we can batch the cache invalidations into a
+        single replication poke.
+
+        Args:
+            txn
+            room_id (str): Room where state changed
+            members_changed (iterable[str]): The user_ids of members that have changed
+        """
+        txn.call_after(self._invalidate_state_caches, room_id, members_changed)
+
+        if members_changed:
+            # We need to be careful that the size of the `members_changed` list
+            # isn't so large that it causes problems sending over replication, so we
+            # send them in chunks.
+            # Max line length is 16K, and max user ID length is 255, so 50 should
+            # be safe.
+            for chunk in batch_iter(members_changed, 50):
+                keys = itertools.chain([room_id], chunk)
+                self._send_invalidation_to_replication(
+                    txn, CURRENT_STATE_CACHE_NAME, keys
+                )
+        else:
+            # if no members changed, we still need to invalidate the other caches.
+            self._send_invalidation_to_replication(
+                txn, CURRENT_STATE_CACHE_NAME, [room_id]
+            )
+
+    def _send_invalidation_to_replication(self, txn, cache_name, keys):
+        """Notifies replication that given cache has been invalidated.
+
+        Note that this does *not* invalidate the cache locally.
+
+        Args:
+            txn
+            cache_name (str)
+            keys (iterable[str])
+        """
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # get_next() returns a context manager which is designed to wrap
+            # the transaction. However, we want to only get an ID when we want
+            # to use it, here, so we need to call __enter__ manually, and have
+            # __exit__ called after the transaction finishes.
+            ctx = self._cache_id_gen.get_next()
+            stream_id = ctx.__enter__()
+            txn.call_on_exception(ctx.__exit__, None, None, None)
+            txn.call_after(ctx.__exit__, None, None, None)
+            txn.call_after(self.hs.get_notifier().on_new_replication_data)
+
+            self.db.simple_insert_txn(
+                txn,
+                table="cache_invalidation_stream",
+                values={
+                    "stream_id": stream_id,
+                    "cache_func": cache_name,
+                    "keys": list(keys),
+                    "invalidation_ts": self.clock.time_msec(),
+                },
+            )
+
+    def get_all_updated_caches(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_updated_caches_txn(txn):
+            # We purposefully don't bound by the current token, as we want to
+            # send across cache invalidations as quickly as possible. Cache
+            # invalidations are idempotent, so duplicates are fine.
+            sql = (
+                "SELECT stream_id, cache_func, keys, invalidation_ts"
+                " FROM cache_invalidation_stream"
+                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, limit))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_all_updated_caches", get_all_updated_caches_txn
+        )
+
+    def get_cache_stream_token(self):
+        if self._cache_id_gen:
+            return self._cache_id_gen.get_current_token()
+        else:
+            return 0
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 706c6a1f3f..320c5b0f07 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -20,9 +20,10 @@ from six import iteritems
 from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage import background_updates
-from synapse.storage._base import Cache
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.caches.descriptors import Cache
 
 logger = logging.getLogger(__name__)
 
@@ -32,41 +33,41 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
-class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
+class ClientIpBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_device_index",
             index_name="user_ips_device_id",
             table="user_ips",
             columns=["user_id", "device_id", "last_seen"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_last_seen_index",
             index_name="user_ips_last_seen",
             table="user_ips",
             columns=["user_id", "last_seen"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_last_seen_only_index",
             index_name="user_ips_last_seen_only",
             table="user_ips",
             columns=["last_seen"],
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_ips_analyze", self._analyze_user_ip
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_ips_remove_dupes", self._remove_user_ip_dupes
         )
 
         # Register a unique index
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "user_ips_device_unique_index",
             index_name="user_ips_user_token_ip_unique_index",
             table="user_ips",
@@ -75,12 +76,12 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
         )
 
         # Drop the old non-unique index
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
         )
 
         # Update the last seen info in devices.
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "devices_last_seen", self._devices_last_seen_update
         )
 
@@ -91,8 +92,8 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
             txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
             txn.close()
 
-        yield self.runWithConnection(f)
-        yield self._end_background_update("user_ips_drop_nonunique_index")
+        yield self.db.runWithConnection(f)
+        yield self.db.updates._end_background_update("user_ips_drop_nonunique_index")
         return 1
 
     @defer.inlineCallbacks
@@ -106,9 +107,9 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
         def user_ips_analyze(txn):
             txn.execute("ANALYZE user_ips")
 
-        yield self.runInteraction("user_ips_analyze", user_ips_analyze)
+        yield self.db.runInteraction("user_ips_analyze", user_ips_analyze)
 
-        yield self._end_background_update("user_ips_analyze")
+        yield self.db.updates._end_background_update("user_ips_analyze")
 
         return 1
 
@@ -140,7 +141,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
                 return None
 
         # Get a last seen that has roughly `batch_size` since `begin_last_seen`
-        end_last_seen = yield self.runInteraction(
+        end_last_seen = yield self.db.runInteraction(
             "user_ips_dups_get_last_seen", get_last_seen
         )
 
@@ -271,14 +272,14 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
                     (user_id, access_token, ip, device_id, user_agent, last_seen),
                 )
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
             )
 
-        yield self.runInteraction("user_ips_dups_remove", remove)
+        yield self.db.runInteraction("user_ips_dups_remove", remove)
 
         if last:
-            yield self._end_background_update("user_ips_remove_dupes")
+            yield self.db.updates._end_background_update("user_ips_remove_dupes")
 
         return batch_size
 
@@ -344,7 +345,7 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
             txn.execute_batch(sql, rows)
 
             _, _, _, user_id, device_id = rows[-1]
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn,
                 "devices_last_seen",
                 {"last_user_id": user_id, "last_device_id": device_id},
@@ -352,24 +353,24 @@ class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
 
             return len(rows)
 
-        updated = yield self.runInteraction(
+        updated = yield self.db.runInteraction(
             "_devices_last_seen_update", _devices_last_seen_update_txn
         )
 
         if not updated:
-            yield self._end_background_update("devices_last_seen")
+            yield self.db.updates._end_background_update("devices_last_seen")
 
         return updated
 
 
 class ClientIpStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
         )
 
-        super(ClientIpStore, self).__init__(db_conn, hs)
+        super(ClientIpStore, self).__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.user_ips_max_age
 
@@ -417,12 +418,12 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         to_update = self._batch_row_update
         self._batch_row_update = {}
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
         )
 
     def _update_client_ips_batch_txn(self, txn, to_update):
-        if "user_ips" in self._unsafe_to_upsert_tables or (
+        if "user_ips" in self.db._unsafe_to_upsert_tables or (
             not self.database_engine.can_native_upsert
         ):
             self.database_engine.lock_table(txn, "user_ips")
@@ -431,7 +432,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
 
             try:
-                self._simple_upsert_txn(
+                self.db.simple_upsert_txn(
                     txn,
                     table="user_ips",
                     keyvalues={
@@ -450,7 +451,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 # Technically an access token might not be associated with
                 # a device so we need to check.
                 if device_id:
-                    self._simple_upsert_txn(
+                    self.db.simple_upsert_txn(
                         txn,
                         table="devices",
                         keyvalues={"user_id": user_id, "device_id": device_id},
@@ -483,7 +484,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         if device_id is not None:
             keyvalues["device_id"] = device_id
 
-        res = yield self._simple_select_list(
+        res = yield self.db.simple_select_list(
             table="devices",
             keyvalues=keyvalues,
             retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@@ -516,7 +517,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 user_agent, _, last_seen = self._batch_row_update[key]
                 results[(access_token, ip)] = (user_agent, last_seen)
 
-        rows = yield self._simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="user_ips",
             keyvalues={"user_id": user_id},
             retcols=["access_token", "ip", "user_agent", "last_seen"],
@@ -546,7 +547,9 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
             # Nothing to do
             return
 
-        if not await self.has_completed_background_update("devices_last_seen"):
+        if not await self.db.updates.has_completed_background_update(
+            "devices_last_seen"
+        ):
             # Only start pruning if we have finished populating the devices
             # last seen info.
             return
@@ -577,4 +580,4 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
         def _prune_old_user_ips_txn(txn):
             txn.execute(sql, (timestamp,))
 
-        await self.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
+        await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 96cd0fb77a..85cfa16850 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
 logger = logging.getLogger(__name__)
@@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 stream_pos = current_stream_id
             return messages, stream_pos
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_new_messages_for_device", get_new_messages_for_device_txn
         )
 
@@ -109,7 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id, device_id, up_to_stream_id))
             return txn.rowcount
 
-        count = yield self.runInteraction(
+        count = yield self.db.runInteraction(
             "delete_messages_for_device", delete_messages_for_device_txn
         )
 
@@ -178,7 +178,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 stream_pos = current_stream_id
             return messages, stream_pos
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_new_device_msgs_for_remote",
             get_new_messages_for_remote_destination_txn,
         )
@@ -203,25 +203,25 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             )
             txn.execute(sql, (destination, up_to_stream_id))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
         )
 
 
-class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
+class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, db_conn, hs):
-        super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_inbox_stream_index",
             index_name="device_inbox_stream_id_user_id",
             table="device_inbox",
             columns=["stream_id", "user_id"],
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
         )
 
@@ -232,9 +232,9 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
             txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
             txn.close()
 
-        yield self.runWithConnection(reindex_txn)
+        yield self.db.runWithConnection(reindex_txn)
 
-        yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+        yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
 
         return 1
 
@@ -242,8 +242,8 @@ class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
 
-    def __init__(self, db_conn, hs):
-        super(DeviceInboxStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceInboxStore, self).__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) to the last stream_id that has been
         # deleted up to. This is so that we can no op deletions.
@@ -294,7 +294,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
         with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
             )
             for user_id in local_messages_by_user_then_device.keys():
@@ -314,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
             # acknowledgement from the first time we received the message.
-            already_inserted = self._simple_select_one_txn(
+            already_inserted = self.db.simple_select_one_txn(
                 txn,
                 table="device_federation_inbox",
                 keyvalues={"origin": origin, "message_id": message_id},
@@ -326,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
             # Add an entry for this message_id so that we know we've processed
             # it.
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="device_federation_inbox",
                 values={
@@ -344,7 +344,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
         with self._device_inbox_id_gen.get_next() as stream_id:
             now_ms = self.clock.time_msec()
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
                 add_messages_txn,
                 now_ms,
@@ -380,7 +380,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
             devices = list(messages_by_device.keys())
             if len(devices) == 1 and devices[0] == "*":
                 # Handle wildcard device_ids.
-                sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
+                sql = "SELECT device_id FROM devices WHERE user_id = ?"
                 txn.execute(sql, (user_id,))
                 message_json = json.dumps(messages_by_device["*"])
                 for row in txn:
@@ -465,6 +465,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
 
             return rows
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_new_device_messages", get_all_new_device_messages_txn
         )
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 71f62036c0..9a828231c4 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -30,16 +30,16 @@ from synapse.logging.opentracing import (
     whitelisted_homeserver,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
-    Cache,
-    SQLBaseStore,
-    db_to_json,
-    make_in_list_sql_clause,
-)
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.types import get_verify_key_from_cross_signing_key
 from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import (
+    Cache,
+    cached,
+    cachedInlineCallbacks,
+    cachedList,
+)
 
 logger = logging.getLogger(__name__)
 
@@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Raises:
             StoreError: if the device is not found
         """
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
@@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore):
             containing "device_id", "user_id" and "display_name" for each
             device.
         """
-        devices = yield self._simple_select_list(
+        devices = yield self.db.simple_select_list(
             table="devices",
             keyvalues={"user_id": user_id, "hidden": False},
             retcols=("user_id", "device_id", "display_name"),
@@ -122,7 +122,7 @@ class DeviceWorkerStore(SQLBaseStore):
         # consider the device update to be too large, and simply skip the
         # stream_id; the rationale being that such a large device list update
         # is likely an error.
-        updates = yield self.runInteraction(
+        updates = yield self.db.runInteraction(
             "get_device_updates_by_remote",
             self._get_device_updates_by_remote_txn,
             destination,
@@ -283,7 +283,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
         """
         devices = (
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_get_e2e_device_keys_txn",
                 self._get_e2e_device_keys_txn,
                 query_map.keys(),
@@ -340,12 +340,12 @@ class DeviceWorkerStore(SQLBaseStore):
             rows = txn.fetchall()
             return rows[0][0]
 
-        return self.runInteraction("get_last_device_update_for_remote_user", f)
+        return self.db.runInteraction("get_last_device_update_for_remote_user", f)
 
     def mark_as_sent_devices_by_remote(self, destination, stream_id):
         """Mark that updates have successfully been sent to the destination.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "mark_as_sent_devices_by_remote",
             self._mark_as_sent_devices_by_remote_txn,
             destination,
@@ -399,7 +399,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
 
         with self._device_list_id_gen.get_next() as stream_id:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
                 from_user_id,
@@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore):
             from_user_id,
             stream_id,
         )
-        self._simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             "user_signature_stream",
             values={
@@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
     @cachedInlineCallbacks(num_args=2, tree=True)
     def _get_cached_user_device(self, user_id, device_id):
-        content = yield self._simple_select_one_onecol(
+        content = yield self.db.simple_select_one_onecol(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id, "device_id": device_id},
             retcol="content",
@@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
     @cachedInlineCallbacks()
     def _get_cached_devices_for_user(self, user_id):
-        devices = yield self._simple_select_list(
+        devices = yield self.db.simple_select_list(
             table="device_lists_remote_cache",
             keyvalues={"user_id": user_id},
             retcols=("device_id", "content"),
@@ -492,7 +492,7 @@ class DeviceWorkerStore(SQLBaseStore):
         Returns:
             (stream_id, devices)
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_devices_with_keys_by_user",
             self._get_devices_with_keys_by_user_txn,
             user_id,
@@ -565,7 +565,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
             return changes
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
         )
 
@@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore):
                 SELECT DISTINCT user_ids FROM user_signature_stream
                 WHERE from_user_id = ? AND stream_id > ?
             """
-            rows = yield self._execute(
+            rows = yield self.db.execute(
                 "get_users_whose_signatures_changed", None, sql, user_id, from_key
             )
             return set(user for row in rows for user in json.loads(row[0]))
@@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore):
             WHERE ? < stream_id AND stream_id <= ?
             GROUP BY user_id, destination
         """
-        return self._execute(
+        return self.db.execute(
             "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
         )
 
@@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
             retcol="stream_id",
@@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore):
         inlineCallbacks=True,
     )
     def get_device_list_last_stream_id_for_remotes(self, user_ids):
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
             iterable=user_ids,
@@ -642,11 +642,11 @@ class DeviceWorkerStore(SQLBaseStore):
         return results
 
 
-class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
+class DeviceBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_lists_stream_idx",
             index_name="device_lists_stream_user_id",
             table="device_lists_stream",
@@ -654,7 +654,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
         )
 
         # create a unique index on device_lists_remote_cache
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_lists_remote_cache_unique_idx",
             index_name="device_lists_remote_cache_unique_id",
             table="device_lists_remote_cache",
@@ -663,7 +663,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
         )
 
         # And one on device_lists_remote_extremeties
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "device_lists_remote_extremeties_unique_idx",
             index_name="device_lists_remote_extremeties_unique_idx",
             table="device_lists_remote_extremeties",
@@ -672,7 +672,7 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
         )
 
         # once they complete, we can remove the old non-unique indexes.
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
             self._drop_device_list_streams_non_unique_indexes,
         )
@@ -685,14 +685,16 @@ class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
             txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
             txn.close()
 
-        yield self.runWithConnection(f)
-        yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
+        yield self.db.runWithConnection(f)
+        yield self.db.updates._end_background_update(
+            DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
+        )
         return 1
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(DeviceStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(DeviceStore, self).__init__(database, db_conn, hs)
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
@@ -722,7 +724,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             return False
 
         try:
-            inserted = yield self._simple_insert(
+            inserted = yield self.db.simple_insert(
                 "devices",
                 values={
                     "user_id": user_id,
@@ -736,7 +738,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             if not inserted:
                 # if the device already exists, check if it's a real device, or
                 # if the device ID is reserved by something else
-                hidden = yield self._simple_select_one_onecol(
+                hidden = yield self.db.simple_select_one_onecol(
                     "devices",
                     keyvalues={"user_id": user_id, "device_id": device_id},
                     retcol="hidden",
@@ -771,7 +773,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Returns:
             defer.Deferred
         """
-        yield self._simple_delete_one(
+        yield self.db.simple_delete_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             desc="delete_device",
@@ -789,7 +791,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Returns:
             defer.Deferred
         """
-        yield self._simple_delete_many(
+        yield self.db.simple_delete_many(
             table="devices",
             column="device_id",
             iterable=device_ids,
@@ -818,7 +820,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             updates["display_name"] = new_display_name
         if not updates:
             return defer.succeed(None)
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="devices",
             keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
             updatevalues=updates,
@@ -829,7 +831,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
     def mark_remote_user_device_list_as_unsubscribed(self, user_id):
         """Mark that we no longer track device lists for remote user.
         """
-        yield self._simple_delete(
+        yield self.db.simple_delete(
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
             desc="mark_remote_user_device_list_as_unsubscribed",
@@ -853,7 +855,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Returns:
             Deferred[None]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_remote_device_list_cache_entry",
             self._update_remote_device_list_cache_entry_txn,
             user_id,
@@ -866,7 +868,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self, txn, user_id, device_id, content, stream_id
     ):
         if content.get("deleted"):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="device_lists_remote_cache",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -874,7 +876,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
             txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
         else:
-            self._simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="device_lists_remote_cache",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -890,7 +892,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
         )
 
-        self._simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
@@ -914,7 +916,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Returns:
             Deferred[None]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_remote_device_list_cache",
             self._update_remote_device_list_cache_txn,
             user_id,
@@ -923,11 +925,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         )
 
     def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
-        self._simple_delete_txn(
+        self.db.simple_delete_txn(
             txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
         )
 
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="device_lists_remote_cache",
             values=[
@@ -946,7 +948,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
         )
 
-        self._simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             table="device_lists_remote_extremeties",
             keyvalues={"user_id": user_id},
@@ -962,7 +964,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         (if any) should be poked.
         """
         with self._device_list_id_gen.get_next() as stream_id:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "add_device_change_to_streams",
                 self._add_device_change_txn,
                 user_id,
@@ -995,7 +997,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             [(user_id, device_id, stream_id) for device_id in device_ids],
         )
 
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="device_lists_stream",
             values=[
@@ -1006,7 +1008,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         context = get_active_span_text_map()
 
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="device_lists_outbound_pokes",
             values=[
@@ -1069,7 +1071,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         return run_as_background_process(
             "prune_old_outbound_device_pokes",
-            self.runInteraction,
+            self.db.runInteraction,
             "_prune_old_outbound_device_pokes",
             _prune_txn,
         )
diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py
index 297966d9f4..c9e7de7d12 100644
--- a/synapse/storage/data_stores/main/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore):
             Deferred: results in namedtuple with keys "room_id" and
             "servers" or None if no association can be found
         """
-        room_id = yield self._simple_select_one_onecol(
+        room_id = yield self.db.simple_select_one_onecol(
             "room_aliases",
             {"room_alias": room_alias.to_string()},
             "room_id",
@@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
         if not room_id:
             return None
 
-        servers = yield self._simple_select_onecol(
+        servers = yield self.db.simple_select_onecol(
             "room_alias_servers",
             {"room_alias": room_alias.to_string()},
             "server",
@@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore):
         return RoomAliasMapping(room_id, room_alias.to_string(), servers)
 
     def get_room_alias_creator(self, room_alias):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="room_aliases",
             keyvalues={"room_alias": room_alias},
             retcol="creator",
@@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore):
 
     @cached(max_entries=5000)
     def get_aliases_for_room(self, room_id):
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             "room_aliases",
             {"room_id": room_id},
             "room_alias",
@@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore):
         """
 
         def alias_txn(txn):
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 "room_aliases",
                 {
@@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore):
                 },
             )
 
-            self._simple_insert_many_txn(
+            self.db.simple_insert_many_txn(
                 txn,
                 table="room_alias_servers",
                 values=[
@@ -117,7 +117,9 @@ class DirectoryStore(DirectoryWorkerStore):
             )
 
         try:
-            ret = yield self.runInteraction("create_room_alias_association", alias_txn)
+            ret = yield self.db.runInteraction(
+                "create_room_alias_association", alias_txn
+            )
         except self.database_engine.module.IntegrityError:
             raise SynapseError(
                 409, "Room alias %s already exists" % room_alias.to_string()
@@ -126,7 +128,7 @@ class DirectoryStore(DirectoryWorkerStore):
 
     @defer.inlineCallbacks
     def delete_room_alias(self, room_alias):
-        room_id = yield self.runInteraction(
+        room_id = yield self.db.runInteraction(
             "delete_room_alias", self._delete_room_alias_txn, room_alias
         )
 
@@ -168,6 +170,6 @@ class DirectoryStore(DirectoryWorkerStore):
                 txn, self.get_aliases_for_room, (new_room_id,)
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "_update_aliases_for_room_txn", _update_aliases_for_room_txn
         )
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 1cbbae5b63..84594cf0a9 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2017 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -24,49 +25,8 @@ from synapse.storage._base import SQLBaseStore
 
 class EndToEndRoomKeyStore(SQLBaseStore):
     @defer.inlineCallbacks
-    def get_e2e_room_key(self, user_id, version, room_id, session_id):
-        """Get the encrypted E2E room key for a given session from a given
-        backup version of room_keys.  We only store the 'best' room key for a given
-        session at a given time, as determined by the handler.
-
-        Args:
-            user_id(str): the user whose backup we're querying
-            version(str): the version ID of the backup for the set of keys we're querying
-            room_id(str): the ID of the room whose keys we're querying.
-                This is a bit redundant as it's implied by the session_id, but
-                we include for consistency with the rest of the API.
-            session_id(str): the session whose room_key we're querying.
-
-        Returns:
-            A deferred dict giving the session_data and message metadata for
-            this room key.
-        """
-
-        row = yield self._simple_select_one(
-            table="e2e_room_keys",
-            keyvalues={
-                "user_id": user_id,
-                "version": version,
-                "room_id": room_id,
-                "session_id": session_id,
-            },
-            retcols=(
-                "first_message_index",
-                "forwarded_count",
-                "is_verified",
-                "session_data",
-            ),
-            desc="get_e2e_room_key",
-        )
-
-        row["session_data"] = json.loads(row["session_data"])
-
-        return row
-
-    @defer.inlineCallbacks
-    def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
-        """Replaces or inserts the encrypted E2E room key for a given session in
-        a given backup
+    def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+        """Replaces the encrypted E2E room key for a given session in a given backup
 
         Args:
             user_id(str): the user whose backup we're setting
@@ -78,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             StoreError
         """
 
-        yield self._simple_upsert(
+        yield self.db.simple_update_one(
             table="e2e_room_keys",
             keyvalues={
                 "user_id": user_id,
@@ -86,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 "room_id": room_id,
                 "session_id": session_id,
             },
-            values={
+            updatevalues={
                 "first_message_index": room_key["first_message_index"],
                 "forwarded_count": room_key["forwarded_count"],
                 "is_verified": room_key["is_verified"],
                 "session_data": json.dumps(room_key["session_data"]),
             },
-            lock=False,
+            desc="update_e2e_room_key",
         )
-        log_kv(
-            {
-                "message": "Set room key",
-                "room_id": room_id,
-                "session_id": session_id,
-                "room_key": room_key,
-            }
+
+    @defer.inlineCallbacks
+    def add_e2e_room_keys(self, user_id, version, room_keys):
+        """Bulk add room keys to a given backup.
+
+        Args:
+            user_id (str): the user whose backup we're adding to
+            version (str): the version ID of the backup for the set of keys we're adding to
+            room_keys (iterable[(str, str, dict)]): the keys to add, in the form
+                (roomID, sessionID, keyData)
+        """
+
+        values = []
+        for (room_id, session_id, room_key) in room_keys:
+            values.append(
+                {
+                    "user_id": user_id,
+                    "version": version,
+                    "room_id": room_id,
+                    "session_id": session_id,
+                    "first_message_index": room_key["first_message_index"],
+                    "forwarded_count": room_key["forwarded_count"],
+                    "is_verified": room_key["is_verified"],
+                    "session_data": json.dumps(room_key["session_data"]),
+                }
+            )
+            log_kv(
+                {
+                    "message": "Set room key",
+                    "room_id": room_id,
+                    "session_id": session_id,
+                    "room_key": room_key,
+                }
+            )
+
+        yield self.db.simple_insert_many(
+            table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
         )
 
     @trace
@@ -110,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         room, or a given session.
 
         Args:
-            user_id(str): the user whose backup we're querying
-            version(str): the version ID of the backup for the set of keys we're querying
-            room_id(str): Optional. the ID of the room whose keys we're querying, if any.
+            user_id (str): the user whose backup we're querying
+            version (str): the version ID of the backup for the set of keys we're querying
+            room_id (str): Optional. the ID of the room whose keys we're querying, if any.
                 If not specified, we return the keys for all the rooms in the backup.
-            session_id(str): Optional. the session whose room_key we're querying, if any.
+            session_id (str): Optional. the session whose room_key we're querying, if any.
                 If specified, we also require the room_id to be specified.
                 If not specified, we return all the keys in this version of
                 the backup (or for the specified room)
@@ -135,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             if session_id:
                 keyvalues["session_id"] = session_id
 
-        rows = yield self._simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="e2e_room_keys",
             keyvalues=keyvalues,
             retcols=(
@@ -162,6 +152,95 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
         return sessions
 
+    def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+        """Get multiple room keys at a time.  The difference between this function and
+        get_e2e_room_keys is that this function can be used to retrieve
+        multiple specific keys at a time, whereas get_e2e_room_keys is used for
+        getting all the keys in a backup version, all the keys for a room, or a
+        specific key.
+
+        Args:
+            user_id (str): the user whose backup we're querying
+            version (str): the version ID of the backup we're querying about
+            room_keys (dict[str, dict[str, iterable[str]]]): a map from
+                room ID -> {"session": [session ids]} indicating the session IDs
+                that we want to query
+
+        Returns:
+           Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
+        """
+
+        return self.db.runInteraction(
+            "get_e2e_room_keys_multi",
+            self._get_e2e_room_keys_multi_txn,
+            user_id,
+            version,
+            room_keys,
+        )
+
+    @staticmethod
+    def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+        if not room_keys:
+            return {}
+
+        where_clauses = []
+        params = [user_id, version]
+        for room_id, room in room_keys.items():
+            sessions = list(room["sessions"])
+            if not sessions:
+                continue
+            params.append(room_id)
+            params.extend(sessions)
+            where_clauses.append(
+                "(room_id = ? AND session_id IN (%s))"
+                % (",".join(["?" for _ in sessions]),)
+            )
+
+        # check if we're actually querying something
+        if not where_clauses:
+            return {}
+
+        sql = """
+        SELECT room_id, session_id, first_message_index, forwarded_count,
+               is_verified, session_data
+        FROM e2e_room_keys
+        WHERE user_id = ? AND version = ? AND (%s)
+        """ % (
+            " OR ".join(where_clauses)
+        )
+
+        txn.execute(sql, params)
+
+        ret = {}
+
+        for row in txn:
+            room_id = row[0]
+            session_id = row[1]
+            ret.setdefault(room_id, {})
+            ret[room_id][session_id] = {
+                "first_message_index": row[2],
+                "forwarded_count": row[3],
+                "is_verified": row[4],
+                "session_data": json.loads(row[5]),
+            }
+
+        return ret
+
+    def count_e2e_room_keys(self, user_id, version):
+        """Get the number of keys in a backup version.
+
+        Args:
+            user_id (str): the user whose backup we're querying
+            version (str): the version ID of the backup we're querying about
+        """
+
+        return self.db.simple_select_one_onecol(
+            table="e2e_room_keys",
+            keyvalues={"user_id": user_id, "version": version},
+            retcol="COUNT(*)",
+            desc="count_e2e_room_keys",
+        )
+
     @trace
     @defer.inlineCallbacks
     def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
@@ -188,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             if session_id:
                 keyvalues["session_id"] = session_id
 
-        yield self._simple_delete(
+        yield self.db.simple_delete(
             table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
         )
 
@@ -219,6 +298,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                 version(str)
                 algorithm(str)
                 auth_data(object): opaque dict supplied by the client
+                etag(int): tag of the keys in the backup
         """
 
         def _get_e2e_room_keys_version_info_txn(txn):
@@ -232,17 +312,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
                     # it isn't there.
                     raise StoreError(404, "No row found")
 
-            result = self._simple_select_one_txn(
+            result = self.db.simple_select_one_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
-                retcols=("version", "algorithm", "auth_data"),
+                retcols=("version", "algorithm", "auth_data", "etag"),
             )
             result["auth_data"] = json.loads(result["auth_data"])
             result["version"] = str(result["version"])
+            if result["etag"] is None:
+                result["etag"] = 0
             return result
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
         )
 
@@ -270,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
             new_version = str(int(current_version) + 1)
 
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 values={
@@ -283,26 +365,38 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
             return new_version
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
         )
 
     @trace
-    def update_e2e_room_keys_version(self, user_id, version, info):
+    def update_e2e_room_keys_version(
+        self, user_id, version, info=None, version_etag=None
+    ):
         """Update a given backup version
 
         Args:
             user_id(str): the user whose backup version we're updating
             version(str): the version ID of the backup version we're updating
-            info(dict): the new backup version info to store
+            info (dict): the new backup version info to store.  If None, then
+                the backup version info is not updated
+            version_etag (Optional[int]): etag of the keys in the backup.  If
+                None, then the etag is not updated
         """
+        updatevalues = {}
 
-        return self._simple_update(
-            table="e2e_room_keys_versions",
-            keyvalues={"user_id": user_id, "version": version},
-            updatevalues={"auth_data": json.dumps(info["auth_data"])},
-            desc="update_e2e_room_keys_version",
-        )
+        if info is not None and "auth_data" in info:
+            updatevalues["auth_data"] = json.dumps(info["auth_data"])
+        if version_etag is not None:
+            updatevalues["etag"] = version_etag
+
+        if updatevalues:
+            return self.db.simple_update(
+                table="e2e_room_keys_versions",
+                keyvalues={"user_id": user_id, "version": version},
+                updatevalues=updatevalues,
+                desc="update_e2e_room_keys_version",
+            )
 
     @trace
     def delete_e2e_room_keys_version(self, user_id, version=None):
@@ -326,19 +420,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             else:
                 this_version = version
 
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="e2e_room_keys",
                 keyvalues={"user_id": user_id, "version": this_version},
             )
 
-            return self._simple_update_one_txn(
+            return self.db.simple_update_one_txn(
                 txn,
                 table="e2e_room_keys_versions",
                 keyvalues={"user_id": user_id, "version": this_version},
                 updatevalues={"deleted": 1},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
         )
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 073412a78d..38cd0ca9b8 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -48,7 +48,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         if not query_list:
             return {}
 
-        results = yield self.runInteraction(
+        results = yield self.db.runInteraction(
             "get_e2e_device_keys",
             self._get_e2e_device_keys_txn,
             query_list,
@@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         )
 
         txn.execute(sql, query_params)
-        rows = self.cursor_to_dict(txn)
+        rows = self.db.cursor_to_dict(txn)
 
         result = {}
         for row in rows:
@@ -138,20 +138,35 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 result.setdefault(user_id, {})[device_id] = None
 
         # get signatures on the device
-        signature_sql = (
-            "SELECT * " "  FROM e2e_cross_signing_signatures " " WHERE %s"
-        ) % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
+        signature_sql = ("SELECT *  FROM e2e_cross_signing_signatures WHERE %s") % (
+            " OR ".join("(" + q + ")" for q in signature_query_clauses)
+        )
 
         txn.execute(signature_sql, signature_query_params)
-        rows = self.cursor_to_dict(txn)
+        rows = self.db.cursor_to_dict(txn)
 
+        # add each cross-signing signature to the correct device in the result dict.
         for row in rows:
+            signing_user_id = row["user_id"]
+            signing_key_id = row["key_id"]
             target_user_id = row["target_user_id"]
             target_device_id = row["target_device_id"]
-            if target_user_id in result and target_device_id in result[target_user_id]:
-                result[target_user_id][target_device_id].setdefault(
-                    "signatures", {}
-                ).setdefault(row["user_id"], {})[row["key_id"]] = row["signature"]
+            signature = row["signature"]
+
+            target_user_result = result.get(target_user_id)
+            if not target_user_result:
+                continue
+
+            target_device_result = target_user_result.get(target_device_id)
+            if not target_device_result:
+                # note that target_device_result will be None for deleted devices.
+                continue
+
+            target_device_signatures = target_device_result.setdefault("signatures", {})
+            signing_user_signatures = target_device_signatures.setdefault(
+                signing_user_id, {}
+            )
+            signing_user_signatures[signing_key_id] = signature
 
         log_kv(result)
         return result
@@ -171,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             key_id) to json string for key
         """
 
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="e2e_one_time_keys_json",
             column="key_id",
             iterable=key_ids,
@@ -204,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             # a unique constraint. If there is a race of two calls to
             # `add_e2e_one_time_keys` then they'll conflict and we will only
             # insert one set.
-            self._simple_insert_many_txn(
+            self.db.simple_insert_many_txn(
                 txn,
                 table="e2e_one_time_keys_json",
                 values=[
@@ -223,7 +238,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
         )
 
@@ -246,7 +261,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
                 result[algorithm] = key_count
             return result
 
-        return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
+        return self.db.runInteraction(
+            "count_e2e_one_time_keys", _count_e2e_one_time_keys
+        )
 
     def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
         """Returns a user's cross-signing key.
@@ -307,7 +324,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         Returns:
             dict of the key data or None if not found
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_e2e_cross_signing_key",
             self._get_e2e_cross_signing_key_txn,
             user_id,
@@ -335,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             WHERE ? < stream_id AND stream_id <= ?
             GROUP BY user_id
         """
-        return self._execute(
+        return self.db.execute(
             "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
         )
 
@@ -352,7 +369,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             set_tag("time_now", time_now)
             set_tag("device_keys", device_keys)
 
-            old_key_json = self._simple_select_one_onecol_txn(
+            old_key_json = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="e2e_device_keys_json",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -368,7 +385,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 log_kv({"Message": "Device key already stored."})
                 return False
 
-            self._simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="e2e_device_keys_json",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -377,7 +394,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             log_kv({"message": "Device keys stored."})
             return True
 
-        return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
+        return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
 
     def claim_e2e_one_time_keys(self, query_list):
         """Take a list of one time keys out of the database"""
@@ -416,7 +433,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 )
             return result
 
-        return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
+        return self.db.runInteraction(
+            "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
+        )
 
     def delete_e2e_keys_by_device(self, user_id, device_id):
         def delete_e2e_keys_by_device_txn(txn):
@@ -427,12 +446,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                     "user_id": user_id,
                 }
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="e2e_device_keys_json",
                 keyvalues={"user_id": user_id, "device_id": device_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="e2e_one_time_keys_json",
                 keyvalues={"user_id": user_id, "device_id": device_id},
@@ -441,7 +460,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
         )
 
@@ -477,7 +496,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         # The "keys" property must only have one entry, which will be the public
         # key, so we just grab the first value in there
         pubkey = next(iter(key["keys"].values()))
-        self._simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             "devices",
             values={
@@ -490,7 +509,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
 
         # and finally, store the key itself
         with self._cross_signing_id_gen.get_next() as stream_id:
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 "e2e_cross_signing_keys",
                 values={
@@ -509,7 +528,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             key_type (str): the type of cross-signing key to set
             key (dict): the key data
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_e2e_cross_signing_key",
             self._set_e2e_cross_signing_key_txn,
             user_id,
@@ -524,7 +543,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             user_id (str): the user who made the signatures
             signatures (iterable[SignatureListItem]): signatures to add
         """
-        return self._simple_insert_many(
+        return self.db.simple_insert_many(
             "e2e_cross_signing_signatures",
             [
                 {
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 90bef0cd2c..1f517e8fad 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -28,6 +28,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -58,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns:
             list of event_ids
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
         )
 
@@ -90,12 +91,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         return list(results)
 
     def get_oldest_events_in_room(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
         )
 
     def get_oldest_events_with_depth_in_room(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_oldest_events_with_depth_in_room",
             self.get_oldest_events_with_depth_in_room_txn,
             room_id,
@@ -126,7 +127,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns
             Deferred[int]
         """
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="events",
             column="event_id",
             iterable=event_ids,
@@ -140,7 +141,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             return max(row["depth"] for row in rows)
 
     def _get_oldest_events_in_room_txn(self, txn, room_id):
-        return self._simple_select_onecol_txn(
+        return self.db.simple_select_onecol_txn(
             txn,
             table="event_backward_extremities",
             keyvalues={"room_id": room_id},
@@ -188,7 +189,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
                 where *hashes* is a map from algorithm to hash.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_latest_event_ids_and_hashes_in_room",
             self._get_latest_event_ids_and_hashes_in_room,
             room_id,
@@ -229,13 +230,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             txn.execute(sql, query_args)
             return [room_id for room_id, in txn]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
         )
 
     @cached(max_entries=5000, iterable=True)
     def get_latest_event_ids_in_room(self, room_id):
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="event_forward_extremities",
             keyvalues={"room_id": room_id},
             retcol="event_id",
@@ -266,12 +267,12 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
     def get_min_depth(self, room_id):
         """ For hte given room, get the minimum depth we have seen for it.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_min_depth", self._get_min_depth_interaction, room_id
         )
 
     def _get_min_depth_interaction(self, txn, room_id):
-        min_depth = self._simple_select_one_onecol_txn(
+        min_depth = self.db.simple_select_one_onecol_txn(
             txn,
             table="room_depth",
             keyvalues={"room_id": room_id},
@@ -337,7 +338,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             txn.execute(sql, (stream_ordering, room_id))
             return [event_id for event_id, in txn]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
@@ -352,7 +353,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             limit (int)
         """
         return (
-            self.runInteraction(
+            self.db.runInteraction(
                 "get_backfill_events",
                 self._get_backfill_events,
                 room_id,
@@ -383,7 +384,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         queue = PriorityQueue()
 
         for event_id in event_list:
-            depth = self._simple_select_one_onecol_txn(
+            depth = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="events",
                 keyvalues={"event_id": event_id, "room_id": room_id},
@@ -415,7 +416,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
     @defer.inlineCallbacks
     def get_missing_events(self, room_id, earliest_events, latest_events, limit):
-        ids = yield self.runInteraction(
+        ids = yield self.db.runInteraction(
             "get_missing_events",
             self._get_missing_events,
             room_id,
@@ -468,7 +469,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         Returns:
             Deferred[list[str]]
         """
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="event_edges",
             column="prev_event_id",
             iterable=event_ids,
@@ -491,10 +492,10 @@ class EventFederationStore(EventFederationWorkerStore):
 
     EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
 
-    def __init__(self, db_conn, hs):
-        super(EventFederationStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventFederationStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
         )
 
@@ -508,7 +509,7 @@ class EventFederationStore(EventFederationWorkerStore):
         if min_depth and depth >= min_depth:
             return
 
-        self._simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             table="room_depth",
             keyvalues={"room_id": room_id},
@@ -520,7 +521,7 @@ class EventFederationStore(EventFederationWorkerStore):
         For the given event, update the event edges table and forward and
         backward extremities tables.
         """
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_edges",
             values=[
@@ -604,13 +605,13 @@ class EventFederationStore(EventFederationWorkerStore):
 
         return run_as_background_process(
             "delete_old_forward_extrem_cache",
-            self.runInteraction,
+            self.db.runInteraction,
             "_delete_old_forward_extrem_cache",
             _delete_old_forward_extrem_cache_txn,
         )
 
     def clean_room_for_join(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
         )
 
@@ -654,17 +655,17 @@ class EventFederationStore(EventFederationWorkerStore):
                 "max_stream_id_exclusive": min_stream_id,
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_AUTH_STATE_ONLY, new_progress
             )
 
             return min_stream_id >= target_min_stream_id
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_AUTH_STATE_ONLY, delete_event_auth
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+            yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY)
 
         return batch_size
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 04ce21ac66..9988a6d3fc 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
@@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight):
 
 
 class EventPushActionsWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
 
         # These get correctly set by _find_stream_orderings_for_times_txn
         self.stream_ordering_month_ago = None
@@ -93,7 +94,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
     def get_unread_event_push_actions_by_room_for_user(
         self, room_id, user_id, last_read_event_id
     ):
-        ret = yield self.runInteraction(
+        ret = yield self.db.runInteraction(
             "get_unread_event_push_actions_by_room",
             self._get_unread_counts_by_receipt_txn,
             room_id,
@@ -177,7 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (min_stream_ordering, max_stream_ordering))
             return [r[0] for r in txn]
 
-        ret = yield self.runInteraction("get_push_action_users_in_range", f)
+        ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
         return ret
 
     @defer.inlineCallbacks
@@ -229,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        after_read_receipt = yield self.runInteraction(
+        after_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
         )
 
@@ -257,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        no_read_receipt = yield self.runInteraction(
+        no_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
         )
 
@@ -329,7 +330,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        after_read_receipt = yield self.runInteraction(
+        after_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
         )
 
@@ -357,7 +358,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        no_read_receipt = yield self.runInteraction(
+        no_read_receipt = yield self.db.runInteraction(
             "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
         )
 
@@ -407,7 +408,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id, min_stream_ordering))
             return bool(txn.fetchone())
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_if_maybe_push_in_range_for_user",
             _get_if_maybe_push_in_range_for_user_txn,
         )
@@ -441,7 +442,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             )
 
         def _add_push_actions_to_staging_txn(txn):
-            # We don't use _simple_insert_many here to avoid the overhead
+            # We don't use simple_insert_many here to avoid the overhead
             # of generating lists of dicts.
 
             sql = """
@@ -458,7 +459,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
                 ),
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_push_actions_to_staging", _add_push_actions_to_staging_txn
         )
 
@@ -472,7 +473,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
         """
 
         try:
-            res = yield self._simple_delete(
+            res = yield self.db.simple_delete(
                 table="event_push_actions_staging",
                 keyvalues={"event_id": event_id},
                 desc="remove_push_actions_from_staging",
@@ -489,7 +490,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
     def _find_stream_orderings_for_times(self):
         return run_as_background_process(
             "event_push_action_stream_orderings",
-            self.runInteraction,
+            self.db.runInteraction,
             "_find_stream_orderings_for_times",
             self._find_stream_orderings_for_times_txn,
         )
@@ -525,7 +526,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
             Deferred[int]: stream ordering of the first event received on/after
                 the timestamp
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "_find_first_stream_ordering_after_ts_txn",
             self._find_first_stream_ordering_after_ts_txn,
             ts,
@@ -611,17 +612,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
 class EventPushActionsStore(EventPushActionsWorkerStore):
     EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
 
-    def __init__(self, db_conn, hs):
-        super(EventPushActionsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventPushActionsStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             self.EPA_HIGHLIGHT_INDEX,
             index_name="event_push_actions_u_highlight",
             table="event_push_actions",
             columns=["user_id", "stream_ordering"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_push_actions_highlights_index",
             index_name="event_push_actions_highlights_index",
             table="event_push_actions",
@@ -677,7 +678,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             )
 
         for event, _ in events_and_contexts:
-            user_ids = self._simple_select_onecol_txn(
+            user_ids = self.db.simple_select_onecol_txn(
                 txn,
                 table="event_push_actions_staging",
                 keyvalues={"event_id": event.event_id},
@@ -727,9 +728,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
                 " LIMIT ?" % (before_clause,)
             )
             txn.execute(sql, args)
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        push_actions = yield self.runInteraction("get_push_actions_for_user", f)
+        push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
         for pa in push_actions:
             pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
         return push_actions
@@ -748,7 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             txn.execute(sql, (stream_ordering,))
             return txn.fetchone()
 
-        result = yield self.runInteraction("get_time_of_last_push_action_before", f)
+        result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
         return result[0] if result else None
 
     @defer.inlineCallbacks
@@ -757,7 +758,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
             return txn.fetchone()
 
-        result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
+        result = yield self.db.runInteraction(
+            "get_latest_push_action_stream_ordering", f
+        )
         return result[0] or 0
 
     def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
@@ -830,7 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
             while True:
                 logger.info("Rotating notifications")
 
-                caught_up = yield self.runInteraction(
+                caught_up = yield self.db.runInteraction(
                     "_rotate_notifs", self._rotate_notifs_txn
                 )
                 if caught_up:
@@ -844,7 +847,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         the archiving process has caught up or not.
         """
 
-        old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+        old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
             keyvalues={},
@@ -880,7 +883,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         return caught_up
 
     def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
-        old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+        old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
             txn,
             table="event_push_summary_stream_ordering",
             keyvalues={},
@@ -912,7 +915,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
         # If the `old.user_id` above is NULL then we know there isn't already an
         # entry in the table, so we simply insert it. Otherwise we update the
         # existing table.
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_push_summary",
             values=[
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 878f7568a6..998bba1aad 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -38,10 +38,10 @@ from synapse.logging.utils import log_function
 from synapse.metrics import BucketCollector
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.data_stores.main.event_federation import EventFederationStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.database import Database
 from synapse.types import RoomStreamToken, get_domain_from_id
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -94,13 +94,10 @@ def _retry_on_integrity_error(func):
 # inherits from EventFederationStore so that we can call _update_backward_extremities
 # and _handle_mult_prev_events (though arguably those could both be moved in here)
 class EventsStore(
-    StateGroupWorkerStore,
-    EventFederationStore,
-    EventsWorkerStore,
-    BackgroundUpdateStore,
+    StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
 ):
-    def __init__(self, db_conn, hs):
-        super(EventsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsStore, self).__init__(database, db_conn, hs)
 
         # Collect metrics on the number of forward extremities that exist.
         # Counter of number of extremities to count
@@ -130,6 +127,8 @@ class EventsStore(
         if self.hs.config.redaction_retention_period is not None:
             hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
 
+        self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
     @defer.inlineCallbacks
     def _read_forward_extremities(self):
         def fetch(txn):
@@ -141,7 +140,7 @@ class EventsStore(
             )
             return txn.fetchall()
 
-        res = yield self.runInteraction("read_forward_extremities", fetch)
+        res = yield self.db.runInteraction("read_forward_extremities", fetch)
         self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
 
     @_retry_on_integrity_error
@@ -206,7 +205,7 @@ class EventsStore(
             for (event, context), stream in zip(events_and_contexts, stream_orderings):
                 event.internal_metadata.stream_ordering = stream
 
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "persist_events",
                 self._persist_events_txn,
                 events_and_contexts=events_and_contexts,
@@ -279,7 +278,7 @@ class EventsStore(
             results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
             )
 
@@ -343,7 +342,7 @@ class EventsStore(
                         existing_prevs.add(prev_event_id)
 
         for chunk in batch_iter(event_ids, 100):
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
             )
 
@@ -430,7 +429,7 @@ class EventsStore(
         # event's auth chain, but its easier for now just to store them (and
         # it doesn't take much storage compared to storing the entire event
         # anyway).
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_auth",
             values=[
@@ -578,12 +577,12 @@ class EventsStore(
         self, txn, new_forward_extremities, max_stream_order
     ):
         for room_id, new_extrem in iteritems(new_forward_extremities):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
             )
             txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
 
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_forward_extremities",
             values=[
@@ -596,7 +595,7 @@ class EventsStore(
         # new stream_ordering to new forward extremeties in the room.
         # This allows us to later efficiently look up the forward extremeties
         # for a room before a given stream_ordering
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="stream_ordering_to_exterm",
             values=[
@@ -713,16 +712,14 @@ class EventsStore(
 
                 metadata_json = encode_json(event.internal_metadata.get_dict())
 
-                sql = (
-                    "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?"
-                )
+                sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
                 txn.execute(sql, (metadata_json, event.event_id))
 
                 # Add an entry to the ex_outlier_stream table to replicate the
                 # change in outlier status to our workers.
                 stream_order = event.internal_metadata.stream_ordering
                 state_group_id = context.state_group
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="ex_outlier_stream",
                     values={
@@ -732,7 +729,7 @@ class EventsStore(
                     },
                 )
 
-                sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?"
+                sql = "UPDATE events SET outlier = ? WHERE event_id = ?"
                 txn.execute(sql, (False, event.event_id))
 
                 # Update the event_backward_extremities table now that this
@@ -794,7 +791,7 @@ class EventsStore(
             d.pop("redacted_because", None)
             return d
 
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_json",
             values=[
@@ -811,7 +808,7 @@ class EventsStore(
             ],
         )
 
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="events",
             values=[
@@ -841,7 +838,7 @@ class EventsStore(
                 # If we're persisting an unredacted event we go and ensure
                 # that we mark any redactions that reference this event as
                 # requiring censoring.
-                self._simple_update_txn(
+                self.db.simple_update_txn(
                     txn,
                     table="redactions",
                     keyvalues={"redacts": event.event_id},
@@ -929,6 +926,9 @@ class EventsStore(
             elif event.type == EventTypes.Redaction:
                 # Insert into the redactions table.
                 self._store_redaction(txn, event)
+            elif event.type == EventTypes.Retention:
+                # Update the room_retention table.
+                self._store_retention_policy_for_room_txn(txn, event)
 
             self._handle_event_relations(txn, event)
 
@@ -939,6 +939,12 @@ class EventsStore(
                     txn, event.event_id, labels, event.room_id, event.depth
                 )
 
+            if self._ephemeral_messages_enabled:
+                # If there's an expiry timestamp on the event, store it.
+                expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+                if isinstance(expiry_ts, int) and not event.is_state():
+                    self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
+
         # Insert into the room_memberships table.
         self._store_room_members_txn(
             txn,
@@ -974,7 +980,7 @@ class EventsStore(
 
             state_values.append(vals)
 
-        self._simple_insert_many_txn(txn, table="state_events", values=state_values)
+        self.db.simple_insert_many_txn(txn, table="state_events", values=state_values)
 
         # Prefill the event cache
         self._add_to_cache(txn, events_and_contexts)
@@ -1005,7 +1011,7 @@ class EventsStore(
             )
 
             txn.execute(sql + clause, args)
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             for row in rows:
                 event = ev_map[row["event_id"]]
                 if not row["rejects"] and not row["redacts"]:
@@ -1023,7 +1029,7 @@ class EventsStore(
         # invalidate the cache for the redacted event
         txn.call_after(self._invalidate_get_event_cache, event.redacts)
 
-        self._simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="redactions",
             values={
@@ -1033,20 +1039,25 @@ class EventsStore(
             },
         )
 
-    @defer.inlineCallbacks
-    def _censor_redactions(self):
+    async def _censor_redactions(self):
         """Censors all redactions older than the configured period that haven't
         been censored yet.
 
         By censor we mean update the event_json table with the redacted event.
-
-        Returns:
-            Deferred
         """
 
         if self.hs.config.redaction_retention_period is None:
             return
 
+        if not (
+            await self.db.updates.has_completed_background_update(
+                "redactions_have_censored_ts_idx"
+            )
+        ):
+            # We don't want to run this until the appropriate index has been
+            # created.
+            return
+
         before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
 
         # We fetch all redactions that:
@@ -1068,15 +1079,15 @@ class EventsStore(
             LIMIT ?
         """
 
-        rows = yield self._execute(
+        rows = await self.db.execute(
             "_censor_redactions_fetch", None, sql, before_ts, 100
         )
 
         updates = []
 
         for redaction_id, event_id in rows:
-            redaction_event = yield self.get_event(redaction_id, allow_none=True)
-            original_event = yield self.get_event(
+            redaction_event = await self.get_event(redaction_id, allow_none=True)
+            original_event = await self.get_event(
                 event_id, allow_rejected=True, allow_none=True
             )
 
@@ -1100,21 +1111,32 @@ class EventsStore(
         def _update_censor_txn(txn):
             for redaction_id, event_id, pruned_json in updates:
                 if pruned_json:
-                    self._simple_update_one_txn(
-                        txn,
-                        table="event_json",
-                        keyvalues={"event_id": event_id},
-                        updatevalues={"json": pruned_json},
-                    )
+                    self._censor_event_txn(txn, event_id, pruned_json)
 
-                self._simple_update_one_txn(
+                self.db.simple_update_one_txn(
                     txn,
                     table="redactions",
                     keyvalues={"event_id": redaction_id},
                     updatevalues={"have_censored": True},
                 )
 
-        yield self.runInteraction("_update_censor_txn", _update_censor_txn)
+        await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+
+    def _censor_event_txn(self, txn, event_id, pruned_json):
+        """Censor an event by replacing its JSON in the event_json table with the
+        provided pruned JSON.
+
+        Args:
+            txn (LoggingTransaction): The database transaction.
+            event_id (str): The ID of the event to censor.
+            pruned_json (str): The pruned JSON
+        """
+        self.db.simple_update_one_txn(
+            txn,
+            table="event_json",
+            keyvalues={"event_id": event_id},
+            updatevalues={"json": pruned_json},
+        )
 
     @defer.inlineCallbacks
     def count_daily_messages(self):
@@ -1135,7 +1157,7 @@ class EventsStore(
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.runInteraction("count_messages", _count_messages)
+        ret = yield self.db.runInteraction("count_messages", _count_messages)
         return ret
 
     @defer.inlineCallbacks
@@ -1156,7 +1178,7 @@ class EventsStore(
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
+        ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
         return ret
 
     @defer.inlineCallbacks
@@ -1171,7 +1193,7 @@ class EventsStore(
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.runInteraction("count_daily_active_rooms", _count)
+        ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
         return ret
 
     def get_current_backfill_token(self):
@@ -1223,7 +1245,7 @@ class EventsStore(
 
             return new_event_updates
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_new_forward_event_rows", get_all_new_forward_event_rows
         )
 
@@ -1268,7 +1290,7 @@ class EventsStore(
 
             return new_event_updates
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
         )
 
@@ -1361,7 +1383,7 @@ class EventsStore(
                 backward_ex_outliers,
             )
 
-        return self.runInteraction("get_all_new_events", get_all_new_events_txn)
+        return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
 
     def purge_history(self, room_id, token, delete_local_events):
         """Deletes room history before a certain point
@@ -1381,7 +1403,7 @@ class EventsStore(
             deleted events.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "purge_history",
             self._purge_history_txn,
             room_id,
@@ -1479,7 +1501,7 @@ class EventsStore(
 
         # We do joins against events_to_purge for e.g. calculating state
         # groups to purge, etc., so lets make an index.
-        txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)")
+        txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
 
         txn.execute("SELECT event_id, should_delete FROM events_to_purge")
         event_rows = txn.fetchall()
@@ -1629,7 +1651,7 @@ class EventsStore(
             Deferred[List[int]]: The list of state groups to delete.
         """
 
-        return self.runInteraction("purge_room", self._purge_room_txn, room_id)
+        return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
 
     def _purge_room_txn(self, txn, room_id):
         # First we fetch all the state groups that should be deleted, before
@@ -1748,7 +1770,7 @@ class EventsStore(
                 to delete.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "purge_unreferenced_state_groups",
             self._purge_unreferenced_state_groups,
             room_id,
@@ -1760,7 +1782,7 @@ class EventsStore(
             "[purge] found %i state groups to delete", len(state_groups_to_delete)
         )
 
-        rows = self._simple_select_many_txn(
+        rows = self.db.simple_select_many_txn(
             txn,
             table="state_group_edges",
             column="prev_state_group",
@@ -1787,15 +1809,15 @@ class EventsStore(
             curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
             curr_state = curr_state[sg]
 
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="state_groups_state", keyvalues={"state_group": sg}
             )
 
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="state_group_edges", keyvalues={"state_group": sg}
             )
 
-            self._simple_insert_many_txn(
+            self.db.simple_insert_many_txn(
                 txn,
                 table="state_groups_state",
                 values=[
@@ -1832,7 +1854,7 @@ class EventsStore(
             state group.
         """
 
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="state_group_edges",
             column="prev_state_group",
             iterable=state_groups,
@@ -1851,7 +1873,7 @@ class EventsStore(
             state_groups_to_delete (list[int]): State groups to delete
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "purge_room_state",
             self._purge_room_state_txn,
             room_id,
@@ -1862,7 +1884,7 @@ class EventsStore(
         # first we have to delete the state groups states
         logger.info("[purge] removing %s from state_groups_state", room_id)
 
-        self._simple_delete_many_txn(
+        self.db.simple_delete_many_txn(
             txn,
             table="state_groups_state",
             column="state_group",
@@ -1873,7 +1895,7 @@ class EventsStore(
         # ... and the state group edges
         logger.info("[purge] removing %s from state_group_edges", room_id)
 
-        self._simple_delete_many_txn(
+        self.db.simple_delete_many_txn(
             txn,
             table="state_group_edges",
             column="state_group",
@@ -1884,7 +1906,7 @@ class EventsStore(
         # ... and the state groups
         logger.info("[purge] removing %s from state_groups", room_id)
 
-        self._simple_delete_many_txn(
+        self.db.simple_delete_many_txn(
             txn,
             table="state_groups",
             column="id",
@@ -1901,7 +1923,7 @@ class EventsStore(
 
     @cachedInlineCallbacks(max_entries=5000)
     def _get_event_ordering(self, event_id):
-        res = yield self._simple_select_one(
+        res = yield self.db.simple_select_one(
             table="events",
             retcols=["topological_ordering", "stream_ordering"],
             keyvalues={"event_id": event_id},
@@ -1924,7 +1946,7 @@ class EventsStore(
             txn.execute(sql, (from_token, to_token, limit))
             return txn.fetchall()
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_current_state_deltas",
             get_all_updated_current_state_deltas_txn,
         )
@@ -1942,7 +1964,7 @@ class EventsStore(
             room_id (str): The ID of the room the event was sent to.
             topological_ordering (int): The position of the event in the room's topology.
         """
-        return self._simple_insert_many_txn(
+        return self.db.simple_insert_many_txn(
             txn=txn,
             table="event_labels",
             values=[
@@ -1956,6 +1978,101 @@ class EventsStore(
             ],
         )
 
+    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+        """Save the expiry timestamp associated with a given event ID.
+
+        Args:
+            txn (LoggingTransaction): The database transaction to use.
+            event_id (str): The event ID the expiry timestamp is associated with.
+            expiry_ts (int): The timestamp at which to expire (delete) the event.
+        """
+        return self.db.simple_insert_txn(
+            txn=txn,
+            table="event_expiry",
+            values={"event_id": event_id, "expiry_ts": expiry_ts},
+        )
+
+    @defer.inlineCallbacks
+    def expire_event(self, event_id):
+        """Retrieve and expire an event that has expired, and delete its associated
+        expiry timestamp. If the event can't be retrieved, delete its associated
+        timestamp so we don't try to expire it again in the future.
+
+        Args:
+             event_id (str): The ID of the event to delete.
+        """
+        # Try to retrieve the event's content from the database or the event cache.
+        event = yield self.get_event(event_id)
+
+        def delete_expired_event_txn(txn):
+            # Delete the expiry timestamp associated with this event from the database.
+            self._delete_event_expiry_txn(txn, event_id)
+
+            if not event:
+                # If we can't find the event, log a warning and delete the expiry date
+                # from the database so that we don't try to expire it again in the
+                # future.
+                logger.warning(
+                    "Can't expire event %s because we don't have it.", event_id
+                )
+                return
+
+            # Prune the event's dict then convert it to JSON.
+            pruned_json = encode_json(prune_event_dict(event.get_dict()))
+
+            # Update the event_json table to replace the event's JSON with the pruned
+            # JSON.
+            self._censor_event_txn(txn, event.event_id, pruned_json)
+
+            # We need to invalidate the event cache entry for this event because we
+            # changed its content in the database. We can't call
+            # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+            # right type.
+            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            # Send that invalidation to replication so that other workers also invalidate
+            # the event cache.
+            self._send_invalidation_to_replication(
+                txn, "_get_event_cache", (event.event_id,)
+            )
+
+        yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+    def _delete_event_expiry_txn(self, txn, event_id):
+        """Delete the expiry timestamp associated with an event ID without deleting the
+        actual event.
+
+        Args:
+            txn (LoggingTransaction): The transaction to use to perform the deletion.
+            event_id (str): The event ID to delete the associated expiry timestamp of.
+        """
+        return self.db.simple_delete_txn(
+            txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+        )
+
+    def get_next_event_to_expire(self):
+        """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+        table, or None if there's no more event to expire.
+
+        Returns: Deferred[Optional[Tuple[str, int]]]
+            A tuple containing the event ID as its first element and an expiry timestamp
+            as its second one, if there's at least one row in the event_expiry table.
+            None otherwise.
+        """
+
+        def get_next_event_to_expire_txn(txn):
+            txn.execute(
+                """
+                SELECT event_id, expiry_ts FROM event_expiry
+                ORDER BY expiry_ts ASC LIMIT 1
+                """
+            )
+
+            return txn.fetchone()
+
+        return self.db.runInteraction(
+            desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+        )
+
 
 AllNewEventsResult = namedtuple(
     "AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index aa87f9abc5..5177b71016 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -22,30 +22,30 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.constants import EventContentFields
-from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 
 logger = logging.getLogger(__name__)
 
 
-class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
+class EventsBackgroundUpdatesStore(SQLBaseStore):
 
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
     EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
     DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
 
-    def __init__(self, db_conn, hs):
-        super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
             self._background_reindex_fields_sender,
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_contains_url_index",
             index_name="event_contains_url_index",
             table="events",
@@ -56,7 +56,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
         # an event_id index on event_search is useful for the purge_history
         # api. Plus it means we get to enforce some integrity with a UNIQUE
         # clause
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_search_event_id_idx",
             index_name="event_search_event_id_idx",
             table="event_search",
@@ -65,16 +65,16 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
             psql_only=True,
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "redactions_received_ts", self._redactions_received_ts
         )
 
         # This index gets deleted in `event_fix_redactions_bytes` update
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "event_fix_redactions_bytes_create_index",
             index_name="redactions_censored_redacts",
             table="redactions",
@@ -82,14 +82,22 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
             where_clause="have_censored",
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "event_fix_redactions_bytes", self._event_fix_redactions_bytes
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "event_store_labels", self._event_store_labels
         )
 
+        self.db.updates.register_background_index_update(
+            "redactions_have_censored_ts_idx",
+            index_name="redactions_have_censored_ts",
+            table="redactions",
+            columns=["received_ts"],
+            where_clause="NOT have_censored",
+        )
+
     @defer.inlineCallbacks
     def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
@@ -145,18 +153,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 "rows_inserted": rows_inserted + len(rows),
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
             )
 
             return len(rows)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+            yield self.db.updates._end_background_update(
+                self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
+            )
 
         return result
 
@@ -189,7 +199,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
             for chunk in chunks:
-                ev_rows = self._simple_select_many_txn(
+                ev_rows = self.db.simple_select_many_txn(
                     txn,
                     table="event_json",
                     column="event_id",
@@ -222,18 +232,20 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 "rows_inserted": rows_inserted + len(rows_to_update),
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
             )
 
             return len(rows_to_update)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
+            yield self.db.updates._end_background_update(
+                self.EVENT_ORIGIN_SERVER_TS_NAME
+            )
 
         return result
 
@@ -366,7 +378,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             to_delete.intersection_update(original_set)
 
-            deleted = self._simple_delete_many_txn(
+            deleted = self.db.simple_delete_many_txn(
                 txn=txn,
                 table="event_forward_extremities",
                 column="event_id",
@@ -382,7 +394,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             if deleted:
                 # We now need to invalidate the caches of these rooms
-                rows = self._simple_select_many_txn(
+                rows = self.db.simple_select_many_txn(
                     txn,
                     table="events",
                     column="event_id",
@@ -396,7 +408,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                         self.get_latest_event_ids_in_room.invalidate, (room_id,)
                     )
 
-            self._simple_delete_many_txn(
+            self.db.simple_delete_many_txn(
                 txn=txn,
                 table="_extremities_to_check",
                 column="event_id",
@@ -406,17 +418,19 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             return len(original_set)
 
-        num_handled = yield self.runInteraction(
+        num_handled = yield self.db.runInteraction(
             "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
         )
 
         if not num_handled:
-            yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES)
+            yield self.db.updates._end_background_update(
+                self.DELETE_SOFT_FAILED_EXTREMITIES
+            )
 
             def _drop_table_txn(txn):
                 txn.execute("DROP TABLE _extremities_to_check")
 
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
             )
 
@@ -464,18 +478,18 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "redactions_received_ts", {"last_event_id": upper_event_id}
             )
 
             return len(rows)
 
-        count = yield self.runInteraction(
+        count = yield self.db.runInteraction(
             "_redactions_received_ts", _redactions_received_ts_txn
         )
 
         if not count:
-            yield self._end_background_update("redactions_received_ts")
+            yield self.db.updates._end_background_update("redactions_received_ts")
 
         return count
 
@@ -501,11 +515,11 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
 
             txn.execute("DROP INDEX redactions_censored_redacts")
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
         )
 
-        yield self._end_background_update("event_fix_redactions_bytes")
+        yield self.db.updates._end_background_update("event_fix_redactions_bytes")
 
         return 1
 
@@ -533,7 +547,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 try:
                     event_json = json.loads(event_json_raw)
 
-                    self._simple_insert_many_txn(
+                    self.db.simple_insert_many_txn(
                         txn=txn,
                         table="event_labels",
                         values=[
@@ -559,17 +573,17 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
                 nbrows += 1
                 last_row_event_id = event_id
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "event_store_labels", {"last_event_id": last_row_event_id}
             )
 
             return nbrows
 
-        num_rows = yield self.runInteraction(
+        num_rows = yield self.db.runInteraction(
             desc="event_store_labels", func=_event_store_labels_txn
         )
 
         if not num_rows:
-            yield self._end_background_update("event_store_labels")
+            yield self.db.updates._end_background_update("event_store_labels")
 
         return num_rows
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 4c4b76bd93..9ee117ce0f 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -17,6 +17,7 @@ from __future__ import division
 
 import itertools
 import logging
+import threading
 from collections import namedtuple
 
 from canonicaljson import json
@@ -32,8 +33,10 @@ from synapse.events.utils import prune_event
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.types import get_domain_from_id
 from synapse.util import batch_iter
+from synapse.util.caches.descriptors import Cache
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
@@ -53,6 +56,17 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
 class EventsWorkerStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(EventsWorkerStore, self).__init__(database, db_conn, hs)
+
+        self._get_event_cache = Cache(
+            "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+        )
+
+        self._event_fetch_lock = threading.Condition()
+        self._event_fetch_list = []
+        self._event_fetch_ongoing = 0
+
     def get_received_ts(self, event_id):
         """Get received_ts (when it was persisted) for the event.
 
@@ -65,7 +79,7 @@ class EventsWorkerStore(SQLBaseStore):
             Deferred[int|None]: Timestamp in milliseconds, or None for events
             that were persisted before received_ts was implemented.
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="events",
             keyvalues={"event_id": event_id},
             retcol="received_ts",
@@ -104,7 +118,7 @@ class EventsWorkerStore(SQLBaseStore):
 
             return ts
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_approximate_received_ts", _get_approximate_received_ts_txn
         )
 
@@ -439,7 +453,7 @@ class EventsWorkerStore(SQLBaseStore):
                     event_id for events, _ in event_list for event_id in events
                 )
 
-                row_dict = self._new_transaction(
+                row_dict = self.db.new_transaction(
                     conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
                 )
 
@@ -571,7 +585,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         if should_start:
             run_as_background_process(
-                "fetch_events", self.runWithConnection, self._do_fetch
+                "fetch_events", self.db.runWithConnection, self._do_fetch
             )
 
         logger.debug("Loading %d events: %s", len(events), events)
@@ -732,7 +746,7 @@ class EventsWorkerStore(SQLBaseStore):
         """Given a list of event ids, check if we have already processed and
         stored them as non outliers.
         """
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="events",
             retcols=("event_id",),
             column="event_id",
@@ -767,42 +781,10 @@ class EventsWorkerStore(SQLBaseStore):
         # break the input up into chunks of 100
         input_iterator = iter(event_ids)
         for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
-            yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
-        return results
-
-    def get_seen_events_with_rejections(self, event_ids):
-        """Given a list of event ids, check if we rejected them.
-
-        Args:
-            event_ids (list[str])
-
-        Returns:
-            Deferred[dict[str, str|None):
-                Has an entry for each event id we already have seen. Maps to
-                the rejected reason string if we rejected the event, else maps
-                to None.
-        """
-        if not event_ids:
-            return defer.succeed({})
-
-        def f(txn):
-            sql = (
-                "SELECT e.event_id, reason FROM events as e "
-                "LEFT JOIN rejections as r ON e.event_id = r.event_id "
-                "WHERE e.event_id = ?"
+            yield self.db.runInteraction(
+                "have_seen_events", have_seen_events_txn, chunk
             )
-
-            res = {}
-            for event_id in event_ids:
-                txn.execute(sql, (event_id,))
-                row = txn.fetchone()
-                if row:
-                    _, rejected = row
-                    res[event_id] = rejected
-
-            return res
-
-        return self.runInteraction("get_seen_events_with_rejections", f)
+        return results
 
     def _get_total_state_event_counts_txn(self, txn, room_id):
         """
@@ -828,7 +810,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             Deferred[int]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_total_state_event_counts",
             self._get_total_state_event_counts_txn,
             room_id,
@@ -853,7 +835,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             Deferred[int]
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_current_state_event_counts",
             self._get_current_state_event_counts_txn,
             room_id,
diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py
index a2a2a67927..342d6622a4 100644
--- a/synapse/storage/data_stores/main/filtering.py
+++ b/synapse/storage/data_stores/main/filtering.py
@@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore):
         except ValueError:
             raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
 
-        def_json = yield self._simple_select_one_onecol(
+        def_json = yield self.db.simple_select_one_onecol(
             table="user_filters",
             keyvalues={"user_id": user_localpart, "filter_id": filter_id},
             retcol="filter_json",
@@ -55,7 +55,7 @@ class FilteringStore(SQLBaseStore):
             if filter_id_response is not None:
                 return filter_id_response[0]
 
-            sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
+            sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
             txn.execute(sql, (user_localpart,))
             max_id = txn.fetchone()[0]
             if max_id is None:
@@ -71,4 +71,4 @@ class FilteringStore(SQLBaseStore):
 
             return filter_id
 
-        return self.runInteraction("add_user_filter", _do_txn)
+        return self.db.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 5ded539af8..6acd45e9f3 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore):
          * "invite"
          * "open"
         """
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
             updatevalues={"join_policy": join_policy},
@@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore):
         )
 
     def get_group(self, group_id):
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="groups",
             keyvalues={"group_id": group_id},
             retcols=(
@@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore):
         if not include_private:
             keyvalues["is_public"] = True
 
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             table="group_users",
             keyvalues=keyvalues,
             retcols=("user_id", "is_public", "is_admin"),
@@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore):
     def get_invited_users_in_group(self, group_id):
         # TODO: Pagination
 
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id},
             retcol="user_id",
@@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore):
         if not include_private:
             keyvalues["is_public"] = True
 
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             table="group_rooms",
             keyvalues=keyvalues,
             retcols=("room_id", "is_public"),
@@ -153,10 +153,12 @@ class GroupServerStore(SQLBaseStore):
 
             return rooms, categories
 
-        return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn)
+        return self.db.runInteraction(
+            "get_rooms_for_summary", _get_rooms_for_summary_txn
+        )
 
     def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_room_to_summary",
             self._add_room_to_summary_txn,
             group_id,
@@ -180,7 +182,7 @@ class GroupServerStore(SQLBaseStore):
                 an order of 1 will put the room first. Otherwise, the room gets
                 added to the end.
         """
-        room_in_group = self._simple_select_one_onecol_txn(
+        room_in_group = self.db.simple_select_one_onecol_txn(
             txn,
             table="group_rooms",
             keyvalues={"group_id": group_id, "room_id": room_id},
@@ -193,7 +195,7 @@ class GroupServerStore(SQLBaseStore):
         if category_id is None:
             category_id = _DEFAULT_CATEGORY_ID
         else:
-            cat_exists = self._simple_select_one_onecol_txn(
+            cat_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_room_categories",
                 keyvalues={"group_id": group_id, "category_id": category_id},
@@ -204,7 +206,7 @@ class GroupServerStore(SQLBaseStore):
                 raise SynapseError(400, "Category doesn't exist")
 
             # TODO: Check category is part of summary already
-            cat_exists = self._simple_select_one_onecol_txn(
+            cat_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_summary_room_categories",
                 keyvalues={"group_id": group_id, "category_id": category_id},
@@ -224,7 +226,7 @@ class GroupServerStore(SQLBaseStore):
                     (group_id, category_id, group_id, category_id),
                 )
 
-        existing = self._simple_select_one_txn(
+        existing = self.db.simple_select_one_txn(
             txn,
             table="group_summary_rooms",
             keyvalues={
@@ -257,7 +259,7 @@ class GroupServerStore(SQLBaseStore):
                 to_update["room_order"] = order
             if is_public is not None:
                 to_update["is_public"] = is_public
-            self._simple_update_txn(
+            self.db.simple_update_txn(
                 txn,
                 table="group_summary_rooms",
                 keyvalues={
@@ -271,7 +273,7 @@ class GroupServerStore(SQLBaseStore):
             if is_public is None:
                 is_public = True
 
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="group_summary_rooms",
                 values={
@@ -287,7 +289,7 @@ class GroupServerStore(SQLBaseStore):
         if category_id is None:
             category_id = _DEFAULT_CATEGORY_ID
 
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="group_summary_rooms",
             keyvalues={
                 "group_id": group_id,
@@ -299,7 +301,7 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_group_categories(self, group_id):
-        rows = yield self._simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="group_room_categories",
             keyvalues={"group_id": group_id},
             retcols=("category_id", "is_public", "profile"),
@@ -316,7 +318,7 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_group_category(self, group_id, category_id):
-        category = yield self._simple_select_one(
+        category = yield self.db.simple_select_one(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             retcols=("is_public", "profile"),
@@ -343,7 +345,7 @@ class GroupServerStore(SQLBaseStore):
         else:
             update_values["is_public"] = is_public
 
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             values=update_values,
@@ -352,7 +354,7 @@ class GroupServerStore(SQLBaseStore):
         )
 
     def remove_group_category(self, group_id, category_id):
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="group_room_categories",
             keyvalues={"group_id": group_id, "category_id": category_id},
             desc="remove_group_category",
@@ -360,7 +362,7 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_group_roles(self, group_id):
-        rows = yield self._simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="group_roles",
             keyvalues={"group_id": group_id},
             retcols=("role_id", "is_public", "profile"),
@@ -377,7 +379,7 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_group_role(self, group_id, role_id):
-        role = yield self._simple_select_one(
+        role = yield self.db.simple_select_one(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             retcols=("is_public", "profile"),
@@ -404,7 +406,7 @@ class GroupServerStore(SQLBaseStore):
         else:
             update_values["is_public"] = is_public
 
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             values=update_values,
@@ -413,14 +415,14 @@ class GroupServerStore(SQLBaseStore):
         )
 
     def remove_group_role(self, group_id, role_id):
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="group_roles",
             keyvalues={"group_id": group_id, "role_id": role_id},
             desc="remove_group_role",
         )
 
     def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_user_to_summary",
             self._add_user_to_summary_txn,
             group_id,
@@ -444,7 +446,7 @@ class GroupServerStore(SQLBaseStore):
                 an order of 1 will put the user first. Otherwise, the user gets
                 added to the end.
         """
-        user_in_group = self._simple_select_one_onecol_txn(
+        user_in_group = self.db.simple_select_one_onecol_txn(
             txn,
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
@@ -457,7 +459,7 @@ class GroupServerStore(SQLBaseStore):
         if role_id is None:
             role_id = _DEFAULT_ROLE_ID
         else:
-            role_exists = self._simple_select_one_onecol_txn(
+            role_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_roles",
                 keyvalues={"group_id": group_id, "role_id": role_id},
@@ -468,7 +470,7 @@ class GroupServerStore(SQLBaseStore):
                 raise SynapseError(400, "Role doesn't exist")
 
             # TODO: Check role is part of the summary already
-            role_exists = self._simple_select_one_onecol_txn(
+            role_exists = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_summary_roles",
                 keyvalues={"group_id": group_id, "role_id": role_id},
@@ -488,7 +490,7 @@ class GroupServerStore(SQLBaseStore):
                     (group_id, role_id, group_id, role_id),
                 )
 
-        existing = self._simple_select_one_txn(
+        existing = self.db.simple_select_one_txn(
             txn,
             table="group_summary_users",
             keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@@ -517,7 +519,7 @@ class GroupServerStore(SQLBaseStore):
                 to_update["user_order"] = order
             if is_public is not None:
                 to_update["is_public"] = is_public
-            self._simple_update_txn(
+            self.db.simple_update_txn(
                 txn,
                 table="group_summary_users",
                 keyvalues={
@@ -531,7 +533,7 @@ class GroupServerStore(SQLBaseStore):
             if is_public is None:
                 is_public = True
 
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="group_summary_users",
                 values={
@@ -547,7 +549,7 @@ class GroupServerStore(SQLBaseStore):
         if role_id is None:
             role_id = _DEFAULT_ROLE_ID
 
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="group_summary_users",
             keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
             desc="remove_user_from_summary",
@@ -561,7 +563,7 @@ class GroupServerStore(SQLBaseStore):
             Deferred[list[str]]: A twisted.Deferred containing a list of group ids
                 containing this room
         """
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="group_rooms",
             keyvalues={"room_id": room_id},
             retcol="group_id",
@@ -625,12 +627,12 @@ class GroupServerStore(SQLBaseStore):
 
             return users, roles
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_users_for_summary_by_role", _get_users_for_summary_txn
         )
 
     def is_user_in_group(self, user_id, group_id):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
@@ -639,7 +641,7 @@ class GroupServerStore(SQLBaseStore):
         ).addCallback(lambda r: bool(r))
 
     def is_user_admin_in_group(self, group_id, user_id):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="group_users",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="is_admin",
@@ -650,7 +652,7 @@ class GroupServerStore(SQLBaseStore):
     def add_group_invite(self, group_id, user_id):
         """Record that the group server has invited a user
         """
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="group_invites",
             values={"group_id": group_id, "user_id": user_id},
             desc="add_group_invite",
@@ -659,7 +661,7 @@ class GroupServerStore(SQLBaseStore):
     def is_user_invited_to_local_group(self, group_id, user_id):
         """Has the group server invited a user?
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcol="user_id",
@@ -682,7 +684,7 @@ class GroupServerStore(SQLBaseStore):
         """
 
         def _get_users_membership_in_group_txn(txn):
-            row = self._simple_select_one_txn(
+            row = self.db.simple_select_one_txn(
                 txn,
                 table="group_users",
                 keyvalues={"group_id": group_id, "user_id": user_id},
@@ -697,7 +699,7 @@ class GroupServerStore(SQLBaseStore):
                     "is_privileged": row["is_admin"],
                 }
 
-            row = self._simple_select_one_onecol_txn(
+            row = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="group_invites",
                 keyvalues={"group_id": group_id, "user_id": user_id},
@@ -710,7 +712,7 @@ class GroupServerStore(SQLBaseStore):
 
             return {}
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_users_membership_info_in_group", _get_users_membership_in_group_txn
         )
 
@@ -738,7 +740,7 @@ class GroupServerStore(SQLBaseStore):
         """
 
         def _add_user_to_group_txn(txn):
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="group_users",
                 values={
@@ -749,14 +751,14 @@ class GroupServerStore(SQLBaseStore):
                 },
             )
 
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_invites",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
 
             if local_attestation:
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="group_attestations_renewals",
                     values={
@@ -766,7 +768,7 @@ class GroupServerStore(SQLBaseStore):
                     },
                 )
             if remote_attestation:
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="group_attestations_remote",
                     values={
@@ -777,49 +779,49 @@ class GroupServerStore(SQLBaseStore):
                     },
                 )
 
-        return self.runInteraction("add_user_to_group", _add_user_to_group_txn)
+        return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn)
 
     def remove_user_from_group(self, group_id, user_id):
         def _remove_user_from_group_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_users",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_invites",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_attestations_renewals",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_attestations_remote",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_summary_users",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "remove_user_from_group", _remove_user_from_group_txn
         )
 
     def add_room_to_group(self, group_id, room_id, is_public):
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="group_rooms",
             values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
             desc="add_room_to_group",
         )
 
     def update_room_in_group_visibility(self, group_id, room_id, is_public):
-        return self._simple_update(
+        return self.db.simple_update(
             table="group_rooms",
             keyvalues={"group_id": group_id, "room_id": room_id},
             updatevalues={"is_public": is_public},
@@ -828,26 +830,26 @@ class GroupServerStore(SQLBaseStore):
 
     def remove_room_from_group(self, group_id, room_id):
         def _remove_room_from_group_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_rooms",
                 keyvalues={"group_id": group_id, "room_id": room_id},
             )
 
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="group_summary_rooms",
                 keyvalues={"group_id": group_id, "room_id": room_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "remove_room_from_group", _remove_room_from_group_txn
         )
 
     def get_publicised_groups_for_user(self, user_id):
         """Get all groups a user is publicising
         """
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="local_group_membership",
             keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
             retcol="group_id",
@@ -857,7 +859,7 @@ class GroupServerStore(SQLBaseStore):
     def update_group_publicity(self, group_id, user_id, publicise):
         """Update whether the user is publicising their membership of the group
         """
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="local_group_membership",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={"is_publicised": publicise},
@@ -893,12 +895,12 @@ class GroupServerStore(SQLBaseStore):
 
         def _register_user_group_membership_txn(txn, next_id):
             # TODO: Upsert?
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="local_group_membership",
                 keyvalues={"group_id": group_id, "user_id": user_id},
             )
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="local_group_membership",
                 values={
@@ -911,7 +913,7 @@ class GroupServerStore(SQLBaseStore):
                 },
             )
 
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="local_group_updates",
                 values={
@@ -930,7 +932,7 @@ class GroupServerStore(SQLBaseStore):
 
             if membership == "join":
                 if local_attestation:
-                    self._simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="group_attestations_renewals",
                         values={
@@ -940,7 +942,7 @@ class GroupServerStore(SQLBaseStore):
                         },
                     )
                 if remote_attestation:
-                    self._simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="group_attestations_remote",
                         values={
@@ -951,12 +953,12 @@ class GroupServerStore(SQLBaseStore):
                         },
                     )
             else:
-                self._simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn,
                     table="group_attestations_renewals",
                     keyvalues={"group_id": group_id, "user_id": user_id},
                 )
-                self._simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn,
                     table="group_attestations_remote",
                     keyvalues={"group_id": group_id, "user_id": user_id},
@@ -965,7 +967,7 @@ class GroupServerStore(SQLBaseStore):
             return next_id
 
         with self._group_updates_id_gen.get_next() as next_id:
-            res = yield self.runInteraction(
+            res = yield self.db.runInteraction(
                 "register_user_group_membership",
                 _register_user_group_membership_txn,
                 next_id,
@@ -976,7 +978,7 @@ class GroupServerStore(SQLBaseStore):
     def create_group(
         self, group_id, user_id, name, avatar_url, short_description, long_description
     ):
-        yield self._simple_insert(
+        yield self.db.simple_insert(
             table="groups",
             values={
                 "group_id": group_id,
@@ -991,7 +993,7 @@ class GroupServerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def update_group_profile(self, group_id, profile):
-        yield self._simple_update_one(
+        yield self.db.simple_update_one(
             table="groups",
             keyvalues={"group_id": group_id},
             updatevalues=profile,
@@ -1008,16 +1010,16 @@ class GroupServerStore(SQLBaseStore):
                 WHERE valid_until_ms <= ?
             """
             txn.execute(sql, (valid_until_ms,))
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_attestations_need_renewals", _get_attestations_need_renewals_txn
         )
 
     def update_attestation_renewal(self, group_id, user_id, attestation):
         """Update an attestation that we have renewed
         """
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
@@ -1027,7 +1029,7 @@ class GroupServerStore(SQLBaseStore):
     def update_remote_attestion(self, group_id, user_id, attestation):
         """Update an attestation that a remote has renewed
         """
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="group_attestations_remote",
             keyvalues={"group_id": group_id, "user_id": user_id},
             updatevalues={
@@ -1046,7 +1048,7 @@ class GroupServerStore(SQLBaseStore):
             group_id (str)
             user_id (str)
         """
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="group_attestations_renewals",
             keyvalues={"group_id": group_id, "user_id": user_id},
             desc="remove_attestation_renewal",
@@ -1057,7 +1059,7 @@ class GroupServerStore(SQLBaseStore):
         """Get the attestation that proves the remote agrees that the user is
         in the group.
         """
-        row = yield self._simple_select_one(
+        row = yield self.db.simple_select_one(
             table="group_attestations_remote",
             keyvalues={"group_id": group_id, "user_id": user_id},
             retcols=("valid_until_ms", "attestation_json"),
@@ -1072,7 +1074,7 @@ class GroupServerStore(SQLBaseStore):
         return None
 
     def get_joined_groups(self, user_id):
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="local_group_membership",
             keyvalues={"user_id": user_id, "membership": "join"},
             retcol="group_id",
@@ -1099,7 +1101,7 @@ class GroupServerStore(SQLBaseStore):
                 for row in txn
             ]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_groups_for_user", _get_all_groups_for_user_txn
         )
 
@@ -1109,7 +1111,7 @@ class GroupServerStore(SQLBaseStore):
             user_id, from_token
         )
         if not has_changed:
-            return []
+            return defer.succeed([])
 
         def _get_groups_changes_for_user_txn(txn):
             sql = """
@@ -1129,7 +1131,7 @@ class GroupServerStore(SQLBaseStore):
                 for group_id, membership, gtype, content_json in txn
             ]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_groups_changes_for_user", _get_groups_changes_for_user_txn
         )
 
@@ -1139,7 +1141,7 @@ class GroupServerStore(SQLBaseStore):
             from_token
         )
         if not has_changed:
-            return []
+            return defer.succeed([])
 
         def _get_all_groups_changes_txn(txn):
             sql = """
@@ -1154,7 +1156,7 @@ class GroupServerStore(SQLBaseStore):
                 for stream_id, group_id, user_id, gtype, content_json in txn
             ]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_groups_changes", _get_all_groups_changes_txn
         )
 
@@ -1188,8 +1190,8 @@ class GroupServerStore(SQLBaseStore):
             ]
 
             for table in tables:
-                self._simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn, table=table, keyvalues={"group_id": group_id}
                 )
 
-        return self.runInteraction("delete_group", _delete_group_txn)
+        return self.db.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
index ebc7db3ed6..6b12f5a75f 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/data_stores/main/keys.py
@@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore):
                 _get_keys(txn, batch)
             return keys
 
-        return self.runInteraction("get_server_verify_keys", _txn)
+        return self.db.runInteraction("get_server_verify_keys", _txn)
 
     def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
         """Stores NACL verification keys for remote servers.
@@ -127,9 +127,9 @@ class KeyStore(SQLBaseStore):
                 f((i,))
             return res
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "store_server_verify_keys",
-            self._simple_upsert_many_txn,
+            self.db.simple_upsert_many_txn,
             table="server_signature_keys",
             key_names=("server_name", "key_id"),
             key_values=key_values,
@@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore):
             ts_valid_until_ms (int): The time when this json stops being valid.
             key_json (bytes): The encoded JSON.
         """
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="server_keys_json",
             keyvalues={
                 "server_name": server_name,
@@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore):
                     keyvalues["key_id"] = key_id
                 if from_server is not None:
                     keyvalues["from_server"] = from_server
-                rows = self._simple_select_list_txn(
+                rows = self.db.simple_select_list_txn(
                     txn,
                     "server_keys_json",
                     keyvalues=keyvalues,
@@ -211,4 +211,4 @@ class KeyStore(SQLBaseStore):
                 results[(server_name, key_id, from_server)] = rows
             return results
 
-        return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
+        return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 84b5f3ad5e..80ca36dedf 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -12,14 +12,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 
 
-class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
+class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(MediaRepositoryBackgroundUpdateStore, self).__init__(
+            database, db_conn, hs
+        )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             update_name="local_media_repository_url_idx",
             index_name="local_media_repository_url_idx",
             table="local_media_repository",
@@ -31,15 +34,15 @@ class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
 class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
     """Persistence for attachments and avatars"""
 
-    def __init__(self, db_conn, hs):
-        super(MediaRepositoryStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
 
     def get_local_media(self, media_id):
         """Get the metadata for a local piece of media
         Returns:
             None if the media_id doesn't exist.
         """
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             "local_media_repository",
             {"media_id": media_id},
             (
@@ -64,7 +67,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         user_id,
         url_cache=None,
     ):
-        return self._simple_insert(
+        return self.db.simple_insert(
             "local_media_repository",
             {
                 "media_id": media_id,
@@ -124,12 +127,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
                 )
             )
 
-        return self.runInteraction("get_url_cache", get_url_cache_txn)
+        return self.db.runInteraction("get_url_cache", get_url_cache_txn)
 
     def store_url_cache(
         self, url, response_code, etag, expires_ts, og, media_id, download_ts
     ):
-        return self._simple_insert(
+        return self.db.simple_insert(
             "local_media_repository_url_cache",
             {
                 "url": url,
@@ -144,7 +147,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     def get_local_media_thumbnails(self, media_id):
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             "local_media_repository_thumbnails",
             {"media_id": media_id},
             (
@@ -166,7 +169,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         thumbnail_method,
         thumbnail_length,
     ):
-        return self._simple_insert(
+        return self.db.simple_insert(
             "local_media_repository_thumbnails",
             {
                 "media_id": media_id,
@@ -180,7 +183,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     def get_cached_remote_media(self, origin, media_id):
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             "remote_media_cache",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -205,7 +208,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         upload_name,
         filesystem_id,
     ):
-        return self._simple_insert(
+        return self.db.simple_insert(
             "remote_media_cache",
             {
                 "media_origin": origin,
@@ -250,10 +253,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
             txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
 
-        return self.runInteraction("update_cached_last_access_time", update_cache_txn)
+        return self.db.runInteraction(
+            "update_cached_last_access_time", update_cache_txn
+        )
 
     def get_remote_media_thumbnails(self, origin, media_id):
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             "remote_media_cache_thumbnails",
             {"media_origin": origin, "media_id": media_id},
             (
@@ -278,7 +283,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         thumbnail_method,
         thumbnail_length,
     ):
-        return self._simple_insert(
+        return self.db.simple_insert(
             "remote_media_cache_thumbnails",
             {
                 "media_origin": origin,
@@ -300,24 +305,24 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             " WHERE last_access_ts < ?"
         )
 
-        return self._execute(
-            "get_remote_media_before", self.cursor_to_dict, sql, before_ts
+        return self.db.execute(
+            "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts
         )
 
     def delete_remote_media(self, media_origin, media_id):
         def delete_remote_media_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 "remote_media_cache",
                 keyvalues={"media_origin": media_origin, "media_id": media_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 "remote_media_cache_thumbnails",
                 keyvalues={"media_origin": media_origin, "media_id": media_id},
             )
 
-        return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+        return self.db.runInteraction("delete_remote_media", delete_remote_media_txn)
 
     def get_expired_url_cache(self, now_ts):
         sql = (
@@ -331,18 +336,20 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             txn.execute(sql, (now_ts,))
             return [row[0] for row in txn]
 
-        return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+        return self.db.runInteraction(
+            "get_expired_url_cache", _get_expired_url_cache_txn
+        )
 
     def delete_url_cache(self, media_ids):
         if len(media_ids) == 0:
             return
 
-        sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
+        sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
 
         def _delete_url_cache_txn(txn):
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-        return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+        return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
 
     def get_url_cache_media_before(self, before_ts):
         sql = (
@@ -356,7 +363,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             txn.execute(sql, (before_ts,))
             return [row[0] for row in txn]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_url_cache_media_before", _get_url_cache_media_before_txn
         )
 
@@ -365,14 +372,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             return
 
         def _delete_url_cache_media_txn(txn):
-            sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
+            sql = "DELETE FROM local_media_repository WHERE media_id = ?"
 
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-            sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
+            sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
 
             txn.executemany(sql, [(media_id,) for media_id in media_ids])
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_url_cache_media", _delete_url_cache_media_txn
         )
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index b41c3d317a..27158534cb 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -17,6 +17,7 @@ import logging
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
@@ -27,13 +28,13 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
 
 
 class MonthlyActiveUsersStore(SQLBaseStore):
-    def __init__(self, dbconn, hs):
-        super(MonthlyActiveUsersStore, self).__init__(None, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
         self._clock = hs.get_clock()
         self.hs = hs
         # Do not add more reserved users than the total allowable number
-        self._new_transaction(
-            dbconn,
+        self.db.new_transaction(
+            db_conn,
             "initialise_mau_threepids",
             [],
             [],
@@ -146,7 +147,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
                     txn.execute(sql, query_args)
 
         reserved_users = yield self.get_registered_reserved_users()
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "reap_monthly_active_users", _reap_users, reserved_users
         )
         # It seems poor to invalidate the whole cache, Postgres supports
@@ -174,7 +175,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        return self.runInteraction("count_users", _count_users)
+        return self.db.runInteraction("count_users", _count_users)
 
     @defer.inlineCallbacks
     def get_registered_reserved_users(self):
@@ -217,7 +218,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
         if is_support:
             return
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
         )
 
@@ -261,7 +262,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
         # never be a big table and alternative approaches (batching multiple
         # upserts into a single txn) introduced a lot of extra complexity.
         # See https://github.com/matrix-org/synapse/issues/3854 for more
-        is_insert = self._simple_upsert_txn(
+        is_insert = self.db.simple_upsert_txn(
             txn,
             table="monthly_active_users",
             keyvalues={"user_id": user_id},
@@ -281,7 +282,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
 
         """
 
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="monthly_active_users",
             keyvalues={"user_id": user_id},
             retcol="timestamp",
diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py
index 79b40044d9..cc21437e92 100644
--- a/synapse/storage/data_stores/main/openid.py
+++ b/synapse/storage/data_stores/main/openid.py
@@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore
 
 class OpenIdStore(SQLBaseStore):
     def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="open_id_tokens",
             values={
                 "token": token,
@@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore):
             else:
                 return rows[0][0]
 
-        return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
+        return self.db.runInteraction(
+            "get_user_id_for_token", get_user_id_for_token_txn
+        )
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index 523ed6575e..a2c83e0867 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -29,7 +29,7 @@ class PresenceStore(SQLBaseStore):
         )
 
         with stream_ordering_manager as stream_orderings:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "update_presence",
                 self._update_presence_txn,
                 stream_orderings,
@@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore):
             txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
 
         # Actually insert new rows
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="presence_stream",
             values=[
@@ -88,7 +88,7 @@ class PresenceStore(SQLBaseStore):
             txn.execute(sql, (last_id, current_id))
             return txn.fetchall()
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_presence_updates", get_all_presence_updates_txn
         )
 
@@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore):
         inlineCallbacks=True,
     )
     def get_presence_for_users(self, user_ids):
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
             iterable=user_ids,
@@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore):
         return self._presence_id_gen.get_current_token()
 
     def allow_presence_visible(self, observed_localpart, observer_userid):
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="presence_allow_inbound",
             values={
                 "observed_user_id": observed_localpart,
@@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore):
         )
 
     def disallow_presence_visible(self, observed_localpart, observer_userid):
-        return self._simple_delete_one(
+        return self.db.simple_delete_one(
             table="presence_allow_inbound",
             keyvalues={
                 "observed_user_id": observed_localpart,
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index e4e8a1c1d6..2b52cf9c1a 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_profileinfo(self, user_localpart):
         try:
-            profile = yield self._simple_select_one(
+            profile = yield self.db.simple_select_one(
                 table="profiles",
                 keyvalues={"user_id": user_localpart},
                 retcols=("displayname", "avatar_url"),
@@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def get_profile_displayname(self, user_localpart):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="displayname",
@@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def get_profile_avatar_url(self, user_localpart):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             retcol="avatar_url",
@@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def get_from_remote_profile_cache(self, user_id):
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             retcols=("displayname", "avatar_url"),
@@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def create_profile(self, user_localpart):
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="profiles", values={"user_id": user_localpart}, desc="create_profile"
         )
 
     def set_profile_displayname(self, user_localpart, new_displayname):
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             updatevalues={"displayname": new_displayname},
@@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     def set_profile_avatar_url(self, user_localpart, new_avatar_url):
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="profiles",
             keyvalues={"user_id": user_localpart},
             updatevalues={"avatar_url": new_avatar_url},
@@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore):
         This should only be called when `is_subscribed_remote_profile_for_user`
         would return true for the user.
         """
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             values={
@@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore):
         )
 
     def update_remote_profile_cache(self, user_id, displayname, avatar_url):
-        return self._simple_update(
+        return self.db.simple_update(
             table="remote_profile_cache",
             keyvalues={"user_id": user_id},
             values={
@@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore):
         """
         subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
         if not subscribed:
-            yield self._simple_delete(
+            yield self.db.simple_delete(
                 table="remote_profile_cache",
                 keyvalues={"user_id": user_id},
                 desc="delete_remote_profile_cache",
@@ -144,9 +144,9 @@ class ProfileStore(ProfileWorkerStore):
 
             txn.execute(sql, (last_checked,))
 
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_remote_profile_cache_entries_that_expire",
             _get_remote_profile_cache_entries_that_expire_txn,
         )
@@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
     def is_subscribed_remote_profile_for_user(self, user_id):
         """Check whether we are interested in a remote user's profile.
         """
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="group_users",
             keyvalues={"user_id": user_id},
             retcol="user_id",
@@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore):
         if res:
             return True
 
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="group_invites",
             keyvalues={"user_id": user_id},
             retcol="user_id",
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index b520062d84..5ba13aa973 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -27,6 +27,7 @@ from synapse.storage.data_stores.main.appservice import ApplicationServiceWorker
 from synapse.storage.data_stores.main.pusher import PusherWorkerStore
 from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
 from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -72,10 +73,10 @@ class PushRulesWorkerStore(
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
 
-        push_rules_prefill, push_rules_id = self._get_cache_dict(
+        push_rules_prefill, push_rules_id = self.db.get_cache_dict(
             db_conn,
             "push_rules_stream",
             entity_column="user_id",
@@ -100,7 +101,7 @@ class PushRulesWorkerStore(
 
     @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_for_user(self, user_id):
-        rows = yield self._simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
             retcols=(
@@ -124,7 +125,7 @@ class PushRulesWorkerStore(
 
     @cachedInlineCallbacks(max_entries=5000)
     def get_push_rules_enabled_for_user(self, user_id):
-        results = yield self._simple_select_list(
+        results = yield self.db.simple_select_list(
             table="push_rules_enable",
             keyvalues={"user_name": user_id},
             retcols=("user_name", "rule_id", "enabled"),
@@ -146,7 +147,7 @@ class PushRulesWorkerStore(
                 (count,) = txn.fetchone()
                 return bool(count)
 
-            return self.runInteraction(
+            return self.db.runInteraction(
                 "have_push_rules_changed", have_push_rules_changed_txn
             )
 
@@ -162,7 +163,7 @@ class PushRulesWorkerStore(
 
         results = {user_id: [] for user_id in user_ids}
 
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="push_rules",
             column="user_name",
             iterable=user_ids,
@@ -320,7 +321,7 @@ class PushRulesWorkerStore(
 
         results = {user_id: {} for user_id in user_ids}
 
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="push_rules_enable",
             column="user_name",
             iterable=user_ids,
@@ -350,7 +351,7 @@ class PushRuleStore(PushRulesWorkerStore):
         with self._push_rules_stream_id_gen.get_next() as ids:
             stream_id, event_stream_ordering = ids
             if before or after:
-                yield self.runInteraction(
+                yield self.db.runInteraction(
                     "_add_push_rule_relative_txn",
                     self._add_push_rule_relative_txn,
                     stream_id,
@@ -364,7 +365,7 @@ class PushRuleStore(PushRulesWorkerStore):
                     after,
                 )
             else:
-                yield self.runInteraction(
+                yield self.db.runInteraction(
                     "_add_push_rule_highest_priority_txn",
                     self._add_push_rule_highest_priority_txn,
                     stream_id,
@@ -395,7 +396,7 @@ class PushRuleStore(PushRulesWorkerStore):
 
         relative_to_rule = before or after
 
-        res = self._simple_select_one_txn(
+        res = self.db.simple_select_one_txn(
             txn,
             table="push_rules",
             keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
@@ -499,7 +500,7 @@ class PushRuleStore(PushRulesWorkerStore):
         actions_json,
         update_stream=True,
     ):
-        """Specialised version of _simple_upsert_txn that picks a push_rule_id
+        """Specialised version of simple_upsert_txn that picks a push_rule_id
         using the _push_rule_id_gen if it needs to insert the rule. It assumes
         that the "push_rules" table is locked"""
 
@@ -518,7 +519,7 @@ class PushRuleStore(PushRulesWorkerStore):
             # We didn't update a row with the given rule_id so insert one
             push_rule_id = self._push_rule_id_gen.get_next()
 
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="push_rules",
                 values={
@@ -561,7 +562,7 @@ class PushRuleStore(PushRulesWorkerStore):
         """
 
         def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
-            self._simple_delete_one_txn(
+            self.db.simple_delete_one_txn(
                 txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
             )
 
@@ -571,7 +572,7 @@ class PushRuleStore(PushRulesWorkerStore):
 
         with self._push_rules_stream_id_gen.get_next() as ids:
             stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "delete_push_rule",
                 delete_push_rule_txn,
                 stream_id,
@@ -582,7 +583,7 @@ class PushRuleStore(PushRulesWorkerStore):
     def set_push_rule_enabled(self, user_id, rule_id, enabled):
         with self._push_rules_stream_id_gen.get_next() as ids:
             stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_set_push_rule_enabled_txn",
                 self._set_push_rule_enabled_txn,
                 stream_id,
@@ -596,7 +597,7 @@ class PushRuleStore(PushRulesWorkerStore):
         self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
     ):
         new_id = self._push_rules_enable_id_gen.get_next()
-        self._simple_upsert_txn(
+        self.db.simple_upsert_txn(
             txn,
             "push_rules_enable",
             {"user_name": user_id, "rule_id": rule_id},
@@ -636,7 +637,7 @@ class PushRuleStore(PushRulesWorkerStore):
                     update_stream=False,
                 )
             else:
-                self._simple_update_one_txn(
+                self.db.simple_update_one_txn(
                     txn,
                     "push_rules",
                     {"user_name": user_id, "rule_id": rule_id},
@@ -655,7 +656,7 @@ class PushRuleStore(PushRulesWorkerStore):
 
         with self._push_rules_stream_id_gen.get_next() as ids:
             stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "set_push_rule_actions",
                 set_push_rule_actions_txn,
                 stream_id,
@@ -675,7 +676,7 @@ class PushRuleStore(PushRulesWorkerStore):
         if data is not None:
             values.update(data)
 
-        self._simple_insert_txn(txn, "push_rules_stream", values=values)
+        self.db.simple_insert_txn(txn, "push_rules_stream", values=values)
 
         txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
         txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
@@ -699,7 +700,7 @@ class PushRuleStore(PushRulesWorkerStore):
             txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_push_rule_updates", get_all_push_rule_updates_txn
         )
 
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index d76861cdc0..f07309ef09 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def user_has_pusher(self, user_id):
-        ret = yield self._simple_select_one_onecol(
+        ret = yield self.db.simple_select_one_onecol(
             "pushers", {"user_name": user_id}, "id", allow_none=True
         )
         return ret is not None
@@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_pushers_by(self, keyvalues):
-        ret = yield self._simple_select_list(
+        ret = yield self.db.simple_select_list(
             "pushers",
             keyvalues,
             [
@@ -100,11 +100,11 @@ class PusherWorkerStore(SQLBaseStore):
     def get_all_pushers(self):
         def get_pushers(txn):
             txn.execute("SELECT * FROM pushers")
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
 
             return self._decode_pushers_rows(rows)
 
-        rows = yield self.runInteraction("get_all_pushers", get_pushers)
+        rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
         return rows
 
     def get_all_updated_pushers(self, last_id, current_id, limit):
@@ -134,7 +134,7 @@ class PusherWorkerStore(SQLBaseStore):
 
             return updated, deleted
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_pushers", get_all_updated_pushers_txn
         )
 
@@ -177,7 +177,7 @@ class PusherWorkerStore(SQLBaseStore):
 
             return results
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
         )
 
@@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
         inlineCallbacks=True,
     )
     def get_if_users_have_pushers(self, user_ids):
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="pushers",
             column="user_name",
             iterable=user_ids,
@@ -229,8 +229,8 @@ class PusherStore(PusherWorkerStore):
     ):
         with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
-            # (app_id, pushkey, user_name) so _simple_upsert will retry
-            yield self._simple_upsert(
+            # (app_id, pushkey, user_name) so simple_upsert will retry
+            yield self.db.simple_upsert(
                 table="pushers",
                 keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
                 values={
@@ -255,7 +255,7 @@ class PusherStore(PusherWorkerStore):
 
             if user_has_pusher is not True:
                 # invalidate, since we the user might not have had a pusher before
-                yield self.runInteraction(
+                yield self.db.runInteraction(
                     "add_pusher",
                     self._invalidate_cache_and_stream,
                     self.get_if_user_has_pusher,
@@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore):
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 
-            self._simple_delete_one_txn(
+            self.db.simple_delete_one_txn(
                 txn,
                 "pushers",
                 {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore):
             # it's possible for us to end up with duplicate rows for
             # (app_id, pushkey, user_id) at different stream_ids, but that
             # doesn't really matter.
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="deleted_pushers",
                 values={
@@ -290,13 +290,13 @@ class PusherStore(PusherWorkerStore):
             )
 
         with self._pushers_id_gen.get_next() as stream_id:
-            yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
+            yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
 
     @defer.inlineCallbacks
     def update_pusher_last_stream_ordering(
         self, app_id, pushkey, user_id, last_stream_ordering
     ):
-        yield self._simple_update_one(
+        yield self.db.simple_update_one(
             "pushers",
             {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             {"last_stream_ordering": last_stream_ordering},
@@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore):
         Returns:
             Deferred[bool]: True if the pusher still exists; False if it has been deleted.
         """
-        updated = yield self._simple_update(
+        updated = yield self.db.simple_update(
             table="pushers",
             keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             updatevalues={
@@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore):
 
     @defer.inlineCallbacks
     def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
-        yield self._simple_update(
+        yield self.db.simple_update(
             table="pushers",
             keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
             updatevalues={"failing_since": failing_since},
@@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore):
 
     @defer.inlineCallbacks
     def get_throttle_params_by_room(self, pusher_id):
-        res = yield self._simple_select_list(
+        res = yield self.db.simple_select_list(
             "pusher_throttle",
             {"pusher": pusher_id},
             ["room_id", "last_sent_ts", "throttle_ms"],
@@ -361,8 +361,8 @@ class PusherStore(PusherWorkerStore):
     @defer.inlineCallbacks
     def set_throttle_params(self, pusher_id, room_id, params):
         # no need to lock because `pusher_throttle` has a primary key on
-        # (pusher, room_id) so _simple_upsert will retry
-        yield self._simple_upsert(
+        # (pusher, room_id) so simple_upsert will retry
+        yield self.db.simple_upsert(
             "pusher_throttle",
             {"pusher": pusher_id, "room_id": room_id},
             params,
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 0c24430f28..96e54d145e 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -22,6 +22,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
     # the abstract methods being implemented.
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
 
         self._receipts_stream_cache = StreamChangeCache(
             "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -61,7 +62,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             table="receipts_linearized",
             keyvalues={"room_id": room_id, "receipt_type": receipt_type},
             retcols=("user_id", "event_id"),
@@ -70,7 +71,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     @cached(num_args=3)
     def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="receipts_linearized",
             keyvalues={
                 "room_id": room_id,
@@ -84,7 +85,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
     @cachedInlineCallbacks(num_args=2)
     def get_receipts_for_user(self, user_id, receipt_type):
-        rows = yield self._simple_select_list(
+        rows = yield self.db.simple_select_list(
             table="receipts_linearized",
             keyvalues={"user_id": user_id, "receipt_type": receipt_type},
             retcols=("room_id", "event_id"),
@@ -108,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             txn.execute(sql, (user_id,))
             return txn.fetchall()
 
-        rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
+        rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
         return {
             row[0]: {
                 "event_id": row[1],
@@ -187,11 +188,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
                 txn.execute(sql, (room_id, to_key))
 
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
 
             return rows
 
-        rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
+        rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
 
         if not rows:
             return []
@@ -237,9 +238,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
                 txn.execute(sql + clause, [to_key] + list(args))
 
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
+        txn_results = yield self.db.runInteraction(
+            "_get_linearized_receipts_for_rooms", f
+        )
 
         results = {}
         for row in txn_results:
@@ -280,9 +283,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 args.append(limit)
             txn.execute(sql, args)
 
-            return (r[0:5] + (json.loads(r[5]),) for r in txn)
+            return list(r[0:5] + (json.loads(r[5]),) for r in txn)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_all_updated_receipts", get_all_updated_receipts_txn
         )
 
@@ -313,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
 
 class ReceiptsStore(ReceiptsWorkerStore):
-    def __init__(self, db_conn, hs):
+    def __init__(self, database: Database, db_conn, hs):
         # We instantiate this first as the ReceiptsWorkerStore constructor
         # needs to be able to call get_max_receipt_stream_id
         self._receipts_id_gen = StreamIdGenerator(
             db_conn, "receipts_linearized", "stream_id"
         )
 
-        super(ReceiptsStore, self).__init__(db_conn, hs)
+        super(ReceiptsStore, self).__init__(database, db_conn, hs)
 
     def get_max_receipt_stream_id(self):
         return self._receipts_id_gen.get_current_token()
@@ -335,7 +338,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             otherwise, the rx timestamp of the event that the RR corresponds to
                 (or 0 if the event is unknown)
         """
-        res = self._simple_select_one_txn(
+        res = self.db.simple_select_one_txn(
             txn,
             table="events",
             retcols=["stream_ordering", "received_ts"],
@@ -388,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             (user_id, room_id, receipt_type),
         )
 
-        self._simple_delete_txn(
+        self.db.simple_delete_txn(
             txn,
             table="receipts_linearized",
             keyvalues={
@@ -398,7 +401,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             },
         )
 
-        self._simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="receipts_linearized",
             values={
@@ -453,13 +456,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 else:
                     raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
 
-            linearized_event_id = yield self.runInteraction(
+            linearized_event_id = yield self.db.runInteraction(
                 "insert_receipt_conv", graph_to_linear
             )
 
         stream_id_manager = self._receipts_id_gen.get_next()
         with stream_id_manager as stream_id:
-            event_ts = yield self.runInteraction(
+            event_ts = yield self.db.runInteraction(
                 "insert_linearized_receipt",
                 self.insert_linearized_receipt_txn,
                 room_id,
@@ -488,7 +491,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
         return stream_id, max_persisted_id
 
     def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "insert_graph_receipt",
             self.insert_graph_receipt_txn,
             room_id,
@@ -514,7 +517,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
             self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
         )
 
-        self._simple_delete_txn(
+        self.db.simple_delete_txn(
             txn,
             table="receipts_graph",
             keyvalues={
@@ -523,7 +526,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
                 "user_id": user_id,
             },
         )
-        self._simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="receipts_graph",
             values={
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 89147ad511..5e8ecac0ea 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -19,7 +19,6 @@ import logging
 import re
 
 from six import iterkeys
-from six.moves import range
 
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
@@ -27,8 +26,8 @@ from twisted.internet.defer import Deferred
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage import background_updates
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -38,15 +37,15 @@ logger = logging.getLogger(__name__)
 
 
 class RegistrationWorkerStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
 
         self.config = hs.config
         self.clock = hs.get_clock()
 
     @cached()
     def get_user_by_id(self, user_id):
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="users",
             keyvalues={"name": user_id},
             retcols=[
@@ -95,7 +94,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 including the keys `name`, `is_guest`, `device_id`, `token_id`,
                 `valid_until_ms`.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
         )
 
@@ -110,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 otherwise int representation of the timestamp (as a number of
                 milliseconds since epoch).
         """
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"user_id": user_id},
             retcol="expiration_ts_ms",
@@ -138,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
 
         def set_account_validity_for_user_txn(txn):
-            self._simple_update_txn(
+            self.db.simple_update_txn(
                 txn=txn,
                 table="account_validity",
                 keyvalues={"user_id": user_id},
@@ -152,7 +151,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 txn, self.get_expiration_ts_for_user, (user_id,)
             )
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "set_account_validity_for_user", set_account_validity_for_user_txn
         )
 
@@ -168,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Raises:
             StoreError: The provided token is already set for another user.
         """
-        yield self._simple_update_one(
+        yield self.db.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             updatevalues={"renewal_token": renewal_token},
@@ -185,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             defer.Deferred[str]: The ID of the user to which the token belongs.
         """
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"renewal_token": renewal_token},
             retcol="user_id",
@@ -204,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             defer.Deferred[str]: The renewal token associated with this user ID.
         """
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="account_validity",
             keyvalues={"user_id": user_id},
             retcol="renewal_token",
@@ -230,9 +229,9 @@ class RegistrationWorkerStore(SQLBaseStore):
             )
             values = [False, now_ms, renew_at]
             txn.execute(sql, values)
-            return self.cursor_to_dict(txn)
+            return self.db.cursor_to_dict(txn)
 
-        res = yield self.runInteraction(
+        res = yield self.db.runInteraction(
             "get_users_expiring_soon",
             select_users_txn,
             self.clock.time_msec(),
@@ -251,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             email_sent (bool): Flag which indicates whether a renewal email has been sent
                 to this user.
         """
-        yield self._simple_update_one(
+        yield self.db.simple_update_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             updatevalues={"email_sent": email_sent},
@@ -266,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Args:
             user_id (str): ID of the user to remove from the account validity table.
         """
-        yield self._simple_delete_one(
+        yield self.db.simple_delete_one(
             table="account_validity",
             keyvalues={"user_id": user_id},
             desc="delete_account_validity_for_user",
@@ -282,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns (bool):
             true iff the user is a server admin, false otherwise.
         """
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user.to_string()},
             retcol="admin",
@@ -300,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             admin (bool): true iff the user is to be a server admin,
                 false otherwise.
         """
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="users",
             keyvalues={"name": user.to_string()},
             updatevalues={"admin": 1 if admin else 0},
@@ -317,7 +316,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
         txn.execute(sql, (token,))
-        rows = self.cursor_to_dict(txn)
+        rows = self.db.cursor_to_dict(txn)
         if rows:
             return rows[0]
 
@@ -333,7 +332,9 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[bool]: True if user 'user_type' is null or empty string
         """
-        res = yield self.runInteraction("is_real_user", self.is_real_user_txn, user_id)
+        res = yield self.db.runInteraction(
+            "is_real_user", self.is_real_user_txn, user_id
+        )
         return res
 
     @cachedInlineCallbacks()
@@ -346,13 +347,13 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[bool]: True if user is of type UserTypes.SUPPORT
         """
-        res = yield self.runInteraction(
+        res = yield self.db.runInteraction(
             "is_support_user", self.is_support_user_txn, user_id
         )
         return res
 
     def is_real_user_txn(self, txn, user_id):
-        res = self._simple_select_one_onecol_txn(
+        res = self.db.simple_select_one_onecol_txn(
             txn=txn,
             table="users",
             keyvalues={"name": user_id},
@@ -362,7 +363,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         return res is None
 
     def is_support_user_txn(self, txn, user_id):
-        res = self._simple_select_one_onecol_txn(
+        res = self.db.simple_select_one_onecol_txn(
             txn=txn,
             table="users",
             keyvalues={"name": user_id},
@@ -377,13 +378,11 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
 
         def f(txn):
-            sql = (
-                "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
-            )
+            sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
             txn.execute(sql, (user_id,))
             return dict(txn)
 
-        return self.runInteraction("get_users_by_id_case_insensitive", f)
+        return self.db.runInteraction("get_users_by_id_case_insensitive", f)
 
     async def get_user_by_external_id(
         self, auth_provider: str, external_id: str
@@ -397,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             str|None: the mxid of the user, or None if they are not known
         """
-        return await self._simple_select_one_onecol(
+        return await self.db.simple_select_one_onecol(
             table="user_external_ids",
             keyvalues={"auth_provider": auth_provider, "external_id": external_id},
             retcol="user_id",
@@ -411,12 +410,12 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         def _count_users(txn):
             txn.execute("SELECT COUNT(*) AS users FROM users")
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if rows:
                 return rows[0]["users"]
             return 0
 
-        ret = yield self.runInteraction("count_users", _count_users)
+        ret = yield self.db.runInteraction("count_users", _count_users)
         return ret
 
     def count_daily_user_type(self):
@@ -448,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 results[row[0]] = row[1]
             return results
 
-        return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+        return self.db.runInteraction("count_daily_user_type", _count_daily_user_type)
 
     @defer.inlineCallbacks
     def count_nonbridged_users(self):
@@ -462,7 +461,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             (count,) = txn.fetchone()
             return count
 
-        ret = yield self.runInteraction("count_users", _count_users)
+        ret = yield self.db.runInteraction("count_users", _count_users)
         return ret
 
     @defer.inlineCallbacks
@@ -471,12 +470,12 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         def _count_users(txn):
             txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if rows:
                 return rows[0]["users"]
             return 0
 
-        ret = yield self.runInteraction("count_real_users", _count_users)
+        ret = yield self.db.runInteraction("count_real_users", _count_users)
         return ret
 
     @defer.inlineCallbacks
@@ -484,12 +483,8 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
         Gets the localpart of the next generated user ID.
 
-        Generated user IDs are integers, and we aim for them to be as small as
-        we can. Unfortunately, it's possible some of them are already taken by
-        existing users, and there may be gaps in the already taken range. This
-        function returns the start of the first allocatable gap. This is to
-        avoid the case of ID 1000 being pre-allocated and starting at 1001 while
-        0-999 are available.
+        Generated user IDs are integers, so we find the largest integer user ID
+        already taken and return that plus one.
         """
 
         def _find_next_generated_user_id(txn):
@@ -499,19 +494,18 @@ class RegistrationWorkerStore(SQLBaseStore):
 
             regex = re.compile(r"^@(\d+):")
 
-            found = set()
+            max_found = 0
 
             for (user_id,) in txn:
                 match = regex.search(user_id)
                 if match:
-                    found.add(int(match.group(1)))
-            for i in range(len(found) + 1):
-                if i not in found:
-                    return i
+                    max_found = max(int(match.group(1)), max_found)
+
+            return max_found + 1
 
         return (
             (
-                yield self.runInteraction(
+                yield self.db.runInteraction(
                     "find_next_generated_user_id", _find_next_generated_user_id
                 )
             )
@@ -528,7 +522,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[str|None]: user id or None if no user id/threepid mapping exists
         """
-        user_id = yield self.runInteraction(
+        user_id = yield self.db.runInteraction(
             "get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
         )
         return user_id
@@ -544,7 +538,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             str|None: user id or None if no user id/threepid mapping exists
         """
-        ret = self._simple_select_one_txn(
+        ret = self.db.simple_select_one_txn(
             txn,
             "user_threepids",
             {"medium": medium, "address": address},
@@ -557,7 +551,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
-        yield self._simple_upsert(
+        yield self.db.simple_upsert(
             "user_threepids",
             {"medium": medium, "address": address},
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
@@ -565,7 +559,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def user_get_threepids(self, user_id):
-        ret = yield self._simple_select_list(
+        ret = yield self.db.simple_select_list(
             "user_threepids",
             {"user_id": user_id},
             ["medium", "address", "validated_at", "added_at"],
@@ -574,7 +568,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         return ret
 
     def user_delete_threepid(self, user_id, medium, address):
-        return self._simple_delete(
+        return self.db.simple_delete(
             "user_threepids",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
             desc="user_delete_threepid",
@@ -587,7 +581,7 @@ class RegistrationWorkerStore(SQLBaseStore):
              user_id: The user id to delete all threepids of
 
         """
-        return self._simple_delete(
+        return self.db.simple_delete(
             "user_threepids",
             keyvalues={"user_id": user_id},
             desc="user_delete_threepids",
@@ -609,7 +603,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
         # We need to use an upsert, in case they user had already bound the
         # threepid
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="user_threepid_id_server",
             keyvalues={
                 "user_id": user_id,
@@ -635,7 +629,7 @@ class RegistrationWorkerStore(SQLBaseStore):
                 medium (str): The medium of the threepid (e.g "email")
                 address (str): The address of the threepid (e.g "bob@example.com")
         """
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             table="user_threepid_id_server",
             keyvalues={"user_id": user_id},
             retcols=["medium", "address"],
@@ -656,7 +650,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred
         """
-        return self._simple_delete(
+        return self.db.simple_delete(
             table="user_threepid_id_server",
             keyvalues={
                 "user_id": user_id,
@@ -679,7 +673,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         Returns:
             Deferred[list[str]]: Resolves to a list of identity servers
         """
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="user_threepid_id_server",
             keyvalues={"user_id": user_id, "medium": medium, "address": address},
             retcol="id_server",
@@ -697,7 +691,7 @@ class RegistrationWorkerStore(SQLBaseStore):
             defer.Deferred(bool): The requested value.
         """
 
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user_id},
             retcol="deactivated",
@@ -764,13 +758,13 @@ class RegistrationWorkerStore(SQLBaseStore):
             sql += " LIMIT 1"
 
             txn.execute(sql, list(keyvalues.values()))
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return None
 
             return rows[0]
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_threepid_validation_session", get_threepid_validation_session_txn
         )
 
@@ -784,39 +778,37 @@ class RegistrationWorkerStore(SQLBaseStore):
         """
 
         def delete_threepid_session_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="threepid_validation_token",
                 keyvalues={"session_id": session_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_threepid_session", delete_threepid_session_txn
         )
 
 
-class RegistrationBackgroundUpdateStore(
-    RegistrationWorkerStore, background_updates.BackgroundUpdateStore
-):
-    def __init__(self, db_conn, hs):
-        super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
+class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.clock = hs.get_clock()
         self.config = hs.config
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "access_tokens_device_index",
             index_name="access_tokens_device_id",
             table="access_tokens",
             columns=["user_id", "device_id"],
         )
 
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "users_creation_ts",
             index_name="users_creation_ts",
             table="users",
@@ -826,13 +818,13 @@ class RegistrationBackgroundUpdateStore(
         # we no longer use refresh tokens, but it's possible that some people
         # might have a background update queued to build this index. Just
         # clear the background update.
-        self.register_noop_background_update("refresh_tokens_device_index")
+        self.db.updates.register_noop_background_update("refresh_tokens_device_index")
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "user_threepids_grandfather", self._bg_user_threepids_grandfather
         )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
@@ -865,7 +857,7 @@ class RegistrationBackgroundUpdateStore(
                 (last_user, batch_size),
             )
 
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
 
             if not rows:
                 return True, 0
@@ -879,7 +871,7 @@ class RegistrationBackgroundUpdateStore(
 
             logger.info("Marked %d rows as deactivated", rows_processed_nb)
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
             )
 
@@ -888,12 +880,12 @@ class RegistrationBackgroundUpdateStore(
             else:
                 return False, len(rows)
 
-        end, nb_processed = yield self.runInteraction(
+        end, nb_processed = yield self.db.runInteraction(
             "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
         )
 
         if end:
-            yield self._end_background_update("users_set_deactivated_flag")
+            yield self.db.updates._end_background_update("users_set_deactivated_flag")
 
         return nb_processed
 
@@ -919,21 +911,29 @@ class RegistrationBackgroundUpdateStore(
             txn.executemany(sql, [(id_server,) for id_server in id_servers])
 
         if id_servers:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
             )
 
-        yield self._end_background_update("user_threepids_grandfather")
+        yield self.db.updates._end_background_update("user_threepids_grandfather")
 
         return 1
 
 
 class RegistrationStore(RegistrationBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RegistrationStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RegistrationStore, self).__init__(database, db_conn, hs)
 
         self._account_validity = hs.config.account_validity
 
+        if self._account_validity.enabled:
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "account_validity_set_expiration_dates",
+                self._set_expiration_date_when_missing,
+            )
+
         # Create a background job for culling expired 3PID validity tokens
         def start_cull():
             # run as a background process to make sure that the database transactions
@@ -961,7 +961,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         """
         next_id = self._access_tokens_id_gen.get_next()
 
-        yield self._simple_insert(
+        yield self.db.simple_insert(
             "access_tokens",
             {
                 "id": next_id,
@@ -1003,7 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         Raises:
             StoreError if the user_id could not be registered.
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "register_user",
             self._register_user,
             user_id,
@@ -1037,7 +1037,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 # Ensure that the guest user actually exists
                 # ``allow_none=False`` makes this raise an exception
                 # if the row isn't in the database.
-                self._simple_select_one_txn(
+                self.db.simple_select_one_txn(
                     txn,
                     "users",
                     keyvalues={"name": user_id, "is_guest": 1},
@@ -1045,7 +1045,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                     allow_none=False,
                 )
 
-                self._simple_update_one_txn(
+                self.db.simple_update_one_txn(
                     txn,
                     "users",
                     keyvalues={"name": user_id, "is_guest": 1},
@@ -1059,7 +1059,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                     },
                 )
             else:
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     "users",
                     values={
@@ -1114,7 +1114,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             external_id: id on that system
             user_id: complete mxid that it is mapped to
         """
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="user_external_ids",
             values={
                 "auth_provider": auth_provider,
@@ -1132,12 +1132,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         """
 
         def user_set_password_hash_txn(txn):
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn, "users", {"name": user_id}, {"password_hash": password_hash}
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
+        return self.db.runInteraction(
+            "user_set_password_hash", user_set_password_hash_txn
+        )
 
     def user_set_consent_version(self, user_id, consent_version):
         """Updates the user table to record privacy policy consent
@@ -1152,7 +1154,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         """
 
         def f(txn):
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="users",
                 keyvalues={"name": user_id},
@@ -1160,7 +1162,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.runInteraction("user_set_consent_version", f)
+        return self.db.runInteraction("user_set_consent_version", f)
 
     def user_set_consent_server_notice_sent(self, user_id, consent_version):
         """Updates the user table to record that we have sent the user a server
@@ -1176,7 +1178,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         """
 
         def f(txn):
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="users",
                 keyvalues={"name": user_id},
@@ -1184,7 +1186,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
             self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
 
-        return self.runInteraction("user_set_consent_server_notice_sent", f)
+        return self.db.runInteraction("user_set_consent_server_notice_sent", f)
 
     def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
         """
@@ -1230,11 +1232,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
             return tokens_and_devices
 
-        return self.runInteraction("user_delete_access_tokens", f)
+        return self.db.runInteraction("user_delete_access_tokens", f)
 
     def delete_access_token(self, access_token):
         def f(txn):
-            self._simple_delete_one_txn(
+            self.db.simple_delete_one_txn(
                 txn, table="access_tokens", keyvalues={"token": access_token}
             )
 
@@ -1242,11 +1244,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 txn, self.get_user_by_access_token, (access_token,)
             )
 
-        return self.runInteraction("delete_access_token", f)
+        return self.db.runInteraction("delete_access_token", f)
 
     @cachedInlineCallbacks()
     def is_guest(self, user_id):
-        res = yield self._simple_select_one_onecol(
+        res = yield self.db.simple_select_one_onecol(
             table="users",
             keyvalues={"name": user_id},
             retcol="is_guest",
@@ -1261,7 +1263,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         Adds a user to the table of users who need to be parted from all the rooms they're
         in
         """
-        return self._simple_insert(
+        return self.db.simple_insert(
             "users_pending_deactivation",
             values={"user_id": user_id},
             desc="add_user_pending_deactivation",
@@ -1274,7 +1276,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         """
         # XXX: This should be simple_delete_one but we failed to put a unique index on
         # the table, so somehow duplicate entries have ended up in it.
-        return self._simple_delete(
+        return self.db.simple_delete(
             "users_pending_deactivation",
             keyvalues={"user_id": user_id},
             desc="del_user_pending_deactivation",
@@ -1285,7 +1287,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         Gets one user from the table of users waiting to be parted from all the rooms
         they're in.
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             "users_pending_deactivation",
             keyvalues={},
             retcol="user_id",
@@ -1315,7 +1317,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         # Insert everything into a transaction in order to run atomically
         def validate_threepid_session_txn(txn):
-            row = self._simple_select_one_txn(
+            row = self.db.simple_select_one_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
@@ -1333,7 +1335,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                     400, "This client_secret does not match the provided session_id"
                 )
 
-            row = self._simple_select_one_txn(
+            row = self.db.simple_select_one_txn(
                 txn,
                 table="threepid_validation_token",
                 keyvalues={"session_id": session_id, "token": token},
@@ -1358,7 +1360,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 )
 
             # Looks good. Validate the session
-            self._simple_update_txn(
+            self.db.simple_update_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
@@ -1368,7 +1370,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             return next_link
 
         # Return next_link if it exists
-        return self.runInteraction(
+        return self.db.runInteraction(
             "validate_threepid_session_txn", validate_threepid_session_txn
         )
 
@@ -1401,7 +1403,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         if validated_at:
             insertion_values["validated_at"] = validated_at
 
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="threepid_validation_session",
             keyvalues={"session_id": session_id},
             values={"last_send_attempt": send_attempt},
@@ -1439,7 +1441,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         def start_or_continue_validation_session_txn(txn):
             # Create or update a validation session
-            self._simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
@@ -1452,7 +1454,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             )
 
             # Create a new validation token with this session ID
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="threepid_validation_token",
                 values={
@@ -1463,7 +1465,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 },
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "start_or_continue_validation_session",
             start_or_continue_validation_session_txn,
         )
@@ -1478,7 +1480,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             """
             return txn.execute(sql, (ts,))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "cull_expired_threepid_validation_tokens",
             cull_expired_threepid_validation_tokens_txn,
             self.clock.time_msec(),
@@ -1493,7 +1495,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             deactivated (bool): The value to set for `deactivated`.
         """
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "set_user_deactivated_status",
             self.set_user_deactivated_status_txn,
             user_id,
@@ -1501,7 +1503,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         )
 
     def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
-        self._simple_update_one_txn(
+        self.db.simple_update_one_txn(
             txn=txn,
             table="users",
             keyvalues={"name": user_id},
@@ -1510,3 +1512,59 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
         self._invalidate_cache_and_stream(
             txn, self.get_user_deactivated_status, (user_id,)
         )
+
+    @defer.inlineCallbacks
+    def _set_expiration_date_when_missing(self):
+        """
+        Retrieves the list of registered users that don't have an expiration date, and
+        adds an expiration date for each of them.
+        """
+
+        def select_users_with_no_expiration_date_txn(txn):
+            """Retrieves the list of registered users with no expiration date from the
+            database, filtering out deactivated users.
+            """
+            sql = (
+                "SELECT users.name FROM users"
+                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+            )
+            txn.execute(sql, [])
+
+            res = self.db.cursor_to_dict(txn)
+            if res:
+                for user in res:
+                    self.set_expiration_date_for_user_txn(
+                        txn, user["name"], use_delta=True
+                    )
+
+        yield self.db.runInteraction(
+            "get_users_with_no_expiration_date",
+            select_users_with_no_expiration_date_txn,
+        )
+
+    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+        """Sets an expiration date to the account with the given user ID.
+
+        Args:
+             user_id (str): User ID to set an expiration date for.
+             use_delta (bool): If set to False, the expiration date for the user will be
+                now + validity period. If set to True, this expiration date will be a
+                random value in the [now + period - d ; now + period] range, d being a
+                delta equal to 10% of the validity period.
+        """
+        now_ms = self._clock.time_msec()
+        expiration_ts = now_ms + self._account_validity.period
+
+        if use_delta:
+            expiration_ts = self.rand.randrange(
+                expiration_ts - self._account_validity.startup_job_max_delta,
+                expiration_ts,
+            )
+
+        self.db.simple_upsert_txn(
+            txn,
+            "account_validity",
+            keyvalues={"user_id": user_id},
+            values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+        )
diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py
index 7d5de0ea2e..1c07c7a425 100644
--- a/synapse/storage/data_stores/main/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
 
 class RejectionsStore(SQLBaseStore):
     def _store_rejections_txn(self, txn, event_id, reason):
-        self._simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="rejections",
             values={
@@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore):
         )
 
     def get_rejection_reason(self, event_id):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="rejections",
             retcol="reason",
             keyvalues={"event_id": event_id},
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
index 858f65582b..046c2b4845 100644
--- a/synapse/storage/data_stores/main/relations.py
+++ b/synapse/storage/data_stores/main/relations.py
@@ -129,7 +129,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_recent_references_for_event", _get_recent_references_for_event_txn
         )
 
@@ -223,7 +223,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
@@ -268,7 +268,7 @@ class RelationsWorkerStore(SQLBaseStore):
             if row:
                 return row[0]
 
-        edit_id = yield self.runInteraction(
+        edit_id = yield self.db.runInteraction(
             "get_applicable_edit", _get_applicable_edit_txn
         )
 
@@ -318,7 +318,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
             return bool(txn.fetchone())
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
@@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore):
 
         aggregation_key = relation.get("key")
 
-        self._simple_insert_txn(
+        self.db.simple_insert_txn(
             txn,
             table="event_relations",
             values={
@@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore):
             redacted_event_id (str): The event that was redacted.
         """
 
-        self._simple_delete_txn(
+        self.db.simple_delete_txn(
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 67bb1b6f60..aa476d0fbf 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -19,13 +19,17 @@ import logging
 import re
 from typing import Optional, Tuple
 
+from six import integer_types
+
 from canonicaljson import json
 
 from twisted.internet import defer
 
+from synapse.api.constants import EventTypes
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.search import SearchStore
+from synapse.storage.database import Database
 from synapse.types import ThirdPartyInstanceID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
@@ -42,6 +46,11 @@ RatelimitOverride = collections.namedtuple(
 
 
 class RoomWorkerStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+
+        self.config = hs.config
+
     def get_room(self, room_id):
         """Retrieve a room.
 
@@ -50,7 +59,7 @@ class RoomWorkerStore(SQLBaseStore):
         Returns:
             A dict containing the room information, or None if the room is unknown.
         """
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="rooms",
             keyvalues={"room_id": room_id},
             retcols=("room_id", "is_public", "creator"),
@@ -59,7 +68,7 @@ class RoomWorkerStore(SQLBaseStore):
         )
 
     def get_public_room_ids(self):
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="rooms",
             keyvalues={"is_public": True},
             retcol="room_id",
@@ -116,7 +125,7 @@ class RoomWorkerStore(SQLBaseStore):
             txn.execute(sql, query_args)
             return txn.fetchone()[0]
 
-        return self.runInteraction("count_public_rooms", _count_public_rooms_txn)
+        return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
 
     @defer.inlineCallbacks
     def get_largest_public_rooms(
@@ -249,21 +258,21 @@ class RoomWorkerStore(SQLBaseStore):
         def _get_largest_public_rooms_txn(txn):
             txn.execute(sql, query_args)
 
-            results = self.cursor_to_dict(txn)
+            results = self.db.cursor_to_dict(txn)
 
             if not forwards:
                 results.reverse()
 
             return results
 
-        ret_val = yield self.runInteraction(
+        ret_val = yield self.db.runInteraction(
             "get_largest_public_rooms", _get_largest_public_rooms_txn
         )
         defer.returnValue(ret_val)
 
     @cached(max_entries=10000)
     def is_room_blocked(self, room_id):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="blocked_rooms",
             keyvalues={"room_id": room_id},
             retcol="1",
@@ -284,7 +293,7 @@ class RoomWorkerStore(SQLBaseStore):
             of RatelimitOverride are None or 0 then ratelimitng has been
             disabled for that user entirely.
         """
-        row = yield self._simple_select_one(
+        row = yield self.db.simple_select_one(
             table="ratelimit_override",
             keyvalues={"user_id": user_id},
             retcols=("messages_per_second", "burst_count"),
@@ -300,8 +309,148 @@ class RoomWorkerStore(SQLBaseStore):
         else:
             return None
 
+    @cachedInlineCallbacks()
+    def get_retention_policy_for_room(self, room_id):
+        """Get the retention policy for a given room.
+
+        If no retention policy has been found for this room, returns a policy defined
+        by the configured default policy (which has None as both the 'min_lifetime' and
+        the 'max_lifetime' if no default policy has been defined in the server's
+        configuration).
+
+        Args:
+            room_id (str): The ID of the room to get the retention policy of.
+
+        Returns:
+            dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+        """
+
+        def get_retention_policy_for_room_txn(txn):
+            txn.execute(
+                """
+                SELECT min_lifetime, max_lifetime FROM room_retention
+                INNER JOIN current_state_events USING (event_id, room_id)
+                WHERE room_id = ?;
+                """,
+                (room_id,),
+            )
+
+            return self.db.cursor_to_dict(txn)
+
+        ret = yield self.db.runInteraction(
+            "get_retention_policy_for_room", get_retention_policy_for_room_txn,
+        )
+
+        # If we don't know this room ID, ret will be None, in this case return the default
+        # policy.
+        if not ret:
+            defer.returnValue(
+                {
+                    "min_lifetime": self.config.retention_default_min_lifetime,
+                    "max_lifetime": self.config.retention_default_max_lifetime,
+                }
+            )
+
+        row = ret[0]
+
+        # If one of the room's policy's attributes isn't defined, use the matching
+        # attribute from the default policy.
+        # The default values will be None if no default policy has been defined, or if one
+        # of the attributes is missing from the default policy.
+        if row["min_lifetime"] is None:
+            row["min_lifetime"] = self.config.retention_default_min_lifetime
+
+        if row["max_lifetime"] is None:
+            row["max_lifetime"] = self.config.retention_default_max_lifetime
+
+        defer.returnValue(row)
+
+
+class RoomBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+        self.config = hs.config
+
+        self.db.updates.register_background_update_handler(
+            "insert_room_retention", self._background_insert_retention,
+        )
+
+    @defer.inlineCallbacks
+    def _background_insert_retention(self, progress, batch_size):
+        """Retrieves a list of all rooms within a range and inserts an entry for each of
+        them into the room_retention table.
+        NULLs the property's columns if missing from the retention event in the room's
+        state (or NULLs all of them if there's no retention event in the room's state),
+        so that we fall back to the server's retention policy.
+        """
+
+        last_room = progress.get("room_id", "")
+
+        def _background_insert_retention_txn(txn):
+            txn.execute(
+                """
+                SELECT state.room_id, state.event_id, events.json
+                FROM current_state_events as state
+                LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
+                WHERE state.room_id > ? AND state.type = '%s'
+                ORDER BY state.room_id ASC
+                LIMIT ?;
+                """
+                % EventTypes.Retention,
+                (last_room, batch_size),
+            )
+
+            rows = self.db.cursor_to_dict(txn)
+
+            if not rows:
+                return True
+
+            for row in rows:
+                if not row["json"]:
+                    retention_policy = {}
+                else:
+                    ev = json.loads(row["json"])
+                    retention_policy = json.dumps(ev["content"])
+
+                self.db.simple_insert_txn(
+                    txn=txn,
+                    table="room_retention",
+                    values={
+                        "room_id": row["room_id"],
+                        "event_id": row["event_id"],
+                        "min_lifetime": retention_policy.get("min_lifetime"),
+                        "max_lifetime": retention_policy.get("max_lifetime"),
+                    },
+                )
+
+            logger.info("Inserted %d rows into room_retention", len(rows))
+
+            self.db.updates._background_update_progress_txn(
+                txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+            )
+
+            if batch_size > len(rows):
+                return True
+            else:
+                return False
+
+        end = yield self.db.runInteraction(
+            "insert_room_retention", _background_insert_retention_txn,
+        )
+
+        if end:
+            yield self.db.updates._end_background_update("insert_room_retention")
+
+        defer.returnValue(batch_size)
+
+
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomStore, self).__init__(database, db_conn, hs)
+
+        self.config = hs.config
 
-class RoomStore(RoomWorkerStore, SearchStore):
     @defer.inlineCallbacks
     def store_room(self, room_id, room_creator_user_id, is_public):
         """Stores a room.
@@ -317,7 +466,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
         try:
 
             def store_room_txn(txn, next_id):
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     "rooms",
                     {
@@ -327,7 +476,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                     },
                 )
                 if is_public:
-                    self._simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="public_room_list_stream",
                         values={
@@ -338,7 +487,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                     )
 
             with self._public_room_id_gen.get_next() as next_id:
-                yield self.runInteraction("store_room_txn", store_room_txn, next_id)
+                yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
         except Exception as e:
             logger.error("store_room with room_id=%s failed: %s", room_id, e)
             raise StoreError(500, "Problem creating room.")
@@ -346,14 +495,14 @@ class RoomStore(RoomWorkerStore, SearchStore):
     @defer.inlineCallbacks
     def set_room_is_public(self, room_id, is_public):
         def set_room_is_public_txn(txn, next_id):
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="rooms",
                 keyvalues={"room_id": room_id},
                 updatevalues={"is_public": is_public},
             )
 
-            entries = self._simple_select_list_txn(
+            entries = self.db.simple_select_list_txn(
                 txn,
                 table="public_room_list_stream",
                 keyvalues={
@@ -371,7 +520,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                 add_to_stream = bool(entries[-1]["visibility"]) != is_public
 
             if add_to_stream:
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="public_room_list_stream",
                     values={
@@ -384,7 +533,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                 )
 
         with self._public_room_id_gen.get_next() as next_id:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "set_room_is_public", set_room_is_public_txn, next_id
             )
         self.hs.get_notifier().on_new_replication_data()
@@ -411,7 +560,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
         def set_room_is_public_appservice_txn(txn, next_id):
             if is_public:
                 try:
-                    self._simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="appservice_room_list",
                         values={
@@ -424,7 +573,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                     # We've already inserted, nothing to do.
                     return
             else:
-                self._simple_delete_txn(
+                self.db.simple_delete_txn(
                     txn,
                     table="appservice_room_list",
                     keyvalues={
@@ -434,7 +583,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                     },
                 )
 
-            entries = self._simple_select_list_txn(
+            entries = self.db.simple_select_list_txn(
                 txn,
                 table="public_room_list_stream",
                 keyvalues={
@@ -452,7 +601,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                 add_to_stream = bool(entries[-1]["visibility"]) != is_public
 
             if add_to_stream:
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="public_room_list_stream",
                     values={
@@ -465,7 +614,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
                 )
 
         with self._public_room_id_gen.get_next() as next_id:
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "set_room_is_public_appservice",
                 set_room_is_public_appservice_txn,
                 next_id,
@@ -482,7 +631,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
             row = txn.fetchone()
             return row[0] or 0
 
-        return self.runInteraction("get_rooms", f)
+        return self.db.runInteraction("get_rooms", f)
 
     def _store_room_topic_txn(self, txn, event):
         if hasattr(event, "content") and "topic" in event.content:
@@ -502,11 +651,40 @@ class RoomStore(RoomWorkerStore, SearchStore):
                 txn, event, "content.body", event.content["body"]
             )
 
+    def _store_retention_policy_for_room_txn(self, txn, event):
+        if hasattr(event, "content") and (
+            "min_lifetime" in event.content or "max_lifetime" in event.content
+        ):
+            if (
+                "min_lifetime" in event.content
+                and not isinstance(event.content.get("min_lifetime"), integer_types)
+            ) or (
+                "max_lifetime" in event.content
+                and not isinstance(event.content.get("max_lifetime"), integer_types)
+            ):
+                # Ignore the event if one of the value isn't an integer.
+                return
+
+            self.db.simple_insert_txn(
+                txn=txn,
+                table="room_retention",
+                values={
+                    "room_id": event.room_id,
+                    "event_id": event.event_id,
+                    "min_lifetime": event.content.get("min_lifetime"),
+                    "max_lifetime": event.content.get("max_lifetime"),
+                },
+            )
+
+            self._invalidate_cache_and_stream(
+                txn, self.get_retention_policy_for_room, (event.room_id,)
+            )
+
     def add_event_report(
         self, room_id, event_id, user_id, reason, content, received_ts
     ):
         next_id = self._event_reports_id_gen.get_next()
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="event_reports",
             values={
                 "id": next_id,
@@ -539,7 +717,9 @@ class RoomStore(RoomWorkerStore, SearchStore):
         if prev_id == current_id:
             return defer.succeed([])
 
-        return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
+        return self.db.runInteraction(
+            "get_all_new_public_rooms", get_all_new_public_rooms
+        )
 
     @defer.inlineCallbacks
     def block_room(self, room_id, user_id):
@@ -552,14 +732,14 @@ class RoomStore(RoomWorkerStore, SearchStore):
         Returns:
             Deferred
         """
-        yield self._simple_upsert(
+        yield self.db.simple_upsert(
             table="blocked_rooms",
             keyvalues={"room_id": room_id},
             values={},
             insertion_values={"user_id": user_id},
             desc="block_room",
         )
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "block_room_invalidation",
             self._invalidate_cache_and_stream,
             self.is_room_blocked,
@@ -590,7 +770,9 @@ class RoomStore(RoomWorkerStore, SearchStore):
 
             return local_media_mxcs, remote_media_mxcs
 
-        return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+        return self.db.runInteraction(
+            "get_media_ids_in_room", _get_media_mxcs_in_room_txn
+        )
 
     def quarantine_media_ids_in_room(self, room_id, quarantined_by):
         """For a room loops through all events with media and quarantines
@@ -629,7 +811,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
 
             return total_media_quarantined
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "quarantine_media_in_room", _quarantine_media_in_room_txn
         )
 
@@ -683,3 +865,89 @@ class RoomStore(RoomWorkerStore, SearchStore):
                             remote_media_mxcs.append((hostname, media_id))
 
         return local_media_mxcs, remote_media_mxcs
+
+    @defer.inlineCallbacks
+    def get_rooms_for_retention_period_in_range(
+        self, min_ms, max_ms, include_null=False
+    ):
+        """Retrieves all of the rooms within the given retention range.
+
+        Optionally includes the rooms which don't have a retention policy.
+
+        Args:
+            min_ms (int|None): Duration in milliseconds that define the lower limit of
+                the range to handle (exclusive). If None, doesn't set a lower limit.
+            max_ms (int|None): Duration in milliseconds that define the upper limit of
+                the range to handle (inclusive). If None, doesn't set an upper limit.
+            include_null (bool): Whether to include rooms which retention policy is NULL
+                in the returned set.
+
+        Returns:
+            dict[str, dict]: The rooms within this range, along with their retention
+                policy. The key is "room_id", and maps to a dict describing the retention
+                policy associated with this room ID. The keys for this nested dict are
+                "min_lifetime" (int|None), and "max_lifetime" (int|None).
+        """
+
+        def get_rooms_for_retention_period_in_range_txn(txn):
+            range_conditions = []
+            args = []
+
+            if min_ms is not None:
+                range_conditions.append("max_lifetime > ?")
+                args.append(min_ms)
+
+            if max_ms is not None:
+                range_conditions.append("max_lifetime <= ?")
+                args.append(max_ms)
+
+            # Do a first query which will retrieve the rooms that have a retention policy
+            # in their current state.
+            sql = """
+                SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+                INNER JOIN current_state_events USING (event_id, room_id)
+                """
+
+            if len(range_conditions):
+                sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+                if include_null:
+                    sql += " OR max_lifetime IS NULL"
+
+            txn.execute(sql, args)
+
+            rows = self.db.cursor_to_dict(txn)
+            rooms_dict = {}
+
+            for row in rows:
+                rooms_dict[row["room_id"]] = {
+                    "min_lifetime": row["min_lifetime"],
+                    "max_lifetime": row["max_lifetime"],
+                }
+
+            if include_null:
+                # If required, do a second query that retrieves all of the rooms we know
+                # of so we can handle rooms with no retention policy.
+                sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+                txn.execute(sql)
+
+                rows = self.db.cursor_to_dict(txn)
+
+                # If a room isn't already in the dict (i.e. it doesn't have a retention
+                # policy in its state), add it with a null policy.
+                for row in rows:
+                    if row["room_id"] not in rooms_dict:
+                        rooms_dict[row["room_id"]] = {
+                            "min_lifetime": None,
+                            "max_lifetime": None,
+                        }
+
+            return rooms_dict
+
+        rooms = yield self.db.runInteraction(
+            "get_rooms_for_retention_period_in_range",
+            get_rooms_for_retention_period_in_range_txn,
+        )
+
+        defer.returnValue(rooms)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 2af24a20b7..92e3b9c512 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from typing import Iterable, List
 
 from six import iteritems, itervalues
 
@@ -25,9 +26,13 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, Membership
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import (
+    LoggingTransaction,
+    SQLBaseStore,
+    make_in_list_sql_clause,
+)
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
     GetRoomsForUserWithStreamOrdering,
@@ -50,8 +55,8 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
 class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
 
         # Is the current_state_events.membership up to date? Or is the
         # background update still running?
@@ -115,7 +120,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(query)
             return list(txn)[0][0]
 
-        count = yield self.runInteraction("get_known_servers", _transact)
+        count = yield self.db.runInteraction("get_known_servers", _transact)
 
         # We always know about ourselves, even if we have nothing in
         # room_memberships (for example, the server is new).
@@ -127,7 +132,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         membership column is up to date
         """
 
-        pending_update = self._simple_select_one_txn(
+        pending_update = self.db.simple_select_one_txn(
             txn,
             table="background_updates",
             keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
@@ -143,7 +148,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 15.0,
                 run_as_background_process,
                 "_check_safe_current_state_events_membership_updated",
-                self.runInteraction,
+                self.db.runInteraction,
                 "_check_safe_current_state_events_membership_updated",
                 self._check_safe_current_state_events_membership_updated_txn,
             )
@@ -160,7 +165,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(max_entries=100000, iterable=True)
     def get_users_in_room(self, room_id):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
@@ -268,7 +273,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
             return res
 
-        return self.runInteraction("get_room_summary", _get_room_summary_txn)
+        return self.db.runInteraction("get_room_summary", _get_room_summary_txn)
 
     def _get_user_counts_in_room_txn(self, txn, room_id):
         """
@@ -338,7 +343,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         if not membership_list:
             return defer.succeed(None)
 
-        rooms = yield self.runInteraction(
+        rooms = yield self.db.runInteraction(
             "get_rooms_for_user_where_membership_is",
             self._get_rooms_for_user_where_membership_is_txn,
             user_id,
@@ -391,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 )
 
             txn.execute(sql, (user_id, *args))
-            results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
+            results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
 
         if do_invite:
             sql = (
@@ -411,7 +416,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                     stream_ordering=r["stream_ordering"],
                     membership=Membership.INVITE,
                 )
-                for r in self.cursor_to_dict(txn)
+                for r in self.db.cursor_to_dict(txn)
             )
 
         return results
@@ -602,7 +607,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             to `user_id` and ProfileInfo (or None if not join event).
         """
 
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
@@ -642,7 +647,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # the returned user actually has the correct domain.
         like_clause = "%:" + host
 
-        rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
+        rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause)
 
         if not rows:
             return False
@@ -682,7 +687,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # the returned user actually has the correct domain.
         like_clause = "%:" + host
 
-        rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
+        rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause)
 
         if not rows:
             return False
@@ -752,7 +757,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             rows = txn.fetchall()
             return rows[0][0]
 
-        count = yield self.runInteraction("did_forget_membership", f)
+        count = yield self.db.runInteraction("did_forget_membership", f)
         return count == 0
 
     @cached()
@@ -789,7 +794,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             txn.execute(sql, (user_id,))
             return set(row[0] for row in txn if row[1] == 0)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
         )
 
@@ -804,7 +809,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             Deferred[set[str]]: Set of room IDs.
         """
 
-        room_ids = yield self._simple_select_onecol(
+        room_ids = yield self.db.simple_select_onecol(
             table="room_memberships",
             keyvalues={"membership": Membership.JOIN, "user_id": user_id},
             retcol="room_id",
@@ -813,18 +818,34 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return set(room_ids)
 
+    def get_membership_from_event_ids(
+        self, member_event_ids: Iterable[str]
+    ) -> List[dict]:
+        """Get user_id and membership of a set of event IDs.
+        """
+
+        return self.db.simple_select_many_batch(
+            table="room_memberships",
+            column="event_id",
+            iterable=member_event_ids,
+            retcols=("user_id", "membership", "event_id"),
+            keyvalues={},
+            batch_size=500,
+            desc="get_membership_from_event_ids",
+        )
+
 
-class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
+class RoomMemberBackgroundUpdateStore(SQLBaseStore):
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        self.db.updates.register_background_update_handler(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
             self._background_current_state_membership,
         )
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             "room_membership_forgotten_idx",
             index_name="room_memberships_user_room_forgotten",
             table="room_memberships",
@@ -857,7 +878,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
 
             txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
 
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return 0
 
@@ -892,18 +913,20 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
                 "max_stream_id_exclusive": min_stream_id,
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
             )
 
             return len(rows)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
         )
 
         if not result:
-            yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
+            yield self.db.updates._end_background_update(
+                _MEMBERSHIP_PROFILE_UPDATE_NAME
+            )
 
         return result
 
@@ -942,7 +965,7 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
 
                 last_processed_room = next_room
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn,
                 _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
                 {"last_processed_room": last_processed_room},
@@ -954,26 +977,28 @@ class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
         # string, which will compare before all room IDs correctly.
         last_processed_room = progress.get("last_processed_room", "")
 
-        row_count, finished = yield self.runInteraction(
+        row_count, finished = yield self.db.runInteraction(
             "_background_current_state_membership_update",
             _background_current_state_membership_txn,
             last_processed_room,
         )
 
         if finished:
-            yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME)
+            yield self.db.updates._end_background_update(
+                _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
+            )
 
         return row_count
 
 
 class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
     def _store_room_members_txn(self, txn, events, backfilled):
         """Store a room member in the database.
         """
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="room_memberships",
             values=[
@@ -1011,7 +1036,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
             is_mine = self.hs.is_mine_id(event.state_key)
             if is_new_state and is_mine:
                 if event.membership == Membership.INVITE:
-                    self._simple_insert_txn(
+                    self.db.simple_insert_txn(
                         txn,
                         table="local_invites",
                         values={
@@ -1051,7 +1076,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
             txn.execute(sql, (stream_ordering, True, room_id, user_id))
 
         with self._stream_id_gen.get_next() as stream_ordering:
-            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+            yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
 
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
@@ -1074,7 +1099,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
                 txn, self.get_forgotten_rooms_for_user, (user_id,)
             )
 
-        return self.runInteraction("forget_membership", f)
+        return self.db.runInteraction("forget_membership", f)
 
 
 class _JoinedHostsCache(object):
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
new file mode 100644
index 0000000000..81a36a8b1d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_expiry (
+    event_id TEXT PRIMARY KEY,
+    expiry_ts BIGINT NOT NULL
+);
+
+CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
index fe51b02309..ea95db0ed7 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
@@ -14,4 +14,3 @@
  */
 
 ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false;
-CREATE INDEX redactions_have_censored ON redactions(event_id) WHERE not have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
index 77a5eca499..49ce35d794 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
@@ -14,7 +14,9 @@
  */
 
 ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
-CREATE INDEX redactions_have_censored_ts ON redactions(received_ts) WHERE not have_censored;
 
 INSERT INTO background_updates (update_name, progress_json) VALUES
   ('redactions_received_ts', '{}');
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('redactions_have_censored_ts_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
new file mode 100644
index 0000000000..b7550f6f4e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS redactions_have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
new file mode 100644
index 0000000000..7d70dd071e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- store the current etag of backup version
+ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
new file mode 100644
index 0000000000..ee6cdf7a14
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
@@ -0,0 +1,33 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks the retention policy of a room.
+-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in
+-- the room's retention policy state event.
+-- If a room doesn't have a retention policy state event in its state, both max_lifetime
+-- and min_lifetime are NULL.
+CREATE TABLE IF NOT EXISTS room_retention(
+    room_id TEXT,
+    event_id TEXT,
+    min_lifetime BIGINT,
+    max_lifetime BIGINT,
+
+    PRIMARY KEY(room_id, event_id)
+);
+
+CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime);
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('insert_room_retention', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
index 27a96123e3..5c5fffcafb 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
@@ -40,7 +40,8 @@ CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures (
     signature TEXT NOT NULL
 );
 
-CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
+-- replaced by the index created in signing_keys_nonunique_signatures.sql
+-- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
 
 -- stream of user signature updates
 CREATE TABLE IF NOT EXISTS user_signature_stream (
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
new file mode 100644
index 0000000000..0aa90ebf0c
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* The cross-signing signatures index should not be a unique index, because a
+ * user may upload multiple signatures for the same target user. The previous
+ * index was unique, so delete it if it's there and create a new non-unique
+ * index. */
+
+DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT
+EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index d1d7c6863d..4eec2fae5e 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -24,8 +24,8 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
-from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
 logger = logging.getLogger(__name__)
@@ -36,23 +36,23 @@ SearchEntry = namedtuple(
 )
 
 
-class SearchBackgroundUpdateStore(BackgroundUpdateStore):
+class SearchBackgroundUpdateStore(SQLBaseStore):
 
     EVENT_SEARCH_UPDATE_NAME = "event_search"
     EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
     EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
     EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
 
-    def __init__(self, db_conn, hs):
-        super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         if not hs.config.enable_search:
             return
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
         )
 
@@ -61,9 +61,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
         # a GIN index. However, it's possible that some people might still have
         # the background update queued, so we register a handler to clear the
         # background update.
-        self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+        self.db.updates.register_noop_background_update(
+            self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME
+        )
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
         )
 
@@ -93,7 +95,7 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
             # store_search_entries_txn with a generator function, but that
             # would mean having two cursors open on the database at once.
             # Instead we just build a list of results.
-            rows = self.cursor_to_dict(txn)
+            rows = self.db.cursor_to_dict(txn)
             if not rows:
                 return 0
 
@@ -153,18 +155,18 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
                 "rows_inserted": rows_inserted + len(event_search_rows),
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_SEARCH_UPDATE_NAME, progress
             )
 
             return len(event_search_rows)
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
         )
 
         if not result:
-            yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
+            yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
 
         return result
 
@@ -206,9 +208,11 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
                 conn.set_session(autocommit=False)
 
         if isinstance(self.database_engine, PostgresEngine):
-            yield self.runWithConnection(create_index)
+            yield self.db.runWithConnection(create_index)
 
-        yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
+        yield self.db.updates._end_background_update(
+            self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
+        )
         return 1
 
     @defer.inlineCallbacks
@@ -237,14 +241,14 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
                 )
                 conn.set_session(autocommit=False)
 
-            yield self.runWithConnection(create_index)
+            yield self.db.runWithConnection(create_index)
 
             pg = dict(progress)
             pg["have_added_indexes"] = True
 
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 self.EVENT_SEARCH_ORDER_UPDATE_NAME,
-                self._background_update_progress_txn,
+                self.db.updates._background_update_progress_txn,
                 self.EVENT_SEARCH_ORDER_UPDATE_NAME,
                 pg,
             )
@@ -274,18 +278,20 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
                 "have_added_indexes": True,
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
             )
 
             return len(rows), True
 
-        num_rows, finished = yield self.runInteraction(
+        num_rows, finished = yield self.db.runInteraction(
             self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
         )
 
         if not finished:
-            yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
+            yield self.db.updates._end_background_update(
+                self.EVENT_SEARCH_ORDER_UPDATE_NAME
+            )
 
         return num_rows
 
@@ -337,8 +343,8 @@ class SearchBackgroundUpdateStore(BackgroundUpdateStore):
 
 
 class SearchStore(SearchBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(SearchStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(SearchStore, self).__init__(database, db_conn, hs)
 
     def store_event_search_txn(self, txn, event, key, value):
         """Add event to the search table
@@ -441,7 +447,9 @@ class SearchStore(SearchBackgroundUpdateStore):
         # entire table from the database.
         sql += " ORDER BY rank DESC LIMIT 500"
 
-        results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args)
+        results = yield self.db.execute(
+            "search_msgs", self.db.cursor_to_dict, sql, *args
+        )
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
@@ -455,8 +463,8 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         count_sql += " GROUP BY room_id"
 
-        count_results = yield self._execute(
-            "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+        count_results = yield self.db.execute(
+            "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
         )
 
         count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -586,7 +594,9 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         args.append(limit)
 
-        results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args)
+        results = yield self.db.execute(
+            "search_rooms", self.db.cursor_to_dict, sql, *args
+        )
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
@@ -600,8 +610,8 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         count_sql += " GROUP BY room_id"
 
-        count_results = yield self._execute(
-            "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+        count_results = yield self.db.execute(
+            "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
         )
 
         count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
@@ -686,7 +696,7 @@ class SearchStore(SearchBackgroundUpdateStore):
 
             return highlight_words
 
-        return self.runInteraction("_find_highlights", f)
+        return self.db.runInteraction("_find_highlights", f)
 
 
 def _to_postgres_options(options_dict):
diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py
index 556191b76f..563216b63c 100644
--- a/synapse/storage/data_stores/main/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -48,7 +48,7 @@ class SignatureWorkerStore(SQLBaseStore):
                 for event_id in event_ids
             }
 
-        return self.runInteraction("get_event_reference_hashes", f)
+        return self.db.runInteraction("get_event_reference_hashes", f)
 
     @defer.inlineCallbacks
     def add_event_hashes(self, event_ids):
@@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore):
                 }
             )
 
-        self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
+        self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 6a90daea31..9ef7b48c74 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -27,8 +27,8 @@ from synapse.api.errors import NotFoundError
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.state import StateFilter
 from synapse.util.caches import get_cache_factor_for, intern_string
@@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
             count = 0
 
             while next_group:
-                next_group = self._simple_select_one_onecol_txn(
+                next_group = self.db.simple_select_one_onecol_txn(
                     txn,
                     table="state_group_edges",
                     keyvalues={"state_group": next_group},
@@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
                     ):
                         break
 
-                    next_group = self._simple_select_one_onecol_txn(
+                    next_group = self.db.simple_select_one_onecol_txn(
                         txn,
                         table="state_group_edges",
                         keyvalues={"state_group": next_group},
@@ -214,8 +214,8 @@ class StateGroupWorkerStore(
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
 
-    def __init__(self, db_conn, hs):
-        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
 
         # Originally the state store used a single DictionaryCache to cache the
         # event IDs for the state types in a given state group to avoid hammering
@@ -348,7 +348,9 @@ class StateGroupWorkerStore(
                 (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
             }
 
-        return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
+        return self.db.runInteraction(
+            "get_current_state_ids", _get_current_state_ids_txn
+        )
 
     # FIXME: how should this be cached?
     def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@@ -392,7 +394,7 @@ class StateGroupWorkerStore(
 
             return results
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
         )
 
@@ -431,7 +433,7 @@ class StateGroupWorkerStore(
         """
 
         def _get_state_group_delta_txn(txn):
-            prev_group = self._simple_select_one_onecol_txn(
+            prev_group = self.db.simple_select_one_onecol_txn(
                 txn,
                 table="state_group_edges",
                 keyvalues={"state_group": state_group},
@@ -442,7 +444,7 @@ class StateGroupWorkerStore(
             if not prev_group:
                 return _GetStateGroupDelta(None, None)
 
-            delta_ids = self._simple_select_list_txn(
+            delta_ids = self.db.simple_select_list_txn(
                 txn,
                 table="state_groups_state",
                 keyvalues={"state_group": state_group},
@@ -454,7 +456,9 @@ class StateGroupWorkerStore(
                 {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
             )
 
-        return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
+        return self.db.runInteraction(
+            "get_state_group_delta", _get_state_group_delta_txn
+        )
 
     @defer.inlineCallbacks
     def get_state_groups_ids(self, _room_id, event_ids):
@@ -540,7 +544,7 @@ class StateGroupWorkerStore(
 
         chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
         for chunk in chunks:
-            res = yield self.runInteraction(
+            res = yield self.db.runInteraction(
                 "_get_state_groups_from_groups",
                 self._get_state_groups_from_groups_txn,
                 chunk,
@@ -644,7 +648,7 @@ class StateGroupWorkerStore(
 
     @cached(max_entries=50000)
     def _get_state_group_for_event(self, event_id):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="event_to_state_groups",
             keyvalues={"event_id": event_id},
             retcol="state_group",
@@ -661,7 +665,7 @@ class StateGroupWorkerStore(
     def _get_state_group_for_events(self, event_ids):
         """Returns mapping event_id -> state_group
         """
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="event_to_state_groups",
             column="event_id",
             iterable=event_ids,
@@ -902,7 +906,7 @@ class StateGroupWorkerStore(
 
             state_group = self.database_engine.get_next_state_group_id(txn)
 
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="state_groups",
                 values={"id": state_group, "room_id": room_id, "event_id": event_id},
@@ -911,7 +915,7 @@ class StateGroupWorkerStore(
             # We persist as a delta if we can, while also ensuring the chain
             # of deltas isn't tooo long, as otherwise read performance degrades.
             if prev_group:
-                is_in_db = self._simple_select_one_onecol_txn(
+                is_in_db = self.db.simple_select_one_onecol_txn(
                     txn,
                     table="state_groups",
                     keyvalues={"id": prev_group},
@@ -926,13 +930,13 @@ class StateGroupWorkerStore(
 
                 potential_hops = self._count_state_group_hops_txn(txn, prev_group)
             if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                self._simple_insert_txn(
+                self.db.simple_insert_txn(
                     txn,
                     table="state_group_edges",
                     values={"state_group": state_group, "prev_state_group": prev_group},
                 )
 
-                self._simple_insert_many_txn(
+                self.db.simple_insert_many_txn(
                     txn,
                     table="state_groups_state",
                     values=[
@@ -947,7 +951,7 @@ class StateGroupWorkerStore(
                     ],
                 )
             else:
-                self._simple_insert_many_txn(
+                self.db.simple_insert_many_txn(
                     txn,
                     table="state_groups_state",
                     values=[
@@ -993,7 +997,7 @@ class StateGroupWorkerStore(
 
             return state_group
 
-        return self.runInteraction("store_state_group", _store_state_group_txn)
+        return self.db.runInteraction("store_state_group", _store_state_group_txn)
 
     @defer.inlineCallbacks
     def get_referenced_state_groups(self, state_groups):
@@ -1007,7 +1011,7 @@ class StateGroupWorkerStore(
             referenced.
         """
 
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="event_to_state_groups",
             column="state_group",
             iterable=state_groups,
@@ -1019,32 +1023,30 @@ class StateGroupWorkerStore(
         return set(row["state_group"] for row in rows)
 
 
-class StateBackgroundUpdateStore(
-    StateGroupBackgroundUpdateStore, BackgroundUpdateStore
-):
+class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
 
     STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
     STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
     CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
     EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
 
-    def __init__(self, db_conn, hs):
-        super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+        self.db.updates.register_background_update_handler(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
             self._background_deduplicate_state,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
         )
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             self.CURRENT_STATE_INDEX_UPDATE_NAME,
             index_name="current_state_events_member_index",
             table="current_state_events",
             columns=["state_key"],
             where_clause="type='m.room.member'",
         )
-        self.register_background_index_update(
+        self.db.updates.register_background_index_update(
             self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
             index_name="event_to_state_groups_sg_index",
             table="event_to_state_groups",
@@ -1065,7 +1067,7 @@ class StateBackgroundUpdateStore(
         batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
 
         if max_group is None:
-            rows = yield self._execute(
+            rows = yield self.db.execute(
                 "_background_deduplicate_state",
                 None,
                 "SELECT coalesce(max(id), 0) FROM state_groups",
@@ -1135,13 +1137,13 @@ class StateBackgroundUpdateStore(
                             if prev_state.get(key, None) != value
                         }
 
-                        self._simple_delete_txn(
+                        self.db.simple_delete_txn(
                             txn,
                             table="state_group_edges",
                             keyvalues={"state_group": state_group},
                         )
 
-                        self._simple_insert_txn(
+                        self.db.simple_insert_txn(
                             txn,
                             table="state_group_edges",
                             values={
@@ -1150,13 +1152,13 @@ class StateBackgroundUpdateStore(
                             },
                         )
 
-                        self._simple_delete_txn(
+                        self.db.simple_delete_txn(
                             txn,
                             table="state_groups_state",
                             keyvalues={"state_group": state_group},
                         )
 
-                        self._simple_insert_many_txn(
+                        self.db.simple_insert_many_txn(
                             txn,
                             table="state_groups_state",
                             values=[
@@ -1177,18 +1179,18 @@ class StateBackgroundUpdateStore(
                 "max_group": max_group,
             }
 
-            self._background_update_progress_txn(
+            self.db.updates._background_update_progress_txn(
                 txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
             )
 
             return False, batch_size
 
-        finished, result = yield self.runInteraction(
+        finished, result = yield self.db.runInteraction(
             self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
         )
 
         if finished:
-            yield self._end_background_update(
+            yield self.db.updates._end_background_update(
                 self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
             )
 
@@ -1218,9 +1220,9 @@ class StateBackgroundUpdateStore(
                 )
                 txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
 
-        yield self.runWithConnection(reindex_txn)
+        yield self.db.runWithConnection(reindex_txn)
 
-        yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+        yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
 
         return 1
 
@@ -1244,8 +1246,8 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
       * `state_groups_state`: Maps state group to state events.
     """
 
-    def __init__(self, db_conn, hs):
-        super(StateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StateStore, self).__init__(database, db_conn, hs)
 
     def _store_event_state_mappings_txn(
         self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
@@ -1263,7 +1265,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
 
             state_groups[event.event_id] = context.state_group
 
-        self._simple_insert_many_txn(
+        self.db.simple_insert_many_txn(
             txn,
             table="event_to_state_groups",
             values=[
diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
index 28f33ec18f..12c982cb26 100644
--- a/synapse/storage/data_stores/main/state_deltas.py
+++ b/synapse/storage/data_stores/main/state_deltas.py
@@ -98,14 +98,14 @@ class StateDeltasStore(SQLBaseStore):
                 ORDER BY stream_id ASC
             """
             txn.execute(sql, (prev_stream_id, clipped_stream_id))
-            return clipped_stream_id, self.cursor_to_dict(txn)
+            return clipped_stream_id, self.db.cursor_to_dict(txn)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_current_state_deltas", get_current_state_deltas_txn
         )
 
     def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
-        return self._simple_select_one_onecol_txn(
+        return self.db.simple_select_one_onecol_txn(
             txn,
             table="current_state_delta_stream",
             keyvalues={},
@@ -113,7 +113,7 @@ class StateDeltasStore(SQLBaseStore):
         )
 
     def get_max_stream_id_in_current_state_deltas(self):
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_max_stream_id_in_current_state_deltas",
             self._get_max_stream_id_in_current_state_deltas_txn,
         )
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 45b3de7d56..7bc186e9a1 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -22,6 +22,7 @@ from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.util.caches.descriptors import cached
 
@@ -58,8 +59,8 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
 
 
 class StatsStore(StateDeltasStore):
-    def __init__(self, db_conn, hs):
-        super(StatsStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StatsStore, self).__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
         self.clock = self.hs.get_clock()
@@ -68,17 +69,17 @@ class StatsStore(StateDeltasStore):
 
         self.stats_delta_processing_lock = DeferredLock()
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_stats_process_rooms", self._populate_stats_process_rooms
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_stats_process_users", self._populate_stats_process_users
         )
         # we no longer need to perform clean-up, but we will give ourselves
         # the potential to reintroduce it in the future – so documentation
         # will still encourage the use of this no-op handler.
-        self.register_noop_background_update("populate_stats_cleanup")
-        self.register_noop_background_update("populate_stats_prepare")
+        self.db.updates.register_noop_background_update("populate_stats_cleanup")
+        self.db.updates.register_noop_background_update("populate_stats_prepare")
 
     def quantise_stats_time(self, ts):
         """
@@ -102,7 +103,7 @@ class StatsStore(StateDeltasStore):
         This is a background update which regenerates statistics for users.
         """
         if not self.stats_enabled:
-            yield self._end_background_update("populate_stats_process_users")
+            yield self.db.updates._end_background_update("populate_stats_process_users")
             return 1
 
         last_user_id = progress.get("last_user_id", "")
@@ -117,22 +118,22 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, (last_user_id, batch_size))
             return [r for r, in txn]
 
-        users_to_work_on = yield self.runInteraction(
+        users_to_work_on = yield self.db.runInteraction(
             "_populate_stats_process_users", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not users_to_work_on:
-            yield self._end_background_update("populate_stats_process_users")
+            yield self.db.updates._end_background_update("populate_stats_process_users")
             return 1
 
         for user_id in users_to_work_on:
             yield self._calculate_and_set_initial_state_for_user(user_id)
             progress["last_user_id"] = user_id
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "populate_stats_process_users",
-            self._background_update_progress_txn,
+            self.db.updates._background_update_progress_txn,
             "populate_stats_process_users",
             progress,
         )
@@ -145,7 +146,7 @@ class StatsStore(StateDeltasStore):
         This is a background update which regenerates statistics for rooms.
         """
         if not self.stats_enabled:
-            yield self._end_background_update("populate_stats_process_rooms")
+            yield self.db.updates._end_background_update("populate_stats_process_rooms")
             return 1
 
         last_room_id = progress.get("last_room_id", "")
@@ -160,22 +161,22 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, (last_room_id, batch_size))
             return [r for r, in txn]
 
-        rooms_to_work_on = yield self.runInteraction(
+        rooms_to_work_on = yield self.db.runInteraction(
             "populate_stats_rooms_get_batch", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
-            yield self._end_background_update("populate_stats_process_rooms")
+            yield self.db.updates._end_background_update("populate_stats_process_rooms")
             return 1
 
         for room_id in rooms_to_work_on:
             yield self._calculate_and_set_initial_state_for_room(room_id)
             progress["last_room_id"] = room_id
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "_populate_stats_process_rooms",
-            self._background_update_progress_txn,
+            self.db.updates._background_update_progress_txn,
             "populate_stats_process_rooms",
             progress,
         )
@@ -186,7 +187,7 @@ class StatsStore(StateDeltasStore):
         """
         Returns the stats processor positions.
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="stats_incremental_position",
             keyvalues={},
             retcol="stream_id",
@@ -215,7 +216,7 @@ class StatsStore(StateDeltasStore):
             if field and "\0" in field:
                 fields[col] = None
 
-        return self._simple_upsert(
+        return self.db.simple_upsert(
             table="room_stats_state",
             keyvalues={"room_id": room_id},
             values=fields,
@@ -236,7 +237,7 @@ class StatsStore(StateDeltasStore):
             Deferred[list[dict]], where the dict has the keys of
             ABSOLUTE_STATS_FIELDS[stats_type],  and "bucket_size" and "end_ts".
         """
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_statistics_for_subject",
             self._get_statistics_for_subject_txn,
             stats_type,
@@ -257,14 +258,14 @@ class StatsStore(StateDeltasStore):
             ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
         )
 
-        slice_list = self._simple_select_list_paginate_txn(
+        slice_list = self.db.simple_select_list_paginate_txn(
             txn,
             table + "_historical",
-            {id_col: stats_id},
             "end_ts",
             start,
             size,
             retcols=selected_columns + ["bucket_size", "end_ts"],
+            keyvalues={id_col: stats_id},
             order_direction="DESC",
         )
 
@@ -282,7 +283,7 @@ class StatsStore(StateDeltasStore):
                 "name", "topic", "canonical_alias", "avatar", "join_rules",
                 "history_visibility"
         """
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             "room_stats_state",
             {"room_id": room_id},
             retcols=(
@@ -308,7 +309,7 @@ class StatsStore(StateDeltasStore):
         """
         table, id_col = TYPE_TO_TABLE[stats_type]
 
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             "%s_current" % (table,),
             keyvalues={id_col: id},
             retcol="completed_delta_stream_id",
@@ -344,14 +345,14 @@ class StatsStore(StateDeltasStore):
                         complete_with_stream_id=stream_id,
                     )
 
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 table="stats_incremental_position",
                 keyvalues={},
                 updatevalues={"stream_id": stream_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "bulk_update_stats_delta", _bulk_update_stats_delta_txn
         )
 
@@ -382,7 +383,7 @@ class StatsStore(StateDeltasStore):
                 Does not work with per-slice fields.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_stats_delta",
             self._update_stats_delta_txn,
             ts,
@@ -517,17 +518,17 @@ class StatsStore(StateDeltasStore):
         else:
             self.database_engine.lock_table(txn, table)
             retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
-            current_row = self._simple_select_one_txn(
+            current_row = self.db.simple_select_one_txn(
                 txn, table, keyvalues, retcols, allow_none=True
             )
             if current_row is None:
                 merged_dict = {**keyvalues, **absolutes, **additive_relatives}
-                self._simple_insert_txn(txn, table, merged_dict)
+                self.db.simple_insert_txn(txn, table, merged_dict)
             else:
                 for (key, val) in additive_relatives.items():
                     current_row[key] += val
                 current_row.update(absolutes)
-                self._simple_update_one_txn(txn, table, keyvalues, current_row)
+                self.db.simple_update_one_txn(txn, table, keyvalues, current_row)
 
     def _upsert_copy_from_table_with_additive_relatives_txn(
         self,
@@ -614,11 +615,11 @@ class StatsStore(StateDeltasStore):
             txn.execute(sql, qargs)
         else:
             self.database_engine.lock_table(txn, into_table)
-            src_row = self._simple_select_one_txn(
+            src_row = self.db.simple_select_one_txn(
                 txn, src_table, keyvalues, copy_columns
             )
             all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
-            dest_current_row = self._simple_select_one_txn(
+            dest_current_row = self.db.simple_select_one_txn(
                 txn,
                 into_table,
                 keyvalues=all_dest_keyvalues,
@@ -634,11 +635,11 @@ class StatsStore(StateDeltasStore):
                     **src_row,
                     **additive_relatives,
                 }
-                self._simple_insert_txn(txn, into_table, merged_dict)
+                self.db.simple_insert_txn(txn, into_table, merged_dict)
             else:
                 for (key, val) in additive_relatives.items():
                     src_row[key] = dest_current_row[key] + val
-                self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
+                self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
 
     def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
         """Fetches the counts of events in the given range of stream IDs.
@@ -652,7 +653,7 @@ class StatsStore(StateDeltasStore):
             changes.
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "stats_incremental_total_events_and_bytes",
             self.get_changes_room_total_events_and_bytes_txn,
             min_pos,
@@ -735,7 +736,7 @@ class StatsStore(StateDeltasStore):
         def _fetch_current_state_stats(txn):
             pos = self.get_room_max_stream_ordering()
 
-            rows = self._simple_select_many_txn(
+            rows = self.db.simple_select_many_txn(
                 txn,
                 table="current_state_events",
                 column="type",
@@ -791,7 +792,7 @@ class StatsStore(StateDeltasStore):
             current_state_events_count,
             users_in_room,
             pos,
-        ) = yield self.runInteraction(
+        ) = yield self.db.runInteraction(
             "get_initial_state_for_room", _fetch_current_state_stats
         )
 
@@ -866,7 +867,7 @@ class StatsStore(StateDeltasStore):
             (count,) = txn.fetchone()
             return count, pos
 
-        joined_rooms, pos = yield self.runInteraction(
+        joined_rooms, pos = yield self.db.runInteraction(
             "calculate_and_set_initial_state_for_user",
             _calculate_and_set_initial_state_for_user_txn,
         )
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 8780fdd989..140da8dad6 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -1,5 +1,8 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -44,6 +47,7 @@ from twisted.internet import defer
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -248,11 +252,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     __metaclass__ = abc.ABCMeta
 
-    def __init__(self, db_conn, hs):
-        super(StreamWorkerStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(StreamWorkerStore, self).__init__(database, db_conn, hs)
 
         events_max = self.get_room_max_stream_ordering()
-        event_cache_prefill, min_event_val = self._get_cache_dict(
+        event_cache_prefill, min_event_val = self.db.get_cache_dict(
             db_conn,
             "events",
             entity_column="room_id",
@@ -397,7 +401,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
             return rows
 
-        rows = yield self.runInteraction("get_room_events_stream_for_room", f)
+        rows = yield self.db.runInteraction("get_room_events_stream_for_room", f)
 
         ret = yield self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
@@ -447,7 +451,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return rows
 
-        rows = yield self.runInteraction("get_membership_changes_for_user", f)
+        rows = yield self.db.runInteraction("get_membership_changes_for_user", f)
 
         ret = yield self.get_events_as_list(
             [r.event_id for r in rows], get_prev_content=True
@@ -508,7 +512,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         end_token = RoomStreamToken.parse(end_token)
 
-        rows, token = yield self.runInteraction(
+        rows, token = yield self.db.runInteraction(
             "get_recent_event_ids_for_room",
             self._paginate_room_events_txn,
             room_id,
@@ -545,7 +549,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             txn.execute(sql, (room_id, stream_ordering))
             return txn.fetchone()
 
-        return self.runInteraction("get_room_event_after_stream_ordering", _f)
+        return self.db.runInteraction("get_room_event_after_stream_ordering", _f)
 
     @defer.inlineCallbacks
     def get_room_events_max_id(self, room_id=None):
@@ -559,7 +563,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if room_id is None:
             return "s%d" % (token,)
         else:
-            topo = yield self.runInteraction(
+            topo = yield self.db.runInteraction(
                 "_get_max_topological_txn", self._get_max_topological_txn, room_id
             )
             return "t%d-%d" % (topo, token)
@@ -573,7 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A deferred "s%d" stream token.
         """
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
         ).addCallback(lambda row: "s%d" % (row,))
 
@@ -586,7 +590,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         Returns:
             A deferred "t%d-%d" topological token.
         """
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="events",
             keyvalues={"event_id": event_id},
             retcols=("stream_ordering", "topological_ordering"),
@@ -610,13 +614,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             "SELECT coalesce(max(topological_ordering), 0) FROM events"
             " WHERE room_id = ? AND stream_ordering < ?"
         )
-        return self._execute(
+        return self.db.execute(
             "get_max_topological_token", None, sql, room_id, stream_key
         ).addCallback(lambda r: r[0][0] if r else 0)
 
     def _get_max_topological_txn(self, txn, room_id):
         txn.execute(
-            "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
+            "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
             (room_id,),
         )
 
@@ -664,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             dict
         """
 
-        results = yield self.runInteraction(
+        results = yield self.db.runInteraction(
             "get_events_around",
             self._get_events_around_txn,
             room_id,
@@ -706,7 +710,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             dict
         """
 
-        results = self._simple_select_one_txn(
+        results = self.db.simple_select_one_txn(
             txn,
             "events",
             keyvalues={"event_id": event_id, "room_id": room_id},
@@ -785,7 +789,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
             return upper_bound, [row[1] for row in rows]
 
-        upper_bound, event_ids = yield self.runInteraction(
+        upper_bound, event_ids = yield self.db.runInteraction(
             "get_all_new_events_stream", get_all_new_events_stream_txn
         )
 
@@ -794,7 +798,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return upper_bound, events
 
     def get_federation_out_pos(self, typ):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="federation_stream_position",
             retcol="stream_id",
             keyvalues={"type": typ},
@@ -802,7 +806,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
     def update_federation_out_pos(self, typ, stream_id):
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="federation_stream_position",
             keyvalues={"type": typ},
             updatevalues={"stream_id": stream_id},
@@ -953,7 +957,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         if to_key:
             to_key = RoomStreamToken.parse(to_key)
 
-        rows, token = yield self.runInteraction(
+        rows, token = yield self.db.runInteraction(
             "paginate_room_events",
             self._paginate_room_events_txn,
             room_id,
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 10d1887f75..2aa1bafd48 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
             tag strings to tag content.
         """
 
-        deferred = self._simple_select_list(
+        deferred = self.db.simple_select_list(
             "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
         )
 
@@ -78,14 +78,12 @@ class TagsWorkerStore(AccountDataWorkerStore):
             txn.execute(sql, (last_id, current_id, limit))
             return txn.fetchall()
 
-        tag_ids = yield self.runInteraction(
+        tag_ids = yield self.db.runInteraction(
             "get_all_updated_tags", get_all_updated_tags_txn
         )
 
         def get_tag_content(txn, tag_ids):
-            sql = (
-                "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
-            )
+            sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
             results = []
             for stream_id, user_id, room_id in tag_ids:
                 txn.execute(sql, (user_id, room_id))
@@ -100,7 +98,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         batch_size = 50
         results = []
         for i in range(0, len(tag_ids), batch_size):
-            tags = yield self.runInteraction(
+            tags = yield self.db.runInteraction(
                 "get_all_updated_tag_content",
                 get_tag_content,
                 tag_ids[i : i + batch_size],
@@ -137,7 +135,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
         if not changed:
             return {}
 
-        room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
+        room_ids = yield self.db.runInteraction(
+            "get_updated_tags", get_updated_tags_txn
+        )
 
         results = {}
         if room_ids:
@@ -155,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         Returns:
             A deferred list of string tags.
         """
-        return self._simple_select_list(
+        return self.db.simple_select_list(
             table="room_tags",
             keyvalues={"user_id": user_id, "room_id": room_id},
             retcols=("tag", "content"),
@@ -180,7 +180,7 @@ class TagsStore(TagsWorkerStore):
         content_json = json.dumps(content)
 
         def add_tag_txn(txn, next_id):
-            self._simple_upsert_txn(
+            self.db.simple_upsert_txn(
                 txn,
                 table="room_tags",
                 keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
@@ -189,7 +189,7 @@ class TagsStore(TagsWorkerStore):
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction("add_tag", add_tag_txn, next_id)
+            yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
@@ -212,7 +212,7 @@ class TagsStore(TagsWorkerStore):
             self._update_revision_txn(txn, user_id, room_id, next_id)
 
         with self._account_data_id_gen.get_next() as next_id:
-            yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
+            yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
 
         self.get_tags_for_user.invalidate((user_id,))
 
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py
index 01b1be5e14..5b07c2fbc0 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import Database
 from synapse.util.caches.expiringcache import ExpiringCache
 
 # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
@@ -52,8 +53,8 @@ class TransactionStore(SQLBaseStore):
     """A collection of queries for handling PDUs.
     """
 
-    def __init__(self, db_conn, hs):
-        super(TransactionStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(TransactionStore, self).__init__(database, db_conn, hs)
 
         self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
 
@@ -77,7 +78,7 @@ class TransactionStore(SQLBaseStore):
             this transaction or a 2-tuple of (int, dict)
         """
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "get_received_txn_response",
             self._get_received_txn_response,
             transaction_id,
@@ -85,7 +86,7 @@ class TransactionStore(SQLBaseStore):
         )
 
     def _get_received_txn_response(self, txn, transaction_id, origin):
-        result = self._simple_select_one_txn(
+        result = self.db.simple_select_one_txn(
             txn,
             table="received_transactions",
             keyvalues={"transaction_id": transaction_id, "origin": origin},
@@ -119,7 +120,7 @@ class TransactionStore(SQLBaseStore):
             response_json (str)
         """
 
-        return self._simple_insert(
+        return self.db.simple_insert(
             table="received_transactions",
             values={
                 "transaction_id": transaction_id,
@@ -148,7 +149,7 @@ class TransactionStore(SQLBaseStore):
         if result is not SENTINEL:
             return result
 
-        result = yield self.runInteraction(
+        result = yield self.db.runInteraction(
             "get_destination_retry_timings",
             self._get_destination_retry_timings,
             destination,
@@ -160,7 +161,7 @@ class TransactionStore(SQLBaseStore):
         return result
 
     def _get_destination_retry_timings(self, txn, destination):
-        result = self._simple_select_one_txn(
+        result = self.db.simple_select_one_txn(
             txn,
             table="destinations",
             keyvalues={"destination": destination},
@@ -187,7 +188,7 @@ class TransactionStore(SQLBaseStore):
         """
 
         self._destination_retry_cache.pop(destination, None)
-        return self.runInteraction(
+        return self.db.runInteraction(
             "set_destination_retry_timings",
             self._set_destination_retry_timings,
             destination,
@@ -227,7 +228,7 @@ class TransactionStore(SQLBaseStore):
         # We need to be careful here as the data may have changed from under us
         # due to a worker setting the timings.
 
-        prev_row = self._simple_select_one_txn(
+        prev_row = self.db.simple_select_one_txn(
             txn,
             table="destinations",
             keyvalues={"destination": destination},
@@ -236,7 +237,7 @@ class TransactionStore(SQLBaseStore):
         )
 
         if not prev_row:
-            self._simple_insert_txn(
+            self.db.simple_insert_txn(
                 txn,
                 table="destinations",
                 values={
@@ -247,7 +248,7 @@ class TransactionStore(SQLBaseStore):
                 },
             )
         elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
-            self._simple_update_one_txn(
+            self.db.simple_update_one_txn(
                 txn,
                 "destinations",
                 keyvalues={"destination": destination},
@@ -270,4 +271,6 @@ class TransactionStore(SQLBaseStore):
         def _cleanup_transactions_txn(txn):
             txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
 
-        return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn)
+        return self.db.runInteraction(
+            "_cleanup_transactions", _cleanup_transactions_txn
+        )
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 652abe0e6a..90c180ec6d 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -19,9 +19,9 @@ import re
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, JoinRules
-from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.data_stores.main.state import StateFilter
 from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
@@ -32,30 +32,30 @@ logger = logging.getLogger(__name__)
 TEMP_TABLE = "_temp_populate_user_directory"
 
 
-class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore):
+class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
     # How many records do we calculate before sending it to
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, db_conn, hs):
-        super(UserDirectoryBackgroundUpdateStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
 
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_createtables",
             self._populate_user_directory_createtables,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_process_rooms",
             self._populate_user_directory_process_rooms,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_process_users",
             self._populate_user_directory_process_users,
         )
-        self.register_background_update_handler(
+        self.db.updates.register_background_update_handler(
             "populate_user_directory_cleanup", self._populate_user_directory_cleanup
         )
 
@@ -85,7 +85,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
             """
             txn.execute(sql)
             rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
-            self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+            self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
             del rooms
 
             # If search all users is on, get all the users we want to add.
@@ -100,15 +100,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
                 txn.execute("SELECT name FROM users")
                 users = [{"user_id": x[0]} for x in txn.fetchall()]
 
-                self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+                self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
 
         new_pos = yield self.get_max_stream_id_in_current_state_deltas()
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "populate_user_directory_temp_build", _make_staging_area
         )
-        yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+        yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
 
-        yield self._end_background_update("populate_user_directory_createtables")
+        yield self.db.updates._end_background_update(
+            "populate_user_directory_createtables"
+        )
         return 1
 
     @defer.inlineCallbacks
@@ -116,7 +118,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
         """
         Update the user directory stream position, then clean up the old tables.
         """
-        position = yield self._simple_select_one_onecol(
+        position = yield self.db.simple_select_one_onecol(
             TEMP_TABLE + "_position", None, "position"
         )
         yield self.update_user_directory_stream_pos(position)
@@ -126,11 +128,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
             txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
 
-        yield self.runInteraction(
+        yield self.db.runInteraction(
             "populate_user_directory_cleanup", _delete_staging_area
         )
 
-        yield self._end_background_update("populate_user_directory_cleanup")
+        yield self.db.updates._end_background_update("populate_user_directory_cleanup")
         return 1
 
     @defer.inlineCallbacks
@@ -170,13 +172,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
 
             return rooms_to_work_on
 
-        rooms_to_work_on = yield self.runInteraction(
+        rooms_to_work_on = yield self.db.runInteraction(
             "populate_user_directory_temp_read", _get_next_batch
         )
 
         # No more rooms -- complete the transaction.
         if not rooms_to_work_on:
-            yield self._end_background_update("populate_user_directory_process_rooms")
+            yield self.db.updates._end_background_update(
+                "populate_user_directory_process_rooms"
+            )
             return 1
 
         logger.info(
@@ -243,12 +247,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
                         to_insert.clear()
 
             # We've finished a room. Delete it from the table.
-            yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+            yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
             # Update the remaining counter.
             progress["remaining"] -= 1
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "populate_user_directory",
-                self._background_update_progress_txn,
+                self.db.updates._background_update_progress_txn,
                 "populate_user_directory_process_rooms",
                 progress,
             )
@@ -267,7 +271,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
         If search_all_users is enabled, add all of the users to the user directory.
         """
         if not self.hs.config.user_directory_search_all_users:
-            yield self._end_background_update("populate_user_directory_process_users")
+            yield self.db.updates._end_background_update(
+                "populate_user_directory_process_users"
+            )
             return 1
 
         def _get_next_batch(txn):
@@ -291,13 +297,15 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
 
             return users_to_work_on
 
-        users_to_work_on = yield self.runInteraction(
+        users_to_work_on = yield self.db.runInteraction(
             "populate_user_directory_temp_read", _get_next_batch
         )
 
         # No more users -- complete the transaction.
         if not users_to_work_on:
-            yield self._end_background_update("populate_user_directory_process_users")
+            yield self.db.updates._end_background_update(
+                "populate_user_directory_process_users"
+            )
             return 1
 
         logger.info(
@@ -312,12 +320,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
             )
 
             # We've finished processing a user. Delete it from the table.
-            yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
+            yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
             # Update the remaining counter.
             progress["remaining"] -= 1
-            yield self.runInteraction(
+            yield self.db.runInteraction(
                 "populate_user_directory",
-                self._background_update_progress_txn,
+                self.db.updates._background_update_progress_txn,
                 "populate_user_directory_process_users",
                 progress,
             )
@@ -361,7 +369,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
         """
 
         def _update_profile_in_user_dir_txn(txn):
-            new_entry = self._simple_upsert_txn(
+            new_entry = self.db.simple_upsert_txn(
                 txn,
                 table="user_directory",
                 keyvalues={"user_id": user_id},
@@ -435,7 +443,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
                         )
             elif isinstance(self.database_engine, Sqlite3Engine):
                 value = "%s %s" % (user_id, display_name) if display_name else user_id
-                self._simple_upsert_txn(
+                self.db.simple_upsert_txn(
                     txn,
                     table="user_directory_search",
                     keyvalues={"user_id": user_id},
@@ -448,7 +456,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
 
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "update_profile_in_user_dir", _update_profile_in_user_dir_txn
         )
 
@@ -462,7 +470,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
         """
 
         def _add_users_who_share_room_txn(txn):
-            self._simple_upsert_many_txn(
+            self.db.simple_upsert_many_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 key_names=["user_id", "other_user_id", "room_id"],
@@ -474,7 +482,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
                 value_values=None,
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
 
@@ -489,7 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
 
         def _add_users_in_public_rooms_txn(txn):
 
-            self._simple_upsert_many_txn(
+            self.db.simple_upsert_many_txn(
                 txn,
                 table="users_in_public_rooms",
                 key_names=["user_id", "room_id"],
@@ -498,7 +506,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
                 value_values=None,
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "add_users_in_public_rooms", _add_users_in_public_rooms_txn
         )
 
@@ -513,13 +521,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
             txn.execute("DELETE FROM users_who_share_private_rooms")
             txn.call_after(self.get_user_in_directory.invalidate_all)
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
         )
 
     @cached()
     def get_user_in_directory(self, user_id):
-        return self._simple_select_one(
+        return self.db.simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
             retcols=("display_name", "avatar_url"),
@@ -528,7 +536,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
         )
 
     def update_user_directory_stream_pos(self, stream_id):
-        return self._simple_update_one(
+        return self.db.simple_update_one(
             table="user_directory_stream_pos",
             keyvalues={},
             updatevalues={"stream_id": stream_id},
@@ -542,47 +550,47 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
     # add_users_who_share_private_rooms?
     SHARE_PRIVATE_WORKING_SET = 500
 
-    def __init__(self, db_conn, hs):
-        super(UserDirectoryStore, self).__init__(db_conn, hs)
+    def __init__(self, database: Database, db_conn, hs):
+        super(UserDirectoryStore, self).__init__(database, db_conn, hs)
 
     def remove_from_user_dir(self, user_id):
         def _remove_from_user_dir_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="user_directory", keyvalues={"user_id": user_id}
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="user_directory_search", keyvalues={"user_id": user_id}
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 keyvalues={"user_id": user_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 keyvalues={"other_user_id": user_id},
             )
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
 
-        return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
+        return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
 
     @defer.inlineCallbacks
     def get_users_in_dir_due_to_room(self, room_id):
         """Get all user_ids that are in the room directory because they're
         in the given room_id
         """
-        user_ids_share_pub = yield self._simple_select_onecol(
+        user_ids_share_pub = yield self.db.simple_select_onecol(
             table="users_in_public_rooms",
             keyvalues={"room_id": room_id},
             retcol="user_id",
             desc="get_users_in_dir_due_to_room",
         )
 
-        user_ids_share_priv = yield self._simple_select_onecol(
+        user_ids_share_priv = yield self.db.simple_select_onecol(
             table="users_who_share_private_rooms",
             keyvalues={"room_id": room_id},
             retcol="other_user_id",
@@ -605,23 +613,23 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         """
 
         def _remove_user_who_share_room_txn(txn):
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_who_share_private_rooms",
                 keyvalues={"other_user_id": user_id, "room_id": room_id},
             )
-            self._simple_delete_txn(
+            self.db.simple_delete_txn(
                 txn,
                 table="users_in_public_rooms",
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
 
-        return self.runInteraction(
+        return self.db.runInteraction(
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )
 
@@ -636,14 +644,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         Returns:
             list: user_id
         """
-        rows = yield self._simple_select_onecol(
+        rows = yield self.db.simple_select_onecol(
             table="users_who_share_private_rooms",
             keyvalues={"user_id": user_id},
             retcol="room_id",
             desc="get_rooms_user_is_in",
         )
 
-        pub_rows = yield self._simple_select_onecol(
+        pub_rows = yield self.db.simple_select_onecol(
             table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
             retcol="room_id",
@@ -674,14 +682,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             ) f2 USING (room_id)
         """
 
-        rows = yield self._execute(
+        rows = yield self.db.execute(
             "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
         )
 
         return [room_id for room_id, in rows]
 
     def get_user_directory_stream_pos(self):
-        return self._simple_select_one_onecol(
+        return self.db.simple_select_one_onecol(
             table="user_directory_stream_pos",
             keyvalues={},
             retcol="stream_id",
@@ -786,8 +794,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
             # This should be unreachable.
             raise Exception("Unrecognized database engine")
 
-        results = yield self._execute(
-            "search_user_dir", self.cursor_to_dict, sql, *args
+        results = yield self.db.execute(
+            "search_user_dir", self.db.cursor_to_dict, sql, *args
         )
 
         limited = len(results) > limit
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index aa4f0da5f0..af8025bc17 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore):
         Returns:
             Deferred[bool]: True if the user has requested erasure
         """
-        return self._simple_select_onecol(
+        return self.db.simple_select_onecol(
             table="erased_users",
             keyvalues={"user_id": user_id},
             retcol="1",
@@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore):
         # iterate it multiple times, and (b) avoiding duplicates.
         user_ids = tuple(set(user_ids))
 
-        rows = yield self._simple_select_many_batch(
+        rows = yield self.db.simple_select_many_batch(
             table="erased_users",
             column="user_id",
             iterable=user_ids,
@@ -88,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore):
 
             self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
 
-        return self.runInteraction("mark_user_erased", f)
+        return self.db.runInteraction("mark_user_erased", f)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
new file mode 100644
index 0000000000..ec19ae1d9d
--- /dev/null
+++ b/synapse/storage/database.py
@@ -0,0 +1,1490 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import sys
+import time
+from typing import Iterable, Tuple
+
+from six import iteritems, iterkeys, itervalues
+from six.moves import intern, range
+
+from prometheus_client import Histogram
+
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from synapse.logging.context import LoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.background_updates import BackgroundUpdater
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.util.stringutils import exception_to_unicode
+
+# import a function which will return a monotonic time, in seconds
+try:
+    # on python 3, use time.monotonic, since time.clock can go backwards
+    from time import monotonic as monotonic_time
+except ImportError:
+    # ... but python 2 doesn't have it
+    from time import clock as monotonic_time
+
+logger = logging.getLogger(__name__)
+
+try:
+    MAX_TXN_ID = sys.maxint - 1
+except AttributeError:
+    # python 3 does not have a maximum int value
+    MAX_TXN_ID = 2 ** 63 - 1
+
+sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
+perf_logger = logging.getLogger("synapse.storage.TIME")
+
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
+
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+
+
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+    "user_ips": "user_ips_device_unique_index",
+    "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+    "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+    "event_search": "event_search_event_id_idx",
+}
+
+
+class LoggingTransaction(object):
+    """An object that almost-transparently proxies for the 'txn' object
+    passed to the constructor. Adds logging and metrics to the .execute()
+    method.
+
+    Args:
+        txn: The database transcation object to wrap.
+        name (str): The name of this transactions for logging.
+        database_engine (Sqlite3Engine|PostgresEngine)
+        after_callbacks(list|None): A list that callbacks will be appended to
+            that have been added by `call_after` which should be run on
+            successful completion of the transaction. None indicates that no
+            callbacks should be allowed to be scheduled to run.
+        exception_callbacks(list|None): A list that callbacks will be appended
+            to that have been added by `call_on_exception` which should be run
+            if transaction ends with an error. None indicates that no callbacks
+            should be allowed to be scheduled to run.
+    """
+
+    __slots__ = [
+        "txn",
+        "name",
+        "database_engine",
+        "after_callbacks",
+        "exception_callbacks",
+    ]
+
+    def __init__(
+        self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
+    ):
+        object.__setattr__(self, "txn", txn)
+        object.__setattr__(self, "name", name)
+        object.__setattr__(self, "database_engine", database_engine)
+        object.__setattr__(self, "after_callbacks", after_callbacks)
+        object.__setattr__(self, "exception_callbacks", exception_callbacks)
+
+    def call_after(self, callback, *args, **kwargs):
+        """Call the given callback on the main twisted thread after the
+        transaction has finished. Used to invalidate the caches on the
+        correct thread.
+        """
+        self.after_callbacks.append((callback, args, kwargs))
+
+    def call_on_exception(self, callback, *args, **kwargs):
+        self.exception_callbacks.append((callback, args, kwargs))
+
+    def __getattr__(self, name):
+        return getattr(self.txn, name)
+
+    def __setattr__(self, name, value):
+        setattr(self.txn, name, value)
+
+    def __iter__(self):
+        return self.txn.__iter__()
+
+    def execute_batch(self, sql, args):
+        if isinstance(self.database_engine, PostgresEngine):
+            from psycopg2.extras import execute_batch
+
+            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+        else:
+            for val in args:
+                self.execute(sql, val)
+
+    def execute(self, sql, *args):
+        self._do_execute(self.txn.execute, sql, *args)
+
+    def executemany(self, sql, *args):
+        self._do_execute(self.txn.executemany, sql, *args)
+
+    def _make_sql_one_line(self, sql):
+        "Strip newlines out of SQL so that the loggers in the DB are on one line"
+        return " ".join(l.strip() for l in sql.splitlines() if l.strip())
+
+    def _do_execute(self, func, sql, *args):
+        sql = self._make_sql_one_line(sql)
+
+        # TODO(paul): Maybe use 'info' and 'debug' for values?
+        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
+
+        sql = self.database_engine.convert_param_style(sql)
+        if args:
+            try:
+                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
+            except Exception:
+                # Don't let logging failures stop SQL from working
+                pass
+
+        start = time.time()
+
+        try:
+            return func(sql, *args)
+        except Exception as e:
+            logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+            raise
+        finally:
+            secs = time.time() - start
+            sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+            sql_query_timer.labels(sql.split()[0]).observe(secs)
+
+
+class PerformanceCounters(object):
+    def __init__(self):
+        self.current_counters = {}
+        self.previous_counters = {}
+
+    def update(self, key, duration_secs):
+        count, cum_time = self.current_counters.get(key, (0, 0))
+        count += 1
+        cum_time += duration_secs
+        self.current_counters[key] = (count, cum_time)
+
+    def interval(self, interval_duration_secs, limit=3):
+        counters = []
+        for name, (count, cum_time) in iteritems(self.current_counters):
+            prev_count, prev_time = self.previous_counters.get(name, (0, 0))
+            counters.append(
+                (
+                    (cum_time - prev_time) / interval_duration_secs,
+                    count - prev_count,
+                    name,
+                )
+            )
+
+        self.previous_counters = dict(self.current_counters)
+
+        counters.sort(reverse=True)
+
+        top_n_counters = ", ".join(
+            "%s(%d): %.3f%%" % (name, count, 100 * ratio)
+            for ratio, count, name in counters[:limit]
+        )
+
+        return top_n_counters
+
+
+class Database(object):
+    """Wraps a single physical database and connection pool.
+
+    A single database may be used by multiple data stores.
+    """
+
+    _TXN_ID = 0
+
+    def __init__(self, hs):
+        self.hs = hs
+        self._clock = hs.get_clock()
+        self._db_pool = hs.get_db_pool()
+
+        self.updates = BackgroundUpdater(hs, self)
+
+        self._previous_txn_total_time = 0
+        self._current_txn_total_time = 0
+        self._previous_loop_ts = 0
+
+        # TODO(paul): These can eventually be removed once the metrics code
+        #   is running in mainline, and we have some nice monitoring frontends
+        #   to watch it
+        self._txn_perf_counters = PerformanceCounters()
+
+        self.engine = hs.database_engine
+
+        # A set of tables that are not safe to use native upserts in.
+        self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+        # We add the user_directory_search table to the blacklist on SQLite
+        # because the existing search table does not have an index, making it
+        # unsafe to use native upserts.
+        if isinstance(self.engine, Sqlite3Engine):
+            self._unsafe_to_upsert_tables.add("user_directory_search")
+
+        if self.engine.can_native_upsert:
+            # Check ASAP (and then later, every 1s) to see if we have finished
+            # background updates of tables that aren't safe to update.
+            self._clock.call_later(
+                0.0,
+                run_as_background_process,
+                "upsert_safety_check",
+                self._check_safe_to_upsert,
+            )
+
+    @defer.inlineCallbacks
+    def _check_safe_to_upsert(self):
+        """
+        Is it safe to use native UPSERT?
+
+        If there are background updates, we will need to wait, as they may be
+        the addition of indexes that set the UNIQUE constraint that we require.
+
+        If the background updates have not completed, wait 15 sec and check again.
+        """
+        updates = yield self.simple_select_list(
+            "background_updates",
+            keyvalues=None,
+            retcols=["update_name"],
+            desc="check_background_updates",
+        )
+        updates = [x["update_name"] for x in updates]
+
+        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+            if update_name not in updates:
+                logger.debug("Now safe to upsert in %s", table)
+                self._unsafe_to_upsert_tables.discard(table)
+
+        # If there's any updates still running, reschedule to run.
+        if updates:
+            self._clock.call_later(
+                15.0,
+                run_as_background_process,
+                "upsert_safety_check",
+                self._check_safe_to_upsert,
+            )
+
+    def start_profiling(self):
+        self._previous_loop_ts = monotonic_time()
+
+        def loop():
+            curr = self._current_txn_total_time
+            prev = self._previous_txn_total_time
+            self._previous_txn_total_time = curr
+
+            time_now = monotonic_time()
+            time_then = self._previous_loop_ts
+            self._previous_loop_ts = time_now
+
+            duration = time_now - time_then
+            ratio = (curr - prev) / duration
+
+            top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
+
+            perf_logger.info(
+                "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
+            )
+
+        self._clock.looping_call(loop, 10000)
+
+    def new_transaction(
+        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
+    ):
+        start = monotonic_time()
+        txn_id = self._TXN_ID
+
+        # We don't really need these to be unique, so lets stop it from
+        # growing really large.
+        self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
+
+        name = "%s-%x" % (desc, txn_id)
+
+        transaction_logger.debug("[TXN START] {%s}", name)
+
+        try:
+            i = 0
+            N = 5
+            while True:
+                cursor = LoggingTransaction(
+                    conn.cursor(),
+                    name,
+                    self.engine,
+                    after_callbacks,
+                    exception_callbacks,
+                )
+                try:
+                    r = func(cursor, *args, **kwargs)
+                    conn.commit()
+                    return r
+                except self.engine.module.OperationalError as e:
+                    # This can happen if the database disappears mid
+                    # transaction.
+                    logger.warning(
+                        "[TXN OPERROR] {%s} %s %d/%d",
+                        name,
+                        exception_to_unicode(e),
+                        i,
+                        N,
+                    )
+                    if i < N:
+                        i += 1
+                        try:
+                            conn.rollback()
+                        except self.engine.module.Error as e1:
+                            logger.warning(
+                                "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
+                            )
+                        continue
+                    raise
+                except self.engine.module.DatabaseError as e:
+                    if self.engine.is_deadlock(e):
+                        logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                        if i < N:
+                            i += 1
+                            try:
+                                conn.rollback()
+                            except self.engine.module.Error as e1:
+                                logger.warning(
+                                    "[TXN EROLL] {%s} %s",
+                                    name,
+                                    exception_to_unicode(e1),
+                                )
+                            continue
+                    raise
+                finally:
+                    # we're either about to retry with a new cursor, or we're about to
+                    # release the connection. Once we release the connection, it could
+                    # get used for another query, which might do a conn.rollback().
+                    #
+                    # In the latter case, even though that probably wouldn't affect the
+                    # results of this transaction, python's sqlite will reset all
+                    # statements on the connection [1], which will make our cursor
+                    # invalid [2].
+                    #
+                    # In any case, continuing to read rows after commit()ing seems
+                    # dubious from the PoV of ACID transactional semantics
+                    # (sqlite explicitly says that once you commit, you may see rows
+                    # from subsequent updates.)
+                    #
+                    # In psycopg2, cursors are essentially a client-side fabrication -
+                    # all the data is transferred to the client side when the statement
+                    # finishes executing - so in theory we could go on streaming results
+                    # from the cursor, but attempting to do so would make us
+                    # incompatible with sqlite, so let's make sure we're not doing that
+                    # by closing the cursor.
+                    #
+                    # (*named* cursors in psycopg2 are different and are proper server-
+                    # side things, but (a) we don't use them and (b) they are implicitly
+                    # closed by ending the transaction anyway.)
+                    #
+                    # In short, if we haven't finished with the cursor yet, that's a
+                    # problem waiting to bite us.
+                    #
+                    # TL;DR: we're done with the cursor, so we can close it.
+                    #
+                    # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
+                    # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
+                    cursor.close()
+        except Exception as e:
+            logger.debug("[TXN FAIL] {%s} %s", name, e)
+            raise
+        finally:
+            end = monotonic_time()
+            duration = end - start
+
+            LoggingContext.current_context().add_database_transaction(duration)
+
+            transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
+
+            self._current_txn_total_time += duration
+            self._txn_perf_counters.update(desc, duration)
+            sql_txn_timer.labels(desc).observe(duration)
+
+    @defer.inlineCallbacks
+    def runInteraction(self, desc, func, *args, **kwargs):
+        """Starts a transaction on the database and runs a given function
+
+        Arguments:
+            desc (str): description of the transaction, for logging and metrics
+            func (func): callback function, which will be called with a
+                database transaction (twisted.enterprise.adbapi.Transaction) as
+                its first argument, followed by `args` and `kwargs`.
+
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
+        after_callbacks = []
+        exception_callbacks = []
+
+        if LoggingContext.current_context() == LoggingContext.sentinel:
+            logger.warning("Starting db txn '%s' from sentinel context", desc)
+
+        try:
+            result = yield self.runWithConnection(
+                self.new_transaction,
+                desc,
+                after_callbacks,
+                exception_callbacks,
+                func,
+                *args,
+                **kwargs
+            )
+
+            for after_callback, after_args, after_kwargs in after_callbacks:
+                after_callback(*after_args, **after_kwargs)
+        except:  # noqa: E722, as we reraise the exception this is fine.
+            for after_callback, after_args, after_kwargs in exception_callbacks:
+                after_callback(*after_args, **after_kwargs)
+            raise
+
+        return result
+
+    @defer.inlineCallbacks
+    def runWithConnection(self, func, *args, **kwargs):
+        """Wraps the .runWithConnection() method on the underlying db_pool.
+
+        Arguments:
+            func (func): callback function, which will be called with a
+                database connection (twisted.enterprise.adbapi.Connection) as
+                its first argument, followed by `args` and `kwargs`.
+            args (list): positional args to pass to `func`
+            kwargs (dict): named args to pass to `func`
+
+        Returns:
+            Deferred: The result of func
+        """
+        parent_context = LoggingContext.current_context()
+        if parent_context == LoggingContext.sentinel:
+            logger.warning(
+                "Starting db connection from sentinel context: metrics will be lost"
+            )
+            parent_context = None
+
+        start_time = monotonic_time()
+
+        def inner_func(conn, *args, **kwargs):
+            with LoggingContext("runWithConnection", parent_context) as context:
+                sched_duration_sec = monotonic_time() - start_time
+                sql_scheduling_timer.observe(sched_duration_sec)
+                context.add_database_scheduled(sched_duration_sec)
+
+                if self.engine.is_connection_closed(conn):
+                    logger.debug("Reconnecting closed database connection")
+                    conn.reconnect()
+
+                return func(conn, *args, **kwargs)
+
+        result = yield make_deferred_yieldable(
+            self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+        )
+
+        return result
+
+    @staticmethod
+    def cursor_to_dict(cursor):
+        """Converts a SQL cursor into an list of dicts.
+
+        Args:
+            cursor : The DBAPI cursor which has executed a query.
+        Returns:
+            A list of dicts where the key is the column header.
+        """
+        col_headers = list(intern(str(column[0])) for column in cursor.description)
+        results = list(dict(zip(col_headers, row)) for row in cursor)
+        return results
+
+    def execute(self, desc, decoder, query, *args):
+        """Runs a single query for a result set.
+
+        Args:
+            decoder - The function which can resolve the cursor results to
+                something meaningful.
+            query - The query string to execute
+            *args - Query args.
+        Returns:
+            The result of decoder(results)
+        """
+
+        def interaction(txn):
+            txn.execute(query, args)
+            if decoder:
+                return decoder(txn)
+            else:
+                return txn.fetchall()
+
+        return self.runInteraction(desc, interaction)
+
+    # "Simple" SQL API methods that operate on a single table with no JOINs,
+    # no complex WHERE clauses, just a dict of values for columns.
+
+    @defer.inlineCallbacks
+    def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+        """Executes an INSERT query on the named table.
+
+        Args:
+            table : string giving the table name
+            values : dict of new column names and values for them
+            or_ignore : bool stating whether an exception should be raised
+                when a conflicting row already exists. If True, False will be
+                returned by the function instead
+            desc : string giving a description of the transaction
+
+        Returns:
+            bool: Whether the row was inserted or not. Only useful when
+            `or_ignore` is True
+        """
+        try:
+            yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+        except self.engine.module.IntegrityError:
+            # We have to do or_ignore flag at this layer, since we can't reuse
+            # a cursor after we receive an error from the db.
+            if not or_ignore:
+                raise
+            return False
+        return True
+
+    @staticmethod
+    def simple_insert_txn(txn, table, values):
+        keys, vals = zip(*values.items())
+
+        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+            table,
+            ", ".join(k for k in keys),
+            ", ".join("?" for _ in keys),
+        )
+
+        txn.execute(sql, vals)
+
+    def simple_insert_many(self, table, values, desc):
+        return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+
+    @staticmethod
+    def simple_insert_many_txn(txn, table, values):
+        if not values:
+            return
+
+        # This is a *slight* abomination to get a list of tuples of key names
+        # and a list of tuples of value names.
+        #
+        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+        #
+        # The sort is to ensure that we don't rely on dictionary iteration
+        # order.
+        keys, vals = zip(
+            *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+        )
+
+        for k in keys:
+            if k != keys[0]:
+                raise RuntimeError("All items must have the same keys")
+
+        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+            table,
+            ", ".join(k for k in keys[0]),
+            ", ".join("?" for _ in keys[0]),
+        )
+
+        txn.executemany(sql, vals)
+
+    @defer.inlineCallbacks
+    def simple_upsert(
+        self,
+        table,
+        keyvalues,
+        values,
+        insertion_values={},
+        desc="simple_upsert",
+        lock=True,
+    ):
+        """
+
+        `lock` should generally be set to True (the default), but can be set
+        to False if either of the following are true:
+
+        * there is a UNIQUE INDEX on the key columns. In this case a conflict
+          will cause an IntegrityError in which case this function will retry
+          the update.
+
+        * we somehow know that we are the only thread which will be updating
+          this table.
+
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key columns and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            Deferred(None or bool): Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        attempts = 0
+        while True:
+            try:
+                result = yield self.runInteraction(
+                    desc,
+                    self.simple_upsert_txn,
+                    table,
+                    keyvalues,
+                    values,
+                    insertion_values,
+                    lock=lock,
+                )
+                return result
+            except self.engine.module.IntegrityError as e:
+                attempts += 1
+                if attempts >= 5:
+                    # don't retry forever, because things other than races
+                    # can cause IntegrityErrors
+                    raise
+
+                # presumably we raced with another transaction: let's retry.
+                logger.warning(
+                    "IntegrityError when upserting into %s; retrying: %s", table, e
+                )
+
+    def simple_upsert_txn(
+        self, txn, table, keyvalues, values, insertion_values={}, lock=True
+    ):
+        """
+        Pick the UPSERT method which works best on the platform. Either the
+        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+        Args:
+            txn: The transaction to use.
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            None or bool: Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+            return self.simple_upsert_txn_native_upsert(
+                txn, table, keyvalues, values, insertion_values=insertion_values
+            )
+        else:
+            return self.simple_upsert_txn_emulated(
+                txn,
+                table,
+                keyvalues,
+                values,
+                insertion_values=insertion_values,
+                lock=lock,
+            )
+
+    def simple_upsert_txn_emulated(
+        self, txn, table, keyvalues, values, insertion_values={}, lock=True
+    ):
+        """
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            bool: Return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        # We need to lock the table :(, unless we're *really* careful
+        if lock:
+            self.engine.lock_table(txn, table)
+
+        def _getwhere(key):
+            # If the value we're passing in is None (aka NULL), we need to use
+            # IS, not =, as NULL = NULL equals NULL (False).
+            if keyvalues[key] is None:
+                return "%s IS ?" % (key,)
+            else:
+                return "%s = ?" % (key,)
+
+        if not values:
+            # If `values` is empty, then all of the values we care about are in
+            # the unique key, so there is nothing to UPDATE. We can just do a
+            # SELECT instead to see if it exists.
+            sql = "SELECT 1 FROM %s WHERE %s" % (
+                table,
+                " AND ".join(_getwhere(k) for k in keyvalues),
+            )
+            sqlargs = list(keyvalues.values())
+            txn.execute(sql, sqlargs)
+            if txn.fetchall():
+                # We have an existing record.
+                return False
+        else:
+            # First try to update.
+            sql = "UPDATE %s SET %s WHERE %s" % (
+                table,
+                ", ".join("%s = ?" % (k,) for k in values),
+                " AND ".join(_getwhere(k) for k in keyvalues),
+            )
+            sqlargs = list(values.values()) + list(keyvalues.values())
+
+            txn.execute(sql, sqlargs)
+            if txn.rowcount > 0:
+                # successfully updated at least one row.
+                return False
+
+        # We didn't find any existing rows, so insert a new one
+        allvalues = {}
+        allvalues.update(keyvalues)
+        allvalues.update(values)
+        allvalues.update(insertion_values)
+
+        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues),
+        )
+        txn.execute(sql, list(allvalues.values()))
+        # successfully inserted
+        return True
+
+    def simple_upsert_txn_native_upsert(
+        self, txn, table, keyvalues, values, insertion_values={}
+    ):
+        """
+        Use the native UPSERT functionality in recent PostgreSQL versions.
+
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+        Returns:
+            None
+        """
+        allvalues = {}
+        allvalues.update(keyvalues)
+        allvalues.update(insertion_values)
+
+        if not values:
+            latter = "NOTHING"
+        else:
+            allvalues.update(values)
+            latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
+
+        sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues),
+            ", ".join(k for k in keyvalues),
+            latter,
+        )
+        txn.execute(sql, list(allvalues.values()))
+
+    def simple_upsert_many_txn(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+            return self.simple_upsert_many_txn_native_upsert(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+        else:
+            return self.simple_upsert_many_txn_emulated(
+                txn, table, key_names, key_values, value_names, value_values
+            )
+
+    def simple_upsert_many_txn_emulated(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times, but without native UPSERT support or batching.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        # No value columns, therefore make a blank list so that the following
+        # zip() works correctly.
+        if not value_names:
+            value_values = [() for x in range(len(key_values))]
+
+        for keyv, valv in zip(key_values, value_values):
+            _keys = {x: y for x, y in zip(key_names, keyv)}
+            _vals = {x: y for x, y in zip(value_names, valv)}
+
+            self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
+
+    def simple_upsert_many_txn_native_upsert(
+        self, txn, table, key_names, key_values, value_names, value_values
+    ):
+        """
+        Upsert, many times, using batching where possible.
+
+        Args:
+            table (str): The table to upsert into
+            key_names (list[str]): The key column names.
+            key_values (list[list]): A list of each row's key column values.
+            value_names (list[str]): The value column names. If empty, no
+                values will be used, even if value_values is provided.
+            value_values (list[list]): A list of each row's value column values.
+        Returns:
+            None
+        """
+        allnames = []
+        allnames.extend(key_names)
+        allnames.extend(value_names)
+
+        if not value_names:
+            # No value columns, therefore make a blank list so that the
+            # following zip() works correctly.
+            latter = "NOTHING"
+            value_values = [() for x in range(len(key_values))]
+        else:
+            latter = "UPDATE SET " + ", ".join(
+                k + "=EXCLUDED." + k for k in value_names
+            )
+
+        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+            table,
+            ", ".join(k for k in allnames),
+            ", ".join("?" for _ in allnames),
+            ", ".join(key_names),
+            latter,
+        )
+
+        args = []
+
+        for x, y in zip(key_values, value_values):
+            args.append(tuple(x) + tuple(y))
+
+        return txn.execute_batch(sql, args)
+
+    def simple_select_one(
+        self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
+    ):
+        """Executes a SELECT query on the named table, which is expected to
+        return a single row, returning multiple columns from it.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+            retcols : list of strings giving the names of the columns to return
+
+            allow_none : If true, return None instead of failing if the SELECT
+              statement returns no rows
+        """
+        return self.runInteraction(
+            desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
+        )
+
+    def simple_select_one_onecol(
+        self,
+        table,
+        keyvalues,
+        retcol,
+        allow_none=False,
+        desc="simple_select_one_onecol",
+    ):
+        """Executes a SELECT query on the named table, which is expected to
+        return a single row, returning a single column from it.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+            retcol : string giving the name of the column to return
+        """
+        return self.runInteraction(
+            desc,
+            self.simple_select_one_onecol_txn,
+            table,
+            keyvalues,
+            retcol,
+            allow_none=allow_none,
+        )
+
+    @classmethod
+    def simple_select_one_onecol_txn(
+        cls, txn, table, keyvalues, retcol, allow_none=False
+    ):
+        ret = cls.simple_select_onecol_txn(
+            txn, table=table, keyvalues=keyvalues, retcol=retcol
+        )
+
+        if ret:
+            return ret[0]
+        else:
+            if allow_none:
+                return None
+            else:
+                raise StoreError(404, "No row found")
+
+    @staticmethod
+    def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+        sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
+
+        if keyvalues:
+            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+            txn.execute(sql, list(keyvalues.values()))
+        else:
+            txn.execute(sql)
+
+        return [r[0] for r in txn]
+
+    def simple_select_onecol(
+        self, table, keyvalues, retcol, desc="simple_select_onecol"
+    ):
+        """Executes a SELECT query on the named table, which returns a list
+        comprising of the values of the named column from the selected rows.
+
+        Args:
+            table (str): table name
+            keyvalues (dict|None): column names and values to select the rows with
+            retcol (str): column whos value we wish to retrieve.
+
+        Returns:
+            Deferred: Results in a list
+        """
+        return self.runInteraction(
+            desc, self.simple_select_onecol_txn, table, keyvalues, retcol
+        )
+
+    def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            keyvalues (dict[str, Any] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self.runInteraction(
+            desc, self.simple_select_list_txn, table, keyvalues, retcols
+        )
+
+    @classmethod
+    def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            retcols (iterable[str]): the names of the columns to return
+        """
+        if keyvalues:
+            sql = "SELECT %s FROM %s WHERE %s" % (
+                ", ".join(retcols),
+                table,
+                " AND ".join("%s = ?" % (k,) for k in keyvalues),
+            )
+            txn.execute(sql, list(keyvalues.values()))
+        else:
+            sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+            txn.execute(sql)
+
+        return cls.cursor_to_dict(txn)
+
+    @defer.inlineCallbacks
+    def simple_select_many_batch(
+        self,
+        table,
+        column,
+        iterable,
+        retcols,
+        keyvalues={},
+        desc="simple_select_many_batch",
+        batch_size=100,
+    ):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        results = []
+
+        if not iterable:
+            return results
+
+        # iterables can not be sliced, so convert it to a list first
+        it_list = list(iterable)
+
+        chunks = [
+            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
+        ]
+        for chunk in chunks:
+            rows = yield self.runInteraction(
+                desc,
+                self.simple_select_many_txn,
+                table,
+                column,
+                chunk,
+                keyvalues,
+                retcols,
+            )
+
+            results.extend(rows)
+
+        return results
+
+    @classmethod
+    def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        if not iterable:
+            return []
+
+        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+        clauses = [clause]
+
+        for key, value in iteritems(keyvalues):
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        sql = "SELECT %s FROM %s WHERE %s" % (
+            ", ".join(retcols),
+            table,
+            " AND ".join(clauses),
+        )
+
+        txn.execute(sql, values)
+        return cls.cursor_to_dict(txn)
+
+    def simple_update(self, table, keyvalues, updatevalues, desc):
+        return self.runInteraction(
+            desc, self.simple_update_txn, table, keyvalues, updatevalues
+        )
+
+    @staticmethod
+    def simple_update_txn(txn, table, keyvalues, updatevalues):
+        if keyvalues:
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+        else:
+            where = ""
+
+        update_sql = "UPDATE %s SET %s %s" % (
+            table,
+            ", ".join("%s = ?" % (k,) for k in updatevalues),
+            where,
+        )
+
+        txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
+
+        return txn.rowcount
+
+    def simple_update_one(
+        self, table, keyvalues, updatevalues, desc="simple_update_one"
+    ):
+        """Executes an UPDATE query on the named table, setting new values for
+        columns in a row matching the key values.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+            updatevalues : dict giving column names and values to update
+            retcols : optional list of column names to return
+
+        If present, retcols gives a list of column names on which to perform
+        a SELECT statement *before* performing the UPDATE statement. The values
+        of these will be returned in a dict.
+
+        These are performed within the same transaction, allowing an atomic
+        get-and-set.  This can be used to implement compare-and-set by putting
+        the update column in the 'keyvalues' dict as well.
+        """
+        return self.runInteraction(
+            desc, self.simple_update_one_txn, table, keyvalues, updatevalues
+        )
+
+    @classmethod
+    def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+        rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
+
+        if rowcount == 0:
+            raise StoreError(404, "No row found (%s)" % (table,))
+        if rowcount > 1:
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+    @staticmethod
+    def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+        select_sql = "SELECT %s FROM %s WHERE %s" % (
+            ", ".join(retcols),
+            table,
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
+        )
+
+        txn.execute(select_sql, list(keyvalues.values()))
+        row = txn.fetchone()
+
+        if not row:
+            if allow_none:
+                return None
+            raise StoreError(404, "No row found (%s)" % (table,))
+        if txn.rowcount > 1:
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+        return dict(zip(retcols, row))
+
+    def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+        """Executes a DELETE query on the named table, expecting to delete a
+        single row.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+        """
+        return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+
+    @staticmethod
+    def simple_delete_one_txn(txn, table, keyvalues):
+        """Executes a DELETE query on the named table, expecting to delete a
+        single row.
+
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+        """
+        sql = "DELETE FROM %s WHERE %s" % (
+            table,
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
+        )
+
+        txn.execute(sql, list(keyvalues.values()))
+        if txn.rowcount == 0:
+            raise StoreError(404, "No row found (%s)" % (table,))
+        if txn.rowcount > 1:
+            raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+    def simple_delete(self, table, keyvalues, desc):
+        return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+
+    @staticmethod
+    def simple_delete_txn(txn, table, keyvalues):
+        sql = "DELETE FROM %s WHERE %s" % (
+            table,
+            " AND ".join("%s = ?" % (k,) for k in keyvalues),
+        )
+
+        txn.execute(sql, list(keyvalues.values()))
+        return txn.rowcount
+
+    def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+        return self.runInteraction(
+            desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
+        )
+
+    @staticmethod
+    def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+        """Executes a DELETE query on the named table.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+
+        Returns:
+            int: Number rows deleted
+        """
+        if not iterable:
+            return 0
+
+        sql = "DELETE FROM %s" % table
+
+        clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+        clauses = [clause]
+
+        for key, value in iteritems(keyvalues):
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        if clauses:
+            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+        txn.execute(sql, values)
+
+        return txn.rowcount
+
+    def get_cache_dict(
+        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
+    ):
+        # Fetch a mapping of room_id -> max stream position for "recent" rooms.
+        # It doesn't really matter how many we get, the StreamChangeCache will
+        # do the right thing to ensure it respects the max size of cache.
+        sql = (
+            "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
+            " WHERE %(stream)s > ? - %(limit)s"
+            " GROUP BY %(entity)s"
+        ) % {
+            "table": table,
+            "entity": entity_column,
+            "stream": stream_column,
+            "limit": limit,
+        }
+
+        sql = self.engine.convert_param_style(sql)
+
+        txn = db_conn.cursor()
+        txn.execute(sql, (int(max_value),))
+
+        cache = {row[0]: int(row[1]) for row in txn}
+
+        txn.close()
+
+        if cache:
+            min_val = min(itervalues(cache))
+        else:
+            min_val = max_value
+
+        return cache, min_val
+
+    def simple_select_list_paginate(
+        self,
+        table,
+        orderby,
+        start,
+        limit,
+        retcols,
+        filters=None,
+        keyvalues=None,
+        order_direction="ASC",
+        desc="simple_select_list_paginate",
+    ):
+        """
+        Executes a SELECT query on the named table with start and limit,
+        of row numbers, which may return zero or number of rows from start to limit,
+        returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            filters (dict[str, T] | None):
+                column names and values to filter the rows with, or None to not
+                apply a WHERE ? LIKE ? clause.
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            orderby (str): Column to order the results by.
+            start (int): Index to begin the query at.
+            limit (int): Number of results to return.
+            retcols (iterable[str]): the names of the columns to return
+            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self.runInteraction(
+            desc,
+            self.simple_select_list_paginate_txn,
+            table,
+            orderby,
+            start,
+            limit,
+            retcols,
+            filters=filters,
+            keyvalues=keyvalues,
+            order_direction=order_direction,
+        )
+
+    @classmethod
+    def simple_select_list_paginate_txn(
+        cls,
+        txn,
+        table,
+        orderby,
+        start,
+        limit,
+        retcols,
+        filters=None,
+        keyvalues=None,
+        order_direction="ASC",
+    ):
+        """
+        Executes a SELECT query on the named table with start and limit,
+        of row numbers, which may return zero or number of rows from start to limit,
+        returning the result as a list of dicts.
+
+        Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
+        select attributes with exact matches. All constraints are joined together
+        using 'AND'.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            orderby (str): Column to order the results by.
+            start (int): Index to begin the query at.
+            limit (int): Number of results to return.
+            retcols (iterable[str]): the names of the columns to return
+            filters (dict[str, T] | None):
+                column names and values to filter the rows with, or None to not
+                apply a WHERE ? LIKE ? clause.
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        if order_direction not in ["ASC", "DESC"]:
+            raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+
+        where_clause = "WHERE " if filters or keyvalues else ""
+        arg_list = []
+        if filters:
+            where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
+            arg_list += list(filters.values())
+        where_clause += " AND " if filters and keyvalues else ""
+        if keyvalues:
+            where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
+            arg_list += list(keyvalues.values())
+
+        sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
+            ", ".join(retcols),
+            table,
+            where_clause,
+            orderby,
+            order_direction,
+        )
+        txn.execute(sql, arg_list + [limit, start])
+
+        return cls.cursor_to_dict(txn)
+
+    def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            term (str | None):
+                term for searching the table matched to a column.
+            col (str): column to query term should be matched to
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]] or None
+        """
+
+        return self.runInteraction(
+            desc, self.simple_search_list_txn, table, term, col, retcols
+        )
+
+    @classmethod
+    def simple_search_list_txn(cls, txn, table, term, col, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            term (str | None):
+                term for searching the table matched to a column.
+            col (str): column to query term should be matched to
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]] or None
+        """
+        if term:
+            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
+            termvalues = ["%%" + term + "%%"]
+            txn.execute(sql, termvalues)
+        else:
+            return 0
+
+        return cls.cursor_to_dict(txn)
+
+
+def make_in_list_sql_clause(
+    database_engine, column: str, iterable: Iterable
+) -> Tuple[str, Iterable]:
+    """Returns an SQL clause that checks the given column is in the iterable.
+
+    On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
+    it expands to `column = ANY(?)`. While both DBs support the `IN` form,
+    using the `ANY` form on postgres means that it views queries with
+    different length iterables as the same, helping the query stats.
+
+    Args:
+        database_engine
+        column: Name of the column
+        iterable: The values to check the column against.
+
+    Returns:
+        A tuple of SQL query and the args
+    """
+
+    if database_engine.supports_using_any_list:
+        # This should hopefully be faster, but also makes postgres query
+        # stats easier to understand.
+        return "%s = ANY(?)" % (column,), [list(iterable)]
+    else:
+        return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 2e7753820e..731e1c9d9c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -447,7 +447,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
         # Mark as done.
         cur.execute(
             database_engine.convert_param_style(
-                "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
+                "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
             ),
             (modname, name),
         )